#pragma once #include "channel.h" #include #include #include #include #include #include #include #include namespace erebos { using std::mutex; using std::optional; using std::unique_ptr; using std::variant; using std::vector; class NetworkProtocol { public: NetworkProtocol(); explicit NetworkProtocol(int sock, Identity self); NetworkProtocol(const NetworkProtocol &) = delete; NetworkProtocol(NetworkProtocol &&); NetworkProtocol & operator=(const NetworkProtocol &) = delete; NetworkProtocol & operator=(NetworkProtocol &&); ~NetworkProtocol(); static constexpr char defaultVersion[] = "0.1"; class Connection; class Stream; class InStream; class OutStream; struct Header; struct StreamData; struct ReceivedAnnounce; struct NewConnection; struct ConnectionReadReady; struct ProtocolClosed {}; using PollResult = variant< ReceivedAnnounce, NewConnection, ConnectionReadReady, ProtocolClosed>; PollResult poll(); struct Cookie { vector value; }; using ChannelState = variant, shared_ptr, Stored, unique_ptr>; Connection connect(sockaddr_in6 addr); void updateIdentity(Identity self); void announceTo(variant addr); void shutdown(); private: bool recvfrom(vector & buffer, sockaddr_in6 & addr); void sendto(const vector & buffer, variant addr); void sendCookie(variant addr); optional verifyNewConnection(const Header & header, sockaddr_in6 addr); Cookie generateCookie(variant addr) const; bool verifyCookie(variant addr, const Cookie & cookie) const; int sock; mutex protocolMutex; vector buffer; optional self; struct ConnectionPriv; vector connections; }; class NetworkProtocol::Connection { friend class NetworkProtocol; Connection(unique_ptr p); public: Connection(const Connection &) = delete; Connection(Connection &&); Connection & operator=(const Connection &) = delete; Connection & operator=(Connection &&); ~Connection(); using Id = uintptr_t; Id id() const; const sockaddr_in6 & peerAddress() const; size_t mtu() const; optional
receive(const PartialStorage &); bool send(const PartialStorage &, NetworkProtocol::Header, const vector &, bool secure); bool send( const StreamData & chunk ); void close(); shared_ptr< InStream > openInStream( uint8_t sid ); shared_ptr< OutStream > openOutStream(); // temporary: ChannelState & channel(); void trySendOutQueue(); private: static variant< monostate, Header, StreamData > parsePacket(vector & buf, Channel * channel, const PartialStorage & st, optional & secure); unique_ptr p; }; class NetworkProtocol::Stream { friend class NetworkProtocol; friend class NetworkProtocol::Connection; protected: Stream(uint8_t id_); public: void close(); protected: bool hasDataLocked() const; size_t writeLocked( const uint8_t * buf, size_t size ); size_t readLocked( uint8_t * buf, size_t size ); public: const uint8_t id; protected: bool closed { false }; vector< uint8_t > writeBuffer; vector< uint8_t > readBuffer; vector< uint8_t >::const_iterator readPtr; mutable mutex streamMutex; }; class NetworkProtocol::InStream : public NetworkProtocol::Stream { friend class NetworkProtocol; friend class NetworkProtocol::Connection; protected: InStream(uint8_t id): Stream( id ) {} public: bool isComplete() const; vector< uint8_t > readAll(); size_t read( uint8_t * buf, size_t size ); protected: void writeChunk( StreamData chunk ); bool tryUseChunkLocked( const StreamData & chunk ); private: uint64_t nextSequence { 0 }; vector< StreamData > outOfOrderChunks; }; class NetworkProtocol::OutStream : public NetworkProtocol::Stream { friend class NetworkProtocol; friend class NetworkProtocol::Connection; protected: OutStream(uint8_t id): Stream( id ) {} public: size_t write( const uint8_t * buf, size_t size ); private: StreamData getNextChunkLocked( size_t size ); uint64_t nextSequence { 0 }; }; struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; }; struct NetworkProtocol::NewConnection { Connection conn; }; struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; struct NetworkProtocol::Header { struct Acknowledged { Digest value; }; struct AcknowledgedSingle { uint64_t value; }; struct Version { string value; }; struct Initiation { Digest value; }; struct CookieSet { Cookie value; }; struct CookieEcho { Cookie value; }; struct DataRequest { Digest value; }; struct DataResponse { Digest value; }; struct AnnounceSelf { Digest value; }; struct AnnounceUpdate { Digest value; }; struct ChannelRequest { Digest value; }; struct ChannelAccept { Digest value; }; struct ServiceType { UUID value; }; struct ServiceRef { Digest value; }; struct StreamOpen { uint8_t value; }; using Item = variant< Acknowledged, AcknowledgedSingle, Version, Initiation, CookieSet, CookieEcho, DataRequest, DataResponse, AnnounceSelf, AnnounceUpdate, ChannelRequest, ChannelAccept, ServiceType, ServiceRef, StreamOpen>; static constexpr size_t itemSize = 78; /* estimate for size of ref-containing headers */ Header(const vector & items): items(items) {} static optional
load(const PartialRef &); static optional
load(const PartialObject &); PartialObject toObject(const PartialStorage &) const; template const T * lookupFirst() const; bool isAcknowledged() const; vector items; }; struct NetworkProtocol::StreamData { uint8_t id; uint8_t sequence; vector< uint8_t > data; }; template const T * NetworkProtocol::Header::lookupFirst() const { for (const auto & h : items) if (auto ptr = std::get_if(&h)) return ptr; return nullptr; } bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &); inline bool operator!=(const NetworkProtocol::Header::Item & left, const NetworkProtocol::Header::Item & right) { return not (left == right); } inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right) { return left.value == right.value; } }