diff options
Diffstat (limited to 'src/Erebos/Network')
-rw-r--r-- | src/Erebos/Network/Address.hs | 65 | ||||
-rw-r--r-- | src/Erebos/Network/Protocol.hs | 76 | ||||
-rw-r--r-- | src/Erebos/Network/ifaddrs.c | 10 | ||||
-rw-r--r-- | src/Erebos/Network/ifaddrs.h | 6 |
4 files changed, 130 insertions, 27 deletions
diff --git a/src/Erebos/Network/Address.hs b/src/Erebos/Network/Address.hs new file mode 100644 index 0000000..63f6af1 --- /dev/null +++ b/src/Erebos/Network/Address.hs @@ -0,0 +1,65 @@ +module Erebos.Network.Address ( + InetAddress(..), + inetFromSockAddr, + inetToSockAddr, + + SockAddr, PortNumber, +) where + +import Data.Bifunctor +import Data.IP qualified as IP +import Data.Word + +import Foreign.C.Types +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.Storable as F + +import Network.Socket + +import Text.Read + + +newtype InetAddress = InetAddress { fromInetAddress :: IP.IP } + deriving (Eq, Ord) + +instance Show InetAddress where + show (InetAddress ipaddr) + | IP.IPv6 ipv6 <- ipaddr + , ( 0, 0, 0xffff, ipv4 ) <- IP.fromIPv6w ipv6 + = show (IP.toIPv4w ipv4) + + | otherwise + = show ipaddr + +instance Read InetAddress where + readPrec = do + readPrec >>= return . InetAddress . \case + IP.IPv4 ipv4 -> IP.IPv6 $ IP.toIPv6w ( 0, 0, 0xffff, IP.fromIPv4w ipv4 ) + ipaddr -> ipaddr + + readListPrec = readListPrecDefault + +instance F.Storable InetAddress where + sizeOf _ = sizeOf (undefined :: CInt) + 16 + alignment _ = 8 + + peek ptr = (unpackFamily <$> peekByteOff ptr 0) >>= \case + AF_INET -> InetAddress . IP.IPv4 . IP.fromHostAddress <$> peekByteOff ptr (sizeOf (undefined :: CInt)) + AF_INET6 -> InetAddress . IP.IPv6 . IP.toIPv6b . map fromIntegral <$> peekArray 16 (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) + _ -> fail "InetAddress: unknown family" + + poke ptr (InetAddress addr) = case addr of + IP.IPv4 ip -> do + pokeByteOff ptr 0 (packFamily AF_INET) + pokeByteOff ptr (sizeOf (undefined :: CInt)) (IP.toHostAddress ip) + IP.IPv6 ip -> do + pokeByteOff ptr 0 (packFamily AF_INET6) + pokeArray (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) (map fromIntegral $ IP.fromIPv6b ip) + + +inetFromSockAddr :: SockAddr -> Maybe ( InetAddress, PortNumber ) +inetFromSockAddr saddr = first InetAddress <$> IP.fromSockAddr saddr + +inetToSockAddr :: ( InetAddress, PortNumber ) -> SockAddr +inetToSockAddr = IP.toSockAddr . first fromInetAddress diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index c340503..f67e296 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -3,6 +3,7 @@ module Erebos.Network.Protocol ( transportToObject, TransportHeader(..), TransportHeaderItem(..), + ServiceID(..), SecurityRequirement(..), WaitingRef(..), @@ -22,7 +23,8 @@ module Erebos.Network.Protocol ( connSetChannel, connClose, - RawStreamReader, RawStreamWriter, + RawStreamReader(..), RawStreamWriter(..), + StreamPacket(..), connAddWriteStream, connAddReadStream, readStreamToList, @@ -36,6 +38,7 @@ import Control.Applicative import Control.Concurrent import Control.Concurrent.Async import Control.Concurrent.STM +import Control.Exception import Control.Monad import Control.Monad.Except import Control.Monad.Trans @@ -68,9 +71,9 @@ import Erebos.Flow import Erebos.Identity import Erebos.Network.Channel import Erebos.Object -import Erebos.Service import Erebos.Storable import Erebos.Storage +import Erebos.UUID (UUID) protocolVersion :: Text @@ -107,6 +110,9 @@ data TransportHeaderItem | StreamOpen Word8 deriving (Eq, Show) +newtype ServiceID = ServiceID UUID + deriving (Eq, Ord, Show, StorableUUID) + newtype Cookie = Cookie ByteString deriving (Eq, Show) @@ -207,6 +213,7 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState , gControlFlow :: Flow (ControlRequest addr) (ControlMessage addr) , gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject)) , gLog :: String -> STM () + , gTestLog :: String -> STM () , gStorage :: PartialStorage , gStartTime :: TimeSpec , gNowVar :: TVar TimeSpec @@ -243,6 +250,12 @@ instance Eq (Connection addr) where connAddress :: Connection addr -> addr connAddress = cAddress +showConnAddress :: forall addr. Connection addr -> String +showConnAddress Connection {..} = helper cGlobalState cAddress + where + helper :: GlobalState addr -> addr -> String + helper GlobalState {} = show + connData :: Connection addr -> Flow (Maybe (Bool, TransportPacket PartialObject)) (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) @@ -267,6 +280,7 @@ connClose conn@Connection {..} = do connAddWriteStream :: Connection addr -> STM (Either String (TransportHeaderItem, RawStreamWriter, IO ())) connAddWriteStream conn@Connection {..} = do + let GlobalState {..} = cGlobalState outStreams <- readTVar cOutStreams let doInsert :: Word8 -> [(Word8, Stream)] -> ExceptT String STM ((Word8, Stream), [(Word8, Stream)]) doInsert n (s@(n', _) : rest) | n == n' = @@ -283,10 +297,16 @@ connAddWriteStream conn@Connection {..} = do runExceptT $ do ((streamNumber, stream), outStreams') <- doInsert 1 outStreams lift $ writeTVar cOutStreams outStreams' - return (StreamOpen streamNumber, sFlowIn stream, go cGlobalState streamNumber stream) + lift $ gTestLog $ "net-ostream-open " <> showConnAddress conn <> " " <> show streamNumber <> " " <> show (length outStreams') + return + ( StreamOpen streamNumber + , RawStreamWriter (fromIntegral streamNumber) (sFlowIn stream) + , go streamNumber stream + ) where - go gs@GlobalState {..} streamNumber stream = do + go streamNumber stream = do + let GlobalState {..} = cGlobalState (reserved, msg) <- atomically $ do readTVar (sState stream) >>= \case StreamRunning -> return () @@ -299,6 +319,8 @@ connAddWriteStream conn@Connection {..} = do return (stpData, True, return ()) StreamClosed {} -> do atomically $ do + gTestLog $ "net-ostream-close-send " <> showConnAddress conn <> " " <> show streamNumber + atomically $ do -- wait for ack on all sent stream data waits <- readTVar (sWaitingForAck stream) when (waits > 0) retry @@ -342,7 +364,7 @@ connAddWriteStream conn@Connection {..} = do sendBytes conn mbReserved' bs Nothing -> return () - when cont $ go gs streamNumber stream + when cont $ go streamNumber stream connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader connAddReadStream Connection {..} streamNumber = do @@ -356,14 +378,21 @@ connAddReadStream Connection {..} streamNumber = do sNextSequence <- newTVar 0 sWaitingForAck <- newTVar 0 let stream = Stream {..} - return (stream, (streamNumber, stream) : streams) - (stream, inStreams') <- doInsert inStreams + return ( streamNumber, stream, (streamNumber, stream) : streams ) + ( num, stream, inStreams' ) <- doInsert inStreams writeTVar cInStreams inStreams' - return $ sFlowOut stream + return $ RawStreamReader (fromIntegral num) (sFlowOut stream) -type RawStreamReader = Flow StreamPacket Void -type RawStreamWriter = Flow Void StreamPacket +data RawStreamReader = RawStreamReader + { rsrNum :: Int + , rsrFlow :: Flow StreamPacket Void + } + +data RawStreamWriter = RawStreamWriter + { rswNum :: Int + , rswFlow :: Flow Void StreamPacket + } data Stream = Stream { sState :: TVar StreamState @@ -394,11 +423,13 @@ streamAccepted Connection {..} snum = atomically $ do Nothing -> return () streamClosed :: Connection addr -> Word8 -> IO () -streamClosed Connection {..} snum = atomically $ do - modifyTVar' cOutStreams $ filter ((snum /=) . fst) +streamClosed conn@Connection {..} snum = atomically $ do + streams <- filter ((snum /=) . fst) <$> readTVar cOutStreams + writeTVar cOutStreams streams + gTestLog cGlobalState $ "net-ostream-close-ack " <> showConnAddress conn <> " " <> show snum <> " " <> show (length streams) readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) -readStreamToList stream = readFlowIO stream >>= \case +readStreamToList stream = readFlowIO (rsrFlow stream) >>= \case StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream StreamClosed sqEnd -> return (sqEnd, []) @@ -420,10 +451,10 @@ writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO () writeByteStringToStream stream = go 0 where go seqNum bstr - | BL.null bstr = writeFlowIO stream $ StreamClosed seqNum + | BL.null bstr = writeFlowIO (rswFlow stream) $ StreamClosed seqNum | otherwise = do let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU - writeFlowIO stream $ StreamData seqNum (BL.toStrict cur) + writeFlowIO (rswFlow stream) $ StreamData seqNum (BL.toStrict cur) go (seqNum + 1) rest @@ -477,10 +508,11 @@ data ControlMessage addr = NewConnection (Connection addr) (Maybe RefDigest) erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) => UnifiedIdentity -> (String -> STM ()) + -> (String -> STM ()) -> SymFlow (addr, ByteString) -> Flow (ControlRequest addr) (ControlMessage addr) -> IO () -erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do +erebosNetworkProtocol initialIdentity gLog gTestLog gDataFlow gControlFlow = do gIdentity <- newTVarIO (initialIdentity, []) gConnections <- newTVarIO [] gNextUp <- newEmptyTMVarIO @@ -512,8 +544,10 @@ erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do race_ (waitTill next) waitForUpdate - race_ signalTimeouts $ forever $ join $ atomically $ - passUpIncoming gs <|> processIncoming gs <|> processOutgoing gs + race_ signalTimeouts $ forever $ do + io <- atomically $ do + passUpIncoming gs <|> processIncoming gs <|> processOutgoing gs + catch io $ \(e :: SomeException) -> atomically $ gLog $ "exception during network protocol handling: " <> show e getConnection :: GlobalState addr -> addr -> STM (Connection addr) @@ -542,6 +576,7 @@ newConnection cGlobalState@GlobalState {..} addr = do cOutStreams <- newTVar [] let conn = Connection {..} + gTestLog $ "net-conn-new " <> show cAddress writeTVar gConnections (conn : conns) return conn @@ -898,7 +933,10 @@ processOutgoing gs@GlobalState {..} = do , rsOnAck = rsOnAck rs >> onAck }) <$> mbReserved sendBytes conn mbReserved' bs - Nothing -> return () + Nothing -> do + when (isJust mbReserved) $ do + atomically $ do + modifyTVar' cReservedPackets (subtract 1) let waitUntil :: TimeSpec -> TimeSpec -> STM () waitUntil now till = do diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c index ff4382a..8139b5e 100644 --- a/src/Erebos/Network/ifaddrs.c +++ b/src/Erebos/Network/ifaddrs.c @@ -22,7 +22,7 @@ #define DISCOVERY_MULTICAST_GROUP "ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1" -uint32_t * join_multicast(int fd, size_t * count) +uint32_t * erebos_join_multicast(int fd, size_t * count) { size_t capacity = 16; *count = 0; @@ -117,7 +117,7 @@ static bool copy_local_address( struct InetAddress * dst, const struct sockaddr #ifndef _WIN32 -struct InetAddress * local_addresses( size_t * count ) +struct InetAddress * erebos_local_addresses( size_t * count ) { struct ifaddrs * addrs; if( getifaddrs( &addrs ) < 0 ) @@ -153,7 +153,7 @@ struct InetAddress * local_addresses( size_t * count ) return ret; } -uint32_t * broadcast_addresses(void) +uint32_t * erebos_broadcast_addresses(void) { struct ifaddrs * addrs; if (getifaddrs(&addrs) < 0) @@ -196,7 +196,7 @@ uint32_t * broadcast_addresses(void) #pragma comment(lib, "ws2_32.lib") -struct InetAddress * local_addresses( size_t * count ) +struct InetAddress * erebos_local_addresses( size_t * count ) { * count = 0; struct InetAddress * ret = NULL; @@ -237,7 +237,7 @@ cleanup: return ret; } -uint32_t * broadcast_addresses(void) +uint32_t * erebos_broadcast_addresses(void) { uint32_t * ret = NULL; SOCKET wsock = INVALID_SOCKET; diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h index 2ee45a7..2b3c014 100644 --- a/src/Erebos/Network/ifaddrs.h +++ b/src/Erebos/Network/ifaddrs.h @@ -13,6 +13,6 @@ struct InetAddress uint8_t addr[16]; } __attribute__((packed)); -uint32_t * join_multicast(int fd, size_t * count); -struct InetAddress * local_addresses( size_t * count ); -uint32_t * broadcast_addresses(void); +uint32_t * erebos_join_multicast(int fd, size_t * count); +struct InetAddress * erebos_local_addresses( size_t * count ); +uint32_t * erebos_broadcast_addresses(void); |