diff options
| -rw-r--r-- | src/network.cpp | 173 | ||||
| -rw-r--r-- | src/network/protocol.cpp | 130 | ||||
| -rw-r--r-- | src/network/protocol.h | 28 | 
3 files changed, 144 insertions, 187 deletions
| diff --git a/src/network.cpp b/src/network.cpp index 8c181cf..7a5a804 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -114,8 +114,7 @@ void Server::addPeer(const string & node, const string & service) const  			vector<NetworkProtocol::Header::Item> header;  			{  				shared_lock lock(p->selfMutex); -				header.push_back(NetworkProtocol::Header::Item { -					NetworkProtocol::Header::Type::AnnounceSelf, p->self.ref()->digest() }); +				header.push_back(NetworkProtocol::Header::AnnounceSelf { p->self.ref()->digest() });  			}  			peer.connection.send(peer.partStorage, header, {}, false);  			return; @@ -226,8 +225,8 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const  {  	if (auto speer = p->speer.lock()) {  		NetworkProtocol::Header header({ -			{ NetworkProtocol::Header::Type::ServiceType, uuid }, -				{ NetworkProtocol::Header::Type::ServiceRef, ref.digest() }, +			NetworkProtocol::Header::ServiceType { uuid }, +			NetworkProtocol::Header::ServiceRef { ref.digest() },  		});  		speer->connection.send(speer->partStorage, header, { obj }, true);  		return true; @@ -417,7 +416,7 @@ void Server::Priv::doAnnounce()  		if (lastAnnounce + announceInterval < now) {  			shared_lock slock(selfMutex);  			NetworkProtocol::Header header({ -				{ NetworkProtocol::Header::Type::AnnounceSelf, self.ref()->digest() } +				NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() },  			});  			vector<uint8_t> bytes = header.toObject(pst).encode(); @@ -506,32 +505,29 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head  	optional<UUID> serviceType; -	for (auto & item : header.items) { -		switch (item.type) { -		case NetworkProtocol::Header::Type::Acknowledged: { -			auto dgst = std::get<Digest>(item.value); +	for (const auto & item : header.items) { +		if (const auto * ack = get_if<NetworkProtocol::Header::Acknowledged>(&item)) { +			const auto & dgst = ack->value;  			if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) &&  					std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() == dgst)  				peer.finalizeChannel(reply,  					std::get<Stored<ChannelAccept>>(peer.connection.channel())->data->channel()); -			break;  		} -		case NetworkProtocol::Header::Type::DataRequest: { -			auto dgst = std::get<Digest>(item.value); +		else if (const auto * req = get_if<NetworkProtocol::Header::DataRequest>(&item)) { +			const auto & dgst = req->value;  			if (holds_alternative<unique_ptr<Channel>>(peer.connection.channel()) ||  					plaintextRefs.find(dgst) != plaintextRefs.end()) {  				if (auto ref = peer.tempStorage.ref(dgst)) { -					reply.header({ NetworkProtocol::Header::Type::DataResponse, ref->digest() }); +					reply.header({ NetworkProtocol::Header::DataResponse { ref->digest() } });  					reply.body(*ref);  				}  			} -			break;  		} -		case NetworkProtocol::Header::Type::DataResponse: { -			auto dgst = std::get<Digest>(item.value); -			reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); +		else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) { +			const auto & dgst = rsp->value; +			reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });  			for (auto & pwref : waiting) {  				if (auto wref = pwref.lock()) {  					if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != @@ -543,16 +539,13 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head  			}  			waiting.erase(std::remove_if(waiting.begin(), waiting.end(),  						[](auto & wref) { return wref.expired(); }), waiting.end()); -			break;  		} -		case NetworkProtocol::Header::Type::AnnounceSelf: { -			auto dgst = std::get<Digest>(item.value); -			if (dgst == self.ref()->digest()) -				break; - -			if (holds_alternative<monostate>(peer.identity)) { -				reply.header({ NetworkProtocol::Header::Type::AnnounceSelf, self.ref()->digest()}); +		else if (const auto * ann = get_if<NetworkProtocol::Header::AnnounceSelf>(&item)) { +			const auto & dgst = ann->value; +			if (dgst != self.ref()->digest() && +					holds_alternative<monostate>(peer.identity)) { +				reply.header({ NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() }});  				shared_ptr<WaitingRef> wref(new WaitingRef {  					.storage = peer.tempStorage, @@ -563,13 +556,12 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head  				peer.identity = wref;  				wref->check(reply);  			} -			break;  		} -		case NetworkProtocol::Header::Type::AnnounceUpdate: +		else if (const auto * anu = get_if<NetworkProtocol::Header::AnnounceUpdate>(&item)) {  			if (holds_alternative<Identity>(peer.identity)) { -				auto dgst = std::get<Digest>(item.value); -				reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); +				const auto & dgst = anu->value; +				reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });  				shared_ptr<WaitingRef> wref(new WaitingRef {  					.storage = peer.tempStorage, @@ -580,76 +572,81 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head  				peer.identityUpdates.push_back(wref);  				wref->check(reply);  			} -			break; +		} -		case NetworkProtocol::Header::Type::ChannelRequest: { -			auto dgst = std::get<Digest>(item.value); -			reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); +		else if (const auto * req = get_if<NetworkProtocol::Header::ChannelRequest>(&item)) { +			const auto & dgst = req->value; +			reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });  			if (holds_alternative<Stored<ChannelRequest>>(peer.connection.channel()) && -					std::get<Stored<ChannelRequest>>(peer.connection.channel()).ref().digest() < dgst) -				break; +					std::get<Stored<ChannelRequest>>(peer.connection.channel()).ref().digest() < dgst) { +				// TODO: reject request with lower priority +			} -			if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel())) -				break; +			else if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel())) { +				// TODO: reject when we already sent accept +			} -			shared_ptr<WaitingRef> wref(new WaitingRef { -				.storage = peer.tempStorage, -				.ref = peer.partStorage.ref(dgst), -				.missing = {}, -			}); -			waiting.push_back(wref); -			peer.connection.channel() = wref; -			wref->check(reply); -			break; +			else { +				shared_ptr<WaitingRef> wref(new WaitingRef { +					.storage = peer.tempStorage, +					.ref = peer.partStorage.ref(dgst), +					.missing = {}, +				}); +				waiting.push_back(wref); +				peer.connection.channel() = wref; +				wref->check(reply); +			}  		} -		case NetworkProtocol::Header::Type::ChannelAccept: { -			auto dgst = std::get<Digest>(item.value); +		else if (const auto * acc = get_if<NetworkProtocol::Header::ChannelAccept>(&item)) { +			const auto & dgst = acc->value;  			if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) && -					std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() < dgst) -				break; - -			auto cres = peer.tempStorage.copy(peer.partStorage.ref(dgst)); -			if (auto r = std::get_if<Ref>(&cres)) { -				auto acc = ChannelAccept::load(*r); -				if (holds_alternative<Identity>(peer.identity) && -						acc.isSignedBy(std::get<Identity>(peer.identity).keyMessage())) { -					reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); -					peer.finalizeChannel(reply, acc.data->channel()); +					std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() < dgst) { +				// TODO: reject request with lower priority +			} + +			else { +				auto cres = peer.tempStorage.copy(peer.partStorage.ref(dgst)); +				if (auto r = get_if<Ref>(&cres)) { +					auto acc = ChannelAccept::load(*r); +					if (holds_alternative<Identity>(peer.identity) && +							acc.isSignedBy(std::get<Identity>(peer.identity).keyMessage())) { +						reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); +						peer.finalizeChannel(reply, acc.data->channel()); +					}  				}  			} -			break;  		} -		case NetworkProtocol::Header::Type::ServiceType: +		else if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) {  			if (!serviceType) -				serviceType = std::get<UUID>(item.value); -			break; +				serviceType = stype->value; +		} -		case NetworkProtocol::Header::Type::ServiceRef: +		else if (const auto * sref = get_if<NetworkProtocol::Header::ServiceRef>(&item)) {  			if (!serviceType)  				for (auto & item : header.items) -					if (item.type == NetworkProtocol::Header::Type::ServiceType) { -						serviceType = std::get<UUID>(item.value); +					if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) { +						serviceType = stype->value;  						break;  					} -			if (!serviceType) -				break; -			auto dgst = std::get<Digest>(item.value); -			auto pref = peer.partStorage.ref(dgst); -			if (pref) -				reply.header({ NetworkProtocol::Header::Type::Acknowledged, dgst }); +			if (serviceType) { +				const auto & dgst = sref->value; +				auto pref = peer.partStorage.ref(dgst); +				if (pref) +					reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); -			shared_ptr<WaitingRef> wref(new WaitingRef { -				.storage = peer.tempStorage, -				.ref = pref, -				.missing = {}, -			}); -			waiting.push_back(wref); -			peer.serviceQueue.emplace_back(*serviceType, wref); -			wref->check(reply); +				shared_ptr<WaitingRef> wref(new WaitingRef { +					.storage = peer.tempStorage, +					.ref = pref, +					.missing = {}, +				}); +				waiting.push_back(wref); +				peer.serviceQueue.emplace_back(*serviceType, wref); +				wref->check(reply); +			}  		}  	}  } @@ -665,11 +662,9 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head)  			vector<NetworkProtocol::Header::Item> hitems;  			for (const auto & r : self.refs()) -				hitems.push_back(NetworkProtocol::Header::Item { -					NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); +				hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });  			for (const auto & r : self.updates()) -				hitems.push_back(NetworkProtocol::Header::Item { -					NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); +				hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });  			NetworkProtocol::Header header(hitems); @@ -724,7 +719,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)  		auto req = Channel::generateRequest(tempStorage,  				server.self, std::get<Identity>(identity));  		connection.channel().emplace<Stored<ChannelRequest>>(req); -		reply.header({ NetworkProtocol::Header::Type::ChannelRequest, req.ref().digest() }); +		reply.header({ NetworkProtocol::Header::ChannelRequest { req.ref().digest() } });  		reply.body(req.ref());  		reply.body(req->data.ref());  		reply.body(req->data->key.ref()); @@ -739,7 +734,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)  					req->isSignedBy(std::get<Identity>(identity).keyMessage())) {  				if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), req)) {  					connection.channel().emplace<Stored<ChannelAccept>>(*acc); -					reply.header({ NetworkProtocol::Header::Type::ChannelAccept, acc->ref().digest() }); +					reply.header({ NetworkProtocol::Header::ChannelAccept { acc->ref().digest() } });  					reply.body(acc->ref());  					reply.body(acc.value()->data.ref());  					reply.body(acc.value()->data->key.ref()); @@ -761,11 +756,9 @@ void Server::Peer::finalizeChannel(ReplyBuilder & reply, unique_ptr<Channel> ch)  	vector<NetworkProtocol::Header::Item> hitems;  	for (const auto & r : server.self.refs()) -		reply.header(NetworkProtocol::Header::Item { -			NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); +		reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() });  	for (const auto & r : server.self.updates()) -		reply.header(NetworkProtocol::Header::Item { -			NetworkProtocol::Header::Type::AnnounceUpdate, r.digest() }); +		reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() });  }  void Server::Peer::updateService(ReplyBuilder & reply) @@ -848,7 +841,7 @@ optional<Ref> WaitingRef::check(ReplyBuilder & reply)  		return r;  	for (const auto & d : missing) -		reply.header({ NetworkProtocol::Header::Type::DataRequest, d }); +		reply.header({ NetworkProtocol::Header::DataRequest { d } });  	return nullopt;  } diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index f38267f..ede7023 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -9,11 +9,12 @@  #include <mutex>  #include <system_error> +using std::get_if;  using std::holds_alternative;  using std::move;  using std::nullopt; -using std::runtime_error;  using std::scoped_lock; +using std::visit;  namespace erebos { @@ -327,21 +328,16 @@ void NetworkProtocol::Connection::trySendOutQueue()  /* Header                                                                     */  /******************************************************************************/ -bool NetworkProtocol::Header::Item::operator==(const Item & other) const +bool operator==(const NetworkProtocol::Header::Item & left, +		const NetworkProtocol::Header::Item & right)  { -	if (type != other.type) +	if (left.index() != right.index())  		return false; -	if (value.index() != other.value.index()) -		return false; - -	if (holds_alternative<Digest>(value)) -		return std::get<Digest>(value) == std::get<Digest>(other.value); - -	if (holds_alternative<UUID>(value)) -		return std::get<UUID>(value) == std::get<UUID>(other.value); - -	throw runtime_error("unhandled network header item type"); +	return visit([&](auto && arg) { +            using T = std::decay_t<decltype(arg)>; +	    return arg.value == std::get<T>(right).value; +	}, left);  }  optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialRef & ref) @@ -359,58 +355,31 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj  	for (const auto & item : rec->items()) {  		if (item.name == "ACK") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::Acknowledged, -					.value = ref->digest(), -				}); +				items.emplace_back(Acknowledged { ref->digest() });  		} else if (item.name == "REQ") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::DataRequest, -					.value = ref->digest(), -				}); +				items.emplace_back(DataRequest { ref->digest() });  		} else if (item.name == "RSP") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::DataResponse, -					.value = ref->digest(), -				}); +				items.emplace_back(DataResponse { ref->digest() });  		} else if (item.name == "ANN") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::AnnounceSelf, -					.value = ref->digest(), -				}); +				items.emplace_back(AnnounceSelf { ref->digest() });  		} else if (item.name == "ANU") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::AnnounceUpdate, -					.value = ref->digest(), -				}); +				items.emplace_back(AnnounceUpdate { ref->digest() });  		} else if (item.name == "CRQ") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::ChannelRequest, -					.value = ref->digest(), -				}); +				items.emplace_back(ChannelRequest { ref->digest() });  		} else if (item.name == "CAC") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::ChannelAccept, -					.value = ref->digest(), -				}); +				items.emplace_back(ChannelAccept { ref->digest() });  		} else if (item.name == "STP") {  			if (auto val = item.asUUID()) -				items.emplace_back(Item { -					.type = Type::ServiceType, -					.value = *val, -				}); +				items.emplace_back(ServiceType { *val });  		} else if (item.name == "SRF") {  			if (auto ref = item.asRef()) -				items.emplace_back(Item { -					.type = Type::ServiceRef, -					.value = ref->digest(), -				}); +				items.emplace_back(ServiceRef { ref->digest() });  		}  	} @@ -422,43 +391,32 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const  	vector<PartialRecord::Item> ritems;  	for (const auto & item : items) { -		switch (item.type) { -		case Type::Acknowledged: -			ritems.emplace_back("ACK", st.ref(std::get<Digest>(item.value))); -			break; - -		case Type::DataRequest: -			ritems.emplace_back("REQ", st.ref(std::get<Digest>(item.value))); -			break; - -		case Type::DataResponse: -			ritems.emplace_back("RSP", st.ref(std::get<Digest>(item.value))); -			break; - -		case Type::AnnounceSelf: -			ritems.emplace_back("ANN", st.ref(std::get<Digest>(item.value))); -			break; - -		case Type::AnnounceUpdate: -			ritems.emplace_back("ANU", st.ref(std::get<Digest>(item.value))); -			break; - -		case Type::ChannelRequest: -			ritems.emplace_back("CRQ", st.ref(std::get<Digest>(item.value))); -			break; - -		case Type::ChannelAccept: -			ritems.emplace_back("CAC", st.ref(std::get<Digest>(item.value))); -			break; - -		case Type::ServiceType: -			ritems.emplace_back("STP", std::get<UUID>(item.value)); -			break; - -		case Type::ServiceRef: -			ritems.emplace_back("SRF", st.ref(std::get<Digest>(item.value))); -			break; -		} +		if (const auto * ptr = get_if<Acknowledged>(&item)) +			ritems.emplace_back("ACK", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<DataRequest>(&item)) +			ritems.emplace_back("REQ", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<DataResponse>(&item)) +			ritems.emplace_back("RSP", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<AnnounceSelf>(&item)) +			ritems.emplace_back("ANN", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<AnnounceUpdate>(&item)) +			ritems.emplace_back("ANU", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<ChannelRequest>(&item)) +			ritems.emplace_back("CRQ", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<ChannelAccept>(&item)) +			ritems.emplace_back("CAC", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<ServiceType>(&item)) +			ritems.emplace_back("STP", ptr->value); + +		else if (const auto * ptr = get_if<ServiceRef>(&item)) +			ritems.emplace_back("SRF", st.ref(ptr->value));  	}  	return PartialObject(PartialRecord(std::move(ritems))); diff --git a/src/network/protocol.h b/src/network/protocol.h index c5803ce..df29c05 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -106,7 +106,17 @@ struct NetworkProtocol::ConnectionReadReady { Connection::Id id; };  struct NetworkProtocol::Header  { -	enum class Type { +	struct Acknowledged { Digest value; }; +	struct DataRequest { Digest value; }; +	struct DataResponse { Digest value; }; +	struct AnnounceSelf { Digest value; }; +	struct AnnounceUpdate { Digest value; }; +	struct ChannelRequest { Digest value; }; +	struct ChannelAccept { Digest value; }; +	struct ServiceType { UUID value; }; +	struct ServiceRef { Digest value; }; + +	using Item = variant<  		Acknowledged,  		DataRequest,  		DataResponse, @@ -115,16 +125,7 @@ struct NetworkProtocol::Header  		ChannelRequest,  		ChannelAccept,  		ServiceType, -		ServiceRef, -	}; - -	struct Item { -		const Type type; -		const variant<Digest, UUID> value; - -		bool operator==(const Item &) const; -		bool operator!=(const Item & other) const { return !(*this == other); } -	}; +		ServiceRef>;  	Header(const vector<Item> & items): items(items) {}  	static optional<Header> load(const PartialRef &); @@ -134,6 +135,11 @@ struct NetworkProtocol::Header  	const vector<Item> items;  }; +bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &); +inline bool operator!=(const NetworkProtocol::Header::Item & left, +		const NetworkProtocol::Header::Item & right) +{ return not (left == right); } +  class ReplyBuilder  {  public: |