module Network.Ip (
    IpPrefix(..),
    textIpNetwork,

    IpAddress(..),
    textIpAddress,
    textIpAddressCidr,

    allowsSubnets,
    ipSubnet,
    lanSubnet,

    MonadPIO(..),
    MonadSTM(..),
    atomicallyWithIO,

    NetworkNamespace,
    HasNetns(..),
    addNetworkNamespace,
    textNetnsName,
    callOn,

    Link(..),
    Loopback, loopback,
    VEth, addVEth,
    Bridge, addBridge,
    addAddress,
    setMaster,
    linkUp, linkDown,

    Route(..),
    addRoute,
) where

import Control.Concurrent.STM
import Control.Monad
import Control.Monad.Writer

import Data.Function
import Data.Text (Text)
import Data.Text qualified as T
import Data.Typeable
import Data.Word

import System.Process

newtype IpPrefix = IpPrefix [Word8]
    deriving (Eq, Ord)

textIpNetwork :: IpPrefix -> Text
textIpNetwork (IpPrefix prefix) =
    T.intercalate "." (map (T.pack . show) $ prefix ++ replicate (4 - length prefix) 0)
    <> "/" <> T.pack (show (8 * length prefix))

data IpAddress = IpAddress IpPrefix Word8
    deriving (Eq, Ord)

textIpAddress :: IpAddress -> Text
textIpAddress (IpAddress (IpPrefix prefix) num) =
    T.intercalate "." $ map (T.pack . show) $ prefix ++ replicate (3 - length prefix) 0 ++ [num]

textIpAddressCidr :: IpAddress -> Text
textIpAddressCidr ip@(IpAddress (IpPrefix prefix) _) =
    textIpAddress ip <> "/" <> T.pack (show (8 * length prefix))

allowsSubnets :: IpPrefix -> Bool
allowsSubnets (IpPrefix prefix) = length prefix < 3

ipSubnet :: Word8 -> IpPrefix -> IpPrefix
ipSubnet num (IpPrefix prefix) = IpPrefix (prefix ++ [num])

lanSubnet :: IpPrefix -> IpPrefix
lanSubnet (IpPrefix prefix) = IpPrefix (take 3 $ prefix ++ repeat 0)


class Monad m => MonadPIO m where
    postpone :: IO () -> m ()

instance MonadPIO IO where
    postpone = id

instance Monad m => MonadPIO (WriterT [IO ()] m) where
    postpone = tell . (:[])


class Monad m => MonadSTM m where
    liftSTM :: STM a -> m a

instance MonadSTM STM where
    liftSTM = id

instance MonadSTM m => MonadSTM (WriterT [IO ()] m) where
    liftSTM = lift . liftSTM


atomicallyWithIO :: MonadIO m => WriterT [IO ()] STM a -> m a
atomicallyWithIO act = liftIO $ do
    (x, fin) <- atomically $ runWriterT act
    sequence_ fin
    return x


data NetworkNamespace = NetworkNamespace
    { netnsName :: Text
    , netnsRoutesConfigured :: TVar [Route]
    , netnsRoutesActive :: TVar [Route]
    }

instance Eq NetworkNamespace where
    (==) = (==) `on` netnsName

instance Ord NetworkNamespace where
    compare = compare `on` netnsName

class HasNetns a where getNetns :: a -> NetworkNamespace
instance HasNetns NetworkNamespace where getNetns = id

addNetworkNamespace :: (MonadPIO m, MonadSTM m) => Text -> m NetworkNamespace
addNetworkNamespace netnsName = do
    postpone $ callCommand $ T.unpack $ "ip netns add \"" <> netnsName <> "\""
    netnsRoutesConfigured <- liftSTM $ newTVar []
    netnsRoutesActive <- liftSTM $ newTVar []
    return $ NetworkNamespace {..}

textNetnsName :: NetworkNamespace -> Text
textNetnsName = netnsName

callOn :: HasNetns a => a -> Text -> IO ()
callOn n cmd = callCommand $ T.unpack $ "ip netns exec \"" <> ns <> "\" " <> cmd
    where ns = textNetnsName $ getNetns n


