diff options
Diffstat (limited to 'src/network.cpp')
-rw-r--r-- | src/network.cpp | 122 |
1 files changed, 94 insertions, 28 deletions
diff --git a/src/network.cpp b/src/network.cpp index 5807381..26a07e3 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -430,20 +430,32 @@ void Server::Priv::doListen() vector<shared_ptr<erebos::Peer::Priv>> notifyPeers; vector<tuple<shared_ptr<erebos::Peer::Priv>, Service &, Ref>> readyServices; - if (auto header = peer->connection.receive(peer->partStorage)) { + { ReplyBuilder reply; scoped_lock hlock(dataMutex); shared_lock slock(selfMutex); - handlePacket(*peer, *header, reply); - peer->updateIdentity(reply, notifyPeers); - peer->updateChannel(reply); - peer->updateService(reply, readyServices); + if( auto header = peer->connection.receive( peer->partStorage )) { + handlePacket( *peer, *header, reply ); + peer->updateIdentity( reply, notifyPeers ); + peer->updateChannel( reply ); + } else { + peer->checkDataResponseStreams( reply ); + } + peer->updateService( reply, readyServices ); + + if( not reply.header().empty() ) { + for( const auto & item : reply.header() ) { + if( const auto * req = get_if< NetworkProtocol::Header::DataRequest >( &item )) { + const auto & dgst = req->value; + peer->requestedData.push_back( dgst ); + } + } - if (!reply.header().empty()) peer->connection.send(peer->partStorage, NetworkProtocol::Header(reply.header()), reply.body(), false); + } peer->connection.trySendOutQueue(); } @@ -565,9 +577,12 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head plaintextRefs.insert(obj.ref().digest()); optional<UUID> serviceType; + shared_ptr< NetworkProtocol::InStream > newDataResponseStream; + + using Header = NetworkProtocol::Header; for (const auto & item : header.items) { - if (const auto * ack = get_if<NetworkProtocol::Header::Acknowledged>(&item)) { + if (const auto * ack = get_if< Header::Acknowledged >( &item )) { const auto & dgst = ack->value; if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) && std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() == dgst) @@ -575,7 +590,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head std::get<Stored<ChannelAccept>>(peer.connection.channel())->data->channel()); } - else if (const auto * req = get_if<NetworkProtocol::Header::DataRequest>(&item)) { + else if (const auto * req = get_if< Header::DataRequest >( &item )) { const auto & dgst = req->value; if (holds_alternative<unique_ptr<Channel>>(peer.connection.channel()) || plaintextRefs.find(dgst) != plaintextRefs.end()) { @@ -586,21 +601,34 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head } } - else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) { + else if (const auto * rsp = get_if< Header::DataResponse >( &item )) { const auto & dgst = rsp->value; - if (not holds_alternative<unique_ptr<Channel>>(peer.connection.channel())) - reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); - for (auto & pwref : waiting) { - if (auto wref = pwref.lock()) { - if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != - wref->missing.end()) { - if (wref->check(reply)) - pwref.reset(); + if (not holds_alternative< unique_ptr< Channel >>( peer.connection.channel() )) + reply.header({ Header::Acknowledged { dgst } }); + + if (peer.partStorage.loadObject( dgst )) { + peer.requestedData.erase( + std::remove( peer.requestedData.begin(), peer.requestedData.end(), dgst ), + peer.requestedData.end() ); + for (auto & pwref : waiting) { + if (auto wref = pwref.lock()) { + if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != + wref->missing.end()) { + if( wref->check( reply, peer.requestedData )) + pwref.reset(); + } + } + } + waiting.erase(std::remove_if(waiting.begin(), waiting.end(), + [](auto & wref) { return wref.expired(); }), waiting.end()); + } else if (not newDataResponseStream) { + for (const auto & item : header.items) { + if (const auto * streamOpen = get_if< Header::StreamOpen >( &item )) { + newDataResponseStream = peer.connection.openInStream( streamOpen->value ); + break; } } } - waiting.erase(std::remove_if(waiting.begin(), waiting.end(), - [](auto & wref) { return wref.expired(); }), waiting.end()); } else if (const auto * ann = get_if<NetworkProtocol::Header::AnnounceSelf>(&item)) { @@ -616,7 +644,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.identity = wref; - wref->check(reply); + wref->check( reply, peer.requestedData ); } } @@ -631,7 +659,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.identityUpdates.push_back(wref); - wref->check(reply); + wref->check( reply, peer.requestedData ); } } @@ -656,7 +684,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.connection.channel() = wref; - wref->check(reply); + wref->check( reply, peer.requestedData ); } } @@ -704,10 +732,13 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.serviceQueue.emplace_back(*serviceType, wref); - wref->check(reply); + wref->check( reply, peer.requestedData ); } } } + + if( newDataResponseStream ) + peer.dataResponseStreams.push_back( move( newDataResponseStream )); } void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head) @@ -777,7 +808,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) } if (holds_alternative<shared_ptr<WaitingRef>>(connection.channel())) { - if (auto ref = std::get<shared_ptr<WaitingRef>>(connection.channel())->check(reply)) { + if( auto ref = std::get< shared_ptr< WaitingRef >>( connection.channel())->check( reply, requestedData )) { auto req = Stored<ChannelRequest>::load(*ref); if (holds_alternative<Identity>(identity) && req->isSignedBy(std::get<Identity>(identity).keyMessage())) { @@ -814,7 +845,7 @@ void Server::Peer::updateService(ReplyBuilder & reply, vector<tuple<shared_ptr<e { decltype(serviceQueue) next; for (auto & x : serviceQueue) { - if (auto ref = std::get<1>(x)->check(reply)) { + if( auto ref = std::get<1>(x)->check( reply, requestedData )) { if (lpeer) { for (auto & svc : server.services) { if (svc->uuid() == std::get<UUID>(x)) { @@ -830,6 +861,38 @@ void Server::Peer::updateService(ReplyBuilder & reply, vector<tuple<shared_ptr<e serviceQueue = std::move(next); } +void Server::Peer::checkDataResponseStreams( ReplyBuilder & reply ) +{ + for( auto & s : dataResponseStreams ) { + if( s->isComplete() ) { + auto objects = PartialObject::decodeMany( partStorage, s->readAll() ); + vector< PartialRef > refs; + refs.reserve( objects.size() ); + for( const auto & obj : objects ) { + auto ref = partStorage.storeObject( obj ); + refs.push_back( ref ); + requestedData.erase( + std::remove( requestedData.begin(), requestedData.end(), ref.digest() ), + requestedData.end() ); + } + + for( auto & pwref : server.waiting ) { + if (auto wref = pwref.lock()) { + for( const auto & ref : refs ) { + if( std::find( wref->missing.begin(), wref->missing.end(), ref.digest() ) != + wref->missing.end() ) { + if( wref->check( reply, requestedData ) ) + pwref.reset(); + } + } + } + } + server.waiting.erase( std::remove_if( server.waiting.begin(), server.waiting.end(), + [](auto & wref) { return wref.expired(); }), server.waiting.end() ); + } + } +} + void ReplyBuilder::header(NetworkProtocol::Header::Item && item) { @@ -870,13 +933,16 @@ optional<Ref> WaitingRef::check() return nullopt; } -optional<Ref> WaitingRef::check(ReplyBuilder & reply) +optional<Ref> WaitingRef::check( ReplyBuilder & reply, const vector< Digest > & alreadyRequested) { if (auto r = check()) return r; - for (const auto & d : missing) - reply.header({ NetworkProtocol::Header::DataRequest { d } }); + for( const auto & d : missing ) { + if( std::find( alreadyRequested.begin(), alreadyRequested.end(), d ) == + alreadyRequested.end() ) + reply.header({ NetworkProtocol::Header::DataRequest { d } }); + } return nullopt; } |