From 09015df5e93de837bdbe0ad87469762dbdda4e6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 13 Nov 2021 22:14:54 +0100 Subject: Pairing: properly handle lingering threads after server stops --- src/pairing.cpp | 132 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 78 insertions(+), 54 deletions(-) (limited to 'src') diff --git a/src/pairing.cpp b/src/pairing.cpp index 4ae215b..c8e2d9f 100644 --- a/src/pairing.cpp +++ b/src/pairing.cpp @@ -2,6 +2,7 @@ #include "service.h" +#include #include #include @@ -15,6 +16,22 @@ using namespace erebos; using std::lock_guard; using std::runtime_error; +using std::scoped_lock; +using std::thread; + +PairingServiceBase::~PairingServiceBase() +{ + // There may be some threads in waitForConfirmation waiting on client + // promise, so make sure they do not touch the service state anymore: + for (auto & [peer, state] : peerStates) { + scoped_lock lock(state->lock); + if (state->phase != StatePhase::PairingDone && + state->phase != StatePhase::PairingFailed) { + state->success.set_value(false); + state->phase = StatePhase::PairingFailed; + } + } +} void PairingServiceBase::onRequestInit(RequestInitHook hook) { @@ -51,27 +68,28 @@ void PairingServiceBase::handle(Context & ctx) throw runtime_error("Pairing request for peer without known identity"); lock_guard lock(stateLock); - auto & state = peerStates.try_emplace(ctx.peer(), State()).first->second; + auto & state = peerStates.try_emplace(ctx.peer(), new State()).first->second; + scoped_lock lock2(state->lock); if (auto request = rec->item("request").asBinary()) { - if (state.phase != StatePhase::NoPairing) + if (state->phase != StatePhase::NoPairing) return; if (requestInitHook) requestInitHook(ctx.peer()); - state.phase = StatePhase::PeerRequest; - state.peerCheck = *request; - state.nonce.resize(32); - RAND_bytes(state.nonce.data(), state.nonce.size()); + state->phase = StatePhase::PeerRequest; + state->peerCheck = *request; + state->nonce.resize(32); + RAND_bytes(state->nonce.data(), state->nonce.size()); ctx.peer().send(uuid(), Object(Record({ - { "response", state.nonce }, + { "response", state->nonce }, }))); } else if (auto response = rec->item("response").asBinary()) { - if (state.phase != StatePhase::OurRequest) { + if (state->phase != StatePhase::OurRequest) { fprintf(stderr, "Unexpected pairing response.\n"); // TODO return; } @@ -79,15 +97,15 @@ void PairingServiceBase::handle(Context & ctx) if (responseHook) { string confirm = confirmationNumber(nonceDigest( ctx.peer().server().identity(), *pid, - state.nonce, *response)); + state->nonce, *response)); std::thread(&PairingServiceBase::waitForConfirmation, - this, ctx.peer(), confirm).detach(); + this, ctx.peer(), state, confirm, responseHook).detach(); } - state.phase = StatePhase::OurRequestConfirm; + state->phase = StatePhase::OurRequestConfirm; ctx.peer().send(uuid(), Object(Record({ - { "reqnonce", state.nonce }, + { "reqnonce", state->nonce }, }))); } @@ -95,12 +113,12 @@ void PairingServiceBase::handle(Context & ctx) auto check = nonceDigest( *pid, ctx.peer().server().identity(), *reqnonce, vector()); - if (check != state.peerCheck) { + if (check != state->peerCheck) { if (requestNonceFailedHook) requestNonceFailedHook(ctx.peer()); - if (state.phase < StatePhase::PairingDone) { - state.phase = StatePhase::PairingFailed; - state.success.set_value(false); + if (state->phase < StatePhase::PairingDone) { + state->phase = StatePhase::PairingFailed; + state->success.set_value(false); } return; } @@ -108,26 +126,26 @@ void PairingServiceBase::handle(Context & ctx) if (requestHook) { string confirm = confirmationNumber(nonceDigest( *pid, ctx.peer().server().identity(), - *reqnonce, state.nonce)); + *reqnonce, state->nonce)); std::thread(&PairingServiceBase::waitForConfirmation, - this, ctx.peer(), confirm).detach(); + this, ctx.peer(), state, confirm, requestHook).detach(); } - state.phase = StatePhase::PeerRequestConfirm; + state->phase = StatePhase::PeerRequestConfirm; } else if (auto decline = rec->item("decline").asText()) { - if (state.phase < StatePhase::PairingDone) { - state.phase = StatePhase::PairingFailed; - state.success.set_value(false); + if (state->phase < StatePhase::PairingDone) { + state->phase = StatePhase::PairingFailed; + state->success.set_value(false); } } else { - if (state.phase == StatePhase::OurRequestReady) { + if (state->phase == StatePhase::OurRequestReady) { handlePairingResult(ctx); - state.phase = StatePhase::PairingDone; - state.success.set_value(true); + state->phase = StatePhase::PairingDone; + state->success.set_value(true); } else { result = ctx.ref(); } @@ -141,16 +159,16 @@ void PairingServiceBase::requestPairing(UUID serviceId, const Peer & peer) throw runtime_error("Pairing request for peer without known identity"); lock_guard lock(stateLock); - auto & state = peerStates.try_emplace(peer, State()).first->second; + auto & state = peerStates.try_emplace(peer, new State()).first->second; - state.phase = StatePhase::OurRequest; - state.nonce.resize(32); - RAND_bytes(state.nonce.data(), state.nonce.size()); + state->phase = StatePhase::OurRequest; + state->nonce.resize(32); + RAND_bytes(state->nonce.data(), state->nonce.size()); vector items; items.emplace_back("request", nonceDigest( peer.server().identity(), *pid, - state.nonce, vector())); + state->nonce, vector())); peer.send(serviceId, Object(Record(std::move(items)))); } @@ -179,28 +197,34 @@ string PairingServiceBase::confirmationNumber(const vector & digest) return ret; } -void PairingServiceBase::waitForConfirmation(Peer peer, string confirm) +void PairingServiceBase::waitForConfirmation(Peer peer, weak_ptr wstate, string confirm, ConfirmHook hook) { - ConfirmHook hook; future success; - { - lock_guard lock(stateLock); - auto & state = peerStates.try_emplace(peer, State()).first->second; - if (state.phase == StatePhase::OurRequestConfirm) - hook = responseHook; - if (state.phase == StatePhase::PeerRequestConfirm) - hook = requestHook; - - success = state.success.get_future(); + if (auto state = wstate.lock()) { + success = state->success.get_future(); + } else { + return; } - bool ok = hook(peer, confirm, std::move(success)).get(); + bool ok; + try { + ok = hook(peer, confirm, std::move(success)).get(); + } + catch (const std::future_error & e) { + if (e.code() == std::future_errc::broken_promise) + ok = false; + else + throw; + } - lock_guard lock(stateLock); - auto & state = peerStates.try_emplace(peer, State()).first->second; + auto state = wstate.lock(); + if (!state) + return; // Server was closed + + scoped_lock lock(state->lock); if (ok) { - if (state.phase == StatePhase::OurRequestConfirm) { + if (state->phase == StatePhase::OurRequestConfirm) { if (result) { peer.server().localHead().update([&] (const Stored & local) { Service::Context ctx(new Service::Context::Priv { @@ -212,21 +236,21 @@ void PairingServiceBase::waitForConfirmation(Peer peer, string confirm) handlePairingResult(ctx); return ctx.local(); }); - state.phase = StatePhase::PairingDone; - state.success.set_value(true); + state->phase = StatePhase::PairingDone; + state->success.set_value(true); } else { - state.phase = StatePhase::OurRequestReady; + state->phase = StatePhase::OurRequestReady; } - } else if (state.phase == StatePhase::PeerRequestConfirm) { + } else if (state->phase == StatePhase::PeerRequestConfirm) { peer.send(uuid(), handlePairingCompleteRef(peer)); - state.phase = StatePhase::PairingDone; - state.success.set_value(true); + state->phase = StatePhase::PairingDone; + state->success.set_value(true); } } else { - if (state.phase != StatePhase::PairingFailed) { + if (state->phase != StatePhase::PairingFailed) { peer.send(uuid(), Object(Record({{ "decline", string() }}))); - state.phase = StatePhase::PairingFailed; - state.success.set_value(false); + state->phase = StatePhase::PairingFailed; + state->success.set_value(false); } } } -- cgit v1.2.3