diff options
Diffstat (limited to 'src/Erebos')
| -rw-r--r-- | src/Erebos/Network/Protocol.hs | 67 | 
1 files changed, 49 insertions, 18 deletions
| diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index f9bb53c..9ac9574 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -197,6 +197,7 @@ connAddWriteStream conn@Connection {..} = do      let doInsert n (s@(n', _) : rest) | n == n' =              fmap (s:) <$> doInsert (n + 1) rest          doInsert n streams = do +            sState <- newTVar StreamOpening              (sFlowIn, sFlowOut) <- newFlow              sNextSequence <- newTVar 0              let info = (n, Stream {..}) @@ -205,9 +206,12 @@ connAddWriteStream conn@Connection {..} = do      writeTVar cOutStreams outStreams'      let go = do -            (reserved, msg) <- atomically $ (,) -                <$> reservePacket conn -                <*> readFlow (sFlowOut stream) +            (reserved, msg) <- atomically $ do +                readTVar (sState stream) >>= \case +                    StreamRunning -> return () +                    _             -> retry +                (,) <$> reservePacket conn +                    <*> readFlow (sFlowOut stream)              let (plain, cont) = case msg of                      StreamData {..} -> (stpData, True)                      StreamClosed {} -> (BC.empty, False) @@ -254,6 +258,7 @@ connAddReadStream Connection {..} streamNumber = do              | streamNumber <  n = fmap (s:) <$> doInsert rest              | streamNumber == n = doInsert rest          doInsert streams = do +            sState <- newTVar StreamRunning              (sFlowIn, sFlowOut) <- newFlow              sNextSequence <- newTVar 0              let stream = Stream {..} @@ -267,11 +272,14 @@ type RawStreamReader = Flow StreamPacket Void  type RawStreamWriter = Flow Void StreamPacket  data Stream = Stream -    { sFlowIn :: Flow Void StreamPacket +    { sState :: TVar StreamState +    , sFlowIn :: Flow Void StreamPacket      , sFlowOut :: Flow StreamPacket Void      , sNextSequence :: TVar Word64      } +data StreamState = StreamOpening | StreamRunning +  data StreamPacket      = StreamData          { stpSequence :: Word64 @@ -281,6 +289,15 @@ data StreamPacket          { stpSequence :: Word64          } +streamAccepted :: Connection addr -> Word8 -> IO () +streamAccepted Connection {..} snum = atomically $ do +    (lookup snum <$> readTVar cOutStreams) >>= \case +        Just Stream {..} -> do +            modifyTVar' sState $ \case +                StreamOpening -> StreamRunning +                x             -> x +        Nothing -> return () +  readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)])  readStreamToList stream = readFlowIO stream >>= \case      StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream @@ -336,12 +353,14 @@ data ChannelState = ChannelNone  data ReservedToSend = ReservedToSend      { rsAckedBy :: Maybe (TransportHeaderItem -> Bool) +    , rsOnAck :: IO ()      }  data SentPacket = SentPacket      { spTime :: TimeSpec      , spRetryCount :: Int      , spAckedBy :: Maybe (TransportHeaderItem -> Bool) +    , spOnAck :: IO ()      , spData :: BC.ByteString      } @@ -479,15 +498,17 @@ processIncoming gs@GlobalState {..} = do                  | hobj:content <- objs                  , Just header@(TransportHeader items) <- transportFromObject hobj                  -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case -                    Just (conn@Connection {..}, mbup) -> atomically $ do -                        case mbcounter of -                            Just counter | any isHeaderItemAcknowledged items -> -                                modifyTVar' cToAcknowledge (fromIntegral counter :) -                            _ -> return () -                        processAcknowledgements gs conn items -                        case mbup of -                            Just up -> putTMVar gNextUp (conn, (secure, up)) -                            Nothing -> return () +                    Just (conn@Connection {..}, mbup) -> do +                        ioAfter <- atomically $ do +                            case mbcounter of +                                Just counter | any isHeaderItemAcknowledged items -> +                                    modifyTVar' cToAcknowledge (fromIntegral counter :) +                                _ -> return () +                            case mbup of +                                Just up -> putTMVar gNextUp (conn, (secure, up)) +                                Nothing -> return () +                            processAcknowledgements gs conn items +                        ioAfter                      Nothing -> return ()                  | otherwise -> atomically $ do @@ -641,7 +662,7 @@ reservePacket Connection {..} = do          retry      writeTVar cReservedPackets $ reserved + 1 -    return $ ReservedToSend Nothing +    return $ ReservedToSend Nothing (return ())  resendBytes :: Connection addr -> Maybe ReservedToSend -> SentPacket -> IO ()  resendBytes Connection {..} reserved sp = do @@ -664,6 +685,7 @@ sendBytes conn reserved bs = resendBytes conn reserved          { spTime = undefined          , spRetryCount = -1          , spAckedBy = rsAckedBy =<< reserved +        , spOnAck = maybe (return ()) rsOnAck reserved          , spData = bs          } @@ -710,6 +732,10 @@ processOutgoing gs@GlobalState {..} = do                          (serializeObject $ transportToObject gStorage header)                          : map lazyLoadBytes content +                let onAck = case catMaybes (map (\case StreamOpen n -> Just n; _ -> Nothing) hitems) of +                                 [] -> return () +                                 xs -> sequence_ $ map (streamAccepted conn) xs +                  mbs <- case mbch of                      Just ch -> do                          runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case @@ -723,7 +749,10 @@ processOutgoing gs@GlobalState {..} = do                  case mbs of                      Just (bs, ackedBy) -> do -                        let mbReserved' = (\rs -> rs { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) }) <$> mbReserved +                        let mbReserved' = (\rs -> rs +                                { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) +                                , rsOnAck = rsOnAck rs >> onAck +                                }) <$> mbReserved                          sendBytes conn mbReserved' bs                      Nothing -> return () @@ -785,6 +814,8 @@ processOutgoing gs@GlobalState {..} = do          , [ handleControlRequests ]          ] -processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM () -processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do -    modifyTVar' cSentPackets $ filter $ \sp -> not (fromJust (spAckedBy sp) hitem) +processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM (IO ()) +processAcknowledgements GlobalState {} Connection {..} header = do +    (acked, notAcked) <- partition (\sp -> any (fromJust (spAckedBy sp)) header) <$> readTVar cSentPackets +    writeTVar cSentPackets notAcked +    return $ sequence_ $ map spOnAck acked |