summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRoman Smrž <roman.smrz@seznam.cz>2023-08-21 22:15:32 +0200
committerRoman Smrž <roman.smrz@seznam.cz>2023-08-27 10:53:18 +0200
commit1d4fa8fafa707642f948da9b033a21d0bcde0bbf (patch)
tree0db7ec2673cce166bd322023d39d4003cd1c3d15 /src
parent401f8c1288842b7479c375fba4aed55f6c5d52e9 (diff)
Network: headers for encryption and streams
Diffstat (limited to 'src')
-rw-r--r--src/network/channel.cpp36
-rw-r--r--src/network/channel.h8
-rw-r--r--src/network/protocol.cpp71
3 files changed, 77 insertions, 38 deletions
diff --git a/src/network/channel.cpp b/src/network/channel.cpp
index b317f3d..b95e0a1 100644
--- a/src/network/channel.cpp
+++ b/src/network/channel.cpp
@@ -133,15 +133,17 @@ optional<Stored<ChannelAccept>> Channel::acceptRequest(const Identity & self,
}));
}
-vector<uint8_t> Channel::encrypt(const vector<uint8_t> & plain)
+uint64_t Channel::encrypt(BufferCIt plainBegin, BufferCIt plainEnd,
+ Buffer & encBuffer, size_t encOffset)
{
- vector<uint8_t> res(plain.size() + 8 + 16 + 16);
+ auto plainSize = plainEnd - plainBegin;
+ encBuffer.resize(encOffset + plainSize + 8 + 16 + 16);
array<uint8_t, 12> iv;
uint64_t beCount = htobe64(nonceCounter++);
- std::memcpy(res.data(), &beCount, 8);
+ std::memcpy(encBuffer.data() + encOffset, &beCount, 8);
std::copy_n(nonceFixedOur.begin(), 6, iv.begin());
- std::copy_n(res.begin() + 2, 6, iv.begin() + 6);
+ std::copy_n(encBuffer.begin() + encOffset + 2, 6, iv.begin() + 6);
const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)>
ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
@@ -149,9 +151,9 @@ vector<uint8_t> Channel::encrypt(const vector<uint8_t> & plain)
nullptr, key.data(), iv.data());
int outl = 0;
- uint8_t * cur = res.data() + 8;
+ uint8_t * cur = encBuffer.data() + encOffset + 8;
- if (EVP_EncryptUpdate(ctx.get(), cur, &outl, plain.data(), plain.size()) != 1)
+ if (EVP_EncryptUpdate(ctx.get(), cur, &outl, &*plainBegin, plainSize) != 1)
throw runtime_error("failed to encrypt data");
cur += outl;
@@ -162,17 +164,19 @@ vector<uint8_t> Channel::encrypt(const vector<uint8_t> & plain)
EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, 16, cur);
cur += 16;
- res.resize(cur - res.data());
- return res;
+ encBuffer.resize(cur - encBuffer.data());
+ return 0;
}
-optional<vector<uint8_t>> Channel::decrypt(const vector<uint8_t> & ctext)
+optional<uint64_t> Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd,
+ Buffer & decBuffer, const size_t decOffset)
{
- vector<uint8_t> res(ctext.size());
+ auto encSize = encEnd - encBegin;
+ decBuffer.resize(decOffset + encSize);
array<uint8_t, 12> iv;
std::copy_n(nonceFixedPeer.begin(), 6, iv.begin());
- std::copy_n(ctext.begin() + 2, 6, iv.begin() + 6);
+ std::copy_n(encBegin + 2, 6, iv.begin() + 6);
const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)>
ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
@@ -180,21 +184,21 @@ optional<vector<uint8_t>> Channel::decrypt(const vector<uint8_t> & ctext)
nullptr, key.data(), iv.data());
int outl = 0;
- uint8_t * cur = res.data();
+ uint8_t * cur = decBuffer.data() + decOffset;
if (EVP_DecryptUpdate(ctx.get(), cur, &outl,
- ctext.data() + 8, ctext.size() - 8 - 16) != 1)
+ &*encBegin + 8, encSize - 8 - 16) != 1)
return nullopt;
cur += outl;
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, 16,
- (void *) (ctext.data() + ctext.size() - 16)))
+ (void *) (&*encEnd - 16)))
return nullopt;
if (EVP_DecryptFinal_ex(ctx.get(), cur, &outl) != 1)
return nullopt;
cur += outl;
- res.resize(cur - res.data());
- return res;
+ decBuffer.resize(cur - decBuffer.data());
+ return 0;
}
diff --git a/src/network/channel.h b/src/network/channel.h
index f932c84..98bfd29 100644
--- a/src/network/channel.h
+++ b/src/network/channel.h
@@ -58,8 +58,12 @@ public:
static optional<Stored<ChannelAccept>> acceptRequest(const Identity & self,
const Identity & peer, const Stored<ChannelRequest> & request);
- vector<uint8_t> encrypt(const vector<uint8_t> &);
- optional<vector<uint8_t>> decrypt(const vector<uint8_t> &);
+ using Buffer = vector<uint8_t>;
+ using BufferCIt = Buffer::const_iterator;
+ uint64_t encrypt(BufferCIt plainBegin, BufferCIt plainEnd,
+ Buffer & encBuffer, size_t encOffset);
+ optional<uint64_t> decrypt(BufferCIt encBegin, BufferCIt encEnd,
+ Buffer & decBuffer, size_t decOffset);
private:
const vector<Stored<Signed<IdentityData>>> peers;
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index 5dc831a..f38267f 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -3,6 +3,7 @@
#include <sys/socket.h>
#include <unistd.h>
+#include <algorithm>
#include <cstring>
#include <iostream>
#include <mutex>
@@ -173,7 +174,8 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const
optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)
{
vector<uint8_t> buf, decrypted;
- vector<uint8_t> * current;
+ auto plainBegin = buf.cbegin();
+ auto plainEnd = buf.cbegin();
{
scoped_lock lock(p->cmutex);
@@ -182,28 +184,47 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
return nullopt;
buf.swap(p->buffer);
- current = &buf;
- if (holds_alternative<unique_ptr<Channel>>(p->channel)) {
- if (auto dec = std::get<unique_ptr<Channel>>(p->channel)->decrypt(buf)) {
- decrypted = std::move(*dec);
- current = &decrypted;
+ if ((buf[0] & 0xE0) == 0x80) {
+ Channel * channel = nullptr;
+ unique_ptr<Channel> channelPtr;
+
+ if (holds_alternative<unique_ptr<Channel>>(p->channel)) {
+ channel = std::get<unique_ptr<Channel>>(p->channel).get();
+ } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) {
+ channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel();
+ channel = channelPtr.get();
}
- } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) {
- if (auto dec = std::get<Stored<ChannelAccept>>(p->channel)->
- data->channel()->decrypt(buf)) {
- decrypted = std::move(*dec);
- current = &decrypted;
+
+ if (not channel) {
+ std::cerr << "unexpected encrypted packet\n";
+ return nullopt;
}
+
+ if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) {
+ if (decrypted.empty()) {
+ std::cerr << "empty decrypted content\n";
+ }
+ else if (decrypted[0] == 0x00) {
+ plainBegin = decrypted.begin() + 1;
+ plainEnd = decrypted.end();
+ }
+ else {
+ std::cerr << "streams not implemented\n";
+ return nullopt;
+ }
+ }
+ }
+ else if ((buf[0] & 0xE0) == 0x60) {
+ plainBegin = buf.begin();
+ plainEnd = buf.end();
}
}
- if (auto dec = PartialObject::decodePrefix(partStorage,
- current->begin(), current->end())) {
+ if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) {
if (auto header = Header::load(std::get<PartialObject>(*dec))) {
auto pos = std::get<1>(*dec);
- while (auto cdec = PartialObject::decodePrefix(partStorage,
- pos, current->end())) {
+ while (auto cdec = PartialObject::decodePrefix(partStorage, pos, plainEnd)) {
partStorage.storeObject(std::get<PartialObject>(*cdec));
pos = std::get<1>(*cdec);
}
@@ -225,6 +246,13 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
{
scoped_lock clock(p->cmutex);
+ Channel * channel = nullptr;
+ if (holds_alternative<unique_ptr<Channel>>(p->channel))
+ channel = std::get<unique_ptr<Channel>>(p->channel).get();
+
+ if (channel || secure)
+ data.push_back(0x00);
+
part = header.toObject(partStorage).encode();
data.insert(data.end(), part.begin(), part.end());
for (const auto & obj : objs) {
@@ -232,12 +260,14 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
data.insert(data.end(), part.begin(), part.end());
}
- if (holds_alternative<unique_ptr<Channel>>(p->channel))
- out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data);
- else if (secure)
+ if (channel) {
+ out.push_back(0x80);
+ channel->encrypt(data.begin(), data.end(), out, 1);
+ } else if (secure) {
p->secureOutQueue.emplace_back(move(data));
- else
+ } else {
out = std::move(data);
+ }
}
if (not out.empty())
@@ -285,8 +315,9 @@ void NetworkProtocol::Connection::trySendOutQueue()
queue.swap(p->secureOutQueue);
}
+ vector<uint8_t> out { 0x80 };
for (const auto & data : queue) {
- auto out = std::get<unique_ptr<Channel>>(p->channel)->encrypt(data);
+ std::get<unique_ptr<Channel>>(p->channel)->encrypt(data.begin(), data.end(), out, 1);
p->protocol->sendto(out, p->peerAddress);
}
}