diff options
Diffstat (limited to 'src/Erebos/Network/Protocol.hs')
-rw-r--r-- | src/Erebos/Network/Protocol.hs | 190 |
1 files changed, 177 insertions, 13 deletions
diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index d7253e3..e5eb652 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -20,6 +20,13 @@ module Erebos.Network.Protocol ( connGetChannel, connSetChannel, + RawStreamReader, RawStreamWriter, + connAddWriteStream, + connAddReadStream, + readStreamToList, + readObjectsFromStream, + writeByteStringToStream, + module Erebos.Flow, ) where @@ -39,6 +46,8 @@ import Data.List import Data.Maybe import Data.Text (Text) import Data.Text qualified as T +import Data.Void +import Data.Word import System.Clock @@ -77,6 +86,7 @@ data TransportHeaderItem | TrChannelAccept RefDigest | ServiceType ServiceID | ServiceRef RefDigest + | StreamOpen Word8 deriving (Eq, Show) newtype Cookie = Cookie ByteString @@ -111,6 +121,7 @@ transportToObject st (TransportHeader items) = Rec $ map single items TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst) ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype) ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst) + StreamOpen num -> (BC.pack "STO", RecInt $ fromIntegral num) transportFromObject :: PartialObject -> Maybe TransportHeader transportFromObject (Rec items) = case catMaybes $ map single items of @@ -132,6 +143,7 @@ transportFromObject (Rec items) = case catMaybes $ map single items of | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref | name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid | name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref + | name == BC.pack "STO", RecInt num <- content -> Just $ StreamOpen $ fromIntegral num | otherwise -> Nothing transportFromObject _ = Nothing @@ -150,13 +162,16 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState } data Connection addr = Connection - { cAddress :: addr + { cGlobalState :: GlobalState addr + , cAddress :: addr , cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) , cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject) , cChannel :: TVar ChannelState , cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem]) , cSentPackets :: TVar [SentPacket] , cToAcknowledge :: TVar [Integer] + , cInStreams :: TVar [(Word8, Stream)] + , cOutStreams :: TVar [(Word8, Stream)] } connAddress :: Connection addr -> addr @@ -172,6 +187,123 @@ connSetChannel :: Connection addr -> ChannelState -> STM () connSetChannel Connection {..} ch = do writeTVar cChannel ch +connAddWriteStream :: Connection addr -> STM (TransportHeaderItem, RawStreamWriter, IO ()) +connAddWriteStream conn@Connection {..} = do + let GlobalState {..} = cGlobalState + + outStreams <- readTVar cOutStreams + let doInsert n (s@(n', _) : rest) | n == n' = + fmap (s:) <$> doInsert (n + 1) rest + doInsert n streams = do + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + let info = (n, Stream {..}) + return (info, info : streams) + ((streamNumber, stream), outStreams') <- doInsert 1 outStreams + writeTVar cOutStreams outStreams' + + let go = do + msg <- atomically $ readFlow (sFlowOut stream) + let (plain, cont) = case msg of + StreamData {..} -> (stpData, True) + StreamClosed {} -> (BC.empty, False) + -- TODO: send channel closed only after delivering all previous data packets + -- TODO: free channel number after delivering stream closed + let secure = True + plainAckedBy = [] + + mbch <- atomically (readTVar cChannel) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ _ ch -> Just ch + _ -> Nothing + + mbs <- case mbch of + Just ch -> do + runExceptT (channelEncrypt ch $ B.concat + [ B.singleton streamNumber + , B.singleton (fromIntegral (stpSequence msg) :: Word8) + , plain + ] ) >>= \case + Right (ctext, counter) -> do + let isAcked = True + return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err + return Nothing + Nothing | secure -> return Nothing + | otherwise -> return $ Just (plain, plainAckedBy) + + case mbs of + Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) + Nothing -> return () + + when cont go + + return (StreamOpen streamNumber, sFlowIn stream, go) + +connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader +connAddReadStream Connection {..} streamNumber = do + inStreams <- readTVar cInStreams + let doInsert (s@(n, _) : rest) + | streamNumber < n = fmap (s:) <$> doInsert rest + | streamNumber == n = doInsert rest + doInsert streams = do + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + let stream = Stream {..} + return (stream, (streamNumber, stream) : streams) + (stream, inStreams') <- doInsert inStreams + writeTVar cInStreams inStreams' + return $ sFlowOut stream + + +type RawStreamReader = Flow StreamPacket Void +type RawStreamWriter = Flow Void StreamPacket + +data Stream = Stream + { sFlowIn :: Flow Void StreamPacket + , sFlowOut :: Flow StreamPacket Void + , sNextSequence :: TVar Word64 + } + +data StreamPacket + = StreamData + { stpSequence :: Word64 + , stpData :: BC.ByteString + } + | StreamClosed + { stpSequence :: Word64 + } + +readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) +readStreamToList stream = readFlowIO stream >>= \case + StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream + StreamClosed sqEnd -> return (sqEnd, []) + +readObjectsFromStream :: PartialStorage -> RawStreamReader -> IO (Except String [PartialObject]) +readObjectsFromStream st stream = do + (seqEnd, list) <- readStreamToList stream + print (seqEnd, length list, list) + let validate s ((s', bytes) : rest) + | s == s' = (bytes : ) <$> validate (s + 1) rest + | s > s' = validate s rest + | otherwise = throwError "missing object chunk" + validate s [] + | s == seqEnd = return [] + | otherwise = throwError "content length mismatch" + return $ do + content <- BL.fromChunks <$> validate 0 list + deserializeObjects st content + +writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO () +writeByteStringToStream stream = go 0 + where + go seqNum bstr + | BL.null bstr = writeFlowIO stream $ StreamClosed seqNum + | otherwise = do + let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU + writeFlowIO stream $ StreamData seqNum (BL.toStrict cur) + go (seqNum + 1) rest + data WaitingRef = WaitingRef { wrefStorage :: Storage @@ -260,7 +392,7 @@ findConnection GlobalState {..} addr = do find ((addr==) . cAddress) <$> readTVar gConnections newConnection :: GlobalState addr -> addr -> STM (Connection addr) -newConnection GlobalState {..} addr = do +newConnection cGlobalState@GlobalState {..} addr = do conns <- readTVar gConnections let cAddress = addr @@ -269,6 +401,8 @@ newConnection GlobalState {..} addr = do cSecureOutQueue <- newTQueue cSentPackets <- newTVar [] cToAcknowledge <- newTVar [] + cInStreams <- newTVar [] + cOutStreams <- newTVar [] let conn = Connection {..} writeTVar gConnections (conn : conns) @@ -306,24 +440,30 @@ processIncoming gs@GlobalState {..} = do case B.uncons dec of Just (0x00, content) -> do objs <- deserialize content - return (True, objs, Just counter) + return $ Left (True, objs, Just counter) + + Just (snum, dec') + | snum < 64 + , Just (seq8, content) <- B.uncons dec' + -> do + return $ Right (snum, seq8, content, counter) Just (_, _) -> do - throwError "streams not implemented" + throwError "unexpected stream header" Nothing -> do throwError "empty decrypted content" | b .&. 0xE0 == 0x60 -> do objs <- deserialize msg - return (False, objs, Nothing) + return $ Left (False, objs, Nothing) | otherwise -> throwError "invalid packet" Nothing -> throwError "empty packet" runExceptT parse >>= \case - Right (secure, objs, mbcounter) + Right (Left (secure, objs, mbcounter)) | hobj:content <- objs , Just header@(TransportHeader items) <- transportFromObject hobj -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case @@ -342,6 +482,29 @@ processIncoming gs@GlobalState {..} = do gLog $ show addr ++ ": invalid objects" gLog $ show objs + Right (Right (snum, seq8, content, counter)) + | Just Connection {..} <- mbconn + -> atomically $ do + (lookup snum <$> readTVar cInStreams) >>= \case + Nothing -> + gLog $ "unexpected stream number " ++ show snum + + Just Stream {..} -> do + expectedSequence <- readTVar sNextSequence + let seqFull = expectedSequence - 0x80 + fromIntegral (seq8 - fromIntegral expectedSequence + 0x80 :: Word8) + sdata <- if + | B.null content -> do + modifyTVar' cInStreams $ filter ((/=snum) . fst) + return $ StreamClosed seqFull + | otherwise -> do + writeTVar sNextSequence $ max expectedSequence (seqFull + 1) + return $ StreamData seqFull content + writeFlow sFlowIn sdata + modifyTVar' cToAcknowledge (fromIntegral counter :) + + | otherwise -> do + atomically $ gLog $ show addr <> ": stream packet without connection" + Left err -> do atomically $ gLog $ show addr <> ": failed to parse packet: " <> err @@ -455,8 +618,9 @@ createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr) verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie -resendBytes :: GlobalState addr -> Connection addr -> SentPacket -> IO () -resendBytes GlobalState {..} Connection {..} sp = do +resendBytes :: Connection addr -> SentPacket -> IO () +resendBytes Connection {..} sp = do + let GlobalState {..} = cGlobalState now <- getTime MonotonicRaw atomically $ do when (isJust $ spAckedBy sp) $ do @@ -466,8 +630,8 @@ resendBytes GlobalState {..} Connection {..} sp = do } writeFlow gDataFlow (cAddress, spData sp) -sendBytes :: GlobalState addr -> Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () -sendBytes gs conn bs ackedBy = resendBytes gs conn +sendBytes :: Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () +sendBytes conn bs ackedBy = resendBytes conn SentPacket { spTime = undefined , spRetryCount = -1 @@ -526,7 +690,7 @@ processOutgoing gs@GlobalState {..} = do | otherwise -> return $ Just (BL.toStrict plain, plainAckedBy) case mbs of - Just (bs, ackedBy) -> sendBytes gs conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) + Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) Nothing -> return () let retransmitPacket :: Connection addr -> STM (IO ()) @@ -545,7 +709,7 @@ processOutgoing gs@GlobalState {..} = do else retry else do writeTVar cSentPackets rest - return $ resendBytes gs conn sp + return $ resendBytes conn sp let handleControlRequests = readFlow gControlFlow >>= \case RequestConnection addr -> do @@ -561,7 +725,7 @@ processOutgoing gs@GlobalState {..} = do , lazyLoadBytes gInitConfig ] writeTVar cChannel ChannelCookieWait - return $ sendBytes gs conn packet $ Just $ \case CookieSet {} -> True; _ -> False + return $ sendBytes conn packet $ Just $ \case CookieSet {} -> True; _ -> False _ -> return $ return () SendAnnounce addr -> do |