diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/network.cpp | 3 | ||||
| -rw-r--r-- | src/network/protocol.cpp | 58 | ||||
| -rw-r--r-- | src/network/protocol.h | 7 | 
3 files changed, 46 insertions, 22 deletions
| diff --git a/src/network.cpp b/src/network.cpp index da480c3..6840f43 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -698,8 +698,7 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)  	if (!holds_alternative<Identity>(identity))  		return; -	if (holds_alternative<monostate>(connection.channel()) || -			holds_alternative<NetworkProtocol::Cookie>(connection.channel())) { +	if (holds_alternative<monostate>(connection.channel())) {  		auto req = Channel::generateRequest(tempStorage,  				server.self, std::get<Identity>(identity));  		connection.channel().emplace<Stored<ChannelRequest>>(req); diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 79c023d..93d171a 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -32,6 +32,8 @@ struct NetworkProtocol::ConnectionPriv  	mutex cmutex {};  	vector<uint8_t> buffer {}; +	optional<Cookie> receivedCookie = nullopt; +	bool confirmedCookie = false;  	ChannelState channel = monostate();  	vector<vector<uint8_t>> secureOutQueue {};  }; @@ -88,7 +90,8 @@ NetworkProtocol::PollResult NetworkProtocol::poll()  		}  		auto pst = self->ref()->storage().deriveEphemeralStorage(); -		if (auto header = Connection::receive(buffer, nullptr, pst)) { +		bool secure = false; +		if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) {  			if (auto conn = verifyNewConnection(*header, addr))  				return NewConnection { move(*conn) }; @@ -113,6 +116,7 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)  		vector<Header::Item> header {  			Header::Initiation { Digest(array<uint8_t, Digest::size> {}) }, +			Header::AnnounceSelf { self->ref()->digest() },  			Header::Version { defaultVersion },  		};  		conn->send(self->ref()->storage(), move(header), {}, false); @@ -311,32 +315,50 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  		}  	} -	if (auto header = receive(buf, channel, partStorage)) { +	bool secure = false; +	if (auto header = parsePacket(buf, channel, partStorage, secure)) {  		scoped_lock lock(p->cmutex); +		if (secure) +			return header; + +		if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) { +			if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) +				return nullopt; + +			p->confirmedCookie = true; + +			if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) +				p->receivedCookie = cookieSet->value; + +			return header; +		} + +		if (holds_alternative<monostate>(p->channel)) { +			if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) { +				p->receivedCookie = cookieSet->value; +				return header; +			} +		} +  		if (header->lookupFirst<Header::Initiation>()) {  			p->protocol->sendCookie(p->peerAddress);  			return nullopt;  		} - -		if (holds_alternative<monostate>(p->channel) || -				holds_alternative<Cookie>(p->channel)) -			if (const auto * cookie = header->lookupFirst<Header::CookieSet>()) -				p->channel = cookie->value; - -		return header;  	}  	return nullopt;  } -optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<uint8_t> & buf, -		Channel * channel, -		const PartialStorage & partStorage) +optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, +		Channel * channel, const PartialStorage & partStorage, +		bool & secure)  {  	vector<uint8_t> decrypted;  	auto plainBegin = buf.cbegin();  	auto plainEnd = buf.cbegin(); +	secure = false; +  	if ((buf[0] & 0xE0) == 0x80) {  		if (not channel) {  			std::cerr << "unexpected encrypted packet\n"; @@ -356,6 +378,8 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<ui  				return nullopt;  			}  		} + +		secure = true;  	}  	else if ((buf[0] & 0xE0) == 0x60) {  		plainBegin = buf.begin(); @@ -398,11 +422,13 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,  		if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel))  			channel = uptr->get(); -		if (channel || secure) +		if (channel || secure) {  			data.push_back(0x00); -		else if (const auto * ptr = get_if<Cookie>(&this->channel)) { -			header.items.push_back(Header::CookieEcho { ptr->value }); -			header.items.push_back(Header::Version { defaultVersion }); +		} else { +			if (receivedCookie) +				header.items.push_back(Header::CookieEcho { receivedCookie->value }); +			if (!confirmedCookie) +				header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) });  		}  		part = header.toObject(partStorage).encode(); diff --git a/src/network/protocol.h b/src/network/protocol.h index dda2ffb..3d7c073 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -54,7 +54,6 @@ public:  	struct Cookie { vector<uint8_t> value; };  	using ChannelState = variant<monostate, -		Cookie,  		Stored<ChannelRequest>,  		shared_ptr<struct WaitingRef>,  		Stored<ChannelAccept>, @@ -115,9 +114,9 @@ public:  	void trySendOutQueue();  private: -	static optional<Header> receive(vector<uint8_t> & buf, -			Channel * channel, -			const PartialStorage & st); +	static optional<Header> parsePacket(vector<uint8_t> & buf, +			Channel * channel, const PartialStorage & st, +			bool & secure);  	unique_ptr<ConnectionPriv> p;  }; |