summaryrefslogtreecommitdiff
path: root/src/network
diff options
context:
space:
mode:
Diffstat (limited to 'src/network')
-rw-r--r--src/network/protocol.cpp220
-rw-r--r--src/network/protocol.h40
2 files changed, 205 insertions, 55 deletions
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index f001d6c..40aeb47 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -23,7 +23,7 @@ struct NetworkProtocol::ConnectionPriv
{
Connection::Id id() const;
- bool send(const PartialStorage &, const Header &,
+ bool send(const PartialStorage &, Header,
const vector<Object> &, bool secure);
NetworkProtocol * protocol;
@@ -76,23 +76,28 @@ NetworkProtocol::PollResult NetworkProtocol::poll()
if (!recvfrom(buffer, addr))
return ProtocolClosed {};
- scoped_lock lock(protocolMutex);
- for (const auto & c : connections) {
- if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) {
- scoped_lock clock(c->cmutex);
- buffer.swap(c->buffer);
- return ConnectionReadReady { c->id() };
+ {
+ scoped_lock lock(protocolMutex);
+
+ for (const auto & c : connections) {
+ if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) {
+ scoped_lock clock(c->cmutex);
+ buffer.swap(c->buffer);
+ return ConnectionReadReady { c->id() };
+ }
}
- }
- auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv {
- .protocol = this,
- .peerAddress = addr,
- });
+ auto pst = self->ref()->storage().deriveEphemeralStorage();
+ if (auto header = Connection::receive(buffer, nullptr, pst)) {
+ if (auto conn = verifyNewConnection(*header, addr))
+ return NewConnection { move(*conn) };
- connections.push_back(conn.get());
- buffer.swap(conn->buffer);
- return NewConnection { Connection(move(conn)) };
+ if (auto ann = header->lookupFirst<Header::AnnounceSelf>())
+ return ReceivedAnnounce { addr, ann->value };
+ }
+ }
+
+ return poll();
}
NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
@@ -107,10 +112,10 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
connections.push_back(conn.get());
vector<Header::Item> header {
- Header::AnnounceSelf { self->ref()->digest() },
+ Header::Initiation { Digest(array<uint8_t, Digest::size> {}) },
Header::Version { defaultVersion },
};
- conn->send(self->ref()->storage(), header, {}, false);
+ conn->send(self->ref()->storage(), move(header), {}, false);
}
return Connection(move(conn));
@@ -179,6 +184,70 @@ void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in
}, vaddr);
}
+void NetworkProtocol::sendCookie(variant<sockaddr_in, sockaddr_in6> addr)
+{
+ auto bytes = Header({
+ Header::CookieSet { generateCookie(addr) },
+ Header::AnnounceSelf { self->ref()->digest() },
+ Header::Version { defaultVersion },
+ }).toObject(self->ref()->storage()).encode();
+
+ sendto(bytes, addr);
+}
+
+optional<NetworkProtocol::Connection> NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr)
+{
+ optional<string> version;
+ for (const auto & h : header.items) {
+ if (const auto * ptr = get_if<Header::Version>(&h)) {
+ if (ptr->value == defaultVersion) {
+ version = ptr->value;
+ break;
+ }
+ }
+ }
+ if (!version)
+ return nullopt;
+
+ if (header.lookupFirst<Header::Initiation>()) {
+ sendCookie(addr);
+ }
+
+ else if (auto cookie = header.lookupFirst<Header::CookieEcho>()) {
+ if (verifyCookie(addr, cookie->value)) {
+ auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv {
+ .protocol = this,
+ .peerAddress = addr,
+ });
+
+ connections.push_back(conn.get());
+ buffer.swap(conn->buffer);
+ return Connection(move(conn));
+ }
+ }
+
+ return nullopt;
+}
+
+NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant<sockaddr_in, sockaddr_in6> vaddr) const
+{
+ vector<uint8_t> cookie;
+ visit([&](auto && addr) {
+ cookie.resize(sizeof addr);
+ memcpy(cookie.data(), &addr, sizeof addr);
+ }, vaddr);
+ return Cookie { cookie };
+}
+
+bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, const NetworkProtocol::Cookie & cookie) const
+{
+ return visit([&](auto && addr) {
+ if (cookie.value.size() != sizeof addr)
+ return false;
+ return memcmp(cookie.value.data(), &addr, sizeof addr) == 0;
+ }, vaddr);
+}
+
/******************************************************************************/
/* Connection */
/******************************************************************************/
@@ -222,53 +291,76 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const
optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)
{
- vector<uint8_t> buf, decrypted;
- auto plainBegin = buf.cbegin();
- auto plainEnd = buf.cbegin();
+ vector<uint8_t> buf;
+
+ Channel * channel = nullptr;
+ unique_ptr<Channel> channelPtr;
{
scoped_lock lock(p->cmutex);
if (p->buffer.empty())
return nullopt;
-
buf.swap(p->buffer);
- if ((buf[0] & 0xE0) == 0x80) {
- Channel * channel = nullptr;
- unique_ptr<Channel> channelPtr;
+ if (holds_alternative<unique_ptr<Channel>>(p->channel)) {
+ channel = std::get<unique_ptr<Channel>>(p->channel).get();
+ } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) {
+ channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel();
+ channel = channelPtr.get();
+ }
+ }
+
+ if (auto header = receive(buf, channel, partStorage)) {
+ scoped_lock lock(p->cmutex);
- if (holds_alternative<unique_ptr<Channel>>(p->channel)) {
- channel = std::get<unique_ptr<Channel>>(p->channel).get();
- } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) {
- channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel();
- channel = channelPtr.get();
- }
+ if (header->lookupFirst<Header::Initiation>()) {
+ p->protocol->sendCookie(p->peerAddress);
+ return nullopt;
+ }
- if (not channel) {
- std::cerr << "unexpected encrypted packet\n";
- return nullopt;
- }
+ if (holds_alternative<monostate>(p->channel) ||
+ holds_alternative<Cookie>(p->channel))
+ if (const auto * cookie = header->lookupFirst<Header::CookieSet>())
+ p->channel = cookie->value;
- if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) {
- if (decrypted.empty()) {
- std::cerr << "empty decrypted content\n";
- }
- else if (decrypted[0] == 0x00) {
- plainBegin = decrypted.begin() + 1;
- plainEnd = decrypted.end();
- }
- else {
- std::cerr << "streams not implemented\n";
- return nullopt;
- }
- }
+ return header;
+ }
+ return nullopt;
+}
+
+optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<uint8_t> & buf,
+ Channel * channel,
+ const PartialStorage & partStorage)
+{
+ vector<uint8_t> decrypted;
+ auto plainBegin = buf.cbegin();
+ auto plainEnd = buf.cbegin();
+
+ if ((buf[0] & 0xE0) == 0x80) {
+ if (not channel) {
+ std::cerr << "unexpected encrypted packet\n";
+ return nullopt;
}
- else if ((buf[0] & 0xE0) == 0x60) {
- plainBegin = buf.begin();
- plainEnd = buf.end();
+
+ if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) {
+ if (decrypted.empty()) {
+ std::cerr << "empty decrypted content\n";
+ }
+ else if (decrypted[0] == 0x00) {
+ plainBegin = decrypted.begin() + 1;
+ plainEnd = decrypted.end();
+ }
+ else {
+ std::cerr << "streams not implemented\n";
+ return nullopt;
+ }
}
}
+ else if ((buf[0] & 0xE0) == 0x60) {
+ plainBegin = buf.begin();
+ plainEnd = buf.end();
+ }
if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) {
if (auto header = Header::load(std::get<PartialObject>(*dec))) {
@@ -287,14 +379,14 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
}
bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
- const Header & header,
+ Header header,
const vector<Object> & objs, bool secure)
{
- return p->send(partStorage, header, objs, secure);
+ return p->send(partStorage, move(header), objs, secure);
}
bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
- const Header & header,
+ Header header,
const vector<Object> & objs, bool secure)
{
vector<uint8_t> data, part, out;
@@ -308,6 +400,10 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
if (channel || secure)
data.push_back(0x00);
+ else if (const auto * ptr = get_if<Cookie>(&this->channel)) {
+ header.items.push_back(Header::CookieEcho { ptr->value });
+ header.items.push_back(Header::Version { defaultVersion });
+ }
part = header.toObject(partStorage).encode();
data.insert(data.end(), part.begin(), part.end());
@@ -414,6 +510,15 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj
} else if (item.name == "VER") {
if (auto ver = item.asText())
items.emplace_back(Version { *ver });
+ } else if (item.name == "INI") {
+ if (auto ref = item.asRef())
+ items.emplace_back(Initiation { ref->digest() });
+ } else if (item.name == "CKS") {
+ if (auto cookie = item.asBinary())
+ items.emplace_back(CookieSet { *cookie });
+ } else if (item.name == "CKE") {
+ if (auto cookie = item.asBinary())
+ items.emplace_back(CookieEcho { *cookie });
} else if (item.name == "REQ") {
if (auto ref = item.asRef())
items.emplace_back(DataRequest { ref->digest() });
@@ -455,6 +560,15 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const
else if (const auto * ptr = get_if<Version>(&item))
ritems.emplace_back("VER", ptr->value);
+ else if (const auto * ptr = get_if<Initiation>(&item))
+ ritems.emplace_back("INI", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<CookieSet>(&item))
+ ritems.emplace_back("CKS", ptr->value.value);
+
+ else if (const auto * ptr = get_if<CookieEcho>(&item))
+ ritems.emplace_back("CKE", ptr->value.value);
+
else if (const auto * ptr = get_if<DataRequest>(&item))
ritems.emplace_back("REQ", st.ref(ptr->value));
diff --git a/src/network/protocol.h b/src/network/protocol.h
index 545585e..dda2ffb 100644
--- a/src/network/protocol.h
+++ b/src/network/protocol.h
@@ -38,18 +38,23 @@ public:
struct Header;
+ struct ReceivedAnnounce;
struct NewConnection;
struct ConnectionReadReady;
struct ProtocolClosed {};
using PollResult = variant<
+ ReceivedAnnounce,
NewConnection,
ConnectionReadReady,
ProtocolClosed>;
PollResult poll();
+ struct Cookie { vector<uint8_t> value; };
+
using ChannelState = variant<monostate,
+ Cookie,
Stored<ChannelRequest>,
shared_ptr<struct WaitingRef>,
Stored<ChannelAccept>,
@@ -66,6 +71,12 @@ private:
bool recvfrom(vector<uint8_t> & buffer, sockaddr_in6 & addr);
void sendto(const vector<uint8_t> & buffer, variant<sockaddr_in, sockaddr_in6> addr);
+ void sendCookie(variant<sockaddr_in, sockaddr_in6> addr);
+ optional<Connection> verifyNewConnection(const Header & header, sockaddr_in6 addr);
+
+ Cookie generateCookie(variant<sockaddr_in, sockaddr_in6> addr) const;
+ bool verifyCookie(variant<sockaddr_in, sockaddr_in6> addr, const Cookie & cookie) const;
+
int sock;
mutex protocolMutex;
@@ -94,7 +105,7 @@ public:
const sockaddr_in6 & peerAddress() const;
optional<Header> receive(const PartialStorage &);
- bool send(const PartialStorage &, const NetworkProtocol::Header &,
+ bool send(const PartialStorage &, NetworkProtocol::Header,
const vector<Object> &, bool secure);
void close();
@@ -104,9 +115,14 @@ public:
void trySendOutQueue();
private:
+ static optional<Header> receive(vector<uint8_t> & buf,
+ Channel * channel,
+ const PartialStorage & st);
+
unique_ptr<ConnectionPriv> p;
};
+struct NetworkProtocol::ReceivedAnnounce { sockaddr_in6 addr; Digest digest; };
struct NetworkProtocol::NewConnection { Connection conn; };
struct NetworkProtocol::ConnectionReadReady { Connection::Id id; };
@@ -114,6 +130,9 @@ struct NetworkProtocol::Header
{
struct Acknowledged { Digest value; };
struct Version { string value; };
+ struct Initiation { Digest value; };
+ struct CookieSet { Cookie value; };
+ struct CookieEcho { Cookie value; };
struct DataRequest { Digest value; };
struct DataResponse { Digest value; };
struct AnnounceSelf { Digest value; };
@@ -126,6 +145,9 @@ struct NetworkProtocol::Header
using Item = variant<
Acknowledged,
Version,
+ Initiation,
+ CookieSet,
+ CookieEcho,
DataRequest,
DataResponse,
AnnounceSelf,
@@ -140,14 +162,28 @@ struct NetworkProtocol::Header
static optional<Header> load(const PartialObject &);
PartialObject toObject(const PartialStorage &) const;
- const vector<Item> items;
+ template<class T> const T * lookupFirst() const;
+
+ vector<Item> items;
};
+template<class T>
+const T * NetworkProtocol::Header::lookupFirst() const
+{
+ for (const auto & h : items)
+ if (auto ptr = std::get_if<T>(&h))
+ return ptr;
+ return nullptr;
+}
+
bool operator==(const NetworkProtocol::Header::Item &, const NetworkProtocol::Header::Item &);
inline bool operator!=(const NetworkProtocol::Header::Item & left,
const NetworkProtocol::Header::Item & right)
{ return not (left == right); }
+inline bool operator==(const NetworkProtocol::Cookie & left, const NetworkProtocol::Cookie & right)
+{ return left.value == right.value; }
+
class ReplyBuilder
{
public: