diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Channel.hs | 128 | ||||
| -rw-r--r-- | src/Network/Protocol.hs | 4 | 
2 files changed, 63 insertions, 69 deletions
| diff --git a/src/Channel.hs b/src/Channel.hs index b273392..a1773bd 100644 --- a/src/Channel.hs +++ b/src/Channel.hs @@ -15,17 +15,14 @@ import Control.Concurrent.MVar  import Control.Monad  import Control.Monad.Except -import Crypto.Cipher.AES -import Crypto.Cipher.Types +import Crypto.Cipher.ChaChaPoly1305  import Crypto.Error  import Data.Binary -import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, append, convert) +import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, convert)  import Data.ByteArray qualified as BA -import Data.ByteString qualified as B  import Data.ByteString.Lazy qualified as BL  import Data.List -import Data.Text qualified as T  import Identity  import PubKey @@ -36,7 +33,8 @@ data Channel = Channel      , chKey :: ScrubbedBytes      , chNonceFixedOur :: Bytes      , chNonceFixedPeer :: Bytes -    , chNonceCounter :: MVar Word64 +    , chCounterNextOut :: MVar Word64 +    , chCounterNextIn :: MVar Word64      }  type ChannelRequest = Signed ChannelRequestData @@ -58,12 +56,9 @@ data ChannelAcceptData = ChannelAccept  instance Storable ChannelRequestData where      store' cr = storeRec $ do          mapM_ (storeRef "peer") $ crPeers cr -        storeText "enc" $ T.pack "aes-128-gcm"          storeRef "key" $ crKey cr      load' = loadRec $ do -        enc <- loadText "enc" -        guard $ enc == "aes-128-gcm"          ChannelRequest              <$> loadRefs "peer"              <*> loadRef "key" @@ -71,17 +66,17 @@ instance Storable ChannelRequestData where  instance Storable ChannelAcceptData where      store' ca = storeRec $ do          storeRef "req" $ caRequest ca -        storeText "enc" $ T.pack "aes-128-gcm"          storeRef "key" $ caKey ca      load' = loadRec $ do -        enc <- loadText "enc" -        guard $ enc == "aes-128-gcm"          ChannelAccept              <$> loadRef "req"              <*> loadRef "key" +keySize :: Int +keySize = 32 +  createChannelRequest :: (MonadIO m) => Storage -> UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest)  createChannelRequest st self peer = liftIO $ do      (_, xpublic) <- generateKeys st @@ -101,30 +96,23 @@ acceptChannelRequest self peer req = do          throwError $ "channel requent not signed by peer"      let st = storedStorage req -    ksize <- case cipherKeySize (undefined :: AES128) of -                  KeySizeFixed s -> return s -                  _ -> throwError "expecting fixed key size"      liftIO $ do          (xsecret, xpublic) <- generateKeys st          Just skey <- loadKey $ idKeyMessage self          acc <- wrappedStore st =<< sign skey =<< wrappedStore st ChannelAccept { caRequest = req, caKey = xpublic } -        counter <- newMVar 0 -        return $ (acc,) $ Channel -            { chPeers = crPeers $ fromStored $ signedData $ fromStored req -            , chKey = BA.take ksize $ dhSecret xsecret $ +        let chPeers = crPeers $ fromStored $ signedData $ fromStored req +            chKey = BA.take keySize $ dhSecret xsecret $                  fromStored $ crKey $ fromStored $ signedData $ fromStored req -            , chNonceFixedOur  = BA.pack [ 2, 0, 0, 0, 0, 0 ] -            , chNonceFixedPeer = BA.pack [ 1, 0, 0, 0, 0, 0 ] -            , chNonceCounter = counter -            } +            chNonceFixedOur  = BA.pack [ 2, 0, 0, 0 ] +            chNonceFixedPeer = BA.pack [ 1, 0, 0, 0 ] +        chCounterNextOut <- newMVar 0 +        chCounterNextIn <- newMVar 0 + +        return (acc, Channel {..})  acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel  acceptedChannel self peer acc = do      let req = caRequest $ fromStored $ signedData $ fromStored acc -    ksize <- case cipherKeySize (undefined :: AES128) of -                  KeySizeFixed s -> return s -                  _ -> throwError "expecting fixed key size" -      case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of          Nothing -> throwError $ "invalid peers in channel accept"          Just peers -> do @@ -140,44 +128,50 @@ acceptedChannel self peer acc = do      xsecret <- liftIO (loadKey $ crKey $ fromStored $ signedData $ fromStored req) >>= \case          Just key -> return key          Nothing  -> throwError $ "secret key not found" -    counter <- liftIO $ newMVar 0 -    return $ Channel -        { chPeers = crPeers $ fromStored $ signedData $ fromStored req -        , chKey = BA.take ksize $ dhSecret xsecret $ +    let chPeers = crPeers $ fromStored $ signedData $ fromStored req +        chKey = BA.take keySize $ dhSecret xsecret $              fromStored $ caKey $ fromStored $ signedData $ fromStored acc -        , chNonceFixedOur  = BA.pack [ 1, 0, 0, 0, 0, 0 ] -        , chNonceFixedPeer = BA.pack [ 2, 0, 0, 0, 0, 0 ] -        , chNonceCounter = counter -        } - - -channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m ba -channelEncrypt ch plain = do -    cipher <- case cipherInit $ chKey ch of -                   CryptoPassed (cipher :: AES128) -> return cipher -                   _ -> throwError "failed to init AES128 cipher" -    let bsize = blockSize cipher -    count <- liftIO $ modifyMVar (chNonceCounter ch) $ \c -> return (c + 1, c) -    let cbytes = convert $ BL.toStrict $ BL.drop 2 $ encode count -        iv = chNonceFixedOur ch `append` cbytes -    aead <- case aeadInit AEAD_GCM cipher iv of -                 CryptoPassed aead -> return aead -                 _ -> throwError "failed to init AEAD_GCM" -    let (tag, ctext) = aeadSimpleEncrypt aead B.empty plain bsize -    return $ BA.concat [ BA.pack [ 0, 0 ], convert cbytes, ctext, convert tag ] - -channelDecrypt :: (ByteArray ba, MonadError String m) => Channel -> ba -> m ba -channelDecrypt ch body = do -    cipher <- case cipherInit $ chKey ch of -                   CryptoPassed (cipher :: AES128) -> return cipher -                   _ -> throwError "failed to init AES128 cipher" -    let bsize = blockSize cipher -        (cbytes, body') = BA.splitAt 8 body -        iv = chNonceFixedPeer ch `append` convert (BA.drop 2 cbytes) -        (ctext, tag) = BA.splitAt (BA.length body' - bsize) body' -    aead <- case aeadInit AEAD_GCM cipher iv of -                 CryptoPassed aead -> return aead -                 _ -> throwError "failed to init AEAD_GCM" -    case aeadSimpleDecrypt aead B.empty ctext (AuthTag $ convert tag) of -         Just plain -> return plain -         Nothing -> throwError "failed to decrypt data" +        chNonceFixedOur  = BA.pack [ 1, 0, 0, 0 ] +        chNonceFixedPeer = BA.pack [ 2, 0, 0, 0 ] +    chCounterNextOut <- liftIO $ newMVar 0 +    chCounterNextIn <- liftIO $ newMVar 0 + +    return Channel {..} + + +channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelEncrypt Channel {..} plain = do +    count <- liftIO $ modifyMVar chCounterNextOut $ \c -> return (c + 1, c) +    let cbytes = convert $ BL.toStrict $ encode count +        nonce = nonce8 chNonceFixedOur cbytes +    state <- case initialize chKey =<< nonce of +        CryptoPassed state -> return state +        CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + +    let (ctext, state') = encrypt plain state +        tag = finalize state' +    return (BA.concat [ convert $ BA.drop 7 cbytes, ctext, convert tag ], count) + +channelDecrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelDecrypt Channel {..} body = do +    when (BA.length body < 17) $ do +        throwError $ "invalid encrypted data length" + +    expectedCount <- liftIO $ readMVar chCounterNextIn +    let countByte = body `BA.index` 0 +        body' = BA.dropView body 1 +        guessedCount = expectedCount - 128 + fromIntegral (countByte - fromIntegral expectedCount + 128 :: Word8) +        nonce = nonce8 chNonceFixedPeer $ convert $ BL.toStrict $ encode guessedCount +        blen = BA.length body' - 16 +        ctext = BA.takeView body' blen +        tag = BA.dropView body' blen +    state <- case initialize chKey =<< nonce of +        CryptoPassed state -> return state +        CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + +    let (plain, state') = decrypt (convert ctext) state +    when (not $ tag `BA.constEq` finalize state') $ do +        throwError $ "tag validation falied" + +    liftIO $ modifyMVar_ chCounterNextIn $ return . max (guessedCount + 1) +    return (plain, guessedCount) diff --git a/src/Network/Protocol.hs b/src/Network/Protocol.hs index 554d93c..054c0fb 100644 --- a/src/Network/Protocol.hs +++ b/src/Network/Protocol.hs @@ -223,7 +223,7 @@ processIncomming gs@GlobalState {..} = do                  Just (b, enc)                      | b .&. 0xE0 == 0x80 -> do                          ch <- maybe (throwError "unexpected encrypted packet") return mbch -                        dec <- channelDecrypt ch enc +                        (dec, _) <- channelDecrypt ch enc                          case B.uncons dec of                              Just (0x00, content) -> do @@ -297,7 +297,7 @@ processOutgoing gs@GlobalState {..} = do                  mbs <- case mbch of                      Just ch -> do                          runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case -                            Right ctext -> return $ Just $ 0x80 `B.cons` ctext +                            Right (ctext, _) -> return $ Just $ 0x80 `B.cons` ctext                              Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err                                             return Nothing                      Nothing | secure    -> return Nothing |