summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/erebos/pairing.h8
-rw-r--r--src/pairing.cpp132
2 files changed, 84 insertions, 56 deletions
diff --git a/include/erebos/pairing.h b/include/erebos/pairing.h
index 4457426..b8b680a 100644
--- a/include/erebos/pairing.h
+++ b/include/erebos/pairing.h
@@ -29,6 +29,8 @@ using std::vector;
class PairingServiceBase : public Service
{
public:
+ virtual ~PairingServiceBase();
+
typedef function<void(const Peer &)> RequestInitHook;
void onRequestInit(RequestInitHook);
@@ -49,7 +51,6 @@ private:
static vector<uint8_t> nonceDigest(const Identity & id1, const Identity & id2,
const vector<uint8_t> & nonce1, const vector<uint8_t> & nonce2);
static string confirmationNumber(const vector<uint8_t> &);
- void waitForConfirmation(Peer peer, string confirm);
RequestInitHook requestInitHook;
ConfirmHook responseHook;
@@ -70,14 +71,17 @@ private:
};
struct State {
+ mutex lock;
StatePhase phase;
vector<uint8_t> nonce;
vector<uint8_t> peerCheck;
promise<bool> success;
};
- map<Peer, State> peerStates;
+ map<Peer, shared_ptr<State>> peerStates;
mutex stateLock;
+
+ void waitForConfirmation(Peer peer, weak_ptr<State> state, string confirm, ConfirmHook hook);
};
template<class Result>
diff --git a/src/pairing.cpp b/src/pairing.cpp
index 4ae215b..c8e2d9f 100644
--- a/src/pairing.cpp
+++ b/src/pairing.cpp
@@ -2,6 +2,7 @@
#include "service.h"
+#include <future>
#include <openssl/rand.h>
#include <arpa/inet.h>
@@ -15,6 +16,22 @@ using namespace erebos;
using std::lock_guard;
using std::runtime_error;
+using std::scoped_lock;
+using std::thread;
+
+PairingServiceBase::~PairingServiceBase()
+{
+ // There may be some threads in waitForConfirmation waiting on client
+ // promise, so make sure they do not touch the service state anymore:
+ for (auto & [peer, state] : peerStates) {
+ scoped_lock lock(state->lock);
+ if (state->phase != StatePhase::PairingDone &&
+ state->phase != StatePhase::PairingFailed) {
+ state->success.set_value(false);
+ state->phase = StatePhase::PairingFailed;
+ }
+ }
+}
void PairingServiceBase::onRequestInit(RequestInitHook hook)
{
@@ -51,27 +68,28 @@ void PairingServiceBase::handle(Context & ctx)
throw runtime_error("Pairing request for peer without known identity");
lock_guard lock(stateLock);
- auto & state = peerStates.try_emplace(ctx.peer(), State()).first->second;
+ auto & state = peerStates.try_emplace(ctx.peer(), new State()).first->second;
+ scoped_lock lock2(state->lock);
if (auto request = rec->item("request").asBinary()) {
- if (state.phase != StatePhase::NoPairing)
+ if (state->phase != StatePhase::NoPairing)
return;
if (requestInitHook)
requestInitHook(ctx.peer());
- state.phase = StatePhase::PeerRequest;
- state.peerCheck = *request;
- state.nonce.resize(32);
- RAND_bytes(state.nonce.data(), state.nonce.size());
+ state->phase = StatePhase::PeerRequest;
+ state->peerCheck = *request;
+ state->nonce.resize(32);
+ RAND_bytes(state->nonce.data(), state->nonce.size());
ctx.peer().send(uuid(), Object(Record({
- { "response", state.nonce },
+ { "response", state->nonce },
})));
}
else if (auto response = rec->item("response").asBinary()) {
- if (state.phase != StatePhase::OurRequest) {
+ if (state->phase != StatePhase::OurRequest) {
fprintf(stderr, "Unexpected pairing response.\n"); // TODO
return;
}
@@ -79,15 +97,15 @@ void PairingServiceBase::handle(Context & ctx)
if (responseHook) {
string confirm = confirmationNumber(nonceDigest(
ctx.peer().server().identity(), *pid,
- state.nonce, *response));
+ state->nonce, *response));
std::thread(&PairingServiceBase::waitForConfirmation,
- this, ctx.peer(), confirm).detach();
+ this, ctx.peer(), state, confirm, responseHook).detach();
}
- state.phase = StatePhase::OurRequestConfirm;
+ state->phase = StatePhase::OurRequestConfirm;
ctx.peer().send(uuid(), Object(Record({
- { "reqnonce", state.nonce },
+ { "reqnonce", state->nonce },
})));
}
@@ -95,12 +113,12 @@ void PairingServiceBase::handle(Context & ctx)
auto check = nonceDigest(
*pid, ctx.peer().server().identity(),
*reqnonce, vector<uint8_t>());
- if (check != state.peerCheck) {
+ if (check != state->peerCheck) {
if (requestNonceFailedHook)
requestNonceFailedHook(ctx.peer());
- if (state.phase < StatePhase::PairingDone) {
- state.phase = StatePhase::PairingFailed;
- state.success.set_value(false);
+ if (state->phase < StatePhase::PairingDone) {
+ state->phase = StatePhase::PairingFailed;
+ state->success.set_value(false);
}
return;
}
@@ -108,26 +126,26 @@ void PairingServiceBase::handle(Context & ctx)
if (requestHook) {
string confirm = confirmationNumber(nonceDigest(
*pid, ctx.peer().server().identity(),
- *reqnonce, state.nonce));
+ *reqnonce, state->nonce));
std::thread(&PairingServiceBase::waitForConfirmation,
- this, ctx.peer(), confirm).detach();
+ this, ctx.peer(), state, confirm, requestHook).detach();
}
- state.phase = StatePhase::PeerRequestConfirm;
+ state->phase = StatePhase::PeerRequestConfirm;
}
else if (auto decline = rec->item("decline").asText()) {
- if (state.phase < StatePhase::PairingDone) {
- state.phase = StatePhase::PairingFailed;
- state.success.set_value(false);
+ if (state->phase < StatePhase::PairingDone) {
+ state->phase = StatePhase::PairingFailed;
+ state->success.set_value(false);
}
}
else {
- if (state.phase == StatePhase::OurRequestReady) {
+ if (state->phase == StatePhase::OurRequestReady) {
handlePairingResult(ctx);
- state.phase = StatePhase::PairingDone;
- state.success.set_value(true);
+ state->phase = StatePhase::PairingDone;
+ state->success.set_value(true);
} else {
result = ctx.ref();
}
@@ -141,16 +159,16 @@ void PairingServiceBase::requestPairing(UUID serviceId, const Peer & peer)
throw runtime_error("Pairing request for peer without known identity");
lock_guard lock(stateLock);
- auto & state = peerStates.try_emplace(peer, State()).first->second;
+ auto & state = peerStates.try_emplace(peer, new State()).first->second;
- state.phase = StatePhase::OurRequest;
- state.nonce.resize(32);
- RAND_bytes(state.nonce.data(), state.nonce.size());
+ state->phase = StatePhase::OurRequest;
+ state->nonce.resize(32);
+ RAND_bytes(state->nonce.data(), state->nonce.size());
vector<Record::Item> items;
items.emplace_back("request", nonceDigest(
peer.server().identity(), *pid,
- state.nonce, vector<uint8_t>()));
+ state->nonce, vector<uint8_t>()));
peer.send(serviceId, Object(Record(std::move(items))));
}
@@ -179,28 +197,34 @@ string PairingServiceBase::confirmationNumber(const vector<uint8_t> & digest)
return ret;
}
-void PairingServiceBase::waitForConfirmation(Peer peer, string confirm)
+void PairingServiceBase::waitForConfirmation(Peer peer, weak_ptr<State> wstate, string confirm, ConfirmHook hook)
{
- ConfirmHook hook;
future<bool> success;
- {
- lock_guard lock(stateLock);
- auto & state = peerStates.try_emplace(peer, State()).first->second;
- if (state.phase == StatePhase::OurRequestConfirm)
- hook = responseHook;
- if (state.phase == StatePhase::PeerRequestConfirm)
- hook = requestHook;
-
- success = state.success.get_future();
+ if (auto state = wstate.lock()) {
+ success = state->success.get_future();
+ } else {
+ return;
}
- bool ok = hook(peer, confirm, std::move(success)).get();
+ bool ok;
+ try {
+ ok = hook(peer, confirm, std::move(success)).get();
+ }
+ catch (const std::future_error & e) {
+ if (e.code() == std::future_errc::broken_promise)
+ ok = false;
+ else
+ throw;
+ }
- lock_guard lock(stateLock);
- auto & state = peerStates.try_emplace(peer, State()).first->second;
+ auto state = wstate.lock();
+ if (!state)
+ return; // Server was closed
+
+ scoped_lock lock(state->lock);
if (ok) {
- if (state.phase == StatePhase::OurRequestConfirm) {
+ if (state->phase == StatePhase::OurRequestConfirm) {
if (result) {
peer.server().localHead().update([&] (const Stored<LocalState> & local) {
Service::Context ctx(new Service::Context::Priv {
@@ -212,21 +236,21 @@ void PairingServiceBase::waitForConfirmation(Peer peer, string confirm)
handlePairingResult(ctx);
return ctx.local();
});
- state.phase = StatePhase::PairingDone;
- state.success.set_value(true);
+ state->phase = StatePhase::PairingDone;
+ state->success.set_value(true);
} else {
- state.phase = StatePhase::OurRequestReady;
+ state->phase = StatePhase::OurRequestReady;
}
- } else if (state.phase == StatePhase::PeerRequestConfirm) {
+ } else if (state->phase == StatePhase::PeerRequestConfirm) {
peer.send(uuid(), handlePairingCompleteRef(peer));
- state.phase = StatePhase::PairingDone;
- state.success.set_value(true);
+ state->phase = StatePhase::PairingDone;
+ state->success.set_value(true);
}
} else {
- if (state.phase != StatePhase::PairingFailed) {
+ if (state->phase != StatePhase::PairingFailed) {
peer.send(uuid(), Object(Record({{ "decline", string() }})));
- state.phase = StatePhase::PairingFailed;
- state.success.set_value(false);
+ state->phase = StatePhase::PairingFailed;
+ state->success.set_value(false);
}
}
}