From c44b82faaf309c916a1aecf4ec939510e6384ae5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Fri, 25 Aug 2023 21:36:59 +0200 Subject: Switch to ChaCha20-Poly1305 AEAD scheme --- src/network/channel.cpp | 74 ++++++++++++++++++++++++++++--------------------- src/network/channel.h | 11 ++++---- 2 files changed, 49 insertions(+), 36 deletions(-) diff --git a/src/network/channel.cpp b/src/network/channel.cpp index b95e0a1..5fff1fa 100644 --- a/src/network/channel.cpp +++ b/src/network/channel.cpp @@ -17,7 +17,6 @@ Ref ChannelRequestData::store(const Storage & st) const for (const auto & p : peers) items.emplace_back("peer", p); - items.emplace_back("enc", "aes-128-gcm"); items.emplace_back("key", key); return st.storeObject(Record(std::move(items))); @@ -26,12 +25,11 @@ Ref ChannelRequestData::store(const Storage & st) const ChannelRequestData ChannelRequestData::load(const Ref & ref) { if (auto rec = ref->asRecord()) { - if (rec->item("enc").asText() == "aes-128-gcm") - if (auto key = rec->item("key").as()) - return ChannelRequestData { - .peers = rec->items("peer").as>(), - .key = *key, - }; + if (auto key = rec->item("key").as()) + return ChannelRequestData { + .peers = rec->items("peer").as>(), + .key = *key, + }; } return ChannelRequestData { @@ -45,7 +43,6 @@ Ref ChannelAcceptData::store(const Storage & st) const vector items; items.emplace_back("req", request); - items.emplace_back("enc", "aes-128-gcm"); items.emplace_back("key", key); return st.storeObject(Record(std::move(items))); @@ -54,11 +51,10 @@ Ref ChannelAcceptData::store(const Storage & st) const ChannelAcceptData ChannelAcceptData::load(const Ref & ref) { if (auto rec = ref->asRecord()) - if (rec->item("enc").asText() == "aes-128-gcm") - return ChannelAcceptData { - .request = *rec->item("req").as(), - .key = *rec->item("key").as(), - }; + return ChannelAcceptData { + .request = *rec->item("req").as(), + .key = *rec->item("key").as(), + }; return ChannelAcceptData { .request = Stored::load(ref.storage().zref()), @@ -137,21 +133,26 @@ uint64_t Channel::encrypt(BufferCIt plainBegin, BufferCIt plainEnd, Buffer & encBuffer, size_t encOffset) { auto plainSize = plainEnd - plainBegin; - encBuffer.resize(encOffset + plainSize + 8 + 16 + 16); + encBuffer.resize(encOffset + plainSize + 1 /* counter */ + 16 /* tag */); array iv; - uint64_t beCount = htobe64(nonceCounter++); - std::memcpy(encBuffer.data() + encOffset, &beCount, 8); - std::copy_n(nonceFixedOur.begin(), 6, iv.begin()); - std::copy_n(encBuffer.begin() + encOffset + 2, 6, iv.begin() + 6); + uint64_t count = counterNextOut.fetch_add(1); + uint64_t beCount = htobe64(count); + encBuffer[encOffset] = count % 0x100; + + constexpr size_t nonceFixedSize = std::tuple_size_v; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedOur.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); const unique_ptr ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); - EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_gcm(), + EVP_EncryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), nullptr, key.data(), iv.data()); int outl = 0; - uint8_t * cur = encBuffer.data() + encOffset + 8; + uint8_t * cur = encBuffer.data() + encOffset + 1; if (EVP_EncryptUpdate(ctx.get(), cur, &outl, &*plainBegin, plainSize) != 1) throw runtime_error("failed to encrypt data"); @@ -161,11 +162,8 @@ uint64_t Channel::encrypt(BufferCIt plainBegin, BufferCIt plainEnd, throw runtime_error("failed to encrypt data"); cur += outl; - EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, 16, cur); - cur += 16; - - encBuffer.resize(cur - encBuffer.data()); - return 0; + EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_GET_TAG, 16, cur); + return count; } optional Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd, @@ -175,23 +173,33 @@ optional Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd, decBuffer.resize(decOffset + encSize); array iv; - std::copy_n(nonceFixedPeer.begin(), 6, iv.begin()); - std::copy_n(encBegin + 2, 6, iv.begin() + 6); + if (encBegin + 1 /* counter */ + 16 /* tag */ > encEnd) + return nullopt; + + uint64_t expectedCount = counterNextIn.load(); + uint64_t guessedCount = expectedCount - 0x80u + ((0x80u + encBegin[0] - expectedCount) % 0x100u); + uint64_t beCount = htobe64(guessedCount); + + constexpr size_t nonceFixedSize = std::tuple_size_v; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedPeer.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); const unique_ptr ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); - EVP_DecryptInit_ex(ctx.get(), EVP_aes_128_gcm(), + EVP_DecryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), nullptr, key.data(), iv.data()); int outl = 0; uint8_t * cur = decBuffer.data() + decOffset; if (EVP_DecryptUpdate(ctx.get(), cur, &outl, - &*encBegin + 8, encSize - 8 - 16) != 1) + &*encBegin + 1, encSize - 1 - 16) != 1) return nullopt; cur += outl; - if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, 16, + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_TAG, 16, (void *) (&*encEnd - 16))) return nullopt; @@ -199,6 +207,10 @@ optional Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd, return nullopt; cur += outl; + while (expectedCount < guessedCount + 1 && + not counterNextIn.compare_exchange_weak(expectedCount, guessedCount + 1)) + ; // empty loop body + decBuffer.resize(cur - decBuffer.data()); - return 0; + return guessedCount; } diff --git a/src/network/channel.h b/src/network/channel.h index 98bfd29..bba11b3 100644 --- a/src/network/channel.h +++ b/src/network/channel.h @@ -44,8 +44,8 @@ public: vector && key, bool ourRequest): peers(peers), key(std::move(key)), - nonceFixedOur({ uint8_t(ourRequest ? 1 : 2), 0, 0, 0, 0, 0 }), - nonceFixedPeer({ uint8_t(ourRequest ? 2 : 1), 0, 0, 0, 0, 0 }) + nonceFixedOur({ uint8_t(ourRequest ? 1 : 2), 0, 0, 0 }), + nonceFixedPeer({ uint8_t(ourRequest ? 2 : 1), 0, 0, 0 }) {} Channel(const Channel &) = delete; @@ -69,9 +69,10 @@ private: const vector>> peers; const vector key; - const array nonceFixedOur; - const array nonceFixedPeer; - atomic nonceCounter = 0; + const array nonceFixedOur; + const array nonceFixedPeer; + atomic counterNextOut = 0; + atomic counterNextIn = 0; }; } -- cgit v1.2.3