data Link a = Link
    { linkName :: Text
    , linkNetns :: NetworkNamespace
    }
    deriving (Eq)

data SomeLink = forall a. Typeable a => SomeLink (Link a)

instance Eq SomeLink where
    SomeLink a == SomeLink b
        | Just b' <- cast b = a == b'
        | otherwise = False

liftSomeLink :: (forall a. Link a -> b) -> SomeLink -> b
liftSomeLink f (SomeLink x) = f x

instance HasNetns (Link a) where getNetns = linkNetns
instance HasNetns SomeLink where getNetns = liftSomeLink linkNetns

data Loopback

loopback :: HasNetns n => n -> Link Loopback
loopback = Link "lo" . getNetns

data VEth

addVEth :: (HasNetns n, HasNetns n', MonadPIO m) => (n, Text) -> (n', Text) -> m (Link VEth, Link VEth)
addVEth (netns, name) (netns', name') = do
    postpone $ callOn netns $ "ip link add \"" <> name <> "\" type veth peer name \"" <> name' <> "\" netns \"" <> textNetnsName (getNetns netns') <> "\""
    return $ (,)
        (Link name  $ getNetns netns )
        (Link name' $ getNetns netns')

data Bridge

addBridge :: (HasNetns n, MonadPIO m) => n -> Text -> m (Link Bridge)
addBridge netns name = do
    postpone $ callOn netns $ "ip link add name \"" <> name <> "\" type bridge"
    return $ Link name $ getNetns netns

addAddress :: MonadPIO m => Link a -> IpAddress -> m ()
addAddress link addr@(IpAddress prefix _) = do
    let bcast = IpAddress prefix 255
    postpone $ callOn link $ "ip addr add " <> textIpAddressCidr addr <> " broadcast " <> textIpAddress bcast <> " dev \"" <> linkName link <> "\""

setMaster :: MonadPIO m => Link a -> Link Bridge -> m ()
setMaster link bridge = postpone $ do
    when (getNetns link /= getNetns bridge) $ fail "link and bridge in different network namespaces"
    callOn link $ "ip link set dev \"" <> linkName link <> "\" master \"" <> linkName bridge <> "\""

linkUp :: (Typeable a, MonadPIO m, MonadSTM m) => Link a -> m ()
linkUp link = do
    routes <- liftSTM $ filter ((== SomeLink link) . routeDev) <$> readTVar (netnsRoutesConfigured (getNetns link))
    liftSTM $ modifyTVar (netnsRoutesActive (getNetns link)) $ (routes ++)
    postpone $ do
        callOn link $ "ip link set dev \"" <> linkName link <> "\" up"
        -- add back routes that were automatically removed by kernel when the link went down
        mapM_ applyRoute routes

linkDown :: (Typeable a, MonadPIO m, MonadSTM m) => Link a -> m ()
linkDown link = do
    -- routes using this device will be automatically removed by kernel
    liftSTM $ modifyTVar (netnsRoutesActive (getNetns link)) $ filter ((/= SomeLink link) . routeDev)
    postpone $ callOn link $ "ip link set dev \"" <> linkName link <> "\" down"


data Route = Route
    { routePrefix :: IpPrefix
    , routeVia :: IpAddress
    , routeDev :: SomeLink
    , routeSrc :: IpAddress
    }

addRoute :: Typeable a => IpPrefix -> IpAddress -> Link a -> IpAddress -> WriterT [IO ()] STM ()
addRoute routePrefix routeVia link routeSrc = do
    let routeDev = SomeLink link
        route = Route {..}
    lift $ do
        modifyTVar (netnsRoutesConfigured (getNetns link)) (route:)
        modifyTVar (netnsRoutesActive (getNetns link)) (route:)
    postpone $ applyRoute route

applyRoute :: Route -> IO ()
applyRoute route = callOn (routeDev route) $ "ip route add "
    <> textIpNetwork (routePrefix route)
    <> " via " <> textIpAddress (routeVia route)
    <> " dev " <> linkName `liftSomeLink` (routeDev route)
    <> " src " <> textIpAddress (routeSrc route)