From b09e73f0abcc386719a2235cc3ae61fb1cbfc5ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Wed, 16 Aug 2023 20:53:58 +0200 Subject: Move network header definitions to protocol module --- src/network.cpp | 229 ++++++++++------------------------------------- src/network.h | 42 ++------- src/network/protocol.cpp | 150 +++++++++++++++++++++++++++++++ src/network/protocol.h | 36 ++++++++ 4 files changed, 237 insertions(+), 220 deletions(-) diff --git a/src/network.cpp b/src/network.cpp index 786e752..ef06d6e 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -111,11 +111,11 @@ void Server::addPeer(const string & node, const string & service) const if (rp->ai_family == AF_INET6) { Peer & peer = p->getPeer(*(sockaddr_in6 *)rp->ai_addr); - vector header; + vector header; { shared_lock lock(p->selfMutex); - header.push_back(TransportHeader::Item { - TransportHeader::Type::AnnounceSelf, *p->self.ref() }); + header.push_back(NetworkProtocol::Header::Item { + NetworkProtocol::Header::Type::AnnounceSelf, *p->self.ref() }); } peer.send(header, {}, false); return; @@ -232,9 +232,9 @@ bool Peer::send(UUID uuid, const Object & obj) const bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const { if (auto speer = p->speer.lock()) { - TransportHeader header({ - { TransportHeader::Type::ServiceType, uuid }, - { TransportHeader::Type::ServiceRef, ref }, + NetworkProtocol::Header header({ + { NetworkProtocol::Header::Type::ServiceType, uuid }, + { NetworkProtocol::Header::Type::ServiceRef, ref }, }); speer->send(header, { obj }, true); return true; @@ -411,7 +411,7 @@ void Server::Priv::doListen() if (auto dec = PartialObject::decodePrefix(peer->partStorage, current->begin(), current->end())) { - if (auto header = TransportHeader::load(std::get(*dec))) { + if (auto header = NetworkProtocol::Header::load(std::get(*dec))) { auto pos = std::get<1>(*dec); while (auto cdec = PartialObject::decodePrefix(peer->partStorage, pos, current->end())) { @@ -430,7 +430,7 @@ void Server::Priv::doListen() peer->updateService(reply); if (!reply.header().empty()) - peer->send(TransportHeader(reply.header()), reply.body(), false); + peer->send(NetworkProtocol::Header(reply.header()), reply.body(), false); peer->trySendOutQueue(); } @@ -450,8 +450,8 @@ void Server::Priv::doAnnounce() if (lastAnnounce + announceInterval < now) { shared_lock slock(selfMutex); - TransportHeader header({ - { TransportHeader::Type::AnnounceSelf, *self.ref() } + NetworkProtocol::Header header({ + { NetworkProtocol::Header::Type::AnnounceSelf, *self.ref() } }); vector bytes = header.toObject().encode(); @@ -534,7 +534,7 @@ Server::Peer & Server::Priv::addPeer(NetworkProtocol::Connection conn) return *peer; } -void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & header, ReplyBuilder & reply) +void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Header & header, ReplyBuilder & reply) { unordered_set plaintextRefs; for (const auto & obj : collectStoredObjects(Stored::load(*self.ref()))) @@ -544,7 +544,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea for (auto & item : header.items) { switch (item.type) { - case TransportHeader::Type::Acknowledged: + case NetworkProtocol::Header::Type::Acknowledged: if (auto pref = std::get(item.value)) { if (holds_alternative>(peer.channel) && std::get>(peer.channel).ref().digest() == pref.digest()) @@ -553,22 +553,22 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea } break; - case TransportHeader::Type::DataRequest: { + case NetworkProtocol::Header::Type::DataRequest: { auto pref = std::get(item.value); if (holds_alternative>(peer.channel) || plaintextRefs.find(pref.digest()) != plaintextRefs.end()) { if (auto ref = peer.tempStorage.ref(pref.digest())) { - TransportHeader::Item hitem { TransportHeader::Type::DataResponse, *ref }; - reply.header({ TransportHeader::Type::DataResponse, *ref }); + NetworkProtocol::Header::Item hitem { NetworkProtocol::Header::Type::DataResponse, *ref }; + reply.header({ NetworkProtocol::Header::Type::DataResponse, *ref }); reply.body(*ref); } } break; } - case TransportHeader::Type::DataResponse: + case NetworkProtocol::Header::Type::DataResponse: if (auto pref = std::get(item.value)) { - reply.header({ TransportHeader::Type::Acknowledged, pref }); + reply.header({ NetworkProtocol::Header::Type::Acknowledged, pref }); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), pref.digest()) != @@ -583,13 +583,13 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea } break; - case TransportHeader::Type::AnnounceSelf: { + case NetworkProtocol::Header::Type::AnnounceSelf: { auto pref = std::get(item.value); if (pref.digest() == self.ref()->digest()) break; if (holds_alternative(peer.identity)) { - reply.header({ TransportHeader::Type::AnnounceSelf, *self.ref()}); + reply.header({ NetworkProtocol::Header::Type::AnnounceSelf, *self.ref()}); shared_ptr wref(new WaitingRef { .storage = peer.tempStorage, @@ -604,10 +604,10 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea break; } - case TransportHeader::Type::AnnounceUpdate: + case NetworkProtocol::Header::Type::AnnounceUpdate: if (holds_alternative(peer.identity)) { auto pref = std::get(item.value); - reply.header({ TransportHeader::Type::Acknowledged, pref }); + reply.header({ NetworkProtocol::Header::Type::Acknowledged, pref }); shared_ptr wref(new WaitingRef { .storage = peer.tempStorage, @@ -621,9 +621,9 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea } break; - case TransportHeader::Type::ChannelRequest: + case NetworkProtocol::Header::Type::ChannelRequest: if (auto pref = std::get(item.value)) { - reply.header({ TransportHeader::Type::Acknowledged, pref }); + reply.header({ NetworkProtocol::Header::Type::Acknowledged, pref }); if (holds_alternative>(peer.channel) && std::get>(peer.channel).ref().digest() < pref.digest()) @@ -644,7 +644,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea } break; - case TransportHeader::Type::ChannelAccept: + case NetworkProtocol::Header::Type::ChannelAccept: if (auto pref = std::get(item.value)) { if (holds_alternative>(peer.channel) && std::get>(peer.channel).ref().digest() < pref.digest()) @@ -655,22 +655,22 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea auto acc = ChannelAccept::load(*r); if (holds_alternative(peer.identity) && acc.isSignedBy(std::get(peer.identity).keyMessage())) { - reply.header({ TransportHeader::Type::Acknowledged, pref }); + reply.header({ NetworkProtocol::Header::Type::Acknowledged, pref }); peer.finalizeChannel(reply, acc.data->channel()); } } } break; - case TransportHeader::Type::ServiceType: + case NetworkProtocol::Header::Type::ServiceType: if (!serviceType) serviceType = std::get(item.value); break; - case TransportHeader::Type::ServiceRef: + case NetworkProtocol::Header::Type::ServiceRef: if (!serviceType) for (auto & item : header.items) - if (item.type == TransportHeader::Type::ServiceType) { + if (item.type == NetworkProtocol::Header::Type::ServiceType) { serviceType = std::get(item.value); break; } @@ -679,7 +679,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea auto pref = std::get(item.value); if (pref) - reply.header({ TransportHeader::Type::Acknowledged, pref }); + reply.header({ NetworkProtocol::Header::Type::Acknowledged, pref }); shared_ptr wref(new WaitingRef { .storage = peer.tempStorage, @@ -703,15 +703,15 @@ void Server::Priv::handleLocalHeadChange(const Head & head) if (*id != self) { self = *id; - vector hitems; + vector hitems; for (const auto & r : self.refs()) - hitems.push_back(TransportHeader::Item { - TransportHeader::Type::AnnounceUpdate, r }); + hitems.push_back(NetworkProtocol::Header::Item { + NetworkProtocol::Header::Type::AnnounceUpdate, r }); for (const auto & r : self.updates()) - hitems.push_back(TransportHeader::Item { - TransportHeader::Type::AnnounceUpdate, r }); + hitems.push_back(NetworkProtocol::Header::Item { + NetworkProtocol::Header::Type::AnnounceUpdate, r }); - TransportHeader header(hitems); + NetworkProtocol::Header header(hitems); for (const auto & peer : peers) peer->send(header, { **self.ref() }, false); @@ -719,7 +719,7 @@ void Server::Priv::handleLocalHeadChange(const Head & head) } } -void Server::Peer::send(const TransportHeader & header, const vector & objs, bool secure) +void Server::Peer::send(const NetworkProtocol::Header & header, const vector & objs, bool secure) { vector data, part, out; @@ -786,7 +786,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) auto req = Channel::generateRequest(tempStorage, server.self, std::get(identity)); channel.emplace>(req); - reply.header({ TransportHeader::Type::ChannelRequest, req.ref() }); + reply.header({ NetworkProtocol::Header::Type::ChannelRequest, req.ref() }); reply.body(req.ref()); reply.body(req->data.ref()); reply.body(req->data->key.ref()); @@ -801,7 +801,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) req->isSignedBy(std::get(identity).keyMessage())) { if (auto acc = Channel::acceptRequest(server.self, std::get(identity), req)) { channel.emplace>(*acc); - reply.header({ TransportHeader::Type::ChannelAccept, acc->ref() }); + reply.header({ NetworkProtocol::Header::Type::ChannelAccept, acc->ref() }); reply.body(acc->ref()); reply.body(acc.value()->data.ref()); reply.body(acc.value()->data->key.ref()); @@ -821,13 +821,13 @@ void Server::Peer::finalizeChannel(ReplyBuilder & reply, unique_ptr ch) { channel.emplace>(move(ch)); - vector hitems; + vector hitems; for (const auto & r : server.self.refs()) - reply.header(TransportHeader::Item { - TransportHeader::Type::AnnounceUpdate, r }); + reply.header(NetworkProtocol::Header::Item { + NetworkProtocol::Header::Type::AnnounceUpdate, r }); for (const auto & r : server.self.updates()) - reply.header(TransportHeader::Item { - TransportHeader::Type::AnnounceUpdate, r }); + reply.header(NetworkProtocol::Header::Item { + NetworkProtocol::Header::Type::AnnounceUpdate, r }); } void Server::Peer::updateService(ReplyBuilder & reply) @@ -881,7 +881,7 @@ void Server::Peer::trySendOutQueue() } -void ReplyBuilder::header(TransportHeader::Item && item) +void ReplyBuilder::header(NetworkProtocol::Header::Item && item) { for (const auto & x : mheader) if (x == item) @@ -926,146 +926,7 @@ optional WaitingRef::check(ReplyBuilder & reply) return r; for (const auto & d : missing) - reply.header({ TransportHeader::Type::DataRequest, peer.partStorage.ref(d) }); + reply.header({ NetworkProtocol::Header::Type::DataRequest, peer.partStorage.ref(d) }); return nullopt; } - - -bool TransportHeader::Item::operator==(const Item & other) const -{ - if (type != other.type) - return false; - - if (value.index() != other.value.index()) - return false; - - if (holds_alternative(value)) - return std::get(value).digest() == - std::get(other.value).digest(); - - if (holds_alternative(value)) - return std::get(value) == std::get(other.value); - - throw runtime_error("unhandled transport header item type"); -} - -optional TransportHeader::load(const PartialRef & ref) -{ - return load(*ref); -} - -optional TransportHeader::load(const PartialObject & obj) -{ - auto rec = obj.asRecord(); - if (!rec) - return nullopt; - - vector items; - for (const auto & item : rec->items()) { - if (item.name == "ACK") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::Acknowledged, - .value = *ref, - }); - } else if (item.name == "REQ") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::DataRequest, - .value = *ref, - }); - } else if (item.name == "RSP") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::DataResponse, - .value = *ref, - }); - } else if (item.name == "ANN") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::AnnounceSelf, - .value = *ref, - }); - } else if (item.name == "ANU") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::AnnounceUpdate, - .value = *ref, - }); - } else if (item.name == "CRQ") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::ChannelRequest, - .value = *ref, - }); - } else if (item.name == "CAC") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::ChannelAccept, - .value = *ref, - }); - } else if (item.name == "STP") { - if (auto val = item.asUUID()) - items.emplace_back(Item { - .type = Type::ServiceType, - .value = *val, - }); - } else if (item.name == "SRF") { - if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::ServiceRef, - .value = *ref, - }); - } - } - - return TransportHeader(items); -} - -PartialObject TransportHeader::toObject() const -{ - vector ritems; - - for (const auto & item : items) { - switch (item.type) { - case Type::Acknowledged: - ritems.emplace_back("ACK", std::get(item.value)); - break; - - case Type::DataRequest: - ritems.emplace_back("REQ", std::get(item.value)); - break; - - case Type::DataResponse: - ritems.emplace_back("RSP", std::get(item.value)); - break; - - case Type::AnnounceSelf: - ritems.emplace_back("ANN", std::get(item.value)); - break; - - case Type::AnnounceUpdate: - ritems.emplace_back("ANU", std::get(item.value)); - break; - - case Type::ChannelRequest: - ritems.emplace_back("CRQ", std::get(item.value)); - break; - - case Type::ChannelAccept: - ritems.emplace_back("CAC", std::get(item.value)); - break; - - case Type::ServiceType: - ritems.emplace_back("STP", std::get(item.value)); - break; - - case Type::ServiceRef: - ritems.emplace_back("SRF", std::get(item.value)); - break; - } - } - - return PartialObject(PartialRecord(std::move(ritems))); -} diff --git a/src/network.h b/src/network.h index 74231bf..c3a2074 100644 --- a/src/network.h +++ b/src/network.h @@ -65,7 +65,7 @@ struct Server::Peer shared_ptr lpeer = nullptr; - void send(const struct TransportHeader &, const vector &, bool secure); + void send(const NetworkProtocol::Header &, const vector &, bool secure); void updateIdentity(ReplyBuilder &); void updateChannel(ReplyBuilder &); void finalizeChannel(ReplyBuilder &, unique_ptr); @@ -91,47 +91,17 @@ struct PeerList::Priv : enable_shared_from_this void push(const shared_ptr &); }; -struct TransportHeader -{ - enum class Type { - Acknowledged, - DataRequest, - DataResponse, - AnnounceSelf, - AnnounceUpdate, - ChannelRequest, - ChannelAccept, - ServiceType, - ServiceRef, - }; - - struct Item { - const Type type; - const variant value; - - bool operator==(const Item &) const; - bool operator!=(const Item & other) const { return !(*this == other); } - }; - - TransportHeader(const vector & items): items(items) {} - static optional load(const PartialRef &); - static optional load(const PartialObject &); - PartialObject toObject() const; - - const vector items; -}; - class ReplyBuilder { public: - void header(TransportHeader::Item &&); + void header(NetworkProtocol::Header::Item &&); void body(const Ref &); - const vector & header() const { return mheader; } + const vector & header() const { return mheader; } vector body() const; private: - vector mheader; + vector mheader; vector mbody; }; @@ -160,7 +130,7 @@ struct Server::Priv Peer * findPeer(NetworkProtocol::Connection::Id cid) const; Peer & getPeer(const sockaddr_in6 & paddr); Peer & addPeer(NetworkProtocol::Connection conn); - void handlePacket(Peer &, const TransportHeader &, ReplyBuilder &); + void handlePacket(Peer &, const NetworkProtocol::Header &, ReplyBuilder &); void handleLocalHeadChange(const Head &); @@ -181,7 +151,7 @@ struct Server::Priv vector> peers; PeerList plist; - vector outgoing; + vector outgoing; vector> waiting; NetworkProtocol protocol; diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index c247bf0..c2c6c5d 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -7,7 +7,10 @@ #include #include +using std::holds_alternative; using std::move; +using std::nullopt; +using std::runtime_error; using std::scoped_lock; namespace erebos { @@ -122,6 +125,10 @@ void NetworkProtocol::shutdown() } +/******************************************************************************/ +/* Connection */ +/******************************************************************************/ + NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const { return reinterpret_cast(this); @@ -195,4 +202,147 @@ void NetworkProtocol::Connection::close() p = nullptr; } + +/******************************************************************************/ +/* Header */ +/******************************************************************************/ + +bool NetworkProtocol::Header::Item::operator==(const Item & other) const +{ + if (type != other.type) + return false; + + if (value.index() != other.value.index()) + return false; + + if (holds_alternative(value)) + return std::get(value).digest() == + std::get(other.value).digest(); + + if (holds_alternative(value)) + return std::get(value) == std::get(other.value); + + throw runtime_error("unhandled network header item type"); +} + +optional NetworkProtocol::Header::load(const PartialRef & ref) +{ + return load(*ref); +} + +optional NetworkProtocol::Header::load(const PartialObject & obj) +{ + auto rec = obj.asRecord(); + if (!rec) + return nullopt; + + vector items; + for (const auto & item : rec->items()) { + if (item.name == "ACK") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::Acknowledged, + .value = *ref, + }); + } else if (item.name == "REQ") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::DataRequest, + .value = *ref, + }); + } else if (item.name == "RSP") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::DataResponse, + .value = *ref, + }); + } else if (item.name == "ANN") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::AnnounceSelf, + .value = *ref, + }); + } else if (item.name == "ANU") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::AnnounceUpdate, + .value = *ref, + }); + } else if (item.name == "CRQ") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::ChannelRequest, + .value = *ref, + }); + } else if (item.name == "CAC") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::ChannelAccept, + .value = *ref, + }); + } else if (item.name == "STP") { + if (auto val = item.asUUID()) + items.emplace_back(Item { + .type = Type::ServiceType, + .value = *val, + }); + } else if (item.name == "SRF") { + if (auto ref = item.asRef()) + items.emplace_back(Item { + .type = Type::ServiceRef, + .value = *ref, + }); + } + } + + return NetworkProtocol::Header(items); +} + +PartialObject NetworkProtocol::Header::toObject() const +{ + vector ritems; + + for (const auto & item : items) { + switch (item.type) { + case Type::Acknowledged: + ritems.emplace_back("ACK", std::get(item.value)); + break; + + case Type::DataRequest: + ritems.emplace_back("REQ", std::get(item.value)); + break; + + case Type::DataResponse: + ritems.emplace_back("RSP", std::get(item.value)); + break; + + case Type::AnnounceSelf: + ritems.emplace_back("ANN", std::get(item.value)); + break; + + case Type::AnnounceUpdate: + ritems.emplace_back("ANU", std::get(item.value)); + break; + + case Type::ChannelRequest: + ritems.emplace_back("CRQ", std::get(item.value)); + break; + + case Type::ChannelAccept: + ritems.emplace_back("CAC", std::get(item.value)); + break; + + case Type::ServiceType: + ritems.emplace_back("STP", std::get(item.value)); + break; + + case Type::ServiceRef: + ritems.emplace_back("SRF", std::get(item.value)); + break; + } + } + + return PartialObject(PartialRecord(std::move(ritems))); +} + } diff --git a/src/network/protocol.h b/src/network/protocol.h index a9bbaff..8aa22a2 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #include @@ -7,10 +9,12 @@ #include #include #include +#include namespace erebos { using std::mutex; +using std::optional; using std::unique_ptr; using std::variant; using std::vector; @@ -28,6 +32,8 @@ public: class Connection; + struct Header; + struct NewConnection; struct ConnectionReadReady; struct ProtocolClosed {}; @@ -85,4 +91,34 @@ private: struct NetworkProtocol::NewConnection { Connection conn; }; struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; +struct NetworkProtocol::Header +{ + enum class Type { + Acknowledged, + DataRequest, + DataResponse, + AnnounceSelf, + AnnounceUpdate, + ChannelRequest, + ChannelAccept, + ServiceType, + ServiceRef, + }; + + struct Item { + const Type type; + const variant value; + + bool operator==(const Item &) const; + bool operator!=(const Item & other) const { return !(*this == other); } + }; + + Header(const vector & items): items(items) {} + static optional
load(const PartialRef &); + static optional
load(const PartialObject &); + PartialObject toObject() const; + + const vector items; +}; + } -- cgit v1.2.3