From 61f745b3c57e4fe78bea8f8a7a48923b364dd874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Tue, 30 Apr 2024 21:51:58 +0200 Subject: Network: fail when no free stream is available --- src/Erebos/Network/Protocol.hs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) (limited to 'src/Erebos/Network') diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index 9ac9574..b79b105 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -36,6 +36,7 @@ import Control.Concurrent.Async import Control.Concurrent.STM import Control.Monad import Control.Monad.Except +import Control.Monad.Trans import Data.Bits import Data.ByteString (ByteString) @@ -189,23 +190,27 @@ connSetChannel :: Connection addr -> ChannelState -> STM () connSetChannel Connection {..} ch = do writeTVar cChannel ch -connAddWriteStream :: Connection addr -> STM (TransportHeaderItem, RawStreamWriter, IO ()) +connAddWriteStream :: Connection addr -> STM (Either String (TransportHeaderItem, RawStreamWriter, IO ())) connAddWriteStream conn@Connection {..} = do - let GlobalState {..} = cGlobalState - outStreams <- readTVar cOutStreams - let doInsert n (s@(n', _) : rest) | n == n' = + let doInsert :: Word8 -> [(Word8, Stream)] -> ExceptT String STM ((Word8, Stream), [(Word8, Stream)]) + doInsert n (s@(n', _) : rest) | n == n' = fmap (s:) <$> doInsert (n + 1) rest - doInsert n streams = do + doInsert n streams | n < 63 = lift $ do sState <- newTVar StreamOpening (sFlowIn, sFlowOut) <- newFlow sNextSequence <- newTVar 0 let info = (n, Stream {..}) return (info, info : streams) - ((streamNumber, stream), outStreams') <- doInsert 1 outStreams - writeTVar cOutStreams outStreams' + doInsert _ _ = throwError "all outbound streams in use" + + runExceptT $ do + ((streamNumber, stream), outStreams') <- doInsert 1 outStreams + lift $ writeTVar cOutStreams outStreams' + return (StreamOpen streamNumber, sFlowIn stream, go cGlobalState streamNumber stream) - let go = do + where + go gs@GlobalState {..} streamNumber stream = do (reserved, msg) <- atomically $ do readTVar (sState stream) >>= \case StreamRunning -> return () @@ -247,9 +252,7 @@ connAddWriteStream conn@Connection {..} = do sendBytes conn mbReserved' bs Nothing -> return () - when cont go - - return (StreamOpen streamNumber, sFlowIn stream, go) + when cont $ go gs streamNumber stream connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader connAddReadStream Connection {..} streamNumber = do -- cgit v1.2.3