diff options
Diffstat (limited to 'src/network/protocol.cpp')
-rw-r--r-- | src/network/protocol.cpp | 220 |
1 files changed, 167 insertions, 53 deletions
diff --git a/src/network/protocol.cpp b/src/network/protocol.cpp index f001d6c..40aeb47 100644 --- a/src/network/protocol.cpp +++ b/src/network/protocol.cpp @@ -23,7 +23,7 @@ struct NetworkProtocol::ConnectionPriv { Connection::Id id() const; - bool send(const PartialStorage &, const Header &, + bool send(const PartialStorage &, Header, const vector<Object> &, bool secure); NetworkProtocol * protocol; @@ -76,23 +76,28 @@ NetworkProtocol::PollResult NetworkProtocol::poll() if (!recvfrom(buffer, addr)) return ProtocolClosed {}; - scoped_lock lock(protocolMutex); - for (const auto & c : connections) { - if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { - scoped_lock clock(c->cmutex); - buffer.swap(c->buffer); - return ConnectionReadReady { c->id() }; + { + scoped_lock lock(protocolMutex); + + for (const auto & c : connections) { + if (memcmp(&c->peerAddress, &addr, sizeof addr) == 0) { + scoped_lock clock(c->cmutex); + buffer.swap(c->buffer); + return ConnectionReadReady { c->id() }; + } } - } - auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { - .protocol = this, - .peerAddress = addr, - }); + auto pst = self->ref()->storage().deriveEphemeralStorage(); + if (auto header = Connection::receive(buffer, nullptr, pst)) { + if (auto conn = verifyNewConnection(*header, addr)) + return NewConnection { move(*conn) }; - connections.push_back(conn.get()); - buffer.swap(conn->buffer); - return NewConnection { Connection(move(conn)) }; + if (auto ann = header->lookupFirst<Header::AnnounceSelf>()) + return ReceivedAnnounce { addr, ann->value }; + } + } + + return poll(); } NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) @@ -107,10 +112,10 @@ NetworkProtocol::Connection NetworkProtocol::connect(sockaddr_in6 addr) connections.push_back(conn.get()); vector<Header::Item> header { - Header::AnnounceSelf { self->ref()->digest() }, + Header::Initiation { Digest(array<uint8_t, Digest::size> {}) }, Header::Version { defaultVersion }, }; - conn->send(self->ref()->storage(), header, {}, false); + conn->send(self->ref()->storage(), move(header), {}, false); } return Connection(move(conn)); @@ -179,6 +184,70 @@ void NetworkProtocol::sendto(const vector<uint8_t> & buffer, variant<sockaddr_in }, vaddr); } +void NetworkProtocol::sendCookie(variant<sockaddr_in, sockaddr_in6> addr) +{ + auto bytes = Header({ + Header::CookieSet { generateCookie(addr) }, + Header::AnnounceSelf { self->ref()->digest() }, + Header::Version { defaultVersion }, + }).toObject(self->ref()->storage()).encode(); + + sendto(bytes, addr); +} + +optional<NetworkProtocol::Connection> NetworkProtocol::verifyNewConnection(const Header & header, sockaddr_in6 addr) +{ + optional<string> version; + for (const auto & h : header.items) { + if (const auto * ptr = get_if<Header::Version>(&h)) { + if (ptr->value == defaultVersion) { + version = ptr->value; + break; + } + } + } + if (!version) + return nullopt; + + if (header.lookupFirst<Header::Initiation>()) { + sendCookie(addr); + } + + else if (auto cookie = header.lookupFirst<Header::CookieEcho>()) { + if (verifyCookie(addr, cookie->value)) { + auto conn = unique_ptr<ConnectionPriv>(new ConnectionPriv { + .protocol = this, + .peerAddress = addr, + }); + + connections.push_back(conn.get()); + buffer.swap(conn->buffer); + return Connection(move(conn)); + } + } + + return nullopt; +} + +NetworkProtocol::Cookie NetworkProtocol::generateCookie(variant<sockaddr_in, sockaddr_in6> vaddr) const +{ + vector<uint8_t> cookie; + visit([&](auto && addr) { + cookie.resize(sizeof addr); + memcpy(cookie.data(), &addr, sizeof addr); + }, vaddr); + return Cookie { cookie }; +} + +bool NetworkProtocol::verifyCookie(variant<sockaddr_in, sockaddr_in6> vaddr, const NetworkProtocol::Cookie & cookie) const +{ + return visit([&](auto && addr) { + if (cookie.value.size() != sizeof addr) + return false; + return memcmp(cookie.value.data(), &addr, sizeof addr) == 0; + }, vaddr); +} + /******************************************************************************/ /* Connection */ /******************************************************************************/ @@ -222,53 +291,76 @@ const sockaddr_in6 & NetworkProtocol::Connection::peerAddress() const optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const PartialStorage & partStorage) { - vector<uint8_t> buf, decrypted; - auto plainBegin = buf.cbegin(); - auto plainEnd = buf.cbegin(); + vector<uint8_t> buf; + + Channel * channel = nullptr; + unique_ptr<Channel> channelPtr; { scoped_lock lock(p->cmutex); if (p->buffer.empty()) return nullopt; - buf.swap(p->buffer); - if ((buf[0] & 0xE0) == 0x80) { - Channel * channel = nullptr; - unique_ptr<Channel> channelPtr; + if (holds_alternative<unique_ptr<Channel>>(p->channel)) { + channel = std::get<unique_ptr<Channel>>(p->channel).get(); + } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { + channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); + channel = channelPtr.get(); + } + } + + if (auto header = receive(buf, channel, partStorage)) { + scoped_lock lock(p->cmutex); - if (holds_alternative<unique_ptr<Channel>>(p->channel)) { - channel = std::get<unique_ptr<Channel>>(p->channel).get(); - } else if (holds_alternative<Stored<ChannelAccept>>(p->channel)) { - channelPtr = std::get<Stored<ChannelAccept>>(p->channel)->data->channel(); - channel = channelPtr.get(); - } + if (header->lookupFirst<Header::Initiation>()) { + p->protocol->sendCookie(p->peerAddress); + return nullopt; + } - if (not channel) { - std::cerr << "unexpected encrypted packet\n"; - return nullopt; - } + if (holds_alternative<monostate>(p->channel) || + holds_alternative<Cookie>(p->channel)) + if (const auto * cookie = header->lookupFirst<Header::CookieSet>()) + p->channel = cookie->value; - if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { - if (decrypted.empty()) { - std::cerr << "empty decrypted content\n"; - } - else if (decrypted[0] == 0x00) { - plainBegin = decrypted.begin() + 1; - plainEnd = decrypted.end(); - } - else { - std::cerr << "streams not implemented\n"; - return nullopt; - } - } + return header; + } + return nullopt; +} + +optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(vector<uint8_t> & buf, + Channel * channel, + const PartialStorage & partStorage) +{ + vector<uint8_t> decrypted; + auto plainBegin = buf.cbegin(); + auto plainEnd = buf.cbegin(); + + if ((buf[0] & 0xE0) == 0x80) { + if (not channel) { + std::cerr << "unexpected encrypted packet\n"; + return nullopt; } - else if ((buf[0] & 0xE0) == 0x60) { - plainBegin = buf.begin(); - plainEnd = buf.end(); + + if (auto dec = channel->decrypt(buf.begin() + 1, buf.end(), decrypted, 0)) { + if (decrypted.empty()) { + std::cerr << "empty decrypted content\n"; + } + else if (decrypted[0] == 0x00) { + plainBegin = decrypted.begin() + 1; + plainEnd = decrypted.end(); + } + else { + std::cerr << "streams not implemented\n"; + return nullopt; + } } } + else if ((buf[0] & 0xE0) == 0x60) { + plainBegin = buf.begin(); + plainEnd = buf.end(); + } if (auto dec = PartialObject::decodePrefix(partStorage, plainBegin, plainEnd)) { if (auto header = Header::load(std::get<PartialObject>(*dec))) { @@ -287,14 +379,14 @@ optional<NetworkProtocol::Header> NetworkProtocol::Connection::receive(const Par } bool NetworkProtocol::Connection::send(const PartialStorage & partStorage, - const Header & header, + Header header, const vector<Object> & objs, bool secure) { - return p->send(partStorage, header, objs, secure); + return p->send(partStorage, move(header), objs, secure); } bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, - const Header & header, + Header header, const vector<Object> & objs, bool secure) { vector<uint8_t> data, part, out; @@ -308,6 +400,10 @@ bool NetworkProtocol::ConnectionPriv::send(const PartialStorage & partStorage, if (channel || secure) data.push_back(0x00); + else if (const auto * ptr = get_if<Cookie>(&this->channel)) { + header.items.push_back(Header::CookieEcho { ptr->value }); + header.items.push_back(Header::Version { defaultVersion }); + } part = header.toObject(partStorage).encode(); data.insert(data.end(), part.begin(), part.end()); @@ -414,6 +510,15 @@ optional<NetworkProtocol::Header> NetworkProtocol::Header::load(const PartialObj } else if (item.name == "VER") { if (auto ver = item.asText()) items.emplace_back(Version { *ver }); + } else if (item.name == "INI") { + if (auto ref = item.asRef()) + items.emplace_back(Initiation { ref->digest() }); + } else if (item.name == "CKS") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieSet { *cookie }); + } else if (item.name == "CKE") { + if (auto cookie = item.asBinary()) + items.emplace_back(CookieEcho { *cookie }); } else if (item.name == "REQ") { if (auto ref = item.asRef()) items.emplace_back(DataRequest { ref->digest() }); @@ -455,6 +560,15 @@ PartialObject NetworkProtocol::Header::toObject(const PartialStorage & st) const else if (const auto * ptr = get_if<Version>(&item)) ritems.emplace_back("VER", ptr->value); + else if (const auto * ptr = get_if<Initiation>(&item)) + ritems.emplace_back("INI", st.ref(ptr->value)); + + else if (const auto * ptr = get_if<CookieSet>(&item)) + ritems.emplace_back("CKS", ptr->value.value); + + else if (const auto * ptr = get_if<CookieEcho>(&item)) + ritems.emplace_back("CKE", ptr->value.value); + else if (const auto * ptr = get_if<DataRequest>(&item)) ritems.emplace_back("REQ", st.ref(ptr->value)); |