From 558ea4d565799aa2000af0b1fc6d159447c9868b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 19 Aug 2023 12:44:35 +0200 Subject: Network: connection initiation with cookie --- src/Network.hs | 98 +++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 36 deletions(-) (limited to 'src/Network.hs') diff --git a/src/Network.hs b/src/Network.hs index 071a0b0..6685045 100644 --- a/src/Network.hs +++ b/src/Network.hs @@ -39,7 +39,7 @@ import Foreign.Storable import GHC.Conc.Sync (unsafeIOToSTM) -import Network.Socket +import Network.Socket hiding (ControlMessage) import qualified Network.Socket.ByteString as S import Channel @@ -67,7 +67,7 @@ data Server = Server , serverThreads :: MVar [ThreadId] , serverSocket :: MVar Socket , serverRawPath :: SymFlow (PeerAddress, BC.ByteString) - , serverControlFlow :: Flow (Connection PeerAddress) (ControlRequest PeerAddress) + , serverControlFlow :: Flow (ControlMessage PeerAddress) (ControlRequest PeerAddress) , serverDataResponse :: TQueue (Peer, Maybe PartialRef) , serverIOActions :: TQueue (ExceptT String IO ()) , serverServices :: [SomeService] @@ -238,7 +238,9 @@ startServer opt serverOrigHead logd' serverServices = do let idt = headLocalIdentity h changedId <- modifyMVar serverIdentity_ $ \cur -> return (idt, cur /= idt) - when changedId $ announceUpdate idt + when changedId $ do + writeFlowIO serverControlFlow $ UpdateSelfIdentity idt + announceUpdate idt forM_ serverServices $ \(SomeService service _) -> do forM_ (serviceStorageWatchers service) $ \(SomeStorageWatcher sel act) -> do @@ -260,27 +262,44 @@ startServer opt serverOrigHead logd' serverServices = do PeerIceSession ice -> iceSend ice msg forkServerThread server $ forever $ do - conn <- readFlowIO serverControlFlow - let paddr = connAddress conn - peer <- modifyMVar serverPeers $ \pvalue -> do - case M.lookup paddr pvalue of - Just peer -> return (pvalue, peer) - Nothing -> do - peer <- mkPeer server paddr - return (M.insert paddr peer pvalue, peer) - - atomically $ do - readTVar (peerConnection peer) >>= \case - Left packets -> writeFlowBulk (connData conn) $ reverse packets - Right _ -> return () - writeTVar (peerConnection peer) (Right conn) - - forkServerThread server $ forever $ do - (secure, TransportPacket header objs) <- readFlowIO $ connData conn - prefs <- forM objs $ storeObject $ peerInStorage peer - identity <- readMVar serverIdentity_ - let svcs = map someServiceID serverServices - handlePacket identity secure peer chanSvc svcs header prefs + readFlowIO serverControlFlow >>= \case + NewConnection conn mbpid -> do + let paddr = connAddress conn + peer <- modifyMVar serverPeers $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, peer) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, peer) + + atomically $ do + readTVar (peerConnection peer) >>= \case + Left packets -> writeFlowBulk (connData conn) $ reverse packets + Right _ -> return () + writeTVar (peerConnection peer) (Right conn) + + forkServerThread server $ forever $ do + (secure, TransportPacket header objs) <- readFlowIO $ connData conn + prefs <- forM objs $ storeObject $ peerInStorage peer + identity <- readMVar serverIdentity_ + let svcs = map someServiceID serverServices + handlePacket identity secure peer chanSvc svcs header prefs + + case mbpid of + Just dgst -> do + identity <- readMVar serverIdentity_ + atomically $ runPacketHandler False peer $ do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan serverChanPeer peer + _ -> return () + Nothing -> return () + + ReceivedAnnounce addr _ -> do + void $ serverPeer' server addr erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd protocolRawPath protocolControlFlow @@ -343,6 +362,17 @@ newtype PacketHandler a = PacketHandler { unPacketHandler :: StateT PacketHandle instance MonadFail PacketHandler where fail = throwError +runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM () +runPacketHandler secure peer@Peer {..} act = do + let logd = writeTQueue $ serverErrorLog peerServer_ + runExceptT (flip execStateT (PacketHandlerState peer [] [] []) $ unPacketHandler act) >>= \case + Left err -> do + logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err + Right ph -> do + when (not $ null $ phHead ph) $ do + let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph) + sendToPeerS' secure peer (phAckedBy ph) packet + liftSTM :: STM a -> PacketHandler a liftSTM = PacketHandler . lift . lift @@ -392,7 +422,7 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = _ -> [] ] - res <- runExceptT $ flip execStateT (PacketHandlerState peer [] [] []) $ unPacketHandler $ do + runPacketHandler secure peer $ do let logd = liftSTM . writeTQueue (serverErrorLog server) forM_ headers $ \case Acknowledged dgst -> do @@ -449,6 +479,8 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = liftSTM (getPeerChannel peer) >>= \case ChannelNone {} -> process + ChannelCookieWait {} -> process + ChannelCookieReceived {} -> process ChannelOurRequest our | dgst < refDigest (storedRef our) -> process | otherwise -> reject ChannelPeerRequest {} -> process @@ -458,8 +490,11 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = TrChannelAccept dgst -> do let process = do handleChannelAccept identity $ partialRefFromDigest (peerInStorage peer) dgst + reject = addHeader $ Rejected dgst liftSTM (getPeerChannel peer) >>= \case - ChannelNone {} -> process + ChannelNone {} -> reject + ChannelCookieWait {} -> reject + ChannelCookieReceived {} -> reject ChannelOurRequest {} -> process ChannelPeerRequest {} -> process ChannelOurAccept our _ | dgst < refDigest (storedRef our) -> process @@ -482,15 +517,6 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = _ -> return () - let logd = writeTQueue (serverErrorLog server) - case res of - Left err -> do - logd $ "Error in handling packet from " ++ show (peerAddress peer) ++ ": " ++ err - Right ph -> do - when (not $ null $ phHead ph) $ do - let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph) - sendToPeerS' secure peer (phAckedBy ph) packet - withPeerIdentity :: MonadIO m => Peer -> (UnifiedIdentity -> ExceptT String IO ()) -> m () withPeerIdentity peer act = liftIO $ atomically $ readTVar (peerIdentityVar peer) >>= \case @@ -509,7 +535,7 @@ setupChannel identity peer upid = do ] liftIO $ atomically $ do getPeerChannel peer >>= \case - ChannelNone -> do + ChannelCookieReceived {} -> do sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ TransportPacket (TransportHeader hitems) [storedRef req] setPeerChannel peer $ ChannelOurRequest req -- cgit v1.2.3