diff options
Diffstat (limited to 'src/Erebos/Network')
| -rw-r--r-- | src/Erebos/Network/Protocol.hs | 190 | 
1 files changed, 177 insertions, 13 deletions
| diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index d7253e3..e5eb652 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -20,6 +20,13 @@ module Erebos.Network.Protocol (      connGetChannel,      connSetChannel, +    RawStreamReader, RawStreamWriter, +    connAddWriteStream, +    connAddReadStream, +    readStreamToList, +    readObjectsFromStream, +    writeByteStringToStream, +      module Erebos.Flow,  ) where @@ -39,6 +46,8 @@ import Data.List  import Data.Maybe  import Data.Text (Text)  import Data.Text qualified as T +import Data.Void +import Data.Word  import System.Clock @@ -77,6 +86,7 @@ data TransportHeaderItem      | TrChannelAccept RefDigest      | ServiceType ServiceID      | ServiceRef RefDigest +    | StreamOpen Word8      deriving (Eq, Show)  newtype Cookie = Cookie ByteString @@ -111,6 +121,7 @@ transportToObject st (TransportHeader items) = Rec $ map single items                TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst)                ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype)                ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst) +              StreamOpen num -> (BC.pack "STO", RecInt $ fromIntegral num)  transportFromObject :: PartialObject -> Maybe TransportHeader  transportFromObject (Rec items) = case catMaybes $ map single items of @@ -132,6 +143,7 @@ transportFromObject (Rec items) = case catMaybes $ map single items of                | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref                | name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid                | name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref +              | name == BC.pack "STO", RecInt num <- content -> Just $ StreamOpen $ fromIntegral num                | otherwise -> Nothing  transportFromObject _ = Nothing @@ -150,13 +162,16 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState      }  data Connection addr = Connection -    { cAddress :: addr +    { cGlobalState :: GlobalState addr +    , cAddress :: addr      , cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem])      , cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject)      , cChannel :: TVar ChannelState      , cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem])      , cSentPackets :: TVar [SentPacket]      , cToAcknowledge :: TVar [Integer] +    , cInStreams :: TVar [(Word8, Stream)] +    , cOutStreams :: TVar [(Word8, Stream)]      }  connAddress :: Connection addr -> addr @@ -172,6 +187,123 @@ connSetChannel :: Connection addr -> ChannelState -> STM ()  connSetChannel Connection {..} ch = do      writeTVar cChannel ch +connAddWriteStream :: Connection addr -> STM (TransportHeaderItem, RawStreamWriter, IO ()) +connAddWriteStream conn@Connection {..} = do +    let GlobalState {..} = cGlobalState + +    outStreams <- readTVar cOutStreams +    let doInsert n (s@(n', _) : rest) | n == n' = +            fmap (s:) <$> doInsert (n + 1) rest +        doInsert n streams = do +            (sFlowIn, sFlowOut) <- newFlow +            sNextSequence <- newTVar 0 +            let info = (n, Stream {..}) +            return (info, info : streams) +    ((streamNumber, stream), outStreams') <- doInsert 1 outStreams +    writeTVar cOutStreams outStreams' + +    let go = do +            msg <- atomically $ readFlow (sFlowOut stream) +            let (plain, cont) = case msg of +                    StreamData {..} -> (stpData, True) +                    StreamClosed {} -> (BC.empty, False) +                    -- TODO: send channel closed only after delivering all previous data packets +                    -- TODO: free channel number after delivering stream closed +            let secure = True +                plainAckedBy = [] + +            mbch <- atomically (readTVar cChannel) >>= return . \case +                ChannelEstablished ch   -> Just ch +                ChannelOurAccept _ _ ch -> Just ch +                _                       -> Nothing + +            mbs <- case mbch of +                Just ch -> do +                    runExceptT (channelEncrypt ch $ B.concat +                            [ B.singleton streamNumber +                            , B.singleton (fromIntegral (stpSequence msg) :: Word8) +                            , plain +                            ] ) >>= \case +                        Right (ctext, counter) -> do +                            let isAcked = True +                            return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) +                        Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err +                                       return Nothing +                Nothing | secure    -> return Nothing +                        | otherwise -> return $ Just (plain, plainAckedBy) + +            case mbs of +                Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) +                Nothing -> return () + +            when cont go + +    return (StreamOpen streamNumber, sFlowIn stream, go) + +connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader +connAddReadStream Connection {..} streamNumber = do +    inStreams <- readTVar cInStreams +    let doInsert (s@(n, _) : rest) +            | streamNumber <  n = fmap (s:) <$> doInsert rest +            | streamNumber == n = doInsert rest +        doInsert streams = do +            (sFlowIn, sFlowOut) <- newFlow +            sNextSequence <- newTVar 0 +            let stream = Stream {..} +            return (stream, (streamNumber, stream) : streams) +    (stream, inStreams') <- doInsert inStreams +    writeTVar cInStreams inStreams' +    return $ sFlowOut stream + + +type RawStreamReader = Flow StreamPacket Void +type RawStreamWriter = Flow Void StreamPacket + +data Stream = Stream +    { sFlowIn :: Flow Void StreamPacket +    , sFlowOut :: Flow StreamPacket Void +    , sNextSequence :: TVar Word64 +    } + +data StreamPacket +    = StreamData +        { stpSequence :: Word64 +        , stpData :: BC.ByteString +        } +    | StreamClosed +        { stpSequence :: Word64 +        } + +readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) +readStreamToList stream = readFlowIO stream >>= \case +    StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream +    StreamClosed sqEnd  -> return (sqEnd, []) + +readObjectsFromStream :: PartialStorage -> RawStreamReader -> IO (Except String [PartialObject]) +readObjectsFromStream st stream = do +    (seqEnd, list) <- readStreamToList stream +    print (seqEnd, length list, list) +    let validate s ((s', bytes) : rest) +            | s == s'   = (bytes : ) <$> validate (s + 1) rest +            | s >  s'   = validate s rest +            | otherwise = throwError "missing object chunk" +        validate s [] +            | s == seqEnd = return [] +            | otherwise = throwError "content length mismatch" +    return $ do +        content <- BL.fromChunks <$> validate 0 list +        deserializeObjects st content + +writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO () +writeByteStringToStream stream = go 0 +  where +    go seqNum bstr +        | BL.null bstr = writeFlowIO stream $ StreamClosed seqNum +        | otherwise    = do +            let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU +            writeFlowIO stream $ StreamData seqNum (BL.toStrict cur) +            go (seqNum + 1) rest +  data WaitingRef = WaitingRef      { wrefStorage :: Storage @@ -260,7 +392,7 @@ findConnection GlobalState {..} addr = do      find ((addr==) . cAddress) <$> readTVar gConnections  newConnection :: GlobalState addr -> addr -> STM (Connection addr) -newConnection GlobalState {..} addr = do +newConnection cGlobalState@GlobalState {..} addr = do      conns <- readTVar gConnections      let cAddress = addr @@ -269,6 +401,8 @@ newConnection GlobalState {..} addr = do      cSecureOutQueue <- newTQueue      cSentPackets <- newTVar []      cToAcknowledge <- newTVar [] +    cInStreams <- newTVar [] +    cOutStreams <- newTVar []      let conn = Connection {..}      writeTVar gConnections (conn : conns) @@ -306,24 +440,30 @@ processIncoming gs@GlobalState {..} = do                          case B.uncons dec of                              Just (0x00, content) -> do                                  objs <- deserialize content -                                return (True, objs, Just counter) +                                return $ Left (True, objs, Just counter) + +                            Just (snum, dec') +                                | snum < 64 +                                , Just (seq8, content) <- B.uncons dec' +                                -> do +                                    return $ Right (snum, seq8, content, counter)                              Just (_, _) -> do -                                throwError "streams not implemented" +                                throwError "unexpected stream header"                              Nothing -> do                                  throwError "empty decrypted content"                      | b .&. 0xE0 == 0x60 -> do                          objs <- deserialize msg -                        return (False, objs, Nothing) +                        return $ Left (False, objs, Nothing)                      | otherwise -> throwError "invalid packet"                  Nothing -> throwError "empty packet"          runExceptT parse >>= \case -            Right (secure, objs, mbcounter) +            Right (Left (secure, objs, mbcounter))                  | hobj:content <- objs                  , Just header@(TransportHeader items) <- transportFromObject hobj                  -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case @@ -342,6 +482,29 @@ processIncoming gs@GlobalState {..} = do                        gLog $ show addr ++ ": invalid objects"                        gLog $ show objs +            Right (Right (snum, seq8, content, counter)) +                | Just Connection {..} <- mbconn +                -> atomically $ do +                    (lookup snum <$> readTVar cInStreams) >>= \case +                        Nothing -> +                            gLog $ "unexpected stream number " ++ show snum + +                        Just Stream {..} -> do +                            expectedSequence <- readTVar sNextSequence +                            let seqFull = expectedSequence - 0x80 + fromIntegral (seq8 - fromIntegral expectedSequence + 0x80 :: Word8) +                            sdata <- if +                                | B.null content -> do +                                    modifyTVar' cInStreams $ filter ((/=snum) . fst) +                                    return $ StreamClosed seqFull +                                | otherwise -> do +                                    writeTVar sNextSequence $ max expectedSequence (seqFull + 1) +                                    return $ StreamData seqFull content +                            writeFlow sFlowIn sdata +                            modifyTVar' cToAcknowledge (fromIntegral counter :) + +                | otherwise -> do +                    atomically $ gLog $ show addr <> ": stream packet without connection" +              Left err -> do                  atomically $ gLog $ show addr <> ": failed to parse packet: " <> err @@ -455,8 +618,9 @@ createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr)  verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool  verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie -resendBytes :: GlobalState addr -> Connection addr -> SentPacket -> IO () -resendBytes GlobalState {..} Connection {..} sp = do +resendBytes :: Connection addr -> SentPacket -> IO () +resendBytes Connection {..} sp = do +    let GlobalState {..} = cGlobalState      now <- getTime MonotonicRaw      atomically $ do          when (isJust $ spAckedBy sp) $ do @@ -466,8 +630,8 @@ resendBytes GlobalState {..} Connection {..} sp = do                  }          writeFlow gDataFlow (cAddress, spData sp) -sendBytes :: GlobalState addr -> Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () -sendBytes gs conn bs ackedBy = resendBytes gs conn +sendBytes :: Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () +sendBytes conn bs ackedBy = resendBytes conn      SentPacket          { spTime = undefined          , spRetryCount = -1 @@ -526,7 +690,7 @@ processOutgoing gs@GlobalState {..} = do                              | otherwise -> return $ Just (BL.toStrict plain, plainAckedBy)                  case mbs of -                    Just (bs, ackedBy) -> sendBytes gs conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) +                    Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy)                      Nothing -> return ()      let retransmitPacket :: Connection addr -> STM (IO ()) @@ -545,7 +709,7 @@ processOutgoing gs@GlobalState {..} = do                     else retry                else do                  writeTVar cSentPackets rest -                return $ resendBytes gs conn sp +                return $ resendBytes conn sp      let handleControlRequests = readFlow gControlFlow >>= \case              RequestConnection addr -> do @@ -561,7 +725,7 @@ processOutgoing gs@GlobalState {..} = do                                  , lazyLoadBytes gInitConfig                                  ]                          writeTVar cChannel ChannelCookieWait -                        return $ sendBytes gs conn packet $ Just $ \case CookieSet {} -> True; _ -> False +                        return $ sendBytes conn packet $ Just $ \case CookieSet {} -> True; _ -> False                      _ -> return $ return ()              SendAnnounce addr -> do |