diff options
Diffstat (limited to 'src/network.cpp')
-rw-r--r-- | src/network.cpp | 83 |
1 files changed, 65 insertions, 18 deletions
diff --git a/src/network.cpp b/src/network.cpp index 5807381..409b829 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -430,16 +430,20 @@ void Server::Priv::doListen() vector<shared_ptr<erebos::Peer::Priv>> notifyPeers; vector<tuple<shared_ptr<erebos::Peer::Priv>, Service &, Ref>> readyServices; - if (auto header = peer->connection.receive(peer->partStorage)) { + { ReplyBuilder reply; scoped_lock hlock(dataMutex); shared_lock slock(selfMutex); - handlePacket(*peer, *header, reply); - peer->updateIdentity(reply, notifyPeers); - peer->updateChannel(reply); - peer->updateService(reply, readyServices); + if( auto header = peer->connection.receive( peer->partStorage )) { + handlePacket( *peer, *header, reply ); + peer->updateIdentity( reply, notifyPeers ); + peer->updateChannel( reply ); + } else { + peer->checkDataResponseStreams( reply ); + } + peer->updateService( reply, readyServices ); if (!reply.header().empty()) peer->connection.send(peer->partStorage, @@ -565,9 +569,12 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head plaintextRefs.insert(obj.ref().digest()); optional<UUID> serviceType; + shared_ptr< NetworkProtocol::InStream > newDataResponseStream; + + using Header = NetworkProtocol::Header; for (const auto & item : header.items) { - if (const auto * ack = get_if<NetworkProtocol::Header::Acknowledged>(&item)) { + if (const auto * ack = get_if< Header::Acknowledged >( &item )) { const auto & dgst = ack->value; if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) && std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() == dgst) @@ -575,7 +582,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head std::get<Stored<ChannelAccept>>(peer.connection.channel())->data->channel()); } - else if (const auto * req = get_if<NetworkProtocol::Header::DataRequest>(&item)) { + else if (const auto * req = get_if< Header::DataRequest >( &item )) { const auto & dgst = req->value; if (holds_alternative<unique_ptr<Channel>>(peer.connection.channel()) || plaintextRefs.find(dgst) != plaintextRefs.end()) { @@ -586,21 +593,31 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head } } - else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) { + else if (const auto * rsp = get_if< Header::DataResponse >( &item )) { const auto & dgst = rsp->value; - if (not holds_alternative<unique_ptr<Channel>>(peer.connection.channel())) - reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); - for (auto & pwref : waiting) { - if (auto wref = pwref.lock()) { - if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != - wref->missing.end()) { - if (wref->check(reply)) - pwref.reset(); + if (not holds_alternative< unique_ptr< Channel >>( peer.connection.channel() )) + reply.header({ Header::Acknowledged { dgst } }); + + if (peer.partStorage.loadObject( dgst )) { + for (auto & pwref : waiting) { + if (auto wref = pwref.lock()) { + if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != + wref->missing.end()) { + if (wref->check(reply)) + pwref.reset(); + } + } + } + waiting.erase(std::remove_if(waiting.begin(), waiting.end(), + [](auto & wref) { return wref.expired(); }), waiting.end()); + } else if (not newDataResponseStream) { + for (const auto & item : header.items) { + if (const auto * streamOpen = get_if< Header::StreamOpen >( &item )) { + newDataResponseStream = peer.connection.openInStream( streamOpen->value ); + break; } } } - waiting.erase(std::remove_if(waiting.begin(), waiting.end(), - [](auto & wref) { return wref.expired(); }), waiting.end()); } else if (const auto * ann = get_if<NetworkProtocol::Header::AnnounceSelf>(&item)) { @@ -708,6 +725,9 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head } } } + + if( newDataResponseStream ) + peer.dataResponseStreams.push_back( move( newDataResponseStream )); } void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head) @@ -830,6 +850,33 @@ void Server::Peer::updateService(ReplyBuilder & reply, vector<tuple<shared_ptr<e serviceQueue = std::move(next); } +void Server::Peer::checkDataResponseStreams( ReplyBuilder & reply ) +{ + for( auto & s : dataResponseStreams ) { + if( s->isComplete() ) { + auto objects = PartialObject::decodeMany( partStorage, s->readAll() ); + vector< PartialRef > refs; + refs.reserve( objects.size() ); + for( const auto & obj : objects ) + refs.push_back( partStorage.storeObject( obj )); + + for( auto & pwref : server.waiting ) { + if (auto wref = pwref.lock()) { + for( const auto & ref : refs ) { + if( std::find( wref->missing.begin(), wref->missing.end(), ref.digest() ) != + wref->missing.end() ) { + if( wref->check( reply ) ) + pwref.reset(); + } + } + } + } + server.waiting.erase( std::remove_if( server.waiting.begin(), server.waiting.end(), + [](auto & wref) { return wref.expired(); }), server.waiting.end() ); + } + } +} + void ReplyBuilder::header(NetworkProtocol::Header::Item && item) { |