diff options
Diffstat (limited to 'src/network')
-rw-r--r-- | src/network/channel.cpp | 216 | ||||
-rw-r--r-- | src/network/channel.h | 78 | ||||
-rw-r--r-- | src/network/protocol.cpp | 677 | ||||
-rw-r--r-- | src/network/protocol.h | 213 |
4 files changed, 1184 insertions, 0 deletions
diff --git a/src/network/channel.cpp b/src/network/channel.cpp new file mode 100644 index 0000000..5fff1fa --- /dev/null +++ b/src/network/channel.cpp @@ -0,0 +1,216 @@ +#include "channel.h" + +#include <algorithm> +#include <cstring> +#include <stdexcept> + +#include <endian.h> + +using std::remove_const; +using std::runtime_error; + +using namespace erebos; + +Ref ChannelRequestData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & p : peers) + items.emplace_back("peer", p); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +ChannelRequestData ChannelRequestData::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) { + if (auto key = rec->item("key").as<PublicKexKey>()) + return ChannelRequestData { + .peers = rec->items("peer").as<Signed<IdentityData>>(), + .key = *key, + }; + } + + return ChannelRequestData { + .peers = {}, + .key = Stored<PublicKexKey>::load(ref.storage().zref()), + }; +} + +Ref ChannelAcceptData::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("req", request); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +ChannelAcceptData ChannelAcceptData::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) + return ChannelAcceptData { + .request = *rec->item("req").as<ChannelRequest>(), + .key = *rec->item("key").as<PublicKexKey>(), + }; + + return ChannelAcceptData { + .request = Stored<ChannelRequest>::load(ref.storage().zref()), + .key = Stored<PublicKexKey>::load(ref.storage().zref()), + }; +} + +unique_ptr<Channel> ChannelAcceptData::channel() const +{ + if (auto secret = SecretKexKey::load(key)) + return make_unique<Channel>( + request->data->peers, + secret->dh(*request->data->key), + false + ); + + if (auto secret = SecretKexKey::load(request->data->key)) + return make_unique<Channel>( + request->data->peers, + secret->dh(*key), + true + ); + + throw runtime_error("failed to load secret DH key"); +} + + +Stored<ChannelRequest> Channel::generateRequest(const Storage & st, + const Identity & self, const Identity & peer) +{ + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelRequestData { + .peers = self.ref()->digest() < peer.ref()->digest() ? + vector<Stored<Signed<IdentityData>>> { + Stored<Signed<IdentityData>>::load(*self.ref()), + Stored<Signed<IdentityData>>::load(*peer.ref()), + } : + vector<Stored<Signed<IdentityData>>> { + Stored<Signed<IdentityData>>::load(*peer.ref()), + Stored<Signed<IdentityData>>::load(*self.ref()), + }, + .key = SecretKexKey::generate(st).pub(), + })); +} + +optional<Stored<ChannelAccept>> Channel::acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request) +{ + if (!request->isSignedBy(peer.keyMessage())) + return nullopt; + + auto & peers = request->data->peers; + if (peers.size() != 2 || + std::none_of(peers.begin(), peers.end(), [&self](const auto & x) + { return x.ref().digest() == self.ref()->digest(); }) || + std::none_of(peers.begin(), peers.end(), [&peer](const auto & x) + { return x.ref().digest() == peer.ref()->digest(); })) + return nullopt; + + auto & st = request.ref().storage(); + + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelAcceptData { + .request = request, + .key = SecretKexKey::generate(st).pub(), + })); +} + +uint64_t Channel::encrypt(BufferCIt plainBegin, BufferCIt plainEnd, + Buffer & encBuffer, size_t encOffset) +{ + auto plainSize = plainEnd - plainBegin; + encBuffer.resize(encOffset + plainSize + 1 /* counter */ + 16 /* tag */); + array<uint8_t, 12> iv; + + uint64_t count = counterNextOut.fetch_add(1); + uint64_t beCount = htobe64(count); + encBuffer[encOffset] = count % 0x100; + + constexpr size_t nonceFixedSize = std::tuple_size_v<decltype(nonceFixedOur)>; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedOur.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_EncryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), + nullptr, key.data(), iv.data()); + + int outl = 0; + uint8_t * cur = encBuffer.data() + encOffset + 1; + + if (EVP_EncryptUpdate(ctx.get(), cur, &outl, &*plainBegin, plainSize) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + if (EVP_EncryptFinal(ctx.get(), cur, &outl) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_GET_TAG, 16, cur); + return count; +} + +optional<uint64_t> Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd, + Buffer & decBuffer, const size_t decOffset) +{ + auto encSize = encEnd - encBegin; + decBuffer.resize(decOffset + encSize); + array<uint8_t, 12> iv; + + if (encBegin + 1 /* counter */ + 16 /* tag */ > encEnd) + return nullopt; + + uint64_t expectedCount = counterNextIn.load(); + uint64_t guessedCount = expectedCount - 0x80u + ((0x80u + encBegin[0] - expectedCount) % 0x100u); + uint64_t beCount = htobe64(guessedCount); + + constexpr size_t nonceFixedSize = std::tuple_size_v<decltype(nonceFixedPeer)>; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedPeer.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_DecryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), + nullptr, key.data(), iv.data()); + + int outl = 0; + uint8_t * cur = decBuffer.data() + decOffset; + + if (EVP_DecryptUpdate(ctx.get(), cur, &outl, + &*encBegin + 1, encSize - 1 - 16) != 1) + return nullopt; + cur += outl; + + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_TAG, 16, + (void *) (&*encEnd - 16))) + return nullopt; + + if (EVP_DecryptFinal_ex(ctx.get(), cur, &outl) != 1) + return nullopt; + cur += outl; + + while (expectedCount < guessedCount + 1 && + not counterNextIn.compare_exchange_weak(expectedCount, guessedCount + 1)) + ; // empty loop body + + decBuffer.resize(cur - decBuffer.data()); + return guessedCount; +} diff --git a/src/network/channel.h b/src/network/channel.h new file mode 100644 index 0000000..bba11b3 --- /dev/null +++ b/src/network/channel.h @@ -0,0 +1,78 @@ +#pragma once + +#include <erebos/storage.h> + +#include "../identity.h" + +#include <atomic> +#include <memory> + +namespace erebos { + +using std::array; +using std::atomic; +using std::unique_ptr; + +struct ChannelRequestData +{ + Ref store(const Storage & st) const; + static ChannelRequestData load(const Ref &); + + const vector<Stored<Signed<IdentityData>>> peers; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelRequestData> ChannelRequest; + +struct ChannelAcceptData +{ + Ref store(const Storage & st) const; + static ChannelAcceptData load(const Ref &); + + unique_ptr<class Channel> channel() const; + + const Stored<ChannelRequest> request; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelAcceptData> ChannelAccept; + +class Channel +{ +public: + Channel(const vector<Stored<Signed<IdentityData>>> & peers, + vector<uint8_t> && key, bool ourRequest): + peers(peers), + key(std::move(key)), + nonceFixedOur({ uint8_t(ourRequest ? 1 : 2), 0, 0, 0 }), + nonceFixedPeer({ uint8_t(ourRequest ? 2 : 1), 0, 0, 0 }) + {} + + Channel(const Channel &) = delete; + Channel(Channel &&) = delete; + Channel & operator=(const Channel &) = delete; + Channel & operator=(Channel &&) = delete; + + static Stored<ChannelRequest> generateRequest(const Storage &, + const Identity & self, const Identity & peer); + static optional<Stored<ChannelAccept>> acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request); + + using Buffer = vector<uint8_t>; + using BufferCIt = Buffer::const_iterator; + uint64_t encrypt(BufferCIt plainBegin, BufferCIt plainEnd, + Buffer & encBuffer, size_t encOffset); + optional<uint64_t> decrypt(BufferCIt encBegin, BufferCIt encEnd, + Buffer & decBuffer, size_t decOffset); + +private: + const vector<Stored<Signed<IdentityData>>> peers; + const vector<uint8_t> key; + + const array<uint8_t, 4> nonceFixedOur; + const array<uint8_t, 4> nonceFixedPeer; + atomic<uint64_t> counterNextOut = 0; + atomic<uint64_t> counterNextIn = 0; +}; + +} diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp new file mode 100644 index 0000000..b781693 --- /dev/null +++ b/src/network/protocol.cpp @@ -0,0 +1,677 @@ +#include "protocol.h" + +#include <sys/socket.h> +#include <unistd.h> + +#include <algorithm> +#include <cstring> +#include <iostream> +#include <mutex> +#include <system_error> + +using std::get_if; +using std::holds_alternative; +using std::move; +using std::nullopt; +using std::runtime_error; +using std::scoped_lock; +using std::visit; + +namespace erebos { + +struct NetworkProtocol::ConnectionPriv +{ + Connection::Id id() const; + + bool send(const PartialStorage &, Header, + const vector<Object> &, bool secure); + + NetworkProtocol * protocol; + const sockaddr_in6 peerAddress; + + mutex cmutex {}; + vector<uint8_t> buffer {}; + + optional<Cookie> receivedCookie = nullopt; + bool confirmedCookie = false; + ChannelState channel = monostate(); + vector<vector<uint8_t>> secureOutQueue {}; + + vector<uint64_t> toAcknowledge {}; +}; + + +NetworkProtocol::NetworkProtocol(): + sock(-1) +{} + +NetworkProtocol::NetworkProtocol(int s, Identity id): + sock(s), + self(move(id)) +{} + +NetworkProtocol::NetworkProtocol(NetworkProtocol && other): + sock(other.sock), + self(move(other.self)) +{ + other.sock = -1; +} + +NetworkProtocol & NetworkProtocol::operator=(NetworkProtocol && other) +{ + sock = other.sock; + other.sock = -1; + self = move(other.self); + return *this; +} + +NetworkProtocol::~NetworkProtocol() +{ + if (sock >= 0) + close(sock); + + for (auto & c : connections) + c->protocol = nullptr; +} + +NetworkProtocol::PollResult NetworkProtocol::poll() +{ + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + { + scoped_lock clock(c->cmutex); + if (c->toAcknowledge.empty()) + continue; + + if (not holds_alternative<unique_ptr<Channel>>(c->channel)) + continue; + } + auto pst = self->ref()->storage().deriveEphemeralStorage(); + c->send(pst, Header {{}}, {}, true); + } + } + + sockaddr_in6 addr; + if (!recvfrom(buffer, addr)) + return ProtocolClosed {}; + + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { + scoped_lock clock(c->cmutex); + buffer.swap(c->buffer); + return ConnectionReadReady { c->id() }; + } + } + + auto pst = self->ref()->storage().deriveEphemeralStorage(); + optional<uint64_t> secure = false; + if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) { + if (auto conn = verifyNewConnection(*header, addr)) + return NewConnection { move(*conn) }; + + if (auto ann = header->lookupFirst<Header::AnnounceSelf>()) + return ReceivedAnnounce { addr, ann->value }; + } + } + + return poll(); +} + +NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) +{ + auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + { + scoped_lock lock(protocolMutex); + connections.push_back(conn.get()); + + vector<Header::Item> header { + Header::Initiation { Digest::of(Object(Record())) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }; + conn->send(self->ref()->storage(), move(header), {}, false); + } + + return Connection(move(conn)); +} + +void NetworkProtocol::updateIdentity(Identity id) +{ + scoped_lock lock(protocolMutex); + self = move(id); + + vector<Header::Item> hitems; + for (const auto & r : self->extRefs()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + for (const auto & r : self->updates()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + + Header header(hitems); + + for (const auto & conn : connections) + conn->send(self->ref()->storage(), header, { **self->ref() }, false); +} + +void NetworkProtocol::announceTo(variant<sockaddr_in, sockaddr_in6> addr) +{ + vector<uint8_t> bytes; + { + scoped_lock lock(protocolMutex); + + if (!self) + throw runtime_error("NetworkProtocol::announceTo without self identity"); + + bytes = Header({ + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + } + + sendto(bytes, addr); +} + +void NetworkProtocol::shutdown() +{ + ::shutdown(sock, SHUT_RDWR); +} + +bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr) +{ + socklen_t addrlen = sizeof(addr); + buffer.resize(4096); + ssize_t ret = ::recvfrom(sock, buffer.data(), buffer.size(), 0, + (sockaddr *) &addr, &addrlen); + if (ret < 0) + throw std::system_error(errno, std::generic_category()); + if (ret == 0) + return false; + + buffer.resize(ret); + return true; +} + +void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> vaddr) +{ + visit([&](auto && addr) { + ::sendto(sock, buffer.data(), buffer.size(), 0, + (sockaddr *) &addr, sizeof(addr)); + }, vaddr); +} + +void NetworkProtocol::sendCookie(variant<sockaddr_in, sockaddr_in6> addr) +{ + auto bytes = Header({ + Header::CookieSet { generateCookie(addr) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + + sendto(bytes, addr); +} + +optional<NetworkProtocol::Connection> NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr) +{ + optional<string> version; + for (const auto & h : header.items) { + if (const auto * ptr = get_if<Header::Version>(&h)) { + if (ptr->value == defaultVersion) { + version = ptr->value; + break; + } + } + } + if (!version) + return nullopt; + + if (header.lookupFirst<Header::Initiation>()) { + sendCookie(addr); + } + + else if (auto cookie = header.lookupFirst<Header::CookieEcho>()) { + if (verifyCookie(addr, cookie->value)) { + auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + connections.push_back(conn.get()); + buffer.swap(conn->buffer); + return Connection(move(conn)); + } + } + + return nullopt; +} + +NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant<sockaddr_in, sockaddr_in6> vaddr) const +{ + vector<uint8_t> cookie; + visit([&](auto && addr) { + cookie.resize(sizeof addr); + memcpy(cookie.data(), &addr, sizeof addr); + }, vaddr); + return Cookie { cookie }; +} + +bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, const NetworkProtocol::Cookie & cookie) const +{ + return visit([&](auto && addr) { + if (cookie.value.size() != sizeof addr) + return false; + return memcmp(cookie.value.data(), &addr, sizeof addr) == 0; + }, vaddr); +} + +/******************************************************************************/ +/* Connection */ +/******************************************************************************/ + +NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const +{ + return reinterpret_cast<uintptr_t>(this); +} + +NetworkProtocol::Connection::Connection(unique_ptr<ConnectionPriv> p_): + p(move(p_)) +{ +} + +NetworkProtocol::Connection::Connection(Connection && other): + p(move(other.p)) +{ +} + +NetworkProtocol::Connection & NetworkProtocol::Connection::operator=(Connection && other) +{ + close(); + p = move(other.p); + return *this; +} + +NetworkProtocol::Connection::~Connection() +{ + close(); +} + +NetworkProtocol::Connection::Id NetworkProtocol::Connection::id() const +{ + return p->id(); +} + +const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const +{ + return p->peerAddress; +} + +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) +{ + vector<uint8_t> buf; + + Channel * channel = nullptr; + unique_ptr<Channel> channelPtr; + + { + scoped_lock lock(p->cmutex); + + if (p->buffer.empty()) + return nullopt; + buf.swap(p->buffer); + + if (holds_alternative<unique_ptr<Channel>>(p->channel)) { + channel = std::get<unique_ptr<Channel>>(p->channel).get(); + } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { + channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); + channel = channelPtr.get(); + } + } + + optional<uint64_t> secure = false; + if (auto header = parsePacket(buf, channel, partStorage, secure)) { + scoped_lock lock(p->cmutex); + + if (secure) { + if (header->isAcknowledged()) + p->toAcknowledge.push_back(*secure); + return header; + } + + if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) { + if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) + return nullopt; + + p->confirmedCookie = true; + + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) + p->receivedCookie = cookieSet->value; + + return header; + } + + if (holds_alternative<monostate>(p->channel)) { + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) { + p->receivedCookie = cookieSet->value; + return header; + } + } + + if (header->lookupFirst<Header::Initiation>()) { + p->protocol->sendCookie(p->peerAddress); + return nullopt; + } + } + return nullopt; +} + +optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & partStorage, + optional<uint64_t> & secure) +{ + vector<uint8_t> decrypted; + auto plainBegin = buf.cbegin(); + auto plainEnd = buf.cbegin(); + + secure = nullopt; + + if ((buf[0] & 0xE0) == 0x80) { + if (not channel) { + std::cerr << "unexpected encrypted packet\n"; + return nullopt; + } + + if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) { + if (decrypted.empty()) { + std::cerr << "empty decrypted content\n"; + } + else if (decrypted[0] == 0x00) { + plainBegin = decrypted.begin() + 1; + plainEnd = decrypted.end(); + } + else { + std::cerr << "streams not implemented\n"; + return nullopt; + } + } + } + else if ((buf[0] & 0xE0) == 0x60) { + plainBegin = buf.begin(); + plainEnd = buf.end(); + } + + if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) { + if (auto header = Header::load(std::get<PartialObject>(*dec))) { + auto pos = std::get<1>(*dec); + while (auto cdec = PartialObject::decodePrefix(partStorage, pos, plainEnd)) { + partStorage.storeObject(std::get<PartialObject>(*cdec)); + pos = std::get<1>(*cdec); + } + + return header; + } + } + + std::cerr << "invalid packet\n"; + return nullopt; +} + +bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, + Header header, + const vector<Object> & objs, bool secure) +{ + return p->send(partStorage, move(header), objs, secure); +} + +bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, + Header header, + const vector<Object> & objs, bool secure) +{ + vector<uint8_t> data, part, out; + + { + scoped_lock clock(cmutex); + + Channel * channel = nullptr; + if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel)) + channel = uptr->get(); + + if (channel || secure) { + data.push_back(0x00); + } else { + if (receivedCookie) + header.items.push_back(Header::CookieEcho { receivedCookie->value }); + if (!confirmedCookie) + header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) }); + } + + if (channel) { + for (auto num : toAcknowledge) + header.items.push_back(Header::AcknowledgedSingle { num }); + toAcknowledge.clear(); + } + + if (header.items.empty()) + return false; + + part = header.toObject(partStorage).encode(); + data.insert(data.end(), part.begin(), part.end()); + for (const auto & obj : objs) { + part = obj.encode(); + data.insert(data.end(), part.begin(), part.end()); + } + + if (channel) { + out.push_back(0x80); + channel->encrypt(data.begin(), data.end(), out, 1); + } else if (secure) { + secureOutQueue.emplace_back(move(data)); + } else { + out = std::move(data); + } + } + + if (not out.empty()) + protocol->sendto(out, peerAddress); + + return true; +} + +void NetworkProtocol::Connection::close() +{ + if (not p) + return; + + if (p->protocol) { + scoped_lock lock(p->protocol->protocolMutex); + for (auto it = p->protocol->connections.begin(); + it != p->protocol->connections.end(); it++) { + if ((*it) == p.get()) { + p->protocol->connections.erase(it); + break; + } + } + } + + p = nullptr; +} + +NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel() +{ + return p->channel; +} + +void NetworkProtocol::Connection::trySendOutQueue() +{ + decltype(p->secureOutQueue) queue; + { + scoped_lock clock(p->cmutex); + + if (p->secureOutQueue.empty()) + return; + + if (not holds_alternative<unique_ptr<Channel>>(p->channel)) + return; + + queue.swap(p->secureOutQueue); + } + + vector<uint8_t> out { 0x80 }; + for (const auto & data : queue) { + std::get<unique_ptr<Channel>>(p->channel)->encrypt(data.begin(), data.end(), out, 1); + p->protocol->sendto(out, p->peerAddress); + } +} + + +/******************************************************************************/ +/* Header */ +/******************************************************************************/ + +bool operator==(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) +{ + if (left.index() != right.index()) + return false; + + return visit([&](auto && arg) { + using T = std::decay_t<decltype(arg)>; + return arg.value == std::get<T>(right).value; + }, left); +} + +optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialRef & ref) +{ + return load(*ref); +} + +optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObject & obj) +{ + auto rec = obj.asRecord(); + if (!rec) + return nullopt; + + vector<Item> items; + for (const auto & item : rec->items()) { + if (item.name == "ACK") { + if (auto ref = item.asRef()) + items.emplace_back(Acknowledged { ref->digest() }); + else if (auto num = item.asInteger()) + items.emplace_back(AcknowledgedSingle { static_cast<uint64_t>(*num) }); + } else if (item.name == "VER") { + if (auto ver = item.asText()) + items.emplace_back(Version { *ver }); + } else if (item.name == "INI") { + if (auto ref = item.asRef()) + items.emplace_back(Initiation { ref->digest() }); + } else if (item.name == "CKS") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieSet { *cookie }); + } else if (item.name == "CKE") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieEcho { *cookie }); + } else if (item.name == "REQ") { + if (auto ref = item.asRef()) + items.emplace_back(DataRequest { ref->digest() }); + } else if (item.name == "RSP") { + if (auto ref = item.asRef()) + items.emplace_back(DataResponse { ref->digest() }); + } else if (item.name == "ANN") { + if (auto ref = item.asRef()) + items.emplace_back(AnnounceSelf { ref->digest() }); + } else if (item.name == "ANU") { + if (auto ref = item.asRef()) + items.emplace_back(AnnounceUpdate { ref->digest() }); + } else if (item.name == "CRQ") { + if (auto ref = item.asRef()) + items.emplace_back(ChannelRequest { ref->digest() }); + } else if (item.name == "CAC") { + if (auto ref = item.asRef()) + items.emplace_back(ChannelAccept { ref->digest() }); + } else if (item.name == "SVT") { + if (auto val = item.asUUID()) + items.emplace_back(ServiceType { *val }); + } else if (item.name == "SVR") { + if (auto ref = item.asRef()) + items.emplace_back(ServiceRef { ref->digest() }); + } + } + + return NetworkProtocol::Header(items); +} + +PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const +{ + vector<PartialRecord::Item> ritems; + + for (const auto & item : items) { + if (const auto * ptr = get_if<Acknowledged>(&item)) + ritems.emplace_back("ACK", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AcknowledgedSingle>(&item)) + ritems.emplace_back("ACK", Record::Item::Integer(ptr->value)); + + else if (const auto * ptr = get_if<Version>(&item)) + ritems.emplace_back("VER", ptr->value); + + else if (const auto * ptr = get_if<Initiation>(&item)) + ritems.emplace_back("INI", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<CookieSet>(&item)) + ritems.emplace_back("CKS", ptr->value.value); + + else if (const auto * ptr = get_if<CookieEcho>(&item)) + ritems.emplace_back("CKE", ptr->value.value); + + else if (const auto * ptr = get_if<DataRequest>(&item)) + ritems.emplace_back("REQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<DataResponse>(&item)) + ritems.emplace_back("RSP", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceSelf>(&item)) + ritems.emplace_back("ANN", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceUpdate>(&item)) + ritems.emplace_back("ANU", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelRequest>(&item)) + ritems.emplace_back("CRQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelAccept>(&item)) + ritems.emplace_back("CAC", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ServiceType>(&item)) + ritems.emplace_back("SVT", ptr->value); + + else if (const auto * ptr = get_if<ServiceRef>(&item)) + ritems.emplace_back("SVR", st.ref(ptr->value)); + } + + return PartialObject(PartialRecord(std::move(ritems))); +} + +bool NetworkProtocol::Header::isAcknowledged() const +{ + for (const auto & item : items) { + if (holds_alternative<Acknowledged>(item) + || holds_alternative<AcknowledgedSingle>(item) + || holds_alternative<Version>(item) + || holds_alternative<Initiation>(item) + || holds_alternative<CookieSet>(item) + || holds_alternative<CookieEcho>(item) + ) + continue; + + return true; + } + return false; +} + +} diff --git a/src/network/protocol.h b/src/network/protocol.h new file mode 100644 index 0000000..ba40744 --- /dev/null +++ b/src/network/protocol.h @@ -0,0 +1,213 @@ +#pragma once + +#include "channel.h" + +#include <erebos/storage.h> + +#include <netinet/in.h> + +#include <cstdint> +#include <memory> +#include <mutex> +#include <variant> +#include <vector> +#include <optional> + +namespace erebos { + +using std::mutex; +using std::optional; +using std::unique_ptr; +using std::variant; +using std::vector; + +class NetworkProtocol +{ +public: + NetworkProtocol(); + explicit NetworkProtocol(int sock, Identity self); + NetworkProtocol(const NetworkProtocol &) = delete; + NetworkProtocol(NetworkProtocol &&); + NetworkProtocol & operator=(const NetworkProtocol &) = delete; + NetworkProtocol & operator=(NetworkProtocol &&); + ~NetworkProtocol(); + + static constexpr char defaultVersion[] = "0.1"; + + class Connection; + + struct Header; + + struct ReceivedAnnounce; + struct NewConnection; + struct ConnectionReadReady; + struct ProtocolClosed {}; + + using PollResult = variant< + ReceivedAnnounce, + NewConnection, + ConnectionReadReady, + ProtocolClosed>; + + PollResult poll(); + + struct Cookie { vector<uint8_t> value; }; + + using ChannelState = variant<monostate, + Stored<ChannelRequest>, + shared_ptr<struct WaitingRef>, + Stored<ChannelAccept>, + unique_ptr<Channel>>; + + Connection connect(sockaddr_in6 addr); + + void updateIdentity(Identity self); + void announceTo(variant<sockaddr_in, sockaddr_in6> addr); + + void shutdown(); + +private: + bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr); + void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr); + + void sendCookie(variant<sockaddr_in, sockaddr_in6> addr); + optional<Connection> verifyNewConnection(const Header & header, sockaddr_in6 addr); + + Cookie generateCookie(variant<sockaddr_in, sockaddr_in6> addr) const; + bool verifyCookie(variant<sockaddr_in, sockaddr_in6> addr, const Cookie & cookie) const; + + int sock; + + mutex protocolMutex; + vector<uint8_t> buffer; + + optional<Identity> self; + + struct ConnectionPriv; + vector<ConnectionPriv *> connections; +}; + +class NetworkProtocol::Connection +{ + friend class NetworkProtocol; + Connection(unique_ptr<ConnectionPriv> p); +public: + Connection(const Connection &) = delete; + Connection(Connection &&); + Connection & operator=(const Connection &) = delete; + Connection & operator=(Connection &&); + ~Connection(); + + using Id = uintptr_t; + Id id() const; + + const sockaddr_in6 & peerAddress() const; + + optional<Header> receive(const PartialStorage &); + bool send(const PartialStorage &, NetworkProtocol::Header, + const vector<Object> &, bool secure); + + void close(); + + // temporary: + ChannelState & channel(); + void trySendOutQueue(); + +private: + static optional<Header> parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & st, + optional<uint64_t> & secure); + + unique_ptr<ConnectionPriv> p; +}; + +struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; }; +struct NetworkProtocol::NewConnection { Connection conn; }; +struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; + +struct NetworkProtocol::Header +{ + struct Acknowledged { Digest value; }; + struct AcknowledgedSingle { uint64_t value; }; + struct Version { string value; }; + struct Initiation { Digest value; }; + struct CookieSet { Cookie value; }; + struct CookieEcho { Cookie value; }; + struct DataRequest { Digest value; }; + struct DataResponse { Digest value; }; + struct AnnounceSelf { Digest value; }; + struct AnnounceUpdate { Digest value; }; + struct ChannelRequest { Digest value; }; + struct ChannelAccept { Digest value; }; + struct ServiceType { UUID value; }; + struct ServiceRef { Digest value; }; + + using Item = variant< + Acknowledged, + AcknowledgedSingle, + Version, + Initiation, + CookieSet, + CookieEcho, + DataRequest, + DataResponse, + AnnounceSelf, + AnnounceUpdate, + ChannelRequest, + ChannelAccept, + ServiceType, + ServiceRef>; + + Header(const vector<Item> & items): items(items) {} + static optional<Header> load(const PartialRef &); + static optional<Header> load(const PartialObject &); + PartialObject toObject(const PartialStorage &) const; + + template<class T> const T * lookupFirst() const; + bool isAcknowledged() const; + + vector<Item> items; +}; + +template<class T> +const T * NetworkProtocol::Header::lookupFirst() const +{ + for (const auto & h : items) + if (auto ptr = std::get_if<T>(&h)) + return ptr; + return nullptr; +} + +bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &); +inline bool operator!=(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) +{ return not (left == right); } + +inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right) +{ return left.value == right.value; } + +class ReplyBuilder +{ +public: + void header(NetworkProtocol::Header::Item &&); + void body(const Ref &); + + const vector<NetworkProtocol::Header::Item> & header() const { return mheader; } + vector<Object> body() const; + +private: + vector<NetworkProtocol::Header::Item> mheader; + vector<Ref> mbody; +}; + +struct WaitingRef +{ + const Storage storage; + const PartialRef ref; + vector<Digest> missing; + + optional<Ref> check(); + optional<Ref> check(ReplyBuilder &); +}; + +} |