diff options
| author | Roman Smrž <roman.smrz@seznam.cz> | 2023-08-21 22:15:32 +0200 | 
|---|---|---|
| committer | Roman Smrž <roman.smrz@seznam.cz> | 2023-08-27 10:53:18 +0200 | 
| commit | 1d4fa8fafa707642f948da9b033a21d0bcde0bbf (patch) | |
| tree | 0db7ec2673cce166bd322023d39d4003cd1c3d15 | |
| parent | 401f8c1288842b7479c375fba4aed55f6c5d52e9 (diff) | |
Network: headers for encryption and streams
| -rw-r--r-- | src/network/channel.cpp | 36 | ||||
| -rw-r--r-- | src/network/channel.h | 8 | ||||
| -rw-r--r-- | src/network/protocol.cpp | 71 | 
3 files changed, 77 insertions, 38 deletions
| diff --git a/src/network/channel.cpp b/src/network/channel.cpp index b317f3d..b95e0a1 100644 --- a/src/network/channel.cpp +++ b/src/network/channel.cpp @@ -133,15 +133,17 @@ optional<Stored<ChannelAccept>> Channel::acceptRequest(const Identity & self,  	}));  } -vector<uint8_t> Channel::encrypt(const vector<uint8_t> & plain) +uint64_t Channel::encrypt(BufferCIt plainBegin, BufferCIt plainEnd, +		Buffer & encBuffer, size_t encOffset)  { -	vector<uint8_t> res(plain.size() + 8 + 16 + 16); +	auto plainSize = plainEnd - plainBegin; +	encBuffer.resize(encOffset + plainSize + 8 + 16 + 16);  	array<uint8_t, 12> iv;  	uint64_t beCount = htobe64(nonceCounter++); -	std::memcpy(res.data(), &beCount, 8); +	std::memcpy(encBuffer.data() + encOffset, &beCount, 8);  	std::copy_n(nonceFixedOur.begin(), 6, iv.begin()); -	std::copy_n(res.begin() + 2, 6, iv.begin() + 6); +	std::copy_n(encBuffer.begin() + encOffset + 2, 6, iv.begin() + 6);  	const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)>  		ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); @@ -149,9 +151,9 @@ vector<uint8_t> Channel::encrypt(const vector<uint8_t> & plain)  			nullptr, key.data(), iv.data());  	int outl = 0; -	uint8_t * cur = res.data() + 8; +	uint8_t * cur = encBuffer.data() + encOffset + 8; -	if (EVP_EncryptUpdate(ctx.get(), cur, &outl, plain.data(), plain.size()) != 1) +	if (EVP_EncryptUpdate(ctx.get(), cur, &outl, &*plainBegin, plainSize) != 1)  		throw runtime_error("failed to encrypt data");  	cur += outl; @@ -162,17 +164,19 @@ vector<uint8_t> Channel::encrypt(const vector<uint8_t> & plain)  	EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, 16, cur);  	cur += 16; -	res.resize(cur - res.data()); -	return res; +	encBuffer.resize(cur - encBuffer.data()); +	return 0;  } -optional<vector<uint8_t>> Channel::decrypt(const vector<uint8_t> & ctext) +optional<uint64_t> Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd, +			Buffer & decBuffer, const size_t decOffset)  { -	vector<uint8_t> res(ctext.size()); +	auto encSize = encEnd - encBegin; +	decBuffer.resize(decOffset + encSize);  	array<uint8_t, 12> iv;  	std::copy_n(nonceFixedPeer.begin(), 6, iv.begin()); -	std::copy_n(ctext.begin() + 2, 6, iv.begin() + 6); +	std::copy_n(encBegin + 2, 6, iv.begin() + 6);  	const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)>  		ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); @@ -180,21 +184,21 @@ optional<vector<uint8_t>> Channel::decrypt(const vector<uint8_t> & ctext)  			nullptr, key.data(), iv.data());  	int outl = 0; -	uint8_t * cur = res.data(); +	uint8_t * cur = decBuffer.data() + decOffset;  	if (EVP_DecryptUpdate(ctx.get(), cur, &outl, -				ctext.data() + 8, ctext.size() - 8 - 16) != 1) +				&*encBegin + 8, encSize - 8 - 16) != 1)  		return nullopt;  	cur += outl;  	if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, 16, -				(void *) (ctext.data() + ctext.size() - 16))) +				(void *) (&*encEnd - 16)))  		return nullopt;  	if (EVP_DecryptFinal_ex(ctx.get(), cur, &outl) != 1)  		return nullopt;  	cur += outl; -	res.resize(cur - res.data()); -	return res; +	decBuffer.resize(cur - decBuffer.data()); +	return 0;  } diff --git a/src/network/channel.h b/src/network/channel.h index f932c84..98bfd29 100644 --- a/src/network/channel.h +++ b/src/network/channel.h @@ -58,8 +58,12 @@ public:  	static optional<Stored<ChannelAccept>> acceptRequest(const Identity & self,  			const Identity & peer, const Stored<ChannelRequest> & request); -	vector<uint8_t> encrypt(const vector<uint8_t> &); -	optional<vector<uint8_t>> decrypt(const vector<uint8_t> &); +	using Buffer = vector<uint8_t>; +	using BufferCIt = Buffer::const_iterator; +	uint64_t encrypt(BufferCIt plainBegin, BufferCIt plainEnd, +			Buffer & encBuffer, size_t encOffset); +	optional<uint64_t> decrypt(BufferCIt encBegin, BufferCIt encEnd, +			Buffer & decBuffer, size_t decOffset);  private:  	const vector<Stored<Signed<IdentityData>>> peers; diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 5dc831a..f38267f 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -3,6 +3,7 @@  #include <sys/socket.h>  #include <unistd.h> +#include <algorithm>  #include <cstring>  #include <iostream>  #include <mutex> @@ -173,7 +174,8 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const  optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)  {  	vector<uint8_t> buf, decrypted; -	vector<uint8_t> * current; +	auto plainBegin = buf.cbegin(); +	auto plainEnd = buf.cbegin();  	{  		scoped_lock lock(p->cmutex); @@ -182,28 +184,47 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  			return nullopt;  		buf.swap(p->buffer); -		current = &buf; -		if (holds_alternative<unique_ptr<Channel>>(p->channel)) { -			if (auto dec = std::get<unique_ptr<Channel>>(p->channel)->decrypt(buf)) { -				decrypted = std::move(*dec); -				current = &decrypted; +		if ((buf[0] & 0xE0) == 0x80) { +			Channel * channel = nullptr; +			unique_ptr<Channel> channelPtr; + +			if (holds_alternative<unique_ptr<Channel>>(p->channel)) { +				channel = std::get<unique_ptr<Channel>>(p->channel).get(); +			} else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { +				channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); +				channel = channelPtr.get();  			} -		} else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { -			if (auto dec = std::get<Stored<ChannelAccept>>(p->channel)-> -					data->channel()->decrypt(buf)) { -				decrypted = std::move(*dec); -				current = &decrypted; + +			if (not channel) { +				std::cerr << "unexpected encrypted packet\n"; +				return nullopt;  			} + +			if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { +				if (decrypted.empty()) { +					std::cerr << "empty decrypted content\n"; +				} +				else if (decrypted[0] == 0x00) { +					plainBegin = decrypted.begin() + 1; +					plainEnd = decrypted.end(); +				} +				else { +					std::cerr << "streams not implemented\n"; +					return nullopt; +				} +			} +		} +		else if ((buf[0] & 0xE0) == 0x60) { +			plainBegin = buf.begin(); +			plainEnd = buf.end();  		}  	} -	if (auto dec = PartialObject::decodePrefix(partStorage, -			current->begin(), current->end())) { +	if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) {  		if (auto header = Header::load(std::get<PartialObject>(*dec))) {  			auto pos = std::get<1>(*dec); -			while (auto cdec = PartialObject::decodePrefix(partStorage, -						pos, current->end())) { +			while (auto cdec = PartialObject::decodePrefix(partStorage, pos, plainEnd)) {  				partStorage.storeObject(std::get<PartialObject>(*cdec));  				pos = std::get<1>(*cdec);  			} @@ -225,6 +246,13 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,  	{  		scoped_lock clock(p->cmutex); +		Channel * channel = nullptr; +		if (holds_alternative<unique_ptr<Channel>>(p->channel)) +			channel = std::get<unique_ptr<Channel>>(p->channel).get(); + +		if (channel || secure) +			data.push_back(0x00); +  		part = header.toObject(partStorage).encode();  		data.insert(data.end(), part.begin(), part.end());  		for (const auto & obj : objs) { @@ -232,12 +260,14 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,  			data.insert(data.end(), part.begin(), part.end());  		} -		if (holds_alternative<unique_ptr<Channel>>(p->channel)) -			out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); -		else if (secure) +		if (channel) { +			out.push_back(0x80); +			channel->encrypt(data.begin(), data.end(), out, 1); +		} else if (secure) {  			p->secureOutQueue.emplace_back(move(data)); -		else +		} else {  			out = std::move(data); +		}  	}  	if (not out.empty()) @@ -285,8 +315,9 @@ void NetworkProtocol::Connection::trySendOutQueue()  		queue.swap(p->secureOutQueue);  	} +	vector<uint8_t> out { 0x80 };  	for (const auto & data : queue) { -		auto out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data); +		std::get<unique_ptr<Channel>>(p->channel)->encrypt(data.begin(), data.end(), out, 1);  		p->protocol->sendto(out, p->peerAddress);  	}  } |