diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | include/erebos/frp.h | 2 | ||||
-rw-r--r-- | src/network.cpp | 83 | ||||
-rw-r--r-- | src/network.h | 33 | ||||
-rw-r--r-- | src/network/protocol.cpp | 78 | ||||
-rw-r--r-- | src/network/protocol.h | 43 | ||||
-rw-r--r-- | src/pubkey.cpp | 2 | ||||
-rw-r--r-- | test/network.test (renamed from test/discovery.test) | 29 |
8 files changed, 214 insertions, 58 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a3727a..4ff8385 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ endif() find_package(Threads REQUIRED) find_package(ZLIB REQUIRED) -find_package(OpenSSL REQUIRED) +find_package(OpenSSL 3.0 REQUIRED) find_library(B2_LIBRARY b2 REQUIRED) add_subdirectory(src) diff --git a/include/erebos/frp.h b/include/erebos/frp.h index 72b5cc9..a06519d 100644 --- a/include/erebos/frp.h +++ b/include/erebos/frp.h @@ -165,7 +165,7 @@ template<typename A> using Bhv = BhvFun<monostate, A>; template<typename A> -Watched<A> Bhv<A>::watch(function<void(const A &)> f) +Watched<A> BhvFun<monostate, A>::watch(function<void(const A &)> f) { BhvCurTime ctime; auto & impl = BhvFun<monostate, A>::impl; diff --git a/src/network.cpp b/src/network.cpp index 409b829..22b292a 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -252,7 +252,13 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const NetworkProtocol::Header::ServiceType { uuid }, NetworkProtocol::Header::ServiceRef { ref.digest() }, }); - speer->connection.send(speer->partStorage, move(header), { obj }, true); + + vector< Object > body; + if( obj.encode().size() + 2 * NetworkProtocol::Header::itemSize + <= speer->connection.mtu() ) + body.push_back( obj ); + + speer->connection.send( speer->partStorage, move(header), body, true ); return true; } @@ -442,12 +448,29 @@ void Server::Priv::doListen() peer->updateChannel( reply ); } else { peer->checkDataResponseStreams( reply ); + peer->updateIdentity( reply, notifyPeers ); } peer->updateService( reply, readyServices ); - if (!reply.header().empty()) + if( not reply.header().empty() ) { + for( const auto & item : reply.header() ) { + if( const auto * req = get_if< NetworkProtocol::Header::DataRequest >( &item )) { + const auto & dgst = req->value; + peer->requestedData.push_back( dgst ); + } + } + peer->connection.send(peer->partStorage, - NetworkProtocol::Header(reply.header()), reply.body(), false); + NetworkProtocol::Header( reply.header() ), + reply.stream() ? vector< Object >{} : reply.body(), false ); + if( reply.stream() ){ + for( const auto & obj : reply.body() ) { + auto part = obj.encode(); + reply.stream()->write( part.data(), part.size() ); + } + reply.stream()->close(); + } + } peer->connection.trySendOutQueue(); } @@ -589,6 +612,10 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head if (auto ref = peer.tempStorage.ref(dgst)) { reply.header({ NetworkProtocol::Header::DataResponse { ref->digest() } }); reply.body(*ref); + + if( holds_alternative< unique_ptr< Channel >>( peer.connection.channel() ) and + reply.size() > peer.connection.mtu() and not reply.stream() ) + reply.stream( peer.connection.openOutStream() ); } } } @@ -599,11 +626,14 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head reply.header({ Header::Acknowledged { dgst } }); if (peer.partStorage.loadObject( dgst )) { + peer.requestedData.erase( + std::remove( peer.requestedData.begin(), peer.requestedData.end(), dgst ), + peer.requestedData.end() ); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != wref->missing.end()) { - if (wref->check(reply)) + if( wref->check( reply, peer.requestedData )) pwref.reset(); } } @@ -633,7 +663,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.identity = wref; - wref->check(reply); + wref->check( reply, peer.requestedData ); } } @@ -648,7 +678,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.identityUpdates.push_back(wref); - wref->check(reply); + wref->check( reply, peer.requestedData ); } } @@ -673,7 +703,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.connection.channel() = wref; - wref->check(reply); + wref->check( reply, peer.requestedData ); } } @@ -721,7 +751,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head }); waiting.push_back(wref); peer.serviceQueue.emplace_back(*serviceType, wref); - wref->check(reply); + wref->check( reply, peer.requestedData ); } } } @@ -797,7 +827,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) } if (holds_alternative<shared_ptr<WaitingRef>>(connection.channel())) { - if (auto ref = std::get<shared_ptr<WaitingRef>>(connection.channel())->check(reply)) { + if( auto ref = std::get< shared_ptr< WaitingRef >>( connection.channel())->check( reply, requestedData )) { auto req = Stored<ChannelRequest>::load(*ref); if (holds_alternative<Identity>(identity) && req->isSignedBy(std::get<Identity>(identity).keyMessage())) { @@ -834,7 +864,7 @@ void Server::Peer::updateService(ReplyBuilder & reply, vector<tuple<shared_ptr<e { decltype(serviceQueue) next; for (auto & x : serviceQueue) { - if (auto ref = std::get<1>(x)->check(reply)) { + if( auto ref = std::get<1>(x)->check( reply, requestedData )) { if (lpeer) { for (auto & svc : server.services) { if (svc->uuid() == std::get<UUID>(x)) { @@ -857,15 +887,20 @@ void Server::Peer::checkDataResponseStreams( ReplyBuilder & reply ) auto objects = PartialObject::decodeMany( partStorage, s->readAll() ); vector< PartialRef > refs; refs.reserve( objects.size() ); - for( const auto & obj : objects ) - refs.push_back( partStorage.storeObject( obj )); + for( const auto & obj : objects ) { + auto ref = partStorage.storeObject( obj ); + refs.push_back( ref ); + requestedData.erase( + std::remove( requestedData.begin(), requestedData.end(), ref.digest() ), + requestedData.end() ); + } for( auto & pwref : server.waiting ) { if (auto wref = pwref.lock()) { for( const auto & ref : refs ) { if( std::find( wref->missing.begin(), wref->missing.end(), ref.digest() ) != wref->missing.end() ) { - if( wref->check( reply ) ) + if( wref->check( reply, requestedData ) ) pwref.reset(); } } @@ -891,9 +926,17 @@ void ReplyBuilder::body(const Ref & ref) for (const auto & x : mbody) if (x.digest() == ref.digest()) return; + + bodySize += ref->encode().size(); mbody.push_back(ref); } +void ReplyBuilder::stream( shared_ptr< NetworkProtocol::OutStream > s ) +{ + mheader.emplace_back( Header::StreamOpen{ s->id }); + mstream = move( s ); +} + vector<Object> ReplyBuilder::body() const { vector<Object> res; @@ -903,6 +946,11 @@ vector<Object> ReplyBuilder::body() const return res; } +size_t ReplyBuilder::size() const +{ + return mheader.size() * Header::itemSize + bodySize; +} + optional<Ref> WaitingRef::check() { @@ -917,13 +965,16 @@ optional<Ref> WaitingRef::check() return nullopt; } -optional<Ref> WaitingRef::check(ReplyBuilder & reply) +optional<Ref> WaitingRef::check( ReplyBuilder & reply, const vector< Digest > & alreadyRequested) { if (auto r = check()) return r; - for (const auto & d : missing) - reply.header({ NetworkProtocol::Header::DataRequest { d } }); + for( const auto & d : missing ) { + if( std::find( alreadyRequested.begin(), alreadyRequested.end(), d ) == + alreadyRequested.end() ) + reply.header({ NetworkProtocol::Header::DataRequest { d } }); + } return nullopt; } diff --git a/src/network.h b/src/network.h index 12013fe..ed02167 100644 --- a/src/network.h +++ b/src/network.h @@ -55,6 +55,7 @@ struct Server::Peer vector<tuple<UUID, shared_ptr<WaitingRef>>> serviceQueue {}; vector< shared_ptr< NetworkProtocol::InStream >> dataResponseStreams {}; + vector< Digest > requestedData {}; shared_ptr<erebos::Peer::Priv> lpeer = nullptr; @@ -135,4 +136,36 @@ struct Server::Priv vector<unique_ptr<Service>> services; }; +class ReplyBuilder +{ +public: + using Header = NetworkProtocol::Header; + + void header( Header::Item && ); + void body( const Ref & ); + void stream( shared_ptr< NetworkProtocol::OutStream >); + + const vector< Header::Item > & header() const { return mheader; } + vector< Object > body() const; + shared_ptr< NetworkProtocol::OutStream > stream() const { return mstream; } + + size_t size() const; + +private: + vector< Header::Item > mheader; + vector< Ref > mbody; + size_t bodySize = 0; + shared_ptr< NetworkProtocol::OutStream > mstream; +}; + +struct WaitingRef +{ + const Storage storage; + const PartialRef ref; + vector< Digest > missing; + + optional< Ref > check(); + optional< Ref > check( ReplyBuilder &, const vector< Digest > &); +}; + } diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 9848d29..89d6a88 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -16,14 +16,18 @@ using std::nullopt; using std::runtime_error; using std::scoped_lock; using std::to_string; +using std::unique_lock; using std::visit; namespace erebos { +static constexpr uint8_t maxStreamNumber = 0x3F; + struct NetworkProtocol::ConnectionPriv { Connection::Id id() const; + size_t mtu() const; bool send(const PartialStorage &, Header, const vector<Object> &, bool secure); bool send( const StreamData & chunk ); @@ -39,7 +43,7 @@ struct NetworkProtocol::ConnectionPriv ChannelState channel = monostate(); vector<vector<uint8_t>> secureOutQueue {}; - size_t mtu = 500; // TODO: MTU + size_t mtuLower = 1000; // TODO: MTU vector<uint64_t> toAcknowledge {}; @@ -94,11 +98,20 @@ NetworkProtocol::PollResult NetworkProtocol::poll() sendAck = not c->toAcknowledge.empty() && holds_alternative< unique_ptr< Channel >>( c->channel ); - for (const auto & s : c->outStreams) { - scoped_lock slock(s->streamMutex); + for (auto & s : c->outStreams) { + unique_lock slock(s->streamMutex); while (s->hasDataLocked()) - streamChunks.push_back( s->getNextChunkLocked( c->mtu ) ); + streamChunks.push_back( s->getNextChunkLocked( c->mtu() )); + if( s->closed ){ + // TODO: wait after ack + streamChunks.push_back( { s->id, (uint8_t) s->nextSequence, {} } ); + slock.unlock(); + s.reset(); + } } + + while( not c->outStreams.empty() && not c->outStreams.back() ) + c->outStreams.pop_back(); } if (sendAck) { auto pst = self->ref()->storage().deriveEphemeralStorage(); @@ -293,6 +306,8 @@ bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, con /* Connection */ /******************************************************************************/ +using Connection = NetworkProtocol::Connection; + NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const { return reinterpret_cast<uintptr_t>(this); @@ -330,6 +345,24 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const return p->peerAddress; } +size_t Connection::mtu() const +{ + return p->mtu(); +} + +size_t NetworkProtocol::ConnectionPriv::mtu() const +{ + if( get_if< unique_ptr< Channel >>( &channel )) + return mtuLower // space for: + - 1 // "encrypted" tag + - 1 // counter + - 1 // channel number + - 1 // channel sequence + - 16 // tag + ; + return mtuLower - 128; // some space for cookie headers +} + optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { vector<uint8_t> buf; @@ -387,7 +420,7 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par return nullopt; } } - else if (const auto * sdata = get_if< StreamData >( &parsed )) { + else if( auto * sdata = get_if< StreamData >( &parsed )){ scoped_lock lock(p->cmutex); if (secure) p->toAcknowledge.push_back(*secure); @@ -440,7 +473,7 @@ NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, plainBegin = decrypted.begin() + 1; plainEnd = decrypted.end(); } - else if (decrypted[0] < 0x40) { + else if (decrypted[0] <= maxStreamNumber) { StreamData sdata; sdata.id = decrypted[0]; sdata.sequence = decrypted[1]; @@ -597,12 +630,17 @@ shared_ptr< NetworkProtocol::InStream > NetworkProtocol::Connection::openInStrea return p->inStreams.back(); } -shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream( uint8_t sid ) +shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream() { scoped_lock lock( p->cmutex ); - for (const auto & s : p->outStreams) - if (s->id == sid) - throw runtime_error("outbound stream " + to_string(sid) + " already open"); + + uint8_t sid = 1; + if( not p->outStreams.empty() ){ + if( p->outStreams.back()->id < maxStreamNumber ) + sid = p->outStreams.back()->id + 1; + else + throw runtime_error("no free outbound stream"); + } p->outStreams.emplace_back( new OutStream( sid )); return p->outStreams.back(); @@ -636,9 +674,21 @@ void NetworkProtocol::Connection::trySendOutQueue() } +NetworkProtocol::Stream::Stream(uint8_t id_): + id(id_) +{ + readPtr = readBuffer.begin(); +} + +void NetworkProtocol::Stream::close() +{ + scoped_lock lock( streamMutex ); + closed = true; +} + bool NetworkProtocol::Stream::hasDataLocked() const { - return not writeBuffer.empty() || not readBuffer.empty(); + return not writeBuffer.empty() || readPtr < readBuffer.end(); } size_t NetworkProtocol::Stream::writeLocked( const uint8_t * buf, size_t size ) @@ -725,6 +775,12 @@ bool NetworkProtocol::InStream::tryUseChunkLocked( const StreamData & chunk ) return true; } +size_t NetworkProtocol::OutStream::write( const uint8_t * buf, size_t size ) +{ + scoped_lock lock( streamMutex ); + return writeLocked( buf, size ); +} + NetworkProtocol::StreamData NetworkProtocol::OutStream::getNextChunkLocked( size_t size ) { StreamData res; diff --git a/src/network/protocol.h b/src/network/protocol.h index 2592c9f..d32b20b 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -106,6 +106,7 @@ public: Id id() const; const sockaddr_in6 & peerAddress() const; + size_t mtu() const; optional<Header> receive(const PartialStorage &); bool send(const PartialStorage &, NetworkProtocol::Header, @@ -115,7 +116,7 @@ public: void close(); shared_ptr< InStream > openInStream( uint8_t sid ); - shared_ptr< OutStream > openOutStream( uint8_t sid ); + shared_ptr< OutStream > openOutStream(); // temporary: ChannelState & channel(); @@ -136,14 +137,21 @@ class NetworkProtocol::Stream friend class NetworkProtocol::Connection; protected: - Stream(uint8_t id_): id( id_ ) {} + Stream(uint8_t id_); +public: + void close(); + +protected: bool hasDataLocked() const; size_t writeLocked( const uint8_t * buf, size_t size ); size_t readLocked( uint8_t * buf, size_t size ); - uint8_t id; +public: + const uint8_t id; + +protected: bool closed { false }; vector< uint8_t > writeBuffer; vector< uint8_t > readBuffer; @@ -181,6 +189,9 @@ class NetworkProtocol::OutStream : public NetworkProtocol::Stream protected: OutStream(uint8_t id): Stream( id ) {} +public: + size_t write( const uint8_t * buf, size_t size ); + private: StreamData getNextChunkLocked( size_t size ); @@ -226,6 +237,8 @@ struct NetworkProtocol::Header ServiceRef, StreamOpen>; + static constexpr size_t itemSize = 78; /* estimate for size of ref-containing headers */ + Header(const vector<Item> & items): items(items) {} static optional<Header> load(const PartialRef &); static optional<Header> load(const PartialObject &); @@ -261,28 +274,4 @@ inline bool operator!=(const NetworkProtocol::Header::Item & left, inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right) { return left.value == right.value; } -class ReplyBuilder -{ -public: - void header(NetworkProtocol::Header::Item &&); - void body(const Ref &); - - const vector<NetworkProtocol::Header::Item> & header() const { return mheader; } - vector<Object> body() const; - -private: - vector<NetworkProtocol::Header::Item> mheader; - vector<Ref> mbody; -}; - -struct WaitingRef -{ - const Storage storage; - const PartialRef ref; - vector<Digest> missing; - - optional<Ref> check(); - optional<Ref> check(ReplyBuilder &); -}; - } diff --git a/src/pubkey.cpp b/src/pubkey.cpp index 59b73f9..9e89cde 100644 --- a/src/pubkey.cpp +++ b/src/pubkey.cpp @@ -108,7 +108,7 @@ optional<SecretKey> SecretKey::fromData(const Stored<PublicKey> & pub, const vec keyData.resize(keyLen); EVP_PKEY_get_raw_public_key(pub->key.get(), keyData.data(), &keyLen); - if (EVP_PKEY_cmp(pkey.get(), pub->key.get()) != 1) + if( EVP_PKEY_eq( pkey.get(), pub->key.get() ) != 1 ) return nullopt; pub.ref().storage().storeKey(pub.ref(), sdata); diff --git a/test/discovery.test b/test/network.test index 2aaaf24..3df7376 100644 --- a/test/discovery.test +++ b/test/network.test @@ -1,4 +1,4 @@ -test: +test Discovery: spawn as p1 spawn as p2 send "create-identity Device1 Owner" to p1 @@ -117,3 +117,30 @@ test: /peer $peer6_4 id Device4/ /peer ([0-9]+) addr ${p5.node.ip} 29665/ capture peer6_5 /peer $peer6_5 id Device5/ + + +test LargeData: + spawn as p1 + spawn as p2 + send "create-identity Device1" to p1 + send "create-identity Device2" to p2 + send "start-server" to p1 + send "start-server" to p2 + expect from p1: + /peer 1 addr ${p2.node.ip} 29665/ + /peer 1 id Device2/ + expect from p2: + /peer 1 addr ${p1.node.ip} 29665/ + /peer 1 id Device1/ + + for i in [0..10]: + with p1: + send "store blob" + for j in [1 .. i * 10]: + send "123456789 123456789 123456789 123456789 123456789 123456789 123456789 123456789 123456789 123456789" + send "" + expect /store-done (blake2#[0-9a-f]*)/ capture ref + + send "test-message-send 1 $ref" + expect /test-message-send done/ + expect /test-message-received blob ${i*1000} $ref/ from p2 |