module Channel (
    Channel,
    ChannelRequest, ChannelRequestData(..),
    ChannelAccept, ChannelAcceptData(..),

    createChannelRequest,
    acceptChannelRequest,
    acceptedChannel,

    channelEncrypt,
    channelDecrypt,
) where

import Control.Concurrent.MVar
import Control.Monad
import Control.Monad.Except

import Crypto.Cipher.AES
import Crypto.Cipher.Types
import Crypto.Error

import Data.Binary
import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, append, convert)
import Data.ByteArray qualified as BA
import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as BL
import Data.List
import Data.Text qualified as T

import Identity
import PubKey
import Storage

data Channel = Channel
    { chPeers :: [Stored (Signed IdentityData)]
    , chKey :: ScrubbedBytes
    , chNonceFixedOur :: Bytes
    , chNonceFixedPeer :: Bytes
    , chNonceCounter :: MVar Word64
    }

type ChannelRequest = Signed ChannelRequestData

data ChannelRequestData = ChannelRequest
    { crPeers :: [Stored (Signed IdentityData)]
    , crKey :: Stored PublicKexKey
    }
    deriving (Show)

type ChannelAccept = Signed ChannelAcceptData

data ChannelAcceptData = ChannelAccept
    { caRequest :: Stored ChannelRequest
    , caKey :: Stored PublicKexKey
    }


instance Storable ChannelRequestData where
    store' cr = storeRec $ do
        mapM_ (storeRef "peer") $ crPeers cr
        storeText "enc" $ T.pack "aes-128-gcm"
        storeRef "key" $ crKey cr

    load' = loadRec $ do
        enc <- loadText "enc"
        guard $ enc == "aes-128-gcm"
        ChannelRequest
            <$> loadRefs "peer"
            <*> loadRef "key"

instance Storable ChannelAcceptData where
    store' ca = storeRec $ do
        storeRef "req" $ caRequest ca
        storeText "enc" $ T.pack "aes-128-gcm"
        storeRef "key" $ caKey ca

    load' = loadRec $ do
        enc <- loadText "enc"
        guard $ enc == "aes-128-gcm"
        ChannelAccept
            <$> loadRef "req"
            <*> loadRef "key"


createChannelRequest :: (MonadIO m) => Storage -> UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest)
createChannelRequest st self peer = liftIO $ do
    (_, xpublic) <- generateKeys st
    Just skey <- loadKey $ idKeyMessage self
    wrappedStore st =<< sign skey =<< wrappedStore st ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic }

acceptChannelRequest :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Channel)
acceptChannelRequest self peer req = do
    case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of
        Nothing -> throwError $ "invalid peers in channel request"
        Just peers -> do
            when (not $ any (self `sameIdentity`) peers) $
                throwError $ "self identity missing in channel request peers"
            when (not $ any (peer `sameIdentity`) peers) $
                throwError $ "peer identity missing in channel request peers"
    when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $
        throwError $ "channel requent not signed by peer"

    let st = storedStorage req
    ksize <- case cipherKeySize (undefined :: AES128) of
                  KeySizeFixed s -> return s
                  _ -> throwError "expecting fixed key size"
    liftIO $ do
        (xsecret, xpublic) <- generateKeys st
        Just skey <- loadKey $ idKeyMessage self
        acc <- wrappedStore st =<< sign skey =<< wrappedStore st ChannelAccept { caRequest = req, caKey = xpublic }
        counter <- newMVar 0
        return $ (acc,) $ Channel
            { chPeers = crPeers $ fromStored $ signedData $ fromStored req
            , chKey = BA.take ksize $ dhSecret xsecret $
                fromStored $ crKey $ fromStored $ signedData $ fromStored req
            , chNonceFixedOur  = BA.pack [ 2, 0, 0, 0, 0, 0 ]
            , chNonceFixedPeer = BA.pack [ 1, 0, 0, 0, 0, 0 ]
            , chNonceCounter = counter
            }

acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel
acceptedChannel self peer acc = do
    let req = caRequest $ fromStored $ signedData $ fromStored acc
    ksize <- case cipherKeySize (undefined :: AES128) of
                  KeySizeFixed s -> return s
                  _ -> throwError "expecting fixed key size"

    case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of
        Nothing -> throwError $ "invalid peers in channel accept"
        Just peers -> do
            when (not $ any (self `sameIdentity`) peers) $
                throwError $ "self identity missing in channel accept peers"
            when (not $ any (peer `sameIdentity`) peers) $
                throwError $ "peer identity missing in channel accept peers"
    when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $
        throwError $ "channel accept not signed by peer"
    when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $
        throwError $ "original channel request not signed by us"

    xsecret <- liftIO (loadKey $ crKey $ fromStored $ signedData $ fromStored req) >>= \case
        Just key -> return key
        Nothing  -> throwError $ "secret key not found"
    counter <- liftIO $ newMVar 0
    return $ Channel
        { chPeers = crPeers $ fromStored $ signedData $ fromStored req
        , chKey = BA.take ksize $ dhSecret xsecret $
            fromStored $ caKey $ fromStored $ signedData $ fromStored acc
        , chNonceFixedOur  = BA.pack [ 1, 0, 0, 0, 0, 0 ]
        , chNonceFixedPeer = BA.pack [ 2, 0, 0, 0, 0, 0 ]
        , chNonceCounter = counter
        }


channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m ba
channelEncrypt ch plain = do
    cipher <- case cipherInit $ chKey ch of
                   CryptoPassed (cipher :: AES128) -> return cipher
                   _ -> throwError "failed to init AES128 cipher"
    let bsize = blockSize cipher
    count <- liftIO $ modifyMVar (chNonceCounter ch) $ \c -> return (c + 1, c)
    let cbytes = convert $ BL.toStrict $ BL.drop 2 $ encode count
        iv = chNonceFixedOur ch `append` cbytes
    aead <- case aeadInit AEAD_GCM cipher iv of
                 CryptoPassed aead -> return aead
                 _ -> throwError "failed to init AEAD_GCM"
    let (tag, ctext) = aeadSimpleEncrypt aead B.empty plain bsize
    return $ BA.concat [ BA.pack [ 0, 0 ], convert cbytes, ctext, convert tag ]

channelDecrypt :: (ByteArray ba, MonadError String m) => Channel -> ba -> m ba
channelDecrypt ch body = do
    cipher <- case cipherInit $ chKey ch of
                   CryptoPassed (cipher :: AES128) -> return cipher
                   _ -> throwError "failed to init AES128 cipher"
    let bsize = blockSize cipher
        (cbytes, body') = BA.splitAt 8 body
        iv = chNonceFixedPeer ch `append` convert (BA.drop 2 cbytes)
        (ctext, tag) = BA.splitAt (BA.length body' - bsize) body'
    aead <- case aeadInit AEAD_GCM cipher iv of
                 CryptoPassed aead -> return aead
                 _ -> throwError "failed to init AEAD_GCM"
    case aeadSimpleDecrypt aead B.empty ctext (AuthTag $ convert tag) of
         Just plain -> return plain
         Nothing -> throwError "failed to decrypt data"