summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/network.cpp30
-rw-r--r--src/network/protocol.cpp101
-rw-r--r--src/network/protocol.h12
3 files changed, 89 insertions, 54 deletions
diff --git a/src/network.cpp b/src/network.cpp
index 7a5a804..a3d1130 100644
--- a/src/network.cpp
+++ b/src/network.cpp
@@ -109,14 +109,7 @@ void Server::addPeer(const string & node, const string & service) const
for (addrinfo * rp = result.get(); rp != nullptr; rp = rp->ai_next) {
if (rp->ai_family == AF_INET6) {
- Peer & peer = p->getPeer(*(sockaddr_in6 *)rp->ai_addr);
-
- vector<NetworkProtocol::Header::Item> header;
- {
- shared_lock lock(p->selfMutex);
- header.push_back(NetworkProtocol::Header::AnnounceSelf { p->self.ref()->digest() });
- }
- peer.connection.send(peer.partStorage, header, {}, false);
+ p->getPeer(*(sockaddr_in6 *)rp->ai_addr);
return;
}
}
@@ -310,7 +303,7 @@ Server::Priv::Priv(const Head<LocalState> & local, const Identity & self):
if (sock < 0)
throw std::system_error(errno, std::generic_category());
- protocol = NetworkProtocol(sock);
+ protocol = NetworkProtocol(sock, self);
int disable = 0;
// Should be disabled by default, but try to make sure. On platforms
@@ -415,18 +408,13 @@ void Server::Priv::doAnnounce()
if (lastAnnounce + announceInterval < now) {
shared_lock slock(selfMutex);
- NetworkProtocol::Header header({
- NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() },
- });
-
- vector<uint8_t> bytes = header.toObject(pst).encode();
for (const auto & in : bcastAddresses) {
sockaddr_in sin = {};
sin.sin_family = AF_INET;
sin.sin_addr = in;
sin.sin_port = htons(discoveryPort);
- protocol.sendto(bytes, sin);
+ protocol.announceTo(sin);
}
lastAnnounce += announceInterval * ((now - lastAnnounce) / announceInterval);
@@ -659,17 +647,7 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head)
if (auto id = head->identity()) {
if (*id != self) {
self = *id;
-
- vector<NetworkProtocol::Header::Item> hitems;
- for (const auto & r : self.refs())
- hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
- for (const auto & r : self.updates())
- hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() });
-
- NetworkProtocol::Header header(hitems);
-
- for (const auto & peer : peers)
- peer->connection.send(peer->partStorage, header, { **self.ref() }, false);
+ protocol.updateIdentity(*id);
}
}
}
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index ede7023..89fa327 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -13,6 +13,7 @@ using std::get_if;
using std::holds_alternative;
using std::move;
using std::nullopt;
+using std::runtime_error;
using std::scoped_lock;
using std::visit;
@@ -22,6 +23,9 @@ struct NetworkProtocol::ConnectionPriv
{
Connection::Id id() const;
+ bool send(const PartialStorage &, const Header &,
+ const vector<Object> &, bool secure);
+
NetworkProtocol * protocol;
const sockaddr_in6 peerAddress;
@@ -37,12 +41,14 @@ NetworkProtocol::NetworkProtocol():
sock(-1)
{}
-NetworkProtocol::NetworkProtocol(int s):
- sock(s)
+NetworkProtocol::NetworkProtocol(int s, Identity id):
+ sock(s),
+ self(move(id))
{}
NetworkProtocol::NetworkProtocol(NetworkProtocol && other):
- sock(other.sock)
+ sock(other.sock),
+ self(move(other.self))
{
other.sock = -1;
}
@@ -51,6 +57,7 @@ NetworkProtocol & NetworkProtocol::operator=(NetworkProtocol && other)
{
sock = other.sock;
other.sock = -1;
+ self = move(other.self);
return *this;
}
@@ -94,10 +101,59 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
.protocol = this,
.peerAddress = addr,
});
- connections.push_back(conn.get());
+
+ {
+ scoped_lock lock(protocolMutex);
+ connections.push_back(conn.get());
+
+ vector<Header::Item> header {
+ Header::AnnounceSelf { self->ref()->digest() },
+ };
+ conn->send(self->ref()->storage(), header, {}, false);
+ }
+
return Connection(move(conn));
}
+void NetworkProtocol::updateIdentity(Identity id)
+{
+ scoped_lock lock(protocolMutex);
+ self = move(id);
+
+ vector<Header::Item> hitems;
+ for (const auto & r : self->refs())
+ hitems.push_back(Header::AnnounceUpdate { r.digest() });
+ for (const auto & r : self->updates())
+ hitems.push_back(Header::AnnounceUpdate { r.digest() });
+
+ Header header(hitems);
+
+ for (const auto & conn : connections)
+ conn->send(self->ref()->storage(), header, { **self->ref() }, false);
+}
+
+void NetworkProtocol::announceTo(variant<sockaddr_in, sockaddr_in6> addr)
+{
+ vector<uint8_t> bytes;
+ {
+ scoped_lock lock(protocolMutex);
+
+ if (!self)
+ throw runtime_error("NetworkProtocol::announceTo without self identity");
+
+ bytes = Header({
+ Header::AnnounceSelf { self->ref()->digest() },
+ }).toObject(self->ref()->storage()).encode();
+ }
+
+ sendto(bytes, addr);
+}
+
+void NetworkProtocol::shutdown()
+{
+ ::shutdown(sock, SHUT_RDWR);
+}
+
bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr)
{
socklen_t addrlen = sizeof(addr);
@@ -113,24 +169,14 @@ bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr)
return true;
}
-void NetworkProtocol::sendto(const vector<uint8_t> & buffer, sockaddr_in addr)
+void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> vaddr)
{
- ::sendto(sock, buffer.data(), buffer.size(), 0,
- (sockaddr *) &addr, sizeof(addr));
+ visit([&](auto && addr) {
+ ::sendto(sock, buffer.data(), buffer.size(), 0,
+ (sockaddr *) &addr, sizeof(addr));
+ }, vaddr);
}
-void NetworkProtocol::sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr)
-{
- ::sendto(sock, buffer.data(), buffer.size(), 0,
- (sockaddr *) &addr, sizeof(addr));
-}
-
-void NetworkProtocol::shutdown()
-{
- ::shutdown(sock, SHUT_RDWR);
-}
-
-
/******************************************************************************/
/* Connection */
/******************************************************************************/
@@ -242,14 +288,21 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
const Header & header,
const vector<Object> & objs, bool secure)
{
+ return p->send(partStorage, header, objs, secure);
+}
+
+bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
+ const Header & header,
+ const vector<Object> & objs, bool secure)
+{
vector<uint8_t> data, part, out;
{
- scoped_lock clock(p->cmutex);
+ scoped_lock clock(cmutex);
Channel * channel = nullptr;
- if (holds_alternative<unique_ptr<Channel>>(p->channel))
- channel = std::get<unique_ptr<Channel>>(p->channel).get();
+ if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel))
+ channel = uptr->get();
if (channel || secure)
data.push_back(0x00);
@@ -265,14 +318,14 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
out.push_back(0x80);
channel->encrypt(data.begin(), data.end(), out, 1);
} else if (secure) {
- p->secureOutQueue.emplace_back(move(data));
+ secureOutQueue.emplace_back(move(data));
} else {
out = std::move(data);
}
}
if (not out.empty())
- p->protocol->sendto(out, p->peerAddress);
+ protocol->sendto(out, peerAddress);
return true;
}
diff --git a/src/network/protocol.h b/src/network/protocol.h
index df29c05..51f8d59 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -25,7 +25,7 @@ class NetworkProtocol
{
public:
NetworkProtocol();
- explicit NetworkProtocol(int sock);
+ explicit NetworkProtocol(int sock, Identity self);
NetworkProtocol(const NetworkProtocol &) = delete;
NetworkProtocol(NetworkProtocol &&);
NetworkProtocol & operator=(const NetworkProtocol &) = delete;
@@ -55,18 +55,22 @@ public:
Connection connect(sockaddr_in6 addr);
- bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);
- void sendto(const vector<uint8_t> & buffer, sockaddr_in addr);
- void sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr);
+ void updateIdentity(Identity self);
+ void announceTo(variant<sockaddr_in, sockaddr_in6> addr);
void shutdown();
private:
+ bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);
+ void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr);
+
int sock;
mutex protocolMutex;
vector<uint8_t> buffer;
+ optional<Identity> self;
+
struct ConnectionPriv;
vector<ConnectionPriv *> connections;
};