diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/channel.cpp | 243 | ||||
-rw-r--r-- | src/channel.h | 58 | ||||
-rw-r--r-- | src/identity.cpp | 18 | ||||
-rw-r--r-- | src/identity.h | 6 | ||||
-rw-r--r-- | src/network.cpp | 153 | ||||
-rw-r--r-- | src/network.h | 25 | ||||
-rw-r--r-- | src/pubkey.cpp | 104 | ||||
-rw-r--r-- | src/pubkey.h | 31 |
9 files changed, 604 insertions, 35 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 75eff66..a50bf66 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories( ) add_library(erebos + channel identity network pubkey diff --git a/src/channel.cpp b/src/channel.cpp new file mode 100644 index 0000000..38d263e --- /dev/null +++ b/src/channel.cpp @@ -0,0 +1,243 @@ +#include "channel.h" + +#include <algorithm> +#include <stdexcept> + +#include <openssl/rand.h> + +using std::remove_const; +using std::runtime_error; + +using namespace erebos; + +Ref ChannelRequestData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto p : peers) + items.emplace_back("peer", p); + items.emplace_back("enc", "aes-128-gcm"); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +optional<ChannelRequestData> ChannelRequestData::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + remove_const<decltype(peers)>::type peers; + for (const auto & i : rec->items("peer")) + if (auto p = i.as<Signed<IdentityData>>()) + peers.push_back(*p); + + auto enc = rec->item("enc").asText(); + if (!enc || enc != "aes-128-gcm") + return nullopt; + + auto key = rec->item("key").as<PublicKexKey>(); + if (!key) + return nullopt; + + return ChannelRequestData { + .peers = std::move(peers), + .key = *key, + }; +} + +Ref ChannelAcceptData::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("req", request); + items.emplace_back("enc", "aes-128-gcm"); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +optional<ChannelAcceptData> ChannelAcceptData::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + auto request = rec->item("req").as<ChannelRequest>(); + if (!request) + return nullopt; + + auto enc = rec->item("enc").asText(); + if (!enc || enc != "aes-128-gcm") + return nullopt; + + auto key = rec->item("key").as<PublicKexKey>(); + if (!key) + return nullopt; + + return ChannelAcceptData { + .request = *request, + .key = *key, + }; +} + +Stored<Channel> ChannelAcceptData::channel() const +{ + const auto & st = request.ref.storage(); + + if (auto secret = SecretKexKey::load(key)) + return st.store(Channel( + request->data->peers, + secret->dh(*request->data->key) + )); + + if (auto secret = SecretKexKey::load(request->data->key)) + return st.store(Channel( + request->data->peers, + secret->dh(*key) + )); + + throw runtime_error("failed to load secret DH key"); +} + + +Ref Channel::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto p : peers) + items.emplace_back("peer", p); + items.emplace_back("enc", "aes-128-gcm"); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +optional<Channel> Channel::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + remove_const<decltype(peers)>::type peers; + for (const auto & i : rec->items("peer")) + if (auto p = i.as<Signed<IdentityData>>()) + peers.push_back(*p); + + auto enc = rec->item("enc").asText(); + if (!enc || enc != "aes-128-gcm") + return nullopt; + + auto key = rec->item("key").asBinary(); + if (!key) + return nullopt; + + return Channel(peers, std::move(*key)); +} + +Stored<ChannelRequest> Channel::generateRequest(const Storage & st, + const Identity & self, const Identity & peer) +{ + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelRequestData { + .peers = self.ref()->digest() < peer.ref()->digest() ? + vector<Stored<Signed<IdentityData>>> { + *Stored<Signed<IdentityData>>::load(*self.ref()), + *Stored<Signed<IdentityData>>::load(*peer.ref()), + } : + vector<Stored<Signed<IdentityData>>> { + *Stored<Signed<IdentityData>>::load(*peer.ref()), + *Stored<Signed<IdentityData>>::load(*self.ref()), + }, + .key = SecretKexKey::generate(st).pub(), + })); +} + +optional<Stored<ChannelAccept>> Channel::acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request) +{ + if (!request->isSignedBy(peer.keyMessage())) + return nullopt; + + auto & peers = request->data->peers; + if (peers.size() != 2 || + std::none_of(peers.begin(), peers.end(), [&self](const auto & x) + { return x.ref.digest() == self.ref()->digest(); }) || + std::none_of(peers.begin(), peers.end(), [&peer](const auto & x) + { return x.ref.digest() == peer.ref()->digest(); })) + return nullopt; + + auto & st = request.ref.storage(); + + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelAcceptData { + .request = request, + .key = SecretKexKey::generate(st).pub(), + })); +} + +vector<uint8_t> Channel::encrypt(const vector<uint8_t> & plain) const +{ + vector<uint8_t> res(plain.size() + 12 + 16 + 16); + + if (RAND_bytes(res.data(), 12) != 1) + throw runtime_error("failed to generate random IV"); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_EncryptInit_ex(ctx.get(), EVP_aes_128_gcm(), + nullptr, key.data(), res.data()); + + int outl = 0; + uint8_t * cur = res.data() + 12; + + if (EVP_EncryptUpdate(ctx.get(), cur, &outl, plain.data(), plain.size()) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + if (EVP_EncryptFinal(ctx.get(), cur, &outl) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, 16, cur); + cur += 16; + + res.resize(cur - res.data()); + return res; +} + +optional<vector<uint8_t>> Channel::decrypt(const vector<uint8_t> & ctext) const +{ + vector<uint8_t> res(ctext.size()); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_DecryptInit_ex(ctx.get(), EVP_aes_128_gcm(), + nullptr, key.data(), ctext.data()); + + int outl = 0; + uint8_t * cur = res.data(); + + if (EVP_DecryptUpdate(ctx.get(), cur, &outl, + ctext.data() + 12, ctext.size() - 12 - 16) != 1) + return nullopt; + cur += outl; + + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, 16, + (void *) (ctext.data() + ctext.size() - 16))) + return nullopt; + + if (EVP_DecryptFinal_ex(ctx.get(), cur, &outl) != 1) + return nullopt; + cur += outl; + + res.resize(cur - res.data()); + return res; +} diff --git a/src/channel.h b/src/channel.h new file mode 100644 index 0000000..100003c --- /dev/null +++ b/src/channel.h @@ -0,0 +1,58 @@ +#pragma once + +#include <erebos/storage.h> + +#include "identity.h" + +namespace erebos { + +struct ChannelRequestData +{ + Ref store(const Storage & st) const; + static optional<ChannelRequestData> load(const Ref &); + + const vector<Stored<Signed<IdentityData>>> peers; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelRequestData> ChannelRequest; + +struct ChannelAcceptData +{ + Ref store(const Storage & st) const; + static optional<ChannelAcceptData> load(const Ref &); + + Stored<class Channel> channel() const; + + const Stored<ChannelRequest> request; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelAcceptData> ChannelAccept; + +class Channel +{ +public: + Channel(const vector<Stored<Signed<IdentityData>>> & peers, + vector<uint8_t> && key): + peers(peers), + key(std::move(key)) + {} + + Ref store(const Storage & st) const; + static optional<Channel> load(const Ref &); + + static Stored<ChannelRequest> generateRequest(const Storage &, + const Identity & self, const Identity & peer); + static optional<Stored<ChannelAccept>> acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request); + + vector<uint8_t> encrypt(const vector<uint8_t> &) const; + optional<vector<uint8_t>> decrypt(const vector<uint8_t> &) const; + +private: + const vector<Stored<Signed<IdentityData>>> peers; + const vector<uint8_t> key; +}; + +} diff --git a/src/identity.cpp b/src/identity.cpp index cee2688..57f25cc 100644 --- a/src/identity.cpp +++ b/src/identity.cpp @@ -43,6 +43,11 @@ optional<Identity> Identity::owner() const return p->owner; } +Stored<PublicKey> Identity::keyMessage() const +{ + return p->keyMessage; +} + optional<Ref> Identity::ref() const { if (p->data.size() == 1) @@ -167,14 +172,20 @@ shared_ptr<Identity::Priv> Identity::Priv::validate(const vector<Stored<Signed<I if (!verifySignatures(d)) return nullptr; + auto keyMessageItem = lookupProperty(sdata, [] + (const IdentityData & d) { return d.keyMessage.has_value(); }); + if (!keyMessageItem) + return nullptr; + auto p = new Priv { .data = sdata, .name = {}, .owner = nullopt, + .keyMessage = keyMessageItem.value()->keyMessage.value(), }; shared_ptr<Priv> ret(p); - auto ownerProp = p->lookupProperty([] + auto ownerProp = lookupProperty(sdata, [] (const IdentityData & d) { return d.owner.has_value(); }); if (ownerProp) { auto owner = validate({ *ownerProp.value()->owner }); @@ -184,7 +195,7 @@ shared_ptr<Identity::Priv> Identity::Priv::validate(const vector<Stored<Signed<I } p->name = async(std::launch::deferred, [p] () -> optional<string> { - if (auto d = p->lookupProperty([] (const IdentityData & d) { return d.name.has_value(); })) + if (auto d = lookupProperty(p->data, [] (const IdentityData & d) { return d.name.has_value(); })) return d.value()->name; return nullopt; }); @@ -193,7 +204,8 @@ shared_ptr<Identity::Priv> Identity::Priv::validate(const vector<Stored<Signed<I } optional<Stored<IdentityData>> Identity::Priv::lookupProperty( - function<bool(const IdentityData &)> sel) const + const vector<Stored<Signed<IdentityData>>> & data, + function<bool(const IdentityData &)> sel) { set<Stored<Signed<IdentityData>>> current, prop_heads; diff --git a/src/identity.h b/src/identity.h index 4335d32..79b335e 100644 --- a/src/identity.h +++ b/src/identity.h @@ -29,11 +29,13 @@ public: vector<Stored<Signed<IdentityData>>> data; shared_future<optional<string>> name; optional<Identity> owner; + Stored<PublicKey> keyMessage; static bool verifySignatures(const Stored<Signed<IdentityData>> & sdata); static shared_ptr<Priv> validate(const vector<Stored<Signed<IdentityData>>> & sdata); - optional<Stored<IdentityData>> lookupProperty( - function<bool(const IdentityData &)> sel) const; + static optional<Stored<IdentityData>> lookupProperty( + const vector<Stored<Signed<IdentityData>>> & data, + function<bool(const IdentityData &)> sel); }; class Identity::Builder::Priv diff --git a/src/network.cpp b/src/network.cpp index bd1ea8e..e9aeb3f 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -4,6 +4,7 @@ #include <algorithm> #include <cstring> +#include <iostream> #include <ifaddrs.h> #include <net/if.h> @@ -77,7 +78,7 @@ Server::Priv::~Priv() void Server::Priv::doListen() { - vector<uint8_t> buf(4096); + vector<uint8_t> buf, decrypted, *current; unique_lock<mutex> lock(dataMutex); while (!finish) { @@ -85,25 +86,51 @@ void Server::Priv::doListen() lock.unlock(); socklen_t addrlen = sizeof(paddr); + buf.resize(4096); ssize_t ret = recvfrom(sock, buf.data(), buf.size(), 0, (sockaddr *) &paddr, &addrlen); if (ret < 0) throw std::system_error(errno, std::generic_category()); + buf.resize(ret); auto & peer = getPeer(paddr); + + current = &buf; + if (holds_alternative<Stored<Channel>>(peer.channel)) { + if (auto dec = std::get<Stored<Channel>>(peer.channel)->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } else if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) { + if (auto dec = std::get<Stored<ChannelAccept>>(peer.channel)-> + data->channel()->decrypt(buf)) { + decrypted = std::move(*dec); + current = &decrypted; + } + } + if (auto dec = PartialObject::decodePrefix(peer.partStorage, - buf.begin(), buf.begin() + ret)) { + current->begin(), current->end())) { if (auto header = TransportHeader::load(std::get<PartialObject>(*dec))) { auto pos = std::get<1>(*dec); while (auto cdec = PartialObject::decodePrefix(peer.partStorage, - pos, buf.begin() + ret)) { + pos, current->end())) { peer.partStorage.storeObject(std::get<PartialObject>(*cdec)); pos = std::get<1>(*cdec); } + + ReplyBuilder reply; + scoped_lock<mutex> hlock(dataMutex); - handlePacket(peer, *header); - peer.updateIdentity(); + handlePacket(peer, *header, reply); + peer.updateIdentity(reply); + peer.updateChannel(reply); + + if (!reply.header.empty()) + peer.send(TransportHeader(reply.header), reply.body); } + } else { + std::cerr << "invalid packet\n"; } lock.lock(); @@ -140,7 +167,7 @@ void Server::Priv::doAnnounce() } } -Peer & Server::Priv::getPeer(const sockaddr_in & paddr) +Server::Peer & Server::Priv::getPeer(const sockaddr_in & paddr) { for (auto & peer : peers) if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0) @@ -148,9 +175,10 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) auto st = self.ref()->storage().deriveEphemeralStorage(); Peer * peer = new Peer { - .sock = sock, + .server = *this, .addr = paddr, .identity = monostate(), + .channel = monostate(), .tempStorage = st, .partStorage = st.derivePartialStorage(), }; @@ -158,18 +186,21 @@ Peer & Server::Priv::getPeer(const sockaddr_in & paddr) return *peer; } -void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) +void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & header, ReplyBuilder & reply) { unordered_set<Digest> plaintextRefs; for (const auto & obj : collectStoredObjects(*Stored<Object>::load(*self.ref()))) plaintextRefs.insert(obj.ref.digest()); - vector<TransportHeader::Item> replyHeaders; - vector<Object> replyBody; - for (auto & item : header.items) { switch (item.type) { case TransportHeader::Type::Acknowledged: + if (auto pref = std::get<PartialRef>(item.value)) { + if (holds_alternative<Stored<ChannelAccept>>(peer.channel) && + std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() == pref.digest()) + peer.channel.emplace<Stored<Channel>> + (std::get<Stored<ChannelAccept>>(peer.channel)->data->channel()); + } break; case TransportHeader::Type::DataRequest: { @@ -177,8 +208,8 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) if (plaintextRefs.find(pref.digest()) != plaintextRefs.end()) { if (auto ref = peer.tempStorage.ref(pref.digest())) { TransportHeader::Item hitem { TransportHeader::Type::DataResponse, *ref }; - replyHeaders.push_back({ TransportHeader::Type::DataResponse, *ref }); - replyBody.push_back(**ref); + reply.header.push_back({ TransportHeader::Type::DataResponse, *ref }); + reply.body.push_back(**ref); } } break; @@ -186,12 +217,12 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) case TransportHeader::Type::DataResponse: if (auto pref = std::get<PartialRef>(item.value)) { - replyHeaders.push_back({ TransportHeader::Type::Acknowledged, pref }); + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); for (auto & pwref : waiting) { if (auto wref = pwref.lock()) { if (std::find(wref->missing.begin(), wref->missing.end(), pref.digest()) != wref->missing.end()) { - if (wref->check(&replyHeaders)) + if (wref->check(&reply.header)) pwref.reset(); } } @@ -207,7 +238,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; if (holds_alternative<monostate>(peer.identity)) - replyHeaders.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); + reply.header.push_back({ TransportHeader::Type::AnnounceSelf, *self.ref()}); shared_ptr<WaitingRef> wref(new WaitingRef { .storage = peer.tempStorage, @@ -217,7 +248,7 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) }); waiting.push_back(wref); peer.identity = wref; - wref->check(&replyHeaders); + wref->check(&reply.header); break; } @@ -225,9 +256,42 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) break; case TransportHeader::Type::ChannelRequest: + if (auto pref = std::get<PartialRef>(item.value)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + + if (holds_alternative<Stored<ChannelRequest>>(peer.channel) && + std::get<Stored<ChannelRequest>>(peer.channel).ref.digest() < pref.digest()) + break; + + if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) + break; + + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = pref, + .peer = peer, + .missing = {}, + }); + waiting.push_back(wref); + peer.channel = wref; + wref->check(&reply.header); + } break; case TransportHeader::Type::ChannelAccept: + if (auto pref = std::get<PartialRef>(item.value)) { + if (holds_alternative<Stored<ChannelAccept>>(peer.channel) && + std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() < pref.digest()) + break; + + auto cres = peer.tempStorage.copy(pref); + if (auto r = std::get_if<Ref>(&cres)) { + if (auto acc = ChannelAccept::load(*r)) { + reply.header.push_back({ TransportHeader::Type::Acknowledged, pref }); + peer.channel.emplace<Stored<Channel>>(acc->data->channel()); + } + } + } break; case TransportHeader::Type::ServiceType: @@ -238,14 +302,11 @@ void Server::Priv::handlePacket(Peer & peer, const TransportHeader & header) } } - - if (!replyHeaders.empty()) - peer.send(TransportHeader(replyHeaders), replyBody); } -void Peer::send(const TransportHeader & header, const vector<Object> & objs) const +void Server::Peer::send(const TransportHeader & header, const vector<Object> & objs) const { - vector<uint8_t> data, part; + vector<uint8_t> data, part, out; part = header.toObject().encode(); data.insert(data.end(), part.begin(), part.end()); @@ -254,18 +315,60 @@ void Peer::send(const TransportHeader & header, const vector<Object> & objs) con data.insert(data.end(), part.begin(), part.end()); } - sendto(sock, data.data(), data.size(), 0, + if (holds_alternative<Stored<Channel>>(channel)) + out = std::get<Stored<Channel>>(channel)->encrypt(data); + else + out = std::move(data); + + sendto(server.sock, out.data(), out.size(), 0, (sockaddr *) &addr, sizeof(addr)); } -void Peer::updateIdentity() +void Server::Peer::updateIdentity(ReplyBuilder & reply) { if (holds_alternative<shared_ptr<WaitingRef>>(identity)) - if (auto ref = std::get<shared_ptr<WaitingRef>>(identity)->check()) + if (auto ref = std::get<shared_ptr<WaitingRef>>(identity)->check(&reply.header)) if (auto id = Identity::load(*ref)) identity.emplace<Identity>(*id); } +void Server::Peer::updateChannel(ReplyBuilder & reply) +{ + if (!holds_alternative<Identity>(identity)) + return; + + if (holds_alternative<monostate>(channel)) { + auto req = Channel::generateRequest(tempStorage, + server.self, std::get<Identity>(identity)); + channel.emplace<Stored<ChannelRequest>>(req); + reply.header.push_back({ TransportHeader::Type::ChannelRequest, req.ref }); + reply.body.push_back(*req.ref); + reply.body.push_back(*req->data.ref); + reply.body.push_back(*req->data->key.ref); + for (const auto & sig : req->sigs) + reply.body.push_back(*sig.ref); + } + + if (holds_alternative<shared_ptr<WaitingRef>>(channel)) { + if (auto ref = std::get<shared_ptr<WaitingRef>>(channel)->check(&reply.header)) { + if (auto req = Stored<ChannelRequest>::load(*ref)) { + if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), *req)) { + channel.emplace<Stored<ChannelAccept>>(*acc); + reply.header.push_back({ TransportHeader::Type::ChannelAccept, acc->ref }); + reply.body.push_back(*acc->ref); + reply.body.push_back(*acc.value()->data.ref); + reply.body.push_back(*acc.value()->data->key.ref); + for (const auto & sig : acc.value()->sigs) + reply.body.push_back(*sig.ref); + } else { + channel = monostate(); + } + } else { + channel = monostate(); + } + } + } +} optional<Ref> WaitingRef::check(vector<TransportHeader::Item> * request) { diff --git a/src/network.h b/src/network.h index bb32323..e07e020 100644 --- a/src/network.h +++ b/src/network.h @@ -2,6 +2,8 @@ #include <erebos/network.h> +#include "channel.h" + #include <condition_variable> #include <mutex> #include <thread> @@ -26,23 +28,30 @@ using chrono::steady_clock; namespace erebos { -struct Peer +struct Server::Peer { Peer(const Peer &) = delete; Peer & operator=(const Peer &) = delete; - const int sock; + Priv & server; const sockaddr_in addr; variant<monostate, shared_ptr<struct WaitingRef>, Identity> identity; + variant<monostate, + Stored<ChannelRequest>, + shared_ptr<struct WaitingRef>, + Stored<ChannelAccept>, + Stored<Channel>> channel; + Storage tempStorage; PartialStorage partStorage; void send(const struct TransportHeader &, const vector<Object> &) const; - void updateIdentity(); + void updateIdentity(struct ReplyBuilder &); + void updateChannel(struct ReplyBuilder &); }; struct TransportHeader @@ -76,12 +85,18 @@ struct WaitingRef { const Storage storage; const PartialRef ref; - const Peer & peer; + const Server::Peer & peer; vector<Digest> missing; optional<Ref> check(vector<TransportHeader::Item> * request = nullptr); }; +struct ReplyBuilder +{ + vector<TransportHeader::Item> header; + vector<Object> body; +}; + struct Server::Priv { Priv(const Identity & self); @@ -90,7 +105,7 @@ struct Server::Priv void doAnnounce(); Peer & getPeer(const sockaddr_in & paddr); - void handlePacket(Peer &, const TransportHeader &); + void handlePacket(Peer &, const TransportHeader &, ReplyBuilder &); constexpr static uint16_t discoveryPort { 29665 }; constexpr static chrono::seconds announceInterval { 60 }; diff --git a/src/pubkey.cpp b/src/pubkey.cpp index e26bead..3a08c70 100644 --- a/src/pubkey.cpp +++ b/src/pubkey.cpp @@ -152,3 +152,107 @@ bool Signature::verify(const Ref & ref) const return EVP_DigestVerify(mdctx.get(), sig.data(), sig.size(), ref.digest().arr().data(), Digest::size) == 1; } + + +optional<PublicKexKey> PublicKexKey::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return nullopt; + + if (auto ktype = rec->item("type").asText()) + if (ktype.value() != "x25519") + throw runtime_error("unsupported key type " + ktype.value()); + + if (auto pubkey = rec->item("pubkey").asBinary()) + return PublicKexKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + pubkey.value().data(), pubkey.value().size())); + + return nullopt; +} + +Ref PublicKexKey::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("type", "x25519"); + + vector<uint8_t> keyData; + size_t keyLen; + EVP_PKEY_get_raw_public_key(key.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(key.get(), keyData.data(), &keyLen); + items.emplace_back("pubkey", keyData); + + return st.storeObject(Record(std::move(items))); +} + +SecretKexKey SecretKexKey::generate(const Storage & st) +{ + unique_ptr<EVP_PKEY_CTX, void(*)(EVP_PKEY_CTX*)> + pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_X25519, NULL), &EVP_PKEY_CTX_free); + if (!pctx) + throw runtime_error("failed to generate key"); + + if (EVP_PKEY_keygen_init(pctx.get()) != 1) + throw runtime_error("failed to generate key"); + + EVP_PKEY *pkey = NULL; + if (EVP_PKEY_keygen(pctx.get(), &pkey) != 1) + throw runtime_error("failed to generate key"); + shared_ptr<EVP_PKEY> seckey(pkey, EVP_PKEY_free); + + vector<uint8_t> keyData; + size_t keyLen; + + EVP_PKEY_get_raw_public_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(seckey.get(), keyData.data(), &keyLen); + auto pubkey = st.store(PublicKexKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + keyData.data(), keyData.size()))); + + EVP_PKEY_get_raw_private_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_private_key(seckey.get(), keyData.data(), &keyLen); + st.storeKey(pubkey.ref, keyData); + + return SecretKexKey(std::move(seckey), pubkey); +} + +optional<SecretKexKey> SecretKexKey::load(const Stored<PublicKexKey> & pub) +{ + auto keyData = pub.ref.storage().loadKey(pub.ref); + if (!keyData) + return nullopt; + + EVP_PKEY * key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, nullptr, + keyData->data(), keyData->size()); + if (!key) + throw runtime_error("falied to parse secret key"); + return SecretKexKey(key, pub); +} + +vector<uint8_t> SecretKexKey::dh(const PublicKexKey & pubkey) const +{ + unique_ptr<EVP_PKEY_CTX, void(*)(EVP_PKEY_CTX*)> + pctx(EVP_PKEY_CTX_new(key.get(), nullptr), &EVP_PKEY_CTX_free); + if (!pctx) + throw runtime_error("failed to derive shared secret"); + + if (EVP_PKEY_derive_init(pctx.get()) <= 0) + throw runtime_error("failed to derive shared secret"); + + if (EVP_PKEY_derive_set_peer(pctx.get(), pubkey.key.get()) <= 0) + throw runtime_error("failed to derive shared secret"); + + size_t dhlen; + if (EVP_PKEY_derive(pctx.get(), NULL, &dhlen) <= 0) + throw runtime_error("failed to derive shared secret"); + + vector<uint8_t> dhsecret(dhlen); + + if (EVP_PKEY_derive(pctx.get(), dhsecret.data(), &dhlen) <= 0) + throw runtime_error("failed to derive shared secret"); + + return dhsecret; +} diff --git a/src/pubkey.h b/src/pubkey.h index 80da3fa..b14743d 100644 --- a/src/pubkey.h +++ b/src/pubkey.h @@ -133,4 +133,35 @@ bool Signed<T>::isSignedBy(const Stored<PublicKey> & key) const return false; } + +class PublicKexKey +{ + PublicKexKey(EVP_PKEY * key): + key(key, EVP_PKEY_free) {} + friend class SecretKexKey; +public: + static optional<PublicKexKey> load(const Ref &); + Ref store(const Storage &) const; + + const shared_ptr<EVP_PKEY> key; +}; + +class SecretKexKey +{ + SecretKexKey(EVP_PKEY * key, const Stored<PublicKexKey> & pub): + key(key, EVP_PKEY_free), pub_(pub) {} + SecretKexKey(shared_ptr<EVP_PKEY> && key, const Stored<PublicKexKey> & pub): + key(key), pub_(pub) {} +public: + static SecretKexKey generate(const Storage & st); + static optional<SecretKexKey> load(const Stored<PublicKexKey> & st); + + Stored<PublicKexKey> pub() const { return pub_; } + vector<uint8_t> dh(const PublicKexKey &) const; + +private: + const shared_ptr<EVP_PKEY> key; + Stored<PublicKexKey> pub_; +}; + } |