From 1e374ab639af7afbdffd3be3be22be4ba21858e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sun, 17 Sep 2023 19:27:34 +0200 Subject: Network: acknowledgment using packet counter --- src/network/protocol.cpp | 67 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 8 deletions(-) (limited to 'src/network/protocol.cpp') diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 93d171a..8e0de61 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -36,6 +36,8 @@ struct NetworkProtocol::ConnectionPriv bool confirmedCookie = false; ChannelState channel = monostate(); vector> secureOutQueue {}; + + vector toAcknowledge {}; }; @@ -74,6 +76,23 @@ NetworkProtocol::~NetworkProtocol() NetworkProtocol::PollResult NetworkProtocol::poll() { + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + { + scoped_lock clock(c->cmutex); + if (c->toAcknowledge.empty()) + continue; + + if (not holds_alternative>(c->channel)) + continue; + } + auto pst = self->ref()->storage().deriveEphemeralStorage(); + c->send(pst, Header {{}}, {}, true); + } + } + sockaddr_in6 addr; if (!recvfrom(buffer, addr)) return ProtocolClosed {}; @@ -90,7 +109,7 @@ NetworkProtocol::PollResult NetworkProtocol::poll() } auto pst = self->ref()->storage().deriveEphemeralStorage(); - bool secure = false; + optional secure = false; if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) { if (auto conn = verifyNewConnection(*header, addr)) return NewConnection { move(*conn) }; @@ -315,12 +334,15 @@ optional NetworkProtocol::Connection::receive(const Par } } - bool secure = false; + optional secure = false; if (auto header = parsePacket(buf, channel, partStorage, secure)) { scoped_lock lock(p->cmutex); - if (secure) + if (secure) { + if (header->isAcknowledged()) + p->toAcknowledge.push_back(*secure); return header; + } if (const auto * cookieEcho = header->lookupFirst()) { if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) @@ -351,13 +373,13 @@ optional NetworkProtocol::Connection::receive(const Par optional NetworkProtocol::Connection::parsePacket(vector & buf, Channel * channel, const PartialStorage & partStorage, - bool & secure) + optional & secure) { vector decrypted; auto plainBegin = buf.cbegin(); auto plainEnd = buf.cbegin(); - secure = false; + secure = nullopt; if ((buf[0] & 0xE0) == 0x80) { if (not channel) { @@ -365,7 +387,7 @@ optional NetworkProtocol::Connection::parsePacket(vecto return nullopt; } - if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { + if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) { if (decrypted.empty()) { std::cerr << "empty decrypted content\n"; } @@ -378,8 +400,6 @@ optional NetworkProtocol::Connection::parsePacket(vecto return nullopt; } } - - secure = true; } else if ((buf[0] & 0xE0) == 0x60) { plainBegin = buf.begin(); @@ -431,6 +451,15 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) }); } + if (channel) { + for (auto num : toAcknowledge) + header.items.push_back(Header::AcknowledgedSingle { num }); + toAcknowledge.clear(); + } + + if (header.items.empty()) + return false; + part = header.toObject(partStorage).encode(); data.insert(data.end(), part.begin(), part.end()); for (const auto & obj : objs) { @@ -533,6 +562,8 @@ optional NetworkProtocol::Header::load(const PartialObj if (item.name == "ACK") { if (auto ref = item.asRef()) items.emplace_back(Acknowledged { ref->digest() }); + else if (auto num = item.asInteger()) + items.emplace_back(AcknowledgedSingle { static_cast(*num) }); } else if (item.name == "VER") { if (auto ver = item.asText()) items.emplace_back(Version { *ver }); @@ -583,6 +614,9 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const if (const auto * ptr = get_if(&item)) ritems.emplace_back("ACK", st.ref(ptr->value)); + else if (const auto * ptr = get_if(&item)) + ritems.emplace_back("ACK", Record::Item::Integer(ptr->value)); + else if (const auto * ptr = get_if(&item)) ritems.emplace_back("VER", ptr->value); @@ -623,4 +657,21 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const return PartialObject(PartialRecord(std::move(ritems))); } +bool NetworkProtocol::Header::isAcknowledged() const +{ + for (const auto & item : items) { + if (holds_alternative(item) + || holds_alternative(item) + || holds_alternative(item) + || holds_alternative(item) + || holds_alternative(item) + || holds_alternative(item) + ) + continue; + + return true; + } + return false; +} + } -- cgit v1.2.3