summaryrefslogtreecommitdiff
path: root/src/network
diff options
context:
space:
mode:
Diffstat (limited to 'src/network')
-rw-r--r--src/network/protocol.cpp67
-rw-r--r--src/network/protocol.h5
2 files changed, 63 insertions, 9 deletions
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index 93d171a..8e0de61 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -36,6 +36,8 @@ struct NetworkProtocol::ConnectionPriv
bool confirmedCookie = false;
ChannelState channel = monostate();
vector<vector<uint8_t>> secureOutQueue {};
+
+ vector<uint64_t> toAcknowledge {};
};
@@ -74,6 +76,23 @@ NetworkProtocol::~NetworkProtocol()
NetworkProtocol::PollResult NetworkProtocol::poll()
{
+ {
+ scoped_lock lock(protocolMutex);
+
+ for (const auto & c : connections) {
+ {
+ scoped_lock clock(c->cmutex);
+ if (c->toAcknowledge.empty())
+ continue;
+
+ if (not holds_alternative<unique_ptr<Channel>>(c->channel))
+ continue;
+ }
+ auto pst = self->ref()->storage().deriveEphemeralStorage();
+ c->send(pst, Header {{}}, {}, true);
+ }
+ }
+
sockaddr_in6 addr;
if (!recvfrom(buffer, addr))
return ProtocolClosed {};
@@ -90,7 +109,7 @@ NetworkProtocol::PollResult NetworkProtocol::poll()
}
auto pst = self->ref()->storage().deriveEphemeralStorage();
- bool secure = false;
+ optional<uint64_t> secure = false;
if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) {
if (auto conn = verifyNewConnection(*header, addr))
return NewConnection { move(*conn) };
@@ -315,12 +334,15 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
}
}
- bool secure = false;
+ optional<uint64_t> secure = false;
if (auto header = parsePacket(buf, channel, partStorage, secure)) {
scoped_lock lock(p->cmutex);
- if (secure)
+ if (secure) {
+ if (header->isAcknowledged())
+ p->toAcknowledge.push_back(*secure);
return header;
+ }
if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) {
if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value))
@@ -351,13 +373,13 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf,
Channel * channel, const PartialStorage & partStorage,
- bool & secure)
+ optional<uint64_t> & secure)
{
vector<uint8_t> decrypted;
auto plainBegin = buf.cbegin();
auto plainEnd = buf.cbegin();
- secure = false;
+ secure = nullopt;
if ((buf[0] & 0xE0) == 0x80) {
if (not channel) {
@@ -365,7 +387,7 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto
return nullopt;
}
- if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) {
+ if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) {
if (decrypted.empty()) {
std::cerr << "empty decrypted content\n";
}
@@ -378,8 +400,6 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto
return nullopt;
}
}
-
- secure = true;
}
else if ((buf[0] & 0xE0) == 0x60) {
plainBegin = buf.begin();
@@ -431,6 +451,15 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) });
}
+ if (channel) {
+ for (auto num : toAcknowledge)
+ header.items.push_back(Header::AcknowledgedSingle { num });
+ toAcknowledge.clear();
+ }
+
+ if (header.items.empty())
+ return false;
+
part = header.toObject(partStorage).encode();
data.insert(data.end(), part.begin(), part.end());
for (const auto & obj : objs) {
@@ -533,6 +562,8 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj
if (item.name == "ACK") {
if (auto ref = item.asRef())
items.emplace_back(Acknowledged { ref->digest() });
+ else if (auto num = item.asInteger())
+ items.emplace_back(AcknowledgedSingle { static_cast<uint64_t>(*num) });
} else if (item.name == "VER") {
if (auto ver = item.asText())
items.emplace_back(Version { *ver });
@@ -583,6 +614,9 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const
if (const auto * ptr = get_if<Acknowledged>(&item))
ritems.emplace_back("ACK", st.ref(ptr->value));
+ else if (const auto * ptr = get_if<AcknowledgedSingle>(&item))
+ ritems.emplace_back("ACK", Record::Item::Integer(ptr->value));
+
else if (const auto * ptr = get_if<Version>(&item))
ritems.emplace_back("VER", ptr->value);
@@ -623,4 +657,21 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const
return PartialObject(PartialRecord(std::move(ritems)));
}
+bool NetworkProtocol::Header::isAcknowledged() const
+{
+ for (const auto & item : items) {
+ if (holds_alternative<Acknowledged>(item)
+ || holds_alternative<AcknowledgedSingle>(item)
+ || holds_alternative<Version>(item)
+ || holds_alternative<Initiation>(item)
+ || holds_alternative<CookieSet>(item)
+ || holds_alternative<CookieEcho>(item)
+ )
+ continue;
+
+ return true;
+ }
+ return false;
+}
+
}
diff --git a/src/network/protocol.h b/src/network/protocol.h
index 3d7c073..ba40744 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -116,7 +116,7 @@ public:
private:
static optional<Header> parsePacket(vector<uint8_t> & buf,
Channel * channel, const PartialStorage & st,
- bool & secure);
+ optional<uint64_t> & secure);
unique_ptr<ConnectionPriv> p;
};
@@ -128,6 +128,7 @@ struct NetworkProtocol::ConnectionReadReady { Connection::Id id; };
struct NetworkProtocol::Header
{
struct Acknowledged { Digest value; };
+ struct AcknowledgedSingle { uint64_t value; };
struct Version { string value; };
struct Initiation { Digest value; };
struct CookieSet { Cookie value; };
@@ -143,6 +144,7 @@ struct NetworkProtocol::Header
using Item = variant<
Acknowledged,
+ AcknowledgedSingle,
Version,
Initiation,
CookieSet,
@@ -162,6 +164,7 @@ struct NetworkProtocol::Header
PartialObject toObject(const PartialStorage &) const;
template<class T> const T * lookupFirst() const;
+ bool isAcknowledged() const;
vector<Item> items;
};