module Channel ( Channel, ChannelRequest, ChannelRequestData(..), ChannelAccept, ChannelAcceptData(..), createChannelRequest, acceptChannelRequest, acceptedChannel, channelEncrypt, channelDecrypt, ) where import Control.Monad import Control.Monad.Except import Crypto.Cipher.AES import Crypto.Cipher.Types import Crypto.Error import Crypto.Random import Data.ByteArray import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString as B import Data.List import qualified Data.Text as T import Identity import PubKey import Storage data Channel = Channel { chPeers :: [Stored (Signed IdentityData)] , chKey :: ScrubbedBytes } deriving (Show) type ChannelRequest = Signed ChannelRequestData data ChannelRequestData = ChannelRequest { crPeers :: [Stored (Signed IdentityData)] , crKey :: Stored PublicKexKey } deriving (Show) type ChannelAccept = Signed ChannelAcceptData data ChannelAcceptData = ChannelAccept { caRequest :: Stored ChannelRequest , caKey :: Stored PublicKexKey } instance Storable Channel where store' ch = storeRec $ do mapM_ (storeRef "peer") $ chPeers ch storeText "enc" $ T.pack "aes-128-gcm" storeBinary "key" $ chKey ch load' = loadRec $ do enc <- loadText "enc" guard $ enc == "aes-128-gcm" Channel <$> loadRefs "peer" <*> loadBinary "key" instance Storable ChannelRequestData where store' cr = storeRec $ do mapM_ (storeRef "peer") $ crPeers cr storeText "enc" $ T.pack "aes-128-gcm" storeRef "key" $ crKey cr load' = loadRec $ do enc <- loadText "enc" guard $ enc == "aes-128-gcm" ChannelRequest <$> loadRefs "peer" <*> loadRef "key" instance Storable ChannelAcceptData where store' ca = storeRec $ do storeRef "req" $ caRequest ca storeText "enc" $ T.pack "aes-128-gcm" storeRef "key" $ caKey ca load' = loadRec $ do enc <- loadText "enc" guard $ enc == "aes-128-gcm" ChannelAccept <$> loadRef "req" <*> loadRef "key" createChannelRequest :: (MonadIO m) => Storage -> UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest) createChannelRequest st self peer = liftIO $ do (_, xpublic) <- generateKeys st Just skey <- loadKey $ idKeyMessage self wrappedStore st =<< sign skey =<< wrappedStore st ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic } acceptChannelRequest :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Stored Channel) acceptChannelRequest self peer req = do when ((crPeers $ fromStored $ signedData $ fromStored req) /= sort (map idData [self, peer])) $ throwError $ "mismatched peers in channel request" when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ throwError $ "channel requent not signed by peer" let st = storedStorage req KeySizeFixed ksize = cipherKeySize (undefined :: AES128) liftIO $ do (xsecret, xpublic) <- generateKeys st Just skey <- loadKey $ idKeyMessage self acc <- wrappedStore st =<< sign skey =<< wrappedStore st ChannelAccept { caRequest = req, caKey = xpublic } ch <- wrappedStore st Channel { chPeers = crPeers $ fromStored $ signedData $ fromStored req , chKey = BA.take ksize $ dhSecret xsecret $ fromStored $ crKey $ fromStored $ signedData $ fromStored req } return (acc, ch) acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m (Stored Channel) acceptedChannel self peer acc = do let st = storedStorage acc req = caRequest $ fromStored $ signedData $ fromStored acc KeySizeFixed ksize = cipherKeySize (undefined :: AES128) when ((crPeers $ fromStored $ signedData $ fromStored req) /= sort (map idData [self, peer])) $ throwError $ "mismatched peers in channel accept" when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $ throwError $ "channel accept not signed by peer" when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ throwError $ "original channel request not signed by us" xsecret <- liftIO (loadKey $ crKey $ fromStored $ signedData $ fromStored req) >>= \case Just key -> return key Nothing -> throwError $ "secret key not found" liftIO $ wrappedStore st Channel { chPeers = crPeers $ fromStored $ signedData $ fromStored req , chKey = BA.take ksize $ dhSecret xsecret $ fromStored $ caKey $ fromStored $ signedData $ fromStored acc } channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m ba channelEncrypt ch plain = do cipher <- case cipherInit $ chKey ch of CryptoPassed (cipher :: AES128) -> return cipher _ -> throwError "failed to init AES128 cipher" let bsize = blockSize cipher (iv :: ByteString) <- liftIO $ getRandomBytes 12 aead <- case aeadInit AEAD_GCM cipher iv of CryptoPassed aead -> return aead _ -> throwError "failed to init AEAD_GCM" let (tag, ctext) = aeadSimpleEncrypt aead B.empty plain bsize return $ BA.concat [ convert iv, ctext, convert tag ] channelDecrypt :: (ByteArray ba, MonadError String m) => Channel -> ba -> m ba channelDecrypt ch body = do cipher <- case cipherInit $ chKey ch of CryptoPassed (cipher :: AES128) -> return cipher _ -> throwError "failed to init AES128 cipher" let bsize = blockSize cipher (iv, body') = BA.splitAt 12 body (ctext, tag) = BA.splitAt (BA.length body' - bsize) body' aead <- case aeadInit AEAD_GCM cipher iv of CryptoPassed aead -> return aead _ -> throwError "failed to init AEAD_GCM" case aeadSimpleDecrypt aead B.empty ctext (AuthTag $ convert tag) of Just plain -> return plain Nothing -> throwError "failed to decrypt data"