diff options
Diffstat (limited to 'src/network.cpp')
-rw-r--r-- | src/network.cpp | 153 |
1 files changed, 128 insertions, 25 deletions
diff --git a/src/network.cpp b/src/network.cpp index bd1ea8e..e9aeb3f 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -4,6 +4,7 @@ #include <algorithm> #include <cstring> +#include <iostream> #include <ifaddrs.h> #include <net/if.h> @@ -77,7 +78,7 @@ Server::Priv::~Priv() void Server::Priv::doListen() { - vector<uint8_t> buf(4096); + vector<uint8_t> buf, decrypted, *current; unique_lock<mutex> lock(dataMutex); while (!finish) { @@ -85,25 +86,51 @@ void Server::Priv::doListen() lock.unlock(); socklen_t addrlen = sizeof(paddr); + buf.resize(4096); ssize_t ret = recvfrom(sock, buf.data(), buf.size(), 0, (sockaddr *) &paddr, &addrlen); if (ret < 0) throw std::system_error(errno, std::generic_category()); + buf.resize(ret); auto & peer = getPeer(paddr); + + current = &buf; + if (holds_alternative<Stored<Channel>>(peer.channel)) { + if (auto dec = std::get<Stored<Channel>>(peer.channel)->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } else if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) { + if (auto dec = std::get<Stored<ChannelAccept>>(peer.channel)-> + data->channel()->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } + if (auto dec = PartialObject::decodePrefix(peer.partStorage, - buf.begin(), buf.begin() + ret)) { + current->begin(), current->end())) { if (auto header = TransportHeader::load(std::get<PartialObject>(*dec))) { auto pos = std::get<1>(*dec); while (auto cdec = PartialObject::decodePrefix(peer.partStorage, - pos, buf.begin() + ret)) { + pos, current->end())) { peer.partStorage.storeObject(std::get<PartialObject>(*cdec)); pos = std::get<1>(*cdec); } + + ReplyBuilder reply; + scoped_lock<mutex> hlock(dataMutex); - handlePacket(peer, *header); - peer.updateIdentity(); + handlePacket(peer, *header, reply); + peer.updateIdentity(reply); + peer.updateChannel(reply); + + if (!reply.header.empty()) + peer.send(TransportHeader(reply.header), reply.body); } + } else { + std::cerr << "invalid packet\n"; } lock.lock(); @@ -140,7 +167,7 @@ void Server::Priv::doAnnounce() } } -Peer & Server::Priv::getPeer(const sockaddr_in & paddr) +Server::Peer & Server::Priv::getPeer(const sockaddr_in & paddr) { for (auto & peer : peers) if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0) @@ -148,9 +175,10 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) auto st = self.ref()->storage().deriveEphemeralStorage(); Peer * peer = new Peer { - .sock = sock, + .server = *this, .addr = paddr, .identity = monostate(), + .channel = monostate(), .tempStorage = st, .partStorage = st.derivePartialStorage(), }; @@ -158,18 +186,21 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) return *peer; } -void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) +void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & header, ReplyBuilder & reply) { unordered_set<Digest> plaintextRefs; for (const auto & obj : collectStoredObjects(*Stored<Object>::load(*self.ref()))) plaintextRefs.insert(obj.ref.digest()); - vector<TransportHeader::Item> replyHeaders; - vector<Object> replyBody; - for (auto & item : header.items) { switch (item.type) { case TransportHeader::Type::Acknowledged: + if (auto pref = std::get<PartialRef>(item.value)) { + if (holds_alternative<Stored<ChannelAccept>>(peer.channel) && + std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() == pref.digest()) + peer.channel.emplace<Stored<Channel>> + (std::get<Stored<ChannelAccept>>(peer.channel)->data->channel()); + } break; case TransportHeader::Type::DataRequest: { @@ -177,8 +208,8 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) if (plaintextRefs.find(pref.digest()) != plaintextRefs.end()) { if (auto ref = peer.tempStorage.ref(pref.digest())) { TransportHeader::Item hitem { TransportHeader::Type::DataResponse, *ref }; - replyHeaders.push_back({ TransportHeader::Type::DataResponse, *ref }); - replyBody.push_back(**ref); + reply.header.push_back({ TransportHeader::Type::DataResponse, *ref }); + reply.body.push_back(**ref); } } break; @@ -186,12 +217,12 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) case TransportHeader::Type::DataResponse: if (auto pref = std::get<PartialRef>(item.value)) { - replyHeaders.push_back({ TransportHeader::Type::Acknowledged, pref }); + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), pref.digest()) != wref->missing.end()) { - if (wref->check(&replyHeaders)) + if (wref->check(&reply.header)) pwref.reset(); } } @@ -207,7 +238,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; if (holds_alternative<monostate>(peer.identity)) - replyHeaders.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); + reply.header.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); shared_ptr<WaitingRef> wref(new WaitingRef { .storage = peer.tempStorage, @@ -217,7 +248,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) }); waiting.push_back(wref); peer.identity = wref; - wref->check(&replyHeaders); + wref->check(&reply.header); break; } @@ -225,9 +256,42 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; case TransportHeader::Type::ChannelRequest: + if (auto pref = std::get<PartialRef>(item.value)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + + if (holds_alternative<Stored<ChannelRequest>>(peer.channel) && + std::get<Stored<ChannelRequest>>(peer.channel).ref.digest() < pref.digest()) + break; + + if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) + break; + + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = pref, + .peer = peer, + .missing = {}, + }); + waiting.push_back(wref); + peer.channel = wref; + wref->check(&reply.header); + } break; case TransportHeader::Type::ChannelAccept: + if (auto pref = std::get<PartialRef>(item.value)) { + if (holds_alternative<Stored<ChannelAccept>>(peer.channel) && + std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() < pref.digest()) + break; + + auto cres = peer.tempStorage.copy(pref); + if (auto r = std::get_if<Ref>(&cres)) { + if (auto acc = ChannelAccept::load(*r)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + peer.channel.emplace<Stored<Channel>>(acc->data->channel()); + } + } + } break; case TransportHeader::Type::ServiceType: @@ -238,14 +302,11 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) } } - - if (!replyHeaders.empty()) - peer.send(TransportHeader(replyHeaders), replyBody); } -void Peer::send(const TransportHeader & header, const vector<Object> & objs) const +void Server::Peer::send(const TransportHeader & header, const vector<Object> & objs) const { - vector<uint8_t> data, part; + vector<uint8_t> data, part, out; part = header.toObject().encode(); data.insert(data.end(), part.begin(), part.end()); @@ -254,18 +315,60 @@ void Peer::send(const TransportHeader & header, const vector<Object> & objs) con data.insert(data.end(), part.begin(), part.end()); } - sendto(sock, data.data(), data.size(), 0, + if (holds_alternative<Stored<Channel>>(channel)) + out = std::get<Stored<Channel>>(channel)->encrypt(data); + else + out = std::move(data); + + sendto(server.sock, out.data(), out.size(), 0, (sockaddr *) &addr, sizeof(addr)); } -void Peer::updateIdentity() +void Server::Peer::updateIdentity(ReplyBuilder & reply) { if (holds_alternative<shared_ptr<WaitingRef>>(identity)) - if (auto ref = std::get<shared_ptr<WaitingRef>>(identity)->check()) + if (auto ref = std::get<shared_ptr<WaitingRef>>(identity)->check(&reply.header)) if (auto id = Identity::load(*ref)) identity.emplace<Identity>(*id); } +void Server::Peer::updateChannel(ReplyBuilder & reply) +{ + if (!holds_alternative<Identity>(identity)) + return; + + if (holds_alternative<monostate>(channel)) { + auto req = Channel::generateRequest(tempStorage, + server.self, std::get<Identity>(identity)); + channel.emplace<Stored<ChannelRequest>>(req); + reply.header.push_back({ TransportHeader::Type::ChannelRequest, req.ref }); + reply.body.push_back(*req.ref); + reply.body.push_back(*req->data.ref); + reply.body.push_back(*req->data->key.ref); + for (const auto & sig : req->sigs) + reply.body.push_back(*sig.ref); + } + + if (holds_alternative<shared_ptr<WaitingRef>>(channel)) { + if (auto ref = std::get<shared_ptr<WaitingRef>>(channel)->check(&reply.header)) { + if (auto req = Stored<ChannelRequest>::load(*ref)) { + if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), *req)) { + channel.emplace<Stored<ChannelAccept>>(*acc); + reply.header.push_back({ TransportHeader::Type::ChannelAccept, acc->ref }); + reply.body.push_back(*acc->ref); + reply.body.push_back(*acc.value()->data.ref); + reply.body.push_back(*acc.value()->data->key.ref); + for (const auto & sig : acc.value()->sigs) + reply.body.push_back(*sig.ref); + } else { + channel = monostate(); + } + } else { + channel = monostate(); + } + } + } +} optional<Ref> WaitingRef::check(vector<TransportHeader::Item> * request) { |