From d0c48bf9b90dfbd55908a88a5aba411ca9b8e600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sun, 27 Aug 2023 21:52:29 +0200 Subject: Network: connection initiation with cookie --- src/network.cpp | 10 ++- src/network/protocol.cpp | 220 +++++++++++++++++++++++++++++++++++------------ src/network/protocol.h | 40 ++++++++- 3 files changed, 213 insertions(+), 57 deletions(-) diff --git a/src/network.cpp b/src/network.cpp index a3d1130..da480c3 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -221,7 +221,7 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const NetworkProtocol::Header::ServiceType { uuid }, NetworkProtocol::Header::ServiceRef { ref.digest() }, }); - speer->connection.send(speer->partStorage, header, { obj }, true); + speer->connection.send(speer->partStorage, move(header), { obj }, true); return true; } @@ -363,6 +363,11 @@ void Server::Priv::doListen() if (holds_alternative(res)) break; + if (const auto * ann = get_if(&res)) { + if (not isSelfAddress(ann->addr)) + getPeer(ann->addr); + } + if (holds_alternative(res)) { auto & conn = get(res).conn; if (not isSelfAddress(conn.peerAddress())) @@ -693,7 +698,8 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) if (!holds_alternative(identity)) return; - if (holds_alternative(connection.channel())) { + if (holds_alternative(connection.channel()) || + holds_alternative(connection.channel())) { auto req = Channel::generateRequest(tempStorage, server.self, std::get(identity)); connection.channel().emplace>(req); diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index f001d6c..40aeb47 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -23,7 +23,7 @@ struct NetworkProtocol::ConnectionPriv { Connection::Id id() const; - bool send(const PartialStorage &, const Header &, + bool send(const PartialStorage &, Header, const vector &, bool secure); NetworkProtocol * protocol; @@ -76,23 +76,28 @@ NetworkProtocol::PollResult NetworkProtocol::poll() if (!recvfrom(buffer, addr)) return ProtocolClosed {}; - scoped_lock lock(protocolMutex); - for (const auto & c : connections) { - if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { - scoped_lock clock(c->cmutex); - buffer.swap(c->buffer); - return ConnectionReadReady { c->id() }; + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { + scoped_lock clock(c->cmutex); + buffer.swap(c->buffer); + return ConnectionReadReady { c->id() }; + } } - } - auto conn = unique_ptr(new ConnectionPriv { - .protocol = this, - .peerAddress = addr, - }); + auto pst = self->ref()->storage().deriveEphemeralStorage(); + if (auto header = Connection::receive(buffer, nullptr, pst)) { + if (auto conn = verifyNewConnection(*header, addr)) + return NewConnection { move(*conn) }; - connections.push_back(conn.get()); - buffer.swap(conn->buffer); - return NewConnection { Connection(move(conn)) }; + if (auto ann = header->lookupFirst()) + return ReceivedAnnounce { addr, ann->value }; + } + } + + return poll(); } NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) @@ -107,10 +112,10 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) connections.push_back(conn.get()); vector header { - Header::AnnounceSelf { self->ref()->digest() }, + Header::Initiation { Digest(array {}) }, Header::Version { defaultVersion }, }; - conn->send(self->ref()->storage(), header, {}, false); + conn->send(self->ref()->storage(), move(header), {}, false); } return Connection(move(conn)); @@ -179,6 +184,70 @@ void NetworkProtocol::sendto(const vector & buffer, variant addr) +{ + auto bytes = Header({ + Header::CookieSet { generateCookie(addr) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + + sendto(bytes, addr); +} + +optional NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr) +{ + optional version; + for (const auto & h : header.items) { + if (const auto * ptr = get_if(&h)) { + if (ptr->value == defaultVersion) { + version = ptr->value; + break; + } + } + } + if (!version) + return nullopt; + + if (header.lookupFirst()) { + sendCookie(addr); + } + + else if (auto cookie = header.lookupFirst()) { + if (verifyCookie(addr, cookie->value)) { + auto conn = unique_ptr(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + connections.push_back(conn.get()); + buffer.swap(conn->buffer); + return Connection(move(conn)); + } + } + + return nullopt; +} + +NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant vaddr) const +{ + vector cookie; + visit([&](auto && addr) { + cookie.resize(sizeof addr); + memcpy(cookie.data(), &addr, sizeof addr); + }, vaddr); + return Cookie { cookie }; +} + +bool NetworkProtocol::verifyCookie(variant vaddr, const NetworkProtocol::Cookie & cookie) const +{ + return visit([&](auto && addr) { + if (cookie.value.size() != sizeof addr) + return false; + return memcmp(cookie.value.data(), &addr, sizeof addr) == 0; + }, vaddr); +} + /******************************************************************************/ /* Connection */ /******************************************************************************/ @@ -222,53 +291,76 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const optional NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { - vector buf, decrypted; - auto plainBegin = buf.cbegin(); - auto plainEnd = buf.cbegin(); + vector buf; + + Channel * channel = nullptr; + unique_ptr channelPtr; { scoped_lock lock(p->cmutex); if (p->buffer.empty()) return nullopt; - buf.swap(p->buffer); - if ((buf[0] & 0xE0) == 0x80) { - Channel * channel = nullptr; - unique_ptr channelPtr; + if (holds_alternative>(p->channel)) { + channel = std::get>(p->channel).get(); + } else if (holds_alternative>(p->channel)) { + channelPtr = std::get>(p->channel)->data->channel(); + channel = channelPtr.get(); + } + } + + if (auto header = receive(buf, channel, partStorage)) { + scoped_lock lock(p->cmutex); - if (holds_alternative>(p->channel)) { - channel = std::get>(p->channel).get(); - } else if (holds_alternative>(p->channel)) { - channelPtr = std::get>(p->channel)->data->channel(); - channel = channelPtr.get(); - } + if (header->lookupFirst()) { + p->protocol->sendCookie(p->peerAddress); + return nullopt; + } - if (not channel) { - std::cerr << "unexpected encrypted packet\n"; - return nullopt; - } + if (holds_alternative(p->channel) || + holds_alternative(p->channel)) + if (const auto * cookie = header->lookupFirst()) + p->channel = cookie->value; - if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { - if (decrypted.empty()) { - std::cerr << "empty decrypted content\n"; - } - else if (decrypted[0] == 0x00) { - plainBegin = decrypted.begin() + 1; - plainEnd = decrypted.end(); - } - else { - std::cerr << "streams not implemented\n"; - return nullopt; - } - } + return header; + } + return nullopt; +} + +optional NetworkProtocol::Connection::receive(vector & buf, + Channel * channel, + const PartialStorage & partStorage) +{ + vector decrypted; + auto plainBegin = buf.cbegin(); + auto plainEnd = buf.cbegin(); + + if ((buf[0] & 0xE0) == 0x80) { + if (not channel) { + std::cerr << "unexpected encrypted packet\n"; + return nullopt; } - else if ((buf[0] & 0xE0) == 0x60) { - plainBegin = buf.begin(); - plainEnd = buf.end(); + + if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { + if (decrypted.empty()) { + std::cerr << "empty decrypted content\n"; + } + else if (decrypted[0] == 0x00) { + plainBegin = decrypted.begin() + 1; + plainEnd = decrypted.end(); + } + else { + std::cerr << "streams not implemented\n"; + return nullopt; + } } } + else if ((buf[0] & 0xE0) == 0x60) { + plainBegin = buf.begin(); + plainEnd = buf.end(); + } if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) { if (auto header = Header::load(std::get(*dec))) { @@ -287,14 +379,14 @@ optional NetworkProtocol::Connection::receive(const Par } bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, - const Header & header, + Header header, const vector & objs, bool secure) { - return p->send(partStorage, header, objs, secure); + return p->send(partStorage, move(header), objs, secure); } bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, - const Header & header, + Header header, const vector & objs, bool secure) { vector data, part, out; @@ -308,6 +400,10 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, if (channel || secure) data.push_back(0x00); + else if (const auto * ptr = get_if(&this->channel)) { + header.items.push_back(Header::CookieEcho { ptr->value }); + header.items.push_back(Header::Version { defaultVersion }); + } part = header.toObject(partStorage).encode(); data.insert(data.end(), part.begin(), part.end()); @@ -414,6 +510,15 @@ optional NetworkProtocol::Header::load(const PartialObj } else if (item.name == "VER") { if (auto ver = item.asText()) items.emplace_back(Version { *ver }); + } else if (item.name == "INI") { + if (auto ref = item.asRef()) + items.emplace_back(Initiation { ref->digest() }); + } else if (item.name == "CKS") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieSet { *cookie }); + } else if (item.name == "CKE") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieEcho { *cookie }); } else if (item.name == "REQ") { if (auto ref = item.asRef()) items.emplace_back(DataRequest { ref->digest() }); @@ -455,6 +560,15 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const else if (const auto * ptr = get_if(&item)) ritems.emplace_back("VER", ptr->value); + else if (const auto * ptr = get_if(&item)) + ritems.emplace_back("INI", st.ref(ptr->value)); + + else if (const auto * ptr = get_if(&item)) + ritems.emplace_back("CKS", ptr->value.value); + + else if (const auto * ptr = get_if(&item)) + ritems.emplace_back("CKE", ptr->value.value); + else if (const auto * ptr = get_if(&item)) ritems.emplace_back("REQ", st.ref(ptr->value)); diff --git a/src/network/protocol.h b/src/network/protocol.h index 545585e..dda2ffb 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -38,18 +38,23 @@ public: struct Header; + struct ReceivedAnnounce; struct NewConnection; struct ConnectionReadReady; struct ProtocolClosed {}; using PollResult = variant< + ReceivedAnnounce, NewConnection, ConnectionReadReady, ProtocolClosed>; PollResult poll(); + struct Cookie { vector value; }; + using ChannelState = variant, shared_ptr, Stored, @@ -66,6 +71,12 @@ private: bool recvfrom(vector & buffer, sockaddr_in6 & addr); void sendto(const vector & buffer, variant addr); + void sendCookie(variant addr); + optional verifyNewConnection(const Header & header, sockaddr_in6 addr); + + Cookie generateCookie(variant addr) const; + bool verifyCookie(variant addr, const Cookie & cookie) const; + int sock; mutex protocolMutex; @@ -94,7 +105,7 @@ public: const sockaddr_in6 & peerAddress() const; optional
receive(const PartialStorage &); - bool send(const PartialStorage &, const NetworkProtocol::Header &, + bool send(const PartialStorage &, NetworkProtocol::Header, const vector &, bool secure); void close(); @@ -104,9 +115,14 @@ public: void trySendOutQueue(); private: + static optional
receive(vector & buf, + Channel * channel, + const PartialStorage & st); + unique_ptr p; }; +struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; }; struct NetworkProtocol::NewConnection { Connection conn; }; struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; @@ -114,6 +130,9 @@ struct NetworkProtocol::Header { struct Acknowledged { Digest value; }; struct Version { string value; }; + struct Initiation { Digest value; }; + struct CookieSet { Cookie value; }; + struct CookieEcho { Cookie value; }; struct DataRequest { Digest value; }; struct DataResponse { Digest value; }; struct AnnounceSelf { Digest value; }; @@ -126,6 +145,9 @@ struct NetworkProtocol::Header using Item = variant< Acknowledged, Version, + Initiation, + CookieSet, + CookieEcho, DataRequest, DataResponse, AnnounceSelf, @@ -140,14 +162,28 @@ struct NetworkProtocol::Header static optional
load(const PartialObject &); PartialObject toObject(const PartialStorage &) const; - const vector items; + template const T * lookupFirst() const; + + vector items; }; +template +const T * NetworkProtocol::Header::lookupFirst() const +{ + for (const auto & h : items) + if (auto ptr = std::get_if(&h)) + return ptr; + return nullptr; +} + bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &); inline bool operator!=(const NetworkProtocol::Header::Item & left, const NetworkProtocol::Header::Item & right) { return not (left == right); } +inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right) +{ return left.value == right.value; } + class ReplyBuilder { public: -- cgit v1.2.3