diff options
| -rw-r--r-- | src/network.cpp | 10 | ||||
| -rw-r--r-- | src/network/protocol.cpp | 220 | ||||
| -rw-r--r-- | src/network/protocol.h | 40 | 
3 files changed, 213 insertions, 57 deletions
| diff --git a/src/network.cpp b/src/network.cpp index a3d1130..da480c3 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -221,7 +221,7 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const  			NetworkProtocol::Header::ServiceType { uuid },  			NetworkProtocol::Header::ServiceRef { ref.digest() },  		}); -		speer->connection.send(speer->partStorage, header, { obj }, true); +		speer->connection.send(speer->partStorage, move(header), { obj }, true);  		return true;  	} @@ -363,6 +363,11 @@ void Server::Priv::doListen()  		if (holds_alternative<NetworkProtocol::ProtocolClosed>(res))  			break; +		if (const auto * ann = get_if<NetworkProtocol::ReceivedAnnounce>(&res)) { +			if (not isSelfAddress(ann->addr)) +				getPeer(ann->addr); +		} +  		if (holds_alternative<NetworkProtocol::NewConnection>(res)) {  			auto & conn = get<NetworkProtocol::NewConnection>(res).conn;  			if (not isSelfAddress(conn.peerAddress())) @@ -693,7 +698,8 @@ void Server::Peer::updateChannel(ReplyBuilder & reply)  	if (!holds_alternative<Identity>(identity))  		return; -	if (holds_alternative<monostate>(connection.channel())) { +	if (holds_alternative<monostate>(connection.channel()) || +			holds_alternative<NetworkProtocol::Cookie>(connection.channel())) {  		auto req = Channel::generateRequest(tempStorage,  				server.self, std::get<Identity>(identity));  		connection.channel().emplace<Stored<ChannelRequest>>(req); diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index f001d6c..40aeb47 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -23,7 +23,7 @@ struct NetworkProtocol::ConnectionPriv  {  	Connection::Id id() const; -	bool send(const PartialStorage &, const Header &, +	bool send(const PartialStorage &, Header,  			const vector<Object> &, bool secure);  	NetworkProtocol * protocol; @@ -76,23 +76,28 @@ NetworkProtocol::PollResult NetworkProtocol::poll()  	if (!recvfrom(buffer, addr))  		return ProtocolClosed {}; -	scoped_lock lock(protocolMutex); -	for (const auto & c : connections) { -		if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { -			scoped_lock clock(c->cmutex); -			buffer.swap(c->buffer); -			return ConnectionReadReady { c->id() }; +	{ +		scoped_lock lock(protocolMutex); + +		for (const auto & c : connections) { +			if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { +				scoped_lock clock(c->cmutex); +				buffer.swap(c->buffer); +				return ConnectionReadReady { c->id() }; +			}  		} -	} -	auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { -		.protocol = this, -		.peerAddress = addr, -	}); +		auto pst = self->ref()->storage().deriveEphemeralStorage(); +		if (auto header = Connection::receive(buffer, nullptr, pst)) { +			if (auto conn = verifyNewConnection(*header, addr)) +				return NewConnection { move(*conn) }; -	connections.push_back(conn.get()); -	buffer.swap(conn->buffer); -	return NewConnection { Connection(move(conn)) }; +			if (auto ann = header->lookupFirst<Header::AnnounceSelf>()) +				return ReceivedAnnounce { addr, ann->value }; +		} +	} + +	return poll();  }  NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) @@ -107,10 +112,10 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)  		connections.push_back(conn.get());  		vector<Header::Item> header { -			Header::AnnounceSelf { self->ref()->digest() }, +			Header::Initiation { Digest(array<uint8_t, Digest::size> {}) },  			Header::Version { defaultVersion },  		}; -		conn->send(self->ref()->storage(), header, {}, false); +		conn->send(self->ref()->storage(), move(header), {}, false);  	}  	return Connection(move(conn)); @@ -179,6 +184,70 @@ void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in  	}, vaddr);  } +void NetworkProtocol::sendCookie(variant<sockaddr_in, sockaddr_in6> addr) +{ +	auto bytes = Header({ +		Header::CookieSet { generateCookie(addr) }, +		Header::AnnounceSelf { self->ref()->digest() }, +		Header::Version { defaultVersion }, +	}).toObject(self->ref()->storage()).encode(); + +	sendto(bytes, addr); +} + +optional<NetworkProtocol::Connection> NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr) +{ +	optional<string> version; +	for (const auto & h : header.items) { +		if (const auto * ptr = get_if<Header::Version>(&h)) { +			if (ptr->value == defaultVersion) { +				version = ptr->value; +				break; +			} +		} +	} +	if (!version) +		return nullopt; + +	if (header.lookupFirst<Header::Initiation>()) { +		sendCookie(addr); +	} + +	else if (auto cookie = header.lookupFirst<Header::CookieEcho>()) { +		if (verifyCookie(addr, cookie->value)) { +			auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { +				.protocol = this, +				.peerAddress = addr, +			}); + +			connections.push_back(conn.get()); +			buffer.swap(conn->buffer); +			return Connection(move(conn)); +		} +	} + +	return nullopt; +} + +NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant<sockaddr_in, sockaddr_in6> vaddr) const +{ +	vector<uint8_t> cookie; +	visit([&](auto && addr) { +		cookie.resize(sizeof addr); +		memcpy(cookie.data(), &addr, sizeof addr); +	}, vaddr); +	return Cookie { cookie }; +} + +bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, const NetworkProtocol::Cookie & cookie) const +{ +	return visit([&](auto && addr) { +		if (cookie.value.size() != sizeof addr) +			return false; +		return memcmp(cookie.value.data(), &addr, sizeof addr) == 0; +	}, vaddr); +} +  /******************************************************************************/  /* Connection                                                                 */  /******************************************************************************/ @@ -222,53 +291,76 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const  optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)  { -	vector<uint8_t> buf, decrypted; -	auto plainBegin = buf.cbegin(); -	auto plainEnd = buf.cbegin(); +	vector<uint8_t> buf; + +	Channel * channel = nullptr; +	unique_ptr<Channel> channelPtr;  	{  		scoped_lock lock(p->cmutex);  		if (p->buffer.empty())  			return nullopt; -  		buf.swap(p->buffer); -		if ((buf[0] & 0xE0) == 0x80) { -			Channel * channel = nullptr; -			unique_ptr<Channel> channelPtr; +		if (holds_alternative<unique_ptr<Channel>>(p->channel)) { +			channel = std::get<unique_ptr<Channel>>(p->channel).get(); +		} else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { +			channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); +			channel = channelPtr.get(); +		} +	} + +	if (auto header = receive(buf, channel, partStorage)) { +		scoped_lock lock(p->cmutex); -			if (holds_alternative<unique_ptr<Channel>>(p->channel)) { -				channel = std::get<unique_ptr<Channel>>(p->channel).get(); -			} else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { -				channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); -				channel = channelPtr.get(); -			} +		if (header->lookupFirst<Header::Initiation>()) { +			p->protocol->sendCookie(p->peerAddress); +			return nullopt; +		} -			if (not channel) { -				std::cerr << "unexpected encrypted packet\n"; -				return nullopt; -			} +		if (holds_alternative<monostate>(p->channel) || +				holds_alternative<Cookie>(p->channel)) +			if (const auto * cookie = header->lookupFirst<Header::CookieSet>()) +				p->channel = cookie->value; -			if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { -				if (decrypted.empty()) { -					std::cerr << "empty decrypted content\n"; -				} -				else if (decrypted[0] == 0x00) { -					plainBegin = decrypted.begin() + 1; -					plainEnd = decrypted.end(); -				} -				else { -					std::cerr << "streams not implemented\n"; -					return nullopt; -				} -			} +		return header; +	} +	return nullopt; +} + +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<uint8_t> & buf, +		Channel * channel, +		const PartialStorage & partStorage) +{ +	vector<uint8_t> decrypted; +	auto plainBegin = buf.cbegin(); +	auto plainEnd = buf.cbegin(); + +	if ((buf[0] & 0xE0) == 0x80) { +		if (not channel) { +			std::cerr << "unexpected encrypted packet\n"; +			return nullopt;  		} -		else if ((buf[0] & 0xE0) == 0x60) { -			plainBegin = buf.begin(); -			plainEnd = buf.end(); + +		if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { +			if (decrypted.empty()) { +				std::cerr << "empty decrypted content\n"; +			} +			else if (decrypted[0] == 0x00) { +				plainBegin = decrypted.begin() + 1; +				plainEnd = decrypted.end(); +			} +			else { +				std::cerr << "streams not implemented\n"; +				return nullopt; +			}  		}  	} +	else if ((buf[0] & 0xE0) == 0x60) { +		plainBegin = buf.begin(); +		plainEnd = buf.end(); +	}  	if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) {  		if (auto header = Header::load(std::get<PartialObject>(*dec))) { @@ -287,14 +379,14 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  }  bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, -		const Header & header, +		Header header,  		const vector<Object> & objs, bool secure)  { -	return p->send(partStorage, header, objs, secure); +	return p->send(partStorage, move(header), objs, secure);  }  bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, -		const Header & header, +		Header header,  		const vector<Object> & objs, bool secure)  {  	vector<uint8_t> data, part, out; @@ -308,6 +400,10 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,  		if (channel || secure)  			data.push_back(0x00); +		else if (const auto * ptr = get_if<Cookie>(&this->channel)) { +			header.items.push_back(Header::CookieEcho { ptr->value }); +			header.items.push_back(Header::Version { defaultVersion }); +		}  		part = header.toObject(partStorage).encode();  		data.insert(data.end(), part.begin(), part.end()); @@ -414,6 +510,15 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj  		} else if (item.name == "VER") {  			if (auto ver = item.asText())  				items.emplace_back(Version { *ver }); +		} else if (item.name == "INI") { +			if (auto ref = item.asRef()) +				items.emplace_back(Initiation { ref->digest() }); +		} else if (item.name == "CKS") { +			if (auto cookie = item.asBinary()) +				items.emplace_back(CookieSet { *cookie }); +		} else if (item.name == "CKE") { +			if (auto cookie = item.asBinary()) +				items.emplace_back(CookieEcho { *cookie });  		} else if (item.name == "REQ") {  			if (auto ref = item.asRef())  				items.emplace_back(DataRequest { ref->digest() }); @@ -455,6 +560,15 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const  		else if (const auto * ptr = get_if<Version>(&item))  			ritems.emplace_back("VER", ptr->value); +		else if (const auto * ptr = get_if<Initiation>(&item)) +			ritems.emplace_back("INI", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if<CookieSet>(&item)) +			ritems.emplace_back("CKS", ptr->value.value); + +		else if (const auto * ptr = get_if<CookieEcho>(&item)) +			ritems.emplace_back("CKE", ptr->value.value); +  		else if (const auto * ptr = get_if<DataRequest>(&item))  			ritems.emplace_back("REQ", st.ref(ptr->value)); diff --git a/src/network/protocol.h b/src/network/protocol.h index 545585e..dda2ffb 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -38,18 +38,23 @@ public:  	struct Header; +	struct ReceivedAnnounce;  	struct NewConnection;  	struct ConnectionReadReady;  	struct ProtocolClosed {};  	using PollResult = variant< +		ReceivedAnnounce,  		NewConnection,  		ConnectionReadReady,  		ProtocolClosed>;  	PollResult poll(); +	struct Cookie { vector<uint8_t> value; }; +  	using ChannelState = variant<monostate, +		Cookie,  		Stored<ChannelRequest>,  		shared_ptr<struct WaitingRef>,  		Stored<ChannelAccept>, @@ -66,6 +71,12 @@ private:  	bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);  	void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr); +	void sendCookie(variant<sockaddr_in, sockaddr_in6> addr); +	optional<Connection> verifyNewConnection(const Header & header, sockaddr_in6 addr); + +	Cookie generateCookie(variant<sockaddr_in, sockaddr_in6> addr) const; +	bool verifyCookie(variant<sockaddr_in, sockaddr_in6> addr, const Cookie & cookie) const; +  	int sock;  	mutex protocolMutex; @@ -94,7 +105,7 @@ public:  	const sockaddr_in6 & peerAddress() const;  	optional<Header> receive(const PartialStorage &); -	bool send(const PartialStorage &, const NetworkProtocol::Header &, +	bool send(const PartialStorage &, NetworkProtocol::Header,  			const vector<Object> &, bool secure);  	void close(); @@ -104,9 +115,14 @@ public:  	void trySendOutQueue();  private: +	static optional<Header> receive(vector<uint8_t> & buf, +			Channel * channel, +			const PartialStorage & st); +  	unique_ptr<ConnectionPriv> p;  }; +struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; };  struct NetworkProtocol::NewConnection { Connection conn; };  struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; @@ -114,6 +130,9 @@ struct NetworkProtocol::Header  {  	struct Acknowledged { Digest value; };  	struct Version { string value; }; +	struct Initiation { Digest value; }; +	struct CookieSet { Cookie value; }; +	struct CookieEcho { Cookie value; };  	struct DataRequest { Digest value; };  	struct DataResponse { Digest value; };  	struct AnnounceSelf { Digest value; }; @@ -126,6 +145,9 @@ struct NetworkProtocol::Header  	using Item = variant<  		Acknowledged,  		Version, +		Initiation, +		CookieSet, +		CookieEcho,  		DataRequest,  		DataResponse,  		AnnounceSelf, @@ -140,14 +162,28 @@ struct NetworkProtocol::Header  	static optional<Header> load(const PartialObject &);  	PartialObject toObject(const PartialStorage &) const; -	const vector<Item> items; +	template<class T> const T * lookupFirst() const; + +	vector<Item> items;  }; +template<class T> +const T * NetworkProtocol::Header::lookupFirst() const +{ +	for (const auto & h : items) +		if (auto ptr = std::get_if<T>(&h)) +			return ptr; +	return nullptr; +} +  bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &);  inline bool operator!=(const NetworkProtocol::Header::Item & left,  		const NetworkProtocol::Header::Item & right)  { return not (left == right); } +inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right) +{ return left.value == right.value; } +  class ReplyBuilder  {  public: |