diff options
| author | Roman Smrž <roman.smrz@seznam.cz> | 2023-09-17 19:27:34 +0200 | 
|---|---|---|
| committer | Roman Smrž <roman.smrz@seznam.cz> | 2023-09-18 21:48:17 +0200 | 
| commit | 1e374ab639af7afbdffd3be3be22be4ba21858e6 (patch) | |
| tree | b691db1b347087e769ed487f371f17bf91a609a4 /src | |
| parent | 512e20fa063e4a4525e47e048f26cc68668e7fac (diff) | |
Network: acknowledgment using packet counter
Diffstat (limited to 'src')
| -rw-r--r-- | src/network.cpp | 6 | ||||
| -rw-r--r-- | src/network/protocol.cpp | 67 | ||||
| -rw-r--r-- | src/network/protocol.h | 5 | 
3 files changed, 65 insertions, 13 deletions
| diff --git a/src/network.cpp b/src/network.cpp index 6840f43..2b8c46b 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -520,7 +520,8 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head  		else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) {  			const auto & dgst = rsp->value; -			reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); +			if (not holds_alternative<unique_ptr<Channel>>(peer.connection.channel())) +				reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });  			for (auto & pwref : waiting) {  				if (auto wref = pwref.lock()) {  					if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != @@ -554,7 +555,6 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head  		else if (const auto * anu = get_if<NetworkProtocol::Header::AnnounceUpdate>(&item)) {  			if (holds_alternative<Identity>(peer.identity)) {  				const auto & dgst = anu->value; -				reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });  				shared_ptr<WaitingRef> wref(new WaitingRef {  					.storage = peer.tempStorage, @@ -628,8 +628,6 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head  			if (serviceType) {  				const auto & dgst = sref->value;  				auto pref = peer.partStorage.ref(dgst); -				if (pref) -					reply.header({ NetworkProtocol::Header::Acknowledged { dgst } });  				shared_ptr<WaitingRef> wref(new WaitingRef {  					.storage = peer.tempStorage, diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 93d171a..8e0de61 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -36,6 +36,8 @@ struct NetworkProtocol::ConnectionPriv  	bool confirmedCookie = false;  	ChannelState channel = monostate();  	vector<vector<uint8_t>> secureOutQueue {}; + +	vector<uint64_t> toAcknowledge {};  }; @@ -74,6 +76,23 @@ NetworkProtocol::~NetworkProtocol()  NetworkProtocol::PollResult NetworkProtocol::poll()  { +	{ +		scoped_lock lock(protocolMutex); + +		for (const auto & c : connections) { +			{ +				scoped_lock clock(c->cmutex); +				if (c->toAcknowledge.empty()) +					continue; + +				if (not holds_alternative<unique_ptr<Channel>>(c->channel)) +					continue; +			} +			auto pst = self->ref()->storage().deriveEphemeralStorage(); +			c->send(pst, Header {{}}, {}, true); +		} +	} +  	sockaddr_in6 addr;  	if (!recvfrom(buffer, addr))  		return ProtocolClosed {}; @@ -90,7 +109,7 @@ NetworkProtocol::PollResult NetworkProtocol::poll()  		}  		auto pst = self->ref()->storage().deriveEphemeralStorage(); -		bool secure = false; +		optional<uint64_t> secure = false;  		if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) {  			if (auto conn = verifyNewConnection(*header, addr))  				return NewConnection { move(*conn) }; @@ -315,12 +334,15 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  		}  	} -	bool secure = false; +	optional<uint64_t> secure = false;  	if (auto header = parsePacket(buf, channel, partStorage, secure)) {  		scoped_lock lock(p->cmutex); -		if (secure) +		if (secure) { +			if (header->isAcknowledged()) +				p->toAcknowledge.push_back(*secure);  			return header; +		}  		if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) {  			if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) @@ -351,13 +373,13 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf,  		Channel * channel, const PartialStorage & partStorage, -		bool & secure) +		optional<uint64_t> & secure)  {  	vector<uint8_t> decrypted;  	auto plainBegin = buf.cbegin();  	auto plainEnd = buf.cbegin(); -	secure = false; +	secure = nullopt;  	if ((buf[0] & 0xE0) == 0x80) {  		if (not channel) { @@ -365,7 +387,7 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto  			return nullopt;  		} -		if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { +		if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) {  			if (decrypted.empty()) {  				std::cerr << "empty decrypted content\n";  			} @@ -378,8 +400,6 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto  				return nullopt;  			}  		} - -		secure = true;  	}  	else if ((buf[0] & 0xE0) == 0x60) {  		plainBegin = buf.begin(); @@ -431,6 +451,15 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,  				header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) });  		} +		if (channel) { +			for (auto num : toAcknowledge) +				header.items.push_back(Header::AcknowledgedSingle { num }); +			toAcknowledge.clear(); +		} + +		if (header.items.empty()) +			return false; +  		part = header.toObject(partStorage).encode();  		data.insert(data.end(), part.begin(), part.end());  		for (const auto & obj : objs) { @@ -533,6 +562,8 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj  		if (item.name == "ACK") {  			if (auto ref = item.asRef())  				items.emplace_back(Acknowledged { ref->digest() }); +			else if (auto num = item.asInteger()) +				items.emplace_back(AcknowledgedSingle { static_cast<uint64_t>(*num) });  		} else if (item.name == "VER") {  			if (auto ver = item.asText())  				items.emplace_back(Version { *ver }); @@ -583,6 +614,9 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const  		if (const auto * ptr = get_if<Acknowledged>(&item))  			ritems.emplace_back("ACK", st.ref(ptr->value)); +		else if (const auto * ptr = get_if<AcknowledgedSingle>(&item)) +			ritems.emplace_back("ACK", Record::Item::Integer(ptr->value)); +  		else if (const auto * ptr = get_if<Version>(&item))  			ritems.emplace_back("VER", ptr->value); @@ -623,4 +657,21 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const  	return PartialObject(PartialRecord(std::move(ritems)));  } +bool NetworkProtocol::Header::isAcknowledged() const +{ +	for (const auto & item : items) { +		if (holds_alternative<Acknowledged>(item) +		 || holds_alternative<AcknowledgedSingle>(item) +		 || holds_alternative<Version>(item) +		 || holds_alternative<Initiation>(item) +		 || holds_alternative<CookieSet>(item) +		 || holds_alternative<CookieEcho>(item) +		   ) +			continue; + +		return true; +	} +	return false; +} +  } diff --git a/src/network/protocol.h b/src/network/protocol.h index 3d7c073..ba40744 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -116,7 +116,7 @@ public:  private:  	static optional<Header> parsePacket(vector<uint8_t> & buf,  			Channel * channel, const PartialStorage & st, -			bool & secure); +			optional<uint64_t> & secure);  	unique_ptr<ConnectionPriv> p;  }; @@ -128,6 +128,7 @@ struct NetworkProtocol::ConnectionReadReady { Connection::Id id; };  struct NetworkProtocol::Header  {  	struct Acknowledged { Digest value; }; +	struct AcknowledgedSingle { uint64_t value; };  	struct Version { string value; };  	struct Initiation { Digest value; };  	struct CookieSet { Cookie value; }; @@ -143,6 +144,7 @@ struct NetworkProtocol::Header  	using Item = variant<  		Acknowledged, +		AcknowledgedSingle,  		Version,  		Initiation,  		CookieSet, @@ -162,6 +164,7 @@ struct NetworkProtocol::Header  	PartialObject toObject(const PartialStorage &) const;  	template<class T> const T * lookupFirst() const; +	bool isAcknowledged() const;  	vector<Item> items;  }; |