diff options
-rw-r--r-- | src/network.cpp | 97 | ||||
-rw-r--r-- | src/network.h | 3 | ||||
-rw-r--r-- | src/network/protocol.cpp | 100 | ||||
-rw-r--r-- | src/network/protocol.h | 6 |
4 files changed, 110 insertions, 96 deletions
diff --git a/src/network.cpp b/src/network.cpp index 455496c..8c181cf 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -117,7 +117,7 @@ void Server::addPeer(const string & node, const string & service) const header.push_back(NetworkProtocol::Header::Item { NetworkProtocol::Header::Type::AnnounceSelf, p->self.ref()->digest() }); } - peer.send(header, {}, false); + peer.connection.send(peer.partStorage, header, {}, false); return; } } @@ -229,7 +229,7 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const { NetworkProtocol::Header::Type::ServiceType, uuid }, { NetworkProtocol::Header::Type::ServiceRef, ref.digest() }, }); - speer->send(header, { obj }, true); + speer->connection.send(speer->partStorage, header, { obj }, true); return true; } @@ -360,7 +360,6 @@ shared_ptr<Server::Priv> Server::Priv::getptr() void Server::Priv::doListen() { - vector<uint8_t> buf, decrypted, *current; unique_lock lock(dataMutex); for (; !finish; lock.lock()) { @@ -385,50 +384,22 @@ void Server::Priv::doListen() if (!peer) continue; - if (not peer->connection.receive(buf)) - continue; - - current = &buf; - if (holds_alternative<unique_ptr<Channel>>(peer->connection.channel())) { - if (auto dec = std::get<unique_ptr<Channel>>(peer->connection.channel())->decrypt(buf)) { - decrypted = std::move(*dec); - current = &decrypted; - } - } else if (holds_alternative<Stored<ChannelAccept>>(peer->connection.channel())) { - if (auto dec = std::get<Stored<ChannelAccept>>(peer->connection.channel())-> - data->channel()->decrypt(buf)) { - decrypted = std::move(*dec); - current = &decrypted; - } - } - - if (auto dec = PartialObject::decodePrefix(peer->partStorage, - current->begin(), current->end())) { - if (auto header = NetworkProtocol::Header::load(std::get<PartialObject>(*dec))) { - auto pos = std::get<1>(*dec); - while (auto cdec = PartialObject::decodePrefix(peer->partStorage, - pos, current->end())) { - peer->partStorage.storeObject(std::get<PartialObject>(*cdec)); - pos = std::get<1>(*cdec); - } + if (auto header = peer->connection.receive(peer->partStorage)) { + ReplyBuilder reply; - ReplyBuilder reply; - - scoped_lock hlock(dataMutex); - shared_lock slock(selfMutex); + scoped_lock hlock(dataMutex); + shared_lock slock(selfMutex); - handlePacket(*peer, *header, reply); - peer->updateIdentity(reply); - peer->updateChannel(reply); - peer->updateService(reply); + handlePacket(*peer, *header, reply); + peer->updateIdentity(reply); + peer->updateChannel(reply); + peer->updateService(reply); - if (!reply.header().empty()) - peer->send(NetworkProtocol::Header(reply.header()), reply.body(), false); + if (!reply.header().empty()) + peer->connection.send(peer->partStorage, + NetworkProtocol::Header(reply.header()), reply.body(), false); - peer->trySendOutQueue(); - } - } else { - std::cerr << "invalid packet\n"; + peer->connection.trySendOutQueue(); } } } @@ -703,33 +674,11 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head) NetworkProtocol::Header header(hitems); for (const auto & peer : peers) - peer->send(header, { **self.ref() }, false); + peer->connection.send(peer->partStorage, header, { **self.ref() }, false); } } } -void Server::Peer::send(const NetworkProtocol::Header & header, const vector<Object> & objs, bool secure) -{ - vector<uint8_t> data, part, out; - - part = header.toObject(partStorage).encode(); - data.insert(data.end(), part.begin(), part.end()); - for (const auto & obj : objs) { - part = obj.encode(); - data.insert(data.end(), part.begin(), part.end()); - } - - if (holds_alternative<unique_ptr<Channel>>(connection.channel())) - out = std::get<unique_ptr<Channel>>(connection.channel())->encrypt(data); - else if (secure) - secureOutQueue.emplace_back(move(data)); - else - out = std::move(data); - - if (!out.empty()) - connection.send(out); -} - void Server::Peer::updateIdentity(ReplyBuilder &) { if (holds_alternative<shared_ptr<WaitingRef>>(identity)) { @@ -853,22 +802,6 @@ void Server::Peer::updateService(ReplyBuilder & reply) serviceQueue = std::move(next); } -void Server::Peer::trySendOutQueue() -{ - if (secureOutQueue.empty()) - return; - - if (!holds_alternative<unique_ptr<Channel>>(connection.channel())) - return; - - for (const auto & data : secureOutQueue) { - auto out = std::get<unique_ptr<Channel>>(connection.channel())->encrypt(data); - connection.send(out); - } - - secureOutQueue.clear(); -} - void ReplyBuilder::header(NetworkProtocol::Header::Item && item) { diff --git a/src/network.h b/src/network.h index 2959adc..d1fae15 100644 --- a/src/network.h +++ b/src/network.h @@ -54,16 +54,13 @@ struct Server::Peer PartialStorage partStorage; vector<tuple<UUID, shared_ptr<WaitingRef>>> serviceQueue {}; - vector<vector<uint8_t>> secureOutQueue {}; shared_ptr<erebos::Peer::Priv> lpeer = nullptr; - void send(const NetworkProtocol::Header &, const vector<Object> &, bool secure); void updateIdentity(ReplyBuilder &); void updateChannel(ReplyBuilder &); void finalizeChannel(ReplyBuilder &, unique_ptr<Channel>); void updateService(ReplyBuilder &); - void trySendOutQueue(); }; struct Peer::Priv : enable_shared_from_this<Peer::Priv> diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 4151bf2..5dc831a 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -4,6 +4,7 @@ #include <unistd.h> #include <cstring> +#include <iostream> #include <mutex> #include <system_error> @@ -26,6 +27,7 @@ struct NetworkProtocol::ConnectionPriv vector<uint8_t> buffer {}; ChannelState channel = monostate(); + vector<vector<uint8_t>> secureOutQueue {}; }; @@ -168,20 +170,79 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const return p->peerAddress; } -bool NetworkProtocol::Connection::receive(vector<uint8_t> & buffer) +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { - scoped_lock lock(p->cmutex); - if (p->buffer.empty()) - return false; + vector<uint8_t> buf, decrypted; + vector<uint8_t> * current; - buffer.swap(p->buffer); - p->buffer.clear(); - return true; + { + scoped_lock lock(p->cmutex); + + if (p->buffer.empty()) + return nullopt; + + buf.swap(p->buffer); + current = &buf; + + if (holds_alternative<unique_ptr<Channel>>(p->channel)) { + if (auto dec = std::get<unique_ptr<Channel>>(p->channel)->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { + if (auto dec = std::get<Stored<ChannelAccept>>(p->channel)-> + data->channel()->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } + } + + if (auto dec = PartialObject::decodePrefix(partStorage, + current->begin(), current->end())) { + if (auto header = Header::load(std::get<PartialObject>(*dec))) { + auto pos = std::get<1>(*dec); + while (auto cdec = PartialObject::decodePrefix(partStorage, + pos, current->end())) { + partStorage.storeObject(std::get<PartialObject>(*cdec)); + pos = std::get<1>(*cdec); + } + + return header; + } + } + + std::cerr << "invalid packet\n"; + return nullopt; } -bool NetworkProtocol::Connection::send(const vector<uint8_t> & buffer) +bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, + const Header & header, + const vector<Object> & objs, bool secure) { - p->protocol->sendto(buffer, p->peerAddress); + vector<uint8_t> data, part, out; + + { + scoped_lock clock(p->cmutex); + + part = header.toObject(partStorage).encode(); + data.insert(data.end(), part.begin(), part.end()); + for (const auto & obj : objs) { + part = obj.encode(); + data.insert(data.end(), part.begin(), part.end()); + } + + if (holds_alternative<unique_ptr<Channel>>(p->channel)) + out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); + else if (secure) + p->secureOutQueue.emplace_back(move(data)); + else + out = std::move(data); + } + + if (not out.empty()) + p->protocol->sendto(out, p->peerAddress); + return true; } @@ -209,6 +270,27 @@ NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel() return p->channel; } +void NetworkProtocol::Connection::trySendOutQueue() +{ + decltype(p->secureOutQueue) queue; + { + scoped_lock clock(p->cmutex); + + if (p->secureOutQueue.empty()) + return; + + if (not holds_alternative<unique_ptr<Channel>>(p->channel)) + return; + + queue.swap(p->secureOutQueue); + } + + for (const auto & data : queue) { + auto out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); + p->protocol->sendto(out, p->peerAddress); + } +} + /******************************************************************************/ /* Header */ diff --git a/src/network/protocol.h b/src/network/protocol.h index 88abf67..c5803ce 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -87,13 +87,15 @@ public: const sockaddr_in6 & peerAddress() const; - bool receive(vector<uint8_t> & buffer); - bool send(const vector<uint8_t> & buffer); + optional<Header> receive(const PartialStorage &); + bool send(const PartialStorage &, const NetworkProtocol::Header &, + const vector<Object> &, bool secure); void close(); // temporary: ChannelState & channel(); + void trySendOutQueue(); private: unique_ptr<ConnectionPriv> p; |