diff options
Diffstat (limited to 'src/network')
-rw-r--r-- | src/network/channel.cpp | 216 | ||||
-rw-r--r-- | src/network/channel.h | 78 | ||||
-rw-r--r-- | src/network/protocol.cpp | 948 | ||||
-rw-r--r-- | src/network/protocol.h | 277 |
4 files changed, 1519 insertions, 0 deletions
diff --git a/src/network/channel.cpp b/src/network/channel.cpp new file mode 100644 index 0000000..5fff1fa --- /dev/null +++ b/src/network/channel.cpp @@ -0,0 +1,216 @@ +#include "channel.h" + +#include <algorithm> +#include <cstring> +#include <stdexcept> + +#include <endian.h> + +using std::remove_const; +using std::runtime_error; + +using namespace erebos; + +Ref ChannelRequestData::store(const Storage & st) const +{ + vector<Record::Item> items; + + for (const auto & p : peers) + items.emplace_back("peer", p); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +ChannelRequestData ChannelRequestData::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) { + if (auto key = rec->item("key").as<PublicKexKey>()) + return ChannelRequestData { + .peers = rec->items("peer").as<Signed<IdentityData>>(), + .key = *key, + }; + } + + return ChannelRequestData { + .peers = {}, + .key = Stored<PublicKexKey>::load(ref.storage().zref()), + }; +} + +Ref ChannelAcceptData::store(const Storage & st) const +{ + vector<Record::Item> items; + + items.emplace_back("req", request); + items.emplace_back("key", key); + + return st.storeObject(Record(std::move(items))); +} + +ChannelAcceptData ChannelAcceptData::load(const Ref & ref) +{ + if (auto rec = ref->asRecord()) + return ChannelAcceptData { + .request = *rec->item("req").as<ChannelRequest>(), + .key = *rec->item("key").as<PublicKexKey>(), + }; + + return ChannelAcceptData { + .request = Stored<ChannelRequest>::load(ref.storage().zref()), + .key = Stored<PublicKexKey>::load(ref.storage().zref()), + }; +} + +unique_ptr<Channel> ChannelAcceptData::channel() const +{ + if (auto secret = SecretKexKey::load(key)) + return make_unique<Channel>( + request->data->peers, + secret->dh(*request->data->key), + false + ); + + if (auto secret = SecretKexKey::load(request->data->key)) + return make_unique<Channel>( + request->data->peers, + secret->dh(*key), + true + ); + + throw runtime_error("failed to load secret DH key"); +} + + +Stored<ChannelRequest> Channel::generateRequest(const Storage & st, + const Identity & self, const Identity & peer) +{ + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelRequestData { + .peers = self.ref()->digest() < peer.ref()->digest() ? + vector<Stored<Signed<IdentityData>>> { + Stored<Signed<IdentityData>>::load(*self.ref()), + Stored<Signed<IdentityData>>::load(*peer.ref()), + } : + vector<Stored<Signed<IdentityData>>> { + Stored<Signed<IdentityData>>::load(*peer.ref()), + Stored<Signed<IdentityData>>::load(*self.ref()), + }, + .key = SecretKexKey::generate(st).pub(), + })); +} + +optional<Stored<ChannelAccept>> Channel::acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request) +{ + if (!request->isSignedBy(peer.keyMessage())) + return nullopt; + + auto & peers = request->data->peers; + if (peers.size() != 2 || + std::none_of(peers.begin(), peers.end(), [&self](const auto & x) + { return x.ref().digest() == self.ref()->digest(); }) || + std::none_of(peers.begin(), peers.end(), [&peer](const auto & x) + { return x.ref().digest() == peer.ref()->digest(); })) + return nullopt; + + auto & st = request.ref().storage(); + + auto signKey = SecretKey::load(self.keyMessage()); + if (!signKey) + throw runtime_error("failed to load own message key"); + + return signKey->sign(st.store(ChannelAcceptData { + .request = request, + .key = SecretKexKey::generate(st).pub(), + })); +} + +uint64_t Channel::encrypt(BufferCIt plainBegin, BufferCIt plainEnd, + Buffer & encBuffer, size_t encOffset) +{ + auto plainSize = plainEnd - plainBegin; + encBuffer.resize(encOffset + plainSize + 1 /* counter */ + 16 /* tag */); + array<uint8_t, 12> iv; + + uint64_t count = counterNextOut.fetch_add(1); + uint64_t beCount = htobe64(count); + encBuffer[encOffset] = count % 0x100; + + constexpr size_t nonceFixedSize = std::tuple_size_v<decltype(nonceFixedOur)>; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedOur.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_EncryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), + nullptr, key.data(), iv.data()); + + int outl = 0; + uint8_t * cur = encBuffer.data() + encOffset + 1; + + if (EVP_EncryptUpdate(ctx.get(), cur, &outl, &*plainBegin, plainSize) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + if (EVP_EncryptFinal(ctx.get(), cur, &outl) != 1) + throw runtime_error("failed to encrypt data"); + cur += outl; + + EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_GET_TAG, 16, cur); + return count; +} + +optional<uint64_t> Channel::decrypt(BufferCIt encBegin, BufferCIt encEnd, + Buffer & decBuffer, const size_t decOffset) +{ + auto encSize = encEnd - encBegin; + decBuffer.resize(decOffset + encSize); + array<uint8_t, 12> iv; + + if (encBegin + 1 /* counter */ + 16 /* tag */ > encEnd) + return nullopt; + + uint64_t expectedCount = counterNextIn.load(); + uint64_t guessedCount = expectedCount - 0x80u + ((0x80u + encBegin[0] - expectedCount) % 0x100u); + uint64_t beCount = htobe64(guessedCount); + + constexpr size_t nonceFixedSize = std::tuple_size_v<decltype(nonceFixedPeer)>; + static_assert(nonceFixedSize + sizeof beCount == iv.size()); + + std::copy_n(nonceFixedPeer.begin(), nonceFixedSize, iv.begin()); + std::memcpy(iv.data() + nonceFixedSize, &beCount, sizeof beCount); + + const unique_ptr<EVP_CIPHER_CTX, void(*)(EVP_CIPHER_CTX*)> + ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free); + EVP_DecryptInit_ex(ctx.get(), EVP_chacha20_poly1305(), + nullptr, key.data(), iv.data()); + + int outl = 0; + uint8_t * cur = decBuffer.data() + decOffset; + + if (EVP_DecryptUpdate(ctx.get(), cur, &outl, + &*encBegin + 1, encSize - 1 - 16) != 1) + return nullopt; + cur += outl; + + if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_TAG, 16, + (void *) (&*encEnd - 16))) + return nullopt; + + if (EVP_DecryptFinal_ex(ctx.get(), cur, &outl) != 1) + return nullopt; + cur += outl; + + while (expectedCount < guessedCount + 1 && + not counterNextIn.compare_exchange_weak(expectedCount, guessedCount + 1)) + ; // empty loop body + + decBuffer.resize(cur - decBuffer.data()); + return guessedCount; +} diff --git a/src/network/channel.h b/src/network/channel.h new file mode 100644 index 0000000..bba11b3 --- /dev/null +++ b/src/network/channel.h @@ -0,0 +1,78 @@ +#pragma once + +#include <erebos/storage.h> + +#include "../identity.h" + +#include <atomic> +#include <memory> + +namespace erebos { + +using std::array; +using std::atomic; +using std::unique_ptr; + +struct ChannelRequestData +{ + Ref store(const Storage & st) const; + static ChannelRequestData load(const Ref &); + + const vector<Stored<Signed<IdentityData>>> peers; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelRequestData> ChannelRequest; + +struct ChannelAcceptData +{ + Ref store(const Storage & st) const; + static ChannelAcceptData load(const Ref &); + + unique_ptr<class Channel> channel() const; + + const Stored<ChannelRequest> request; + const Stored<PublicKexKey> key; +}; + +typedef Signed<ChannelAcceptData> ChannelAccept; + +class Channel +{ +public: + Channel(const vector<Stored<Signed<IdentityData>>> & peers, + vector<uint8_t> && key, bool ourRequest): + peers(peers), + key(std::move(key)), + nonceFixedOur({ uint8_t(ourRequest ? 1 : 2), 0, 0, 0 }), + nonceFixedPeer({ uint8_t(ourRequest ? 2 : 1), 0, 0, 0 }) + {} + + Channel(const Channel &) = delete; + Channel(Channel &&) = delete; + Channel & operator=(const Channel &) = delete; + Channel & operator=(Channel &&) = delete; + + static Stored<ChannelRequest> generateRequest(const Storage &, + const Identity & self, const Identity & peer); + static optional<Stored<ChannelAccept>> acceptRequest(const Identity & self, + const Identity & peer, const Stored<ChannelRequest> & request); + + using Buffer = vector<uint8_t>; + using BufferCIt = Buffer::const_iterator; + uint64_t encrypt(BufferCIt plainBegin, BufferCIt plainEnd, + Buffer & encBuffer, size_t encOffset); + optional<uint64_t> decrypt(BufferCIt encBegin, BufferCIt encEnd, + Buffer & decBuffer, size_t decOffset); + +private: + const vector<Stored<Signed<IdentityData>>> peers; + const vector<uint8_t> key; + + const array<uint8_t, 4> nonceFixedOur; + const array<uint8_t, 4> nonceFixedPeer; + atomic<uint64_t> counterNextOut = 0; + atomic<uint64_t> counterNextIn = 0; +}; + +} diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp new file mode 100644 index 0000000..89d6a88 --- /dev/null +++ b/src/network/protocol.cpp @@ -0,0 +1,948 @@ +#include "protocol.h" + +#include <sys/socket.h> +#include <unistd.h> + +#include <algorithm> +#include <cstring> +#include <iostream> +#include <mutex> +#include <system_error> + +using std::get_if; +using std::holds_alternative; +using std::move; +using std::nullopt; +using std::runtime_error; +using std::scoped_lock; +using std::to_string; +using std::unique_lock; +using std::visit; + +namespace erebos { + +static constexpr uint8_t maxStreamNumber = 0x3F; + +struct NetworkProtocol::ConnectionPriv +{ + Connection::Id id() const; + + size_t mtu() const; + bool send(const PartialStorage &, Header, + const vector<Object> &, bool secure); + bool send( const StreamData & chunk ); + + NetworkProtocol * protocol; + const sockaddr_in6 peerAddress; + + mutex cmutex {}; + vector<uint8_t> buffer {}; + + optional<Cookie> receivedCookie = nullopt; + bool confirmedCookie = false; + ChannelState channel = monostate(); + vector<vector<uint8_t>> secureOutQueue {}; + + size_t mtuLower = 1000; // TODO: MTU + + vector<uint64_t> toAcknowledge {}; + + vector< shared_ptr< InStream >> inStreams {}; + vector< shared_ptr< OutStream >> outStreams {}; +}; + + +NetworkProtocol::NetworkProtocol(): + sock(-1) +{} + +NetworkProtocol::NetworkProtocol(int s, Identity id): + sock(s), + self(move(id)) +{} + +NetworkProtocol::NetworkProtocol(NetworkProtocol && other): + sock(other.sock), + self(move(other.self)) +{ + other.sock = -1; +} + +NetworkProtocol & NetworkProtocol::operator=(NetworkProtocol && other) +{ + sock = other.sock; + other.sock = -1; + self = move(other.self); + return *this; +} + +NetworkProtocol::~NetworkProtocol() +{ + if (sock >= 0) + close(sock); + + for (auto & c : connections) + c->protocol = nullptr; +} + +NetworkProtocol::PollResult NetworkProtocol::poll() +{ + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + vector< StreamData > streamChunks; + bool sendAck = false; + { + scoped_lock clock(c->cmutex); + sendAck = not c->toAcknowledge.empty() && + holds_alternative< unique_ptr< Channel >>( c->channel ); + + for (auto & s : c->outStreams) { + unique_lock slock(s->streamMutex); + while (s->hasDataLocked()) + streamChunks.push_back( s->getNextChunkLocked( c->mtu() )); + if( s->closed ){ + // TODO: wait after ack + streamChunks.push_back( { s->id, (uint8_t) s->nextSequence, {} } ); + slock.unlock(); + s.reset(); + } + } + + while( not c->outStreams.empty() && not c->outStreams.back() ) + c->outStreams.pop_back(); + } + if (sendAck) { + auto pst = self->ref()->storage().deriveEphemeralStorage(); + c->send(pst, Header {{}}, {}, true); + } + for (const auto & chunk : streamChunks) { + c->send( chunk ); + } + } + } + + sockaddr_in6 addr; + if (!recvfrom(buffer, addr)) + return ProtocolClosed {}; + + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { + scoped_lock clock(c->cmutex); + buffer.swap(c->buffer); + return ConnectionReadReady { c->id() }; + } + } + + auto pst = self->ref()->storage().deriveEphemeralStorage(); + optional<uint64_t> secure = false; + auto parsed = Connection::parsePacket(buffer, nullptr, pst, secure); + if (const auto * header = get_if< Header >( &parsed )) { + if (auto conn = verifyNewConnection(*header, addr)) + return NewConnection { move(*conn) }; + + if (auto ann = header->lookupFirst<Header::AnnounceSelf>()) + return ReceivedAnnounce { addr, ann->value }; + } + } + + return poll(); +} + +NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) +{ + auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + { + scoped_lock lock(protocolMutex); + connections.push_back(conn.get()); + + vector<Header::Item> header { + Header::Initiation { Digest::of(Object(Record())) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }; + conn->send(self->ref()->storage(), move(header), {}, false); + } + + return Connection(move(conn)); +} + +void NetworkProtocol::updateIdentity(Identity id) +{ + scoped_lock lock(protocolMutex); + self = move(id); + + vector<Header::Item> hitems; + for (const auto & r : self->extRefs()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + for (const auto & r : self->updates()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + + Header header(hitems); + + for (const auto & conn : connections) + conn->send(self->ref()->storage(), header, { **self->ref() }, false); +} + +void NetworkProtocol::announceTo(variant<sockaddr_in, sockaddr_in6> addr) +{ + vector<uint8_t> bytes; + { + scoped_lock lock(protocolMutex); + + if (!self) + throw runtime_error("NetworkProtocol::announceTo without self identity"); + + bytes = Header({ + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + } + + sendto(bytes, addr); +} + +void NetworkProtocol::shutdown() +{ + ::shutdown(sock, SHUT_RDWR); +} + +bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr) +{ + socklen_t addrlen = sizeof(addr); + buffer.resize(4096); + ssize_t ret = ::recvfrom(sock, buffer.data(), buffer.size(), 0, + (sockaddr *) &addr, &addrlen); + if (ret < 0) + throw std::system_error(errno, std::generic_category()); + if (ret == 0) + return false; + + buffer.resize(ret); + return true; +} + +void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> vaddr) +{ + visit([&](auto && addr) { + ::sendto(sock, buffer.data(), buffer.size(), 0, + (sockaddr *) &addr, sizeof(addr)); + }, vaddr); +} + +void NetworkProtocol::sendCookie(variant<sockaddr_in, sockaddr_in6> addr) +{ + auto bytes = Header({ + Header::CookieSet { generateCookie(addr) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + + sendto(bytes, addr); +} + +optional<NetworkProtocol::Connection> NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr) +{ + optional<string> version; + for (const auto & h : header.items) { + if (const auto * ptr = get_if<Header::Version>(&h)) { + if (ptr->value == defaultVersion) { + version = ptr->value; + break; + } + } + } + if (!version) + return nullopt; + + if (header.lookupFirst<Header::Initiation>()) { + sendCookie(addr); + } + + else if (auto cookie = header.lookupFirst<Header::CookieEcho>()) { + if (verifyCookie(addr, cookie->value)) { + auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + connections.push_back(conn.get()); + buffer.swap(conn->buffer); + return Connection(move(conn)); + } + } + + return nullopt; +} + +NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant<sockaddr_in, sockaddr_in6> vaddr) const +{ + vector<uint8_t> cookie; + visit([&](auto && addr) { + cookie.resize(sizeof addr); + memcpy(cookie.data(), &addr, sizeof addr); + }, vaddr); + return Cookie { cookie }; +} + +bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, const NetworkProtocol::Cookie & cookie) const +{ + return visit([&](auto && addr) { + if (cookie.value.size() != sizeof addr) + return false; + return memcmp(cookie.value.data(), &addr, sizeof addr) == 0; + }, vaddr); +} + +/******************************************************************************/ +/* Connection */ +/******************************************************************************/ + +using Connection = NetworkProtocol::Connection; + +NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const +{ + return reinterpret_cast<uintptr_t>(this); +} + +NetworkProtocol::Connection::Connection(unique_ptr<ConnectionPriv> p_): + p(move(p_)) +{ +} + +NetworkProtocol::Connection::Connection(Connection && other): + p(move(other.p)) +{ +} + +NetworkProtocol::Connection & NetworkProtocol::Connection::operator=(Connection && other) +{ + close(); + p = move(other.p); + return *this; +} + +NetworkProtocol::Connection::~Connection() +{ + close(); +} + +NetworkProtocol::Connection::Id NetworkProtocol::Connection::id() const +{ + return p->id(); +} + +const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const +{ + return p->peerAddress; +} + +size_t Connection::mtu() const +{ + return p->mtu(); +} + +size_t NetworkProtocol::ConnectionPriv::mtu() const +{ + if( get_if< unique_ptr< Channel >>( &channel )) + return mtuLower // space for: + - 1 // "encrypted" tag + - 1 // counter + - 1 // channel number + - 1 // channel sequence + - 16 // tag + ; + return mtuLower - 128; // some space for cookie headers +} + +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) +{ + vector<uint8_t> buf; + + Channel * channel = nullptr; + unique_ptr<Channel> channelPtr; + + { + scoped_lock lock(p->cmutex); + + if (p->buffer.empty()) + return nullopt; + buf.swap(p->buffer); + + if (holds_alternative<unique_ptr<Channel>>(p->channel)) { + channel = std::get<unique_ptr<Channel>>(p->channel).get(); + } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { + channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); + channel = channelPtr.get(); + } + } + + optional<uint64_t> secure = false; + auto parsed = parsePacket(buf, channel, partStorage, secure); + if (const auto * header = get_if< Header >( &parsed )) { + scoped_lock lock(p->cmutex); + + if (secure) { + if (header->isAcknowledged()) + p->toAcknowledge.push_back(*secure); + return *header; + } + + if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) { + if (!p->protocol->verifyCookie(p->peerAddress, cookieEcho->value)) + return nullopt; + + p->confirmedCookie = true; + + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) + p->receivedCookie = cookieSet->value; + + return *header; + } + + if (holds_alternative<monostate>(p->channel)) { + if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) { + p->receivedCookie = cookieSet->value; + return *header; + } + } + + if (header->lookupFirst<Header::Initiation>()) { + p->protocol->sendCookie(p->peerAddress); + return nullopt; + } + } + else if( auto * sdata = get_if< StreamData >( &parsed )){ + scoped_lock lock(p->cmutex); + if (secure) + p->toAcknowledge.push_back(*secure); + + InStream * stream = nullptr; + for (const auto & s : p->inStreams) { + if (s->id == sdata->id) { + stream = s.get(); + break; + } + } + if (not stream) { + std::cerr << "unexpected stream number\n"; + return nullopt; + } + + stream->writeChunk( move(*sdata) ); + if( stream->closed ) + p->inStreams.erase( + std::remove_if( p->inStreams.begin(), p->inStreams.end(), + [&]( auto & sptr ) { return sptr.get() == stream; } ), + p->inStreams.end() ); + return nullopt; + } + return nullopt; +} + +variant< monostate, NetworkProtocol::Header, NetworkProtocol::StreamData > +NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & partStorage, + optional<uint64_t> & secure) +{ + vector<uint8_t> decrypted; + auto plainBegin = buf.cbegin(); + auto plainEnd = buf.cbegin(); + + secure = nullopt; + + if ((buf[0] & 0xE0) == 0x80) { + if (not channel) { + std::cerr << "unexpected encrypted packet\n"; + return monostate(); + } + + if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) { + if (decrypted.empty()) { + std::cerr << "empty decrypted content\n"; + } + else if (decrypted[0] == 0x00) { + plainBegin = decrypted.begin() + 1; + plainEnd = decrypted.end(); + } + else if (decrypted[0] <= maxStreamNumber) { + StreamData sdata; + sdata.id = decrypted[0]; + sdata.sequence = decrypted[1]; + sdata.data.resize( decrypted.size() - 2 ); + std::copy(decrypted.begin() + 2, decrypted.end(), sdata.data.begin()); + return sdata; + } + else { + std::cerr << "unexpected stream header\n"; + return monostate(); + } + } + } + else if ((buf[0] & 0xE0) == 0x60) { + plainBegin = buf.begin(); + plainEnd = buf.end(); + } + + if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) { + if (auto header = Header::load(std::get<PartialObject>(*dec))) { + auto pos = std::get<1>(*dec); + while (auto cdec = PartialObject::decodePrefix(partStorage, pos, plainEnd)) { + partStorage.storeObject(std::get<PartialObject>(*cdec)); + pos = std::get<1>(*cdec); + } + + return *header; + } + } + + std::cerr << "invalid packet\n"; + return monostate(); +} + +bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, + Header header, + const vector<Object> & objs, bool secure) +{ + return p->send(partStorage, move(header), objs, secure); +} + +bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, + Header header, + const vector<Object> & objs, bool secure) +{ + vector<uint8_t> data, part, out; + + { + scoped_lock clock(cmutex); + + Channel * channel = nullptr; + if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel)) + channel = uptr->get(); + + if (channel || secure) { + data.push_back(0x00); + } else { + if (receivedCookie) + header.items.push_back(Header::CookieEcho { receivedCookie->value }); + if (!confirmedCookie) + header.items.push_back(Header::CookieSet { protocol->generateCookie(peerAddress) }); + } + + if (channel) { + for (auto num : toAcknowledge) + header.items.push_back(Header::AcknowledgedSingle { num }); + toAcknowledge.clear(); + } + + if (header.items.empty()) + return false; + + part = header.toObject(partStorage).encode(); + data.insert(data.end(), part.begin(), part.end()); + for (const auto & obj : objs) { + part = obj.encode(); + data.insert(data.end(), part.begin(), part.end()); + } + + if (channel) { + out.push_back(0x80); + channel->encrypt(data.begin(), data.end(), out, 1); + } else if (secure) { + secureOutQueue.emplace_back(move(data)); + } else { + out = std::move(data); + } + } + + if (not out.empty()) + protocol->sendto(out, peerAddress); + + return true; +} + +bool NetworkProtocol::Connection::send( const StreamData & chunk ) +{ + return p->send( chunk ); +} + +bool NetworkProtocol::ConnectionPriv::send( const StreamData & chunk ) +{ + vector<uint8_t> data, out; + + { + scoped_lock clock( cmutex ); + + Channel * channel = nullptr; + if (auto uptr = get_if< unique_ptr< Channel >>( &this->channel )) + channel = uptr->get(); + if (not channel) + return false; + + data.push_back( chunk.id ); + data.push_back( static_cast< uint8_t >( chunk.sequence )); + data.insert( data.end(), chunk.data.begin(), chunk.data.end() ); + + out.push_back( 0x80 ); + channel->encrypt( data.begin(), data.end(), out, 1 ); + } + + protocol->sendto( out, peerAddress ); + return true; +} + + +void NetworkProtocol::Connection::close() +{ + if (not p) + return; + + if (p->protocol) { + scoped_lock lock(p->protocol->protocolMutex); + for (auto it = p->protocol->connections.begin(); + it != p->protocol->connections.end(); it++) { + if ((*it) == p.get()) { + p->protocol->connections.erase(it); + break; + } + } + } + + p = nullptr; +} + +shared_ptr< NetworkProtocol::InStream > NetworkProtocol::Connection::openInStream( uint8_t sid ) +{ + scoped_lock lock( p->cmutex ); + for (const auto & s : p->inStreams) + if (s->id == sid) + throw runtime_error("inbound stream " + to_string(sid) + " already open"); + + p->inStreams.emplace_back( new InStream( sid )); + return p->inStreams.back(); +} + +shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream() +{ + scoped_lock lock( p->cmutex ); + + uint8_t sid = 1; + if( not p->outStreams.empty() ){ + if( p->outStreams.back()->id < maxStreamNumber ) + sid = p->outStreams.back()->id + 1; + else + throw runtime_error("no free outbound stream"); + } + + p->outStreams.emplace_back( new OutStream( sid )); + return p->outStreams.back(); +} + +NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel() +{ + return p->channel; +} + +void NetworkProtocol::Connection::trySendOutQueue() +{ + decltype(p->secureOutQueue) queue; + { + scoped_lock clock(p->cmutex); + + if (p->secureOutQueue.empty()) + return; + + if (not holds_alternative<unique_ptr<Channel>>(p->channel)) + return; + + queue.swap(p->secureOutQueue); + } + + vector<uint8_t> out { 0x80 }; + for (const auto & data : queue) { + std::get<unique_ptr<Channel>>(p->channel)->encrypt(data.begin(), data.end(), out, 1); + p->protocol->sendto(out, p->peerAddress); + } +} + + +NetworkProtocol::Stream::Stream(uint8_t id_): + id(id_) +{ + readPtr = readBuffer.begin(); +} + +void NetworkProtocol::Stream::close() +{ + scoped_lock lock( streamMutex ); + closed = true; +} + +bool NetworkProtocol::Stream::hasDataLocked() const +{ + return not writeBuffer.empty() || readPtr < readBuffer.end(); +} + +size_t NetworkProtocol::Stream::writeLocked( const uint8_t * buf, size_t size ) +{ + writeBuffer.insert( writeBuffer.end(), buf, buf + size ); + return size; +} + +size_t NetworkProtocol::Stream::readLocked( uint8_t * buf, size_t size ) +{ + size_t res = 0; + if (readPtr < readBuffer.end()) { + res = std::min( size, static_cast< size_t >( readBuffer.end() - readPtr )); + std::copy_n( readPtr, res, buf ); + readPtr += res; + } + if (res < size && not writeBuffer.empty()) { + std::swap( readBuffer, writeBuffer ); + readPtr = readBuffer.begin(); + writeBuffer.clear(); + return res + readLocked( buf + res, size - res ); + } + return res; +} + +bool NetworkProtocol::InStream::isComplete() const +{ + scoped_lock lock( streamMutex ); + return closed && outOfOrderChunks.empty(); +} + +vector< uint8_t > NetworkProtocol::InStream::readAll() +{ + scoped_lock lock( streamMutex ); + if (readBuffer.empty()) { + vector< uint8_t > res; + std::swap( res, writeBuffer ); + return res; + } + + readBuffer.insert( readBuffer.end(), writeBuffer.begin(), writeBuffer.end() ); + writeBuffer.clear(); + + vector< uint8_t > res; + std::swap( res, readBuffer ); + readPtr = readBuffer.begin(); + return res; +} + +size_t NetworkProtocol::InStream::read( uint8_t * buf, size_t size ) +{ + scoped_lock lock( streamMutex ); + return readLocked( buf, size ); +} + +void NetworkProtocol::InStream::writeChunk( StreamData chunk ) +{ + scoped_lock lock( streamMutex ); + if( tryUseChunkLocked( chunk )) { + auto it = outOfOrderChunks.begin(); + while( it != outOfOrderChunks.end() && tryUseChunkLocked( *it )) + it++; + outOfOrderChunks.erase( outOfOrderChunks.begin(), it ); + } else { + auto it = outOfOrderChunks.begin(); + while( it < outOfOrderChunks.end() && + it->sequence - static_cast< uint8_t >( nextSequence ) + < chunk.sequence - static_cast< uint8_t >( nextSequence )) + it++; + outOfOrderChunks.insert( it, move(chunk) ); + } +} + +bool NetworkProtocol::InStream::tryUseChunkLocked( const StreamData & chunk ) +{ + if( chunk.sequence != static_cast< uint8_t >( nextSequence )) + return false; + + if( chunk.data.empty() ) + closed = true; + else + writeLocked( chunk.data.data(), chunk.data.size() ); + nextSequence++; + return true; +} + +size_t NetworkProtocol::OutStream::write( const uint8_t * buf, size_t size ) +{ + scoped_lock lock( streamMutex ); + return writeLocked( buf, size ); +} + +NetworkProtocol::StreamData NetworkProtocol::OutStream::getNextChunkLocked( size_t size ) +{ + StreamData res; + res.id = id; + res.sequence = nextSequence++, + + res.data.resize( size ); + size = readLocked( res.data.data(), size ); + res.data.resize( size ); + + return res; +} + + +/******************************************************************************/ +/* Header */ +/******************************************************************************/ + +bool operator==(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) +{ + if (left.index() != right.index()) + return false; + + return visit([&](auto && arg) { + using T = std::decay_t<decltype(arg)>; + return arg.value == std::get<T>(right).value; + }, left); +} + +optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialRef & ref) +{ + return load(*ref); +} + +optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObject & obj) +{ + auto rec = obj.asRecord(); + if (!rec) + return nullopt; + + vector<Item> items; + for (const auto & item : rec->items()) { + if (item.name == "ACK") { + if (auto ref = item.asRef()) + items.emplace_back(Acknowledged { ref->digest() }); + else if (auto num = item.asInteger()) + items.emplace_back(AcknowledgedSingle { static_cast<uint64_t>(*num) }); + } else if (item.name == "VER") { + if (auto ver = item.asText()) + items.emplace_back(Version { *ver }); + } else if (item.name == "INI") { + if (auto ref = item.asRef()) + items.emplace_back(Initiation { ref->digest() }); + } else if (item.name == "CKS") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieSet { *cookie }); + } else if (item.name == "CKE") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieEcho { *cookie }); + } else if (item.name == "REQ") { + if (auto ref = item.asRef()) + items.emplace_back(DataRequest { ref->digest() }); + } else if (item.name == "RSP") { + if (auto ref = item.asRef()) + items.emplace_back(DataResponse { ref->digest() }); + } else if (item.name == "ANN") { + if (auto ref = item.asRef()) + items.emplace_back(AnnounceSelf { ref->digest() }); + } else if (item.name == "ANU") { + if (auto ref = item.asRef()) + items.emplace_back(AnnounceUpdate { ref->digest() }); + } else if (item.name == "CRQ") { + if (auto ref = item.asRef()) + items.emplace_back(ChannelRequest { ref->digest() }); + } else if (item.name == "CAC") { + if (auto ref = item.asRef()) + items.emplace_back(ChannelAccept { ref->digest() }); + } else if (item.name == "SVT") { + if (auto val = item.asUUID()) + items.emplace_back(ServiceType { *val }); + } else if (item.name == "SVR") { + if (auto ref = item.asRef()) + items.emplace_back(ServiceRef { ref->digest() }); + } else if (item.name == "STO") { + if (auto num = item.asInteger()) + items.emplace_back( StreamOpen{ static_cast< uint8_t >( *num )}); + } + } + + return NetworkProtocol::Header(items); +} + +PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const +{ + vector<PartialRecord::Item> ritems; + + for (const auto & item : items) { + if (const auto * ptr = get_if<Acknowledged>(&item)) + ritems.emplace_back("ACK", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AcknowledgedSingle>(&item)) + ritems.emplace_back("ACK", Record::Item::Integer(ptr->value)); + + else if (const auto * ptr = get_if<Version>(&item)) + ritems.emplace_back("VER", ptr->value); + + else if (const auto * ptr = get_if<Initiation>(&item)) + ritems.emplace_back("INI", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<CookieSet>(&item)) + ritems.emplace_back("CKS", ptr->value.value); + + else if (const auto * ptr = get_if<CookieEcho>(&item)) + ritems.emplace_back("CKE", ptr->value.value); + + else if (const auto * ptr = get_if<DataRequest>(&item)) + ritems.emplace_back("REQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<DataResponse>(&item)) + ritems.emplace_back("RSP", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceSelf>(&item)) + ritems.emplace_back("ANN", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<AnnounceUpdate>(&item)) + ritems.emplace_back("ANU", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelRequest>(&item)) + ritems.emplace_back("CRQ", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ChannelAccept>(&item)) + ritems.emplace_back("CAC", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<ServiceType>(&item)) + ritems.emplace_back("SVT", ptr->value); + + else if (const auto * ptr = get_if<ServiceRef>(&item)) + ritems.emplace_back("SVR", st.ref(ptr->value)); + + else if (const auto * ptr = get_if< StreamOpen >( &item )) + ritems.emplace_back("STO", Record::Item::Integer( ptr->value )); + } + + return PartialObject(PartialRecord(std::move(ritems))); +} + +bool NetworkProtocol::Header::isAcknowledged() const +{ + for (const auto & item : items) { + if (holds_alternative<Acknowledged>(item) + || holds_alternative<AcknowledgedSingle>(item) + || holds_alternative<Version>(item) + || holds_alternative<Initiation>(item) + || holds_alternative<CookieSet>(item) + || holds_alternative<CookieEcho>(item) + ) + continue; + + return true; + } + return false; +} + +} diff --git a/src/network/protocol.h b/src/network/protocol.h new file mode 100644 index 0000000..d32b20b --- /dev/null +++ b/src/network/protocol.h @@ -0,0 +1,277 @@ +#pragma once + +#include "channel.h" + +#include <erebos/storage.h> + +#include <netinet/in.h> + +#include <cstdint> +#include <memory> +#include <mutex> +#include <variant> +#include <vector> +#include <optional> + +namespace erebos { + +using std::mutex; +using std::optional; +using std::unique_ptr; +using std::variant; +using std::vector; + +class NetworkProtocol +{ +public: + NetworkProtocol(); + explicit NetworkProtocol(int sock, Identity self); + NetworkProtocol(const NetworkProtocol &) = delete; + NetworkProtocol(NetworkProtocol &&); + NetworkProtocol & operator=(const NetworkProtocol &) = delete; + NetworkProtocol & operator=(NetworkProtocol &&); + ~NetworkProtocol(); + + static constexpr char defaultVersion[] = "0.1"; + + class Connection; + class Stream; + class InStream; + class OutStream; + + struct Header; + struct StreamData; + + struct ReceivedAnnounce; + struct NewConnection; + struct ConnectionReadReady; + struct ProtocolClosed {}; + + using PollResult = variant< + ReceivedAnnounce, + NewConnection, + ConnectionReadReady, + ProtocolClosed>; + + PollResult poll(); + + struct Cookie { vector<uint8_t> value; }; + + using ChannelState = variant<monostate, + Stored<ChannelRequest>, + shared_ptr<struct WaitingRef>, + Stored<ChannelAccept>, + unique_ptr<Channel>>; + + Connection connect(sockaddr_in6 addr); + + void updateIdentity(Identity self); + void announceTo(variant<sockaddr_in, sockaddr_in6> addr); + + void shutdown(); + +private: + bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr); + void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr); + + void sendCookie(variant<sockaddr_in, sockaddr_in6> addr); + optional<Connection> verifyNewConnection(const Header & header, sockaddr_in6 addr); + + Cookie generateCookie(variant<sockaddr_in, sockaddr_in6> addr) const; + bool verifyCookie(variant<sockaddr_in, sockaddr_in6> addr, const Cookie & cookie) const; + + int sock; + + mutex protocolMutex; + vector<uint8_t> buffer; + + optional<Identity> self; + + struct ConnectionPriv; + vector<ConnectionPriv *> connections; +}; + +class NetworkProtocol::Connection +{ + friend class NetworkProtocol; + Connection(unique_ptr<ConnectionPriv> p); +public: + Connection(const Connection &) = delete; + Connection(Connection &&); + Connection & operator=(const Connection &) = delete; + Connection & operator=(Connection &&); + ~Connection(); + + using Id = uintptr_t; + Id id() const; + + const sockaddr_in6 & peerAddress() const; + size_t mtu() const; + + optional<Header> receive(const PartialStorage &); + bool send(const PartialStorage &, NetworkProtocol::Header, + const vector<Object> &, bool secure); + bool send( const StreamData & chunk ); + + void close(); + + shared_ptr< InStream > openInStream( uint8_t sid ); + shared_ptr< OutStream > openOutStream(); + + // temporary: + ChannelState & channel(); + void trySendOutQueue(); + +private: + static variant< monostate, Header, StreamData > + parsePacket(vector<uint8_t> & buf, + Channel * channel, const PartialStorage & st, + optional<uint64_t> & secure); + + unique_ptr<ConnectionPriv> p; +}; + +class NetworkProtocol::Stream +{ + friend class NetworkProtocol; + friend class NetworkProtocol::Connection; + +protected: + Stream(uint8_t id_); + +public: + void close(); + +protected: + bool hasDataLocked() const; + + size_t writeLocked( const uint8_t * buf, size_t size ); + size_t readLocked( uint8_t * buf, size_t size ); + +public: + const uint8_t id; + +protected: + bool closed { false }; + vector< uint8_t > writeBuffer; + vector< uint8_t > readBuffer; + vector< uint8_t >::const_iterator readPtr; + mutable mutex streamMutex; +}; + +class NetworkProtocol::InStream : public NetworkProtocol::Stream +{ + friend class NetworkProtocol; + friend class NetworkProtocol::Connection; + +protected: + InStream(uint8_t id): Stream( id ) {} + +public: + bool isComplete() const; + vector< uint8_t > readAll(); + size_t read( uint8_t * buf, size_t size ); + +protected: + void writeChunk( StreamData chunk ); + bool tryUseChunkLocked( const StreamData & chunk ); + +private: + uint64_t nextSequence { 0 }; + vector< StreamData > outOfOrderChunks; +}; + +class NetworkProtocol::OutStream : public NetworkProtocol::Stream +{ + friend class NetworkProtocol; + friend class NetworkProtocol::Connection; + +protected: + OutStream(uint8_t id): Stream( id ) {} + +public: + size_t write( const uint8_t * buf, size_t size ); + +private: + StreamData getNextChunkLocked( size_t size ); + + uint64_t nextSequence { 0 }; +}; + +struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; }; +struct NetworkProtocol::NewConnection { Connection conn; }; +struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; + +struct NetworkProtocol::Header +{ + struct Acknowledged { Digest value; }; + struct AcknowledgedSingle { uint64_t value; }; + struct Version { string value; }; + struct Initiation { Digest value; }; + struct CookieSet { Cookie value; }; + struct CookieEcho { Cookie value; }; + struct DataRequest { Digest value; }; + struct DataResponse { Digest value; }; + struct AnnounceSelf { Digest value; }; + struct AnnounceUpdate { Digest value; }; + struct ChannelRequest { Digest value; }; + struct ChannelAccept { Digest value; }; + struct ServiceType { UUID value; }; + struct ServiceRef { Digest value; }; + struct StreamOpen { uint8_t value; }; + + using Item = variant< + Acknowledged, + AcknowledgedSingle, + Version, + Initiation, + CookieSet, + CookieEcho, + DataRequest, + DataResponse, + AnnounceSelf, + AnnounceUpdate, + ChannelRequest, + ChannelAccept, + ServiceType, + ServiceRef, + StreamOpen>; + + static constexpr size_t itemSize = 78; /* estimate for size of ref-containing headers */ + + Header(const vector<Item> & items): items(items) {} + static optional<Header> load(const PartialRef &); + static optional<Header> load(const PartialObject &); + PartialObject toObject(const PartialStorage &) const; + + template<class T> const T * lookupFirst() const; + bool isAcknowledged() const; + + vector<Item> items; +}; + +struct NetworkProtocol::StreamData +{ + uint8_t id; + uint8_t sequence; + vector< uint8_t > data; +}; + +template<class T> +const T * NetworkProtocol::Header::lookupFirst() const +{ + for (const auto & h : items) + if (auto ptr = std::get_if<T>(&h)) + return ptr; + return nullptr; +} + +bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &); +inline bool operator!=(const NetworkProtocol::Header::Item & left, + const NetworkProtocol::Header::Item & right) +{ return not (left == right); } + +inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right) +{ return left.value == right.value; } + +} |