From 558ea4d565799aa2000af0b1fc6d159447c9868b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 19 Aug 2023 12:44:35 +0200 Subject: Network: connection initiation with cookie --- src/Network.hs | 98 ++++++++++++------- src/Network/Protocol.hs | 245 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 245 insertions(+), 98 deletions(-) diff --git a/src/Network.hs b/src/Network.hs index 071a0b0..6685045 100644 --- a/src/Network.hs +++ b/src/Network.hs @@ -39,7 +39,7 @@ import Foreign.Storable import GHC.Conc.Sync (unsafeIOToSTM) -import Network.Socket +import Network.Socket hiding (ControlMessage) import qualified Network.Socket.ByteString as S import Channel @@ -67,7 +67,7 @@ data Server = Server , serverThreads :: MVar [ThreadId] , serverSocket :: MVar Socket , serverRawPath :: SymFlow (PeerAddress, BC.ByteString) - , serverControlFlow :: Flow (Connection PeerAddress) (ControlRequest PeerAddress) + , serverControlFlow :: Flow (ControlMessage PeerAddress) (ControlRequest PeerAddress) , serverDataResponse :: TQueue (Peer, Maybe PartialRef) , serverIOActions :: TQueue (ExceptT String IO ()) , serverServices :: [SomeService] @@ -238,7 +238,9 @@ startServer opt serverOrigHead logd' serverServices = do let idt = headLocalIdentity h changedId <- modifyMVar serverIdentity_ $ \cur -> return (idt, cur /= idt) - when changedId $ announceUpdate idt + when changedId $ do + writeFlowIO serverControlFlow $ UpdateSelfIdentity idt + announceUpdate idt forM_ serverServices $ \(SomeService service _) -> do forM_ (serviceStorageWatchers service) $ \(SomeStorageWatcher sel act) -> do @@ -260,27 +262,44 @@ startServer opt serverOrigHead logd' serverServices = do PeerIceSession ice -> iceSend ice msg forkServerThread server $ forever $ do - conn <- readFlowIO serverControlFlow - let paddr = connAddress conn - peer <- modifyMVar serverPeers $ \pvalue -> do - case M.lookup paddr pvalue of - Just peer -> return (pvalue, peer) - Nothing -> do - peer <- mkPeer server paddr - return (M.insert paddr peer pvalue, peer) - - atomically $ do - readTVar (peerConnection peer) >>= \case - Left packets -> writeFlowBulk (connData conn) $ reverse packets - Right _ -> return () - writeTVar (peerConnection peer) (Right conn) - - forkServerThread server $ forever $ do - (secure, TransportPacket header objs) <- readFlowIO $ connData conn - prefs <- forM objs $ storeObject $ peerInStorage peer - identity <- readMVar serverIdentity_ - let svcs = map someServiceID serverServices - handlePacket identity secure peer chanSvc svcs header prefs + readFlowIO serverControlFlow >>= \case + NewConnection conn mbpid -> do + let paddr = connAddress conn + peer <- modifyMVar serverPeers $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, peer) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, peer) + + atomically $ do + readTVar (peerConnection peer) >>= \case + Left packets -> writeFlowBulk (connData conn) $ reverse packets + Right _ -> return () + writeTVar (peerConnection peer) (Right conn) + + forkServerThread server $ forever $ do + (secure, TransportPacket header objs) <- readFlowIO $ connData conn + prefs <- forM objs $ storeObject $ peerInStorage peer + identity <- readMVar serverIdentity_ + let svcs = map someServiceID serverServices + handlePacket identity secure peer chanSvc svcs header prefs + + case mbpid of + Just dgst -> do + identity <- readMVar serverIdentity_ + atomically $ runPacketHandler False peer $ do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan serverChanPeer peer + _ -> return () + Nothing -> return () + + ReceivedAnnounce addr _ -> do + void $ serverPeer' server addr erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd protocolRawPath protocolControlFlow @@ -343,6 +362,17 @@ newtype PacketHandler a = PacketHandler { unPacketHandler :: StateT PacketHandle instance MonadFail PacketHandler where fail = throwError +runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM () +runPacketHandler secure peer@Peer {..} act = do + let logd = writeTQueue $ serverErrorLog peerServer_ + runExceptT (flip execStateT (PacketHandlerState peer [] [] []) $ unPacketHandler act) >>= \case + Left err -> do + logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err + Right ph -> do + when (not $ null $ phHead ph) $ do + let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph) + sendToPeerS' secure peer (phAckedBy ph) packet + liftSTM :: STM a -> PacketHandler a liftSTM = PacketHandler . lift . lift @@ -392,7 +422,7 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = _ -> [] ] - res <- runExceptT $ flip execStateT (PacketHandlerState peer [] [] []) $ unPacketHandler $ do + runPacketHandler secure peer $ do let logd = liftSTM . writeTQueue (serverErrorLog server) forM_ headers $ \case Acknowledged dgst -> do @@ -449,6 +479,8 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = liftSTM (getPeerChannel peer) >>= \case ChannelNone {} -> process + ChannelCookieWait {} -> process + ChannelCookieReceived {} -> process ChannelOurRequest our | dgst < refDigest (storedRef our) -> process | otherwise -> reject ChannelPeerRequest {} -> process @@ -458,8 +490,11 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = TrChannelAccept dgst -> do let process = do handleChannelAccept identity $ partialRefFromDigest (peerInStorage peer) dgst + reject = addHeader $ Rejected dgst liftSTM (getPeerChannel peer) >>= \case - ChannelNone {} -> process + ChannelNone {} -> reject + ChannelCookieWait {} -> reject + ChannelCookieReceived {} -> reject ChannelOurRequest {} -> process ChannelPeerRequest {} -> process ChannelOurAccept our _ | dgst < refDigest (storedRef our) -> process @@ -482,15 +517,6 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = _ -> return () - let logd = writeTQueue (serverErrorLog server) - case res of - Left err -> do - logd $ "Error in handling packet from " ++ show (peerAddress peer) ++ ": " ++ err - Right ph -> do - when (not $ null $ phHead ph) $ do - let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph) - sendToPeerS' secure peer (phAckedBy ph) packet - withPeerIdentity :: MonadIO m => Peer -> (UnifiedIdentity -> ExceptT String IO ()) -> m () withPeerIdentity peer act = liftIO $ atomically $ readTVar (peerIdentityVar peer) >>= \case @@ -509,7 +535,7 @@ setupChannel identity peer upid = do ] liftIO $ atomically $ do getPeerChannel peer >>= \case - ChannelNone -> do + ChannelCookieReceived {} -> do sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ TransportPacket (TransportHeader hitems) [storedRef req] setPeerChannel peer $ ChannelOurRequest req diff --git a/src/Network/Protocol.hs b/src/Network/Protocol.hs index db5e767..bd93386 100644 --- a/src/Network/Protocol.hs +++ b/src/Network/Protocol.hs @@ -11,6 +11,7 @@ module Network.Protocol ( ChannelState(..), ControlRequest(..), + ControlMessage(..), erebosNetworkProtocol, Connection, @@ -63,6 +64,9 @@ data TransportHeaderItem = Acknowledged RefDigest | Rejected RefDigest | ProtocolVersion Text + | Initiation RefDigest + | CookieSet Cookie + | CookieEcho Cookie | DataRequest RefDigest | DataResponse RefDigest | AnnounceSelf RefDigest @@ -73,12 +77,18 @@ data TransportHeaderItem | ServiceRef RefDigest deriving (Eq) +newtype Cookie = Cookie ByteString + deriving (Eq) + transportToObject :: PartialStorage -> TransportHeader -> PartialObject transportToObject st (TransportHeader items) = Rec $ map single items where single = \case Acknowledged dgst -> (BC.pack "ACK", RecRef $ partialRefFromDigest st dgst) Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst) ProtocolVersion ver -> (BC.pack "VER", RecText ver) + Initiation dgst -> (BC.pack "INI", RecRef $ partialRefFromDigest st dgst) + CookieSet (Cookie bytes) -> (BC.pack "CKS", RecBinary bytes) + CookieEcho (Cookie bytes) -> (BC.pack "CKE", RecBinary bytes) DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst) DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst) AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst) @@ -96,6 +106,9 @@ transportFromObject (Rec items) = case catMaybes $ map single items of | name == BC.pack "ACK", RecRef ref <- content -> Just $ Acknowledged $ refDigest ref | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref | name == BC.pack "VER", RecText ver <- content -> Just $ ProtocolVersion ver + | name == BC.pack "INI", RecRef ref <- content -> Just $ Initiation $ refDigest ref + | name == BC.pack "CKS", RecBinary bytes <- content -> Just $ CookieSet (Cookie bytes) + | name == BC.pack "CKE", RecBinary bytes <- content -> Just $ CookieEcho (Cookie bytes) | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref @@ -109,14 +122,15 @@ transportFromObject _ = Nothing data GlobalState addr = (Eq addr, Show addr) => GlobalState - { gIdentity :: TVar UnifiedIdentity + { gIdentity :: TVar (UnifiedIdentity, [UnifiedIdentity]) , gConnections :: TVar [Connection addr] , gDataFlow :: SymFlow (addr, ByteString) - , gControlFlow :: Flow (ControlRequest addr) (Connection addr) + , gControlFlow :: Flow (ControlRequest addr) (ControlMessage addr) , gLog :: String -> STM () , gStorage :: PartialStorage , gNowVar :: TVar TimeSpec , gNextTimeout :: TVar TimeSpec + , gInitConfig :: Ref } data Connection addr = Connection @@ -156,6 +170,8 @@ wrDigest = refDigest . wrefPartial data ChannelState = ChannelNone + | ChannelCookieWait + | ChannelCookieReceived Cookie | ChannelOurRequest (Stored ChannelRequest) | ChannelPeerRequest WaitingRef | ChannelOurAccept (Stored ChannelAccept) Channel @@ -165,29 +181,35 @@ data ChannelState = ChannelNone data SentPacket = SentPacket { spTime :: TimeSpec , spRetryCount :: Int - , spAckedBy :: [TransportHeaderItem] + , spAckedBy :: Maybe (TransportHeaderItem -> Bool) , spData :: BC.ByteString } data ControlRequest addr = RequestConnection addr | SendAnnounce addr + | UpdateSelfIdentity UnifiedIdentity + +data ControlMessage addr = NewConnection (Connection addr) (Maybe RefDigest) + | ReceivedAnnounce addr RefDigest erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) => UnifiedIdentity -> (String -> STM ()) -> SymFlow (addr, ByteString) - -> Flow (ControlRequest addr) (Connection addr) + -> Flow (ControlRequest addr) (ControlMessage addr) -> IO () erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do - gIdentity <- newTVarIO initialIdentity + gIdentity <- newTVarIO (initialIdentity, []) gConnections <- newTVarIO [] - gStorage <- derivePartialStorage =<< memoryStorage + mStorage <- memoryStorage + gStorage <- derivePartialStorage mStorage startTime <- getTime MonotonicRaw gNowVar <- newTVarIO startTime gNextTimeout <- newTVarIO startTime + gInitConfig <- store mStorage $ (Rec [] :: Object) let gs = GlobalState {..} @@ -212,31 +234,38 @@ erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do getConnection :: GlobalState addr -> addr -> STM (Connection addr) -getConnection GlobalState {..} addr = do +getConnection gs addr = do + maybe (newConnection gs addr) return =<< findConnection gs addr + +findConnection :: GlobalState addr -> addr -> STM (Maybe (Connection addr)) +findConnection GlobalState {..} addr = do + find ((addr==) . cAddress) <$> readTVar gConnections + +newConnection :: GlobalState addr -> addr -> STM (Connection addr) +newConnection GlobalState {..} addr = do conns <- readTVar gConnections - case find ((addr==) . cAddress) conns of - Just conn -> return conn - Nothing -> do - let cAddress = addr - (cDataUp, cDataInternal) <- newFlow - cChannel <- newTVar ChannelNone - cSecureOutQueue <- newTQueue - cSentPackets <- newTVar [] - let conn = Connection {..} - - writeTVar gConnections (conn : conns) - writeFlow gControlFlow conn - return conn + + let cAddress = addr + (cDataUp, cDataInternal) <- newFlow + cChannel <- newTVar ChannelNone + cSecureOutQueue <- newTQueue + cSentPackets <- newTVar [] + let conn = Connection {..} + + writeTVar gConnections (conn : conns) + return conn processIncomming :: GlobalState addr -> STM (IO ()) processIncomming gs@GlobalState {..} = do (addr, msg) <- readFlow gDataFlow - conn@Connection {..} <- getConnection gs addr + mbconn <- findConnection gs addr - mbch <- readTVar cChannel >>= return . \case - ChannelEstablished ch -> Just ch - ChannelOurAccept _ ch -> Just ch - _ -> Nothing + mbch <- case mbconn of + Nothing -> return Nothing + Just conn -> readTVar (cChannel conn) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ ch -> Just ch + _ -> Nothing return $ do let deserialize = liftEither . runExcept . deserializeObjects gStorage . BL.fromStrict @@ -269,30 +298,112 @@ processIncomming gs@GlobalState {..} = do Right (secure, objs) | hobj:content <- objs , Just header@(TransportHeader items) <- transportFromObject hobj - -> atomically $ do - processAcknowledgements gs conn items - writeFlow cDataInternal (secure, TransportPacket header content) + -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case + Just (conn@Connection {..}, mbup) -> atomically $ do + processAcknowledgements gs conn items + case mbup of + Just up -> writeFlow cDataInternal (secure, up) + Nothing -> return () + Nothing -> return () | otherwise -> atomically $ do - gLog $ show cAddress ++ ": invalid objects" + gLog $ show addr ++ ": invalid objects" gLog $ show objs Left err -> do - atomically $ gLog $ show cAddress <> ": failed to parse packet: " <> err - + atomically $ gLog $ show addr <> ": failed to parse packet: " <> err + +processPacket :: GlobalState addr -> Either addr (Connection addr) -> Bool -> TransportPacket a -> IO (Maybe (Connection addr, Maybe (TransportPacket a))) +processPacket gs@GlobalState {..} econn secure packet@(TransportPacket (TransportHeader header) _) = if + | Right conn <- econn, secure + -> return $ Just (conn, Just packet) + + | _:_ <- mapMaybe (\case Initiation x -> Just x; _ -> Nothing) header + , Just ver <- version + -> do + cookie <- createCookie gs addr + atomically $ do + identity <- fst <$> readTVar gIdentity + let reply = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader + [ CookieSet cookie + , AnnounceSelf $ refDigest $ storedRef $ idData identity + , ProtocolVersion ver + ] + writeFlow gDataFlow (addr, reply) + return Nothing + + | cookie:_ <- mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header + , Just _ <- version + , Right conn@Connection {..} <- econn + -> do + atomically $ readTVar cChannel >>= \case + ChannelCookieWait -> do + writeTVar cChannel $ ChannelCookieReceived cookie + writeFlow gControlFlow (NewConnection conn mbpid) + return $ Just (conn, Nothing) + _ -> return Nothing + + | Right conn <- econn + -> return $ Just (conn, Just packet) + + | cookie:_ <- mapMaybe (\case CookieEcho x -> Just x; _ -> Nothing) header + , Just _ <- version + -> verifyCookie gs addr cookie >>= \case + True -> do + conn <- atomically $ findConnection gs addr >>= \case + Just conn -> return conn + Nothing -> do + conn <- newConnection gs addr + writeFlow gControlFlow (NewConnection conn mbpid) + return conn + return $ Just (conn, Just packet) + False -> return Nothing + + | dgst:_ <- mapMaybe (\case AnnounceSelf x -> Just x; _ -> Nothing) header + , Just _ <- version + -> do + atomically $ do + (cur, past) <- readTVar gIdentity + when (not $ dgst `elem` map (refDigest . storedRef . idData) (cur : past)) $ do + writeFlow gControlFlow $ ReceivedAnnounce addr dgst + return Nothing + + | otherwise -> return Nothing + + where + addr = either id cAddress econn + mbpid = listToMaybe $ mapMaybe (\case AnnounceSelf dgst -> Just dgst; _ -> Nothing) header + version = listToMaybe $ filter (\v -> ProtocolVersion v `elem` header) protocolVersions + + +createCookie :: GlobalState addr -> addr -> IO Cookie +createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr) + +verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool +verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie + +resendBytes :: GlobalState addr -> Connection addr -> SentPacket -> IO () +resendBytes GlobalState {..} Connection {..} sp = do + now <- getTime MonotonicRaw + atomically $ do + when (isJust $ spAckedBy sp) $ do + modifyTVar' cSentPackets $ (:) sp + { spTime = now + , spRetryCount = spRetryCount sp + 1 + } + writeFlow gDataFlow (cAddress, spData sp) + +sendBytes :: GlobalState addr -> Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () +sendBytes gs conn bs ackedBy = resendBytes gs conn + SentPacket + { spTime = undefined + , spRetryCount = -1 + , spAckedBy = ackedBy + , spData = bs + } processOutgoing :: forall addr. GlobalState addr -> STM (IO ()) processOutgoing gs@GlobalState {..} = do - let sendBytes :: Connection addr -> SentPacket -> IO () - sendBytes Connection {..} sp = do - now <- getTime MonotonicRaw - atomically $ do - when (not $ null $ spAckedBy sp) $ do - modifyTVar' cSentPackets $ (:) sp - { spTime = now - , spRetryCount = spRetryCount sp + 1 - } - writeFlow gDataFlow (cAddress, spData sp) let sendNextPacket :: Connection addr -> STM (IO ()) sendNextPacket conn@Connection {..} = do @@ -304,16 +415,20 @@ processOutgoing gs@GlobalState {..} = do | isJust mbch = readTQueue cSecureOutQueue | otherwise = retry - (secure, packet@(TransportPacket header content), ackedBy) <- + (secure, packet@(TransportPacket (TransportHeader hitems) content), ackedBy) <- checkOutstanding <|> readFlow cDataInternal + when (isNothing mbch && secure) $ do + writeTQueue cSecureOutQueue (secure, packet, ackedBy) + + header <- readTVar cChannel >>= return . TransportHeader . \case + ChannelCookieReceived cookie -> CookieEcho cookie : ProtocolVersion protocolVersion : hitems + _ -> hitems + let plain = BL.concat $ (serializeObject $ transportToObject gStorage header) : map lazyLoadBytes content - when (isNothing mbch && secure) $ do - writeTQueue cSecureOutQueue (secure, packet, ackedBy) - return $ do mbs <- case mbch of Just ch -> do @@ -325,13 +440,7 @@ processOutgoing gs@GlobalState {..} = do | otherwise -> return $ Just $ BL.toStrict plain case mbs of - Just bs -> do - sendBytes conn $ SentPacket - { spTime = undefined - , spRetryCount = -1 - , spAckedBy = ackedBy - , spData = bs - } + Just bs -> sendBytes gs conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) Nothing -> return () let retransmitPacket :: Connection addr -> STM (IO ()) @@ -350,26 +459,38 @@ processOutgoing gs@GlobalState {..} = do else retry else do writeTVar cSentPackets rest - return $ sendBytes conn sp + return $ resendBytes gs conn sp let handleControlRequests = readFlow gControlFlow >>= \case RequestConnection addr -> do - _ <- getConnection gs addr - identity <- readTVar gIdentity - let packet = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader $ - [ AnnounceSelf $ refDigest $ storedRef $ idData identity - ] ++ map ProtocolVersion protocolVersions - writeFlow gDataFlow (addr, packet) - return $ return () + conn@Connection {..} <- getConnection gs addr + identity <- fst <$> readTVar gIdentity + readTVar cChannel >>= \case + ChannelNone -> do + let packet = BL.toStrict $ BL.concat + [ serializeObject $ transportToObject gStorage $ TransportHeader $ + [ Initiation $ refDigest gInitConfig + , AnnounceSelf $ refDigest $ storedRef $ idData identity + ] ++ map ProtocolVersion protocolVersions + , lazyLoadBytes gInitConfig + ] + writeTVar cChannel ChannelCookieWait + return $ sendBytes gs conn packet $ Just $ \case CookieSet {} -> True; _ -> False + _ -> return $ return () SendAnnounce addr -> do - identity <- readTVar gIdentity + identity <- fst <$> readTVar gIdentity let packet = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader $ [ AnnounceSelf $ refDigest $ storedRef $ idData identity ] ++ map ProtocolVersion protocolVersions writeFlow gDataFlow (addr, packet) return $ return () + UpdateSelfIdentity nid -> do + (cur, past) <- readTVar gIdentity + writeTVar gIdentity (nid, cur : past) + return $ return () + conns <- readTVar gConnections msum $ concat $ [ map retransmitPacket conns @@ -379,4 +500,4 @@ processOutgoing gs@GlobalState {..} = do processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM () processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do - modifyTVar' cSentPackets $ filter $ (hitem `notElem`) . spAckedBy + modifyTVar' cSentPackets $ filter $ \sp -> not (fromJust (spAckedBy sp) hitem) -- cgit v1.2.3