diff options
| -rw-r--r-- | src/network.cpp | 87 | ||||
| -rw-r--r-- | src/network.h | 6 | ||||
| -rw-r--r-- | src/network/protocol.cpp | 129 | ||||
| -rw-r--r-- | src/network/protocol.h | 55 | 
4 files changed, 253 insertions, 24 deletions
| diff --git a/src/network.cpp b/src/network.cpp index b5dfd68..786e752 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -175,7 +175,7 @@ optional<Identity> Peer::identity() const  const sockaddr_in6 & Peer::address() const  {  	if (auto speer = p->speer.lock()) -		return speer->addr; +		return speer->connection.peerAddress();  	throw runtime_error("Server no longer running");  } @@ -373,36 +373,49 @@ void Server::Priv::doListen()  	for (; !finish; lock.lock()) {  		lock.unlock(); -		sockaddr_in6 paddr; -		if (not protocol.recvfrom(buf, paddr)) +		Peer * peer = nullptr; +		auto res = protocol.poll(); + +		if (holds_alternative<NetworkProtocol::ProtocolClosed>(res))  			break; -		if (isSelfAddress(paddr)) +		if (holds_alternative<NetworkProtocol::NewConnection>(res)) { +			auto & conn = get<NetworkProtocol::NewConnection>(res).conn; +			if (not isSelfAddress(conn.peerAddress())) +				peer = &addPeer(move(conn)); +		} + +		if (holds_alternative<NetworkProtocol::ConnectionReadReady>(res)) { +			peer = findPeer(get<NetworkProtocol::ConnectionReadReady>(res).id); +		} + +		if (!peer)  			continue; -		auto & peer = getPeer(paddr); +		if (not peer->connection.receive(buf)) +			continue;  		current = &buf; -		if (holds_alternative<unique_ptr<Channel>>(peer.channel)) { -			if (auto dec = std::get<unique_ptr<Channel>>(peer.channel)->decrypt(buf)) { +		if (holds_alternative<unique_ptr<Channel>>(peer->channel)) { +			if (auto dec = std::get<unique_ptr<Channel>>(peer->channel)->decrypt(buf)) {  				decrypted = std::move(*dec);  				current = &decrypted;  			} -		} else if (holds_alternative<Stored<ChannelAccept>>(peer.channel)) { -			if (auto dec = std::get<Stored<ChannelAccept>>(peer.channel)-> +		} else if (holds_alternative<Stored<ChannelAccept>>(peer->channel)) { +			if (auto dec = std::get<Stored<ChannelAccept>>(peer->channel)->  					data->channel()->decrypt(buf)) {  				decrypted = std::move(*dec);  				current = &decrypted;  			}  		} -		if (auto dec = PartialObject::decodePrefix(peer.partStorage, +		if (auto dec = PartialObject::decodePrefix(peer->partStorage,  				current->begin(), current->end())) {  			if (auto header = TransportHeader::load(std::get<PartialObject>(*dec))) {  				auto pos = std::get<1>(*dec); -				while (auto cdec = PartialObject::decodePrefix(peer.partStorage, +				while (auto cdec = PartialObject::decodePrefix(peer->partStorage,  							pos, current->end())) { -					peer.partStorage.storeObject(std::get<PartialObject>(*cdec)); +					peer->partStorage.storeObject(std::get<PartialObject>(*cdec));  					pos = std::get<1>(*cdec);  				} @@ -411,15 +424,15 @@ void Server::Priv::doListen()  				scoped_lock hlock(dataMutex);  				shared_lock slock(selfMutex); -				handlePacket(peer, *header, reply); -				peer.updateIdentity(reply); -				peer.updateChannel(reply); -				peer.updateService(reply); +				handlePacket(*peer, *header, reply); +				peer->updateIdentity(reply); +				peer->updateChannel(reply); +				peer->updateService(reply);  				if (!reply.header().empty()) -					peer.send(TransportHeader(reply.header()), reply.body(), false); +					peer->send(TransportHeader(reply.header()), reply.body(), false); -				peer.trySendOutQueue(); +				peer->trySendOutQueue();  			}  		} else {  			std::cerr << "invalid packet\n"; @@ -468,18 +481,48 @@ bool Server::Priv::isSelfAddress(const sockaddr_in6 & paddr)  	return false;  } +Server::Peer * Server::Priv::findPeer(NetworkProtocol::Connection::Id cid) const +{ +	scoped_lock lock(dataMutex); + +	for (auto & peer : peers) +		if (peer->connection.id() == cid) +			return peer.get(); + +	return nullptr; +} +  Server::Peer & Server::Priv::getPeer(const sockaddr_in6 & paddr)  {  	scoped_lock lock(dataMutex);  	for (auto & peer : peers) -		if (memcmp(&peer->addr, &paddr, sizeof paddr) == 0) +		if (memcmp(&peer->connection.peerAddress(), &paddr, sizeof paddr) == 0)  			return *peer;  	auto st = self.ref()->storage().deriveEphemeralStorage();  	shared_ptr<Peer> peer(new Peer {  		.server = *this, -		.addr = paddr, +		.connection = protocol.connect(paddr), +		.identity = monostate(), +		.identityUpdates = {}, +		.channel = monostate(), +		.tempStorage = st, +		.partStorage = st.derivePartialStorage(), +		}); +	peers.push_back(peer); +	plist.p->push(peer); +	return *peer; +} + +Server::Peer & Server::Priv::addPeer(NetworkProtocol::Connection conn) +{ +	scoped_lock lock(dataMutex); + +	auto st = self.ref()->storage().deriveEphemeralStorage(); +	shared_ptr<Peer> peer(new Peer { +		.server = *this, +		.connection = move(conn),  		.identity = monostate(),  		.identityUpdates = {},  		.channel = monostate(), @@ -695,7 +738,7 @@ void Server::Peer::send(const TransportHeader & header, const vector<Object> & o  		out = std::move(data);  	if (!out.empty()) -		server.protocol.sendto(out, addr); +		connection.send(out);  }  void Server::Peer::updateIdentity(ReplyBuilder &) @@ -831,7 +874,7 @@ void Server::Peer::trySendOutQueue()  	for (const auto & data : secureOutQueue) {  		auto out = std::get<unique_ptr<Channel>>(channel)->encrypt(data); -		server.protocol.sendto(out, addr); +		connection.send(out);  	}  	secureOutQueue.clear(); diff --git a/src/network.h b/src/network.h index c242ac5..74231bf 100644 --- a/src/network.h +++ b/src/network.h @@ -44,7 +44,7 @@ struct Server::Peer  	Peer & operator=(const Peer &) = delete;  	Priv & server; -	const sockaddr_in6 addr; +	NetworkProtocol::Connection connection;  	variant<monostate,  		shared_ptr<struct WaitingRef>, @@ -157,7 +157,9 @@ struct Server::Priv  	void doAnnounce();  	bool isSelfAddress(const sockaddr_in6 & paddr); +	Peer * findPeer(NetworkProtocol::Connection::Id cid) const;  	Peer & getPeer(const sockaddr_in6 & paddr); +	Peer & addPeer(NetworkProtocol::Connection conn);  	void handlePacket(Peer &, const TransportHeader &, ReplyBuilder &);  	void handleLocalHeadChange(const Head<LocalState> &); @@ -165,7 +167,7 @@ struct Server::Priv  	constexpr static uint16_t discoveryPort { 29665 };  	constexpr static chrono::seconds announceInterval { 60 }; -	mutex dataMutex; +	mutable mutex dataMutex;  	condition_variable announceCondvar;  	bool finish = false; diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index 63cfde5..c247bf0 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -3,10 +3,27 @@  #include <sys/socket.h>  #include <unistd.h> +#include <cstring> +#include <mutex>  #include <system_error> +using std::move; +using std::scoped_lock; +  namespace erebos { +struct NetworkProtocol::ConnectionPriv +{ +	Connection::Id id() const; + +	NetworkProtocol * protocol; +	const sockaddr_in6 peerAddress; + +	mutex cmutex {}; +	vector<uint8_t> buffer {}; +}; + +  NetworkProtocol::NetworkProtocol():  	sock(-1)  {} @@ -32,6 +49,44 @@ NetworkProtocol::~NetworkProtocol()  {  	if (sock >= 0)  		close(sock); + +	for (auto & c : connections) +		c->protocol = nullptr; +} + +NetworkProtocol::PollResult NetworkProtocol::poll() +{ +	sockaddr_in6 addr; +	if (!recvfrom(buffer, addr)) +		return ProtocolClosed {}; + +	scoped_lock lock(protocolMutex); +	for (const auto & c : connections) { +		if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { +			scoped_lock clock(c->cmutex); +			buffer.swap(c->buffer); +			return ConnectionReadReady { c->id() }; +		} +	} + +	auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { +		.protocol = this, +		.peerAddress = addr, +	}); + +	connections.push_back(conn.get()); +	buffer.swap(conn->buffer); +	return NewConnection { Connection(move(conn)) }; +} + +NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) +{ +	auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { +		.protocol = this, +		.peerAddress = addr, +	}); +	connections.push_back(conn.get()); +	return Connection(move(conn));  }  bool NetworkProtocol::recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr) @@ -66,4 +121,78 @@ void NetworkProtocol::shutdown()  	::shutdown(sock, SHUT_RDWR);  } + +NetworkProtocol::Connection::Id NetworkProtocol::ConnectionPriv::id() const +{ +	return reinterpret_cast<uintptr_t>(this); +} + +NetworkProtocol::Connection::Connection(unique_ptr<ConnectionPriv> p_): +	p(move(p_)) +{ +} + +NetworkProtocol::Connection::Connection(Connection && other): +	p(move(other.p)) +{ +} + +NetworkProtocol::Connection & NetworkProtocol::Connection::operator=(Connection && other) +{ +	close(); +	p = move(other.p); +	return *this; +} + +NetworkProtocol::Connection::~Connection() +{ +	close(); +} + +NetworkProtocol::Connection::Id NetworkProtocol::Connection::id() const +{ +	return p->id(); +} + +const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const +{ +	return p->peerAddress; +} + +bool NetworkProtocol::Connection::receive(vector<uint8_t> & buffer) +{ +	scoped_lock lock(p->cmutex); +	if (p->buffer.empty()) +		return false; + +	buffer.swap(p->buffer); +	p->buffer.clear(); +	return true; +} + +bool NetworkProtocol::Connection::send(const vector<uint8_t> & buffer) +{ +	p->protocol->sendto(buffer, p->peerAddress); +	return true; +} + +void NetworkProtocol::Connection::close() +{ +	if (not p) +		return; + +	if (p->protocol) { +		scoped_lock lock(p->protocol->protocolMutex); +		for (auto it = p->protocol->connections.begin(); +				it != p->protocol->connections.end(); it++) { +			if ((*it) == p.get()) { +				p->protocol->connections.erase(it); +				break; +			} +		} +	} + +	p = nullptr; +} +  } diff --git a/src/network/protocol.h b/src/network/protocol.h index 6a22f3b..a9bbaff 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -3,10 +3,16 @@  #include <netinet/in.h>  #include <cstdint> +#include <memory> +#include <mutex> +#include <variant>  #include <vector>  namespace erebos { +using std::mutex; +using std::unique_ptr; +using std::variant;  using std::vector;  class NetworkProtocol @@ -20,6 +26,21 @@ public:  	NetworkProtocol & operator=(NetworkProtocol &&);  	~NetworkProtocol(); +	class Connection; + +	struct NewConnection; +	struct ConnectionReadReady; +	struct ProtocolClosed {}; + +	using PollResult = variant< +		NewConnection, +		ConnectionReadReady, +		ProtocolClosed>; + +	PollResult poll(); + +	Connection connect(sockaddr_in6 addr); +  	bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);  	void sendto(const vector<uint8_t> & buffer, sockaddr_in addr);  	void sendto(const vector<uint8_t> & buffer, sockaddr_in6 addr); @@ -28,6 +49,40 @@ public:  private:  	int sock; + +	mutex protocolMutex; +	vector<uint8_t> buffer; + +	struct ConnectionPriv; +	vector<ConnectionPriv *> connections; +}; + +class NetworkProtocol::Connection +{ +	friend class NetworkProtocol; +	Connection(unique_ptr<ConnectionPriv> p); +public: +	Connection(const Connection &) = delete; +	Connection(Connection &&); +	Connection & operator=(const Connection &) = delete; +	Connection & operator=(Connection &&); +	~Connection(); + +	using Id = uintptr_t; +	Id id() const; + +	const sockaddr_in6 & peerAddress() const; + +	bool receive(vector<uint8_t> & buffer); +	bool send(const vector<uint8_t> & buffer); + +	void close(); + +private: +	unique_ptr<ConnectionPriv> p;  }; +struct NetworkProtocol::NewConnection { Connection conn; }; +struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; +  } |