From 26f7f5629ef453c12b311c439aacbded0889a63f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Mon, 22 Apr 2024 21:32:02 +0200 Subject: Network: wait for stream open ack before sending any data --- src/Erebos/Network/Protocol.hs | 67 ++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 18 deletions(-) (limited to 'src/Erebos') diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index f9bb53c..9ac9574 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -197,6 +197,7 @@ connAddWriteStream conn@Connection {..} = do let doInsert n (s@(n', _) : rest) | n == n' = fmap (s:) <$> doInsert (n + 1) rest doInsert n streams = do + sState <- newTVar StreamOpening (sFlowIn, sFlowOut) <- newFlow sNextSequence <- newTVar 0 let info = (n, Stream {..}) @@ -205,9 +206,12 @@ connAddWriteStream conn@Connection {..} = do writeTVar cOutStreams outStreams' let go = do - (reserved, msg) <- atomically $ (,) - <$> reservePacket conn - <*> readFlow (sFlowOut stream) + (reserved, msg) <- atomically $ do + readTVar (sState stream) >>= \case + StreamRunning -> return () + _ -> retry + (,) <$> reservePacket conn + <*> readFlow (sFlowOut stream) let (plain, cont) = case msg of StreamData {..} -> (stpData, True) StreamClosed {} -> (BC.empty, False) @@ -254,6 +258,7 @@ connAddReadStream Connection {..} streamNumber = do | streamNumber < n = fmap (s:) <$> doInsert rest | streamNumber == n = doInsert rest doInsert streams = do + sState <- newTVar StreamRunning (sFlowIn, sFlowOut) <- newFlow sNextSequence <- newTVar 0 let stream = Stream {..} @@ -267,11 +272,14 @@ type RawStreamReader = Flow StreamPacket Void type RawStreamWriter = Flow Void StreamPacket data Stream = Stream - { sFlowIn :: Flow Void StreamPacket + { sState :: TVar StreamState + , sFlowIn :: Flow Void StreamPacket , sFlowOut :: Flow StreamPacket Void , sNextSequence :: TVar Word64 } +data StreamState = StreamOpening | StreamRunning + data StreamPacket = StreamData { stpSequence :: Word64 @@ -281,6 +289,15 @@ data StreamPacket { stpSequence :: Word64 } +streamAccepted :: Connection addr -> Word8 -> IO () +streamAccepted Connection {..} snum = atomically $ do + (lookup snum <$> readTVar cOutStreams) >>= \case + Just Stream {..} -> do + modifyTVar' sState $ \case + StreamOpening -> StreamRunning + x -> x + Nothing -> return () + readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) readStreamToList stream = readFlowIO stream >>= \case StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream @@ -336,12 +353,14 @@ data ChannelState = ChannelNone data ReservedToSend = ReservedToSend { rsAckedBy :: Maybe (TransportHeaderItem -> Bool) + , rsOnAck :: IO () } data SentPacket = SentPacket { spTime :: TimeSpec , spRetryCount :: Int , spAckedBy :: Maybe (TransportHeaderItem -> Bool) + , spOnAck :: IO () , spData :: BC.ByteString } @@ -479,15 +498,17 @@ processIncoming gs@GlobalState {..} = do | hobj:content <- objs , Just header@(TransportHeader items) <- transportFromObject hobj -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case - Just (conn@Connection {..}, mbup) -> atomically $ do - case mbcounter of - Just counter | any isHeaderItemAcknowledged items -> - modifyTVar' cToAcknowledge (fromIntegral counter :) - _ -> return () - processAcknowledgements gs conn items - case mbup of - Just up -> putTMVar gNextUp (conn, (secure, up)) - Nothing -> return () + Just (conn@Connection {..}, mbup) -> do + ioAfter <- atomically $ do + case mbcounter of + Just counter | any isHeaderItemAcknowledged items -> + modifyTVar' cToAcknowledge (fromIntegral counter :) + _ -> return () + case mbup of + Just up -> putTMVar gNextUp (conn, (secure, up)) + Nothing -> return () + processAcknowledgements gs conn items + ioAfter Nothing -> return () | otherwise -> atomically $ do @@ -641,7 +662,7 @@ reservePacket Connection {..} = do retry writeTVar cReservedPackets $ reserved + 1 - return $ ReservedToSend Nothing + return $ ReservedToSend Nothing (return ()) resendBytes :: Connection addr -> Maybe ReservedToSend -> SentPacket -> IO () resendBytes Connection {..} reserved sp = do @@ -664,6 +685,7 @@ sendBytes conn reserved bs = resendBytes conn reserved { spTime = undefined , spRetryCount = -1 , spAckedBy = rsAckedBy =<< reserved + , spOnAck = maybe (return ()) rsOnAck reserved , spData = bs } @@ -710,6 +732,10 @@ processOutgoing gs@GlobalState {..} = do (serializeObject $ transportToObject gStorage header) : map lazyLoadBytes content + let onAck = case catMaybes (map (\case StreamOpen n -> Just n; _ -> Nothing) hitems) of + [] -> return () + xs -> sequence_ $ map (streamAccepted conn) xs + mbs <- case mbch of Just ch -> do runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case @@ -723,7 +749,10 @@ processOutgoing gs@GlobalState {..} = do case mbs of Just (bs, ackedBy) -> do - let mbReserved' = (\rs -> rs { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) }) <$> mbReserved + let mbReserved' = (\rs -> rs + { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) + , rsOnAck = rsOnAck rs >> onAck + }) <$> mbReserved sendBytes conn mbReserved' bs Nothing -> return () @@ -785,6 +814,8 @@ processOutgoing gs@GlobalState {..} = do , [ handleControlRequests ] ] -processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM () -processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do - modifyTVar' cSentPackets $ filter $ \sp -> not (fromJust (spAckedBy sp) hitem) +processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM (IO ()) +processAcknowledgements GlobalState {} Connection {..} header = do + (acked, notAcked) <- partition (\sp -> any (fromJust (spAckedBy sp)) header) <$> readTVar cSentPackets + writeTVar cSentPackets notAcked + return $ sequence_ $ map spOnAck acked -- cgit v1.2.3