summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/network.cpp173
-rw-r--r--src/network/protocol.cpp130
-rw-r--r--src/network/protocol.h28
3 files changed, 144 insertions, 187 deletions
diff --git a/src/network.cpp b/src/network.cpp
index 8c181cf..7a5a804 100644
--- a/src/network.cpp
+++ b/src/network.cpp
@@ -114,8 +114,7 @@ void Server::addPeer(const string & node, const string & service) const
vector<NetworkProtocol::Header::Item> header;
{
shared_lock lock(p->selfMutex);
- header.push_back(NetworkProtocol::Header::Item {
- NetworkProtocol::Header::Type::AnnounceSelf, p->self.ref()->digest() });
+ header.push_back(NetworkProtocol::Header::AnnounceSelf { p->self.ref()->digest() });
}
peer.connection.send(peer.partStorage, header, {}, false);
return;
@@ -226,8 +225,8 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const
{
if (auto speer = p->speer.lock()) {
NetworkProtocol::Header header({
- { NetworkProtocol::Header::Type::ServiceType, uuid },
- { NetworkProtocol::Header::Type::ServiceRef, ref.digest() },
+ NetworkProtocol::Header::ServiceType { uuid },
+ NetworkProtocol::Header::ServiceRef { ref.digest() },
});
speer->connection.send(speer->partStorage, header, { obj }, true);
return true;
@@ -417,7 +416,7 @@ void Server::Priv::doAnnounce()
if (lastAnnounce + announceInterval < now) {
shared_lock slock(selfMutex);
NetworkProtocol::Header header({
- { NetworkProtocol::Header::Type::AnnounceSelf, self.ref()->digest() }
+ NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() },
});
vector<uint8_t> bytes = header.toObject(pst).encode();
@@ -506,32 +505,29 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
optional<UUID> serviceType;
- for (auto & item : header.items) {
- switch (item.type) {
- case NetworkProtocol::Header::Type::Acknowledged: {
- auto dgst = std::get<Digest>(item.value);
+ for (const auto & item : header.items) {
+ if (const auto * ack = get_if<NetworkProtocol::Header::Acknowledged>(&item)) {
+ const auto & dgst = ack->value;
if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) &&
std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() == dgst)
peer.finalizeChannel(reply,
std::get<Stored<ChannelAccept>>(peer.connection.channel())->data->channel());
- break;
}
- case NetworkProtocol::Header::Type::DataRequest: {
- auto dgst = std::get<Digest>(item.value);
+ else if (const auto * req = get_if<NetworkProtocol::Header::DataRequest>(&item)) {
+ const auto & dgst = req->value;
if (holds_alternative<unique_ptr<Channel>>(peer.connection.channel()) ||
plaintextRefs.find(dgst) != plaintextRefs.end()) {
if (auto ref = peer.tempStorage.ref(dgst)) {
- reply.header({ NetworkProtocol::Header::Type::DataResponse, ref->digest() });
+ reply.header({ NetworkProtocol::Header::DataResponse { ref->digest() } });
reply.body(*ref);
}
}
- break;
}
- case NetworkProtocol::Header::Type::DataResponse: {
- auto dgst = std::get<Digest>(item.value);
- reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst });
+ else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) {
+ const auto & dgst = rsp->value;
+ reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });
for (auto & pwref : waiting) {
if (auto wref = pwref.lock()) {
if (std::find(wref->missing.begin(), wref->missing.end(), dgst) !=
@@ -543,16 +539,13 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
}
waiting.erase(std::remove_if(waiting.begin(), waiting.end(),
[](auto & wref) { return wref.expired(); }), waiting.end());
- break;
}
- case NetworkProtocol::Header::Type::AnnounceSelf: {
- auto dgst = std::get<Digest>(item.value);
- if (dgst == self.ref()->digest())
- break;
-
- if (holds_alternative<monostate>(peer.identity)) {
- reply.header({ NetworkProtocol::Header::Type::AnnounceSelf, self.ref()->digest()});
+ else if (const auto * ann = get_if<NetworkProtocol::Header::AnnounceSelf>(&item)) {
+ const auto & dgst = ann->value;
+ if (dgst != self.ref()->digest() &&
+ holds_alternative<monostate>(peer.identity)) {
+ reply.header({ NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() }});
shared_ptr<WaitingRef> wref(new WaitingRef {
.storage = peer.tempStorage,
@@ -563,13 +556,12 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
peer.identity = wref;
wref->check(reply);
}
- break;
}
- case NetworkProtocol::Header::Type::AnnounceUpdate:
+ else if (const auto * anu = get_if<NetworkProtocol::Header::AnnounceUpdate>(&item)) {
if (holds_alternative<Identity>(peer.identity)) {
- auto dgst = std::get<Digest>(item.value);
- reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst });
+ const auto & dgst = anu->value;
+ reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });
shared_ptr<WaitingRef> wref(new WaitingRef {
.storage = peer.tempStorage,
@@ -580,76 +572,81 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head
peer.identityUpdates.push_back(wref);
wref->check(reply);
}
- break;
+ }
- case NetworkProtocol::Header::Type::ChannelRequest: {
- auto dgst = std::get<Digest>(item.value);
- reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst });
+ else if (const auto * req = get_if<NetworkProtocol::Header::ChannelRequest>(&item)) {
+ const auto & dgst = req->value;
+ reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });
if (holds_alternative<Stored<ChannelRequest>>(peer.connection.channel()) &&
- std::get<Stored<ChannelRequest>>(peer.connection.channel()).ref().digest() < dgst)
- break;
+ std::get<Stored<ChannelRequest>>(peer.connection.channel()).ref().digest() < dgst) {
+ // TODO: reject request with lower priority
+ }
- if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()))
- break;
+ else if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel())) {
+ // TODO: reject when we already sent accept
+ }
- shared_ptr<WaitingRef> wref(new WaitingRef {
- .storage = peer.tempStorage,
- .ref = peer.partStorage.ref(dgst),
- .missing = {},
- });
- waiting.push_back(wref);
- peer.connection.channel() = wref;
- wref->check(reply);
- break;
+ else {
+ shared_ptr<WaitingRef> wref(new WaitingRef {
+ .storage = peer.tempStorage,
+ .ref = peer.partStorage.ref(dgst),
+ .missing = {},
+ });
+ waiting.push_back(wref);
+ peer.connection.channel() = wref;
+ wref->check(reply);
+ }
}
- case NetworkProtocol::Header::Type::ChannelAccept: {
- auto dgst = std::get<Digest>(item.value);
+ else if (const auto * acc = get_if<NetworkProtocol::Header::ChannelAccept>(&item)) {
+ const auto & dgst = acc->value;
if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) &&
- std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() < dgst)
- break;
-
- auto cres = peer.tempStorage.copy(peer.partStorage.ref(dgst));
- if (auto r = std::get_if<Ref>(&cres)) {
- auto acc = ChannelAccept::load(*r);
- if (holds_alternative<Identity>(peer.identity) &&
- acc.isSignedBy(std::get<Identity>(peer.identity).keyMessage())) {
- reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst });
- peer.finalizeChannel(reply, acc.data->channel());
+ std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() < dgst) {
+ // TODO: reject request with lower priority
+ }
+
+ else {
+ auto cres = peer.tempStorage.copy(peer.partStorage.ref(dgst));
+ if (auto r = get_if<Ref>(&cres)) {
+ auto acc = ChannelAccept::load(*r);
+ if (holds_alternative<Identity>(peer.identity) &&
+ acc.isSignedBy(std::get<Identity>(peer.identity).keyMessage())) {
+ reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });
+ peer.finalizeChannel(reply, acc.data->channel());
+ }
}
}
- break;
}
- case NetworkProtocol::Header::Type::ServiceType:
+ else if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) {
if (!serviceType)
- serviceType = std::get<UUID>(item.value);
- break;
+ serviceType = stype->value;
+ }
- case NetworkProtocol::Header::Type::ServiceRef:
+ else if (const auto * sref = get_if<NetworkProtocol::Header::ServiceRef>(&item)) {
if (!serviceType)
for (auto & item : header.items)
- if (item.type == NetworkProtocol::Header::Type::ServiceType) {
- serviceType = std::get<UUID>(item.value);
+ if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) {
+ serviceType = stype->value;
break;
}
- if (!serviceType)
- break;
- auto dgst = std::get<Digest>(item.value);
- auto pref = peer.partStorage.ref(dgst);
- if (pref)
- reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst });
+ if (serviceType) {
+ const auto & dgst = sref->value;
+ auto pref = peer.partStorage.ref(dgst);
+ if (pref)
+ reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });
- shared_ptr<WaitingRef> wref(new WaitingRef {
- .storage = peer.tempStorage,
- .ref = pref,
- .missing = {},
- });
- waiting.push_back(wref);
- peer.serviceQueue.emplace_back(*serviceType, wref);
- wref->check(reply);
+ shared_ptr<WaitingRef> wref(new WaitingRef {
+ .storage = peer.tempStorage,
+ .ref = pref,
+ .missing = {},
+ });
+ waiting.push_back(wref);
+ peer.serviceQueue.emplace_back(*serviceType, wref);
+ wref->check(reply);
+ }
}
}
}
@@ -665,11 +662,9 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head)
vector<NetworkProtocol::Header::Item> hitems;
for (const auto & r : self.refs())
- hitems.push_back(NetworkProtocol::Header::Item {
- NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() });
+ hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
for (const auto & r : self.updates())
- hitems.push_back(NetworkProtocol::Header::Item {
- NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() });
+ hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
NetworkProtocol::Header header(hitems);
@@ -724,7 +719,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)
auto req = Channel::generateRequest(tempStorage,
server.self, std::get<Identity>(identity));
connection.channel().emplace<Stored<ChannelRequest>>(req);
- reply.header({ NetworkProtocol::Header::Type::ChannelRequest, req.ref().digest() });
+ reply.header({ NetworkProtocol::Header::ChannelRequest { req.ref().digest() } });
reply.body(req.ref());
reply.body(req->data.ref());
reply.body(req->data->key.ref());
@@ -739,7 +734,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)
req->isSignedBy(std::get<Identity>(identity).keyMessage())) {
if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), req)) {
connection.channel().emplace<Stored<ChannelAccept>>(*acc);
- reply.header({ NetworkProtocol::Header::Type::ChannelAccept, acc->ref().digest() });
+ reply.header({ NetworkProtocol::Header::ChannelAccept { acc->ref().digest() } });
reply.body(acc->ref());
reply.body(acc.value()->data.ref());
reply.body(acc.value()->data->key.ref());
@@ -761,11 +756,9 @@ void Server::Peer::finalizeChannel(ReplyBuilder & reply, unique_ptr<Channel> ch)
vector<NetworkProtocol::Header::Item> hitems;
for (const auto & r : server.self.refs())
- reply.header(NetworkProtocol::Header::Item {
- NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() });
+ reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
for (const auto & r : server.self.updates())
- reply.header(NetworkProtocol::Header::Item {
- NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() });
+ reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
}
void Server::Peer::updateService(ReplyBuilder & reply)
@@ -848,7 +841,7 @@ optional<Ref> WaitingRef::check(ReplyBuilder & reply)
return r;
for (const auto & d : missing)
- reply.header({ NetworkProtocol::Header::Type::DataRequest, d });
+ reply.header({ NetworkProtocol::Header::DataRequest { d } });
return nullopt;
}
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index f38267f..ede7023 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -9,11 +9,12 @@
#include <mutex>
#include <system_error>
+using std::get_if;
using std::holds_alternative;
using std::move;
using std::nullopt;
-using std::runtime_error;
using std::scoped_lock;
+using std::visit;
namespace erebos {
@@ -327,21 +328,16 @@ void NetworkProtocol::Connection::trySendOutQueue()
/* Header */
/******************************************************************************/
-bool NetworkProtocol::Header::Item::operator==(const Item & other) const
+bool operator==(const NetworkProtocol::Header::Item & left,
+ const NetworkProtocol::Header::Item & right)
{
- if (type != other.type)
+ if (left.index() != right.index())
return false;
- if (value.index() != other.value.index())
- return false;
-
- if (holds_alternative<Digest>(value))
- return std::get<Digest>(value) == std::get<Digest>(other.value);
-
- if (holds_alternative<UUID>(value))
- return std::get<UUID>(value) == std::get<UUID>(other.value);
-
- throw runtime_error("unhandled network header item type");
+ return visit([&](auto && arg) {
+ using T = std::decay_t<decltype(arg)>;
+ return arg.value == std::get<T>(right).value;
+ }, left);
}
optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialRef & ref)
@@ -359,58 +355,31 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj
for (const auto & item : rec->items()) {
if (item.name == "ACK") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::Acknowledged,
- .value = ref->digest(),
- });
+ items.emplace_back(Acknowledged { ref->digest() });
} else if (item.name == "REQ") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::DataRequest,
- .value = ref->digest(),
- });
+ items.emplace_back(DataRequest { ref->digest() });
} else if (item.name == "RSP") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::DataResponse,
- .value = ref->digest(),
- });
+ items.emplace_back(DataResponse { ref->digest() });
} else if (item.name == "ANN") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::AnnounceSelf,
- .value = ref->digest(),
- });
+ items.emplace_back(AnnounceSelf { ref->digest() });
} else if (item.name == "ANU") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::AnnounceUpdate,
- .value = ref->digest(),
- });
+ items.emplace_back(AnnounceUpdate { ref->digest() });
} else if (item.name == "CRQ") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::ChannelRequest,
- .value = ref->digest(),
- });
+ items.emplace_back(ChannelRequest { ref->digest() });
} else if (item.name == "CAC") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::ChannelAccept,
- .value = ref->digest(),
- });
+ items.emplace_back(ChannelAccept { ref->digest() });
} else if (item.name == "STP") {
if (auto val = item.asUUID())
- items.emplace_back(Item {
- .type = Type::ServiceType,
- .value = *val,
- });
+ items.emplace_back(ServiceType { *val });
} else if (item.name == "SRF") {
if (auto ref = item.asRef())
- items.emplace_back(Item {
- .type = Type::ServiceRef,
- .value = ref->digest(),
- });
+ items.emplace_back(ServiceRef { ref->digest() });
}
}
@@ -422,43 +391,32 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const
vector<PartialRecord::Item> ritems;
for (const auto & item : items) {
- switch (item.type) {
- case Type::Acknowledged:
- ritems.emplace_back("ACK", st.ref(std::get<Digest>(item.value)));
- break;
-
- case Type::DataRequest:
- ritems.emplace_back("REQ", st.ref(std::get<Digest>(item.value)));
- break;
-
- case Type::DataResponse:
- ritems.emplace_back("RSP", st.ref(std::get<Digest>(item.value)));
- break;
-
- case Type::AnnounceSelf:
- ritems.emplace_back("ANN", st.ref(std::get<Digest>(item.value)));
- break;
-
- case Type::AnnounceUpdate:
- ritems.emplace_back("ANU", st.ref(std::get<Digest>(item.value)));
- break;
-
- case Type::ChannelRequest:
- ritems.emplace_back("CRQ", st.ref(std::get<Digest>(item.value)));
- break;
-
- case Type::ChannelAccept:
- ritems.emplace_back("CAC", st.ref(std::get<Digest>(item.value)));
- break;
-
- case Type::ServiceType:
- ritems.emplace_back("STP", std::get<UUID>(item.value));
- break;
-
- case Type::ServiceRef:
- ritems.emplace_back("SRF", st.ref(std::get<Digest>(item.value)));
- break;
- }
+ if (const auto * ptr = get_if<Acknowledged>(&item))
+ ritems.emplace_back("ACK", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<DataRequest>(&item))
+ ritems.emplace_back("REQ", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<DataResponse>(&item))
+ ritems.emplace_back("RSP", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<AnnounceSelf>(&item))
+ ritems.emplace_back("ANN", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<AnnounceUpdate>(&item))
+ ritems.emplace_back("ANU", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<ChannelRequest>(&item))
+ ritems.emplace_back("CRQ", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<ChannelAccept>(&item))
+ ritems.emplace_back("CAC", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<ServiceType>(&item))
+ ritems.emplace_back("STP", ptr->value);
+
+ else if (const auto * ptr = get_if<ServiceRef>(&item))
+ ritems.emplace_back("SRF", st.ref(ptr->value));
}
return PartialObject(PartialRecord(std::move(ritems)));
diff --git a/src/network/protocol.h b/src/network/protocol.h
index c5803ce..df29c05 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -106,7 +106,17 @@ struct NetworkProtocol::ConnectionReadReady { Connection::Id id; };
struct NetworkProtocol::Header
{
- enum class Type {
+ struct Acknowledged { Digest value; };
+ struct DataRequest { Digest value; };
+ struct DataResponse { Digest value; };
+ struct AnnounceSelf { Digest value; };
+ struct AnnounceUpdate { Digest value; };
+ struct ChannelRequest { Digest value; };
+ struct ChannelAccept { Digest value; };
+ struct ServiceType { UUID value; };
+ struct ServiceRef { Digest value; };
+
+ using Item = variant<
Acknowledged,
DataRequest,
DataResponse,
@@ -115,16 +125,7 @@ struct NetworkProtocol::Header
ChannelRequest,
ChannelAccept,
ServiceType,
- ServiceRef,
- };
-
- struct Item {
- const Type type;
- const variant<Digest, UUID> value;
-
- bool operator==(const Item &) const;
- bool operator!=(const Item & other) const { return !(*this == other); }
- };
+ ServiceRef>;
Header(const vector<Item> & items): items(items) {}
static optional<Header> load(const PartialRef &);
@@ -134,6 +135,11 @@ struct NetworkProtocol::Header
const vector<Item> items;
};
+bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &);
+inline bool operator!=(const NetworkProtocol::Header::Item & left,
+ const NetworkProtocol::Header::Item & right)
+{ return not (left == right); }
+
class ReplyBuilder
{
public: