From bda62efef1ad38779f23b38b4e1436f06fb9c7c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Tue, 1 Aug 2023 23:01:30 +0200 Subject: Network protocol refactoring with explicit data flows --- src/Network/Protocol.hs | 325 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 src/Network/Protocol.hs (limited to 'src/Network/Protocol.hs') diff --git a/src/Network/Protocol.hs b/src/Network/Protocol.hs new file mode 100644 index 0000000..adc9471 --- /dev/null +++ b/src/Network/Protocol.hs @@ -0,0 +1,325 @@ +module Network.Protocol ( + TransportPacket(..), + transportToObject, + TransportHeader(..), + TransportHeaderItem(..), + + WaitingRef(..), + WaitingRefCallback, + wrDigest, + + ChannelState(..), + + erebosNetworkProtocol, + + Connection, + connAddress, + connData, + connGetChannel, + connSetChannel, + + module Flow, +) where + +import Control.Applicative +import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.Except + +import Data.ByteString (ByteString) +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.List +import Data.Maybe + +import System.Clock + +import Channel +import Flow +import Service +import Storage + + +data TransportPacket a = TransportPacket TransportHeader [a] + +data TransportHeader = TransportHeader [TransportHeaderItem] + +data TransportHeaderItem + = Acknowledged RefDigest + | Rejected RefDigest + | DataRequest RefDigest + | DataResponse RefDigest + | AnnounceSelf RefDigest + | AnnounceUpdate RefDigest + | TrChannelRequest RefDigest + | TrChannelAccept RefDigest + | ServiceType ServiceID + | ServiceRef RefDigest + deriving (Eq) + +transportToObject :: PartialStorage -> TransportHeader -> PartialObject +transportToObject st (TransportHeader items) = Rec $ map single items + where single = \case + Acknowledged dgst -> (BC.pack "ACK", RecRef $ partialRefFromDigest st dgst) + Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst) + DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst) + DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst) + AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst) + AnnounceUpdate dgst -> (BC.pack "ANU", RecRef $ partialRefFromDigest st dgst) + TrChannelRequest dgst -> (BC.pack "CRQ", RecRef $ partialRefFromDigest st dgst) + TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst) + ServiceType stype -> (BC.pack "STP", RecUUID $ toUUID stype) + ServiceRef dgst -> (BC.pack "SRF", RecRef $ partialRefFromDigest st dgst) + +transportFromObject :: PartialObject -> Maybe TransportHeader +transportFromObject (Rec items) = case catMaybes $ map single items of + [] -> Nothing + titems -> Just $ TransportHeader titems + where single (name, content) = if + | name == BC.pack "ACK", RecRef ref <- content -> Just $ Acknowledged $ refDigest ref + | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref + | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref + | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref + | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref + | name == BC.pack "ANU", RecRef ref <- content -> Just $ AnnounceUpdate $ refDigest ref + | name == BC.pack "CRQ", RecRef ref <- content -> Just $ TrChannelRequest $ refDigest ref + | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref + | name == BC.pack "STP", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid + | name == BC.pack "SRF", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref + | otherwise -> Nothing +transportFromObject _ = Nothing + + +data GlobalState addr = (Eq addr, Show addr) => GlobalState + { gConnections :: TVar [Connection addr] + , gDataFlow :: SymFlow (addr, ByteString) + , gConnectionFlow :: Flow addr (Connection addr) + , gLog :: String -> STM () + , gStorage :: PartialStorage + , gNowVar :: TVar TimeSpec + , gNextTimeout :: TVar TimeSpec + } + +data Connection addr = Connection + { cAddress :: addr + , cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) + , cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject) + , cChannel :: TVar ChannelState + , cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem]) + , cSentPackets :: TVar [SentPacket] + } + +connAddress :: Connection addr -> addr +connAddress = cAddress + +connData :: Connection addr -> Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) +connData = cDataUp + +connGetChannel :: Connection addr -> STM ChannelState +connGetChannel Connection {..} = readTVar cChannel + +connSetChannel :: Connection addr -> ChannelState -> STM () +connSetChannel Connection {..} ch = do + writeTVar cChannel ch + + +data WaitingRef = WaitingRef + { wrefStorage :: Storage + , wrefPartial :: PartialRef + , wrefAction :: Ref -> WaitingRefCallback + , wrefStatus :: TVar (Either [RefDigest] Ref) + } + +type WaitingRefCallback = ExceptT String IO () + +wrDigest :: WaitingRef -> RefDigest +wrDigest = refDigest . wrefPartial + + +data ChannelState = ChannelNone + | ChannelOurRequest (Stored ChannelRequest) + | ChannelPeerRequest WaitingRef + | ChannelOurAccept (Stored ChannelAccept) Channel + | ChannelEstablished Channel + + +data SentPacket = SentPacket + { spTime :: TimeSpec + , spRetryCount :: Int + , spAckedBy :: [TransportHeaderItem] + , spData :: BC.ByteString + } + + +erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) + => (String -> STM ()) + -> SymFlow (addr, ByteString) + -> Flow addr (Connection addr) + -> IO () +erebosNetworkProtocol gLog gDataFlow gConnectionFlow = do + gConnections <- newTVarIO [] + gStorage <- derivePartialStorage =<< memoryStorage + + startTime <- getTime MonotonicRaw + gNowVar <- newTVarIO startTime + gNextTimeout <- newTVarIO startTime + + let gs = GlobalState {..} + + let signalTimeouts = forever $ do + now <- getTime MonotonicRaw + next <- atomically $ do + writeTVar gNowVar now + readTVar gNextTimeout + + let waitTill time + | time > now = threadDelay $ fromInteger (toNanoSecs (time - now)) `div` 1000 + | otherwise = threadDelay maxBound + waitForUpdate = atomically $ do + next' <- readTVar gNextTimeout + when (next' == next) retry + + race_ (waitTill next) waitForUpdate + + race_ signalTimeouts $ forever $ join $ atomically $ + processIncomming gs <|> processOutgoing gs + + + +getConnection :: GlobalState addr -> addr -> STM (Connection addr) +getConnection GlobalState {..} addr = do + conns <- readTVar gConnections + case find ((addr==) . cAddress) conns of + Just conn -> return conn + Nothing -> do + let cAddress = addr + (cDataUp, cDataInternal) <- newFlow + cChannel <- newTVar ChannelNone + cSecureOutQueue <- newTQueue + cSentPackets <- newTVar [] + let conn = Connection {..} + + writeTVar gConnections (conn : conns) + writeFlow gConnectionFlow conn + return conn + +processIncomming :: GlobalState addr -> STM (IO ()) +processIncomming gs@GlobalState {..} = do + (addr, msg) <- readFlow gDataFlow + conn@Connection {..} <- getConnection gs addr + + mbch <- readTVar cChannel >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ ch -> Just ch + _ -> Nothing + + return $ do + (content, secure) <- do + if | Just ch <- mbch + -> runExceptT (channelDecrypt ch msg) >>= \case + Right plain -> return (plain, True) + _ -> return (msg, False) + + | otherwise + -> return (msg, False) + + case runExcept $ deserializeObjects gStorage $ BL.fromStrict content of + Right (obj:objs) + | Just header@(TransportHeader items) <- transportFromObject obj -> atomically $ do + processAcknowledgements gs conn items + writeFlow cDataInternal (secure, TransportPacket header objs) + + | otherwise -> atomically $ do + gLog $ show cAddress ++ ": invalid objects" + gLog $ show objs + + _ -> do atomically $ gLog $ show cAddress ++ ": invalid objects" + + +processOutgoing :: forall addr. GlobalState addr -> STM (IO ()) +processOutgoing gs@GlobalState {..} = do + let sendBytes :: Connection addr -> SentPacket -> IO () + sendBytes Connection {..} sp = do + now <- getTime MonotonicRaw + atomically $ do + when (not $ null $ spAckedBy sp) $ do + modifyTVar' cSentPackets $ (:) sp + { spTime = now + , spRetryCount = spRetryCount sp + 1 + } + writeFlow gDataFlow (cAddress, spData sp) + + let sendNextPacket :: Connection addr -> STM (IO ()) + sendNextPacket conn@Connection {..} = do + mbch <- readTVar cChannel >>= return . \case + ChannelEstablished ch -> Just ch + _ -> Nothing + + let checkOutstanding + | isJust mbch = readTQueue cSecureOutQueue + | otherwise = retry + + (secure, packet@(TransportPacket header content), ackedBy) <- + checkOutstanding <|> readFlow cDataInternal + + let plain = BL.toStrict $ BL.concat $ + (serializeObject $ transportToObject gStorage header) + : map lazyLoadBytes content + + when (isNothing mbch && secure) $ do + writeTQueue cSecureOutQueue (secure, packet, ackedBy) + + return $ do + mbs <- case mbch of + Just ch -> do + runExceptT (channelEncrypt ch plain) >>= \case + Right ctext -> return $ Just ctext + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err + return Nothing + Nothing | secure -> return Nothing + | otherwise -> return $ Just plain + + case mbs of + Just bs -> do + sendBytes conn $ SentPacket + { spTime = undefined + , spRetryCount = -1 + , spAckedBy = ackedBy + , spData = bs + } + Nothing -> return () + + let retransmitPacket :: Connection addr -> STM (IO ()) + retransmitPacket conn@Connection {..} = do + now <- readTVar gNowVar + (sp, rest) <- readTVar cSentPackets >>= \case + sps@(_:_) -> return (last sps, init sps) + _ -> retry + let nextTry = spTime sp + fromNanoSecs 1000000000 + if now < nextTry + then do + nextTimeout <- readTVar gNextTimeout + if nextTimeout <= now || nextTry < nextTimeout + then do writeTVar gNextTimeout nextTry + return $ return () + else retry + else do + writeTVar cSentPackets rest + return $ sendBytes conn sp + + let establishNewConnection = do + _ <- getConnection gs =<< readFlow gConnectionFlow + return $ return () + + conns <- readTVar gConnections + msum $ concat $ + [ map retransmitPacket conns + , map sendNextPacket conns + , [ establishNewConnection ] + ] + +processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM () +processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do + modifyTVar' cSentPackets $ filter $ (hitem `notElem`) . spAckedBy -- cgit v1.2.3