From a40f12cf820b3e11cc72f7b20046c8077ab0d0a5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Roman=20Smr=C5=BE?= <roman.smrz@seznam.cz>
Date: Sun, 27 Aug 2023 10:37:20 +0200
Subject: Network: identity announce and update in protocol object

---
 src/network.cpp          |  30 ++------------
 src/network/protocol.cpp | 101 ++++++++++++++++++++++++++++++++++++-----------
 src/network/protocol.h   |  12 ++++--
 3 files changed, 89 insertions(+), 54 deletions(-)

(limited to 'src')

diff --git a/src/network.cpp b/src/network.cpp
index 7a5a804..a3d1130 100644
--- a/src/network.cpp
+++ b/src/network.cpp
@@ -109,14 +109,7 @@ void Server::addPeer(const string & node, const string & service) const
 
 	for (addrinfo * rp = result.get(); rp != nullptr; rp = rp->ai_next) {
 		if (rp->ai_family == AF_INET6) {
-			Peer & peer = p->getPeer(*(sockaddr_in6 *)rp->ai_addr);
-
-			vector<NetworkProtocol::Header::Item> header;
-			{
-				shared_lock lock(p->selfMutex);
-				header.push_back(NetworkProtocol::Header::AnnounceSelf { p->self.ref()->digest() });
-			}
-			peer.connection.send(peer.partStorage, header, {}, false);
+			p->getPeer(*(sockaddr_in6 *)rp->ai_addr);
 			return;
 		}
 	}
@@ -310,7 +303,7 @@ Server::Priv::Priv(const Head<LocalState> & local, const Identity & self):
 	if (sock < 0)
 		throw std::system_error(errno, std::generic_category());
 
-	protocol = NetworkProtocol(sock);
+	protocol = NetworkProtocol(sock, self);
 
 	int disable = 0;
 	// Should be disabled by default, but try to make sure. On platforms
@@ -415,18 +408,13 @@ void Server::Priv::doAnnounce()
 
 		if (lastAnnounce + announceInterval < now) {
 			shared_lock slock(selfMutex);
-			NetworkProtocol::Header header({
-				NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() },
-			});
-
-			vector<uint8_t> bytes = header.toObject(pst).encode();
 
 			for (const auto & in : bcastAddresses) {
 				sockaddr_in sin = {};
 				sin.sin_family = AF_INET;
 				sin.sin_addr = in;
 				sin.sin_port = htons(discoveryPort);
-				protocol.sendto(bytes, sin);
+				protocol.announceTo(sin);
 			}
 
 			lastAnnounce += announceInterval * ((now - lastAnnounce) / announceInterval);
@@ -659,17 +647,7 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head)
 	if (auto id = head->identity()) {
 		if (*id != self) {
 			self = *id;
-
-			vector<NetworkProtocol::Header::Item> hitems;
-			for (const auto & r : self.refs())
-				hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
-			for (const auto & r : self.updates())
-				hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
-
-			NetworkProtocol::Header header(hitems);
-
-			for (const auto & peer : peers)
-				peer->connection.send(peer->partStorage, header, { **self.ref() }, false);
+			protocol.updateIdentity(*id);
 		}
 	}
 }
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index ede7023..89fa327 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -13,6 +13,7 @@ using std::get_if;
 using std::holds_alternative;
 using std::move;
 using std::nullopt;
+using std::runtime_error;
 using std::scoped_lock;
 using std::visit;
 
@@ -22,6 +23,9 @@ struct NetworkProtocol::ConnectionPriv
 {
 	Connection::Id id() const;
 
+	bool send(const PartialStorage &, const Header &,
+			const vector<Object> &, bool secure);
+
 	NetworkProtocol * protocol;
 	const sockaddr_in6 peerAddress;
 
@@ -37,12 +41,14 @@ NetworkProtocol::NetworkProtocol():
 	sock(-1)
 {}
 
-NetworkProtocol::NetworkProtocol(int s):
-	sock(s)
+NetworkProtocol::NetworkProtocol(int s, Identity id):
+	sock(s),
+	self(move(id))
 {}
 
 NetworkProtocol::NetworkProtocol(NetworkProtocol && other):
-	sock(other.sock)
+	sock(other.sock),
+	self(move(other.self))
 {
 	other.sock = -1;
 }
@@ -51,6 +57,7 @@ NetworkProtocol & NetworkProtocol::operator=(NetworkProtocol && other)
 {
 	sock = other.sock;
 	other.sock = -1;
+	self = move(other.self);
 	return *this;
 }
 
@@ -94,10 +101,59 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
 		.protocol = this,
 		.peerAddress = addr,
 	});
-	connections.push_back(conn.get());
+
+	{
+		scoped_lock lock(protocolMutex);
+		connections.push_back(conn.get());
+
+		vector<Header::Item> header {
+			Header::AnnounceSelf { self->ref()->digest() },
+		};
+		conn->send(self->ref()->storage(), header, {}, false);
+	}
+
 	return Connection(move(conn));
 }
 
