summaryrefslogtreecommitdiff
path: root/src/Erebos/Network/Protocol.hs
diff options
context:
space:
mode:
authorRoman Smrž <roman.smrz@seznam.cz>2024-03-02 21:01:37 +0100
committerRoman Smrž <roman.smrz@seznam.cz>2024-03-20 11:43:19 +0100
commite0a5dbf7164517c79940da5691745cd281e8557e (patch)
tree7eedc85d8a34e18650f88ae91f7fe5e7b6790557 /src/Erebos/Network/Protocol.hs
parent9d2671dc19bdc46d1f0fc976813cb9d63e34c71e (diff)
Network streams
Changelog: Implemented streams in network protocol
Diffstat (limited to 'src/Erebos/Network/Protocol.hs')
-rw-r--r--src/Erebos/Network/Protocol.hs190
1 files changed, 177 insertions, 13 deletions
diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs
index d7253e3..e5eb652 100644
--- a/src/Erebos/Network/Protocol.hs
+++ b/src/Erebos/Network/Protocol.hs
@@ -20,6 +20,13 @@ module Erebos.Network.Protocol (
connGetChannel,
connSetChannel,
+ RawStreamReader, RawStreamWriter,
+ connAddWriteStream,
+ connAddReadStream,
+ readStreamToList,
+ readObjectsFromStream,
+ writeByteStringToStream,
+
module Erebos.Flow,
) where
@@ -39,6 +46,8 @@ import Data.List
import Data.Maybe
import Data.Text (Text)
import Data.Text qualified as T
+import Data.Void
+import Data.Word
import System.Clock
@@ -77,6 +86,7 @@ data TransportHeaderItem
| TrChannelAccept RefDigest
| ServiceType ServiceID
| ServiceRef RefDigest
+ | StreamOpen Word8
deriving (Eq, Show)
newtype Cookie = Cookie ByteString
@@ -111,6 +121,7 @@ transportToObject st (TransportHeader items) = Rec $ map single items
TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst)
ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype)
ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst)
+ StreamOpen num -> (BC.pack "STO", RecInt $ fromIntegral num)
transportFromObject :: PartialObject -> Maybe TransportHeader
transportFromObject (Rec items) = case catMaybes $ map single items of
@@ -132,6 +143,7 @@ transportFromObject (Rec items) = case catMaybes $ map single items of
| name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref
| name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid
| name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref
+ | name == BC.pack "STO", RecInt num <- content -> Just $ StreamOpen $ fromIntegral num
| otherwise -> Nothing
transportFromObject _ = Nothing
@@ -150,13 +162,16 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState
}
data Connection addr = Connection
- { cAddress :: addr
+ { cGlobalState :: GlobalState addr
+ , cAddress :: addr
, cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem])
, cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject)
, cChannel :: TVar ChannelState
, cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem])
, cSentPackets :: TVar [SentPacket]
, cToAcknowledge :: TVar [Integer]
+ , cInStreams :: TVar [(Word8, Stream)]
+ , cOutStreams :: TVar [(Word8, Stream)]
}
connAddress :: Connection addr -> addr
@@ -172,6 +187,123 @@ connSetChannel :: Connection addr -> ChannelState -> STM ()
connSetChannel Connection {..} ch = do
writeTVar cChannel ch
+connAddWriteStream :: Connection addr -> STM (TransportHeaderItem, RawStreamWriter, IO ())
+connAddWriteStream conn@Connection {..} = do
+ let GlobalState {..} = cGlobalState
+
+ outStreams <- readTVar cOutStreams
+ let doInsert n (s@(n', _) : rest) | n == n' =
+ fmap (s:) <$> doInsert (n + 1) rest
+ doInsert n streams = do
+ (sFlowIn, sFlowOut) <- newFlow
+ sNextSequence <- newTVar 0
+ let info = (n, Stream {..})
+ return (info, info : streams)
+ ((streamNumber, stream), outStreams') <- doInsert 1 outStreams
+ writeTVar cOutStreams outStreams'
+
+ let go = do
+ msg <- atomically $ readFlow (sFlowOut stream)
+ let (plain, cont) = case msg of
+ StreamData {..} -> (stpData, True)
+ StreamClosed {} -> (BC.empty, False)
+ -- TODO: send channel closed only after delivering all previous data packets
+ -- TODO: free channel number after delivering stream closed
+ let secure = True
+ plainAckedBy = []
+
+ mbch <- atomically (readTVar cChannel) >>= return . \case
+ ChannelEstablished ch -> Just ch
+ ChannelOurAccept _ _ ch -> Just ch
+ _ -> Nothing
+
+ mbs <- case mbch of
+ Just ch -> do
+ runExceptT (channelEncrypt ch $ B.concat
+ [ B.singleton streamNumber
+ , B.singleton (fromIntegral (stpSequence msg) :: Word8)
+ , plain
+ ] ) >>= \case
+ Right (ctext, counter) -> do
+ let isAcked = True
+ return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else [])
+ Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err
+ return Nothing
+ Nothing | secure -> return Nothing
+ | otherwise -> return $ Just (plain, plainAckedBy)
+
+ case mbs of
+ Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy)
+ Nothing -> return ()
+
+ when cont go
+
+ return (StreamOpen streamNumber, sFlowIn stream, go)
+
+connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader
+connAddReadStream Connection {..} streamNumber = do
+ inStreams <- readTVar cInStreams
+ let doInsert (s@(n, _) : rest)
+ | streamNumber < n = fmap (s:) <$> doInsert rest
+ | streamNumber == n = doInsert rest
+ doInsert streams = do
+ (sFlowIn, sFlowOut) <- newFlow
+ sNextSequence <- newTVar 0
+ let stream = Stream {..}
+ return (stream, (streamNumber, stream) : streams)
+ (stream, inStreams') <- doInsert inStreams
+ writeTVar cInStreams inStreams'
+ return $ sFlowOut stream
+
+
+type RawStreamReader = Flow StreamPacket Void
+type RawStreamWriter = Flow Void StreamPacket
+
+data Stream = Stream
+ { sFlowIn :: Flow Void StreamPacket
+ , sFlowOut :: Flow StreamPacket Void
+ , sNextSequence :: TVar Word64
+ }
+
+data StreamPacket
+ = StreamData
+ { stpSequence :: Word64
+ , stpData :: BC.ByteString
+ }
+ | StreamClosed
+ { stpSequence :: Word64
+ }
+
+readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)])
+readStreamToList stream = readFlowIO stream >>= \case
+ StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream
+ StreamClosed sqEnd -> return (sqEnd, [])
+
+readObjectsFromStream :: PartialStorage -> RawStreamReader -> IO (Except String [PartialObject])
+readObjectsFromStream st stream = do
+ (seqEnd, list) <- readStreamToList stream
+ print (seqEnd, length list, list)
+ let validate s ((s', bytes) : rest)
+ | s == s' = (bytes : ) <$> validate (s + 1) rest
+ | s > s' = validate s rest
+ | otherwise = throwError "missing object chunk"
+ validate s []
+ | s == seqEnd = return []
+ | otherwise = throwError "content length mismatch"
+ return $ do
+ content <- BL.fromChunks <$> validate 0 list
+ deserializeObjects st content
+
+writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO ()
+writeByteStringToStream stream = go 0
+ where
+ go seqNum bstr
+ | BL.null bstr = writeFlowIO stream $ StreamClosed seqNum
+ | otherwise = do
+ let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU
+ writeFlowIO stream $ StreamData seqNum (BL.toStrict cur)
+ go (seqNum + 1) rest
+
data WaitingRef = WaitingRef
{ wrefStorage :: Storage
@@ -260,7 +392,7 @@ findConnection GlobalState {..} addr = do
find ((addr==) . cAddress) <$> readTVar gConnections
newConnection :: GlobalState addr -> addr -> STM (Connection addr)
-newConnection GlobalState {..} addr = do
+newConnection cGlobalState@GlobalState {..} addr = do
conns <- readTVar gConnections
let cAddress = addr
@@ -269,6 +401,8 @@ newConnection GlobalState {..} addr = do
cSecureOutQueue <- newTQueue
cSentPackets <- newTVar []
cToAcknowledge <- newTVar []
+ cInStreams <- newTVar []
+ cOutStreams <- newTVar []
let conn = Connection {..}
writeTVar gConnections (conn : conns)
@@ -306,24 +440,30 @@ processIncoming gs@GlobalState {..} = do
case B.uncons dec of
Just (0x00, content) -> do
objs <- deserialize content
- return (True, objs, Just counter)
+ return $ Left (True, objs, Just counter)
+
+ Just (snum, dec')
+ | snum < 64
+ , Just (seq8, content) <- B.uncons dec'
+ -> do
+ return $ Right (snum, seq8, content, counter)
Just (_, _) -> do
- throwError "streams not implemented"
+ throwError "unexpected stream header"
Nothing -> do
throwError "empty decrypted content"
| b .&. 0xE0 == 0x60 -> do
objs <- deserialize msg
- return (False, objs, Nothing)
+ return $ Left (False, objs, Nothing)
| otherwise -> throwError "invalid packet"
Nothing -> throwError "empty packet"
runExceptT parse >>= \case
- Right (secure, objs, mbcounter)
+ Right (Left (secure, objs, mbcounter))
| hobj:content <- objs
, Just header@(TransportHeader items) <- transportFromObject hobj
-> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case
@@ -342,6 +482,29 @@ processIncoming gs@GlobalState {..} = do
gLog $ show addr ++ ": invalid objects"
gLog $ show objs
+ Right (Right (snum, seq8, content, counter))
+ | Just Connection {..} <- mbconn
+ -> atomically $ do
+ (lookup snum <$> readTVar cInStreams) >>= \case
+ Nothing ->
+ gLog $ "unexpected stream number " ++ show snum
+
+ Just Stream {..} -> do
+ expectedSequence <- readTVar sNextSequence
+ let seqFull = expectedSequence - 0x80 + fromIntegral (seq8 - fromIntegral expectedSequence + 0x80 :: Word8)
+ sdata <- if
+ | B.null content -> do
+ modifyTVar' cInStreams $ filter ((/=snum) . fst)
+ return $ StreamClosed seqFull
+ | otherwise -> do
+ writeTVar sNextSequence $ max expectedSequence (seqFull + 1)
+ return $ StreamData seqFull content
+ writeFlow sFlowIn sdata
+ modifyTVar' cToAcknowledge (fromIntegral counter :)
+
+ | otherwise -> do
+ atomically $ gLog $ show addr <> ": stream packet without connection"
+
Left err -> do
atomically $ gLog $ show addr <> ": failed to parse packet: " <> err
@@ -455,8 +618,9 @@ createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr)
verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool
verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie
-resendBytes :: GlobalState addr -> Connection addr -> SentPacket -> IO ()
-resendBytes GlobalState {..} Connection {..} sp = do
+resendBytes :: Connection addr -> SentPacket -> IO ()
+resendBytes Connection {..} sp = do
+ let GlobalState {..} = cGlobalState
now <- getTime MonotonicRaw
atomically $ do
when (isJust $ spAckedBy sp) $ do
@@ -466,8 +630,8 @@ resendBytes GlobalState {..} Connection {..} sp = do
}
writeFlow gDataFlow (cAddress, spData sp)
-sendBytes :: GlobalState addr -> Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO ()
-sendBytes gs conn bs ackedBy = resendBytes gs conn
+sendBytes :: Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO ()
+sendBytes conn bs ackedBy = resendBytes conn
SentPacket
{ spTime = undefined
, spRetryCount = -1
@@ -526,7 +690,7 @@ processOutgoing gs@GlobalState {..} = do
| otherwise -> return $ Just (BL.toStrict plain, plainAckedBy)
case mbs of
- Just (bs, ackedBy) -> sendBytes gs conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy)
+ Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy)
Nothing -> return ()
let retransmitPacket :: Connection addr -> STM (IO ())
@@ -545,7 +709,7 @@ processOutgoing gs@GlobalState {..} = do
else retry
else do
writeTVar cSentPackets rest
- return $ resendBytes gs conn sp
+ return $ resendBytes conn sp
let handleControlRequests = readFlow gControlFlow >>= \case
RequestConnection addr -> do
@@ -561,7 +725,7 @@ processOutgoing gs@GlobalState {..} = do
, lazyLoadBytes gInitConfig
]
writeTVar cChannel ChannelCookieWait
- return $ sendBytes gs conn packet $ Just $ \case CookieSet {} -> True; _ -> False
+ return $ sendBytes conn packet $ Just $ \case CookieSet {} -> True; _ -> False
_ -> return $ return ()
SendAnnounce addr -> do