summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Smrž <roman.smrz@seznam.cz>2023-08-09 22:48:11 +0200
committerRoman Smrž <roman.smrz@seznam.cz>2023-08-27 12:01:16 +0200
commit740c55ac1989ba5093af9350a63820a818ff0202 (patch)
tree26d6257b4871059858bcb421fdd5be5718603428
parentee1dce0d8d3a2f08dac579a0453b69a37110d2ae (diff)
Switch to ChaCha20-Poly1305 AEAD scheme
-rw-r--r--src/Channel.hs128
-rw-r--r--src/Network/Protocol.hs4
2 files changed, 63 insertions, 69 deletions
diff --git a/src/Channel.hs b/src/Channel.hs
index b273392..a1773bd 100644
--- a/src/Channel.hs
+++ b/src/Channel.hs
@@ -15,17 +15,14 @@ import Control.Concurrent.MVar
import Control.Monad
import Control.Monad.Except
-import Crypto.Cipher.AES
-import Crypto.Cipher.Types
+import Crypto.Cipher.ChaChaPoly1305
import Crypto.Error
import Data.Binary
-import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, append, convert)
+import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, convert)
import Data.ByteArray qualified as BA
-import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as BL
import Data.List
-import Data.Text qualified as T
import Identity
import PubKey
@@ -36,7 +33,8 @@ data Channel = Channel
, chKey :: ScrubbedBytes
, chNonceFixedOur :: Bytes
, chNonceFixedPeer :: Bytes
- , chNonceCounter :: MVar Word64
+ , chCounterNextOut :: MVar Word64
+ , chCounterNextIn :: MVar Word64
}
type ChannelRequest = Signed ChannelRequestData
@@ -58,12 +56,9 @@ data ChannelAcceptData = ChannelAccept
instance Storable ChannelRequestData where
store' cr = storeRec $ do
mapM_ (storeRef "peer") $ crPeers cr
- storeText "enc" $ T.pack "aes-128-gcm"
storeRef "key" $ crKey cr
load' = loadRec $ do
- enc <- loadText "enc"
- guard $ enc == "aes-128-gcm"
ChannelRequest
<$> loadRefs "peer"
<*> loadRef "key"
@@ -71,17 +66,17 @@ instance Storable ChannelRequestData where
instance Storable ChannelAcceptData where
store' ca = storeRec $ do
storeRef "req" $ caRequest ca
- storeText "enc" $ T.pack "aes-128-gcm"
storeRef "key" $ caKey ca
load' = loadRec $ do
- enc <- loadText "enc"
- guard $ enc == "aes-128-gcm"
ChannelAccept
<$> loadRef "req"
<*> loadRef "key"
+keySize :: Int
+keySize = 32
+
createChannelRequest :: (MonadIO m) => Storage -> UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest)
createChannelRequest st self peer = liftIO $ do
(_, xpublic) <- generateKeys st
@@ -101,30 +96,23 @@ acceptChannelRequest self peer req = do
throwError $ "channel requent not signed by peer"
let st = storedStorage req
- ksize <- case cipherKeySize (undefined :: AES128) of
- KeySizeFixed s -> return s
- _ -> throwError "expecting fixed key size"
liftIO $ do
(xsecret, xpublic) <- generateKeys st
Just skey <- loadKey $ idKeyMessage self
acc <- wrappedStore st =<< sign skey =<< wrappedStore st ChannelAccept { caRequest = req, caKey = xpublic }
- counter <- newMVar 0
- return $ (acc,) $ Channel
- { chPeers = crPeers $ fromStored $ signedData $ fromStored req
- , chKey = BA.take ksize $ dhSecret xsecret $
+ let chPeers = crPeers $ fromStored $ signedData $ fromStored req
+ chKey = BA.take keySize $ dhSecret xsecret $
fromStored $ crKey $ fromStored $ signedData $ fromStored req
- , chNonceFixedOur = BA.pack [ 2, 0, 0, 0, 0, 0 ]
- , chNonceFixedPeer = BA.pack [ 1, 0, 0, 0, 0, 0 ]
- , chNonceCounter = counter
- }
+ chNonceFixedOur = BA.pack [ 2, 0, 0, 0 ]
+ chNonceFixedPeer = BA.pack [ 1, 0, 0, 0 ]
+ chCounterNextOut <- newMVar 0
+ chCounterNextIn <- newMVar 0
+
+ return (acc, Channel {..})
acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel
acceptedChannel self peer acc = do
let req = caRequest $ fromStored $ signedData $ fromStored acc
- ksize <- case cipherKeySize (undefined :: AES128) of
- KeySizeFixed s -> return s
- _ -> throwError "expecting fixed key size"
-
case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of
Nothing -> throwError $ "invalid peers in channel accept"
Just peers -> do
@@ -140,44 +128,50 @@ acceptedChannel self peer acc = do
xsecret <- liftIO (loadKey $ crKey $ fromStored $ signedData $ fromStored req) >>= \case
Just key -> return key
Nothing -> throwError $ "secret key not found"
- counter <- liftIO $ newMVar 0
- return $ Channel
- { chPeers = crPeers $ fromStored $ signedData $ fromStored req
- , chKey = BA.take ksize $ dhSecret xsecret $
+ let chPeers = crPeers $ fromStored $ signedData $ fromStored req
+ chKey = BA.take keySize $ dhSecret xsecret $
fromStored $ caKey $ fromStored $ signedData $ fromStored acc
- , chNonceFixedOur = BA.pack [ 1, 0, 0, 0, 0, 0 ]
- , chNonceFixedPeer = BA.pack [ 2, 0, 0, 0, 0, 0 ]
- , chNonceCounter = counter
- }
-
-
-channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m ba
-channelEncrypt ch plain = do
- cipher <- case cipherInit $ chKey ch of
- CryptoPassed (cipher :: AES128) -> return cipher
- _ -> throwError "failed to init AES128 cipher"
- let bsize = blockSize cipher
- count <- liftIO $ modifyMVar (chNonceCounter ch) $ \c -> return (c + 1, c)
- let cbytes = convert $ BL.toStrict $ BL.drop 2 $ encode count
- iv = chNonceFixedOur ch `append` cbytes
- aead <- case aeadInit AEAD_GCM cipher iv of
- CryptoPassed aead -> return aead
- _ -> throwError "failed to init AEAD_GCM"
- let (tag, ctext) = aeadSimpleEncrypt aead B.empty plain bsize
- return $ BA.concat [ BA.pack [ 0, 0 ], convert cbytes, ctext, convert tag ]
-
-channelDecrypt :: (ByteArray ba, MonadError String m) => Channel -> ba -> m ba
-channelDecrypt ch body = do
- cipher <- case cipherInit $ chKey ch of
- CryptoPassed (cipher :: AES128) -> return cipher
- _ -> throwError "failed to init AES128 cipher"
- let bsize = blockSize cipher
- (cbytes, body') = BA.splitAt 8 body
- iv = chNonceFixedPeer ch `append` convert (BA.drop 2 cbytes)
- (ctext, tag) = BA.splitAt (BA.length body' - bsize) body'
- aead <- case aeadInit AEAD_GCM cipher iv of
- CryptoPassed aead -> return aead
- _ -> throwError "failed to init AEAD_GCM"
- case aeadSimpleDecrypt aead B.empty ctext (AuthTag $ convert tag) of
- Just plain -> return plain
- Nothing -> throwError "failed to decrypt data"
+ chNonceFixedOur = BA.pack [ 1, 0, 0, 0 ]
+ chNonceFixedPeer = BA.pack [ 2, 0, 0, 0 ]
+ chCounterNextOut <- liftIO $ newMVar 0
+ chCounterNextIn <- liftIO $ newMVar 0
+
+ return Channel {..}
+
+
+channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64)
+channelEncrypt Channel {..} plain = do
+ count <- liftIO $ modifyMVar chCounterNextOut $ \c -> return (c + 1, c)
+ let cbytes = convert $ BL.toStrict $ encode count
+ nonce = nonce8 chNonceFixedOur cbytes
+ state <- case initialize chKey =<< nonce of
+ CryptoPassed state -> return state
+ CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err
+
+ let (ctext, state') = encrypt plain state
+ tag = finalize state'
+ return (BA.concat [ convert $ BA.drop 7 cbytes, ctext, convert tag ], count)
+
+channelDecrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64)
+channelDecrypt Channel {..} body = do
+ when (BA.length body < 17) $ do
+ throwError $ "invalid encrypted data length"
+
+ expectedCount <- liftIO $ readMVar chCounterNextIn
+ let countByte = body `BA.index` 0
+ body' = BA.dropView body 1
+ guessedCount = expectedCount - 128 + fromIntegral (countByte - fromIntegral expectedCount + 128 :: Word8)
+ nonce = nonce8 chNonceFixedPeer $ convert $ BL.toStrict $ encode guessedCount
+ blen = BA.length body' - 16
+ ctext = BA.takeView body' blen
+ tag = BA.dropView body' blen
+ state <- case initialize chKey =<< nonce of
+ CryptoPassed state -> return state
+ CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err
+
+ let (plain, state') = decrypt (convert ctext) state
+ when (not $ tag `BA.constEq` finalize state') $ do
+ throwError $ "tag validation falied"
+
+ liftIO $ modifyMVar_ chCounterNextIn $ return . max (guessedCount + 1)
+ return (plain, guessedCount)
diff --git a/src/Network/Protocol.hs b/src/Network/Protocol.hs
index 554d93c..054c0fb 100644
--- a/src/Network/Protocol.hs
+++ b/src/Network/Protocol.hs
@@ -223,7 +223,7 @@ processIncomming gs@GlobalState {..} = do
Just (b, enc)
| b .&. 0xE0 == 0x80 -> do
ch <- maybe (throwError "unexpected encrypted packet") return mbch
- dec <- channelDecrypt ch enc
+ (dec, _) <- channelDecrypt ch enc
case B.uncons dec of
Just (0x00, content) -> do
@@ -297,7 +297,7 @@ processOutgoing gs@GlobalState {..} = do
mbs <- case mbch of
Just ch -> do
runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case
- Right ctext -> return $ Just $ 0x80 `B.cons` ctext
+ Right (ctext, _) -> return $ Just $ 0x80 `B.cons` ctext
Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err
return Nothing
Nothing | secure -> return Nothing