summaryrefslogtreecommitdiff
path: root/src/Erebos/Channel.hs
blob: 5f66637f37911868a3cf08e47d0a5a0b6e896d03 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
module Erebos.Channel (
    Channel,
    ChannelRequest, ChannelRequestData(..),
    ChannelAccept, ChannelAcceptData(..),

    createChannelRequest,
    acceptChannelRequest,
    acceptedChannel,

    channelEncrypt,
    channelDecrypt,
) where

import Control.Concurrent.MVar
import Control.Monad
import Control.Monad.Except
import Control.Monad.IO.Class

import Crypto.Cipher.ChaChaPoly1305
import Crypto.Error

import Data.Binary
import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, convert)
import Data.ByteArray qualified as BA
import Data.ByteString.Lazy qualified as BL
import Data.List

import Erebos.Identity
import Erebos.PubKey
import Erebos.Storage

data Channel = Channel
    { chPeers :: [Stored (Signed IdentityData)]
    , chKey :: ScrubbedBytes
    , chNonceFixedOur :: Bytes
    , chNonceFixedPeer :: Bytes
    , chCounterNextOut :: MVar Word64
    , chCounterNextIn :: MVar Word64
    }

type ChannelRequest = Signed ChannelRequestData

data ChannelRequestData = ChannelRequest
    { crPeers :: [Stored (Signed IdentityData)]
    , crKey :: Stored PublicKexKey
    }
    deriving (Show)

type ChannelAccept = Signed ChannelAcceptData

data ChannelAcceptData = ChannelAccept
    { caRequest :: Stored ChannelRequest
    , caKey :: Stored PublicKexKey
    }


instance Storable ChannelRequestData where
    store' cr = storeRec $ do
        mapM_ (storeRef "peer") $ crPeers cr
        storeRef "key" $ crKey cr

    load' = loadRec $ do
        ChannelRequest
            <$> loadRefs "peer"
            <*> loadRef "key"

instance Storable ChannelAcceptData where
    store' ca = storeRec $ do
        storeRef "req" $ caRequest ca
        storeRef "key" $ caKey ca

    load' = loadRec $ do
        ChannelAccept
            <$> loadRef "req"
            <*> loadRef "key"


keySize :: Int
keySize = 32

createChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest)
createChannelRequest self peer = do
    (_, xpublic) <- liftIO . generateKeys =<< getStorage
    skey <- loadKey $ idKeyMessage self
    mstore =<< sign skey =<< mstore ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic }

acceptChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Channel)
acceptChannelRequest self peer req = do
    case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of
        Nothing -> throwError $ "invalid peers in channel request"
        Just peers -> do
            when (not $ any (self `sameIdentity`) peers) $
                throwError $ "self identity missing in channel request peers"
            when (not $ any (peer `sameIdentity`) peers) $
                throwError $ "peer identity missing in channel request peers"
    when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $
        throwError $ "channel requent not signed by peer"

    (xsecret, xpublic) <- liftIO . generateKeys =<< getStorage
    skey <- loadKey $ idKeyMessage self
    acc <- mstore =<< sign skey =<< mstore ChannelAccept { caRequest = req, caKey = xpublic }
    liftIO $ do
        let chPeers = crPeers $ fromStored $ signedData $ fromStored req
            chKey = BA.take keySize $ dhSecret xsecret $
                fromStored $ crKey $ fromStored $ signedData $ fromStored req
            chNonceFixedOur  = BA.pack [ 2, 0, 0, 0 ]
            chNonceFixedPeer = BA.pack [ 1, 0, 0, 0 ]
        chCounterNextOut <- newMVar 0
        chCounterNextIn <- newMVar 0

        return (acc, Channel {..})

acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel
acceptedChannel self peer acc = do
    let req = caRequest $ fromStored $ signedData $ fromStored acc
    case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of
        Nothing -> throwError $ "invalid peers in channel accept"
        Just peers -> do
            when (not $ any (self `sameIdentity`) peers) $
                throwError $ "self identity missing in channel accept peers"
            when (not $ any (peer `sameIdentity`) peers) $
                throwError $ "peer identity missing in channel accept peers"
    when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $
        throwError $ "channel accept not signed by peer"
    when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $
        throwError $ "original channel request not signed by us"

    xsecret <- loadKey $ crKey $ fromStored $ signedData $ fromStored req
    let chPeers = crPeers $ fromStored $ signedData $ fromStored req
        chKey = BA.take keySize $ dhSecret xsecret $
            fromStored $ caKey $ fromStored $ signedData $ fromStored acc
        chNonceFixedOur  = BA.pack [ 1, 0, 0, 0 ]
        chNonceFixedPeer = BA.pack [ 2, 0, 0, 0 ]
    chCounterNextOut <- liftIO $ newMVar 0
    chCounterNextIn <- liftIO $ newMVar 0

    return Channel {..}


channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64)
channelEncrypt Channel {..} plain = do
    count <- liftIO $ modifyMVar chCounterNextOut $ \c -> return (c + 1, c)
    let cbytes = convert $ BL.toStrict $ encode count
        nonce = nonce8 chNonceFixedOur cbytes
    state <- case initialize chKey =<< nonce of
        CryptoPassed state -> return state
        CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err

    let (ctext, state') = encrypt plain state
        tag = finalize state'
    return (BA.concat [ convert $ BA.drop 7 cbytes, ctext, convert tag ], count)

channelDecrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64)
channelDecrypt Channel {..} body = do
    when (BA.length body < 17) $ do
        throwError $ "invalid encrypted data length"

    expectedCount <- liftIO $ readMVar chCounterNextIn
    let countByte = body `BA.index` 0
        body' = BA.dropView body 1
        guessedCount = expectedCount - 128 + fromIntegral (countByte - fromIntegral expectedCount + 128 :: Word8)
        nonce = nonce8 chNonceFixedPeer $ convert $ BL.toStrict $ encode guessedCount
        blen = BA.length body' - 16
        ctext = BA.takeView body' blen
        tag = BA.dropView body' blen
    state <- case initialize chKey =<< nonce of
        CryptoPassed state -> return state
        CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err

    let (plain, state') = decrypt (convert ctext) state
    when (not $ tag `BA.constEq` finalize state') $ do
        throwError $ "tag validation falied"

    liftIO $ modifyMVar_ chCounterNextIn $ return . max (guessedCount + 1)
    return (plain, guessedCount)