From ab86a1f0c3b86050e65fc5b7ac1e88a00f0d228c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Mon, 27 Jan 2020 21:25:39 +0100 Subject: Encrypted channels --- src/CMakeLists.txt | 1 + src/channel.cpp | 243 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/channel.h | 58 +++++++++++++ src/identity.cpp | 18 +++- src/identity.h | 6 +- src/network.cpp | 153 +++++++++++++++++++++++++++------ src/network.h | 25 ++++-- src/pubkey.cpp | 104 +++++++++++++++++++++++ src/pubkey.h | 31 +++++++ 9 files changed, 604 insertions(+), 35 deletions(-) create mode 100644 src/channel.cpp create mode 100644 src/channel.h (limited to 'src') diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 75eff66..a50bf66 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories( ) add_library(erebos + channel identity network pubkey diff --git a/src/channel.cpp b/src/channel.cpp new file mode 100644 index 0000000..38d263e --- /dev/null +++ b/src/channel.cpp @@ -0,0 +1,243 @@ +#include "channel.h" + +#include +#include + +#include + +using std::remove_const; +using std::runtime_error; + +using namespace erebos; + +Ref ChannelRequestData::store(const Storage & st) const +{ + vector items; + + for (const auto p : peers) + items.emplace_back("peer", p); + items.emplace_back("enc", "aes-128-gcm"); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +optional ChannelRequestData::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + remove_const::type peers; + for (const auto & i : rec->items("peer")) + if (auto p = i.as>()) + peers.push_back(*p); + + auto enc = rec->item("enc").asText(); + if (!enc || enc != "aes-128-gcm") + return nullopt; + + auto key = rec->item("key").as(); + if (!key) + return nullopt; + + return ChannelRequestData { + .peers = std::move(peers), + .key = *key, + }; +} + +Ref ChannelAcceptData::store(const Storage & st) const +{ + vector items; + + items.emplace_back("req", request); + items.emplace_back("enc", "aes-128-gcm"); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +optional ChannelAcceptData::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + auto request = rec->item("req").as(); + if (!request) + return nullopt; + + auto enc = rec->item("enc").asText(); + if (!enc || enc != "aes-128-gcm") + return nullopt; + + auto key = rec->item("key").as(); + if (!key) + return nullopt; + + return ChannelAcceptData { + .request = *request, + .key = *key, + }; +} + +Stored ChannelAcceptData::channel() const +{ + const auto & st = request.ref.storage(); + + if (auto secret = SecretKexKey::load(key)) + return st.store(Channel( + request->data->peers, + secret->dh(*request->data->key) + )); + + if (auto secret = SecretKexKey::load(request->data->key)) + return st.store(Channel( + request->data->peers, + secret->dh(*key) + )); + + throw runtime_error("failed to load secret DH key"); +} + + +Ref Channel::store(const Storage & st) const +{ + vector items; + + for (const auto p : peers) + items.emplace_back("peer", p); + items.emplace_back("enc", "aes-128-gcm"); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +optional Channel::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + remove_const::type peers; + for (const auto & i : rec->items("peer")) + if (auto p = i.as>()) + peers.push_back(*p); + + auto enc = rec->item("enc").asText(); + if (!enc || enc != "aes-128-gcm") + return nullopt; + + auto key = rec->item("key").asBinary(); + if (!key) + return nullopt; + + return Channel(peers, std::move(*key)); +} + +Stored Channel::generateRequest(const Storage & st, + const Identity & self, const Identity & peer) +{ + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelRequestData { + .peers = self.ref()->digest() < peer.ref()->digest() ? + vector>> { + *Stored>::load(*self.ref()), + *Stored>::load(*peer.ref()), + } : + vector>> { + *Stored>::load(*peer.ref()), + *Stored>::load(*self.ref()), + }, + .key = SecretKexKey::generate(st).pub(), + })); +} + +optional> Channel::acceptRequest(const Identity & self, + const Identity & peer, const Stored & request) +{ + if (!request->isSignedBy(peer.keyMessage())) + return nullopt; + + auto & peers = request->data->peers; + if (peers.size() != 2 || + std::none_of(peers.begin(), peers.end(), [&self](const auto & x) + { return x.ref.digest() == self.ref()->digest(); }) || + std::none_of(peers.begin(), peers.end(), [&peer](const auto & x) + { return x.ref.digest() == peer.ref()->digest(); })) + return nullopt; + + auto & st = request.ref.storage(); + + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelAcceptData { + .request = request, + .key = SecretKexKey::generate(st).pub(), + })); +} + +vector Channel::encrypt(const vector & plain) const +{ + vector res(plain.size() + 12 + 16 + 16); + + if (RAND_bytes(res.data(), 12) != 1) + throw runtime_error("failed to generate random IV"); + + const unique_ptr + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_gcm(), + nullptr, key.data(), res.data()); + + int outl = 0; + uint8_t * cur = res.data() + 12; + + if (EVP_EncryptUpdate(ctx.get(), cur, &outl, plain.data(), plain.size()) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + if (EVP_EncryptFinal(ctx.get(), cur, &outl) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, 16, cur); + cur += 16; + + res.resize(cur - res.data()); + return res; +} + +optional> Channel::decrypt(const vector & ctext) const +{ + vector res(ctext.size()); + + const unique_ptr + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_DecryptInit_ex(ctx.get(), EVP_aes_128_gcm(), + nullptr, key.data(), ctext.data()); + + int outl = 0; + uint8_t * cur = res.data(); + + if (EVP_DecryptUpdate(ctx.get(), cur, &outl, + ctext.data() + 12, ctext.size() - 12 - 16) != 1) + return nullopt; + cur += outl; + + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, 16, + (void *) (ctext.data() + ctext.size() - 16))) + return nullopt; + + if (EVP_DecryptFinal_ex(ctx.get(), cur, &outl) != 1) + return nullopt; + cur += outl; + + res.resize(cur - res.data()); + return res; +} diff --git a/src/channel.h b/src/channel.h new file mode 100644 index 0000000..100003c --- /dev/null +++ b/src/channel.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +#include "identity.h" + +namespace erebos { + +struct ChannelRequestData +{ + Ref store(const Storage & st) const; + static optional load(const Ref &); + + const vector>> peers; + const Stored key; +}; + +typedef Signed ChannelRequest; + +struct ChannelAcceptData +{ + Ref store(const Storage & st) const; + static optional load(const Ref &); + + Stored channel() const; + + const Stored request; + const Stored key; +}; + +typedef Signed ChannelAccept; + +class Channel +{ +public: + Channel(const vector>> & peers, + vector && key): + peers(peers), + key(std::move(key)) + {} + + Ref store(const Storage & st) const; + static optional load(const Ref &); + + static Stored generateRequest(const Storage &, + const Identity & self, const Identity & peer); + static optional> acceptRequest(const Identity & self, + const Identity & peer, const Stored & request); + + vector encrypt(const vector &) const; + optional> decrypt(const vector &) const; + +private: + const vector>> peers; + const vector key; +}; + +} diff --git a/src/identity.cpp b/src/identity.cpp index cee2688..57f25cc 100644 --- a/src/identity.cpp +++ b/src/identity.cpp @@ -43,6 +43,11 @@ optional Identity::owner() const return p->owner; } +Stored Identity::keyMessage() const +{ + return p->keyMessage; +} + optional Identity::ref() const { if (p->data.size() == 1) @@ -167,14 +172,20 @@ shared_ptr Identity::Priv::validate(const vectorkeyMessage.value(), }; shared_ptr ret(p); - auto ownerProp = p->lookupProperty([] + auto ownerProp = lookupProperty(sdata, [] (const IdentityData & d) { return d.owner.has_value(); }); if (ownerProp) { auto owner = validate({ *ownerProp.value()->owner }); @@ -184,7 +195,7 @@ shared_ptr Identity::Priv::validate(const vectorname = async(std::launch::deferred, [p] () -> optional { - if (auto d = p->lookupProperty([] (const IdentityData & d) { return d.name.has_value(); })) + if (auto d = lookupProperty(p->data, [] (const IdentityData & d) { return d.name.has_value(); })) return d.value()->name; return nullopt; }); @@ -193,7 +204,8 @@ shared_ptr Identity::Priv::validate(const vector> Identity::Priv::lookupProperty( - function sel) const + const vector>> & data, + function sel) { set>> current, prop_heads; diff --git a/src/identity.h b/src/identity.h index 4335d32..79b335e 100644 --- a/src/identity.h +++ b/src/identity.h @@ -29,11 +29,13 @@ public: vector>> data; shared_future> name; optional owner; + Stored keyMessage; static bool verifySignatures(const Stored> & sdata); static shared_ptr validate(const vector>> & sdata); - optional> lookupProperty( - function sel) const; + static optional> lookupProperty( + const vector>> & data, + function sel); }; class Identity::Builder::Priv diff --git a/src/network.cpp b/src/network.cpp index bd1ea8e..e9aeb3f 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -77,7 +78,7 @@ Server::Priv::~Priv() void Server::Priv::doListen() { - vector buf(4096); + vector buf, decrypted, *current; unique_lock lock(dataMutex); while (!finish) { @@ -85,25 +86,51 @@ void Server::Priv::doListen() lock.unlock(); socklen_t addrlen = sizeof(paddr); + buf.resize(4096); ssize_t ret = recvfrom(sock, buf.data(), buf.size(), 0, (sockaddr *) &paddr, &addrlen); if (ret < 0) throw std::system_error(errno, std::generic_category()); + buf.resize(ret); auto & peer = getPeer(paddr); + + current = &buf; + if (holds_alternative>(peer.channel)) { + if (auto dec = std::get>(peer.channel)->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } else if (holds_alternative>(peer.channel)) { + if (auto dec = std::get>(peer.channel)-> + data->channel()->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } + if (auto dec = PartialObject::decodePrefix(peer.partStorage, - buf.begin(), buf.begin() + ret)) { + current->begin(), current->end())) { if (auto header = TransportHeader::load(std::get(*dec))) { auto pos = std::get<1>(*dec); while (auto cdec = PartialObject::decodePrefix(peer.partStorage, - pos, buf.begin() + ret)) { + pos, current->end())) { peer.partStorage.storeObject(std::get(*cdec)); pos = std::get<1>(*cdec); } + + ReplyBuilder reply; + scoped_lock hlock(dataMutex); - handlePacket(peer, *header); - peer.updateIdentity(); + handlePacket(peer, *header, reply); + peer.updateIdentity(reply); + peer.updateChannel(reply); + + if (!reply.header.empty()) + peer.send(TransportHeader(reply.header), reply.body); } + } else { + std::cerr << "invalid packet\n"; } lock.lock(); @@ -140,7 +167,7 @@ void Server::Priv::doAnnounce() } } -Peer & Server::Priv::getPeer(const sockaddr_in & paddr) +Server::Peer & Server::Priv::getPeer(const sockaddr_in & paddr) { for (auto & peer : peers) if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0) @@ -148,9 +175,10 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) auto st = self.ref()->storage().deriveEphemeralStorage(); Peer * peer = new Peer { - .sock = sock, + .server = *this, .addr = paddr, .identity = monostate(), + .channel = monostate(), .tempStorage = st, .partStorage = st.derivePartialStorage(), }; @@ -158,18 +186,21 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) return *peer; } -void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) +void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & header, ReplyBuilder & reply) { unordered_set plaintextRefs; for (const auto & obj : collectStoredObjects(*Stored::load(*self.ref()))) plaintextRefs.insert(obj.ref.digest()); - vector replyHeaders; - vector replyBody; - for (auto & item : header.items) { switch (item.type) { case TransportHeader::Type::Acknowledged: + if (auto pref = std::get(item.value)) { + if (holds_alternative>(peer.channel) && + std::get>(peer.channel).ref.digest() == pref.digest()) + peer.channel.emplace> + (std::get>(peer.channel)->data->channel()); + } break; case TransportHeader::Type::DataRequest: { @@ -177,8 +208,8 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) if (plaintextRefs.find(pref.digest()) != plaintextRefs.end()) { if (auto ref = peer.tempStorage.ref(pref.digest())) { TransportHeader::Item hitem { TransportHeader::Type::DataResponse, *ref }; - replyHeaders.push_back({ TransportHeader::Type::DataResponse, *ref }); - replyBody.push_back(**ref); + reply.header.push_back({ TransportHeader::Type::DataResponse, *ref }); + reply.body.push_back(**ref); } } break; @@ -186,12 +217,12 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) case TransportHeader::Type::DataResponse: if (auto pref = std::get(item.value)) { - replyHeaders.push_back({ TransportHeader::Type::Acknowledged, pref }); + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), pref.digest()) != wref->missing.end()) { - if (wref->check(&replyHeaders)) + if (wref->check(&reply.header)) pwref.reset(); } } @@ -207,7 +238,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; if (holds_alternative(peer.identity)) - replyHeaders.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); + reply.header.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); shared_ptr wref(new WaitingRef { .storage = peer.tempStorage, @@ -217,7 +248,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) }); waiting.push_back(wref); peer.identity = wref; - wref->check(&replyHeaders); + wref->check(&reply.header); break; } @@ -225,9 +256,42 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; case TransportHeader::Type::ChannelRequest: + if (auto pref = std::get(item.value)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + + if (holds_alternative>(peer.channel) && + std::get>(peer.channel).ref.digest() < pref.digest()) + break; + + if (holds_alternative>(peer.channel)) + break; + + shared_ptr wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = pref, + .peer = peer, + .missing = {}, + }); + waiting.push_back(wref); + peer.channel = wref; + wref->check(&reply.header); + } break; case TransportHeader::Type::ChannelAccept: + if (auto pref = std::get(item.value)) { + if (holds_alternative>(peer.channel) && + std::get>(peer.channel).ref.digest() < pref.digest()) + break; + + auto cres = peer.tempStorage.copy(pref); + if (auto r = std::get_if(&cres)) { + if (auto acc = ChannelAccept::load(*r)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + peer.channel.emplace>(acc->data->channel()); + } + } + } break; case TransportHeader::Type::ServiceType: @@ -238,14 +302,11 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) } } - - if (!replyHeaders.empty()) - peer.send(TransportHeader(replyHeaders), replyBody); } -void Peer::send(const TransportHeader & header, const vector & objs) const +void Server::Peer::send(const TransportHeader & header, const vector & objs) const { - vector data, part; + vector data, part, out; part = header.toObject().encode(); data.insert(data.end(), part.begin(), part.end()); @@ -254,18 +315,60 @@ void Peer::send(const TransportHeader & header, const vector & objs) con data.insert(data.end(), part.begin(), part.end()); } - sendto(sock, data.data(), data.size(), 0, + if (holds_alternative>(channel)) + out = std::get>(channel)->encrypt(data); + else + out = std::move(data); + + sendto(server.sock, out.data(), out.size(), 0, (sockaddr *) &addr, sizeof(addr)); } -void Peer::updateIdentity() +void Server::Peer::updateIdentity(ReplyBuilder & reply) { if (holds_alternative>(identity)) - if (auto ref = std::get>(identity)->check()) + if (auto ref = std::get>(identity)->check(&reply.header)) if (auto id = Identity::load(*ref)) identity.emplace(*id); } +void Server::Peer::updateChannel(ReplyBuilder & reply) +{ + if (!holds_alternative(identity)) + return; + + if (holds_alternative(channel)) { + auto req = Channel::generateRequest(tempStorage, + server.self, std::get(identity)); + channel.emplace>(req); + reply.header.push_back({ TransportHeader::Type::ChannelRequest, req.ref }); + reply.body.push_back(*req.ref); + reply.body.push_back(*req->data.ref); + reply.body.push_back(*req->data->key.ref); + for (const auto & sig : req->sigs) + reply.body.push_back(*sig.ref); + } + + if (holds_alternative>(channel)) { + if (auto ref = std::get>(channel)->check(&reply.header)) { + if (auto req = Stored::load(*ref)) { + if (auto acc = Channel::acceptRequest(server.self, std::get(identity), *req)) { + channel.emplace>(*acc); + reply.header.push_back({ TransportHeader::Type::ChannelAccept, acc->ref }); + reply.body.push_back(*acc->ref); + reply.body.push_back(*acc.value()->data.ref); + reply.body.push_back(*acc.value()->data->key.ref); + for (const auto & sig : acc.value()->sigs) + reply.body.push_back(*sig.ref); + } else { + channel = monostate(); + } + } else { + channel = monostate(); + } + } + } +} optional WaitingRef::check(vector * request) { diff --git a/src/network.h b/src/network.h index bb32323..e07e020 100644 --- a/src/network.h +++ b/src/network.h @@ -2,6 +2,8 @@ #include +#include "channel.h" + #include #include #include @@ -26,23 +28,30 @@ using chrono::steady_clock; namespace erebos { -struct Peer +struct Server::Peer { Peer(const Peer &) = delete; Peer & operator=(const Peer &) = delete; - const int sock; + Priv & server; const sockaddr_in addr; variant, Identity> identity; + variant, + shared_ptr, + Stored, + Stored> channel; + Storage tempStorage; PartialStorage partStorage; void send(const struct TransportHeader &, const vector &) const; - void updateIdentity(); + void updateIdentity(struct ReplyBuilder &); + void updateChannel(struct ReplyBuilder &); }; struct TransportHeader @@ -76,12 +85,18 @@ struct WaitingRef { const Storage storage; const PartialRef ref; - const Peer & peer; + const Server::Peer & peer; vector missing; optional check(vector * request = nullptr); }; +struct ReplyBuilder +{ + vector header; + vector body; +}; + struct Server::Priv { Priv(const Identity & self); @@ -90,7 +105,7 @@ struct Server::Priv void doAnnounce(); Peer & getPeer(const sockaddr_in & paddr); - void handlePacket(Peer &, const TransportHeader &); + void handlePacket(Peer &, const TransportHeader &, ReplyBuilder &); constexpr static uint16_t discoveryPort { 29665 }; constexpr static chrono::seconds announceInterval { 60 }; diff --git a/src/pubkey.cpp b/src/pubkey.cpp index e26bead..3a08c70 100644 --- a/src/pubkey.cpp +++ b/src/pubkey.cpp @@ -152,3 +152,107 @@ bool Signature::verify(const Ref & ref) const return EVP_DigestVerify(mdctx.get(), sig.data(), sig.size(), ref.digest().arr().data(), Digest::size) == 1; } + + +optional PublicKexKey::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + if (auto ktype = rec->item("type").asText()) + if (ktype.value() != "x25519") + throw runtime_error("unsupported key type " + ktype.value()); + + if (auto pubkey = rec->item("pubkey").asBinary()) + return PublicKexKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + pubkey.value().data(), pubkey.value().size())); + + return nullopt; +} + +Ref PublicKexKey::store(const Storage & st) const +{ + vector items; + + items.emplace_back("type", "x25519"); + + vector keyData; + size_t keyLen; + EVP_PKEY_get_raw_public_key(key.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(key.get(), keyData.data(), &keyLen); + items.emplace_back("pubkey", keyData); + + return st.storeObject(Record(std::move(items))); +} + +SecretKexKey SecretKexKey::generate(const Storage & st) +{ + unique_ptr + pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_X25519, NULL), &EVP_PKEY_CTX_free); + if (!pctx) + throw runtime_error("failed to generate key"); + + if (EVP_PKEY_keygen_init(pctx.get()) != 1) + throw runtime_error("failed to generate key"); + + EVP_PKEY *pkey = NULL; + if (EVP_PKEY_keygen(pctx.get(), &pkey) != 1) + throw runtime_error("failed to generate key"); + shared_ptr seckey(pkey, EVP_PKEY_free); + + vector keyData; + size_t keyLen; + + EVP_PKEY_get_raw_public_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(seckey.get(), keyData.data(), &keyLen); + auto pubkey = st.store(PublicKexKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + keyData.data(), keyData.size()))); + + EVP_PKEY_get_raw_private_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_private_key(seckey.get(), keyData.data(), &keyLen); + st.storeKey(pubkey.ref, keyData); + + return SecretKexKey(std::move(seckey), pubkey); +} + +optional SecretKexKey::load(const Stored & pub) +{ + auto keyData = pub.ref.storage().loadKey(pub.ref); + if (!keyData) + return nullopt; + + EVP_PKEY * key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, nullptr, + keyData->data(), keyData->size()); + if (!key) + throw runtime_error("falied to parse secret key"); + return SecretKexKey(key, pub); +} + +vector SecretKexKey::dh(const PublicKexKey & pubkey) const +{ + unique_ptr + pctx(EVP_PKEY_CTX_new(key.get(), nullptr), &EVP_PKEY_CTX_free); + if (!pctx) + throw runtime_error("failed to derive shared secret"); + + if (EVP_PKEY_derive_init(pctx.get()) <= 0) + throw runtime_error("failed to derive shared secret"); + + if (EVP_PKEY_derive_set_peer(pctx.get(), pubkey.key.get()) <= 0) + throw runtime_error("failed to derive shared secret"); + + size_t dhlen; + if (EVP_PKEY_derive(pctx.get(), NULL, &dhlen) <= 0) + throw runtime_error("failed to derive shared secret"); + + vector dhsecret(dhlen); + + if (EVP_PKEY_derive(pctx.get(), dhsecret.data(), &dhlen) <= 0) + throw runtime_error("failed to derive shared secret"); + + return dhsecret; +} diff --git a/src/pubkey.h b/src/pubkey.h index 80da3fa..b14743d 100644 --- a/src/pubkey.h +++ b/src/pubkey.h @@ -133,4 +133,35 @@ bool Signed::isSignedBy(const Stored & key) const return false; } + +class PublicKexKey +{ + PublicKexKey(EVP_PKEY * key): + key(key, EVP_PKEY_free) {} + friend class SecretKexKey; +public: + static optional load(const Ref &); + Ref store(const Storage &) const; + + const shared_ptr key; +}; + +class SecretKexKey +{ + SecretKexKey(EVP_PKEY * key, const Stored & pub): + key(key, EVP_PKEY_free), pub_(pub) {} + SecretKexKey(shared_ptr && key, const Stored & pub): + key(key), pub_(pub) {} +public: + static SecretKexKey generate(const Storage & st); + static optional load(const Stored & st); + + Stored pub() const { return pub_; } + vector dh(const PublicKexKey &) const; + +private: + const shared_ptr key; + Stored pub_; +}; + } -- cgit v1.2.3