summaryrefslogtreecommitdiff
path: root/src/network/protocol.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/network/protocol.cpp')
-rw-r--r--src/network/protocol.cpp220
1 files changed, 167 insertions, 53 deletions
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp
index f001d6c..40aeb47 100644
--- a/src/network/protocol.cpp
+++ b/src/network/protocol.cpp
@@ -23,7 +23,7 @@ struct NetworkProtocol::ConnectionPriv
{
Connection::Id id() const;
- bool send(const PartialStorage &, const Header &,
+ bool send(const PartialStorage &, Header,
const vector<Object> &, bool secure);
NetworkProtocol * protocol;
@@ -76,23 +76,28 @@ NetworkProtocol::PollResult NetworkProtocol::poll()
if (!recvfrom(buffer, addr))
return ProtocolClosed {};
- scoped_lock lock(protocolMutex);
- for (const auto & c : connections) {
- if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) {
- scoped_lock clock(c->cmutex);
- buffer.swap(c->buffer);
- return ConnectionReadReady { c->id() };
+ {
+ scoped_lock lock(protocolMutex);
+
+ for (const auto & c : connections) {
+ if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) {
+ scoped_lock clock(c->cmutex);
+ buffer.swap(c->buffer);
+ return ConnectionReadReady { c->id() };
+ }
}
- }
- auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv {
- .protocol = this,
- .peerAddress = addr,
- });
+ auto pst = self->ref()->storage().deriveEphemeralStorage();
+ if (auto header = Connection::receive(buffer, nullptr, pst)) {
+ if (auto conn = verifyNewConnection(*header, addr))
+ return NewConnection { move(*conn) };
- connections.push_back(conn.get());
- buffer.swap(conn->buffer);
- return NewConnection { Connection(move(conn)) };
+ if (auto ann = header->lookupFirst<Header::AnnounceSelf>())
+ return ReceivedAnnounce { addr, ann->value };
+ }
+ }
+
+ return poll();
}
NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
@@ -107,10 +112,10 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr)
connections.push_back(conn.get());
vector<Header::Item> header {
- Header::AnnounceSelf { self->ref()->digest() },
+ Header::Initiation { Digest(array<uint8_t, Digest::size> {}) },
Header::Version { defaultVersion },
};
- conn->send(self->ref()->storage(), header, {}, false);
+ conn->send(self->ref()->storage(), move(header), {}, false);
}
return Connection(move(conn));
@@ -179,6 +184,70 @@ void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in
}, vaddr);
}
+void NetworkProtocol::sendCookie(variant<sockaddr_in, sockaddr_in6> addr)
+{
+ auto bytes = Header({
+ Header::CookieSet { generateCookie(addr) },
+ Header::AnnounceSelf { self->ref()->digest() },
+ Header::Version { defaultVersion },
+ }).toObject(self->ref()->storage()).encode();
+
+ sendto(bytes, addr);
+}
+
+optional<NetworkProtocol::Connection> NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr)
+{
+ optional<string> version;
+ for (const auto & h : header.items) {
+ if (const auto * ptr = get_if<Header::Version>(&h)) {
+ if (ptr->value == defaultVersion) {
+ version = ptr->value;
+ break;
+ }
+ }
+ }
+ if (!version)
+ return nullopt;
+
+ if (header.lookupFirst<Header::Initiation>()) {
+ sendCookie(addr);
+ }
+
+ else if (auto cookie = header.lookupFirst<Header::CookieEcho>()) {
+ if (verifyCookie(addr, cookie->value)) {
+ auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv {
+ .protocol = this,
+ .peerAddress = addr,
+ });
+
+ connections.push_back(conn.get());
+ buffer.swap(conn->buffer);
+ return Connection(move(conn));
+ }
+ }
+
+ return nullopt;
+}
+
+NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant<sockaddr_in, sockaddr_in6> vaddr) const
+{
+ vector<uint8_t> cookie;
+ visit([&](auto && addr) {
+ cookie.resize(sizeof addr);
+ memcpy(cookie.data(), &addr, sizeof addr);
+ }, vaddr);
+ return Cookie { cookie };
+}
+
+bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, const NetworkProtocol::Cookie & cookie) const
+{
+ return visit([&](auto && addr) {
+ if (cookie.value.size() != sizeof addr)
+ return false;
+ return memcmp(cookie.value.data(), &addr, sizeof addr) == 0;
+ }, vaddr);
+}
+
/******************************************************************************/
/* Connection */
/******************************************************************************/
@@ -222,53 +291,76 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const
optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage)
{
- vector<uint8_t> buf, decrypted;
- auto plainBegin = buf.cbegin();
- auto plainEnd = buf.cbegin();
+ vector<uint8_t> buf;
+
+ Channel * channel = nullptr;
+ unique_ptr<Channel> channelPtr;
{
scoped_lock lock(p->cmutex);
if (p->buffer.empty())
return nullopt;
-
buf.swap(p->buffer);
- if ((buf[0] & 0xE0) == 0x80) {
- Channel * channel = nullptr;
- unique_ptr<Channel> channelPtr;
+ if (holds_alternative<unique_ptr<Channel>>(p->channel)) {
+ channel = std::get<unique_ptr<Channel>>(p->channel).get();
+ } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) {
+ channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel();
+ channel = channelPtr.get();
+ }
+ }
+
+ if (auto header = receive(buf, channel, partStorage)) {
+ scoped_lock lock(p->cmutex);
- if (holds_alternative<unique_ptr<Channel>>(p->channel)) {
- channel = std::get<unique_ptr<Channel>>(p->channel).get();
- } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) {
- channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel();
- channel = channelPtr.get();
- }
+ if (header->lookupFirst<Header::Initiation>()) {
+ p->protocol->sendCookie(p->peerAddress);
+ return nullopt;
+ }
- if (not channel) {
- std::cerr << "unexpected encrypted packet\n";
- return nullopt;
- }
+ if (holds_alternative<monostate>(p->channel) ||
+ holds_alternative<Cookie>(p->channel))
+ if (const auto * cookie = header->lookupFirst<Header::CookieSet>())
+ p->channel = cookie->value;
- if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) {
- if (decrypted.empty()) {
- std::cerr << "empty decrypted content\n";
- }
- else if (decrypted[0] == 0x00) {
- plainBegin = decrypted.begin() + 1;
- plainEnd = decrypted.end();
- }
- else {
- std::cerr << "streams not implemented\n";
- return nullopt;
- }
- }
+ return header;
+ }
+ return nullopt;
+}
+
+optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<uint8_t> & buf,
+ Channel * channel,
+ const PartialStorage & partStorage)
+{
+ vector<uint8_t> decrypted;
+ auto plainBegin = buf.cbegin();
+ auto plainEnd = buf.cbegin();
+
+ if ((buf[0] & 0xE0) == 0x80) {
+ if (not channel) {
+ std::cerr << "unexpected encrypted packet\n";
+ return nullopt;
}
- else if ((buf[0] & 0xE0) == 0x60) {
- plainBegin = buf.begin();
- plainEnd = buf.end();
+
+ if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) {
+ if (decrypted.empty()) {
+ std::cerr << "empty decrypted content\n";
+ }
+ else if (decrypted[0] == 0x00) {
+ plainBegin = decrypted.begin() + 1;
+ plainEnd = decrypted.end();
+ }
+ else {
+ std::cerr << "streams not implemented\n";
+ return nullopt;
+ }
}
}
+ else if ((buf[0] & 0xE0) == 0x60) {
+ plainBegin = buf.begin();
+ plainEnd = buf.end();
+ }
if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) {
if (auto header = Header::load(std::get<PartialObject>(*dec))) {
@@ -287,14 +379,14 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par
}
bool NetworkProtocol::Connection::send(const PartialStorage & partStorage,
- const Header & header,
+ Header header,
const vector<Object> & objs, bool secure)
{
- return p->send(partStorage, header, objs, secure);
+ return p->send(partStorage, move(header), objs, secure);
}
bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
- const Header & header,
+ Header header,
const vector<Object> & objs, bool secure)
{
vector<uint8_t> data, part, out;
@@ -308,6 +400,10 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage,
if (channel || secure)
data.push_back(0x00);
+ else if (const auto * ptr = get_if<Cookie>(&this->channel)) {
+ header.items.push_back(Header::CookieEcho { ptr->value });
+ header.items.push_back(Header::Version { defaultVersion });
+ }
part = header.toObject(partStorage).encode();
data.insert(data.end(), part.begin(), part.end());
@@ -414,6 +510,15 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj
} else if (item.name == "VER") {
if (auto ver = item.asText())
items.emplace_back(Version { *ver });
+ } else if (item.name == "INI") {
+ if (auto ref = item.asRef())
+ items.emplace_back(Initiation { ref->digest() });
+ } else if (item.name == "CKS") {
+ if (auto cookie = item.asBinary())
+ items.emplace_back(CookieSet { *cookie });
+ } else if (item.name == "CKE") {
+ if (auto cookie = item.asBinary())
+ items.emplace_back(CookieEcho { *cookie });
} else if (item.name == "REQ") {
if (auto ref = item.asRef())
items.emplace_back(DataRequest { ref->digest() });
@@ -455,6 +560,15 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const
else if (const auto * ptr = get_if<Version>(&item))
ritems.emplace_back("VER", ptr->value);
+ else if (const auto * ptr = get_if<Initiation>(&item))
+ ritems.emplace_back("INI", st.ref(ptr->value));
+
+ else if (const auto * ptr = get_if<CookieSet>(&item))
+ ritems.emplace_back("CKS", ptr->value.value);
+
+ else if (const auto * ptr = get_if<CookieEcho>(&item))
+ ritems.emplace_back("CKE", ptr->value.value);
+
else if (const auto * ptr = get_if<DataRequest>(&item))
ritems.emplace_back("REQ", st.ref(ptr->value));