summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Smrž <roman.smrz@seznam.cz>2024-04-22 21:32:02 +0200
committerRoman Smrž <roman.smrz@seznam.cz>2024-04-22 21:54:04 +0200
commit26f7f5629ef453c12b311c439aacbded0889a63f (patch)
tree3c81197d0da4dc77d3b81e0bec2aa075e03ac3a5
parentf195f4a165b573d92d975b3e489372e88708e687 (diff)
Network: wait for stream open ack before sending any data
-rw-r--r--src/Erebos/Network/Protocol.hs67
1 files changed, 49 insertions, 18 deletions
diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs
index f9bb53c..9ac9574 100644
--- a/src/Erebos/Network/Protocol.hs
+++ b/src/Erebos/Network/Protocol.hs
@@ -197,6 +197,7 @@ connAddWriteStream conn@Connection {..} = do
let doInsert n (s@(n', _) : rest) | n == n' =
fmap (s:) <$> doInsert (n + 1) rest
doInsert n streams = do
+ sState <- newTVar StreamOpening
(sFlowIn, sFlowOut) <- newFlow
sNextSequence <- newTVar 0
let info = (n, Stream {..})
@@ -205,9 +206,12 @@ connAddWriteStream conn@Connection {..} = do
writeTVar cOutStreams outStreams'
let go = do
- (reserved, msg) <- atomically $ (,)
- <$> reservePacket conn
- <*> readFlow (sFlowOut stream)
+ (reserved, msg) <- atomically $ do
+ readTVar (sState stream) >>= \case
+ StreamRunning -> return ()
+ _ -> retry
+ (,) <$> reservePacket conn
+ <*> readFlow (sFlowOut stream)
let (plain, cont) = case msg of
StreamData {..} -> (stpData, True)
StreamClosed {} -> (BC.empty, False)
@@ -254,6 +258,7 @@ connAddReadStream Connection {..} streamNumber = do
| streamNumber < n = fmap (s:) <$> doInsert rest
| streamNumber == n = doInsert rest
doInsert streams = do
+ sState <- newTVar StreamRunning
(sFlowIn, sFlowOut) <- newFlow
sNextSequence <- newTVar 0
let stream = Stream {..}
@@ -267,11 +272,14 @@ type RawStreamReader = Flow StreamPacket Void
type RawStreamWriter = Flow Void StreamPacket
data Stream = Stream
- { sFlowIn :: Flow Void StreamPacket
+ { sState :: TVar StreamState
+ , sFlowIn :: Flow Void StreamPacket
, sFlowOut :: Flow StreamPacket Void
, sNextSequence :: TVar Word64
}
+data StreamState = StreamOpening | StreamRunning
+
data StreamPacket
= StreamData
{ stpSequence :: Word64
@@ -281,6 +289,15 @@ data StreamPacket
{ stpSequence :: Word64
}
+streamAccepted :: Connection addr -> Word8 -> IO ()
+streamAccepted Connection {..} snum = atomically $ do
+ (lookup snum <$> readTVar cOutStreams) >>= \case
+ Just Stream {..} -> do
+ modifyTVar' sState $ \case
+ StreamOpening -> StreamRunning
+ x -> x
+ Nothing -> return ()
+
readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)])
readStreamToList stream = readFlowIO stream >>= \case
StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream
@@ -336,12 +353,14 @@ data ChannelState = ChannelNone
data ReservedToSend = ReservedToSend
{ rsAckedBy :: Maybe (TransportHeaderItem -> Bool)
+ , rsOnAck :: IO ()
}
data SentPacket = SentPacket
{ spTime :: TimeSpec
, spRetryCount :: Int
, spAckedBy :: Maybe (TransportHeaderItem -> Bool)
+ , spOnAck :: IO ()
, spData :: BC.ByteString
}
@@ -479,15 +498,17 @@ processIncoming gs@GlobalState {..} = do
| hobj:content <- objs
, Just header@(TransportHeader items) <- transportFromObject hobj
-> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case
- Just (conn@Connection {..}, mbup) -> atomically $ do
- case mbcounter of
- Just counter | any isHeaderItemAcknowledged items ->
- modifyTVar' cToAcknowledge (fromIntegral counter :)
- _ -> return ()
- processAcknowledgements gs conn items
- case mbup of
- Just up -> putTMVar gNextUp (conn, (secure, up))
- Nothing -> return ()
+ Just (conn@Connection {..}, mbup) -> do
+ ioAfter <- atomically $ do
+ case mbcounter of
+ Just counter | any isHeaderItemAcknowledged items ->
+ modifyTVar' cToAcknowledge (fromIntegral counter :)
+ _ -> return ()
+ case mbup of
+ Just up -> putTMVar gNextUp (conn, (secure, up))
+ Nothing -> return ()
+ processAcknowledgements gs conn items
+ ioAfter
Nothing -> return ()
| otherwise -> atomically $ do
@@ -641,7 +662,7 @@ reservePacket Connection {..} = do
retry
writeTVar cReservedPackets $ reserved + 1
- return $ ReservedToSend Nothing
+ return $ ReservedToSend Nothing (return ())
resendBytes :: Connection addr -> Maybe ReservedToSend -> SentPacket -> IO ()
resendBytes Connection {..} reserved sp = do
@@ -664,6 +685,7 @@ sendBytes conn reserved bs = resendBytes conn reserved
{ spTime = undefined
, spRetryCount = -1
, spAckedBy = rsAckedBy =<< reserved
+ , spOnAck = maybe (return ()) rsOnAck reserved
, spData = bs
}
@@ -710,6 +732,10 @@ processOutgoing gs@GlobalState {..} = do
(serializeObject $ transportToObject gStorage header)
: map lazyLoadBytes content
+ let onAck = case catMaybes (map (\case StreamOpen n -> Just n; _ -> Nothing) hitems) of
+ [] -> return ()
+ xs -> sequence_ $ map (streamAccepted conn) xs
+
mbs <- case mbch of
Just ch -> do
runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case
@@ -723,7 +749,10 @@ processOutgoing gs@GlobalState {..} = do
case mbs of
Just (bs, ackedBy) -> do
- let mbReserved' = (\rs -> rs { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) }) <$> mbReserved
+ let mbReserved' = (\rs -> rs
+ { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy)
+ , rsOnAck = rsOnAck rs >> onAck
+ }) <$> mbReserved
sendBytes conn mbReserved' bs
Nothing -> return ()
@@ -785,6 +814,8 @@ processOutgoing gs@GlobalState {..} = do
, [ handleControlRequests ]
]
-processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM ()
-processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do
- modifyTVar' cSentPackets $ filter $ \sp -> not (fromJust (spAckedBy sp) hitem)
+processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM (IO ())
+processAcknowledgements GlobalState {} Connection {..} header = do
+ (acked, notAcked) <- partition (\sp -> any (fromJust (spAckedBy sp)) header) <$> readTVar cSentPackets
+ writeTVar cSentPackets notAcked
+ return $ sequence_ $ map spOnAck acked