summaryrefslogtreecommitdiff
path: root/src/Network/Protocol.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Network/Protocol.hs')
-rw-r--r--src/Network/Protocol.hs325
1 files changed, 325 insertions, 0 deletions
diff --git a/src/Network/Protocol.hs b/src/Network/Protocol.hs
new file mode 100644
index 0000000..adc9471
--- /dev/null
+++ b/src/Network/Protocol.hs
@@ -0,0 +1,325 @@
+module Network.Protocol (
+ TransportPacket(..),
+ transportToObject,
+ TransportHeader(..),
+ TransportHeaderItem(..),
+
+ WaitingRef(..),
+ WaitingRefCallback,
+ wrDigest,
+
+ ChannelState(..),
+
+ erebosNetworkProtocol,
+
+ Connection,
+ connAddress,
+ connData,
+ connGetChannel,
+ connSetChannel,
+
+ module Flow,
+) where
+
+import Control.Applicative
+import Control.Concurrent
+import Control.Concurrent.Async
+import Control.Concurrent.STM
+import Control.Monad
+import Control.Monad.Except
+
+import Data.ByteString (ByteString)
+import Data.ByteString.Char8 qualified as BC
+import Data.ByteString.Lazy qualified as BL
+import Data.List
+import Data.Maybe
+
+import System.Clock
+
+import Channel
+import Flow
+import Service
+import Storage
+
+
+data TransportPacket a = TransportPacket TransportHeader [a]
+
+data TransportHeader = TransportHeader [TransportHeaderItem]
+
+data TransportHeaderItem
+ = Acknowledged RefDigest
+ | Rejected RefDigest
+ | DataRequest RefDigest
+ | DataResponse RefDigest
+ | AnnounceSelf RefDigest
+ | AnnounceUpdate RefDigest
+ | TrChannelRequest RefDigest
+ | TrChannelAccept RefDigest
+ | ServiceType ServiceID
+ | ServiceRef RefDigest
+ deriving (Eq)
+
+transportToObject :: PartialStorage -> TransportHeader -> PartialObject
+transportToObject st (TransportHeader items) = Rec $ map single items
+ where single = \case
+ Acknowledged dgst -> (BC.pack "ACK", RecRef $ partialRefFromDigest st dgst)
+ Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst)
+ DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst)
+ DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst)
+ AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst)
+ AnnounceUpdate dgst -> (BC.pack "ANU", RecRef $ partialRefFromDigest st dgst)
+ TrChannelRequest dgst -> (BC.pack "CRQ", RecRef $ partialRefFromDigest st dgst)
+ TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst)
+ ServiceType stype -> (BC.pack "STP", RecUUID $ toUUID stype)
+ ServiceRef dgst -> (BC.pack "SRF", RecRef $ partialRefFromDigest st dgst)
+
+transportFromObject :: PartialObject -> Maybe TransportHeader
+transportFromObject (Rec items) = case catMaybes $ map single items of
+ [] -> Nothing
+ titems -> Just $ TransportHeader titems
+ where single (name, content) = if
+ | name == BC.pack "ACK", RecRef ref <- content -> Just $ Acknowledged $ refDigest ref
+ | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref
+ | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref
+ | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref
+ | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref
+ | name == BC.pack "ANU", RecRef ref <- content -> Just $ AnnounceUpdate $ refDigest ref
+ | name == BC.pack "CRQ", RecRef ref <- content -> Just $ TrChannelRequest $ refDigest ref
+ | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref
+ | name == BC.pack "STP", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid
+ | name == BC.pack "SRF", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref
+ | otherwise -> Nothing
+transportFromObject _ = Nothing
+
+
+data GlobalState addr = (Eq addr, Show addr) => GlobalState
+ { gConnections :: TVar [Connection addr]
+ , gDataFlow :: SymFlow (addr, ByteString)
+ , gConnectionFlow :: Flow addr (Connection addr)
+ , gLog :: String -> STM ()
+ , gStorage :: PartialStorage
+ , gNowVar :: TVar TimeSpec
+ , gNextTimeout :: TVar TimeSpec
+ }
+
+data Connection addr = Connection
+ { cAddress :: addr
+ , cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem])
+ , cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject)
+ , cChannel :: TVar ChannelState
+ , cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem])
+ , cSentPackets :: TVar [SentPacket]
+ }
+
+connAddress :: Connection addr -> addr
+connAddress = cAddress
+
+connData :: Connection addr -> Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem])
+connData = cDataUp
+
+connGetChannel :: Connection addr -> STM ChannelState
+connGetChannel Connection {..} = readTVar cChannel
+
+connSetChannel :: Connection addr -> ChannelState -> STM ()
+connSetChannel Connection {..} ch = do
+ writeTVar cChannel ch
+
+
+data WaitingRef = WaitingRef
+ { wrefStorage :: Storage
+ , wrefPartial :: PartialRef
+ , wrefAction :: Ref -> WaitingRefCallback
+ , wrefStatus :: TVar (Either [RefDigest] Ref)
+ }
+
+type WaitingRefCallback = ExceptT String IO ()
+
+wrDigest :: WaitingRef -> RefDigest
+wrDigest = refDigest . wrefPartial
+
+
+data ChannelState = ChannelNone
+ | ChannelOurRequest (Stored ChannelRequest)
+ | ChannelPeerRequest WaitingRef
+ | ChannelOurAccept (Stored ChannelAccept) Channel
+ | ChannelEstablished Channel
+
+
+data SentPacket = SentPacket
+ { spTime :: TimeSpec
+ , spRetryCount :: Int
+ , spAckedBy :: [TransportHeaderItem]
+ , spData :: BC.ByteString
+ }
+
+
+erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr)
+ => (String -> STM ())
+ -> SymFlow (addr, ByteString)
+ -> Flow addr (Connection addr)
+ -> IO ()
+erebosNetworkProtocol gLog gDataFlow gConnectionFlow = do
+ gConnections <- newTVarIO []
+ gStorage <- derivePartialStorage =<< memoryStorage
+
+ startTime <- getTime MonotonicRaw
+ gNowVar <- newTVarIO startTime
+ gNextTimeout <- newTVarIO startTime
+
+ let gs = GlobalState {..}
+
+ let signalTimeouts = forever $ do
+ now <- getTime MonotonicRaw
+ next <- atomically $ do
+ writeTVar gNowVar now
+ readTVar gNextTimeout
+
+ let waitTill time
+ | time > now = threadDelay $ fromInteger (toNanoSecs (time - now)) `div` 1000
+ | otherwise = threadDelay maxBound
+ waitForUpdate = atomically $ do
+ next' <- readTVar gNextTimeout
+ when (next' == next) retry
+
+ race_ (waitTill next) waitForUpdate
+
+ race_ signalTimeouts $ forever $ join $ atomically $
+ processIncomming gs <|> processOutgoing gs
+
+
+
+getConnection :: GlobalState addr -> addr -> STM (Connection addr)
+getConnection GlobalState {..} addr = do
+ conns <- readTVar gConnections
+ case find ((addr==) . cAddress) conns of
+ Just conn -> return conn
+ Nothing -> do
+ let cAddress = addr
+ (cDataUp, cDataInternal) <- newFlow
+ cChannel <- newTVar ChannelNone
+ cSecureOutQueue <- newTQueue
+ cSentPackets <- newTVar []
+ let conn = Connection {..}
+
+ writeTVar gConnections (conn : conns)
+ writeFlow gConnectionFlow conn
+ return conn
+
+processIncomming :: GlobalState addr -> STM (IO ())
+processIncomming gs@GlobalState {..} = do
+ (addr, msg) <- readFlow gDataFlow
+ conn@Connection {..} <- getConnection gs addr
+
+ mbch <- readTVar cChannel >>= return . \case
+ ChannelEstablished ch -> Just ch
+ ChannelOurAccept _ ch -> Just ch
+ _ -> Nothing
+
+ return $ do
+ (content, secure) <- do
+ if | Just ch <- mbch
+ -> runExceptT (channelDecrypt ch msg) >>= \case
+ Right plain -> return (plain, True)
+ _ -> return (msg, False)
+
+ | otherwise
+ -> return (msg, False)
+
+ case runExcept $ deserializeObjects gStorage $ BL.fromStrict content of
+ Right (obj:objs)
+ | Just header@(TransportHeader items) <- transportFromObject obj -> atomically $ do
+ processAcknowledgements gs conn items
+ writeFlow cDataInternal (secure, TransportPacket header objs)
+
+ | otherwise -> atomically $ do
+ gLog $ show cAddress ++ ": invalid objects"
+ gLog $ show objs
+
+ _ -> do atomically $ gLog $ show cAddress ++ ": invalid objects"
+
+
+processOutgoing :: forall addr. GlobalState addr -> STM (IO ())
+processOutgoing gs@GlobalState {..} = do
+ let sendBytes :: Connection addr -> SentPacket -> IO ()
+ sendBytes Connection {..} sp = do
+ now <- getTime MonotonicRaw
+ atomically $ do
+ when (not $ null $ spAckedBy sp) $ do
+ modifyTVar' cSentPackets $ (:) sp
+ { spTime = now
+ , spRetryCount = spRetryCount sp + 1
+ }
+ writeFlow gDataFlow (cAddress, spData sp)
+
+ let sendNextPacket :: Connection addr -> STM (IO ())
+ sendNextPacket conn@Connection {..} = do
+ mbch <- readTVar cChannel >>= return . \case
+ ChannelEstablished ch -> Just ch
+ _ -> Nothing
+
+ let checkOutstanding
+ | isJust mbch = readTQueue cSecureOutQueue
+ | otherwise = retry
+
+ (secure, packet@(TransportPacket header content), ackedBy) <-
+ checkOutstanding <|> readFlow cDataInternal
+
+ let plain = BL.toStrict $ BL.concat $
+ (serializeObject $ transportToObject gStorage header)
+ : map lazyLoadBytes content
+
+ when (isNothing mbch && secure) $ do
+ writeTQueue cSecureOutQueue (secure, packet, ackedBy)
+
+ return $ do
+ mbs <- case mbch of
+ Just ch -> do
+ runExceptT (channelEncrypt ch plain) >>= \case
+ Right ctext -> return $ Just ctext
+ Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err
+ return Nothing
+ Nothing | secure -> return Nothing
+ | otherwise -> return $ Just plain
+
+ case mbs of
+ Just bs -> do
+ sendBytes conn $ SentPacket
+ { spTime = undefined
+ , spRetryCount = -1
+ , spAckedBy = ackedBy
+ , spData = bs
+ }
+ Nothing -> return ()
+
+ let retransmitPacket :: Connection addr -> STM (IO ())
+ retransmitPacket conn@Connection {..} = do
+ now <- readTVar gNowVar
+ (sp, rest) <- readTVar cSentPackets >>= \case
+ sps@(_:_) -> return (last sps, init sps)
+ _ -> retry
+ let nextTry = spTime sp + fromNanoSecs 1000000000
+ if now < nextTry
+ then do
+ nextTimeout <- readTVar gNextTimeout
+ if nextTimeout <= now || nextTry < nextTimeout
+ then do writeTVar gNextTimeout nextTry
+ return $ return ()
+ else retry
+ else do
+ writeTVar cSentPackets rest
+ return $ sendBytes conn sp
+
+ let establishNewConnection = do
+ _ <- getConnection gs =<< readFlow gConnectionFlow
+ return $ return ()
+
+ conns <- readTVar gConnections
+ msum $ concat $
+ [ map retransmitPacket conns
+ , map sendNextPacket conns
+ , [ establishNewConnection ]
+ ]
+
+processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM ()
+processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do
+ modifyTVar' cSentPackets $ filter $ (hitem `notElem`) . spAckedBy