From 12497ed32f70a23552fd35161138b2e1812fc4f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 16 Nov 2024 20:25:47 +0100 Subject: Network: use streams to send large objects --- src/network.cpp | 35 ++++++++++++++++++++-- src/network.h | 6 ++++ src/network/protocol.cpp | 76 +++++++++++++++++++++++++++++++++++++++++------- src/network/protocol.h | 19 ++++++++++-- 4 files changed, 121 insertions(+), 15 deletions(-) diff --git a/src/network.cpp b/src/network.cpp index 26a07e3..8455eea 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -252,7 +252,13 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const NetworkProtocol::Header::ServiceType { uuid }, NetworkProtocol::Header::ServiceRef { ref.digest() }, }); - speer->connection.send(speer->partStorage, move(header), { obj }, true); + + vector< Object > body; + if( obj.encode().size() + 2 * NetworkProtocol::Header::itemSize + <= speer->connection.mtu() ) + body.push_back( obj ); + + speer->connection.send( speer->partStorage, move(header), body, true ); return true; } @@ -454,7 +460,15 @@ void Server::Priv::doListen() } peer->connection.send(peer->partStorage, - NetworkProtocol::Header(reply.header()), reply.body(), false); + NetworkProtocol::Header( reply.header() ), + reply.stream() ? vector< Object >{} : reply.body(), false ); + if( reply.stream() ){ + for( const auto & obj : reply.body() ) { + auto part = obj.encode(); + reply.stream()->write( part.data(), part.size() ); + } + reply.stream()->close(); + } } peer->connection.trySendOutQueue(); @@ -597,6 +611,10 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head if (auto ref = peer.tempStorage.ref(dgst)) { reply.header({ NetworkProtocol::Header::DataResponse { ref->digest() } }); reply.body(*ref); + + if( holds_alternative< unique_ptr< Channel >>( peer.connection.channel() ) and + reply.size() > peer.connection.mtu() and not reply.stream() ) + reply.stream( peer.connection.openOutStream() ); } } } @@ -907,9 +925,17 @@ void ReplyBuilder::body(const Ref & ref) for (const auto & x : mbody) if (x.digest() == ref.digest()) return; + + bodySize += ref->encode().size(); mbody.push_back(ref); } +void ReplyBuilder::stream( shared_ptr< NetworkProtocol::OutStream > s ) +{ + mheader.emplace_back( Header::StreamOpen{ s->id }); + mstream = move( s ); +} + vector ReplyBuilder::body() const { vector res; @@ -919,6 +945,11 @@ vector ReplyBuilder::body() const return res; } +size_t ReplyBuilder::size() const +{ + return mheader.size() * Header::itemSize + bodySize; +} + optional WaitingRef::check() { diff --git a/src/network.h b/src/network.h index 8ea8b6c..ed02167 100644 --- a/src/network.h +++ b/src/network.h @@ -143,13 +143,19 @@ public: void header( Header::Item && ); void body( const Ref & ); + void stream( shared_ptr< NetworkProtocol::OutStream >); const vector< Header::Item > & header() const { return mheader; } vector< Object > body() const; + shared_ptr< NetworkProtocol::OutStream > stream() const { return mstream; } + + size_t size() const; private: vector< Header::Item > mheader; vector< Ref > mbody; + size_t bodySize = 0; + shared_ptr< NetworkProtocol::OutStream > mstream; }; struct WaitingRef diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index dbf1c40..89d6a88 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -16,14 +16,18 @@ using std::nullopt; using std::runtime_error; using std::scoped_lock; using std::to_string; +using std::unique_lock; using std::visit; namespace erebos { +static constexpr uint8_t maxStreamNumber = 0x3F; + struct NetworkProtocol::ConnectionPriv { Connection::Id id() const; + size_t mtu() const; bool send(const PartialStorage &, Header, const vector &, bool secure); bool send( const StreamData & chunk ); @@ -39,7 +43,7 @@ struct NetworkProtocol::ConnectionPriv ChannelState channel = monostate(); vector> secureOutQueue {}; - size_t mtu = 500; // TODO: MTU + size_t mtuLower = 1000; // TODO: MTU vector toAcknowledge {}; @@ -94,11 +98,20 @@ NetworkProtocol::PollResult NetworkProtocol::poll() sendAck = not c->toAcknowledge.empty() && holds_alternative< unique_ptr< Channel >>( c->channel ); - for (const auto & s : c->outStreams) { - scoped_lock slock(s->streamMutex); + for (auto & s : c->outStreams) { + unique_lock slock(s->streamMutex); while (s->hasDataLocked()) - streamChunks.push_back( s->getNextChunkLocked( c->mtu ) ); + streamChunks.push_back( s->getNextChunkLocked( c->mtu() )); + if( s->closed ){ + // TODO: wait after ack + streamChunks.push_back( { s->id, (uint8_t) s->nextSequence, {} } ); + slock.unlock(); + s.reset(); + } } + + while( not c->outStreams.empty() && not c->outStreams.back() ) + c->outStreams.pop_back(); } if (sendAck) { auto pst = self->ref()->storage().deriveEphemeralStorage(); @@ -293,6 +306,8 @@ bool NetworkProtocol::verifyCookie(variant vaddr, con /* Connection */ /******************************************************************************/ +using Connection = NetworkProtocol::Connection; + NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const { return reinterpret_cast(this); @@ -330,6 +345,24 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const return p->peerAddress; } +size_t Connection::mtu() const +{ + return p->mtu(); +} + +size_t NetworkProtocol::ConnectionPriv::mtu() const +{ + if( get_if< unique_ptr< Channel >>( &channel )) + return mtuLower // space for: + - 1 // "encrypted" tag + - 1 // counter + - 1 // channel number + - 1 // channel sequence + - 16 // tag + ; + return mtuLower - 128; // some space for cookie headers +} + optional NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { vector buf; @@ -440,7 +473,7 @@ NetworkProtocol::Connection::parsePacket(vector & buf, plainBegin = decrypted.begin() + 1; plainEnd = decrypted.end(); } - else if (decrypted[0] < 0x40) { + else if (decrypted[0] <= maxStreamNumber) { StreamData sdata; sdata.id = decrypted[0]; sdata.sequence = decrypted[1]; @@ -597,12 +630,17 @@ shared_ptr< NetworkProtocol::InStream > NetworkProtocol::Connection::openInStrea return p->inStreams.back(); } -shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream( uint8_t sid ) +shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream() { scoped_lock lock( p->cmutex ); - for (const auto & s : p->outStreams) - if (s->id == sid) - throw runtime_error("outbound stream " + to_string(sid) + " already open"); + + uint8_t sid = 1; + if( not p->outStreams.empty() ){ + if( p->outStreams.back()->id < maxStreamNumber ) + sid = p->outStreams.back()->id + 1; + else + throw runtime_error("no free outbound stream"); + } p->outStreams.emplace_back( new OutStream( sid )); return p->outStreams.back(); @@ -636,9 +674,21 @@ void NetworkProtocol::Connection::trySendOutQueue() } +NetworkProtocol::Stream::Stream(uint8_t id_): + id(id_) +{ + readPtr = readBuffer.begin(); +} + +void NetworkProtocol::Stream::close() +{ + scoped_lock lock( streamMutex ); + closed = true; +} + bool NetworkProtocol::Stream::hasDataLocked() const { - return not writeBuffer.empty() || not readBuffer.empty(); + return not writeBuffer.empty() || readPtr < readBuffer.end(); } size_t NetworkProtocol::Stream::writeLocked( const uint8_t * buf, size_t size ) @@ -725,6 +775,12 @@ bool NetworkProtocol::InStream::tryUseChunkLocked( const StreamData & chunk ) return true; } +size_t NetworkProtocol::OutStream::write( const uint8_t * buf, size_t size ) +{ + scoped_lock lock( streamMutex ); + return writeLocked( buf, size ); +} + NetworkProtocol::StreamData NetworkProtocol::OutStream::getNextChunkLocked( size_t size ) { StreamData res; diff --git a/src/network/protocol.h b/src/network/protocol.h index 2db4e63..d32b20b 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -106,6 +106,7 @@ public: Id id() const; const sockaddr_in6 & peerAddress() const; + size_t mtu() const; optional
receive(const PartialStorage &); bool send(const PartialStorage &, NetworkProtocol::Header, @@ -115,7 +116,7 @@ public: void close(); shared_ptr< InStream > openInStream( uint8_t sid ); - shared_ptr< OutStream > openOutStream( uint8_t sid ); + shared_ptr< OutStream > openOutStream(); // temporary: ChannelState & channel(); @@ -136,14 +137,21 @@ class NetworkProtocol::Stream friend class NetworkProtocol::Connection; protected: - Stream(uint8_t id_): id( id_ ) {} + Stream(uint8_t id_); +public: + void close(); + +protected: bool hasDataLocked() const; size_t writeLocked( const uint8_t * buf, size_t size ); size_t readLocked( uint8_t * buf, size_t size ); - uint8_t id; +public: + const uint8_t id; + +protected: bool closed { false }; vector< uint8_t > writeBuffer; vector< uint8_t > readBuffer; @@ -181,6 +189,9 @@ class NetworkProtocol::OutStream : public NetworkProtocol::Stream protected: OutStream(uint8_t id): Stream( id ) {} +public: + size_t write( const uint8_t * buf, size_t size ); + private: StreamData getNextChunkLocked( size_t size ); @@ -226,6 +237,8 @@ struct NetworkProtocol::Header ServiceRef, StreamOpen>; + static constexpr size_t itemSize = 78; /* estimate for size of ref-containing headers */ + Header(const vector & items): items(items) {} static optional
load(const PartialRef &); static optional
load(const PartialObject &); -- cgit v1.2.3