diff options
-rw-r--r-- | src/network.cpp | 173 | ||||
-rw-r--r-- | src/network/protocol.cpp | 130 | ||||
-rw-r--r-- | src/network/protocol.h | 28 |
3 files changed, 144 insertions, 187 deletions
diff --git a/src/network.cpp b/src/network.cpp index 8c181cf..7a5a804 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -114,8 +114,7 @@ void Server::addPeer(const string & node, const string & service) const vector<NetworkProtocol::Header::Item> header; { shared_lock lock(p->selfMutex); - header.push_back(NetworkProtocol::Header::Item { - NetworkProtocol::Header::Type::AnnounceSelf, p->self.ref()->digest() }); + header.push_back(NetworkProtocol::Header::AnnounceSelf { p->self.ref()->digest() }); } peer.connection.send(peer.partStorage, header, {}, false); return; @@ -226,8 +225,8 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const { if (auto speer = p->speer.lock()) { NetworkProtocol::Header header({ - { NetworkProtocol::Header::Type::ServiceType, uuid }, - { NetworkProtocol::Header::Type::ServiceRef, ref.digest() }, + NetworkProtocol::Header::ServiceType { uuid }, + NetworkProtocol::Header::ServiceRef { ref.digest() }, }); speer->connection.send(speer->partStorage, header, { obj }, true); return true; @@ -417,7 +416,7 @@ void Server::Priv::doAnnounce() if (lastAnnounce + announceInterval < now) { shared_lock slock(selfMutex); NetworkProtocol::Header header({ - { NetworkProtocol::Header::Type::AnnounceSelf, self.ref()->digest() } + NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() }, }); vector<uint8_t> bytes = header.toObject(pst).encode(); @@ -506,32 +505,29 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head optional<UUID> serviceType; - for (auto & item : header.items) { - switch (item.type) { - case NetworkProtocol::Header::Type::Acknowledged: { - auto dgst = std::get<Digest>(item.value); + for (const auto & item : header.items) { + if (const auto * ack = get_if<NetworkProtocol::Header::Acknowledged>(&item)) { + const auto & dgst = ack->value; if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) && std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() == dgst) peer.finalizeChannel(reply, std::get<Stored<ChannelAccept>>(peer.connection.channel())->data->channel()); - break; } - case NetworkProtocol::Header::Type::DataRequest: { - auto dgst = std::get<Digest>(item.value); + else if (const auto * req = get_if<NetworkProtocol::Header::DataRequest>(&item)) { + const auto & dgst = req->value; if (holds_alternative<unique_ptr<Channel>>(peer.connection.channel()) || plaintextRefs.find(dgst) != plaintextRefs.end()) { if (auto ref = peer.tempStorage.ref(dgst)) { - reply.header({ NetworkProtocol::Header::Type::DataResponse, ref->digest() }); + reply.header({ NetworkProtocol::Header::DataResponse { ref->digest() } }); reply.body(*ref); } } - break; } - case NetworkProtocol::Header::Type::DataResponse: { - auto dgst = std::get<Digest>(item.value); - reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); + else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) { + const auto & dgst = rsp->value; + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != @@ -543,16 +539,13 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head } waiting.erase(std::remove_if(waiting.begin(), waiting.end(), [](auto & wref) { return wref.expired(); }), waiting.end()); - break; } - case NetworkProtocol::Header::Type::AnnounceSelf: { - auto dgst = std::get<Digest>(item.value); - if (dgst == self.ref()->digest()) - break; - - if (holds_alternative<monostate>(peer.identity)) { - reply.header({ NetworkProtocol::Header::Type::AnnounceSelf, self.ref()->digest()}); + else if (const auto * ann = get_if<NetworkProtocol::Header::AnnounceSelf>(&item)) { + const auto & dgst = ann->value; + if (dgst != self.ref()->digest() && + holds_alternative<monostate>(peer.identity)) { + reply.header({ NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() }}); shared_ptr<WaitingRef> wref(new WaitingRef { .storage = peer.tempStorage, @@ -563,13 +556,12 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head peer.identity = wref; wref->check(reply); } - break; } - case NetworkProtocol::Header::Type::AnnounceUpdate: + else if (const auto * anu = get_if<NetworkProtocol::Header::AnnounceUpdate>(&item)) { if (holds_alternative<Identity>(peer.identity)) { - auto dgst = std::get<Digest>(item.value); - reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); + const auto & dgst = anu->value; + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); shared_ptr<WaitingRef> wref(new WaitingRef { .storage = peer.tempStorage, @@ -580,76 +572,81 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head peer.identityUpdates.push_back(wref); wref->check(reply); } - break; + } - case NetworkProtocol::Header::Type::ChannelRequest: { - auto dgst = std::get<Digest>(item.value); - reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); + else if (const auto * req = get_if<NetworkProtocol::Header::ChannelRequest>(&item)) { + const auto & dgst = req->value; + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); if (holds_alternative<Stored<ChannelRequest>>(peer.connection.channel()) && - std::get<Stored<ChannelRequest>>(peer.connection.channel()).ref().digest() < dgst) - break; + std::get<Stored<ChannelRequest>>(peer.connection.channel()).ref().digest() < dgst) { + // TODO: reject request with lower priority + } - if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel())) - break; + else if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel())) { + // TODO: reject when we already sent accept + } - shared_ptr<WaitingRef> wref(new WaitingRef { - .storage = peer.tempStorage, - .ref = peer.partStorage.ref(dgst), - .missing = {}, - }); - waiting.push_back(wref); - peer.connection.channel() = wref; - wref->check(reply); - break; + else { + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = peer.partStorage.ref(dgst), + .missing = {}, + }); + waiting.push_back(wref); + peer.connection.channel() = wref; + wref->check(reply); + } } - case NetworkProtocol::Header::Type::ChannelAccept: { - auto dgst = std::get<Digest>(item.value); + else if (const auto * acc = get_if<NetworkProtocol::Header::ChannelAccept>(&item)) { + const auto & dgst = acc->value; if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) && - std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() < dgst) - break; - - auto cres = peer.tempStorage.copy(peer.partStorage.ref(dgst)); - if (auto r = std::get_if<Ref>(&cres)) { - auto acc = ChannelAccept::load(*r); - if (holds_alternative<Identity>(peer.identity) && - acc.isSignedBy(std::get<Identity>(peer.identity).keyMessage())) { - reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); - peer.finalizeChannel(reply, acc.data->channel()); + std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() < dgst) { + // TODO: reject request with lower priority + } + + else { + auto cres = peer.tempStorage.copy(peer.partStorage.ref(dgst)); + if (auto r = get_if<Ref>(&cres)) { + auto acc = ChannelAccept::load(*r); + if (holds_alternative<Identity>(peer.identity) && + acc.isSignedBy(std::get<Identity>(peer.identity).keyMessage())) { + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); + peer.finalizeChannel(reply, acc.data->channel()); + } } } - break; } - case NetworkProtocol::Header::Type::ServiceType: + else if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) { if (!serviceType) - serviceType = std::get<UUID>(item.value); - break; + serviceType = stype->value; + } - case NetworkProtocol::Header::Type::ServiceRef: + else if (const auto * sref = get_if<NetworkProtocol::Header::ServiceRef>(&item)) { if (!serviceType) for (auto & item : header.items) - if (item.type == NetworkProtocol::Header::Type::ServiceType) { - serviceType = std::get<UUID>(item.value); + if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) { + serviceType = stype->value; break; } - if (!serviceType) - break; - auto dgst = std::get<Digest>(item.value); - auto pref = peer.partStorage.ref(dgst); - if (pref) - reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); + if (serviceType) { + const auto & dgst = sref->value; + auto pref = peer.partStorage.ref(dgst); + if (pref) + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); - shared_ptr<WaitingRef> wref(new WaitingRef { - .storage = peer.tempStorage, - .ref = pref, - .missing = {}, - }); - waiting.push_back(wref); - peer.serviceQueue.emplace_back(*serviceType, wref); - wref->check(reply); + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = pref, + .missing = {}, + }); + waiting.push_back(wref); + peer.serviceQueue.emplace_back(*serviceType, wref); + wref->check(reply); + } } } } @@ -665,11 +662,9 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head) vector<NetworkProtocol::Header::Item> hitems; for (const auto & r : self.refs()) - hitems.push_back(NetworkProtocol::Header::Item { - NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); + hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); for (const auto & r : self.updates()) - hitems.push_back(NetworkProtocol::Header::Item { - NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); + hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); NetworkProtocol::Header header(hitems); @@ -724,7 +719,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) auto req = Channel::generateRequest(tempStorage, server.self, std::get<Identity>(identity)); connection.channel().emplace<Stored<ChannelRequest>>(req); - reply.header({ NetworkProtocol::Header::Type::ChannelRequest, req.ref().digest() }); + reply.header({ NetworkProtocol::Header::ChannelRequest { req.ref().digest() } }); reply.body(req.ref()); reply.body(req->data.ref()); reply.body(req->data->key.ref()); @@ -739,7 +734,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) req->isSignedBy(std::get<Identity>(identity).keyMessage())) { if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), req)) { connection.channel().emplace<Stored<ChannelAccept>>(*acc); - reply.header({ NetworkProtocol::Header::Type::ChannelAccept, acc->ref().digest() }); + reply.header({ NetworkProtocol::Header::ChannelAccept { acc->ref().digest() } }); reply.body(acc->ref()); reply.body(acc.value()->data.ref()); reply.body(acc.value()->data->key.ref()); @@ -761,11 +756,9 @@ void Server::Peer::finalizeChannel(ReplyBuilder & reply, unique_ptr<Channel> ch) vector<NetworkProtocol::Header::Item> hitems; for (const auto & r : server.self.refs()) - reply.header(NetworkProtocol::Header::Item { - NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); + reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); for (const auto & r : server.self.updates()) - reply.header(NetworkProtocol::Header::Item { - NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); + reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); } void Server::Peer::updateService(ReplyBuilder & reply) @@ -848,7 +841,7 @@ optional<Ref> WaitingRef::check(ReplyBuilder & reply) return r; for (const auto & d : missing) - reply.header({ NetworkProtocol::Header::Type::DataRequest, d }); + reply.header({ NetworkProtocol::Header::DataRequest { d } }); return nullopt; } diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index f38267f..ede7023 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -9,11 +9,12 @@ #include <mutex> #include <system_error> +using std::get_if; using std::holds_alternative; using std::move; using std::nullopt; -using std::runtime_error; using std::scoped_lock; +using std::visit; namespace erebos { @@ -327,21 +328,16 @@ void NetworkProtocol::Connection::trySendOutQueue() /* Header */ /******************************************************************************/ -bool NetworkProtocol::Header::Item::operator==(const Item & other) const +bool operator==(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) { - if (type != other.type) + if (left.index() != right.index()) return false; - if (value.index() != other.value.index()) - return false; - - if (holds_alternative<Digest>(value)) - return std::get<Digest>(value) == std::get<Digest>(other.value); - - if (holds_alternative<UUID>(value)) - return std::get<UUID>(value) == std::get<UUID>(other.value); - - throw runtime_error("unhandled network header item type"); + return visit([&](auto && arg) { + using T = std::decay_t<decltype(arg)>; + return arg.value == std::get<T>(right).value; + }, left); } optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialRef & ref) @@ -359,58 +355,31 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj for (const auto & item : rec->items()) { if (item.name == "ACK") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::Acknowledged, - .value = ref->digest(), - }); + items.emplace_back(Acknowledged { ref->digest() }); } else if (item.name == "REQ") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::DataRequest, - .value = ref->digest(), - }); + items.emplace_back(DataRequest { ref->digest() }); } else if (item.name == "RSP") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::DataResponse, - .value = ref->digest(), - }); + items.emplace_back(DataResponse { ref->digest() }); } else if (item.name == "ANN") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::AnnounceSelf, - .value = ref->digest(), - }); + items.emplace_back(AnnounceSelf { ref->digest() }); } else if (item.name == "ANU") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::AnnounceUpdate, - .value = ref->digest(), - }); + items.emplace_back(AnnounceUpdate { ref->digest() }); } else if (item.name == "CRQ") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::ChannelRequest, - .value = ref->digest(), - }); + items.emplace_back(ChannelRequest { ref->digest() }); } else if (item.name == "CAC") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::ChannelAccept, - .value = ref->digest(), - }); + items.emplace_back(ChannelAccept { ref->digest() }); } else if (item.name == "STP") { if (auto val = item.asUUID()) - items.emplace_back(Item { - .type = Type::ServiceType, - .value = *val, - }); + items.emplace_back(ServiceType { *val }); } else if (item.name == "SRF") { if (auto ref = item.asRef()) - items.emplace_back(Item { - .type = Type::ServiceRef, - .value = ref->digest(), - }); + items.emplace_back(ServiceRef { ref->digest() }); } } @@ -422,43 +391,32 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const vector<PartialRecord::Item> ritems; for (const auto & item : items) { - switch (item.type) { - case Type::Acknowledged: - ritems.emplace_back("ACK", st.ref(std::get<Digest>(item.value))); - break; - - case Type::DataRequest: - ritems.emplace_back("REQ", st.ref(std::get<Digest>(item.value))); - break; - - case Type::DataResponse: - ritems.emplace_back("RSP", st.ref(std::get<Digest>(item.value))); - break; - - case Type::AnnounceSelf: - ritems.emplace_back("ANN", st.ref(std::get<Digest>(item.value))); - break; - - case Type::AnnounceUpdate: - ritems.emplace_back("ANU", st.ref(std::get<Digest>(item.value))); - break; - - case Type::ChannelRequest: - ritems.emplace_back("CRQ", st.ref(std::get<Digest>(item.value))); - break; - - case Type::ChannelAccept: - ritems.emplace_back("CAC", st.ref(std::get<Digest>(item.value))); - break; - - case Type::ServiceType: - ritems.emplace_back("STP", std::get<UUID>(item.value)); - break; - - case Type::ServiceRef: - ritems.emplace_back("SRF", st.ref(std::get<Digest>(item.value))); - break; - } + if (const auto * ptr = get_if<Acknowledged>(&item)) + ritems.emplace_back("ACK", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<DataRequest>(&item)) + ritems.emplace_back("REQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<DataResponse>(&item)) + ritems.emplace_back("RSP", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceSelf>(&item)) + ritems.emplace_back("ANN", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceUpdate>(&item)) + ritems.emplace_back("ANU", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelRequest>(&item)) + ritems.emplace_back("CRQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelAccept>(&item)) + ritems.emplace_back("CAC", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ServiceType>(&item)) + ritems.emplace_back("STP", ptr->value); + + else if (const auto * ptr = get_if<ServiceRef>(&item)) + ritems.emplace_back("SRF", st.ref(ptr->value)); } return PartialObject(PartialRecord(std::move(ritems))); diff --git a/src/network/protocol.h b/src/network/protocol.h index c5803ce..df29c05 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -106,7 +106,17 @@ struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; struct NetworkProtocol::Header { - enum class Type { + struct Acknowledged { Digest value; }; + struct DataRequest { Digest value; }; + struct DataResponse { Digest value; }; + struct AnnounceSelf { Digest value; }; + struct AnnounceUpdate { Digest value; }; + struct ChannelRequest { Digest value; }; + struct ChannelAccept { Digest value; }; + struct ServiceType { UUID value; }; + struct ServiceRef { Digest value; }; + + using Item = variant< Acknowledged, DataRequest, DataResponse, @@ -115,16 +125,7 @@ struct NetworkProtocol::Header ChannelRequest, ChannelAccept, ServiceType, - ServiceRef, - }; - - struct Item { - const Type type; - const variant<Digest, UUID> value; - - bool operator==(const Item &) const; - bool operator!=(const Item & other) const { return !(*this == other); } - }; + ServiceRef>; Header(const vector<Item> & items): items(items) {} static optional<Header> load(const PartialRef &); @@ -134,6 +135,11 @@ struct NetworkProtocol::Header const vector<Item> items; }; +bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &); +inline bool operator!=(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) +{ return not (left == right); } + class ReplyBuilder { public: |