diff options
| author | Roman Smrž <roman.smrz@seznam.cz> | 2023-08-19 11:22:02 +0200 | 
|---|---|---|
| committer | Roman Smrž <roman.smrz@seznam.cz> | 2023-08-27 10:51:20 +0200 | 
| commit | 401f8c1288842b7479c375fba4aed55f6c5d52e9 (patch) | |
| tree | 62ccc1414dc1cfdeffd3bf105ca2df3396b90abf | |
| parent | 4153da3c16d184a1e6ffa15d2c504c6e3f6b0e1f (diff) | |
Network: encrypt and decrypt within connection object
| -rw-r--r-- | src/network.cpp | 97 | ||||
| -rw-r--r-- | src/network.h | 3 | ||||
| -rw-r--r-- | src/network/protocol.cpp | 100 | ||||
| -rw-r--r-- | src/network/protocol.h | 6 | 
4 files changed, 110 insertions, 96 deletions
| diff --git a/src/network.cpp b/src/network.cpp index 455496c..8c181cf 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -117,7 +117,7 @@ void Server::addPeer(const string & node, const string & service) const  				header.push_back(NetworkProtocol::Header::Item {  					NetworkProtocol::Header::Type::AnnounceSelf, p->self.ref()->digest() });  			} -			peer.send(header, {}, false); +			peer.connection.send(peer.partStorage, header, {}, false);  			return;  		}  	} @@ -229,7 +229,7 @@ bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const  			{ NetworkProtocol::Header::Type::ServiceType, uuid },  				{ NetworkProtocol::Header::Type::ServiceRef, ref.digest() },  		}); -		speer->send(header, { obj }, true); +		speer->connection.send(speer->partStorage, header, { obj }, true);  		return true;  	} @@ -360,7 +360,6 @@ shared_ptr<Server::Priv> Server::Priv::getptr()  void Server::Priv::doListen()  { -	vector<uint8_t> buf, decrypted, *current;  	unique_lock lock(dataMutex);  	for (; !finish; lock.lock()) { @@ -385,50 +384,22 @@ void Server::Priv::doListen()  		if (!peer)  			continue; -		if (not peer->connection.receive(buf)) -			continue; - -		current = &buf; -		if (holds_alternative<unique_ptr<Channel>>(peer->connection.channel())) { -			if (auto dec = std::get<unique_ptr<Channel>>(peer->connection.channel())->decrypt(buf)) { -				decrypted = std::move(*dec); -				current = &decrypted; -			} -		} else if (holds_alternative<Stored<ChannelAccept>>(peer->connection.channel())) { -			if (auto dec = std::get<Stored<ChannelAccept>>(peer->connection.channel())-> -					data->channel()->decrypt(buf)) { -				decrypted = std::move(*dec); -				current = &decrypted; -			} -		} - -		if (auto dec = PartialObject::decodePrefix(peer->partStorage, -				current->begin(), current->end())) { -			if (auto header = NetworkProtocol::Header::load(std::get<PartialObject>(*dec))) { -				auto pos = std::get<1>(*dec); -				while (auto cdec = PartialObject::decodePrefix(peer->partStorage, -							pos, current->end())) { -					peer->partStorage.storeObject(std::get<PartialObject>(*cdec)); -					pos = std::get<1>(*cdec); -				} +		if (auto header = peer->connection.receive(peer->partStorage)) { +			ReplyBuilder reply; -				ReplyBuilder reply; - -				scoped_lock hlock(dataMutex); -				shared_lock slock(selfMutex); +			scoped_lock hlock(dataMutex); +			shared_lock slock(selfMutex); -				handlePacket(*peer, *header, reply); -				peer->updateIdentity(reply); -				peer->updateChannel(reply); -				peer->updateService(reply); +			handlePacket(*peer, *header, reply); +			peer->updateIdentity(reply); +			peer->updateChannel(reply); +			peer->updateService(reply); -				if (!reply.header().empty()) -					peer->send(NetworkProtocol::Header(reply.header()), reply.body(), false); +			if (!reply.header().empty()) +				peer->connection.send(peer->partStorage, +						NetworkProtocol::Header(reply.header()), reply.body(), false); -				peer->trySendOutQueue(); -			} -		} else { -			std::cerr << "invalid packet\n"; +			peer->connection.trySendOutQueue();  		}  	}  } @@ -703,33 +674,11 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head)  			NetworkProtocol::Header header(hitems);  			for (const auto & peer : peers) -				peer->send(header, { **self.ref() }, false); +				peer->connection.send(peer->partStorage, header, { **self.ref() }, false);  		}  	}  } -void Server::Peer::send(const NetworkProtocol::Header & header, const vector<Object> & objs, bool secure) -{ -	vector<uint8_t> data, part, out; - -	part = header.toObject(partStorage).encode(); -	data.insert(data.end(), part.begin(), part.end()); -	for (const auto & obj : objs) { -		part = obj.encode(); -		data.insert(data.end(), part.begin(), part.end()); -	} - -	if (holds_alternative<unique_ptr<Channel>>(connection.channel())) -		out = std::get<unique_ptr<Channel>>(connection.channel())->encrypt(data); -	else if (secure) -		secureOutQueue.emplace_back(move(data)); -	else -		out = std::move(data); - -	if (!out.empty()) -		connection.send(out); -} -  void Server::Peer::updateIdentity(ReplyBuilder &)  {  	if (holds_alternative<shared_ptr<WaitingRef>>(identity)) { @@ -853,22 +802,6 @@ void Server::Peer::updateService(ReplyBuilder & reply)  	serviceQueue = std::move(next);  } -void Server::Peer::trySendOutQueue() -{ -	if (secureOutQueue.empty()) -		return; - -	if (!holds_alternative<unique_ptr<Channel>>(connection.channel())) -		return; - -	for (const auto & data : secureOutQueue) { -		auto out = std::get<unique_ptr<Channel>>(connection.channel())->encrypt(data); -		connection.send(out); -	} - -	secureOutQueue.clear(); -} -  void ReplyBuilder::header(NetworkProtocol::Header::Item && item)  { diff --git a/src/network.h b/src/network.h index 2959adc..d1fae15 100644 --- a/src/network.h +++ b/src/network.h @@ -54,16 +54,13 @@ struct Server::Peer  	PartialStorage partStorage;  	vector<tuple<UUID, shared_ptr<WaitingRef>>> serviceQueue {}; -	vector<vector<uint8_t>> secureOutQueue {};  	shared_ptr<erebos::Peer::Priv> lpeer = nullptr; -	void send(const NetworkProtocol::Header &, const vector<Object> &, bool secure);  	void updateIdentity(ReplyBuilder &);  	void updateChannel(ReplyBuilder &);  	void finalizeChannel(ReplyBuilder &, unique_ptr<Channel>);  	void updateService(ReplyBuilder &); -	void trySendOutQueue();  };  struct Peer::Priv : enable_shared_from_this<Peer::Priv> diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 4151bf2..5dc831a 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -4,6 +4,7 @@  #include <unistd.h>  #include <cstring> +#include <iostream>  #include <mutex>  #include <system_error> @@ -26,6 +27,7 @@ struct NetworkProtocol::ConnectionPriv  	vector<uint8_t> buffer {};  	ChannelState channel = monostate(); +	vector<vector<uint8_t>> secureOutQueue {};  }; @@ -168,20 +170,79 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const  	return p->peerAddress;  } -bool NetworkProtocol::Connection::receive(vector<uint8_t> & buffer) +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)  { -	scoped_lock lock(p->cmutex); -	if (p->buffer.empty()) -		return false; +	vector<uint8_t> buf, decrypted; +	vector<uint8_t> * current; -	buffer.swap(p->buffer); -	p->buffer.clear(); -	return true; +	{ +		scoped_lock lock(p->cmutex); + +		if (p->buffer.empty()) +			return nullopt; + +		buf.swap(p->buffer); +		current = &buf; + +		if (holds_alternative<unique_ptr<Channel>>(p->channel)) { +			if (auto dec = std::get<unique_ptr<Channel>>(p->channel)->decrypt(buf)) { +				decrypted = std::move(*dec); +				current = &decrypted; +			} +		} else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { +			if (auto dec = std::get<Stored<ChannelAccept>>(p->channel)-> +					data->channel()->decrypt(buf)) { +				decrypted = std::move(*dec); +				current = &decrypted; +			} +		} +	} + +	if (auto dec = PartialObject::decodePrefix(partStorage, +			current->begin(), current->end())) { +		if (auto header = Header::load(std::get<PartialObject>(*dec))) { +			auto pos = std::get<1>(*dec); +			while (auto cdec = PartialObject::decodePrefix(partStorage, +						pos, current->end())) { +				partStorage.storeObject(std::get<PartialObject>(*cdec)); +				pos = std::get<1>(*cdec); +			} + +			return header; +		} +	} + +	std::cerr << "invalid packet\n"; +	return nullopt;  } -bool NetworkProtocol::Connection::send(const vector<uint8_t> & buffer) +bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, +		const Header & header, +		const vector<Object> & objs, bool secure)  { -	p->protocol->sendto(buffer, p->peerAddress); +	vector<uint8_t> data, part, out; + +	{ +		scoped_lock clock(p->cmutex); + +		part = header.toObject(partStorage).encode(); +		data.insert(data.end(), part.begin(), part.end()); +		for (const auto & obj : objs) { +			part = obj.encode(); +			data.insert(data.end(), part.begin(), part.end()); +		} + +		if (holds_alternative<unique_ptr<Channel>>(p->channel)) +			out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); +		else if (secure) +			p->secureOutQueue.emplace_back(move(data)); +		else +			out = std::move(data); +	} + +	if (not out.empty()) +		p->protocol->sendto(out, p->peerAddress); +  	return true;  } @@ -209,6 +270,27 @@ NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel()  	return p->channel;  } +void NetworkProtocol::Connection::trySendOutQueue() +{ +	decltype(p->secureOutQueue) queue; +	{ +		scoped_lock clock(p->cmutex); + +		if (p->secureOutQueue.empty()) +			return; + +		if (not holds_alternative<unique_ptr<Channel>>(p->channel)) +			return; + +		queue.swap(p->secureOutQueue); +	} + +	for (const auto & data : queue) { +		auto out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); +		p->protocol->sendto(out, p->peerAddress); +	} +} +  /******************************************************************************/  /* Header                                                                     */ diff --git a/src/network/protocol.h b/src/network/protocol.h index 88abf67..c5803ce 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -87,13 +87,15 @@ public:  	const sockaddr_in6 & peerAddress() const; -	bool receive(vector<uint8_t> & buffer); -	bool send(const vector<uint8_t> & buffer); +	optional<Header> receive(const PartialStorage &); +	bool send(const PartialStorage &, const NetworkProtocol::Header &, +			const vector<Object> &, bool secure);  	void close();  	// temporary:  	ChannelState & channel(); +	void trySendOutQueue();  private:  	unique_ptr<ConnectionPriv> p; |