diff options
61 files changed, 10683 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5681188 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +build/ +.erebos/ +.test/ +.minici/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c7f6b82 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Revision history for erebos-cpp + +## 0.1.0 -- 2024-02-18 + +* First version. diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..4a3727a --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.10) +project(Erebos) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") + +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wno-unqualified-std-cast-call) +endif() + +find_package(Threads REQUIRED) +find_package(ZLIB REQUIRED) +find_package(OpenSSL REQUIRED) +find_library(B2_LIBRARY b2 REQUIRED) + +add_subdirectory(src) @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..acf2c19 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +all: build/Makefile + +make -C build +.PHONY: all + +build/Makefile: + mkdir -p build + (cd build; cmake ..) + +clean: + rm -rf build +.PHONY: clean diff --git a/README.md b/README.md new file mode 100644 index 0000000..04ff4f2 --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +Erebos/C++ +========== + +Library implementing the Erebos identity management, decentralized messaging +and synchronization protocol, along with local storage. Specification being +created at: + +[http://erebosprotocol.net](http://erebosprotocol.net) + +Erebos identity is based on locally stored cryptographic keys, all +communication is end-to-end encrypted. Multiple devices can be attached to the +same identity, after which they function interchangeably, without any one being +in any way "primary"; messages and other state data are then synchronized +automatically whenever the devices are able to connect with one another. + +Status +------ + +This is experimental implementation of yet unfinished specification, so +changes, especially in the library API, are expected. Storage format and +network protocol should generally remain backward compatible, with their +respective versions to be increased in case of incompatible changes, to allow +for interoperability even in that case. + +Build +----- + +This library uses CMake for building: + +``` +cmake -B build +cmake --build build +``` + +Usage +----- + +The API is currently experimental and without documentation; some example of +usage can be found in the test executable (`src/main.cpp`). diff --git a/erebos-tester.yaml b/erebos-tester.yaml new file mode 100644 index 0000000..ab77de7 --- /dev/null +++ b/erebos-tester.yaml @@ -0,0 +1,2 @@ +tool: build/src/erebos +tests: test/**/*.test diff --git a/include/erebos/attach.h b/include/erebos/attach.h new file mode 100644 index 0000000..6ea3d64 --- /dev/null +++ b/include/erebos/attach.h @@ -0,0 +1,49 @@ +#pragma once + +#include <erebos/pairing.h> + +#include <future> +#include <mutex> +#include <optional> +#include <string> +#include <vector> + +namespace erebos { + +using std::mutex; +using std::optional; +using std::promise; +using std::string; +using std::vector; + +struct AttachIdentity; + +class AttachService : public PairingService<AttachIdentity> +{ +public: + AttachService(Config &&, const Server &); + virtual ~AttachService(); + + UUID uuid() const override; + + void attachTo(const Peer &); + +protected: + virtual Stored<AttachIdentity> handlePairingComplete(const Peer &) override; + virtual void handlePairingResult(Context &, Stored<AttachIdentity>) override; + + mutex handlerLock; +}; + +template<class T> class Signed; + +struct AttachIdentity +{ + Stored<Signed<struct IdentityData>> identity; + vector<vector<uint8_t>> keys; + + static AttachIdentity load(const Ref &); + Ref store(const Storage &) const; +}; + +} diff --git a/include/erebos/contact.h b/include/erebos/contact.h new file mode 100644 index 0000000..9008ce7 --- /dev/null +++ b/include/erebos/contact.h @@ -0,0 +1,98 @@ +#pragma once + +#include <erebos/identity.h> +#include <erebos/list.h> +#include <erebos/pairing.h> +#include <erebos/set.h> +#include <erebos/state.h> +#include <erebos/storage.h> + +#include <memory> +#include <optional> +#include <string> +#include <vector> + +namespace erebos { + +using std::optional; +using std::shared_ptr; +using std::string; +using std::vector; + +struct ContactData; + +class Contact +{ +public: + Contact(vector<Stored<ContactData>> data); + Contact(const Contact &) = default; + Contact(Contact &&) = default; + Contact & operator=(const Contact &) = default; + Contact & operator=(Contact &&) = default; + + optional<Identity> identity() const; + optional<string> customName() const; + Contact customName(const Storage & st, const string & name) const; + string name() const; + + bool operator==(const Contact &) const; + bool operator!=(const Contact &) const; + + vector<Stored<ContactData>> data() const; + Digest leastRoot() const; + +private: + struct Priv; + shared_ptr<Priv> p; + Contact(shared_ptr<Priv> p): p(p) {} + + friend class ContactService; +}; + +DECLARE_SHARED_TYPE(Set<Contact>) + +struct ContactData +{ + static ContactData load(const Ref &); + Ref store(const Storage &) const; + + vector<Stored<ContactData>> prev; + vector<StoredIdentityPart> identity; + optional<string> name; +}; + +template<> struct Mergeable<Contact> +{ + using Component = ContactData; + static vector<Stored<ContactData>> components(const Contact & c) { return c.data(); } + static Contact merge(vector<Stored<ContactData>> x) { return Contact(move(x)); } +}; + +struct ContactAccepted; + +class ContactService : public PairingService<ContactAccepted> +{ +public: + ContactService(Config &&, const Server &); + virtual ~ContactService(); + + UUID uuid() const override; + + void request(const Peer &); + +protected: + virtual Stored<ContactAccepted> handlePairingComplete(const Peer &) override; + virtual void handlePairingResult(Context &, Stored<ContactAccepted>) override; + + const Server & server; +}; + +template<class T> class Signed; + +struct ContactAccepted +{ + static ContactAccepted load(const Ref &); + Ref store(const Storage &) const; +}; + +} diff --git a/include/erebos/frp.h b/include/erebos/frp.h new file mode 100644 index 0000000..72b5cc9 --- /dev/null +++ b/include/erebos/frp.h @@ -0,0 +1,284 @@ +#pragma once + +#include <functional> +#include <memory> +#include <optional> +#include <functional> +#include <tuple> +#include <variant> +#include <vector> + +namespace erebos { + +using std::enable_if_t; +using std::function; +using std::is_same_v; +using std::make_shared; +using std::monostate; +using std::optional; +using std::shared_ptr; +using std::static_pointer_cast; +using std::vector; +using std::weak_ptr; + +class BhvCurTime; + +class BhvTime +{ + BhvTime(uint64_t t): t(t) {} + friend BhvCurTime; +public: + BhvTime(const BhvCurTime &); + + bool operator==(const BhvTime & other) const { return t == other.t; } + bool operator!=(const BhvTime & other) const { return t != other.t; } + bool operator<(const BhvTime & other) const { return t < other.t; } + bool operator<=(const BhvTime & other) const { return t <= other.t; } + bool operator>(const BhvTime & other) const { return t > other.t; } + bool operator>=(const BhvTime & other) const { return t >= other.t; } + +private: + uint64_t t; +}; + +class BhvCurTime +{ +public: + BhvCurTime(); + ~BhvCurTime(); + BhvCurTime(const BhvCurTime &) = delete; + BhvCurTime(BhvCurTime &&); + + BhvCurTime & operator=(const BhvCurTime &) = delete; + BhvCurTime & operator=(BhvCurTime &&); + + BhvTime time() const { return t.value(); } + +private: + optional<BhvTime> t; +}; + +template<typename T> +class Watched +{ +public: + Watched() = default; + Watched(const Watched<T> &) = default; + Watched & operator=(const Watched<T> &) = default; + Watched(Watched<T> &&) = default; + Watched & operator=(Watched<T> &&) = default; + + Watched(shared_ptr<function<void(const BhvCurTime &)>> && cb): + cb(move(cb)) {} + ~Watched(); + +private: + shared_ptr<function<void(const BhvCurTime &)>> cb; +}; + +template<typename T> +Watched<T>::~Watched() +{ + BhvCurTime ctime; + cb.reset(); +} + +class BhvImplBase : public std::enable_shared_from_this<BhvImplBase> +{ +public: + virtual ~BhvImplBase(); + +protected: + void dependsOn(const BhvCurTime &, shared_ptr<BhvImplBase> other); + void updated(const BhvCurTime &); + virtual bool needsUpdate(const BhvCurTime &) const; + virtual void doUpdate(const BhvCurTime &); + + bool isDirty(const BhvCurTime &) const { return dirty; } + + vector<weak_ptr<function<void(const BhvCurTime &)>>> watchers; +private: + void markDirty(const BhvCurTime &, vector<shared_ptr<BhvImplBase>> &); + void updateDirty(const BhvCurTime &); + + bool dirty = false; + vector<shared_ptr<BhvImplBase>> depends; + vector<weak_ptr<BhvImplBase>> rdepends; + + template<typename A, typename B> friend class BhvFun; +}; + +template<typename A, typename B> +class BhvImpl : public BhvImplBase +{ +public: + virtual B get(const BhvCurTime &, const A &) const = 0; +}; + +template<typename A> +using BhvSource = BhvImpl<monostate, A>; + +template<typename A, typename B> +class BhvFun +{ +public: + BhvFun(shared_ptr<BhvImpl<A, B>> impl): + impl(move(impl)) {} + + template<typename T> BhvFun(shared_ptr<T> impl): + BhvFun(static_pointer_cast<BhvImpl<A, B>>(impl)) {} + + B get(const A & x) const + { + BhvCurTime ctime; + return impl->get(ctime, x); + } + + template<typename C> BhvFun<A, C> lens() const; + + const shared_ptr<BhvImpl<A, B>> impl; +}; + +template<typename A> +class BhvFun<monostate, A> +{ +public: + BhvFun(shared_ptr<BhvSource<A>> impl): + impl(move(impl)) {} + + template<typename T> BhvFun(shared_ptr<T> impl): + BhvFun(static_pointer_cast<BhvSource<A>>(impl)) {} + + A get() const + { + BhvCurTime ctime; + return impl->get(ctime, monostate()); + } + Watched<A> watch(function<void(const A &)>); + + template<typename C> BhvFun<monostate, C> lens() const; + + const shared_ptr<BhvSource<A>> impl; +}; + +template<typename A> +using Bhv = BhvFun<monostate, A>; + +template<typename A> +Watched<A> Bhv<A>::watch(function<void(const A &)> f) +{ + BhvCurTime ctime; + auto & impl = BhvFun<monostate, A>::impl; + if (impl->needsUpdate(ctime)) + impl->doUpdate(ctime); + + auto cb = make_shared<function<void(const BhvCurTime &)>>( + [impl = BhvFun<monostate, A>::impl, f] (const BhvCurTime & ctime) { + f(impl->get(ctime, monostate())); + }); + + impl->watchers.push_back(cb); + f(impl->get(ctime, monostate())); + return Watched<A>(move(cb)); +} + + +template<typename A, typename B> +class BhvLambda : public BhvImpl<A, B> +{ +public: + BhvLambda(function<B(const A &)> f): f(f) {} + + B get(const BhvCurTime &, const A & x) const override + { return f(x); } + +private: + function<B(const A &)> f; +}; + +template<typename A, typename B> +BhvFun<A, B> bfun(function<B(const A &)> f) +{ + return make_shared<BhvLambda<A, B>>(f); +} + + +template<typename A, typename B, typename C> class BhvComp; +template<typename A, typename B, typename C> +BhvFun<A, C> operator>>(const BhvFun<A, B> & f, const BhvFun<B, C> & g); + +template<typename A, typename B, typename C> +class BhvComp : public BhvImpl<A, C> +{ +public: + BhvComp(const BhvFun<A, B> & f, const BhvFun<B, C>): + f(f), g(g) {} + + C get(const BhvCurTime & ctime, const A & x) const override + { return g.impl.get(ctime, f.impl.get(ctime, x)); } + +private: + BhvFun<A, B> f; + BhvFun<B, C> g; + + friend BhvFun<A, C> operator>> <A, B, C>(const BhvFun<A, B> &, const BhvFun<B, C> &); +}; + +template<typename B, typename C> +class BhvComp<monostate, B, C> : public BhvSource<C> +{ +public: + BhvComp(const BhvFun<monostate, B> & f, const BhvFun<B, C> & g): + f(f), g(g) {} + + bool needsUpdate(const BhvCurTime & ctime) const override + { return !x || g.impl->get(ctime, f.impl->get(ctime, monostate())) != x.value(); } + + void doUpdate(const BhvCurTime & ctime) override + { x = g.impl->get(ctime, f.impl->get(ctime, monostate())); } + + C get(const BhvCurTime & ctime, const monostate & m) const override + { return x && !BhvImplBase::isDirty(ctime) ? x.value() : g.impl->get(ctime, f.impl->get(ctime, m)); } + +private: + BhvFun<monostate, B> f; + BhvFun<B, C> g; + optional<C> x; + + friend BhvFun<monostate, C> operator>> <monostate, B, C>(const BhvFun<monostate, B> &, const BhvFun<B, C> &); +}; + +template<typename A, typename B, typename C> +BhvFun<A, C> operator>>(const BhvFun<A, B> & f, const BhvFun<B, C> & g) +{ + BhvCurTime ctime; + auto impl = make_shared<BhvComp<A, B, C>>(f, g); + impl->dependsOn(ctime, f.impl); + impl->dependsOn(ctime, g.impl); + return impl; +} + + +template<typename A, typename B> +class BhvLens : public BhvImpl<A, B> +{ +public: + B get(const BhvCurTime &, const A & x) const override + { return A::template lens<B>(x); } +}; + +template<typename A, typename B> +template<typename C> +BhvFun<A, C> BhvFun<A, B>::lens() const +{ + return *this >> BhvFun<B, C>(make_shared<BhvLens<B, C>>()); +} + +template<typename A> +template<typename C> +BhvFun<monostate, C> BhvFun<monostate, A>::lens() const +{ + return *this >> BhvFun<A, C>(make_shared<BhvLens<A, C>>()); +} + +} diff --git a/include/erebos/identity.h b/include/erebos/identity.h new file mode 100644 index 0000000..fa8fde3 --- /dev/null +++ b/include/erebos/identity.h @@ -0,0 +1,109 @@ +#pragma once + +#include <erebos/state.h> +#include <erebos/storage.h> + +namespace erebos { + +using std::optional; +using std::vector; + +template<class T> class Signed; +struct IdentityData; +struct StoredIdentityPart; + +class Identity +{ +public: + Identity(const Identity &) = default; + Identity(Identity &&) = default; + Identity & operator=(const Identity &) = default; + Identity & operator=(Identity &&) = default; + + static std::optional<Identity> load(const Ref &); + static std::optional<Identity> load(const std::vector<Ref> &); + static std::optional<Identity> load(const std::vector<Stored<Signed<IdentityData>>> &); + static std::optional<Identity> load(const std::vector<StoredIdentityPart> &); + std::vector<Ref> store() const; + std::vector<Ref> store(const Storage & st) const; + vector<Stored<Signed<IdentityData>>> data() const; + vector<StoredIdentityPart> extData() const; + + std::optional<std::string> name() const; + std::optional<Identity> owner() const; + const Identity & finalOwner() const; + + Stored<class PublicKey> keyIdentity() const; + Stored<class PublicKey> keyMessage() const; + + bool sameAs(const Identity &) const; + bool operator==(const Identity & other) const; + bool operator!=(const Identity & other) const; + + std::optional<Ref> ref() const; + std::optional<Ref> extRef() const; + std::vector<Ref> refs() const; + std::vector<Ref> extRefs() const; + std::vector<Ref> updates() const; + + class Builder + { + public: + Identity commit() const; + + void name(const std::string &); + void owner(const Identity &); + + private: + friend class Identity; + struct Priv; + const std::shared_ptr<Priv> p; + Builder(Priv * p); + }; + + static Builder create(const Storage &); + Builder modify() const; + Identity update(const vector<Stored<Signed<IdentityData>>> &) const; + Identity update(const vector<StoredIdentityPart> &) const; + +private: + struct Priv; + std::shared_ptr<const Priv> p; + Identity(const Priv * p); + Identity(std::shared_ptr<const Priv> && p); +}; + +struct IdentityData; +struct IdentityExtension; + +struct StoredIdentityPart +{ + using Part = variant< + Stored<Signed<IdentityData>>, + Stored<Signed<IdentityExtension>>>; + + StoredIdentityPart(Part p): part(move(p)) {} + + static StoredIdentityPart load(const Ref &); + Ref store(const Storage & st) const; + + bool operator==(const StoredIdentityPart & other) const + { return part == other.part; } + bool operator<(const StoredIdentityPart & other) const + { return part < other.part; } + + const Ref & ref() const; + const Stored<Signed<IdentityData>> & base() const; + + vector<StoredIdentityPart> previous() const; + vector<Digest> roots() const; + optional<string> name() const; + optional<StoredIdentityPart> owner() const; + bool isSignedBy(const Stored<PublicKey> &) const; + + Part part; +}; + +DECLARE_SHARED_TYPE(optional<Identity>) + +} diff --git a/include/erebos/list.h b/include/erebos/list.h new file mode 100644 index 0000000..f5f2d3f --- /dev/null +++ b/include/erebos/list.h @@ -0,0 +1,116 @@ +#pragma once + +#include <functional> +#include <memory> +#include <mutex> +#include <variant> + +namespace erebos { + +using std::function; +using std::make_shared; +using std::make_unique; +using std::move; +using std::shared_ptr; +using std::unique_ptr; +using std::variant; + +template<typename T> +class List +{ +public: + struct Nil { bool operator==(const Nil &) const { return true; } }; + struct Cons { + T head; List<T> tail; + bool operator==(const Cons & x) const { return head == x.head && tail == x.tail; } + }; + + List(); + List(const T head, List<T> tail); + + const T & front() const; + const List & tail() const; + + bool empty() const; + + bool operator==(const List<T> &) const; + bool operator!=(const List<T> &) const; + + List push_front(T x) const; + +private: + struct Priv; + shared_ptr<Priv> p; +}; + +template<typename T> +struct List<T>::Priv +{ + variant<Nil, Cons> value; + + function<void()> eval = {}; + mutable std::once_flag once = {}; +}; + +template<typename T> +List<T>::List(): + p(shared_ptr<Priv>(new Priv { Nil() })) +{ + std::call_once(p->once, [](){}); +} + +template<typename T> +List<T>::List(T head, List<T> tail): + p(shared_ptr<Priv>(new Priv { + Cons { move(head), move(tail) } + })) +{ + std::call_once(p->once, [](){}); +} + +template<typename T> +const T & List<T>::front() const +{ + std::call_once(p->once, p->eval); + return std::get<Cons>(p->value).head; +} + +template<typename T> +const List<T> & List<T>::tail() const +{ + std::call_once(p->once, p->eval); + return std::get<Cons>(p->value).tail; +} + +template<typename T> +bool List<T>::empty() const +{ + std::call_once(p->once, p->eval); + return std::holds_alternative<Nil>(p->value); +} + +template<typename T> +bool List<T>::operator==(const List<T> & other) const +{ + if (p == other.p) + return true; + + std::call_once(p->once, p->eval); + std::call_once(other.p->once, other.p->eval); + return p->value == other.p->value; + +} + +template<typename T> +bool List<T>::operator!=(const List<T> & other) const +{ + return !(*this == other); +} + +template<typename T> +List<T> List<T>::push_front(T x) const +{ + return List<T>(move(x), *this); +} + +} diff --git a/include/erebos/merge.h b/include/erebos/merge.h new file mode 100644 index 0000000..9705e94 --- /dev/null +++ b/include/erebos/merge.h @@ -0,0 +1,73 @@ +#pragma once + +#include <erebos/storage.h> + +#include <optional> +#include <vector> + +namespace erebos +{ + +using std::nullopt; +using std::optional; +using std::vector; + +template<class T> struct Mergeable +{ +}; + +template<> struct Mergeable<vector<Stored<Object>>> +{ + using Component = Object; + + static vector<Stored<Object>> components(const vector<Stored<Object>> & x) { return x; } + static vector<Stored<Object>> merge(const vector<Stored<Object>> & x) { return x; } +}; + +vector<Stored<Object>> findPropertyObjects(const vector<Stored<Object>> & leaves, const string & prop); + +template<typename T> +optional<Stored<typename Mergeable<T>::Component>> findPropertyComponent(const vector<Stored<typename Mergeable<T>::Component>> & components, const string & prop) +{ + vector<Stored<Object>> leaves; + leaves.reserve(components.size()); + + for (const auto & c : components) + leaves.push_back(Stored<Object>::load(c.ref())); + + auto candidates = findPropertyObjects(leaves, prop); + if (!candidates.empty()) + return Stored<typename Mergeable<T>::Component>::load(candidates[0].ref()); + return nullopt; +} + +template<typename T> +optional<Stored<typename Mergeable<T>::Component>> findPropertyComponent(const T & x, const string & prop) +{ + return findPropertyComponent(x.components(), prop); +} + +template<typename T> +vector<Stored<typename Mergeable<T>::Component>> findPropertyComponents(const vector<Stored<typename Mergeable<T>::Component>> & components, const string & prop) +{ + vector<Stored<Object>> leaves; + leaves.reserve(components.size()); + + for (const auto & c : components) + leaves.push_back(Stored<Object>::load(c.ref())); + + auto candidates = findPropertyObjects(leaves, prop); + vector<Stored<typename Mergeable<T>::Component>> result; + result.reserve(candidates.size()); + for (const auto & obj : candidates) + result.push_back(Stored<typename Mergeable<T>::Component>::load(obj.ref())); + return result; +} + +template<typename T> +vector<Stored<typename Mergeable<T>::Component>> findPropertyComponents(const T & x, const string & prop) +{ + return findPropertyComponents(x.components(), prop); +} + +} diff --git a/include/erebos/message.h b/include/erebos/message.h new file mode 100644 index 0000000..b52b84b --- /dev/null +++ b/include/erebos/message.h @@ -0,0 +1,169 @@ +#pragma once + +#include <erebos/merge.h> +#include <erebos/service.h> + +#include <condition_variable> +#include <deque> +#include <functional> +#include <memory> +#include <mutex> +#include <optional> +#include <string> +#include <tuple> + +namespace erebos { + +using std::condition_variable; +using std::deque; +using std::mutex; +using std::tuple; +using std::unique_ptr; + +class Contact; +class Identity; +struct DirectMessageState; + +class DirectMessage +{ +public: + const std::optional<Identity> & from() const; + const std::optional<struct ZonedTime> & time() const; + std::string text() const; + +private: + friend class DirectMessageThread; + friend class DirectMessageService; + struct Priv; + DirectMessage(Priv *); + std::shared_ptr<Priv> p; +}; + +class DirectMessageThread +{ +public: + class Iterator + { + struct Priv; + Iterator(Priv *); + public: + using iterator_category = std::forward_iterator_tag; + using value_type = DirectMessage; + using difference_type = ssize_t; + using pointer = const DirectMessage *; + using reference = const DirectMessage &; + + Iterator(const Iterator &); + ~Iterator(); + Iterator & operator=(const Iterator &); + Iterator & operator++(); + value_type operator*() const; + bool operator==(const Iterator &) const; + bool operator!=(const Iterator &) const; + + private: + friend DirectMessageThread; + std::unique_ptr<Priv> p; + }; + + Iterator begin() const; + Iterator end() const; + + size_t size() const; + DirectMessage at(size_t) const; + + const Identity & peer() const; + +private: + friend class DirectMessageService; + friend class DirectMessageThreads; + struct Priv; + DirectMessageThread(Priv *); + std::shared_ptr<Priv> p; +}; + +class DirectMessageThreads +{ +public: + DirectMessageThreads(); + DirectMessageThreads(Stored<DirectMessageState>); + DirectMessageThreads(vector<Stored<DirectMessageState>>); + + static DirectMessageThreads load(const vector<Ref> & refs); + vector<Ref> store() const; + vector<Stored<DirectMessageState>> data() const; + + bool operator==(const DirectMessageThreads &) const; + bool operator!=(const DirectMessageThreads &) const; + + DirectMessageThread thread(const Identity &) const; + +private: + vector<Stored<DirectMessageState>> state; + + friend class DirectMessageService; +}; + +DECLARE_SHARED_TYPE(DirectMessageThreads) + +template<> struct Mergeable<DirectMessageThreads> +{ + using Component = DirectMessageState; + static vector<Stored<DirectMessageState>> components(const DirectMessageThreads &); + static Contact merge(vector<Stored<DirectMessageState>>); +}; + +class DirectMessageService : public Service +{ +public: + using ThreadWatcher = std::function<void(const DirectMessageThread &, ssize_t, ssize_t)>; + + class Config + { + public: + Config & onUpdate(ThreadWatcher); + + private: + friend class DirectMessageService; + vector<ThreadWatcher> watchers; + }; + + DirectMessageService(Config &&, const Server &); + virtual ~DirectMessageService(); + + UUID uuid() const override; + void handle(Context &) override; + + DirectMessageThread thread(const Identity &); + + static DirectMessage send(const Head<LocalState> &, const Identity &, const std::string &); + static DirectMessage send(const Head<LocalState> &, const Contact &, const std::string &); + static DirectMessage send(const Head<LocalState> &, const Peer &, const std::string &); + + DirectMessage send(const Identity &, const std::string &); + DirectMessage send(const Contact &, const std::string &); + DirectMessage send(const Peer &, const std::string &); + +private: + void updateHandler(const DirectMessageThreads &); + void peerWatcher(size_t, const class Peer *); + void syncWithPeer(const DirectMessageThread &, const Peer &); + void doSyncWithPeers(); + void doSyncWithPeer(const DirectMessageThread &, const Peer &); + + const Config config; + const Server & server; + + vector<Stored<DirectMessageState>> prevState; + mutex stateMutex; + + mutex peerSyncMutex; + condition_variable peerSyncCond; + bool peerSyncRun; + deque<tuple<DirectMessageThread, Peer>> peerSyncQueue; + std::thread peerSyncThread; + + Watched<DirectMessageThreads> watched; +}; + +} diff --git a/include/erebos/network.h b/include/erebos/network.h new file mode 100644 index 0000000..2761a40 --- /dev/null +++ b/include/erebos/network.h @@ -0,0 +1,135 @@ +#pragma once + +#include <erebos/service.h> +#include <erebos/state.h> + +#include <functional> +#include <typeinfo> + +struct sockaddr_in6; + +namespace erebos { + +using std::vector; +using std::unique_ptr; + +class ServerConfig; +class Peer; + +class Server +{ + struct Priv; +public: + Server(const Head<LocalState> &, ServerConfig &&); + Server(const std::shared_ptr<Priv> &); + ~Server(); + + Server(const Server &) = delete; + Server & operator=(const Server &) = delete; + + const Head<LocalState> & localHead() const; + const Bhv<LocalState> & localState() const; + + Identity identity() const; + template<class S> S & svc(); + + class PeerList & peerList() const; + optional<erebos::Peer> peer(const Identity &) const; + void addPeer(const string & node) const; + void addPeer(const string & node, const string & service) const; + + struct Peer; +private: + Service & svcHelper(const std::type_info &); + + const std::shared_ptr<Priv> p; +}; + +class ServerConfig +{ +public: + ServerConfig() = default; + ServerConfig(const ServerConfig &) = delete; + ServerConfig(ServerConfig &&) = default; + ServerConfig & operator=(const ServerConfig &) = delete; + ServerConfig & operator=(ServerConfig &&) = default; + + template<class S> + typename S::Config & service(); + +private: + friend class Server; + vector<function<unique_ptr<Service>(const Server &)>> services; +}; + +template<class S> +S & Server::svc() +{ + return dynamic_cast<S&>(svcHelper(typeid(S))); +} + +template<class S> +typename S::Config & ServerConfig::service() +{ + auto config = make_shared<typename S::Config>(); + auto & configRef = *config; + + services.push_back([config = move(config)](const Server & server) { + return make_unique<S>(move(*config), server); + }); + + return configRef; +} + +class Peer +{ +public: + struct Priv; + Peer(const std::shared_ptr<Priv> & p); + ~Peer(); + + Server server() const; + + const Storage & tempStorage() const; + const PartialStorage & partialStorage() const; + + std::string name() const; + std::optional<Identity> identity() const; + const struct sockaddr_in6 & address() const; + string addressStr() const; + uint16_t port() const; + + bool send(UUID, const Ref &) const; + bool send(UUID, const Object &) const; + + bool operator==(const Peer & other) const; + bool operator!=(const Peer & other) const; + bool operator<(const Peer & other) const; + bool operator<=(const Peer & other) const; + bool operator>(const Peer & other) const; + bool operator>=(const Peer & other) const; + +private: + bool send(UUID, const Ref &, const Object &) const; + std::shared_ptr<Priv> p; +}; + +class PeerList +{ +public: + struct Priv; + PeerList(); + PeerList(const std::shared_ptr<Priv> & p); + ~PeerList(); + + size_t size() const; + Peer at(size_t n) const; + + void onUpdate(std::function<void(size_t, const Peer *)>); + +private: + friend Server; + const std::shared_ptr<Priv> p; +}; + +} diff --git a/include/erebos/pairing.h b/include/erebos/pairing.h new file mode 100644 index 0000000..71c9288 --- /dev/null +++ b/include/erebos/pairing.h @@ -0,0 +1,133 @@ +#pragma once + +#include <erebos/identity.h> +#include <erebos/network.h> +#include <erebos/service.h> + +#include <future> +#include <map> +#include <mutex> +#include <string> +#include <variant> +#include <vector> + +namespace erebos { + +using std::function; +using std::future; +using std::map; +using std::mutex; +using std::promise; +using std::string; +using std::variant; +using std::vector; + +/** + * Template-less base class for the paring functionality that does not depend + * on the result parameter. + */ +class PairingServiceBase : public Service +{ +public: + enum class Outcome + { + Success, + PeerRejected, + UserRejected, + UnexpectedMessage, + NonceMismatch, + Stale, + }; + + using RequestInitHook = function<void(const Peer &)>; + using ConfirmHook = function<future<bool>(const Peer &, string, future<Outcome> &&)>; + using RequestNonceFailedHook = function<void(const Peer &)>; + + class Config + { + public: + Config & onRequestInit(RequestInitHook); + Config & onResponse(PairingServiceBase::ConfirmHook); + Config & onRequest(PairingServiceBase::ConfirmHook); + Config & onRequestNonceFailed(RequestNonceFailedHook); + + private: + friend class PairingServiceBase; + RequestInitHook requestInitHook; + ConfirmHook responseHook; + ConfirmHook requestHook; + RequestNonceFailedHook requestNonceFailedHook; + }; + + PairingServiceBase(Config &&); + virtual ~PairingServiceBase(); + +protected: + void requestPairing(UUID serviceId, const Peer & peer); + virtual void handle(Context &) override; + virtual Ref handlePairingCompleteRef(const Peer &) = 0; + virtual void handlePairingResult(Context &) = 0; + +private: + static vector<uint8_t> nonceDigest(const Identity & id1, const Identity & id2, + const vector<uint8_t> & nonce1, const vector<uint8_t> & nonce2); + static string confirmationNumber(const vector<uint8_t> &); + + const Config config; + optional<Ref> result; + + enum class StatePhase { + NoPairing, + OurRequest, + OurRequestConfirm, + OurRequestReady, + PeerRequest, + PeerRequestConfirm, + PairingDone, + PairingFailed + }; + + struct State { + mutex lock; + StatePhase phase; + optional<Identity> idReq; + optional<Identity> idRsp; + vector<uint8_t> nonce; + vector<uint8_t> peerCheck; + promise<Outcome> outcome; + }; + + map<Peer, shared_ptr<State>> peerStates; + mutex stateLock; + + void waitForConfirmation(Peer peer, weak_ptr<State> state, string confirm, ConfirmHook hook); +}; + +template<class Result> +class PairingService : public PairingServiceBase +{ +public: + PairingService(Config && config): + PairingServiceBase(move(config)) {} + +protected: + virtual Stored<Result> handlePairingComplete(const Peer &) = 0; + virtual void handlePairingResult(Context &, Stored<Result>) = 0; + + virtual Ref handlePairingCompleteRef(const Peer &) override final; + virtual void handlePairingResult(Context &) override final; +}; + +template<class Result> +Ref PairingService<Result>::handlePairingCompleteRef(const Peer & peer) +{ + return handlePairingComplete(peer).ref(); +} + +template<class Result> +void PairingService<Result>::handlePairingResult(Context & ctx) +{ + handlePairingResult(ctx, Stored<Result>::load(ctx.ref())); +} + +} diff --git a/include/erebos/service.h b/include/erebos/service.h new file mode 100644 index 0000000..7e037f8 --- /dev/null +++ b/include/erebos/service.h @@ -0,0 +1,44 @@ +#pragma once + +#include <erebos/state.h> +#include <erebos/uuid.h> + +#include <memory> + +namespace erebos { + +class Server; + +class Service +{ +public: + Service(); + virtual ~Service(); + + using Config = monostate; + + class Context + { + public: + struct Priv; + Context(Priv *); + Priv & priv(); + + const class Ref & ref() const; + const class Peer & peer() const; + + const Stored<LocalState> & local() const; + void local(const LocalState &); + + void afterCommit(function<void()>); + void runAfterCommitHooks() const; + + private: + std::unique_ptr<Priv> p; + }; + + virtual UUID uuid() const = 0; + virtual void handle(Context &) = 0; +}; + +} diff --git a/include/erebos/set.h b/include/erebos/set.h new file mode 100644 index 0000000..e4a5c91 --- /dev/null +++ b/include/erebos/set.h @@ -0,0 +1,101 @@ +#pragma once + +#include <erebos/merge.h> +#include <erebos/storage.h> + +namespace erebos +{ + +class SetViewBase; +template<class T> class SetView; + +class SetBase +{ +protected: + struct Priv; + + SetBase(); + SetBase(const vector<Ref> &); + SetBase(shared_ptr<const Priv>); + + shared_ptr<const Priv> add(const Storage &, const vector<Ref> &) const; + + vector<vector<Ref>> toList() const; + +public: + bool operator==(const SetBase &) const; + bool operator!=(const SetBase &) const; + + vector<Digest> digests() const; + vector<Ref> store() const; + +protected: + shared_ptr<const Priv> p; +}; + +template<class T> +class Set : public SetBase +{ + Set(shared_ptr<const Priv> p): SetBase(p) {}; +public: + Set() = default; + Set(const vector<Ref> & refs): SetBase(move(refs)) {} + Set(const Set<T> &) = default; + Set(Set<T> &&) = default; + Set & operator=(const Set<T> &) = default; + Set & operator=(Set<T> &&) = default; + + static Set<T> load(const vector<Ref> & refs) { return Set<T>(move(refs)); } + + Set<T> add(const Storage &, const T &) const; + + template<class F> + SetView<T> view(F && cmp) const; +}; + +template<class T> +class SetView +{ +public: + template<class F> + SetView(F && cmp, const vector<vector<Ref>> & refs); + + size_t size() const { return items.size(); } + typename vector<T>::const_iterator begin() const { return items.begin(); } + typename vector<T>::const_iterator end() const { return items.end(); } + +private: + vector<T> items; +}; + +template<class T> +Set<T> Set<T>::add(const Storage & st, const T & x) const +{ + return Set<T>(SetBase::add(st, storedRefs(Mergeable<T>::components(x)))); +} + +template<class T> +template<class F> +SetView<T> Set<T>::view(F && cmp) const +{ + return SetView<T>(std::move(cmp), toList()); +} + +template<class T> +template<class F> +SetView<T>::SetView(F && cmp, const vector<vector<Ref>> & refs) +{ + items.reserve(refs.size()); + for (const auto & crefs : refs) { + vector<Stored<typename Mergeable<T>::Component>> comps; + comps.reserve(crefs.size()); + for (const auto & r : crefs) + comps.push_back(Stored<typename Mergeable<T>::Component>::load(r)); + + filterAncestors(comps); + items.push_back(Mergeable<T>::merge(comps)); + } + std::sort(items.begin(), items.end(), cmp); +} + +} diff --git a/include/erebos/state.h b/include/erebos/state.h new file mode 100644 index 0000000..16be464 --- /dev/null +++ b/include/erebos/state.h @@ -0,0 +1,107 @@ +#pragma once + +#include <erebos/storage.h> +#include <erebos/uuid.h> + +#include <memory> +#include <optional> +#include <vector> + +namespace erebos { + +using std::optional; +using std::shared_ptr; +using std::vector; + +template<typename T> +struct SharedType +{ + static const UUID id; + static T(*const load)(const vector<Ref> &); + static vector<Ref>(*const store)(const T &); +}; + +#define DECLARE_SHARED_TYPE(T) \ + template<> const UUID erebos::SharedType<T>::id; \ + template<> T(*const erebos::SharedType<T>::load)(const std::vector<erebos::Ref> &); \ + template<> std::vector<erebos::Ref>(*const erebos::SharedType<T>::store) (const T &); + +#define DEFINE_SHARED_TYPE(T, id_, load_, store_) \ + template<> const UUID erebos::SharedType<T>::id { id_ }; \ + template<> T(*const erebos::SharedType<T>::load)(const vector<Ref> &) { load_ }; \ + template<> std::vector<erebos::Ref>(*const erebos::SharedType<T>::store) (const T &) { store_ }; + +class Identity; + +class LocalState +{ +public: + LocalState(); + explicit LocalState(const Ref &); + static LocalState load(const Ref & ref) { return LocalState(ref); } + Ref store(const Storage &) const; + + static const UUID headTypeId; + + const optional<Identity> & identity() const; + LocalState identity(const Identity &) const; + + template<class T> T shared() const; + template<class T> LocalState shared(const T & x) const; + + vector<Ref> sharedRefs() const; + LocalState sharedRefAdd(const Ref &) const; + + template<typename T> static T lens(const LocalState &); + +private: + vector<Ref> lookupShared(UUID) const; + LocalState updateShared(UUID, const vector<Ref> &) const; + + struct Priv; + std::shared_ptr<Priv> p; +}; + +class SharedState +{ +public: + template<class T> T get() const; + template<typename T> static T lens(const SharedState &); + + bool operator==(const SharedState &) const; + bool operator!=(const SharedState &) const; + +private: + vector<Ref> lookup(UUID) const; + + struct Priv; + SharedState(shared_ptr<Priv> && p): p(std::move(p)) {} + shared_ptr<Priv> p; + friend class LocalState; +}; + +template<class T> +T LocalState::shared() const +{ + return SharedType<T>::load(lookupShared(SharedType<T>::id)); +} + +template<class T> +LocalState LocalState::shared(const T & x) const +{ + return updateShared(SharedType<T>::id, SharedType<T>::store(x)); +} + +template<class T> +T SharedState::get() const +{ + return SharedType<T>::load(lookup(SharedType<T>::id)); +} + +template<class T> +T SharedState::lens(const SharedState & x) +{ + return x.get<T>(); +} + +} diff --git a/include/erebos/storage.h b/include/erebos/storage.h new file mode 100644 index 0000000..96a27d4 --- /dev/null +++ b/include/erebos/storage.h @@ -0,0 +1,839 @@ +#pragma once + +#include <erebos/frp.h> +#include <erebos/time.h> +#include <erebos/uuid.h> + +#include <algorithm> +#include <array> +#include <cstring> +#include <filesystem> +#include <functional> +#include <memory> +#include <mutex> +#include <optional> +#include <stdexcept> +#include <string> +#include <thread> +#include <variant> +#include <vector> + +namespace erebos { + +class Storage; +class PartialStorage; +class Digest; +class Ref; +class PartialRef; + +template<class S> class RecordT; +typedef RecordT<Storage> Record; +typedef RecordT<PartialStorage> PartialRecord; +template<class S> class ObjectT; +typedef ObjectT<Storage> Object; +typedef ObjectT<PartialStorage> PartialObject; +class Blob; + +template<typename T> class Stored; +template<typename T> class Head; + +using std::bind; +using std::call_once; +using std::make_unique; +using std::monostate; +using std::move; +using std::optional; +using std::shared_ptr; +using std::string; +using std::variant; +using std::vector; + +class PartialStorage +{ +public: + typedef erebos::PartialRef Ref; + + PartialStorage(const PartialStorage &) = default; + PartialStorage & operator=(const PartialStorage &) = delete; + virtual ~PartialStorage() = default; + + bool operator==(const PartialStorage &) const; + bool operator!=(const PartialStorage &) const; + + PartialRef ref(const Digest &) const; + + std::optional<PartialObject> loadObject(const Digest &) const; + PartialRef storeObject(const PartialObject &) const; + PartialRef storeObject(const PartialRecord &) const; + PartialRef storeObject(const Blob &) const; + +protected: + friend class Storage; + friend erebos::Ref; + friend erebos::PartialRef; + struct Priv; + const std::shared_ptr<const Priv> p; + PartialStorage(const std::shared_ptr<const Priv> & p): p(p) {} + +public: + // For test usage + const Priv & priv() const { return *p; } +}; + +class Storage : public PartialStorage +{ +public: + typedef erebos::Ref Ref; + + Storage(const std::filesystem::path &); + Storage(const Storage &) = default; + Storage & operator=(const Storage &) = delete; + + Storage deriveEphemeralStorage() const; + PartialStorage derivePartialStorage() const; + + std::optional<Ref> ref(const Digest &) const; + Ref zref() const; + + std::optional<Object> loadObject(const Digest &) const; + Ref storeObject(const Object &) const; + Ref storeObject(const Record &) const; + Ref storeObject(const Blob &) const; + + std::variant<Ref, std::vector<Digest>> copy(const PartialRef &) const; + std::variant<Ref, std::vector<Digest>> copy(const PartialObject &) const; + Ref copy(const Ref &) const; + Ref copy(const Object &) const; + + template<typename T> Stored<T> store(const T &) const; + + template<typename T> std::optional<Head<T>> head(UUID id) const; + template<typename T> std::vector<Head<T>> heads() const; + template<typename T> Head<T> storeHead(const T &) const; + template<typename T> Head<T> storeHead(const Stored<T> &) const; + + void storeKey(Ref pubref, const std::vector<uint8_t> &) const; + std::optional<std::vector<uint8_t>> loadKey(Ref pubref) const; + +protected: + template<typename T> friend class Head; + template<typename T> friend class WatchedHead; + + Storage(const std::shared_ptr<const Priv> & p): PartialStorage(p) {} + + std::optional<Ref> headRef(UUID type, UUID id) const; + std::vector<std::tuple<UUID, Ref>> headRefs(UUID type) const; + static UUID storeHead(UUID type, const Ref & ref); + static bool replaceHead(UUID type, UUID id, const Ref & old, const Ref & ref); + static std::optional<Ref> updateHead(UUID type, UUID id, const Ref & old, const std::function<Ref(const Ref &)> &); + int watchHead(UUID type, UUID id, const std::function<void(const Ref &)>) const; + void unwatchHead(UUID type, UUID id, int watchId) const; +}; + +class Digest +{ +public: + static constexpr size_t size = 32; + + Digest(const Digest &) = default; + Digest & operator=(const Digest &) = default; + + explicit Digest(std::array<uint8_t, size> value): value(value) {} + explicit Digest(const std::string &); + explicit operator std::string() const; + bool isZero() const; + + static Digest of(const std::vector<uint8_t> & content); + template<class S> static Digest of(const ObjectT<S> &); + + const std::array<uint8_t, size> & arr() const { return value; } + + bool operator==(const Digest & other) const { return value == other.value; } + bool operator!=(const Digest & other) const { return value != other.value; } + bool operator<(const Digest & other) const { return value < other.value; } + bool operator<=(const Digest & other) const { return value <= other.value; } + bool operator>(const Digest & other) const { return value > other.value; } + bool operator>=(const Digest & other) const { return value >= other.value; } + +private: + std::array<uint8_t, size> value; +}; + +template<class S> +Digest Digest::of(const ObjectT<S> & obj) +{ + return Digest::of(obj.encode()); +} + +class PartialRef +{ +public: + PartialRef(const PartialRef &) = default; + PartialRef(PartialRef &&) = default; + PartialRef & operator=(const PartialRef &) = default; + PartialRef & operator=(PartialRef &&) = default; + + static PartialRef create(const PartialStorage &, const Digest &); + static PartialRef zcreate(const PartialStorage &); + + const Digest & digest() const; + + operator bool() const; + const PartialObject operator*() const; + std::unique_ptr<PartialObject> operator->() const; + + const PartialStorage & storage() const; + +protected: + friend class Storage; + struct Priv; + std::shared_ptr<const Priv> p; + PartialRef(const std::shared_ptr<const Priv> p): p(p) {} +}; + +class Ref : public PartialRef +{ +public: + Ref(const Ref &) = default; + Ref(Ref &&) = default; + Ref & operator=(const Ref &) = default; + Ref & operator=(Ref &&) = default; + + bool operator==(const Ref &) const; + bool operator!=(const Ref &) const; + + static std::optional<Ref> create(const Storage &, const Digest &); + static Ref zcreate(const Storage &); + + explicit constexpr operator bool() const { return true; } + const Object operator*() const; + std::unique_ptr<Object> operator->() const; + + const Storage & storage() const; + + vector<Ref> previous() const; + class Generation generation() const; + vector<Digest> roots() const; + +private: + class Generation generationLocked() const; + class vector<Digest> rootsLocked() const; + +protected: + Ref(const std::shared_ptr<const Priv> p): PartialRef(p) {} +}; + +template<class S> +class RecordT +{ +public: + class Item; + class Items; + +private: + RecordT(const std::shared_ptr<std::vector<Item>> & ptr): + ptr(ptr) {} + +public: + RecordT(): RecordT(std::vector<Item> {}) {} + RecordT(const std::vector<Item> &); + RecordT(std::vector<Item> &&); + std::vector<uint8_t> encode() const; + + Items items() const; + Item item(const std::string & name) const; + Item operator[](const std::string & name) const; + Items items(const std::string & name) const; + +private: + friend ObjectT<S>; + std::vector<uint8_t> encodeInner() const; + static std::optional<RecordT<S>> decode(const S &, + std::vector<uint8_t>::const_iterator, + std::vector<uint8_t>::const_iterator); + + const std::shared_ptr<const std::vector<Item>> ptr; +}; + +template<class S> +class RecordT<S>::Item +{ +public: + struct UnknownType + { + string type; + string value; + }; + + struct Empty {}; + + using Integer = int; + using Text = string; + using Binary = vector<uint8_t>; + using Date = ZonedTime; + using UUID = erebos::UUID; + using Ref = typename S::Ref; + + using Variant = variant< + monostate, + Empty, + int, + string, + vector<uint8_t>, + ZonedTime, + UUID, + typename S::Ref, + UnknownType>; + + Item(const string & name): + Item(name, monostate()) {} + Item(const string & name, Variant value): + name(name), value(value) {} + template<typename T> + Item(const string & name, const Stored<T> & value): + Item(name, value.ref()) {} + + Item(const Item &) = default; + Item & operator=(const Item &) = delete; + + operator bool() const; + + optional<Empty> asEmpty() const; + optional<Integer> asInteger() const; + optional<Text> asText() const; + optional<Binary> asBinary() const; + optional<Date> asDate() const; + optional<UUID> asUUID() const; + optional<Ref> asRef() const; + optional<UnknownType> asUnknown() const; + + template<typename T> optional<Stored<T>> as() const; + + const string name; + const Variant value; +}; + +template<class S> +class RecordT<S>::Items +{ +public: + using Empty = typename Item::Empty; + using Integer = typename Item::Integer; + using Text = typename Item::Text; + using Binary = typename Item::Binary; + using Date = typename Item::Date; + using UUID = typename Item::UUID; + using Ref = typename Item::Ref; + using UnknownType = typename Item::UnknownType; + + Items(shared_ptr<const vector<Item>> items); + Items(shared_ptr<const vector<Item>> items, string filter); + + class Iterator + { + Iterator(const Items & source, size_t idx); + friend Items; + public: + using iterator_category = std::forward_iterator_tag; + using value_type = Item; + using difference_type = ssize_t; + using pointer = const Item *; + using reference = const Item &; + + Iterator(const Iterator &) = default; + ~Iterator() = default; + Iterator & operator=(const Iterator &) = default; + Iterator & operator++(); + value_type operator*() const { return (*source.items)[idx]; } + bool operator==(const Iterator & other) const { return idx == other.idx; } + bool operator!=(const Iterator & other) const { return idx != other.idx; } + + private: + const Items & source; + size_t idx; + }; + + Iterator begin() const; + Iterator end() const; + + vector<Empty> asEmpty() const; + vector<Integer> asInteger() const; + vector<Text> asText() const; + vector<Binary> asBinary() const; + vector<Date> asDate() const; + vector<UUID> asUUID() const; + vector<Ref> asRef() const; + vector<UnknownType> asUnknown() const; + + template<typename T> vector<Stored<T>> as() const; + +private: + const shared_ptr<const vector<Item>> items; + const optional<string> filter; +}; + +extern template class RecordT<Storage>; +extern template class RecordT<PartialStorage>; + +class Blob +{ +public: + Blob(const std::vector<uint8_t> &); + + const std::vector<uint8_t> & data() const { return *ptr; } + std::vector<uint8_t> encode() const; + +private: + friend Object; + friend PartialObject; + std::vector<uint8_t> encodeInner() const; + static Blob decode( + std::vector<uint8_t>::const_iterator, + std::vector<uint8_t>::const_iterator); + + Blob(std::shared_ptr<std::vector<uint8_t>> ptr): ptr(ptr) {} + + const std::shared_ptr<const std::vector<uint8_t>> ptr; +}; + +template<class S> +class ObjectT +{ +public: + typedef std::variant< + RecordT<S>, + Blob, + std::monostate> Variants; + + ObjectT(const ObjectT<S> &) = default; + ObjectT(Variants content): content(content) {} + ObjectT<S> & operator=(const ObjectT<S> &) = default; + + static std::optional<std::tuple<ObjectT<S>, std::vector<uint8_t>::const_iterator>> + decodePrefix(const S &, + std::vector<uint8_t>::const_iterator, + std::vector<uint8_t>::const_iterator); + + static std::optional<ObjectT<S>> decode(const S &, const std::vector<uint8_t> &); + static std::optional<ObjectT<S>> decode(const S &, + std::vector<uint8_t>::const_iterator, + std::vector<uint8_t>::const_iterator); + static std::vector<ObjectT<S>> decodeMany(const S &, const std::vector<uint8_t> &); + std::vector<uint8_t> encode() const; + static ObjectT<S> load(const typename S::Ref &); + + operator bool() const; + + std::optional<RecordT<S>> asRecord() const; + std::optional<Blob> asBlob() const; + +private: + friend RecordT<S>; + friend Blob; + + Variants content; +}; + +extern template class ObjectT<Storage>; +extern template class ObjectT<PartialStorage>; + +template<class S> +template<typename T> +std::optional<Stored<T>> RecordT<S>::Item::as() const +{ + if (auto ref = asRef()) + return Stored<T>::load(ref.value()); + return std::nullopt; +} + +template<class S> +template<typename T> +vector<Stored<T>> RecordT<S>::Items::as() const +{ + auto refs = asRef(); + vector<Stored<T>> res; + res.reserve(refs.size()); + for (const auto & ref : refs) + res.push_back(Stored<T>::load(ref)); + return res; +} + +class Generation +{ +public: + Generation(); + static Generation next(const vector<Generation> &); + + explicit operator string() const; + +private: + Generation(size_t); + size_t gen; +}; + +template<typename T> +class Stored +{ + Stored(Ref ref, T x); + friend class Storage; + friend class Head<T>; +public: + Stored() = default; + Stored(const Stored &) = default; + Stored(Stored &&) = default; + Stored & operator=(const Stored &) = default; + Stored & operator=(Stored &&) = default; + + Stored(Ref); + static Stored<T> load(const Ref &); + Ref store(const Storage &) const; + + bool operator==(const Stored<T> & other) const + { return p->ref.digest() == other.p->ref.digest(); } + bool operator!=(const Stored<T> & other) const + { return p->ref.digest() != other.p->ref.digest(); } + bool operator<(const Stored<T> & other) const + { return p->ref.digest() < other.p->ref.digest(); } + bool operator<=(const Stored<T> & other) const + { return p->ref.digest() <= other.p->ref.digest(); } + bool operator>(const Stored<T> & other) const + { return p->ref.digest() > other.p->ref.digest(); } + bool operator>=(const Stored<T> & other) const + { return p->ref.digest() >= other.p->ref.digest(); } + + void init() const; + const T & operator*() const { init(); return *p->val; } + const T * operator->() const { init(); return p->val.get(); } + + Generation generation() const { return p->ref.generation(); } + + std::vector<Stored<T>> previous() const; + bool precedes(const Stored<T> &) const; + + std::vector<Digest> roots() const { return p->ref.roots(); } + + const Ref & ref() const { return p->ref; } + +private: + struct Priv { + const Ref ref; + mutable std::once_flag once {}; + mutable std::unique_ptr<T> val {}; + mutable std::function<T()> init {}; + }; + std::shared_ptr<Priv> p; +}; + +template<typename T> +void Stored<T>::init() const +{ + call_once(p->once, [this]() { + p->val = std::make_unique<T>(p->init()); + p->init = decltype(p->init)(); + }); +} + +template<typename T> +Stored<T> Storage::store(const T & val) const +{ + return Stored(val.store(*this), val); +} + +template<typename T> +Stored<T>::Stored(Ref ref, T x): + p(new Priv { + .ref = move(ref), + .val = make_unique<T>(move(x)), + }) +{ + call_once(p->once, [](){}); +} + +template<typename T> +Stored<T>::Stored(Ref ref): + p(new Priv { + .ref = move(ref), + }) +{ + p->init = [p = p.get()]() { return T::load(p->ref); }; +} + +template<typename T> +Stored<T> Stored<T>::load(const Ref & ref) +{ + return Stored(ref); +} + +template<typename T> +Ref Stored<T>::store(const Storage & st) const +{ + if (st == p->ref.storage()) + return p->ref; + return st.storeObject(*p->ref); +} + +template<typename T> +std::vector<Stored<T>> Stored<T>::previous() const +{ + auto refs = p->ref.previous(); + vector<Stored<T>> res; + res.reserve(refs.size()); + for (const auto & r : refs) + res.push_back(Stored<T>::load(r)); + return res; +} + +template<typename T> +bool precedes(const T & ancestor, const T & descendant) +{ + for (const auto & x : descendant.previous()) { + if (ancestor == x || precedes(ancestor, x)) + return true; + } + return false; +} + +template<typename T> +bool Stored<T>::precedes(const Stored<T> & other) const +{ + return erebos::precedes(*this, other); +} + +template<typename T> +void filterAncestors(std::vector<T> & xs) +{ + if (xs.size() < 2) + return; + + std::sort(xs.begin(), xs.end()); + xs.erase(std::unique(xs.begin(), xs.end()), xs.end()); + + std::vector<T> old; + old.swap(xs); + + for (auto i = old.begin(); i != old.end(); i++) { + bool add = true; + for (const auto & x : xs) + if (precedes(*i, x)) { + add = false; + break; + } + if (add) + for (auto j = i + 1; j != old.end(); j++) + if (precedes(*i, *j)) { + add = false; + break; + } + if (add) + xs.push_back(std::move(*i)); + } +} + +template<class T> class WatchedHead; +template<class T> class HeadBhv; + +template<class T> +class Head +{ + Head(UUID id, Stored<T> stored): + mid(id), mstored(move(stored)) {} + Head(UUID id, Ref ref, T val): + mid(id), mstored(move(ref), move(val)) {} + friend class Storage; +public: + Head(UUID id, Ref ref): mid(id), mstored(ref) {} + + const T & operator*() const { return *mstored; } + const T * operator->() const { return &(*mstored); } + + UUID id() const { return mid; } + const Stored<T> & stored() const { return mstored; } + const Ref & ref() const { return mstored.ref(); } + const Storage & storage() const { return mstored.ref().storage(); } + + optional<Head<T>> reload() const; + std::optional<Head<T>> update(const std::function<Stored<T>(const Stored<T> &)> &) const; + WatchedHead<T> watch(const std::function<void(const Head<T> &)> &) const; + + Bhv<T> behavior() const; + +private: + UUID mid; + Stored<T> mstored; +}; + +/** + * Manages registered watch callbacks to Head<T> object using RAII principle. + */ +template<class T> +class WatchedHead : public Head<T> +{ + friend class Head<T>; + friend class HeadBhv<T>; + + WatchedHead(const Head<T> & h): + Head<T>(h), watcherId(-1) {} + WatchedHead(const Head<T> & h, int watcherId): + Head<T>(h), watcherId(watcherId) {} + + int watcherId; + +public: + WatchedHead(WatchedHead<T> && h): + Head<T>(h), watcherId(h.watcherId) + { h.watcherId = -1; } + + WatchedHead<T> & operator=(WatchedHead<T> && h) + { watcherId = h.watcherId; h.watcherId = -1; return *this; } + + WatchedHead<T> & operator=(const Head<T> & h) { + if (Head<T>::id() != h.id()) + throw std::runtime_error("WatchedHead ID mismatch"); + static_cast<Head<T> &>(*this) = h; + return *this; + } + + /// Destructor stops the watching started with Head<T>::watch call. + /** + * Once the WatchedHead object is destroyed, no further Head<T> changes + * will trigger the associated callback. + * + * The destructor also ensures that any scheduled callback run + * triggered by a previous change to the head is executed and finished + * before the destructor returns. The exception is when the destructor + * is called directly from the callback itself, in which case the + * destructor returns immediately. + */ + ~WatchedHead(); +}; + +template<class T> +class HeadBhv : public BhvSource<T> +{ +public: + HeadBhv(const Head<T> & head): + whead(head) + {} + + T get(const BhvCurTime &, const std::monostate &) const { return *whead; } + +private: + friend class Head<T>; + + void init() + { + whead = whead.watch([wp = weak_ptr<BhvImplBase>(BhvImplBase::shared_from_this()), this] (const Head<T> & cur) { + // make sure this object still exists + if (auto ptr = wp.lock()) { + BhvCurTime ctime; + whead = cur; + BhvImplBase::updated(ctime); + } + }); + } + + WatchedHead<T> whead; +}; + +template<typename T> +std::optional<Head<T>> Storage::head(UUID id) const +{ + if (auto ref = headRef(T::headTypeId, id)) + return Head<T>(id, *ref); + return std::nullopt; +} + +template<typename T> +std::vector<Head<T>> Storage::heads() const +{ + std::vector<Head<T>> res; + for (const auto & x : headRefs(T::headTypeId)) + res.emplace_back(std::get<UUID>(x), std::get<Ref>(x)); + return res; +} + +template<typename T> +Head<T> Storage::storeHead(const T & val) const +{ + auto ref = val.store(*this); + auto id = storeHead(T::headTypeId, ref); + return Head(id, ref, val); +} + +template<typename T> +Head<T> Storage::storeHead(const Stored<T> & val) const +{ + auto id = storeHead(T::headTypeId, val.ref()); + return Head(id, val); +} + +template<typename T> +optional<Head<T>> Head<T>::reload() const +{ + return storage().template head<T>(id()); +} + +template<typename T> +std::optional<Head<T>> Head<T>::update(const std::function<Stored<T>(const Stored<T> &)> & f) const +{ + auto res = Storage::updateHead(T::headTypeId, mid, ref(), [&f, this](const Ref & r) { + return f(r.digest() == ref().digest() ? stored() : Stored<T>::load(r)).ref(); + }); + + if (!res) + return std::nullopt; + if (res->digest() == ref().digest()) + return *this; + return Head<T>(mid, *res); +} + +template<typename T> +WatchedHead<T> Head<T>::watch(const std::function<void(const Head<T> &)> & watcher) const +{ + int wid = storage().watchHead(T::headTypeId, id(), [id = id(), watcher] (const Ref & ref) { + watcher(Head<T>(id, ref)); + }); + return WatchedHead<T>(*this, wid); +} + +template<typename T> +Bhv<T> Head<T>::behavior() const +{ + auto cur = reload(); + auto ret = make_shared<HeadBhv<T>>(cur ? *cur : *this); + ret->init(); + return ret; +} + +template<class T> +WatchedHead<T>::~WatchedHead() +{ + if (watcherId >= 0) + Head<T>::storage().unwatchHead( + T::headTypeId, Head<T>::id(), watcherId); +} + +template<class T> +vector<Ref> storedRefs(const vector<Stored<T>> & v) +{ + vector<Ref> res; + res.reserve(v.size()); + for (const auto & x : v) + res.push_back(x.ref()); + return res; +} + +} + +namespace std +{ + template<> struct hash<erebos::Digest> + { + std::size_t operator()(const erebos::Digest & dgst) const noexcept + { + std::size_t res; + std::memcpy(&res, dgst.arr().data(), sizeof res); + return res; + } + }; +} diff --git a/include/erebos/sync.h b/include/erebos/sync.h new file mode 100644 index 0000000..662a558 --- /dev/null +++ b/include/erebos/sync.h @@ -0,0 +1,32 @@ +#pragma once + +#include <erebos/service.h> +#include <erebos/state.h> +#include <erebos/storage.h> + +#include <optional> +#include <mutex> +#include <vector> + +namespace erebos { + +using std::vector; + +class SyncService : public Service +{ +public: + SyncService(Config &&, const Server &); + virtual ~SyncService(); + + UUID uuid() const override; + void handle(Context &) override; + +private: + void peerWatcher(size_t, const class Peer *); + void localStateWatcher(const vector<Ref> &); + + const Server & server; + Watched<vector<Ref>> watchedLocal; +}; + +} diff --git a/include/erebos/time.h b/include/erebos/time.h new file mode 100644 index 0000000..d8ff5b1 --- /dev/null +++ b/include/erebos/time.h @@ -0,0 +1,20 @@ +#pragma once + +#include <chrono> +#include <string> + +namespace erebos { + +struct ZonedTime +{ + explicit ZonedTime(std::string); + ZonedTime(std::chrono::system_clock::time_point t): time(t), zone(0) {} + explicit operator std::string() const; + + static ZonedTime now(); + + std::chrono::system_clock::time_point time; + std::chrono::minutes zone; // zone offset +}; + +} diff --git a/include/erebos/uuid.h b/include/erebos/uuid.h new file mode 100644 index 0000000..d6ccf50 --- /dev/null +++ b/include/erebos/uuid.h @@ -0,0 +1,41 @@ +#pragma once + +#include <array> +#include <cstdint> +#include <cstring> +#include <optional> +#include <string> + +namespace erebos { + +struct UUID +{ + UUID(): uuid({}) {} + explicit UUID(const std::string &); + explicit operator std::string() const; + + static std::optional<UUID> fromString(const std::string &); + static bool fromString(const std::string &, UUID &); + + static UUID generate(); + + bool operator==(const UUID &) const; + bool operator!=(const UUID &) const; + + std::array<uint8_t, 16> uuid; +}; + +} + +namespace std +{ + template<> struct hash<erebos::UUID> + { + std::size_t operator()(const erebos::UUID & uuid) const noexcept + { + std::size_t res; + std::memcpy(&res, uuid.uuid.data(), sizeof res); + return res; + } + }; +} diff --git a/minici.yaml b/minici.yaml new file mode 100644 index 0000000..4a6d65d --- /dev/null +++ b/minici.yaml @@ -0,0 +1,15 @@ +job build: + shell: + - make CC=gcc CFLAGS="-Werror -Wno-deprecated-declarations" CXXFLAGS="-Werror -Wno-deprecated-declarations" + artifact erebos: + path: build/src/erebos + +job clang: + shell: + - make CC=clang CFLAGS="-Werror -Wno-deprecated-declarations" CXXFLAGS="-Werror -Wno-deprecated-declarations" + +job test: + uses: + - build.erebos + shell: + - erebos-tester -v diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..fff6242 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,43 @@ +include_directories( + ../include +) + +add_library(erebos + attach.cpp + contact.cpp + frp.cpp + identity.cpp + merge.cpp + message.cpp + network.cpp + network/channel.cpp + network/protocol.cpp + pairing.cpp + pubkey.cpp + service.cpp + set.cpp + state.cpp + storage.cpp + sync.cpp + time.cpp + uuid.cpp +) + +if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android") + add_executable(erebos-bin + main.cpp + ) + + set_target_properties(erebos-bin + PROPERTIES OUTPUT_NAME erebos + ) + + target_link_libraries(erebos-bin + erebos + stdc++fs + Threads::Threads + ${ZLIB_LIBRARIES} + ${OPENSSL_LIBRARIES} + ${B2_LIBRARY} + ) +endif() diff --git a/src/attach.cpp b/src/attach.cpp new file mode 100644 index 0000000..887de38 --- /dev/null +++ b/src/attach.cpp @@ -0,0 +1,130 @@ +#include <erebos/attach.h> + +#include "identity.h" +#include "pubkey.h" + +#include <erebos/network.h> + +#include <stdexcept> + +using namespace erebos; +using std::lock_guard; +using std::nullopt; +using std::runtime_error; + +static const UUID myUUID("4995a5f9-2d4d-48e9-ad3b-0bf1c2a1be7f"); + +AttachService::AttachService(Config && config, const Server &): + PairingService(move(config)) +{ +} + +AttachService::~AttachService() = default; + +UUID AttachService::uuid() const +{ + return myUUID; +} + +void AttachService::attachTo(const Peer & peer) +{ + requestPairing(myUUID, peer); +} + +Stored<AttachIdentity> AttachService::handlePairingComplete(const Peer & peer) +{ + auto owner = peer.server().identity().finalOwner(); + auto pid = peer.identity(); + + auto idata = peer.tempStorage().store(IdentityData { + .prev = pid->data(), + .name = nullopt, + .owner = owner.data()[0], + .keyIdentity = pid->keyIdentity(), + .keyMessage = nullopt, + }); + + auto key = SecretKey::load(owner.keyIdentity()); + if (!key) + throw runtime_error("failed to load secret key"); + + auto mkey = SecretKey::load(owner.keyMessage()); + if (!mkey) + throw runtime_error("failed to load secret key"); + + auto sdata = key->sign(idata); + + return peer.tempStorage().store(AttachIdentity { + .identity = sdata, + .keys = { key->getData(), mkey->getData() }, + }); +} + +void AttachService::handlePairingResult(Context & ctx, Stored<AttachIdentity> att) +{ + if (att->identity->data->prev.size() != 1 || + att->identity->data->prev[0].ref().digest() != + ctx.local()->identity()->ref()->digest()) + return; + + if (att->identity->data->keyIdentity.ref().digest() != + ctx.local()->identity()->keyIdentity().ref().digest()) + return; + + auto key = SecretKey::load(ctx.peer().server().identity().keyIdentity()); + if (!key) + throw runtime_error("failed to load secret key"); + + vector<StoredIdentityPart> parts = ctx.local()->identity()->extData(); + parts.emplace_back(key->signAdd(att->identity)); + filterAncestors(parts); + + auto id = Identity::load(parts); + if (!id) + printf("New identity validation failed\n"); + + optional<Ref> tmpref = id->extRef(); + if (not tmpref) + tmpref = id->modify().commit().extRef(); + + auto rid = ctx.local().ref().storage().copy(*tmpref); + id = Identity::load(rid); + + auto owner = id->owner(); + if (!owner) + printf("New identity without owner\n"); + + // Store the keys + for (const auto & k : att->keys) { + SecretKey::fromData(owner->keyIdentity(), k); + SecretKey::fromData(owner->keyMessage(), k); + } + + ctx.local(ctx.local()->identity(*id)); +} + +AttachIdentity AttachIdentity::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return AttachIdentity { + .identity = Stored<Signed<IdentityData>>::load(ref.storage().zref()), + .keys = {}, + }; + + return AttachIdentity { + .identity = *rec->item("identity").as<Signed<IdentityData>>(), + .keys = rec->items("skey").asBinary(), + }; +} + +Ref AttachIdentity::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("identity", identity.ref()); + for (const auto & key : keys) + items.emplace_back("skey", key); + + return st.storeObject(Record(std::move(items))); +} diff --git a/src/contact.cpp b/src/contact.cpp new file mode 100644 index 0000000..9ab5699 --- /dev/null +++ b/src/contact.cpp @@ -0,0 +1,207 @@ +#include "contact.h" + +#include "identity.h" + +#include <array> + +using namespace erebos; + +using std::array; +using std::move; + +DEFINE_SHARED_TYPE(Set<Contact>, + "34fbb61e-6022-405f-b1b3-a5a1abecd25e", + &Set<Contact>::load, + [](const Set<Contact> & set) { + return set.store(); + }) + +static const UUID serviceUUID("d9c37368-0da1-4280-93e9-d9bd9a198084"); + +Contact::Contact(vector<Stored<ContactData>> data): + p(shared_ptr<Priv>(new Priv { + .data = data, + })) +{ +} + +optional<Identity> Contact::identity() const +{ + p->init(); + return p->identity; +} + +optional<string> Contact::customName() const +{ + p->init(); + return p->name; +} + +Contact Contact::customName(const Storage & st, const string & name) const +{ + auto cdata = st.store(ContactData { + .prev = p->data, + .identity = {}, + .name = name, + }); + + return Contact(shared_ptr<Contact::Priv>(new Contact::Priv { + .data = { cdata }, + })); +} + +string Contact::name() const +{ + if (auto cust = customName()) + return *cust; + if (auto id = p->identity) + if (auto idname = id->name()) + return *idname; + return ""; +} + +bool Contact::operator==(const Contact & other) const +{ + return p->data == other.p->data; +} + +bool Contact::operator!=(const Contact & other) const +{ + return p->data != other.p->data; +} + +vector<Stored<ContactData>> Contact::data() const +{ + return p->data; +} + +Digest Contact::leastRoot() const +{ + if (p->data.empty()) + return Digest(array<uint8_t, Digest::size> {}); + + vector<Digest> roots; + for (const auto & d : p->data) + for (const auto & r : d.ref().roots()) + roots.push_back(r); + roots.erase(std::unique(roots.begin(), roots.end()), roots.end()); + return roots[0]; +} + +void Contact::Priv::init() +{ + std::call_once(initFlag, [this]() { + vector<StoredIdentityPart> idata; + for (const auto & c : findPropertyComponents<Contact>(data, "identity")) + for (const auto & i : c->identity) + idata.push_back(i); + + identity = Identity::load(idata); + if (identity) + name = identity->name(); + + if (auto opt = findPropertyComponent<Contact>(data, "name")) + name = (*opt)->name; + }); +} + +ContactData ContactData::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return ContactData(); + + vector<StoredIdentityPart> identity; + for (const auto & r : rec->items("identity").asRef()) + identity.push_back(StoredIdentityPart::load(r)); + + return ContactData { + .prev = rec->items("PREV").as<ContactData>(), + .identity = move(identity), + .name = rec->item("name").asText(), + }; +} + +Ref ContactData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & prev : prev) + items.emplace_back("PREV", prev.ref()); + for (const auto & idt : identity) + items.emplace_back("identity", idt.ref()); + if (name) + items.emplace_back("name", *name); + + return st.storeObject(Record(std::move(items))); +} + +ContactService::ContactService(Config && config, const Server & s): + PairingService(move(config)), + server(s) +{ +} + +ContactService::~ContactService() = default; + +UUID ContactService::uuid() const +{ + return serviceUUID; +} + +void ContactService::request(const Peer & peer) +{ + requestPairing(serviceUUID, peer); +} + +Stored<ContactAccepted> ContactService::handlePairingComplete(const Peer & peer) +{ + server.localHead().update([&] (const Stored<LocalState> & local) { + auto cdata = local.ref().storage().store(ContactData { + .prev = {}, + .identity = peer.identity()->finalOwner().extData(), + .name = std::nullopt, + }); + + Contact contact(shared_ptr<Contact::Priv>(new Contact::Priv { + .data = { cdata }, + })); + + auto contacts = local->shared<Set<Contact>>(); + + return local.ref().storage().store(local->shared<Set<Contact>>( + contacts.add(local.ref().storage(), contact))); + }); + + return peer.tempStorage().store(ContactAccepted {}); +} + +void ContactService::handlePairingResult(Context & ctx, Stored<ContactAccepted>) +{ + auto cdata = ctx.local().ref().storage().store(ContactData { + .prev = {}, + .identity = ctx.peer().identity()->finalOwner().extData(), + .name = std::nullopt, + }); + + Contact contact(shared_ptr<Contact::Priv>(new Contact::Priv { + .data = { cdata }, + })); + + auto contacts = ctx.local()->shared<Set<Contact>>(); + + ctx.local(ctx.local()->shared<Set<Contact>>( + contacts.add(ctx.local().ref().storage(), contact))); +} + +ContactAccepted ContactAccepted::load(const Ref &) +{ + return ContactAccepted {}; +} + +Ref ContactAccepted::store(const Storage & st) const +{ + vector<Record::Item> items; + items.emplace_back("accept", ""); + return st.storeObject(Record(std::move(items))); +} diff --git a/src/contact.h b/src/contact.h new file mode 100644 index 0000000..6fc0219 --- /dev/null +++ b/src/contact.h @@ -0,0 +1,30 @@ +#pragma once + +#include <erebos/contact.h> + +#include <mutex> +#include <optional> +#include <string> +#include <vector> + +namespace erebos { + +using std::optional; +using std::string; +using std::vector; + +struct ContactData; +struct IdentityData; + +struct Contact::Priv +{ + vector<Stored<ContactData>> data; + + void init(); + std::once_flag initFlag {}; + + optional<Identity> identity {}; + optional<string> name {}; +}; + +} diff --git a/src/frp.cpp b/src/frp.cpp new file mode 100644 index 0000000..eba104d --- /dev/null +++ b/src/frp.cpp @@ -0,0 +1,154 @@ +#include <erebos/frp.h> + +#include <condition_variable> +#include <mutex> +#include <thread> + +using namespace erebos; + +using std::condition_variable; +using std::move; +using std::mutex; +using std::nullopt; +using std::thread; +using std::unique_lock; + +namespace { + +mutex bhvTimeMutex; +condition_variable bhvTimeCond; +optional<thread::id> bhvTimeRunning = nullopt; +uint64_t bhvTimeCount = 0; +uint64_t bhvTimeLast = 0; + +} + +BhvTime::BhvTime(const BhvCurTime & ct): + BhvTime(ct.time()) +{} + +BhvCurTime::BhvCurTime() +{ + auto tid = std::this_thread::get_id(); + unique_lock lock(bhvTimeMutex); + bhvTimeCond.wait(lock, [tid]{ + return !bhvTimeRunning || bhvTimeRunning == tid; + }); + + if (bhvTimeRunning != tid) { + bhvTimeRunning = tid; + bhvTimeLast++; + } + t = BhvTime(bhvTimeLast); + bhvTimeCount++; +} + +BhvCurTime::~BhvCurTime() +{ + if (t) { + unique_lock lock(bhvTimeMutex); + bhvTimeCount--; + + if (bhvTimeCount == 0) { + bhvTimeRunning.reset(); + lock.unlock(); + bhvTimeCond.notify_one(); + } + } +} + +BhvCurTime::BhvCurTime(BhvCurTime && other) +{ + t = other.t; + other.t = nullopt; +} + +BhvCurTime & BhvCurTime::operator=(BhvCurTime && other) +{ + t = other.t; + other.t = nullopt; + return *this; +} + + +BhvImplBase::~BhvImplBase() = default; + +void BhvImplBase::dependsOn(const BhvCurTime &, shared_ptr<BhvImplBase> other) +{ + depends.push_back(other); + other->rdepends.push_back(shared_from_this()); +} + +void BhvImplBase::updated(const BhvCurTime & ctime) +{ + vector<shared_ptr<BhvImplBase>> toUpdate; + markDirty(ctime, toUpdate); + + for (auto & bhv : toUpdate) + bhv->updateDirty(ctime); +} + +void BhvImplBase::markDirty(const BhvCurTime & ctime, vector<shared_ptr<BhvImplBase>> & toUpdate) +{ + if (dirty) + return; + + if (!needsUpdate(ctime)) + return; + + dirty = true; + toUpdate.push_back(shared_from_this()); + + bool prune = false; + for (const auto & w : rdepends) { + if (auto b = w.lock()) + b->markDirty(ctime, toUpdate); + else + prune = true; + } + + if (prune) { + decltype(rdepends) pruned; + for (const auto & w : rdepends) + if (!w.expired()) + pruned.push_back(move(w)); + rdepends = move(pruned); + } +} + +void BhvImplBase::updateDirty(const BhvCurTime & ctime) +{ + if (!dirty) + return; + + for (auto & d : depends) + d->updateDirty(ctime); + + doUpdate(ctime); + dirty = false; + + bool prune = false; + for (const auto & wcb : watchers) { + if (auto cb = wcb.lock()) + (*cb)(ctime); + else + prune = true; + } + + if (prune) { + decltype(watchers) pruned; + for (const auto & w : watchers) + if (!w.expired()) + pruned.push_back(move(w)); + watchers = move(pruned); + } +} + +bool BhvImplBase::needsUpdate(const BhvCurTime &) const +{ + return true; +} + +void BhvImplBase::doUpdate(const BhvCurTime &) +{ +} diff --git a/src/identity.cpp b/src/identity.cpp new file mode 100644 index 0000000..8b6ee2a --- /dev/null +++ b/src/identity.cpp @@ -0,0 +1,616 @@ +#include "identity.h" + +#include <erebos/state.h> + +#include <algorithm> +#include <set> +#include <stdexcept> + +using namespace erebos; + +using std::async; +using std::holds_alternative; +using std::nullopt; +using std::runtime_error; +using std::set; +using std::visit; + +template<class> +inline constexpr bool always_false_v = false; + +DEFINE_SHARED_TYPE(optional<Identity>, + "0c6c1fe0-f2d7-4891-926b-c332449f7871", + &Identity::load, + [](const optional<Identity> & id) { + if (id) + return id->store(); + return vector<Ref>(); + }) + +Identity::Identity(const Priv * p): p(p) {} +Identity::Identity(shared_ptr<const Priv> && p): p(std::move(p)) {} + +optional<Identity> Identity::load(const Ref & ref) +{ + return Identity::load(vector { ref }); +} + +optional<Identity> Identity::load(const vector<Ref> & refs) +{ + vector<StoredIdentityPart> data; + data.reserve(refs.size()); + + for (const auto & ref : refs) + data.push_back(StoredIdentityPart::load(ref)); + + return load(data); +} + +optional<Identity> Identity::load(const vector<Stored<Signed<IdentityData>>> & data) +{ + vector<StoredIdentityPart> parts; + parts.reserve(data.size()); + + for (const auto & d : data) + parts.emplace_back(d); + + return load(parts); +} + +optional<Identity> Identity::load(const vector<StoredIdentityPart> & data) +{ + if (auto ptr = Priv::validate(data)) + return Identity(ptr); + return nullopt; +} + +vector<Ref> Identity::store() const +{ + vector<Ref> res; + res.reserve(p->data.size()); + for (const auto & x : p->data) + res.push_back(x.ref()); + return res; +} + +vector<Ref> Identity::store(const Storage & st) const +{ + vector<Ref> res; + res.reserve(p->data.size()); + for (const auto & x : p->data) + res.push_back(x.store(st)); + return res; +} + +vector<Stored<Signed<IdentityData>>> Identity::data() const +{ + vector<Stored<Signed<IdentityData>>> base; + base.reserve(p->data.size()); + + for (const auto & d : p->data) + base.push_back(d.base()); + filterAncestors(base); + return base; +} + +vector<StoredIdentityPart> Identity::extData() const +{ + return p->data; +} + +optional<string> Identity::name() const +{ + return p->name.get(); +} + +optional<Identity> Identity::owner() const +{ + return p->owner; +} + +const Identity & Identity::finalOwner() const +{ + if (p->owner) + return p->owner->finalOwner(); + return *this; +} + +Stored<PublicKey> Identity::keyIdentity() const +{ + return p->data[0].base()->data->keyIdentity; +} + +Stored<PublicKey> Identity::keyMessage() const +{ + return p->keyMessage; +} + +bool Identity::sameAs(const Identity & other) const +{ + // TODO: proper identity check + return p->data[0].base()->data->keyIdentity == + other.p->data[0].base()->data->keyIdentity; +} + +bool Identity::operator==(const Identity & other) const +{ + return p->data == other.p->data && + p->updates == other.p->updates; +} + +bool Identity::operator!=(const Identity & other) const +{ + return !(*this == other); +} + +optional<Ref> Identity::ref() const +{ + if (p->data.size() == 1) + return p->data[0].base().ref(); + return nullopt; +} + +optional<Ref> Identity::extRef() const +{ + if (p->data.size() == 1) + return p->data[0].ref(); + return nullopt; +} + +vector<Ref> Identity::refs() const +{ + auto base = data(); + vector<Ref> res; + res.reserve(base.size()); + + for (const auto & d : base) + res.push_back(d.ref()); + return res; +} + +vector<Ref> Identity::extRefs() const +{ + vector<Ref> res; + res.reserve(p->data.size()); + for (const auto & idata : p->data) + res.push_back(idata.ref()); + return res; +} + +vector<Ref> Identity::updates() const +{ + vector<Ref> res; + res.reserve(p->updates.size()); + for (const auto & idata : p->updates) + res.push_back(idata.ref()); + return res; +} + +Identity::Builder Identity::create(const Storage & st) +{ + return Builder (new Builder::Priv { + .storage = st, + .keyIdentity = SecretKey::generate(st).pub(), + .keyMessage = SecretKey::generate(st).pub(), + }); +} + +Identity::Builder Identity::modify() const +{ + vector<Stored<Signed<IdentityData>>> prevBase; + vector<StoredIdentityPart> prevExt; + + prevBase.reserve(p->data.size()); + for (const auto & d : p->data) { + prevBase.push_back(d.base()); + if (holds_alternative<Stored<Signed<IdentityExtension>>>(d.part)) + prevExt.push_back(d); + } + + filterAncestors(prevBase); + + return Builder (new Builder::Priv { + .storage = p->data[0].ref().storage(), + .prevBase = move(prevBase), + .prevExt = move(prevExt), + .keyIdentity = p->data[0].base()->data->keyIdentity, + .keyMessage = p->data[0].base()->data->keyMessage, + }); +} + +Identity Identity::update(const vector<Stored<Signed<IdentityData>>> & updates) const +{ + vector<StoredIdentityPart> eupdates; + eupdates.reserve(updates.size()); + for (const auto & u : updates) + eupdates.emplace_back(u); + return update(eupdates); +} + +static bool intersectsRoots(const vector<Digest> & x, const vector<Digest> & y) +{ + for (size_t i = 0, j = 0; + i < x.size() && j < y.size(); ) { + if (x[i] == y[j]) + return true; + if (x[i] < y[j]) + i++; + else + j++; + } + return false; +} + +Identity Identity::update(const vector<StoredIdentityPart> & updates) const +{ + vector<StoredIdentityPart> ndata = p->data; + vector<StoredIdentityPart> ownerUpdates = p->updates; + + for (const auto & u : updates) { + bool isOur = false; + for (const auto & d : p->data) { + if (intersectsRoots(u.roots(), d.roots())) { + isOur = true; + break; + } + } + + if (isOur) + ndata.emplace_back(u); + else + ownerUpdates.emplace_back(u); + } + + filterAncestors(ndata); + filterAncestors(ownerUpdates); + + if (auto p = Priv::validate(ndata)) { + p->updates = move(ownerUpdates); + if (p->owner && !p->updates.empty()) + p->owner = p->owner->update(p->updates); + return Identity(move(p)); + } + + return *this; +} + + +Identity::Builder::Builder(Priv * p): p(p) {} + +Identity Identity::Builder::commit() const +{ + optional<Stored<Signed<IdentityData>>> ownerBaseData; + optional<StoredIdentityPart> ownerExtData; + if (p->owner && p->owner->p->data.size() == 1) { + ownerExtData = p->owner->p->data[0]; + ownerBaseData = ownerExtData->base(); + if (holds_alternative<Stored<Signed<IdentityData>>>(ownerExtData->part)) + ownerExtData.reset(); + } + + auto base = p->storage.store(IdentityData { + .prev = p->prevBase, + .owner = ownerBaseData, + .keyIdentity = p->keyIdentity, + .keyMessage = p->keyMessage, + }); + + auto key = SecretKey::load(p->keyIdentity); + if (!key) + throw runtime_error("failed to load secret key"); + + auto sbase = key->sign(base); + if (base->owner) { + if (auto okey = SecretKey::load((*base->owner)->data->keyIdentity)) + sbase = okey->signAdd(sbase); + else + throw runtime_error("failed to load secret key"); + } + + optional<StoredIdentityPart> spart; + + if (not p->prevExt.empty() || p->name || ownerExtData) { + auto ext = p->storage.store(IdentityExtension { + .base = sbase, + .prev = p->prevExt, + .name = p->name, + .owner = ownerExtData, + }); + + auto sext = key->sign(ext); + if (ext->owner) { + if (auto okey = SecretKey::load(p->owner->keyIdentity())) + sext = okey->signAdd(sext); + else + throw runtime_error("failed to load secret key"); + } + + spart.emplace(sext); + } else { + spart.emplace(sbase); + } + + auto p = Identity::Priv::validate({ *spart }); + if (!p) + throw runtime_error("failed to validate committed identity"); + + return Identity(std::move(p)); +} + +void Identity::Builder::name(const string & val) +{ + p->name = val; +} + +void Identity::Builder::owner(const Identity & val) +{ + p->owner.emplace(val); +} + +IdentityData IdentityData::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) { + if (auto keyIdentity = rec->item("key-id").as<PublicKey>()) + return IdentityData { + .prev = rec->items("SPREV").as<Signed<IdentityData>>(), + .name = rec->item("name").asText(), + .owner = rec->item("owner").as<Signed<IdentityData>>(), + .keyIdentity = keyIdentity.value(), + .keyMessage = rec->item("key-msg").as<PublicKey>(), + }; + } + + return IdentityData { + .prev = {}, + .name = nullopt, + .owner = nullopt, + .keyIdentity = Stored<PublicKey>::load(ref.storage().zref()), + .keyMessage = nullopt, + }; +} + +Ref IdentityData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & p : prev) + items.emplace_back("SPREV", p.ref()); + if (name) + items.emplace_back("name", *name); + if (owner) + items.emplace_back("owner", owner->ref()); + items.emplace_back("key-id", keyIdentity.ref()); + if (keyMessage) + items.emplace_back("key-msg", keyMessage->ref()); + + return st.storeObject(Record(std::move(items))); +} + +IdentityExtension IdentityExtension::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) { + if (auto base = rec->item("SBASE").as<Signed<IdentityData>>()) { + vector<StoredIdentityPart> prev; + for (const auto & r : rec->items("SPREV").asRef()) + prev.push_back(StoredIdentityPart::load(r)); + + auto ownerRef = rec->item("owner").asRef(); + return IdentityExtension { + .base = *base, + .prev = move(prev), + .name = rec->item("name").asText(), + .owner = ownerRef ? optional(StoredIdentityPart::load(*ownerRef)) : nullopt, + }; + } + } + + return IdentityExtension { + .base = Stored<Signed<IdentityData>>::load(ref.storage().zref()), + .prev = {}, + .name = nullopt, + .owner = nullopt, + }; +} + +Ref IdentityExtension::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("SBASE", base); + for (const auto & p : prev) + items.emplace_back("SPREV", p.ref()); + if (name) + items.emplace_back("name", *name); + if (owner) + items.emplace_back("owner", owner->ref()); + + return st.storeObject(Record(std::move(items))); +} + +StoredIdentityPart StoredIdentityPart::load(const Ref & ref) +{ + if (auto srec = ref->asRecord()) { + if (auto sref = srec->item("SDATA").asRef()) { + if (auto rec = (*sref)->asRecord()) { + if (rec->item("SBASE")) { + return StoredIdentityPart(Stored<Signed<IdentityExtension>>::load(ref)); + } + } + } + } + + return StoredIdentityPart(Stored<Signed<IdentityData>>::load(ref)); +} + +Ref StoredIdentityPart::store(const Storage & st) const +{ + return visit([&](auto && p) { + return p.store(st); + }, part); +} + +const Ref & StoredIdentityPart::ref() const +{ + return visit([&](auto && p) -> auto const & { + return p.ref(); + }, part); +} + +const Stored<Signed<IdentityData>> & StoredIdentityPart::base() const +{ + return visit([&](auto && p) -> auto const & { + using T = std::decay_t<decltype(p)>; + if constexpr (std::is_same_v<T, Stored<Signed<IdentityData>>>) + return p; + else if constexpr (std::is_same_v<T, Stored<Signed<IdentityExtension>>>) + return p->data->base; + else + static_assert(always_false_v<T>, "non-exhaustive visitor!"); + }, part); +} + +vector<StoredIdentityPart> StoredIdentityPart::previous() const +{ + return visit([&](auto && p) { + using T = std::decay_t<decltype(p)>; + if constexpr (std::is_same_v<T, Stored<Signed<IdentityData>>>) { + vector<StoredIdentityPart> res; + res.reserve(p->data->prev.size()); + for (const auto & x : p->data->prev) + res.emplace_back(x); + return res; + + } else if constexpr (std::is_same_v<T, Stored<Signed<IdentityExtension>>>) { + vector<StoredIdentityPart> res; + res.reserve(1 + p->data->prev.size()); + res.emplace_back(p->data->base); + for (const auto & x : p->data->prev) + res.push_back(x); + return res; + + } else { + static_assert(always_false_v<T>, "non-exhaustive visitor!"); + } + }, part); +} + +vector<Digest> StoredIdentityPart::roots() const +{ + return visit([&](auto && p) { + return p.roots(); + }, part); +} + +optional<string> StoredIdentityPart::name() const +{ + return visit([&](auto && p) { + return p->data->name; + }, part); +} + +optional<StoredIdentityPart> StoredIdentityPart::owner() const +{ + return visit([&](auto && p) -> optional<StoredIdentityPart> { + if (p->data->owner) + return StoredIdentityPart(p->data->owner.value()); + return nullopt; + }, part); +} + +bool StoredIdentityPart::isSignedBy(const Stored<PublicKey> & key) const +{ + return visit([&](auto && p) { + return p->isSignedBy(key); + }, part); +} + + +bool Identity::Priv::verifySignatures(const StoredIdentityPart & sdata) +{ + if (!sdata.isSignedBy(sdata.base()->data->keyIdentity)) + return false; + + for (const auto & p : sdata.previous()) + if (!sdata.isSignedBy(p.base()->data->keyIdentity)) + return false; + + if (auto owner = sdata.owner()) + if (!sdata.isSignedBy(owner->base()->data->keyIdentity)) + return false; + + for (const auto & p : sdata.previous()) + if (!verifySignatures(p)) + return false; + + return true; +} + +shared_ptr<Identity::Priv> Identity::Priv::validate(const vector<StoredIdentityPart> & sdata) +{ + for (const auto & d : sdata) + if (!verifySignatures(d)) + return nullptr; + + auto keyMessageItem = lookupProperty(sdata, [] + (const StoredIdentityPart & d) { return d.base()->data->keyMessage.has_value(); }); + if (!keyMessageItem) + return nullptr; + + auto p = new Priv { + .data = sdata, + .updates = {}, + .name = {}, + .owner = nullopt, + .keyMessage = keyMessageItem->base()->data->keyMessage.value(), + }; + shared_ptr<Priv> ret(p); + + auto ownerProp = lookupProperty(sdata, [] + (const StoredIdentityPart & d) { return d.owner().has_value(); }); + if (ownerProp) { + auto owner = validate({ ownerProp->owner().value() }); + if (!owner) + return nullptr; + p->owner.emplace(Identity(owner)); + } + + p->name = async(std::launch::deferred, [p] () -> optional<string> { + if (auto d = lookupProperty(p->data, [] (const StoredIdentityPart & d) { return d.name().has_value(); })) + return d->name(); + return nullopt; + }); + + return ret; +} + +optional<StoredIdentityPart> Identity::Priv::lookupProperty( + const vector<StoredIdentityPart> & data, + function<bool(const StoredIdentityPart &)> sel) +{ + set<StoredIdentityPart> current, prop_heads; + + for (const auto & d : data) + current.insert(d); + + while (!current.empty()) { + StoredIdentityPart sdata = + current.extract(current.begin()).value(); + + if (sel(sdata)) + prop_heads.insert(sdata); + else + for (const auto & p : sdata.previous()) + current.insert(p); + } + + for (auto x = prop_heads.begin(); x != prop_heads.end(); x++) + for (auto y = prop_heads.begin(); y != prop_heads.end();) + if (y != x && precedes(*y, *x)) + y = prop_heads.erase(y); + else + y++; + + if (prop_heads.begin() != prop_heads.end()) + return *prop_heads.begin(); + return nullopt; +} diff --git a/src/identity.h b/src/identity.h new file mode 100644 index 0000000..bfa5932 --- /dev/null +++ b/src/identity.h @@ -0,0 +1,67 @@ +#pragma once + +#include <erebos/identity.h> +#include "pubkey.h" + +#include <future> +#include <variant> + +using std::function; +using std::optional; +using std::shared_future; +using std::string; +using std::variant; +using std::vector; + +namespace erebos { + +struct IdentityData +{ + static IdentityData load(const Ref &); + Ref store(const Storage & st) const; + + const vector<Stored<Signed<IdentityData>>> prev {}; + const optional<string> name {}; + const optional<Stored<Signed<IdentityData>>> owner {}; + const Stored<PublicKey> keyIdentity; + const optional<Stored<PublicKey>> keyMessage; +}; + +struct IdentityExtension +{ + static IdentityExtension load(const Ref &); + Ref store(const Storage & st) const; + + const Stored<Signed<IdentityData>> base; + const vector<StoredIdentityPart> prev {}; + const optional<string> name {}; + const optional<StoredIdentityPart> owner {}; +}; + +struct Identity::Priv +{ + vector<StoredIdentityPart> data; + vector<StoredIdentityPart> updates; + shared_future<optional<string>> name; + optional<Identity> owner; + Stored<PublicKey> keyMessage; + + static bool verifySignatures(const StoredIdentityPart & sdata); + static shared_ptr<Priv> validate(const vector<StoredIdentityPart> & sdata); + static optional<StoredIdentityPart> lookupProperty( + const vector<StoredIdentityPart> & data, + function<bool(const StoredIdentityPart &)> sel); +}; + +struct Identity::Builder::Priv +{ + Storage storage; + vector<Stored<Signed<IdentityData>>> prevBase = {}; + vector<StoredIdentityPart> prevExt = {}; + optional<string> name = nullopt; + optional<Identity> owner = nullopt; + Stored<PublicKey> keyIdentity; + optional<Stored<PublicKey>> keyMessage; +}; + +} diff --git a/src/main.cpp b/src/main.cpp new file mode 100644 index 0000000..a0a9458 --- /dev/null +++ b/src/main.cpp @@ -0,0 +1,693 @@ +#include <erebos/attach.h> +#include <erebos/contact.h> +#include <erebos/identity.h> +#include <erebos/message.h> +#include <erebos/network.h> +#include <erebos/set.h> +#include <erebos/storage.h> +#include <erebos/sync.h> + +#include "storage.h" + +#include <arpa/inet.h> +#include <netinet/in.h> +#include <sys/socket.h> + +#include <filesystem> +#include <functional> +#include <future> +#include <iostream> +#include <map> +#include <memory> +#include <mutex> +#include <optional> +#include <sstream> +#include <stdexcept> +#include <string> +#include <thread> +#include <vector> + +using std::cerr; +using std::cout; +using std::endl; +using std::function; +using std::future; +using std::invalid_argument; +using std::make_unique; +using std::map; +using std::mutex; +using std::optional; +using std::ostringstream; +using std::promise; +using std::scoped_lock; +using std::string; +using std::thread; +using std::to_string; +using std::unique_ptr; +using std::vector; + +namespace fs = std::filesystem; + +using namespace erebos; + +namespace { + +fs::path getErebosDir() +{ + const char * value = getenv("EREBOS_DIR"); + if (value) + return value; + return "./.erebos"; +} + +mutex outputMutex; +void printLine(const string & line) +{ + scoped_lock lock(outputMutex); + cout << line << std::endl; +} + +Storage st(getErebosDir()); +optional<Head<LocalState>> testHead; +optional<Server> server; + +struct TestPeer +{ + Peer peer; + size_t id; + bool deleted = false; + promise<bool> pairingAnswer {}; +}; +vector<TestPeer> testPeers; + +TestPeer & getPeer(const string & name) +{ + try { + return testPeers.at(std::stoi(name) - 1); + } + catch (const std::invalid_argument &) {} + + for (auto & p : testPeers) + if (p.peer.name() == name) + return p; + + ostringstream ss; + ss << "peer '" << name << "' not found"; + throw invalid_argument(ss.str().c_str()); +} + +TestPeer & getPeer(const Peer & peer) +{ + for (auto & p : testPeers) + if (p.peer == peer) + return p; + throw invalid_argument("peer not found"); +} + +Contact getContact(const string & id) +{ + auto cmp = [](const Contact & x, const Contact & y) { + return x.data() < y.data(); + }; + for (const auto & c : testHead->behavior().lens<SharedState>().lens<Set<Contact>>().get().view(cmp)) { + if (string(c.leastRoot()) == id) { + return c; + } + } + + ostringstream ss; + ss << "contact '" << id << "' not found"; + throw invalid_argument(ss.str().c_str()); +} + +struct Command +{ + string name; + function<void(const vector<string> &)> action; +}; + +void store(const vector<string> & args) +{ + auto type = args.at(0); + + vector<uint8_t> inner, data; + + char * line = nullptr; + size_t size = 0; + + while (getline(&line, &size, stdin) > 0 && strlen(line) > 1) + copy(line, line + strlen(line), std::back_inserter(inner)); + + free(line); + + auto inserter = std::back_inserter(data); + copy(type.begin(), type.end(), inserter); + inserter = ' '; + + auto slen = to_string(inner.size()); + copy(slen.begin(), slen.end(), inserter); + inserter = '\n'; + + copy(inner.begin(), inner.end(), inserter); + + auto digest = st.priv().storeBytes(data); + + ostringstream ss; + ss << "store-done " << string(digest); + printLine(ss.str()); +} + +void storedGeneration(const vector<string> & args) +{ + auto ref = st.ref(Digest(args.at(0))); + if (!ref) + throw invalid_argument("ref " + args.at(0) + " not found"); + + ostringstream ss; + ss << "stored-generation " << string(ref->digest()) << " " << string(ref->generation()); + printLine(ss.str()); +} + +void storedRoots(const vector<string> & args) +{ + auto ref = st.ref(Digest(args.at(0))); + if (!ref) + throw invalid_argument("ref " + args.at(0) + " not found"); + + ostringstream ss; + ss << "stored-roots " << string(ref->digest()); + for (const auto & dgst : ref->roots()) + ss << " " << string(dgst); + printLine(ss.str()); +} + +void storedSetAdd(const vector<string> & args) +{ + auto iref = st.ref(Digest(args.at(0))); + if (!iref) + throw invalid_argument("ref " + args.at(0) + " not found"); + + auto set = args.size() > 1 ? + Set<vector<Stored<Object>>>::load({ *st.ref(Digest(args.at(1))) }) : + Set<vector<Stored<Object>>>(); + + ostringstream ss; + ss << "stored-set-add"; + for (const auto & d : set.add(st, { Stored<Object>::load(*iref) }).digests()) + ss << " " << string(d); + printLine(ss.str()); +} + +void storedSetList(const vector<string> & args) +{ + auto ref = st.ref(Digest(args.at(0))); + if (!ref) + throw invalid_argument("ref " + args.at(0) + " not found"); + + for (const auto & vec : Set<vector<Stored<Object>>>::load({ *ref }).view(std::less{})) { + ostringstream ss; + ss << "stored-set-item"; + for (const auto & x : vec) + ss << " " << string(x.ref().digest()); + printLine(ss.str()); + } + printLine("stored-set-done"); +} + +void createIdentity(const vector<string> & args) +{ + optional<Identity> identity; + for (ssize_t i = args.size() - 1; i >= 0; i--) { + const auto & name = args[i]; + auto builder = Identity::create(st); + builder.name(name); + if (identity) + builder.owner(*identity); + identity = builder.commit(); + } + + if (identity) { + auto nh = testHead->update([&identity] (const auto & loc) { + auto ret = loc->identity(*identity); + if (identity->owner()) + ret = ret.template shared<optional<Identity>>(identity->finalOwner()); + return st.store(ret); + }); + if (nh) + *testHead = *nh; + } +} + +void printPairingResult(string prefix, Peer peer, future<PairingServiceBase::Outcome> && future) +{ + auto outcome = future.get(); + ostringstream ss; + ss << prefix << + (outcome == PairingServiceBase::Outcome::Success ? "-done " : "-failed ") << + getPeer(peer).id; + switch (outcome) + { + case PairingServiceBase::Outcome::Success: break; + case PairingServiceBase::Outcome::PeerRejected: ss << " rejected"; break; + case PairingServiceBase::Outcome::UserRejected: ss << " user"; break; + case PairingServiceBase::Outcome::UnexpectedMessage: ss << " unexpected"; break; + case PairingServiceBase::Outcome::NonceMismatch: ss << " nonce"; break; + case PairingServiceBase::Outcome::Stale: ss << " stale"; break; + } + printLine(ss.str()); +} + +future<bool> confirmPairing(string prefix, const Peer & peer, string confirm, future<PairingServiceBase::Outcome> && outcome) +{ + thread(printPairingResult, prefix, peer, move(outcome)).detach(); + + promise<bool> promise; + auto input = promise.get_future(); + getPeer(peer).pairingAnswer = move(promise); + + ostringstream ss; + ss << prefix << " " << getPeer(peer).id << " " << confirm; + printLine(ss.str()); + return input; +} + +void startServer(const vector<string> &) +{ + using namespace std::placeholders; + + ServerConfig config; + + config.service<AttachService>() + .onRequest(bind(confirmPairing, "attach-request", _1, _2, _3)) + .onResponse(bind(confirmPairing, "attach-response", _1, _2, _3)) + ; + + config.service<ContactService>() + .onRequest(bind(confirmPairing, "contact-request", _1, _2, _3)) + .onResponse(bind(confirmPairing, "contact-response", _1, _2, _3)) + ; + + config.service<DirectMessageService>() + .onUpdate([](const DirectMessageThread & thread, ssize_t, ssize_t) { + if (thread.at(0).from()->sameAs(server->identity().finalOwner())) + return; + + ostringstream ss; + + string name = "<unnamed>"; + if (auto from = thread.at(0).from()) + if (auto fname = from->name()) + name = *fname; + + ss << "dm-received" + << " from " << name + << " text " << thread.at(0).text() + ; + printLine(ss.str()); + }) + ; + + config.service<SyncService>(); + + server.emplace(*testHead, move(config)); + + server->peerList().onUpdate([](size_t idx, const Peer * peer) { + size_t i = 0; + while (idx > 0 && i < testPeers.size()) { + if (!testPeers[i].deleted) + idx--; + i++; + } + + string prefix = "peer " + to_string(i + 1); + if (peer) { + if (i >= testPeers.size()) { + testPeers.push_back(TestPeer { .peer = *peer, .id = i + 1 }); + + ostringstream ss; + ss << prefix << " addr " << peer->addressStr() << " " << peer->port(); + printLine(ss.str()); + } + + if (peer->identity()) { + ostringstream ss; + ss << prefix << " id"; + for (auto idt = peer->identity(); idt; idt = idt->owner()) + ss << " " << (idt->name() ? *idt->name() : "<unnamed>"); + printLine(ss.str()); + } + } else { + testPeers[i].deleted = true; + printLine(prefix + " deleted"); + } + }); +} + +void stopServer(const vector<string> &) +{ + server.reset(); + testPeers.clear(); + printLine("stop-server-done"); +} + +void peerAdd(const vector<string> & args) +{ + if (args.size() == 1) + server->addPeer(args.at(0)); + else if (args.size() == 2) + server->addPeer(args.at(0), args.at(1)); + else + throw invalid_argument("usage: peer-add <node> [<port>]"); +} + +void sharedStateGet(const vector<string> &) +{ + ostringstream ss; + ss << "shared-state-get"; + for (const auto & r : testHead->behavior().lens<vector<Ref>>().get()) + ss << " " << string(r.digest()); + printLine(ss.str()); +} + +void sharedStateWait(const vector<string> & args) +{ + struct SharedStateWait + { + mutex lock; + bool done { false }; + optional<Watched<vector<Ref>>> watched; + }; + auto watchedPtr = make_shared<SharedStateWait>(); + + auto watched = testHead->behavior().lens<vector<Ref>>().watch([args, watchedPtr] (const vector<Ref> & refs) { + vector<Stored<Object>> objs; + objs.reserve(refs.size()); + for (const auto & r : refs) + objs.push_back(Stored<Object>::load(r)); + + auto objs2 = objs; + for (const auto & a : args) + if (auto ref = st.ref(Digest(a))) + objs2.push_back(Stored<Object>::load(*ref)); + else + return; + + filterAncestors(objs2); + if (objs2 == objs) { + ostringstream ss; + ss << "shared-state-wait"; + for (const auto & a : args) + ss << " " << a; + printLine(ss.str()); + + scoped_lock lock(watchedPtr->lock); + watchedPtr->done = true; + watchedPtr->watched = std::nullopt; + } + }); + + scoped_lock lock(watchedPtr->lock); + if (!watchedPtr->done) + watchedPtr->watched = move(watched); +} + +void watchLocalIdentity(const vector<string> &) +{ + auto bhv = testHead->behavior().lens<optional<Identity>>(); + static auto watchedLocalIdentity = bhv.watch([] (const optional<Identity> & idt) { + if (idt) { + ostringstream ss; + ss << "local-identity"; + for (optional<Identity> i = idt; i; i = i->owner()) { + if (auto name = i->name()) + ss << " " << i->name().value(); + else + ss << " <unnamed>"; + } + printLine(ss.str()); + } + }); +} + +void watchSharedIdentity(const vector<string> &) +{ + auto bhv = testHead->behavior().lens<SharedState>().lens<optional<Identity>>(); + static auto watchedSharedIdentity = bhv.watch([] (const optional<Identity> & idt) { + if (idt) { + ostringstream ss; + ss << "shared-identity"; + for (optional<Identity> i = idt; i; i = i->owner()) + ss << " " << i->name().value(); + printLine(ss.str()); + } + }); +} + +void updateLocalIdentity(const vector<string> & params) +{ + if (params.size() != 1) { + throw invalid_argument("usage: update-local-identity <name>"); + } + + auto nh = testHead->update([¶ms] (const Stored<LocalState> & loc) { + auto st = loc.ref().storage(); + + auto b = loc->identity()->modify(); + b.name(params[0]); + return st.store(loc->identity(b.commit())); + }); + if (nh) + *testHead = *nh; +} + +void updateSharedIdentity(const vector<string> & params) +{ + if (params.size() != 1) { + throw invalid_argument("usage: update-shared-identity <name>"); + } + + auto nh = testHead->update([¶ms] (const Stored<LocalState> & loc) { + auto st = loc.ref().storage(); + auto mbid = loc->shared<optional<Identity>>(); + if (!mbid) + return loc; + + auto b = mbid->modify(); + b.name(params[0]); + return st.store(loc->shared<optional<Identity>>(optional(b.commit()))); + }); + if (nh) + *testHead = *nh; +} + +void attachTo(const vector<string> & params) +{ + server->svc<AttachService>().attachTo(getPeer(params.at(0)).peer); +} + +void attachAccept(const vector<string> & params) +{ + getPeer(params.at(0)).pairingAnswer.set_value(true); +} + +void attachReject(const vector<string> & params) +{ + getPeer(params.at(0)).pairingAnswer.set_value(false); +} + +void contactRequest(const vector<string> & params) +{ + server->svc<ContactService>().request(getPeer(params.at(0)).peer); +} + +void contactAccept(const vector<string> & params) +{ + getPeer(params.at(0)).pairingAnswer.set_value(true); +} + +void contactReject(const vector<string> & params) +{ + getPeer(params.at(0)).pairingAnswer.set_value(false); +} + +void contactList(const vector<string> &) +{ + auto cmp = [](const Contact & x, const Contact & y) { + return x.data() < y.data(); + }; + for (const auto & c : testHead->behavior().lens<SharedState>().lens<Set<Contact>>().get().view(cmp)) { + ostringstream ss; + ss << "contact-list-item " << string(c.leastRoot()) << " " << c.name(); + if (auto id = c.identity()) + if (auto iname = id->name()) + ss << " " << *iname; + printLine(ss.str()); + } + printLine("contact-list-done"); +} + +void contactSetName(const vector<string> & args) +{ + auto id = args.at(0); + auto name = args.at(1); + + auto c = getContact(id); + auto nh = testHead->update([&] (const Stored<LocalState> & loc) { + auto st = loc.ref().storage(); + auto cc = c.customName(st, name); + auto contacts = loc->shared<Set<Contact>>(); + return st.store(loc->shared<Set<Contact>>(contacts.add(st, cc))); + }); + if (nh) + *testHead = *nh; + + printLine("contact-set-name-done"); +} + +void dmSendPeer(const vector<string> & args) +{ + DirectMessageService::send( + *testHead, + getPeer(args.at(0)).peer, + args.at(1)); +} + +void dmSendContact(const vector<string> & args) +{ + DirectMessageService::send( + *testHead, + getContact(args.at(0)), + args.at(1)); +} + +template<class T> +static void dmList(const T & peer) +{ + if (auto id = peer.identity()) + for (const auto & msg : testHead->behavior().get().shared<DirectMessageThreads>().thread(*id)) { + string name = "<unnamed>"; + if (const auto & from = msg.from()) + if (const auto & opt = from->name()) + name = *opt; + + ostringstream ss; + ss << "dm-list-item" + << " from " << name + << " text " << msg.text() + ; + printLine(ss.str()); + } + printLine("dm-list-done"); +} + +void dmListPeer(const vector<string> & args) +{ + dmList(getPeer(args.at(0)).peer); +} + +void dmListContact(const vector<string> & args) +{ + dmList(getContact(args.at(0))); +} + +vector<Command> commands = { + { "store", store }, + { "stored-generation", storedGeneration }, + { "stored-roots", storedRoots }, + { "stored-set-add", storedSetAdd }, + { "stored-set-list", storedSetList }, + { "create-identity", createIdentity }, + { "start-server", startServer }, + { "stop-server", stopServer }, + { "peer-add", peerAdd }, + { "shared-state-get", sharedStateGet }, + { "shared-state-wait", sharedStateWait }, + { "watch-local-identity", watchLocalIdentity }, + { "watch-shared-identity", watchSharedIdentity }, + { "update-local-identity", updateLocalIdentity }, + { "update-shared-identity", updateSharedIdentity }, + { "attach-to", attachTo }, + { "attach-accept", attachAccept }, + { "attach-reject", attachReject }, + { "contact-request", contactRequest }, + { "contact-accept", contactAccept }, + { "contact-reject", contactReject }, + { "contact-list", contactList }, + { "contact-set-name", contactSetName }, + { "dm-send-peer", dmSendPeer }, + { "dm-send-contact", dmSendContact }, + { "dm-list-peer", dmListPeer }, + { "dm-list-contact", dmListContact }, +}; + +} + +int main(int argc, char * argv[]) +{ + testHead.emplace([] { + auto hs = st.heads<LocalState>(); + if (!hs.empty()) + return hs[0]; + else + return st.storeHead(LocalState()); + }()); + + char * line = nullptr; + size_t size = 0; + + if (argc > 1) { + vector<string> args; + for (int i = 2; i < argc; i++) + args.emplace_back(argv[i]); + + for (const auto & cmd : commands) { + if (cmd.name == argv[1]) { + cmd.action(args); + return 0; + } + } + + cerr << "Unknown command: '" << argv[1] << "'" << endl; + return 1; + } + + while (getline(&line, &size, stdin) > 0) { + optional<string> command; + vector<string> args; + + const char * last = line; + for (const char * cur = line;; cur++) { + if (isspace(*cur) || *cur == '\0') { + if (last < cur) { + if (!command) + command.emplace(last, cur); + else + args.emplace_back(last, cur); + } + last = cur + 1; + + if (*cur == '\0') + break; + } + } + + if (!command) + continue; + + bool found = false; + for (const auto & cmd : commands) { + if (cmd.name == *command) { + found = true; + cmd.action(args); + break; + } + } + + if (!found) + cerr << "Unknown command: '" << *command << "'" << endl; + } + + free(line); + server.reset(); + return 0; +} diff --git a/src/merge.cpp b/src/merge.cpp new file mode 100644 index 0000000..040ebb4 --- /dev/null +++ b/src/merge.cpp @@ -0,0 +1,26 @@ +#include <erebos/merge.h> + +namespace erebos { + +static void findPropertyObjects(vector<Stored<Object>> & candidates, const Stored<Object> & obj, const string & prop) +{ + if (auto rec = obj->asRecord()) { + if (rec->item(prop)) { + candidates.push_back(obj); + } else { + for (const auto & r : obj.ref().previous()) + findPropertyObjects(candidates, Stored<Object>::load(r), prop); + } + } +} + +vector<Stored<Object>> findPropertyObjects(const vector<Stored<Object>> & leaves, const string & prop) +{ + vector<Stored<Object>> candidates; + for (const auto & obj : leaves) + findPropertyObjects(candidates, obj, prop); + filterAncestors(candidates); + return candidates; +} + +} diff --git a/src/message.cpp b/src/message.cpp new file mode 100644 index 0000000..349accb --- /dev/null +++ b/src/message.cpp @@ -0,0 +1,568 @@ +#include "message.h" + +#include <erebos/contact.h> +#include <erebos/network.h> + +#include <iostream> +#include <thread> + +using namespace erebos; +using std::nullopt; +using std::scoped_lock; +using std::unique_lock; + +static const UUID myUUID("c702076c-4928-4415-8b6b-3e839eafcb0d"); + +DEFINE_SHARED_TYPE(DirectMessageThreads, + "ee793681-5976-466a-b0f0-4e1907d3fade", + &DirectMessageThreads::load, + [](const DirectMessageThreads & threads) { + return threads.store(); + }) + + +static void findThreadComponents(vector<Stored<DirectMessageState>> & candidates, + const Stored<DirectMessageState> & cur, + const Identity & peer, + vector<Stored<DirectMessageData>> DirectMessageState::* sel) +{ + if (cur->peer && cur->peer->sameAs(peer) && not ((*cur).*sel).empty()) + candidates.push_back(cur); + else + for (const auto & p : cur->prev) + findThreadComponents(candidates, p, peer, sel); +} + +static vector<Stored<DirectMessageState>> findThreadComponents( + const vector<Stored<DirectMessageState>> & leaves, + const Identity & peer, + vector<Stored<DirectMessageData>> DirectMessageState::* sel) +{ + vector<Stored<DirectMessageState>> candidates; + for (const auto & obj : leaves) + findThreadComponents(candidates, obj, peer, sel); + filterAncestors(candidates); + return candidates; +} + + +DirectMessage::DirectMessage(Priv * p): + p(p) +{} + +DirectMessageData DirectMessageData::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return DirectMessageData(); + + auto fref = rec->item("from").asRef(); + + return DirectMessageData { + .prev = rec->items("PREV").as<DirectMessageData>(), + .from = fref ? Identity::load(*fref) : nullopt, + .time = *rec->item("time").asDate(), + .text = rec->item("text").asText().value(), + }; +} + +Ref DirectMessageData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & prev : prev) + items.emplace_back("PREV", prev.ref()); + if (from) + items.emplace_back("from", from->extRef().value()); + if (time) + items.emplace_back("time", *time); + if (text) + items.emplace_back("text", *text); + + return st.storeObject(Record(std::move(items))); +} + + +const optional<Identity> & DirectMessage::from() const +{ + return p->data->from; +} + +const optional<ZonedTime> & DirectMessage::time() const +{ + return p->data->time; +} + +string DirectMessage::text() const +{ + if (p->data->text) + return p->data->text.value(); + return ""; +} + + +DirectMessageThread::DirectMessageThread(Priv * p): + p(p) +{} + +DirectMessageThread::Iterator::Iterator(Priv * p): + p(p) +{} + +DirectMessageThread::Iterator::Iterator(const Iterator & other): + Iterator(new Priv(*other.p)) +{} + +DirectMessageThread::Iterator::~Iterator() = default; + +DirectMessageThread::Iterator & DirectMessageThread::Iterator::operator=(const Iterator & other) +{ + p.reset(new Priv(*other.p)); + return *this; +} + +DirectMessageThread::Iterator & DirectMessageThread::Iterator::operator++() +{ + if (p->current) + for (const auto & m : p->current->p->data->prev) + p->next.push_back(m); + + if (p->next.empty()) { + p->current.reset(); + } else { + filterAncestors(p->next); + auto ncur = p->next[0]; + + for (const auto & m : p->next) + if (!ncur->time || (m->time && m->time->time >= ncur->time->time)) + ncur = m; + + p->current.emplace(DirectMessage(new DirectMessage::Priv { + .data = ncur, + })); + + p->next.erase(std::remove(p->next.begin(), p->next.end(), p->current->p->data)); + } + + return *this; +} + +DirectMessage DirectMessageThread::Iterator::operator*() const +{ + return *p->current; +} + +bool DirectMessageThread::Iterator::operator==(const Iterator & other) const +{ + if (p->current && other.p->current) + return p->current->p->data == other.p->current->p->data; + return bool(p->current) == bool(other.p->current); +} + +bool DirectMessageThread::Iterator::operator!=(const Iterator & other) const +{ + return !(*this == other); +} + +DirectMessageThread::Iterator DirectMessageThread::begin() const +{ + return ++Iterator(new Iterator::Priv { + .current = {}, + .next = p->head, + }); +} + +DirectMessageThread::Iterator DirectMessageThread::end() const +{ + return Iterator(new Iterator::Priv { + .current = {}, + .next = {}, + }); +} + +size_t DirectMessageThread::size() const +{ + size_t c = 0; + for (auto it = begin(); it != end(); ++it) + c++; + return c; +} + +DirectMessage DirectMessageThread::at(size_t i) const +{ + return *std::next(begin(), i); +} + +const Identity & DirectMessageThread::peer() const +{ + return p->peer; +} + + +DirectMessageState DirectMessageState::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) { + return DirectMessageState { + .prev = rec->items("PREV").as<DirectMessageState>(), + .peer = Identity::load(rec->items("peer").asRef()), + + .ready = rec->items("ready").as<DirectMessageData>(), + .sent = rec->items("sent").as<DirectMessageData>(), + .received = rec->items("received").as<DirectMessageData>(), + .seen = rec->items("seen").as<DirectMessageData>(), + }; + } + + return DirectMessageState(); +} + +Ref DirectMessageState::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & prev : prev) + items.emplace_back("PREV", prev.ref()); + if (peer) + for (const auto & ref : peer->refs()) + items.emplace_back("peer", ref); + + for (const auto & x : ready) + items.emplace_back("ready", x.ref()); + for (const auto & x : sent) + items.emplace_back("sent", x.ref()); + for (const auto & x : received) + items.emplace_back("received", x.ref()); + for (const auto & x : seen) + items.emplace_back("seen", x.ref()); + + return st.storeObject(Record(std::move(items))); +} + + +DirectMessageThreads::DirectMessageThreads() = default; + +DirectMessageThreads::DirectMessageThreads(Stored<DirectMessageState> s): + DirectMessageThreads(vector<Stored<DirectMessageState>> { move(s) }) +{ +} + +DirectMessageThreads::DirectMessageThreads(vector<Stored<DirectMessageState>> s): + state(move(s)) +{ +} + +DirectMessageThreads DirectMessageThreads::load(const vector<Ref> & refs) +{ + DirectMessageThreads res; + res.state.reserve(refs.size()); + for (const auto & ref : refs) + res.state.push_back(Stored<DirectMessageState>::load(ref)); + return res; +} + +vector<Ref> DirectMessageThreads::store() const +{ + vector<Ref> refs; + refs.reserve(state.size()); + for (const auto & x : state) + refs.push_back(x.ref()); + return refs; +} + +vector<Stored<DirectMessageState>> DirectMessageThreads::data() const +{ + return state; +} + +bool DirectMessageThreads::operator==(const DirectMessageThreads & other) const +{ + return state == other.state; +} + +bool DirectMessageThreads::operator!=(const DirectMessageThreads & other) const +{ + return state != other.state; +} + +DirectMessageThread DirectMessageThreads::thread(const Identity & peer) const +{ + vector<Stored<DirectMessageData>> head; + for (const auto & c : findThreadComponents(state, peer, &DirectMessageState::ready)) + for (const auto & m : c->ready) + head.push_back(m); + for (const auto & c : findThreadComponents(state, peer, &DirectMessageState::received)) + for (const auto & m : c->received) + head.push_back(m); + filterAncestors(head); + + return new DirectMessageThread::Priv { + .peer = peer, + .head = move(head), + }; +} + +vector<Stored<DirectMessageState>> Mergeable<DirectMessageThreads>::components(const DirectMessageThreads & threads) +{ + return threads.data(); +} + + +DirectMessageService::Config & DirectMessageService::Config::onUpdate(ThreadWatcher w) +{ + watchers.push_back(w); + return *this; +} + +DirectMessageService::DirectMessageService(Config && c, const Server & s): + config(move(c)), + server(s), + watched(server.localState().lens<SharedState>().lens<DirectMessageThreads>().watch( + std::bind(&DirectMessageService::updateHandler, this, std::placeholders::_1))) +{ + server.peerList().onUpdate(std::bind(&DirectMessageService::peerWatcher, this, + std::placeholders::_1, std::placeholders::_2)); + + peerSyncRun = true; + peerSyncThread = std::thread(&DirectMessageService::doSyncWithPeers, this); +} + +DirectMessageService::~DirectMessageService() +{ + { + scoped_lock lock(peerSyncMutex); + peerSyncRun = false; + } + peerSyncCond.notify_all(); + peerSyncThread.join(); +} + +UUID DirectMessageService::uuid() const +{ + return myUUID; +} + +void DirectMessageService::handle(Context & ctx) +{ + auto pid = ctx.peer().identity(); + if (!pid) + return; + auto powner = pid->finalOwner(); + + auto msg = Stored<DirectMessageData>::load(ctx.ref()); + + server.localHead().update([&](const Stored<LocalState> & loc) { + auto st = loc.ref().storage(); + auto threads = loc->shared<DirectMessageThreads>(); + + vector<Stored<DirectMessageData>> receivedOld; + for (const auto & c : findThreadComponents(threads.state, powner, &DirectMessageState::received)) + for (const auto & m : c->received) + receivedOld.push_back(m); + auto receivedNew = receivedOld; + receivedNew.push_back(msg); + filterAncestors(receivedNew); + + if (receivedNew != receivedOld) { + auto state = st.store(DirectMessageState { + .prev = threads.data(), + .peer = powner, + .received = { msg }, + }); + + auto res = st.store(loc->shared<DirectMessageThreads>(DirectMessageThreads(state))); + return res; + } else { + return loc; + } + }); +} + +DirectMessageThread DirectMessageService::thread(const Identity & peer) +{ + return server.localState().get().shared<DirectMessageThreads>().thread(peer); +} + +DirectMessage DirectMessageService::send(const Head<LocalState> & head, const Identity & to, const string & text) +{ + Stored<DirectMessageData> msg; + + head.update([&](const Stored<LocalState> & loc) { + auto st = loc.ref().storage(); + + auto threads = loc->shared<DirectMessageThreads>(); + msg = st.store(DirectMessageData { + .prev = threads.thread(to).p->head, + .from = loc->identity()->finalOwner(), + .time = ZonedTime::now(), + .text = text, + }); + + auto state = st.store(DirectMessageState { + .prev = threads.data(), + .peer = to, + .ready = { msg }, + }); + + return st.store(loc->shared<DirectMessageThreads>(DirectMessageThreads(state))); + }); + + return DirectMessage(new DirectMessage::Priv { + .data = move(msg), + }); +} + +DirectMessage DirectMessageService::send(const Head<LocalState> & head, const Contact & to, const string & text) +{ + if (auto id = to.identity()) + return send(head, *id, text); + throw std::runtime_error("contact without erebos identity"); +} + +DirectMessage DirectMessageService::send(const Head<LocalState> & head, const Peer & to, const string & text) +{ + if (auto id = to.identity()) + return send(head, id->finalOwner(), text); + throw std::runtime_error("peer without known identity"); +} + +DirectMessage DirectMessageService::send(const Identity & to, const string & text) +{ + return send(server.localHead(), to, text); +} + +DirectMessage DirectMessageService::send(const Contact & to, const string & text) +{ + if (auto id = to.identity()) + return send(*id, text); + throw std::runtime_error("contact without erebos identity"); +} + +DirectMessage DirectMessageService::send(const Peer & to, const string & text) +{ + if (auto id = to.identity()) + return send(id->finalOwner(), text); + throw std::runtime_error("peer without known identity"); +} + +void DirectMessageService::updateHandler(const DirectMessageThreads & threads) +{ + scoped_lock lock(stateMutex); + + auto state = prevState; + for (const auto & s : threads.state) + state.push_back(s); + filterAncestors(state); + + if (state != prevState) { + auto queue = state; + vector<Identity> peers; + + while (not queue.empty()) { + auto cur = move(queue.back()); + queue.pop_back(); + + if (auto peer = cur->peer) { + bool found = false; + for (const auto & p : peers) { + if (p.sameAs(*peer)) { + found = true; + break; + } + } + + if (not found) + peers.push_back(*peer); + + for (const auto & prev : cur->prev) + queue.push_back(prev); + } + } + + for (const auto & peer : peers) { + auto dmt = threads.thread(peer); + for (const auto & w : config.watchers) + w(dmt, -1, -1); + + if (auto netPeer = server.peer(peer)) + syncWithPeer(dmt, *netPeer); + } + + prevState = move(state); + } +} + +void DirectMessageService::peerWatcher(size_t, const class Peer * peer) +{ + if (peer) { + if (auto pid = peer->identity()) { + syncWithPeer(thread(pid->finalOwner()), *peer); + } + } +} + +void DirectMessageService::syncWithPeer(const DirectMessageThread & thread, const Peer & peer) +{ + { + scoped_lock lock(peerSyncMutex); + peerSyncQueue.emplace_back(thread, peer); + } + peerSyncCond.notify_one(); +} + +void DirectMessageService::doSyncWithPeers() +{ + unique_lock lock(peerSyncMutex); + + while (peerSyncRun) + { + if (peerSyncQueue.empty()) { + peerSyncCond.wait(lock); + continue; + } + + auto & [ thread, peer ] = peerSyncQueue.front(); + lock.unlock(); + + doSyncWithPeer(thread, peer); + + lock.lock(); + peerSyncQueue.pop_front(); + } +} + +void DirectMessageService::doSyncWithPeer(const DirectMessageThread & thread, const Peer & peer) +{ + for (const auto & msg : thread.p->head) + if (not peer.send(myUUID, msg.ref())) + return; + + const auto & head = server.localHead(); + head.update([&](const Stored<LocalState> & loc) { + auto st = head.storage(); + + auto threads = loc->shared<DirectMessageThreads>(); + + vector<Stored<DirectMessageData>> oldSent; + for (const auto & c : findThreadComponents(threads.data(), thread.peer(), &DirectMessageState::sent)) + for (const auto & m : c->sent) + oldSent.push_back(m); + filterAncestors(oldSent); + + auto newSent = oldSent; + for (const auto & msg : thread.p->head) + newSent.push_back(msg); + filterAncestors(newSent); + + if (newSent != oldSent) { + auto state = st.store(DirectMessageState { + .prev = threads.data(), + .peer = thread.peer(), + .sent = move(newSent), + }); + + return st.store(loc->shared<DirectMessageThreads>(DirectMessageThreads(state))); + } + + return loc; + }); +} diff --git a/src/message.h b/src/message.h new file mode 100644 index 0000000..22da0fd --- /dev/null +++ b/src/message.h @@ -0,0 +1,63 @@ +#pragma once + +#include <erebos/identity.h> +#include <erebos/message.h> +#include <erebos/storage.h> +#include <erebos/time.h> + +#include <chrono> +#include <mutex> +#include <vector> + +namespace chrono = std::chrono; +using chrono::system_clock; +using std::mutex; +using std::optional; +using std::string; +using std::vector; + +namespace erebos { + +struct DirectMessageData +{ + static DirectMessageData load(const Ref &); + Ref store(const Storage &) const; + + vector<Stored<DirectMessageData>> prev; + optional<Identity> from; + optional<ZonedTime> time; + optional<string> text; +}; + +struct DirectMessage::Priv +{ + Stored<DirectMessageData> data; +}; + +struct DirectMessageThread::Priv +{ + const Identity peer; + const vector<Stored<DirectMessageData>> head; +}; + +struct DirectMessageThread::Iterator::Priv +{ + optional<DirectMessage> current; + vector<Stored<DirectMessageData>> next; +}; + +struct DirectMessageState +{ + static DirectMessageState load(const Ref &); + Ref store(const Storage &) const; + + vector<Stored<DirectMessageState>> prev; + optional<Identity> peer; + + vector<Stored<DirectMessageData>> ready {}; + vector<Stored<DirectMessageData>> sent {}; + vector<Stored<DirectMessageData>> received {}; + vector<Stored<DirectMessageData>> seen {}; +}; + +} diff --git a/src/network.cpp b/src/network.cpp new file mode 100644 index 0000000..7f37cf7 --- /dev/null +++ b/src/network.cpp @@ -0,0 +1,828 @@ +#include "network.h" + +#include "identity.h" +#include "network/protocol.h" +#include "service.h" + +#include <algorithm> +#include <cstring> +#include <iostream> +#include <stdexcept> + +#include <arpa/inet.h> +#include <ifaddrs.h> +#include <net/if.h> +#include <netdb.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +using std::get; +using std::get_if; +using std::holds_alternative; +using std::move; +using std::runtime_error; +using std::scoped_lock; +using std::to_string; +using std::unique_lock; + +using namespace erebos; + +Server::Server(const Head<LocalState> & head, ServerConfig && config): + p(new Priv(head, *head->identity())) +{ + p->services.reserve(config.services.size()); + for (const auto & ctor : config.services) + p->services.emplace_back(ctor(*this)); +} + +Server:: Server(const std::shared_ptr<Priv> & ptr): + p(ptr) +{ +} + +Server::~Server() = default; + +const Head<LocalState> & Server::localHead() const +{ + return p->localHead; +} + +const Bhv<LocalState> & Server::localState() const +{ + return p->localState; +} + +Identity Server::identity() const +{ + shared_lock lock(p->selfMutex); + return p->self; +} + +Service & Server::svcHelper(const std::type_info & tinfo) +{ + for (auto & s : p->services) { + auto & sobj = *s; + if (typeid(sobj) == tinfo) + return sobj; + } + throw runtime_error("service not found"); +} + +PeerList & Server::peerList() const +{ + return p->plist; +} + +optional<Peer> Server::peer(const Identity & identity) const +{ + scoped_lock lock(p->dataMutex); + + for (auto & peer : p->peers) { + const auto & pid = peer->identity; + if (holds_alternative<Identity>(pid)) + if (std::get<Identity>(pid).finalOwner().sameAs(identity)) + return peer->lpeer; + } + + return nullopt; +} + +void Server::addPeer(const string & node) const +{ + return addPeer(node, to_string(Priv::discoveryPort)); +} + +void Server::addPeer(const string & node, const string & service) const +{ + addrinfo hints {}; + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG; + hints.ai_family = AF_INET6; + hints.ai_socktype = SOCK_DGRAM; + addrinfo *aptr; + + int r = getaddrinfo(node.c_str(), service.c_str(), &hints, &aptr); + if (r != 0) + throw runtime_error(string("Server::addPeer: getaddrinfo failed: ") + gai_strerror(r)); + + unique_ptr<addrinfo, void(*)(addrinfo*)> result { aptr, &freeaddrinfo }; + + for (addrinfo * rp = result.get(); rp != nullptr; rp = rp->ai_next) { + if (rp->ai_family == AF_INET6) { + p->getPeer(*(sockaddr_in6 *)rp->ai_addr); + return; + } + } + + throw runtime_error("Server::addPeer: no suitable peer address found"); +} + + +Peer::Peer(const shared_ptr<Priv> & p): p(p) {} +Peer::~Peer() = default; + +Server Peer::server() const +{ + if (auto speer = p->speer.lock()) + return Server(speer->server.getptr()); + throw runtime_error("Server no longer running"); +} + +const Storage & Peer::tempStorage() const +{ + if (auto speer = p->speer.lock()) + return speer->tempStorage; + throw runtime_error("Server no longer running"); +} + +const PartialStorage & Peer::partialStorage() const +{ + if (auto speer = p->speer.lock()) + return speer->partStorage; + throw runtime_error("Server no longer running"); +} + +string Peer::name() const +{ + if (auto speer = p->speer.lock()) { + if (holds_alternative<Identity>(speer->identity)) + if (auto name = std::get<Identity>(speer->identity).finalOwner().name()) + return *name; + if (holds_alternative<shared_ptr<WaitingRef>>(speer->identity)) + return string(std::get<shared_ptr<WaitingRef>>(speer->identity)->ref.digest()); + + return addressStr(); + } + return "<server closed>"; +} + +optional<Identity> Peer::identity() const +{ + if (auto speer = p->speer.lock()) + if (holds_alternative<Identity>(speer->identity)) + return std::get<Identity>(speer->identity); + return nullopt; +} + +const sockaddr_in6 & Peer::address() const +{ + if (auto speer = p->speer.lock()) + return speer->connection.peerAddress(); + throw runtime_error("Server no longer running"); +} + +string Peer::addressStr() const +{ + char buf[INET6_ADDRSTRLEN]; + const in6_addr & addr = address().sin6_addr; + + if (inet_ntop(AF_INET6, &addr, buf, sizeof(buf))) { + if (IN6_IS_ADDR_V4MAPPED(&addr) && strncmp(buf, "::ffff:", 7) == 0) + return buf + 7; + return buf; + } + + return "<invalid address>"; +} + +uint16_t Peer::port() const +{ + return ntohs(address().sin6_port); +} + +void Peer::Priv::notifyWatchers() +{ + if (auto slist = list.lock()) { + Peer p(shared_from_this()); + for (const auto & w : slist->watchers) + w(listIndex, &p); + } +} + +bool Peer::send(UUID uuid, const Ref & ref) const +{ + return send(uuid, ref, *ref); +} + +bool Peer::send(UUID uuid, const Object & obj) const +{ + if (auto speer = p->speer.lock()) { + auto ref = speer->tempStorage.storeObject(obj); + return send(uuid, ref, obj); + } + + return false; +} + +bool Peer::send(UUID uuid, const Ref & ref, const Object & obj) const +{ + if (auto speer = p->speer.lock()) { + NetworkProtocol::Header header({ + NetworkProtocol::Header::ServiceType { uuid }, + NetworkProtocol::Header::ServiceRef { ref.digest() }, + }); + speer->connection.send(speer->partStorage, move(header), { obj }, true); + return true; + } + + return false; +} + +bool Peer::operator==(const Peer & other) const { return p == other.p; } +bool Peer::operator!=(const Peer & other) const { return p != other.p; } +bool Peer::operator<(const Peer & other) const { return p < other.p; } +bool Peer::operator<=(const Peer & other) const { return p <= other.p; } +bool Peer::operator>(const Peer & other) const { return p > other.p; } +bool Peer::operator>=(const Peer & other) const { return p >= other.p; } + + +PeerList::PeerList(): p(new Priv) {} +PeerList::PeerList(const shared_ptr<PeerList::Priv> & p): p(p) {} +PeerList::~PeerList() = default; + +void PeerList::Priv::push(const shared_ptr<Server::Peer> & speer) +{ + scoped_lock lock(dataMutex); + size_t s = peers.size(); + + speer->lpeer.reset(new Peer::Priv); + speer->lpeer->speer = speer; + speer->lpeer->list = shared_from_this(); + speer->lpeer->listIndex = s; + + Peer p(speer->lpeer); + + peers.push_back(speer->lpeer); + for (const auto & w : watchers) + w(s, &p); +} + +size_t PeerList::size() const +{ + return p->peers.size(); +} + +Peer PeerList::at(size_t i) const +{ + return Peer(p->peers.at(i)); +} + +void PeerList::onUpdate(function<void(size_t, const Peer *)> w) +{ + scoped_lock lock(p->dataMutex); + for (size_t i = 0; i < p->peers.size(); i++) { + if (auto speer = p->peers[i]->speer.lock()) { + Peer peer(speer->lpeer); + w(i, &peer); + } + } + p->watchers.push_back(w); +} + + +Server::Priv::Priv(const Head<LocalState> & local, const Identity & self): + self(self), + // Watching needs to start after self is initialized + localState(local.behavior()), + localHead(local.watch(std::bind(&Priv::handleLocalHeadChange, this, std::placeholders::_1))) +{ + struct ifaddrs * raddrs; + if (getifaddrs(&raddrs) < 0) + throw std::system_error(errno, std::generic_category()); + unique_ptr<ifaddrs, void(*)(ifaddrs *)> addrs(raddrs, freeifaddrs); + + for (struct ifaddrs * ifa = addrs.get(); ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET && + ifa->ifa_flags & IFF_BROADCAST) { + localAddresses.push_back(((sockaddr_in*)ifa->ifa_addr)->sin_addr); + bcastAddresses.push_back(((sockaddr_in*)ifa->ifa_broadaddr)->sin_addr); + } + } + + int sock = socket(AF_INET6, SOCK_DGRAM, 0); + if (sock < 0) + throw std::system_error(errno, std::generic_category()); + + protocol = NetworkProtocol(sock, self); + + int disable = 0; + // Should be disabled by default, but try to make sure. On platforms + // where the calls fails, IPv4 might not work. + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + &disable, sizeof(disable)); + + int enable = 1; + if (setsockopt(sock, SOL_SOCKET, SO_BROADCAST, + &enable, sizeof(enable)) < 0) + throw std::system_error(errno, std::generic_category()); + + if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, + &enable, sizeof(enable)) < 0) + throw std::system_error(errno, std::generic_category()); + + sockaddr_in6 laddr = {}; + laddr.sin6_family = AF_INET6; + laddr.sin6_port = htons(discoveryPort); + if (::bind(sock, (sockaddr *) &laddr, sizeof(laddr)) < 0) + throw std::system_error(errno, std::generic_category()); + + threadListen = thread([this] { doListen(); }); + threadAnnounce = thread([this] { doAnnounce(); }); +} + +Server::Priv::~Priv() +{ + { + scoped_lock lock(dataMutex); + finish = true; + } + + protocol.shutdown(); + + announceCondvar.notify_all(); + threadListen.join(); + threadAnnounce.join(); +} + +shared_ptr<Server::Priv> Server::Priv::getptr() +{ + // Creating temporary object, so just use null deleter + return shared_ptr<Priv>(this, [](Priv *){}); +} + +void Server::Priv::doListen() +{ + unique_lock lock(dataMutex); + + for (; !finish; lock.lock()) { + lock.unlock(); + + Peer * peer = nullptr; + auto res = protocol.poll(); + + if (holds_alternative<NetworkProtocol::ProtocolClosed>(res)) + break; + + if (const auto * ann = get_if<NetworkProtocol::ReceivedAnnounce>(&res)) { + if (not isSelfAddress(ann->addr)) + getPeer(ann->addr); + } + + if (holds_alternative<NetworkProtocol::NewConnection>(res)) { + auto & conn = get<NetworkProtocol::NewConnection>(res).conn; + if (not isSelfAddress(conn.peerAddress())) + peer = &addPeer(move(conn)); + } + + if (holds_alternative<NetworkProtocol::ConnectionReadReady>(res)) { + peer = findPeer(get<NetworkProtocol::ConnectionReadReady>(res).id); + } + + if (!peer) + continue; + + if (auto header = peer->connection.receive(peer->partStorage)) { + ReplyBuilder reply; + + scoped_lock hlock(dataMutex); + shared_lock slock(selfMutex); + + handlePacket(*peer, *header, reply); + peer->updateIdentity(reply); + peer->updateChannel(reply); + peer->updateService(reply); + + if (!reply.header().empty()) + peer->connection.send(peer->partStorage, + NetworkProtocol::Header(reply.header()), reply.body(), false); + + peer->connection.trySendOutQueue(); + } + } +} + +void Server::Priv::doAnnounce() +{ + auto pst = self.ref()->storage().derivePartialStorage(); + + unique_lock lock(dataMutex); + auto lastAnnounce = steady_clock::now() - announceInterval; + + while (!finish) { + auto now = steady_clock::now(); + + if (lastAnnounce + announceInterval < now) { + shared_lock slock(selfMutex); + + for (const auto & in : bcastAddresses) { + sockaddr_in sin = {}; + sin.sin_family = AF_INET; + sin.sin_addr = in; + sin.sin_port = htons(discoveryPort); + protocol.announceTo(sin); + } + + lastAnnounce += announceInterval * ((now - lastAnnounce) / announceInterval); + } + + announceCondvar.wait_until(lock, lastAnnounce + announceInterval); + } +} + +bool Server::Priv::isSelfAddress(const sockaddr_in6 & paddr) +{ + if (IN6_IS_ADDR_V4MAPPED(&paddr.sin6_addr)) + for (const auto & in : localAddresses) + if (in.s_addr == *reinterpret_cast<const in_addr_t*>(paddr.sin6_addr.s6_addr + 12) && + ntohs(paddr.sin6_port) == discoveryPort) + return true; + return false; +} + +Server::Peer * Server::Priv::findPeer(NetworkProtocol::Connection::Id cid) const +{ + scoped_lock lock(dataMutex); + + for (auto & peer : peers) + if (peer->connection.id() == cid) + return peer.get(); + + return nullptr; +} + +Server::Peer & Server::Priv::getPeer(const sockaddr_in6 & paddr) +{ + scoped_lock lock(dataMutex); + + for (auto & peer : peers) + if (memcmp(&peer->connection.peerAddress(), &paddr, sizeof paddr) == 0) + return *peer; + + auto st = self.ref()->storage().deriveEphemeralStorage(); + shared_ptr<Peer> peer(new Peer { + .server = *this, + .connection = protocol.connect(paddr), + .identity = monostate(), + .identityUpdates = {}, + .tempStorage = st, + .partStorage = st.derivePartialStorage(), + }); + peers.push_back(peer); + plist.p->push(peer); + return *peer; +} + +Server::Peer & Server::Priv::addPeer(NetworkProtocol::Connection conn) +{ + scoped_lock lock(dataMutex); + + auto st = self.ref()->storage().deriveEphemeralStorage(); + shared_ptr<Peer> peer(new Peer { + .server = *this, + .connection = move(conn), + .identity = monostate(), + .identityUpdates = {}, + .tempStorage = st, + .partStorage = st.derivePartialStorage(), + }); + peers.push_back(peer); + plist.p->push(peer); + return *peer; +} + +void Server::Priv::handlePacket(Server::Peer & peer, const NetworkProtocol::Header & header, ReplyBuilder & reply) +{ + unordered_set<Digest> plaintextRefs; + for (const auto & obj : collectStoredObjects(Stored<Object>::load(*self.ref()))) + plaintextRefs.insert(obj.ref().digest()); + + optional<UUID> serviceType; + + for (const auto & item : header.items) { + if (const auto * ack = get_if<NetworkProtocol::Header::Acknowledged>(&item)) { + const auto & dgst = ack->value; + if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) && + std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() == dgst) + peer.finalizeChannel(reply, + std::get<Stored<ChannelAccept>>(peer.connection.channel())->data->channel()); + } + + else if (const auto * req = get_if<NetworkProtocol::Header::DataRequest>(&item)) { + const auto & dgst = req->value; + if (holds_alternative<unique_ptr<Channel>>(peer.connection.channel()) || + plaintextRefs.find(dgst) != plaintextRefs.end()) { + if (auto ref = peer.tempStorage.ref(dgst)) { + reply.header({ NetworkProtocol::Header::DataResponse { ref->digest() } }); + reply.body(*ref); + } + } + } + + else if (const auto * rsp = get_if<NetworkProtocol::Header::DataResponse>(&item)) { + const auto & dgst = rsp->value; + if (not holds_alternative<unique_ptr<Channel>>(peer.connection.channel())) + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); + for (auto & pwref : waiting) { + if (auto wref = pwref.lock()) { + if (std::find(wref->missing.begin(), wref->missing.end(), dgst) != + wref->missing.end()) { + if (wref->check(reply)) + pwref.reset(); + } + } + } + waiting.erase(std::remove_if(waiting.begin(), waiting.end(), + [](auto & wref) { return wref.expired(); }), waiting.end()); + } + + else if (const auto * ann = get_if<NetworkProtocol::Header::AnnounceSelf>(&item)) { + const auto & dgst = ann->value; + if (dgst != self.ref()->digest() && + holds_alternative<monostate>(peer.identity)) { + reply.header({ NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() }}); + + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = peer.partStorage.ref(dgst), + .missing = {}, + }); + waiting.push_back(wref); + peer.identity = wref; + wref->check(reply); + } + } + + else if (const auto * anu = get_if<NetworkProtocol::Header::AnnounceUpdate>(&item)) { + if (holds_alternative<Identity>(peer.identity)) { + const auto & dgst = anu->value; + + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = peer.partStorage.ref(dgst), + .missing = {}, + }); + waiting.push_back(wref); + peer.identityUpdates.push_back(wref); + wref->check(reply); + } + } + + else if (const auto * req = get_if<NetworkProtocol::Header::ChannelRequest>(&item)) { + const auto & dgst = req->value; + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); + + if (holds_alternative<Stored<ChannelRequest>>(peer.connection.channel()) && + std::get<Stored<ChannelRequest>>(peer.connection.channel()).ref().digest() < dgst) { + // TODO: reject request with lower priority + } + + else if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel())) { + // TODO: reject when we already sent accept + } + + else { + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = peer.partStorage.ref(dgst), + .missing = {}, + }); + waiting.push_back(wref); + peer.connection.channel() = wref; + wref->check(reply); + } + } + + else if (const auto * acc = get_if<NetworkProtocol::Header::ChannelAccept>(&item)) { + const auto & dgst = acc->value; + if (holds_alternative<Stored<ChannelAccept>>(peer.connection.channel()) && + std::get<Stored<ChannelAccept>>(peer.connection.channel()).ref().digest() < dgst) { + // TODO: reject request with lower priority + } + + else { + auto cres = peer.tempStorage.copy(peer.partStorage.ref(dgst)); + if (auto r = get_if<Ref>(&cres)) { + auto acc = ChannelAccept::load(*r); + if (holds_alternative<Identity>(peer.identity) && + acc.isSignedBy(std::get<Identity>(peer.identity).keyMessage())) { + reply.header({ NetworkProtocol::Header::Acknowledged { dgst } }); + peer.finalizeChannel(reply, acc.data->channel()); + } + } + } + } + + else if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) { + if (!serviceType) + serviceType = stype->value; + } + + else if (const auto * sref = get_if<NetworkProtocol::Header::ServiceRef>(&item)) { + if (!serviceType) + for (auto & item : header.items) + if (const auto * stype = get_if<NetworkProtocol::Header::ServiceType>(&item)) { + serviceType = stype->value; + break; + } + + if (serviceType) { + const auto & dgst = sref->value; + auto pref = peer.partStorage.ref(dgst); + + shared_ptr<WaitingRef> wref(new WaitingRef { + .storage = peer.tempStorage, + .ref = pref, + .missing = {}, + }); + waiting.push_back(wref); + peer.serviceQueue.emplace_back(*serviceType, wref); + wref->check(reply); + } + } + } +} + +void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head) +{ + scoped_lock lock(dataMutex); + scoped_lock slock(selfMutex); + + if (auto id = head->identity()) { + if (*id != self) { + self = *id; + protocol.updateIdentity(*id); + } + } +} + +void Server::Peer::updateIdentity(ReplyBuilder &) +{ + if (holds_alternative<shared_ptr<WaitingRef>>(identity)) { + if (auto ref = std::get<shared_ptr<WaitingRef>>(identity)->check()) + if (auto id = Identity::load(*ref)) { + identity.emplace<Identity>(*id); + if (lpeer) + lpeer->notifyWatchers(); + } + } + else if (holds_alternative<Identity>(identity)) { + if (!identityUpdates.empty()) { + decltype(identityUpdates) keep; + vector<StoredIdentityPart> updates; + + for (auto wref : identityUpdates) { + if (auto ref = wref->check()) + updates.push_back(StoredIdentityPart::load(*ref)); + else + keep.push_back(move(wref)); + } + + identityUpdates = move(keep); + + if (!updates.empty()) { + auto nid = get<Identity>(identity).update(updates); + if (nid != get<Identity>(identity)) { + identity = move(nid); + if (lpeer) + lpeer->notifyWatchers(); + } + } + } + } +} + +void Server::Peer::updateChannel(ReplyBuilder & reply) +{ + if (!holds_alternative<Identity>(identity)) + return; + + if (holds_alternative<monostate>(connection.channel())) { + auto req = Channel::generateRequest(tempStorage, + server.self, std::get<Identity>(identity)); + connection.channel().emplace<Stored<ChannelRequest>>(req); + reply.header({ NetworkProtocol::Header::ChannelRequest { req.ref().digest() } }); + reply.body(req.ref()); + reply.body(req->data.ref()); + reply.body(req->data->key.ref()); + for (const auto & sig : req->sigs) + reply.body(sig.ref()); + } + + if (holds_alternative<shared_ptr<WaitingRef>>(connection.channel())) { + if (auto ref = std::get<shared_ptr<WaitingRef>>(connection.channel())->check(reply)) { + auto req = Stored<ChannelRequest>::load(*ref); + if (holds_alternative<Identity>(identity) && + req->isSignedBy(std::get<Identity>(identity).keyMessage())) { + if (auto acc = Channel::acceptRequest(server.self, std::get<Identity>(identity), req)) { + connection.channel().emplace<Stored<ChannelAccept>>(*acc); + reply.header({ NetworkProtocol::Header::ChannelAccept { acc->ref().digest() } }); + reply.body(acc->ref()); + reply.body(acc.value()->data.ref()); + reply.body(acc.value()->data->key.ref()); + for (const auto & sig : acc.value()->sigs) + reply.body(sig.ref()); + } else { + connection.channel() = monostate(); + } + } else { + connection.channel() = monostate(); + } + } + } +} + +void Server::Peer::finalizeChannel(ReplyBuilder & reply, unique_ptr<Channel> ch) +{ + connection.channel().emplace<unique_ptr<Channel>>(move(ch)); + + vector<NetworkProtocol::Header::Item> hitems; + for (const auto & r : server.self.extRefs()) + reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); + for (const auto & r : server.self.updates()) + reply.header(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); +} + +void Server::Peer::updateService(ReplyBuilder & reply) +{ + decltype(serviceQueue) next; + for (auto & x : serviceQueue) { + if (auto ref = std::get<1>(x)->check(reply)) { + if (lpeer) { + Service::Context ctx { nullptr }; + + server.localHead.update([&] (const Stored<LocalState> & local) { + ctx = Service::Context(new Service::Context::Priv { + .ref = *ref, + .peer = erebos::Peer(lpeer), + .local = local, + }); + + for (auto & svc : server.services) { + if (svc->uuid() == std::get<UUID>(x)) { + svc->handle(ctx); + break; + } + } + + return ctx.local(); + }); + + ctx.runAfterCommitHooks(); + } + } else { + next.push_back(std::move(x)); + } + } + serviceQueue = std::move(next); +} + + +void ReplyBuilder::header(NetworkProtocol::Header::Item && item) +{ + for (const auto & x : mheader) + if (x == item) + return; + mheader.emplace_back(std::move(item)); +} + +void ReplyBuilder::body(const Ref & ref) +{ + for (const auto & x : mbody) + if (x.digest() == ref.digest()) + return; + mbody.push_back(ref); +} + +vector<Object> ReplyBuilder::body() const +{ + vector<Object> res; + res.reserve(mbody.size()); + for (const Ref & ref : mbody) + res.push_back(*ref); + return res; +} + + +optional<Ref> WaitingRef::check() +{ + if (auto r = storage.ref(ref.digest())) + return *r; + + auto res = storage.copy(ref); + if (auto r = get_if<Ref>(&res)) + return *r; + + missing = std::get<vector<Digest>>(res); + return nullopt; +} + +optional<Ref> WaitingRef::check(ReplyBuilder & reply) +{ + if (auto r = check()) + return r; + + for (const auto & d : missing) + reply.header({ NetworkProtocol::Header::DataRequest { d } }); + + return nullopt; +} diff --git a/src/network.h b/src/network.h new file mode 100644 index 0000000..d1fae15 --- /dev/null +++ b/src/network.h @@ -0,0 +1,133 @@ +#pragma once + +#include <erebos/network.h> + +#include "network/protocol.h" + +#include <condition_variable> +#include <mutex> +#include <shared_mutex> +#include <thread> +#include <vector> + +#include <netinet/in.h> + +using std::condition_variable; +using std::monostate; +using std::mutex; +using std::optional; +using std::shared_lock; +using std::shared_mutex; +using std::shared_ptr; +using std::string; +using std::thread; +using std::unique_ptr; +using std::variant; +using std::vector; +using std::tuple; +using std::weak_ptr; + +using std::enable_shared_from_this; + +namespace chrono = std::chrono; +using chrono::steady_clock; + +namespace erebos { + +class ReplyBuilder; +struct WaitingRef; + +struct Server::Peer +{ + Peer(const Peer &) = delete; + Peer & operator=(const Peer &) = delete; + + Priv & server; + NetworkProtocol::Connection connection; + + variant<monostate, + shared_ptr<struct WaitingRef>, + Identity> identity; + vector<shared_ptr<WaitingRef>> identityUpdates; + + Storage tempStorage; + PartialStorage partStorage; + + vector<tuple<UUID, shared_ptr<WaitingRef>>> serviceQueue {}; + + shared_ptr<erebos::Peer::Priv> lpeer = nullptr; + + void updateIdentity(ReplyBuilder &); + void updateChannel(ReplyBuilder &); + void finalizeChannel(ReplyBuilder &, unique_ptr<Channel>); + void updateService(ReplyBuilder &); +}; + +struct Peer::Priv : enable_shared_from_this<Peer::Priv> +{ + weak_ptr<Server::Peer> speer; + weak_ptr<PeerList::Priv> list; + size_t listIndex; + + void notifyWatchers(); +}; + +struct PeerList::Priv : enable_shared_from_this<PeerList::Priv> +{ + mutex dataMutex; + vector<shared_ptr<Peer::Priv>> peers; + vector<function<void(size_t, const Peer *)>> watchers; + + void push(const shared_ptr<Server::Peer> &); +}; + +struct Server::Priv +{ + Priv(const Head<LocalState> & local, const Identity & self); + ~Priv(); + + shared_ptr<Priv> getptr(); + + void doListen(); + void doAnnounce(); + + bool isSelfAddress(const sockaddr_in6 & paddr); + Peer * findPeer(NetworkProtocol::Connection::Id cid) const; + Peer & getPeer(const sockaddr_in6 & paddr); + Peer & addPeer(NetworkProtocol::Connection conn); + void handlePacket(Peer &, const NetworkProtocol::Header &, ReplyBuilder &); + + void handleLocalHeadChange(const Head<LocalState> &); + + constexpr static uint16_t discoveryPort { 29665 }; + constexpr static chrono::seconds announceInterval { 60 }; + + mutable mutex dataMutex; + condition_variable announceCondvar; + bool finish = false; + + shared_mutex selfMutex; + Identity self; + const Bhv<LocalState> localState; + + thread threadListen; + thread threadAnnounce; + + vector<shared_ptr<Peer>> peers; + PeerList plist; + + vector<struct NetworkProtocol::Header> outgoing; + vector<weak_ptr<WaitingRef>> waiting; + + NetworkProtocol protocol; + vector<in_addr> localAddresses; + vector<in_addr> bcastAddresses; + + // Stop watching before destroying other data + WatchedHead<LocalState> localHead; + + // Start destruction with finalizing services + vector<unique_ptr<Service>> services; +}; + +} diff --git a/src/network/channel.cpp b/src/network/channel.cpp new file mode 100644 index 0000000..5fff1fa --- /dev/null +++ b/src/network/channel.cpp @@ -0,0 +1,216 @@ +#include "channel.h" + +#include <algorithm> +#include <cstring> +#include <stdexcept> + +#include <endian.h> + +using std::remove_const; +using std::runtime_error; + +using namespace erebos; + +Ref ChannelRequestData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & p : peers) + items.emplace_back("peer", p); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +ChannelRequestData ChannelRequestData::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) { + if (auto key = rec->item("key").as<PublicKexKey>()) + return ChannelRequestData { + .peers = rec->items("peer").as<Signed<IdentityData>>(), + .key = *key, + }; + } + + return ChannelRequestData { + .peers = {}, + .key = Stored<PublicKexKey>::load(ref.storage().zref()), + }; +} + +Ref ChannelAcceptData::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("req", request); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +ChannelAcceptData ChannelAcceptData::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) + return ChannelAcceptData { + .request = *rec->item("req").as<ChannelRequest>(), + .key = *rec->item("key").as<PublicKexKey>(), + }; + + return ChannelAcceptData { + .request = Stored<ChannelRequest>::load(ref.storage().zref()), + .key = Stored<PublicKexKey>::load(ref.storage().zref()), + }; +} + +unique_ptr<Channel> ChannelAcceptData::channel() const +{ + if (auto secret = SecretKexKey::load(key)) + return make_unique<Channel>( + request->data->peers, + secret->dh(*request->data->key), + false + ); + + if (auto secret = SecretKexKey::load(request->data->key)) + return make_unique<Channel>( + request->data->peers, + secret->dh(*key), + true + ); + + throw runtime_error("failed to load secret DH key"); +} + + +Stored<ChannelRequest> Channel::generateRequest(const Storage & st, + const Identity & self, const Identity & peer) +{ + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelRequestData { + .peers = self.ref()->digest() < peer.ref()->digest() ? + vector<Stored<Signed<IdentityData>>> { + Stored<Signed<IdentityData>>::load(*self.ref()), + Stored<Signed<IdentityData>>::load(*peer.ref()), + } : + vector<Stored<Signed<IdentityData>>> { + Stored<Signed<IdentityData>>::load(*peer.ref()), + Stored<Signed<IdentityData>>::load(*self.ref()), + }, + .key = SecretKexKey::generate(st).pub(), + })); +} + +optional<Stored<ChannelAccept>> Channel::acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request) +{ + if (!request->isSignedBy(peer.keyMessage())) + return nullopt; + + auto & peers = request->data->peers; + if (peers.size() != 2 || + std::none_of(peers.begin(), peers.end(), [&self](const auto & x) + { return x.ref().digest() == self.ref()->digest(); }) || + std::none_of(peers.begin(), peers.end(), [&peer](const auto & x) + { return x.ref().digest() == peer.ref()->digest(); })) + return nullopt; + + auto & st = request.ref().storage(); + + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelAcceptData { + .request = request, + .key = SecretKexKey::generate(st).pub(), + })); +} + +uint64_t Channel::encrypt(BufferCIt plainBegin, BufferCIt plainEnd, + Buffer & encBuffer, size_t encOffset) +{ + auto plainSize = plainEnd - plainBegin; + encBuffer.resize(encOffset + plainSize + 1 /* counter */ + 16 /* tag */); + array<uint8_t, 12> iv; + + uint64_t count = counterNextOut.fetch_add(1); + uint64_t beCount = htobe64(count); + encBuffer[encOffset] = count % 0x100; + + constexpr size_t nonceFixedSize = std::tuple_size_v<decltype(nonceFixedOur)>; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedOur.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_EncryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), + nullptr, key.data(), iv.data()); + + int outl = 0; + uint8_t * cur = encBuffer.data() + encOffset + 1; + + if (EVP_EncryptUpdate(ctx.get(), cur, &outl, &*plainBegin, plainSize) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + if (EVP_EncryptFinal(ctx.get(), cur, &outl) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_GET_TAG, 16, cur); + return count; +} + +optional<uint64_t> Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd, + Buffer & decBuffer, const size_t decOffset) +{ + auto encSize = encEnd - encBegin; + decBuffer.resize(decOffset + encSize); + array<uint8_t, 12> iv; + + if (encBegin + 1 /* counter */ + 16 /* tag */ > encEnd) + return nullopt; + + uint64_t expectedCount = counterNextIn.load(); + uint64_t guessedCount = expectedCount - 0x80u + ((0x80u + encBegin[0] - expectedCount) % 0x100u); + uint64_t beCount = htobe64(guessedCount); + + constexpr size_t nonceFixedSize = std::tuple_size_v<decltype(nonceFixedPeer)>; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedPeer.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_DecryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), + nullptr, key.data(), iv.data()); + + int outl = 0; + uint8_t * cur = decBuffer.data() + decOffset; + + if (EVP_DecryptUpdate(ctx.get(), cur, &outl, + &*encBegin + 1, encSize - 1 - 16) != 1) + return nullopt; + cur += outl; + + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_TAG, 16, + (void *) (&*encEnd - 16))) + return nullopt; + + if (EVP_DecryptFinal_ex(ctx.get(), cur, &outl) != 1) + return nullopt; + cur += outl; + + while (expectedCount < guessedCount + 1 && + not counterNextIn.compare_exchange_weak(expectedCount, guessedCount + 1)) + ; // empty loop body + + decBuffer.resize(cur - decBuffer.data()); + return guessedCount; +} diff --git a/src/network/channel.h b/src/network/channel.h new file mode 100644 index 0000000..bba11b3 --- /dev/null +++ b/src/network/channel.h @@ -0,0 +1,78 @@ +#pragma once + +#include <erebos/storage.h> + +#include "../identity.h" + +#include <atomic> +#include <memory> + +namespace erebos { + +using std::array; +using std::atomic; +using std::unique_ptr; + +struct ChannelRequestData +{ + Ref store(const Storage & st) const; + static ChannelRequestData load(const Ref &); + + const vector<Stored<Signed<IdentityData>>> peers; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelRequestData> ChannelRequest; + +struct ChannelAcceptData +{ + Ref store(const Storage & st) const; + static ChannelAcceptData load(const Ref &); + + unique_ptr<class Channel> channel() const; + + const Stored<ChannelRequest> request; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelAcceptData> ChannelAccept; + +class Channel +{ +public: + Channel(const vector<Stored<Signed<IdentityData>>> & peers, + vector<uint8_t> && key, bool ourRequest): + peers(peers), + key(std::move(key)), + nonceFixedOur({ uint8_t(ourRequest ? 1 : 2), 0, 0, 0 }), + nonceFixedPeer({ uint8_t(ourRequest ? 2 : 1), 0, 0, 0 }) + {} + + Channel(const Channel &) = delete; + Channel(Channel &&) = delete; + Channel & operator=(const Channel &) = delete; + Channel & operator=(Channel &&) = delete; + + static Stored<ChannelRequest> generateRequest(const Storage &, + const Identity & self, const Identity & peer); + static optional<Stored<ChannelAccept>> acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request); + + using Buffer = vector<uint8_t>; + using BufferCIt = Buffer::const_iterator; + uint64_t encrypt(BufferCIt plainBegin, BufferCIt plainEnd, + Buffer & encBuffer, size_t encOffset); + optional<uint64_t> decrypt(BufferCIt encBegin, BufferCIt encEnd, + Buffer & decBuffer, size_t decOffset); + +private: + const vector<Stored<Signed<IdentityData>>> peers; + const vector<uint8_t> key; + + const array<uint8_t, 4> nonceFixedOur; + const array<uint8_t, 4> nonceFixedPeer; + atomic<uint64_t> counterNextOut = 0; + atomic<uint64_t> counterNextIn = 0; +}; + +} diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp new file mode 100644 index 0000000..b781693 --- /dev/null +++ b/src/network/protocol.cpp @@ -0,0 +1,677 @@ +#include "protocol.h" + +#include <sys/socket.h> +#include <unistd.h> + +#include <algorithm> +#include <cstring> +#include <iostream> +#include <mutex> +#include <system_error> + +using std::get_if; +using std::holds_alternative; +using std::move; +using std::nullopt; +using std::runtime_error; +using std::scoped_lock; +using std::visit; + +namespace erebos { + +struct NetworkProtocol::ConnectionPriv +{ + Connection::Id id() const; + + bool send(const PartialStorage &, Header, + const vector<Object> &, bool secure); + + NetworkProtocol * protocol; + const sockaddr_in6 peerAddress; + + mutex cmutex {}; + vector<uint8_t> buffer {}; + + optional<Cookie> receivedCookie = nullopt; + bool confirmedCookie = false; + ChannelState channel = monostate(); + vector<vector<uint8_t>> secureOutQueue {}; + + vector<uint64_t> toAcknowledge {}; +}; + + +NetworkProtocol::NetworkProtocol(): + sock(-1) +{} + +NetworkProtocol::NetworkProtocol(int s, Identity id): + sock(s), + self(move(id)) +{} + +NetworkProtocol::NetworkProtocol(NetworkProtocol && other): + sock(other.sock), + self(move(other.self)) +{ + other.sock = -1; +} + +NetworkProtocol & NetworkProtocol::operator=(NetworkProtocol && other) +{ + sock = other.sock; + other.sock = -1; + self = move(other.self); + return *this; +} + +NetworkProtocol::~NetworkProtocol() +{ + if (sock >= 0) + close(sock); + + for (auto & c : connections) + c->protocol = nullptr; +} + +NetworkProtocol::PollResult NetworkProtocol::poll() +{ + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + { + scoped_lock clock(c->cmutex); + if (c->toAcknowledge.empty()) + continue; + + if (not holds_alternative<unique_ptr<Channel>>(c->channel)) + continue; + } + auto pst = self->ref()->storage().deriveEphemeralStorage(); + c->send(pst, Header {{}}, {}, true); + } + } + + sockaddr_in6 addr; + if (!recvfrom(buffer, addr)) + return ProtocolClosed {}; + + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { + scoped_lock clock(c->cmutex); + buffer.swap(c->buffer); + return ConnectionReadReady { c->id() }; + } + } + + auto pst = self->ref()->storage().deriveEphemeralStorage(); + optional<uint64_t> secure = false; + if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) { + if (auto conn = verifyNewConnection(*header, addr)) + return NewConnection { move(*conn) }; + + if (auto ann = header->lookupFirst<Header::AnnounceSelf>()) + return ReceivedAnnounce { addr, ann->value }; + } + } + + return poll(); +} + +NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) +{ + auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + { + scoped_lock lock(protocolMutex); + connections.push_back(conn.get()); + + vector<Header::Item> header { + Header::Initiation { Digest::of(Object(Record())) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }; + conn->send(self->ref()->storage(), move(header), {}, false); + } + + return Connection(move(conn)); +} + +void NetworkProtocol::updateIdentity(Identity id) +{ + scoped_lock lock(protocolMutex); + self = move(id); + + vector<Header::Item> hitems; + for (const auto & r : self->extRefs()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + for (const auto & r : self->updates()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + + Header header(hitems); + + for (const auto & conn : connections) + conn->send(self->ref()->storage(), header, { **self->ref() }, false); +} + +void NetworkProtocol::announceTo(variant<sockaddr_in, sockaddr_in6> addr) +{ + vector<uint8_t> bytes; + { + scoped_lock lock(protocolMutex); + + if (!self) + throw runtime_error("NetworkProtocol::announceTo without self identity"); + + bytes = Header({ + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + } + + sendto(bytes, addr); +} + +void NetworkProtocol::shutdown() +{ + ::shutdown(sock, SHUT_RDWR); +} + +bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr) +{ + socklen_t addrlen = sizeof(addr); + buffer.resize(4096); + ssize_t ret = ::recvfrom(sock, buffer.data(), buffer.size(), 0, + (sockaddr *) &addr, &addrlen); + if (ret < 0) + throw std::system_error(errno, std::generic_category()); + if (ret == 0) + return false; + + buffer.resize(ret); + return true; +} + +void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> vaddr) +{ + visit([&](auto && addr) { + ::sendto(sock, buffer.data(), buffer.size(), 0, + (sockaddr *) &addr, sizeof(addr)); + }, vaddr); +} + +void NetworkProtocol::sendCookie(variant<sockaddr_in, sockaddr_in6> addr) +{ + auto bytes = Header({ + Header::CookieSet { generateCookie(addr) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + + sendto(bytes, addr); +} + +optional<NetworkProtocol::Connection> NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr) +{ + optional<string> version; + for (const auto & h : header.items) { + if (const auto * ptr = get_if<Header::Version>(&h)) { + if (ptr->value == defaultVersion) { + version = ptr->value; + break; + } + } + } + if (!version) + return nullopt; + + if (header.lookupFirst<Header::Initiation>()) { + sendCookie(addr); + } + + else if (auto cookie = header.lookupFirst<Header::CookieEcho>()) { + if (verifyCookie(addr, cookie->value)) { + auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + connections.push_back(conn.get()); + buffer.swap(conn->buffer); + return Connection(move(conn)); + } + } + + return nullopt; +} + +NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant<sockaddr_in, sockaddr_in6> vaddr) const +{ + vector<uint8_t> cookie; + visit([&](auto && addr) { + cookie.resize(sizeof addr); + memcpy(cookie.data(), &addr, sizeof addr); + }, vaddr); + return Cookie { cookie }; +} + +bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, const NetworkProtocol::Cookie & cookie) const +{ + return visit([&](auto && addr) { + if (cookie.value.size() != sizeof addr) + return false; + return memcmp(cookie.value.data(), &addr, sizeof addr) == 0; + }, vaddr); +} + +/******************************************************************************/ +/* Connection */ +/******************************************************************************/ + +NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const +{ + return reinterpret_cast<uintptr_t>(this); +} + +NetworkProtocol::Connection::Connection(unique_ptr<ConnectionPriv> p_): + p(move(p_)) +{ +} + +NetworkProtocol::Connection::Connection(Connection && other): + p(move(other.p)) +{ +} + +NetworkProtocol::Connection & NetworkProtocol::Connection::operator=(Connection && other) +{ + close(); + p = move(other.p); + return *this; +} + +NetworkProtocol::Connection::~Connection() +{ + close(); +} + +NetworkProtocol::Connection::Id NetworkProtocol::Connection::id() const +{ + return p->id(); +} + +const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const +{ + return p->peerAddress; +} + +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) +{ + vector<uint8_t> buf; + + Channel * channel = nullptr; + unique_ptr<Channel> channelPtr; + + { + scoped_lock lock(p->cmutex); + + if (p->buffer.empty()) + return nullopt; + buf.swap(p->buffer); + + if (holds_alternative<unique_ptr<Channel>>(p->channel)) { + channel = std::get<unique_ptr<Channel>>(p->channel).get(); + } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { + channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); + channel = channelPtr.get(); + } + } + + optional<uint64_t> secure = false; + if (auto header = parsePacket(buf, channel, partStorage, secure)) { + scoped_lock lock(p->cmutex); + + if (secure) { + if (header->isAcknowledged()) + p->toAcknowledge.push_back(*secure); + return header; + } + + if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) { + if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) + return nullopt; + + p->confirmedCookie = true; + + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) + p->receivedCookie = cookieSet->value; + + return header; + } + + if (holds_alternative<monostate>(p->channel)) { + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) { + p->receivedCookie = cookieSet->value; + return header; + } + } + + if (header->lookupFirst<Header::Initiation>()) { + p->protocol->sendCookie(p->peerAddress); + return nullopt; + } + } + return nullopt; +} + +optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & partStorage, + optional<uint64_t> & secure) +{ + vector<uint8_t> decrypted; + auto plainBegin = buf.cbegin(); + auto plainEnd = buf.cbegin(); + + secure = nullopt; + + if ((buf[0] & 0xE0) == 0x80) { + if (not channel) { + std::cerr << "unexpected encrypted packet\n"; + return nullopt; + } + + if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) { + if (decrypted.empty()) { + std::cerr << "empty decrypted content\n"; + } + else if (decrypted[0] == 0x00) { + plainBegin = decrypted.begin() + 1; + plainEnd = decrypted.end(); + } + else { + std::cerr << "streams not implemented\n"; + return nullopt; + } + } + } + else if ((buf[0] & 0xE0) == 0x60) { + plainBegin = buf.begin(); + plainEnd = buf.end(); + } + + if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) { + if (auto header = Header::load(std::get<PartialObject>(*dec))) { + auto pos = std::get<1>(*dec); + while (auto cdec = PartialObject::decodePrefix(partStorage, pos, plainEnd)) { + partStorage.storeObject(std::get<PartialObject>(*cdec)); + pos = std::get<1>(*cdec); + } + + return header; + } + } + + std::cerr << "invalid packet\n"; + return nullopt; +} + +bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, + Header header, + const vector<Object> & objs, bool secure) +{ + return p->send(partStorage, move(header), objs, secure); +} + +bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, + Header header, + const vector<Object> & objs, bool secure) +{ + vector<uint8_t> data, part, out; + + { + scoped_lock clock(cmutex); + + Channel * channel = nullptr; + if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel)) + channel = uptr->get(); + + if (channel || secure) { + data.push_back(0x00); + } else { + if (receivedCookie) + header.items.push_back(Header::CookieEcho { receivedCookie->value }); + if (!confirmedCookie) + header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) }); + } + + if (channel) { + for (auto num : toAcknowledge) + header.items.push_back(Header::AcknowledgedSingle { num }); + toAcknowledge.clear(); + } + + if (header.items.empty()) + return false; + + part = header.toObject(partStorage).encode(); + data.insert(data.end(), part.begin(), part.end()); + for (const auto & obj : objs) { + part = obj.encode(); + data.insert(data.end(), part.begin(), part.end()); + } + + if (channel) { + out.push_back(0x80); + channel->encrypt(data.begin(), data.end(), out, 1); + } else if (secure) { + secureOutQueue.emplace_back(move(data)); + } else { + out = std::move(data); + } + } + + if (not out.empty()) + protocol->sendto(out, peerAddress); + + return true; +} + +void NetworkProtocol::Connection::close() +{ + if (not p) + return; + + if (p->protocol) { + scoped_lock lock(p->protocol->protocolMutex); + for (auto it = p->protocol->connections.begin(); + it != p->protocol->connections.end(); it++) { + if ((*it) == p.get()) { + p->protocol->connections.erase(it); + break; + } + } + } + + p = nullptr; +} + +NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel() +{ + return p->channel; +} + +void NetworkProtocol::Connection::trySendOutQueue() +{ + decltype(p->secureOutQueue) queue; + { + scoped_lock clock(p->cmutex); + + if (p->secureOutQueue.empty()) + return; + + if (not holds_alternative<unique_ptr<Channel>>(p->channel)) + return; + + queue.swap(p->secureOutQueue); + } + + vector<uint8_t> out { 0x80 }; + for (const auto & data : queue) { + std::get<unique_ptr<Channel>>(p->channel)->encrypt(data.begin(), data.end(), out, 1); + p->protocol->sendto(out, p->peerAddress); + } +} + + +/******************************************************************************/ +/* Header */ +/******************************************************************************/ + +bool operator==(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) +{ + if (left.index() != right.index()) + return false; + + return visit([&](auto && arg) { + using T = std::decay_t<decltype(arg)>; + return arg.value == std::get<T>(right).value; + }, left); +} + +optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialRef & ref) +{ + return load(*ref); +} + +optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObject & obj) +{ + auto rec = obj.asRecord(); + if (!rec) + return nullopt; + + vector<Item> items; + for (const auto & item : rec->items()) { + if (item.name == "ACK") { + if (auto ref = item.asRef()) + items.emplace_back(Acknowledged { ref->digest() }); + else if (auto num = item.asInteger()) + items.emplace_back(AcknowledgedSingle { static_cast<uint64_t>(*num) }); + } else if (item.name == "VER") { + if (auto ver = item.asText()) + items.emplace_back(Version { *ver }); + } else if (item.name == "INI") { + if (auto ref = item.asRef()) + items.emplace_back(Initiation { ref->digest() }); + } else if (item.name == "CKS") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieSet { *cookie }); + } else if (item.name == "CKE") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieEcho { *cookie }); + } else if (item.name == "REQ") { + if (auto ref = item.asRef()) + items.emplace_back(DataRequest { ref->digest() }); + } else if (item.name == "RSP") { + if (auto ref = item.asRef()) + items.emplace_back(DataResponse { ref->digest() }); + } else if (item.name == "ANN") { + if (auto ref = item.asRef()) + items.emplace_back(AnnounceSelf { ref->digest() }); + } else if (item.name == "ANU") { + if (auto ref = item.asRef()) + items.emplace_back(AnnounceUpdate { ref->digest() }); + } else if (item.name == "CRQ") { + if (auto ref = item.asRef()) + items.emplace_back(ChannelRequest { ref->digest() }); + } else if (item.name == "CAC") { + if (auto ref = item.asRef()) + items.emplace_back(ChannelAccept { ref->digest() }); + } else if (item.name == "SVT") { + if (auto val = item.asUUID()) + items.emplace_back(ServiceType { *val }); + } else if (item.name == "SVR") { + if (auto ref = item.asRef()) + items.emplace_back(ServiceRef { ref->digest() }); + } + } + + return NetworkProtocol::Header(items); +} + +PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const +{ + vector<PartialRecord::Item> ritems; + + for (const auto & item : items) { + if (const auto * ptr = get_if<Acknowledged>(&item)) + ritems.emplace_back("ACK", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AcknowledgedSingle>(&item)) + ritems.emplace_back("ACK", Record::Item::Integer(ptr->value)); + + else if (const auto * ptr = get_if<Version>(&item)) + ritems.emplace_back("VER", ptr->value); + + else if (const auto * ptr = get_if<Initiation>(&item)) + ritems.emplace_back("INI", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<CookieSet>(&item)) + ritems.emplace_back("CKS", ptr->value.value); + + else if (const auto * ptr = get_if<CookieEcho>(&item)) + ritems.emplace_back("CKE", ptr->value.value); + + else if (const auto * ptr = get_if<DataRequest>(&item)) + ritems.emplace_back("REQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<DataResponse>(&item)) + ritems.emplace_back("RSP", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceSelf>(&item)) + ritems.emplace_back("ANN", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceUpdate>(&item)) + ritems.emplace_back("ANU", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelRequest>(&item)) + ritems.emplace_back("CRQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelAccept>(&item)) + ritems.emplace_back("CAC", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ServiceType>(&item)) + ritems.emplace_back("SVT", ptr->value); + + else if (const auto * ptr = get_if<ServiceRef>(&item)) + ritems.emplace_back("SVR", st.ref(ptr->value)); + } + + return PartialObject(PartialRecord(std::move(ritems))); +} + +bool NetworkProtocol::Header::isAcknowledged() const +{ + for (const auto & item : items) { + if (holds_alternative<Acknowledged>(item) + || holds_alternative<AcknowledgedSingle>(item) + || holds_alternative<Version>(item) + || holds_alternative<Initiation>(item) + || holds_alternative<CookieSet>(item) + || holds_alternative<CookieEcho>(item) + ) + continue; + + return true; + } + return false; +} + +} diff --git a/src/network/protocol.h b/src/network/protocol.h new file mode 100644 index 0000000..ba40744 --- /dev/null +++ b/src/network/protocol.h @@ -0,0 +1,213 @@ +#pragma once + +#include "channel.h" + +#include <erebos/storage.h> + +#include <netinet/in.h> + +#include <cstdint> +#include <memory> +#include <mutex> +#include <variant> +#include <vector> +#include <optional> + +namespace erebos { + +using std::mutex; +using std::optional; +using std::unique_ptr; +using std::variant; +using std::vector; + +class NetworkProtocol +{ +public: + NetworkProtocol(); + explicit NetworkProtocol(int sock, Identity self); + NetworkProtocol(const NetworkProtocol &) = delete; + NetworkProtocol(NetworkProtocol &&); + NetworkProtocol & operator=(const NetworkProtocol &) = delete; + NetworkProtocol & operator=(NetworkProtocol &&); + ~NetworkProtocol(); + + static constexpr char defaultVersion[] = "0.1"; + + class Connection; + + struct Header; + + struct ReceivedAnnounce; + struct NewConnection; + struct ConnectionReadReady; + struct ProtocolClosed {}; + + using PollResult = variant< + ReceivedAnnounce, + NewConnection, + ConnectionReadReady, + ProtocolClosed>; + + PollResult poll(); + + struct Cookie { vector<uint8_t> value; }; + + using ChannelState = variant<monostate, + Stored<ChannelRequest>, + shared_ptr<struct WaitingRef>, + Stored<ChannelAccept>, + unique_ptr<Channel>>; + + Connection connect(sockaddr_in6 addr); + + void updateIdentity(Identity self); + void announceTo(variant<sockaddr_in, sockaddr_in6> addr); + + void shutdown(); + +private: + bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr); + void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr); + + void sendCookie(variant<sockaddr_in, sockaddr_in6> addr); + optional<Connection> verifyNewConnection(const Header & header, sockaddr_in6 addr); + + Cookie generateCookie(variant<sockaddr_in, sockaddr_in6> addr) const; + bool verifyCookie(variant<sockaddr_in, sockaddr_in6> addr, const Cookie & cookie) const; + + int sock; + + mutex protocolMutex; + vector<uint8_t> buffer; + + optional<Identity> self; + + struct ConnectionPriv; + vector<ConnectionPriv *> connections; +}; + +class NetworkProtocol::Connection +{ + friend class NetworkProtocol; + Connection(unique_ptr<ConnectionPriv> p); +public: + Connection(const Connection &) = delete; + Connection(Connection &&); + Connection & operator=(const Connection &) = delete; + Connection & operator=(Connection &&); + ~Connection(); + + using Id = uintptr_t; + Id id() const; + + const sockaddr_in6 & peerAddress() const; + + optional<Header> receive(const PartialStorage &); + bool send(const PartialStorage &, NetworkProtocol::Header, + const vector<Object> &, bool secure); + + void close(); + + // temporary: + ChannelState & channel(); + void trySendOutQueue(); + +private: + static optional<Header> parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & st, + optional<uint64_t> & secure); + + unique_ptr<ConnectionPriv> p; +}; + +struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; }; +struct NetworkProtocol::NewConnection { Connection conn; }; +struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; + +struct NetworkProtocol::Header +{ + struct Acknowledged { Digest value; }; + struct AcknowledgedSingle { uint64_t value; }; + struct Version { string value; }; + struct Initiation { Digest value; }; + struct CookieSet { Cookie value; }; + struct CookieEcho { Cookie value; }; + struct DataRequest { Digest value; }; + struct DataResponse { Digest value; }; + struct AnnounceSelf { Digest value; }; + struct AnnounceUpdate { Digest value; }; + struct ChannelRequest { Digest value; }; + struct ChannelAccept { Digest value; }; + struct ServiceType { UUID value; }; + struct ServiceRef { Digest value; }; + + using Item = variant< + Acknowledged, + AcknowledgedSingle, + Version, + Initiation, + CookieSet, + CookieEcho, + DataRequest, + DataResponse, + AnnounceSelf, + AnnounceUpdate, + ChannelRequest, + ChannelAccept, + ServiceType, + ServiceRef>; + + Header(const vector<Item> & items): items(items) {} + static optional<Header> load(const PartialRef &); + static optional<Header> load(const PartialObject &); + PartialObject toObject(const PartialStorage &) const; + + template<class T> const T * lookupFirst() const; + bool isAcknowledged() const; + + vector<Item> items; +}; + +template<class T> +const T * NetworkProtocol::Header::lookupFirst() const +{ + for (const auto & h : items) + if (auto ptr = std::get_if<T>(&h)) + return ptr; + return nullptr; +} + +bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &); +inline bool operator!=(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) +{ return not (left == right); } + +inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right) +{ return left.value == right.value; } + +class ReplyBuilder +{ +public: + void header(NetworkProtocol::Header::Item &&); + void body(const Ref &); + + const vector<NetworkProtocol::Header::Item> & header() const { return mheader; } + vector<Object> body() const; + +private: + vector<NetworkProtocol::Header::Item> mheader; + vector<Ref> mbody; +}; + +struct WaitingRef +{ + const Storage storage; + const PartialRef ref; + vector<Digest> missing; + + optional<Ref> check(); + optional<Ref> check(ReplyBuilder &); +}; + +} diff --git a/src/pairing.cpp b/src/pairing.cpp new file mode 100644 index 0000000..dca5b03 --- /dev/null +++ b/src/pairing.cpp @@ -0,0 +1,306 @@ +#include <erebos/pairing.h> + +#include "service.h" + +#include <future> +#include <openssl/rand.h> + +#include <arpa/inet.h> + +#include <algorithm> +#include <stdexcept> +#include <thread> +#include <vector> + +using namespace erebos; + +using std::lock_guard; +using std::make_shared; +using std::runtime_error; +using std::scoped_lock; +using std::thread; +using std::unique_lock; + +PairingServiceBase::PairingServiceBase(Config && c): + config(move(c)) +{ +} + +PairingServiceBase::~PairingServiceBase() +{ + // There may be some threads in waitForConfirmation waiting on client + // promise, so make sure they do not touch the service state anymore: + for (auto & [peer, state] : peerStates) { + scoped_lock lock(state->lock); + if (state->phase != StatePhase::PairingDone && + state->phase != StatePhase::PairingFailed) { + state->outcome.set_value(Outcome::Stale); + state->phase = StatePhase::PairingFailed; + } + } +} + +PairingServiceBase::Config & PairingServiceBase::Config::onRequestInit(RequestInitHook hook) +{ + requestInitHook = hook; + return *this; +} + +PairingServiceBase::Config & PairingServiceBase::Config::onResponse(ConfirmHook hook) +{ + responseHook = hook; + return *this; +} + +PairingServiceBase::Config & PairingServiceBase::Config::onRequest(ConfirmHook hook) +{ + requestHook = hook; + return *this; +} + +PairingServiceBase::Config & PairingServiceBase::Config::onRequestNonceFailed(RequestNonceFailedHook hook) +{ + requestNonceFailedHook = hook; + return *this; +} + +void PairingServiceBase::handle(Context & ctx) +{ + auto rec = ctx.ref()->asRecord(); + if (!rec) + return; + + auto pid = ctx.peer().identity(); + if (!pid) + throw runtime_error("Pairing request for peer without known identity"); + + lock_guard lock(stateLock); + auto & state = peerStates.try_emplace(ctx.peer(), new State()).first->second; + unique_lock lock_state(state->lock); + + if (auto request = rec->item("request").asBinary()) { + auto idReqRef = rec->item("id-req").asRef(); + if (!idReqRef) + return; + auto idReq = Identity::load(*idReqRef); + if (!idReq) + return; + if (!idReq->sameAs(*pid)) + return; + + auto idRspRef = rec->item("id-rsp").asRef(); + if (!idRspRef) + return; + auto idRsp = Identity::load(*idRspRef); + if (!idRsp) + return; + if (!idRsp->sameAs(ctx.peer().server().identity())) + return; + + if (state->phase >= StatePhase::PairingDone) { + auto nstate = make_shared<State>(); + lock_state = unique_lock(nstate->lock); + state = move(nstate); + } else if (state->phase != StatePhase::NoPairing) + return; + + if (config.requestInitHook) + config.requestInitHook(ctx.peer()); + + state->phase = StatePhase::PeerRequest; + state->idReq = idReq; + state->idRsp = idRsp; + state->peerCheck = *request; + state->nonce.resize(32); + RAND_bytes(state->nonce.data(), state->nonce.size()); + + ctx.peer().send(uuid(), Object(Record({ + { "response", state->nonce }, + }))); + } + + else if (auto response = rec->item("response").asBinary()) { + if (state->phase != StatePhase::OurRequest) { + fprintf(stderr, "Unexpected pairing response.\n"); // TODO + return; + } + + if (config.responseHook) { + string confirm = confirmationNumber(nonceDigest( + *state->idReq, *state->idRsp, + state->nonce, *response)); + std::thread(&PairingServiceBase::waitForConfirmation, + this, ctx.peer(), state, confirm, config.responseHook).detach(); + } + + state->phase = StatePhase::OurRequestConfirm; + + ctx.peer().send(uuid(), Object(Record({ + { "reqnonce", state->nonce }, + }))); + } + + else if (auto reqnonce = rec->item("reqnonce").asBinary()) { + if (state->phase != StatePhase::PeerRequest) + return; + + auto check = nonceDigest( + *state->idReq, *state->idRsp, + *reqnonce, vector<uint8_t>()); + if (check != state->peerCheck) { + if (config.requestNonceFailedHook) + config.requestNonceFailedHook(ctx.peer()); + if (state->phase < StatePhase::PairingDone) { + state->phase = StatePhase::PairingFailed; + ctx.afterCommit([&]() { + state->outcome.set_value(Outcome::NonceMismatch); + }); + } + return; + } + + if (config.requestHook) { + string confirm = confirmationNumber(nonceDigest( + *state->idReq, *state->idRsp, + *reqnonce, state->nonce)); + std::thread(&PairingServiceBase::waitForConfirmation, + this, ctx.peer(), state, confirm, config.requestHook).detach(); + } + + state->phase = StatePhase::PeerRequestConfirm; + } + + else if (rec->item("reject")) { + if (state->phase < StatePhase::PairingDone) { + state->phase = StatePhase::PairingFailed; + ctx.afterCommit([&]() { + state->outcome.set_value(Outcome::PeerRejected); + }); + } + } + + else { + if (state->phase == StatePhase::OurRequestReady) { + handlePairingResult(ctx); + state->phase = StatePhase::PairingDone; + ctx.afterCommit([&]() { + state->outcome.set_value(Outcome::Success); + }); + } else { + result = ctx.ref(); + } + } +} + +void PairingServiceBase::requestPairing(UUID serviceId, const Peer & peer) +{ + auto pid = peer.identity(); + if (!pid) + throw runtime_error("Pairing request for peer without known identity"); + + unique_lock lock(stateLock); + auto & state = peerStates.try_emplace(peer, new State()).first->second; + + if (state->phase != StatePhase::NoPairing) { + auto nstate = make_shared<State>(); + lock = unique_lock(nstate->lock); + state = move(nstate); + } + + state->phase = StatePhase::OurRequest; + state->idReq = peer.server().identity(); + state->idRsp = pid; + state->nonce.resize(32); + RAND_bytes(state->nonce.data(), state->nonce.size()); + + vector<Record::Item> items; + items.emplace_back("id-req", state->idReq->ref().value()); + items.emplace_back("id-rsp", state->idRsp->ref().value()); + items.emplace_back("request", nonceDigest( + *state->idReq, *state->idRsp, + state->nonce, vector<uint8_t>())); + + peer.send(serviceId, Object(Record(std::move(items)))); +} + +vector<uint8_t> PairingServiceBase::nonceDigest(const Identity & idReq, const Identity & idRsp, + const vector<uint8_t> & nonceReq, const vector<uint8_t> & nonceRsp) +{ + vector<Record::Item> items; + items.emplace_back("id-req", idReq.ref().value()); + items.emplace_back("id-rsp", idRsp.ref().value()); + items.emplace_back("nonce-req", nonceReq); + items.emplace_back("nonce-rsp", nonceRsp); + + const auto arr = Digest::of(Object(Record(std::move(items)))).arr(); + vector<uint8_t> ret(arr.size()); + std::copy_n(arr.begin(), arr.size(), ret.begin()); + return ret; +} + +string PairingServiceBase::confirmationNumber(const vector<uint8_t> & digest) +{ + uint32_t confirm; + memcpy(&confirm, digest.data(), sizeof(confirm)); + string ret(6, '\0'); + snprintf(ret.data(), ret.size() + 1, "%06d", ntohl(confirm) % 1000000); + return ret; +} + +void PairingServiceBase::waitForConfirmation(Peer peer, weak_ptr<State> wstate, string confirm, ConfirmHook hook) +{ + future<Outcome> outcome; + if (auto state = wstate.lock()) { + outcome = state->outcome.get_future(); + } else { + return; + } + + bool ok; + try { + ok = hook(peer, confirm, std::move(outcome)).get(); + } + catch (const std::future_error & e) { + if (e.code() == std::future_errc::broken_promise) + ok = false; + else + throw; + } + + auto state = wstate.lock(); + if (!state) + return; // Server was closed + + scoped_lock lock(state->lock); + + if (ok) { + if (state->phase == StatePhase::OurRequestConfirm) { + if (result) { + peer.server().localHead().update([&] (const Stored<LocalState> & local) { + Service::Context ctx(new Service::Context::Priv { + .ref = *result, + .peer = peer, + .local = local, + }); + + handlePairingResult(ctx); + return ctx.local(); + }); + state->phase = StatePhase::PairingDone; + state->outcome.set_value(Outcome::Success); + } else { + state->phase = StatePhase::OurRequestReady; + } + } else if (state->phase == StatePhase::PeerRequestConfirm) { + peer.send(uuid(), handlePairingCompleteRef(peer)); + state->phase = StatePhase::PairingDone; + state->outcome.set_value(Outcome::Success); + } + } else { + if (state->phase != StatePhase::PairingFailed) { + peer.send(uuid(), Object(Record({{ "reject", Record::Item::Empty {} }}))); + state->phase = StatePhase::PairingFailed; + state->outcome.set_value(Outcome::UserRejected); + } + } +} diff --git a/src/pubkey.cpp b/src/pubkey.cpp new file mode 100644 index 0000000..59b73f9 --- /dev/null +++ b/src/pubkey.cpp @@ -0,0 +1,296 @@ +#include "pubkey.h" + +#include <stdexcept> + +using std::unique_ptr; +using std::runtime_error; +using std::string; + +using namespace erebos; + +PublicKey PublicKey::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return PublicKey(nullptr); + + if (auto ktype = rec->item("type").asText()) + if (ktype.value() != "ed25519") + throw runtime_error("unsupported key type " + ktype.value()); + + if (auto pubkey = rec->item("pubkey").asBinary()) + return PublicKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_ED25519, nullptr, + pubkey.value().data(), pubkey.value().size())); + + return PublicKey(nullptr); +} + +Ref PublicKey::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("type", "ed25519"); + + if (key) { + vector<uint8_t> keyData; + size_t keyLen; + EVP_PKEY_get_raw_public_key(key.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(key.get(), keyData.data(), &keyLen); + items.emplace_back("pubkey", keyData); + } + + return st.storeObject(Record(std::move(items))); +} + +SecretKey SecretKey::generate(const Storage & st) +{ + unique_ptr<EVP_PKEY_CTX, void(*)(EVP_PKEY_CTX*)> + pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_ED25519, NULL), &EVP_PKEY_CTX_free); + if (!pctx) + throw runtime_error("failed to generate key"); + + if (EVP_PKEY_keygen_init(pctx.get()) != 1) + throw runtime_error("failed to generate key"); + + EVP_PKEY *pkey = NULL; + if (EVP_PKEY_keygen(pctx.get(), &pkey) != 1) + throw runtime_error("failed to generate key"); + shared_ptr<EVP_PKEY> seckey(pkey, EVP_PKEY_free); + + vector<uint8_t> keyData; + size_t keyLen; + + EVP_PKEY_get_raw_public_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(seckey.get(), keyData.data(), &keyLen); + auto pubkey = st.store(PublicKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_ED25519, nullptr, + keyData.data(), keyData.size()))); + + EVP_PKEY_get_raw_private_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_private_key(seckey.get(), keyData.data(), &keyLen); + st.storeKey(pubkey.ref(), keyData); + + return SecretKey(std::move(seckey), pubkey); +} + +optional<SecretKey> SecretKey::load(const Stored<PublicKey> & pub) +{ + auto keyData = pub.ref().storage().loadKey(pub.ref()); + if (!keyData) + return nullopt; + + EVP_PKEY * key = EVP_PKEY_new_raw_private_key(EVP_PKEY_ED25519, nullptr, + keyData->data(), keyData->size()); + if (!key) + throw runtime_error("falied to parse secret key"); + return SecretKey(key, pub); +} + +optional<SecretKey> SecretKey::fromData(const Stored<PublicKey> & pub, const vector<uint8_t> & sdata) +{ + shared_ptr<EVP_PKEY> pkey( + EVP_PKEY_new_raw_private_key(EVP_PKEY_ED25519, NULL, + sdata.data(), sdata.size()), + EVP_PKEY_free); + if (!pkey) + return nullopt; + + vector<uint8_t> keyData; + size_t keyLen; + + EVP_PKEY_get_raw_public_key(pkey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(pkey.get(), keyData.data(), &keyLen); + + EVP_PKEY_get_raw_public_key(pub->key.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(pub->key.get(), keyData.data(), &keyLen); + + if (EVP_PKEY_cmp(pkey.get(), pub->key.get()) != 1) + return nullopt; + + pub.ref().storage().storeKey(pub.ref(), sdata); + return SecretKey(std::move(pkey), pub); +} + +vector<uint8_t> SecretKey::getData() const +{ + vector<uint8_t> keyData; + size_t keyLen; + + EVP_PKEY_get_raw_private_key(key.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_private_key(key.get(), keyData.data(), &keyLen); + + return keyData; +} + +vector<uint8_t> SecretKey::sign(const Digest & dgst) const +{ + unique_ptr<EVP_MD_CTX, void(*)(EVP_MD_CTX*)> + mdctx(EVP_MD_CTX_create(), &EVP_MD_CTX_free); + if (!mdctx) + throw runtime_error("failed to create EVP_MD_CTX"); + + if (EVP_DigestSignInit(mdctx.get(), nullptr, EVP_md_null(), + nullptr, key.get()) != 1) + throw runtime_error("failed to initialize EVP_MD_CTX"); + + size_t sigLen; + if (EVP_DigestSign(mdctx.get(), nullptr, &sigLen, + dgst.arr().data(), Digest::size) != 1) + throw runtime_error("failed to sign data"); + + vector<uint8_t> sigData(sigLen); + if (EVP_DigestSign(mdctx.get(), sigData.data(), &sigLen, + dgst.arr().data(), Digest::size) != 1) + throw runtime_error("failed to sign data"); + + return sigData; +} + +Signature Signature::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) + if (auto key = rec->item("key").as<PublicKey>()) + if (auto sig = rec->item("sig").asBinary()) + return Signature(*key, *sig); + + return Signature(Stored<PublicKey>::load(ref.storage().zref()), {}); +} + +Ref Signature::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("key", key); + items.emplace_back("sig", sig); + + return st.storeObject(Record(std::move(items))); +} + +bool Signature::verify(const Ref & ref) const +{ + if (!key->key) + return false; + + unique_ptr<EVP_MD_CTX, void(*)(EVP_MD_CTX*)> + mdctx(EVP_MD_CTX_create(), &EVP_MD_CTX_free); + if (!mdctx) + throw runtime_error("failed to create EVP_MD_CTX"); + + if (EVP_DigestVerifyInit(mdctx.get(), nullptr, EVP_md_null(), + nullptr, key->key.get()) != 1) + throw runtime_error("failed to initialize EVP_MD_CTX"); + + return EVP_DigestVerify(mdctx.get(), sig.data(), sig.size(), + ref.digest().arr().data(), Digest::size) == 1; +} + + +PublicKexKey PublicKexKey::load(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return PublicKexKey(nullptr); + + if (auto ktype = rec->item("type").asText()) + if (ktype.value() != "x25519") + throw runtime_error("unsupported key type " + ktype.value()); + + if (auto pubkey = rec->item("pubkey").asBinary()) + return PublicKexKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + pubkey.value().data(), pubkey.value().size())); + + return PublicKexKey(nullptr); +} + +Ref PublicKexKey::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("type", "x25519"); + + if (key) { + vector<uint8_t> keyData; + size_t keyLen; + EVP_PKEY_get_raw_public_key(key.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(key.get(), keyData.data(), &keyLen); + items.emplace_back("pubkey", keyData); + } + + return st.storeObject(Record(std::move(items))); +} + +SecretKexKey SecretKexKey::generate(const Storage & st) +{ + unique_ptr<EVP_PKEY_CTX, void(*)(EVP_PKEY_CTX*)> + pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_X25519, NULL), &EVP_PKEY_CTX_free); + if (!pctx) + throw runtime_error("failed to generate key"); + + if (EVP_PKEY_keygen_init(pctx.get()) != 1) + throw runtime_error("failed to generate key"); + + EVP_PKEY *pkey = NULL; + if (EVP_PKEY_keygen(pctx.get(), &pkey) != 1) + throw runtime_error("failed to generate key"); + shared_ptr<EVP_PKEY> seckey(pkey, EVP_PKEY_free); + + vector<uint8_t> keyData; + size_t keyLen; + + EVP_PKEY_get_raw_public_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_public_key(seckey.get(), keyData.data(), &keyLen); + auto pubkey = st.store(PublicKexKey(EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + keyData.data(), keyData.size()))); + + EVP_PKEY_get_raw_private_key(seckey.get(), nullptr, &keyLen); + keyData.resize(keyLen); + EVP_PKEY_get_raw_private_key(seckey.get(), keyData.data(), &keyLen); + st.storeKey(pubkey.ref(), keyData); + + return SecretKexKey(std::move(seckey), pubkey); +} + +optional<SecretKexKey> SecretKexKey::load(const Stored<PublicKexKey> & pub) +{ + auto keyData = pub.ref().storage().loadKey(pub.ref()); + if (!keyData) + return nullopt; + + EVP_PKEY * key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, nullptr, + keyData->data(), keyData->size()); + if (!key) + throw runtime_error("falied to parse secret key"); + return SecretKexKey(key, pub); +} + +vector<uint8_t> SecretKexKey::dh(const PublicKexKey & pubkey) const +{ + unique_ptr<EVP_PKEY_CTX, void(*)(EVP_PKEY_CTX*)> + pctx(EVP_PKEY_CTX_new(key.get(), nullptr), &EVP_PKEY_CTX_free); + if (!pctx) + throw runtime_error("failed to derive shared secret"); + + if (EVP_PKEY_derive_init(pctx.get()) <= 0) + throw runtime_error("failed to derive shared secret"); + + if (EVP_PKEY_derive_set_peer(pctx.get(), pubkey.key.get()) <= 0) + throw runtime_error("failed to derive shared secret"); + + size_t dhlen; + if (EVP_PKEY_derive(pctx.get(), NULL, &dhlen) <= 0) + throw runtime_error("failed to derive shared secret"); + + vector<uint8_t> dhsecret(dhlen); + + if (EVP_PKEY_derive(pctx.get(), dhsecret.data(), &dhlen) <= 0) + throw runtime_error("failed to derive shared secret"); + + return dhsecret; +} diff --git a/src/pubkey.h b/src/pubkey.h new file mode 100644 index 0000000..ca662ba --- /dev/null +++ b/src/pubkey.h @@ -0,0 +1,175 @@ +#pragma once + +#include "storage.h" + +#include <openssl/evp.h> + +using std::nullopt; +using std::optional; +using std::shared_ptr; + +namespace erebos { + +template<typename T> class Signed; + +class PublicKey +{ + PublicKey(EVP_PKEY * key): + key(key, EVP_PKEY_free) {} + friend class SecretKey; +public: + static PublicKey load(const Ref &); + Ref store(const Storage &) const; + + const shared_ptr<EVP_PKEY> key; +}; + +class SecretKey +{ + SecretKey(EVP_PKEY * key, const Stored<PublicKey> & pub): + key(key, EVP_PKEY_free), pub_(pub) {} + SecretKey(shared_ptr<EVP_PKEY> && key, const Stored<PublicKey> & pub): + key(key), pub_(pub) {} +public: + static SecretKey generate(const Storage & st); + static optional<SecretKey> load(const Stored<PublicKey> & st); + + static optional<SecretKey> fromData(const Stored<PublicKey> &, const vector<uint8_t> &); + vector<uint8_t> getData() const; + + Stored<PublicKey> pub() const { return pub_; } + + template<class T> + Stored<Signed<T>> sign(const Stored<T> &) const; + template<class T> + Stored<Signed<T>> signAdd(const Stored<Signed<T>> &) const; + +private: + vector<uint8_t> sign(const Digest &) const; + + const shared_ptr<EVP_PKEY> key; + Stored<PublicKey> pub_; +}; + +class Signature +{ +public: + static Signature load(const Ref &); + Ref store(const Storage &) const; + + bool verify(const Ref &) const; + + Stored<PublicKey> key; + vector<uint8_t> sig; + +private: + friend class SecretKey; + Signature(const Stored<PublicKey> & key, const vector<uint8_t> & sig): + key(key), sig(sig) {} +}; + +template<typename T> +class Signed +{ +public: + static Signed<T> load(const Ref &); + Ref store(const Storage &) const; + + bool isSignedBy(const Stored<PublicKey> &) const; + + const Stored<T> data; + const vector<Stored<Signature>> sigs; + +private: + friend class SecretKey; + Signed(const Stored<T> & data, const vector<Stored<Signature>> & sigs): + data(data), sigs(sigs) {} +}; + +template<class T> +Stored<Signed<T>> SecretKey::sign(const Stored<T> & val) const +{ + auto st = val.ref().storage(); + auto sig = st.store(Signature(pub(), sign(val.ref().digest()))); + return st.store(Signed(val, { sig })); +} + +template<class T> +Stored<Signed<T>> SecretKey::signAdd(const Stored<Signed<T>> & val) const +{ + auto st = val.ref().storage(); + auto sig = st.store(Signature(pub(), sign(val.ref().digest()))); + auto sigs = val->sigs; + sigs.push_back(st.store(Signature(pub(), sign(val->data.ref().digest())))); + return st.store(Signed(val->data, sigs)); +} + +template<typename T> +Signed<T> Signed<T>::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) + if (auto data = rec->item("SDATA").as<T>()) { + vector<Stored<Signature>> sigs; + for (const auto & sig : rec->items("sig").as<Signature>()) + if (sig->verify(data.value().ref())) + sigs.push_back(sig); + + return Signed(*data, sigs); + } + + return Signed(Stored<T>::load(ref.storage().zref()), {}); +} + +template<typename T> +Ref Signed<T>::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("SDATA", data); + for (const auto & sig : sigs) + items.emplace_back("sig", sig); + + return st.storeObject(Record(std::move(items))); +} + +template<typename T> +bool Signed<T>::isSignedBy(const Stored<PublicKey> & key) const +{ + for (const auto & sig : sigs) + if (sig->key == key) + return true; + return false; +} + + +class PublicKexKey +{ + PublicKexKey(EVP_PKEY * key): + key(key, EVP_PKEY_free) {} + friend class SecretKexKey; +public: + static PublicKexKey load(const Ref &); + Ref store(const Storage &) const; + + const shared_ptr<EVP_PKEY> key; +}; + +class SecretKexKey +{ + SecretKexKey(EVP_PKEY * key, const Stored<PublicKexKey> & pub): + key(key, EVP_PKEY_free), pub_(pub) {} + SecretKexKey(shared_ptr<EVP_PKEY> && key, const Stored<PublicKexKey> & pub): + key(key), pub_(pub) {} +public: + static SecretKexKey generate(const Storage & st); + static optional<SecretKexKey> load(const Stored<PublicKexKey> & st); + + Stored<PublicKexKey> pub() const { return pub_; } + vector<uint8_t> dh(const PublicKexKey &) const; + +private: + const shared_ptr<EVP_PKEY> key; + Stored<PublicKexKey> pub_; +}; + +} diff --git a/src/service.cpp b/src/service.cpp new file mode 100644 index 0000000..fc1ec5f --- /dev/null +++ b/src/service.cpp @@ -0,0 +1,46 @@ +#include "service.h" + +using namespace erebos; + +Service::Service() = default; +Service::~Service() = default; + +Service::Context::Context(Priv * p): + p(p) +{} + +Service::Context::Priv & Service::Context::priv() +{ + return *p; +} + +const Ref & Service::Context::ref() const +{ + return p->ref; +} + +const Peer & Service::Context::peer() const +{ + return p->peer; +} + +const Stored<LocalState> & Service::Context::local() const +{ + return p->local; +} + +void Service::Context::local(const LocalState & ls) +{ + p->local = p->local.ref().storage().store(ls); +} + +void Service::Context::afterCommit(function<void()> hook) +{ + p->afterCommit.push_back(move(hook)); +} + +void Service::Context::runAfterCommitHooks() const +{ + for (const auto & hook : p->afterCommit) + hook(); +} diff --git a/src/service.h b/src/service.h new file mode 100644 index 0000000..dada1fb --- /dev/null +++ b/src/service.h @@ -0,0 +1,16 @@ +#pragma once + +#include <erebos/network.h> +#include <erebos/service.h> + +namespace erebos { + +struct Service::Context::Priv +{ + Ref ref; + Peer peer; + Stored<LocalState> local; + vector<function<void()>> afterCommit {}; +}; + +} diff --git a/src/set.cpp b/src/set.cpp new file mode 100644 index 0000000..ce343d8 --- /dev/null +++ b/src/set.cpp @@ -0,0 +1,180 @@ +#include "set.h" + +#include <unordered_map> +#include <unordered_set> +#include <utility> + +namespace erebos { + +using std::pair; +using std::unordered_map; +using std::unordered_set; +using std::move; + +SetBase::SetBase(): + p(make_shared<Priv>()) +{ +} + +SetBase::SetBase(const vector<Ref> & refs) +{ + vector<Stored<SetItem>> items; + for (const auto & r : refs) + items.push_back(Stored<SetItem>::load(r)); + + p = shared_ptr<Priv>(new Priv { + .items = move(items), + }); +} + +SetBase::SetBase(shared_ptr<const Priv> p_): + p(move(p_)) +{ +} + +shared_ptr<const SetBase::Priv> SetBase::add(const Storage & st, const vector<Ref> & refs) const +{ + auto item = st.store(SetItem { + .prev = p->items, + .item = refs, + }); + + return shared_ptr<const Priv>(new Priv { + .items = { move(item) }, + }); +} + +static void gatherSetItems(unordered_set<Digest> & seenSet, unordered_set<Digest> & seenElem, + vector<Ref> & res, const Stored<SetItem> & item) +{ + if (!seenElem.insert(item.ref().digest()).second) + return; + + for (const auto & r : item->item) + if (seenSet.insert(r.digest()).second) + res.push_back(r); + + for (const auto & p : item->prev) + gatherSetItems(seenSet, seenElem, res, p); +} + +vector<vector<Ref>> SetBase::toList() const +{ + /* Splits the graph starting from all set item refs into connected + * components (partitions), each such partition makes one set item, + * merged together in the templated SetView constructor. */ + + // Gather all item references + vector<Ref> items; + { + unordered_set<Digest> seenSet, seenElem; + for (const auto & i : p->items) + gatherSetItems(seenSet, seenElem, items, i); + } + + unordered_map<Digest, unsigned> partMap; // maps item ref to partition number + vector<unsigned> partMerge; // maps partitions to resulting one after partition merge + + // Use (cached) root set for assigning partition numbers + for (const auto & item : items) { + const auto roots = item.roots(); + unsigned part = partMerge.size(); + + // If any root has partition number already, pick the smallest one + for (const auto & rdgst : roots) { + auto it = partMap.find(rdgst); + if (it != partMap.end() && it->second < part) + part = it->second; + } + + // Update partition number for the roots and if this item + // merges some partitions, also update the merge info + for (const auto & rdgst : roots) { + auto it = partMap.find(rdgst); + if (it == partMap.end()) { + partMap.emplace(rdgst, part); + } else if (it->second != part) { + partMerge[it->second] = part; + it->second = part; + } + } + + // If no existing partition has been touched, mark a new one + if (part == partMerge.size()) + partMerge.push_back(part); + + // And store resulting partition number + partMap.emplace(item.digest(), part); + } + + // Get all the refs for each partition + vector<vector<Ref>> res(partMerge.size()); + for (const auto & item : items) { + unsigned part = partMap[item.digest()]; + for (unsigned p = partMerge[part]; p != part; p = partMerge[p]) + part = p; + res[part].push_back(item); + } + + // Remove empty elements (merged partitions) from result list + res.erase(std::remove(res.begin(), res.end(), vector<Ref>()), res.end()); + + return res; +} + +bool SetBase::operator==(const SetBase & other) const +{ + return p->items == other.p->items; +} + +bool SetBase::operator!=(const SetBase & other) const +{ + return !(*this == other); +} + +vector<Digest> SetBase::digests() const +{ + vector<Digest> res; + res.reserve(p->items.size()); + for (const auto & i : p->items) + res.push_back(i.ref().digest()); + return res; +} + +vector<Ref> SetBase::store() const +{ + vector<Ref> res; + res.reserve(p->items.size()); + for (const auto & i : p->items) + res.push_back(i.ref()); + return res; +} + +SetItem SetItem::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) { + return SetItem { + .prev = rec->items("PREV").as<SetItem>(), + .item = rec->items("item").asRef(), + }; + } + + return SetItem { + .prev = {}, + .item = {}, + }; +} + +Ref SetItem::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & p : prev) + items.emplace_back("PREV", p.ref()); + for (const auto & r : item) + items.emplace_back("item", r); + + return st.storeObject(Record(std::move(items))); +} + +} diff --git a/src/set.h b/src/set.h new file mode 100644 index 0000000..ffbcbd6 --- /dev/null +++ b/src/set.h @@ -0,0 +1,19 @@ +#include <erebos/set.h> + +namespace erebos { + +struct SetItem +{ + static SetItem load(const Ref &); + Ref store(const Storage & st) const; + + const vector<Stored<SetItem>> prev; + const vector<Ref> item; +}; + +struct SetBase::Priv +{ + vector<Stored<SetItem>> items; +}; + +} diff --git a/src/state.cpp b/src/state.cpp new file mode 100644 index 0000000..31171d7 --- /dev/null +++ b/src/state.cpp @@ -0,0 +1,194 @@ +#include "state.h" + +#include "identity.h" + +using namespace erebos; + +using std::make_shared; + +const UUID LocalState::headTypeId { "1d7491a9-7bcb-4eaa-8f13-c8c4c4087e4e" }; + +LocalState::LocalState(): + p(make_shared<Priv>()) +{} + +LocalState::LocalState(const Ref & ref): + LocalState() +{ + auto rec = ref->asRecord(); + if (!rec) + return; + + if (auto x = rec->item("id").asRef()) + p->identity = Identity::load(*x); + + p->shared.tip = rec->items("shared").as<SharedData>(); + + if (p->identity) { + vector<StoredIdentityPart> updates; + for (const auto & r : lookupShared(SharedType<optional<Identity>>::id)) + updates.push_back(StoredIdentityPart::load(r)); + if (!updates.empty()) + p->identity = p->identity->update(updates); + } +} + +Ref LocalState::store(const Storage & st) const +{ + vector<Record::Item> items; + + if (p->identity) + items.emplace_back("id", *p->identity->extRef()); + for (const auto & x : p->shared.tip) + items.emplace_back("shared", x); + + return st.storeObject(Record(std::move(items))); +} + +const optional<Identity> & LocalState::identity() const +{ + return p->identity; +} + +LocalState LocalState::identity(const Identity & id) const +{ + LocalState ret; + ret.p->identity = id; + ret.p->shared = p->shared; + return ret; +} + +vector<Ref> LocalState::lookupShared(UUID type) const +{ + return p->shared.lookup(type); +} + +vector<Ref> SharedState::lookup(UUID type) const +{ + return p->lookup(type); +} + +vector<Ref> SharedState::Priv::lookup(UUID type) const +{ + vector<Stored<SharedData>> found; + vector<Stored<SharedData>> process = tip; + + while (!process.empty()) { + auto cur = std::move(process.back()); + process.pop_back(); + + if (cur->type == type) { + found.push_back(std::move(cur)); + continue; + } + + for (const auto & x : cur->prev) + process.push_back(x); + } + + filterAncestors(found); + vector<Ref> res; + for (const auto & s : found) + for (const auto & v : s->value) + res.push_back(v); + return res; +} + +vector<Ref> LocalState::sharedRefs() const +{ + vector<Ref> refs; + for (const auto & x : p->shared.tip) + refs.push_back(x.ref()); + return refs; +} + +LocalState LocalState::sharedRefAdd(const Ref & ref) const +{ + const Storage * st; + if (p->shared.tip.size() > 0) + st = &p->shared.tip[0].ref().storage(); + else if (p->identity) + st = &p->identity->ref()->storage(); + else + st = &ref.storage(); + + LocalState ret; + ret.p->identity = p->identity; + ret.p->shared = p->shared; + ret.p->shared.tip.push_back(SharedData(ref).store(*st)); + filterAncestors(ret.p->shared.tip); + return ret; +} + +LocalState LocalState::updateShared(UUID type, const vector<Ref> & xs) const +{ + const Storage * st; + if (p->shared.tip.size() > 0) + st = &p->shared.tip[0].ref().storage(); + else if (p->identity) + st = &p->identity->ref()->storage(); + else if (xs.size() > 0) + st = &xs[0].storage(); + else + return *this; + + LocalState ret; + ret.p->identity = p->identity; + ret.p->shared.tip.push_back(SharedData(p->shared.tip, type, xs).store(*st)); + return ret; +} + + +bool SharedState::operator==(const SharedState & other) const +{ + return p->tip == other.p->tip; +} + +bool SharedState::operator!=(const SharedState & other) const +{ + return p->tip != other.p->tip; +} + + +SharedData::SharedData(const Ref & ref) +{ + auto rec = ref->asRecord(); + if (!rec) + return; + + prev = rec->items("PREV").as<SharedData>(); + if (auto x = rec->item("type").asUUID()) + type = *x; + value = rec->items("value").asRef(); +} + +Ref SharedData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & x : prev) + items.emplace_back("PREV", x); + items.emplace_back("type", type); + for (const auto & x : value) + items.emplace_back("value", x); + + return st.storeObject(Record(std::move(items))); +} + +template<> +optional<Identity> LocalState::lens<optional<Identity>>(const LocalState & x) +{ + return x.identity(); +} + +template<> +vector<Ref> LocalState::lens<vector<Ref>>(const LocalState & x) +{ + return x.sharedRefs(); +} + +template<> +SharedState LocalState::lens<SharedState>(const LocalState & x) +{ + return SharedState(shared_ptr<SharedState::Priv>(x.p, &x.p->shared)); +} diff --git a/src/state.h b/src/state.h new file mode 100644 index 0000000..397a906 --- /dev/null +++ b/src/state.h @@ -0,0 +1,41 @@ +#pragma once + +#include <erebos/state.h> +#include <erebos/identity.h> + +#include "pubkey.h" + +using std::optional; +using std::shared_ptr; +using std::vector; + +namespace erebos { + +struct SharedState::Priv +{ + vector<Ref> lookup(UUID) const; + + vector<Stored<struct SharedData>> tip; +}; + +struct LocalState::Priv +{ + optional<Identity> identity; + SharedState::Priv shared; +}; + +struct SharedData +{ + explicit SharedData(vector<Stored<SharedData>> prev, + UUID type, vector<Ref> value): + prev(prev), type(type), value(value) {} + explicit SharedData(const Ref &); + static SharedData load(const Ref & ref) { return SharedData(ref); } + Ref store(const Storage &) const; + + vector<Stored<SharedData>> prev; + UUID type; + vector<Ref> value; +}; + +} diff --git a/src/storage.cpp b/src/storage.cpp new file mode 100644 index 0000000..19f35a9 --- /dev/null +++ b/src/storage.cpp @@ -0,0 +1,1682 @@ +#include "storage.h" + +#include <charconv> +#include <chrono> +#include <fstream> +#include <iomanip> +#include <iterator> +#include <stdexcept> +#include <thread> + +#include <poll.h> +#include <stdio.h> +#include <sys/eventfd.h> +#include <sys/inotify.h> + +#include <blake2.h> +#include <zlib.h> + +using namespace erebos; + +using std::array; +using std::copy; +using std::get; +using std::holds_alternative; +using std::ifstream; +using std::invalid_argument; +using std::is_same_v; +using std::make_shared; +using std::make_unique; +using std::monostate; +using std::nullopt; +using std::ofstream; +using std::out_of_range; +using std::runtime_error; +using std::scoped_lock; +using std::shared_ptr; +using std::string; +using std::system_error; +using std::to_string; +using std::weak_ptr; + +void StorageWatchCallback::schedule(UUID uuid, const Digest & dgst) +{ + scoped_lock lock(runMutex); + scheduled.emplace(uuid, dgst); +} + +void StorageWatchCallback::run() +{ + scoped_lock lock(runMutex); + if (scheduled) { + auto [uuid, dgst] = *scheduled; + scheduled.reset(); // avoid running the callback twice + + callback(uuid, dgst); + } +} + +FilesystemStorage::FilesystemStorage(const fs::path & path): + root(path) +{ +} + +FilesystemStorage::~FilesystemStorage() +{ + if (inotifyWakeup >= 0) { + uint64_t x = 1; + write(inotifyWakeup, &x, sizeof(x)); + } + + if (watcherThread.joinable()) + watcherThread.join(); + + if (inotify >= 0) + close(inotify); + + if (inotifyWakeup >= 0) + close(inotifyWakeup); + +} + +bool FilesystemStorage::contains(const Digest & digest) const +{ + return fs::exists(objectPath(digest)); +} + +optional<vector<uint8_t>> FilesystemStorage::loadBytes(const Digest & digest) const +{ + vector<uint8_t> in(CHUNK); + vector<uint8_t> out; + size_t decoded = 0; + + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + strm.avail_in = 0; + strm.next_in = Z_NULL; + int ret = inflateInit(&strm); + if (ret != Z_OK) + throw runtime_error("zlib initialization failed"); + + ifstream fin(objectPath(digest), std::ios::binary); + if (!fin.is_open()) + return nullopt; + + while (!fin.eof() && ret != Z_STREAM_END) { + fin.read((char*) in.data(), in.size()); + if (fin.bad()) { + inflateEnd(&strm); + throw runtime_error("failed to read stored file"); + } + strm.avail_in = fin.gcount(); + if (strm.avail_in == 0) + break; + strm.next_in = in.data(); + + do { + if (out.size() < decoded + in.size()) + out.resize(decoded + in.size()); + + strm.avail_out = out.size() - decoded; + strm.next_out = out.data() + decoded; + ret = inflate(&strm, Z_NO_FLUSH); + switch (ret) { + case Z_STREAM_ERROR: + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm); + throw runtime_error("zlib decoding failed"); + } + decoded = out.size() - strm.avail_out; + } while (strm.avail_out == 0); + } + + + inflateEnd(&strm); + if (ret != Z_STREAM_END) + throw runtime_error("zlib decoding failed"); + + out.resize(decoded); + return out; +} + +void FilesystemStorage::storeBytes(const Digest & digest, const vector<uint8_t> & in) +{ + vector<uint8_t> out(CHUNK); + + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + int ret = deflateInit(&strm, Z_DEFAULT_COMPRESSION); + if (ret != Z_OK) + throw runtime_error("zlib initialization failed"); + + auto path = objectPath(digest); + auto lock = path; + lock += ".lock"; + + FILE * f = openLockFile(lock); + if (fs::exists(path)) { + if (f) { + fclose(f); + fs::remove(lock); + } + return; + } + if (!f) + throw runtime_error("failed to open storage file"); + + strm.avail_in = in.size(); + strm.next_in = const_cast<uint8_t*>(in.data()); + do { + strm.avail_out = out.size(); + strm.next_out = out.data(); + ret = deflate(&strm, Z_FINISH); + if (ret == Z_STREAM_ERROR) + break; + size_t have = out.size() - strm.avail_out; + if (fwrite(out.data(), 1, have, f) != have || ferror(f)) { + ret = Z_ERRNO; + break; + } + } while (strm.avail_out == 0); + + fclose(f); + deflateEnd(&strm); + + if (strm.avail_in != 0 || ret != Z_STREAM_END) { + fs::remove(lock); + throw runtime_error("failed to deflate object"); + } + + fs::rename(lock, path); +} + +optional<Digest> FilesystemStorage::headRef(UUID type, UUID id) const +{ + ifstream fin(headPath(type, id)); + if (!fin) + return nullopt; + + string sdgst; + fin >> sdgst; + return Digest(sdgst); +} + +vector<tuple<UUID, Digest>> FilesystemStorage::headRefs(UUID type) const +{ + vector<tuple<UUID, Digest>> res; + string stype(type); + fs::path ptype(stype.begin(), stype.end()); + try { + for (const auto & p : fs::directory_iterator(root/"heads"/ptype)) + if (auto u = UUID::fromString(p.path().filename())) { + ifstream fin(p.path()); + if (fin) { + string sdgst; + fin >> sdgst; + res.emplace_back(*u, Digest(sdgst)); + } + } + } catch (const fs::filesystem_error & e) { + if (e.code() == std::errc::no_such_file_or_directory) + return {}; + throw e; + } + return res; +} + +UUID FilesystemStorage::storeHead(UUID type, const Digest & dgst) +{ + auto id = UUID::generate(); + auto path = headPath(type, id); + fs::create_directories(path.parent_path()); + ofstream fout(path); + if (!fout) + throw runtime_error("failed to open head file"); + + fout << string(dgst) << '\n'; + return id; +} + +bool FilesystemStorage::replaceHead(UUID type, UUID id, const Digest & old, const Digest & dgst) +{ + auto path = headPath(type, id); + auto lock = path; + lock += ".lock"; + FILE * f = openLockFile(lock); + if (!f) + throw runtime_error(("failed to lock head file " + string(path) + + ": " + string(strerror(errno))).c_str()); + + string scur; + ifstream fin(path); + fin >> scur; + fin.close(); + Digest cur(scur); + + if (cur != old) { + fclose(f); + unlink(lock.c_str()); + return false; + } + + fprintf(f, "%s\n", string(dgst).c_str()); + fclose(f); + fs::rename(lock, path); + return true; +} + +int FilesystemStorage::watchHead(UUID type, const function<void(UUID id, const Digest &)> & watcher) +{ + scoped_lock lock(watcherLock); + int wid = nextWatcherId++; + + if (inotify < 0) { + inotify = inotify_init(); + if (inotify < 0) + throw system_error(errno, std::generic_category()); + + inotifyWakeup = eventfd(0, 0); + if (inotifyWakeup < 0) + throw system_error(errno, std::generic_category()); + + watcherThread = std::thread(&FilesystemStorage::inotifyWatch, this); + } + + if (watchers.find(type) == watchers.end()) { + int wd = inotify_add_watch(inotify, headPath(type).c_str(), IN_MOVED_TO); + if (wd < 0) + throw system_error(errno, std::generic_category()); + + watchMap[wd] = type; + } + watchers.emplace(type, make_shared<StorageWatchCallback>(wid, watcher)); + + return wid; +} + +void FilesystemStorage::unwatchHead(UUID type, int wid) +{ + shared_ptr<StorageWatchCallback> cb; + + { + scoped_lock lock(watcherLock); + + if (inotify < 0) + return; + + auto range = watchers.equal_range(type); + for (auto it = range.first; it != range.second; it++) { + if (it->second->id == wid) { + cb = move(it->second); + watchers.erase(it); + break; + } + } + + if (watchers.find(type) == watchers.end()) { + for (auto it = watchMap.begin(); it != watchMap.end(); it++) { + if (it->second == type) { + if (inotify_rm_watch(inotify, it->first) < 0) + throw system_error(errno, std::generic_category()); + watchMap.erase(it); + break; + } + } + } + } + + // Run the last callback if scheduled and not yet executed + if (cb) + cb->run(); +} + +optional<vector<uint8_t>> FilesystemStorage::loadKey(const Digest & pubref) const +{ + fs::path path = keyPath(pubref); + std::error_code err; + size_t size = fs::file_size(path, err); + if (err) + return nullopt; + + vector<uint8_t> key(size); + ifstream file(keyPath(pubref)); + file.read((char *) key.data(), size); + return key; +} + +void FilesystemStorage::storeKey(const Digest & pubref, const vector<uint8_t> & key) +{ + fs::path path = keyPath(pubref); + fs::create_directories(path.parent_path()); + ofstream file(path); + file.write((const char *) key.data(), key.size()); +} + +void FilesystemStorage::inotifyWatch() +{ + char buf[4096] + __attribute__ ((aligned(__alignof__(struct inotify_event)))); + const struct inotify_event * event; + + array pfds { + pollfd { inotify, POLLIN, 0 }, + pollfd { inotifyWakeup, POLLIN, 0 }, + }; + + while (true) { + int ret = poll(pfds.data(), pfds.size(), -1); + if (ret < 0) + throw system_error(errno, std::generic_category()); + + if (!(pfds[0].revents & POLLIN)) + break; + + ssize_t len = read(inotify, buf, sizeof buf); + if (len < 0) { + if (errno == EAGAIN) + continue; + + throw system_error(errno, std::generic_category()); + } + + if (len == 0) + break; + + for (char * ptr = buf; ptr < buf + len; + ptr += sizeof(struct inotify_event) + event->len) { + event = (const struct inotify_event *) ptr; + + if (event->mask & IN_MOVED_TO) { + vector<shared_ptr<StorageWatchCallback>> callbacks; + + { + // Copy relevant callbacks to temporary array, so they + // can be called without holding the watcherLock. + + scoped_lock lock(watcherLock); + UUID type = watchMap[event->wd]; + if (auto mbid = UUID::fromString(event->name)) { + if (auto mbref = headRef(type, *mbid)) { + auto range = watchers.equal_range(type); + for (auto it = range.first; it != range.second; it++) { + it->second->schedule(*mbid, *mbref); + callbacks.push_back(it->second); + } + } + } + } + + for (const auto & cb : callbacks) + cb->run(); + } + } + } +} + +fs::path FilesystemStorage::objectPath(const Digest & digest) const +{ + string name(digest); + size_t delim = name.find('#'); + + return root/"objects"/ + fs::path(name.begin(), name.begin() + delim)/ + fs::path(name.begin() + delim + 1, name.begin() + delim + 3)/ + fs::path(name.begin() + delim + 3, name.end()); +} + +fs::path FilesystemStorage::headPath(UUID type) const +{ + string stype(type); + return root/"heads"/fs::path(stype.begin(), stype.end()); +} + +fs::path FilesystemStorage::headPath(UUID type, UUID id) const +{ + string sid(id); + return headPath(type) / fs::path(sid.begin(), sid.end()); +} + +fs::path FilesystemStorage::keyPath(const Digest & digest) const +{ + string name(digest); + return root/"keys"/fs::path(name.begin(), name.end()); +} + +FILE * FilesystemStorage::openLockFile(const fs::path & path) const +{ + fs::create_directories(path.parent_path()); + + // No way to use open exclusively in c++ stdlib + FILE *f = nullptr; + for (int i = 0; i < 10; i++) { + f = fopen(path.c_str(), "wbxe"); + if (f || errno != EEXIST) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + return f; +} + + +bool MemoryStorage::contains(const Digest & digest) const +{ + return storage.find(digest) != storage.end(); +} + +optional<vector<uint8_t>> MemoryStorage::loadBytes(const Digest & digest) const +{ + auto it = storage.find(digest); + if (it != storage.end()) + return it->second; + return nullopt; +} + +void MemoryStorage::storeBytes(const Digest & digest, const vector<uint8_t> & content) +{ + storage.emplace(digest, content); +} + +optional<Digest> MemoryStorage::headRef(UUID type, UUID id) const +{ + auto it = heads.find(type); + if (it == heads.end()) + return nullopt; + + for (const auto & x : it->second) + if (get<UUID>(x) == id) + return get<Digest>(x); + + return nullopt; +} + +vector<tuple<UUID, Digest>> MemoryStorage::headRefs(UUID type) const +{ + auto it = heads.find(type); + if (it != heads.end()) + return it->second; + return {}; +} + +UUID MemoryStorage::storeHead(UUID type, const Digest & dgst) +{ + auto id = UUID::generate(); + auto it = heads.find(type); + if (it == heads.end()) + heads[type] = { { id, dgst } }; + else + it->second.emplace_back(id, dgst); + return id; +} + +bool MemoryStorage::replaceHead(UUID type, UUID id, const Digest & old, const Digest & dgst) +{ + auto it = heads.find(type); + if (it == heads.end()) + return false; + + for (auto & x : it->second) + if (get<UUID>(x) == id) { + if (get<Digest>(x) == old) { + get<Digest>(x) = dgst; + return true; + } else { + return false; + } + } + + return false; +} + +int MemoryStorage::watchHead(UUID type, const function<void(UUID id, const Digest &)> & watcher) +{ + scoped_lock lock(watcherLock); + int wid = nextWatcherId++; + watchers.emplace(type, make_shared<StorageWatchCallback>(wid, watcher)); + return wid; +} + +void MemoryStorage::unwatchHead(UUID type, int wid) +{ + shared_ptr<StorageWatchCallback> cb; + + { + scoped_lock lock(watcherLock); + + auto range = watchers.equal_range(type); + for (auto it = range.first; it != range.second; it++) { + if (it->second->id == wid) { + cb = move(it->second); + watchers.erase(it); + break; + } + } + } + + // Run the last callback if scheduled and not yet executed + if (cb) + cb->run(); +} + +optional<vector<uint8_t>> MemoryStorage::loadKey(const Digest & digest) const +{ + auto it = keys.find(digest); + if (it != keys.end()) + return it->second; + return nullopt; +} + +void MemoryStorage::storeKey(const Digest & digest, const vector<uint8_t> & content) +{ + keys.emplace(digest, content); +} + +bool ChainStorage::contains(const Digest & digest) const +{ + return storage->contains(digest) || + (parent && parent->contains(digest)); +} + +optional<vector<uint8_t>> ChainStorage::loadBytes(const Digest & digest) const +{ + if (auto res = storage->loadBytes(digest)) + return res; + if (parent) + return parent->loadBytes(digest); + return nullopt; +} + +void ChainStorage::storeBytes(const Digest & digest, const vector<uint8_t> & content) +{ + storage->storeBytes(digest, content); +} + +optional<Digest> ChainStorage::headRef(UUID type, UUID id) const +{ + if (auto res = storage->headRef(type, id)) + return res; + if (parent) + return parent->headRef(type, id); + return nullopt; +} + +vector<tuple<UUID, Digest>> ChainStorage::headRefs(UUID type) const +{ + auto res = storage->headRefs(type); + if (parent) + for (auto x : parent->headRefs(type)) { + bool add = true; + for (const auto & y : res) + if (get<UUID>(y) == get<UUID>(x)) { + add = false; + break; + } + if (add) + res.push_back(x); + } + return res; +} + +UUID ChainStorage::storeHead(UUID type, const Digest & dgst) +{ + return storage->storeHead(type, dgst); +} + +bool ChainStorage::replaceHead(UUID type, UUID id, const Digest & old, const Digest & dgst) +{ + return storage->replaceHead(type, id, old, dgst); +} + +int ChainStorage::watchHead(UUID type, const function<void(UUID id, const Digest &)> & watcher) +{ + scoped_lock lock(watcherLock); + int wid = nextWatcherId++; + + int id1 = parent->watchHead(type, watcher); + int id2 = storage->watchHead(type, watcher); + watchers.emplace(wid, tuple(id1, id2)); + + return wid; +} + +void ChainStorage::unwatchHead(UUID type, int wid) +{ + scoped_lock lock(watcherLock); + + auto [id1, id2] = watchers.extract(wid).mapped(); + parent->unwatchHead(type, id1); + storage->unwatchHead(type, id2); +} + +optional<vector<uint8_t>> ChainStorage::loadKey(const Digest & digest) const +{ + if (auto res = storage->loadKey(digest)) + return res; + if (parent) + return parent->loadKey(digest); + return nullopt; +} + +void ChainStorage::storeKey(const Digest & digest, const vector<uint8_t> & content) +{ + storage->storeKey(digest, content); +} + + +Storage::Storage(const fs::path & path): + PartialStorage(shared_ptr<Priv>(new Priv { .backend = make_shared<FilesystemStorage>(path) })) +{} + +Storage Storage::deriveEphemeralStorage() const +{ + return Storage(shared_ptr<Priv>(new Priv { .backend = + make_shared<ChainStorage>( + make_shared<MemoryStorage>(), + make_unique<ChainStorage>(p->backend) + )})); +} + +PartialStorage Storage::derivePartialStorage() const +{ + return PartialStorage(shared_ptr<Priv>(new Priv { .backend = + make_shared<ChainStorage>( + make_shared<MemoryStorage>(), + make_unique<ChainStorage>(p->backend) + )})); +} + +bool PartialStorage::operator==(const PartialStorage & other) const +{ + return p == other.p; +} + +bool PartialStorage::operator!=(const PartialStorage & other) const +{ + return p != other.p; +} + +PartialRef PartialStorage::ref(const Digest & digest) const +{ + return PartialRef::create(*this, digest); +} + +optional<Ref> Storage::ref(const Digest & digest) const +{ + return Ref::create(*this, digest); +} + +Ref Storage::zref() const +{ + return Ref::zcreate(*this); +} + +Digest PartialStorage::Priv::storeBytes(const vector<uint8_t> & content) const +{ + Digest digest = Digest::of(content); + backend->storeBytes(digest, content); + return digest; +} + +optional<vector<uint8_t>> PartialStorage::Priv::loadBytes(const Digest & digest) const +{ + auto ocontent = backend->loadBytes(digest); + if (!ocontent.has_value()) + return nullopt; + auto content = ocontent.value(); + + if (digest != Digest::of(content)) + throw runtime_error("digest verification failed"); + + return content; +} + +optional<PartialObject> PartialStorage::loadObject(const Digest & digest) const +{ + if (digest.isZero()) + return PartialObject(monostate()); + if (auto content = p->loadBytes(digest)) + return PartialObject::decode(*this, *content); + return nullopt; +} + +PartialRef PartialStorage::storeObject(const PartialObject & obj) const +{ + if (not obj) + return PartialRef::zcreate(*this); + return ref(p->storeBytes(obj.encode())); +} + +PartialRef PartialStorage::storeObject(const PartialRecord & val) const +{ return storeObject(PartialObject(val)); } + +PartialRef PartialStorage::storeObject(const Blob & val) const +{ return storeObject(PartialObject(val)); } + +optional<Object> Storage::loadObject(const Digest & digest) const +{ + if (digest.isZero()) + return Object(monostate()); + if (auto content = p->loadBytes(digest)) + return Object::decode(*this, *content); + return nullopt; +} + +Ref Storage::storeObject(const Object & object) const +{ return copy(object); } + +Ref Storage::storeObject(const Record & val) const +{ return storeObject(Object(val)); } + +Ref Storage::storeObject(const Blob & val) const +{ return storeObject(Object(val)); } + +template<class S> +optional<Digest> Storage::Priv::copy(const typename S::Ref & pref, vector<Digest> * missing) const +{ + if (pref.digest().isZero()) + return pref.digest(); + if (backend->contains(pref.digest())) + return pref.digest(); + if (pref) + return copy<S>(*pref, missing); + if (missing) + missing->push_back(pref.digest()); + return nullopt; +} + +template<class S> +optional<Digest> Storage::Priv::copy(const ObjectT<S> & pobj, vector<Digest> * missing) const +{ + if (not pobj) + return Digest(array<uint8_t, Digest::size> {}); + + bool fail = false; + if (auto rec = pobj.asRecord()) + for (const auto & r : rec->items().asRef()) + if (!copy<S>(r, missing)) + fail = true; + + if (fail) + return nullopt; + + return storeBytes(pobj.encode()); +} + +variant<Ref, vector<Digest>> Storage::copy(const PartialRef & pref) const +{ + vector<Digest> missing; + if (auto digest = p->copy<PartialStorage>(pref, &missing)) + return Ref::create(*this, *digest).value(); + return missing; +} + +variant<Ref, vector<Digest>> Storage::copy(const PartialObject & pobj) const +{ + vector<Digest> missing; + if (auto digest = p->copy<PartialStorage>(pobj, &missing)) + return Ref::create(*this, *digest).value(); + return missing; +} + +Ref Storage::copy(const Ref & ref) const +{ + if (auto digest = p->copy<Storage>(ref, nullptr)) + return Ref::create(*this, *digest).value(); + throw runtime_error("corrupted storage"); +} + +Ref Storage::copy(const Object & obj) const +{ + if (auto digest = p->copy<Storage>(obj, nullptr)) + return Ref::create(*this, *digest).value(); + throw runtime_error("corrupted storage"); +} + +void Storage::storeKey(Ref pubref, const vector<uint8_t> & key) const +{ + p->backend->storeKey(pubref.digest(), key); +} + +optional<vector<uint8_t>> Storage::loadKey(Ref pubref) const +{ + return p->backend->loadKey(pubref.digest()); +} + +optional<Ref> Storage::headRef(UUID type, UUID id) const +{ + if (auto dgst = p->backend->headRef(type, id)) + return ref(*dgst); + return nullopt; +} + +vector<tuple<UUID, Ref>> Storage::headRefs(UUID type) const +{ + vector<tuple<UUID, Ref>> res; + for (auto x : p->backend->headRefs(type)) + if (auto r = ref(get<Digest>(x))) + res.emplace_back(get<UUID>(x), *r); + return res; +} + +UUID Storage::storeHead(UUID type, const Ref & ref) +{ + return ref.storage().p->backend->storeHead(type, ref.digest()); +} + +bool Storage::replaceHead(UUID type, UUID id, const Ref & old, const Ref & ref) +{ + return ref.storage().p->backend->replaceHead(type, id, old.digest(), ref.digest()); +} + +optional<Ref> Storage::updateHead(UUID type, UUID id, const Ref & old, const std::function<Ref(const Ref &)> & f) +{ + auto cur = old.storage().headRef(type, id); + if (!cur) + return nullopt; + + Ref r = f(*cur); + if (r.digest() == cur->digest() || replaceHead(type, id, *cur, r)) + return r; + + return updateHead(type, id, *cur, f); +} + +int Storage::watchHead(UUID type, UUID wid, const std::function<void(const Ref &)> watcher) const +{ + return p->backend->watchHead(type, [wp = weak_ptr<const Priv>(p), wid, watcher] (UUID id, const Digest & dgst) { + if (id == wid) + if (auto p = wp.lock()) + if (auto r = Ref::create(Storage(p), dgst)) + watcher(*r); + }); +} + +void Storage::unwatchHead(UUID type, UUID, int wid) const +{ + p->backend->unwatchHead(type, wid); +} + + +Digest::Digest(const string & str) +{ + if (str.size() != 2 * size + 7) + throw runtime_error("invalid ref digest"); + + if (strncmp(str.data(), "blake2#", 7) != 0) + throw runtime_error("invalid ref digest"); + + for (size_t i = 0; i < size; i++) + std::from_chars(str.data() + 7 + 2 * i, + str.data() + 7 + 2 * i + 2, + value[i], 16); +} + +Digest::operator string() const +{ + string res(size * 2 + 7, '0'); + memcpy(res.data(), "blake2#", 7); + for (size_t i = 0; i < size; i++) + std::to_chars(res.data() + 7 + 2 * i + (value[i] < 0x10), + res.data() + 7 + 2 * i + 2, + value[i], 16); + return res; +} + +bool Digest::isZero() const +{ + for (uint8_t x : value) + if (x) return false; + return true; +} + +Digest Digest::of(const vector<uint8_t> & content) +{ + array<uint8_t, size> arr; + int ret = blake2b(arr.data(), content.data(), nullptr, + size, content.size(), 0); + if (ret != 0) + throw runtime_error("failed to compute digest"); + + return Digest(arr); +} + + +PartialRef PartialRef::create(const PartialStorage & st, const Digest & digest) +{ + auto p = new Priv { + .storage = make_unique<PartialStorage>(st), + .digest = digest, + }; + + return PartialRef(shared_ptr<Priv>(p)); +} + +PartialRef PartialRef::zcreate(const PartialStorage & st) +{ + return create(st, Digest(array<uint8_t, Digest::size> {})); +} + +const Digest & PartialRef::digest() const +{ + return p->digest; +} + +PartialRef::operator bool() const +{ + return storage().p->backend->contains(p->digest); +} + +const PartialObject PartialRef::operator*() const +{ + if (auto res = p->storage->loadObject(p->digest)) + return *res; + throw runtime_error("failed to load object from partial storage"); +} + +unique_ptr<PartialObject> PartialRef::operator->() const +{ + return make_unique<PartialObject>(**this); +} + +const PartialStorage & PartialRef::storage() const +{ + return *p->storage; +} + +bool Ref::operator==(const Ref & other) const +{ + return p->digest == other.p->digest; +} + +bool Ref::operator!=(const Ref & other) const +{ + return p->digest != other.p->digest; +} + +optional<Ref> Ref::create(const Storage & st, const Digest & digest) +{ + if (!st.p->backend->contains(digest)) + return nullopt; + + auto p = new Priv { + .storage = make_unique<PartialStorage>(st), + .digest = digest, + }; + + return Ref(shared_ptr<Priv>(p)); +} + +Ref Ref::zcreate(const Storage & st) +{ + auto p = new Priv { + .storage = make_unique<PartialStorage>(st), + .digest = Digest(array<uint8_t, Digest::size> {}), + }; + + return Ref(shared_ptr<Priv>(p)); +} + +const Object Ref::operator*() const +{ + if (auto res = static_cast<Storage*>(p->storage.get())->loadObject(p->digest)) + return *res; + throw runtime_error("falied to load object - corrupted storage"); +} + +unique_ptr<Object> Ref::operator->() const +{ + return make_unique<Object>(**this); +} + +const Storage & Ref::storage() const +{ + return *static_cast<const Storage*>(p->storage.get()); +} + +vector<Ref> Ref::previous() const +{ + auto rec = (**this).asRecord(); + if (!rec) + return {}; + + if (auto sdata = rec->item("SDATA").asRef()) { + if (auto drec = sdata.value()->asRecord()) { + auto res = drec->items("SPREV").asRef(); + if (auto base = drec->item("SBASE").asRef()) + res.push_back(*base); + return res; + } + return {}; + } + + auto res = rec->items("PREV").asRef(); + if (auto base = rec->item("BASE").asRef()) + res.push_back(*base); + return res; +} + +Generation Ref::generation() const +{ + scoped_lock lock(p->storage->p->generationCacheLock); + return generationLocked(); +} + +Generation Ref::generationLocked() const +{ + auto it = p->storage->p->generationCache.find(p->digest); + if (it != p->storage->p->generationCache.end()) + return it->second; + + auto prev = previous(); + vector<Generation> pgen; + pgen.reserve(prev.size()); + for (const auto & r : prev) + pgen.push_back(r.generationLocked()); + + auto gen = Generation::next(pgen); + + p->storage->p->generationCache.emplace(p->digest, gen); + return gen; +} + +vector<Digest> Ref::roots() const +{ + scoped_lock lock(p->storage->p->rootsCacheLock); + return rootsLocked(); +} + +vector<Digest> Ref::rootsLocked() const +{ + auto it = p->storage->p->rootsCache.find(p->digest); + if (it != p->storage->p->rootsCache.end()) + return it->second; + + vector<Digest> roots; + auto prev = previous(); + + if (prev.empty()) { + roots.push_back(p->digest); + } else { + for (const auto & p : previous()) + for (const auto & r : p.rootsLocked()) + roots.push_back(r); + + std::sort(roots.begin(), roots.end()); + roots.erase(std::unique(roots.begin(), roots.end()), roots.end()); + } + + p->storage->p->rootsCache.emplace(p->digest, roots); + return roots; +} + + +template<class S> +RecordT<S>::Item::operator bool() const +{ + return !holds_alternative<monostate>(value); +} + +template<class S> +optional<typename RecordT<S>::Item::Empty> RecordT<S>::Item::asEmpty() const +{ + if (holds_alternative<RecordT<S>::Item::Empty>(value)) + return std::get<RecordT<S>::Item::Empty>(value); + return nullopt; +} + +template<class S> +optional<int> RecordT<S>::Item::asInteger() const +{ + if (holds_alternative<int>(value)) + return std::get<int>(value); + return nullopt; +} + +template<class S> +optional<string> RecordT<S>::Item::asText() const +{ + if (holds_alternative<string>(value)) + return std::get<string>(value); + return nullopt; +} + +template<class S> +optional<vector<uint8_t>> RecordT<S>::Item::asBinary() const +{ + if (holds_alternative<vector<uint8_t>>(value)) + return std::get<vector<uint8_t>>(value); + return nullopt; +} + +template<class S> +optional<ZonedTime> RecordT<S>::Item::asDate() const +{ + if (holds_alternative<ZonedTime>(value)) + return std::get<ZonedTime>(value); + return nullopt; +} + +template<class S> +optional<UUID> RecordT<S>::Item::asUUID() const +{ + if (holds_alternative<UUID>(value)) + return std::get<UUID>(value); + return nullopt; +} + +template<class S> +optional<typename S::Ref> RecordT<S>::Item::asRef() const +{ + if (holds_alternative<typename S::Ref>(value)) + return std::get<typename S::Ref>(value); + return nullopt; +} + +template<class S> +optional<typename RecordT<S>::Item::UnknownType> RecordT<S>::Item::asUnknown() const +{ + if (holds_alternative<typename Item::UnknownType>(value)) + return std::get<typename Item::UnknownType>(value); + return nullopt; +} + +template<class S> +RecordT<S>::Items::Items(shared_ptr<const vector<Item>> items): + items(move(items)), filter(nullopt) +{} + +template<class S> +RecordT<S>::Items::Items(shared_ptr<const vector<Item>> items, string filter): + items(move(items)), filter(move(filter)) +{} + +template<class S> +RecordT<S>::Items::Iterator::Iterator(const Items & source, size_t idx): + source(source), idx(idx) +{} + +template<class S> +typename RecordT<S>::Items::Iterator & RecordT<S>::Items::Iterator::operator++() +{ + const auto & items = *source.items; + do { + idx++; + } while (idx < items.size() && + source.filter && + items[idx].name != *source.filter); + return *this; +} + +template<class S> +typename RecordT<S>::Items::Iterator RecordT<S>::Items::begin() const +{ + return ++Iterator(*this, -1); +} + +template<class S> +typename RecordT<S>::Items::Iterator RecordT<S>::Items::end() const +{ + return Iterator(*this, items->size()); +} + +template<class S> +vector<typename RecordT<S>::Item::Empty> RecordT<S>::Items::asEmpty() const +{ + vector<Empty> res; + for (const auto & item : *this) + if (holds_alternative<Empty>(item.value)) + res.push_back(std::get<Empty>(item.value)); + return res; +} + +template<class S> +vector<typename RecordT<S>::Item::Integer> RecordT<S>::Items::asInteger() const +{ + vector<Integer> res; + for (const auto & item : *this) + if (holds_alternative<Integer>(item.value)) + res.push_back(std::get<Integer>(item.value)); + return res; +} + +template<class S> +vector<typename RecordT<S>::Item::Text> RecordT<S>::Items::asText() const +{ + vector<Text> res; + for (const auto & item : *this) + if (holds_alternative<Text>(item.value)) + res.push_back(std::get<Text>(item.value)); + return res; +} + +template<class S> +vector<typename RecordT<S>::Item::Binary> RecordT<S>::Items::asBinary() const +{ + vector<Binary> res; + for (const auto & item : *this) + if (holds_alternative<Binary>(item.value)) + res.push_back(std::get<Binary>(item.value)); + return res; +} + +template<class S> +vector<typename RecordT<S>::Item::Date> RecordT<S>::Items::asDate() const +{ + vector<Date> res; + for (const auto & item : *this) + if (holds_alternative<Date>(item.value)) + res.push_back(std::get<Date>(item.value)); + return res; +} + +template<class S> +vector<typename RecordT<S>::Item::UUID> RecordT<S>::Items::asUUID() const +{ + vector<UUID> res; + for (const auto & item : *this) + if (holds_alternative<UUID>(item.value)) + res.push_back(std::get<UUID>(item.value)); + return res; +} + +template<class S> +vector<typename RecordT<S>::Item::Ref> RecordT<S>::Items::asRef() const +{ + vector<Ref> res; + for (const auto & item : *this) + if (holds_alternative<Ref>(item.value)) + res.push_back(std::get<Ref>(item.value)); + return res; +} + +template<class S> +vector<typename RecordT<S>::Item::UnknownType> RecordT<S>::Items::asUnknown() const +{ + vector<UnknownType> res; + for (const auto & item : *this) + if (holds_alternative<UnknownType>(item.value)) + res.push_back(std::get<UnknownType>(item.value)); + return res; +} + + +template<class S> +RecordT<S>::RecordT(const vector<Item> & from): + ptr(new vector<Item>(from)) +{} + +template<class S> +RecordT<S>::RecordT(vector<Item> && from): + ptr(new vector<Item>(std::move(from))) +{} + +template<class S> +optional<RecordT<S>> RecordT<S>::decode(const S & st, + vector<uint8_t>::const_iterator begin, + vector<uint8_t>::const_iterator end) +{ + auto items = make_shared<vector<Item>>(); + + while (begin != end) { + const auto colon = std::find(begin, end, ':'); + if (colon == end) + return nullopt; + const string name(begin, colon); + + const auto space = std::find(colon + 1, end, ' '); + if (space == end) + return nullopt; + const string type(colon + 1, space); + + begin = space + 1; + string value; + for (bool cont = true; cont; ) { + auto newline = std::find(begin, end, '\n'); + if (newline == end) + return nullopt; + + if (newline + 1 != end && *(newline + 1) == '\t') + newline++; + else + cont = false; + + value.append(begin, newline); + begin = newline + 1; + } + + if (type == "e") { + if (value.size() != 0) + return nullopt; + items->emplace_back(name, typename Item::Empty {}); + } else if (type == "i") + try { + items->emplace_back(name, std::stoi(value)); + } catch (invalid_argument &) { + return nullopt; + } catch (out_of_range &) { + return nullopt; // TODO + } + else if (type == "t") + items->emplace_back(name, value); + else if (type == "b") { + if (value.size() % 2) + return nullopt; + vector<uint8_t> binary(value.size() / 2, 0); + + for (size_t i = 0; i < binary.size(); i++) + std::from_chars(value.data() + 2 * i, + value.data() + 2 * i + 2, + binary[i], 16); + items->emplace_back(name, std::move(binary)); + } else if (type == "d") + items->emplace_back(name, ZonedTime(value)); + else if (type == "u") + items->emplace_back(name, UUID(value)); + else if (type == "r") { + if constexpr (is_same_v<S, Storage>) { + if (auto ref = st.ref(Digest(value))) + items->emplace_back(name, ref.value()); + else + return nullopt; + } else if constexpr (std::is_same_v<S, PartialStorage>) { + items->emplace_back(name, st.ref(Digest(value))); + } + } else + items->emplace_back(name, + typename Item::UnknownType { type, value }); + } + + return RecordT<S>(items); +} + +template<class S> +vector<uint8_t> RecordT<S>::encode() const +{ + return ObjectT<S>(*this).encode(); +} + +template<class S> +typename RecordT<S>::Items RecordT<S>::items() const +{ + return Items(ptr); +} + +template<class S> +typename RecordT<S>::Item RecordT<S>::item(const string & name) const +{ + for (auto item : *ptr) { + if (item.name == name) + return item; + } + return Item("", monostate()); +} + +template<class S> +typename RecordT<S>::Item RecordT<S>::operator[](const string & name) const +{ + return item(name); +} + +template<class S> +typename RecordT<S>::Items RecordT<S>::items(const string & name) const +{ + return Items(ptr, name); +} + +template<class S> +vector<uint8_t> RecordT<S>::encodeInner() const +{ + vector<uint8_t> res; + auto inserter = std::back_inserter(res); + for (const auto & item : *ptr) { + string type; + string value; + + if (item.asEmpty()) { + type = "e"; + value = ""; + } else if (auto x = item.asInteger()) { + type = "i"; + value = to_string(*x); + } else if (auto x = item.asText()) { + type = "t"; + value = *x; + } else if (auto x = item.asBinary()) { + type = "b"; + value.resize(x->size() * 2, '0'); + for (size_t i = 0; i < x->size(); i++) + std::to_chars(value.data() + 2 * i + ((*x)[i] < 0x10), + value.data() + 2 * i + 2, + (*x)[i], 16); + } else if (auto x = item.asDate()) { + type = "d"; + value = string(*x); + } else if (auto x = item.asUUID()) { + type = "u"; + value = string(*x); + } else if (auto x = item.asRef()) { + type = "r"; + if (x->digest().isZero()) + continue; + value = string(x->digest()); + } else if (auto x = item.asUnknown()) { + type = x->type; + value = x->value; + } else { + throw runtime_error("unhandeled record item type"); + } + + copy(item.name.begin(), item.name.end(), inserter); + inserter = ':'; + copy(type.begin(), type.end(), inserter); + inserter = ' '; + + auto i = value.begin(); + while (true) { + auto j = std::find(i, value.end(), '\n'); + copy(i, j, inserter); + inserter = '\n'; + if (j == value.end()) + break; + inserter = '\t'; + i = j + 1; + } + } + return res; +} + +template class erebos::RecordT<Storage>; +template class erebos::RecordT<PartialStorage>; + + +Blob::Blob(const vector<uint8_t> & vec): + ptr(make_shared<vector<uint8_t>>(vec)) +{} + +vector<uint8_t> Blob::encode() const +{ + return Object(*this).encode(); +} + +vector<uint8_t> Blob::encodeInner() const +{ + return *ptr; +} + +Blob Blob::decode( + vector<uint8_t>::const_iterator begin, + vector<uint8_t>::const_iterator end) +{ + return Blob(make_shared<vector<uint8_t>>(begin, end)); +} + +template<class S> +optional<tuple<ObjectT<S>, vector<uint8_t>::const_iterator>> +ObjectT<S>::decodePrefix(const S & st, + vector<uint8_t>::const_iterator begin, + vector<uint8_t>::const_iterator end) +{ + auto newline = std::find(begin, end, '\n'); + if (newline == end) + return nullopt; + + auto space = std::find(begin, newline, ' '); + if (space == newline) + return nullopt; + + ssize_t size; + try { + size = std::stol(string(space + 1, newline)); + } catch (invalid_argument &) { + return nullopt; + } catch (out_of_range &) { + // Way too big to handle anyway + return nullopt; + } + if (end - newline - 1 < size) + return nullopt; + auto cend = newline + 1 + size; + + string type(begin, space); + optional<ObjectT<S>> obj; + if (type == "rec") + if (auto rec = RecordT<S>::decode(st, newline + 1, cend)) + obj.emplace(*rec); + else + return nullopt; + else if (type == "blob") + obj.emplace(Blob::decode(newline + 1, cend)); + else + throw runtime_error("unknown object type '" + type + "'"); + + if (obj) + return std::make_tuple(*obj, cend); + return nullopt; +} + +template<class S> +optional<ObjectT<S>> ObjectT<S>::decode(const S & st, const vector<uint8_t> & data) +{ + return decode(st, data.begin(), data.end()); +} + +template<class S> +optional<ObjectT<S>> ObjectT<S>::decode(const S & st, + vector<uint8_t>::const_iterator begin, + vector<uint8_t>::const_iterator end) +{ + if (auto res = decodePrefix(st, begin, end)) { + auto [obj, next] = *res; + if (next == end) + return obj; + } + return nullopt; +} + +template<class S> +vector<uint8_t> ObjectT<S>::encode() const +{ + vector<uint8_t> res, inner; + string type; + + if (auto rec = asRecord()) { + type = "rec"; + inner = rec->encodeInner(); + } else if (auto blob = asBlob()) { + type = "blob"; + inner = blob->encodeInner(); + } else { + throw runtime_error("unhandeled object type"); + } + + auto inserter = std::back_inserter(res); + copy(type.begin(), type.end(), inserter); + inserter = ' '; + + auto slen = to_string(inner.size()); + copy(slen.begin(), slen.end(), inserter); + inserter = '\n'; + + copy(inner.begin(), inner.end(), inserter); + return res; +} + +template<class S> +ObjectT<S> ObjectT<S>::load(const typename S::Ref & ref) +{ + return *ref; +} + +template<class S> +ObjectT<S>::operator bool() const +{ + return not holds_alternative<monostate>(content); +} + +template<class S> +optional<RecordT<S>> ObjectT<S>::asRecord() const +{ + if (holds_alternative<RecordT<S>>(content)) + return std::get<RecordT<S>>(content); + return nullopt; +} + +template<class S> +optional<Blob> ObjectT<S>::asBlob() const +{ + if (holds_alternative<Blob>(content)) + return std::get<Blob>(content); + return nullopt; +} + +template class erebos::ObjectT<Storage>; +template class erebos::ObjectT<PartialStorage>; + + +Generation::Generation(): Generation(0) {} +Generation::Generation(size_t g): gen(g) {} + +Generation Generation::next(const vector<Generation> & prev) +{ + Generation ret; + for (const auto g : prev) + if (ret.gen <= g.gen) + ret.gen = g.gen + 1; + return ret; +} + +Generation::operator string() const +{ + return to_string(gen); +} + + +vector<Stored<Object>> erebos::collectStoredObjects(const Stored<Object> & from) +{ + unordered_set<Digest> seen; + vector<Stored<Object>> queue { from }; + vector<Stored<Object>> res; + + while (!queue.empty()) { + auto cur = queue.back(); + queue.pop_back(); + + auto [it, added] = seen.insert(cur.ref().digest()); + if (!added) + continue; + + res.push_back(cur); + + if (auto rec = cur->asRecord()) + for (const auto & ref : rec->items().asRef()) + queue.push_back(Stored<Object>::load(ref)); + } + + return res; +} diff --git a/src/storage.h b/src/storage.h new file mode 100644 index 0000000..c6b5ed2 --- /dev/null +++ b/src/storage.h @@ -0,0 +1,201 @@ +#pragma once + +#include "erebos/storage.h" + +#include <functional> +#include <mutex> +#include <unordered_map> +#include <unordered_set> + +namespace fs = std::filesystem; + +using std::function; +using std::mutex; +using std::optional; +using std::shared_ptr; +using std::unique_ptr; +using std::unordered_map; +using std::unordered_multimap; +using std::unordered_set; +using std::tuple; +using std::variant; +using std::vector; + +namespace erebos { + +class StorageBackend +{ +public: + StorageBackend() = default; + virtual ~StorageBackend() = default; + + virtual bool contains(const Digest &) const = 0; + + virtual optional<vector<uint8_t>> loadBytes(const Digest &) const = 0; + virtual void storeBytes(const Digest &, const vector<uint8_t> &) = 0; + + virtual optional<Digest> headRef(UUID type, UUID id) const = 0; + virtual vector<tuple<UUID, Digest>> headRefs(UUID type) const = 0; + virtual UUID storeHead(UUID type, const Digest & dgst) = 0; + virtual bool replaceHead(UUID type, UUID id, const Digest & old, const Digest & dgst) = 0; + virtual int watchHead(UUID type, const function<void(UUID id, const Digest &)> &) = 0; + virtual void unwatchHead(UUID type, int watchId) = 0; + + virtual optional<vector<uint8_t>> loadKey(const Digest &) const = 0; + virtual void storeKey(const Digest &, const vector<uint8_t> &) = 0; +}; + +class StorageWatchCallback +{ +public: + StorageWatchCallback(int id, const function<void(UUID, const Digest &)> callback): + id(id), callback(callback) {} + + void schedule(UUID, const Digest &); + void run(); + + const int id; + +private: + const function<void(UUID, const Digest &)> callback; + + std::recursive_mutex runMutex; + optional<tuple<UUID, Digest>> scheduled; +}; + +class FilesystemStorage : public StorageBackend +{ +public: + FilesystemStorage(const fs::path &); + virtual ~FilesystemStorage(); + + virtual bool contains(const Digest &) const override; + + virtual optional<vector<uint8_t>> loadBytes(const Digest &) const override; + virtual void storeBytes(const Digest &, const vector<uint8_t> &) override; + + virtual optional<Digest> headRef(UUID type, UUID id) const override; + virtual vector<tuple<UUID, Digest>> headRefs(UUID type) const override; + virtual UUID storeHead(UUID type, const Digest & dgst) override; + virtual bool replaceHead(UUID type, UUID id, const Digest & old, const Digest & dgst) override; + virtual int watchHead(UUID type, const function<void(UUID id, const Digest &)> &) override; + virtual void unwatchHead(UUID type, int watchId) override; + + virtual optional<vector<uint8_t>> loadKey(const Digest &) const override; + virtual void storeKey(const Digest &, const vector<uint8_t> &) override; + +private: + void inotifyWatch(); + + static constexpr size_t CHUNK = 16384; + + fs::path objectPath(const Digest &) const; + fs::path headPath(UUID id) const; + fs::path headPath(UUID id, UUID type) const; + fs::path keyPath(const Digest &) const; + + FILE * openLockFile(const fs::path & path) const; + + fs::path root; + + mutex watcherLock; + std::thread watcherThread; + int inotify = -1; + int inotifyWakeup = -1; + int nextWatcherId = 1; + unordered_multimap<UUID, shared_ptr<StorageWatchCallback>> watchers; + unordered_map<int, UUID> watchMap; +}; + +class MemoryStorage : public StorageBackend +{ +public: + MemoryStorage() = default; + virtual ~MemoryStorage() = default; + + virtual bool contains(const Digest &) const override; + + virtual optional<vector<uint8_t>> loadBytes(const Digest &) const override; + virtual void storeBytes(const Digest &, const vector<uint8_t> &) override; + + virtual optional<Digest> headRef(UUID type, UUID id) const override; + virtual vector<tuple<UUID, Digest>> headRefs(UUID type) const override; + virtual UUID storeHead(UUID type, const Digest & dgst) override; + virtual bool replaceHead(UUID type, UUID id, const Digest & old, const Digest & dgst) override; + virtual int watchHead(UUID type, const function<void(UUID id, const Digest &)> &) override; + virtual void unwatchHead(UUID type, int watchId) override; + + virtual optional<vector<uint8_t>> loadKey(const Digest &) const override; + virtual void storeKey(const Digest &, const vector<uint8_t> &) override; + +private: + unordered_map<Digest, vector<uint8_t>> storage; + unordered_map<UUID, vector<tuple<UUID, Digest>>> heads; + unordered_map<Digest, vector<uint8_t>> keys; + + mutex watcherLock; + int nextWatcherId = 1; + unordered_multimap<UUID, shared_ptr<StorageWatchCallback>> watchers; +}; + +class ChainStorage : public StorageBackend +{ +public: + ChainStorage(shared_ptr<StorageBackend> storage): + ChainStorage(std::move(storage), nullptr) {} + ChainStorage(shared_ptr<StorageBackend> storage, unique_ptr<ChainStorage> parent): + storage(std::move(storage)), parent(std::move(parent)) {} + virtual ~ChainStorage() = default; + + virtual bool contains(const Digest &) const override; + + virtual optional<vector<uint8_t>> loadBytes(const Digest &) const override; + virtual void storeBytes(const Digest &, const vector<uint8_t> &) override; + + virtual optional<Digest> headRef(UUID type, UUID id) const override; + virtual vector<tuple<UUID, Digest>> headRefs(UUID type) const override; + virtual UUID storeHead(UUID type, const Digest & dgst) override; + virtual bool replaceHead(UUID type, UUID id, const Digest & old, const Digest & dgst) override; + virtual int watchHead(UUID type, const function<void(UUID id, const Digest &)> &) override; + virtual void unwatchHead(UUID type, int watchId) override; + + virtual optional<vector<uint8_t>> loadKey(const Digest &) const override; + virtual void storeKey(const Digest &, const vector<uint8_t> &) override; + +private: + shared_ptr<StorageBackend> storage; + unique_ptr<ChainStorage> parent; + + mutex watcherLock; + int nextWatcherId = 1; + unordered_map<int, tuple<int, int>> watchers; +}; + +struct PartialStorage::Priv +{ + shared_ptr<StorageBackend> backend; + + Digest storeBytes(const vector<uint8_t> &) const; + optional<vector<uint8_t>> loadBytes(const Digest & digest) const; + + template<class S> + optional<Digest> copy(const typename S::Ref &, vector<Digest> *) const; + template<class S> + optional<Digest> copy(const ObjectT<S> &, vector<Digest> *) const; + + mutable mutex generationCacheLock {}; + mutable unordered_map<Digest, Generation> generationCache {}; + + mutable mutex rootsCacheLock {}; + mutable unordered_map<Digest, vector<Digest>> rootsCache {}; +}; + +struct PartialRef::Priv +{ + const unique_ptr<PartialStorage> storage; + const Digest digest; +}; + +vector<Stored<Object>> collectStoredObjects(const Stored<Object> &); + +} diff --git a/src/sync.cpp b/src/sync.cpp new file mode 100644 index 0000000..5680da6 --- /dev/null +++ b/src/sync.cpp @@ -0,0 +1,66 @@ +#include <erebos/sync.h> + +#include <erebos/identity.h> +#include <erebos/network.h> + +using namespace erebos; + +using std::scoped_lock; + +static const UUID myUUID("a4f538d0-4e50-4082-8e10-7e3ec2af175d"); + +SyncService::SyncService(Config &&, const Server & s): + server(s) +{ + server.peerList().onUpdate(std::bind(&SyncService::peerWatcher, this, + std::placeholders::_1, std::placeholders::_2)); + watchedLocal = server.localState().lens<vector<Ref>>().watch(std::bind(&SyncService::localStateWatcher, this, + std::placeholders::_1)); +} + +SyncService::~SyncService() = default; + +UUID SyncService::uuid() const +{ + return myUUID; +} + +void SyncService::handle(Context & ctx) +{ + auto pid = ctx.peer().identity(); + if (!pid) + return; + + const auto & powner = pid->finalOwner(); + const Identity owner = ctx.peer().server().identity().finalOwner(); + + if (!powner.sameAs(owner)) + return; + + ctx.local( + ctx.local()->sharedRefAdd(ctx.ref()) + ); +} + +void SyncService::peerWatcher(size_t, const Peer * peer) +{ + if (peer) { + if (auto id = peer->identity()) { + if (id->finalOwner().sameAs(server.identity().finalOwner())) + for (const auto & r : server.localState().get().sharedRefs()) + peer->send(myUUID, r); + } + } +} + +void SyncService::localStateWatcher(const vector<Ref> & refs) +{ + const auto & plist = server.peerList(); + const Identity owner = server.identity().finalOwner(); + + for (size_t i = 0; i < plist.size(); i++) + if (auto id = plist.at(i).identity()) + if (id->finalOwner().sameAs(owner)) + for (const auto & r : refs) + plist.at(i).send(myUUID, r); +} diff --git a/src/time.cpp b/src/time.cpp new file mode 100644 index 0000000..631e0f8 --- /dev/null +++ b/src/time.cpp @@ -0,0 +1,35 @@ +#include <erebos/time.h> + +#include <stdexcept> + +using namespace erebos; + +using std::runtime_error; +using std::string; + +ZonedTime::ZonedTime(string str) +{ + intmax_t t; + unsigned int h, m; + char sign[2]; + if (sscanf(str.c_str(), "%jd %1[+-]%2u%2u", &t, sign, &h, &m) != 4) + throw runtime_error("invalid zoned time"); + + time = std::chrono::system_clock::time_point(std::chrono::seconds(t)); + zone = std::chrono::minutes((sign[0] == '-' ? -1 : 1) * (60 * h + m)); +} + +ZonedTime::operator string() const +{ + char buf[32]; + unsigned int az = std::chrono::abs(zone).count(); + snprintf(buf, sizeof(buf), "%jd %c%02u%02u", + (intmax_t) std::chrono::duration_cast<std::chrono::seconds>(time.time_since_epoch()).count(), + zone < decltype(zone)::zero() ? '-' : '+', az / 60, az % 60); + return string(buf); +} + +ZonedTime ZonedTime::now() +{ + return ZonedTime(std::chrono::system_clock::now()); +} diff --git a/src/uuid.cpp b/src/uuid.cpp new file mode 100644 index 0000000..a53bf27 --- /dev/null +++ b/src/uuid.cpp @@ -0,0 +1,75 @@ +#include <erebos/uuid.h> + +#include <stdexcept> + +#include <openssl/rand.h> + +using namespace erebos; + +using std::nullopt; +using std::optional; +using std::runtime_error; +using std::string; + +static const size_t UUID_STR_LEN = 36; + +static const char * FORMAT_STRING = "%02hhx%02hhx%02hhx%02hhx-%02hhx%02hhx-" + "%02hhx%02hhx-%02hhx%02hhx-%02hhx%02hhx%02hhx%02hhx%02hhx%02hhx"; + +UUID::UUID(const string & str) +{ + if (!fromString(str, *this)) + throw runtime_error("invalid UUID"); +} + +UUID::operator string() const +{ + string str(UUID_STR_LEN, '\0'); + snprintf(str.data(), UUID_STR_LEN + 1, FORMAT_STRING, + uuid[0], uuid[1], uuid[2], uuid[3], uuid[4], uuid[5], uuid[6], uuid[7], + uuid[8], uuid[9], uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15]); + return str; +} + +optional<UUID> UUID::fromString(const string & str) +{ + UUID u; + if (fromString(str, u)) + return u; + return nullopt; +} + +bool UUID::fromString(const string & str, UUID & u) +{ + if (str.size() != UUID_STR_LEN) + return false; + + if (sscanf(str.c_str(), FORMAT_STRING, + &u.uuid[0], &u.uuid[1], &u.uuid[2], &u.uuid[3], &u.uuid[4], &u.uuid[5], &u.uuid[6], &u.uuid[7], + &u.uuid[8], &u.uuid[9], &u.uuid[10], &u.uuid[11], &u.uuid[12], &u.uuid[13], &u.uuid[14], &u.uuid[15]) + != 16) + return false; + + return true; +} + +UUID UUID::generate() +{ + UUID u; + if (RAND_bytes(u.uuid.data(), u.uuid.size()) != 1) + throw runtime_error("failed to generate random UUID"); + + u.uuid[6] = (u.uuid[6] & 0x0f) | 0x40; + u.uuid[8] = (u.uuid[8] & 0x3f) | 0x80; + return u; +} + +bool UUID::operator==(const UUID & other) const +{ + return std::equal(std::begin(uuid), std::end(uuid), std::begin(other.uuid)); +} + +bool UUID::operator!=(const UUID & other) const +{ + return !(*this == other); +} diff --git a/attach.test b/test/attach.test index 33a1483..33a1483 100644 --- a/attach.test +++ b/test/attach.test diff --git a/contact.test b/test/contact.test index 438aa1f..438aa1f 100644 --- a/contact.test +++ b/test/contact.test diff --git a/discovery.test b/test/discovery.test index 2aaaf24..2aaaf24 100644 --- a/discovery.test +++ b/test/discovery.test diff --git a/message.test b/test/message.test index 307f11a..307f11a 100644 --- a/message.test +++ b/test/message.test diff --git a/storage.test b/test/storage.test index 9bf468e..9bf468e 100644 --- a/storage.test +++ b/test/storage.test diff --git a/sync.test b/test/sync.test index ea9595d..ea9595d 100644 --- a/sync.test +++ b/test/sync.test |