summaryrefslogtreecommitdiff
path: root/src/network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/network.cpp')
-rw-r--r--src/network.cpp153
1 files changed, 128 insertions, 25 deletions
diff --git a/src/network.cpp b/src/network.cpp
index bd1ea8e..e9aeb3f 100644
--- a/src/network.cpp
+++ b/src/network.cpp
@@ -4,6 +4,7 @@
#include <algorithm>
#include <cstring>
+#include <iostream>
#include <ifaddrs.h>
#include <net/if.h>
@@ -77,7 +78,7 @@ Server::Priv::~Priv()
void Server::Priv::doListen()
{
- vector<uint8_t> buf(4096);
+ vector<uint8_t> buf, decrypted, *current;
unique_lock<mutex> lock(dataMutex);
while (!finish) {
@@ -85,25 +86,51 @@ void Server::Priv::doListen()
lock.unlock();
socklen_t addrlen = sizeof(paddr);
+ buf.resize(4096);
ssize_t ret = recvfrom(sock, buf.data(), buf.size(), 0,
(sockaddr *) &paddr, &addrlen);
if (ret < 0)
throw std::system_error(errno, std::generic_category());
+ buf.resize(ret);
auto & peer = getPeer(paddr);
+
+ current = &buf;
+ if (holds_alternative<Stored<Channel>>(peer.channel)) {
+ if (auto dec = std::get<Stored<Channel>>(peer.channel)->decrypt(buf)) {
+ decrypted = std::move(*dec);
+ current = &decrypted;
+ }
+ } else if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) {
+ if (auto dec = std::get<Stored<ChannelAccept>>(peer.channel)->
+ data->channel()->decrypt(buf)) {
+ decrypted = std::move(*dec);
+ current = &decrypted;
+ }
+ }
+
if (auto dec = PartialObject::decodePrefix(peer.partStorage,
- buf.begin(), buf.begin() + ret)) {
+ current->begin(), current->end())) {
if (auto header = TransportHeader::load(std::get<PartialObject>(*dec))) {
auto pos = std::get<1>(*dec);
while (auto cdec = PartialObject::decodePrefix(peer.partStorage,
- pos, buf.begin() + ret)) {
+ pos, current->end())) {
peer.partStorage.storeObject(std::get<PartialObject>(*cdec));
pos = std::get<1>(*cdec);
}
+
+ ReplyBuilder reply;
+
scoped_lock<mutex> hlock(dataMutex);
- handlePacket(peer, *header);
- peer.updateIdentity();
+ handlePacket(peer, *header, reply);
+ peer.updateIdentity(reply);
+ peer.updateChannel(reply);
+
+ if (!reply.header.empty())
+ peer.send(TransportHeader(reply.header), reply.body);
}
+ } else {
+ std::cerr << "invalid packet\n";
}
lock.lock();
@@ -140,7 +167,7 @@ void Server::Priv::doAnnounce()
}
}
-Peer & Server::Priv::getPeer(const sockaddr_in & paddr)
+Server::Peer & Server::Priv::getPeer(const sockaddr_in & paddr)
{
for (auto & peer : peers)
if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0)
@@ -148,9 +175,10 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr)
auto st = self.ref()->storage().deriveEphemeralStorage();
Peer * peer = new Peer {
- .sock = sock,
+ .server = *this,
.addr = paddr,
.identity = monostate(),
+ .channel = monostate(),
.tempStorage = st,
.partStorage = st.derivePartialStorage(),
};
@@ -158,18 +186,21 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr)
return *peer;
}
-void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header)
+void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & header, ReplyBuilder & reply)
{
unordered_set<Digest> plaintextRefs;
for (const auto & obj : collectStoredObjects(*Stored<Object>::load(*self.ref())))
plaintextRefs.insert(obj.ref.digest());
- vector<TransportHeader::Item> replyHeaders;
- vector<Object> replyBody;
-
for (auto & item : header.items) {
switch (item.type) {
case TransportHeader::Type::Acknowledged:
+ if (auto pref = std::get<PartialRef>(item.value)) {
+ if (holds_alternative<Stored<ChannelAccept>>(peer.channel) &&
+ std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() == pref.digest())
+ peer.channel.emplace<Stored<Channel>>
+ (std::get<Stored<ChannelAccept>>(peer.channel)->data->channel());
+ }
break;
case TransportHeader::Type::DataRequest: {
@@ -177,8 +208,8 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header)
if (plaintextRefs.find(pref.digest()) != plaintextRefs.end()) {
if (auto ref = peer.tempStorage.ref(pref.digest())) {
TransportHeader::Item hitem { TransportHeader::Type::DataResponse, *ref };
- replyHeaders.push_back({ TransportHeader::Type::DataResponse, *ref });
- replyBody.push_back(**ref);
+ reply.header.push_back({ TransportHeader::Type::DataResponse, *ref });
+ reply.body.push_back(**ref);
}
}
break;
@@ -186,12 +217,12 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header)
case TransportHeader::Type::DataResponse:
if (auto pref = std::get<PartialRef>(item.value)) {
- replyHeaders.push_back({ TransportHeader::Type::Acknowledged, pref });
+ reply.header.push_back({ TransportHeader::Type::Acknowledged, pref });
for (auto & pwref : waiting) {
if (auto wref = pwref.lock()) {
if (std::find(wref->missing.begin(), wref->missing.end(), pref.digest()) !=
wref->missing.end()) {
- if (wref->check(&replyHeaders))
+ if (wref->check(&reply.header))
pwref.reset();
}
}
@@ -207,7 +238,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header)
break;
if (holds_alternative<monostate>(peer.identity))
- replyHeaders.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()});
+ reply.header.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()});
shared_ptr<WaitingRef> wref(new WaitingRef {
.storage = peer.tempStorage,
@@ -217,7 +248,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header)
});
waiting.push_back(wref);
peer.identity = wref;
- wref->check(&replyHeaders);
+ wref->check(&reply.header);
break;
}
@@ -225,9 +256,42 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header)
break;
case TransportHeader::Type::ChannelRequest:
+ if (auto pref = std::get<PartialRef>(item.value)) {
+ reply.header.push_back({ TransportHeader::Type::Acknowledged, pref });
+
+ if (holds_alternative<Stored<ChannelRequest>>(peer.channel) &&
+ std::get<Stored<ChannelRequest>>(peer.channel).ref.digest() < pref.digest())
+ break;
+
+ if (holds_alternative<Stored<ChannelAccept>>(peer.channel))
+ break;
+
+ shared_ptr<WaitingRef> wref(new WaitingRef {
+ .storage = peer.tempStorage,
+ .ref = pref,
+ .peer = peer,
+ .missing = {},
+ });
+ waiting.push_back(wref);
+ peer.channel = wref;
+ wref->check(&reply.header);
+ }
break;
case TransportHeader::Type::ChannelAccept:
+ if (auto pref = std::get<PartialRef>(item.value)) {
+ if (holds_alternative<Stored<ChannelAccept>>(peer.channel) &&
+ std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() < pref.digest())
+ break;
+
+ auto cres = peer.tempStorage.copy(pref);
+ if (auto r = std::get_if<Ref>(&cres)) {
+ if (auto acc = ChannelAccept::load(*r)) {
+ reply.header.push_back({ TransportHeader::Type::Acknowledged, pref });
+ peer.channel.emplace<Stored<Channel>>(acc->data->channel());
+ }
+ }
+ }
break;
case TransportHeader::Type::ServiceType:
@@ -238,14 +302,11 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header)
}
}
-
- if (!replyHeaders.empty())
- peer.send(TransportHeader(replyHeaders), replyBody);
}
-void Peer::send(const TransportHeader & header, const vector<Object> & objs) const
+void Server::Peer::send(const TransportHeader & header, const vector<Object> & objs) const
{
- vector<uint8_t> data, part;
+ vector<uint8_t> data, part, out;
part = header.toObject().encode();
data.insert(data.end(), part.begin(), part.end());
@@ -254,18 +315,60 @@ void Peer::send(const TransportHeader & header, const vector<Object> & objs) con
data.insert(data.end(), part.begin(), part.end());
}
- sendto(sock, data.data(), data.size(), 0,
+ if (holds_alternative<Stored<Channel>>(channel))
+ out = std::get<Stored<Channel>>(channel)->encrypt(data);
+ else
+ out = std::move(data);
+
+ sendto(server.sock, out.data(), out.size(), 0,
(sockaddr *) &addr, sizeof(addr));
}
-void Peer::updateIdentity()
+void Server::Peer::updateIdentity(ReplyBuilder & reply)
{
if (holds_alternative<shared_ptr<WaitingRef>>(identity))
- if (auto ref = std::get<shared_ptr<WaitingRef>>(identity)->check())
+ if (auto ref = std::get<shared_ptr<WaitingRef>>(identity)->check(&reply.header))
if (auto id = Identity::load(*ref))
identity.emplace<Identity>(*id);
}
+void Server::Peer::updateChannel(ReplyBuilder & reply)
+{
+ if (!holds_alternative<Identity>(identity))
+ return;
+
+ if (holds_alternative<monostate>(channel)) {
+ auto req = Channel::generateRequest(tempStorage,
+ server.self, std::get<Identity>(identity));
+ channel.emplace<Stored<ChannelRequest>>(req);
+ reply.header.push_back({ TransportHeader::Type::ChannelRequest, req.ref });
+ reply.body.push_back(*req.ref);
+ reply.body.push_back(*req->data.ref);
+ reply.body.push_back(*req->data->key.ref);
+ for (const auto & sig : req->sigs)
+ reply.body.push_back(*sig.ref);
+ }
+
+ if (holds_alternative<shared_ptr<WaitingRef>>(channel)) {
+ if (auto ref = std::get<shared_ptr<WaitingRef>>(channel)->check(&reply.header)) {
+ if (auto req = Stored<ChannelRequest>::load(*ref)) {
+ if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), *req)) {
+ channel.emplace<Stored<ChannelAccept>>(*acc);
+ reply.header.push_back({ TransportHeader::Type::ChannelAccept, acc->ref });
+ reply.body.push_back(*acc->ref);
+ reply.body.push_back(*acc.value()->data.ref);
+ reply.body.push_back(*acc.value()->data->key.ref);
+ for (const auto & sig : acc.value()->sigs)
+ reply.body.push_back(*sig.ref);
+ } else {
+ channel = monostate();
+ }
+ } else {
+ channel = monostate();
+ }
+ }
+ }
+}
optional<Ref> WaitingRef::check(vector<TransportHeader::Item> * request)
{