From 512e20fa063e4a4525e47e048f26cc68668e7fac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Roman=20Smr=C5=BE?= <roman.smrz@seznam.cz>
Date: Sat, 16 Sep 2023 11:33:40 +0200
Subject: Protocol: use cookies during whole plaintext phase

---
 src/network.cpp          |  3 +--
 src/network/protocol.cpp | 58 +++++++++++++++++++++++++++++++++++-------------
 src/network/protocol.h   |  7 +++---
 3 files changed, 46 insertions(+), 22 deletions(-)

diff --git a/src/network.cpp b/src/network.cpp
index da480c3..6840f43 100644
--- a/src/network.cpp
+++ b/src/network.cpp
@@ -698,8 +698,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)
 	if (!holds_alternative<Identity>(identity))
 		return;
 
-	if (holds_alternative<monostate>(connection.channel()) ||
-			holds_alternative<NetworkProtocol::Cookie>(connection.channel())) {
+	if (holds_alternative<monostate>(connection.channel())) {
 		auto req = Channel::generateRequest(tempStorage,
 				server.self, std::get<Identity>(identity));
 		connection.channel().emplace<Stored<ChannelRequest>>(req);
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index 79c023d..93d171a 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -32,6 +32,8 @@ struct NetworkProtocol::ConnectionPriv
 	mutex cmutex {};
 	vector<uint8_t> buffer {};
 
+	optional<Cookie> receivedCookie = nullopt;
+	bool confirmedCookie = false;
 	ChannelState channel = monostate();
 	vector<vector<uint8_t>> secureOutQueue {};
 };
@@ -88,7 +90,8 @@ NetworkProtocol::PollResult NetworkProtocol::poll()
 		}
 
 		auto pst = self->ref()->storage().deriveEphemeralStorage();
-		if (auto header = Connection::receive(buffer, nullptr, pst)) {
+		bool secure = false;
+		if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) {
 			if (auto conn = verifyNewConnection(*header, addr))
 				return NewConnection { move(*conn) };
 
@@ -113,6 +116,7 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
 
 		vector<Header::Item> header {
 			Header::Initiation { Digest(array<uint8_t, Digest::size> {}) },
+			Header::AnnounceSelf { self->ref()->digest() },
 			Header::Version { defaultVersion },
 		};
 		conn->send(self->ref()->storage(), move(header), {}, false);
@@ -311,32 +315,50 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
 		}
 	}
 
-	if (auto header = receive(buf, channel, partStorage)) {
+	bool secure = false;
+	if (auto header = parsePacket(buf, channel, partStorage, secure)) {
 		scoped_lock lock(p->cmutex);
 
+		if (secure)
+			return header;
+
+		if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) {
+			if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value))
+				return nullopt;
+
+			p->confirmedCookie = true;
+
+			if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>())
+				p->receivedCookie = cookieSet->value;
+
+			return header;
+		}
+
+		if (holds_alternative<monostate>(p->channel)) {
+			if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) {
+				p->receivedCookie = cookieSet->value;
+				return header;
+			}
+		}
+
 		if (header->lookupFirst<Header::Initiation>()) {
 			p->protocol->sendCookie(p->peerAddress);
 			return nullopt;
 		}
-
-		if (holds_alternative<monostate>(p->channel) ||
-				holds_alternative<Cookie>(p->channel))
-			if (const auto * cookie = header->lookupFirst<Header::CookieSet>())
-				p->channel = cookie->value;
-
-		return header;
 	}
 	return nullopt;
 }
 
-optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<uint8_t> & buf,
-		Channel * channel,
-		const PartialStorage & partStorage)
+optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf,
+		Channel * channel, const PartialStorage & partStorage,
+		bool & secure)
 {
 	vector<uint8_t> decrypted;
 	auto plainBegin = buf.cbegin();
 	auto plainEnd = buf.cbegin();
 
+	secure = false;
+
 	if ((buf[0] & 0xE0) == 0x80) {
 		if (not channel) {
 			std::cerr << "unexpected encrypted packet\n";
@@ -356,6 +378,8 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<ui
 				return nullopt;
 			}
 		}
+
+		secure = true;
 	}
 	else if ((buf[0] & 0xE0) == 0x60) {
 		plainBegin = buf.begin();
@@ -398,11 +422,13 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
 		if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel))
 			channel = uptr->get();
 
-		if (channel || secure)
+		if (channel || secure) {
 			data.push_back(0x00);
-		else if (const auto * ptr = get_if<Cookie>(&this->channel)) {
-			header.items.push_back(Header::CookieEcho { ptr->value });
-			header.items.push_back(Header::Version { defaultVersion });
+		} else {
+			if (receivedCookie)
+				header.items.push_back(Header::CookieEcho { receivedCookie->value });
+			if (!confirmedCookie)
+				header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) });
 		}
 
 		part = header.toObject(partStorage).encode();
diff --git a/src/network/protocol.h b/src/network/protocol.h
index dda2ffb..3d7c073 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -54,7 +54,6 @@ public:
 	struct Cookie { vector<uint8_t> value; };
 
 	using ChannelState = variant<monostate,
-		Cookie,
 		Stored<ChannelRequest>,
 		shared_ptr<struct WaitingRef>,
 		Stored<ChannelAccept>,
@@ -115,9 +114,9 @@ public:
 	void trySendOutQueue();
 
 private:
-	static optional<Header> receive(vector<uint8_t> & buf,
-			Channel * channel,
-			const PartialStorage & st);
+	static optional<Header> parsePacket(vector<uint8_t> & buf,
+			Channel * channel, const PartialStorage & st,
+			bool & secure);
 
 	unique_ptr<ConnectionPriv> p;
 };
-- 
cgit v1.2.3