From ab86a1f0c3b86050e65fc5b7ac1e88a00f0d228c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Mon, 27 Jan 2020 21:25:39 +0100 Subject: Encrypted channels --- src/network.cpp | 153 +++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 128 insertions(+), 25 deletions(-) (limited to 'src/network.cpp') diff --git a/src/network.cpp b/src/network.cpp index bd1ea8e..e9aeb3f 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -77,7 +78,7 @@ Server::Priv::~Priv() void Server::Priv::doListen() { - vector buf(4096); + vector buf, decrypted, *current; unique_lock lock(dataMutex); while (!finish) { @@ -85,25 +86,51 @@ void Server::Priv::doListen() lock.unlock(); socklen_t addrlen = sizeof(paddr); + buf.resize(4096); ssize_t ret = recvfrom(sock, buf.data(), buf.size(), 0, (sockaddr *) &paddr, &addrlen); if (ret < 0) throw std::system_error(errno, std::generic_category()); + buf.resize(ret); auto & peer = getPeer(paddr); + + current = &buf; + if (holds_alternative>(peer.channel)) { + if (auto dec = std::get>(peer.channel)->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } else if (holds_alternative>(peer.channel)) { + if (auto dec = std::get>(peer.channel)-> + data->channel()->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } + if (auto dec = PartialObject::decodePrefix(peer.partStorage, - buf.begin(), buf.begin() + ret)) { + current->begin(), current->end())) { if (auto header = TransportHeader::load(std::get(*dec))) { auto pos = std::get<1>(*dec); while (auto cdec = PartialObject::decodePrefix(peer.partStorage, - pos, buf.begin() + ret)) { + pos, current->end())) { peer.partStorage.storeObject(std::get(*cdec)); pos = std::get<1>(*cdec); } + + ReplyBuilder reply; + scoped_lock hlock(dataMutex); - handlePacket(peer, *header); - peer.updateIdentity(); + handlePacket(peer, *header, reply); + peer.updateIdentity(reply); + peer.updateChannel(reply); + + if (!reply.header.empty()) + peer.send(TransportHeader(reply.header), reply.body); } + } else { + std::cerr << "invalid packet\n"; } lock.lock(); @@ -140,7 +167,7 @@ void Server::Priv::doAnnounce() } } -Peer & Server::Priv::getPeer(const sockaddr_in & paddr) +Server::Peer & Server::Priv::getPeer(const sockaddr_in & paddr) { for (auto & peer : peers) if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0) @@ -148,9 +175,10 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) auto st = self.ref()->storage().deriveEphemeralStorage(); Peer * peer = new Peer { - .sock = sock, + .server = *this, .addr = paddr, .identity = monostate(), + .channel = monostate(), .tempStorage = st, .partStorage = st.derivePartialStorage(), }; @@ -158,18 +186,21 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) return *peer; } -void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) +void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & header, ReplyBuilder & reply) { unordered_set plaintextRefs; for (const auto & obj : collectStoredObjects(*Stored::load(*self.ref()))) plaintextRefs.insert(obj.ref.digest()); - vector replyHeaders; - vector replyBody; - for (auto & item : header.items) { switch (item.type) { case TransportHeader::Type::Acknowledged: + if (auto pref = std::get(item.value)) { + if (holds_alternative>(peer.channel) && + std::get>(peer.channel).ref.digest() == pref.digest()) + peer.channel.emplace> + (std::get>(peer.channel)->data->channel()); + } break; case TransportHeader::Type::DataRequest: { @@ -177,8 +208,8 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) if (plaintextRefs.find(pref.digest()) != plaintextRefs.end()) { if (auto ref = peer.tempStorage.ref(pref.digest())) { TransportHeader::Item hitem { TransportHeader::Type::DataResponse, *ref }; - replyHeaders.push_back({ TransportHeader::Type::DataResponse, *ref }); - replyBody.push_back(**ref); + reply.header.push_back({ TransportHeader::Type::DataResponse, *ref }); + reply.body.push_back(**ref); } } break; @@ -186,12 +217,12 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) case TransportHeader::Type::DataResponse: if (auto pref = std::get(item.value)) { - replyHeaders.push_back({ TransportHeader::Type::Acknowledged, pref }); + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), pref.digest()) != wref->missing.end()) { - if (wref->check(&replyHeaders)) + if (wref->check(&reply.header)) pwref.reset(); } } @@ -207,7 +238,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; if (holds_alternative(peer.identity)) - replyHeaders.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); + reply.header.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); shared_ptr wref(new WaitingRef { .storage = peer.tempStorage, @@ -217,7 +248,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) }); waiting.push_back(wref); peer.identity = wref; - wref->check(&replyHeaders); + wref->check(&reply.header); break; } @@ -225,9 +256,42 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; case TransportHeader::Type::ChannelRequest: + if (auto pref = std::get(item.value)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + + if (holds_alternative>(peer.channel) && + std::get>(peer.channel).ref.digest() < pref.digest()) + break; + + if (holds_alternative>(peer.channel)) + break; + + shared_ptr wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = pref, + .peer = peer, + .missing = {}, + }); + waiting.push_back(wref); + peer.channel = wref; + wref->check(&reply.header); + } break; case TransportHeader::Type::ChannelAccept: + if (auto pref = std::get(item.value)) { + if (holds_alternative>(peer.channel) && + std::get>(peer.channel).ref.digest() < pref.digest()) + break; + + auto cres = peer.tempStorage.copy(pref); + if (auto r = std::get_if(&cres)) { + if (auto acc = ChannelAccept::load(*r)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + peer.channel.emplace>(acc->data->channel()); + } + } + } break; case TransportHeader::Type::ServiceType: @@ -238,14 +302,11 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) } } - - if (!replyHeaders.empty()) - peer.send(TransportHeader(replyHeaders), replyBody); } -void Peer::send(const TransportHeader & header, const vector & objs) const +void Server::Peer::send(const TransportHeader & header, const vector & objs) const { - vector data, part; + vector data, part, out; part = header.toObject().encode(); data.insert(data.end(), part.begin(), part.end()); @@ -254,18 +315,60 @@ void Peer::send(const TransportHeader & header, const vector & objs) con data.insert(data.end(), part.begin(), part.end()); } - sendto(sock, data.data(), data.size(), 0, + if (holds_alternative>(channel)) + out = std::get>(channel)->encrypt(data); + else + out = std::move(data); + + sendto(server.sock, out.data(), out.size(), 0, (sockaddr *) &addr, sizeof(addr)); } -void Peer::updateIdentity() +void Server::Peer::updateIdentity(ReplyBuilder & reply) { if (holds_alternative>(identity)) - if (auto ref = std::get>(identity)->check()) + if (auto ref = std::get>(identity)->check(&reply.header)) if (auto id = Identity::load(*ref)) identity.emplace(*id); } +void Server::Peer::updateChannel(ReplyBuilder & reply) +{ + if (!holds_alternative(identity)) + return; + + if (holds_alternative(channel)) { + auto req = Channel::generateRequest(tempStorage, + server.self, std::get(identity)); + channel.emplace>(req); + reply.header.push_back({ TransportHeader::Type::ChannelRequest, req.ref }); + reply.body.push_back(*req.ref); + reply.body.push_back(*req->data.ref); + reply.body.push_back(*req->data->key.ref); + for (const auto & sig : req->sigs) + reply.body.push_back(*sig.ref); + } + + if (holds_alternative>(channel)) { + if (auto ref = std::get>(channel)->check(&reply.header)) { + if (auto req = Stored::load(*ref)) { + if (auto acc = Channel::acceptRequest(server.self, std::get(identity), *req)) { + channel.emplace>(*acc); + reply.header.push_back({ TransportHeader::Type::ChannelAccept, acc->ref }); + reply.body.push_back(*acc->ref); + reply.body.push_back(*acc.value()->data.ref); + reply.body.push_back(*acc.value()->data->key.ref); + for (const auto & sig : acc.value()->sigs) + reply.body.push_back(*sig.ref); + } else { + channel = monostate(); + } + } else { + channel = monostate(); + } + } + } +} optional WaitingRef::check(vector * request) { -- cgit v1.2.3