From 1d50a136067b59ea8fc6b95b8b22a911f603a605 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 4 May 2024 21:24:17 +0200 Subject: Network: wait with channel close after delivering all data --- src/Erebos/Network/Protocol.hs | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) (limited to 'src/Erebos/Network') diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index 59fcdca..3191f16 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -200,6 +200,7 @@ connAddWriteStream conn@Connection {..} = do sState <- newTVar StreamOpening (sFlowIn, sFlowOut) <- newFlow sNextSequence <- newTVar 0 + sWaitingForAck <- newTVar 0 let info = (n, Stream {..}) return (info, info : streams) doInsert _ _ = throwError "all outbound streams in use" @@ -217,10 +218,17 @@ connAddWriteStream conn@Connection {..} = do _ -> retry (,) <$> reservePacket conn <*> readFlow (sFlowOut stream) - let (plain, cont, onAck) = case msg of - StreamData {..} -> (stpData, True, return ()) - StreamClosed {} -> (BC.empty, False, streamClosed conn streamNumber) - -- TODO: send channel closed only after delivering all previous data packets + + (plain, cont, onAck) <- case msg of + StreamData {..} -> do + return (stpData, True, return ()) + StreamClosed {} -> do + atomically $ do + -- wait for ack on all sent stream data + waits <- readTVar (sWaitingForAck stream) + when (waits > 0) retry + return (BC.empty, False, streamClosed conn streamNumber) + let secure = True plainAckedBy = [] mbReserved = Just reserved @@ -247,9 +255,14 @@ connAddWriteStream conn@Connection {..} = do case mbs of Just (bs, ackedBy) -> do + atomically $ do + modifyTVar' (sWaitingForAck stream) (+ 1) let mbReserved' = (\rs -> rs { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) - , rsOnAck = rsOnAck rs >> onAck + , rsOnAck = do + rsOnAck rs + onAck + atomically $ modifyTVar' (sWaitingForAck stream) (subtract 1) }) <$> mbReserved sendBytes conn mbReserved' bs Nothing -> return () @@ -266,6 +279,7 @@ connAddReadStream Connection {..} streamNumber = do sState <- newTVar StreamRunning (sFlowIn, sFlowOut) <- newFlow sNextSequence <- newTVar 0 + sWaitingForAck <- newTVar 0 let stream = Stream {..} return (stream, (streamNumber, stream) : streams) (stream, inStreams') <- doInsert inStreams @@ -281,6 +295,7 @@ data Stream = Stream , sFlowIn :: Flow Void StreamPacket , sFlowOut :: Flow StreamPacket Void , sNextSequence :: TVar Word64 + , sWaitingForAck :: TVar Word64 } data StreamState = StreamOpening | StreamRunning -- cgit v1.2.3