#pragma once

#include "channel.h"

#include <erebos/storage.h>

#include <netinet/in.h>

#include <cstdint>
#include <memory>
#include <mutex>
#include <variant>
#include <vector>
#include <optional>

namespace erebos {

using std::mutex;
using std::optional;
using std::unique_ptr;
using std::variant;
using std::vector;

class NetworkProtocol
{
public:
	NetworkProtocol();
	explicit NetworkProtocol(int sock, Identity self);
	NetworkProtocol(const NetworkProtocol &) = delete;
	NetworkProtocol(NetworkProtocol &&);
	NetworkProtocol & operator=(const NetworkProtocol &) = delete;
	NetworkProtocol & operator=(NetworkProtocol &&);
	~NetworkProtocol();

	static constexpr char defaultVersion[] = "0.1";

	class Connection;
	class Stream;
	class InStream;
	class OutStream;

	struct Header;
	struct StreamData;

	struct ReceivedAnnounce;
	struct NewConnection;
	struct ConnectionReadReady;
	struct ProtocolClosed {};

	using PollResult = variant<
		ReceivedAnnounce,
		NewConnection,
		ConnectionReadReady,
		ProtocolClosed>;

	PollResult poll();

	struct Cookie { vector<uint8_t> value; };

	using ChannelState = variant<monostate,
		Stored<ChannelRequest>,
		shared_ptr<struct WaitingRef>,
		Stored<ChannelAccept>,
		unique_ptr<Channel>>;

	Connection connect(sockaddr_in6 addr);

	void updateIdentity(Identity self);
	void announceTo(variant<sockaddr_in, sockaddr_in6> addr);

	void shutdown();

private:
	bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);
	void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr);

	void sendCookie(variant<sockaddr_in, sockaddr_in6> addr);
	optional<Connection> verifyNewConnection(const Header & header, sockaddr_in6 addr);

	Cookie generateCookie(variant<sockaddr_in, sockaddr_in6> addr) const;
	bool verifyCookie(variant<sockaddr_in, sockaddr_in6> addr, const Cookie & cookie) const;

	int sock;

	mutex protocolMutex;
	vector<uint8_t> buffer;

	optional<Identity> self;

	struct ConnectionPriv;
	vector<ConnectionPriv *> connections;
};

class NetworkProtocol::Connection
{
	friend class NetworkProtocol;
	Connection(unique_ptr<ConnectionPriv> p);
public:
	Connection(const Connection &) = delete;
	Connection(Connection &&);
	Connection & operator=(const Connection &) = delete;
	Connection & operator=(Connection &&);
	~Connection();

	using Id = uintptr_t;
	Id id() const;

	const sockaddr_in6 & peerAddress() const;
	size_t mtu() const;

	optional<Header> receive(const PartialStorage &);
	bool send(const PartialStorage &, NetworkProtocol::Header,
			const vector<Object> &, bool secure);
	bool send( const StreamData & chunk );

	void close();

	shared_ptr< InStream > openInStream( uint8_t sid );
	shared_ptr< OutStream > openOutStream();

	// temporary:
	ChannelState & channel();
	void trySendOutQueue();

private:
	static variant< monostate, Header, StreamData >
		parsePacket(vector<uint8_t> & buf,
				Channel * channel, const PartialStorage & st,
				optional<uint64_t> & secure);

	unique_ptr<ConnectionPriv> p;
};

class NetworkProtocol::Stream
{
	friend class NetworkProtocol;
	friend class NetworkProtocol::Connection;

protected:
	Stream(uint8_t id_);

public:
	void close();

protected:
	bool hasDataLocked() const;

	size_t writeLocked( const uint8_t * buf, size_t size );
	size_t readLocked( uint8_t * buf, size_t size );

public:
	const uint8_t id;

protected:
	bool closed { false };
	vector< uint8_t > writeBuffer;
	vector< uint8_t > readBuffer;
	vector< uint8_t >::const_iterator readPtr;
	mutable mutex streamMutex;
};

class NetworkProtocol::InStream : public NetworkProtocol::Stream
{
	friend class NetworkProtocol;
	friend class NetworkProtocol::Connection;

protected:
	InStream(uint8_t id): Stream( id ) {}

public:
	bool isComplete() const;
	vector< uint8_t > readAll();
	size_t read( uint8_t * buf, size_t size );

protected:
	void writeChunk( StreamData chunk );
	bool tryUseChunkLocked( const StreamData & chunk );

private:
	uint64_t nextSequence { 0 };
	vector< StreamData > outOfOrderChunks;
};

class NetworkProtocol::OutStream : public NetworkProtocol::Stream
{
	friend class NetworkProtocol;
	friend class NetworkProtocol::Connection;

protected:
	OutStream(uint8_t id): Stream( id ) {}

public:
	size_t write( const uint8_t * buf, size_t size );

private:
	StreamData getNextChunkLocked( size_t size );

	uint64_t nextSequence { 0 };
};

struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; };
struct NetworkProtocol::NewConnection { Connection conn; };
struct NetworkProtocol::ConnectionReadReady { Connection::Id id; };

struct NetworkProtocol::Header
{
	struct Acknowledged { Digest value; };
	struct AcknowledgedSingle { uint64_t value; };
	struct Version { string value; };
	struct Initiation { Digest value; };
	struct CookieSet { Cookie value; };
	struct CookieEcho { Cookie value; };
	struct DataRequest { Digest value; };
	struct DataResponse { Digest value; };
	struct AnnounceSelf { Digest value; };
	struct AnnounceUpdate { Digest value; };
	struct ChannelRequest { Digest value; };
	struct ChannelAccept { Digest value; };
	struct ServiceType { UUID value; };
	struct ServiceRef { Digest value; };
	struct StreamOpen { uint8_t value; };

	using Item = variant<
		Acknowledged,
		AcknowledgedSingle,
		Version,
		Initiation,
		CookieSet,
		CookieEcho,
		DataRequest,
		DataResponse,
		AnnounceSelf,
		AnnounceUpdate,
		ChannelRequest,
		ChannelAccept,
		ServiceType,
		ServiceRef,
		StreamOpen>;

	static constexpr size_t itemSize = 78; /* estimate for size of ref-containing headers */

	Header(const vector<Item> & items): items(items) {}
	static optional<Header> load(const PartialRef &);
	static optional<Header> load(const PartialObject &);
	PartialObject toObject(const PartialStorage &) const;

	template<class T> const T * lookupFirst() const;
	bool isAcknowledged() const;

	vector<Item> items;
};

struct NetworkProtocol::StreamData
{
	uint8_t id;
	uint8_t sequence;
	vector< uint8_t > data;
};

template<class T>
const T * NetworkProtocol::Header::lookupFirst() const
{
	for (const auto & h : items)
		if (auto ptr = std::get_if<T>(&h))
			return ptr;
	return nullptr;
}

bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &);
inline bool operator!=(const NetworkProtocol::Header::Item & left,
		const NetworkProtocol::Header::Item & right)
{ return not (left == right); }

inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right)
{ return left.value == right.value; }

}