diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/network.cpp | 3 | ||||
-rw-r--r-- | src/network/protocol.cpp | 58 | ||||
-rw-r--r-- | src/network/protocol.h | 7 |
3 files changed, 46 insertions, 22 deletions
diff --git a/src/network.cpp b/src/network.cpp index da480c3..6840f43 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -698,8 +698,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) if (!holds_alternative<Identity>(identity)) return; - if (holds_alternative<monostate>(connection.channel()) || - holds_alternative<NetworkProtocol::Cookie>(connection.channel())) { + if (holds_alternative<monostate>(connection.channel())) { auto req = Channel::generateRequest(tempStorage, server.self, std::get<Identity>(identity)); connection.channel().emplace<Stored<ChannelRequest>>(req); diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 79c023d..93d171a 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -32,6 +32,8 @@ struct NetworkProtocol::ConnectionPriv mutex cmutex {}; vector<uint8_t> buffer {}; + optional<Cookie> receivedCookie = nullopt; + bool confirmedCookie = false; ChannelState channel = monostate(); vector<vector<uint8_t>> secureOutQueue {}; }; @@ -88,7 +90,8 @@ NetworkProtocol::PollResult NetworkProtocol::poll() } auto pst = self->ref()->storage().deriveEphemeralStorage(); - if (auto header = Connection::receive(buffer, nullptr, pst)) { + bool secure = false; + if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) { if (auto conn = verifyNewConnection(*header, addr)) return NewConnection { move(*conn) }; @@ -113,6 +116,7 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) vector<Header::Item> header { Header::Initiation { Digest(array<uint8_t, Digest::size> {}) }, + Header::AnnounceSelf { self->ref()->digest() }, Header::Version { defaultVersion }, }; conn->send(self->ref()->storage(), move(header), {}, false); @@ -311,32 +315,50 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par } } - if (auto header = receive(buf, channel, partStorage)) { + bool secure = false; + if (auto header = parsePacket(buf, channel, partStorage, secure)) { scoped_lock lock(p->cmutex); + if (secure) + return header; + + if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) { + if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) + return nullopt; + + p->confirmedCookie = true; + + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) + p->receivedCookie = cookieSet->value; + + return header; + } + + if (holds_alternative<monostate>(p->channel)) { + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) { + p->receivedCookie = cookieSet->value; + return header; + } + } + if (header->lookupFirst<Header::Initiation>()) { p->protocol->sendCookie(p->peerAddress); return nullopt; } - - if (holds_alternative<monostate>(p->channel) || - holds_alternative<Cookie>(p->channel)) - if (const auto * cookie = header->lookupFirst<Header::CookieSet>()) - p->channel = cookie->value; - - return header; } return nullopt; } -optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<uint8_t> & buf, - Channel * channel, - const PartialStorage & partStorage) +optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & partStorage, + bool & secure) { vector<uint8_t> decrypted; auto plainBegin = buf.cbegin(); auto plainEnd = buf.cbegin(); + secure = false; + if ((buf[0] & 0xE0) == 0x80) { if (not channel) { std::cerr << "unexpected encrypted packet\n"; @@ -356,6 +378,8 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<ui return nullopt; } } + + secure = true; } else if ((buf[0] & 0xE0) == 0x60) { plainBegin = buf.begin(); @@ -398,11 +422,13 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel)) channel = uptr->get(); - if (channel || secure) + if (channel || secure) { data.push_back(0x00); - else if (const auto * ptr = get_if<Cookie>(&this->channel)) { - header.items.push_back(Header::CookieEcho { ptr->value }); - header.items.push_back(Header::Version { defaultVersion }); + } else { + if (receivedCookie) + header.items.push_back(Header::CookieEcho { receivedCookie->value }); + if (!confirmedCookie) + header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) }); } part = header.toObject(partStorage).encode(); diff --git a/src/network/protocol.h b/src/network/protocol.h index dda2ffb..3d7c073 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -54,7 +54,6 @@ public: struct Cookie { vector<uint8_t> value; }; using ChannelState = variant<monostate, - Cookie, Stored<ChannelRequest>, shared_ptr<struct WaitingRef>, Stored<ChannelAccept>, @@ -115,9 +114,9 @@ public: void trySendOutQueue(); private: - static optional<Header> receive(vector<uint8_t> & buf, - Channel * channel, - const PartialStorage & st); + static optional<Header> parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & st, + bool & secure); unique_ptr<ConnectionPriv> p; }; |