diff options
Diffstat (limited to 'src/network')
| -rw-r--r-- | src/network/protocol.cpp | 251 | ||||
| -rw-r--r-- | src/network/protocol.h | 83 | 
2 files changed, 312 insertions, 22 deletions
| diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index b781693..dbf1c40 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -15,6 +15,7 @@ using std::move;  using std::nullopt;  using std::runtime_error;  using std::scoped_lock; +using std::to_string;  using std::visit;  namespace erebos { @@ -25,6 +26,7 @@ struct NetworkProtocol::ConnectionPriv  	bool send(const PartialStorage &, Header,  			const vector<Object> &, bool secure); +	bool send( const StreamData & chunk );  	NetworkProtocol * protocol;  	const sockaddr_in6 peerAddress; @@ -37,7 +39,12 @@ struct NetworkProtocol::ConnectionPriv  	ChannelState channel = monostate();  	vector<vector<uint8_t>> secureOutQueue {}; +	size_t mtu = 500; // TODO: MTU +  	vector<uint64_t> toAcknowledge {}; + +	vector< shared_ptr< InStream >> inStreams {}; +	vector< shared_ptr< OutStream >> outStreams {};  }; @@ -80,16 +87,26 @@ NetworkProtocol::PollResult NetworkProtocol::poll()  		scoped_lock lock(protocolMutex);  		for (const auto & c : connections) { +			vector< StreamData > streamChunks; +			bool sendAck = false;  			{  				scoped_lock clock(c->cmutex); -				if (c->toAcknowledge.empty()) -					continue; - -				if (not holds_alternative<unique_ptr<Channel>>(c->channel)) -					continue; +				sendAck = not c->toAcknowledge.empty() && +					holds_alternative< unique_ptr< Channel >>( c->channel ); + +				for (const auto & s : c->outStreams) { +					scoped_lock slock(s->streamMutex); +					while (s->hasDataLocked()) +						streamChunks.push_back( s->getNextChunkLocked( c->mtu ) ); +				} +			} +			if (sendAck) { +				auto pst = self->ref()->storage().deriveEphemeralStorage(); +				c->send(pst, Header {{}}, {}, true); +			} +			for (const auto & chunk : streamChunks) { +				c->send( chunk );  			} -			auto pst = self->ref()->storage().deriveEphemeralStorage(); -			c->send(pst, Header {{}}, {}, true);  		}  	} @@ -110,7 +127,8 @@ NetworkProtocol::PollResult NetworkProtocol::poll()  		auto pst = self->ref()->storage().deriveEphemeralStorage();  		optional<uint64_t> secure = false; -		if (auto header = Connection::parsePacket(buffer, nullptr, pst, secure)) { +		auto parsed = Connection::parsePacket(buffer, nullptr, pst, secure); +		if (const auto * header = get_if< Header >( &parsed )) {  			if (auto conn = verifyNewConnection(*header, addr))  				return NewConnection { move(*conn) }; @@ -335,13 +353,14 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  	}  	optional<uint64_t> secure = false; -	if (auto header = parsePacket(buf, channel, partStorage, secure)) { +	auto parsed = parsePacket(buf, channel, partStorage, secure); +	if (const auto * header = get_if< Header >( &parsed )) {  		scoped_lock lock(p->cmutex);  		if (secure) {  			if (header->isAcknowledged())  				p->toAcknowledge.push_back(*secure); -			return header; +			return *header;  		}  		if (const auto * cookieEcho = header->lookupFirst<Header::CookieEcho>()) { @@ -353,13 +372,13 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  			if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>())  				p->receivedCookie = cookieSet->value; -			return header; +			return *header;  		}  		if (holds_alternative<monostate>(p->channel)) {  			if (const auto * cookieSet = header->lookupFirst<Header::CookieSet>()) {  				p->receivedCookie = cookieSet->value; -				return header; +				return *header;  			}  		} @@ -368,10 +387,36 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par  			return nullopt;  		}  	} +	else if( auto * sdata = get_if< StreamData >( &parsed )){ +		scoped_lock lock(p->cmutex); +		if (secure) +			p->toAcknowledge.push_back(*secure); + +		InStream * stream = nullptr; +		for (const auto & s : p->inStreams) { +			if (s->id == sdata->id) { +				stream = s.get(); +				break; +			} +		} +		if (not stream) { +			std::cerr << "unexpected stream number\n"; +			return nullopt; +		} + +		stream->writeChunk( move(*sdata) ); +		if( stream->closed ) +			p->inStreams.erase( +					std::remove_if( p->inStreams.begin(), p->inStreams.end(), +						[&]( auto & sptr ) { return sptr.get() == stream; } ), +					p->inStreams.end() ); +		return nullopt; +	}  	return nullopt;  } -optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf, +variant< monostate, NetworkProtocol::Header, NetworkProtocol::StreamData > +NetworkProtocol::Connection::parsePacket(vector<uint8_t> & buf,  		Channel * channel, const PartialStorage & partStorage,  		optional<uint64_t> & secure)  { @@ -384,7 +429,7 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto  	if ((buf[0] & 0xE0) == 0x80) {  		if (not channel) {  			std::cerr << "unexpected encrypted packet\n"; -			return nullopt; +			return monostate();  		}  		if ((secure = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0))) { @@ -395,9 +440,17 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto  				plainBegin = decrypted.begin() + 1;  				plainEnd = decrypted.end();  			} +			else if (decrypted[0] < 0x40) { +				StreamData sdata; +				sdata.id = decrypted[0]; +				sdata.sequence = decrypted[1]; +				sdata.data.resize( decrypted.size() - 2 ); +				std::copy(decrypted.begin() + 2, decrypted.end(), sdata.data.begin()); +				return sdata; +			}  			else { -				std::cerr << "streams not implemented\n"; -				return nullopt; +				std::cerr << "unexpected stream header\n"; +				return monostate();  			}  		}  	} @@ -414,12 +467,12 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::parsePacket(vecto  				pos = std::get<1>(*cdec);  			} -			return header; +			return *header;  		}  	}  	std::cerr << "invalid packet\n"; -	return nullopt; +	return monostate();  }  bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, @@ -483,6 +536,37 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,  	return true;  } +bool NetworkProtocol::Connection::send( const StreamData & chunk ) +{ +	return p->send( chunk ); +} + +bool NetworkProtocol::ConnectionPriv::send( const StreamData & chunk ) +{ +	vector<uint8_t> data, out; + +	{ +		scoped_lock clock( cmutex ); + +		Channel * channel = nullptr; +		if (auto uptr = get_if< unique_ptr< Channel >>( &this->channel )) +			channel = uptr->get(); +		if (not channel) +			return false; + +		data.push_back( chunk.id ); +		data.push_back( static_cast< uint8_t >( chunk.sequence )); +		data.insert( data.end(), chunk.data.begin(), chunk.data.end() ); + +		out.push_back( 0x80 ); +		channel->encrypt( data.begin(), data.end(), out, 1 ); +	} + +	protocol->sendto( out, peerAddress ); +	return true; +} + +  void NetworkProtocol::Connection::close()  {  	if (not p) @@ -502,6 +586,28 @@ void NetworkProtocol::Connection::close()  	p = nullptr;  } +shared_ptr< NetworkProtocol::InStream > NetworkProtocol::Connection::openInStream( uint8_t sid ) +{ +	scoped_lock lock( p->cmutex ); +	for (const auto & s : p->inStreams) +		if (s->id == sid) +			throw runtime_error("inbound stream " + to_string(sid) + " already open"); + +	p->inStreams.emplace_back( new InStream( sid )); +	return p->inStreams.back(); +} + +shared_ptr< NetworkProtocol::OutStream > NetworkProtocol::Connection::openOutStream( uint8_t sid ) +{ +	scoped_lock lock( p->cmutex ); +	for (const auto & s : p->outStreams) +		if (s->id == sid) +			throw runtime_error("outbound stream " + to_string(sid) + " already open"); + +	p->outStreams.emplace_back( new OutStream( sid )); +	return p->outStreams.back(); +} +  NetworkProtocol::ChannelState & NetworkProtocol::Connection::channel()  {  	return p->channel; @@ -530,6 +636,109 @@ void NetworkProtocol::Connection::trySendOutQueue()  } +bool NetworkProtocol::Stream::hasDataLocked() const +{ +	return not writeBuffer.empty() || not readBuffer.empty(); +} + +size_t NetworkProtocol::Stream::writeLocked( const uint8_t * buf, size_t size ) +{ +	writeBuffer.insert( writeBuffer.end(), buf, buf + size ); +	return size; +} + +size_t NetworkProtocol::Stream::readLocked( uint8_t * buf, size_t size ) +{ +	size_t res = 0; +	if (readPtr < readBuffer.end()) { +		res = std::min( size, static_cast< size_t >( readBuffer.end() - readPtr )); +		std::copy_n( readPtr, res, buf ); +		readPtr += res; +	} +	if (res < size && not writeBuffer.empty()) { +		std::swap( readBuffer, writeBuffer ); +		readPtr = readBuffer.begin(); +		writeBuffer.clear(); +		return res + readLocked( buf + res, size - res ); +	} +	return res; +} + +bool NetworkProtocol::InStream::isComplete() const +{ +	scoped_lock lock( streamMutex ); +	return closed && outOfOrderChunks.empty(); +} + +vector< uint8_t > NetworkProtocol::InStream::readAll() +{ +	scoped_lock lock( streamMutex ); +	if (readBuffer.empty()) { +		vector< uint8_t > res; +		std::swap( res, writeBuffer ); +		return res; +	} + +	readBuffer.insert( readBuffer.end(), writeBuffer.begin(), writeBuffer.end() ); +	writeBuffer.clear(); + +	vector< uint8_t > res; +	std::swap( res, readBuffer ); +	readPtr = readBuffer.begin(); +	return res; +} + +size_t NetworkProtocol::InStream::read( uint8_t * buf, size_t size ) +{ +	scoped_lock lock( streamMutex ); +	return readLocked( buf, size ); +} + +void NetworkProtocol::InStream::writeChunk( StreamData chunk ) +{ +	scoped_lock lock( streamMutex ); +	if( tryUseChunkLocked( chunk )) { +		auto it = outOfOrderChunks.begin(); +		while( it != outOfOrderChunks.end() && tryUseChunkLocked( *it )) +			it++; +		outOfOrderChunks.erase( outOfOrderChunks.begin(), it ); +	} else { +		auto it = outOfOrderChunks.begin(); +		while( it < outOfOrderChunks.end() && +				it->sequence - static_cast< uint8_t >( nextSequence ) +				< chunk.sequence - static_cast< uint8_t >( nextSequence )) +			it++; +		outOfOrderChunks.insert( it, move(chunk) ); +	} +} + +bool NetworkProtocol::InStream::tryUseChunkLocked( const StreamData & chunk ) +{ +	if( chunk.sequence != static_cast< uint8_t >( nextSequence )) +		return false; + +	if( chunk.data.empty() ) +		closed = true; +	else +		writeLocked( chunk.data.data(), chunk.data.size() ); +	nextSequence++; +	return true; +} + +NetworkProtocol::StreamData NetworkProtocol::OutStream::getNextChunkLocked( size_t size ) +{ +	StreamData res; +	res.id = id; +	res.sequence = nextSequence++, + +	res.data.resize( size ); +	size = readLocked( res.data.data(), size ); +	res.data.resize( size ); + +	return res; +} + +  /******************************************************************************/  /* Header                                                                     */  /******************************************************************************/ @@ -600,6 +809,9 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj  		} else if (item.name == "SVR") {  			if (auto ref = item.asRef())  				items.emplace_back(ServiceRef { ref->digest() }); +		} else if (item.name == "STO") { +			if (auto num = item.asInteger()) +				items.emplace_back( StreamOpen{ static_cast< uint8_t >( *num )});  		}  	} @@ -652,6 +864,9 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const  		else if (const auto * ptr = get_if<ServiceRef>(&item))  			ritems.emplace_back("SVR", st.ref(ptr->value)); + +		else if (const auto * ptr = get_if< StreamOpen >( &item )) +			ritems.emplace_back("STO", Record::Item::Integer( ptr->value ));  	}  	return PartialObject(PartialRecord(std::move(ritems))); diff --git a/src/network/protocol.h b/src/network/protocol.h index ba40744..2592c9f 100644 --- a/src/network/protocol.h +++ b/src/network/protocol.h @@ -35,8 +35,12 @@ public:  	static constexpr char defaultVersion[] = "0.1";  	class Connection; +	class Stream; +	class InStream; +	class OutStream;  	struct Header; +	struct StreamData;  	struct ReceivedAnnounce;  	struct NewConnection; @@ -106,21 +110,83 @@ public:  	optional<Header> receive(const PartialStorage &);  	bool send(const PartialStorage &, NetworkProtocol::Header,  			const vector<Object> &, bool secure); +	bool send( const StreamData & chunk );  	void close(); +	shared_ptr< InStream > openInStream( uint8_t sid ); +	shared_ptr< OutStream > openOutStream( uint8_t sid ); +  	// temporary:  	ChannelState & channel();  	void trySendOutQueue();  private: -	static optional<Header> parsePacket(vector<uint8_t> & buf, -			Channel * channel, const PartialStorage & st, -			optional<uint64_t> & secure); +	static variant< monostate, Header, StreamData > +		parsePacket(vector<uint8_t> & buf, +				Channel * channel, const PartialStorage & st, +				optional<uint64_t> & secure);  	unique_ptr<ConnectionPriv> p;  }; +class NetworkProtocol::Stream +{ +	friend class NetworkProtocol; +	friend class NetworkProtocol::Connection; + +protected: +	Stream(uint8_t id_): id( id_ ) {} + +	bool hasDataLocked() const; + +	size_t writeLocked( const uint8_t * buf, size_t size ); +	size_t readLocked( uint8_t * buf, size_t size ); + +	uint8_t id; +	bool closed { false }; +	vector< uint8_t > writeBuffer; +	vector< uint8_t > readBuffer; +	vector< uint8_t >::const_iterator readPtr; +	mutable mutex streamMutex; +}; + +class NetworkProtocol::InStream : public NetworkProtocol::Stream +{ +	friend class NetworkProtocol; +	friend class NetworkProtocol::Connection; + +protected: +	InStream(uint8_t id): Stream( id ) {} + +public: +	bool isComplete() const; +	vector< uint8_t > readAll(); +	size_t read( uint8_t * buf, size_t size ); + +protected: +	void writeChunk( StreamData chunk ); +	bool tryUseChunkLocked( const StreamData & chunk ); + +private: +	uint64_t nextSequence { 0 }; +	vector< StreamData > outOfOrderChunks; +}; + +class NetworkProtocol::OutStream : public NetworkProtocol::Stream +{ +	friend class NetworkProtocol; +	friend class NetworkProtocol::Connection; + +protected: +	OutStream(uint8_t id): Stream( id ) {} + +private: +	StreamData getNextChunkLocked( size_t size ); + +	uint64_t nextSequence { 0 }; +}; +  struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; };  struct NetworkProtocol::NewConnection { Connection conn; };  struct NetworkProtocol::ConnectionReadReady { Connection::Id id; }; @@ -141,6 +207,7 @@ struct NetworkProtocol::Header  	struct ChannelAccept { Digest value; };  	struct ServiceType { UUID value; };  	struct ServiceRef { Digest value; }; +	struct StreamOpen { uint8_t value; };  	using Item = variant<  		Acknowledged, @@ -156,7 +223,8 @@ struct NetworkProtocol::Header  		ChannelRequest,  		ChannelAccept,  		ServiceType, -		ServiceRef>; +		ServiceRef, +		StreamOpen>;  	Header(const vector<Item> & items): items(items) {}  	static optional<Header> load(const PartialRef &); @@ -169,6 +237,13 @@ struct NetworkProtocol::Header  	vector<Item> items;  }; +struct NetworkProtocol::StreamData +{ +	uint8_t id; +	uint8_t sequence; +	vector< uint8_t > data; +}; +  template<class T>  const T * NetworkProtocol::Header::lookupFirst() const  { |