diff options
-rw-r--r-- | src/network.cpp | 6 | ||||
-rw-r--r-- | src/network/protocol.cpp | 67 | ||||
-rw-r--r-- | src/network/protocol.h | 5 |
3 files changed, 65 insertions, 13 deletions
diff --git a/src/network.cpp b/src/network.cpp index 6840f43..2b8c46b 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -520,7 +520,8 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) { const auto & dgst = rsp->value; - reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); + if (not holds_alternative<unique_ptr<Channel>>(peer.connection.channel())) + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != @@ -554,7 +555,6 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head else if (const auto * anu = get_if<NetworkProtocol::Header::AnnounceUpdate>(&item)) { if (holds_alternative<Identity>(peer.identity)) { const auto & dgst = anu->value; - reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); shared_ptr<WaitingRef> wref(new WaitingRef { .storage = peer.tempStorage, @@ -628,8 +628,6 @@ void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Head if (serviceType) { const auto & dgst = sref->value; auto pref = peer.partStorage.ref(dgst); - if (pref) - reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); shared_ptr<WaitingRef> wref(new WaitingRef { .storage = peer.tempStorage, diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 93d171a..8e0de61 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -36,6 +36,8 @@ struct NetworkProtocol::ConnectionPriv bool confirmedCookie = false; ChannelState channel = monostate(); vector<vector<uint8_t>> secureOutQueue {}; + + vector<uint64_t> toAcknowledge {}; }; @@ -74,6 +76,23 @@ NetworkProtocol::~NetworkProtocol() NetworkProtocol::PollResult NetworkProtocol::poll() { + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + { + scoped_lock clock(c->cmutex); + if (c->toAcknowledge.empty()) + continue; + + if (not holds_alternative<unique_ptr<Channel>>(c->channel)) + continue; + } + auto pst = self->ref()->storage().deriveEphemeralStorage(); + c->send(pst, Header {{}}, {}, true); + } + } + sockaddr_in6 addr; if (!recvfrom(buffer, addr)) return ProtocolClosed {}; @@ -90,7 +109,7 @@ NetworkProtocol::PollResult NetworkProtocol::poll() } auto pst = self->ref()->storage().deriveEphemeralStorage(); - bool secure = false; + optional<uint64_t> secure = false; if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) { if (auto conn = verifyNewConnection(*header, addr)) return NewConnection { move(*conn) }; @@ -315,12 +334,15 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par } } - bool secure = false; + optional<uint64_t> secure = false; if (auto header = parsePacket(buf, channel, partStorage, secure)) { scoped_lock lock(p->cmutex); - if (secure) + if (secure) { + if (header->isAcknowledged()) + p->toAcknowledge.push_back(*secure); return header; + } if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) { if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) @@ -351,13 +373,13 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, Channel * channel, const PartialStorage & partStorage, - bool & secure) + optional<uint64_t> & secure) { vector<uint8_t> decrypted; auto plainBegin = buf.cbegin(); auto plainEnd = buf.cbegin(); - secure = false; + secure = nullopt; if ((buf[0] & 0xE0) == 0x80) { if (not channel) { @@ -365,7 +387,7 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto return nullopt; } - if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { + if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) { if (decrypted.empty()) { std::cerr << "empty decrypted content\n"; } @@ -378,8 +400,6 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto return nullopt; } } - - secure = true; } else if ((buf[0] & 0xE0) == 0x60) { plainBegin = buf.begin(); @@ -431,6 +451,15 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) }); } + if (channel) { + for (auto num : toAcknowledge) + header.items.push_back(Header::AcknowledgedSingle { num }); + toAcknowledge.clear(); + } + + if (header.items.empty()) + return false; + part = header.toObject(partStorage).encode(); data.insert(data.end(), part.begin(), part.end()); for (const auto & obj : objs) { @@ -533,6 +562,8 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj if (item.name == "ACK") { if (auto ref = item.asRef()) items.emplace_back(Acknowledged { ref->digest() }); + else if (auto num = item.asInteger()) + items.emplace_back(AcknowledgedSingle { static_cast<uint64_t>(*num) }); } else if (item.name == "VER") { if (auto ver = item.asText()) items.emplace_back(Version { *ver }); @@ -583,6 +614,9 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const if (const auto * ptr = get_if<Acknowledged>(&item)) ritems.emplace_back("ACK", st.ref(ptr->value)); + else if (const auto * ptr = get_if<AcknowledgedSingle>(&item)) + ritems.emplace_back("ACK", Record::Item::Integer(ptr->value)); + else if (const auto * ptr = get_if<Version>(&item)) ritems.emplace_back("VER", ptr->value); @@ -623,4 +657,21 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const return PartialObject(PartialRecord(std::move(ritems))); } +bool NetworkProtocol::Header::isAcknowledged() const +{ + for (const auto & item : items) { + if (holds_alternative<Acknowledged>(item) + || holds_alternative<AcknowledgedSingle>(item) + || holds_alternative<Version>(item) + || holds_alternative<Initiation>(item) + || holds_alternative<CookieSet>(item) + || holds_alternative<CookieEcho>(item) + ) + continue; + + return true; + } + return false; +} + } diff --git a/src/network/protocol.h b/src/network/protocol.h index 3d7c073..ba40744 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -116,7 +116,7 @@ public: private: static optional<Header> parsePacket(vector<uint8_t> & buf, Channel * channel, const PartialStorage & st, - bool & secure); + optional<uint64_t> & secure); unique_ptr<ConnectionPriv> p; }; @@ -128,6 +128,7 @@ struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; struct NetworkProtocol::Header { struct Acknowledged { Digest value; }; + struct AcknowledgedSingle { uint64_t value; }; struct Version { string value; }; struct Initiation { Digest value; }; struct CookieSet { Cookie value; }; @@ -143,6 +144,7 @@ struct NetworkProtocol::Header using Item = variant< Acknowledged, + AcknowledgedSingle, Version, Initiation, CookieSet, @@ -162,6 +164,7 @@ struct NetworkProtocol::Header PartialObject toObject(const PartialStorage &) const; template<class T> const T * lookupFirst() const; + bool isAcknowledged() const; vector<Item> items; }; |