+void NetworkProtocol::updateIdentity(Identity id)
+{
+	scoped_lock lock(protocolMutex);
+	self = move(id);
+
+	vector<Header::Item> hitems;
+	for (const auto & r : self->refs())
+		hitems.push_back(Header::AnnounceUpdate { r.digest() });
+	for (const auto & r : self->updates())
+		hitems.push_back(Header::AnnounceUpdate { r.digest() });
+
+	Header header(hitems);
+
+	for (const auto & conn : connections)
+		conn->send(self->ref()->storage(), header, { **self->ref() }, false);
+}
+
+void NetworkProtocol::announceTo(variant<sockaddr_in, sockaddr_in6> addr)
+{
+	vector<uint8_t> bytes;
+	{
+		scoped_lock lock(protocolMutex);
+
+		if (!self)
+			throw runtime_error("NetworkProtocol::announceTo without self identity");
+
+		bytes = Header({
+			Header::AnnounceSelf { self->ref()->digest() },
+		}).toObject(self->ref()->storage()).encode();
+	}
+
+	sendto(bytes, addr);
+}
+
+void NetworkProtocol::shutdown()
+{
+	::shutdown(sock, SHUT_RDWR);
+}
+
 bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr)
 {
 	socklen_t addrlen = sizeof(addr);
@@ -113,24 +169,14 @@ bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr)
 	return true;
 }
 
-void NetworkProtocol::sendto(const vector<uint8_t> & buffer, sockaddr_in addr)
+void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> vaddr)
 {
-	::sendto(sock, buffer.data(), buffer.size(), 0,
-			(sockaddr *) &addr, sizeof(addr));
+	visit([&](auto && addr) {
+		::sendto(sock, buffer.data(), buffer.size(), 0,
+				(sockaddr *) &addr, sizeof(addr));
+	}, vaddr);
 }
 
-void NetworkProtocol::sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr)
-{
-	::sendto(sock, buffer.data(), buffer.size(), 0,
-			(sockaddr *) &addr, sizeof(addr));
-}
-
-void NetworkProtocol::shutdown()
-{
-	::shutdown(sock, SHUT_RDWR);
-}
-
-
 /******************************************************************************/
 /* Connection                                                                 */
 /******************************************************************************/
@@ -241,15 +287,22 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
 bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
 		const Header & header,
 		const vector<Object> & objs, bool secure)
+{
+	return p->send(partStorage, header, objs, secure);
+}
+
+bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
+		const Header & header,
+		const vector<Object> & objs, bool secure)
 {
 	vector<uint8_t> data, part, out;
 
 	{
-		scoped_lock clock(p->cmutex);
+		scoped_lock clock(cmutex);
 
 		Channel * channel = nullptr;
-		if (holds_alternative<unique_ptr<Channel>>(p->channel))
-			channel = std::get<unique_ptr<Channel>>(p->channel).get();
+		if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel))
+			channel = uptr->get();
 
 		if (channel || secure)
 			data.push_back(0x00);
@@ -265,14 +318,14 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
 			out.push_back(0x80);
 			channel->encrypt(data.begin(), data.end(), out, 1);
 		} else if (secure) {
-			p->secureOutQueue.emplace_back(move(data));
+			secureOutQueue.emplace_back(move(data));
 		} else {
 			out = std::move(data);
 		}
 	}
 
 	if (not out.empty())
-		p->protocol->sendto(out, p->peerAddress);
+		protocol->sendto(out, peerAddress);
 
 	return true;
 }
diff --git a/src/network/protocol.h b/src/network/protocol.h
index df29c05..51f8d59 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -25,7 +25,7 @@ class NetworkProtocol
 {
 public:
 	NetworkProtocol();
-	explicit NetworkProtocol(int sock);
+	explicit NetworkProtocol(int sock, Identity self);
 	NetworkProtocol(const NetworkProtocol &) = delete;
 	NetworkProtocol(NetworkProtocol &&);
 	NetworkProtocol & operator=(const NetworkProtocol &) = delete;
@@ -55,18 +55,22 @@ public:
 
 	Connection connect(sockaddr_in6 addr);
 
-	bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);
-	void sendto(const vector<uint8_t> & buffer, sockaddr_in addr);
-	void sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr);
+	void updateIdentity(Identity self);
+	void announceTo(variant<sockaddr_in, sockaddr_in6> addr);
 
 	void shutdown();
 
 private:
+	bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);
+	void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr);
+
 	int sock;
 
 	mutex protocolMutex;
 	vector<uint8_t> buffer;
 
+	optional<Identity> self;
+
 	struct ConnectionPriv;
 	vector<ConnectionPriv *> connections;
 };
-- 
cgit v1.2.3