diff options
Diffstat (limited to 'src/Erebos')
| -rw-r--r-- | src/Erebos/Attach.hs | 123 | ||||
| -rw-r--r-- | src/Erebos/Channel.hs | 175 | ||||
| -rw-r--r-- | src/Erebos/Contact.hs | 175 | ||||
| -rw-r--r-- | src/Erebos/Discovery.hs | 222 | ||||
| -rw-r--r-- | src/Erebos/Flow.hs | 73 | ||||
| -rw-r--r-- | src/Erebos/ICE.chs | 205 | ||||
| -rw-r--r-- | src/Erebos/ICE/pjproject.c | 363 | ||||
| -rw-r--r-- | src/Erebos/ICE/pjproject.h | 14 | ||||
| -rw-r--r-- | src/Erebos/Identity.hs | 402 | ||||
| -rw-r--r-- | src/Erebos/Message.hs | 267 | ||||
| -rw-r--r-- | src/Erebos/Network.hs | 860 | ||||
| -rw-r--r-- | src/Erebos/Network.hs-boot | 8 | ||||
| -rw-r--r-- | src/Erebos/Network/Protocol.hs | 753 | ||||
| -rw-r--r-- | src/Erebos/Network/ifaddrs.c | 41 | ||||
| -rw-r--r-- | src/Erebos/Network/ifaddrs.h | 3 | ||||
| -rw-r--r-- | src/Erebos/Pairing.hs | 242 | ||||
| -rw-r--r-- | src/Erebos/PubKey.hs | 156 | ||||
| -rw-r--r-- | src/Erebos/Service.hs | 190 | ||||
| -rw-r--r-- | src/Erebos/Set.hs | 78 | ||||
| -rw-r--r-- | src/Erebos/State.hs | 199 | ||||
| -rw-r--r-- | src/Erebos/Storage.hs | 1007 | ||||
| -rw-r--r-- | src/Erebos/Storage/Internal.hs | 282 | ||||
| -rw-r--r-- | src/Erebos/Storage/Key.hs | 85 | ||||
| -rw-r--r-- | src/Erebos/Storage/List.hs | 154 | ||||
| -rw-r--r-- | src/Erebos/Storage/Merge.hs | 156 | ||||
| -rw-r--r-- | src/Erebos/Sync.hs | 46 | ||||
| -rw-r--r-- | src/Erebos/Util.hs | 37 | 
27 files changed, 6316 insertions, 0 deletions
| diff --git a/src/Erebos/Attach.hs b/src/Erebos/Attach.hs new file mode 100644 index 0000000..bd2f521 --- /dev/null +++ b/src/Erebos/Attach.hs @@ -0,0 +1,123 @@ +module Erebos.Attach ( +    AttachService, +    attachToOwner, +    attachAccept, +    attachReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.ByteArray (ScrubbedBytes) +import Data.Maybe +import Data.Proxy +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Key + +type AttachService = PairingService AttachIdentity + +data AttachIdentity = AttachIdentity (Stored (Signed IdentityData)) [ScrubbedBytes] + +instance Storable AttachIdentity where +    store' (AttachIdentity x keys) = storeRec $ do +         storeRef "identity" x +         mapM_ (storeBinary "skey") keys + +    load' = loadRec $ AttachIdentity +        <$> loadRef "identity" +        <*> loadBinaries "skey" + +instance PairingResult AttachIdentity where +    pairingServiceID _ = mkServiceID "4995a5f9-2d4d-48e9-ad3b-0bf1c2a1be7f" + +    type PairingVerifiedResult AttachIdentity = (UnifiedIdentity, [ScrubbedBytes]) + +    pairingVerifyResult (AttachIdentity sdata keys) = do +        curid <- lsIdentity . fromStored <$> svcGetLocal +        secret <- loadKey $ eiddKeyIdentity $ fromSigned curid +        sdata' <- mstore =<< signAdd secret (fromStored sdata) +        return $ do +            guard $ iddKeyIdentity (fromSigned sdata) == +                eiddKeyIdentity (fromSigned curid) +            identity <- validateIdentity sdata' +            guard $ iddPrev (fromSigned $ idData identity) == [eiddStoredBase curid] +            return (identity, keys) + +    pairingFinalizeRequest (identity, keys) = updateLocalHead_ $ \slocal -> do +        let owner = finalOwner identity +        st <- getStorage +        pkeys <- mapM (copyStored st) [ idKeyIdentity owner, idKeyMessage owner ] +        liftIO $ mapM_ storeKey $ catMaybes [ keyFromData sec pub | sec <- keys, pub <- pkeys ] + +        identity' <- mergeIdentity $ updateIdentity [ lsIdentity $ fromStored slocal ] identity +        shared <- makeSharedStateUpdate st (Just owner) (lsShared $ fromStored slocal) +        mstore (fromStored slocal) +            { lsIdentity = idExtData identity' +            , lsShared = [ shared ] +            } + +    pairingFinalizeResponse = do +        owner <- mergeSharedIdentity +        pid <- asks svcPeerIdentity +        secret <- loadKey $ idKeyIdentity owner +        identity <- mstore =<< sign secret =<< mstore (emptyIdentityData $ idKeyIdentity pid) +            { iddPrev = [idData pid], iddOwner = Just (idData owner) } +        skeys <- map keyGetData . catMaybes <$> mapM loadKeyMb [ idKeyIdentity owner, idKeyMessage owner ] +        return $ AttachIdentity identity skeys + +    defaultPairingAttributes _ = PairingAttributes +        { pairingHookRequest = do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ " initiated" + +        , pairingHookResponse = \confirm -> do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Attach to " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + +        , pairingHookRequestNonce = \confirm -> do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + +        , pairingHookRequestNonceFailed = do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Failed attach from " ++ T.unpack (displayIdentity peer) + +        , pairingHookConfirmedResponse = do +            svcPrint $ "Confirmed peer, waiting for updated identity" + +        , pairingHookConfirmedRequest = do +            svcPrint $ "Attachment confirmed by peer" + +        , pairingHookAcceptedResponse = do +            svcPrint $ "Accepted updated identity" + +        , pairingHookAcceptedRequest = do +            svcPrint $ "Accepted new attached device, seding updated identity" + +        , pairingHookVerifyFailed = do +            svcPrint $ "Failed to verify new identity" + +        , pairingHookRejected = do +            svcPrint $ "Attachment rejected by peer" + +        , pairingHookFailed = \_ -> do +            svcPrint $ "Attachement failed" +        } + +attachToOwner :: (MonadIO m, MonadError String m) => Peer -> m () +attachToOwner = pairingRequest @AttachIdentity Proxy + +attachAccept :: (MonadIO m, MonadError String m) => Peer -> m () +attachAccept = pairingAccept @AttachIdentity Proxy + +attachReject :: (MonadIO m, MonadError String m) => Peer -> m () +attachReject = pairingReject @AttachIdentity Proxy diff --git a/src/Erebos/Channel.hs b/src/Erebos/Channel.hs new file mode 100644 index 0000000..5f66637 --- /dev/null +++ b/src/Erebos/Channel.hs @@ -0,0 +1,175 @@ +module Erebos.Channel ( +    Channel, +    ChannelRequest, ChannelRequestData(..), +    ChannelAccept, ChannelAcceptData(..), + +    createChannelRequest, +    acceptChannelRequest, +    acceptedChannel, + +    channelEncrypt, +    channelDecrypt, +) where + +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Crypto.Cipher.ChaChaPoly1305 +import Crypto.Error + +import Data.Binary +import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, convert) +import Data.ByteArray qualified as BA +import Data.ByteString.Lazy qualified as BL +import Data.List + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Storage + +data Channel = Channel +    { chPeers :: [Stored (Signed IdentityData)] +    , chKey :: ScrubbedBytes +    , chNonceFixedOur :: Bytes +    , chNonceFixedPeer :: Bytes +    , chCounterNextOut :: MVar Word64 +    , chCounterNextIn :: MVar Word64 +    } + +type ChannelRequest = Signed ChannelRequestData + +data ChannelRequestData = ChannelRequest +    { crPeers :: [Stored (Signed IdentityData)] +    , crKey :: Stored PublicKexKey +    } +    deriving (Show) + +type ChannelAccept = Signed ChannelAcceptData + +data ChannelAcceptData = ChannelAccept +    { caRequest :: Stored ChannelRequest +    , caKey :: Stored PublicKexKey +    } + + +instance Storable ChannelRequestData where +    store' cr = storeRec $ do +        mapM_ (storeRef "peer") $ crPeers cr +        storeRef "key" $ crKey cr + +    load' = loadRec $ do +        ChannelRequest +            <$> loadRefs "peer" +            <*> loadRef "key" + +instance Storable ChannelAcceptData where +    store' ca = storeRec $ do +        storeRef "req" $ caRequest ca +        storeRef "key" $ caKey ca + +    load' = loadRec $ do +        ChannelAccept +            <$> loadRef "req" +            <*> loadRef "key" + + +keySize :: Int +keySize = 32 + +createChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest) +createChannelRequest self peer = do +    (_, xpublic) <- liftIO . generateKeys =<< getStorage +    skey <- loadKey $ idKeyMessage self +    mstore =<< sign skey =<< mstore ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic } + +acceptChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Channel) +acceptChannelRequest self peer req = do +    case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of +        Nothing -> throwError $ "invalid peers in channel request" +        Just peers -> do +            when (not $ any (self `sameIdentity`) peers) $ +                throwError $ "self identity missing in channel request peers" +            when (not $ any (peer `sameIdentity`) peers) $ +                throwError $ "peer identity missing in channel request peers" +    when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ +        throwError $ "channel requent not signed by peer" + +    (xsecret, xpublic) <- liftIO . generateKeys =<< getStorage +    skey <- loadKey $ idKeyMessage self +    acc <- mstore =<< sign skey =<< mstore ChannelAccept { caRequest = req, caKey = xpublic } +    liftIO $ do +        let chPeers = crPeers $ fromStored $ signedData $ fromStored req +            chKey = BA.take keySize $ dhSecret xsecret $ +                fromStored $ crKey $ fromStored $ signedData $ fromStored req +            chNonceFixedOur  = BA.pack [ 2, 0, 0, 0 ] +            chNonceFixedPeer = BA.pack [ 1, 0, 0, 0 ] +        chCounterNextOut <- newMVar 0 +        chCounterNextIn <- newMVar 0 + +        return (acc, Channel {..}) + +acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel +acceptedChannel self peer acc = do +    let req = caRequest $ fromStored $ signedData $ fromStored acc +    case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of +        Nothing -> throwError $ "invalid peers in channel accept" +        Just peers -> do +            when (not $ any (self `sameIdentity`) peers) $ +                throwError $ "self identity missing in channel accept peers" +            when (not $ any (peer `sameIdentity`) peers) $ +                throwError $ "peer identity missing in channel accept peers" +    when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $ +        throwError $ "channel accept not signed by peer" +    when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ +        throwError $ "original channel request not signed by us" + +    xsecret <- loadKey $ crKey $ fromStored $ signedData $ fromStored req +    let chPeers = crPeers $ fromStored $ signedData $ fromStored req +        chKey = BA.take keySize $ dhSecret xsecret $ +            fromStored $ caKey $ fromStored $ signedData $ fromStored acc +        chNonceFixedOur  = BA.pack [ 1, 0, 0, 0 ] +        chNonceFixedPeer = BA.pack [ 2, 0, 0, 0 ] +    chCounterNextOut <- liftIO $ newMVar 0 +    chCounterNextIn <- liftIO $ newMVar 0 + +    return Channel {..} + + +channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelEncrypt Channel {..} plain = do +    count <- liftIO $ modifyMVar chCounterNextOut $ \c -> return (c + 1, c) +    let cbytes = convert $ BL.toStrict $ encode count +        nonce = nonce8 chNonceFixedOur cbytes +    state <- case initialize chKey =<< nonce of +        CryptoPassed state -> return state +        CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + +    let (ctext, state') = encrypt plain state +        tag = finalize state' +    return (BA.concat [ convert $ BA.drop 7 cbytes, ctext, convert tag ], count) + +channelDecrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelDecrypt Channel {..} body = do +    when (BA.length body < 17) $ do +        throwError $ "invalid encrypted data length" + +    expectedCount <- liftIO $ readMVar chCounterNextIn +    let countByte = body `BA.index` 0 +        body' = BA.dropView body 1 +        guessedCount = expectedCount - 128 + fromIntegral (countByte - fromIntegral expectedCount + 128 :: Word8) +        nonce = nonce8 chNonceFixedPeer $ convert $ BL.toStrict $ encode guessedCount +        blen = BA.length body' - 16 +        ctext = BA.takeView body' blen +        tag = BA.dropView body' blen +    state <- case initialize chKey =<< nonce of +        CryptoPassed state -> return state +        CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + +    let (plain, state') = decrypt (convert ctext) state +    when (not $ tag `BA.constEq` finalize state') $ do +        throwError $ "tag validation falied" + +    liftIO $ modifyMVar_ chCounterNextIn $ return . max (guessedCount + 1) +    return (plain, guessedCount) diff --git a/src/Erebos/Contact.hs b/src/Erebos/Contact.hs new file mode 100644 index 0000000..d90aa50 --- /dev/null +++ b/src/Erebos/Contact.hs @@ -0,0 +1,175 @@ +module Erebos.Contact ( +    Contact, +    contactIdentity, +    contactCustomName, +    contactName, + +    contactSetName, + +    ContactService, +    contactRequest, +    contactAccept, +    contactReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Maybe +import Data.Proxy +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data Contact = Contact +    { contactData :: [Stored ContactData] +    , contactIdentity_ :: Maybe ComposedIdentity +    , contactCustomName_ :: Maybe Text +    } + +data ContactData = ContactData +    { cdPrev :: [Stored ContactData] +    , cdIdentity :: [Stored (Signed ExtendedIdentityData)] +    , cdName :: Maybe Text +    } + +instance Storable ContactData where +    store' x = storeRec $ do +        mapM_ (storeRef "PREV") $ cdPrev x +        mapM_ (storeRef "identity") $ cdIdentity x +        storeMbText "name" $ cdName x + +    load' = loadRec $ ContactData +        <$> loadRefs "PREV" +        <*> loadRefs "identity" +        <*> loadMbText "name" + +instance Mergeable Contact where +    type Component Contact = ContactData + +    mergeSorted cdata = Contact +        { contactData = cdata +        , contactIdentity_ = validateExtendedIdentityF $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . cdIdentity) cdata +        , contactCustomName_ = findPropertyFirst cdName cdata +        } + +    toComponents = contactData + +instance SharedType (Set Contact) where +    sharedTypeID _ = mkSharedTypeID "34fbb61e-6022-405f-b1b3-a5a1abecd25e" + +contactIdentity :: Contact -> Maybe ComposedIdentity +contactIdentity = contactIdentity_ + +contactCustomName :: Contact -> Maybe Text +contactCustomName = contactCustomName_ + +contactName :: Contact -> Text +contactName c = fromJust $ msum +    [ contactCustomName c +    , idName =<< contactIdentity c +    , Just T.empty +    ] + +contactSetName :: MonadHead LocalState m => Contact -> Text -> Set Contact -> m (Set Contact) +contactSetName contact name set = do +    st <- getStorage +    cdata <- wrappedStore st ContactData +        { cdPrev = toComponents contact +        , cdIdentity = [] +        , cdName = Just name +        } +    storeSetAdd st (mergeSorted @Contact [cdata]) set + + +type ContactService = PairingService ContactAccepted + +data ContactAccepted = ContactAccepted + +instance Storable ContactAccepted where +    store' ContactAccepted = storeRec $ do +        storeText "accept" "" +    load' = loadRec $ do +        (_ :: T.Text) <- loadText "accept" +        return ContactAccepted + +instance PairingResult ContactAccepted where +    pairingServiceID _ = mkServiceID "d9c37368-0da1-4280-93e9-d9bd9a198084" + +    pairingVerifyResult = return . Just + +    pairingFinalizeRequest ContactAccepted = do +        pid <- asks svcPeerIdentity +        finalizeContact pid + +    pairingFinalizeResponse = do +        pid <- asks svcPeerIdentity +        finalizeContact pid +        return ContactAccepted + +    defaultPairingAttributes _ = PairingAttributes +        { pairingHookRequest = do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Contact pairing from " ++ T.unpack (displayIdentity peer) ++ " initiated" + +        , pairingHookResponse = \confirm -> do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Confirm contact " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + +        , pairingHookRequestNonce = \confirm -> do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Contact request from " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + +        , pairingHookRequestNonceFailed = do +            peer <- asks $ svcPeerIdentity +            svcPrint $ "Failed contact request from " ++ T.unpack (displayIdentity peer) + +        , pairingHookConfirmedResponse = do +            svcPrint $ "Contact accepted, waiting for peer confirmation" + +        , pairingHookConfirmedRequest = do +            svcPrint $ "Contact confirmed by peer" + +        , pairingHookAcceptedResponse = do +            svcPrint $ "Contact accepted" + +        , pairingHookAcceptedRequest = do +            svcPrint $ "Contact accepted" + +        , pairingHookVerifyFailed = return () + +        , pairingHookRejected = do +            svcPrint $ "Contact rejected by peer" + +        , pairingHookFailed = \_ -> do +            svcPrint $ "Contact failed" +        } + +contactRequest :: (MonadIO m, MonadError String m) => Peer -> m () +contactRequest = pairingRequest @ContactAccepted Proxy + +contactAccept :: (MonadIO m, MonadError String m) => Peer -> m () +contactAccept = pairingAccept @ContactAccepted Proxy + +contactReject :: (MonadIO m, MonadError String m) => Peer -> m () +contactReject = pairingReject @ContactAccepted Proxy + +finalizeContact :: MonadHead LocalState m => UnifiedIdentity -> m () +finalizeContact identity = updateLocalHead_ $ updateSharedState_ $ \contacts -> do +    st <- getStorage +    cdata <- wrappedStore st ContactData +        { cdPrev = [] +        , cdIdentity = idExtDataF $ finalOwner identity +        , cdName = Nothing +        } +    storeSetAdd st (mergeSorted @Contact [cdata]) contacts diff --git a/src/Erebos/Discovery.hs b/src/Erebos/Discovery.hs new file mode 100644 index 0000000..86bdbe7 --- /dev/null +++ b/src/Erebos/Discovery.hs @@ -0,0 +1,222 @@ +module Erebos.Discovery ( +    DiscoveryService(..), +    DiscoveryConnection(..) +) where + +import Control.Concurrent +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import Data.Maybe +import Data.Text (Text) +import qualified Data.Text as T + +import Network.Socket + +import Erebos.ICE +import Erebos.Identity +import Erebos.Network +import Erebos.Service +import Erebos.Storage + + +keepaliveSeconds :: Int +keepaliveSeconds = 20 + + +data DiscoveryService = DiscoverySelf Text Int +                      | DiscoveryAcknowledged Text +                      | DiscoverySearch Ref +                      | DiscoveryResult Ref (Maybe Text) +                      | DiscoveryConnectionRequest DiscoveryConnection +                      | DiscoveryConnectionResponse DiscoveryConnection + +data DiscoveryConnection = DiscoveryConnection +    { dconnSource :: Ref +    , dconnTarget :: Ref +    , dconnAddress :: Maybe Text +    , dconnIceSession :: Maybe IceRemoteInfo +    } + +emptyConnection :: Ref -> Ref -> DiscoveryConnection +emptyConnection source target = DiscoveryConnection source target Nothing Nothing + +instance Storable DiscoveryService where +    store' x = storeRec $ do +        case x of +            DiscoverySelf addr priority -> do +                storeText "self" addr +                storeInt "priority" priority +            DiscoveryAcknowledged addr -> do +                storeText "ack" addr +            DiscoverySearch ref -> storeRawRef "search" ref +            DiscoveryResult ref addr -> do +                storeRawRef "result" ref +                storeMbText "address" addr +            DiscoveryConnectionRequest conn -> storeConnection "request" conn +            DiscoveryConnectionResponse conn -> storeConnection "response" conn + +        where storeConnection ctype conn = do +                  storeText "connection" $ ctype +                  storeRawRef "source" $ dconnSource conn +                  storeRawRef "target" $ dconnTarget conn +                  storeMbText "address" $ dconnAddress conn +                  storeMbRef "ice-session" $ dconnIceSession conn + +    load' = loadRec $ msum +            [ DiscoverySelf +                <$> loadText "self" +                <*> loadInt "priority" +            , DiscoveryAcknowledged +                <$> loadText "ack" +            , DiscoverySearch <$> loadRawRef "search" +            , DiscoveryResult +                <$> loadRawRef "result" +                <*> loadMbText "address" +            , loadConnection "request" DiscoveryConnectionRequest +            , loadConnection "response" DiscoveryConnectionResponse +            ] +        where loadConnection ctype ctor = do +                  ctype' <- loadText "connection" +                  guard $ ctype == ctype' +                  return . ctor =<< DiscoveryConnection +                      <$> loadRawRef "source" +                      <*> loadRawRef "target" +                      <*> loadMbText "address" +                      <*> loadMbRef "ice-session" + +data DiscoveryPeer = DiscoveryPeer +    { dpPriority :: Int +    , dpPeer :: Maybe Peer +    , dpAddress :: Maybe Text +    , dpIceSession :: Maybe IceSession +    } + +instance Service DiscoveryService where +    serviceID _ = mkServiceID "dd59c89c-69cc-4703-b75b-4ddcd4b3c23b" + +    type ServiceGlobalState DiscoveryService = Map RefDigest DiscoveryPeer +    emptyServiceGlobalState _ = M.empty + +    serviceHandler msg = case fromStored msg of +        DiscoverySelf addr priority -> do +            pid <- asks svcPeerIdentity +            peer <- asks svcPeer +            let insertHelper new old | dpPriority new > dpPriority old = new +                                     | otherwise                       = old +            mbaddr <- case words (T.unpack addr) of +                [ipaddr, port] | DatagramAddress paddr <- peerAddress peer -> do +                    saddr <- liftIO $ head <$> getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) +                    return $ if paddr == addrAddress saddr +                                then Just addr +                                else Nothing +                _ -> return Nothing +            forM_ (idDataF =<< unfoldOwners pid) $ \s -> +                svcModifyGlobal $ M.insertWith insertHelper (refDigest $ storedRef s) $ +                    DiscoveryPeer priority (Just peer) mbaddr Nothing +            replyPacket $ DiscoveryAcknowledged $ fromMaybe (T.pack "ICE") mbaddr + +        DiscoveryAcknowledged addr -> do +            when (addr == T.pack "ICE") $ do +                -- keep-alive packet from behind NAT +                peer <- asks svcPeer +                liftIO $ void $ forkIO $ do +                    threadDelay (keepaliveSeconds * 1000 * 1000) +                    res <- runExceptT $ sendToPeer peer $ DiscoverySelf addr 0 +                    case res of +                        Right _ -> return () +                        Left err -> putStrLn $ "Discovery: failed to send keep-alive: " ++ err + +        DiscoverySearch ref -> do +            addr <- M.lookup (refDigest ref) <$> svcGetGlobal +            replyPacket $ DiscoveryResult ref $ fromMaybe (T.pack "ICE") . dpAddress <$> addr + +        DiscoveryResult ref Nothing -> do +            svcPrint $ "Discovery: " ++ show (refDigest ref) ++ " not found" + +        DiscoveryResult ref (Just addr) -> do +            -- TODO: check if we really requested that +            server <- asks svcServer +            if addr == T.pack "ICE" +               then do +                    self <- svcSelf +                    peer <- asks svcPeer +                    ice <- liftIO $ iceCreate PjIceSessRoleControlling $ \ice -> do +                        rinfo <- iceRemoteInfo ice +                        res <- runExceptT $ sendToPeer peer $ +                            DiscoveryConnectionRequest (emptyConnection (storedRef $ idData self) ref) { dconnIceSession = Just rinfo } +                        case res of +                            Right _ -> return () +                            Left err -> putStrLn $ "Discovery: failed to send connection request: " ++ err + +                    svcModifyGlobal $ M.insert (refDigest ref) $ +                        DiscoveryPeer 0 Nothing Nothing (Just ice) +               else do +                    case words (T.unpack addr) of +                        [ipaddr, port] -> do +                            saddr <- liftIO $ head <$> +                                getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) +                            peer <- liftIO $ serverPeer server (addrAddress saddr) +                            svcModifyGlobal $ M.insert (refDigest ref) $ +                                DiscoveryPeer 0 (Just peer) Nothing Nothing + +                        _ -> svcPrint $ "Discovery: invalid address in result: " ++ T.unpack addr + +        DiscoveryConnectionRequest conn -> do +            self <- svcSelf +            let rconn = emptyConnection (dconnSource conn) (dconnTarget conn) +            if refDigest (dconnTarget conn) `elem` (map (refDigest . storedRef) $ idDataF =<< unfoldOwners self) +               then do +                    -- request for us, create ICE sesssion +                    server <- asks svcServer +                    peer <- asks svcPeer +                    liftIO $ void $ iceCreate PjIceSessRoleControlled $ \ice -> do +                        rinfo <- iceRemoteInfo ice +                        res <- runExceptT $ sendToPeer peer $ DiscoveryConnectionResponse rconn { dconnIceSession = Just rinfo } +                        case res of +                            Right _ -> do +                                case dconnIceSession conn of +                                    Just prinfo -> iceConnect ice prinfo $ void $ serverPeerIce server ice +                                    Nothing -> putStrLn $ "Discovery: connection request without ICE remote info" +                            Left err -> putStrLn $ "Discovery: failed to send connection response: " ++ err + +               else do +                    -- request to some of our peers, relay +                    mbdp <- M.lookup (refDigest $ dconnTarget conn) <$> svcGetGlobal +                    case mbdp of +                        Nothing -> replyPacket $ DiscoveryConnectionResponse rconn +                        Just dp | Just addr <- dpAddress dp -> do +                                    replyPacket $ DiscoveryConnectionResponse rconn { dconnAddress = Just addr } +                                | Just dpeer <- dpPeer dp -> do +                                    sendToPeer dpeer $ DiscoveryConnectionRequest conn +                                | otherwise -> svcPrint $ "Discovery: failed to relay connection request" + +        DiscoveryConnectionResponse conn -> do +            self <- svcSelf +            dpeers <- svcGetGlobal +            if refDigest (dconnSource conn) `elem` (map (refDigest . storedRef) $ idDataF =<< unfoldOwners self) +               then do +                    -- response to our request, try to connect to the peer +                    server <- asks svcServer +                    if  | Just addr <- dconnAddress conn +                        , [ipaddr, port] <- words (T.unpack addr) -> do +                            saddr <- liftIO $ head <$> +                                getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) +                            peer <- liftIO $ serverPeer server (addrAddress saddr) +                            svcModifyGlobal $ M.insert (refDigest $ dconnTarget conn) $ +                                DiscoveryPeer 0 (Just peer) Nothing Nothing + +                        | Just dp <- M.lookup (refDigest $ dconnTarget conn) dpeers +                        , Just ice <- dpIceSession dp +                        , Just rinfo <- dconnIceSession conn -> do +                            liftIO $ iceConnect ice rinfo $ void $ serverPeerIce server ice + +                        | otherwise -> svcPrint $ "Discovery: connection request failed" +               else do +                    -- response to relayed request +                    case M.lookup (refDigest $ dconnSource conn) dpeers of +                        Just dp | Just dpeer <- dpPeer dp -> do +                            sendToPeer dpeer $ DiscoveryConnectionResponse conn +                        _ -> svcPrint $ "Discovery: failed to relay connection response" diff --git a/src/Erebos/Flow.hs b/src/Erebos/Flow.hs new file mode 100644 index 0000000..ba2607a --- /dev/null +++ b/src/Erebos/Flow.hs @@ -0,0 +1,73 @@ +module Erebos.Flow ( +    Flow, SymFlow, +    newFlow, newFlowIO, +    readFlow, tryReadFlow, canReadFlow, +    writeFlow, writeFlowBulk, tryWriteFlow, canWriteFlow, +    readFlowIO, writeFlowIO, + +    mapFlow, +) where + +import Control.Concurrent.STM + + +data Flow r w = Flow (TMVar [r]) (TMVar [w]) +              | forall r' w'. MappedFlow (r' -> r) (w -> w') (Flow r' w') + +type SymFlow a = Flow a a + +newFlow :: STM (Flow a b, Flow b a) +newFlow = do +    x <- newEmptyTMVar +    y <- newEmptyTMVar +    return (Flow x y, Flow y x) + +newFlowIO :: IO (Flow a b, Flow b a) +newFlowIO = atomically newFlow + +readFlow :: Flow r w -> STM r +readFlow (Flow rvar _) = takeTMVar rvar >>= \case +    (x:[]) -> return x +    (x:xs) -> putTMVar rvar xs >> return x +    [] -> error "Flow: empty list" +readFlow (MappedFlow f _ up) = f <$> readFlow up + +tryReadFlow :: Flow r w -> STM (Maybe r) +tryReadFlow (Flow rvar _) = tryTakeTMVar rvar >>= \case +    Just (x:[]) -> return (Just x) +    Just (x:xs) -> putTMVar rvar xs >> return (Just x) +    Just [] -> error "Flow: empty list" +    Nothing -> return Nothing +tryReadFlow (MappedFlow f _ up) = fmap f <$> tryReadFlow up + +canReadFlow :: Flow r w -> STM Bool +canReadFlow (Flow rvar _) = not <$> isEmptyTMVar rvar +canReadFlow (MappedFlow _ _ up) = canReadFlow up + +writeFlow :: Flow r w -> w -> STM () +writeFlow (Flow _ wvar) = putTMVar wvar . (:[]) +writeFlow (MappedFlow _ f up) = writeFlow up . f + +writeFlowBulk :: Flow r w -> [w] -> STM () +writeFlowBulk _ [] = return () +writeFlowBulk (Flow _ wvar) xs = putTMVar wvar xs +writeFlowBulk (MappedFlow _ f up) xs = writeFlowBulk up $ map f xs + +tryWriteFlow :: Flow r w -> w -> STM Bool +tryWriteFlow (Flow _ wvar) = tryPutTMVar wvar . (:[]) +tryWriteFlow (MappedFlow _ f up) = tryWriteFlow up . f + +canWriteFlow :: Flow r w -> STM Bool +canWriteFlow (Flow _ wvar) = isEmptyTMVar wvar +canWriteFlow (MappedFlow _ _ up) = canWriteFlow up + +readFlowIO :: Flow r w -> IO r +readFlowIO path = atomically $ readFlow path + +writeFlowIO :: Flow r w -> w -> IO () +writeFlowIO path = atomically . writeFlow path + + +mapFlow :: (r -> r') -> (w' -> w) -> Flow r w -> Flow r' w' +mapFlow rf wf (MappedFlow rf' wf' up) = MappedFlow (rf . rf') (wf' . wf) up +mapFlow rf wf up = MappedFlow rf wf up diff --git a/src/Erebos/ICE.chs b/src/Erebos/ICE.chs new file mode 100644 index 0000000..096ee0d --- /dev/null +++ b/src/Erebos/ICE.chs @@ -0,0 +1,205 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE RecursiveDo #-} + +module Erebos.ICE ( +    IceSession, +    IceSessionRole(..), +    IceRemoteInfo, + +    iceCreate, +    iceDestroy, +    iceRemoteInfo, +    iceShow, +    iceConnect, +    iceSend, + +    iceSetChan, +) where + +import Control.Arrow +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity + +import Data.ByteString (ByteString, packCStringLen, useAsCString) +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.ByteString.Unsafe +import Data.Function +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import qualified Data.Text.Read as T +import Data.Void + +import Foreign.C.String +import Foreign.C.Types +import Foreign.Marshal.Alloc +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.StablePtr + +import Erebos.Flow +import Erebos.Storage + +#include "pjproject.h" + +data IceSession = IceSession +    { isStrans :: PjIceStrans +    , isChan :: MVar (Either [ByteString] (Flow Void ByteString)) +    } + +instance Eq IceSession where +    (==) = (==) `on` isStrans + +instance Ord IceSession where +    compare = compare `on` isStrans + +instance Show IceSession where +    show _ = "<ICE>" + + +data IceRemoteInfo = IceRemoteInfo +    { iriUsernameFrament :: Text +    , iriPassword :: Text +    , iriDefaultCandidate :: Text +    , iriCandidates :: [Text] +    } + +data IceCandidate = IceCandidate +    { icandFoundation :: Text +    , icandPriority :: Int +    , icandAddr :: Text +    , icandPort :: Int +    , icandType :: Text +    } + +instance Storable IceRemoteInfo where +    store' x = storeRec $ do +        storeText "ice-ufrag" $ iriUsernameFrament x +        storeText "ice-pass" $ iriPassword x +        storeText "ice-default" $ iriDefaultCandidate x +        mapM_ (storeText "ice-candidate") $ iriCandidates x + +    load' = loadRec $ IceRemoteInfo +        <$> loadText "ice-ufrag" +        <*> loadText "ice-pass" +        <*> loadText "ice-default" +        <*> loadTexts "ice-candidate" + +instance StorableText IceCandidate where +    toText x = T.concat $ +        [ icandFoundation x +        , T.singleton ' ' +        , T.pack $ show $ icandPriority x +        , T.singleton ' ' +        , icandAddr x +        , T.singleton ' ' +        , T.pack $ show $ icandPort x +        , T.singleton ' ' +        , icandType x +        ] + +    fromText t = case T.words t of +        [found, tprio, addr, tport, ctype] +            | Right (prio, _) <- T.decimal tprio +            , Right (port, _) <- T.decimal tport +            -> return $ IceCandidate +                { icandFoundation = found +                , icandPriority = prio +                , icandAddr = addr +                , icandPort = port +                , icandType = ctype +                } +        _ -> throwError "failed to parse candidate" + + +{#enum pj_ice_sess_role as IceSessionRole {underscoreToCase} deriving (Show, Eq) #} + +{#pointer *pj_ice_strans as ^ #} + +iceCreate :: IceSessionRole -> (IceSession -> IO ()) -> IO IceSession +iceCreate role cb = do +    rec sptr <- newStablePtr sess +        cbptr <- newStablePtr $ cb sess +        sess <- IceSession +            <$> {#call ice_create #} (fromIntegral $ fromEnum role) (castStablePtrToPtr sptr) (castStablePtrToPtr cbptr) +            <*> (newMVar $ Left []) +    return $ sess + +{#fun ice_destroy as ^ { isStrans `IceSession' } -> `()' #} + +iceRemoteInfo :: IceSession -> IO IceRemoteInfo +iceRemoteInfo sess = do +    let maxlen = 128 +        maxcand = 29 + +    allocaBytes maxlen $ \ufrag -> +        allocaBytes maxlen $ \pass -> +        allocaBytes maxlen $ \def -> +        allocaBytes (maxcand*maxlen) $ \bytes -> +        allocaArray maxcand $ \carr -> do +        let cptrs = take maxcand $ iterate (`plusPtr` maxlen) bytes +        pokeArray carr $ take maxcand cptrs + +        ncand <- {#call ice_encode_session #} (isStrans sess) ufrag pass def carr (fromIntegral maxlen) (fromIntegral maxcand) +        if ncand < 0 then fail "failed to generate ICE remote info" +                     else IceRemoteInfo +                              <$> (T.pack <$> peekCString ufrag) +                              <*> (T.pack <$> peekCString pass) +                              <*> (T.pack <$> peekCString def) +                              <*> (mapM (return . T.pack <=< peekCString) $ take (fromIntegral ncand) cptrs) + +iceShow :: IceSession -> IO String +iceShow sess = do +    st <- memoryStorage +    return . drop 1 . dropWhile (/='\n') . BLC.unpack . runIdentity =<< +        ioLoadBytes =<< store st =<< iceRemoteInfo sess + +iceConnect :: IceSession -> IceRemoteInfo -> (IO ()) -> IO () +iceConnect sess remote cb = do +    cbptr <- newStablePtr $ cb +    ice_connect sess cbptr +        (iriUsernameFrament remote) +        (iriPassword remote) +        (iriDefaultCandidate remote) +        (iriCandidates remote) + +{#fun ice_connect { isStrans `IceSession', castStablePtrToPtr `StablePtr (IO ())', +    withText* `Text',  withText* `Text', withText* `Text', withTextArray* `[Text]'& } -> `()' #} + +withText :: Text -> (Ptr CChar -> IO a) -> IO a +withText t f = useAsCString (T.encodeUtf8 t) f + +withTextArray :: Num n => [Text] -> ((Ptr (Ptr CChar), n) -> IO ()) -> IO () +withTextArray tsAll f = helper tsAll [] +    where helper (t:ts) bs = withText t $ \b -> helper ts (b:bs) +          helper [] bs = allocaArray (length bs) $ \ptr -> do +              pokeArray ptr $ reverse bs +              f (ptr, fromIntegral $ length bs) + +withByteStringLen :: Num n => ByteString -> ((Ptr CChar, n) -> IO a) -> IO a +withByteStringLen t f = unsafeUseAsCStringLen t (f . (id *** fromIntegral)) + +{#fun ice_send as ^ { isStrans `IceSession', withByteStringLen* `ByteString'& } -> `()' #} + +foreign export ccall ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb = join . deRefStablePtr + +iceSetChan :: IceSession -> Flow Void ByteString -> IO () +iceSetChan sess chan = do +    modifyMVar_ (isChan sess) $ \orig -> do +        case orig of +             Left buf -> mapM_ (writeFlowIO chan) $ reverse buf +             Right _ -> return () +        return $ Right chan + +foreign export ccall ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data sptr buf len = do +    sess <- deRefStablePtr sptr +    bs <- packCStringLen (buf, len) +    modifyMVar_ (isChan sess) $ \case +            mc@(Right chan) -> writeFlowIO chan bs >> return mc +            Left bss -> return $ Left (bs:bss) diff --git a/src/Erebos/ICE/pjproject.c b/src/Erebos/ICE/pjproject.c new file mode 100644 index 0000000..bb06b1f --- /dev/null +++ b/src/Erebos/ICE/pjproject.c @@ -0,0 +1,363 @@ +#include "pjproject.h" +#include "Erebos/ICE_stub.h" + +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <pthread.h> +#include <pjlib.h> +#include <pjlib-util.h> + +static struct +{ +	pj_caching_pool cp; +	pj_pool_t * pool; +	pj_ice_strans_cfg cfg; +	pj_sockaddr def_addr; +} ice; + +struct user_data +{ +	pj_ice_sess_role role; +	HsStablePtr sptr; +	HsStablePtr cb_init; +	HsStablePtr cb_connect; +}; + +static void ice_perror(const char * msg, pj_status_t status) +{ +	char err[PJ_ERR_MSG_SIZE]; +	pj_strerror(status, err, sizeof(err)); +	fprintf(stderr, "ICE: %s: %s\n", msg, err); +} + +static int ice_worker_thread(void * unused) +{ +	PJ_UNUSED_ARG(unused); + +	while (true) { +		pj_time_val max_timeout = { 0, 0 }; +		pj_time_val timeout = { 0, 0 }; + +		max_timeout.msec = 500; + +		pj_timer_heap_poll(ice.cfg.stun_cfg.timer_heap, &timeout); + +		pj_assert(timeout.sec >= 0 && timeout.msec >= 0); +		if (timeout.msec >= 1000) +			timeout.msec = 999; + +		if (PJ_TIME_VAL_GT(timeout, max_timeout)) +			timeout = max_timeout; + +		int c = pj_ioqueue_poll(ice.cfg.stun_cfg.ioqueue, &timeout); +		if (c < 0) +			pj_thread_sleep(PJ_TIME_VAL_MSEC(timeout)); +	} + +	return 0; +} + +static void cb_on_rx_data(pj_ice_strans * strans, unsigned comp_id, +		void * pkt, pj_size_t size, +		const pj_sockaddr_t * src_addr, unsigned src_addr_len) +{ +	struct user_data * udata = pj_ice_strans_get_user_data(strans); +	ice_rx_data(udata->sptr, pkt, size); +} + +static void cb_on_ice_complete(pj_ice_strans * strans, +		pj_ice_strans_op op, pj_status_t status) +{ +	if (status != PJ_SUCCESS) { +		ice_perror("cb_on_ice_complete", status); +		ice_destroy(strans); +		return; +	} + +	struct user_data * udata = pj_ice_strans_get_user_data(strans); +	if (op == PJ_ICE_STRANS_OP_INIT) { +		pj_status_t istatus = pj_ice_strans_init_ice(strans, udata->role, NULL, NULL); +		if (istatus != PJ_SUCCESS) +			ice_perror("error creating session", istatus); + +		if (udata->cb_init) { +			ice_call_cb(udata->cb_init); +			hs_free_stable_ptr(udata->cb_init); +			udata->cb_init = NULL; +		} +	} + +	if (op == PJ_ICE_STRANS_OP_NEGOTIATION) { +		if (udata->cb_connect) { +			ice_call_cb(udata->cb_connect); +			hs_free_stable_ptr(udata->cb_connect); +			udata->cb_connect = NULL; +		} +	} +} + +static void ice_init(void) +{ +	static bool done = false; +	static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; +	pthread_mutex_lock(&mutex); + +	if (done) { +		pthread_mutex_unlock(&mutex); +		goto exit; +	} + +	pj_log_set_level(1); + +	if (pj_init() != PJ_SUCCESS) { +		fprintf(stderr, "pj_init failed\n"); +		goto exit; +	} +	if (pjlib_util_init() != PJ_SUCCESS) { +		fprintf(stderr, "pjlib_util_init failed\n"); +		goto exit; +	} +	if (pjnath_init() != PJ_SUCCESS) { +		fprintf(stderr, "pjnath_init failed\n"); +		goto exit; +	} + +	pj_caching_pool_init(&ice.cp, NULL, 0); + +	pj_ice_strans_cfg_default(&ice.cfg); +	ice.cfg.stun_cfg.pf = &ice.cp.factory; + +	ice.pool = pj_pool_create(&ice.cp.factory, "ice", 512, 512, NULL); + +	if (pj_timer_heap_create(ice.pool, 100, +				&ice.cfg.stun_cfg.timer_heap) != PJ_SUCCESS) { +		fprintf(stderr, "pj_timer_heap_create failed\n"); +		goto exit; +	} + +	if (pj_ioqueue_create(ice.pool, 16, &ice.cfg.stun_cfg.ioqueue) != PJ_SUCCESS) { +		fprintf(stderr, "pj_ioqueue_create failed\n"); +		goto exit; +	} + +	pj_thread_t * thread; +	if (pj_thread_create(ice.pool, "ice", &ice_worker_thread, +				NULL, 0, 0, &thread) != PJ_SUCCESS) { +		fprintf(stderr, "pj_thread_create failed\n"); +		goto exit; +	} + +	ice.cfg.af = pj_AF_INET(); +	ice.cfg.opt.aggressive = PJ_TRUE; + +	ice.cfg.stun.server.ptr = "discovery1.erebosprotocol.net"; +	ice.cfg.stun.server.slen = strlen(ice.cfg.stun.server.ptr); +	ice.cfg.stun.port = 29670; + +	ice.cfg.turn.server = ice.cfg.stun.server; +	ice.cfg.turn.port = ice.cfg.stun.port; +	ice.cfg.turn.auth_cred.type = PJ_STUN_AUTH_CRED_STATIC; +	ice.cfg.turn.auth_cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN; +	ice.cfg.turn.conn_type = PJ_TURN_TP_UDP; + +exit: +	done = true; +	pthread_mutex_unlock(&mutex); +} + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb) +{ +	ice_init(); + +	pj_ice_strans * res; + +	struct user_data * udata = malloc(sizeof(struct user_data)); +	udata->role = role; +	udata->sptr = sptr; +	udata->cb_init = cb; + +	pj_ice_strans_cb icecb = { +		.on_rx_data = cb_on_rx_data, +		.on_ice_complete = cb_on_ice_complete, +	}; + +	pj_status_t status = pj_ice_strans_create(NULL, &ice.cfg, 1, +			udata, &icecb, &res); + +	if (status != PJ_SUCCESS) +		ice_perror("error creating ice", status); + +	return res; +} + +void ice_destroy(pj_ice_strans * strans) +{ +	struct user_data * udata = pj_ice_strans_get_user_data(strans); +	if (udata->sptr) +		hs_free_stable_ptr(udata->sptr); +	if (udata->cb_init) +		hs_free_stable_ptr(udata->cb_init); +	if (udata->cb_connect) +		hs_free_stable_ptr(udata->cb_connect); +	free(udata); + +	pj_ice_strans_stop_ice(strans); +	pj_ice_strans_destroy(strans); +} + +ssize_t ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, +		char * def, char * candidates[], size_t maxlen, size_t maxcand) +{ +	int n; +	pj_str_t local_ufrag, local_pwd; +	pj_status_t status; + +	pj_ice_strans_get_ufrag_pwd(strans, &local_ufrag, &local_pwd, NULL, NULL); + +	n = snprintf(ufrag, maxlen, "%.*s", (int) local_ufrag.slen, local_ufrag.ptr); +	if (n < 0 || n == maxlen) +		return -PJ_ETOOSMALL; + +	n = snprintf(pass, maxlen, "%.*s", (int) local_pwd.slen, local_pwd.ptr); +	if (n < 0 || n == maxlen) +		return -PJ_ETOOSMALL; + +	pj_ice_sess_cand cand[PJ_ICE_ST_MAX_CAND]; +	char ipaddr[PJ_INET6_ADDRSTRLEN]; + +	status = pj_ice_strans_get_def_cand(strans, 1, &cand[0]); +	if (status != PJ_SUCCESS) +		return -status; + +	n = snprintf(def, maxlen, "%s %d", +			pj_sockaddr_print(&cand[0].addr, ipaddr, sizeof(ipaddr), 0), +			(int) pj_sockaddr_get_port(&cand[0].addr)); +	if (n < 0 || n == maxlen) +		return -PJ_ETOOSMALL; + +	unsigned cand_cnt = PJ_ARRAY_SIZE(cand); +	status = pj_ice_strans_enum_cands(strans, 1, &cand_cnt, cand); +	if (status != PJ_SUCCESS) +		return -status; + +	for (unsigned i = 0; i < cand_cnt && i < maxcand; i++) { +		char ipaddr[PJ_INET6_ADDRSTRLEN]; +		n = snprintf(candidates[i], maxlen, +				"%.*s %u %s %u %s", +				(int) cand[i].foundation.slen, cand[i].foundation.ptr, +				cand[i].prio, +				pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0), +				(unsigned) pj_sockaddr_get_port(&cand[i].addr), +				pj_ice_get_cand_type_name(cand[i].type)); + +		if (n < 0 || n == maxlen) +			return -PJ_ETOOSMALL; +	} + +	return cand_cnt; +} + +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, +		const char * ufrag, const char * pass, +		const char * defcand, const char * tcandidates[], size_t ncand) +{ +	unsigned def_port = 0; +	char     def_addr[80]; +	pj_bool_t done = PJ_FALSE; +	char line[256]; +	pj_ice_sess_cand candidates[PJ_ICE_ST_MAX_CAND]; + +	struct user_data * udata = pj_ice_strans_get_user_data(strans); +	udata->cb_connect = cb; + +	def_addr[0] = '\0'; + +	if (ncand == 0) { +		fprintf(stderr, "ICE: no candidates\n"); +		return; +	} + +	int cnt = sscanf(defcand, "%s %u", def_addr, &def_port); +	if (cnt != 2) { +		fprintf(stderr, "ICE: error parsing default candidate\n"); +		return; +	} + +	int okcand = 0; +	for (int i = 0; i < ncand; i++) { +		char foundation[32], ipaddr[80], type[32]; +		int prio, port; + +		int cnt = sscanf(tcandidates[i], "%s %d %s %d %s", +				foundation, &prio, +				ipaddr, &port, +				type); +		if (cnt != 5) +			continue; + +		pj_ice_sess_cand * cand = &candidates[okcand]; +		pj_bzero(cand, sizeof(*cand)); + +		if (strcmp(type, "host") == 0) +			cand->type = PJ_ICE_CAND_TYPE_HOST; +		else if (strcmp(type, "srflx") == 0) +			cand->type = PJ_ICE_CAND_TYPE_SRFLX; +		else if (strcmp(type, "relay") == 0) +			cand->type = PJ_ICE_CAND_TYPE_RELAYED; +		else +			continue; + +		cand->comp_id = 1; +		pj_strdup2(ice.pool, &cand->foundation, foundation); +		cand->prio = prio; + +		int af = strchr(ipaddr, ':') ? pj_AF_INET6() : pj_AF_INET(); +		pj_str_t tmpaddr = pj_str(ipaddr); +		pj_sockaddr_init(af, &cand->addr, NULL, 0); +		pj_status_t status = pj_sockaddr_set_str_addr(af, &cand->addr, &tmpaddr); +		if (status != PJ_SUCCESS) { +			fprintf(stderr, "ICE: invalid IP address \"%s\"\n", ipaddr); +			continue; +		} + +		pj_sockaddr_set_port(&cand->addr, (pj_uint16_t)port); +		okcand++; +	} + +	pj_str_t tmp_addr; +	pj_status_t status; + +	int af = strchr(def_addr, ':') ? pj_AF_INET6() : pj_AF_INET(); + +	pj_sockaddr_init(af, &ice.def_addr, NULL, 0); +	tmp_addr = pj_str(def_addr); +	status = pj_sockaddr_set_str_addr(af, &ice.def_addr, &tmp_addr); +	if (status != PJ_SUCCESS) { +		fprintf(stderr, "ICE: invalid default IP address \"%s\"\n", def_addr); +		return; +	} +	pj_sockaddr_set_port(&ice.def_addr, (pj_uint16_t) def_port); + +	pj_str_t rufrag, rpwd; +	status = pj_ice_strans_start_ice(strans, +			pj_cstr(&rufrag, ufrag), pj_cstr(&rpwd, pass), +			okcand, candidates); +	if (status != PJ_SUCCESS) { +		ice_perror("error starting ICE", status); +		return; +	} +} + +void ice_send(pj_ice_strans * strans, const char * data, size_t len) +{ +	if (!pj_ice_strans_sess_is_complete(strans)) { +		fprintf(stderr, "ICE: negotiation has not been started or is in progress\n"); +		return; +	} + +	pj_status_t status = pj_ice_strans_sendto(strans, 1, data, len, +			&ice.def_addr, pj_sockaddr_get_len(&ice.def_addr)); +	if (status != PJ_SUCCESS && status != PJ_EPENDING) +		ice_perror("error sending data", status); +} diff --git a/src/Erebos/ICE/pjproject.h b/src/Erebos/ICE/pjproject.h new file mode 100644 index 0000000..e230e75 --- /dev/null +++ b/src/Erebos/ICE/pjproject.h @@ -0,0 +1,14 @@ +#pragma once + +#include <pjnath.h> +#include <HsFFI.h> + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb); +void ice_destroy(pj_ice_strans * strans); + +ssize_t ice_encode_session(pj_ice_strans *, char * ufrag, char * pass, +		char * def, char * candidates[], size_t maxlen, size_t maxcand); +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, +		const char * ufrag, const char * pass, +		const char * defcand, const char * candidates[], size_t ncand); +void ice_send(pj_ice_strans *, const char * data, size_t len); diff --git a/src/Erebos/Identity.hs b/src/Erebos/Identity.hs new file mode 100644 index 0000000..8761fde --- /dev/null +++ b/src/Erebos/Identity.hs @@ -0,0 +1,402 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Erebos.Identity ( +    Identity, ComposedIdentity, UnifiedIdentity, +    IdentityData(..), ExtendedIdentityData(..), IdentityExtension(..), +    idData, idDataF, idExtData, idExtDataF, +    idName, idOwner, idUpdates, idKeyIdentity, idKeyMessage, +    eiddBase, eiddStoredBase, +    eiddName, eiddOwner, eiddKeyIdentity, eiddKeyMessage, + +    emptyIdentityData, +    emptyIdentityExtension, +    createIdentity, +    validateIdentity, validateIdentityF, validateIdentityFE, +    validateExtendedIdentity, validateExtendedIdentityF, validateExtendedIdentityFE, +    loadIdentity, loadUnifiedIdentity, + +    mergeIdentity, toUnifiedIdentity, toComposedIdentity, +    updateIdentity, updateOwners, +    sameIdentity, + +    unfoldOwners, +    finalOwner, +    displayIdentity, +) where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity qualified as I +import Control.Monad.Reader + +import Data.Either +import Data.Foldable +import Data.Function +import Data.List +import Data.Maybe +import Data.Ord +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.PubKey +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Util + +data Identity m = IdentityKind m => Identity +    { idData_ :: m (Stored (Signed ExtendedIdentityData)) +    , idName_ :: Maybe Text +    , idOwner_ :: Maybe ComposedIdentity +    , idUpdates_ :: [Stored (Signed ExtendedIdentityData)] +    , idKeyIdentity_ :: Stored PublicKey +    , idKeyMessage_ :: Stored PublicKey +    } + +deriving instance Show (m (Stored (Signed ExtendedIdentityData))) => Show (Identity m) + +class (Functor f, Foldable f) => IdentityKind f where +    ikFilterAncestors :: Storable a => f (Stored a) -> f (Stored a) + +instance IdentityKind I.Identity where +    ikFilterAncestors = id + +instance IdentityKind [] where +    ikFilterAncestors = filterAncestors + +type ComposedIdentity = Identity [] +type UnifiedIdentity = Identity I.Identity + +instance Eq (m (Stored (Signed ExtendedIdentityData))) => Eq (Identity m) where +    (==) = (==) `on` (idData_ &&& idUpdates_) + +data IdentityData = IdentityData +    { iddPrev :: [Stored (Signed IdentityData)] +    , iddName :: Maybe Text +    , iddOwner :: Maybe (Stored (Signed IdentityData)) +    , iddKeyIdentity :: Stored PublicKey +    , iddKeyMessage :: Maybe (Stored PublicKey) +    } +    deriving (Show) + +data IdentityExtension = IdentityExtension +    { idePrev :: [Stored (Signed ExtendedIdentityData)] +    , ideBase :: Stored (Signed IdentityData) +    , ideName :: Maybe Text +    , ideOwner :: Maybe (Stored (Signed ExtendedIdentityData)) +    } +    deriving (Show) + +data ExtendedIdentityData = BaseIdentityData IdentityData +                          | ExtendedIdentityData IdentityExtension +    deriving (Show) + +baseToExtended :: Stored (Signed IdentityData) -> Stored (Signed ExtendedIdentityData) +baseToExtended = unsafeMapStored (unsafeMapSigned BaseIdentityData) + +instance Storable IdentityData where +    store' idt = storeRec $ do +        mapM_ (storeRef "SPREV") $ iddPrev idt +        storeMbText "name" $ iddName idt +        storeMbRef "owner" $ iddOwner idt +        storeRef "key-id" $ iddKeyIdentity idt +        storeMbRef "key-msg" $ iddKeyMessage idt + +    load' = loadRec $ IdentityData +        <$> loadRefs "SPREV" +        <*> loadMbText "name" +        <*> loadMbRef "owner" +        <*> loadRef "key-id" +        <*> loadMbRef "key-msg" + +instance Storable IdentityExtension where +    store' IdentityExtension {..} = storeRec $ do +        mapM_ (storeRef "SPREV") idePrev +        storeRef "SBASE" ideBase +        storeMbText "name" ideName +        storeMbRef "owner" ideOwner + +    load' = loadRec $ IdentityExtension +        <$> loadRefs "SPREV" +        <*> loadRef "SBASE" +        <*> loadMbText "name" +        <*> loadMbRef "owner" + +instance Storable ExtendedIdentityData where +    store' (BaseIdentityData idata) = store' idata +    store' (ExtendedIdentityData idata) = store' idata + +    load' = msum +        [ BaseIdentityData <$> load' +        , ExtendedIdentityData <$> load' +        ] + +instance Mergeable (Maybe ComposedIdentity) where +    type Component (Maybe ComposedIdentity) = Signed ExtendedIdentityData +    mergeSorted = validateExtendedIdentityF +    toComponents = maybe [] idExtDataF + +idData :: UnifiedIdentity -> Stored (Signed IdentityData) +idData = I.runIdentity . idDataF + +idDataF :: Identity m -> m (Stored (Signed IdentityData)) +idDataF idt@Identity {} = ikFilterAncestors . fmap eiddStoredBase . idData_ $ idt + +idExtData :: UnifiedIdentity -> Stored (Signed ExtendedIdentityData) +idExtData = I.runIdentity . idExtDataF + +idExtDataF :: Identity m -> m (Stored (Signed ExtendedIdentityData)) +idExtDataF = idData_ + +idName :: Identity m -> Maybe Text +idName = idName_ + +idOwner :: Identity m -> Maybe ComposedIdentity +idOwner = idOwner_ + +idUpdates :: Identity m -> [Stored (Signed ExtendedIdentityData)] +idUpdates = idUpdates_ + +idKeyIdentity :: Identity m -> Stored PublicKey +idKeyIdentity = idKeyIdentity_ + +idKeyMessage :: Identity m -> Stored PublicKey +idKeyMessage = idKeyMessage_ + +eiddPrev :: ExtendedIdentityData -> [Stored (Signed ExtendedIdentityData)] +eiddPrev (BaseIdentityData idata) = baseToExtended <$> iddPrev idata +eiddPrev (ExtendedIdentityData IdentityExtension {..}) = baseToExtended ideBase : idePrev + +eiddBase :: ExtendedIdentityData -> IdentityData +eiddBase (BaseIdentityData idata) = idata +eiddBase (ExtendedIdentityData IdentityExtension {..}) = fromSigned ideBase + +eiddStoredBase :: Stored (Signed ExtendedIdentityData) -> Stored (Signed IdentityData) +eiddStoredBase ext = case fromSigned ext of +                          (BaseIdentityData idata) -> unsafeMapStored (unsafeMapSigned (const idata)) ext +                          (ExtendedIdentityData IdentityExtension {..}) -> ideBase + +eiddName :: ExtendedIdentityData -> Maybe Text +eiddName (BaseIdentityData idata) = iddName idata +eiddName (ExtendedIdentityData IdentityExtension {..}) = ideName + +eiddOwner :: ExtendedIdentityData -> Maybe (Stored (Signed ExtendedIdentityData)) +eiddOwner (BaseIdentityData idata) = baseToExtended <$> iddOwner idata +eiddOwner (ExtendedIdentityData IdentityExtension {..}) = ideOwner + +eiddKeyIdentity :: ExtendedIdentityData -> Stored PublicKey +eiddKeyIdentity = iddKeyIdentity . eiddBase + +eiddKeyMessage :: ExtendedIdentityData -> Maybe (Stored PublicKey) +eiddKeyMessage = iddKeyMessage . eiddBase + + +emptyIdentityData :: Stored PublicKey -> IdentityData +emptyIdentityData key = IdentityData +    { iddName = Nothing +    , iddPrev = [] +    , iddOwner = Nothing +    , iddKeyIdentity = key +    , iddKeyMessage = Nothing +    } + +emptyIdentityExtension :: Stored (Signed IdentityData) -> IdentityExtension +emptyIdentityExtension base = IdentityExtension +    { idePrev = [] +    , ideBase = base +    , ideName = Nothing +    , ideOwner = Nothing +    } + +isExtension :: Stored (Signed ExtendedIdentityData) -> Bool +isExtension x = case fromSigned x of BaseIdentityData {} -> False +                                     _ -> True + + +createIdentity :: Storage -> Maybe Text -> Maybe UnifiedIdentity -> IO UnifiedIdentity +createIdentity st name owner = do +    (secret, public) <- generateKeys st +    (_secretMsg, publicMsg) <- generateKeys st + +    let signOwner :: Signed a -> ReaderT Storage IO (Signed a) +        signOwner idd +            | Just o <- owner = do +                Just ownerSecret <- loadKeyMb (iddKeyIdentity $ fromSigned $ idData o) +                signAdd ownerSecret idd +            | otherwise = return idd + +    Just identity <- flip runReaderT st $ do +        baseData <- mstore =<< signOwner =<< sign secret =<< +            mstore (emptyIdentityData public) +                { iddOwner = idData <$> owner +                , iddKeyMessage = Just publicMsg +                } +        let extOwner = do +                odata <- idExtData <$> owner +                guard $ isExtension odata +                return odata + +        validateExtendedIdentityF . I.Identity <$> +            if isJust name || isJust extOwner +               then mstore =<< signOwner =<< sign secret =<< +                       mstore . ExtendedIdentityData =<< return (emptyIdentityExtension baseData) +                       { ideName = name +                       , ideOwner = extOwner +                       } +               else return $ baseToExtended baseData +    return identity + +validateIdentity :: Stored (Signed IdentityData) -> Maybe UnifiedIdentity +validateIdentity = validateIdentityF . I.Identity + +validateIdentityF :: IdentityKind m => m (Stored (Signed IdentityData)) -> Maybe (Identity m) +validateIdentityF = either (const Nothing) Just . runExcept . validateIdentityFE + +validateIdentityFE :: IdentityKind m => m (Stored (Signed IdentityData)) -> Except String (Identity m) +validateIdentityFE = validateExtendedIdentityFE . fmap baseToExtended + +validateExtendedIdentity :: Stored (Signed ExtendedIdentityData) -> Maybe UnifiedIdentity +validateExtendedIdentity = validateExtendedIdentityF . I.Identity + +validateExtendedIdentityF :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Maybe (Identity m) +validateExtendedIdentityF = either (const Nothing) Just . runExcept . validateExtendedIdentityFE + +validateExtendedIdentityFE :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Except String (Identity m) +validateExtendedIdentityFE mdata = do +    let idata = ikFilterAncestors mdata +    when (null idata) $ throwError "null data" +    mapM_ verifySignatures $ gatherPrevious S.empty $ toList idata +    Identity +        <$> pure idata +        <*> pure (lookupProperty eiddName idata) +        <*> case lookupProperty eiddOwner idata of +                 Nothing    -> return Nothing +                 Just owner -> return <$> validateExtendedIdentityFE [owner] +        <*> pure [] +        <*> pure (eiddKeyIdentity $ fromSigned $ minimum idata) +        <*> case lookupProperty eiddKeyMessage idata of +                 Nothing -> throwError "no message key" +                 Just mk -> return mk + +loadIdentity :: String -> LoadRec ComposedIdentity +loadIdentity name = maybe (throwError "identity validation failed") return . validateExtendedIdentityF =<< loadRefs name + +loadUnifiedIdentity :: String -> LoadRec UnifiedIdentity +loadUnifiedIdentity name = maybe (throwError "identity validation failed") return . validateExtendedIdentity =<< loadRef name + + +gatherPrevious :: Set (Stored (Signed ExtendedIdentityData)) -> [Stored (Signed ExtendedIdentityData)] -> Set (Stored (Signed ExtendedIdentityData)) +gatherPrevious res (n:ns) | n `S.member` res = gatherPrevious res ns +                          | otherwise        = gatherPrevious (S.insert n res) $ (eiddPrev $ fromSigned n) ++ ns +gatherPrevious res [] = res + +verifySignatures :: Stored (Signed ExtendedIdentityData) -> Except String () +verifySignatures sidd = do +    let idd = fromSigned sidd +        required = concat +            [ [ eiddKeyIdentity idd ] +            , map (eiddKeyIdentity . fromSigned) $ eiddPrev idd +            , map (eiddKeyIdentity . fromSigned) $ toList $ eiddOwner idd +            ] +    unless (all (fromStored sidd `isSignedBy`) required) $ do +        throwError "signature verification failed" + +lookupProperty :: forall a m. Foldable m => (ExtendedIdentityData -> Maybe a) -> m (Stored (Signed ExtendedIdentityData)) -> Maybe a +lookupProperty sel topHeads = findResult filteredLayers +    where findPropHeads :: Stored (Signed ExtendedIdentityData) -> [(Stored (Signed ExtendedIdentityData), a)] +          findPropHeads sobj | Just x <- sel $ fromSigned sobj = [(sobj, x)] +                             | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj) + +          propHeads :: [(Stored (Signed ExtendedIdentityData), a)] +          propHeads = findPropHeads =<< toList topHeads + +          historyLayers :: [Set (Stored (Signed ExtendedIdentityData))] +          historyLayers = generations $ map fst propHeads + +          filteredLayers :: [[(Stored (Signed ExtendedIdentityData), a)]] +          filteredLayers = scanl (\cur obsolete -> filter ((`S.notMember` obsolete) . fst) cur) propHeads historyLayers + +          findResult ([(_, x)] : _) = Just x +          findResult ([] : _) = Nothing +          findResult [] = Nothing +          findResult [xs] = Just $ snd $ minimumBy (comparing fst) xs +          findResult (_:rest) = findResult rest + +mergeIdentity :: (MonadStorage m, MonadError String m, MonadIO m) => Identity f -> m UnifiedIdentity +mergeIdentity idt | Just idt' <- toUnifiedIdentity idt = return idt' +mergeIdentity idt@Identity {..} = do +    (owner, ownerData) <- case idOwner_ of +        Nothing -> return (Nothing, Nothing) +        Just cowner | Just owner <- toUnifiedIdentity cowner -> return (Just owner, Nothing) +                    | otherwise -> do owner <- mergeIdentity cowner +                                      return (Just owner, Just $ idData owner) + +    let public = idKeyIdentity idt +    secret <- loadKey public + +    unifiedBaseData <- +        case toList $ idDataF idt of +            [idata] -> return idata +            idatas -> mstore =<< sign secret =<< mstore (emptyIdentityData public) +                { iddPrev = idatas, iddOwner = ownerData } + +    case filter isExtension $ toList $ idExtDataF idt of +        [] -> return Identity { idData_ = I.Identity (baseToExtended unifiedBaseData), idOwner_ = toComposedIdentity <$> owner, .. } +        extdata -> do +            unifiedExtendedData <- mstore =<< sign secret =<< +                (mstore . ExtendedIdentityData) (emptyIdentityExtension unifiedBaseData) +                    { idePrev = extdata } +            return Identity { idData_ = I.Identity unifiedExtendedData, idOwner_ = toComposedIdentity <$> owner, .. } + + +toUnifiedIdentity :: Identity m -> Maybe UnifiedIdentity +toUnifiedIdentity Identity {..} +    | [sdata] <- toList idData_ = Just Identity { idData_ = I.Identity sdata, .. } +    | otherwise = Nothing + +toComposedIdentity :: Identity m -> ComposedIdentity +toComposedIdentity Identity {..} = Identity { idData_ = toList idData_ +                                            , idOwner_ = toComposedIdentity <$> idOwner_ +                                            , .. +                                            } + +updateIdentity :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> ComposedIdentity +updateIdentity [] orig = toComposedIdentity orig +updateIdentity updates orig@Identity {} = +    case validateExtendedIdentityF $ ourUpdates ++ idata of +         Just updated -> updated +             { idOwner_ = updateIdentity ownerUpdates <$> idOwner_ updated +             , idUpdates_ = ownerUpdates +             } +         Nothing -> toComposedIdentity orig +    where idata = toList $ idData_ orig +          idataRoots = foldl' mergeUniq [] $ map storedRoots idata +          (ourUpdates, ownerUpdates) = partitionEithers $ flip map (filterAncestors $ updates ++ idUpdates_ orig) $ +              -- if an update is related to anything in idData_, use it here, otherwise push to owners +              \u -> if storedRoots u `intersectsSorted` idataRoots +                       then Left u +                       else Right u + +updateOwners :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> Identity m +updateOwners updates orig@Identity { idOwner_ = Just owner, idUpdates_ = cupdates } = +    orig { idOwner_ = Just $ updateIdentity updates owner, idUpdates_ = filterAncestors (updates ++ cupdates) } +updateOwners _ orig@Identity { idOwner_ = Nothing } = orig + +sameIdentity :: (Foldable m, Foldable m') => Identity m -> Identity m' -> Bool +sameIdentity x y = not $ S.null $ S.intersection (refset x) (refset y) +    where refset idt = foldr S.insert (ancestors $ toList $ idDataF idt) (idDataF idt) + + +unfoldOwners :: (Foldable m) => Identity m -> [ComposedIdentity] +unfoldOwners = unfoldr (fmap (\i -> (i, idOwner i))) . Just . toComposedIdentity + +finalOwner :: (Foldable m, Applicative m) => Identity m -> ComposedIdentity +finalOwner = last . unfoldOwners + +displayIdentity :: (Foldable m, Applicative m) => Identity m -> Text +displayIdentity identity = T.concat +    [ T.intercalate (T.pack " / ") $ map (fromMaybe (T.pack "<unnamed>") . idName) owners +    ] +    where owners = reverse $ unfoldOwners identity diff --git a/src/Erebos/Message.hs b/src/Erebos/Message.hs new file mode 100644 index 0000000..ea86ca0 --- /dev/null +++ b/src/Erebos/Message.hs @@ -0,0 +1,267 @@ +module Erebos.Message ( +    DirectMessage(..), +    sendDirectMessage, + +    DirectMessageAttributes(..), +    defaultDirectMessageAttributes, + +    DirectMessageThreads, +    toThreadList, + +    DirectMessageThread(..), +    threadToList, +    messageThreadView, + +    watchReceivedMessages, +    formatMessage, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.List +import Data.Ord +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Time.Format +import Data.Time.LocalTime + +import Erebos.Identity +import Erebos.Network +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data DirectMessage = DirectMessage +    { msgFrom :: ComposedIdentity +    , msgPrev :: [Stored DirectMessage] +    , msgTime :: ZonedTime +    , msgText :: Text +    } + +instance Storable DirectMessage where +    store' msg = storeRec $ do +        mapM_ (storeRef "from") $ idExtDataF $ msgFrom msg +        mapM_ (storeRef "PREV") $ msgPrev msg +        storeDate "time" $ msgTime msg +        storeText "text" $ msgText msg + +    load' = loadRec $ DirectMessage +        <$> loadIdentity "from" +        <*> loadRefs "PREV" +        <*> loadDate "time" +        <*> loadText "text" + +data DirectMessageAttributes = DirectMessageAttributes +    { dmOwnerMismatch :: ServiceHandler DirectMessage () +    } + +defaultDirectMessageAttributes :: DirectMessageAttributes +defaultDirectMessageAttributes = DirectMessageAttributes +    { dmOwnerMismatch = svcPrint "Owner mismatch" +    } + +instance Service DirectMessage where +    serviceID _ = mkServiceID "c702076c-4928-4415-8b6b-3e839eafcb0d" + +    type ServiceAttributes DirectMessage = DirectMessageAttributes +    defaultServiceAttributes _ = defaultDirectMessageAttributes + +    serviceHandler smsg = do +        let msg = fromStored smsg +        powner <- asks $ finalOwner . svcPeerIdentity +        erb <- svcGetLocal +        st <- getStorage +        let DirectMessageThreads prev _ = lookupSharedValue $ lsShared $ fromStored erb +            sent = findMsgProperty powner msSent prev +            received = findMsgProperty powner msReceived prev +            received' = filterAncestors $ smsg : received +        if powner `sameIdentity` msgFrom msg || +               filterAncestors sent == filterAncestors (smsg : sent) +           then do +               when (received' /= received) $ do +                   next <- wrappedStore st $ MessageState +                       { msPrev = prev +                       , msPeer = powner +                       , msReady = [] +                       , msSent = [] +                       , msReceived = received' +                       , msSeen = [] +                       } +                   let threads = DirectMessageThreads [next] (messageThreadView [next]) +                   shared <- makeSharedStateUpdate st threads (lsShared $ fromStored erb) +                   svcSetLocal =<< wrappedStore st (fromStored erb) { lsShared = [shared] } + +               when (powner `sameIdentity` msgFrom msg) $ do +                   replyStoredRef smsg + +           else join $ asks $ dmOwnerMismatch . svcAttributes + +    serviceNewPeer = syncDirectMessageToPeer . lookupSharedValue . lsShared . fromStored =<< svcGetLocal + +    serviceStorageWatchers _ = (:[]) $ +        SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncDirectMessageToPeer + + +data MessageState = MessageState +    { msPrev :: [Stored MessageState] +    , msPeer :: ComposedIdentity +    , msReady :: [Stored DirectMessage] +    , msSent :: [Stored DirectMessage] +    , msReceived :: [Stored DirectMessage] +    , msSeen :: [Stored DirectMessage] +    } + +data DirectMessageThreads = DirectMessageThreads [Stored MessageState] [DirectMessageThread] + +instance Eq DirectMessageThreads where +    DirectMessageThreads mss _ == DirectMessageThreads mss' _ = mss == mss' + +toThreadList :: DirectMessageThreads -> [DirectMessageThread] +toThreadList (DirectMessageThreads _ threads) = threads + +instance Storable MessageState where +    store' MessageState {..} = storeRec $ do +        mapM_ (storeRef "PREV") msPrev +        mapM_ (storeRef "peer") $ idDataF msPeer +        mapM_ (storeRef "ready") msReady +        mapM_ (storeRef "sent") msSent +        mapM_ (storeRef "received") msReceived +        mapM_ (storeRef "seen") msSeen + +    load' = loadRec $ do +        msPrev <- loadRefs "PREV" +        msPeer <- loadIdentity "peer" +        msReady <- loadRefs "ready" +        msSent <- loadRefs "sent" +        msReceived <- loadRefs "received" +        msSeen <- loadRefs "seen" +        return MessageState {..} + +instance Mergeable DirectMessageThreads where +    type Component DirectMessageThreads = MessageState +    mergeSorted mss = DirectMessageThreads mss (messageThreadView mss) +    toComponents (DirectMessageThreads mss _) = mss + +instance SharedType DirectMessageThreads where +    sharedTypeID _ = mkSharedTypeID "ee793681-5976-466a-b0f0-4e1907d3fade" + +findMsgProperty :: Foldable m => Identity m -> (MessageState -> [a]) -> [Stored MessageState] -> [a] +findMsgProperty pid sel mss = concat $ flip findProperty mss $ \x -> do +    guard $ msPeer x `sameIdentity` pid +    guard $ not $ null $ sel x +    return $ sel x + + +sendDirectMessage :: (Foldable f, Applicative f, MonadHead LocalState m, MonadError String m) +                  => Identity f -> Text -> m (Stored DirectMessage) +sendDirectMessage pid text = updateLocalHead $ \ls -> do +    let self = localIdentity $ fromStored ls +        powner = finalOwner pid +    flip updateSharedState ls $ \(DirectMessageThreads prev _) -> do +        let ready = findMsgProperty powner msReady prev +            received = findMsgProperty powner msReceived prev + +        time <- liftIO getZonedTime +        smsg <- mstore DirectMessage +            { msgFrom = toComposedIdentity $ finalOwner self +            , msgPrev = filterAncestors $ ready ++ received +            , msgTime = time +            , msgText = text +            } +        next <- mstore MessageState +            { msPrev = prev +            , msPeer = powner +            , msReady = [smsg] +            , msSent = [] +            , msReceived = [] +            , msSeen = [] +            } +        return (DirectMessageThreads [next] (messageThreadView [next]), smsg) + +syncDirectMessageToPeer :: DirectMessageThreads -> ServiceHandler DirectMessage () +syncDirectMessageToPeer (DirectMessageThreads mss _) = do +    pid <- finalOwner <$> asks svcPeerIdentity +    peer <- asks svcPeer +    let thread = messageThreadFor pid mss +    mapM_ (sendToPeerStored peer) $ msgHead thread +    updateLocalHead_ $ \ls -> do +        let powner = finalOwner pid +        flip updateSharedState_ ls $ \unchanged@(DirectMessageThreads prev _) -> do +            let ready = findMsgProperty powner msReady prev +                sent = findMsgProperty powner msSent prev +                sent' = filterAncestors (ready ++ sent) + +            if sent' /= sent +              then do +                next <- mstore MessageState +                    { msPrev = prev +                    , msPeer = powner +                    , msReady = [] +                    , msSent = sent' +                    , msReceived = [] +                    , msSeen = [] +                    } +                return $ DirectMessageThreads [next] (messageThreadView [next]) +              else do +                return unchanged + + +data DirectMessageThread = DirectMessageThread +    { msgPeer :: ComposedIdentity +    , msgHead :: [Stored DirectMessage] +    , msgSent :: [Stored DirectMessage] +    , msgSeen :: [Stored DirectMessage] +    } + +threadToList :: DirectMessageThread -> [DirectMessage] +threadToList thread = helper S.empty $ msgHead thread +    where helper seen msgs +              | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = +                  fromStored msg : helper (S.insert msg seen) (msgs' ++ msgPrev (fromStored msg)) +              | otherwise = [] +          cmpView msg = (zonedTimeToUTC $ msgTime $ fromStored msg, msg) + +messageThreadView :: [Stored MessageState] -> [DirectMessageThread] +messageThreadView = helper [] +    where helper used ms' = case filterAncestors ms' of +              mss@(sms : rest) +                  | any (sameIdentity $ msPeer $ fromStored sms) used -> +                      helper used $ msPrev (fromStored sms) ++ rest +                  | otherwise -> +                      let peer = msPeer $ fromStored sms +                       in messageThreadFor peer mss : helper (peer : used) (msPrev (fromStored sms) ++ rest) +              _ -> [] + +messageThreadFor :: ComposedIdentity -> [Stored MessageState] -> DirectMessageThread +messageThreadFor peer mss = +    let ready = findMsgProperty peer msReady mss +        sent = findMsgProperty peer msSent mss +        received = findMsgProperty peer msReceived mss +        seen = findMsgProperty peer msSeen mss + +     in DirectMessageThread +         { msgPeer = peer +         , msgHead = filterAncestors $ ready ++ received +         , msgSent = filterAncestors $ sent ++ received +         , msgSeen = filterAncestors $ ready ++ seen +         } + + +watchReceivedMessages :: Head LocalState -> (Stored DirectMessage -> IO ()) -> IO WatchedHead +watchReceivedMessages h f = do +    let self = finalOwner $ localIdentity $ headObject h +    watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \(DirectMessageThreads sms _) -> do +        forM_ (map fromStored sms) $ \ms -> do +            mapM_ f $ filter (not . sameIdentity self . msgFrom . fromStored) $ msReceived ms + +formatMessage :: TimeZone -> DirectMessage -> String +formatMessage tzone msg = concat +    [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg +    , maybe "<unnamed>" T.unpack $ idName $ msgFrom msg +    , ": " +    , T.unpack $ msgText msg +    ] diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs new file mode 100644 index 0000000..7c6a61e --- /dev/null +++ b/src/Erebos/Network.hs @@ -0,0 +1,860 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Network ( +    Server, +    startServer, +    stopServer, +    getNextPeerChange, +    ServerOptions(..), serverIdentity, defaultServerOptions, + +    Peer, peerServer, peerStorage, +    PeerAddress(..), peerAddress, +    PeerIdentity(..), peerIdentity, +    WaitingRef, wrDigest, +    Service(..), +    serverPeer, +#ifdef ENABLE_ICE_SUPPORT +    serverPeerIce, +#endif +    sendToPeer, sendToPeerStored, sendToPeerWith, +    runPeerService, + +    discoveryPort, +) where + +import Control.Concurrent +import Control.Concurrent.STM +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Function +import Data.IP qualified as IP +import Data.List +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe +import Data.Typeable +import Data.Word + +import Foreign.Ptr +import Foreign.Storable + +import GHC.Conc.Sync (unsafeIOToSTM) + +import Network.Socket hiding (ControlMessage) +import qualified Network.Socket.ByteString as S + +import Erebos.Channel +#ifdef ENABLE_ICE_SUPPORT +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Network.Protocol +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Key +import Erebos.Storage.Merge + + +discoveryPort :: PortNumber +discoveryPort = 29665 + +announceIntervalSeconds :: Int +announceIntervalSeconds = 60 + + +data Server = Server +    { serverStorage :: Storage +    , serverOrigHead :: Head LocalState +    , serverIdentity_ :: MVar UnifiedIdentity +    , serverThreads :: MVar [ThreadId] +    , serverSocket :: MVar Socket +    , serverRawPath :: SymFlow (PeerAddress, BC.ByteString) +    , serverControlFlow :: Flow (ControlMessage PeerAddress) (ControlRequest PeerAddress) +    , serverDataResponse :: TQueue (Peer, Maybe PartialRef) +    , serverIOActions :: TQueue (ExceptT String IO ()) +    , serverServices :: [SomeService] +    , serverServiceStates :: TMVar (M.Map ServiceID SomeServiceGlobalState) +    , serverPeers :: MVar (Map PeerAddress Peer) +    , serverChanPeer :: TChan Peer +    , serverErrorLog :: TQueue String +    } + +serverIdentity :: Server -> IO UnifiedIdentity +serverIdentity = readMVar . serverIdentity_ + +getNextPeerChange :: Server -> IO Peer +getNextPeerChange = atomically . readTChan . serverChanPeer + +data ServerOptions = ServerOptions +    { serverPort :: PortNumber +    , serverLocalDiscovery :: Bool +    } + +defaultServerOptions :: ServerOptions +defaultServerOptions = ServerOptions +    { serverPort = discoveryPort +    , serverLocalDiscovery = True +    } + + +data Peer = Peer +    { peerAddress :: PeerAddress +    , peerServer_ :: Server +    , peerConnection :: TVar (Either [(Bool, TransportPacket Ref, [TransportHeaderItem])] (Connection PeerAddress)) +    , peerIdentityVar :: TVar PeerIdentity +    , peerStorage_ :: Storage +    , peerInStorage :: PartialStorage +    , peerServiceState :: TMVar (M.Map ServiceID SomeServiceState) +    , peerWaitingRefs :: TMVar [WaitingRef] +    } + +peerServer :: Peer -> Server +peerServer = peerServer_ + +peerStorage :: Peer -> Storage +peerStorage = peerStorage_ + +getPeerChannel :: Peer -> STM ChannelState +getPeerChannel Peer {..} = either (const $ return ChannelNone) connGetChannel =<< readTVar peerConnection + +setPeerChannel :: Peer -> ChannelState -> STM () +setPeerChannel Peer {..} ch = do +    readTVar peerConnection >>= \case +        Left _ -> retry +        Right conn -> connSetChannel conn ch + +instance Eq Peer where +    (==) = (==) `on` peerIdentityVar + +data PeerAddress = DatagramAddress SockAddr +#ifdef ENABLE_ICE_SUPPORT +                 | PeerIceSession IceSession +#endif + +instance Show PeerAddress where +    show (DatagramAddress saddr) = unwords $ case IP.fromSockAddr saddr of +        Just (IP.IPv6 ipv6, port) +            | (0, 0, 0xffff, ipv4) <- IP.fromIPv6w ipv6 +            -> [show (IP.toIPv4w ipv4), show port] +        Just (addr, port) +            -> [show addr, show port] +        _ -> [show saddr] +#ifdef ENABLE_ICE_SUPPORT +    show (PeerIceSession ice) = show ice +#endif + +instance Eq PeerAddress where +    DatagramAddress addr == DatagramAddress addr' = addr == addr' +#ifdef ENABLE_ICE_SUPPORT +    PeerIceSession ice   == PeerIceSession ice'   = ice == ice' +    _                    == _                     = False +#endif + +instance Ord PeerAddress where +    compare (DatagramAddress addr) (DatagramAddress addr') = compare addr addr' +#ifdef ENABLE_ICE_SUPPORT +    compare (DatagramAddress _   ) _                       = LT +    compare _                      (DatagramAddress _    ) = GT +    compare (PeerIceSession ice  ) (PeerIceSession ice')   = compare ice ice' +#endif + + +data PeerIdentity = PeerIdentityUnknown (TVar [UnifiedIdentity -> ExceptT String IO ()]) +                  | PeerIdentityRef WaitingRef (TVar [UnifiedIdentity -> ExceptT String IO ()]) +                  | PeerIdentityFull UnifiedIdentity + +peerIdentity :: MonadIO m => Peer -> m PeerIdentity +peerIdentity = liftIO . atomically . readTVar . peerIdentityVar + + +lookupServiceType :: [TransportHeaderItem] -> Maybe ServiceID +lookupServiceType (ServiceType stype : _) = Just stype +lookupServiceType (_ : hs) = lookupServiceType hs +lookupServiceType [] = Nothing + +lookupNewStreams :: [TransportHeaderItem] -> [Word8] +lookupNewStreams (StreamOpen num : rest) = num : lookupNewStreams rest +lookupNewStreams (_ : rest) = lookupNewStreams rest +lookupNewStreams [] = [] + + +newWaitingRef :: RefDigest -> (Ref -> WaitingRefCallback) -> PacketHandler WaitingRef +newWaitingRef dgst act = do +    peer@Peer {..} <- gets phPeer +    wref <- WaitingRef peerStorage_ (partialRefFromDigest peerInStorage dgst) act <$> liftSTM (newTVar (Left [])) +    modifyTMVarP peerWaitingRefs (wref:) +    liftSTM $ writeTQueue (serverDataResponse $ peerServer peer) (peer, Nothing) +    return wref + + +forkServerThread :: Server -> IO () -> IO () +forkServerThread server act = modifyMVar_ (serverThreads server) $ \ts -> do +    (:ts) <$> forkIO act + +startServer :: ServerOptions -> Head LocalState -> (String -> IO ()) -> [SomeService] -> IO Server +startServer opt serverOrigHead logd' serverServices = do +    let serverStorage = headStorage serverOrigHead +    serverIdentity_ <- newMVar $ headLocalIdentity serverOrigHead +    serverThreads <- newMVar [] +    serverSocket <- newEmptyMVar +    (serverRawPath, protocolRawPath) <- newFlowIO +    (serverControlFlow, protocolControlFlow) <- newFlowIO +    serverDataResponse <- newTQueueIO +    serverIOActions <- newTQueueIO +    serverServiceStates <- newTMVarIO M.empty +    serverPeers <- newMVar M.empty +    serverChanPeer <- newTChanIO +    serverErrorLog <- newTQueueIO +    let server = Server {..} + +    chanSvc <- newTQueueIO + +    let logd = writeTQueue serverErrorLog +    forkServerThread server $ forever $ do +        logd' =<< atomically (readTQueue serverErrorLog) + +    forkServerThread server $ dataResponseWorker server +    forkServerThread server $ forever $ do +        either (atomically . logd) return =<< runExceptT =<< +            atomically (readTQueue serverIOActions) + +    broadcastAddreses <- getBroadcastAddresses discoveryPort + +    let open addr = do +            sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) +            putMVar serverSocket sock +            setSocketOption sock ReuseAddr 1 +            setSocketOption sock Broadcast 1 +            withFdSocket sock setCloseOnExecIfNeeded +            bind sock (addrAddress addr) +            return sock + +        loop sock = do +            when (serverLocalDiscovery opt) $ forkServerThread server $ forever $ do +                atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) broadcastAddreses +                threadDelay $ announceIntervalSeconds * 1000 * 1000 + +            let announceUpdate identity = do +                    st <- derivePartialStorage serverStorage +                    let selfRef = partialRef st $ storedRef $ idExtData identity +                        updateRefs = map refDigest $ selfRef : map (partialRef st . storedRef) (idUpdates identity) +                        ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] +                        hitems = map AnnounceUpdate updateRefs +                        packet = TransportPacket (TransportHeader $  hitems) [] + +                    ps <- readMVar serverPeers +                    forM_ ps $ \peer -> atomically $ do +                        ((,) <$> readTVar (peerIdentityVar peer) <*> getPeerChannel peer) >>= \case +                            (PeerIdentityFull _, ChannelEstablished _) -> +                                sendToPeerS peer ackedBy packet +                            _  -> return () + +            void $ watchHead serverOrigHead $ \h -> do +                let idt = headLocalIdentity h +                changedId <- modifyMVar serverIdentity_ $ \cur -> +                    return (idt, cur /= idt) +                when changedId $ do +                    writeFlowIO serverControlFlow $ UpdateSelfIdentity idt +                    announceUpdate idt + +            forM_ serverServices $ \(SomeService service _) -> do +                forM_ (serviceStorageWatchers service) $ \(SomeStorageWatcher sel act) -> do +                    watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do +                        withMVar serverPeers $ mapM_ $ \peer -> atomically $ do +                            readTVar (peerIdentityVar peer) >>= \case +                                PeerIdentityFull _ -> writeTQueue serverIOActions $ do +                                    runPeerService peer $ act x +                                _ -> return () + +            forkServerThread server $ forever $ do +                (msg, saddr) <- S.recvFrom sock 4096 +                writeFlowIO serverRawPath (DatagramAddress saddr, msg) + +            forkServerThread server $ forever $ do +                (paddr, msg) <- readFlowIO serverRawPath +                case paddr of +                    DatagramAddress addr -> void $ S.sendTo sock msg addr +#ifdef ENABLE_ICE_SUPPORT +                    PeerIceSession ice   -> iceSend ice msg +#endif + +            forkServerThread server $ forever $ do +                readFlowIO serverControlFlow >>= \case +                    NewConnection conn mbpid -> do +                        let paddr = connAddress conn +                        peer <- modifyMVar serverPeers $ \pvalue -> do +                            case M.lookup paddr pvalue of +                                Just peer -> return (pvalue, peer) +                                Nothing -> do +                                    peer <- mkPeer server paddr +                                    return (M.insert paddr peer pvalue, peer) + +                        forkServerThread server $ do +                            atomically $ do +                                readTVar (peerConnection peer) >>= \case +                                    Left packets -> writeFlowBulk (connData conn) $ reverse packets +                                    Right _ -> return () +                                writeTVar (peerConnection peer) (Right conn) + +                            case mbpid of +                                Just dgst -> do +                                    identity <- readMVar serverIdentity_ +                                    atomically $ runPacketHandler False peer $ do +                                        wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer +                                        readTVarP (peerIdentityVar peer) >>= \case +                                            PeerIdentityUnknown idwait -> do +                                                addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity +                                                writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait +                                                liftSTM $ writeTChan serverChanPeer peer +                                            _ -> return () +                                Nothing -> return () + +                            forever $ do +                                (secure, TransportPacket header objs) <- readFlowIO $ connData conn +                                prefs <- forM objs $ storeObject $ peerInStorage peer +                                identity <- readMVar serverIdentity_ +                                let svcs = map someServiceID serverServices +                                handlePacket identity secure peer chanSvc svcs header prefs + +                    ReceivedAnnounce addr _ -> do +                        void $ serverPeer' server addr + +            erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd protocolRawPath protocolControlFlow + +    forkServerThread server $ withSocketsDo $ do +        let hints = defaultHints +              { addrFlags = [AI_PASSIVE] +              , addrFamily = AF_INET6 +              , addrSocketType = Datagram +              } +        addr:_ <- getAddrInfo (Just hints) Nothing (Just $ show $ serverPort opt) +        bracket (open addr) close loop + +    forkServerThread server $ forever $ do +        (peer, svc, ref) <- atomically $ readTQueue chanSvc +        case find ((svc ==) . someServiceID) serverServices of +            Just service@(SomeService (_ :: Proxy s) attr) -> runPeerServiceOn (Just (service, attr)) peer (serviceHandler $ wrappedLoad @s ref) +            _ -> atomically $ logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + +    return server + +stopServer :: Server -> IO () +stopServer Server {..} = do +    mapM_ killThread =<< takeMVar serverThreads + +dataResponseWorker :: Server -> IO () +dataResponseWorker server = forever $ do +    (peer, npref) <- atomically (readTQueue $ serverDataResponse server) + +    wait <- atomically $ takeTMVar (peerWaitingRefs peer) +    list <- forM wait $ \wr@WaitingRef { wrefStatus = tvar } -> +        atomically (readTVar tvar) >>= \case +            Left ds -> case maybe id (filter . (/=) . refDigest) npref $ ds of +                [] -> copyRef (wrefStorage wr) (wrefPartial wr) >>= \case +                          Right ref -> do +                              atomically (writeTVar tvar $ Right ref) +                              forkServerThread server $ runExceptT (wrefAction wr ref) >>= \case +                                  Left err -> atomically $ writeTQueue (serverErrorLog server) err +                                  Right () -> return () + +                              return (Nothing, []) +                          Left dgst -> do +                              atomically (writeTVar tvar $ Left [dgst]) +                              return (Just wr, [dgst]) +                ds' -> do +                    atomically (writeTVar tvar $ Left ds') +                    return (Just wr, []) +            Right _ -> return (Nothing, []) +    atomically $ putTMVar (peerWaitingRefs peer) $ catMaybes $ map fst list + +    let reqs = concat $ map snd list +    when (not $ null reqs) $ do +        let packet = TransportPacket (TransportHeader $ map DataRequest reqs) [] +            ackedBy = concat [[ Rejected r, DataResponse r ] | r <- reqs ] +        atomically $ sendToPeerPlain peer ackedBy packet + + +newtype PacketHandler a = PacketHandler { unPacketHandler :: StateT PacketHandlerState (ExceptT String STM) a } +    deriving (Functor, Applicative, Monad, MonadState PacketHandlerState, MonadError String) + +instance MonadFail PacketHandler where +    fail = throwError + +runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM () +runPacketHandler secure peer@Peer {..} act = do +    let logd = writeTQueue $ serverErrorLog peerServer_ +    runExceptT (flip execStateT (PacketHandlerState peer [] [] []) $ unPacketHandler act) >>= \case +        Left err -> do +            logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err +        Right ph -> do +            when (not $ null $ phHead ph) $ do +                let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph) +                sendToPeerS' secure peer (phAckedBy ph) packet + +liftSTM :: STM a -> PacketHandler a +liftSTM = PacketHandler . lift . lift + +readTVarP :: TVar a -> PacketHandler a +readTVarP = liftSTM . readTVar + +writeTVarP :: TVar a -> a -> PacketHandler () +writeTVarP v = liftSTM . writeTVar v + +modifyTMVarP :: TMVar a -> (a -> a) -> PacketHandler () +modifyTMVarP v f = liftSTM $ putTMVar v . f =<< takeTMVar v + +data PacketHandlerState = PacketHandlerState +    { phPeer :: Peer +    , phHead :: [TransportHeaderItem] +    , phAckedBy :: [TransportHeaderItem] +    , phBody :: [Ref] +    } + +addHeader :: TransportHeaderItem -> PacketHandler () +addHeader h = modify $ \ph -> ph { phHead = h `appendDistinct` phHead ph } + +addAckedBy :: [TransportHeaderItem] -> PacketHandler () +addAckedBy hs = modify $ \ph -> ph { phAckedBy = foldr appendDistinct (phAckedBy ph) hs } + +addBody :: Ref -> PacketHandler () +addBody r = modify $ \ph -> ph { phBody = r `appendDistinct` phBody ph } + +openStream :: PacketHandler RawStreamWriter +openStream = do +    Peer {..} <- gets phPeer +    conn <- readTVarP peerConnection >>= \case +        Right conn -> return conn +        _          -> throwError "can't open stream without established connection" +    (hdr, writer, handler) <- liftSTM $ connAddWriteStream conn +    liftSTM $ writeTQueue (serverIOActions peerServer_) (liftIO $ forkServerThread peerServer_ handler) +    addHeader hdr +    return writer + +acceptStream :: Word8 -> PacketHandler RawStreamReader +acceptStream streamNumber = do +    Peer {..} <- gets phPeer +    conn <- readTVarP peerConnection >>= \case +        Right conn -> return conn +        _          -> throwError "can't accept stream without established connection" +    liftSTM $ connAddReadStream conn streamNumber + +appendDistinct :: Eq a => a -> [a] -> [a] +appendDistinct x (y:ys) | x == y    = y : ys +                        | otherwise = y : appendDistinct x ys +appendDistinct x [] = [x] + +handlePacket :: UnifiedIdentity -> Bool +    -> Peer -> TQueue (Peer, ServiceID, Ref) -> [ServiceID] +    -> TransportHeader -> [PartialRef] -> IO () +handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = atomically $ do +    let server = peerServer peer +    ochannel <- getPeerChannel peer +    let sidentity = idData identity +        plaintextRefs = map (refDigest . storedRef) $ concatMap (collectStoredObjects . wrappedLoad) $ concat +            [ [ storedRef sidentity ] +            , map storedRef $ idUpdates identity +            , case ochannel of +                   ChannelOurRequest _ req  -> [ storedRef req ] +                   ChannelOurAccept _ acc _ -> [ storedRef acc ] +                   _                        -> [] +            ] + +    runPacketHandler secure peer $ do +        let logd = liftSTM . writeTQueue (serverErrorLog server) +        forM_ headers $ \case +            Acknowledged dgst -> do +                liftSTM (getPeerChannel peer) >>= \case +                    ChannelOurAccept _ acc ch | refDigest (storedRef acc) == dgst -> do +                        liftSTM $ finalizedChannel peer ch identity +                    _ -> return () + +            Rejected dgst -> do +                logd $ "rejected by peer: " ++ show dgst + +            DataRequest dgst +                | secure || dgst `elem` plaintextRefs -> do +                    Right mref <- liftSTM $ unsafeIOToSTM $ +                        copyRef (peerStorage peer) $ +                        partialRefFromDigest (peerInStorage peer) dgst +                    addHeader $ DataResponse dgst +                    addAckedBy [ Acknowledged dgst, Rejected dgst ] +                    let bytes = lazyLoadBytes mref +                    -- TODO: MTU +                    if (secure && BL.length bytes > 500) +                      then do +                        stream <- openStream +                        liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do +                            writeByteStringToStream stream bytes +                      else do +                        addBody $ mref +                | otherwise -> do +                    logd $ "unauthorized data request for " ++ show dgst +                    addHeader $ Rejected dgst + +            DataResponse dgst -> if +                | Just pref <- find ((==dgst) . refDigest) prefs -> do +                    when (not secure) $ do +                        addHeader $ Acknowledged dgst +                    liftSTM $ writeTQueue (serverDataResponse server) (peer, Just pref) + +                | streamNumber : _ <- lookupNewStreams headers -> do +                    streamReader <- acceptStream streamNumber +                    liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do +                        (runExcept <$> readObjectsFromStream (peerInStorage peer) streamReader) >>= \case +                            Left err -> atomically $ writeTQueue (serverErrorLog server) $ +                                "failed to receive object from stream: " <> err +                            Right objs -> do +                                forM_ objs $ \obj -> do +                                    pref <- storeObject (peerInStorage peer) obj +                                    atomically $ writeTQueue (serverDataResponse server) (peer, Just pref) + +                | otherwise -> throwError $ "mismatched data response " ++ show dgst + +            AnnounceSelf dgst +                | dgst == refDigest (storedRef sidentity) -> return () +                | otherwise -> do +                    wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer +                    readTVarP (peerIdentityVar peer) >>= \case +                        PeerIdentityUnknown idwait -> do +                            addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity +                            writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait +                            liftSTM $ writeTChan (serverChanPeer $ peerServer peer) peer +                        _ -> return () + +            AnnounceUpdate dgst -> do +                readTVarP (peerIdentityVar peer) >>= \case +                    PeerIdentityFull _ -> do +                        void $ newWaitingRef dgst $ handleIdentityUpdate peer +                    _ -> return () + +            TrChannelRequest dgst -> do +                let process cookie = do +                        addHeader $ Acknowledged dgst +                        wref <- newWaitingRef dgst $ handleChannelRequest peer identity +                        liftSTM $ setPeerChannel peer $ ChannelPeerRequest cookie wref +                    reject = addHeader $ Rejected dgst + +                liftSTM (getPeerChannel peer) >>= \case +                    ChannelNone {} -> return () +                    ChannelCookieWait {} -> return () +                    ChannelCookieReceived cookie -> process $ Just cookie +                    ChannelCookieConfirmed cookie -> process $ Just cookie +                    ChannelOurRequest mbcookie our | dgst < refDigest (storedRef our) -> process mbcookie +                                                   | otherwise -> reject +                    ChannelPeerRequest mbcookie _ -> process mbcookie +                    ChannelOurAccept {} -> reject +                    ChannelEstablished {} -> process Nothing + +            TrChannelAccept dgst -> do +                let process = do +                        handleChannelAccept identity $ partialRefFromDigest (peerInStorage peer) dgst +                    reject = addHeader $ Rejected dgst +                liftSTM (getPeerChannel peer) >>= \case +                    ChannelNone {} -> reject +                    ChannelCookieWait {} -> reject +                    ChannelCookieReceived {} -> reject +                    ChannelCookieConfirmed {} -> reject +                    ChannelOurRequest {} -> process +                    ChannelPeerRequest {} -> process +                    ChannelOurAccept _ our _ | dgst < refDigest (storedRef our) -> process +                                             | otherwise -> addHeader $ Rejected dgst +                    ChannelEstablished {} -> process + +            ServiceType _ -> return () +            ServiceRef dgst +                | not secure -> throwError $ "service packet without secure channel" +                | Just svc <- lookupServiceType headers -> if +                    | svc `elem` svcs -> do +                        if dgst `elem` map refDigest prefs || True {- TODO: used by Message service to confirm receive -} +                           then do +                                void $ newWaitingRef dgst $ \ref -> +                                    liftIO $ atomically $ writeTQueue chanSvc (peer, svc, ref) +                           else throwError $ "missing service object " ++ show dgst +                    | otherwise -> addHeader $ Rejected dgst +                | otherwise -> throwError $ "service ref without type" + +            _ -> return () + + +withPeerIdentity :: MonadIO m => Peer -> (UnifiedIdentity -> ExceptT String IO ()) -> m () +withPeerIdentity peer act = liftIO $ atomically $ readTVar (peerIdentityVar peer) >>= \case +    PeerIdentityUnknown tvar -> modifyTVar' tvar (act:) +    PeerIdentityRef _ tvar -> modifyTVar' tvar (act:) +    PeerIdentityFull idt -> writeTQueue (serverIOActions $ peerServer peer) (act idt) + + +setupChannel :: UnifiedIdentity -> Peer -> UnifiedIdentity -> WaitingRefCallback +setupChannel identity peer upid = do +    req <- flip runReaderT (peerStorage peer) $ createChannelRequest identity upid +    let reqref = refDigest $ storedRef req +    let hitems = +            [ TrChannelRequest reqref +            , AnnounceSelf $ refDigest $ storedRef $ idData identity +            ] +    liftIO $ atomically $ do +        getPeerChannel peer >>= \case +            ChannelCookieConfirmed cookie -> do +                sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ +                    TransportPacket (TransportHeader hitems) [storedRef req] +                setPeerChannel peer $ ChannelOurRequest (Just cookie) req +            _ -> return () + +handleChannelRequest :: Peer -> UnifiedIdentity -> Ref -> WaitingRefCallback +handleChannelRequest peer identity req = do +    withPeerIdentity peer $ \upid -> do +        (acc, ch) <- flip runReaderT (peerStorage peer) $ acceptChannelRequest identity upid (wrappedLoad req) +        liftIO $ atomically $ do +            getPeerChannel peer >>= \case +                ChannelPeerRequest mbcookie wr | wrDigest wr == refDigest req -> do +                    setPeerChannel peer $ ChannelOurAccept mbcookie acc ch +                    let accref = refDigest $ storedRef acc +                        header = TrChannelAccept accref +                        ackedBy = [ Acknowledged accref, Rejected accref ] +                    sendToPeerPlain peer ackedBy $ TransportPacket (TransportHeader [header]) $ concat +                        [ [ storedRef $ acc ] +                        , [ storedRef $ signedData $ fromStored acc ] +                        , [ storedRef $ caKey $ fromStored $ signedData $ fromStored acc ] +                        , map storedRef $ signedSignature $ fromStored acc +                        ] +                _ -> writeTQueue (serverErrorLog $ peerServer peer) $ "unexpected channel request" + +handleChannelAccept :: UnifiedIdentity -> PartialRef -> PacketHandler () +handleChannelAccept identity accref = do +    peer <- gets phPeer +    liftSTM $ writeTQueue (serverIOActions $ peerServer peer) $ do +        withPeerIdentity peer $ \upid -> do +            copyRef (peerStorage peer) accref >>= \case +                Right acc -> do +                    ch <- acceptedChannel identity upid (wrappedLoad acc) +                    liftIO $ atomically $ do +                        sendToPeerS peer [] $ TransportPacket (TransportHeader [Acknowledged $ refDigest accref]) [] +                        finalizedChannel peer ch identity + +                Left dgst -> throwError $ "missing accept data " ++ BC.unpack (showRefDigest dgst) + + +finalizedChannel :: Peer -> Channel -> UnifiedIdentity -> STM () +finalizedChannel peer@Peer {..} ch self = do +    setPeerChannel peer $ ChannelEstablished ch + +    -- Identity update +    writeTQueue (serverIOActions peerServer_) $ liftIO $ atomically $ do +        let selfRef = refDigest $ storedRef $ idExtData $ self +            updateRefs = selfRef : map (refDigest . storedRef) (idUpdates self) +            ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] +        sendToPeerS peer ackedBy $ flip TransportPacket [] $ TransportHeader $ map AnnounceUpdate updateRefs + +    -- Notify services about new peer +    readTVar peerIdentityVar >>= \case +        PeerIdentityFull _ -> notifyServicesOfPeer peer +        _ -> return () + + +handleIdentityAnnounce :: UnifiedIdentity -> Peer -> Ref -> WaitingRefCallback +handleIdentityAnnounce self peer ref = liftIO $ atomically $ do +    let validateAndUpdate upds act = case validateIdentity $ wrappedLoad ref of +            Just pid' -> do +                let pid = fromMaybe pid' $ toUnifiedIdentity (updateIdentity upds pid') +                writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid +                writeTChan (serverChanPeer $ peerServer peer) peer +                act pid +                writeTQueue (serverIOActions $ peerServer peer) $ do +                    setupChannel self peer pid +            Nothing -> return () + +    readTVar (peerIdentityVar peer) >>= \case +        PeerIdentityRef wref wact +            | wrDigest wref == refDigest ref +            -> validateAndUpdate [] $ \pid -> do +                mapM_ (writeTQueue (serverIOActions $ peerServer peer) . ($ pid)) . +                    reverse =<< readTVar wact + +        PeerIdentityFull pid +            | idData pid `precedes` wrappedLoad ref +            -> validateAndUpdate (idUpdates pid) $ \_ -> do +                notifyServicesOfPeer peer + +        _ -> return () + +handleIdentityUpdate :: Peer -> Ref -> WaitingRefCallback +handleIdentityUpdate peer ref = liftIO $ atomically $ do +    pidentity <- readTVar (peerIdentityVar peer) +    if  | PeerIdentityFull pid <- pidentity +        , Just pid' <- toUnifiedIdentity $ updateIdentity [wrappedLoad ref] pid +        -> do +            writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid' +            writeTChan (serverChanPeer $ peerServer peer) peer +            when (idData pid /= idData pid') $ notifyServicesOfPeer peer + +        | otherwise -> return () + +notifyServicesOfPeer :: Peer -> STM () +notifyServicesOfPeer peer@Peer { peerServer_ = Server {..} } = do +    writeTQueue serverIOActions $ do +        forM_ serverServices $ \service@(SomeService _ attrs) -> +            runPeerServiceOn (Just (service, attrs)) peer serviceNewPeer + + +mkPeer :: Server -> PeerAddress -> IO Peer +mkPeer peerServer_ peerAddress = do +    peerConnection <- newTVarIO (Left []) +    peerIdentityVar <- newTVarIO . PeerIdentityUnknown =<< newTVarIO [] +    peerStorage_ <- deriveEphemeralStorage $ serverStorage peerServer_ +    peerInStorage <- derivePartialStorage peerStorage_ +    peerServiceState <- newTMVarIO M.empty +    peerWaitingRefs <- newTMVarIO [] +    return Peer {..} + +serverPeer :: Server -> SockAddr -> IO Peer +serverPeer server paddr = do +    serverPeer' server (DatagramAddress paddr) + +#ifdef ENABLE_ICE_SUPPORT +serverPeerIce :: Server -> IceSession -> IO Peer +serverPeerIce server@Server {..} ice = do +    let paddr = PeerIceSession ice +    peer <- serverPeer' server paddr +    iceSetChan ice $ mapFlow undefined (paddr,) serverRawPath +    return peer +#endif + +serverPeer' :: Server -> PeerAddress -> IO Peer +serverPeer' server paddr = do +    (peer, hello) <- modifyMVar (serverPeers server) $ \pvalue -> do +        case M.lookup paddr pvalue of +             Just peer -> return (pvalue, (peer, False)) +             Nothing -> do +                 peer <- mkPeer server paddr +                 return (M.insert paddr peer pvalue, (peer, True)) +    when hello $ atomically $ do +        writeFlow (serverControlFlow server) (RequestConnection paddr) +    return peer + + +sendToPeer :: (Service s, MonadIO m) => Peer -> s -> m () +sendToPeer peer packet = sendToPeerList peer [ServiceReply (Left packet) True] + +sendToPeerStored :: (Service s, MonadIO m) => Peer -> Stored s -> m () +sendToPeerStored peer spacket = sendToPeerList peer [ServiceReply (Right spacket) True] + +sendToPeerList :: (Service s, MonadIO m) => Peer -> [ServiceReply s] -> m () +sendToPeerList peer parts = do +    let st = peerStorage peer +    srefs <- liftIO $ fmap catMaybes $ forM parts $ \case +        ServiceReply (Left x) use -> Just . (,use) <$> store st x +        ServiceReply (Right sx) use -> return $ Just (storedRef sx, use) +        ServiceFinally act -> act >> return Nothing +    let dgsts = map (refDigest . fst) srefs +    let content = map fst $ filter (\(ref, use) -> use && BL.length (lazyLoadBytes ref) < 500) srefs -- TODO: MTU +        header = TransportHeader (ServiceType (serviceID $ head parts) : map ServiceRef dgsts) +        packet = TransportPacket header content +        ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- dgsts ] +    liftIO $ atomically $ sendToPeerS peer ackedBy packet + +sendToPeerS' :: Bool -> Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS' secure Peer {..} ackedBy packet = do +    readTVar peerConnection >>= \case +        Left xs -> writeTVar peerConnection $ Left $ (secure, packet, ackedBy) : xs +        Right conn -> writeFlow (connData conn) (secure, packet, ackedBy) + +sendToPeerS :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS = sendToPeerS' True + +sendToPeerPlain :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerPlain = sendToPeerS' False + +sendToPeerWith :: forall s m. (Service s, MonadIO m, MonadError String m) => Peer -> (ServiceState s -> ExceptT String IO (Maybe s, ServiceState s)) -> m () +sendToPeerWith peer fobj = do +    let sproxy = Proxy @s +        sid = serviceID sproxy +    res <- liftIO $ do +        svcs <- atomically $ takeTMVar (peerServiceState peer) +        (svcs', res) <- runExceptT (fobj $ fromMaybe (emptyServiceState sproxy) $ fromServiceState sproxy =<< M.lookup sid svcs) >>= \case +            Right (obj, s') -> return $ (M.insert sid (SomeServiceState sproxy s') svcs, Right obj) +            Left err -> return $ (svcs, Left err) +        atomically $ putTMVar (peerServiceState peer) svcs' +        return res + +    case res of +         Right (Just obj) -> sendToPeer peer obj +         Right Nothing -> return () +         Left err -> throwError err + + +lookupService :: forall s. Service s => Proxy s -> [SomeService] -> Maybe (SomeService, ServiceAttributes s) +lookupService proxy (service@(SomeService (_ :: Proxy t) attr) : rest) +    | Just (Refl :: s :~: t) <- eqT = Just (service, attr) +    | otherwise = lookupService proxy rest +lookupService _ [] = Nothing + +runPeerService :: forall s m. (Service s, MonadIO m) => Peer -> ServiceHandler s () -> m () +runPeerService = runPeerServiceOn Nothing + +runPeerServiceOn :: forall s m. (Service s, MonadIO m) => Maybe (SomeService, ServiceAttributes s) -> Peer -> ServiceHandler s () -> m () +runPeerServiceOn mbservice peer handler = liftIO $ do +    let server = peerServer peer +        proxy = Proxy @s +        svc = serviceID proxy +        logd = writeTQueue (serverErrorLog server) +    case mbservice `mplus` lookupService proxy (serverServices server) of +        Just (service, attr) -> +            atomically (readTVar (peerIdentityVar peer)) >>= \case +                PeerIdentityFull peerId -> do +                    (global, svcs) <- atomically $ (,) +                        <$> takeTMVar (serverServiceStates server) +                        <*> takeTMVar (peerServiceState peer) +                    case (fromMaybe (someServiceEmptyState service) $ M.lookup svc svcs, +                            fromMaybe (someServiceEmptyGlobalState service) $ M.lookup svc global) of +                        ((SomeServiceState (_ :: Proxy ps) ps), +                                (SomeServiceGlobalState (_ :: Proxy gs) gs)) -> do +                            Just (Refl :: s :~: ps) <- return $ eqT +                            Just (Refl :: s :~: gs) <- return $ eqT + +                            let inp = ServiceInput +                                    { svcAttributes = attr +                                    , svcPeer = peer +                                    , svcPeerIdentity = peerId +                                    , svcServer = server +                                    , svcPrintOp = atomically . logd +                                    } +                            reloadHead (serverOrigHead server) >>= \case +                                Nothing -> atomically $ do +                                    logd $ "current head deleted" +                                    putTMVar (peerServiceState peer) svcs +                                    putTMVar (serverServiceStates server) global +                                Just h -> do +                                    (rsp, (s', gs')) <- runServiceHandler h inp ps gs handler +                                    moveKeys (peerStorage peer) (serverStorage server) +                                    when (not (null rsp)) $ do +                                        sendToPeerList peer rsp +                                    atomically $ do +                                        putTMVar (peerServiceState peer) $ M.insert svc (SomeServiceState proxy s') svcs +                                        putTMVar (serverServiceStates server) $ M.insert svc (SomeServiceGlobalState proxy gs') global +                _ -> do +                    atomically $ logd $ "can't run service handler on peer with incomplete identity " ++ show (peerAddress peer) + +        _ -> atomically $ do +            logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + + +foreign import ccall unsafe "Network/ifaddrs.h broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32) +foreign import ccall unsafe "stdlib.h free" cFree :: Ptr Word32 -> IO () + +getBroadcastAddresses :: PortNumber -> IO [SockAddr] +getBroadcastAddresses port = do +    ptr <- cBroadcastAddresses +    let parse i = do +            w <- peekElemOff ptr i +            if w == 0 then return [] +                      else (SockAddrInet port w:) <$> parse (i + 1) +    addrs <- parse 0 +    cFree ptr +    return addrs diff --git a/src/Erebos/Network.hs-boot b/src/Erebos/Network.hs-boot new file mode 100644 index 0000000..849bfc1 --- /dev/null +++ b/src/Erebos/Network.hs-boot @@ -0,0 +1,8 @@ +module Erebos.Network where + +import Erebos.Storage + +data Server +data Peer + +peerStorage :: Peer -> Storage diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs new file mode 100644 index 0000000..e5eb652 --- /dev/null +++ b/src/Erebos/Network/Protocol.hs @@ -0,0 +1,753 @@ +module Erebos.Network.Protocol ( +    TransportPacket(..), +    transportToObject, +    TransportHeader(..), +    TransportHeaderItem(..), + +    WaitingRef(..), +    WaitingRefCallback, +    wrDigest, + +    ChannelState(..), + +    ControlRequest(..), +    ControlMessage(..), +    erebosNetworkProtocol, + +    Connection, +    connAddress, +    connData, +    connGetChannel, +    connSetChannel, + +    RawStreamReader, RawStreamWriter, +    connAddWriteStream, +    connAddReadStream, +    readStreamToList, +    readObjectsFromStream, +    writeByteStringToStream, + +    module Erebos.Flow, +) where + +import Control.Applicative +import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.Except + +import Data.Bits +import Data.ByteString (ByteString) +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.List +import Data.Maybe +import Data.Text (Text) +import Data.Text qualified as T +import Data.Void +import Data.Word + +import System.Clock + +import Erebos.Channel +import Erebos.Flow +import Erebos.Identity +import Erebos.Service +import Erebos.Storage + + +protocolVersion :: Text +protocolVersion = T.pack "0.1" + +protocolVersions :: [Text] +protocolVersions = [protocolVersion] + + +data TransportPacket a = TransportPacket TransportHeader [a] + +data TransportHeader = TransportHeader [TransportHeaderItem] +    deriving (Show) + +data TransportHeaderItem +    = Acknowledged RefDigest +    | AcknowledgedSingle Integer +    | Rejected RefDigest +    | ProtocolVersion Text +    | Initiation RefDigest +    | CookieSet Cookie +    | CookieEcho Cookie +    | DataRequest RefDigest +    | DataResponse RefDigest +    | AnnounceSelf RefDigest +    | AnnounceUpdate RefDigest +    | TrChannelRequest RefDigest +    | TrChannelAccept RefDigest +    | ServiceType ServiceID +    | ServiceRef RefDigest +    | StreamOpen Word8 +    deriving (Eq, Show) + +newtype Cookie = Cookie ByteString +    deriving (Eq, Show) + +isHeaderItemAcknowledged :: TransportHeaderItem -> Bool +isHeaderItemAcknowledged = \case +    Acknowledged {} -> False +    AcknowledgedSingle {} -> False +    Rejected {} -> False +    ProtocolVersion {} -> False +    Initiation {} -> False +    CookieSet {} -> False +    CookieEcho {} -> False +    _ -> True + +transportToObject :: PartialStorage -> TransportHeader -> PartialObject +transportToObject st (TransportHeader items) = Rec $ map single items +    where single = \case +              Acknowledged dgst -> (BC.pack "ACK", RecRef $ partialRefFromDigest st dgst) +              AcknowledgedSingle num -> (BC.pack "ACK", RecInt num) +              Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst) +              ProtocolVersion ver -> (BC.pack "VER", RecText ver) +              Initiation dgst -> (BC.pack "INI", RecRef $ partialRefFromDigest st dgst) +              CookieSet (Cookie bytes) -> (BC.pack "CKS", RecBinary bytes) +              CookieEcho (Cookie bytes) -> (BC.pack "CKE", RecBinary bytes) +              DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst) +              DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst) +              AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst) +              AnnounceUpdate dgst -> (BC.pack "ANU", RecRef $ partialRefFromDigest st dgst) +              TrChannelRequest dgst -> (BC.pack "CRQ", RecRef $ partialRefFromDigest st dgst) +              TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst) +              ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype) +              ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst) +              StreamOpen num -> (BC.pack "STO", RecInt $ fromIntegral num) + +transportFromObject :: PartialObject -> Maybe TransportHeader +transportFromObject (Rec items) = case catMaybes $ map single items of +                                       [] -> Nothing +                                       titems -> Just $ TransportHeader titems +    where single (name, content) = if +              | name == BC.pack "ACK", RecRef ref <- content -> Just $ Acknowledged $ refDigest ref +              | name == BC.pack "ACK", RecInt num <- content -> Just $ AcknowledgedSingle num +              | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref +              | name == BC.pack "VER", RecText ver <- content -> Just $ ProtocolVersion ver +              | name == BC.pack "INI", RecRef ref <- content -> Just $ Initiation $ refDigest ref +              | name == BC.pack "CKS", RecBinary bytes <- content -> Just $ CookieSet (Cookie bytes) +              | name == BC.pack "CKE", RecBinary bytes <- content -> Just $ CookieEcho (Cookie bytes) +              | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref +              | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref +              | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref +              | name == BC.pack "ANU", RecRef ref <- content -> Just $ AnnounceUpdate $ refDigest ref +              | name == BC.pack "CRQ", RecRef ref <- content -> Just $ TrChannelRequest $ refDigest ref +              | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref +              | name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid +              | name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref +              | name == BC.pack "STO", RecInt num <- content -> Just $ StreamOpen $ fromIntegral num +              | otherwise -> Nothing +transportFromObject _ = Nothing + + +data GlobalState addr = (Eq addr, Show addr) => GlobalState +    { gIdentity :: TVar (UnifiedIdentity, [UnifiedIdentity]) +    , gConnections :: TVar [Connection addr] +    , gDataFlow :: SymFlow (addr, ByteString) +    , gControlFlow :: Flow (ControlRequest addr) (ControlMessage addr) +    , gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject)) +    , gLog :: String -> STM () +    , gStorage :: PartialStorage +    , gNowVar :: TVar TimeSpec +    , gNextTimeout :: TVar TimeSpec +    , gInitConfig :: Ref +    } + +data Connection addr = Connection +    { cGlobalState :: GlobalState addr +    , cAddress :: addr +    , cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) +    , cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject) +    , cChannel :: TVar ChannelState +    , cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem]) +    , cSentPackets :: TVar [SentPacket] +    , cToAcknowledge :: TVar [Integer] +    , cInStreams :: TVar [(Word8, Stream)] +    , cOutStreams :: TVar [(Word8, Stream)] +    } + +connAddress :: Connection addr -> addr +connAddress = cAddress + +connData :: Connection addr -> Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) +connData = cDataUp + +connGetChannel :: Connection addr -> STM ChannelState +connGetChannel Connection {..} = readTVar cChannel + +connSetChannel :: Connection addr -> ChannelState -> STM () +connSetChannel Connection {..} ch = do +    writeTVar cChannel ch + +connAddWriteStream :: Connection addr -> STM (TransportHeaderItem, RawStreamWriter, IO ()) +connAddWriteStream conn@Connection {..} = do +    let GlobalState {..} = cGlobalState + +    outStreams <- readTVar cOutStreams +    let doInsert n (s@(n', _) : rest) | n == n' = +            fmap (s:) <$> doInsert (n + 1) rest +        doInsert n streams = do +            (sFlowIn, sFlowOut) <- newFlow +            sNextSequence <- newTVar 0 +            let info = (n, Stream {..}) +            return (info, info : streams) +    ((streamNumber, stream), outStreams') <- doInsert 1 outStreams +    writeTVar cOutStreams outStreams' + +    let go = do +            msg <- atomically $ readFlow (sFlowOut stream) +            let (plain, cont) = case msg of +                    StreamData {..} -> (stpData, True) +                    StreamClosed {} -> (BC.empty, False) +                    -- TODO: send channel closed only after delivering all previous data packets +                    -- TODO: free channel number after delivering stream closed +            let secure = True +                plainAckedBy = [] + +            mbch <- atomically (readTVar cChannel) >>= return . \case +                ChannelEstablished ch   -> Just ch +                ChannelOurAccept _ _ ch -> Just ch +                _                       -> Nothing + +            mbs <- case mbch of +                Just ch -> do +                    runExceptT (channelEncrypt ch $ B.concat +                            [ B.singleton streamNumber +                            , B.singleton (fromIntegral (stpSequence msg) :: Word8) +                            , plain +                            ] ) >>= \case +                        Right (ctext, counter) -> do +                            let isAcked = True +                            return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) +                        Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err +                                       return Nothing +                Nothing | secure    -> return Nothing +                        | otherwise -> return $ Just (plain, plainAckedBy) + +            case mbs of +                Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) +                Nothing -> return () + +            when cont go + +    return (StreamOpen streamNumber, sFlowIn stream, go) + +connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader +connAddReadStream Connection {..} streamNumber = do +    inStreams <- readTVar cInStreams +    let doInsert (s@(n, _) : rest) +            | streamNumber <  n = fmap (s:) <$> doInsert rest +            | streamNumber == n = doInsert rest +        doInsert streams = do +            (sFlowIn, sFlowOut) <- newFlow +            sNextSequence <- newTVar 0 +            let stream = Stream {..} +            return (stream, (streamNumber, stream) : streams) +    (stream, inStreams') <- doInsert inStreams +    writeTVar cInStreams inStreams' +    return $ sFlowOut stream + + +type RawStreamReader = Flow StreamPacket Void +type RawStreamWriter = Flow Void StreamPacket + +data Stream = Stream +    { sFlowIn :: Flow Void StreamPacket +    , sFlowOut :: Flow StreamPacket Void +    , sNextSequence :: TVar Word64 +    } + +data StreamPacket +    = StreamData +        { stpSequence :: Word64 +        , stpData :: BC.ByteString +        } +    | StreamClosed +        { stpSequence :: Word64 +        } + +readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) +readStreamToList stream = readFlowIO stream >>= \case +    StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream +    StreamClosed sqEnd  -> return (sqEnd, []) + +readObjectsFromStream :: PartialStorage -> RawStreamReader -> IO (Except String [PartialObject]) +readObjectsFromStream st stream = do +    (seqEnd, list) <- readStreamToList stream +    print (seqEnd, length list, list) +    let validate s ((s', bytes) : rest) +            | s == s'   = (bytes : ) <$> validate (s + 1) rest +            | s >  s'   = validate s rest +            | otherwise = throwError "missing object chunk" +        validate s [] +            | s == seqEnd = return [] +            | otherwise = throwError "content length mismatch" +    return $ do +        content <- BL.fromChunks <$> validate 0 list +        deserializeObjects st content + +writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO () +writeByteStringToStream stream = go 0 +  where +    go seqNum bstr +        | BL.null bstr = writeFlowIO stream $ StreamClosed seqNum +        | otherwise    = do +            let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU +            writeFlowIO stream $ StreamData seqNum (BL.toStrict cur) +            go (seqNum + 1) rest + + +data WaitingRef = WaitingRef +    { wrefStorage :: Storage +    , wrefPartial :: PartialRef +    , wrefAction :: Ref -> WaitingRefCallback +    , wrefStatus :: TVar (Either [RefDigest] Ref) +    } + +type WaitingRefCallback = ExceptT String IO () + +wrDigest :: WaitingRef -> RefDigest +wrDigest = refDigest . wrefPartial + + +data ChannelState = ChannelNone +                  | ChannelCookieWait -- sent initiation, waiting for response +                  | ChannelCookieReceived Cookie -- received cookie, but no cookie echo yet +                  | ChannelCookieConfirmed Cookie -- received cookie echo, no need to send from our side +                  | ChannelOurRequest (Maybe Cookie) (Stored ChannelRequest) +                  | ChannelPeerRequest (Maybe Cookie) WaitingRef +                  | ChannelOurAccept (Maybe Cookie) (Stored ChannelAccept) Channel +                  | ChannelEstablished Channel + + +data SentPacket = SentPacket +    { spTime :: TimeSpec +    , spRetryCount :: Int +    , spAckedBy :: Maybe (TransportHeaderItem -> Bool) +    , spData :: BC.ByteString +    } + + +data ControlRequest addr = RequestConnection addr +                         | SendAnnounce addr +                         | UpdateSelfIdentity UnifiedIdentity + +data ControlMessage addr = NewConnection (Connection addr) (Maybe RefDigest) +                         | ReceivedAnnounce addr RefDigest + + +erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) +                      => UnifiedIdentity +                      -> (String -> STM ()) +                      -> SymFlow (addr, ByteString) +                      -> Flow (ControlRequest addr) (ControlMessage addr) +                      -> IO () +erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do +    gIdentity <- newTVarIO (initialIdentity, []) +    gConnections <- newTVarIO [] +    gNextUp <- newEmptyTMVarIO +    mStorage <- memoryStorage +    gStorage <- derivePartialStorage mStorage + +    startTime <- getTime MonotonicRaw +    gNowVar <- newTVarIO startTime +    gNextTimeout <- newTVarIO startTime +    gInitConfig <- store mStorage $ (Rec [] :: Object) + +    let gs = GlobalState {..} + +    let signalTimeouts = forever $ do +            now <- getTime MonotonicRaw +            next <- atomically $ do +                writeTVar gNowVar now +                readTVar gNextTimeout + +            let waitTill time +                    | time > now = threadDelay $ fromInteger (toNanoSecs (time - now)) `div` 1000 +                    | otherwise = threadDelay maxBound +                waitForUpdate = atomically $ do +                    next' <- readTVar gNextTimeout +                    when (next' == next) retry + +            race_ (waitTill next) waitForUpdate + +    race_ signalTimeouts $ forever $ join $ atomically $ +        passUpIncoming gs <|> processIncoming gs <|> processOutgoing gs + + +getConnection :: GlobalState addr -> addr -> STM (Connection addr) +getConnection gs addr = do +    maybe (newConnection gs addr) return =<< findConnection gs addr + +findConnection :: GlobalState addr -> addr -> STM (Maybe (Connection addr)) +findConnection GlobalState {..} addr = do +    find ((addr==) . cAddress) <$> readTVar gConnections + +newConnection :: GlobalState addr -> addr -> STM (Connection addr) +newConnection cGlobalState@GlobalState {..} addr = do +    conns <- readTVar gConnections + +    let cAddress = addr +    (cDataUp, cDataInternal) <- newFlow +    cChannel <- newTVar ChannelNone +    cSecureOutQueue <- newTQueue +    cSentPackets <- newTVar [] +    cToAcknowledge <- newTVar [] +    cInStreams <- newTVar [] +    cOutStreams <- newTVar [] +    let conn = Connection {..} + +    writeTVar gConnections (conn : conns) +    return conn + +passUpIncoming :: GlobalState addr -> STM (IO ()) +passUpIncoming GlobalState {..} = do +    (Connection {..}, up) <- takeTMVar gNextUp +    writeFlow cDataInternal up +    return $ return () + +processIncoming :: GlobalState addr -> STM (IO ()) +processIncoming gs@GlobalState {..} = do +    guard =<< isEmptyTMVar gNextUp +    guard =<< canWriteFlow gControlFlow + +    (addr, msg) <- readFlow gDataFlow +    mbconn <- findConnection gs addr + +    mbch <- case mbconn of +        Nothing -> return Nothing +        Just conn -> readTVar (cChannel conn) >>= return . \case +            ChannelEstablished ch   -> Just ch +            ChannelOurAccept _ _ ch -> Just ch +            _                       -> Nothing + +    return $ do +        let deserialize = liftEither . runExcept . deserializeObjects gStorage . BL.fromStrict +        let parse = case B.uncons msg of +                Just (b, enc) +                    | b .&. 0xE0 == 0x80 -> do +                        ch <- maybe (throwError "unexpected encrypted packet") return mbch +                        (dec, counter) <- channelDecrypt ch enc + +                        case B.uncons dec of +                            Just (0x00, content) -> do +                                objs <- deserialize content +                                return $ Left (True, objs, Just counter) + +                            Just (snum, dec') +                                | snum < 64 +                                , Just (seq8, content) <- B.uncons dec' +                                -> do +                                    return $ Right (snum, seq8, content, counter) + +                            Just (_, _) -> do +                                throwError "unexpected stream header" + +                            Nothing -> do +                                throwError "empty decrypted content" + +                    | b .&. 0xE0 == 0x60 -> do +                        objs <- deserialize msg +                        return $ Left (False, objs, Nothing) + +                    | otherwise -> throwError "invalid packet" + +                Nothing -> throwError "empty packet" + +        runExceptT parse >>= \case +            Right (Left (secure, objs, mbcounter)) +                | hobj:content <- objs +                , Just header@(TransportHeader items) <- transportFromObject hobj +                -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case +                    Just (conn@Connection {..}, mbup) -> atomically $ do +                        case mbcounter of +                            Just counter | any isHeaderItemAcknowledged items -> +                                modifyTVar' cToAcknowledge (fromIntegral counter :) +                            _ -> return () +                        processAcknowledgements gs conn items +                        case mbup of +                            Just up -> putTMVar gNextUp (conn, (secure, up)) +                            Nothing -> return () +                    Nothing -> return () + +                | otherwise -> atomically $ do +                      gLog $ show addr ++ ": invalid objects" +                      gLog $ show objs + +            Right (Right (snum, seq8, content, counter)) +                | Just Connection {..} <- mbconn +                -> atomically $ do +                    (lookup snum <$> readTVar cInStreams) >>= \case +                        Nothing -> +                            gLog $ "unexpected stream number " ++ show snum + +                        Just Stream {..} -> do +                            expectedSequence <- readTVar sNextSequence +                            let seqFull = expectedSequence - 0x80 + fromIntegral (seq8 - fromIntegral expectedSequence + 0x80 :: Word8) +                            sdata <- if +                                | B.null content -> do +                                    modifyTVar' cInStreams $ filter ((/=snum) . fst) +                                    return $ StreamClosed seqFull +                                | otherwise -> do +                                    writeTVar sNextSequence $ max expectedSequence (seqFull + 1) +                                    return $ StreamData seqFull content +                            writeFlow sFlowIn sdata +                            modifyTVar' cToAcknowledge (fromIntegral counter :) + +                | otherwise -> do +                    atomically $ gLog $ show addr <> ": stream packet without connection" + +            Left err -> do +                atomically $ gLog $ show addr <> ": failed to parse packet: " <> err + +processPacket :: GlobalState addr -> Either addr (Connection addr) -> Bool -> TransportPacket a -> IO (Maybe (Connection addr, Maybe (TransportPacket a))) +processPacket gs@GlobalState {..} econn secure packet@(TransportPacket (TransportHeader header) _) = if +    -- Established secure communication +    | Right conn <- econn, secure +    -> return $ Just (conn, Just packet) + +    -- Plaintext communication with cookies to prove origin +    | cookie:_ <- mapMaybe (\case CookieEcho x -> Just x; _ -> Nothing) header +    -> verifyCookie gs addr cookie >>= \case +        True -> do +            atomically $ do +                conn@Connection {..} <- getConnection gs addr +                channel <- readTVar cChannel +                let received = listToMaybe $ mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header +                case received `mplus` channelCurrentCookie channel of +                    Just current -> do +                        cookieEchoReceived gs conn mbpid current +                        return $ Just (conn, Just packet) +                    Nothing -> do +                        gLog $ show addr <> ": missing cookie set, dropping " <> show header +                        return $ Nothing + +        False -> do +            atomically $ gLog $ show addr <> ": cookie verification failed, dropping " <> show header +            return Nothing + +    -- Response to initiation packet +    | cookie:_ <- mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header +    , Just _ <- version +    , Right conn@Connection {..} <- econn +    -> do +        atomically $ readTVar cChannel >>= \case +            ChannelCookieWait -> do +                writeTVar cChannel $ ChannelCookieReceived cookie +                writeFlow gControlFlow (NewConnection conn mbpid) +                return $ Just (conn, Nothing) +            _ -> return Nothing + +    -- Initiation packet +    | _:_ <- mapMaybe (\case Initiation x -> Just x; _ -> Nothing) header +    , Just ver <- version +    -> do +        cookie <- createCookie gs addr +        atomically $ do +            identity <- fst <$> readTVar gIdentity +            let reply = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader +                    [ CookieSet cookie +                    , AnnounceSelf $ refDigest $ storedRef $ idData identity +                    , ProtocolVersion ver +                    ] +            writeFlow gDataFlow (addr, reply) +        return Nothing + +    -- Announce packet outside any connection +    | dgst:_ <- mapMaybe (\case AnnounceSelf x -> Just x; _ -> Nothing) header +    , Just _ <- version +    -> do +        atomically $ do +            (cur, past) <- readTVar gIdentity +            when (not $ dgst `elem` map (refDigest . storedRef . idData) (cur : past)) $ do +                writeFlow gControlFlow $ ReceivedAnnounce addr dgst +        return Nothing + +    | otherwise -> do +        atomically $ gLog $ show addr <> ": dropping packet " <> show header +        return Nothing + +  where +    addr = either id cAddress econn +    mbpid = listToMaybe $ mapMaybe (\case AnnounceSelf dgst -> Just dgst; _ -> Nothing) header +    version = listToMaybe $ filter (\v -> ProtocolVersion v `elem` header) protocolVersions + +channelCurrentCookie :: ChannelState -> Maybe Cookie +channelCurrentCookie = \case +    ChannelCookieReceived cookie -> Just cookie +    ChannelCookieConfirmed cookie -> Just cookie +    ChannelOurRequest mbcookie _ -> mbcookie +    ChannelPeerRequest mbcookie _ -> mbcookie +    ChannelOurAccept mbcookie _ _ -> mbcookie +    _ -> Nothing + +cookieEchoReceived :: GlobalState addr -> Connection addr -> Maybe RefDigest -> Cookie -> STM () +cookieEchoReceived GlobalState {..} conn@Connection {..} mbpid cookieSet = do +    readTVar cChannel >>= \case +        ChannelNone -> newConn +        ChannelCookieWait -> newConn +        ChannelCookieReceived {} -> update +        _ -> return () +  where +    update = do +        writeTVar cChannel $ ChannelCookieConfirmed cookieSet +    newConn = do +        update +        writeFlow gControlFlow (NewConnection conn mbpid) + +generateCookieHeaders :: GlobalState addr -> addr -> ChannelState -> IO [TransportHeaderItem] +generateCookieHeaders gs addr ch = catMaybes <$> sequence [ echoHeader, setHeader ] +  where +    echoHeader = return $ CookieEcho <$> channelCurrentCookie ch +    setHeader = case ch of +        ChannelCookieWait {} -> Just . CookieSet <$> createCookie gs addr +        ChannelCookieReceived {} -> Just . CookieSet <$> createCookie gs addr +        _ -> return Nothing + +createCookie :: GlobalState addr -> addr -> IO Cookie +createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr) + +verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool +verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie + +resendBytes :: Connection addr -> SentPacket -> IO () +resendBytes Connection {..} sp = do +    let GlobalState {..} = cGlobalState +    now <- getTime MonotonicRaw +    atomically $ do +        when (isJust $ spAckedBy sp) $ do +            modifyTVar' cSentPackets $ (:) sp +                { spTime = now +                , spRetryCount = spRetryCount sp + 1 +                } +        writeFlow gDataFlow (cAddress, spData sp) + +sendBytes :: Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () +sendBytes conn bs ackedBy = resendBytes conn +    SentPacket +        { spTime = undefined +        , spRetryCount = -1 +        , spAckedBy = ackedBy +        , spData = bs +        } + +processOutgoing :: forall addr. GlobalState addr -> STM (IO ()) +processOutgoing gs@GlobalState {..} = do + +    let sendNextPacket :: Connection addr -> STM (IO ()) +        sendNextPacket conn@Connection {..} = do +            channel <- readTVar cChannel +            let mbch = case channel of +                    ChannelEstablished ch -> Just ch +                    _                     -> Nothing + +            let checkOutstanding +                    | isJust mbch = readTQueue cSecureOutQueue +                    | otherwise = retry + +                checkAcknowledgements +                    | isJust mbch = do +                        acks <- readTVar cToAcknowledge +                        if null acks then retry +                                     else return (True, TransportPacket (TransportHeader []) [], []) +                    | otherwise = retry + +            (secure, packet@(TransportPacket (TransportHeader hitems) content), plainAckedBy) <- +                checkOutstanding <|> readFlow cDataInternal <|> checkAcknowledgements + +            when (isNothing mbch && secure) $ do +                writeTQueue cSecureOutQueue (secure, packet, plainAckedBy) + +            acknowledge <- case mbch of +                Nothing -> return [] +                Just _ -> swapTVar cToAcknowledge [] + +            return $ do +                cookieHeaders <- generateCookieHeaders gs cAddress channel +                let header = TransportHeader $ map AcknowledgedSingle acknowledge ++ cookieHeaders ++ hitems + +                let plain = BL.concat $ +                        (serializeObject $ transportToObject gStorage header) +                        : map lazyLoadBytes content + +                mbs <- case mbch of +                    Just ch -> do +                        runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case +                            Right (ctext, counter) -> do +                                let isAcked = any isHeaderItemAcknowledged hitems +                                return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) +                            Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err +                                           return Nothing +                    Nothing | secure    -> return Nothing +                            | otherwise -> return $ Just (BL.toStrict plain, plainAckedBy) + +                case mbs of +                    Just (bs, ackedBy) -> sendBytes conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) +                    Nothing -> return () + +    let retransmitPacket :: Connection addr -> STM (IO ()) +        retransmitPacket conn@Connection {..} = do +            now <- readTVar gNowVar +            (sp, rest) <- readTVar cSentPackets >>= \case +                sps@(_:_) -> return (last sps, init sps) +                _         -> retry +            let nextTry = spTime sp + fromNanoSecs 1000000000 +            if now < nextTry +              then do +                nextTimeout <- readTVar gNextTimeout +                if nextTimeout <= now || nextTry < nextTimeout +                   then do writeTVar gNextTimeout nextTry +                           return $ return () +                   else retry +              else do +                writeTVar cSentPackets rest +                return $ resendBytes conn sp + +    let handleControlRequests = readFlow gControlFlow >>= \case +            RequestConnection addr -> do +                conn@Connection {..} <- getConnection gs addr +                identity <- fst <$> readTVar gIdentity +                readTVar cChannel >>= \case +                    ChannelNone -> do +                        let packet = BL.toStrict $ BL.concat +                                [ serializeObject $ transportToObject gStorage $ TransportHeader $ +                                    [ Initiation $ refDigest gInitConfig +                                    , AnnounceSelf $ refDigest $ storedRef $ idData identity +                                    ] ++ map ProtocolVersion protocolVersions +                                , lazyLoadBytes gInitConfig +                                ] +                        writeTVar cChannel ChannelCookieWait +                        return $ sendBytes conn packet $ Just $ \case CookieSet {} -> True; _ -> False +                    _ -> return $ return () + +            SendAnnounce addr -> do +                identity <- fst <$> readTVar gIdentity +                let packet = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader $ +                        [ AnnounceSelf $ refDigest $ storedRef $ idData identity +                        ] ++ map ProtocolVersion protocolVersions +                writeFlow gDataFlow (addr, packet) +                return $ return () + +            UpdateSelfIdentity nid -> do +                (cur, past) <- readTVar gIdentity +                writeTVar gIdentity (nid, cur : past) +                return $ return () + +    conns <- readTVar gConnections +    msum $ concat $ +        [ map retransmitPacket conns +        , map sendNextPacket conns +        , [ handleControlRequests ] +        ] + +processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM () +processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do +    modifyTVar' cSentPackets $ filter $ \sp -> not (fromJust (spAckedBy sp) hitem) diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c new file mode 100644 index 0000000..37c3e00 --- /dev/null +++ b/src/Erebos/Network/ifaddrs.c @@ -0,0 +1,41 @@ +#include "ifaddrs.h" + +#include <arpa/inet.h> +#include <ifaddrs.h> +#include <net/if.h> +#include <stdlib.h> +#include <sys/types.h> +#include <endian.h> + +uint32_t * broadcast_addresses(void) +{ +	struct ifaddrs * addrs; +	if (getifaddrs(&addrs) < 0) +		return 0; + +	size_t capacity = 16, count = 0; +	uint32_t * ret = malloc(sizeof(uint32_t) * capacity); + +	for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) { +		if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET && +				ifa->ifa_flags & IFF_BROADCAST) { +			if (count + 2 >= capacity) { +				capacity *= 2; +				uint32_t * nret = realloc(ret, sizeof(uint32_t) * capacity); +				if (nret) { +					ret = nret; +				} else { +					free(ret); +					return 0; +				} +			} + +			ret[count] = ((struct sockaddr_in*)ifa->ifa_broadaddr)->sin_addr.s_addr; +			count++; +		} +	} + +	freeifaddrs(addrs); +	ret[count] = 0; +	return ret; +} diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h new file mode 100644 index 0000000..06d26ec --- /dev/null +++ b/src/Erebos/Network/ifaddrs.h @@ -0,0 +1,3 @@ +#include <stdint.h> + +uint32_t * broadcast_addresses(void); diff --git a/src/Erebos/Pairing.hs b/src/Erebos/Pairing.hs new file mode 100644 index 0000000..2166e71 --- /dev/null +++ b/src/Erebos/Pairing.hs @@ -0,0 +1,242 @@ +module Erebos.Pairing ( +    PairingService(..), +    PairingState(..), +    PairingAttributes(..), +    PairingResult(..), +    PairingFailureReason(..), + +    pairingRequest, +    pairingAccept, +    pairingReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Crypto.Random + +import Data.Bits +import Data.ByteArray (Bytes, convert) +import qualified Data.ByteArray as BA +import qualified Data.ByteString.Char8 as BC +import Data.Kind +import Data.Maybe +import Data.Typeable +import Data.Word + +import Erebos.Identity +import Erebos.Network +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage + +data PairingService a = PairingRequest (Stored (Signed IdentityData)) (Stored (Signed IdentityData)) RefDigest +                      | PairingResponse Bytes +                      | PairingRequestNonce Bytes +                      | PairingAccept a +                      | PairingReject + +data PairingState a = NoPairing +                    | OurRequest UnifiedIdentity UnifiedIdentity Bytes +                    | OurRequestConfirm (Maybe (PairingVerifiedResult a)) +                    | OurRequestReady +                    | PeerRequest UnifiedIdentity UnifiedIdentity Bytes RefDigest +                    | PeerRequestConfirm +                    | PairingDone + +data PairingFailureReason a = PairingUserRejected +                            | PairingUnexpectedMessage (PairingState a) (PairingService a) +                            | PairingFailedOther String + +data PairingAttributes a = PairingAttributes +    { pairingHookRequest :: ServiceHandler (PairingService a) () +    , pairingHookResponse :: String -> ServiceHandler (PairingService a) () +    , pairingHookRequestNonce :: String -> ServiceHandler (PairingService a) () +    , pairingHookRequestNonceFailed :: ServiceHandler (PairingService a) () +    , pairingHookConfirmedResponse :: ServiceHandler (PairingService a) () +    , pairingHookConfirmedRequest :: ServiceHandler (PairingService a) () +    , pairingHookAcceptedResponse :: ServiceHandler (PairingService a) () +    , pairingHookAcceptedRequest :: ServiceHandler (PairingService a) () +    , pairingHookVerifyFailed :: ServiceHandler (PairingService a) () +    , pairingHookRejected :: ServiceHandler (PairingService a) () +    , pairingHookFailed :: PairingFailureReason a -> ServiceHandler (PairingService a) () +    } + +class (Typeable a, Storable a) => PairingResult a where +    type PairingVerifiedResult a :: Type +    type PairingVerifiedResult a = a + +    pairingServiceID :: proxy a -> ServiceID +    pairingVerifyResult :: a -> ServiceHandler (PairingService a) (Maybe (PairingVerifiedResult a)) +    pairingFinalizeRequest :: PairingVerifiedResult a -> ServiceHandler (PairingService a) () +    pairingFinalizeResponse :: ServiceHandler (PairingService a) a +    defaultPairingAttributes :: proxy (PairingService a) -> PairingAttributes a + + +instance Storable a => Storable (PairingService a) where +    store' (PairingRequest idReq idRsp x) = storeRec $ do +        storeRef "id-req" idReq +        storeRef "id-rsp" idRsp +        storeBinary "request" x +    store' (PairingResponse x) = storeRec $ storeBinary "response" x +    store' (PairingRequestNonce x) = storeRec $ storeBinary "reqnonce" x +    store' (PairingAccept x) = store' x +    store' (PairingReject) = storeRec $ storeEmpty "reject" + +    load' = do +        res <- loadRec $ do +            (req :: Maybe Bytes) <- loadMbBinary "request" +            idReq <- loadMbRef "id-req" +            idRsp <- loadMbRef "id-rsp" +            rsp <- loadMbBinary "response" +            rnonce <- loadMbBinary "reqnonce" +            rej <- loadMbEmpty "reject" +            return $ catMaybes +                    [ PairingRequest <$> idReq <*> idRsp <*> (refDigestFromByteString =<< req) +                    , PairingResponse <$> rsp +                    , PairingRequestNonce <$> rnonce +                    , const PairingReject <$> rej +                    ] +        case res of +             x:_ -> return x +             [] -> PairingAccept <$> load' + + +instance PairingResult a => Service (PairingService a) where +    serviceID _ = pairingServiceID @a Proxy + +    type ServiceAttributes (PairingService a) = PairingAttributes a +    defaultServiceAttributes = defaultPairingAttributes + +    type ServiceState (PairingService a) = PairingState a +    emptyServiceState _ = NoPairing + +    serviceHandler spacket = ((,fromStored spacket) <$> svcGet) >>= \case +        (NoPairing, PairingRequest pdata sdata confirm) -> do +            self <- maybe (throwError "failed to validate received identity") return $ validateIdentity sdata +            self' <- maybe (throwError "failed to validate own identity") return . +                validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal +            when (not $ self `sameIdentity` self') $ do +                throwError "pairing request to different identity" + +            peer <- maybe (throwError "failed to validate received peer identity") return $ validateIdentity pdata +            peer' <- asks $ svcPeerIdentity +            when (not $ peer `sameIdentity` peer') $ do +                throwError "pairing request from different identity" + +            join $ asks $ pairingHookRequest . svcAttributes +            nonce <- liftIO $ getRandomBytes 32 +            svcSet $ PeerRequest peer self nonce confirm +            replyPacket $ PairingResponse nonce +        (NoPairing, _) -> return () + +        (PairingDone, _) -> return () +        (_, PairingReject) -> do +            join $ asks $ pairingHookRejected . svcAttributes +            svcSet NoPairing + +        (OurRequest self peer nonce, PairingResponse pnonce) -> do +            hook <- asks $ pairingHookResponse . svcAttributes +            hook $ confirmationNumber $ nonceDigest self peer nonce pnonce +            svcSet $ OurRequestConfirm Nothing +            replyPacket $ PairingRequestNonce nonce +        x@(OurRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x + +        (OurRequestConfirm _, PairingAccept x) -> do +            flip catchError (reject . PairingFailedOther) $ do +                pairingVerifyResult x >>= \case +                    Just x' -> do +                        join $ asks $ pairingHookConfirmedRequest . svcAttributes +                        svcSet $ OurRequestConfirm (Just x') +                    Nothing -> do +                        join $ asks $ pairingHookVerifyFailed . svcAttributes +                        svcSet NoPairing +                        replyPacket PairingReject + +        x@(OurRequestConfirm _, _) -> reject $ uncurry PairingUnexpectedMessage x + +        (OurRequestReady, PairingAccept x) -> do +            flip catchError (reject . PairingFailedOther) $ do +                pairingVerifyResult x >>= \case +                    Just x' -> do +                        pairingFinalizeRequest x' +                        join $ asks $ pairingHookAcceptedResponse . svcAttributes +                        svcSet $ PairingDone +                    Nothing -> do +                        join $ asks $ pairingHookVerifyFailed . svcAttributes +                        throwError "" +        x@(OurRequestReady, _) -> reject $ uncurry PairingUnexpectedMessage x + +        (PeerRequest peer self nonce dgst, PairingRequestNonce pnonce) -> do +            if dgst == nonceDigest peer self pnonce BA.empty +               then do hook <- asks $ pairingHookRequestNonce . svcAttributes +                       hook $ confirmationNumber $ nonceDigest peer self pnonce nonce +                       svcSet PeerRequestConfirm +               else do join $ asks $ pairingHookRequestNonceFailed . svcAttributes +                       svcSet NoPairing +                       replyPacket PairingReject +        x@(PeerRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x +        x@(PeerRequestConfirm, _) -> reject $ uncurry PairingUnexpectedMessage x + +reject :: PairingResult a => PairingFailureReason a -> ServiceHandler (PairingService a) () +reject reason = do +    join $ asks $ flip pairingHookFailed reason . svcAttributes +    svcSet NoPairing +    replyPacket PairingReject + + +nonceDigest :: UnifiedIdentity -> UnifiedIdentity -> Bytes -> Bytes -> RefDigest +nonceDigest idReq idRsp nonceReq nonceRsp = hashToRefDigest $ serializeObject $ Rec +        [ (BC.pack "id-req", RecRef $ storedRef $ idData idReq) +        , (BC.pack "id-rsp", RecRef $ storedRef $ idData idRsp) +        , (BC.pack "nonce-req", RecBinary $ convert nonceReq) +        , (BC.pack "nonce-rsp", RecBinary $ convert nonceRsp) +        ] + +confirmationNumber :: RefDigest -> String +confirmationNumber dgst = +    case map fromIntegral $ BA.unpack dgst :: [Word32] of +         (a:b:c:d:_) -> let str = show $ ((a `shift` 24) .|. (b `shift` 16) .|. (c `shift` 8) .|. d) `mod` (10 ^ len) +                         in replicate (len - length str) '0' ++ str +         _ -> "" +    where len = 6 + +pairingRequest :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingRequest _ peer = do +    self <- liftIO $ serverIdentity $ peerServer peer +    nonce <- liftIO $ getRandomBytes 32 +    pid <- peerIdentity peer >>= \case +        PeerIdentityFull pid -> return pid +        _ -> throwError "incomplete peer identity" +    sendToPeerWith @(PairingService a) peer $ \case +        NoPairing -> return (Just $ PairingRequest (idData self) (idData pid) (nonceDigest self pid nonce BA.empty), OurRequest self pid nonce) +        _ -> throwError "already in progress" + +pairingAccept :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingAccept _ peer = runPeerService @(PairingService a) peer $ do +    svcGet >>= \case +        NoPairing -> throwError $ "none in progress" +        OurRequest {} -> throwError $ "waiting for peer" +        OurRequestConfirm Nothing -> do +            join $ asks $ pairingHookConfirmedResponse . svcAttributes +            svcSet OurRequestReady +        OurRequestConfirm (Just verified) -> do +            join $ asks $ pairingHookAcceptedResponse . svcAttributes +            pairingFinalizeRequest verified +            svcSet PairingDone +        OurRequestReady -> throwError $ "already accepted, waiting for peer" +        PeerRequest {} -> throwError $ "waiting for peer" +        PeerRequestConfirm -> do +            join $ asks $ pairingHookAcceptedRequest . svcAttributes +            replyPacket . PairingAccept =<< pairingFinalizeResponse +            svcSet PairingDone +        PairingDone -> throwError $ "already done" + +pairingReject :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingReject _ peer = runPeerService @(PairingService a) peer $ do +    svcGet >>= \case +        NoPairing -> throwError $ "none in progress" +        PairingDone -> throwError $ "already done" +        _ -> reject PairingUserRejected diff --git a/src/Erebos/PubKey.hs b/src/Erebos/PubKey.hs new file mode 100644 index 0000000..09a8e02 --- /dev/null +++ b/src/Erebos/PubKey.hs @@ -0,0 +1,156 @@ +module Erebos.PubKey ( +    PublicKey, SecretKey, +    KeyPair(generateKeys), loadKey, loadKeyMb, +    Signature(sigKey), Signed, signedData, signedSignature, +    sign, signAdd, isSignedBy, +    fromSigned, +    unsafeMapSigned, + +    PublicKexKey, SecretKexKey, +    dhSecret, +) where + +import Control.Monad +import Control.Monad.Except + +import Crypto.Error +import qualified Crypto.PubKey.Ed25519 as ED +import qualified Crypto.PubKey.Curve25519 as CX + +import Data.ByteArray +import Data.ByteString (ByteString) +import qualified Data.Text as T + +import Erebos.Storage +import Erebos.Storage.Key + +data PublicKey = PublicKey ED.PublicKey +    deriving (Show) + +data SecretKey = SecretKey ED.SecretKey (Stored PublicKey) + +data Signature = Signature +    { sigKey :: Stored PublicKey +    , sigSignature :: ED.Signature +    } +    deriving (Show) + +data Signed a = Signed +    { signedData_ :: Stored a +    , signedSignature_ :: [Stored Signature] +    } +    deriving (Show) + +signedData :: Signed a -> Stored a +signedData = signedData_ + +signedSignature :: Signed a -> [Stored Signature] +signedSignature = signedSignature_ + +instance KeyPair SecretKey PublicKey where +    keyGetPublic (SecretKey _ pub) = pub +    keyGetData (SecretKey sec _) = convert sec +    keyFromData kdata spub = do +        skey <- maybeCryptoError $ ED.secretKey kdata +        let PublicKey pkey = fromStored spub +        guard $ ED.toPublic skey == pkey +        return $ SecretKey skey spub +    generateKeys st = do +        secret <- ED.generateSecretKey +        public <- wrappedStore st $ PublicKey $ ED.toPublic secret +        let pair = SecretKey secret public +        storeKey pair +        return (pair, public) + +instance Storable PublicKey where +    store' (PublicKey pk) = storeRec $ do +        storeText "type" $ T.pack "ed25519" +        storeBinary "pubkey" pk + +    load' = loadRec $ do +        ktype <- loadText "type" +        guard $ ktype == "ed25519" +        maybe (throwError "Public key decoding failed") (return . PublicKey) . +            maybeCryptoError . (ED.publicKey :: ByteString -> CryptoFailable ED.PublicKey) =<< +                loadBinary "pubkey" + +instance Storable Signature where +    store' sig = storeRec $ do +        storeRef "key" $ sigKey sig +        storeBinary "sig" $ sigSignature sig + +    load' = loadRec $ Signature +        <$> loadRef "key" +        <*> loadSignature "sig" +        where loadSignature = maybe (throwError "Signature decoding failed") return . +                  maybeCryptoError . (ED.signature :: ByteString -> CryptoFailable ED.Signature) <=< loadBinary + +instance Storable a => Storable (Signed a) where +    store' sig = storeRec $ do +        storeRef "SDATA" $ signedData sig +        mapM_ (storeRef "sig") $ signedSignature sig + +    load' = loadRec $ do +        sdata <- loadRef "SDATA" +        sigs <- loadRefs "sig" +        forM_ sigs $ \sig -> do +            let PublicKey pubkey = fromStored $ sigKey $ fromStored sig +            when (not $ ED.verify pubkey (storedRef sdata) $ sigSignature $ fromStored sig) $ +                throwError "signature verification failed" +        return $ Signed sdata sigs + +sign :: MonadStorage m => SecretKey -> Stored a -> m (Signed a) +sign secret val = signAdd secret $ Signed val [] + +signAdd :: MonadStorage m => SecretKey -> Signed a -> m (Signed a) +signAdd (SecretKey secret spublic) (Signed val sigs) = do +    let PublicKey public = fromStored spublic +        sig = ED.sign secret public $ storedRef val +    ssig <- mstore $ Signature spublic sig +    return $ Signed val (ssig : sigs) + +isSignedBy :: Signed a -> Stored PublicKey -> Bool +isSignedBy sig key = key `elem` map (sigKey . fromStored) (signedSignature sig) + +fromSigned :: Stored (Signed a) -> a +fromSigned = fromStored . signedData . fromStored + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapSigned :: (a -> b) -> Signed a -> Signed b +unsafeMapSigned f signed = signed { signedData_ = unsafeMapStored f (signedData_ signed) } + + +data PublicKexKey = PublicKexKey CX.PublicKey +    deriving (Show) + +data SecretKexKey = SecretKexKey CX.SecretKey (Stored PublicKexKey) + +instance KeyPair SecretKexKey PublicKexKey where +    keyGetPublic (SecretKexKey _ pub) = pub +    keyGetData (SecretKexKey sec _) = convert sec +    keyFromData kdata spub = do +        skey <- maybeCryptoError $ CX.secretKey kdata +        let PublicKexKey pkey = fromStored spub +        guard $ CX.toPublic skey == pkey +        return $ SecretKexKey skey spub +    generateKeys st = do +        secret <- CX.generateSecretKey +        public <- wrappedStore st $ PublicKexKey $ CX.toPublic secret +        let pair = SecretKexKey secret public +        storeKey pair +        return (pair, public) + +instance Storable PublicKexKey where +    store' (PublicKexKey pk) = storeRec $ do +        storeText "type" $ T.pack "x25519" +        storeBinary "pubkey" pk + +    load' = loadRec $ do +        ktype <- loadText "type" +        guard $ ktype == "x25519" +        maybe (throwError "public key decoding failed") (return . PublicKexKey) . +            maybeCryptoError . (CX.publicKey :: ScrubbedBytes -> CryptoFailable CX.PublicKey) =<< +                loadBinary "pubkey" + +dhSecret :: SecretKexKey -> PublicKexKey -> ScrubbedBytes +dhSecret (SecretKexKey secret _) (PublicKexKey public) = convert $ CX.dh public secret diff --git a/src/Erebos/Service.hs b/src/Erebos/Service.hs new file mode 100644 index 0000000..f8428d1 --- /dev/null +++ b/src/Erebos/Service.hs @@ -0,0 +1,190 @@ +module Erebos.Service ( +    Service(..), +    SomeService(..), someService, someServiceAttr, someServiceID, +    SomeServiceState(..), fromServiceState, someServiceEmptyState, +    SomeServiceGlobalState(..), fromServiceGlobalState, someServiceEmptyGlobalState, +    SomeStorageWatcher(..), +    ServiceID, mkServiceID, + +    ServiceHandler, +    ServiceInput(..), +    ServiceReply(..), +    runServiceHandler, + +    svcGet, svcSet, svcModify, +    svcGetGlobal, svcSetGlobal, svcModifyGlobal, +    svcGetLocal, svcSetLocal, + +    svcSelf, +    svcPrint, + +    replyPacket, replyStored, replyStoredRef, +    afterCommit, +) where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer + +import Data.Kind +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U + +import Erebos.Identity +import {-# SOURCE #-} Erebos.Network +import Erebos.State +import Erebos.Storage + +class (Typeable s, Storable s, Typeable (ServiceState s), Typeable (ServiceGlobalState s)) => Service s where +    serviceID :: proxy s -> ServiceID +    serviceHandler :: Stored s -> ServiceHandler s () + +    serviceNewPeer :: ServiceHandler s () +    serviceNewPeer = return () + +    type ServiceAttributes s = attr | attr -> s +    type ServiceAttributes s = Proxy s +    defaultServiceAttributes :: proxy s -> ServiceAttributes s +    default defaultServiceAttributes :: ServiceAttributes s ~ Proxy s => proxy s -> ServiceAttributes s +    defaultServiceAttributes _ = Proxy + +    type ServiceState s :: Type +    type ServiceState s = () +    emptyServiceState :: proxy s -> ServiceState s +    default emptyServiceState :: ServiceState s ~ () => proxy s -> ServiceState s +    emptyServiceState _ = () + +    type ServiceGlobalState s :: Type +    type ServiceGlobalState s = () +    emptyServiceGlobalState :: proxy s -> ServiceGlobalState s +    default emptyServiceGlobalState :: ServiceGlobalState s ~ () => proxy s -> ServiceGlobalState s +    emptyServiceGlobalState _ = () + +    serviceStorageWatchers :: proxy s -> [SomeStorageWatcher s] +    serviceStorageWatchers _ = [] + + +data SomeService = forall s. Service s => SomeService (Proxy s) (ServiceAttributes s) + +someService :: forall s proxy. Service s => proxy s -> SomeService +someService _ = SomeService @s Proxy (defaultServiceAttributes @s Proxy) + +someServiceAttr :: forall s. Service s => ServiceAttributes s -> SomeService +someServiceAttr attr = SomeService @s Proxy attr + +someServiceID :: SomeService -> ServiceID +someServiceID (SomeService s _) = serviceID s + +data SomeServiceState = forall s. Service s => SomeServiceState (Proxy s) (ServiceState s) + +fromServiceState :: Service s => proxy s -> SomeServiceState -> Maybe (ServiceState s) +fromServiceState _ (SomeServiceState _ s) = cast s + +someServiceEmptyState :: SomeService -> SomeServiceState +someServiceEmptyState (SomeService p _) = SomeServiceState p (emptyServiceState p) + +data SomeServiceGlobalState = forall s. Service s => SomeServiceGlobalState (Proxy s) (ServiceGlobalState s) + +fromServiceGlobalState :: Service s => proxy s -> SomeServiceGlobalState -> Maybe (ServiceGlobalState s) +fromServiceGlobalState _ (SomeServiceGlobalState _ s) = cast s + +someServiceEmptyGlobalState :: SomeService -> SomeServiceGlobalState +someServiceEmptyGlobalState (SomeService p _) = SomeServiceGlobalState p (emptyServiceGlobalState p) + + +data SomeStorageWatcher s = forall a. Eq a => SomeStorageWatcher (Stored LocalState -> a) (a -> ServiceHandler s ()) + + +newtype ServiceID = ServiceID UUID +    deriving (Eq, Ord, Show, StorableUUID) + +mkServiceID :: String -> ServiceID +mkServiceID = maybe (error "Invalid service ID") ServiceID . U.fromString + +data ServiceInput s = ServiceInput +    { svcAttributes :: ServiceAttributes s +    , svcPeer :: Peer +    , svcPeerIdentity :: UnifiedIdentity +    , svcServer :: Server +    , svcPrintOp :: String -> IO () +    } + +data ServiceReply s = ServiceReply (Either s (Stored s)) Bool +                    | ServiceFinally (IO ()) + +data ServiceHandlerState s = ServiceHandlerState +    { svcValue :: ServiceState s +    , svcGlobal :: ServiceGlobalState s +    , svcLocal :: Stored LocalState +    } + +newtype ServiceHandler s a = ServiceHandler (ReaderT (ServiceInput s) (WriterT [ServiceReply s] (StateT (ServiceHandlerState s) (ExceptT String IO))) a) +    deriving (Functor, Applicative, Monad, MonadReader (ServiceInput s), MonadWriter [ServiceReply s], MonadState (ServiceHandlerState s), MonadError String, MonadIO) + +instance MonadStorage (ServiceHandler s) where +    getStorage = asks $ peerStorage . svcPeer + +instance MonadHead LocalState (ServiceHandler s) where +    updateLocalHead f = do +        (ls, x) <- f =<< gets svcLocal +        modify $ \s -> s { svcLocal = ls } +        return x + +runServiceHandler :: Service s => Head LocalState -> ServiceInput s -> ServiceState s -> ServiceGlobalState s -> ServiceHandler s () -> IO ([ServiceReply s], (ServiceState s, ServiceGlobalState s)) +runServiceHandler h input svc global shandler = do +    let sstate = ServiceHandlerState { svcValue = svc, svcGlobal = global, svcLocal = headStoredObject h } +        ServiceHandler handler = shandler +    (runExceptT $ flip runStateT sstate $ execWriterT $ flip runReaderT input $ handler) >>= \case +        Left err -> do +            svcPrintOp input $ "service failed: " ++ err +            return ([], (svc, global)) +        Right (rsp, sstate') +            | svcLocal sstate' == svcLocal sstate -> return (rsp, (svcValue sstate', svcGlobal sstate')) +            | otherwise -> replaceHead h (svcLocal sstate') >>= \case +                Left (Just h') -> runServiceHandler h' input svc global shandler +                _              -> return (rsp, (svcValue sstate', svcGlobal sstate')) + +svcGet :: ServiceHandler s (ServiceState s) +svcGet = gets svcValue + +svcSet :: ServiceState s -> ServiceHandler s () +svcSet x = modify $ \st -> st { svcValue = x } + +svcModify :: (ServiceState s -> ServiceState s) -> ServiceHandler s () +svcModify f = modify $ \st -> st { svcValue = f (svcValue st) } + +svcGetGlobal :: ServiceHandler s (ServiceGlobalState s) +svcGetGlobal = gets svcGlobal + +svcSetGlobal :: ServiceGlobalState s -> ServiceHandler s () +svcSetGlobal x = modify $ \st -> st { svcGlobal = x } + +svcModifyGlobal :: (ServiceGlobalState s -> ServiceGlobalState s) -> ServiceHandler s () +svcModifyGlobal f = modify $ \st -> st { svcGlobal = f (svcGlobal st) } + +svcGetLocal :: ServiceHandler s (Stored LocalState) +svcGetLocal = gets svcLocal + +svcSetLocal :: Stored LocalState -> ServiceHandler s () +svcSetLocal x = modify $ \st -> st { svcLocal = x } + +svcSelf :: ServiceHandler s UnifiedIdentity +svcSelf = maybe (throwError "failed to validate own identity") return . +        validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal + +svcPrint :: String -> ServiceHandler s () +svcPrint str = afterCommit . ($ str) =<< asks svcPrintOp + +replyPacket :: Service s => s -> ServiceHandler s () +replyPacket x = tell [ServiceReply (Left x) True] + +replyStored :: Service s => Stored s -> ServiceHandler s () +replyStored x = tell [ServiceReply (Right x) True] + +replyStoredRef :: Service s => Stored s -> ServiceHandler s () +replyStoredRef x = tell [ServiceReply (Right x) False] + +afterCommit :: IO () -> ServiceHandler s () +afterCommit x = tell [ServiceFinally x] diff --git a/src/Erebos/Set.hs b/src/Erebos/Set.hs new file mode 100644 index 0000000..0abe02d --- /dev/null +++ b/src/Erebos/Set.hs @@ -0,0 +1,78 @@ +module Erebos.Set ( +    Set, + +    emptySet, +    loadSet, +    storeSetAdd, + +    fromSetBy, +) where + +import Control.Arrow +import Control.Monad.IO.Class + +import Data.Function +import Data.List +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Ord + +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Util + +data Set a = Set [Stored (SetItem (Component a))] + +data SetItem a = SetItem +    { siPrev :: [Stored (SetItem a)] +    , siItem :: [Stored a] +    } + +instance Storable a => Storable (SetItem a) where +    store' x = storeRec $ do +        mapM_ (storeRef "PREV") $ siPrev x +        mapM_ (storeRef "item") $ siItem x + +    load' = loadRec $ SetItem +        <$> loadRefs "PREV" +        <*> loadRefs "item" + +instance Mergeable a => Mergeable (Set a) where +    type Component (Set a) = SetItem (Component a) +    mergeSorted = Set +    toComponents (Set items) = items + + +emptySet :: Set a +emptySet = Set [] + +loadSet :: Mergeable a => Ref -> Set a +loadSet = mergeSorted . (:[]) . wrappedLoad + +storeSetAdd :: (Mergeable a, MonadIO m) => Storage -> a -> Set a -> m (Set a) +storeSetAdd st x (Set prev) = Set . (:[]) <$> wrappedStore st SetItem +    { siPrev = prev +    , siItem = toComponents x +    } + + +fromSetBy :: forall a. Mergeable a => (a -> a -> Ordering) -> Set a -> [a] +fromSetBy cmp (Set heads) = sortBy cmp $ map merge $ groupRelated items +  where +    -- gather all item components in the set history +    items :: [Stored (Component a)] +    items = walkAncestors (siItem . fromStored) heads + +    -- map individual roots to full root set as joined in history of individual items +    rootToRootSet :: Map RefDigest [RefDigest] +    rootToRootSet = foldl' (\m rs -> foldl' (\m' r -> M.insertWith (\a b -> uniq $ sort $ a++b) r rs m') m rs) M.empty $ +        map (map (refDigest . storedRef) . storedRoots) items + +    -- get full root set for given item component +    storedRootSet :: Stored (Component a) -> [RefDigest] +    storedRootSet = fromJust . flip M.lookup rootToRootSet . refDigest . storedRef . head . storedRoots + +    -- group components of single item, i.e. components sharing some root +    groupRelated :: [Stored (Component a)] -> [[Stored (Component a)]] +    groupRelated = map (map fst) . groupBy ((==) `on` snd) . sortBy (comparing snd) . map (id &&& storedRootSet) diff --git a/src/Erebos/State.hs b/src/Erebos/State.hs new file mode 100644 index 0000000..1f0bf7d --- /dev/null +++ b/src/Erebos/State.hs @@ -0,0 +1,199 @@ +module Erebos.State ( +    LocalState(..), +    SharedState, SharedType(..), +    SharedTypeID, mkSharedTypeID, + +    MonadHead(..), +    updateLocalHead_, + +    loadLocalStateHead, + +    updateSharedState, updateSharedState_, +    lookupSharedValue, makeSharedStateUpdate, + +    localIdentity, +    headLocalIdentity, + +    mergeSharedIdentity, +    updateSharedIdentity, +    interactiveIdentityUpdate, +) where + +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Foldable +import Data.Maybe +import qualified Data.Text as T +import qualified Data.Text.IO as T +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U + +import System.IO + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Storage +import Erebos.Storage.Merge + +data LocalState = LocalState +    { lsIdentity :: Stored (Signed ExtendedIdentityData) +    , lsShared :: [Stored SharedState] +    } + +data SharedState = SharedState +    { ssPrev :: [Stored SharedState] +    , ssType :: Maybe SharedTypeID +    , ssValue :: [Ref] +    } + +newtype SharedTypeID = SharedTypeID UUID +    deriving (Eq, Ord, StorableUUID) + +mkSharedTypeID :: String -> SharedTypeID +mkSharedTypeID = maybe (error "Invalid shared type ID") SharedTypeID . U.fromString + +class Mergeable a => SharedType a where +    sharedTypeID :: proxy a -> SharedTypeID + +instance Storable LocalState where +    store' st = storeRec $ do +        storeRef "id" $ lsIdentity st +        mapM_ (storeRef "shared") $ lsShared st + +    load' = loadRec $ LocalState +        <$> loadRef "id" +        <*> loadRefs "shared" + +instance HeadType LocalState where +    headTypeID _ = mkHeadTypeID "1d7491a9-7bcb-4eaa-8f13-c8c4c4087e4e" + +instance Storable SharedState where +    store' st = storeRec $ do +        mapM_ (storeRef "PREV") $ ssPrev st +        storeMbUUID "type" $ ssType st +        mapM_ (storeRawRef "value") $ ssValue st + +    load' = loadRec $ SharedState +        <$> loadRefs "PREV" +        <*> loadMbUUID "type" +        <*> loadRawRefs "value" + +instance SharedType (Maybe ComposedIdentity) where +    sharedTypeID _ = mkSharedTypeID "0c6c1fe0-f2d7-4891-926b-c332449f7871" + + +class (MonadIO m, MonadStorage m) => MonadHead a m where +    updateLocalHead :: (Stored a -> m (Stored a, b)) -> m b + +updateLocalHead_ :: MonadHead a m => (Stored a -> m (Stored a)) -> m () +updateLocalHead_ f = updateLocalHead (fmap (,()) . f) + +instance (HeadType a, MonadIO m) => MonadHead a (ReaderT (Head a) m) where +    updateLocalHead f = do +        h <- ask +        snd <$> updateHead h f + + +loadLocalStateHead :: MonadIO m => Storage -> m (Head LocalState) +loadLocalStateHead st = loadHeads st >>= \case +    (h:_) -> return h +    [] -> liftIO $ do +        putStr "Name: " +        hFlush stdout +        name <- T.getLine + +        putStr "Device: " +        hFlush stdout +        devName <- T.getLine + +        owner <- if +            | T.null name -> return Nothing +            | otherwise -> Just <$> createIdentity st (Just name) Nothing + +        identity <- createIdentity st (if T.null devName then Nothing else Just devName) owner + +        shared <- wrappedStore st $ SharedState +            { ssPrev = [] +            , ssType = Just $ sharedTypeID @(Maybe ComposedIdentity) Proxy +            , ssValue = [storedRef $ idExtData $ fromMaybe identity owner] +            } +        storeHead st $ LocalState +            { lsIdentity = idExtData identity +            , lsShared = [shared] +            } + +localIdentity :: LocalState -> UnifiedIdentity +localIdentity ls = maybe (error "failed to verify local identity") +    (updateOwners $ maybe [] idExtDataF $ lookupSharedValue $ lsShared ls) +    (validateExtendedIdentity $ lsIdentity ls) + +headLocalIdentity :: Head LocalState -> UnifiedIdentity +headLocalIdentity = localIdentity . headObject + + +updateSharedState_ :: forall a m. (SharedType a, MonadHead LocalState m) => (a -> m a) -> Stored LocalState -> m (Stored LocalState) +updateSharedState_ f = fmap fst <$> updateSharedState (fmap (,()) . f) + +updateSharedState :: forall a b m. (SharedType a, MonadHead LocalState m) => (a -> m (a, b)) -> Stored LocalState -> m (Stored LocalState, b) +updateSharedState f = \ls -> do +    let shared = lsShared $ fromStored ls +        val = lookupSharedValue shared +    st <- getStorage +    (val', x) <- f val +    (,x) <$> if toComponents val' == toComponents val +                then return ls +                else do shared' <- makeSharedStateUpdate st val' shared +                        wrappedStore st (fromStored ls) { lsShared = [shared'] } + +lookupSharedValue :: forall a. SharedType a => [Stored SharedState] -> a +lookupSharedValue = mergeSorted . filterAncestors . map wrappedLoad . concatMap (ssValue . fromStored) . filterAncestors . helper +    where helper (x:xs) | Just sid <- ssType (fromStored x), sid == sharedTypeID @a Proxy = x : helper xs +                        | otherwise = helper $ ssPrev (fromStored x) ++ xs +          helper [] = [] + +makeSharedStateUpdate :: forall a m. MonadIO m => SharedType a => Storage -> a -> [Stored SharedState] -> m (Stored SharedState) +makeSharedStateUpdate st val prev = liftIO $ wrappedStore st SharedState +    { ssPrev = prev +    , ssType = Just $ sharedTypeID @a Proxy +    , ssValue = storedRef <$> toComponents val +    } + + +mergeSharedIdentity :: (MonadHead LocalState m, MonadError String m) => m UnifiedIdentity +mergeSharedIdentity = updateLocalHead $ updateSharedState $ \case +    Just cidentity -> do +        identity <- mergeIdentity cidentity +        return (Just $ toComposedIdentity identity, identity) +    Nothing -> throwError "no existing shared identity" + +updateSharedIdentity :: (MonadHead LocalState m, MonadError String m) => m () +updateSharedIdentity = updateLocalHead_ $ updateSharedState_ $ \case +    Just identity -> do +        Just . toComposedIdentity <$> interactiveIdentityUpdate identity +    Nothing -> throwError "no existing shared identity" + +interactiveIdentityUpdate :: (Foldable f, MonadStorage m, MonadIO m, MonadError String m) => Identity f -> m UnifiedIdentity +interactiveIdentityUpdate identity = do +    let public = idKeyIdentity identity + +    name <- liftIO $ do +        T.putStr $ T.concat $ concat +            [ [ T.pack "Name" ] +            , case idName identity of +                   Just name -> [T.pack " [", name, T.pack "]"] +                   Nothing -> [] +            , [ T.pack ": " ] +            ] +        hFlush stdout +        T.getLine + +    if  | T.null name -> mergeIdentity identity +        | otherwise -> do +            secret <- loadKey public +            maybe (throwError "created invalid identity") return . validateIdentity =<< +                mstore =<< sign secret =<< mstore (emptyIdentityData public) +                { iddPrev = toList $ idDataF identity +                , iddName = Just name +                } diff --git a/src/Erebos/Storage.hs b/src/Erebos/Storage.hs new file mode 100644 index 0000000..0511814 --- /dev/null +++ b/src/Erebos/Storage.hs @@ -0,0 +1,1007 @@ +module Erebos.Storage ( +    Storage, PartialStorage, StorageCompleteness, +    openStorage, memoryStorage, +    deriveEphemeralStorage, derivePartialStorage, + +    Ref, PartialRef, RefDigest, +    refDigest, +    readRef, showRef, showRefDigest, +    refDigestFromByteString, hashToRefDigest, +    copyRef, partialRef, partialRefFromDigest, + +    Object, PartialObject, Object'(..), RecItem, RecItem'(..), +    serializeObject, deserializeObject, deserializeObjects, +    ioLoadObject, ioLoadBytes, +    storeRawBytes, lazyLoadBytes, +    storeObject, +    collectObjects, collectStoredObjects, + +    Head, HeadType(..), +    HeadTypeID, mkHeadTypeID, +    headId, headStorage, headRef, headObject, headStoredObject, +    loadHeads, loadHead, reloadHead, +    storeHead, replaceHead, updateHead, updateHead_, + +    WatchedHead, +    watchHead, watchHeadWith, unwatchHead, + +    MonadStorage(..), + +    Storable(..), ZeroStorable(..), +    StorableText(..), StorableDate(..), StorableUUID(..), + +    Store, StoreRec, +    evalStore, evalStoreObject, +    storeBlob, storeRec, storeZero, +    storeEmpty, storeInt, storeNum, storeText, storeBinary, storeDate, storeUUID, storeRef, storeRawRef, +    storeMbEmpty, storeMbInt, storeMbNum, storeMbText, storeMbBinary, storeMbDate, storeMbUUID, storeMbRef, storeMbRawRef, +    storeZRef, + +    Load, LoadRec, +    evalLoad, +    loadCurrentRef, loadCurrentObject, +    loadRecCurrentRef, loadRecItems, + +    loadBlob, loadRec, loadZero, +    loadEmpty, loadInt, loadNum, loadText, loadBinary, loadDate, loadUUID, loadRef, loadRawRef, +    loadMbEmpty, loadMbInt, loadMbNum, loadMbText, loadMbBinary, loadMbDate, loadMbUUID, loadMbRef, loadMbRawRef, +    loadTexts, loadBinaries, loadRefs, loadRawRefs, +    loadZRef, + +    Stored, +    fromStored, storedRef, +    wrappedStore, wrappedLoad, +    copyStored, +    unsafeMapStored, + +    StoreInfo(..), makeStoreInfo, + +    StoredHistory, +    fromHistory, fromHistoryAt, storedFromHistory, storedHistoryList, +    beginHistory, modifyHistory, +) where + +import Control.Applicative +import Control.Arrow +import Control.Concurrent +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.Writer + +import Crypto.Hash + +import Data.ByteString (ByteString) +import qualified Data.ByteArray as BA +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.Char +import Data.Function +import qualified Data.HashTable.IO as HT +import Data.List +import qualified Data.Map as M +import Data.Maybe +import Data.Ratio +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding +import Data.Text.Encoding.Error +import Data.Time.Calendar +import Data.Time.Clock +import Data.Time.Format +import Data.Time.LocalTime +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U +import qualified Data.UUID.V4 as U + +import System.Directory +import System.FilePath +import System.INotify +import System.IO.Error +import System.IO.Unsafe + +import Erebos.Storage.Internal + + +type Storage = Storage' Complete +type PartialStorage = Storage' Partial + +openStorage :: FilePath -> IO Storage +openStorage path = do +    createDirectoryIfMissing True $ path ++ "/objects" +    createDirectoryIfMissing True $ path ++ "/heads" +    watchers <- newMVar ([], WatchList 1 []) +    refgen <- newMVar =<< HT.new +    refroots <- newMVar =<< HT.new +    return $ Storage +        { stBacking = StorageDir path watchers +        , stParent = Nothing +        , stRefGeneration = refgen +        , stRefRoots = refroots +        } + +memoryStorage' :: IO (Storage' c') +memoryStorage' = do +    backing <- StorageMemory <$> newMVar [] <*> newMVar M.empty <*> newMVar M.empty <*> newMVar (WatchList 1 []) +    refgen <- newMVar =<< HT.new +    refroots <- newMVar =<< HT.new +    return $ Storage +        { stBacking = backing +        , stParent = Nothing +        , stRefGeneration = refgen +        , stRefRoots = refroots +        } + +memoryStorage :: IO Storage +memoryStorage = memoryStorage' + +deriveEphemeralStorage :: Storage -> IO Storage +deriveEphemeralStorage parent = do +    st <- memoryStorage +    return $ st { stParent = Just parent } + +derivePartialStorage :: Storage -> IO PartialStorage +derivePartialStorage parent = do +    st <- memoryStorage' +    return $ st { stParent = Just parent } + +type Ref = Ref' Complete +type PartialRef = Ref' Partial + +zeroRef :: Storage' c -> Ref' c +zeroRef s = Ref s (RefDigest h) +    where h = case digestFromByteString $ B.replicate (hashDigestSize $ digestAlgo h) 0 of +                   Nothing -> error $ "Failed to create zero hash" +                   Just h' -> h' +          digestAlgo :: Digest a -> a +          digestAlgo = undefined + +isZeroRef :: Ref' c -> Bool +isZeroRef (Ref _ h) = all (==0) $ BA.unpack h + + +refFromDigest :: Storage' c -> RefDigest -> IO (Maybe (Ref' c)) +refFromDigest st dgst = fmap (const $ Ref st dgst) <$> ioLoadBytesFromStorage st dgst + +readRef :: Storage -> ByteString -> IO (Maybe Ref) +readRef s b = +    case readRefDigest b of +         Nothing -> return Nothing +         Just dgst -> refFromDigest s dgst + +copyRef' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Ref' c -> IO (c (Ref' c')) +copyRef' st ref'@(Ref _ dgst) = refFromDigest st dgst >>= \case Just ref -> return $ return ref +                                                                Nothing  -> doCopy +    where doCopy = do mbobj' <- ioLoadObject ref' +                      mbobj <- sequence $ copyObject' st <$> mbobj' +                      sequence $ unsafeStoreObject st <$> join mbobj + +copyObject' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (c (Object' c')) +copyObject' _ (Blob bs) = return $ return $ Blob bs +copyObject' st (Rec rs) = fmap Rec . sequence <$> mapM copyItem rs +    where copyItem :: (ByteString, RecItem' c) -> IO (c (ByteString, RecItem' c')) +          copyItem (n, item) = fmap (n,) <$> case item of +              RecEmpty -> return $ return $ RecEmpty +              RecInt x -> return $ return $ RecInt x +              RecNum x -> return $ return $ RecNum x +              RecText x -> return $ return $ RecText x +              RecBinary x -> return $ return $ RecBinary x +              RecDate x -> return $ return $ RecDate x +              RecUUID x -> return $ return $ RecUUID x +              RecRef x -> fmap RecRef <$> copyRef' st x +copyObject' _ ZeroObject = return $ return ZeroObject + +copyRef :: forall c c' m. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => Storage' c' -> Ref' c -> m (LoadResult c (Ref' c')) +copyRef st ref' = liftIO $ returnLoadResult <$> copyRef' st ref' + +copyObject :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (LoadResult c (Object' c')) +copyObject st obj' = returnLoadResult <$> copyObject' st obj' + +partialRef :: PartialStorage -> Ref -> PartialRef +partialRef st (Ref _ dgst) = Ref st dgst + +partialRefFromDigest :: PartialStorage -> RefDigest -> PartialRef +partialRefFromDigest st dgst = Ref st dgst + + +data Object' c +    = Blob ByteString +    | Rec [(ByteString, RecItem' c)] +    | ZeroObject +    deriving (Show) + +type Object = Object' Complete +type PartialObject = Object' Partial + +data RecItem' c +    = RecEmpty +    | RecInt Integer +    | RecNum Rational +    | RecText Text +    | RecBinary ByteString +    | RecDate ZonedTime +    | RecUUID UUID +    | RecRef (Ref' c) +    deriving (Show) + +type RecItem = RecItem' Complete + +serializeObject :: Object' c -> BL.ByteString +serializeObject = \case +    Blob cnt -> BL.fromChunks [BC.pack "blob ", BC.pack (show $ B.length cnt), BC.singleton '\n', cnt] +    Rec rec -> let cnt = BL.fromChunks $ concatMap (uncurry serializeRecItem) rec +                in BL.fromChunks [BC.pack "rec ", BC.pack (show $ BL.length cnt), BC.singleton '\n'] `BL.append` cnt +    ZeroObject -> BL.empty + +-- |Serializes and stores object data without ony dependencies, so is safe only +-- if all the referenced objects are already stored or reference is partial. +unsafeStoreObject :: Storage' c -> Object' c -> IO (Ref' c) +unsafeStoreObject storage = \case +    ZeroObject -> return $ zeroRef storage +    obj -> unsafeStoreRawBytes storage $ serializeObject obj + +storeObject :: PartialStorage -> PartialObject -> IO PartialRef +storeObject = unsafeStoreObject + +storeRawBytes :: PartialStorage -> BL.ByteString -> IO PartialRef +storeRawBytes = unsafeStoreRawBytes + +serializeRecItem :: ByteString -> RecItem' c -> [ByteString] +serializeRecItem name (RecEmpty) = [name, BC.pack ":e", BC.singleton ' ', BC.singleton '\n'] +serializeRecItem name (RecInt x) = [name, BC.pack ":i", BC.singleton ' ', BC.pack (show x), BC.singleton '\n'] +serializeRecItem name (RecNum x) = [name, BC.pack ":n", BC.singleton ' ', BC.pack (showRatio x), BC.singleton '\n'] +serializeRecItem name (RecText x) = [name, BC.pack ":t", BC.singleton ' ', escaped, BC.singleton '\n'] +    where escaped = BC.concatMap escape $ encodeUtf8 x +          escape '\n' = BC.pack "\n\t" +          escape c    = BC.singleton c +serializeRecItem name (RecBinary x) = [name, BC.pack ":b ", showHex x, BC.singleton '\n'] +serializeRecItem name (RecDate x) = [name, BC.pack ":d", BC.singleton ' ', BC.pack (formatTime defaultTimeLocale "%s %z" x), BC.singleton '\n'] +serializeRecItem name (RecUUID x) = [name, BC.pack ":u", BC.singleton ' ', U.toASCIIBytes x, BC.singleton '\n'] +serializeRecItem name (RecRef x) = [name, BC.pack ":r ", showRef x, BC.singleton '\n'] + +lazyLoadObject :: forall c. StorageCompleteness c => Ref' c -> LoadResult c (Object' c) +lazyLoadObject = returnLoadResult . unsafePerformIO . ioLoadObject + +ioLoadObject :: forall c. StorageCompleteness c => Ref' c -> IO (c (Object' c)) +ioLoadObject ref | isZeroRef ref = return $ return ZeroObject +ioLoadObject ref@(Ref st rhash) = do +    file' <- ioLoadBytes ref +    return $ do +        file <- file' +        let chash = hashToRefDigest file +        when (chash /= rhash) $ error $ "Hash mismatch on object " ++ BC.unpack (showRef ref) {- TODO throw -} +        return $ case runExcept $ unsafeDeserializeObject st file of +                      Left err -> error $ err ++ ", ref " ++ BC.unpack (showRef ref) {- TODO throw -} +                      Right (x, rest) | BL.null rest -> x +                                      | otherwise -> error $ "Superfluous content after " ++ BC.unpack (showRef ref) {- TODO throw -} + +lazyLoadBytes :: forall c. StorageCompleteness c => Ref' c -> LoadResult c BL.ByteString +lazyLoadBytes ref | isZeroRef ref = returnLoadResult (return BL.empty :: c BL.ByteString) +lazyLoadBytes ref = returnLoadResult $ unsafePerformIO $ ioLoadBytes ref + +unsafeDeserializeObject :: Storage' c -> BL.ByteString -> Except String (Object' c, BL.ByteString) +unsafeDeserializeObject _  bytes | BL.null bytes = return (ZeroObject, bytes) +unsafeDeserializeObject st bytes = +    case BLC.break (=='\n') bytes of +        (line, rest) | Just (otype, len) <- splitObjPrefix line -> do +            let (content, next) = first BL.toStrict $ BL.splitAt (fromIntegral len) $ BL.drop 1 rest +            guard $ B.length content == len +            (,next) <$> case otype of +                 _ | otype == BC.pack "blob" -> return $ Blob content +                   | otype == BC.pack "rec" -> maybe (throwError $ "Malformed record item ") +                                                   (return . Rec) $ sequence $ map parseRecLine $ mergeCont [] $ BC.lines content +                   | otherwise -> throwError $ "Unknown object type" +        _ -> throwError $ "Malformed object" +    where splitObjPrefix line = do +              [otype, tlen] <- return $ BLC.words line +              (len, rest) <- BLC.readInt tlen +              guard $ BL.null rest +              return (BL.toStrict otype, len) + +          mergeCont cs (a:b:rest) | Just ('\t', b') <- BC.uncons b = mergeCont (b':BC.pack "\n":cs) (a:rest) +          mergeCont cs (a:rest) = B.concat (a : reverse cs) : mergeCont [] rest +          mergeCont _ [] = [] + +          parseRecLine line = do +              colon <- BC.elemIndex ':' line +              space <- BC.elemIndex ' ' line +              guard $ colon < space +              let name = B.take colon line +                  itype = B.take (space-colon-1) $ B.drop (colon+1) line +                  content = B.drop (space+1) line + +              val <- case BC.unpack itype of +                          "e" -> do guard $ B.null content +                                    return RecEmpty +                          "i" -> do (num, rest) <- BC.readInteger content +                                    guard $ B.null rest +                                    return $ RecInt num +                          "n" -> RecNum <$> parseRatio content +                          "t" -> return $ RecText $ decodeUtf8With lenientDecode content +                          "b" -> RecBinary <$> readHex content +                          "d" -> RecDate <$> parseTimeM False defaultTimeLocale "%s %z" (BC.unpack content) +                          "u" -> RecUUID <$> U.fromASCIIBytes content +                          "r" -> RecRef . Ref st <$> readRefDigest content +                          _   -> Nothing +              return (name, val) + +deserializeObject :: PartialStorage -> BL.ByteString -> Except String (PartialObject, BL.ByteString) +deserializeObject = unsafeDeserializeObject + +deserializeObjects :: PartialStorage -> BL.ByteString -> Except String [PartialObject] +deserializeObjects _  bytes | BL.null bytes = return [] +deserializeObjects st bytes = do (obj, rest) <- deserializeObject st bytes +                                 (obj:) <$> deserializeObjects st rest + + +collectObjects :: Object -> [Object] +collectObjects obj = obj : map fromStored (fst $ collectOtherStored S.empty obj) + +collectStoredObjects :: Stored Object -> [Stored Object] +collectStoredObjects obj = obj : (fst $ collectOtherStored S.empty $ fromStored obj) + +collectOtherStored :: Set RefDigest -> Object -> ([Stored Object], Set RefDigest) +collectOtherStored seen (Rec items) = foldr helper ([], seen) $ map snd items +    where helper (RecRef ref) (xs, s) | r <- refDigest ref +                                      , r `S.notMember` s +                                      = let o = wrappedLoad ref +                                            (xs', s') = collectOtherStored (S.insert r s) $ fromStored o +                                         in ((o : xs') ++ xs, s') +          helper _          (xs, s) = (xs, s) +collectOtherStored seen _ = ([], seen) + + +type Head = Head' Complete + +headId :: Head a -> HeadID +headId (Head uuid _) = uuid + +headStorage :: Head a -> Storage +headStorage = refStorage . headRef + +headRef :: Head a -> Ref +headRef (Head _ sx) = storedRef sx + +headObject :: Head a -> a +headObject (Head _ sx) = fromStored sx + +headStoredObject :: Head a -> Stored a +headStoredObject (Head _ sx) = sx + +deriving instance StorableUUID HeadID +deriving instance StorableUUID HeadTypeID + +mkHeadTypeID :: String -> HeadTypeID +mkHeadTypeID = maybe (error "Invalid head type ID") HeadTypeID . U.fromString + +class Storable a => HeadType a where +    headTypeID :: proxy a -> HeadTypeID + + +headTypePath :: FilePath -> HeadTypeID -> FilePath +headTypePath spath (HeadTypeID tid) = spath </> "heads" </> U.toString tid + +headPath :: FilePath -> HeadTypeID -> HeadID -> FilePath +headPath spath tid (HeadID hid) = headTypePath spath tid </> U.toString hid + +loadHeads :: forall a m. MonadIO m => HeadType a => Storage -> m [Head a] +loadHeads s@(Storage { stBacking = StorageDir { dirPath = spath }}) = liftIO $ do +    let hpath = headTypePath spath $ headTypeID @a Proxy + +    files <- filterM (doesFileExist . (hpath </>)) =<< +        handleJust (\e -> guard (isDoesNotExistError e)) (const $ return []) +        (getDirectoryContents hpath) +    fmap catMaybes $ forM files $ \hname -> do +        case U.fromString hname of +             Just hid -> do +                 (h:_) <- BC.lines <$> B.readFile (hpath </> hname) +                 Just ref <- readRef s h +                 return $ Just $ Head (HeadID hid) $ wrappedLoad ref +             Nothing -> return Nothing +loadHeads Storage { stBacking = StorageMemory { memHeads = theads } } = liftIO $ do +    let toHead ((tid, hid), ref) | tid == headTypeID @a Proxy = Just $ Head hid $ wrappedLoad ref +                                 | otherwise                  = Nothing +    catMaybes . map toHead <$> readMVar theads + +loadHead :: forall a m. (HeadType a, MonadIO m) => Storage -> HeadID -> m (Maybe (Head a)) +loadHead s@(Storage { stBacking = StorageDir { dirPath = spath }}) hid = liftIO $ do +    handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ do +        (h:_) <- BC.lines <$> B.readFile (headPath spath (headTypeID @a Proxy) hid) +        Just ref <- readRef s h +        return $ Just $ Head hid $ wrappedLoad ref +loadHead Storage { stBacking = StorageMemory { memHeads = theads } } hid = liftIO $ do +    fmap (Head hid . wrappedLoad) . lookup (headTypeID @a Proxy, hid) <$> readMVar theads + +reloadHead :: (HeadType a, MonadIO m) => Head a -> m (Maybe (Head a)) +reloadHead (Head hid (Stored (Ref st _) _)) = loadHead st hid + +storeHead :: forall a m. MonadIO m => HeadType a => Storage -> a -> m (Head a) +storeHead st obj = liftIO $ do +    let tid = headTypeID @a Proxy +    hid <- HeadID <$> U.nextRandom +    stored <- wrappedStore st obj +    case stBacking st of +         StorageDir { dirPath = spath } -> do +             Right () <- writeFileChecked (headPath spath tid hid) Nothing $ +                 showRef (storedRef stored) `B.append` BC.singleton '\n' +             return () +         StorageMemory { memHeads = theads } -> do +             modifyMVar_ theads $ return . (((tid, hid), storedRef stored) :) +    return $ Head hid stored + +replaceHead :: forall a m. (HeadType a, MonadIO m) => Head a -> Stored a -> m (Either (Maybe (Head a)) (Head a)) +replaceHead prev@(Head hid pobj) stored' = liftIO $ do +    let st = headStorage prev +        tid = headTypeID @a Proxy +    stored <- copyStored st stored' +    case stBacking st of +         StorageDir { dirPath = spath } -> do +             let filename = headPath spath tid hid +                 showRefL r = showRef r `B.append` BC.singleton '\n' + +             writeFileChecked filename (Just $ showRefL $ headRef prev) (showRefL $ storedRef stored) >>= \case +                 Left Nothing -> return $ Left Nothing +                 Left (Just bs) -> do Just oref <- readRef st $ BC.takeWhile (/='\n') bs +                                      return $ Left $ Just $ Head hid $ wrappedLoad oref +                 Right () -> return $ Right $ Head hid stored + +         StorageMemory { memHeads = theads, memWatchers = twatch } -> do +             res <- modifyMVar theads $ \hs -> do +                 ws <- map wlFun . filter ((==(tid, hid)) . wlHead) . wlList <$> readMVar twatch +                 return $ case partition ((==(tid, hid)) . fst) hs of +                     ([] , _  ) -> (hs, Left Nothing) +                     ((_, r):_, hs') | r == storedRef pobj -> (((tid, hid), storedRef stored) : hs', +                                                                  Right (Head hid stored, ws)) +                                     | otherwise -> (hs, Left $ Just $ Head hid $ wrappedLoad r) +             case res of +                  Right (h, ws) -> mapM_ ($ headRef h) ws >> return (Right h) +                  Left x -> return $ Left x + +updateHead :: (HeadType a, MonadIO m) => Head a -> (Stored a -> m (Stored a, b)) -> m (Maybe (Head a), b) +updateHead h f = do +    (o, x) <- f $ headStoredObject h +    replaceHead h o >>= \case +        Right h' -> return (Just h', x) +        Left Nothing -> return (Nothing, x) +        Left (Just h') -> updateHead h' f + +updateHead_ :: (HeadType a, MonadIO m) => Head a -> (Stored a -> m (Stored a)) -> m (Maybe (Head a)) +updateHead_ h = fmap fst . updateHead h . (fmap (,()) .) + + +data WatchedHead = forall a. WatchedHead Storage WatchID (MVar a) + +watchHead :: forall a. HeadType a => Head a -> (Head a -> IO ()) -> IO WatchedHead +watchHead h = watchHeadWith h id + +watchHeadWith :: forall a b. (HeadType a, Eq b) => Head a -> (Head a -> b) -> (b -> IO ()) -> IO WatchedHead +watchHeadWith oh@(Head hid (Stored (Ref st _) _)) sel cb = do +    memo <- newEmptyMVar +    let tid = headTypeID @a Proxy +        addWatcher wl = (wl', WatchedHead st (wlNext wl) memo) +            where wl' = wl { wlNext = wlNext wl + 1 +                           , wlList = WatchListItem +                               { wlID = wlNext wl +                               , wlHead = (tid, hid) +                               , wlFun = \r -> do +                                   let x = sel $ Head hid $ wrappedLoad r +                                   modifyMVar_ memo $ \prev -> do +                                       when (x /= prev) $ cb x +                                       return x +                               } : wlList wl +                           } + +    watched <- case stBacking st of +         StorageDir { dirPath = spath, dirWatchers = mvar } -> modifyMVar mvar $ \(ilist, wl) -> do +             ilist' <- case lookup tid ilist of +                 Just _ -> return ilist +                 Nothing -> do +                     inotify <- initINotify +                     void $ addWatch inotify [Move] (BC.pack $ headTypePath spath tid) $ \case +                         MovedIn { filePath = fpath } | Just ihid <- HeadID <$> U.fromASCIIBytes fpath -> do +                             loadHead @a st ihid >>= \case +                                 Just h -> mapM_ ($ headRef h) . map wlFun . filter ((== (tid, ihid)) . wlHead) . wlList . snd =<< readMVar mvar +                                 Nothing -> return () +                         _ -> return () +                     return $ (tid, inotify) : ilist +             return $ first (ilist',) $ addWatcher wl + +         StorageMemory { memWatchers = mvar } -> modifyMVar mvar $ return . addWatcher + +    cur <- sel . maybe oh id <$> reloadHead oh +    cb cur +    putMVar memo cur + +    return watched + +unwatchHead :: WatchedHead -> IO () +unwatchHead (WatchedHead st wid _) = do +    let delWatcher wl = wl { wlList = filter ((/=wid) . wlID) $ wlList wl } +    case stBacking st of +        StorageDir { dirWatchers = mvar } -> modifyMVar_ mvar $ return . second delWatcher +        StorageMemory { memWatchers = mvar } -> modifyMVar_ mvar $ return . delWatcher + + +class Monad m => MonadStorage m where +    getStorage :: m Storage +    mstore :: Storable a => a -> m (Stored a) + +    default mstore :: MonadIO m => Storable a => a -> m (Stored a) +    mstore x = do +        st <- getStorage +        wrappedStore st x + +instance MonadIO m => MonadStorage (ReaderT Storage m) where +    getStorage = ask + +instance MonadIO m => MonadStorage (ReaderT (Head a) m) where +    getStorage = asks $ headStorage + + +class Storable a where +    store' :: a -> Store +    load' :: Load a + +    store :: StorageCompleteness c => Storage' c -> a -> IO (Ref' c) +    store st = evalStore st . store' +    load :: Ref -> a +    load = evalLoad load' + +class Storable a => ZeroStorable a where +    fromZero :: Storage -> a + +data Store = StoreBlob ByteString +           | StoreRec (forall c. StorageCompleteness c => Storage' c -> [IO [(ByteString, RecItem' c)]]) +           | StoreZero + +evalStore :: StorageCompleteness c => Storage' c -> Store -> IO (Ref' c) +evalStore st = unsafeStoreObject st <=< evalStoreObject st + +evalStoreObject :: StorageCompleteness c => Storage' c -> Store -> IO (Object' c) +evalStoreObject _ (StoreBlob x) = return $ Blob x +evalStoreObject s (StoreRec f) = Rec . concat <$> sequence (f s) +evalStoreObject _ StoreZero = return ZeroObject + +newtype StoreRecM c a = StoreRecM (ReaderT (Storage' c) (Writer [IO [(ByteString, RecItem' c)]]) a) +    deriving (Functor, Applicative, Monad) + +type StoreRec c = StoreRecM c () + +newtype Load a = Load (ReaderT (Ref, Object) (Except String) a) +    deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError String) + +evalLoad :: Load a -> Ref -> a +evalLoad (Load f) ref = either (error {- TODO throw -} . ((BC.unpack (showRef ref) ++ ": ")++)) id $ runExcept $ runReaderT f (ref, lazyLoadObject ref) + +loadCurrentRef :: Load Ref +loadCurrentRef = Load $ asks fst + +loadCurrentObject :: Load Object +loadCurrentObject = Load $ asks snd + +newtype LoadRec a = LoadRec (ReaderT (Ref, [(ByteString, RecItem)]) (Except String) a) +    deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError String) + +loadRecCurrentRef :: LoadRec Ref +loadRecCurrentRef = LoadRec $ asks fst + +loadRecItems :: LoadRec [(ByteString, RecItem)] +loadRecItems = LoadRec $ asks snd + + +instance Storable Object where +    store' (Blob bs) = StoreBlob bs +    store' (Rec xs) = StoreRec $ \st -> return $ do +        Rec xs' <- copyObject st (Rec xs) +        return xs' +    store' ZeroObject = StoreZero + +    load' = loadCurrentObject + +    store st = unsafeStoreObject st <=< copyObject st +    load = lazyLoadObject + +instance Storable ByteString where +    store' = storeBlob +    load' = loadBlob id + +instance Storable a => Storable [a] where +    store' []     = storeZero +    store' (x:xs) = storeRec $ do +        storeRef "i" x +        storeRef "n" xs + +    load' = loadCurrentObject >>= \case +                ZeroObject -> return [] +                _          -> loadRec $ (:) +                                  <$> loadRef "i" +                                  <*> loadRef "n" + +instance Storable a => ZeroStorable [a] where +    fromZero _ = [] + + +storeBlob :: ByteString -> Store +storeBlob = StoreBlob + +storeRec :: (forall c. StorageCompleteness c => StoreRec c) -> Store +storeRec sr = StoreRec $ do +    let StoreRecM r = sr +    execWriter . runReaderT r + +storeZero :: Store +storeZero = StoreZero + + +class StorableText a where +    toText :: a -> Text +    fromText :: MonadError String m => Text -> m a + +instance StorableText Text where +    toText = id; fromText = return + +instance StorableText [Char] where +    toText = T.pack; fromText = return . T.unpack + + +class StorableDate a where +    toDate :: a -> ZonedTime +    fromDate :: ZonedTime -> a + +instance StorableDate ZonedTime where +    toDate = id; fromDate = id + +instance StorableDate UTCTime where +    toDate = utcToZonedTime utc +    fromDate = zonedTimeToUTC + +instance StorableDate Day where +    toDate day = toDate $ UTCTime day 0 +    fromDate = utctDay . fromDate + + +class StorableUUID a where +    toUUID :: a -> UUID +    fromUUID :: UUID -> a + +instance StorableUUID UUID where +    toUUID = id; fromUUID = id + + +storeEmpty :: String -> StoreRec c +storeEmpty name = StoreRecM $ tell [return [(BC.pack name, RecEmpty)]] + +storeMbEmpty :: String -> Maybe () -> StoreRec c +storeMbEmpty name = maybe (return ()) (const $ storeEmpty name) + +storeInt :: Integral a => String -> a -> StoreRec c +storeInt name x = StoreRecM $ tell [return [(BC.pack name, RecInt $ toInteger x)]] + +storeMbInt :: Integral a => String -> Maybe a -> StoreRec c +storeMbInt name = maybe (return ()) (storeInt name) + +storeNum :: (Real a, Fractional a) => String -> a -> StoreRec c +storeNum name x = StoreRecM $ tell [return [(BC.pack name, RecNum $ toRational x)]] + +storeMbNum :: (Real a, Fractional a) => String -> Maybe a -> StoreRec c +storeMbNum name = maybe (return ()) (storeNum name) + +storeText :: StorableText a => String -> a -> StoreRec c +storeText name x = StoreRecM $ tell [return [(BC.pack name, RecText $ toText x)]] + +storeMbText :: StorableText a => String -> Maybe a -> StoreRec c +storeMbText name = maybe (return ()) (storeText name) + +storeBinary :: BA.ByteArrayAccess a => String -> a -> StoreRec c +storeBinary name x = StoreRecM $ tell [return [(BC.pack name, RecBinary $ BA.convert x)]] + +storeMbBinary :: BA.ByteArrayAccess a => String -> Maybe a -> StoreRec c +storeMbBinary name = maybe (return ()) (storeBinary name) + +storeDate :: StorableDate a => String -> a -> StoreRec c +storeDate name x = StoreRecM $ tell [return [(BC.pack name, RecDate $ toDate x)]] + +storeMbDate :: StorableDate a => String -> Maybe a -> StoreRec c +storeMbDate name = maybe (return ()) (storeDate name) + +storeUUID :: StorableUUID a => String -> a -> StoreRec c +storeUUID name x = StoreRecM $ tell [return [(BC.pack name, RecUUID $ toUUID x)]] + +storeMbUUID :: StorableUUID a => String -> Maybe a -> StoreRec c +storeMbUUID name = maybe (return ()) (storeUUID name) + +storeRef :: Storable a => StorageCompleteness c => String -> a -> StoreRec c +storeRef name x = StoreRecM $ do +    s <- ask +    tell $ (:[]) $ do +        ref <- store s x +        return [(BC.pack name, RecRef ref)] + +storeMbRef :: Storable a => StorageCompleteness c => String -> Maybe a -> StoreRec c +storeMbRef name = maybe (return ()) (storeRef name) + +storeRawRef :: StorageCompleteness c => String -> Ref -> StoreRec c +storeRawRef name ref = StoreRecM $ do +    st <- ask +    tell $ (:[]) $ do +        ref' <- copyRef st ref +        return [(BC.pack name, RecRef ref')] + +storeMbRawRef :: StorageCompleteness c => String -> Maybe Ref -> StoreRec c +storeMbRawRef name = maybe (return ()) (storeRawRef name) + +storeZRef :: (ZeroStorable a, StorageCompleteness c) => String -> a -> StoreRec c +storeZRef name x = StoreRecM $ do +    s <- ask +    tell $ (:[]) $ do +        ref <- store s x +        return $ if isZeroRef ref then [] +                                  else [(BC.pack name, RecRef ref)] + + +loadBlob :: (ByteString -> a) -> Load a +loadBlob f = loadCurrentObject >>= \case +    Blob x -> return $ f x +    _      -> throwError "Expecting blob" + +loadRec :: LoadRec a -> Load a +loadRec (LoadRec lrec) = loadCurrentObject >>= \case +    Rec rs -> do +        ref <- loadCurrentRef +        either throwError return $ runExcept $ runReaderT lrec (ref, rs) +    _ -> throwError "Expecting record" + +loadZero :: a -> Load a +loadZero x = loadCurrentObject >>= \case +    ZeroObject -> return x +    _          -> throwError "Expecting zero" + + +loadEmpty :: String -> LoadRec () +loadEmpty name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbEmpty name + +loadMbEmpty :: String -> LoadRec (Maybe ()) +loadMbEmpty name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecEmpty) -> return (Just ()) +    Just _ -> throwError $ "Expecting type int of record item '"++name++"'" + +loadInt :: Num a => String -> LoadRec a +loadInt name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbInt name + +loadMbInt :: Num a => String -> LoadRec (Maybe a) +loadMbInt name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecInt x) -> return (Just $ fromInteger x) +    Just _ -> throwError $ "Expecting type int of record item '"++name++"'" + +loadNum :: (Real a, Fractional a) => String -> LoadRec a +loadNum name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbNum name + +loadMbNum :: (Real a, Fractional a) => String -> LoadRec (Maybe a) +loadMbNum name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecNum x) -> return (Just $ fromRational x) +    Just _ -> throwError $ "Expecting type number of record item '"++name++"'" + +loadText :: StorableText a => String -> LoadRec a +loadText name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbText name + +loadMbText :: StorableText a => String -> LoadRec (Maybe a) +loadMbText name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecText x) -> Just <$> fromText x +    Just _ -> throwError $ "Expecting type text of record item '"++name++"'" + +loadTexts :: StorableText a => String -> LoadRec [a] +loadTexts name = do +    items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems +    forM items $ \case RecText x -> fromText x +                       _ -> throwError $ "Expecting type text of record item '"++name++"'" + +loadBinary :: BA.ByteArray a => String -> LoadRec a +loadBinary name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbBinary name + +loadMbBinary :: BA.ByteArray a => String -> LoadRec (Maybe a) +loadMbBinary name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecBinary x) -> return $ Just $ BA.convert x +    Just _ -> throwError $ "Expecting type binary of record item '"++name++"'" + +loadBinaries :: BA.ByteArray a => String -> LoadRec [a] +loadBinaries name = do +    items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems +    forM items $ \case RecBinary x -> return $ BA.convert x +                       _ -> throwError $ "Expecting type binary of record item '"++name++"'" + +loadDate :: StorableDate a => String -> LoadRec a +loadDate name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbDate name + +loadMbDate :: StorableDate a => String -> LoadRec (Maybe a) +loadMbDate name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecDate x) -> return $ Just $ fromDate x +    Just _ -> throwError $ "Expecting type date of record item '"++name++"'" + +loadUUID :: StorableUUID a => String -> LoadRec a +loadUUID name = maybe (throwError $ "Missing record iteem '"++name++"'") return =<< loadMbUUID name + +loadMbUUID :: StorableUUID a => String -> LoadRec (Maybe a) +loadMbUUID name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecUUID x) -> return $ Just $ fromUUID x +    Just _ -> throwError $ "Expecting type UUID of record item '"++name++"'" + +loadRawRef :: String -> LoadRec Ref +loadRawRef name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbRawRef name + +loadMbRawRef :: String -> LoadRec (Maybe Ref) +loadMbRawRef name = (lookup (BC.pack name) <$> loadRecItems) >>= \case +    Nothing -> return Nothing +    Just (RecRef x) -> return (Just x) +    Just _ -> throwError $ "Expecting type ref of record item '"++name++"'" + +loadRawRefs :: String -> LoadRec [Ref] +loadRawRefs name = do +    items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems +    forM items $ \case RecRef x -> return x +                       _ -> throwError $ "Expecting type ref of record item '"++name++"'" + +loadRef :: Storable a => String -> LoadRec a +loadRef name = load <$> loadRawRef name + +loadMbRef :: Storable a => String -> LoadRec (Maybe a) +loadMbRef name = fmap load <$> loadMbRawRef name + +loadRefs :: Storable a => String -> LoadRec [a] +loadRefs name = map load <$> loadRawRefs name + +loadZRef :: ZeroStorable a => String -> LoadRec a +loadZRef name = loadMbRef name >>= \case +                    Nothing -> do Ref st _ <- loadRecCurrentRef +                                  return $ fromZero st +                    Just x  -> return x + + +type Stored a = Stored' Complete a + +instance Storable a => Storable (Stored a) where +    store st = copyRef st . storedRef +    store' (Stored _ x) = store' x +    load' = Stored <$> loadCurrentRef <*> load' + +instance ZeroStorable a => ZeroStorable (Stored a) where +    fromZero st = Stored (zeroRef st) $ fromZero st + +fromStored :: Stored a -> a +fromStored (Stored _ x) = x + +storedRef :: Stored a -> Ref +storedRef (Stored ref _) = ref + +wrappedStore :: MonadIO m => Storable a => Storage -> a -> m (Stored a) +wrappedStore st x = do ref <- liftIO $ store st x +                       return $ Stored ref x + +wrappedLoad :: Storable a => Ref -> Stored a +wrappedLoad ref = Stored ref (load ref) + +copyStored :: forall c c' m a. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => +    Storage' c' -> Stored' c a -> m (LoadResult c (Stored' c' a)) +copyStored st (Stored ref' x) = liftIO $ returnLoadResult . fmap (flip Stored x) <$> copyRef' st ref' + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapStored :: (a -> b) -> Stored a -> Stored b +unsafeMapStored f (Stored ref x) = Stored ref (f x) + + +data StoreInfo = StoreInfo +    { infoDate :: ZonedTime +    , infoNote :: Maybe Text +    } +    deriving (Show) + +makeStoreInfo :: IO StoreInfo +makeStoreInfo = StoreInfo +    <$> getZonedTime +    <*> pure Nothing + +storeInfoRec :: StoreInfo -> StoreRec c +storeInfoRec info = do +    storeDate "date" $ infoDate info +    storeMbText "note" $ infoNote info + +loadInfoRec :: LoadRec StoreInfo +loadInfoRec = StoreInfo +    <$> loadDate "date" +    <*> loadMbText "note" + + +data History a = History StoreInfo (Stored a) (Maybe (StoredHistory a)) +    deriving (Show) + +type StoredHistory a = Stored (History a) + +instance Storable a => Storable (History a) where +    store' (History si x prev) = storeRec $ do +        storeInfoRec si +        storeMbRef "prev" prev +        storeRef "item" x + +    load' = loadRec $ History +        <$> loadInfoRec +        <*> loadRef "item" +        <*> loadMbRef "prev" + +fromHistory :: StoredHistory a -> a +fromHistory = fromStored . storedFromHistory + +fromHistoryAt :: ZonedTime -> StoredHistory a -> Maybe a +fromHistoryAt zat = fmap (fromStored . snd) . listToMaybe . dropWhile ((at<) . zonedTimeToUTC . fst) . storedHistoryTimedList +    where at = zonedTimeToUTC zat + +storedFromHistory :: StoredHistory a -> Stored a +storedFromHistory sh = let History _ item _ = fromStored sh +                        in item + +storedHistoryList :: StoredHistory a -> [Stored a] +storedHistoryList = map snd . storedHistoryTimedList + +storedHistoryTimedList :: StoredHistory a -> [(ZonedTime, Stored a)] +storedHistoryTimedList sh = let History hinfo item prev = fromStored sh +                             in (infoDate hinfo, item) : maybe [] storedHistoryTimedList prev + +beginHistory :: Storable a => Storage -> StoreInfo -> a -> IO (StoredHistory a) +beginHistory st si x = do sx <- wrappedStore st x +                          wrappedStore st $ History si sx Nothing + +modifyHistory :: Storable a => StoreInfo -> (a -> a) -> StoredHistory a -> IO (StoredHistory a) +modifyHistory si f prev@(Stored (Ref st _) _) = do +    sx <- wrappedStore st $ f $ fromHistory prev +    wrappedStore st $ History si sx (Just prev) + + +showRatio :: Rational -> String +showRatio r = case decimalRatio r of +                   Just (n, 1) -> show n +                   Just (n', d) -> let n = abs n' +                                    in (if n' < 0 then "-" else "") ++ show (n `div` d) ++ "." ++ +                                       (concatMap (show.(`mod` 10).snd) $ reverse $ takeWhile ((>1).fst) $ zip (iterate (`div` 10) d) (iterate (`div` 10) (n `mod` d))) +                   Nothing -> show (numerator r) ++ "/" ++ show (denominator r) + +decimalRatio :: Rational -> Maybe (Integer, Integer) +decimalRatio r = do +    let n = numerator r +        d = denominator r +        (c2, d') = takeFactors 2 d +        (c5, d'') = takeFactors 5 d' +    guard $ d'' == 1 +    let m = if c2 > c5 then 5 ^ (c2 - c5) +                       else 2 ^ (c5 - c2) +    return (n * m, d * m) + +takeFactors :: Integer -> Integer -> (Integer, Integer) +takeFactors f n | n `mod` f == 0 = let (c, n') = takeFactors f (n `div` f) +                                    in (c+1, n') +                | otherwise = (0, n) + +parseRatio :: ByteString -> Maybe Rational +parseRatio bs = case BC.groupBy ((==) `on` isNumber) bs of +                     (m:xs) | m == BC.pack "-" -> negate <$> positive xs +                     xs                        -> positive xs +    where positive = \case +              [bx] -> fromInteger . fst <$> BC.readInteger bx +              [bx, op, by] -> do +                  (x, _) <- BC.readInteger bx +                  (y, _) <- BC.readInteger by +                  case BC.unpack op of +                       "." -> return $ (x % 1) + (y % (10 ^ BC.length by)) +                       "/" -> return $ x % y +                       _   -> Nothing +              _ -> Nothing diff --git a/src/Erebos/Storage/Internal.hs b/src/Erebos/Storage/Internal.hs new file mode 100644 index 0000000..a61e705 --- /dev/null +++ b/src/Erebos/Storage/Internal.hs @@ -0,0 +1,282 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Storage.Internal where + +import Codec.Compression.Zlib + +import Control.Arrow +import Control.Concurrent +import Control.DeepSeq +import Control.Exception +import Control.Monad +import Control.Monad.Identity + +import Crypto.Hash + +import Data.Bits +import Data.ByteArray (ByteArray, ByteArrayAccess, ScrubbedBytes) +import qualified Data.ByteArray as BA +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import Data.Char +import Data.Function +import Data.Hashable +import qualified Data.HashTable.IO as HT +import Data.Kind +import Data.List +import Data.Map (Map) +import qualified Data.Map as M +import Data.UUID (UUID) + +import Foreign.Storable (peek) + +import System.Directory +import System.FilePath +import System.INotify (INotify) +import System.IO +import System.IO.Error +import System.IO.Unsafe (unsafePerformIO) +import System.Posix.Files +import System.Posix.IO + + +data Storage' c = Storage +    { stBacking :: StorageBacking c +    , stParent :: Maybe (Storage' Identity) +    , stRefGeneration :: MVar (HT.BasicHashTable RefDigest Generation) +    , stRefRoots :: MVar (HT.BasicHashTable RefDigest [RefDigest]) +    } + +instance Eq (Storage' c) where +    (==) = (==) `on` (stBacking &&& stParent) + +instance Show (Storage' c) where +    show st@(Storage { stBacking = StorageDir { dirPath = path }}) = "dir" ++ showParentStorage st ++ ":" ++ path +    show st@(Storage { stBacking = StorageMemory {} }) = "mem" ++ showParentStorage st + +showParentStorage :: Storage' c -> String +showParentStorage Storage { stParent = Nothing } = "" +showParentStorage Storage { stParent = Just st } = "@" ++ show st + +data StorageBacking c +         = StorageDir { dirPath :: FilePath +                      , dirWatchers :: MVar ([(HeadTypeID, INotify)], WatchList c) +                      } +         | StorageMemory { memHeads :: MVar [((HeadTypeID, HeadID), Ref' c)] +                         , memObjs :: MVar (Map RefDigest BL.ByteString) +                         , memKeys :: MVar (Map RefDigest ScrubbedBytes) +                         , memWatchers :: MVar (WatchList c) +                         } +    deriving (Eq) + +newtype WatchID = WatchID Int +    deriving (Eq, Ord, Num) + +data WatchList c = WatchList +    { wlNext :: WatchID +    , wlList :: [WatchListItem c] +    } + +data WatchListItem c = WatchListItem +    { wlID :: WatchID +    , wlHead :: (HeadTypeID, HeadID) +    , wlFun :: Ref' c -> IO () +    } + + +newtype RefDigest = RefDigest (Digest Blake2b_256) +    deriving (Eq, Ord, NFData, ByteArrayAccess) + +instance Show RefDigest where +    show = BC.unpack . showRefDigest + +data Ref' c = Ref (Storage' c) RefDigest + +instance Eq (Ref' c) where +    Ref _ d1 == Ref _ d2  =  d1 == d2 + +instance Show (Ref' c) where +    show ref@(Ref st _) = show st ++ ":" ++ BC.unpack (showRef ref) + +instance ByteArrayAccess (Ref' c) where +    length (Ref _ dgst) = BA.length dgst +    withByteArray (Ref _ dgst) = BA.withByteArray dgst + +instance Hashable RefDigest where +    hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +instance Hashable (Ref' c) where +    hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +refStorage :: Ref' c -> Storage' c +refStorage (Ref st _) = st + +refDigest :: Ref' c -> RefDigest +refDigest (Ref _ dgst) = dgst + +showRef :: Ref' c -> ByteString +showRef = showRefDigest . refDigest + +showRefDigestParts :: RefDigest -> (ByteString, ByteString) +showRefDigestParts x = (BC.pack "blake2", showHex x) + +showRefDigest :: RefDigest -> ByteString +showRefDigest = showRefDigestParts >>> \(alg, hex) -> alg <> BC.pack "#" <> hex + +readRefDigest :: ByteString -> Maybe RefDigest +readRefDigest x = case BC.split '#' x of +                       [alg, dgst] | BA.convert alg == BC.pack "blake2" -> +                           refDigestFromByteString =<< readHex @ByteString dgst +                       _ -> Nothing + +refDigestFromByteString :: ByteArrayAccess ba => ba -> Maybe RefDigest +refDigestFromByteString = fmap RefDigest . digestFromByteString + +hashToRefDigest :: BL.ByteString -> RefDigest +hashToRefDigest = RefDigest . hashFinalize . hashUpdates hashInit . BL.toChunks + +showHex :: ByteArrayAccess ba => ba -> ByteString +showHex = B.concat . map showHexByte . BA.unpack +    where showHexChar x | x < 10    = x + o '0' +                        | otherwise = x + o 'a' - 10 +          showHexByte x = B.pack [ showHexChar (x `div` 16), showHexChar (x `mod` 16) ] +          o = fromIntegral . ord + +readHex :: ByteArray ba => ByteString -> Maybe ba +readHex = return . BA.concat <=< readHex' +    where readHex' bs | B.null bs = Just [] +          readHex' bs = do (bx, bs') <- B.uncons bs +                           (by, bs'') <- B.uncons bs' +                           x <- hexDigit bx +                           y <- hexDigit by +                           (B.singleton (x * 16 + y) :) <$> readHex' bs'' +          hexDigit x | x >= o '0' && x <= o '9' = Just $ x - o '0' +                     | x >= o 'a' && x <= o 'z' = Just $ x - o 'a' + 10 +                     | otherwise                = Nothing +          o = fromIntegral . ord + + +newtype Generation = Generation Int +    deriving (Eq, Show) + +data Head' c a = Head HeadID (Stored' c a) +    deriving (Eq, Show) + +newtype HeadID = HeadID UUID +    deriving (Eq, Ord, Show) + +newtype HeadTypeID = HeadTypeID UUID +    deriving (Eq, Ord) + +data Stored' c a = Stored (Ref' c) a +    deriving (Show) + +instance Eq (Stored' c a) where +    Stored r1 _ == Stored r2 _  =  refDigest r1 == refDigest r2 + +instance Ord (Stored' c a) where +    compare (Stored r1 _) (Stored r2 _) = compare (refDigest r1) (refDigest r2) + +storedStorage :: Stored' c a -> Storage' c +storedStorage (Stored (Ref st _) _) = st + + +type Complete = Identity +type Partial = Either RefDigest + +class (Traversable compl, Monad compl) => StorageCompleteness compl where +    type LoadResult compl a :: Type +    returnLoadResult :: compl a -> LoadResult compl a +    ioLoadBytes :: Ref' compl -> IO (compl BL.ByteString) + +instance StorageCompleteness Complete where +    type LoadResult Complete a = a +    returnLoadResult = runIdentity +    ioLoadBytes ref@(Ref st dgst) = maybe (error $ "Ref not found in complete storage: "++show ref) Identity +        <$> ioLoadBytesFromStorage st dgst + +instance StorageCompleteness Partial where +    type LoadResult Partial a = Either RefDigest a +    returnLoadResult = id +    ioLoadBytes (Ref st dgst) = maybe (Left dgst) Right <$> ioLoadBytesFromStorage st dgst + +unsafeStoreRawBytes :: Storage' c -> BL.ByteString -> IO (Ref' c) +unsafeStoreRawBytes st raw = do +    let dgst = hashToRefDigest raw +    case stBacking st of +         StorageDir { dirPath = sdir } -> writeFileOnce (refPath sdir dgst) $ compress raw +         StorageMemory { memObjs = tobjs } -> +             dgst `deepseq` -- the TVar may be accessed when evaluating the data to be written +                 modifyMVar_ tobjs (return . M.insert dgst raw) +    return $ Ref st dgst + +ioLoadBytesFromStorage :: Storage' c -> RefDigest -> IO (Maybe BL.ByteString) +ioLoadBytesFromStorage st dgst = loadCurrent st >>= +    \case Just bytes -> return $ Just bytes +          Nothing | Just parent <- stParent st -> ioLoadBytesFromStorage parent dgst +                  | otherwise                  -> return Nothing +    where loadCurrent Storage { stBacking = StorageDir { dirPath = spath } } = handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ +              Just . decompress . BL.fromChunks . (:[]) <$> (B.readFile $ refPath spath dgst) +          loadCurrent Storage { stBacking = StorageMemory { memObjs = tobjs } } = M.lookup dgst <$> readMVar tobjs + +refPath :: FilePath -> RefDigest -> FilePath +refPath spath rdgst = intercalate "/" [spath, "objects", BC.unpack alg, pref, rest] +    where (alg, dgst) = showRefDigestParts rdgst +          (pref, rest) = splitAt 2 $ BC.unpack dgst + + +openLockFile :: FilePath -> IO Handle +openLockFile path = do +    createDirectoryIfMissing True (takeDirectory path) +    fd <- retry 10 $ +#if MIN_VERSION_unix(2,8,0) +        openFd path WriteOnly defaultFileFlags +            { creat = Just $ unionFileModes ownerReadMode ownerWriteMode +            , exclusive = True +            } +#else +        openFd path WriteOnly (Just $ unionFileModes ownerReadMode ownerWriteMode) (defaultFileFlags { exclusive = True }) +#endif +    fdToHandle fd +  where +    retry :: Int -> IO a -> IO a +    retry 0 act = act +    retry n act = catchJust (\e -> if isAlreadyExistsError e then Just () else Nothing) +                      act (\_ -> threadDelay (100 * 1000) >> retry (n - 1) act) + +writeFileOnce :: FilePath -> BL.ByteString -> IO () +writeFileOnce file content = bracket (openLockFile locked) +    hClose $ \h -> do +        fileExist file >>= \case +            True  -> removeLink locked +            False -> do BL.hPut h content +                        hFlush h +                        rename locked file +    where locked = file ++ ".lock" + +writeFileChecked :: FilePath -> Maybe ByteString -> ByteString -> IO (Either (Maybe ByteString) ()) +writeFileChecked file prev content = bracket (openLockFile locked) +    hClose $ \h -> do +        (prev,) <$> fileExist file >>= \case +            (Nothing, True) -> do +                current <- B.readFile file +                removeLink locked +                return $ Left $ Just current +            (Nothing, False) -> do B.hPut h content +                                   hFlush h +                                   rename locked file +                                   return $ Right () +            (Just expected, True) -> do +                current <- B.readFile file +                if current == expected then do B.hPut h content +                                               hFlush h +                                               rename locked file +                                               return $ return () +                                       else do removeLink locked +                                               return $ Left $ Just current +            (Just _, False) -> do +                removeLink locked +                return $ Left Nothing +    where locked = file ++ ".lock" diff --git a/src/Erebos/Storage/Key.hs b/src/Erebos/Storage/Key.hs new file mode 100644 index 0000000..b6afc20 --- /dev/null +++ b/src/Erebos/Storage/Key.hs @@ -0,0 +1,85 @@ +module Erebos.Storage.Key ( +    KeyPair(..), +    storeKey, loadKey, loadKeyMb, +    moveKeys, +) where + +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Data.ByteArray +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.Map as M + +import System.Directory +import System.FilePath +import System.IO.Error + +import Erebos.Storage +import Erebos.Storage.Internal + +class Storable pub => KeyPair sec pub | sec -> pub, pub -> sec where +    generateKeys :: Storage -> IO (sec, Stored pub) +    keyGetPublic :: sec -> Stored pub +    keyGetData :: sec -> ScrubbedBytes +    keyFromData :: ScrubbedBytes -> Stored pub -> Maybe sec + + +keyFilePath :: KeyPair sec pub => FilePath -> Stored pub -> FilePath +keyFilePath sdir pkey = sdir </> "keys" </> (BC.unpack $ showRef $ storedRef pkey) + +storeKey :: KeyPair sec pub => sec -> IO () +storeKey key = do +    let spub = keyGetPublic key +    case stBacking $ storedStorage spub of +         StorageDir { dirPath = dir } -> writeFileOnce (keyFilePath dir spub) (BL.fromStrict $ convert $ keyGetData key) +         StorageMemory { memKeys = kstore } -> modifyMVar_ kstore $ return . M.insert (refDigest $ storedRef spub) (keyGetData key) + +loadKey :: (KeyPair sec pub, MonadIO m, MonadError String m) => Stored pub -> m sec +loadKey pub = maybe (throwError $ "secret key not found for " <> show (storedRef pub)) return =<< loadKeyMb pub + +loadKeyMb :: (KeyPair sec pub, MonadIO m) => Stored pub -> m (Maybe sec) +loadKeyMb spub = liftIO $ run $ storedStorage spub +  where +    run st = tryOneLevel (stBacking st) >>= \case +        key@Just {} -> return key +        Nothing | Just parent <- stParent st -> run parent +                | otherwise -> return Nothing +    tryOneLevel = \case +        StorageDir { dirPath = dir } -> tryIOError (BC.readFile (keyFilePath dir spub)) >>= \case +            Right kdata -> return $ keyFromData (convert kdata) spub +            Left _ -> return Nothing +        StorageMemory { memKeys = kstore } -> (flip keyFromData spub <=< M.lookup (refDigest $ storedRef spub)) <$> readMVar kstore + +moveKeys :: MonadIO m => Storage -> Storage -> m () +moveKeys from to = liftIO $ do +    case (stBacking from, stBacking to) of +        (StorageDir { dirPath = fromPath }, StorageDir { dirPath = toPath }) -> do +            files <- listDirectory (fromPath </> "keys") +            forM_ files $ \file -> do +                renameFile (fromPath </> "keys" </> file) (toPath </> "keys" </> file) + +        (StorageDir { dirPath = fromPath }, StorageMemory { memKeys = toKeys }) -> do +            let move m file +                    | Just dgst <- readRefDigest (BC.pack file) = do +                        let path = fromPath </> "keys" </> file +                        key <- convert <$> BC.readFile path +                        removeFile path +                        return $ M.insert dgst key m +                    | otherwise = return m +            files <- listDirectory (fromPath </> "keys") +            modifyMVar_ toKeys $ \keys -> foldM move keys files + +        (StorageMemory { memKeys = fromKeys }, StorageDir { dirPath = toPath }) -> do +            modifyMVar_ fromKeys $ \keys -> do +                forM_ (M.assocs keys) $ \(dgst, key) -> +                    writeFileOnce (toPath </> "keys" </> (BC.unpack $ showRefDigest dgst)) (BL.fromStrict $ convert key) +                return M.empty + +        (StorageMemory { memKeys = fromKeys }, StorageMemory { memKeys = toKeys }) -> do +            modifyMVar_ fromKeys $ \fkeys -> do +                modifyMVar_ toKeys $ return . M.union fkeys +                return M.empty diff --git a/src/Erebos/Storage/List.hs b/src/Erebos/Storage/List.hs new file mode 100644 index 0000000..f0f8786 --- /dev/null +++ b/src/Erebos/Storage/List.hs @@ -0,0 +1,154 @@ +module Erebos.Storage.List ( +    StoredList, +    emptySList, fromSList, storedFromSList, +    slistAdd, slistAddS, +    -- TODO slistInsert, slistInsertS, +    slistRemove, slistReplace, slistReplaceS, +    -- TODO mapFromSList, updateOld, + +    -- TODO StoreUpdate(..), +    -- TODO withStoredListItem, withStoredListItemS, +) where + +import Data.List +import Data.Maybe +import qualified Data.Set as S + +import Erebos.Storage +import Erebos.Storage.Internal +import Erebos.Storage.Merge + +data List a = ListNil +            | ListItem { listPrev :: [StoredList a] +                       , listItem :: Maybe (Stored a) +                       , listRemove :: Maybe (Stored (List a)) +                       } + +type StoredList a = Stored (List a) + +instance Storable a => Storable (List a) where +    store' ListNil = storeZero +    store' x@ListItem {} = storeRec $ do +        mapM_ (storeRef "PREV") $ listPrev x +        mapM_ (storeRef "item") $ listItem x +        mapM_ (storeRef "remove") $ listRemove x + +    load' = loadCurrentObject >>= \case +        ZeroObject -> return ListNil +        _ -> loadRec $ ListItem <$> loadRefs "PREV" +                                <*> loadMbRef "item" +                                <*> loadMbRef "remove" + +instance Storable a => ZeroStorable (List a) where +    fromZero _ = ListNil + + +emptySList :: Storable a => Storage -> IO (StoredList a) +emptySList st = wrappedStore st ListNil + +groupsFromSLists :: forall a. Storable a => StoredList a -> [[Stored a]] +groupsFromSLists = helperSelect S.empty . (:[]) +  where +    helperSelect :: S.Set (StoredList a) -> [StoredList a] -> [[Stored a]] +    helperSelect rs xxs | x:xs <- sort $ filterRemoved rs xxs = helper rs x xs +                        | otherwise = [] + +    helper :: S.Set (StoredList a) -> StoredList a -> [StoredList a] -> [[Stored a]] +    helper rs x xs +        | ListNil <- fromStored x +        = [] + +        | Just rm <- listRemove (fromStored x) +        , ans <- ancestors [x] +        , (other, collision) <- partition (S.null . S.intersection ans . ancestors . (:[])) xs +        , cont <- helperSelect (rs `S.union` ancestors [rm]) $ concatMap (listPrev . fromStored) (x : collision) ++ other +        = case catMaybes $ map (listItem . fromStored) (x : collision) of +               [] -> cont +               xis -> xis : cont + +        | otherwise = case listItem (fromStored x) of +                           Nothing -> helperSelect rs $ listPrev (fromStored x) ++ xs +                           Just xi -> [xi] : (helperSelect rs $ listPrev (fromStored x) ++ xs) + +    filterRemoved :: S.Set (StoredList a) -> [StoredList a] -> [StoredList a] +    filterRemoved rs = filter (S.null . S.intersection rs . ancestors . (:[])) + +fromSList :: Mergeable a => StoredList (Component a) -> [a] +fromSList = map merge . groupsFromSLists + +storedFromSList :: (Mergeable a, Storable a) => StoredList (Component a) -> IO [Stored a] +storedFromSList = mapM storeMerge . groupsFromSLists + +slistAdd :: Storable a => a -> StoredList a -> IO (StoredList a) +slistAdd x prev@(Stored (Ref st _) _) = do +    sx <- wrappedStore st x +    slistAddS sx prev + +slistAddS :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistAddS sx prev@(Stored (Ref st _) _) = wrappedStore st (ListItem [prev] (Just sx) Nothing) + +{- TODO +slistInsert :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistInsert after x prev@(Stored (Ref st _) _) = do +    sx <- wrappedStore st x +    slistInsertS after sx prev + +slistInsertS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistInsertS after sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem Nothing (findSListRef after prev) (Just sx) prev +-} + +slistRemove :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistRemove rm prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] Nothing (findSListRef rm prev) + +slistReplace :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistReplace rm x prev@(Stored (Ref st _) _) = do +    sx <- wrappedStore st x +    slistReplaceS rm sx prev + +slistReplaceS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistReplaceS rm sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] (Just sx) (findSListRef rm prev) + +findSListRef :: Stored a -> StoredList a -> Maybe (StoredList a) +findSListRef _ (Stored _ ListNil) = Nothing +findSListRef x cur | listItem (fromStored cur) == Just x = Just cur +                   | otherwise                           = listToMaybe $ catMaybes $ map (findSListRef x) $ listPrev $ fromStored cur + +{- TODO +mapFromSList :: Storable a => StoredList a -> Map RefDigest (Stored a) +mapFromSList list = helper list M.empty +    where helper :: Storable a => StoredList a -> Map RefDigest (Stored a) -> Map RefDigest (Stored a) +          helper (Stored _ ListNil) cur = cur +          helper (Stored _ (ListItem (Just rref) _ (Just x) rest)) cur = +              let rxref = case load rref of +                               ListItem _ _ (Just rx) _  -> sameType rx x $ storedRef rx +                               _ -> error "mapFromSList: malformed list" +               in helper rest $ case M.lookup (refDigest $ storedRef x) cur of +                                     Nothing -> M.insert (refDigest rxref) x cur +                                     Just x' -> M.insert (refDigest rxref) x' cur +          helper (Stored _ (ListItem _ _ _ rest)) cur = helper rest cur +          sameType :: a -> a -> b -> b +          sameType _ _ x = x + +updateOld :: Map RefDigest (Stored a) -> Stored a -> Stored a +updateOld m x = fromMaybe x $ M.lookup (refDigest $ storedRef x) m + + +data StoreUpdate a = StoreKeep +                   | StoreReplace a +                   | StoreRemove + +withStoredListItem :: (Storable a) => (a -> Bool) -> StoredList a -> (a -> IO (StoreUpdate a)) -> IO (StoredList a) +withStoredListItem p list f = withStoredListItemS (p . fromStored) list (suMap (wrappedStore $ storedStorage list) <=< f . fromStored) +    where suMap :: Monad m => (a -> m b) -> StoreUpdate a -> m (StoreUpdate b) +          suMap _ StoreKeep = return StoreKeep +          suMap g (StoreReplace x) = return . StoreReplace =<< g x +          suMap _ StoreRemove = return StoreRemove + +withStoredListItemS :: (Storable a) => (Stored a -> Bool) -> StoredList a -> (Stored a -> IO (StoreUpdate (Stored a))) -> IO (StoredList a) +withStoredListItemS p list f = do +    case find p $ storedFromSList list of +         Just sx -> f sx >>= \case StoreKeep -> return list +                                   StoreReplace nx -> slistReplaceS sx nx list +                                   StoreRemove -> slistRemove sx list +         Nothing -> return list +-} diff --git a/src/Erebos/Storage/Merge.hs b/src/Erebos/Storage/Merge.hs new file mode 100644 index 0000000..7234b87 --- /dev/null +++ b/src/Erebos/Storage/Merge.hs @@ -0,0 +1,156 @@ +module Erebos.Storage.Merge ( +    Mergeable(..), +    merge, storeMerge, + +    Generation, +    showGeneration, +    compareGeneration, generationMax, +    storedGeneration, + +    generations, +    ancestors, +    precedes, +    filterAncestors, +    storedRoots, +    walkAncestors, + +    findProperty, +    findPropertyFirst, +) where + +import Control.Concurrent.MVar + +import Data.ByteString.Char8 qualified as BC +import Data.HashTable.IO qualified as HT +import Data.Kind +import Data.List +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S + +import System.IO.Unsafe (unsafePerformIO) + +import Erebos.Storage +import Erebos.Storage.Internal +import Erebos.Util + +class Storable (Component a) => Mergeable a where +    type Component a :: Type +    mergeSorted :: [Stored (Component a)] -> a +    toComponents :: a -> [Stored (Component a)] + +instance Mergeable [Stored Object] where +    type Component [Stored Object] = Object +    mergeSorted = id +    toComponents = id + +merge :: Mergeable a => [Stored (Component a)] -> a +merge [] = error "merge: empty list" +merge xs = mergeSorted $ filterAncestors xs + +storeMerge :: (Mergeable a, Storable a) => [Stored (Component a)] -> IO (Stored a) +storeMerge [] = error "merge: empty list" +storeMerge xs@(Stored ref _ : _) = wrappedStore (refStorage ref) $ mergeSorted $ filterAncestors xs + +previous :: Storable a => Stored a -> [Stored a] +previous (Stored ref _) = case load ref of +    Rec items | Just (RecRef dref) <- lookup (BC.pack "SDATA") items +              , Rec ditems <- load dref -> +                    map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ +                        map snd $ filter ((`elem` [ BC.pack "SPREV", BC.pack "SBASE" ]) . fst) ditems + +              | otherwise -> +                    map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ +                        map snd $ filter ((`elem` [ BC.pack "PREV", BC.pack "BASE" ]) . fst) items +    _ -> [] + + +nextGeneration :: [Generation] -> Generation +nextGeneration = foldl' helper (Generation 0) +    where helper (Generation c) (Generation n) | c <= n    = Generation (n + 1) +                                               | otherwise = Generation c + +showGeneration :: Generation -> String +showGeneration (Generation x) = show x + +compareGeneration :: Generation -> Generation -> Maybe Ordering +compareGeneration (Generation x) (Generation y) = Just $ compare x y + +generationMax :: Storable a => [Stored a] -> Maybe (Stored a) +generationMax (x : xs) = Just $ snd $ foldl' helper (storedGeneration x, x) xs +    where helper (mg, mx) y = let yg = storedGeneration y +                               in case compareGeneration mg yg of +                                       Just LT -> (yg, y) +                                       _       -> (mg, mx) +generationMax [] = Nothing + +storedGeneration :: Storable a => Stored a -> Generation +storedGeneration x = +    unsafePerformIO $ withMVar (stRefGeneration $ refStorage $ storedRef x) $ \ht -> do +        let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case +                Just gen -> return gen +                Nothing -> do +                    gen <- nextGeneration <$> mapM doLookup (previous y) +                    HT.insert ht (refDigest $ storedRef y) gen +                    return gen +        doLookup x + + +generations :: Storable a => [Stored a] -> [Set (Stored a)] +generations = unfoldr gen . (,S.empty) +    where gen (hs, cur) = case filter (`S.notMember` cur) $ previous =<< hs of +              []    -> Nothing +              added -> let next = foldr S.insert cur added +                        in Just (next, (added, next)) + +ancestors :: Storable a => [Stored a] -> Set (Stored a) +ancestors = last . (S.empty:) . generations + +precedes :: Storable a => Stored a -> Stored a -> Bool +precedes x y = not $ x `elem` filterAncestors [x, y] + +filterAncestors :: Storable a => [Stored a] -> [Stored a] +filterAncestors [x] = [x] +filterAncestors xs = let xs' = uniq $ sort xs +                      in helper xs' xs' +    where helper remains walk = case generationMax walk of +                                     Just x -> let px = previous x +                                                   remains' = filter (\r -> all (/=r) px) remains +                                                in helper remains' $ uniq $ sort (px ++ filter (/=x) walk) +                                     Nothing -> remains + +storedRoots :: Storable a => Stored a -> [Stored a] +storedRoots x = do +    let st = refStorage $ storedRef x +    unsafePerformIO $ withMVar (stRefRoots st) $ \ht -> do +        let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case +                Just roots -> return roots +                Nothing -> do +                    roots <- case previous y of +                        [] -> return [refDigest $ storedRef y] +                        ps -> map (refDigest . storedRef) . filterAncestors . map (wrappedLoad @Object . Ref st) . concat <$> mapM doLookup ps +                    HT.insert ht (refDigest $ storedRef y) roots +                    return roots +        map (wrappedLoad . Ref st) <$> doLookup x + +walkAncestors :: (Storable a, Monoid m) => (Stored a -> m) -> [Stored a] -> m +walkAncestors f = helper . sortBy cmp +  where +    helper (x : y : xs) | x == y = helper (x : xs) +    helper (x : xs) = f x <> helper (mergeBy cmp (sortBy cmp (previous x)) xs) +    helper [] = mempty + +    cmp x y = case compareGeneration (storedGeneration x) (storedGeneration y) of +                   Just LT -> GT +                   Just GT -> LT +                   _ -> compare x y + +findProperty :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> [b] +findProperty sel = map (fromJust . sel . fromStored) . filterAncestors . (findPropHeads sel =<<) + +findPropertyFirst :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> Maybe b +findPropertyFirst sel = fmap (fromJust . sel . fromStored) . listToMaybe . filterAncestors . (findPropHeads sel =<<) + +findPropHeads :: forall a b. Storable a => (a -> Maybe b) -> Stored a -> [Stored a] +findPropHeads sel sobj | Just _ <- sel $ fromStored sobj = [sobj] +                       | otherwise = findPropHeads sel =<< previous sobj diff --git a/src/Erebos/Sync.hs b/src/Erebos/Sync.hs new file mode 100644 index 0000000..04b5f11 --- /dev/null +++ b/src/Erebos/Sync.hs @@ -0,0 +1,46 @@ +module Erebos.Sync ( +    SyncService(..), +) where + +import Control.Monad +import Control.Monad.Reader + +import Data.List + +import Erebos.Identity +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data SyncService = SyncPacket (Stored SharedState) + +instance Service SyncService where +    serviceID _ = mkServiceID "a4f538d0-4e50-4082-8e10-7e3ec2af175d" + +    serviceHandler packet = do +        let SyncPacket added = fromStored packet +        pid <- asks svcPeerIdentity +        self <- svcSelf +        when (finalOwner pid `sameIdentity` finalOwner self) $ do +            updateLocalHead_ $ \ls -> do +                let current = sort $ lsShared $ fromStored ls +                    updated = filterAncestors (added : current) +                if current /= updated +                   then mstore (fromStored ls) { lsShared = updated } +                   else return ls + +    serviceNewPeer = notifyPeer . lsShared . fromStored =<< svcGetLocal +    serviceStorageWatchers _ = (:[]) $ SomeStorageWatcher (lsShared . fromStored) notifyPeer + +instance Storable SyncService where +    store' (SyncPacket smsg) = store' smsg +    load' = SyncPacket <$> load' + +notifyPeer :: [Stored SharedState] -> ServiceHandler SyncService () +notifyPeer shared = do +    pid <- asks svcPeerIdentity +    self <- svcSelf +    when (finalOwner pid `sameIdentity` finalOwner self) $ do +        forM_ shared $ \sh -> +            replyStoredRef =<< (mstore . SyncPacket) sh diff --git a/src/Erebos/Util.hs b/src/Erebos/Util.hs new file mode 100644 index 0000000..ffca9c7 --- /dev/null +++ b/src/Erebos/Util.hs @@ -0,0 +1,37 @@ +module Erebos.Util where + +uniq :: Eq a => [a] -> [a] +uniq (x:y:xs) | x == y    = uniq (x:xs) +              | otherwise = x : uniq (y:xs) +uniq xs = xs + +mergeBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeBy cmp (x : xs) (y : ys) = case cmp x y of +                                     LT -> x : mergeBy cmp xs (y : ys) +                                     EQ -> x : y : mergeBy cmp xs ys +                                     GT -> y : mergeBy cmp (x : xs) ys +mergeBy _ xs [] = xs +mergeBy _ [] ys = ys + +mergeUniqBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeUniqBy cmp (x : xs) (y : ys) = case cmp x y of +                                         LT -> x : mergeBy cmp xs (y : ys) +                                         EQ -> x : mergeBy cmp xs ys +                                         GT -> y : mergeBy cmp (x : xs) ys +mergeUniqBy _ xs [] = xs +mergeUniqBy _ [] ys = ys + +mergeUniq :: Ord a => [a] -> [a] -> [a] +mergeUniq = mergeUniqBy compare + +diffSorted :: Ord a => [a] -> [a] -> [a] +diffSorted (x:xs) (y:ys) | x < y     = x : diffSorted xs (y:ys) +                         | x > y     = diffSorted (x:xs) ys +                         | otherwise = diffSorted xs (y:ys) +diffSorted xs _ = xs + +intersectsSorted :: Ord a => [a] -> [a] -> Bool +intersectsSorted (x:xs) (y:ys) | x < y     = intersectsSorted xs (y:ys) +                               | x > y     = intersectsSorted (x:xs) ys +                               | otherwise = True +intersectsSorted _ _ = False |