From 401f8c1288842b7479c375fba4aed55f6c5d52e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 19 Aug 2023 11:22:02 +0200 Subject: Network: encrypt and decrypt within connection object --- src/network/protocol.cpp | 100 ++++++++++++++++++++++++++++++++++++++++++----- src/network/protocol.h | 6 ++- 2 files changed, 95 insertions(+), 11 deletions(-) (limited to 'src/network') diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 4151bf2..5dc831a 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -26,6 +27,7 @@ struct NetworkProtocol::ConnectionPriv vector buffer {}; ChannelState channel = monostate(); + vector> secureOutQueue {}; }; @@ -168,20 +170,79 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const return p->peerAddress; } -bool NetworkProtocol::Connection::receive(vector & buffer) +optional NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { - scoped_lock lock(p->cmutex); - if (p->buffer.empty()) - return false; + vector buf, decrypted; + vector * current; - buffer.swap(p->buffer); - p->buffer.clear(); - return true; + { + scoped_lock lock(p->cmutex); + + if (p->buffer.empty()) + return nullopt; + + buf.swap(p->buffer); + current = &buf; + + if (holds_alternative>(p->channel)) { + if (auto dec = std::get>(p->channel)->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } else if (holds_alternative>(p->channel)) { + if (auto dec = std::get>(p->channel)-> + data->channel()->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } + } + + if (auto dec = PartialObject::decodePrefix(partStorage, + current->begin(), current->end())) { + if (auto header = Header::load(std::get(*dec))) { + auto pos = std::get<1>(*dec); + while (auto cdec = PartialObject::decodePrefix(partStorage, + pos, current->end())) { + partStorage.storeObject(std::get(*cdec)); + pos = std::get<1>(*cdec); + } + + return header; + } + } + + std::cerr << "invalid packet\n"; + return nullopt; } -bool NetworkProtocol::Connection::send(const vector & buffer) +bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, + const Header & header, + const vector & objs, bool secure) { - p->protocol->sendto(buffer, p->peerAddress); + vector data, part, out; + + { + scoped_lock clock(p->cmutex); + + part = header.toObject(partStorage).encode(); + data.insert(data.end(), part.begin(), part.end()); + for (const auto & obj : objs) { + part = obj.encode(); + data.insert(data.end(), part.begin(), part.end()); + } + + if (holds_alternative>(p->channel)) + out = std::get>(p->channel)->encrypt(data); + else if (secure) + p->secureOutQueue.emplace_back(move(data)); + else + out = std::move(data); + } + + if (not out.empty()) + p->protocol->sendto(out, p->peerAddress); + return true; } @@ -209,6 +270,27 @@ NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel() return p->channel; } +void NetworkProtocol::Connection::trySendOutQueue() +{ + decltype(p->secureOutQueue) queue; + { + scoped_lock clock(p->cmutex); + + if (p->secureOutQueue.empty()) + return; + + if (not holds_alternative>(p->channel)) + return; + + queue.swap(p->secureOutQueue); + } + + for (const auto & data : queue) { + auto out = std::get>(p->channel)->encrypt(data); + p->protocol->sendto(out, p->peerAddress); + } +} + /******************************************************************************/ /* Header */ diff --git a/src/network/protocol.h b/src/network/protocol.h index 88abf67..c5803ce 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -87,13 +87,15 @@ public: const sockaddr_in6 & peerAddress() const; - bool receive(vector & buffer); - bool send(const vector & buffer); + optional
receive(const PartialStorage &); + bool send(const PartialStorage &, const NetworkProtocol::Header &, + const vector &, bool secure); void close(); // temporary: ChannelState & channel(); + void trySendOutQueue(); private: unique_ptr p; -- cgit v1.2.3