diff options
Diffstat (limited to 'src/network/protocol.cpp')
-rw-r--r-- | src/network/protocol.cpp | 100 |
1 files changed, 91 insertions, 9 deletions
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 4151bf2..5dc831a 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -4,6 +4,7 @@ #include <unistd.h> #include <cstring> +#include <iostream> #include <mutex> #include <system_error> @@ -26,6 +27,7 @@ struct NetworkProtocol::ConnectionPriv vector<uint8_t> buffer {}; ChannelState channel = monostate(); + vector<vector<uint8_t>> secureOutQueue {}; }; @@ -168,20 +170,79 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const return p->peerAddress; } -bool NetworkProtocol::Connection::receive(vector<uint8_t> & buffer) +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { - scoped_lock lock(p->cmutex); - if (p->buffer.empty()) - return false; + vector<uint8_t> buf, decrypted; + vector<uint8_t> * current; - buffer.swap(p->buffer); - p->buffer.clear(); - return true; + { + scoped_lock lock(p->cmutex); + + if (p->buffer.empty()) + return nullopt; + + buf.swap(p->buffer); + current = &buf; + + if (holds_alternative<unique_ptr<Channel>>(p->channel)) { + if (auto dec = std::get<unique_ptr<Channel>>(p->channel)->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { + if (auto dec = std::get<Stored<ChannelAccept>>(p->channel)-> + data->channel()->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } + } + + if (auto dec = PartialObject::decodePrefix(partStorage, + current->begin(), current->end())) { + if (auto header = Header::load(std::get<PartialObject>(*dec))) { + auto pos = std::get<1>(*dec); + while (auto cdec = PartialObject::decodePrefix(partStorage, + pos, current->end())) { + partStorage.storeObject(std::get<PartialObject>(*cdec)); + pos = std::get<1>(*cdec); + } + + return header; + } + } + + std::cerr << "invalid packet\n"; + return nullopt; } -bool NetworkProtocol::Connection::send(const vector<uint8_t> & buffer) +bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, + const Header & header, + const vector<Object> & objs, bool secure) { - p->protocol->sendto(buffer, p->peerAddress); + vector<uint8_t> data, part, out; + + { + scoped_lock clock(p->cmutex); + + part = header.toObject(partStorage).encode(); + data.insert(data.end(), part.begin(), part.end()); + for (const auto & obj : objs) { + part = obj.encode(); + data.insert(data.end(), part.begin(), part.end()); + } + + if (holds_alternative<unique_ptr<Channel>>(p->channel)) + out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); + else if (secure) + p->secureOutQueue.emplace_back(move(data)); + else + out = std::move(data); + } + + if (not out.empty()) + p->protocol->sendto(out, p->peerAddress); + return true; } @@ -209,6 +270,27 @@ NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel() return p->channel; } +void NetworkProtocol::Connection::trySendOutQueue() +{ + decltype(p->secureOutQueue) queue; + { + scoped_lock clock(p->cmutex); + + if (p->secureOutQueue.empty()) + return; + + if (not holds_alternative<unique_ptr<Channel>>(p->channel)) + return; + + queue.swap(p->secureOutQueue); + } + + for (const auto & data : queue) { + auto out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); + p->protocol->sendto(out, p->peerAddress); + } +} + /******************************************************************************/ /* Header */ |