From 2ed8103ff1c0fca7372b3c3888f590ba41c525e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sun, 13 Aug 2023 19:01:48 +0200 Subject: Connection class for network protocol --- src/network.cpp | 87 ++++++++++++++++++++++++-------- src/network.h | 6 ++- src/network/protocol.cpp | 129 +++++++++++++++++++++++++++++++++++++++++++++++ src/network/protocol.h | 55 ++++++++++++++++++++ 4 files changed, 253 insertions(+), 24 deletions(-) diff --git a/src/network.cpp b/src/network.cpp index b5dfd68..786e752 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -175,7 +175,7 @@ optional Peer::identity() const const sockaddr_in6 & Peer::address() const { if (auto speer = p->speer.lock()) - return speer->addr; + return speer->connection.peerAddress(); throw runtime_error("Server no longer running"); } @@ -373,36 +373,49 @@ void Server::Priv::doListen() for (; !finish; lock.lock()) { lock.unlock(); - sockaddr_in6 paddr; - if (not protocol.recvfrom(buf, paddr)) + Peer * peer = nullptr; + auto res = protocol.poll(); + + if (holds_alternative(res)) break; - if (isSelfAddress(paddr)) + if (holds_alternative(res)) { + auto & conn = get(res).conn; + if (not isSelfAddress(conn.peerAddress())) + peer = &addPeer(move(conn)); + } + + if (holds_alternative(res)) { + peer = findPeer(get(res).id); + } + + if (!peer) continue; - auto & peer = getPeer(paddr); + if (not peer->connection.receive(buf)) + continue; current = &buf; - if (holds_alternative>(peer.channel)) { - if (auto dec = std::get>(peer.channel)->decrypt(buf)) { + if (holds_alternative>(peer->channel)) { + if (auto dec = std::get>(peer->channel)->decrypt(buf)) { decrypted = std::move(*dec); current = &decrypted; } - } else if (holds_alternative>(peer.channel)) { - if (auto dec = std::get>(peer.channel)-> + } else if (holds_alternative>(peer->channel)) { + if (auto dec = std::get>(peer->channel)-> data->channel()->decrypt(buf)) { decrypted = std::move(*dec); current = &decrypted; } } - if (auto dec = PartialObject::decodePrefix(peer.partStorage, + if (auto dec = PartialObject::decodePrefix(peer->partStorage, current->begin(), current->end())) { if (auto header = TransportHeader::load(std::get(*dec))) { auto pos = std::get<1>(*dec); - while (auto cdec = PartialObject::decodePrefix(peer.partStorage, + while (auto cdec = PartialObject::decodePrefix(peer->partStorage, pos, current->end())) { - peer.partStorage.storeObject(std::get(*cdec)); + peer->partStorage.storeObject(std::get(*cdec)); pos = std::get<1>(*cdec); } @@ -411,15 +424,15 @@ void Server::Priv::doListen() scoped_lock hlock(dataMutex); shared_lock slock(selfMutex); - handlePacket(peer, *header, reply); - peer.updateIdentity(reply); - peer.updateChannel(reply); - peer.updateService(reply); + handlePacket(*peer, *header, reply); + peer->updateIdentity(reply); + peer->updateChannel(reply); + peer->updateService(reply); if (!reply.header().empty()) - peer.send(TransportHeader(reply.header()), reply.body(), false); + peer->send(TransportHeader(reply.header()), reply.body(), false); - peer.trySendOutQueue(); + peer->trySendOutQueue(); } } else { std::cerr << "invalid packet\n"; @@ -468,18 +481,48 @@ bool Server::Priv::isSelfAddress(const sockaddr_in6 & paddr) return false; } +Server::Peer * Server::Priv::findPeer(NetworkProtocol::Connection::Id cid) const +{ + scoped_lock lock(dataMutex); + + for (auto & peer : peers) + if (peer->connection.id() == cid) + return peer.get(); + + return nullptr; +} + Server::Peer & Server::Priv::getPeer(const sockaddr_in6 & paddr) { scoped_lock lock(dataMutex); for (auto & peer : peers) - if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0) + if (memcmp(&peer->connection.peerAddress(), &paddr, sizeof paddr) == 0) return *peer; auto st = self.ref()->storage().deriveEphemeralStorage(); shared_ptr peer(new Peer { .server = *this, - .addr = paddr, + .connection = protocol.connect(paddr), + .identity = monostate(), + .identityUpdates = {}, + .channel = monostate(), + .tempStorage = st, + .partStorage = st.derivePartialStorage(), + }); + peers.push_back(peer); + plist.p->push(peer); + return *peer; +} + +Server::Peer & Server::Priv::addPeer(NetworkProtocol::Connection conn) +{ + scoped_lock lock(dataMutex); + + auto st = self.ref()->storage().deriveEphemeralStorage(); + shared_ptr peer(new Peer { + .server = *this, + .connection = move(conn), .identity = monostate(), .identityUpdates = {}, .channel = monostate(), @@ -695,7 +738,7 @@ void Server::Peer::send(const TransportHeader & header, const vector & o out = std::move(data); if (!out.empty()) - server.protocol.sendto(out, addr); + connection.send(out); } void Server::Peer::updateIdentity(ReplyBuilder &) @@ -831,7 +874,7 @@ void Server::Peer::trySendOutQueue() for (const auto & data : secureOutQueue) { auto out = std::get>(channel)->encrypt(data); - server.protocol.sendto(out, addr); + connection.send(out); } secureOutQueue.clear(); diff --git a/src/network.h b/src/network.h index c242ac5..74231bf 100644 --- a/src/network.h +++ b/src/network.h @@ -44,7 +44,7 @@ struct Server::Peer Peer & operator=(const Peer &) = delete; Priv & server; - const sockaddr_in6 addr; + NetworkProtocol::Connection connection; variant, @@ -157,7 +157,9 @@ struct Server::Priv void doAnnounce(); bool isSelfAddress(const sockaddr_in6 & paddr); + Peer * findPeer(NetworkProtocol::Connection::Id cid) const; Peer & getPeer(const sockaddr_in6 & paddr); + Peer & addPeer(NetworkProtocol::Connection conn); void handlePacket(Peer &, const TransportHeader &, ReplyBuilder &); void handleLocalHeadChange(const Head &); @@ -165,7 +167,7 @@ struct Server::Priv constexpr static uint16_t discoveryPort { 29665 }; constexpr static chrono::seconds announceInterval { 60 }; - mutex dataMutex; + mutable mutex dataMutex; condition_variable announceCondvar; bool finish = false; diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 63cfde5..c247bf0 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -3,10 +3,27 @@ #include #include +#include +#include #include +using std::move; +using std::scoped_lock; + namespace erebos { +struct NetworkProtocol::ConnectionPriv +{ + Connection::Id id() const; + + NetworkProtocol * protocol; + const sockaddr_in6 peerAddress; + + mutex cmutex {}; + vector buffer {}; +}; + + NetworkProtocol::NetworkProtocol(): sock(-1) {} @@ -32,6 +49,44 @@ NetworkProtocol::~NetworkProtocol() { if (sock >= 0) close(sock); + + for (auto & c : connections) + c->protocol = nullptr; +} + +NetworkProtocol::PollResult NetworkProtocol::poll() +{ + sockaddr_in6 addr; + if (!recvfrom(buffer, addr)) + return ProtocolClosed {}; + + scoped_lock lock(protocolMutex); + for (const auto & c : connections) { + if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { + scoped_lock clock(c->cmutex); + buffer.swap(c->buffer); + return ConnectionReadReady { c->id() }; + } + } + + auto conn = unique_ptr(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + connections.push_back(conn.get()); + buffer.swap(conn->buffer); + return NewConnection { Connection(move(conn)) }; +} + +NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) +{ + auto conn = unique_ptr(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + connections.push_back(conn.get()); + return Connection(move(conn)); } bool NetworkProtocol::recvfrom(vector & buffer, sockaddr_in6 & addr) @@ -66,4 +121,78 @@ void NetworkProtocol::shutdown() ::shutdown(sock, SHUT_RDWR); } + +NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const +{ + return reinterpret_cast(this); +} + +NetworkProtocol::Connection::Connection(unique_ptr p_): + p(move(p_)) +{ +} + +NetworkProtocol::Connection::Connection(Connection && other): + p(move(other.p)) +{ +} + +NetworkProtocol::Connection & NetworkProtocol::Connection::operator=(Connection && other) +{ + close(); + p = move(other.p); + return *this; +} + +NetworkProtocol::Connection::~Connection() +{ + close(); +} + +NetworkProtocol::Connection::Id NetworkProtocol::Connection::id() const +{ + return p->id(); +} + +const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const +{ + return p->peerAddress; +} + +bool NetworkProtocol::Connection::receive(vector & buffer) +{ + scoped_lock lock(p->cmutex); + if (p->buffer.empty()) + return false; + + buffer.swap(p->buffer); + p->buffer.clear(); + return true; +} + +bool NetworkProtocol::Connection::send(const vector & buffer) +{ + p->protocol->sendto(buffer, p->peerAddress); + return true; +} + +void NetworkProtocol::Connection::close() +{ + if (not p) + return; + + if (p->protocol) { + scoped_lock lock(p->protocol->protocolMutex); + for (auto it = p->protocol->connections.begin(); + it != p->protocol->connections.end(); it++) { + if ((*it) == p.get()) { + p->protocol->connections.erase(it); + break; + } + } + } + + p = nullptr; +} + } diff --git a/src/network/protocol.h b/src/network/protocol.h index 6a22f3b..a9bbaff 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -3,10 +3,16 @@ #include #include +#include +#include +#include #include namespace erebos { +using std::mutex; +using std::unique_ptr; +using std::variant; using std::vector; class NetworkProtocol @@ -20,6 +26,21 @@ public: NetworkProtocol & operator=(NetworkProtocol &&); ~NetworkProtocol(); + class Connection; + + struct NewConnection; + struct ConnectionReadReady; + struct ProtocolClosed {}; + + using PollResult = variant< + NewConnection, + ConnectionReadReady, + ProtocolClosed>; + + PollResult poll(); + + Connection connect(sockaddr_in6 addr); + bool recvfrom(vector & buffer, sockaddr_in6 & addr); void sendto(const vector & buffer, sockaddr_in addr); void sendto(const vector & buffer, sockaddr_in6 addr); @@ -28,6 +49,40 @@ public: private: int sock; + + mutex protocolMutex; + vector buffer; + + struct ConnectionPriv; + vector connections; +}; + +class NetworkProtocol::Connection +{ + friend class NetworkProtocol; + Connection(unique_ptr p); +public: + Connection(const Connection &) = delete; + Connection(Connection &&); + Connection & operator=(const Connection &) = delete; + Connection & operator=(Connection &&); + ~Connection(); + + using Id = uintptr_t; + Id id() const; + + const sockaddr_in6 & peerAddress() const; + + bool receive(vector & buffer); + bool send(const vector & buffer); + + void close(); + +private: + unique_ptr p; }; +struct NetworkProtocol::NewConnection { Connection conn; }; +struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; + } -- cgit v1.2.3