From e0a5dbf7164517c79940da5691745cd281e8557e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sat, 2 Mar 2024 21:01:37 +0100 Subject: Network streams Changelog: Implemented streams in network protocol --- src/Erebos/Network.hs | 51 ++++++++++- src/Erebos/Network/Protocol.hs | 190 ++++++++++++++++++++++++++++++++++++++--- test/network.test | 2 +- 3 files changed, 226 insertions(+), 17 deletions(-) diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs index b26ada5..7c6a61e 100644 --- a/src/Erebos/Network.hs +++ b/src/Erebos/Network.hs @@ -30,7 +30,8 @@ import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State -import qualified Data.ByteString.Char8 as BC +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL import Data.Function import Data.IP qualified as IP import Data.List @@ -179,6 +180,11 @@ lookupServiceType (ServiceType stype : _) = Just stype lookupServiceType (_ : hs) = lookupServiceType hs lookupServiceType [] = Nothing +lookupNewStreams :: [TransportHeaderItem] -> [Word8] +lookupNewStreams (StreamOpen num : rest) = num : lookupNewStreams rest +lookupNewStreams (_ : rest) = lookupNewStreams rest +lookupNewStreams [] = [] + newWaitingRef :: RefDigest -> (Ref -> WaitingRefCallback) -> PacketHandler WaitingRef newWaitingRef dgst act = do @@ -421,6 +427,25 @@ addAckedBy hs = modify $ \ph -> ph { phAckedBy = foldr appendDistinct (phAckedBy addBody :: Ref -> PacketHandler () addBody r = modify $ \ph -> ph { phBody = r `appendDistinct` phBody ph } +openStream :: PacketHandler RawStreamWriter +openStream = do + Peer {..} <- gets phPeer + conn <- readTVarP peerConnection >>= \case + Right conn -> return conn + _ -> throwError "can't open stream without established connection" + (hdr, writer, handler) <- liftSTM $ connAddWriteStream conn + liftSTM $ writeTQueue (serverIOActions peerServer_) (liftIO $ forkServerThread peerServer_ handler) + addHeader hdr + return writer + +acceptStream :: Word8 -> PacketHandler RawStreamReader +acceptStream streamNumber = do + Peer {..} <- gets phPeer + conn <- readTVarP peerConnection >>= \case + Right conn -> return conn + _ -> throwError "can't accept stream without established connection" + liftSTM $ connAddReadStream conn streamNumber + appendDistinct :: Eq a => a -> [a] -> [a] appendDistinct x (y:ys) | x == y = y : ys | otherwise = y : appendDistinct x ys @@ -461,7 +486,15 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = partialRefFromDigest (peerInStorage peer) dgst addHeader $ DataResponse dgst addAckedBy [ Acknowledged dgst, Rejected dgst ] - addBody $ mref + let bytes = lazyLoadBytes mref + -- TODO: MTU + if (secure && BL.length bytes > 500) + then do + stream <- openStream + liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do + writeByteStringToStream stream bytes + else do + addBody $ mref | otherwise -> do logd $ "unauthorized data request for " ++ show dgst addHeader $ Rejected dgst @@ -471,6 +504,18 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = when (not secure) $ do addHeader $ Acknowledged dgst liftSTM $ writeTQueue (serverDataResponse server) (peer, Just pref) + + | streamNumber : _ <- lookupNewStreams headers -> do + streamReader <- acceptStream streamNumber + liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do + (runExcept <$> readObjectsFromStream (peerInStorage peer) streamReader) >>= \case + Left err -> atomically $ writeTQueue (serverErrorLog server) $ + "failed to receive object from stream: " <> err + Right objs -> do + forM_ objs $ \obj -> do + pref <- storeObject (peerInStorage peer) obj + atomically $ writeTQueue (serverDataResponse server) (peer, Just pref) + | otherwise -> throwError $ "mismatched data response " ++ show dgst AnnounceSelf dgst @@ -708,7 +753,7 @@ sendToPeerList peer parts = do ServiceReply (Right sx) use -> return $ Just (storedRef sx, use) ServiceFinally act -> act >> return Nothing let dgsts = map (refDigest . fst) srefs - let content = map fst $ filter snd srefs + let content = map fst $ filter (\(ref, use) -> use && BL.length (lazyLoadBytes ref) < 500) srefs -- TODO: MTU header = TransportHeader (ServiceType (serviceID $ head parts) : map ServiceRef dgsts) packet = TransportPacket header content ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- dgsts ] diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index d7253e3..e5eb652 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -20,6 +20,13 @@ module Erebos.Network.Protocol ( connGetChannel, connSetChannel, + RawStreamReader, RawStreamWriter, + connAddWriteStream, + connAddReadStream, + readStreamToList, + readObjectsFromStream, + writeByteStringToStream, + module Erebos.Flow, ) where @@ -39,6 +46,8 @@ import Data.List import Data.Maybe import Data.Text (Text) import Data.Text qualified as T +import Data.Void +import Data.Word import System.Clock @@ -77,6 +86,7 @@ data TransportHeaderItem | TrChannelAccept RefDigest | ServiceType ServiceID | ServiceRef RefDigest + | StreamOpen Word8 deriving (Eq, Show) newtype Cookie = Cookie ByteString @@ -111,6 +121,7 @@ transportToObject st (TransportHeader items) = Rec $ map single items TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst) ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype) ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst) + StreamOpen num -> (BC.pack "STO", RecInt $ fromIntegral num) transportFromObject :: PartialObject -> Maybe TransportHeader transportFromObject (Rec items) = case catMaybes $ map single items of @@ -132,6 +143,7 @@ transportFromObject (Rec items) = case catMaybes $ map single items of | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref | name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid | name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref + | name == BC.pack "STO", RecInt num <- content -> Just $ StreamOpen $ fromIntegral num | otherwise -> Nothing transportFromObject _ = Nothing @@ -150,13 +162,16 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState } data Connection addr = Connection - { cAddress :: addr + { cGlobalState :: GlobalState addr + , cAddress :: addr , cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) , cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject) , cChannel :: TVar ChannelState , cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem]) , cSentPackets :: TVar [SentPacket] , cToAcknowledge :: TVar [Integer] + , cInStreams :: TVar [(Word8, Stream)] + , cOutStreams :: TVar [(Word8, Stream)] } connAddress :: Connection addr -> addr @@ -172,6 +187,123 @@ connSetChannel :: Connection addr -> ChannelState -> STM () connSetChannel Connection {..} ch = do writeTVar cChannel ch +connAddWriteStream :: Connection addr -> STM (TransportHeaderItem, RawStreamWriter, IO ()) +connAddWriteStream conn@Connection {..} = do + let GlobalState {..} = cGlobalState + + outStreams <- readTVar cOutStreams + let doInsert n (s@(n', _) : rest) | n == n' = + fmap (s:) <$> doInsert (n + 1) rest + doInsert n streams = do + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + let info = (n, Stream {..}) + return (info, info : streams) + ((streamNumber, stream), outStreams') <- doInsert 1 outStreams + writeTVar cOutStreams outStreams' + + let go = do + msg <- atomically $ readFlow (sFlowOut stream) + let (plain, cont) = case msg of + StreamData {..} -> (stpData, True) + StreamClosed {} -> (BC.empty, False) + -- TODO: send channel closed only after delivering all previous data packets + -- TODO: free channel number after delivering stream closed + let secure = True + plainAckedBy = [] + + mbch <- atomically (readTVar cChannel) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ _ ch -> Just ch + _ -> Nothing + + mbs <- case mbch of + Just ch -> do + runExceptT (channelEncrypt ch $ B.concat + [ B.singleton streamNumber + , B.singleton (fromIntegral (stpSequence msg) :: Word8) + , plain + ] ) >>= \case + Right (ctext, counter) -> do + let isAcked = True + return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err + return Nothing + Nothing | secure -> return Nothing + | otherwise -> return $ Just (plain, plainAckedBy) + + case mbs of + Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) + Nothing -> return () + + when cont go + + return (StreamOpen streamNumber, sFlowIn stream, go) + +connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader +connAddReadStream Connection {..} streamNumber = do + inStreams <- readTVar cInStreams + let doInsert (s@(n, _) : rest) + | streamNumber < n = fmap (s:) <$> doInsert rest + | streamNumber == n = doInsert rest + doInsert streams = do + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + let stream = Stream {..} + return (stream, (streamNumber, stream) : streams) + (stream, inStreams') <- doInsert inStreams + writeTVar cInStreams inStreams' + return $ sFlowOut stream + + +type RawStreamReader = Flow StreamPacket Void +type RawStreamWriter = Flow Void StreamPacket + +data Stream = Stream + { sFlowIn :: Flow Void StreamPacket + , sFlowOut :: Flow StreamPacket Void + , sNextSequence :: TVar Word64 + } + +data StreamPacket + = StreamData + { stpSequence :: Word64 + , stpData :: BC.ByteString + } + | StreamClosed + { stpSequence :: Word64 + } + +readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) +readStreamToList stream = readFlowIO stream >>= \case + StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream + StreamClosed sqEnd -> return (sqEnd, []) + +readObjectsFromStream :: PartialStorage -> RawStreamReader -> IO (Except String [PartialObject]) +readObjectsFromStream st stream = do + (seqEnd, list) <- readStreamToList stream + print (seqEnd, length list, list) + let validate s ((s', bytes) : rest) + | s == s' = (bytes : ) <$> validate (s + 1) rest + | s > s' = validate s rest + | otherwise = throwError "missing object chunk" + validate s [] + | s == seqEnd = return [] + | otherwise = throwError "content length mismatch" + return $ do + content <- BL.fromChunks <$> validate 0 list + deserializeObjects st content + +writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO () +writeByteStringToStream stream = go 0 + where + go seqNum bstr + | BL.null bstr = writeFlowIO stream $ StreamClosed seqNum + | otherwise = do + let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU + writeFlowIO stream $ StreamData seqNum (BL.toStrict cur) + go (seqNum + 1) rest + data WaitingRef = WaitingRef { wrefStorage :: Storage @@ -260,7 +392,7 @@ findConnection GlobalState {..} addr = do find ((addr==) . cAddress) <$> readTVar gConnections newConnection :: GlobalState addr -> addr -> STM (Connection addr) -newConnection GlobalState {..} addr = do +newConnection cGlobalState@GlobalState {..} addr = do conns <- readTVar gConnections let cAddress = addr @@ -269,6 +401,8 @@ newConnection GlobalState {..} addr = do cSecureOutQueue <- newTQueue cSentPackets <- newTVar [] cToAcknowledge <- newTVar [] + cInStreams <- newTVar [] + cOutStreams <- newTVar [] let conn = Connection {..} writeTVar gConnections (conn : conns) @@ -306,24 +440,30 @@ processIncoming gs@GlobalState {..} = do case B.uncons dec of Just (0x00, content) -> do objs <- deserialize content - return (True, objs, Just counter) + return $ Left (True, objs, Just counter) + + Just (snum, dec') + | snum < 64 + , Just (seq8, content) <- B.uncons dec' + -> do + return $ Right (snum, seq8, content, counter) Just (_, _) -> do - throwError "streams not implemented" + throwError "unexpected stream header" Nothing -> do throwError "empty decrypted content" | b .&. 0xE0 == 0x60 -> do objs <- deserialize msg - return (False, objs, Nothing) + return $ Left (False, objs, Nothing) | otherwise -> throwError "invalid packet" Nothing -> throwError "empty packet" runExceptT parse >>= \case - Right (secure, objs, mbcounter) + Right (Left (secure, objs, mbcounter)) | hobj:content <- objs , Just header@(TransportHeader items) <- transportFromObject hobj -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case @@ -342,6 +482,29 @@ processIncoming gs@GlobalState {..} = do gLog $ show addr ++ ": invalid objects" gLog $ show objs + Right (Right (snum, seq8, content, counter)) + | Just Connection {..} <- mbconn + -> atomically $ do + (lookup snum <$> readTVar cInStreams) >>= \case + Nothing -> + gLog $ "unexpected stream number " ++ show snum + + Just Stream {..} -> do + expectedSequence <- readTVar sNextSequence + let seqFull = expectedSequence - 0x80 + fromIntegral (seq8 - fromIntegral expectedSequence + 0x80 :: Word8) + sdata <- if + | B.null content -> do + modifyTVar' cInStreams $ filter ((/=snum) . fst) + return $ StreamClosed seqFull + | otherwise -> do + writeTVar sNextSequence $ max expectedSequence (seqFull + 1) + return $ StreamData seqFull content + writeFlow sFlowIn sdata + modifyTVar' cToAcknowledge (fromIntegral counter :) + + | otherwise -> do + atomically $ gLog $ show addr <> ": stream packet without connection" + Left err -> do atomically $ gLog $ show addr <> ": failed to parse packet: " <> err @@ -455,8 +618,9 @@ createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr) verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie -resendBytes :: GlobalState addr -> Connection addr -> SentPacket -> IO () -resendBytes GlobalState {..} Connection {..} sp = do +resendBytes :: Connection addr -> SentPacket -> IO () +resendBytes Connection {..} sp = do + let GlobalState {..} = cGlobalState now <- getTime MonotonicRaw atomically $ do when (isJust $ spAckedBy sp) $ do @@ -466,8 +630,8 @@ resendBytes GlobalState {..} Connection {..} sp = do } writeFlow gDataFlow (cAddress, spData sp) -sendBytes :: GlobalState addr -> Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () -sendBytes gs conn bs ackedBy = resendBytes gs conn +sendBytes :: Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () +sendBytes conn bs ackedBy = resendBytes conn SentPacket { spTime = undefined , spRetryCount = -1 @@ -526,7 +690,7 @@ processOutgoing gs@GlobalState {..} = do | otherwise -> return $ Just (BL.toStrict plain, plainAckedBy) case mbs of - Just (bs, ackedBy) -> sendBytes gs conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) + Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) Nothing -> return () let retransmitPacket :: Connection addr -> STM (IO ()) @@ -545,7 +709,7 @@ processOutgoing gs@GlobalState {..} = do else retry else do writeTVar cSentPackets rest - return $ resendBytes gs conn sp + return $ resendBytes conn sp let handleControlRequests = readFlow gControlFlow >>= \case RequestConnection addr -> do @@ -561,7 +725,7 @@ processOutgoing gs@GlobalState {..} = do , lazyLoadBytes gInitConfig ] writeTVar cChannel ChannelCookieWait - return $ sendBytes gs conn packet $ Just $ \case CookieSet {} -> True; _ -> False + return $ sendBytes conn packet $ Just $ \case CookieSet {} -> True; _ -> False _ -> return $ return () SendAnnounce addr -> do diff --git a/test/network.test b/test/network.test index 0b9fecb..3df7376 100644 --- a/test/network.test +++ b/test/network.test @@ -133,7 +133,7 @@ test LargeData: /peer 1 addr ${p1.node.ip} 29665/ /peer 1 id Device1/ - for i in [0..1]: + for i in [0..10]: with p1: send "store blob" for j in [1 .. i * 10]: -- cgit v1.2.3