diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/network.cpp | 30 | ||||
-rw-r--r-- | src/network/protocol.cpp | 101 | ||||
-rw-r--r-- | src/network/protocol.h | 12 |
3 files changed, 89 insertions, 54 deletions
diff --git a/src/network.cpp b/src/network.cpp index 7a5a804..a3d1130 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -109,14 +109,7 @@ void Server::addPeer(const string & node, const string & service) const for (addrinfo * rp = result.get(); rp != nullptr; rp = rp->ai_next) { if (rp->ai_family == AF_INET6) { - Peer & peer = p->getPeer(*(sockaddr_in6 *)rp->ai_addr); - - vector<NetworkProtocol::Header::Item> header; - { - shared_lock lock(p->selfMutex); - header.push_back(NetworkProtocol::Header::AnnounceSelf { p->self.ref()->digest() }); - } - peer.connection.send(peer.partStorage, header, {}, false); + p->getPeer(*(sockaddr_in6 *)rp->ai_addr); return; } } @@ -310,7 +303,7 @@ Server::Priv::Priv(const Head<LocalState> & local, const Identity & self): if (sock < 0) throw std::system_error(errno, std::generic_category()); - protocol = NetworkProtocol(sock); + protocol = NetworkProtocol(sock, self); int disable = 0; // Should be disabled by default, but try to make sure. On platforms @@ -415,18 +408,13 @@ void Server::Priv::doAnnounce() if (lastAnnounce + announceInterval < now) { shared_lock slock(selfMutex); - NetworkProtocol::Header header({ - NetworkProtocol::Header::AnnounceSelf { self.ref()->digest() }, - }); - - vector<uint8_t> bytes = header.toObject(pst).encode(); for (const auto & in : bcastAddresses) { sockaddr_in sin = {}; sin.sin_family = AF_INET; sin.sin_addr = in; sin.sin_port = htons(discoveryPort); - protocol.sendto(bytes, sin); + protocol.announceTo(sin); } lastAnnounce += announceInterval * ((now - lastAnnounce) / announceInterval); @@ -659,17 +647,7 @@ void Server::Priv::handleLocalHeadChange(const Head<LocalState> & head) if (auto id = head->identity()) { if (*id != self) { self = *id; - - vector<NetworkProtocol::Header::Item> hitems; - for (const auto & r : self.refs()) - hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); - for (const auto & r : self.updates()) - hitems.push_back(NetworkProtocol::Header::AnnounceUpdate { r.digest() }); - - NetworkProtocol::Header header(hitems); - - for (const auto & peer : peers) - peer->connection.send(peer->partStorage, header, { **self.ref() }, false); + protocol.updateIdentity(*id); } } } diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index ede7023..89fa327 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -13,6 +13,7 @@ using std::get_if; using std::holds_alternative; using std::move; using std::nullopt; +using std::runtime_error; using std::scoped_lock; using std::visit; @@ -22,6 +23,9 @@ struct NetworkProtocol::ConnectionPriv { Connection::Id id() const; + bool send(const PartialStorage &, const Header &, + const vector<Object> &, bool secure); + NetworkProtocol * protocol; const sockaddr_in6 peerAddress; @@ -37,12 +41,14 @@ NetworkProtocol::NetworkProtocol(): sock(-1) {} -NetworkProtocol::NetworkProtocol(int s): - sock(s) +NetworkProtocol::NetworkProtocol(int s, Identity id): + sock(s), + self(move(id)) {} NetworkProtocol::NetworkProtocol(NetworkProtocol && other): - sock(other.sock) + sock(other.sock), + self(move(other.self)) { other.sock = -1; } @@ -51,6 +57,7 @@ NetworkProtocol & NetworkProtocol::operator=(NetworkProtocol && other) { sock = other.sock; other.sock = -1; + self = move(other.self); return *this; } @@ -94,10 +101,59 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) .protocol = this, .peerAddress = addr, }); - connections.push_back(conn.get()); + + { + scoped_lock lock(protocolMutex); + connections.push_back(conn.get()); + + vector<Header::Item> header { + Header::AnnounceSelf { self->ref()->digest() }, + }; + conn->send(self->ref()->storage(), header, {}, false); + } + return Connection(move(conn)); } +void NetworkProtocol::updateIdentity(Identity id) +{ + scoped_lock lock(protocolMutex); + self = move(id); + + vector<Header::Item> hitems; + for (const auto & r : self->refs()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + for (const auto & r : self->updates()) + hitems.push_back(Header::AnnounceUpdate { r.digest() }); + + Header header(hitems); + + for (const auto & conn : connections) + conn->send(self->ref()->storage(), header, { **self->ref() }, false); +} + +void NetworkProtocol::announceTo(variant<sockaddr_in, sockaddr_in6> addr) +{ + vector<uint8_t> bytes; + { + scoped_lock lock(protocolMutex); + + if (!self) + throw runtime_error("NetworkProtocol::announceTo without self identity"); + + bytes = Header({ + Header::AnnounceSelf { self->ref()->digest() }, + }).toObject(self->ref()->storage()).encode(); + } + + sendto(bytes, addr); +} + +void NetworkProtocol::shutdown() +{ + ::shutdown(sock, SHUT_RDWR); +} + bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr) { socklen_t addrlen = sizeof(addr); @@ -113,24 +169,14 @@ bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr) return true; } -void NetworkProtocol::sendto(const vector<uint8_t> & buffer, sockaddr_in addr) +void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> vaddr) { - ::sendto(sock, buffer.data(), buffer.size(), 0, - (sockaddr *) &addr, sizeof(addr)); + visit([&](auto && addr) { + ::sendto(sock, buffer.data(), buffer.size(), 0, + (sockaddr *) &addr, sizeof(addr)); + }, vaddr); } -void NetworkProtocol::sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr) -{ - ::sendto(sock, buffer.data(), buffer.size(), 0, - (sockaddr *) &addr, sizeof(addr)); -} - -void NetworkProtocol::shutdown() -{ - ::shutdown(sock, SHUT_RDWR); -} - - /******************************************************************************/ /* Connection */ /******************************************************************************/ @@ -242,14 +288,21 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, const Header & header, const vector<Object> & objs, bool secure) { + return p->send(partStorage, header, objs, secure); +} + +bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, + const Header & header, + const vector<Object> & objs, bool secure) +{ vector<uint8_t> data, part, out; { - scoped_lock clock(p->cmutex); + scoped_lock clock(cmutex); Channel * channel = nullptr; - if (holds_alternative<unique_ptr<Channel>>(p->channel)) - channel = std::get<unique_ptr<Channel>>(p->channel).get(); + if (auto uptr = get_if<unique_ptr<Channel>>(&this->channel)) + channel = uptr->get(); if (channel || secure) data.push_back(0x00); @@ -265,14 +318,14 @@ bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, out.push_back(0x80); channel->encrypt(data.begin(), data.end(), out, 1); } else if (secure) { - p->secureOutQueue.emplace_back(move(data)); + secureOutQueue.emplace_back(move(data)); } else { out = std::move(data); } } if (not out.empty()) - p->protocol->sendto(out, p->peerAddress); + protocol->sendto(out, peerAddress); return true; } diff --git a/src/network/protocol.h b/src/network/protocol.h index df29c05..51f8d59 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -25,7 +25,7 @@ class NetworkProtocol { public: NetworkProtocol(); - explicit NetworkProtocol(int sock); + explicit NetworkProtocol(int sock, Identity self); NetworkProtocol(const NetworkProtocol &) = delete; NetworkProtocol(NetworkProtocol &&); NetworkProtocol & operator=(const NetworkProtocol &) = delete; @@ -55,18 +55,22 @@ public: Connection connect(sockaddr_in6 addr); - bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr); - void sendto(const vector<uint8_t> & buffer, sockaddr_in addr); - void sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr); + void updateIdentity(Identity self); + void announceTo(variant<sockaddr_in, sockaddr_in6> addr); void shutdown(); private: + bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr); + void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr); + int sock; mutex protocolMutex; vector<uint8_t> buffer; + optional<Identity> self; + struct ConnectionPriv; vector<ConnectionPriv *> connections; }; |