diff options
Diffstat (limited to 'src/network')
-rw-r--r-- | src/network/protocol.cpp | 76 | ||||
-rw-r--r-- | src/network/protocol.h | 19 |
2 files changed, 82 insertions, 13 deletions
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index dbf1c40..89d6a88 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -16,14 +16,18 @@ using std::nullopt; using std::runtime_error; using std::scoped_lock; using std::to_string; +using std::unique_lock; using std::visit; namespace erebos { +static constexpr uint8_t maxStreamNumber = 0x3F; + struct NetworkProtocol::ConnectionPriv { Connection::Id id() const; + size_t mtu() const; bool send(const PartialStorage &, Header, const vector<Object> &, bool secure); bool send( const StreamData & chunk ); @@ -39,7 +43,7 @@ struct NetworkProtocol::ConnectionPriv ChannelState channel = monostate(); vector<vector<uint8_t>> secureOutQueue {}; - size_t mtu = 500; // TODO: MTU + size_t mtuLower = 1000; // TODO: MTU vector<uint64_t> toAcknowledge {}; @@ -94,11 +98,20 @@ NetworkProtocol::PollResult NetworkProtocol::poll() sendAck = not c->toAcknowledge.empty() && holds_alternative< unique_ptr< Channel >>( c->channel ); - for (const auto & s : c->outStreams) { - scoped_lock slock(s->streamMutex); + for (auto & s : c->outStreams) { + unique_lock slock(s->streamMutex); while (s->hasDataLocked()) - streamChunks.push_back( s->getNextChunkLocked( c->mtu ) ); + streamChunks.push_back( s->getNextChunkLocked( c->mtu() )); + if( s->closed ){ + // TODO: wait after ack + streamChunks.push_back( { s->id, (uint8_t) s->nextSequence, {} } ); + slock.unlock(); + s.reset(); + } } + + while( not c->outStreams.empty() && not c->outStreams.back() ) + c->outStreams.pop_back(); } if (sendAck) { auto pst = self->ref()->storage().deriveEphemeralStorage(); @@ -293,6 +306,8 @@ bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, con /* Connection */ /******************************************************************************/ +using Connection = NetworkProtocol::Connection; + NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const { return reinterpret_cast<uintptr_t>(this); @@ -330,6 +345,24 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const return p->peerAddress; } +size_t Connection::mtu() const +{ + return p->mtu(); +} + +size_t NetworkProtocol::ConnectionPriv::mtu() const +{ + if( get_if< unique_ptr< Channel >>( &channel )) + return mtuLower // space for: + - 1 // "encrypted" tag + - 1 // counter + - 1 // channel number + - 1 // channel sequence + - 16 // tag + ; + return mtuLower - 128; // some space for cookie headers +} + optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { vector<uint8_t> buf; @@ -440,7 +473,7 @@ NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, plainBegin = decrypted.begin() + 1; plainEnd = decrypted.end(); } - else if (decrypted[0] < 0x40) { + else if (decrypted[0] <= maxStreamNumber) { StreamData sdata; sdata.id = decrypted[0]; sdata.sequence = decrypted[1]; @@ -597,12 +630,17 @@ shared_ptr< NetworkProtocol::InStream > NetworkProtocol::Connection::openInStrea return p->inStreams.back(); } -shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream( uint8_t sid ) +shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream() { scoped_lock lock( p->cmutex ); - for (const auto & s : p->outStreams) - if (s->id == sid) - throw runtime_error("outbound stream " + to_string(sid) + " already open"); + + uint8_t sid = 1; + if( not p->outStreams.empty() ){ + if( p->outStreams.back()->id < maxStreamNumber ) + sid = p->outStreams.back()->id + 1; + else + throw runtime_error("no free outbound stream"); + } p->outStreams.emplace_back( new OutStream( sid )); return p->outStreams.back(); @@ -636,9 +674,21 @@ void NetworkProtocol::Connection::trySendOutQueue() } +NetworkProtocol::Stream::Stream(uint8_t id_): + id(id_) +{ + readPtr = readBuffer.begin(); +} + +void NetworkProtocol::Stream::close() +{ + scoped_lock lock( streamMutex ); + closed = true; +} + bool NetworkProtocol::Stream::hasDataLocked() const { - return not writeBuffer.empty() || not readBuffer.empty(); + return not writeBuffer.empty() || readPtr < readBuffer.end(); } size_t NetworkProtocol::Stream::writeLocked( const uint8_t * buf, size_t size ) @@ -725,6 +775,12 @@ bool NetworkProtocol::InStream::tryUseChunkLocked( const StreamData & chunk ) return true; } +size_t NetworkProtocol::OutStream::write( const uint8_t * buf, size_t size ) +{ + scoped_lock lock( streamMutex ); + return writeLocked( buf, size ); +} + NetworkProtocol::StreamData NetworkProtocol::OutStream::getNextChunkLocked( size_t size ) { StreamData res; diff --git a/src/network/protocol.h b/src/network/protocol.h index 2db4e63..d32b20b 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -106,6 +106,7 @@ public: Id id() const; const sockaddr_in6 & peerAddress() const; + size_t mtu() const; optional<Header> receive(const PartialStorage &); bool send(const PartialStorage &, NetworkProtocol::Header, @@ -115,7 +116,7 @@ public: void close(); shared_ptr< InStream > openInStream( uint8_t sid ); - shared_ptr< OutStream > openOutStream( uint8_t sid ); + shared_ptr< OutStream > openOutStream(); // temporary: ChannelState & channel(); @@ -136,14 +137,21 @@ class NetworkProtocol::Stream friend class NetworkProtocol::Connection; protected: - Stream(uint8_t id_): id( id_ ) {} + Stream(uint8_t id_); +public: + void close(); + +protected: bool hasDataLocked() const; size_t writeLocked( const uint8_t * buf, size_t size ); size_t readLocked( uint8_t * buf, size_t size ); - uint8_t id; +public: + const uint8_t id; + +protected: bool closed { false }; vector< uint8_t > writeBuffer; vector< uint8_t > readBuffer; @@ -181,6 +189,9 @@ class NetworkProtocol::OutStream : public NetworkProtocol::Stream protected: OutStream(uint8_t id): Stream( id ) {} +public: + size_t write( const uint8_t * buf, size_t size ); + private: StreamData getNextChunkLocked( size_t size ); @@ -226,6 +237,8 @@ struct NetworkProtocol::Header ServiceRef, StreamOpen>; + static constexpr size_t itemSize = 78; /* estimate for size of ref-containing headers */ + Header(const vector<Item> & items): items(items) {} static optional<Header> load(const PartialRef &); static optional<Header> load(const PartialObject &); |