summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Smrž <roman.smrz@seznam.cz>2023-08-13 19:01:48 +0200
committerRoman Smrž <roman.smrz@seznam.cz>2023-08-16 21:49:39 +0200
commit2ed8103ff1c0fca7372b3c3888f590ba41c525e6 (patch)
tree103834746f4b64c7dbaf4a237447108cdf44c8d9
parent7420a170928da75cb860e3fc8804416babdeec8c (diff)
Connection class for network protocol
-rw-r--r--src/network.cpp87
-rw-r--r--src/network.h6
-rw-r--r--src/network/protocol.cpp129
-rw-r--r--src/network/protocol.h55
4 files changed, 253 insertions, 24 deletions
diff --git a/src/network.cpp b/src/network.cpp
index b5dfd68..786e752 100644
--- a/src/network.cpp
+++ b/src/network.cpp
@@ -175,7 +175,7 @@ optional<Identity> Peer::identity() const
const sockaddr_in6 & Peer::address() const
{
if (auto speer = p->speer.lock())
- return speer->addr;
+ return speer->connection.peerAddress();
throw runtime_error("Server no longer running");
}
@@ -373,36 +373,49 @@ void Server::Priv::doListen()
for (; !finish; lock.lock()) {
lock.unlock();
- sockaddr_in6 paddr;
- if (not protocol.recvfrom(buf, paddr))
+ Peer * peer = nullptr;
+ auto res = protocol.poll();
+
+ if (holds_alternative<NetworkProtocol::ProtocolClosed>(res))
break;
- if (isSelfAddress(paddr))
+ if (holds_alternative<NetworkProtocol::NewConnection>(res)) {
+ auto & conn = get<NetworkProtocol::NewConnection>(res).conn;
+ if (not isSelfAddress(conn.peerAddress()))
+ peer = &addPeer(move(conn));
+ }
+
+ if (holds_alternative<NetworkProtocol::ConnectionReadReady>(res)) {
+ peer = findPeer(get<NetworkProtocol::ConnectionReadReady>(res).id);
+ }
+
+ if (!peer)
continue;
- auto & peer = getPeer(paddr);
+ if (not peer->connection.receive(buf))
+ continue;
current = &buf;
- if (holds_alternative<unique_ptr<Channel>>(peer.channel)) {
- if (auto dec = std::get<unique_ptr<Channel>>(peer.channel)->decrypt(buf)) {
+ if (holds_alternative<unique_ptr<Channel>>(peer->channel)) {
+ if (auto dec = std::get<unique_ptr<Channel>>(peer->channel)->decrypt(buf)) {
decrypted = std::move(*dec);
current = &decrypted;
}
- } else if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) {
- if (auto dec = std::get<Stored<ChannelAccept>>(peer.channel)->
+ } else if (holds_alternative<Stored<ChannelAccept>>(peer->channel)) {
+ if (auto dec = std::get<Stored<ChannelAccept>>(peer->channel)->
data->channel()->decrypt(buf)) {
decrypted = std::move(*dec);
current = &decrypted;
}
}
- if (auto dec = PartialObject::decodePrefix(peer.partStorage,
+ if (auto dec = PartialObject::decodePrefix(peer->partStorage,
current->begin(), current->end())) {
if (auto header = TransportHeader::load(std::get<PartialObject>(*dec))) {
auto pos = std::get<1>(*dec);
- while (auto cdec = PartialObject::decodePrefix(peer.partStorage,
+ while (auto cdec = PartialObject::decodePrefix(peer->partStorage,
pos, current->end())) {
- peer.partStorage.storeObject(std::get<PartialObject>(*cdec));
+ peer->partStorage.storeObject(std::get<PartialObject>(*cdec));
pos = std::get<1>(*cdec);
}
@@ -411,15 +424,15 @@ void Server::Priv::doListen()
scoped_lock hlock(dataMutex);
shared_lock slock(selfMutex);
- handlePacket(peer, *header, reply);
- peer.updateIdentity(reply);
- peer.updateChannel(reply);
- peer.updateService(reply);
+ handlePacket(*peer, *header, reply);
+ peer->updateIdentity(reply);
+ peer->updateChannel(reply);
+ peer->updateService(reply);
if (!reply.header().empty())
- peer.send(TransportHeader(reply.header()), reply.body(), false);
+ peer->send(TransportHeader(reply.header()), reply.body(), false);
- peer.trySendOutQueue();
+ peer->trySendOutQueue();
}
} else {
std::cerr << "invalid packet\n";
@@ -468,18 +481,48 @@ bool Server::Priv::isSelfAddress(const sockaddr_in6 & paddr)
return false;
}
+Server::Peer * Server::Priv::findPeer(NetworkProtocol::Connection::Id cid) const
+{
+ scoped_lock lock(dataMutex);
+
+ for (auto & peer : peers)
+ if (peer->connection.id() == cid)
+ return peer.get();
+
+ return nullptr;
+}
+
Server::Peer & Server::Priv::getPeer(const sockaddr_in6 & paddr)
{
scoped_lock lock(dataMutex);
for (auto & peer : peers)
- if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0)
+ if (memcmp(&peer->connection.peerAddress(), &paddr, sizeof paddr) == 0)
return *peer;
auto st = self.ref()->storage().deriveEphemeralStorage();
shared_ptr<Peer> peer(new Peer {
.server = *this,
- .addr = paddr,
+ .connection = protocol.connect(paddr),
+ .identity = monostate(),
+ .identityUpdates = {},
+ .channel = monostate(),
+ .tempStorage = st,
+ .partStorage = st.derivePartialStorage(),
+ });
+ peers.push_back(peer);
+ plist.p->push(peer);
+ return *peer;
+}
+
+Server::Peer & Server::Priv::addPeer(NetworkProtocol::Connection conn)
+{
+ scoped_lock lock(dataMutex);
+
+ auto st = self.ref()->storage().deriveEphemeralStorage();
+ shared_ptr<Peer> peer(new Peer {
+ .server = *this,
+ .connection = move(conn),
.identity = monostate(),
.identityUpdates = {},
.channel = monostate(),
@@ -695,7 +738,7 @@ void Server::Peer::send(const TransportHeader & header, const vector<Object> & o
out = std::move(data);
if (!out.empty())
- server.protocol.sendto(out, addr);
+ connection.send(out);
}
void Server::Peer::updateIdentity(ReplyBuilder &)
@@ -831,7 +874,7 @@ void Server::Peer::trySendOutQueue()
for (const auto & data : secureOutQueue) {
auto out = std::get<unique_ptr<Channel>>(channel)->encrypt(data);
- server.protocol.sendto(out, addr);
+ connection.send(out);
}
secureOutQueue.clear();
diff --git a/src/network.h b/src/network.h
index c242ac5..74231bf 100644
--- a/src/network.h
+++ b/src/network.h
@@ -44,7 +44,7 @@ struct Server::Peer
Peer & operator=(const Peer &) = delete;
Priv & server;
- const sockaddr_in6 addr;
+ NetworkProtocol::Connection connection;
variant<monostate,
shared_ptr<struct WaitingRef>,
@@ -157,7 +157,9 @@ struct Server::Priv
void doAnnounce();
bool isSelfAddress(const sockaddr_in6 & paddr);
+ Peer * findPeer(NetworkProtocol::Connection::Id cid) const;
Peer & getPeer(const sockaddr_in6 & paddr);
+ Peer & addPeer(NetworkProtocol::Connection conn);
void handlePacket(Peer &, const TransportHeader &, ReplyBuilder &);
void handleLocalHeadChange(const Head<LocalState> &);
@@ -165,7 +167,7 @@ struct Server::Priv
constexpr static uint16_t discoveryPort { 29665 };
constexpr static chrono::seconds announceInterval { 60 };
- mutex dataMutex;
+ mutable mutex dataMutex;
condition_variable announceCondvar;
bool finish = false;
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index 63cfde5..c247bf0 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -3,10 +3,27 @@
#include <sys/socket.h>
#include <unistd.h>
+#include <cstring>
+#include <mutex>
#include <system_error>
+using std::move;
+using std::scoped_lock;
+
namespace erebos {
+struct NetworkProtocol::ConnectionPriv
+{
+ Connection::Id id() const;
+
+ NetworkProtocol * protocol;
+ const sockaddr_in6 peerAddress;
+
+ mutex cmutex {};
+ vector<uint8_t> buffer {};
+};
+
+
NetworkProtocol::NetworkProtocol():
sock(-1)
{}
@@ -32,6 +49,44 @@ NetworkProtocol::~NetworkProtocol()
{
if (sock >= 0)
close(sock);
+
+ for (auto & c : connections)
+ c->protocol = nullptr;
+}
+
+NetworkProtocol::PollResult NetworkProtocol::poll()
+{
+ sockaddr_in6 addr;
+ if (!recvfrom(buffer, addr))
+ return ProtocolClosed {};
+
+ scoped_lock lock(protocolMutex);
+ for (const auto & c : connections) {
+ if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) {
+ scoped_lock clock(c->cmutex);
+ buffer.swap(c->buffer);
+ return ConnectionReadReady { c->id() };
+ }
+ }
+
+ auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv {
+ .protocol = this,
+ .peerAddress = addr,
+ });
+
+ connections.push_back(conn.get());
+ buffer.swap(conn->buffer);
+ return NewConnection { Connection(move(conn)) };
+}
+
+NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
+{
+ auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv {
+ .protocol = this,
+ .peerAddress = addr,
+ });
+ connections.push_back(conn.get());
+ return Connection(move(conn));
}
bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr)
@@ -66,4 +121,78 @@ void NetworkProtocol::shutdown()
::shutdown(sock, SHUT_RDWR);
}
+
+NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const
+{
+ return reinterpret_cast<uintptr_t>(this);
+}
+
+NetworkProtocol::Connection::Connection(unique_ptr<ConnectionPriv> p_):
+ p(move(p_))
+{
+}
+
+NetworkProtocol::Connection::Connection(Connection && other):
+ p(move(other.p))
+{
+}
+
+NetworkProtocol::Connection & NetworkProtocol::Connection::operator=(Connection && other)
+{
+ close();
+ p = move(other.p);
+ return *this;
+}
+
+NetworkProtocol::Connection::~Connection()
+{
+ close();
+}
+
+NetworkProtocol::Connection::Id NetworkProtocol::Connection::id() const
+{
+ return p->id();
+}
+
+const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const
+{
+ return p->peerAddress;
+}
+
+bool NetworkProtocol::Connection::receive(vector<uint8_t> & buffer)
+{
+ scoped_lock lock(p->cmutex);
+ if (p->buffer.empty())
+ return false;
+
+ buffer.swap(p->buffer);
+ p->buffer.clear();
+ return true;
+}
+
+bool NetworkProtocol::Connection::send(const vector<uint8_t> & buffer)
+{
+ p->protocol->sendto(buffer, p->peerAddress);
+ return true;
+}
+
+void NetworkProtocol::Connection::close()
+{
+ if (not p)
+ return;
+
+ if (p->protocol) {
+ scoped_lock lock(p->protocol->protocolMutex);
+ for (auto it = p->protocol->connections.begin();
+ it != p->protocol->connections.end(); it++) {
+ if ((*it) == p.get()) {
+ p->protocol->connections.erase(it);
+ break;
+ }
+ }
+ }
+
+ p = nullptr;
+}
+
}
diff --git a/src/network/protocol.h b/src/network/protocol.h
index 6a22f3b..a9bbaff 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -3,10 +3,16 @@
#include <netinet/in.h>
#include <cstdint>
+#include <memory>
+#include <mutex>
+#include <variant>
#include <vector>
namespace erebos {
+using std::mutex;
+using std::unique_ptr;
+using std::variant;
using std::vector;
class NetworkProtocol
@@ -20,6 +26,21 @@ public:
NetworkProtocol & operator=(NetworkProtocol &&);
~NetworkProtocol();
+ class Connection;
+
+ struct NewConnection;
+ struct ConnectionReadReady;
+ struct ProtocolClosed {};
+
+ using PollResult = variant<
+ NewConnection,
+ ConnectionReadReady,
+ ProtocolClosed>;
+
+ PollResult poll();
+
+ Connection connect(sockaddr_in6 addr);
+
bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);
void sendto(const vector<uint8_t> & buffer, sockaddr_in addr);
void sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr);
@@ -28,6 +49,40 @@ public:
private:
int sock;
+
+ mutex protocolMutex;
+ vector<uint8_t> buffer;
+
+ struct ConnectionPriv;
+ vector<ConnectionPriv *> connections;
+};
+
+class NetworkProtocol::Connection
+{
+ friend class NetworkProtocol;
+ Connection(unique_ptr<ConnectionPriv> p);
+public:
+ Connection(const Connection &) = delete;
+ Connection(Connection &&);
+ Connection & operator=(const Connection &) = delete;
+ Connection & operator=(Connection &&);
+ ~Connection();
+
+ using Id = uintptr_t;
+ Id id() const;
+
+ const sockaddr_in6 & peerAddress() const;
+
+ bool receive(vector<uint8_t> & buffer);
+ bool send(const vector<uint8_t> & buffer);
+
+ void close();
+
+private:
+ unique_ptr<ConnectionPriv> p;
};
+struct NetworkProtocol::NewConnection { Connection conn; };
+struct NetworkProtocol::ConnectionReadReady { Connection::Id id; };
+
}