From 81895699131121a1dab67ce026dcf8490c4de9e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 31 Aug 2024 22:17:16 +0200 Subject: Network streams, accept for data response Changelog: Implemented streams in network protocol --- src/network.cpp | 83 ++++++++++++---- src/network.h | 2 + src/network/protocol.cpp | 251 +++++++++++++++++++++++++++++++++++++++++++---- src/network/protocol.h | 83 +++++++++++++++- src/storage.cpp | 15 +++ 5 files changed, 394 insertions(+), 40 deletions(-) diff --git a/src/network.cpp b/src/network.cpp index 5807381..409b829 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -430,16 +430,20 @@ void Server::Priv::doListen() vector> notifyPeers; vector, Service &, Ref>> readyServices; - if (auto header = peer->connection.receive(peer->partStorage)) { + { ReplyBuilder reply; scoped_lock hlock(dataMutex); shared_lock slock(selfMutex); - handlePacket(*peer, *header, reply); - peer->updateIdentity(reply, notifyPeers); - peer->updateChannel(reply); - peer->updateService(reply, readyServices); + if( auto header = peer->connection.receive( peer->partStorage )) { + handlePacket( *peer, *header, reply ); + peer->updateIdentity( reply, notifyPeers ); + peer->updateChannel( reply ); + } else { + peer->checkDataResponseStreams( reply ); + } + peer->updateService( reply, readyServices ); if (!reply.header().empty()) peer->connection.send(peer->partStorage, @@ -565,9 +569,12 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head plaintextRefs.insert(obj.ref().digest()); optional serviceType; + shared_ptr< NetworkProtocol::InStream > newDataResponseStream; + + using Header = NetworkProtocol::Header; for (const auto & item : header.items) { - if (const auto * ack = get_if(&item)) { + if (const auto * ack = get_if< Header::Acknowledged >( &item )) { const auto & dgst = ack->value; if (holds_alternative>(peer.connection.channel()) && std::get>(peer.connection.channel()).ref().digest() == dgst) @@ -575,7 +582,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head std::get>(peer.connection.channel())->data->channel()); } - else if (const auto * req = get_if(&item)) { + else if (const auto * req = get_if< Header::DataRequest >( &item )) { const auto & dgst = req->value; if (holds_alternative>(peer.connection.channel()) || plaintextRefs.find(dgst) != plaintextRefs.end()) { @@ -586,21 +593,31 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head } } - else if (const auto * rsp = get_if(&item)) { + else if (const auto * rsp = get_if< Header::DataResponse >( &item )) { const auto & dgst = rsp->value; - if (not holds_alternative>(peer.connection.channel())) - reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); - for (auto & pwref : waiting) { - if (auto wref = pwref.lock()) { - if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != - wref->missing.end()) { - if (wref->check(reply)) - pwref.reset(); + if (not holds_alternative< unique_ptr< Channel >>( peer.connection.channel() )) + reply.header({ Header::Acknowledged { dgst } }); + + if (peer.partStorage.loadObject( dgst )) { + for (auto & pwref : waiting) { + if (auto wref = pwref.lock()) { + if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != + wref->missing.end()) { + if (wref->check(reply)) + pwref.reset(); + } + } + } + waiting.erase(std::remove_if(waiting.begin(), waiting.end(), + [](auto & wref) { return wref.expired(); }), waiting.end()); + } else if (not newDataResponseStream) { + for (const auto & item : header.items) { + if (const auto * streamOpen = get_if< Header::StreamOpen >( &item )) { + newDataResponseStream = peer.connection.openInStream( streamOpen->value ); + break; } } } - waiting.erase(std::remove_if(waiting.begin(), waiting.end(), - [](auto & wref) { return wref.expired(); }), waiting.end()); } else if (const auto * ann = get_if(&item)) { @@ -708,6 +725,9 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head } } } + + if( newDataResponseStream ) + peer.dataResponseStreams.push_back( move( newDataResponseStream )); } void Server::Priv::handleLocalHeadChange(const Head & head) @@ -830,6 +850,33 @@ void Server::Peer::updateService(ReplyBuilder & reply, vectorisComplete() ) { + auto objects = PartialObject::decodeMany( partStorage, s->readAll() ); + vector< PartialRef > refs; + refs.reserve( objects.size() ); + for( const auto & obj : objects ) + refs.push_back( partStorage.storeObject( obj )); + + for( auto & pwref : server.waiting ) { + if (auto wref = pwref.lock()) { + for( const auto & ref : refs ) { + if( std::find( wref->missing.begin(), wref->missing.end(), ref.digest() ) != + wref->missing.end() ) { + if( wref->check( reply ) ) + pwref.reset(); + } + } + } + } + server.waiting.erase( std::remove_if( server.waiting.begin(), server.waiting.end(), + [](auto & wref) { return wref.expired(); }), server.waiting.end() ); + } + } +} + void ReplyBuilder::header(NetworkProtocol::Header::Item && item) { diff --git a/src/network.h b/src/network.h index 4c23b47..12013fe 100644 --- a/src/network.h +++ b/src/network.h @@ -54,6 +54,7 @@ struct Server::Peer PartialStorage partStorage; vector>> serviceQueue {}; + vector< shared_ptr< NetworkProtocol::InStream >> dataResponseStreams {}; shared_ptr lpeer = nullptr; @@ -61,6 +62,7 @@ struct Server::Peer void updateChannel(ReplyBuilder &); void finalizeChannel(ReplyBuilder &, unique_ptr); void updateService(ReplyBuilder &, vector, Service &, Ref>> & readyServices); + void checkDataResponseStreams( ReplyBuilder & ); }; struct Peer::Priv : enable_shared_from_this diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index b781693..9848d29 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -15,6 +15,7 @@ using std::move; using std::nullopt; using std::runtime_error; using std::scoped_lock; +using std::to_string; using std::visit; namespace erebos { @@ -25,6 +26,7 @@ struct NetworkProtocol::ConnectionPriv bool send(const PartialStorage &, Header, const vector &, bool secure); + bool send( const StreamData & chunk ); NetworkProtocol * protocol; const sockaddr_in6 peerAddress; @@ -37,7 +39,12 @@ struct NetworkProtocol::ConnectionPriv ChannelState channel = monostate(); vector> secureOutQueue {}; + size_t mtu = 500; // TODO: MTU + vector toAcknowledge {}; + + vector< shared_ptr< InStream >> inStreams {}; + vector< shared_ptr< OutStream >> outStreams {}; }; @@ -80,16 +87,26 @@ NetworkProtocol::PollResult NetworkProtocol::poll() scoped_lock lock(protocolMutex); for (const auto & c : connections) { + vector< StreamData > streamChunks; + bool sendAck = false; { scoped_lock clock(c->cmutex); - if (c->toAcknowledge.empty()) - continue; - - if (not holds_alternative>(c->channel)) - continue; + sendAck = not c->toAcknowledge.empty() && + holds_alternative< unique_ptr< Channel >>( c->channel ); + + for (const auto & s : c->outStreams) { + scoped_lock slock(s->streamMutex); + while (s->hasDataLocked()) + streamChunks.push_back( s->getNextChunkLocked( c->mtu ) ); + } + } + if (sendAck) { + auto pst = self->ref()->storage().deriveEphemeralStorage(); + c->send(pst, Header {{}}, {}, true); + } + for (const auto & chunk : streamChunks) { + c->send( chunk ); } - auto pst = self->ref()->storage().deriveEphemeralStorage(); - c->send(pst, Header {{}}, {}, true); } } @@ -110,7 +127,8 @@ NetworkProtocol::PollResult NetworkProtocol::poll() auto pst = self->ref()->storage().deriveEphemeralStorage(); optional secure = false; - if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) { + auto parsed = Connection::parsePacket(buffer, nullptr, pst, secure); + if (const auto * header = get_if< Header >( &parsed )) { if (auto conn = verifyNewConnection(*header, addr)) return NewConnection { move(*conn) }; @@ -335,13 +353,14 @@ optional NetworkProtocol::Connection::receive(const Par } optional secure = false; - if (auto header = parsePacket(buf, channel, partStorage, secure)) { + auto parsed = parsePacket(buf, channel, partStorage, secure); + if (const auto * header = get_if< Header >( &parsed )) { scoped_lock lock(p->cmutex); if (secure) { if (header->isAcknowledged()) p->toAcknowledge.push_back(*secure); - return header; + return *header; } if (const auto * cookieEcho = header->lookupFirst()) { @@ -353,13 +372,13 @@ optional NetworkProtocol::Connection::receive(const Par if (const auto * cookieSet = header->lookupFirst()) p->receivedCookie = cookieSet->value; - return header; + return *header; } if (holds_alternative(p->channel)) { if (const auto * cookieSet = header->lookupFirst()) { p->receivedCookie = cookieSet->value; - return header; + return *header; } } @@ -368,10 +387,36 @@ optional NetworkProtocol::Connection::receive(const Par return nullopt; } } + else if (const auto * sdata = get_if< StreamData >( &parsed )) { + scoped_lock lock(p->cmutex); + if (secure) + p->toAcknowledge.push_back(*secure); + + InStream * stream = nullptr; + for (const auto & s : p->inStreams) { + if (s->id == sdata->id) { + stream = s.get(); + break; + } + } + if (not stream) { + std::cerr << "unexpected stream number\n"; + return nullopt; + } + + stream->writeChunk( move(*sdata) ); + if( stream->closed ) + p->inStreams.erase( + std::remove_if( p->inStreams.begin(), p->inStreams.end(), + [&]( auto & sptr ) { return sptr.get() == stream; } ), + p->inStreams.end() ); + return nullopt; + } return nullopt; } -optional NetworkProtocol::Connection::parsePacket(vector & buf, +variant< monostate, NetworkProtocol::Header, NetworkProtocol::StreamData > +NetworkProtocol::Connection::parsePacket(vector & buf, Channel * channel, const PartialStorage & partStorage, optional & secure) { @@ -384,7 +429,7 @@ optional NetworkProtocol::Connection::parsePacket(vecto if ((buf[0] & 0xE0) == 0x80) { if (not channel) { std::cerr << "unexpected encrypted packet\n"; - return nullopt; + return monostate(); } if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) { @@ -395,9 +440,17 @@ optional NetworkProtocol::Connection::parsePacket(vecto plainBegin = decrypted.begin() + 1; plainEnd = decrypted.end(); } + else if (decrypted[0] < 0x40) { + StreamData sdata; + sdata.id = decrypted[0]; + sdata.sequence = decrypted[1]; + sdata.data.resize( decrypted.size() - 2 ); + std::copy(decrypted.begin() + 2, decrypted.end(), sdata.data.begin()); + return sdata; + } else { - std::cerr << "streams not implemented\n"; - return nullopt; + std::cerr << "unexpected stream header\n"; + return monostate(); } } } @@ -414,12 +467,12 @@ optional NetworkProtocol::Connection::parsePacket(vecto pos = std::get<1>(*cdec); } - return header; + return *header; } } std::cerr << "invalid packet\n"; - return nullopt; + return monostate(); } bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, @@ -483,6 +536,37 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, return true; } +bool NetworkProtocol::Connection::send( const StreamData & chunk ) +{ + return p->send( chunk ); +} + +bool NetworkProtocol::ConnectionPriv::send( const StreamData & chunk ) +{ + vector data, out; + + { + scoped_lock clock( cmutex ); + + Channel * channel = nullptr; + if (auto uptr = get_if< unique_ptr< Channel >>( &this->channel )) + channel = uptr->get(); + if (not channel) + return false; + + data.push_back( chunk.id ); + data.push_back( static_cast< uint8_t >( chunk.sequence )); + data.insert( data.end(), chunk.data.begin(), chunk.data.end() ); + + out.push_back( 0x80 ); + channel->encrypt( data.begin(), data.end(), out, 1 ); + } + + protocol->sendto( out, peerAddress ); + return true; +} + + void NetworkProtocol::Connection::close() { if (not p) @@ -502,6 +586,28 @@ void NetworkProtocol::Connection::close() p = nullptr; } +shared_ptr< NetworkProtocol::InStream > NetworkProtocol::Connection::openInStream( uint8_t sid ) +{ + scoped_lock lock( p->cmutex ); + for (const auto & s : p->inStreams) + if (s->id == sid) + throw runtime_error("inbound stream " + to_string(sid) + " already open"); + + p->inStreams.emplace_back( new InStream( sid )); + return p->inStreams.back(); +} + +shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream( uint8_t sid ) +{ + scoped_lock lock( p->cmutex ); + for (const auto & s : p->outStreams) + if (s->id == sid) + throw runtime_error("outbound stream " + to_string(sid) + " already open"); + + p->outStreams.emplace_back( new OutStream( sid )); + return p->outStreams.back(); +} + NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel() { return p->channel; @@ -530,6 +636,109 @@ void NetworkProtocol::Connection::trySendOutQueue() } +bool NetworkProtocol::Stream::hasDataLocked() const +{ + return not writeBuffer.empty() || not readBuffer.empty(); +} + +size_t NetworkProtocol::Stream::writeLocked( const uint8_t * buf, size_t size ) +{ + writeBuffer.insert( writeBuffer.end(), buf, buf + size ); + return size; +} + +size_t NetworkProtocol::Stream::readLocked( uint8_t * buf, size_t size ) +{ + size_t res = 0; + if (readPtr < readBuffer.end()) { + res = std::min( size, static_cast< size_t >( readBuffer.end() - readPtr )); + std::copy_n( readPtr, res, buf ); + readPtr += res; + } + if (res < size && not writeBuffer.empty()) { + std::swap( readBuffer, writeBuffer ); + readPtr = readBuffer.begin(); + writeBuffer.clear(); + return res + readLocked( buf + res, size - res ); + } + return res; +} + +bool NetworkProtocol::InStream::isComplete() const +{ + scoped_lock lock( streamMutex ); + return closed && outOfOrderChunks.empty(); +} + +vector< uint8_t > NetworkProtocol::InStream::readAll() +{ + scoped_lock lock( streamMutex ); + if (readBuffer.empty()) { + vector< uint8_t > res; + std::swap( res, writeBuffer ); + return res; + } + + readBuffer.insert( readBuffer.end(), writeBuffer.begin(), writeBuffer.end() ); + writeBuffer.clear(); + + vector< uint8_t > res; + std::swap( res, readBuffer ); + readPtr = readBuffer.begin(); + return res; +} + +size_t NetworkProtocol::InStream::read( uint8_t * buf, size_t size ) +{ + scoped_lock lock( streamMutex ); + return readLocked( buf, size ); +} + +void NetworkProtocol::InStream::writeChunk( StreamData chunk ) +{ + scoped_lock lock( streamMutex ); + if( tryUseChunkLocked( chunk )) { + auto it = outOfOrderChunks.begin(); + while( it != outOfOrderChunks.end() && tryUseChunkLocked( *it )) + it++; + outOfOrderChunks.erase( outOfOrderChunks.begin(), it ); + } else { + auto it = outOfOrderChunks.begin(); + while( it < outOfOrderChunks.end() && + it->sequence - static_cast< uint8_t >( nextSequence ) + < chunk.sequence - static_cast< uint8_t >( nextSequence )) + it++; + outOfOrderChunks.insert( it, move(chunk) ); + } +} + +bool NetworkProtocol::InStream::tryUseChunkLocked( const StreamData & chunk ) +{ + if( chunk.sequence != static_cast< uint8_t >( nextSequence )) + return false; + + if( chunk.data.empty() ) + closed = true; + else + writeLocked( chunk.data.data(), chunk.data.size() ); + nextSequence++; + return true; +} + +NetworkProtocol::StreamData NetworkProtocol::OutStream::getNextChunkLocked( size_t size ) +{ + StreamData res; + res.id = id; + res.sequence = nextSequence++, + + res.data.resize( size ); + size = readLocked( res.data.data(), size ); + res.data.resize( size ); + + return res; +} + + /******************************************************************************/ /* Header */ /******************************************************************************/ @@ -600,6 +809,9 @@ optional NetworkProtocol::Header::load(const PartialObj } else if (item.name == "SVR") { if (auto ref = item.asRef()) items.emplace_back(ServiceRef { ref->digest() }); + } else if (item.name == "STO") { + if (auto num = item.asInteger()) + items.emplace_back( StreamOpen{ static_cast< uint8_t >( *num )}); } } @@ -652,6 +864,9 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const else if (const auto * ptr = get_if(&item)) ritems.emplace_back("SVR", st.ref(ptr->value)); + + else if (const auto * ptr = get_if< StreamOpen >( &item )) + ritems.emplace_back("STO", Record::Item::Integer( ptr->value )); } return PartialObject(PartialRecord(std::move(ritems))); diff --git a/src/network/protocol.h b/src/network/protocol.h index ba40744..2592c9f 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -35,8 +35,12 @@ public: static constexpr char defaultVersion[] = "0.1"; class Connection; + class Stream; + class InStream; + class OutStream; struct Header; + struct StreamData; struct ReceivedAnnounce; struct NewConnection; @@ -106,21 +110,83 @@ public: optional
receive(const PartialStorage &); bool send(const PartialStorage &, NetworkProtocol::Header, const vector &, bool secure); + bool send( const StreamData & chunk ); void close(); + shared_ptr< InStream > openInStream( uint8_t sid ); + shared_ptr< OutStream > openOutStream( uint8_t sid ); + // temporary: ChannelState & channel(); void trySendOutQueue(); private: - static optional
parsePacket(vector & buf, - Channel * channel, const PartialStorage & st, - optional & secure); + static variant< monostate, Header, StreamData > + parsePacket(vector & buf, + Channel * channel, const PartialStorage & st, + optional & secure); unique_ptr p; }; +class NetworkProtocol::Stream +{ + friend class NetworkProtocol; + friend class NetworkProtocol::Connection; + +protected: + Stream(uint8_t id_): id( id_ ) {} + + bool hasDataLocked() const; + + size_t writeLocked( const uint8_t * buf, size_t size ); + size_t readLocked( uint8_t * buf, size_t size ); + + uint8_t id; + bool closed { false }; + vector< uint8_t > writeBuffer; + vector< uint8_t > readBuffer; + vector< uint8_t >::const_iterator readPtr; + mutable mutex streamMutex; +}; + +class NetworkProtocol::InStream : public NetworkProtocol::Stream +{ + friend class NetworkProtocol; + friend class NetworkProtocol::Connection; + +protected: + InStream(uint8_t id): Stream( id ) {} + +public: + bool isComplete() const; + vector< uint8_t > readAll(); + size_t read( uint8_t * buf, size_t size ); + +protected: + void writeChunk( StreamData chunk ); + bool tryUseChunkLocked( const StreamData & chunk ); + +private: + uint64_t nextSequence { 0 }; + vector< StreamData > outOfOrderChunks; +}; + +class NetworkProtocol::OutStream : public NetworkProtocol::Stream +{ + friend class NetworkProtocol; + friend class NetworkProtocol::Connection; + +protected: + OutStream(uint8_t id): Stream( id ) {} + +private: + StreamData getNextChunkLocked( size_t size ); + + uint64_t nextSequence { 0 }; +}; + struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; }; struct NetworkProtocol::NewConnection { Connection conn; }; struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; @@ -141,6 +207,7 @@ struct NetworkProtocol::Header struct ChannelAccept { Digest value; }; struct ServiceType { UUID value; }; struct ServiceRef { Digest value; }; + struct StreamOpen { uint8_t value; }; using Item = variant< Acknowledged, @@ -156,7 +223,8 @@ struct NetworkProtocol::Header ChannelRequest, ChannelAccept, ServiceType, - ServiceRef>; + ServiceRef, + StreamOpen>; Header(const vector & items): items(items) {} static optional
load(const PartialRef &); @@ -169,6 +237,13 @@ struct NetworkProtocol::Header vector items; }; +struct NetworkProtocol::StreamData +{ + uint8_t id; + uint8_t sequence; + vector< uint8_t > data; +}; + template const T * NetworkProtocol::Header::lookupFirst() const { diff --git a/src/storage.cpp b/src/storage.cpp index 19f35a9..fd985c7 100644 --- a/src/storage.cpp +++ b/src/storage.cpp @@ -1579,6 +1579,21 @@ optional> ObjectT::decode(const S & st, return nullopt; } +template< class S > +vector< ObjectT< S >> ObjectT< S >::decodeMany( const S & st, + const std::vector< uint8_t > & data) +{ + vector< ObjectT< S >> objects; + auto cur = data.begin(); + + while( auto pair = decodePrefix( st, cur, data.end() )) { + auto [ obj, next ] = *pair; + objects.push_back( move( obj )); + cur = next; + } + return objects; +} + template vector ObjectT::encode() const { -- cgit v1.2.3