From d0c48bf9b90dfbd55908a88a5aba411ca9b8e600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sun, 27 Aug 2023 21:52:29 +0200 Subject: Network: connection initiation with cookie --- src/network/protocol.cpp | 220 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 167 insertions(+), 53 deletions(-) (limited to 'src/network/protocol.cpp') diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index f001d6c..40aeb47 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -23,7 +23,7 @@ struct NetworkProtocol::ConnectionPriv { Connection::Id id() const; - bool send(const PartialStorage &, const Header &, + bool send(const PartialStorage &, Header, const vector &, bool secure); NetworkProtocol * protocol; @@ -76,23 +76,28 @@ NetworkProtocol::PollResult NetworkProtocol::poll() if (!recvfrom(buffer, addr)) return ProtocolClosed {}; - scoped_lock lock(protocolMutex); - for (const auto & c : connections) { - if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { - scoped_lock clock(c->cmutex); - buffer.swap(c->buffer); - return ConnectionReadReady { c->id() }; + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { + scoped_lock clock(c->cmutex); + buffer.swap(c->buffer); + return ConnectionReadReady { c->id() }; + } } - } - auto conn = unique_ptr(new ConnectionPriv { - .protocol = this, - .peerAddress = addr, - }); + auto pst = self->ref()->storage().deriveEphemeralStorage(); + if (auto header = Connection::receive(buffer, nullptr, pst)) { + if (auto conn = verifyNewConnection(*header, addr)) + return NewConnection { move(*conn) }; - connections.push_back(conn.get()); - buffer.swap(conn->buffer); - return NewConnection { Connection(move(conn)) }; + if (auto ann = header->lookupFirst()) + return ReceivedAnnounce { addr, ann->value }; + } + } + + return poll(); } NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) @@ -107,10 +112,10 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) connections.push_back(conn.get()); vector header { - Header::AnnounceSelf { self->ref()->digest() }, + Header::Initiation { Digest(array {}) }, Header::Version { defaultVersion }, }; - conn->send(self->ref()->storage(), header, {}, false); + conn->send(self->ref()->storage(), move(header), {}, false); } return Connection(move(conn)); @@ -179,6 +184,70 @@ void NetworkProtocol::sendto(const vector & buffer, variant addr) +{ + auto bytes = Header({ + Header::CookieSet { generateCookie(addr) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + + sendto(bytes, addr); +} + +optional NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr) +{ + optional version; + for (const auto & h : header.items) { + if (const auto * ptr = get_if(&h)) { + if (ptr->value == defaultVersion) { + version = ptr->value; + break; + } + } + } + if (!version) + return nullopt; + + if (header.lookupFirst()) { + sendCookie(addr); + } + + else if (auto cookie = header.lookupFirst()) { + if (verifyCookie(addr, cookie->value)) { + auto conn = unique_ptr(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + connections.push_back(conn.get()); + buffer.swap(conn->buffer); + return Connection(move(conn)); + } + } + + return nullopt; +} + +NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant vaddr) const +{ + vector cookie; + visit([&](auto && addr) { + cookie.resize(sizeof addr); + memcpy(cookie.data(), &addr, sizeof addr); + }, vaddr); + return Cookie { cookie }; +} + +bool NetworkProtocol::verifyCookie(variant vaddr, const NetworkProtocol::Cookie & cookie) const +{ + return visit([&](auto && addr) { + if (cookie.value.size() != sizeof addr) + return false; + return memcmp(cookie.value.data(), &addr, sizeof addr) == 0; + }, vaddr); +} + /******************************************************************************/ /* Connection */ /******************************************************************************/ @@ -222,53 +291,76 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const optional NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { - vector buf, decrypted; - auto plainBegin = buf.cbegin(); - auto plainEnd = buf.cbegin(); + vector buf; + + Channel * channel = nullptr; + unique_ptr channelPtr; { scoped_lock lock(p->cmutex); if (p->buffer.empty()) return nullopt; - buf.swap(p->buffer); - if ((buf[0] & 0xE0) == 0x80) { - Channel * channel = nullptr; - unique_ptr channelPtr; + if (holds_alternative>(p->channel)) { + channel = std::get>(p->channel).get(); + } else if (holds_alternative>(p->channel)) { + channelPtr = std::get>(p->channel)->data->channel(); + channel = channelPtr.get(); + } + } + + if (auto header = receive(buf, channel, partStorage)) { + scoped_lock lock(p->cmutex); - if (holds_alternative>(p->channel)) { - channel = std::get>(p->channel).get(); - } else if (holds_alternative>(p->channel)) { - channelPtr = std::get>(p->channel)->data->channel(); - channel = channelPtr.get(); - } + if (header->lookupFirst()) { + p->protocol->sendCookie(p->peerAddress); + return nullopt; + } - if (not channel) { - std::cerr << "unexpected encrypted packet\n"; - return nullopt; - } + if (holds_alternative(p->channel) || + holds_alternative(p->channel)) + if (const auto * cookie = header->lookupFirst()) + p->channel = cookie->value; - if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { - if (decrypted.empty()) { - std::cerr << "empty decrypted content\n"; - } - else if (decrypted[0] == 0x00) { - plainBegin = decrypted.begin() + 1; - plainEnd = decrypted.end(); - } - else { - std::cerr << "streams not implemented\n"; - return nullopt; - } - } + return header; + } + return nullopt; +} + +optional NetworkProtocol::Connection::receive(vector & buf, + Channel * channel, + const PartialStorage & partStorage) +{ + vector decrypted; + auto plainBegin = buf.cbegin(); + auto plainEnd = buf.cbegin(); + + if ((buf[0] & 0xE0) == 0x80) { + if (not channel) { + std::cerr << "unexpected encrypted packet\n"; + return nullopt; } - else if ((buf[0] & 0xE0) == 0x60) { - plainBegin = buf.begin(); - plainEnd = buf.end(); + + if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { + if (decrypted.empty()) { + std::cerr << "empty decrypted content\n"; + } + else if (decrypted[0] == 0x00) { + plainBegin = decrypted.begin() + 1; + plainEnd = decrypted.end(); + } + else { + std::cerr << "streams not implemented\n"; + return nullopt; + } } } + else if ((buf[0] & 0xE0) == 0x60) { + plainBegin = buf.begin(); + plainEnd = buf.end(); + } if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) { if (auto header = Header::load(std::get(*dec))) { @@ -287,14 +379,14 @@ optional NetworkProtocol::Connection::receive(const Par } bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, - const Header & header, + Header header, const vector & objs, bool secure) { - return p->send(partStorage, header, objs, secure); + return p->send(partStorage, move(header), objs, secure); } bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, - const Header & header, + Header header, const vector & objs, bool secure) { vector data, part, out; @@ -308,6 +400,10 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, if (channel || secure) data.push_back(0x00); + else if (const auto * ptr = get_if(&this->channel)) { + header.items.push_back(Header::CookieEcho { ptr->value }); + header.items.push_back(Header::Version { defaultVersion }); + } part = header.toObject(partStorage).encode(); data.insert(data.end(), part.begin(), part.end()); @@ -414,6 +510,15 @@ optional NetworkProtocol::Header::load(const PartialObj } else if (item.name == "VER") { if (auto ver = item.asText()) items.emplace_back(Version { *ver }); + } else if (item.name == "INI") { + if (auto ref = item.asRef()) + items.emplace_back(Initiation { ref->digest() }); + } else if (item.name == "CKS") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieSet { *cookie }); + } else if (item.name == "CKE") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieEcho { *cookie }); } else if (item.name == "REQ") { if (auto ref = item.asRef()) items.emplace_back(DataRequest { ref->digest() }); @@ -455,6 +560,15 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const else if (const auto * ptr = get_if(&item)) ritems.emplace_back("VER", ptr->value); + else if (const auto * ptr = get_if(&item)) + ritems.emplace_back("INI", st.ref(ptr->value)); + + else if (const auto * ptr = get_if(&item)) + ritems.emplace_back("CKS", ptr->value.value); + + else if (const auto * ptr = get_if(&item)) + ritems.emplace_back("CKE", ptr->value.value); + else if (const auto * ptr = get_if(&item)) ritems.emplace_back("REQ", st.ref(ptr->value)); -- cgit v1.2.3