diff options
-rw-r--r-- | include/erebos/storage.h | 76 | ||||
-rw-r--r-- | src/channel.cpp | 8 | ||||
-rw-r--r-- | src/identity.cpp | 12 | ||||
-rw-r--r-- | src/network.cpp | 28 | ||||
-rw-r--r-- | src/pubkey.cpp | 8 | ||||
-rw-r--r-- | src/pubkey.h | 6 | ||||
-rw-r--r-- | src/storage.cpp | 3 |
7 files changed, 89 insertions, 52 deletions
diff --git a/include/erebos/storage.h b/include/erebos/storage.h index b6595b6..4d7c691 100644 --- a/include/erebos/storage.h +++ b/include/erebos/storage.h @@ -1,5 +1,6 @@ #pragma once +#include <algorithm> #include <array> #include <chrono> #include <cstring> @@ -120,7 +121,9 @@ class PartialRef { public: PartialRef(const PartialRef &) = default; - PartialRef & operator=(const PartialRef &) = delete; + PartialRef(PartialRef &&) = default; + PartialRef & operator=(const PartialRef &) = default; + PartialRef & operator=(PartialRef &&) = default; static PartialRef create(PartialStorage, const Digest &); @@ -135,7 +138,7 @@ public: protected: friend class Storage; struct Priv; - const std::shared_ptr<const Priv> p; + std::shared_ptr<const Priv> p; PartialRef(const std::shared_ptr<const Priv> p): p(p) {} }; @@ -143,7 +146,9 @@ class Ref : public PartialRef { public: Ref(const Ref &) = default; - Ref & operator=(const Ref &) = delete; + Ref(Ref &&) = default; + Ref & operator=(const Ref &) = default; + Ref & operator=(Ref &&) = default; static std::optional<Ref> create(Storage, const Digest &); @@ -208,7 +213,7 @@ public: name(name), value(value) {} template<typename T> Item(const std::string & name, const Stored<T> & value): - Item(name, value.ref) {} + Item(name, value.ref()) {} Item(const Item &) = default; Item & operator=(const Item &) = delete; @@ -327,33 +332,42 @@ std::optional<Stored<T>> RecordT<S>::Item::as() const template<typename T> class Stored { - Stored(Ref ref, std::shared_ptr<T> val): ref(ref), val(val) {} + Stored(Ref ref, std::shared_ptr<T> val): mref(ref), mval(val) {} friend class Storage; public: + Stored(const Stored &) = default; + Stored(Stored &&) = default; + Stored & operator=(const Stored &) = default; + Stored & operator=(Stored &&) = default; + static std::optional<Stored<T>> load(const Ref &); Ref store(const Storage &) const; bool operator==(const Stored<T> & other) const - { return ref.digest() == other.ref.digest(); } + { return mref.digest() == other.mref.digest(); } bool operator!=(const Stored<T> & other) const - { return ref.digest() != other.ref.digest(); } + { return mref.digest() != other.mref.digest(); } bool operator<(const Stored<T> & other) const - { return ref.digest() < other.ref.digest(); } + { return mref.digest() < other.mref.digest(); } bool operator<=(const Stored<T> & other) const - { return ref.digest() <= other.ref.digest(); } + { return mref.digest() <= other.mref.digest(); } bool operator>(const Stored<T> & other) const - { return ref.digest() > other.ref.digest(); } + { return mref.digest() > other.mref.digest(); } bool operator>=(const Stored<T> & other) const - { return ref.digest() >= other.ref.digest(); } + { return mref.digest() >= other.mref.digest(); } - const T & operator*() const { return *val; } - const T * operator->() const { return val.get(); } + const T & operator*() const { return *mval; } + const T * operator->() const { return mval.get(); } std::vector<Stored<T>> previous() const; bool precedes(const Stored<T> &) const; - const Ref ref; - const std::shared_ptr<T> val; + const Ref & ref() const { return mref; } + const std::shared_ptr<T> & value() const { return mval; } + +private: + Ref mref; + std::shared_ptr<T> mval; }; template<typename T> @@ -373,15 +387,15 @@ std::optional<Stored<T>> Stored<T>::load(const Ref & ref) template<typename T> Ref Stored<T>::store(const Storage & st) const { - if (st == ref.storage()) - return ref; - return st.storeObject(*ref); + if (st == mref.storage()) + return mref; + return st.storeObject(*mref); } template<typename T> std::vector<Stored<T>> Stored<T>::previous() const { - auto rec = ref->asRecord(); + auto rec = mref->asRecord(); if (!rec) return {}; @@ -415,6 +429,30 @@ bool Stored<T>::precedes(const Stored<T> & other) const return false; } +template<typename T> +void filterAncestors(std::vector<Stored<T>> & xs) +{ + if (xs.size() < 2) + return; + + std::sort(xs.begin(), xs.end()); + xs.erase(std::unique(xs.begin(), xs.end()), xs.end()); + + std::vector<Stored<T>> old; + old.swap(xs); + + for (auto i = old.begin(); i != old.end(); i++) { + bool add = true; + for (auto j = i + 1; j != old.end(); j++) + if (i->precedes(*j)) { + add = false; + break; + } + if (add) + xs.push_back(std::move(*i)); + } +} + } namespace std diff --git a/src/channel.cpp b/src/channel.cpp index 38d263e..50c5f97 100644 --- a/src/channel.cpp +++ b/src/channel.cpp @@ -84,7 +84,7 @@ optional<ChannelAcceptData> ChannelAcceptData::load(const Ref & ref) Stored<Channel> ChannelAcceptData::channel() const { - const auto & st = request.ref.storage(); + const auto & st = request.ref().storage(); if (auto secret = SecretKexKey::load(key)) return st.store(Channel( @@ -166,12 +166,12 @@ optional<Stored<ChannelAccept>> Channel::acceptRequest(const Identity & self, auto & peers = request->data->peers; if (peers.size() != 2 || std::none_of(peers.begin(), peers.end(), [&self](const auto & x) - { return x.ref.digest() == self.ref()->digest(); }) || + { return x.ref().digest() == self.ref()->digest(); }) || std::none_of(peers.begin(), peers.end(), [&peer](const auto & x) - { return x.ref.digest() == peer.ref()->digest(); })) + { return x.ref().digest() == peer.ref()->digest(); })) return nullopt; - auto & st = request.ref.storage(); + auto & st = request.ref().storage(); auto signKey = SecretKey::load(self.keyMessage()); if (!signKey) diff --git a/src/identity.cpp b/src/identity.cpp index 61059ab..8f606ae 100644 --- a/src/identity.cpp +++ b/src/identity.cpp @@ -61,7 +61,7 @@ Stored<PublicKey> Identity::keyMessage() const optional<Ref> Identity::ref() const { if (p->data.size() == 1) - return p->data[0].ref; + return p->data[0].ref(); return nullopt; } @@ -77,7 +77,7 @@ Identity::Builder Identity::create(const Storage & st) Identity::Builder Identity::modify() const { return Builder (new Builder::Priv { - .storage = p->data[0].ref.storage(), + .storage = p->data[0].ref().storage(), .prev = p->data, .keyIdentity = p->data[0]->data->keyIdentity, .keyMessage = p->data[0]->data->keyMessage, @@ -146,14 +146,14 @@ Ref IdentityData::store(const Storage & st) const vector<Record::Item> items; for (const auto p : prev) - items.emplace_back("SPREV", p.ref); + items.emplace_back("SPREV", p.ref()); if (name) items.emplace_back("name", *name); if (owner) - items.emplace_back("owner", owner->ref); - items.emplace_back("key-id", keyIdentity.ref); + items.emplace_back("owner", owner->ref()); + items.emplace_back("key-id", keyIdentity.ref()); if (keyMessage) - items.emplace_back("key-msg", keyMessage->ref); + items.emplace_back("key-msg", keyMessage->ref()); return st.storeObject(Record(std::move(items))); } diff --git a/src/network.cpp b/src/network.cpp index ce0dd30..b31d949 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -299,7 +299,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea { unordered_set<Digest> plaintextRefs; for (const auto & obj : collectStoredObjects(*Stored<Object>::load(*self.ref()))) - plaintextRefs.insert(obj.ref.digest()); + plaintextRefs.insert(obj.ref().digest()); optional<UUID> serviceType; @@ -308,7 +308,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea case TransportHeader::Type::Acknowledged: if (auto pref = std::get<PartialRef>(item.value)) { if (holds_alternative<Stored<ChannelAccept>>(peer.channel) && - std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() == pref.digest()) + std::get<Stored<ChannelAccept>>(peer.channel).ref().digest() == pref.digest()) peer.channel.emplace<Stored<Channel>> (std::get<Stored<ChannelAccept>>(peer.channel)->data->channel()); } @@ -371,7 +371,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea reply.header({ TransportHeader::Type::Acknowledged, pref }); if (holds_alternative<Stored<ChannelRequest>>(peer.channel) && - std::get<Stored<ChannelRequest>>(peer.channel).ref.digest() < pref.digest()) + std::get<Stored<ChannelRequest>>(peer.channel).ref().digest() < pref.digest()) break; if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) @@ -392,7 +392,7 @@ void Server::Priv::handlePacket(Server::Peer & peer, const TransportHeader & hea case TransportHeader::Type::ChannelAccept: if (auto pref = std::get<PartialRef>(item.value)) { if (holds_alternative<Stored<ChannelAccept>>(peer.channel) && - std::get<Stored<ChannelAccept>>(peer.channel).ref.digest() < pref.digest()) + std::get<Stored<ChannelAccept>>(peer.channel).ref().digest() < pref.digest()) break; auto cres = peer.tempStorage.copy(pref); @@ -475,12 +475,12 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) auto req = Channel::generateRequest(tempStorage, server.self, std::get<Identity>(identity)); channel.emplace<Stored<ChannelRequest>>(req); - reply.header({ TransportHeader::Type::ChannelRequest, req.ref }); - reply.body(req.ref); - reply.body(req->data.ref); - reply.body(req->data->key.ref); + reply.header({ TransportHeader::Type::ChannelRequest, req.ref() }); + reply.body(req.ref()); + reply.body(req->data.ref()); + reply.body(req->data->key.ref()); for (const auto & sig : req->sigs) - reply.body(sig.ref); + reply.body(sig.ref()); } if (holds_alternative<shared_ptr<WaitingRef>>(channel)) { @@ -488,12 +488,12 @@ void Server::Peer::updateChannel(ReplyBuilder & reply) if (auto req = Stored<ChannelRequest>::load(*ref)) { if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), *req)) { channel.emplace<Stored<ChannelAccept>>(*acc); - reply.header({ TransportHeader::Type::ChannelAccept, acc->ref }); - reply.body(acc->ref); - reply.body(acc.value()->data.ref); - reply.body(acc.value()->data->key.ref); + reply.header({ TransportHeader::Type::ChannelAccept, acc->ref() }); + reply.body(acc->ref()); + reply.body(acc.value()->data.ref()); + reply.body(acc.value()->data->key.ref()); for (const auto & sig : acc.value()->sigs) - reply.body(sig.ref); + reply.body(sig.ref()); } else { channel = monostate(); } diff --git a/src/pubkey.cpp b/src/pubkey.cpp index 6f6c1e7..0e83136 100644 --- a/src/pubkey.cpp +++ b/src/pubkey.cpp @@ -68,14 +68,14 @@ SecretKey SecretKey::generate(const Storage & st) EVP_PKEY_get_raw_private_key(seckey.get(), nullptr, &keyLen); keyData.resize(keyLen); EVP_PKEY_get_raw_private_key(seckey.get(), keyData.data(), &keyLen); - st.storeKey(pubkey.ref, keyData); + st.storeKey(pubkey.ref(), keyData); return SecretKey(std::move(seckey), pubkey); } optional<SecretKey> SecretKey::load(const Stored<PublicKey> & pub) { - auto keyData = pub.ref.storage().loadKey(pub.ref); + auto keyData = pub.ref().storage().loadKey(pub.ref()); if (!keyData) return nullopt; @@ -211,14 +211,14 @@ SecretKexKey SecretKexKey::generate(const Storage & st) EVP_PKEY_get_raw_private_key(seckey.get(), nullptr, &keyLen); keyData.resize(keyLen); EVP_PKEY_get_raw_private_key(seckey.get(), keyData.data(), &keyLen); - st.storeKey(pubkey.ref, keyData); + st.storeKey(pubkey.ref(), keyData); return SecretKexKey(std::move(seckey), pubkey); } optional<SecretKexKey> SecretKexKey::load(const Stored<PublicKexKey> & pub) { - auto keyData = pub.ref.storage().loadKey(pub.ref); + auto keyData = pub.ref().storage().loadKey(pub.ref()); if (!keyData) return nullopt; diff --git a/src/pubkey.h b/src/pubkey.h index 607352d..c922dc7 100644 --- a/src/pubkey.h +++ b/src/pubkey.h @@ -84,8 +84,8 @@ private: template<class T> Stored<Signed<T>> SecretKey::sign(const Stored<T> & val) const { - auto st = val.ref.storage(); - auto sig = st.store(Signature(pub(), sign(val.ref.digest()))); + auto st = val.ref().storage(); + auto sig = st.store(Signature(pub(), sign(val.ref().digest()))); return st.store(Signed(val, { sig })); } @@ -103,7 +103,7 @@ optional<Signed<T>> Signed<T>::load(const Ref & ref) vector<Stored<Signature>> sigs; for (auto item : rec->items("sig")) if (auto sig = item.as<Signature>()) - if (sig.value()->verify(data.value().ref)) + if (sig.value()->verify(data.value().ref())) sigs.push_back(sig.value()); return Signed(*data, sigs); diff --git a/src/storage.cpp b/src/storage.cpp index 525d83d..49bac54 100644 --- a/src/storage.cpp +++ b/src/storage.cpp @@ -1,7 +1,6 @@ #include "storage.h" #include "base64.h" -#include <algorithm> #include <charconv> #include <chrono> #include <fstream> @@ -943,7 +942,7 @@ vector<Stored<Object>> erebos::collectStoredObjects(const Stored<Object> & from) auto cur = queue.back(); queue.pop_back(); - auto [it, added] = seen.insert(cur.ref.digest()); + auto [it, added] = seen.insert(cur.ref().digest()); if (!added) continue; |