summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/network.cpp83
-rw-r--r--src/network.h33
-rw-r--r--src/network/protocol.cpp78
-rw-r--r--src/network/protocol.h43
-rw-r--r--src/pubkey.cpp2
5 files changed, 184 insertions, 55 deletions
diff --git a/src/network.cpp b/src/network.cpp
index 409b829..22b292a 100644
--- a/src/network.cpp
+++ b/src/network.cpp
@@ -252,7 +252,13 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const
NetworkProtocol::Header::ServiceType { uuid },
NetworkProtocol::Header::ServiceRef { ref.digest() },
});
- speer->connection.send(speer->partStorage, move(header), { obj }, true);
+
+ vector< Object > body;
+ if( obj.encode().size() + 2 * NetworkProtocol::Header::itemSize
+ <= speer->connection.mtu() )
+ body.push_back( obj );
+
+ speer->connection.send( speer->partStorage, move(header), body, true );
return true;
}
@@ -442,12 +448,29 @@ void Server::Priv::doListen()
peer->updateChannel( reply );
} else {
peer->checkDataResponseStreams( reply );
+ peer->updateIdentity( reply, notifyPeers );
}
peer->updateService( reply, readyServices );
- if (!reply.header().empty())
+ if( not reply.header().empty() ) {
+ for( const auto & item : reply.header() ) {
+ if( const auto * req = get_if< NetworkProtocol::Header::DataRequest >( &item )) {
+ const auto & dgst = req->value;
+ peer->requestedData.push_back( dgst );
+ }
+ }
+
peer->connection.send(peer->partStorage,
- NetworkProtocol::Header(reply.header()), reply.body(), false);
+ NetworkProtocol::Header( reply.header() ),
+ reply.stream() ? vector< Object >{} : reply.body(), false );
+ if( reply.stream() ){
+ for( const auto & obj : reply.body() ) {
+ auto part = obj.encode();
+ reply.stream()->write( part.data(), part.size() );
+ }
+ reply.stream()->close();
+ }
+ }
peer->connection.trySendOutQueue();
}
@@ -589,6 +612,10 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
if (auto ref = peer.tempStorage.ref(dgst)) {
reply.header({ NetworkProtocol::Header::DataResponse { ref->digest() } });
reply.body(*ref);
+
+ if( holds_alternative< unique_ptr< Channel >>( peer.connection.channel() ) and
+ reply.size() > peer.connection.mtu() and not reply.stream() )
+ reply.stream( peer.connection.openOutStream() );
}
}
}
@@ -599,11 +626,14 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
reply.header({ Header::Acknowledged { dgst } });
if (peer.partStorage.loadObject( dgst )) {
+ peer.requestedData.erase(
+ std::remove( peer.requestedData.begin(), peer.requestedData.end(), dgst ),
+ peer.requestedData.end() );
for (auto & pwref : waiting) {
if (auto wref = pwref.lock()) {
if (std::find(wref->missing.begin(), wref->missing.end(), dgst) !=
wref->missing.end()) {
- if (wref->check(reply))
+ if( wref->check( reply, peer.requestedData ))
pwref.reset();
}
}
@@ -633,7 +663,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
});
waiting.push_back(wref);
peer.identity = wref;
- wref->check(reply);
+ wref->check( reply, peer.requestedData );
}
}
@@ -648,7 +678,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
});
waiting.push_back(wref);
peer.identityUpdates.push_back(wref);
- wref->check(reply);
+ wref->check( reply, peer.requestedData );
}
}
@@ -673,7 +703,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
});
waiting.push_back(wref);
peer.connection.channel() = wref;
- wref->check(reply);
+ wref->check( reply, peer.requestedData );
}
}
@@ -721,7 +751,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
});
waiting.push_back(wref);
peer.serviceQueue.emplace_back(*serviceType, wref);
- wref->check(reply);
+ wref->check( reply, peer.requestedData );
}
}
}
@@ -797,7 +827,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)
}
if (holds_alternative<shared_ptr<WaitingRef>>(connection.channel())) {
- if (auto ref = std::get<shared_ptr<WaitingRef>>(connection.channel())->check(reply)) {
+ if( auto ref = std::get< shared_ptr< WaitingRef >>( connection.channel())->check( reply, requestedData )) {
auto req = Stored<ChannelRequest>::load(*ref);
if (holds_alternative<Identity>(identity) &&
req->isSignedBy(std::get<Identity>(identity).keyMessage())) {
@@ -834,7 +864,7 @@ void Server::Peer::updateService(ReplyBuilder & reply, vector<tuple<shared_ptr<e
{
decltype(serviceQueue) next;
for (auto & x : serviceQueue) {
- if (auto ref = std::get<1>(x)->check(reply)) {
+ if( auto ref = std::get<1>(x)->check( reply, requestedData )) {
if (lpeer) {
for (auto & svc : server.services) {
if (svc->uuid() == std::get<UUID>(x)) {
@@ -857,15 +887,20 @@ void Server::Peer::checkDataResponseStreams( ReplyBuilder & reply )
auto objects = PartialObject::decodeMany( partStorage, s->readAll() );
vector< PartialRef > refs;
refs.reserve( objects.size() );
- for( const auto & obj : objects )
- refs.push_back( partStorage.storeObject( obj ));
+ for( const auto & obj : objects ) {
+ auto ref = partStorage.storeObject( obj );
+ refs.push_back( ref );
+ requestedData.erase(
+ std::remove( requestedData.begin(), requestedData.end(), ref.digest() ),
+ requestedData.end() );
+ }
for( auto & pwref : server.waiting ) {
if (auto wref = pwref.lock()) {
for( const auto & ref : refs ) {
if( std::find( wref->missing.begin(), wref->missing.end(), ref.digest() ) !=
wref->missing.end() ) {
- if( wref->check( reply ) )
+ if( wref->check( reply, requestedData ) )
pwref.reset();
}
}
@@ -891,9 +926,17 @@ void ReplyBuilder::body(const Ref & ref)
for (const auto & x : mbody)
if (x.digest() == ref.digest())
return;
+
+ bodySize += ref->encode().size();
mbody.push_back(ref);
}
+void ReplyBuilder::stream( shared_ptr< NetworkProtocol::OutStream > s )
+{
+ mheader.emplace_back( Header::StreamOpen{ s->id });
+ mstream = move( s );
+}
+
vector<Object> ReplyBuilder::body() const
{
vector<Object> res;
@@ -903,6 +946,11 @@ vector<Object> ReplyBuilder::body() const
return res;
}
+size_t ReplyBuilder::size() const
+{
+ return mheader.size() * Header::itemSize + bodySize;
+}
+
optional<Ref> WaitingRef::check()
{
@@ -917,13 +965,16 @@ optional<Ref> WaitingRef::check()
return nullopt;
}
-optional<Ref> WaitingRef::check(ReplyBuilder & reply)
+optional<Ref> WaitingRef::check( ReplyBuilder & reply, const vector< Digest > & alreadyRequested)
{
if (auto r = check())
return r;
- for (const auto & d : missing)
- reply.header({ NetworkProtocol::Header::DataRequest { d } });
+ for( const auto & d : missing ) {
+ if( std::find( alreadyRequested.begin(), alreadyRequested.end(), d ) ==
+ alreadyRequested.end() )
+ reply.header({ NetworkProtocol::Header::DataRequest { d } });
+ }
return nullopt;
}
diff --git a/src/network.h b/src/network.h
index 12013fe..ed02167 100644
--- a/src/network.h
+++ b/src/network.h
@@ -55,6 +55,7 @@ struct Server::Peer
vector<tuple<UUID, shared_ptr<WaitingRef>>> serviceQueue {};
vector< shared_ptr< NetworkProtocol::InStream >> dataResponseStreams {};
+ vector< Digest > requestedData {};
shared_ptr<erebos::Peer::Priv> lpeer = nullptr;
@@ -135,4 +136,36 @@ struct Server::Priv
vector<unique_ptr<Service>> services;
};
+class ReplyBuilder
+{
+public:
+ using Header = NetworkProtocol::Header;
+
+ void header( Header::Item && );
+ void body( const Ref & );
+ void stream( shared_ptr< NetworkProtocol::OutStream >);
+
+ const vector< Header::Item > & header() const { return mheader; }
+ vector< Object > body() const;
+ shared_ptr< NetworkProtocol::OutStream > stream() const { return mstream; }
+
+ size_t size() const;
+
+private:
+ vector< Header::Item > mheader;
+ vector< Ref > mbody;
+ size_t bodySize = 0;
+ shared_ptr< NetworkProtocol::OutStream > mstream;
+};
+
+struct WaitingRef
+{
+ const Storage storage;
+ const PartialRef ref;
+ vector< Digest > missing;
+
+ optional< Ref > check();
+ optional< Ref > check( ReplyBuilder &, const vector< Digest > &);
+};
+
}
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index 9848d29..89d6a88 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -16,14 +16,18 @@ using std::nullopt;
using std::runtime_error;
using std::scoped_lock;
using std::to_string;
+using std::unique_lock;
using std::visit;
namespace erebos {
+static constexpr uint8_t maxStreamNumber = 0x3F;
+
struct NetworkProtocol::ConnectionPriv
{
Connection::Id id() const;
+ size_t mtu() const;
bool send(const PartialStorage &, Header,
const vector<Object> &, bool secure);
bool send( const StreamData & chunk );
@@ -39,7 +43,7 @@ struct NetworkProtocol::ConnectionPriv
ChannelState channel = monostate();
vector<vector<uint8_t>> secureOutQueue {};
- size_t mtu = 500; // TODO: MTU
+ size_t mtuLower = 1000; // TODO: MTU
vector<uint64_t> toAcknowledge {};
@@ -94,11 +98,20 @@ NetworkProtocol::PollResult NetworkProtocol::poll()
sendAck = not c->toAcknowledge.empty() &&
holds_alternative< unique_ptr< Channel >>( c->channel );
- for (const auto & s : c->outStreams) {
- scoped_lock slock(s->streamMutex);
+ for (auto & s : c->outStreams) {
+ unique_lock slock(s->streamMutex);
while (s->hasDataLocked())
- streamChunks.push_back( s->getNextChunkLocked( c->mtu ) );
+ streamChunks.push_back( s->getNextChunkLocked( c->mtu() ));
+ if( s->closed ){
+ // TODO: wait after ack
+ streamChunks.push_back( { s->id, (uint8_t) s->nextSequence, {} } );
+ slock.unlock();
+ s.reset();
+ }
}
+
+ while( not c->outStreams.empty() && not c->outStreams.back() )
+ c->outStreams.pop_back();
}
if (sendAck) {
auto pst = self->ref()->storage().deriveEphemeralStorage();
@@ -293,6 +306,8 @@ bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, con
/* Connection */
/******************************************************************************/
+using Connection = NetworkProtocol::Connection;
+
NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const
{
return reinterpret_cast<uintptr_t>(this);
@@ -330,6 +345,24 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const
return p->peerAddress;
}
+size_t Connection::mtu() const
+{
+ return p->mtu();
+}
+
+size_t NetworkProtocol::ConnectionPriv::mtu() const
+{
+ if( get_if< unique_ptr< Channel >>( &channel ))
+ return mtuLower // space for:
+ - 1 // "encrypted" tag
+ - 1 // counter
+ - 1 // channel number
+ - 1 // channel sequence
+ - 16 // tag
+ ;
+ return mtuLower - 128; // some space for cookie headers
+}
+
optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)
{
vector<uint8_t> buf;
@@ -387,7 +420,7 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
return nullopt;
}
}
- else if (const auto * sdata = get_if< StreamData >( &parsed )) {
+ else if( auto * sdata = get_if< StreamData >( &parsed )){
scoped_lock lock(p->cmutex);
if (secure)
p->toAcknowledge.push_back(*secure);
@@ -440,7 +473,7 @@ NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf,
plainBegin = decrypted.begin() + 1;
plainEnd = decrypted.end();
}
- else if (decrypted[0] < 0x40) {
+ else if (decrypted[0] <= maxStreamNumber) {
StreamData sdata;
sdata.id = decrypted[0];
sdata.sequence = decrypted[1];
@@ -597,12 +630,17 @@ shared_ptr< NetworkProtocol::InStream > NetworkProtocol::Connection::openInStrea
return p->inStreams.back();
}
-shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream( uint8_t sid )
+shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream()
{
scoped_lock lock( p->cmutex );
- for (const auto & s : p->outStreams)
- if (s->id == sid)
- throw runtime_error("outbound stream " + to_string(sid) + " already open");
+
+ uint8_t sid = 1;
+ if( not p->outStreams.empty() ){
+ if( p->outStreams.back()->id < maxStreamNumber )
+ sid = p->outStreams.back()->id + 1;
+ else
+ throw runtime_error("no free outbound stream");
+ }
p->outStreams.emplace_back( new OutStream( sid ));
return p->outStreams.back();
@@ -636,9 +674,21 @@ void NetworkProtocol::Connection::trySendOutQueue()
}
+NetworkProtocol::Stream::Stream(uint8_t id_):
+ id(id_)
+{
+ readPtr = readBuffer.begin();
+}
+
+void NetworkProtocol::Stream::close()
+{
+ scoped_lock lock( streamMutex );
+ closed = true;
+}
+
bool NetworkProtocol::Stream::hasDataLocked() const
{
- return not writeBuffer.empty() || not readBuffer.empty();
+ return not writeBuffer.empty() || readPtr < readBuffer.end();
}
size_t NetworkProtocol::Stream::writeLocked( const uint8_t * buf, size_t size )
@@ -725,6 +775,12 @@ bool NetworkProtocol::InStream::tryUseChunkLocked( const StreamData & chunk )
return true;
}
+size_t NetworkProtocol::OutStream::write( const uint8_t * buf, size_t size )
+{
+ scoped_lock lock( streamMutex );
+ return writeLocked( buf, size );
+}
+
NetworkProtocol::StreamData NetworkProtocol::OutStream::getNextChunkLocked( size_t size )
{
StreamData res;
diff --git a/src/network/protocol.h b/src/network/protocol.h
index 2592c9f..d32b20b 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -106,6 +106,7 @@ public:
Id id() const;
const sockaddr_in6 & peerAddress() const;
+ size_t mtu() const;
optional<Header> receive(const PartialStorage &);
bool send(const PartialStorage &, NetworkProtocol::Header,
@@ -115,7 +116,7 @@ public:
void close();
shared_ptr< InStream > openInStream( uint8_t sid );
- shared_ptr< OutStream > openOutStream( uint8_t sid );
+ shared_ptr< OutStream > openOutStream();
// temporary:
ChannelState & channel();
@@ -136,14 +137,21 @@ class NetworkProtocol::Stream
friend class NetworkProtocol::Connection;
protected:
- Stream(uint8_t id_): id( id_ ) {}
+ Stream(uint8_t id_);
+public:
+ void close();
+
+protected:
bool hasDataLocked() const;
size_t writeLocked( const uint8_t * buf, size_t size );
size_t readLocked( uint8_t * buf, size_t size );
- uint8_t id;
+public:
+ const uint8_t id;
+
+protected:
bool closed { false };
vector< uint8_t > writeBuffer;
vector< uint8_t > readBuffer;
@@ -181,6 +189,9 @@ class NetworkProtocol::OutStream : public NetworkProtocol::Stream
protected:
OutStream(uint8_t id): Stream( id ) {}
+public:
+ size_t write( const uint8_t * buf, size_t size );
+
private:
StreamData getNextChunkLocked( size_t size );
@@ -226,6 +237,8 @@ struct NetworkProtocol::Header
ServiceRef,
StreamOpen>;
+ static constexpr size_t itemSize = 78; /* estimate for size of ref-containing headers */
+
Header(const vector<Item> & items): items(items) {}
static optional<Header> load(const PartialRef &);
static optional<Header> load(const PartialObject &);
@@ -261,28 +274,4 @@ inline bool operator!=(const NetworkProtocol::Header::Item & left,
inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right)
{ return left.value == right.value; }
-class ReplyBuilder
-{
-public:
- void header(NetworkProtocol::Header::Item &&);
- void body(const Ref &);
-
- const vector<NetworkProtocol::Header::Item> & header() const { return mheader; }
- vector<Object> body() const;
-
-private:
- vector<NetworkProtocol::Header::Item> mheader;
- vector<Ref> mbody;
-};
-
-struct WaitingRef
-{
- const Storage storage;
- const PartialRef ref;
- vector<Digest> missing;
-
- optional<Ref> check();
- optional<Ref> check(ReplyBuilder &);
-};
-
}
diff --git a/src/pubkey.cpp b/src/pubkey.cpp
index 59b73f9..9e89cde 100644
--- a/src/pubkey.cpp
+++ b/src/pubkey.cpp
@@ -108,7 +108,7 @@ optional<SecretKey> SecretKey::fromData(const Stored<PublicKey> & pub, const vec
keyData.resize(keyLen);
EVP_PKEY_get_raw_public_key(pub->key.get(), keyData.data(), &keyLen);
- if (EVP_PKEY_cmp(pkey.get(), pub->key.get()) != 1)
+ if( EVP_PKEY_eq( pkey.get(), pub->key.get() ) != 1 )
return nullopt;
pub.ref().storage().storeKey(pub.ref(), sdata);