summaryrefslogtreecommitdiff
path: root/src/Channel.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Channel.hs')
-rw-r--r--src/Channel.hs66
1 files changed, 43 insertions, 23 deletions
diff --git a/src/Channel.hs b/src/Channel.hs
index 50e1b81..ad88190 100644
--- a/src/Channel.hs
+++ b/src/Channel.hs
@@ -13,7 +13,6 @@ module Channel (
import Control.Monad
import Control.Monad.Except
-import Control.Monad.Fail
import Crypto.Cipher.AES
import Crypto.Cipher.Types
@@ -43,6 +42,7 @@ data ChannelRequestData = ChannelRequest
{ crPeers :: [Stored (Signed IdentityData)]
, crKey :: Stored PublicKexKey
}
+ deriving (Show)
type ChannelAccept = Signed ChannelAcceptData
@@ -68,11 +68,15 @@ instance Storable Channel where
instance Storable ChannelRequestData where
store' cr = storeRec $ do
mapM_ (storeRef "peer") $ crPeers cr
+ storeText "enc" $ T.pack "aes-128-gcm"
storeRef "key" $ crKey cr
- load' = loadRec $ ChannelRequest
- <$> loadRefs "peer"
- <*> loadRef "key"
+ load' = loadRec $ do
+ enc <- loadText "enc"
+ guard $ enc == "aes-128-gcm"
+ ChannelRequest
+ <$> loadRefs "peer"
+ <*> loadRef "key"
instance Storable ChannelAcceptData where
store' ca = storeRec $ do
@@ -88,16 +92,18 @@ instance Storable ChannelAcceptData where
<*> loadRef "key"
-createChannelRequest :: Storage -> UnifiedIdentity -> UnifiedIdentity -> IO (Stored ChannelRequest)
-createChannelRequest st self peer = do
+createChannelRequest :: (MonadIO m) => Storage -> UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest)
+createChannelRequest st self peer = liftIO $ do
(_, xpublic) <- generateKeys st
Just skey <- loadKey $ idKeyMessage self
wrappedStore st =<< sign skey =<< wrappedStore st ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic }
-acceptChannelRequest :: UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> ExceptT [String] IO (Stored ChannelAccept, Stored Channel)
+acceptChannelRequest :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Stored Channel)
acceptChannelRequest self peer req = do
- guard $ (crPeers $ fromStored $ signedData $ fromStored req) == sort (map idData [self, peer])
- guard $ (idKeyMessage peer) `elem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)
+ when ((crPeers $ fromStored $ signedData $ fromStored req) /= sort (map idData [self, peer])) $
+ throwError $ "mismatched peers in channel request"
+ when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $
+ throwError $ "channel requent not signed by peer"
let st = storedStorage req
KeySizeFixed ksize = cipherKeySize (undefined :: AES128)
@@ -112,17 +118,22 @@ acceptChannelRequest self peer req = do
}
return (acc, ch)
-acceptedChannel :: UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> ExceptT [String] IO (Stored Channel)
+acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m (Stored Channel)
acceptedChannel self peer acc = do
let st = storedStorage acc
req = caRequest $ fromStored $ signedData $ fromStored acc
KeySizeFixed ksize = cipherKeySize (undefined :: AES128)
- guard $ (crPeers $ fromStored $ signedData $ fromStored req) == sort (map idData [self, peer])
- guard $ idKeyMessage peer `elem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)
- guard $ idKeyMessage self `elem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)
+ when ((crPeers $ fromStored $ signedData $ fromStored req) /= sort (map idData [self, peer])) $
+ throwError $ "mismatched peers in channel accept"
+ when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $
+ throwError $ "channel accept not signed by peer"
+ when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $
+ throwError $ "original channel request not signed by us"
- Just xsecret <- liftIO $ loadKey $ crKey $ fromStored $ signedData $ fromStored req
+ xsecret <- liftIO (loadKey $ crKey $ fromStored $ signedData $ fromStored req) >>= \case
+ Just key -> return key
+ Nothing -> throwError $ "secret key not found"
liftIO $ wrappedStore st Channel
{ chPeers = crPeers $ fromStored $ signedData $ fromStored req
, chKey = BA.take ksize $ dhSecret xsecret $
@@ -130,21 +141,30 @@ acceptedChannel self peer acc = do
}
-channelEncrypt :: (ByteArray ba, MonadRandom m, MonadFail m) => Channel -> ba -> m ba
+channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m ba
channelEncrypt ch plain = do
- CryptoPassed (cipher :: AES128) <- return $ cipherInit $ chKey ch
+ cipher <- case cipherInit $ chKey ch of
+ CryptoPassed (cipher :: AES128) -> return cipher
+ _ -> throwError "failed to init AES128 cipher"
let bsize = blockSize cipher
- (iv :: ByteString) <- getRandomBytes 12
- CryptoPassed aead <- return $ aeadInit AEAD_GCM cipher iv
+ (iv :: ByteString) <- liftIO $ getRandomBytes 12
+ aead <- case aeadInit AEAD_GCM cipher iv of
+ CryptoPassed aead -> return aead
+ _ -> throwError "failed to init AEAD_GCM"
let (tag, ctext) = aeadSimpleEncrypt aead B.empty plain bsize
return $ BA.concat [ convert iv, ctext, convert tag ]
-channelDecrypt :: (ByteArray ba, MonadFail m) => Channel -> ba -> m ba
+channelDecrypt :: (ByteArray ba, MonadError String m) => Channel -> ba -> m ba
channelDecrypt ch body = do
- CryptoPassed (cipher :: AES128) <- return $ cipherInit $ chKey ch
+ cipher <- case cipherInit $ chKey ch of
+ CryptoPassed (cipher :: AES128) -> return cipher
+ _ -> throwError "failed to init AES128 cipher"
let bsize = blockSize cipher
(iv, body') = BA.splitAt 12 body
(ctext, tag) = BA.splitAt (BA.length body' - bsize) body'
- CryptoPassed aead <- return $ aeadInit AEAD_GCM cipher iv
- Just plain <- return $ aeadSimpleDecrypt aead B.empty ctext (AuthTag $ convert tag)
- return plain
+ aead <- case aeadInit AEAD_GCM cipher iv of
+ CryptoPassed aead -> return aead
+ _ -> throwError "failed to init AEAD_GCM"
+ case aeadSimpleDecrypt aead B.empty ctext (AuthTag $ convert tag) of
+ Just plain -> return plain
+ Nothing -> throwError "failed to decrypt data"