summaryrefslogtreecommitdiff
path: root/src/network/protocol.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/network/protocol.cpp')
-rw-r--r--src/network/protocol.cpp100
1 files changed, 91 insertions, 9 deletions
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index 4151bf2..5dc831a 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -4,6 +4,7 @@
#include <unistd.h>
#include <cstring>
+#include <iostream>
#include <mutex>
#include <system_error>
@@ -26,6 +27,7 @@ struct NetworkProtocol::ConnectionPriv
vector<uint8_t> buffer {};
ChannelState channel = monostate();
+ vector<vector<uint8_t>> secureOutQueue {};
};
@@ -168,20 +170,79 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const
return p->peerAddress;
}
-bool NetworkProtocol::Connection::receive(vector<uint8_t> & buffer)
+optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)
{
- scoped_lock lock(p->cmutex);
- if (p->buffer.empty())
- return false;
+ vector<uint8_t> buf, decrypted;
+ vector<uint8_t> * current;
- buffer.swap(p->buffer);
- p->buffer.clear();
- return true;
+ {
+ scoped_lock lock(p->cmutex);
+
+ if (p->buffer.empty())
+ return nullopt;
+
+ buf.swap(p->buffer);
+ current = &buf;
+
+ if (holds_alternative<unique_ptr<Channel>>(p->channel)) {
+ if (auto dec = std::get<unique_ptr<Channel>>(p->channel)->decrypt(buf)) {
+ decrypted = std::move(*dec);
+ current = &decrypted;
+ }
+ } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) {
+ if (auto dec = std::get<Stored<ChannelAccept>>(p->channel)->
+ data->channel()->decrypt(buf)) {
+ decrypted = std::move(*dec);
+ current = &decrypted;
+ }
+ }
+ }
+
+ if (auto dec = PartialObject::decodePrefix(partStorage,
+ current->begin(), current->end())) {
+ if (auto header = Header::load(std::get<PartialObject>(*dec))) {
+ auto pos = std::get<1>(*dec);
+ while (auto cdec = PartialObject::decodePrefix(partStorage,
+ pos, current->end())) {
+ partStorage.storeObject(std::get<PartialObject>(*cdec));
+ pos = std::get<1>(*cdec);
+ }
+
+ return header;
+ }
+ }
+
+ std::cerr << "invalid packet\n";
+ return nullopt;
}
-bool NetworkProtocol::Connection::send(const vector<uint8_t> & buffer)
+bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
+ const Header & header,
+ const vector<Object> & objs, bool secure)
{
- p->protocol->sendto(buffer, p->peerAddress);
+ vector<uint8_t> data, part, out;
+
+ {
+ scoped_lock clock(p->cmutex);
+
+ part = header.toObject(partStorage).encode();
+ data.insert(data.end(), part.begin(), part.end());
+ for (const auto & obj : objs) {
+ part = obj.encode();
+ data.insert(data.end(), part.begin(), part.end());
+ }
+
+ if (holds_alternative<unique_ptr<Channel>>(p->channel))
+ out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data);
+ else if (secure)
+ p->secureOutQueue.emplace_back(move(data));
+ else
+ out = std::move(data);
+ }
+
+ if (not out.empty())
+ p->protocol->sendto(out, p->peerAddress);
+
return true;
}
@@ -209,6 +270,27 @@ NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel()
return p->channel;
}
+void NetworkProtocol::Connection::trySendOutQueue()
+{
+ decltype(p->secureOutQueue) queue;
+ {
+ scoped_lock clock(p->cmutex);
+
+ if (p->secureOutQueue.empty())
+ return;
+
+ if (not holds_alternative<unique_ptr<Channel>>(p->channel))
+ return;
+
+ queue.swap(p->secureOutQueue);
+ }
+
+ for (const auto & data : queue) {
+ auto out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data);
+ p->protocol->sendto(out, p->peerAddress);
+ }
+}
+
/******************************************************************************/
/* Header */