diff options
| -rw-r--r-- | include/erebos/pairing.h | 8 | ||||
| -rw-r--r-- | src/pairing.cpp | 132 | 
2 files changed, 84 insertions, 56 deletions
| diff --git a/include/erebos/pairing.h b/include/erebos/pairing.h index 4457426..b8b680a 100644 --- a/include/erebos/pairing.h +++ b/include/erebos/pairing.h @@ -29,6 +29,8 @@ using std::vector;  class PairingServiceBase : public Service  {  public: +	virtual ~PairingServiceBase(); +  	typedef function<void(const Peer &)> RequestInitHook;  	void onRequestInit(RequestInitHook); @@ -49,7 +51,6 @@ private:  	static vector<uint8_t> nonceDigest(const Identity & id1, const Identity & id2,  			const vector<uint8_t> & nonce1, const vector<uint8_t> & nonce2);  	static string confirmationNumber(const vector<uint8_t> &); -	void waitForConfirmation(Peer peer, string confirm);  	RequestInitHook requestInitHook;  	ConfirmHook responseHook; @@ -70,14 +71,17 @@ private:  	};  	struct State { +		mutex lock;  		StatePhase phase;  		vector<uint8_t> nonce;  		vector<uint8_t> peerCheck;  		promise<bool> success;  	}; -	map<Peer, State> peerStates; +	map<Peer, shared_ptr<State>> peerStates;  	mutex stateLock; + +	void waitForConfirmation(Peer peer, weak_ptr<State> state, string confirm, ConfirmHook hook);  };  template<class Result> diff --git a/src/pairing.cpp b/src/pairing.cpp index 4ae215b..c8e2d9f 100644 --- a/src/pairing.cpp +++ b/src/pairing.cpp @@ -2,6 +2,7 @@  #include "service.h" +#include <future>  #include <openssl/rand.h>  #include <arpa/inet.h> @@ -15,6 +16,22 @@ using namespace erebos;  using std::lock_guard;  using std::runtime_error; +using std::scoped_lock; +using std::thread; + +PairingServiceBase::~PairingServiceBase() +{ +	// There may be some threads in waitForConfirmation waiting on client +	// promise, so make sure they do not touch the service state anymore: +	for (auto & [peer, state] : peerStates) { +		scoped_lock lock(state->lock); +		if (state->phase != StatePhase::PairingDone && +				state->phase != StatePhase::PairingFailed) { +			state->success.set_value(false); +			state->phase = StatePhase::PairingFailed; +		} +	} +}  void PairingServiceBase::onRequestInit(RequestInitHook hook)  { @@ -51,27 +68,28 @@ void PairingServiceBase::handle(Context & ctx)  		throw runtime_error("Pairing request for peer without known identity");  	lock_guard lock(stateLock); -	auto & state = peerStates.try_emplace(ctx.peer(), State()).first->second; +	auto & state = peerStates.try_emplace(ctx.peer(), new State()).first->second; +	scoped_lock lock2(state->lock);  	if (auto request = rec->item("request").asBinary()) { -		if (state.phase != StatePhase::NoPairing) +		if (state->phase != StatePhase::NoPairing)  			return;  		if (requestInitHook)  			requestInitHook(ctx.peer()); -		state.phase = StatePhase::PeerRequest; -		state.peerCheck = *request; -		state.nonce.resize(32); -		RAND_bytes(state.nonce.data(), state.nonce.size()); +		state->phase = StatePhase::PeerRequest; +		state->peerCheck = *request; +		state->nonce.resize(32); +		RAND_bytes(state->nonce.data(), state->nonce.size());  		ctx.peer().send(uuid(), Object(Record({ -			{ "response", state.nonce }, +			{ "response", state->nonce },  		})));  	}  	else if (auto response = rec->item("response").asBinary()) { -		if (state.phase != StatePhase::OurRequest) { +		if (state->phase != StatePhase::OurRequest) {  			fprintf(stderr, "Unexpected pairing response.\n"); // TODO  			return;  		} @@ -79,15 +97,15 @@ void PairingServiceBase::handle(Context & ctx)  		if (responseHook) {  			string confirm = confirmationNumber(nonceDigest(  				ctx.peer().server().identity(), *pid,  -				state.nonce, *response)); +				state->nonce, *response));  			std::thread(&PairingServiceBase::waitForConfirmation, -					this, ctx.peer(), confirm).detach(); +					this, ctx.peer(), state, confirm, responseHook).detach();  		} -		state.phase = StatePhase::OurRequestConfirm; +		state->phase = StatePhase::OurRequestConfirm;  		ctx.peer().send(uuid(), Object(Record({ -			{ "reqnonce", state.nonce }, +			{ "reqnonce", state->nonce },  		})));  	} @@ -95,12 +113,12 @@ void PairingServiceBase::handle(Context & ctx)  		auto check = nonceDigest(  				*pid, ctx.peer().server().identity(),  				*reqnonce, vector<uint8_t>()); -		if (check != state.peerCheck) { +		if (check != state->peerCheck) {  			if (requestNonceFailedHook)  				requestNonceFailedHook(ctx.peer()); -			if (state.phase < StatePhase::PairingDone) { -				state.phase = StatePhase::PairingFailed; -				state.success.set_value(false); +			if (state->phase < StatePhase::PairingDone) { +				state->phase = StatePhase::PairingFailed; +				state->success.set_value(false);  			}  			return;  		} @@ -108,26 +126,26 @@ void PairingServiceBase::handle(Context & ctx)  		if (requestHook) {  			string confirm = confirmationNumber(nonceDigest(  				*pid, ctx.peer().server().identity(), -				*reqnonce, state.nonce)); +				*reqnonce, state->nonce));  			std::thread(&PairingServiceBase::waitForConfirmation, -					this, ctx.peer(), confirm).detach(); +					this, ctx.peer(), state, confirm, requestHook).detach();  		} -		state.phase = StatePhase::PeerRequestConfirm; +		state->phase = StatePhase::PeerRequestConfirm;  	}  	else if (auto decline = rec->item("decline").asText()) { -		if (state.phase < StatePhase::PairingDone) { -			state.phase = StatePhase::PairingFailed; -			state.success.set_value(false); +		if (state->phase < StatePhase::PairingDone) { +			state->phase = StatePhase::PairingFailed; +			state->success.set_value(false);  		}  	}  	else { -		if (state.phase == StatePhase::OurRequestReady) { +		if (state->phase == StatePhase::OurRequestReady) {  			handlePairingResult(ctx); -			state.phase = StatePhase::PairingDone; -			state.success.set_value(true); +			state->phase = StatePhase::PairingDone; +			state->success.set_value(true);  		} else {  			result = ctx.ref();  		} @@ -141,16 +159,16 @@ void PairingServiceBase::requestPairing(UUID serviceId, const Peer & peer)  		throw runtime_error("Pairing request for peer without known identity");  	lock_guard lock(stateLock); -	auto & state = peerStates.try_emplace(peer, State()).first->second; +	auto & state = peerStates.try_emplace(peer, new State()).first->second; -	state.phase = StatePhase::OurRequest; -	state.nonce.resize(32); -	RAND_bytes(state.nonce.data(), state.nonce.size()); +	state->phase = StatePhase::OurRequest; +	state->nonce.resize(32); +	RAND_bytes(state->nonce.data(), state->nonce.size());  	vector<Record::Item> items;  	items.emplace_back("request", nonceDigest(  				peer.server().identity(), *pid, -				state.nonce, vector<uint8_t>())); +				state->nonce, vector<uint8_t>()));  	peer.send(serviceId, Object(Record(std::move(items))));  } @@ -179,28 +197,34 @@ string PairingServiceBase::confirmationNumber(const vector<uint8_t> & digest)  	return ret;  } -void PairingServiceBase::waitForConfirmation(Peer peer, string confirm) +void PairingServiceBase::waitForConfirmation(Peer peer, weak_ptr<State> wstate, string confirm, ConfirmHook hook)  { -	ConfirmHook hook;  	future<bool> success; -	{ -		lock_guard lock(stateLock); -		auto & state = peerStates.try_emplace(peer, State()).first->second; -		if (state.phase == StatePhase::OurRequestConfirm) -			hook = responseHook; -		if (state.phase == StatePhase::PeerRequestConfirm) -			hook = requestHook; - -		success = state.success.get_future(); +	if (auto state = wstate.lock()) { +		success = state->success.get_future(); +	} else { +		return;  	} -	bool ok = hook(peer, confirm, std::move(success)).get(); +	bool ok; +	try { +		ok = hook(peer, confirm, std::move(success)).get(); +	} +	catch (const std::future_error & e) { +		if (e.code() == std::future_errc::broken_promise) +			ok = false; +		else +			throw; +	} -	lock_guard lock(stateLock); -	auto & state = peerStates.try_emplace(peer, State()).first->second; +	auto state = wstate.lock(); +	if (!state) +		return; // Server was closed + +	scoped_lock lock(state->lock);  	if (ok) { -		if (state.phase == StatePhase::OurRequestConfirm) { +		if (state->phase == StatePhase::OurRequestConfirm) {  			if (result) {  				peer.server().localHead().update([&] (const Stored<LocalState> & local) {  					Service::Context ctx(new Service::Context::Priv { @@ -212,21 +236,21 @@ void PairingServiceBase::waitForConfirmation(Peer peer, string confirm)  					handlePairingResult(ctx);  					return ctx.local();  				}); -				state.phase = StatePhase::PairingDone; -				state.success.set_value(true); +				state->phase = StatePhase::PairingDone; +				state->success.set_value(true);  			} else { -				state.phase = StatePhase::OurRequestReady; +				state->phase = StatePhase::OurRequestReady;  			} -		} else if (state.phase == StatePhase::PeerRequestConfirm) { +		} else if (state->phase == StatePhase::PeerRequestConfirm) {  			peer.send(uuid(), handlePairingCompleteRef(peer)); -			state.phase = StatePhase::PairingDone; -			state.success.set_value(true); +			state->phase = StatePhase::PairingDone; +			state->success.set_value(true);  		}  	} else { -		if (state.phase != StatePhase::PairingFailed) { +		if (state->phase != StatePhase::PairingFailed) {  			peer.send(uuid(), Object(Record({{ "decline", string() }}))); -			state.phase = StatePhase::PairingFailed; -			state.success.set_value(false); +			state->phase = StatePhase::PairingFailed; +			state->success.set_value(false);  		}  	}  } |