diff options
43 files changed, 7545 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cceb219 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.ghc.environment.* +cabal.project.local +dist-newstyle/ +.erebos diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..a79ff4c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Revision history for erebos + +## 0.1.0 -- 2024-02-10 + +* First version. @@ -0,0 +1,30 @@ +Copyright (c) 2019, Roman Smrž + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of Roman Smrž nor the names of other + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a29c2c7 --- /dev/null +++ b/README.md @@ -0,0 +1,133 @@ +Erebos +====== + +The erebos binary provides simple CLI interface to the decentralized Erebos +messaging service. Local identity is created on the first run. Protocol and +services specification is being written at: + +[http://erebosprotocol.net](http://erebosprotocol.net) + +Erebos identity is based on locally stored cryptographic keys, all +communication is end-to-end encrypted. Multiple devices can be attached to the +same identity, after which they function interchangeably, without any one being +in any way "primary"; messages and other state data are then synchronized +automatically whenever the devices are able to connect with one another. + +Status +------ + +This is experimental implementation of yet unfinished specification, so +changes, especially in the library API, are expected. Storage format and +network protocol should generally remain backward compatible, with their +respective versions to be increased in case of incompatible changes, to allow +for interoperability even in that case. + +Usage +----- + +On the first run, local identity will be created for this device based on +interactive prompts for: + +`Name:` name of the user/owner, which will be shared among all devices +belonging to the same user; keep empty when initializing device that is going +to be attached to already existing identity on other device. + +`Device:` name describing current device, can be empty. + +After the initial setup, the erebos tool presents interactive prompt for +messages and commands. All commands start with the slash (`/`) character, +followed by command name and parameters (if any) separated by spaces. When a +peer or contact is selected, message to send him can be entered directly on the +command prompt. + +### Messaging + +`/peers` +List peers with direct network connection. Peers are discovered automatically +on local network or can be manually added. + +`/contacts` +List known contacts (see below). + +`/<number>` +Select contact or peer `<number>` based on previous `/contacts` or `/peers` +output list. + +`<message>` +Send `<message>` to selected contact. + +`/history` +Show message history for selected contact or peer. + +### Add contacts + +To ensure the identity of the contact and prevent man-in-the-middle attack, +generated verification code needs to be confirmed on both devices to add +contacts to contact list (similar to bluetooth device pairing). Before adding +new contact, list peers using `/peers` command and select one with `/<number>`. + +`/contacts` +List already added contacts. + +`/contact-add` +Add selected peer as contact. Six-digit verification code will be computed +based on peer keys, which will be displayed on both devices and needs to be +checked that both numbers are same. After that it needs to be confirmed using +`/contact-accept` to finish the process. + +`/contact-accept` +Confirm that displayed verification codes are same on both devices and add the +selected peer as contact. The side, which did not initiate the contact adding +process, needs to select the corresponding peer with `/<number>` command first. + +`/contact-reject` +Reject contact request or verification code of selected peer. + +### Attach other devices + +Multiple devices can be attached to single identity to be used by the same +user. After the attachment process completes the roles of the devices are +equivalent, both can send and receive messages independently and those +messages, along with any other sate data, are synchronized automatically +whenever the devices can connect to each other. + +The attachment process and underlying protocol is very similar to the contact +adding described above, so also generates verification code based on peer keys +that needs to be checked and confirmed on both devices to avoid potential +man-in-the-middle attack. + +Before attaching device, list peers using `/peers` command and select the +target device with `/<number>`. + +`/attach` +Attach current device to the selected peer. After the process completes the +owner of the selected peer will become owner of this device as well. Six-digit +verification code will be displayed on both devices and the user needs to check +that both are the same before confirmation using the `/attach-accept` command. + +`/attach-accept` +Confirm that displayed verification codes are same on both devices and complete +the attachment process (or wait for the confirmation on the peer device). The +side, which did not initiate the attachment process, needs to select the +corresponding peer with `/<number>` command first. + +`/attach-reject` +Reject device attachment request or verification code of selected peer. + +### Other + +`/peer-add <host> [<port>]` +Manually add network peer with given hostname or IP address. + +`/update-identity` +Interactively update current identity information + + +Storage +------- + +Data are by default stored within `.erebos` subdirectory of the current working +directory. This can be overriden by `EREBOS_DIR` environment variable. + +Private keys are currently stored in plaintext under the `keys` subdirectory of +the erebos directory. diff --git a/Setup.hs b/Setup.hs new file mode 100644 index 0000000..9a994af --- /dev/null +++ b/Setup.hs @@ -0,0 +1,2 @@ +import Distribution.Simple +main = defaultMain diff --git a/erebos.cabal b/erebos.cabal new file mode 100644 index 0000000..d732385 --- /dev/null +++ b/erebos.cabal @@ -0,0 +1,177 @@ +Cabal-Version: 2.2 + +Name: erebos +Version: 0.1.0 +Synopsis: Decentralized messaging and synchronization +Description: + Library and simple CLI interface implementing the Erebos identity + management, decentralized messaging and synchronization protocol, along + with local storage. + . + Erebos identity is based on locally stored cryptographic keys, all + communication is end-to-end encrypted. Multiple devices can be attached to + the same identity, after which they function interchangeably, without any + one being in any way "primary"; messages and other state data are then + synchronized automatically whenever the devices are able to connect with + one another. + . + See README for usage of the CLI tool. +License: BSD-3-Clause +License-File: LICENSE +Homepage: http://erebosprotocol.net +Author: Roman Smrž <roman.smrz@seznam.cz> +Maintainer: roman.smrz@seznam.cz +Category: Network +Stability: experimental +Build-type: Simple +Extra-Doc-Files: + README.md + CHANGELOG.md +Extra-Source-Files: + src/Erebos/ICE/pjproject.h + +Flag ice + Description: Enable peer discovery with ICE support using pjproject + +source-repository head + type: git + location: git://erebosprotocol.net/erebos + +common common + ghc-options: -Wall + + build-depends: + base >=4.13 && <4.17, + + default-extensions: + DefaultSignatures + ExistentialQuantification + FlexibleContexts + FlexibleInstances + FunctionalDependencies + GeneralizedNewtypeDeriving + ImportQualifiedPost + LambdaCase + MultiWayIf + RankNTypes + RecordWildCards + ScopedTypeVariables + StandaloneDeriving + TypeOperators + TupleSections + TypeApplications + TypeFamilies + TypeFamilyDependencies + + other-extensions: + CPP + ForeignFunctionInterface + OverloadedStrings + RecursiveDo + TemplateHaskell + UndecidableInstances + + if flag(ice) + cpp-options: -DENABLE_ICE_SUPPORT + +library + import: common + default-language: Haskell2010 + + hs-source-dirs: src + exposed-modules: + Erebos.Attach + Erebos.Channel + Erebos.Contact + Erebos.Identity + Erebos.Message + Erebos.Network + Erebos.Network.Protocol + Erebos.Pairing + Erebos.PubKey + Erebos.Service + Erebos.Set + Erebos.State + Erebos.Storage + Erebos.Storage.Key + Erebos.Storage.Merge + Erebos.Sync + + -- Used by test tool: + Erebos.Storage.Internal + other-modules: + Erebos.Flow + Erebos.Storage.List + Erebos.Util + + c-sources: + src/Erebos/Network/ifaddrs.c + include-dirs: + src + + if flag(ice) + exposed-modules: + Erebos.Discovery + Erebos.ICE + c-sources: + src/Erebos/ICE/pjproject.c + include-dirs: + src/Erebos/ICE + includes: + src/Erebos/ICE/pjproject.h + build-tool-depends: c2hs:c2hs + pkgconfig-depends: libpjproject >= 2.9 + + build-depends: + async >=2.2 && <2.3, + binary >=0.8 && <0.11, + bytestring >=0.10 && <0.13, + cereal >= 0.5 && <0.6, + clock >=0.8 && < 0.9, + containers >= 0.6 && <0.8, + cryptonite >=0.25 && <0.31, + deepseq >= 1.4 && <1.6, + directory >= 1.3 && <1.4, + filepath >=1.4 && <1.6, + hashable >=1.3 && <1.5, + hashtables >=1.2 && <1.4, + hinotify >=0.4 && <0.5, + iproute >=1.7 && <1.8, + memory >=0.14 && <0.19, + mime >= 0.4 && < 0.5, + mtl >=2.2 && <2.4, + network >= 3.1 && <3.2, + stm >=2.5 && <2.6, + tagged >= 0.8 && <0.9, + text >= 1.2 && <2.2, + time >= 1.8 && <1.14, + unix >=2.7 && <2.9, + uuid >=1.3 && <1.4, + zlib >=0.6 && <0.8 + +executable erebos + import: common + default-language: Haskell2010 + hs-source-dirs: main + ghc-options: -threaded + + main-is: Main.hs + other-modules: + Paths_erebos + Test + Version + Version.Git + + build-depends: + bytestring, + cryptonite, + directory, + erebos, + haskeline >=0.7 && <0.9, + mtl, + network, + process >=1.6 && <1.7, + template-haskell >=2.18 && <2.22, + text, + time, + transformers >= 0.5 && <0.7, diff --git a/main/Main.hs b/main/Main.hs new file mode 100644 index 0000000..90e4079 --- /dev/null +++ b/main/Main.hs @@ -0,0 +1,494 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE OverloadedStrings #-} + +module Main (main) where + +import Control.Arrow (first) +import Control.Concurrent +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Trans.Maybe + +import Crypto.Random + +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import Data.Char +import Data.List +import Data.Maybe +import Data.Ord +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import qualified Data.Text.IO as T +import Data.Time.LocalTime +import Data.Typeable + +import Network.Socket + +import System.Console.GetOpt +import System.Console.Haskeline +import System.Environment + +import Erebos.Attach +import Erebos.Contact +#ifdef ENABLE_ICE_SUPPORT +import Erebos.Discovery +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Message +import Erebos.Network +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Sync + +import Test +import Version + +data Options = Options + { optServer :: ServerOptions + , optShowVersion :: Bool + } + +defaultOptions :: Options +defaultOptions = Options + { optServer = defaultServerOptions + , optShowVersion = False + } + +options :: [OptDescr (Options -> Options)] +options = + [ Option ['p'] ["port"] + (ReqArg (\p -> so $ \opts -> opts { serverPort = read p }) "PORT") + "local port to bind" + , Option ['s'] ["silent"] + (NoArg (so $ \opts -> opts { serverLocalDiscovery = False })) + "do not send announce packets for local discovery" + , Option ['V'] ["version"] + (NoArg $ \opts -> opts { optShowVersion = True }) + "show version and exit" + ] + where so f opts = opts { optServer = f $ optServer opts } + +main :: IO () +main = do + st <- liftIO $ openStorage . fromMaybe "./.erebos" =<< lookupEnv "EREBOS_DIR" + getArgs >>= \case + ["cat-file", sref] -> do + readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> BL.putStr $ lazyLoadBytes ref + + ("cat-file" : objtype : srefs@(_:_)) -> do + sequence <$> (mapM (readRef st . BC.pack) srefs) >>= \case + Nothing -> error "ref does not exist" + Just refs -> case objtype of + "signed" -> forM_ refs $ \ref -> do + let signed = load ref :: Signed Object + BL.putStr $ lazyLoadBytes $ storedRef $ signedData signed + forM_ (signedSignature signed) $ \sig -> do + putStr $ "SIG " + BC.putStrLn $ showRef $ storedRef $ sigKey $ fromStored sig + "identity" -> case validateIdentityF (wrappedLoad <$> refs) of + Just identity -> do + let disp :: Identity m -> IO () + disp idt = do + maybe (return ()) (T.putStrLn . (T.pack "Name: " `T.append`)) $ idName idt + BC.putStrLn . (BC.pack "KeyId: " `BC.append`) . showRefDigest . refDigest . storedRef $ idKeyIdentity idt + BC.putStrLn . (BC.pack "KeyMsg: " `BC.append`) . showRefDigest . refDigest . storedRef $ idKeyMessage idt + case idOwner idt of + Nothing -> return () + Just owner -> do + mapM_ (putStrLn . ("OWNER " ++) . BC.unpack . showRefDigest . refDigest . storedRef) $ idDataF owner + disp owner + disp identity + Nothing -> putStrLn $ "Identity verification failed" + _ -> error $ "unknown object type '" ++ objtype ++ "'" + + ["show-generation", sref] -> readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> print $ storedGeneration (wrappedLoad ref :: Stored Object) + + ["update-identity"] -> either fail return <=< runExceptT $ do + runReaderT updateSharedIdentity =<< loadLocalStateHead st + + ("update-identity" : srefs) -> do + sequence <$> mapM (readRef st . BC.pack) srefs >>= \case + Nothing -> error "ref does not exist" + Just refs + | Just idt <- validateIdentityF $ map wrappedLoad refs -> do + BC.putStrLn . showRefDigest . refDigest . storedRef . idData =<< + (either fail return <=< runExceptT $ runReaderT (interactiveIdentityUpdate idt) st) + | otherwise -> error "invalid identity" + + ["test"] -> runTestTool st + + args -> do + opts <- case getOpt Permute options args of + (o, [], []) -> return (foldl (flip id) defaultOptions o) + (_, _, errs) -> ioError (userError (concat errs ++ usageInfo header options)) + where header = "Usage: erebos [OPTION...]" + + if optShowVersion opts + then putStrLn versionLine + else interactiveLoop st opts + + +interactiveLoop :: Storage -> Options -> IO () +interactiveLoop st opts = runInputT defaultSettings $ do + erebosHead <- liftIO $ loadLocalStateHead st + outputStrLn $ T.unpack $ displayIdentity $ headLocalIdentity erebosHead + + haveTerminalUI >>= \case True -> return () + False -> error "Requires terminal" + extPrint <- getExternalPrint + let extPrintLn str = extPrint $ case reverse str of ('\n':_) -> str + _ -> str ++ "\n"; + + _ <- liftIO $ do + tzone <- getCurrentTimeZone + watchReceivedMessages erebosHead $ + extPrintLn . formatMessage tzone . fromStored + + server <- liftIO $ do + startServer (optServer opts) erebosHead extPrintLn + [ someService @AttachService Proxy + , someService @SyncService Proxy + , someService @ContactService Proxy + , someService @DirectMessage Proxy +#ifdef ENABLE_ICE_SUPPORT + , someService @DiscoveryService Proxy +#endif + ] + + peers <- liftIO $ newMVar [] + contextOptions <- liftIO $ newMVar [] + + void $ liftIO $ forkIO $ void $ forever $ do + peer <- getNextPeerChange server + peerIdentity peer >>= \case + pid@(PeerIdentityFull _) -> do + let shown = showPeer pid $ peerAddress peer + let update [] = ([(peer, shown)], Nothing) + update ((p,s):ps) | p == peer = ((peer, shown) : ps, Just s) + | otherwise = first ((p,s):) $ update ps + let ctxUpdate n [] = ([SelectedPeer peer], n) + ctxUpdate n (ctx:ctxs) + | SelectedPeer p <- ctx, p == peer = (ctx:ctxs, n) + | otherwise = first (ctx:) $ ctxUpdate (n + 1) ctxs + op <- modifyMVar peers (return . update) + idx <- modifyMVar contextOptions (return . ctxUpdate (1 :: Int)) + when (Just shown /= op) $ extPrint $ "[" <> show idx <> "] PEER " <> shown + _ -> return () + + let getInputLines prompt = do + Just input <- lift $ getInputLine prompt + case reverse input of + _ | all isSpace input -> getInputLines prompt + '\\':rest -> (reverse ('\n':rest) ++) <$> getInputLines ">> " + _ -> return input + + let process :: CommandState -> MaybeT (InputT IO) CommandState + process cstate = do + pname <- case csContext cstate of + NoContext -> return "" + SelectedPeer peer -> peerIdentity peer >>= return . \case + PeerIdentityFull pid -> maybe "<unnamed>" T.unpack $ idName $ finalOwner pid + PeerIdentityRef wref _ -> "<" ++ BC.unpack (showRefDigest $ wrDigest wref) ++ ">" + PeerIdentityUnknown _ -> "<unknown>" + SelectedContact contact -> return $ T.unpack $ contactName contact + input <- getInputLines $ pname ++ "> " + let (CommandM cmd, line) = case input of + '/':rest -> let (scmd, args) = dropWhile isSpace <$> span (\c -> isAlphaNum c || c == '-') rest + in if not (null scmd) && all isDigit scmd + then (cmdSelectContext $ read scmd, args) + else (fromMaybe (cmdUnknown scmd) $ lookup scmd commands, args) + _ -> (cmdSend, input) + h <- liftIO (reloadHead $ csHead cstate) >>= \case + Just h -> return h + Nothing -> do lift $ lift $ extPrint "current head deleted" + mzero + res <- liftIO $ runExceptT $ flip execStateT cstate { csHead = h } $ runReaderT cmd CommandInput + { ciServer = server + , ciLine = line + , ciPrint = extPrintLn + , ciPeers = liftIO $ readMVar peers + , ciContextOptions = liftIO $ readMVar contextOptions + , ciSetContextOptions = \ctxs -> liftIO $ modifyMVar_ contextOptions $ const $ return ctxs + } + case res of + Right cstate' -> return cstate' + Left err -> do lift $ lift $ extPrint $ "Error: " ++ err + return cstate + + let loop (Just cstate) = runMaybeT (process cstate) >>= loop + loop Nothing = return () + loop $ Just $ CommandState + { csHead = erebosHead + , csContext = NoContext +#ifdef ENABLE_ICE_SUPPORT + , csIceSessions = [] +#endif + , csIcePeer = Nothing + } + + +data CommandInput = CommandInput + { ciServer :: Server + , ciLine :: String + , ciPrint :: String -> IO () + , ciPeers :: CommandM [(Peer, String)] + , ciContextOptions :: CommandM [CommandContext] + , ciSetContextOptions :: [CommandContext] -> Command + } + +data CommandState = CommandState + { csHead :: Head LocalState + , csContext :: CommandContext +#ifdef ENABLE_ICE_SUPPORT + , csIceSessions :: [IceSession] +#endif + , csIcePeer :: Maybe Peer + } + +data CommandContext = NoContext + | SelectedPeer Peer + | SelectedContact Contact + +newtype CommandM a = CommandM (ReaderT CommandInput (StateT CommandState (ExceptT String IO)) a) + deriving (Functor, Applicative, Monad, MonadIO, MonadReader CommandInput, MonadState CommandState, MonadError String) + +instance MonadFail CommandM where + fail = throwError + +instance MonadRandom CommandM where + getRandomBytes = liftIO . getRandomBytes + +instance MonadStorage CommandM where + getStorage = gets $ headStorage . csHead + +instance MonadHead LocalState CommandM where + updateLocalHead f = do + h <- gets csHead + (Just h', x) <- maybe (fail "failed to reload head") (flip updateHead f) =<< reloadHead h + modify $ \s -> s { csHead = h' } + return x + +type Command = CommandM () + +getSelectedPeer :: CommandM Peer +getSelectedPeer = gets csContext >>= \case + SelectedPeer peer -> return peer + _ -> throwError "no peer selected" + +getSelectedIdentity :: CommandM ComposedIdentity +getSelectedIdentity = gets csContext >>= \case + SelectedPeer peer -> peerIdentity peer >>= \case + PeerIdentityFull pid -> return $ toComposedIdentity pid + _ -> throwError "incomplete peer identity" + SelectedContact contact -> case contactIdentity contact of + Just cid -> return cid + Nothing -> throwError "contact without erebos identity" + _ -> throwError "no contact or peer selected" + +commands :: [(String, Command)] +commands = + [ ("history", cmdHistory) + , ("peers", cmdPeers) + , ("peer-add", cmdPeerAdd) + , ("send", cmdSend) + , ("update-identity", cmdUpdateIdentity) + , ("attach", cmdAttach) + , ("attach-accept", cmdAttachAccept) + , ("attach-reject", cmdAttachReject) + , ("contacts", cmdContacts) + , ("contact-add", cmdContactAdd) + , ("contact-accept", cmdContactAccept) + , ("contact-reject", cmdContactReject) +#ifdef ENABLE_ICE_SUPPORT + , ("discovery-init", cmdDiscoveryInit) + , ("discovery", cmdDiscovery) + , ("ice-create", cmdIceCreate) + , ("ice-destroy", cmdIceDestroy) + , ("ice-show", cmdIceShow) + , ("ice-connect", cmdIceConnect) + , ("ice-send", cmdIceSend) +#endif + ] + +cmdUnknown :: String -> Command +cmdUnknown cmd = liftIO $ putStrLn $ "Unknown command: " ++ cmd + +cmdPeers :: Command +cmdPeers = do + peers <- join $ asks ciPeers + set <- asks ciSetContextOptions + set $ map (SelectedPeer . fst) peers + forM_ (zip [1..] peers) $ \(i :: Int, (_, name)) -> do + liftIO $ putStrLn $ "[" ++ show i ++ "] " ++ name + +cmdPeerAdd :: Command +cmdPeerAdd = void $ do + server <- asks ciServer + (hostname, port) <- (words <$> asks ciLine) >>= \case + hostname:p:_ -> return (hostname, p) + [hostname] -> return (hostname, show discoveryPort) + [] -> throwError "missing peer address" + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just hostname) (Just port) + liftIO $ serverPeer server (addrAddress addr) + +showPeer :: PeerIdentity -> PeerAddress -> String +showPeer pidentity paddr = + let name = case pidentity of + PeerIdentityUnknown _ -> "<noid>" + PeerIdentityRef wref _ -> "<" ++ BC.unpack (showRefDigest $ wrDigest wref) ++ ">" + PeerIdentityFull pid -> T.unpack $ displayIdentity pid + in name ++ " [" ++ show paddr ++ "]" + +cmdSelectContext :: Int -> Command +cmdSelectContext n = join (asks ciContextOptions) >>= \ctxs -> if + | n > 0, (ctx : _) <- drop (n - 1) ctxs -> modify $ \s -> s { csContext = ctx } + | otherwise -> throwError "invalid index" + +cmdSend :: Command +cmdSend = void $ do + text <- asks ciLine + powner <- finalOwner <$> getSelectedIdentity + smsg <- sendDirectMessage powner $ T.pack text + tzone <- liftIO $ getCurrentTimeZone + liftIO $ putStrLn $ formatMessage tzone $ fromStored smsg + +cmdHistory :: Command +cmdHistory = void $ do + ehead <- gets csHead + powner <- finalOwner <$> getSelectedIdentity + + case find (sameIdentity powner . msgPeer) $ + toThreadList $ lookupSharedValue $ lsShared $ headObject ehead of + Just thread -> do + tzone <- liftIO $ getCurrentTimeZone + liftIO $ mapM_ (putStrLn . formatMessage tzone) $ reverse $ take 50 $ threadToList thread + Nothing -> do + liftIO $ putStrLn $ "<empty history>" + +cmdUpdateIdentity :: Command +cmdUpdateIdentity = void $ do + runReaderT updateSharedIdentity =<< gets csHead + +cmdAttach :: Command +cmdAttach = attachToOwner =<< getSelectedPeer + +cmdAttachAccept :: Command +cmdAttachAccept = attachAccept =<< getSelectedPeer + +cmdAttachReject :: Command +cmdAttachReject = attachReject =<< getSelectedPeer + +cmdContacts :: Command +cmdContacts = do + args <- words <$> asks ciLine + ehead <- gets csHead + let contacts = fromSetBy (comparing contactName) $ lookupSharedValue $ lsShared $ headObject ehead + verbose = "-v" `elem` args + set <- asks ciSetContextOptions + set $ map SelectedContact contacts + forM_ (zip [1..] contacts) $ \(i :: Int, c) -> liftIO $ do + T.putStrLn $ T.concat + [ "[", T.pack (show i), "] ", contactName c + , case contactIdentity c of + Just idt | cname <- displayIdentity idt + , cname /= contactName c + -> " (" <> cname <> ")" + _ -> "" + , if verbose then " " <> (T.unwords $ map (T.decodeUtf8 . showRef . storedRef) $ maybe [] idDataF $ contactIdentity c) + else "" + ] + +cmdContactAdd :: Command +cmdContactAdd = contactRequest =<< getSelectedPeer + +cmdContactAccept :: Command +cmdContactAccept = contactAccept =<< getSelectedPeer + +cmdContactReject :: Command +cmdContactReject = contactReject =<< getSelectedPeer + +#ifdef ENABLE_ICE_SUPPORT + +cmdDiscoveryInit :: Command +cmdDiscoveryInit = void $ do + server <- asks ciServer + + (hostname, port) <- (words <$> asks ciLine) >>= return . \case + hostname:p:_ -> (hostname, p) + [hostname] -> (hostname, show discoveryPort) + [] -> ("discovery.erebosprotocol.net", show discoveryPort) + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just hostname) (Just port) + peer <- liftIO $ serverPeer server (addrAddress addr) + sendToPeer peer $ DiscoverySelf (T.pack "ICE") 0 + modify $ \s -> s { csIcePeer = Just peer } + +cmdDiscovery :: Command +cmdDiscovery = void $ do + Just peer <- gets csIcePeer + st <- getStorage + sref <- asks ciLine + eprint <- asks ciPrint + liftIO $ readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> do + res <- runExceptT $ sendToPeer peer $ DiscoverySearch ref + case res of + Right _ -> return () + Left err -> eprint err + +cmdIceCreate :: Command +cmdIceCreate = do + role <- asks ciLine >>= return . \case + 'm':_ -> PjIceSessRoleControlling + 's':_ -> PjIceSessRoleControlled + _ -> PjIceSessRoleUnknown + eprint <- asks ciPrint + sess <- liftIO $ iceCreate role $ eprint <=< iceShow + modify $ \s -> s { csIceSessions = sess : csIceSessions s } + +cmdIceDestroy :: Command +cmdIceDestroy = do + s:ss <- gets csIceSessions + modify $ \st -> st { csIceSessions = ss } + liftIO $ iceDestroy s + +cmdIceShow :: Command +cmdIceShow = do + sess <- gets csIceSessions + eprint <- asks ciPrint + liftIO $ forM_ (zip [1::Int ..] sess) $ \(i, s) -> do + eprint $ "[" ++ show i ++ "]" + eprint =<< iceShow s + +cmdIceConnect :: Command +cmdIceConnect = do + s:_ <- gets csIceSessions + server <- asks ciServer + let loadInfo = BC.getLine >>= \case line | BC.null line -> return [] + | otherwise -> (line:) <$> loadInfo + Right remote <- liftIO $ do + st <- memoryStorage + pst <- derivePartialStorage st + rbytes <- (BL.fromStrict . BC.unlines) <$> loadInfo + copyRef st =<< storeRawBytes pst (BL.fromChunks [ BC.pack "rec ", BC.pack (show (BL.length rbytes)), BC.singleton '\n' ] `BL.append` rbytes) + liftIO $ iceConnect s (load remote) $ void $ serverPeerIce server s + +cmdIceSend :: Command +cmdIceSend = void $ do + s:_ <- gets csIceSessions + server <- asks ciServer + liftIO $ serverPeerIce server s + +#endif diff --git a/main/Test.hs b/main/Test.hs new file mode 100644 index 0000000..7f0f7d9 --- /dev/null +++ b/main/Test.hs @@ -0,0 +1,550 @@ +module Test ( + runTestTool, +) where + +import Control.Arrow +import Control.Concurrent +import Control.Exception +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Crypto.Random + +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Foldable +import Data.Ord +import Data.Text (Text) +import Data.Text qualified as T +import Data.Text.Encoding +import Data.Text.IO qualified as T +import Data.Typeable + +import Network.Socket + +import System.IO +import System.IO.Error + +import Erebos.Attach +import Erebos.Contact +import Erebos.Identity +import Erebos.Message +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Internal (unsafeStoreRawBytes) +import Erebos.Storage.Merge +import Erebos.Sync + + +data TestState = TestState + { tsHead :: Maybe (Head LocalState) + , tsServer :: Maybe RunningServer + , tsWatchedLocalIdentity :: Maybe WatchedHead + , tsWatchedSharedIdentity :: Maybe WatchedHead + } + +data RunningServer = RunningServer + { rsServer :: Server + , rsPeers :: MVar (Int, [(Int, Peer)]) + , rsPeerThread :: ThreadId + } + +initTestState :: TestState +initTestState = TestState + { tsHead = Nothing + , tsServer = Nothing + , tsWatchedLocalIdentity = Nothing + , tsWatchedSharedIdentity = Nothing + } + +data TestInput = TestInput + { tiOutput :: Output + , tiStorage :: Storage + , tiParams :: [Text] + } + + +runTestTool :: Storage -> IO () +runTestTool st = do + out <- newMVar () + let testLoop = getLineMb >>= \case + Just line -> do + case T.words line of + (cname:params) + | Just (CommandM cmd) <- lookup cname commands -> do + runReaderT cmd $ TestInput out st params + | otherwise -> fail $ "Unknown command '" ++ T.unpack cname ++ "'" + [] -> return () + testLoop + + Nothing -> return () + + runExceptT (evalStateT testLoop initTestState) >>= \case + Left x -> hPutStrLn stderr x + Right () -> return () + +getLineMb :: MonadIO m => m (Maybe Text) +getLineMb = liftIO $ catchIOError (Just <$> T.getLine) (\e -> if isEOFError e then return Nothing else ioError e) + +getLines :: MonadIO m => m [Text] +getLines = getLineMb >>= \case + Just line | not (T.null line) -> (line:) <$> getLines + _ -> return [] + +getHead :: CommandM (Head LocalState) +getHead = do + h <- maybe (fail "failed to reload head") return =<< maybe (fail "no current head") reloadHead =<< gets tsHead + modify $ \s -> s { tsHead = Just h } + return h + + +type Output = MVar () + +outLine :: Output -> String -> IO () +outLine mvar line = do + evaluate $ foldl' (flip seq) () line + withMVar mvar $ \() -> do + putStrLn line + hFlush stdout + +cmdOut :: String -> Command +cmdOut line = do + out <- asks tiOutput + liftIO $ outLine out line + + +getPeer :: Text -> CommandM Peer +getPeer spidx = do + Just RunningServer {..} <- gets tsServer + Just peer <- lookup (read $ T.unpack spidx) . snd <$> liftIO (readMVar rsPeers) + return peer + +getPeerIndex :: MVar (Int, [(Int, Peer)]) -> ServiceHandler (PairingService a) Int +getPeerIndex pmvar = do + peer <- asks svcPeer + maybe 0 fst . find ((==peer) . snd) . snd <$> liftIO (readMVar pmvar) + +pairingAttributes :: PairingResult a => proxy (PairingService a) -> Output -> MVar (Int, [(Int, Peer)]) -> String -> PairingAttributes a +pairingAttributes _ out peers prefix = PairingAttributes + { pairingHookRequest = return () + + , pairingHookResponse = \confirm -> do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-response", index, confirm] + + , pairingHookRequestNonce = \confirm -> do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-request", index, confirm] + + , pairingHookRequestNonceFailed = failed "nonce" + + , pairingHookConfirmedResponse = return () + , pairingHookConfirmedRequest = return () + + , pairingHookAcceptedResponse = do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-response-done", index] + + , pairingHookAcceptedRequest = do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-request-done", index] + + , pairingHookFailed = \case + PairingUserRejected -> failed "user" + PairingUnexpectedMessage pstate packet -> failed $ "unexpected " ++ strState pstate ++ " " ++ strPacket packet + PairingFailedOther str -> failed $ "other " ++ str + , pairingHookVerifyFailed = failed "verify" + , pairingHookRejected = failed "rejected" + } + where + failed :: PairingResult a => String -> ServiceHandler (PairingService a) () + failed detail = do + ptype <- svcGet >>= return . \case + OurRequest {} -> "response" + OurRequestConfirm {} -> "response" + OurRequestReady -> "response" + PeerRequest {} -> "request" + PeerRequestConfirm -> "request" + _ -> fail "unexpected pairing state" + + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ prefix ++ "-" ++ ptype ++ "-failed " ++ index ++ " " ++ detail + + strState :: PairingState a -> String + strState = \case + NoPairing -> "none" + OurRequest {} -> "our-request" + OurRequestConfirm {} -> "our-request-confirm" + OurRequestReady -> "our-request-ready" + PeerRequest {} -> "peer-request" + PeerRequestConfirm -> "peer-request-confirm" + PairingDone -> "done" + + strPacket :: PairingService a -> String + strPacket = \case + PairingRequest {} -> "request" + PairingResponse {} -> "response" + PairingRequestNonce {} -> "nonce" + PairingAccept {} -> "accept" + PairingReject -> "reject" + +directMessageAttributes :: Output -> DirectMessageAttributes +directMessageAttributes out = DirectMessageAttributes + { dmOwnerMismatch = afterCommit $ outLine out "dm-owner-mismatch" + } + +dmReceivedWatcher :: Output -> Stored DirectMessage -> IO () +dmReceivedWatcher out smsg = do + let msg = fromStored smsg + outLine out $ unwords + [ "dm-received" + , "from", maybe "<unnamed>" T.unpack $ idName $ msgFrom msg + , "text", T.unpack $ msgText msg + ] + + +newtype CommandM a = CommandM (ReaderT TestInput (StateT TestState (ExceptT String IO)) a) + deriving (Functor, Applicative, Monad, MonadIO, MonadReader TestInput, MonadState TestState, MonadError String) + +instance MonadFail CommandM where + fail = throwError + +instance MonadRandom CommandM where + getRandomBytes = liftIO . getRandomBytes + +instance MonadStorage CommandM where + getStorage = asks tiStorage + +instance MonadHead LocalState CommandM where + updateLocalHead f = do + Just h <- gets tsHead + (Just h', x) <- maybe (fail "failed to reload head") (flip updateHead f) =<< reloadHead h + modify $ \s -> s { tsHead = Just h' } + return x + +type Command = CommandM () + +commands :: [(Text, Command)] +commands = map (T.pack *** id) + [ ("store", cmdStore) + , ("stored-generation", cmdStoredGeneration) + , ("stored-roots", cmdStoredRoots) + , ("stored-set-add", cmdStoredSetAdd) + , ("stored-set-list", cmdStoredSetList) + , ("create-identity", cmdCreateIdentity) + , ("start-server", cmdStartServer) + , ("stop-server", cmdStopServer) + , ("peer-add", cmdPeerAdd) + , ("shared-state-get", cmdSharedStateGet) + , ("shared-state-wait", cmdSharedStateWait) + , ("watch-local-identity", cmdWatchLocalIdentity) + , ("watch-shared-identity", cmdWatchSharedIdentity) + , ("update-local-identity", cmdUpdateLocalIdentity) + , ("update-shared-identity", cmdUpdateSharedIdentity) + , ("attach-to", cmdAttachTo) + , ("attach-accept", cmdAttachAccept) + , ("attach-reject", cmdAttachReject) + , ("contact-request", cmdContactRequest) + , ("contact-accept", cmdContactAccept) + , ("contact-reject", cmdContactReject) + , ("contact-list", cmdContactList) + , ("contact-set-name", cmdContactSetName) + , ("dm-send-peer", cmdDmSendPeer) + , ("dm-send-contact", cmdDmSendContact) + , ("dm-list-peer", cmdDmListPeer) + , ("dm-list-contact", cmdDmListContact) + ] + +cmdStore :: Command +cmdStore = do + st <- asks tiStorage + [otype] <- asks tiParams + ls <- getLines + + let cnt = encodeUtf8 $ T.unlines ls + ref <- liftIO $ unsafeStoreRawBytes st $ BL.fromChunks [encodeUtf8 otype, BC.singleton ' ', BC.pack (show $ B.length cnt), BC.singleton '\n', cnt] + cmdOut $ "store-done " ++ show (refDigest ref) + +cmdStoredGeneration :: Command +cmdStoredGeneration = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + cmdOut $ "stored-generation " ++ T.unpack tref ++ " " ++ showGeneration (storedGeneration $ wrappedLoad @Object ref) + +cmdStoredRoots :: Command +cmdStoredRoots = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + cmdOut $ "stored-roots " ++ T.unpack tref ++ concatMap ((' ':) . show . refDigest . storedRef) (storedRoots $ wrappedLoad @Object ref) + +cmdStoredSetAdd :: Command +cmdStoredSetAdd = do + st <- asks tiStorage + (item, set) <- asks tiParams >>= liftIO . mapM (readRef st . encodeUtf8) >>= \case + [Just iref, Just sref] -> return (wrappedLoad iref, loadSet @[Stored Object] sref) + [Just iref] -> return (wrappedLoad iref, emptySet) + _ -> fail "unexpected parameters" + set' <- storeSetAdd st [item] set + cmdOut $ "stored-set-add" ++ concatMap ((' ':) . show . refDigest . storedRef) (toComponents set') + +cmdStoredSetList :: Command +cmdStoredSetList = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + let items = fromSetBy compare $ loadSet @[Stored Object] ref + forM_ items $ \item -> do + cmdOut $ "stored-set-item" ++ concatMap ((' ':) . show . refDigest . storedRef) item + cmdOut $ "stored-set-done" + +cmdCreateIdentity :: Command +cmdCreateIdentity = do + st <- asks tiStorage + names <- asks tiParams + + h <- liftIO $ do + Just identity <- if null names + then Just <$> createIdentity st Nothing Nothing + else foldrM (\n o -> Just <$> createIdentity st (Just n) o) Nothing names + + shared <- case names of + _:_:_ -> (:[]) <$> makeSharedStateUpdate st (Just $ finalOwner identity) [] + _ -> return [] + + storeHead st $ LocalState + { lsIdentity = idExtData identity + , lsShared = shared + } + + _ <- liftIO . watchReceivedMessages h . dmReceivedWatcher =<< asks tiOutput + modify $ \s -> s { tsHead = Just h } + +cmdStartServer :: Command +cmdStartServer = do + out <- asks tiOutput + + Just h <- gets tsHead + rsPeers <- liftIO $ newMVar (1, []) + rsServer <- liftIO $ startServer defaultServerOptions h (hPutStrLn stderr) + [ someServiceAttr $ pairingAttributes (Proxy @AttachService) out rsPeers "attach" + , someServiceAttr $ pairingAttributes (Proxy @ContactService) out rsPeers "contact" + , someServiceAttr $ directMessageAttributes out + , someService @SyncService Proxy + ] + + rsPeerThread <- liftIO $ forkIO $ void $ forever $ do + peer <- getNextPeerChange rsServer + + let printPeer (idx, p) = do + params <- peerIdentity p >>= return . \case + PeerIdentityFull pid -> ("id":) $ map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners pid) + _ -> [ "addr", show (peerAddress p) ] + outLine out $ unwords $ [ "peer", show idx ] ++ params + + update (nid, []) = printPeer (nid, peer) >> return (nid + 1, [(nid, peer)]) + update cur@(nid, p:ps) | snd p == peer = printPeer p >> return cur + | otherwise = fmap (p:) <$> update (nid, ps) + + modifyMVar_ rsPeers update + + modify $ \s -> s { tsServer = Just RunningServer {..} } + +cmdStopServer :: Command +cmdStopServer = do + Just RunningServer {..} <- gets tsServer + liftIO $ do + killThread rsPeerThread + stopServer rsServer + modify $ \s -> s { tsServer = Nothing } + cmdOut "stop-server-done" + +cmdPeerAdd :: Command +cmdPeerAdd = do + Just RunningServer {..} <- gets tsServer + host:rest <- map T.unpack <$> asks tiParams + + let port = case rest of [] -> show discoveryPort + (p:_) -> p + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just host) (Just port) + void $ liftIO $ serverPeer rsServer (addrAddress addr) + +cmdSharedStateGet :: Command +cmdSharedStateGet = do + h <- getHead + cmdOut $ unwords $ "shared-state-get" : map (show . refDigest . storedRef) (lsShared $ headObject h) + +cmdSharedStateWait :: Command +cmdSharedStateWait = do + st <- asks tiStorage + out <- asks tiOutput + Just h <- gets tsHead + trefs <- asks tiParams + + liftIO $ do + mvar <- newEmptyMVar + w <- watchHeadWith h (lsShared . headObject) $ \cur -> do + mbobjs <- mapM (readRef st . encodeUtf8) trefs + case map wrappedLoad <$> sequence mbobjs of + Just objs | filterAncestors (cur ++ objs) == cur -> do + outLine out $ unwords $ "shared-state-wait" : map T.unpack trefs + void $ forkIO $ unwatchHead =<< takeMVar mvar + _ -> return () + putMVar mvar w + +cmdWatchLocalIdentity :: Command +cmdWatchLocalIdentity = do + Just h <- gets tsHead + Nothing <- gets tsWatchedLocalIdentity + + out <- asks tiOutput + w <- liftIO $ watchHeadWith h headLocalIdentity $ \idt -> do + outLine out $ unwords $ "local-identity" : map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners idt) + modify $ \s -> s { tsWatchedLocalIdentity = Just w } + +cmdWatchSharedIdentity :: Command +cmdWatchSharedIdentity = do + Just h <- gets tsHead + Nothing <- gets tsWatchedSharedIdentity + + out <- asks tiOutput + w <- liftIO $ watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \case + Just (idt :: ComposedIdentity) -> do + outLine out $ unwords $ "shared-identity" : map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners idt) + Nothing -> do + outLine out $ "shared-identity-failed" + modify $ \s -> s { tsWatchedSharedIdentity = Just w } + +cmdUpdateLocalIdentity :: Command +cmdUpdateLocalIdentity = do + [name] <- asks tiParams + updateLocalHead_ $ \ls -> do + Just identity <- return $ validateExtendedIdentity $ lsIdentity $ fromStored ls + let public = idKeyIdentity identity + + secret <- loadKey public + nidata <- maybe (error "created invalid identity") (return . idExtData) . validateExtendedIdentity =<< + mstore =<< sign secret =<< mstore . ExtendedIdentityData =<< return (emptyIdentityExtension $ idData identity) + { idePrev = toList $ idExtDataF identity + , ideName = Just name + } + mstore (fromStored ls) { lsIdentity = nidata } + +cmdUpdateSharedIdentity :: Command +cmdUpdateSharedIdentity = do + [name] <- asks tiParams + updateLocalHead_ $ updateSharedState_ $ \case + Nothing -> throwError "no existing shared identity" + Just identity -> do + let public = idKeyIdentity identity + secret <- loadKey public + uidentity <- mergeIdentity identity + maybe (error "created invalid identity") (return . Just . toComposedIdentity) . validateExtendedIdentity =<< + mstore =<< sign secret =<< mstore . ExtendedIdentityData =<< return (emptyIdentityExtension $ idData uidentity) + { idePrev = toList $ idExtDataF identity + , ideName = Just name + } + +cmdAttachTo :: Command +cmdAttachTo = do + [spidx] <- asks tiParams + attachToOwner =<< getPeer spidx + +cmdAttachAccept :: Command +cmdAttachAccept = do + [spidx] <- asks tiParams + attachAccept =<< getPeer spidx + +cmdAttachReject :: Command +cmdAttachReject = do + [spidx] <- asks tiParams + attachReject =<< getPeer spidx + +cmdContactRequest :: Command +cmdContactRequest = do + [spidx] <- asks tiParams + contactRequest =<< getPeer spidx + +cmdContactAccept :: Command +cmdContactAccept = do + [spidx] <- asks tiParams + contactAccept =<< getPeer spidx + +cmdContactReject :: Command +cmdContactReject = do + [spidx] <- asks tiParams + contactReject =<< getPeer spidx + +cmdContactList :: Command +cmdContactList = do + h <- getHead + let contacts = fromSetBy (comparing contactName) . lookupSharedValue . lsShared . headObject $ h + forM_ contacts $ \c -> do + r:_ <- return $ filterAncestors $ concatMap storedRoots $ toComponents c + cmdOut $ concat + [ "contact-list-item " + , show $ refDigest $ storedRef r + , " " + , T.unpack $ contactName c + , case contactIdentity c of Nothing -> ""; Just idt -> " " ++ T.unpack (displayIdentity idt) + ] + cmdOut "contact-list-done" + +getContact :: Text -> CommandM Contact +getContact cid = do + h <- getHead + let contacts = fromSetBy (comparing contactName) . lookupSharedValue . lsShared . headObject $ h + [contact] <- flip filterM contacts $ \c -> do + r:_ <- return $ filterAncestors $ concatMap storedRoots $ toComponents c + return $ T.pack (show $ refDigest $ storedRef r) == cid + return contact + +cmdContactSetName :: Command +cmdContactSetName = do + [cid, name] <- asks tiParams + contact <- getContact cid + updateLocalHead_ $ updateSharedState_ $ contactSetName contact name + cmdOut "contact-set-name-done" + +cmdDmSendPeer :: Command +cmdDmSendPeer = do + [spidx, msg] <- asks tiParams + PeerIdentityFull to <- peerIdentity =<< getPeer spidx + void $ sendDirectMessage to msg + +cmdDmSendContact :: Command +cmdDmSendContact = do + [cid, msg] <- asks tiParams + Just to <- contactIdentity <$> getContact cid + void $ sendDirectMessage to msg + +dmList :: Foldable f => Identity f -> Command +dmList peer = do + threads <- toThreadList . lookupSharedValue . lsShared . headObject <$> getHead + case find (sameIdentity peer . msgPeer) threads of + Just thread -> do + forM_ (reverse $ threadToList thread) $ \DirectMessage {..} -> cmdOut $ "dm-list-item" + <> " from " <> (maybe "<unnamed>" T.unpack $ idName msgFrom) + <> " text " <> (T.unpack msgText) + Nothing -> return () + cmdOut "dm-list-done" + +cmdDmListPeer :: Command +cmdDmListPeer = do + [spidx] <- asks tiParams + PeerIdentityFull to <- peerIdentity =<< getPeer spidx + dmList to + +cmdDmListContact :: Command +cmdDmListContact = do + [cid] <- asks tiParams + Just to <- contactIdentity <$> getContact cid + dmList to diff --git a/main/Version.hs b/main/Version.hs new file mode 100644 index 0000000..d7583bf --- /dev/null +++ b/main/Version.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE TemplateHaskell #-} + +module Version ( + versionLine, +) where + +import Paths_erebos (version) +import Data.Version (showVersion) +import Version.Git + +{-# NOINLINE versionLine #-} +versionLine :: String +versionLine = do + let ver = case $$tGitVersion of + Just gver + | 'v':v <- gver, not $ all (`elem` ('.': ['0'..'9'])) v + -> "git " <> gver + _ -> "version " <> showVersion version + in "Erebos CLI " <> ver diff --git a/main/Version/Git.hs b/main/Version/Git.hs new file mode 100644 index 0000000..2aae6e3 --- /dev/null +++ b/main/Version/Git.hs @@ -0,0 +1,31 @@ +module Version.Git ( + tGitVersion, +) where + +import Language.Haskell.TH +import Language.Haskell.TH.Syntax + +import System.Directory +import System.Exit +import System.Process + +tGitVersion :: Code Q (Maybe String) +tGitVersion = unsafeCodeCoerce $ do + let git args = do + (ExitSuccess, out, _) <- readCreateProcessWithExitCode + (proc "git" $ [ "--git-dir=./.git", "--work-tree=." ] ++ args) "" + return $ lines out + + mbver <- runIO $ do + doesPathExist "./.git" >>= \case + False -> return Nothing + True -> do + desc:_ <- git [ "describe", "--always", "--dirty= (dirty)" ] + files <- git [ "ls-files" ] + return $ Just (desc, files) + + case mbver of + Just (_, files) -> mapM_ addDependentFile files + Nothing -> return () + + lift (fst <$> mbver :: Maybe String) diff --git a/src/Erebos/Attach.hs b/src/Erebos/Attach.hs new file mode 100644 index 0000000..4fd976f --- /dev/null +++ b/src/Erebos/Attach.hs @@ -0,0 +1,122 @@ +module Erebos.Attach ( + AttachService, + attachToOwner, + attachAccept, + attachReject, +) where + +import Control.Monad.Except +import Control.Monad.Reader + +import Data.ByteArray (ScrubbedBytes) +import Data.Maybe +import Data.Proxy +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Key + +type AttachService = PairingService AttachIdentity + +data AttachIdentity = AttachIdentity (Stored (Signed IdentityData)) [ScrubbedBytes] + +instance Storable AttachIdentity where + store' (AttachIdentity x keys) = storeRec $ do + storeRef "identity" x + mapM_ (storeBinary "skey") keys + + load' = loadRec $ AttachIdentity + <$> loadRef "identity" + <*> loadBinaries "skey" + +instance PairingResult AttachIdentity where + pairingServiceID _ = mkServiceID "4995a5f9-2d4d-48e9-ad3b-0bf1c2a1be7f" + + type PairingVerifiedResult AttachIdentity = (UnifiedIdentity, [ScrubbedBytes]) + + pairingVerifyResult (AttachIdentity sdata keys) = do + curid <- lsIdentity . fromStored <$> svcGetLocal + secret <- loadKey $ eiddKeyIdentity $ fromSigned curid + sdata' <- mstore =<< signAdd secret (fromStored sdata) + return $ do + guard $ iddKeyIdentity (fromSigned sdata) == + eiddKeyIdentity (fromSigned curid) + identity <- validateIdentity sdata' + guard $ iddPrev (fromSigned $ idData identity) == [eiddStoredBase curid] + return (identity, keys) + + pairingFinalizeRequest (identity, keys) = updateLocalHead_ $ \slocal -> do + let owner = finalOwner identity + st <- getStorage + pkeys <- mapM (copyStored st) [ idKeyIdentity owner, idKeyMessage owner ] + liftIO $ mapM_ storeKey $ catMaybes [ keyFromData sec pub | sec <- keys, pub <- pkeys ] + + identity' <- mergeIdentity $ updateIdentity [ lsIdentity $ fromStored slocal ] identity + shared <- makeSharedStateUpdate st (Just owner) (lsShared $ fromStored slocal) + mstore (fromStored slocal) + { lsIdentity = idExtData identity' + , lsShared = [ shared ] + } + + pairingFinalizeResponse = do + owner <- mergeSharedIdentity + pid <- asks svcPeerIdentity + secret <- loadKey $ idKeyIdentity owner + identity <- mstore =<< sign secret =<< mstore (emptyIdentityData $ idKeyIdentity pid) + { iddPrev = [idData pid], iddOwner = Just (idData owner) } + skeys <- map keyGetData . catMaybes <$> mapM loadKeyMb [ idKeyIdentity owner, idKeyMessage owner ] + return $ AttachIdentity identity skeys + + defaultPairingAttributes _ = PairingAttributes + { pairingHookRequest = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ " initiated" + + , pairingHookResponse = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach to " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + + , pairingHookRequestNonce = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + + , pairingHookRequestNonceFailed = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Failed attach from " ++ T.unpack (displayIdentity peer) + + , pairingHookConfirmedResponse = do + svcPrint $ "Confirmed peer, waiting for updated identity" + + , pairingHookConfirmedRequest = do + svcPrint $ "Attachment confirmed by peer" + + , pairingHookAcceptedResponse = do + svcPrint $ "Accepted updated identity" + + , pairingHookAcceptedRequest = do + svcPrint $ "Accepted new attached device, seding updated identity" + + , pairingHookVerifyFailed = do + svcPrint $ "Failed to verify new identity" + + , pairingHookRejected = do + svcPrint $ "Attachment rejected by peer" + + , pairingHookFailed = \_ -> do + svcPrint $ "Attachement failed" + } + +attachToOwner :: (MonadIO m, MonadError String m) => Peer -> m () +attachToOwner = pairingRequest @AttachIdentity Proxy + +attachAccept :: (MonadIO m, MonadError String m) => Peer -> m () +attachAccept = pairingAccept @AttachIdentity Proxy + +attachReject :: (MonadIO m, MonadError String m) => Peer -> m () +attachReject = pairingReject @AttachIdentity Proxy diff --git a/src/Erebos/Channel.hs b/src/Erebos/Channel.hs new file mode 100644 index 0000000..c10f971 --- /dev/null +++ b/src/Erebos/Channel.hs @@ -0,0 +1,174 @@ +module Erebos.Channel ( + Channel, + ChannelRequest, ChannelRequestData(..), + ChannelAccept, ChannelAcceptData(..), + + createChannelRequest, + acceptChannelRequest, + acceptedChannel, + + channelEncrypt, + channelDecrypt, +) where + +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except + +import Crypto.Cipher.ChaChaPoly1305 +import Crypto.Error + +import Data.Binary +import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, convert) +import Data.ByteArray qualified as BA +import Data.ByteString.Lazy qualified as BL +import Data.List + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Storage + +data Channel = Channel + { chPeers :: [Stored (Signed IdentityData)] + , chKey :: ScrubbedBytes + , chNonceFixedOur :: Bytes + , chNonceFixedPeer :: Bytes + , chCounterNextOut :: MVar Word64 + , chCounterNextIn :: MVar Word64 + } + +type ChannelRequest = Signed ChannelRequestData + +data ChannelRequestData = ChannelRequest + { crPeers :: [Stored (Signed IdentityData)] + , crKey :: Stored PublicKexKey + } + deriving (Show) + +type ChannelAccept = Signed ChannelAcceptData + +data ChannelAcceptData = ChannelAccept + { caRequest :: Stored ChannelRequest + , caKey :: Stored PublicKexKey + } + + +instance Storable ChannelRequestData where + store' cr = storeRec $ do + mapM_ (storeRef "peer") $ crPeers cr + storeRef "key" $ crKey cr + + load' = loadRec $ do + ChannelRequest + <$> loadRefs "peer" + <*> loadRef "key" + +instance Storable ChannelAcceptData where + store' ca = storeRec $ do + storeRef "req" $ caRequest ca + storeRef "key" $ caKey ca + + load' = loadRec $ do + ChannelAccept + <$> loadRef "req" + <*> loadRef "key" + + +keySize :: Int +keySize = 32 + +createChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest) +createChannelRequest self peer = do + (_, xpublic) <- liftIO . generateKeys =<< getStorage + skey <- loadKey $ idKeyMessage self + mstore =<< sign skey =<< mstore ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic } + +acceptChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Channel) +acceptChannelRequest self peer req = do + case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of + Nothing -> throwError $ "invalid peers in channel request" + Just peers -> do + when (not $ any (self `sameIdentity`) peers) $ + throwError $ "self identity missing in channel request peers" + when (not $ any (peer `sameIdentity`) peers) $ + throwError $ "peer identity missing in channel request peers" + when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ + throwError $ "channel requent not signed by peer" + + (xsecret, xpublic) <- liftIO . generateKeys =<< getStorage + skey <- loadKey $ idKeyMessage self + acc <- mstore =<< sign skey =<< mstore ChannelAccept { caRequest = req, caKey = xpublic } + liftIO $ do + let chPeers = crPeers $ fromStored $ signedData $ fromStored req + chKey = BA.take keySize $ dhSecret xsecret $ + fromStored $ crKey $ fromStored $ signedData $ fromStored req + chNonceFixedOur = BA.pack [ 2, 0, 0, 0 ] + chNonceFixedPeer = BA.pack [ 1, 0, 0, 0 ] + chCounterNextOut <- newMVar 0 + chCounterNextIn <- newMVar 0 + + return (acc, Channel {..}) + +acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel +acceptedChannel self peer acc = do + let req = caRequest $ fromStored $ signedData $ fromStored acc + case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of + Nothing -> throwError $ "invalid peers in channel accept" + Just peers -> do + when (not $ any (self `sameIdentity`) peers) $ + throwError $ "self identity missing in channel accept peers" + when (not $ any (peer `sameIdentity`) peers) $ + throwError $ "peer identity missing in channel accept peers" + when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $ + throwError $ "channel accept not signed by peer" + when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ + throwError $ "original channel request not signed by us" + + xsecret <- loadKey $ crKey $ fromStored $ signedData $ fromStored req + let chPeers = crPeers $ fromStored $ signedData $ fromStored req + chKey = BA.take keySize $ dhSecret xsecret $ + fromStored $ caKey $ fromStored $ signedData $ fromStored acc + chNonceFixedOur = BA.pack [ 1, 0, 0, 0 ] + chNonceFixedPeer = BA.pack [ 2, 0, 0, 0 ] + chCounterNextOut <- liftIO $ newMVar 0 + chCounterNextIn <- liftIO $ newMVar 0 + + return Channel {..} + + +channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelEncrypt Channel {..} plain = do + count <- liftIO $ modifyMVar chCounterNextOut $ \c -> return (c + 1, c) + let cbytes = convert $ BL.toStrict $ encode count + nonce = nonce8 chNonceFixedOur cbytes + state <- case initialize chKey =<< nonce of + CryptoPassed state -> return state + CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + + let (ctext, state') = encrypt plain state + tag = finalize state' + return (BA.concat [ convert $ BA.drop 7 cbytes, ctext, convert tag ], count) + +channelDecrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelDecrypt Channel {..} body = do + when (BA.length body < 17) $ do + throwError $ "invalid encrypted data length" + + expectedCount <- liftIO $ readMVar chCounterNextIn + let countByte = body `BA.index` 0 + body' = BA.dropView body 1 + guessedCount = expectedCount - 128 + fromIntegral (countByte - fromIntegral expectedCount + 128 :: Word8) + nonce = nonce8 chNonceFixedPeer $ convert $ BL.toStrict $ encode guessedCount + blen = BA.length body' - 16 + ctext = BA.takeView body' blen + tag = BA.dropView body' blen + state <- case initialize chKey =<< nonce of + CryptoPassed state -> return state + CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + + let (plain, state') = decrypt (convert ctext) state + when (not $ tag `BA.constEq` finalize state') $ do + throwError $ "tag validation falied" + + liftIO $ modifyMVar_ chCounterNextIn $ return . max (guessedCount + 1) + return (plain, guessedCount) diff --git a/src/Erebos/Contact.hs b/src/Erebos/Contact.hs new file mode 100644 index 0000000..d90aa50 --- /dev/null +++ b/src/Erebos/Contact.hs @@ -0,0 +1,175 @@ +module Erebos.Contact ( + Contact, + contactIdentity, + contactCustomName, + contactName, + + contactSetName, + + ContactService, + contactRequest, + contactAccept, + contactReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Maybe +import Data.Proxy +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data Contact = Contact + { contactData :: [Stored ContactData] + , contactIdentity_ :: Maybe ComposedIdentity + , contactCustomName_ :: Maybe Text + } + +data ContactData = ContactData + { cdPrev :: [Stored ContactData] + , cdIdentity :: [Stored (Signed ExtendedIdentityData)] + , cdName :: Maybe Text + } + +instance Storable ContactData where + store' x = storeRec $ do + mapM_ (storeRef "PREV") $ cdPrev x + mapM_ (storeRef "identity") $ cdIdentity x + storeMbText "name" $ cdName x + + load' = loadRec $ ContactData + <$> loadRefs "PREV" + <*> loadRefs "identity" + <*> loadMbText "name" + +instance Mergeable Contact where + type Component Contact = ContactData + + mergeSorted cdata = Contact + { contactData = cdata + , contactIdentity_ = validateExtendedIdentityF $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . cdIdentity) cdata + , contactCustomName_ = findPropertyFirst cdName cdata + } + + toComponents = contactData + +instance SharedType (Set Contact) where + sharedTypeID _ = mkSharedTypeID "34fbb61e-6022-405f-b1b3-a5a1abecd25e" + +contactIdentity :: Contact -> Maybe ComposedIdentity +contactIdentity = contactIdentity_ + +contactCustomName :: Contact -> Maybe Text +contactCustomName = contactCustomName_ + +contactName :: Contact -> Text +contactName c = fromJust $ msum + [ contactCustomName c + , idName =<< contactIdentity c + , Just T.empty + ] + +contactSetName :: MonadHead LocalState m => Contact -> Text -> Set Contact -> m (Set Contact) +contactSetName contact name set = do + st <- getStorage + cdata <- wrappedStore st ContactData + { cdPrev = toComponents contact + , cdIdentity = [] + , cdName = Just name + } + storeSetAdd st (mergeSorted @Contact [cdata]) set + + +type ContactService = PairingService ContactAccepted + +data ContactAccepted = ContactAccepted + +instance Storable ContactAccepted where + store' ContactAccepted = storeRec $ do + storeText "accept" "" + load' = loadRec $ do + (_ :: T.Text) <- loadText "accept" + return ContactAccepted + +instance PairingResult ContactAccepted where + pairingServiceID _ = mkServiceID "d9c37368-0da1-4280-93e9-d9bd9a198084" + + pairingVerifyResult = return . Just + + pairingFinalizeRequest ContactAccepted = do + pid <- asks svcPeerIdentity + finalizeContact pid + + pairingFinalizeResponse = do + pid <- asks svcPeerIdentity + finalizeContact pid + return ContactAccepted + + defaultPairingAttributes _ = PairingAttributes + { pairingHookRequest = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Contact pairing from " ++ T.unpack (displayIdentity peer) ++ " initiated" + + , pairingHookResponse = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Confirm contact " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + + , pairingHookRequestNonce = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Contact request from " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + + , pairingHookRequestNonceFailed = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Failed contact request from " ++ T.unpack (displayIdentity peer) + + , pairingHookConfirmedResponse = do + svcPrint $ "Contact accepted, waiting for peer confirmation" + + , pairingHookConfirmedRequest = do + svcPrint $ "Contact confirmed by peer" + + , pairingHookAcceptedResponse = do + svcPrint $ "Contact accepted" + + , pairingHookAcceptedRequest = do + svcPrint $ "Contact accepted" + + , pairingHookVerifyFailed = return () + + , pairingHookRejected = do + svcPrint $ "Contact rejected by peer" + + , pairingHookFailed = \_ -> do + svcPrint $ "Contact failed" + } + +contactRequest :: (MonadIO m, MonadError String m) => Peer -> m () +contactRequest = pairingRequest @ContactAccepted Proxy + +contactAccept :: (MonadIO m, MonadError String m) => Peer -> m () +contactAccept = pairingAccept @ContactAccepted Proxy + +contactReject :: (MonadIO m, MonadError String m) => Peer -> m () +contactReject = pairingReject @ContactAccepted Proxy + +finalizeContact :: MonadHead LocalState m => UnifiedIdentity -> m () +finalizeContact identity = updateLocalHead_ $ updateSharedState_ $ \contacts -> do + st <- getStorage + cdata <- wrappedStore st ContactData + { cdPrev = [] + , cdIdentity = idExtDataF $ finalOwner identity + , cdName = Nothing + } + storeSetAdd st (mergeSorted @Contact [cdata]) contacts diff --git a/src/Erebos/Discovery.hs b/src/Erebos/Discovery.hs new file mode 100644 index 0000000..86bdbe7 --- /dev/null +++ b/src/Erebos/Discovery.hs @@ -0,0 +1,222 @@ +module Erebos.Discovery ( + DiscoveryService(..), + DiscoveryConnection(..) +) where + +import Control.Concurrent +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import Data.Maybe +import Data.Text (Text) +import qualified Data.Text as T + +import Network.Socket + +import Erebos.ICE +import Erebos.Identity +import Erebos.Network +import Erebos.Service +import Erebos.Storage + + +keepaliveSeconds :: Int +keepaliveSeconds = 20 + + +data DiscoveryService = DiscoverySelf Text Int + | DiscoveryAcknowledged Text + | DiscoverySearch Ref + | DiscoveryResult Ref (Maybe Text) + | DiscoveryConnectionRequest DiscoveryConnection + | DiscoveryConnectionResponse DiscoveryConnection + +data DiscoveryConnection = DiscoveryConnection + { dconnSource :: Ref + , dconnTarget :: Ref + , dconnAddress :: Maybe Text + , dconnIceSession :: Maybe IceRemoteInfo + } + +emptyConnection :: Ref -> Ref -> DiscoveryConnection +emptyConnection source target = DiscoveryConnection source target Nothing Nothing + +instance Storable DiscoveryService where + store' x = storeRec $ do + case x of + DiscoverySelf addr priority -> do + storeText "self" addr + storeInt "priority" priority + DiscoveryAcknowledged addr -> do + storeText "ack" addr + DiscoverySearch ref -> storeRawRef "search" ref + DiscoveryResult ref addr -> do + storeRawRef "result" ref + storeMbText "address" addr + DiscoveryConnectionRequest conn -> storeConnection "request" conn + DiscoveryConnectionResponse conn -> storeConnection "response" conn + + where storeConnection ctype conn = do + storeText "connection" $ ctype + storeRawRef "source" $ dconnSource conn + storeRawRef "target" $ dconnTarget conn + storeMbText "address" $ dconnAddress conn + storeMbRef "ice-session" $ dconnIceSession conn + + load' = loadRec $ msum + [ DiscoverySelf + <$> loadText "self" + <*> loadInt "priority" + , DiscoveryAcknowledged + <$> loadText "ack" + , DiscoverySearch <$> loadRawRef "search" + , DiscoveryResult + <$> loadRawRef "result" + <*> loadMbText "address" + , loadConnection "request" DiscoveryConnectionRequest + , loadConnection "response" DiscoveryConnectionResponse + ] + where loadConnection ctype ctor = do + ctype' <- loadText "connection" + guard $ ctype == ctype' + return . ctor =<< DiscoveryConnection + <$> loadRawRef "source" + <*> loadRawRef "target" + <*> loadMbText "address" + <*> loadMbRef "ice-session" + +data DiscoveryPeer = DiscoveryPeer + { dpPriority :: Int + , dpPeer :: Maybe Peer + , dpAddress :: Maybe Text + , dpIceSession :: Maybe IceSession + } + +instance Service DiscoveryService where + serviceID _ = mkServiceID "dd59c89c-69cc-4703-b75b-4ddcd4b3c23b" + + type ServiceGlobalState DiscoveryService = Map RefDigest DiscoveryPeer + emptyServiceGlobalState _ = M.empty + + serviceHandler msg = case fromStored msg of + DiscoverySelf addr priority -> do + pid <- asks svcPeerIdentity + peer <- asks svcPeer + let insertHelper new old | dpPriority new > dpPriority old = new + | otherwise = old + mbaddr <- case words (T.unpack addr) of + [ipaddr, port] | DatagramAddress paddr <- peerAddress peer -> do + saddr <- liftIO $ head <$> getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + return $ if paddr == addrAddress saddr + then Just addr + else Nothing + _ -> return Nothing + forM_ (idDataF =<< unfoldOwners pid) $ \s -> + svcModifyGlobal $ M.insertWith insertHelper (refDigest $ storedRef s) $ + DiscoveryPeer priority (Just peer) mbaddr Nothing + replyPacket $ DiscoveryAcknowledged $ fromMaybe (T.pack "ICE") mbaddr + + DiscoveryAcknowledged addr -> do + when (addr == T.pack "ICE") $ do + -- keep-alive packet from behind NAT + peer <- asks svcPeer + liftIO $ void $ forkIO $ do + threadDelay (keepaliveSeconds * 1000 * 1000) + res <- runExceptT $ sendToPeer peer $ DiscoverySelf addr 0 + case res of + Right _ -> return () + Left err -> putStrLn $ "Discovery: failed to send keep-alive: " ++ err + + DiscoverySearch ref -> do + addr <- M.lookup (refDigest ref) <$> svcGetGlobal + replyPacket $ DiscoveryResult ref $ fromMaybe (T.pack "ICE") . dpAddress <$> addr + + DiscoveryResult ref Nothing -> do + svcPrint $ "Discovery: " ++ show (refDigest ref) ++ " not found" + + DiscoveryResult ref (Just addr) -> do + -- TODO: check if we really requested that + server <- asks svcServer + if addr == T.pack "ICE" + then do + self <- svcSelf + peer <- asks svcPeer + ice <- liftIO $ iceCreate PjIceSessRoleControlling $ \ice -> do + rinfo <- iceRemoteInfo ice + res <- runExceptT $ sendToPeer peer $ + DiscoveryConnectionRequest (emptyConnection (storedRef $ idData self) ref) { dconnIceSession = Just rinfo } + case res of + Right _ -> return () + Left err -> putStrLn $ "Discovery: failed to send connection request: " ++ err + + svcModifyGlobal $ M.insert (refDigest ref) $ + DiscoveryPeer 0 Nothing Nothing (Just ice) + else do + case words (T.unpack addr) of + [ipaddr, port] -> do + saddr <- liftIO $ head <$> + getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + peer <- liftIO $ serverPeer server (addrAddress saddr) + svcModifyGlobal $ M.insert (refDigest ref) $ + DiscoveryPeer 0 (Just peer) Nothing Nothing + + _ -> svcPrint $ "Discovery: invalid address in result: " ++ T.unpack addr + + DiscoveryConnectionRequest conn -> do + self <- svcSelf + let rconn = emptyConnection (dconnSource conn) (dconnTarget conn) + if refDigest (dconnTarget conn) `elem` (map (refDigest . storedRef) $ idDataF =<< unfoldOwners self) + then do + -- request for us, create ICE sesssion + server <- asks svcServer + peer <- asks svcPeer + liftIO $ void $ iceCreate PjIceSessRoleControlled $ \ice -> do + rinfo <- iceRemoteInfo ice + res <- runExceptT $ sendToPeer peer $ DiscoveryConnectionResponse rconn { dconnIceSession = Just rinfo } + case res of + Right _ -> do + case dconnIceSession conn of + Just prinfo -> iceConnect ice prinfo $ void $ serverPeerIce server ice + Nothing -> putStrLn $ "Discovery: connection request without ICE remote info" + Left err -> putStrLn $ "Discovery: failed to send connection response: " ++ err + + else do + -- request to some of our peers, relay + mbdp <- M.lookup (refDigest $ dconnTarget conn) <$> svcGetGlobal + case mbdp of + Nothing -> replyPacket $ DiscoveryConnectionResponse rconn + Just dp | Just addr <- dpAddress dp -> do + replyPacket $ DiscoveryConnectionResponse rconn { dconnAddress = Just addr } + | Just dpeer <- dpPeer dp -> do + sendToPeer dpeer $ DiscoveryConnectionRequest conn + | otherwise -> svcPrint $ "Discovery: failed to relay connection request" + + DiscoveryConnectionResponse conn -> do + self <- svcSelf + dpeers <- svcGetGlobal + if refDigest (dconnSource conn) `elem` (map (refDigest . storedRef) $ idDataF =<< unfoldOwners self) + then do + -- response to our request, try to connect to the peer + server <- asks svcServer + if | Just addr <- dconnAddress conn + , [ipaddr, port] <- words (T.unpack addr) -> do + saddr <- liftIO $ head <$> + getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + peer <- liftIO $ serverPeer server (addrAddress saddr) + svcModifyGlobal $ M.insert (refDigest $ dconnTarget conn) $ + DiscoveryPeer 0 (Just peer) Nothing Nothing + + | Just dp <- M.lookup (refDigest $ dconnTarget conn) dpeers + , Just ice <- dpIceSession dp + , Just rinfo <- dconnIceSession conn -> do + liftIO $ iceConnect ice rinfo $ void $ serverPeerIce server ice + + | otherwise -> svcPrint $ "Discovery: connection request failed" + else do + -- response to relayed request + case M.lookup (refDigest $ dconnSource conn) dpeers of + Just dp | Just dpeer <- dpPeer dp -> do + sendToPeer dpeer $ DiscoveryConnectionResponse conn + _ -> svcPrint $ "Discovery: failed to relay connection response" diff --git a/src/Erebos/Flow.hs b/src/Erebos/Flow.hs new file mode 100644 index 0000000..ba2607a --- /dev/null +++ b/src/Erebos/Flow.hs @@ -0,0 +1,73 @@ +module Erebos.Flow ( + Flow, SymFlow, + newFlow, newFlowIO, + readFlow, tryReadFlow, canReadFlow, + writeFlow, writeFlowBulk, tryWriteFlow, canWriteFlow, + readFlowIO, writeFlowIO, + + mapFlow, +) where + +import Control.Concurrent.STM + + +data Flow r w = Flow (TMVar [r]) (TMVar [w]) + | forall r' w'. MappedFlow (r' -> r) (w -> w') (Flow r' w') + +type SymFlow a = Flow a a + +newFlow :: STM (Flow a b, Flow b a) +newFlow = do + x <- newEmptyTMVar + y <- newEmptyTMVar + return (Flow x y, Flow y x) + +newFlowIO :: IO (Flow a b, Flow b a) +newFlowIO = atomically newFlow + +readFlow :: Flow r w -> STM r +readFlow (Flow rvar _) = takeTMVar rvar >>= \case + (x:[]) -> return x + (x:xs) -> putTMVar rvar xs >> return x + [] -> error "Flow: empty list" +readFlow (MappedFlow f _ up) = f <$> readFlow up + +tryReadFlow :: Flow r w -> STM (Maybe r) +tryReadFlow (Flow rvar _) = tryTakeTMVar rvar >>= \case + Just (x:[]) -> return (Just x) + Just (x:xs) -> putTMVar rvar xs >> return (Just x) + Just [] -> error "Flow: empty list" + Nothing -> return Nothing +tryReadFlow (MappedFlow f _ up) = fmap f <$> tryReadFlow up + +canReadFlow :: Flow r w -> STM Bool +canReadFlow (Flow rvar _) = not <$> isEmptyTMVar rvar +canReadFlow (MappedFlow _ _ up) = canReadFlow up + +writeFlow :: Flow r w -> w -> STM () +writeFlow (Flow _ wvar) = putTMVar wvar . (:[]) +writeFlow (MappedFlow _ f up) = writeFlow up . f + +writeFlowBulk :: Flow r w -> [w] -> STM () +writeFlowBulk _ [] = return () +writeFlowBulk (Flow _ wvar) xs = putTMVar wvar xs +writeFlowBulk (MappedFlow _ f up) xs = writeFlowBulk up $ map f xs + +tryWriteFlow :: Flow r w -> w -> STM Bool +tryWriteFlow (Flow _ wvar) = tryPutTMVar wvar . (:[]) +tryWriteFlow (MappedFlow _ f up) = tryWriteFlow up . f + +canWriteFlow :: Flow r w -> STM Bool +canWriteFlow (Flow _ wvar) = isEmptyTMVar wvar +canWriteFlow (MappedFlow _ _ up) = canWriteFlow up + +readFlowIO :: Flow r w -> IO r +readFlowIO path = atomically $ readFlow path + +writeFlowIO :: Flow r w -> w -> IO () +writeFlowIO path = atomically . writeFlow path + + +mapFlow :: (r -> r') -> (w' -> w) -> Flow r w -> Flow r' w' +mapFlow rf wf (MappedFlow rf' wf' up) = MappedFlow (rf . rf') (wf' . wf) up +mapFlow rf wf up = MappedFlow rf wf up diff --git a/src/Erebos/ICE.chs b/src/Erebos/ICE.chs new file mode 100644 index 0000000..096ee0d --- /dev/null +++ b/src/Erebos/ICE.chs @@ -0,0 +1,205 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE RecursiveDo #-} + +module Erebos.ICE ( + IceSession, + IceSessionRole(..), + IceRemoteInfo, + + iceCreate, + iceDestroy, + iceRemoteInfo, + iceShow, + iceConnect, + iceSend, + + iceSetChan, +) where + +import Control.Arrow +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity + +import Data.ByteString (ByteString, packCStringLen, useAsCString) +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.ByteString.Unsafe +import Data.Function +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import qualified Data.Text.Read as T +import Data.Void + +import Foreign.C.String +import Foreign.C.Types +import Foreign.Marshal.Alloc +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.StablePtr + +import Erebos.Flow +import Erebos.Storage + +#include "pjproject.h" + +data IceSession = IceSession + { isStrans :: PjIceStrans + , isChan :: MVar (Either [ByteString] (Flow Void ByteString)) + } + +instance Eq IceSession where + (==) = (==) `on` isStrans + +instance Ord IceSession where + compare = compare `on` isStrans + +instance Show IceSession where + show _ = "<ICE>" + + +data IceRemoteInfo = IceRemoteInfo + { iriUsernameFrament :: Text + , iriPassword :: Text + , iriDefaultCandidate :: Text + , iriCandidates :: [Text] + } + +data IceCandidate = IceCandidate + { icandFoundation :: Text + , icandPriority :: Int + , icandAddr :: Text + , icandPort :: Int + , icandType :: Text + } + +instance Storable IceRemoteInfo where + store' x = storeRec $ do + storeText "ice-ufrag" $ iriUsernameFrament x + storeText "ice-pass" $ iriPassword x + storeText "ice-default" $ iriDefaultCandidate x + mapM_ (storeText "ice-candidate") $ iriCandidates x + + load' = loadRec $ IceRemoteInfo + <$> loadText "ice-ufrag" + <*> loadText "ice-pass" + <*> loadText "ice-default" + <*> loadTexts "ice-candidate" + +instance StorableText IceCandidate where + toText x = T.concat $ + [ icandFoundation x + , T.singleton ' ' + , T.pack $ show $ icandPriority x + , T.singleton ' ' + , icandAddr x + , T.singleton ' ' + , T.pack $ show $ icandPort x + , T.singleton ' ' + , icandType x + ] + + fromText t = case T.words t of + [found, tprio, addr, tport, ctype] + | Right (prio, _) <- T.decimal tprio + , Right (port, _) <- T.decimal tport + -> return $ IceCandidate + { icandFoundation = found + , icandPriority = prio + , icandAddr = addr + , icandPort = port + , icandType = ctype + } + _ -> throwError "failed to parse candidate" + + +{#enum pj_ice_sess_role as IceSessionRole {underscoreToCase} deriving (Show, Eq) #} + +{#pointer *pj_ice_strans as ^ #} + +iceCreate :: IceSessionRole -> (IceSession -> IO ()) -> IO IceSession +iceCreate role cb = do + rec sptr <- newStablePtr sess + cbptr <- newStablePtr $ cb sess + sess <- IceSession + <$> {#call ice_create #} (fromIntegral $ fromEnum role) (castStablePtrToPtr sptr) (castStablePtrToPtr cbptr) + <*> (newMVar $ Left []) + return $ sess + +{#fun ice_destroy as ^ { isStrans `IceSession' } -> `()' #} + +iceRemoteInfo :: IceSession -> IO IceRemoteInfo +iceRemoteInfo sess = do + let maxlen = 128 + maxcand = 29 + + allocaBytes maxlen $ \ufrag -> + allocaBytes maxlen $ \pass -> + allocaBytes maxlen $ \def -> + allocaBytes (maxcand*maxlen) $ \bytes -> + allocaArray maxcand $ \carr -> do + let cptrs = take maxcand $ iterate (`plusPtr` maxlen) bytes + pokeArray carr $ take maxcand cptrs + + ncand <- {#call ice_encode_session #} (isStrans sess) ufrag pass def carr (fromIntegral maxlen) (fromIntegral maxcand) + if ncand < 0 then fail "failed to generate ICE remote info" + else IceRemoteInfo + <$> (T.pack <$> peekCString ufrag) + <*> (T.pack <$> peekCString pass) + <*> (T.pack <$> peekCString def) + <*> (mapM (return . T.pack <=< peekCString) $ take (fromIntegral ncand) cptrs) + +iceShow :: IceSession -> IO String +iceShow sess = do + st <- memoryStorage + return . drop 1 . dropWhile (/='\n') . BLC.unpack . runIdentity =<< + ioLoadBytes =<< store st =<< iceRemoteInfo sess + +iceConnect :: IceSession -> IceRemoteInfo -> (IO ()) -> IO () +iceConnect sess remote cb = do + cbptr <- newStablePtr $ cb + ice_connect sess cbptr + (iriUsernameFrament remote) + (iriPassword remote) + (iriDefaultCandidate remote) + (iriCandidates remote) + +{#fun ice_connect { isStrans `IceSession', castStablePtrToPtr `StablePtr (IO ())', + withText* `Text', withText* `Text', withText* `Text', withTextArray* `[Text]'& } -> `()' #} + +withText :: Text -> (Ptr CChar -> IO a) -> IO a +withText t f = useAsCString (T.encodeUtf8 t) f + +withTextArray :: Num n => [Text] -> ((Ptr (Ptr CChar), n) -> IO ()) -> IO () +withTextArray tsAll f = helper tsAll [] + where helper (t:ts) bs = withText t $ \b -> helper ts (b:bs) + helper [] bs = allocaArray (length bs) $ \ptr -> do + pokeArray ptr $ reverse bs + f (ptr, fromIntegral $ length bs) + +withByteStringLen :: Num n => ByteString -> ((Ptr CChar, n) -> IO a) -> IO a +withByteStringLen t f = unsafeUseAsCStringLen t (f . (id *** fromIntegral)) + +{#fun ice_send as ^ { isStrans `IceSession', withByteStringLen* `ByteString'& } -> `()' #} + +foreign export ccall ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb = join . deRefStablePtr + +iceSetChan :: IceSession -> Flow Void ByteString -> IO () +iceSetChan sess chan = do + modifyMVar_ (isChan sess) $ \orig -> do + case orig of + Left buf -> mapM_ (writeFlowIO chan) $ reverse buf + Right _ -> return () + return $ Right chan + +foreign export ccall ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data sptr buf len = do + sess <- deRefStablePtr sptr + bs <- packCStringLen (buf, len) + modifyMVar_ (isChan sess) $ \case + mc@(Right chan) -> writeFlowIO chan bs >> return mc + Left bss -> return $ Left (bs:bss) diff --git a/src/Erebos/ICE/pjproject.c b/src/Erebos/ICE/pjproject.c new file mode 100644 index 0000000..bb06b1f --- /dev/null +++ b/src/Erebos/ICE/pjproject.c @@ -0,0 +1,363 @@ +#include "pjproject.h" +#include "Erebos/ICE_stub.h" + +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <pthread.h> +#include <pjlib.h> +#include <pjlib-util.h> + +static struct +{ + pj_caching_pool cp; + pj_pool_t * pool; + pj_ice_strans_cfg cfg; + pj_sockaddr def_addr; +} ice; + +struct user_data +{ + pj_ice_sess_role role; + HsStablePtr sptr; + HsStablePtr cb_init; + HsStablePtr cb_connect; +}; + +static void ice_perror(const char * msg, pj_status_t status) +{ + char err[PJ_ERR_MSG_SIZE]; + pj_strerror(status, err, sizeof(err)); + fprintf(stderr, "ICE: %s: %s\n", msg, err); +} + +static int ice_worker_thread(void * unused) +{ + PJ_UNUSED_ARG(unused); + + while (true) { + pj_time_val max_timeout = { 0, 0 }; + pj_time_val timeout = { 0, 0 }; + + max_timeout.msec = 500; + + pj_timer_heap_poll(ice.cfg.stun_cfg.timer_heap, &timeout); + + pj_assert(timeout.sec >= 0 && timeout.msec >= 0); + if (timeout.msec >= 1000) + timeout.msec = 999; + + if (PJ_TIME_VAL_GT(timeout, max_timeout)) + timeout = max_timeout; + + int c = pj_ioqueue_poll(ice.cfg.stun_cfg.ioqueue, &timeout); + if (c < 0) + pj_thread_sleep(PJ_TIME_VAL_MSEC(timeout)); + } + + return 0; +} + +static void cb_on_rx_data(pj_ice_strans * strans, unsigned comp_id, + void * pkt, pj_size_t size, + const pj_sockaddr_t * src_addr, unsigned src_addr_len) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + ice_rx_data(udata->sptr, pkt, size); +} + +static void cb_on_ice_complete(pj_ice_strans * strans, + pj_ice_strans_op op, pj_status_t status) +{ + if (status != PJ_SUCCESS) { + ice_perror("cb_on_ice_complete", status); + ice_destroy(strans); + return; + } + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (op == PJ_ICE_STRANS_OP_INIT) { + pj_status_t istatus = pj_ice_strans_init_ice(strans, udata->role, NULL, NULL); + if (istatus != PJ_SUCCESS) + ice_perror("error creating session", istatus); + + if (udata->cb_init) { + ice_call_cb(udata->cb_init); + hs_free_stable_ptr(udata->cb_init); + udata->cb_init = NULL; + } + } + + if (op == PJ_ICE_STRANS_OP_NEGOTIATION) { + if (udata->cb_connect) { + ice_call_cb(udata->cb_connect); + hs_free_stable_ptr(udata->cb_connect); + udata->cb_connect = NULL; + } + } +} + +static void ice_init(void) +{ + static bool done = false; + static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; + pthread_mutex_lock(&mutex); + + if (done) { + pthread_mutex_unlock(&mutex); + goto exit; + } + + pj_log_set_level(1); + + if (pj_init() != PJ_SUCCESS) { + fprintf(stderr, "pj_init failed\n"); + goto exit; + } + if (pjlib_util_init() != PJ_SUCCESS) { + fprintf(stderr, "pjlib_util_init failed\n"); + goto exit; + } + if (pjnath_init() != PJ_SUCCESS) { + fprintf(stderr, "pjnath_init failed\n"); + goto exit; + } + + pj_caching_pool_init(&ice.cp, NULL, 0); + + pj_ice_strans_cfg_default(&ice.cfg); + ice.cfg.stun_cfg.pf = &ice.cp.factory; + + ice.pool = pj_pool_create(&ice.cp.factory, "ice", 512, 512, NULL); + + if (pj_timer_heap_create(ice.pool, 100, + &ice.cfg.stun_cfg.timer_heap) != PJ_SUCCESS) { + fprintf(stderr, "pj_timer_heap_create failed\n"); + goto exit; + } + + if (pj_ioqueue_create(ice.pool, 16, &ice.cfg.stun_cfg.ioqueue) != PJ_SUCCESS) { + fprintf(stderr, "pj_ioqueue_create failed\n"); + goto exit; + } + + pj_thread_t * thread; + if (pj_thread_create(ice.pool, "ice", &ice_worker_thread, + NULL, 0, 0, &thread) != PJ_SUCCESS) { + fprintf(stderr, "pj_thread_create failed\n"); + goto exit; + } + + ice.cfg.af = pj_AF_INET(); + ice.cfg.opt.aggressive = PJ_TRUE; + + ice.cfg.stun.server.ptr = "discovery1.erebosprotocol.net"; + ice.cfg.stun.server.slen = strlen(ice.cfg.stun.server.ptr); + ice.cfg.stun.port = 29670; + + ice.cfg.turn.server = ice.cfg.stun.server; + ice.cfg.turn.port = ice.cfg.stun.port; + ice.cfg.turn.auth_cred.type = PJ_STUN_AUTH_CRED_STATIC; + ice.cfg.turn.auth_cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN; + ice.cfg.turn.conn_type = PJ_TURN_TP_UDP; + +exit: + done = true; + pthread_mutex_unlock(&mutex); +} + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb) +{ + ice_init(); + + pj_ice_strans * res; + + struct user_data * udata = malloc(sizeof(struct user_data)); + udata->role = role; + udata->sptr = sptr; + udata->cb_init = cb; + + pj_ice_strans_cb icecb = { + .on_rx_data = cb_on_rx_data, + .on_ice_complete = cb_on_ice_complete, + }; + + pj_status_t status = pj_ice_strans_create(NULL, &ice.cfg, 1, + udata, &icecb, &res); + + if (status != PJ_SUCCESS) + ice_perror("error creating ice", status); + + return res; +} + +void ice_destroy(pj_ice_strans * strans) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (udata->sptr) + hs_free_stable_ptr(udata->sptr); + if (udata->cb_init) + hs_free_stable_ptr(udata->cb_init); + if (udata->cb_connect) + hs_free_stable_ptr(udata->cb_connect); + free(udata); + + pj_ice_strans_stop_ice(strans); + pj_ice_strans_destroy(strans); +} + +ssize_t ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand) +{ + int n; + pj_str_t local_ufrag, local_pwd; + pj_status_t status; + + pj_ice_strans_get_ufrag_pwd(strans, &local_ufrag, &local_pwd, NULL, NULL); + + n = snprintf(ufrag, maxlen, "%.*s", (int) local_ufrag.slen, local_ufrag.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + n = snprintf(pass, maxlen, "%.*s", (int) local_pwd.slen, local_pwd.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + pj_ice_sess_cand cand[PJ_ICE_ST_MAX_CAND]; + char ipaddr[PJ_INET6_ADDRSTRLEN]; + + status = pj_ice_strans_get_def_cand(strans, 1, &cand[0]); + if (status != PJ_SUCCESS) + return -status; + + n = snprintf(def, maxlen, "%s %d", + pj_sockaddr_print(&cand[0].addr, ipaddr, sizeof(ipaddr), 0), + (int) pj_sockaddr_get_port(&cand[0].addr)); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + unsigned cand_cnt = PJ_ARRAY_SIZE(cand); + status = pj_ice_strans_enum_cands(strans, 1, &cand_cnt, cand); + if (status != PJ_SUCCESS) + return -status; + + for (unsigned i = 0; i < cand_cnt && i < maxcand; i++) { + char ipaddr[PJ_INET6_ADDRSTRLEN]; + n = snprintf(candidates[i], maxlen, + "%.*s %u %s %u %s", + (int) cand[i].foundation.slen, cand[i].foundation.ptr, + cand[i].prio, + pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0), + (unsigned) pj_sockaddr_get_port(&cand[i].addr), + pj_ice_get_cand_type_name(cand[i].type)); + + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + } + + return cand_cnt; +} + +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * tcandidates[], size_t ncand) +{ + unsigned def_port = 0; + char def_addr[80]; + pj_bool_t done = PJ_FALSE; + char line[256]; + pj_ice_sess_cand candidates[PJ_ICE_ST_MAX_CAND]; + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + udata->cb_connect = cb; + + def_addr[0] = '\0'; + + if (ncand == 0) { + fprintf(stderr, "ICE: no candidates\n"); + return; + } + + int cnt = sscanf(defcand, "%s %u", def_addr, &def_port); + if (cnt != 2) { + fprintf(stderr, "ICE: error parsing default candidate\n"); + return; + } + + int okcand = 0; + for (int i = 0; i < ncand; i++) { + char foundation[32], ipaddr[80], type[32]; + int prio, port; + + int cnt = sscanf(tcandidates[i], "%s %d %s %d %s", + foundation, &prio, + ipaddr, &port, + type); + if (cnt != 5) + continue; + + pj_ice_sess_cand * cand = &candidates[okcand]; + pj_bzero(cand, sizeof(*cand)); + + if (strcmp(type, "host") == 0) + cand->type = PJ_ICE_CAND_TYPE_HOST; + else if (strcmp(type, "srflx") == 0) + cand->type = PJ_ICE_CAND_TYPE_SRFLX; + else if (strcmp(type, "relay") == 0) + cand->type = PJ_ICE_CAND_TYPE_RELAYED; + else + continue; + + cand->comp_id = 1; + pj_strdup2(ice.pool, &cand->foundation, foundation); + cand->prio = prio; + + int af = strchr(ipaddr, ':') ? pj_AF_INET6() : pj_AF_INET(); + pj_str_t tmpaddr = pj_str(ipaddr); + pj_sockaddr_init(af, &cand->addr, NULL, 0); + pj_status_t status = pj_sockaddr_set_str_addr(af, &cand->addr, &tmpaddr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid IP address \"%s\"\n", ipaddr); + continue; + } + + pj_sockaddr_set_port(&cand->addr, (pj_uint16_t)port); + okcand++; + } + + pj_str_t tmp_addr; + pj_status_t status; + + int af = strchr(def_addr, ':') ? pj_AF_INET6() : pj_AF_INET(); + + pj_sockaddr_init(af, &ice.def_addr, NULL, 0); + tmp_addr = pj_str(def_addr); + status = pj_sockaddr_set_str_addr(af, &ice.def_addr, &tmp_addr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid default IP address \"%s\"\n", def_addr); + return; + } + pj_sockaddr_set_port(&ice.def_addr, (pj_uint16_t) def_port); + + pj_str_t rufrag, rpwd; + status = pj_ice_strans_start_ice(strans, + pj_cstr(&rufrag, ufrag), pj_cstr(&rpwd, pass), + okcand, candidates); + if (status != PJ_SUCCESS) { + ice_perror("error starting ICE", status); + return; + } +} + +void ice_send(pj_ice_strans * strans, const char * data, size_t len) +{ + if (!pj_ice_strans_sess_is_complete(strans)) { + fprintf(stderr, "ICE: negotiation has not been started or is in progress\n"); + return; + } + + pj_status_t status = pj_ice_strans_sendto(strans, 1, data, len, + &ice.def_addr, pj_sockaddr_get_len(&ice.def_addr)); + if (status != PJ_SUCCESS && status != PJ_EPENDING) + ice_perror("error sending data", status); +} diff --git a/src/Erebos/ICE/pjproject.h b/src/Erebos/ICE/pjproject.h new file mode 100644 index 0000000..e230e75 --- /dev/null +++ b/src/Erebos/ICE/pjproject.h @@ -0,0 +1,14 @@ +#pragma once + +#include <pjnath.h> +#include <HsFFI.h> + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb); +void ice_destroy(pj_ice_strans * strans); + +ssize_t ice_encode_session(pj_ice_strans *, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand); +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * candidates[], size_t ncand); +void ice_send(pj_ice_strans *, const char * data, size_t len); diff --git a/src/Erebos/Identity.hs b/src/Erebos/Identity.hs new file mode 100644 index 0000000..8761fde --- /dev/null +++ b/src/Erebos/Identity.hs @@ -0,0 +1,402 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Erebos.Identity ( + Identity, ComposedIdentity, UnifiedIdentity, + IdentityData(..), ExtendedIdentityData(..), IdentityExtension(..), + idData, idDataF, idExtData, idExtDataF, + idName, idOwner, idUpdates, idKeyIdentity, idKeyMessage, + eiddBase, eiddStoredBase, + eiddName, eiddOwner, eiddKeyIdentity, eiddKeyMessage, + + emptyIdentityData, + emptyIdentityExtension, + createIdentity, + validateIdentity, validateIdentityF, validateIdentityFE, + validateExtendedIdentity, validateExtendedIdentityF, validateExtendedIdentityFE, + loadIdentity, loadUnifiedIdentity, + + mergeIdentity, toUnifiedIdentity, toComposedIdentity, + updateIdentity, updateOwners, + sameIdentity, + + unfoldOwners, + finalOwner, + displayIdentity, +) where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity qualified as I +import Control.Monad.Reader + +import Data.Either +import Data.Foldable +import Data.Function +import Data.List +import Data.Maybe +import Data.Ord +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.PubKey +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Util + +data Identity m = IdentityKind m => Identity + { idData_ :: m (Stored (Signed ExtendedIdentityData)) + , idName_ :: Maybe Text + , idOwner_ :: Maybe ComposedIdentity + , idUpdates_ :: [Stored (Signed ExtendedIdentityData)] + , idKeyIdentity_ :: Stored PublicKey + , idKeyMessage_ :: Stored PublicKey + } + +deriving instance Show (m (Stored (Signed ExtendedIdentityData))) => Show (Identity m) + +class (Functor f, Foldable f) => IdentityKind f where + ikFilterAncestors :: Storable a => f (Stored a) -> f (Stored a) + +instance IdentityKind I.Identity where + ikFilterAncestors = id + +instance IdentityKind [] where + ikFilterAncestors = filterAncestors + +type ComposedIdentity = Identity [] +type UnifiedIdentity = Identity I.Identity + +instance Eq (m (Stored (Signed ExtendedIdentityData))) => Eq (Identity m) where + (==) = (==) `on` (idData_ &&& idUpdates_) + +data IdentityData = IdentityData + { iddPrev :: [Stored (Signed IdentityData)] + , iddName :: Maybe Text + , iddOwner :: Maybe (Stored (Signed IdentityData)) + , iddKeyIdentity :: Stored PublicKey + , iddKeyMessage :: Maybe (Stored PublicKey) + } + deriving (Show) + +data IdentityExtension = IdentityExtension + { idePrev :: [Stored (Signed ExtendedIdentityData)] + , ideBase :: Stored (Signed IdentityData) + , ideName :: Maybe Text + , ideOwner :: Maybe (Stored (Signed ExtendedIdentityData)) + } + deriving (Show) + +data ExtendedIdentityData = BaseIdentityData IdentityData + | ExtendedIdentityData IdentityExtension + deriving (Show) + +baseToExtended :: Stored (Signed IdentityData) -> Stored (Signed ExtendedIdentityData) +baseToExtended = unsafeMapStored (unsafeMapSigned BaseIdentityData) + +instance Storable IdentityData where + store' idt = storeRec $ do + mapM_ (storeRef "SPREV") $ iddPrev idt + storeMbText "name" $ iddName idt + storeMbRef "owner" $ iddOwner idt + storeRef "key-id" $ iddKeyIdentity idt + storeMbRef "key-msg" $ iddKeyMessage idt + + load' = loadRec $ IdentityData + <$> loadRefs "SPREV" + <*> loadMbText "name" + <*> loadMbRef "owner" + <*> loadRef "key-id" + <*> loadMbRef "key-msg" + +instance Storable IdentityExtension where + store' IdentityExtension {..} = storeRec $ do + mapM_ (storeRef "SPREV") idePrev + storeRef "SBASE" ideBase + storeMbText "name" ideName + storeMbRef "owner" ideOwner + + load' = loadRec $ IdentityExtension + <$> loadRefs "SPREV" + <*> loadRef "SBASE" + <*> loadMbText "name" + <*> loadMbRef "owner" + +instance Storable ExtendedIdentityData where + store' (BaseIdentityData idata) = store' idata + store' (ExtendedIdentityData idata) = store' idata + + load' = msum + [ BaseIdentityData <$> load' + , ExtendedIdentityData <$> load' + ] + +instance Mergeable (Maybe ComposedIdentity) where + type Component (Maybe ComposedIdentity) = Signed ExtendedIdentityData + mergeSorted = validateExtendedIdentityF + toComponents = maybe [] idExtDataF + +idData :: UnifiedIdentity -> Stored (Signed IdentityData) +idData = I.runIdentity . idDataF + +idDataF :: Identity m -> m (Stored (Signed IdentityData)) +idDataF idt@Identity {} = ikFilterAncestors . fmap eiddStoredBase . idData_ $ idt + +idExtData :: UnifiedIdentity -> Stored (Signed ExtendedIdentityData) +idExtData = I.runIdentity . idExtDataF + +idExtDataF :: Identity m -> m (Stored (Signed ExtendedIdentityData)) +idExtDataF = idData_ + +idName :: Identity m -> Maybe Text +idName = idName_ + +idOwner :: Identity m -> Maybe ComposedIdentity +idOwner = idOwner_ + +idUpdates :: Identity m -> [Stored (Signed ExtendedIdentityData)] +idUpdates = idUpdates_ + +idKeyIdentity :: Identity m -> Stored PublicKey +idKeyIdentity = idKeyIdentity_ + +idKeyMessage :: Identity m -> Stored PublicKey +idKeyMessage = idKeyMessage_ + +eiddPrev :: ExtendedIdentityData -> [Stored (Signed ExtendedIdentityData)] +eiddPrev (BaseIdentityData idata) = baseToExtended <$> iddPrev idata +eiddPrev (ExtendedIdentityData IdentityExtension {..}) = baseToExtended ideBase : idePrev + +eiddBase :: ExtendedIdentityData -> IdentityData +eiddBase (BaseIdentityData idata) = idata +eiddBase (ExtendedIdentityData IdentityExtension {..}) = fromSigned ideBase + +eiddStoredBase :: Stored (Signed ExtendedIdentityData) -> Stored (Signed IdentityData) +eiddStoredBase ext = case fromSigned ext of + (BaseIdentityData idata) -> unsafeMapStored (unsafeMapSigned (const idata)) ext + (ExtendedIdentityData IdentityExtension {..}) -> ideBase + +eiddName :: ExtendedIdentityData -> Maybe Text +eiddName (BaseIdentityData idata) = iddName idata +eiddName (ExtendedIdentityData IdentityExtension {..}) = ideName + +eiddOwner :: ExtendedIdentityData -> Maybe (Stored (Signed ExtendedIdentityData)) +eiddOwner (BaseIdentityData idata) = baseToExtended <$> iddOwner idata +eiddOwner (ExtendedIdentityData IdentityExtension {..}) = ideOwner + +eiddKeyIdentity :: ExtendedIdentityData -> Stored PublicKey +eiddKeyIdentity = iddKeyIdentity . eiddBase + +eiddKeyMessage :: ExtendedIdentityData -> Maybe (Stored PublicKey) +eiddKeyMessage = iddKeyMessage . eiddBase + + +emptyIdentityData :: Stored PublicKey -> IdentityData +emptyIdentityData key = IdentityData + { iddName = Nothing + , iddPrev = [] + , iddOwner = Nothing + , iddKeyIdentity = key + , iddKeyMessage = Nothing + } + +emptyIdentityExtension :: Stored (Signed IdentityData) -> IdentityExtension +emptyIdentityExtension base = IdentityExtension + { idePrev = [] + , ideBase = base + , ideName = Nothing + , ideOwner = Nothing + } + +isExtension :: Stored (Signed ExtendedIdentityData) -> Bool +isExtension x = case fromSigned x of BaseIdentityData {} -> False + _ -> True + + +createIdentity :: Storage -> Maybe Text -> Maybe UnifiedIdentity -> IO UnifiedIdentity +createIdentity st name owner = do + (secret, public) <- generateKeys st + (_secretMsg, publicMsg) <- generateKeys st + + let signOwner :: Signed a -> ReaderT Storage IO (Signed a) + signOwner idd + | Just o <- owner = do + Just ownerSecret <- loadKeyMb (iddKeyIdentity $ fromSigned $ idData o) + signAdd ownerSecret idd + | otherwise = return idd + + Just identity <- flip runReaderT st $ do + baseData <- mstore =<< signOwner =<< sign secret =<< + mstore (emptyIdentityData public) + { iddOwner = idData <$> owner + , iddKeyMessage = Just publicMsg + } + let extOwner = do + odata <- idExtData <$> owner + guard $ isExtension odata + return odata + + validateExtendedIdentityF . I.Identity <$> + if isJust name || isJust extOwner + then mstore =<< signOwner =<< sign secret =<< + mstore . ExtendedIdentityData =<< return (emptyIdentityExtension baseData) + { ideName = name + , ideOwner = extOwner + } + else return $ baseToExtended baseData + return identity + +validateIdentity :: Stored (Signed IdentityData) -> Maybe UnifiedIdentity +validateIdentity = validateIdentityF . I.Identity + +validateIdentityF :: IdentityKind m => m (Stored (Signed IdentityData)) -> Maybe (Identity m) +validateIdentityF = either (const Nothing) Just . runExcept . validateIdentityFE + +validateIdentityFE :: IdentityKind m => m (Stored (Signed IdentityData)) -> Except String (Identity m) +validateIdentityFE = validateExtendedIdentityFE . fmap baseToExtended + +validateExtendedIdentity :: Stored (Signed ExtendedIdentityData) -> Maybe UnifiedIdentity +validateExtendedIdentity = validateExtendedIdentityF . I.Identity + +validateExtendedIdentityF :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Maybe (Identity m) +validateExtendedIdentityF = either (const Nothing) Just . runExcept . validateExtendedIdentityFE + +validateExtendedIdentityFE :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Except String (Identity m) +validateExtendedIdentityFE mdata = do + let idata = ikFilterAncestors mdata + when (null idata) $ throwError "null data" + mapM_ verifySignatures $ gatherPrevious S.empty $ toList idata + Identity + <$> pure idata + <*> pure (lookupProperty eiddName idata) + <*> case lookupProperty eiddOwner idata of + Nothing -> return Nothing + Just owner -> return <$> validateExtendedIdentityFE [owner] + <*> pure [] + <*> pure (eiddKeyIdentity $ fromSigned $ minimum idata) + <*> case lookupProperty eiddKeyMessage idata of + Nothing -> throwError "no message key" + Just mk -> return mk + +loadIdentity :: String -> LoadRec ComposedIdentity +loadIdentity name = maybe (throwError "identity validation failed") return . validateExtendedIdentityF =<< loadRefs name + +loadUnifiedIdentity :: String -> LoadRec UnifiedIdentity +loadUnifiedIdentity name = maybe (throwError "identity validation failed") return . validateExtendedIdentity =<< loadRef name + + +gatherPrevious :: Set (Stored (Signed ExtendedIdentityData)) -> [Stored (Signed ExtendedIdentityData)] -> Set (Stored (Signed ExtendedIdentityData)) +gatherPrevious res (n:ns) | n `S.member` res = gatherPrevious res ns + | otherwise = gatherPrevious (S.insert n res) $ (eiddPrev $ fromSigned n) ++ ns +gatherPrevious res [] = res + +verifySignatures :: Stored (Signed ExtendedIdentityData) -> Except String () +verifySignatures sidd = do + let idd = fromSigned sidd + required = concat + [ [ eiddKeyIdentity idd ] + , map (eiddKeyIdentity . fromSigned) $ eiddPrev idd + , map (eiddKeyIdentity . fromSigned) $ toList $ eiddOwner idd + ] + unless (all (fromStored sidd `isSignedBy`) required) $ do + throwError "signature verification failed" + +lookupProperty :: forall a m. Foldable m => (ExtendedIdentityData -> Maybe a) -> m (Stored (Signed ExtendedIdentityData)) -> Maybe a +lookupProperty sel topHeads = findResult filteredLayers + where findPropHeads :: Stored (Signed ExtendedIdentityData) -> [(Stored (Signed ExtendedIdentityData), a)] + findPropHeads sobj | Just x <- sel $ fromSigned sobj = [(sobj, x)] + | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj) + + propHeads :: [(Stored (Signed ExtendedIdentityData), a)] + propHeads = findPropHeads =<< toList topHeads + + historyLayers :: [Set (Stored (Signed ExtendedIdentityData))] + historyLayers = generations $ map fst propHeads + + filteredLayers :: [[(Stored (Signed ExtendedIdentityData), a)]] + filteredLayers = scanl (\cur obsolete -> filter ((`S.notMember` obsolete) . fst) cur) propHeads historyLayers + + findResult ([(_, x)] : _) = Just x + findResult ([] : _) = Nothing + findResult [] = Nothing + findResult [xs] = Just $ snd $ minimumBy (comparing fst) xs + findResult (_:rest) = findResult rest + +mergeIdentity :: (MonadStorage m, MonadError String m, MonadIO m) => Identity f -> m UnifiedIdentity +mergeIdentity idt | Just idt' <- toUnifiedIdentity idt = return idt' +mergeIdentity idt@Identity {..} = do + (owner, ownerData) <- case idOwner_ of + Nothing -> return (Nothing, Nothing) + Just cowner | Just owner <- toUnifiedIdentity cowner -> return (Just owner, Nothing) + | otherwise -> do owner <- mergeIdentity cowner + return (Just owner, Just $ idData owner) + + let public = idKeyIdentity idt + secret <- loadKey public + + unifiedBaseData <- + case toList $ idDataF idt of + [idata] -> return idata + idatas -> mstore =<< sign secret =<< mstore (emptyIdentityData public) + { iddPrev = idatas, iddOwner = ownerData } + + case filter isExtension $ toList $ idExtDataF idt of + [] -> return Identity { idData_ = I.Identity (baseToExtended unifiedBaseData), idOwner_ = toComposedIdentity <$> owner, .. } + extdata -> do + unifiedExtendedData <- mstore =<< sign secret =<< + (mstore . ExtendedIdentityData) (emptyIdentityExtension unifiedBaseData) + { idePrev = extdata } + return Identity { idData_ = I.Identity unifiedExtendedData, idOwner_ = toComposedIdentity <$> owner, .. } + + +toUnifiedIdentity :: Identity m -> Maybe UnifiedIdentity +toUnifiedIdentity Identity {..} + | [sdata] <- toList idData_ = Just Identity { idData_ = I.Identity sdata, .. } + | otherwise = Nothing + +toComposedIdentity :: Identity m -> ComposedIdentity +toComposedIdentity Identity {..} = Identity { idData_ = toList idData_ + , idOwner_ = toComposedIdentity <$> idOwner_ + , .. + } + +updateIdentity :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> ComposedIdentity +updateIdentity [] orig = toComposedIdentity orig +updateIdentity updates orig@Identity {} = + case validateExtendedIdentityF $ ourUpdates ++ idata of + Just updated -> updated + { idOwner_ = updateIdentity ownerUpdates <$> idOwner_ updated + , idUpdates_ = ownerUpdates + } + Nothing -> toComposedIdentity orig + where idata = toList $ idData_ orig + idataRoots = foldl' mergeUniq [] $ map storedRoots idata + (ourUpdates, ownerUpdates) = partitionEithers $ flip map (filterAncestors $ updates ++ idUpdates_ orig) $ + -- if an update is related to anything in idData_, use it here, otherwise push to owners + \u -> if storedRoots u `intersectsSorted` idataRoots + then Left u + else Right u + +updateOwners :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> Identity m +updateOwners updates orig@Identity { idOwner_ = Just owner, idUpdates_ = cupdates } = + orig { idOwner_ = Just $ updateIdentity updates owner, idUpdates_ = filterAncestors (updates ++ cupdates) } +updateOwners _ orig@Identity { idOwner_ = Nothing } = orig + +sameIdentity :: (Foldable m, Foldable m') => Identity m -> Identity m' -> Bool +sameIdentity x y = not $ S.null $ S.intersection (refset x) (refset y) + where refset idt = foldr S.insert (ancestors $ toList $ idDataF idt) (idDataF idt) + + +unfoldOwners :: (Foldable m) => Identity m -> [ComposedIdentity] +unfoldOwners = unfoldr (fmap (\i -> (i, idOwner i))) . Just . toComposedIdentity + +finalOwner :: (Foldable m, Applicative m) => Identity m -> ComposedIdentity +finalOwner = last . unfoldOwners + +displayIdentity :: (Foldable m, Applicative m) => Identity m -> Text +displayIdentity identity = T.concat + [ T.intercalate (T.pack " / ") $ map (fromMaybe (T.pack "<unnamed>") . idName) owners + ] + where owners = reverse $ unfoldOwners identity diff --git a/src/Erebos/Message.hs b/src/Erebos/Message.hs new file mode 100644 index 0000000..50e4acb --- /dev/null +++ b/src/Erebos/Message.hs @@ -0,0 +1,266 @@ +module Erebos.Message ( + DirectMessage(..), + sendDirectMessage, + + DirectMessageAttributes(..), + defaultDirectMessageAttributes, + + DirectMessageThreads, + toThreadList, + + DirectMessageThread(..), + threadToList, + messageThreadView, + + watchReceivedMessages, + formatMessage, +) where + +import Control.Monad.Except +import Control.Monad.Reader + +import Data.List +import Data.Ord +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Time.Format +import Data.Time.LocalTime + +import Erebos.Identity +import Erebos.Network +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data DirectMessage = DirectMessage + { msgFrom :: ComposedIdentity + , msgPrev :: [Stored DirectMessage] + , msgTime :: ZonedTime + , msgText :: Text + } + +instance Storable DirectMessage where + store' msg = storeRec $ do + mapM_ (storeRef "from") $ idExtDataF $ msgFrom msg + mapM_ (storeRef "PREV") $ msgPrev msg + storeDate "time" $ msgTime msg + storeText "text" $ msgText msg + + load' = loadRec $ DirectMessage + <$> loadIdentity "from" + <*> loadRefs "PREV" + <*> loadDate "time" + <*> loadText "text" + +data DirectMessageAttributes = DirectMessageAttributes + { dmOwnerMismatch :: ServiceHandler DirectMessage () + } + +defaultDirectMessageAttributes :: DirectMessageAttributes +defaultDirectMessageAttributes = DirectMessageAttributes + { dmOwnerMismatch = svcPrint "Owner mismatch" + } + +instance Service DirectMessage where + serviceID _ = mkServiceID "c702076c-4928-4415-8b6b-3e839eafcb0d" + + type ServiceAttributes DirectMessage = DirectMessageAttributes + defaultServiceAttributes _ = defaultDirectMessageAttributes + + serviceHandler smsg = do + let msg = fromStored smsg + powner <- asks $ finalOwner . svcPeerIdentity + erb <- svcGetLocal + st <- getStorage + let DirectMessageThreads prev _ = lookupSharedValue $ lsShared $ fromStored erb + sent = findMsgProperty powner msSent prev + received = findMsgProperty powner msReceived prev + received' = filterAncestors $ smsg : received + if powner `sameIdentity` msgFrom msg || + filterAncestors sent == filterAncestors (smsg : sent) + then do + when (received' /= received) $ do + next <- wrappedStore st $ MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = [] + , msReceived = received' + , msSeen = [] + } + let threads = DirectMessageThreads [next] (messageThreadView [next]) + shared <- makeSharedStateUpdate st threads (lsShared $ fromStored erb) + svcSetLocal =<< wrappedStore st (fromStored erb) { lsShared = [shared] } + + when (powner `sameIdentity` msgFrom msg) $ do + replyStoredRef smsg + + else join $ asks $ dmOwnerMismatch . svcAttributes + + serviceNewPeer = syncDirectMessageToPeer . lookupSharedValue . lsShared . fromStored =<< svcGetLocal + + serviceStorageWatchers _ = (:[]) $ + SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncDirectMessageToPeer + + +data MessageState = MessageState + { msPrev :: [Stored MessageState] + , msPeer :: ComposedIdentity + , msReady :: [Stored DirectMessage] + , msSent :: [Stored DirectMessage] + , msReceived :: [Stored DirectMessage] + , msSeen :: [Stored DirectMessage] + } + +data DirectMessageThreads = DirectMessageThreads [Stored MessageState] [DirectMessageThread] + +instance Eq DirectMessageThreads where + DirectMessageThreads mss _ == DirectMessageThreads mss' _ = mss == mss' + +toThreadList :: DirectMessageThreads -> [DirectMessageThread] +toThreadList (DirectMessageThreads _ threads) = threads + +instance Storable MessageState where + store' MessageState {..} = storeRec $ do + mapM_ (storeRef "PREV") msPrev + mapM_ (storeRef "peer") $ idDataF msPeer + mapM_ (storeRef "ready") msReady + mapM_ (storeRef "sent") msSent + mapM_ (storeRef "received") msReceived + mapM_ (storeRef "seen") msSeen + + load' = loadRec $ do + msPrev <- loadRefs "PREV" + msPeer <- loadIdentity "peer" + msReady <- loadRefs "ready" + msSent <- loadRefs "sent" + msReceived <- loadRefs "received" + msSeen <- loadRefs "seen" + return MessageState {..} + +instance Mergeable DirectMessageThreads where + type Component DirectMessageThreads = MessageState + mergeSorted mss = DirectMessageThreads mss (messageThreadView mss) + toComponents (DirectMessageThreads mss _) = mss + +instance SharedType DirectMessageThreads where + sharedTypeID _ = mkSharedTypeID "ee793681-5976-466a-b0f0-4e1907d3fade" + +findMsgProperty :: Foldable m => Identity m -> (MessageState -> [a]) -> [Stored MessageState] -> [a] +findMsgProperty pid sel mss = concat $ flip findProperty mss $ \x -> do + guard $ msPeer x `sameIdentity` pid + guard $ not $ null $ sel x + return $ sel x + + +sendDirectMessage :: (Foldable f, Applicative f, MonadHead LocalState m, MonadError String m) + => Identity f -> Text -> m (Stored DirectMessage) +sendDirectMessage pid text = updateLocalHead $ \ls -> do + let self = localIdentity $ fromStored ls + powner = finalOwner pid + flip updateSharedState ls $ \(DirectMessageThreads prev _) -> do + let ready = findMsgProperty powner msReady prev + received = findMsgProperty powner msReceived prev + + time <- liftIO getZonedTime + smsg <- mstore DirectMessage + { msgFrom = toComposedIdentity $ finalOwner self + , msgPrev = filterAncestors $ ready ++ received + , msgTime = time + , msgText = text + } + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [smsg] + , msSent = [] + , msReceived = [] + , msSeen = [] + } + return (DirectMessageThreads [next] (messageThreadView [next]), smsg) + +syncDirectMessageToPeer :: DirectMessageThreads -> ServiceHandler DirectMessage () +syncDirectMessageToPeer (DirectMessageThreads mss _) = do + pid <- finalOwner <$> asks svcPeerIdentity + peer <- asks svcPeer + let thread = messageThreadFor pid mss + mapM_ (sendToPeerStored peer) $ msgHead thread + updateLocalHead_ $ \ls -> do + let powner = finalOwner pid + flip updateSharedState_ ls $ \unchanged@(DirectMessageThreads prev _) -> do + let ready = findMsgProperty powner msReady prev + sent = findMsgProperty powner msSent prev + sent' = filterAncestors (ready ++ sent) + + if sent' /= sent + then do + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = sent' + , msReceived = [] + , msSeen = [] + } + return $ DirectMessageThreads [next] (messageThreadView [next]) + else do + return unchanged + + +data DirectMessageThread = DirectMessageThread + { msgPeer :: ComposedIdentity + , msgHead :: [Stored DirectMessage] + , msgSent :: [Stored DirectMessage] + , msgSeen :: [Stored DirectMessage] + } + +threadToList :: DirectMessageThread -> [DirectMessage] +threadToList thread = helper S.empty $ msgHead thread + where helper seen msgs + | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = + fromStored msg : helper (S.insert msg seen) (msgs' ++ msgPrev (fromStored msg)) + | otherwise = [] + cmpView msg = (zonedTimeToUTC $ msgTime $ fromStored msg, msg) + +messageThreadView :: [Stored MessageState] -> [DirectMessageThread] +messageThreadView = helper [] + where helper used ms' = case filterAncestors ms' of + mss@(sms : rest) + | any (sameIdentity $ msPeer $ fromStored sms) used -> + helper used $ msPrev (fromStored sms) ++ rest + | otherwise -> + let peer = msPeer $ fromStored sms + in messageThreadFor peer mss : helper (peer : used) (msPrev (fromStored sms) ++ rest) + _ -> [] + +messageThreadFor :: ComposedIdentity -> [Stored MessageState] -> DirectMessageThread +messageThreadFor peer mss = + let ready = findMsgProperty peer msReady mss + sent = findMsgProperty peer msSent mss + received = findMsgProperty peer msReceived mss + seen = findMsgProperty peer msSeen mss + + in DirectMessageThread + { msgPeer = peer + , msgHead = filterAncestors $ ready ++ received + , msgSent = filterAncestors $ sent ++ received + , msgSeen = filterAncestors $ ready ++ seen + } + + +watchReceivedMessages :: Head LocalState -> (Stored DirectMessage -> IO ()) -> IO WatchedHead +watchReceivedMessages h f = do + let self = finalOwner $ localIdentity $ headObject h + watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \(DirectMessageThreads sms _) -> do + forM_ (map fromStored sms) $ \ms -> do + mapM_ f $ filter (not . sameIdentity self . msgFrom . fromStored) $ msReceived ms + +formatMessage :: TimeZone -> DirectMessage -> String +formatMessage tzone msg = concat + [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg + , maybe "<unnamed>" T.unpack $ idName $ msgFrom msg + , ": " + , T.unpack $ msgText msg + ] diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs new file mode 100644 index 0000000..b26ada5 --- /dev/null +++ b/src/Erebos/Network.hs @@ -0,0 +1,815 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Network ( + Server, + startServer, + stopServer, + getNextPeerChange, + ServerOptions(..), serverIdentity, defaultServerOptions, + + Peer, peerServer, peerStorage, + PeerAddress(..), peerAddress, + PeerIdentity(..), peerIdentity, + WaitingRef, wrDigest, + Service(..), + serverPeer, +#ifdef ENABLE_ICE_SUPPORT + serverPeerIce, +#endif + sendToPeer, sendToPeerStored, sendToPeerWith, + runPeerService, + + discoveryPort, +) where + +import Control.Concurrent +import Control.Concurrent.STM +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import qualified Data.ByteString.Char8 as BC +import Data.Function +import Data.IP qualified as IP +import Data.List +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe +import Data.Typeable +import Data.Word + +import Foreign.Ptr +import Foreign.Storable + +import GHC.Conc.Sync (unsafeIOToSTM) + +import Network.Socket hiding (ControlMessage) +import qualified Network.Socket.ByteString as S + +import Erebos.Channel +#ifdef ENABLE_ICE_SUPPORT +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Network.Protocol +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Key +import Erebos.Storage.Merge + + +discoveryPort :: PortNumber +discoveryPort = 29665 + +announceIntervalSeconds :: Int +announceIntervalSeconds = 60 + + +data Server = Server + { serverStorage :: Storage + , serverOrigHead :: Head LocalState + , serverIdentity_ :: MVar UnifiedIdentity + , serverThreads :: MVar [ThreadId] + , serverSocket :: MVar Socket + , serverRawPath :: SymFlow (PeerAddress, BC.ByteString) + , serverControlFlow :: Flow (ControlMessage PeerAddress) (ControlRequest PeerAddress) + , serverDataResponse :: TQueue (Peer, Maybe PartialRef) + , serverIOActions :: TQueue (ExceptT String IO ()) + , serverServices :: [SomeService] + , serverServiceStates :: TMVar (M.Map ServiceID SomeServiceGlobalState) + , serverPeers :: MVar (Map PeerAddress Peer) + , serverChanPeer :: TChan Peer + , serverErrorLog :: TQueue String + } + +serverIdentity :: Server -> IO UnifiedIdentity +serverIdentity = readMVar . serverIdentity_ + +getNextPeerChange :: Server -> IO Peer +getNextPeerChange = atomically . readTChan . serverChanPeer + +data ServerOptions = ServerOptions + { serverPort :: PortNumber + , serverLocalDiscovery :: Bool + } + +defaultServerOptions :: ServerOptions +defaultServerOptions = ServerOptions + { serverPort = discoveryPort + , serverLocalDiscovery = True + } + + +data Peer = Peer + { peerAddress :: PeerAddress + , peerServer_ :: Server + , peerConnection :: TVar (Either [(Bool, TransportPacket Ref, [TransportHeaderItem])] (Connection PeerAddress)) + , peerIdentityVar :: TVar PeerIdentity + , peerStorage_ :: Storage + , peerInStorage :: PartialStorage + , peerServiceState :: TMVar (M.Map ServiceID SomeServiceState) + , peerWaitingRefs :: TMVar [WaitingRef] + } + +peerServer :: Peer -> Server +peerServer = peerServer_ + +peerStorage :: Peer -> Storage +peerStorage = peerStorage_ + +getPeerChannel :: Peer -> STM ChannelState +getPeerChannel Peer {..} = either (const $ return ChannelNone) connGetChannel =<< readTVar peerConnection + +setPeerChannel :: Peer -> ChannelState -> STM () +setPeerChannel Peer {..} ch = do + readTVar peerConnection >>= \case + Left _ -> retry + Right conn -> connSetChannel conn ch + +instance Eq Peer where + (==) = (==) `on` peerIdentityVar + +data PeerAddress = DatagramAddress SockAddr +#ifdef ENABLE_ICE_SUPPORT + | PeerIceSession IceSession +#endif + +instance Show PeerAddress where + show (DatagramAddress saddr) = unwords $ case IP.fromSockAddr saddr of + Just (IP.IPv6 ipv6, port) + | (0, 0, 0xffff, ipv4) <- IP.fromIPv6w ipv6 + -> [show (IP.toIPv4w ipv4), show port] + Just (addr, port) + -> [show addr, show port] + _ -> [show saddr] +#ifdef ENABLE_ICE_SUPPORT + show (PeerIceSession ice) = show ice +#endif + +instance Eq PeerAddress where + DatagramAddress addr == DatagramAddress addr' = addr == addr' +#ifdef ENABLE_ICE_SUPPORT + PeerIceSession ice == PeerIceSession ice' = ice == ice' + _ == _ = False +#endif + +instance Ord PeerAddress where + compare (DatagramAddress addr) (DatagramAddress addr') = compare addr addr' +#ifdef ENABLE_ICE_SUPPORT + compare (DatagramAddress _ ) _ = LT + compare _ (DatagramAddress _ ) = GT + compare (PeerIceSession ice ) (PeerIceSession ice') = compare ice ice' +#endif + + +data PeerIdentity = PeerIdentityUnknown (TVar [UnifiedIdentity -> ExceptT String IO ()]) + | PeerIdentityRef WaitingRef (TVar [UnifiedIdentity -> ExceptT String IO ()]) + | PeerIdentityFull UnifiedIdentity + +peerIdentity :: MonadIO m => Peer -> m PeerIdentity +peerIdentity = liftIO . atomically . readTVar . peerIdentityVar + + +lookupServiceType :: [TransportHeaderItem] -> Maybe ServiceID +lookupServiceType (ServiceType stype : _) = Just stype +lookupServiceType (_ : hs) = lookupServiceType hs +lookupServiceType [] = Nothing + + +newWaitingRef :: RefDigest -> (Ref -> WaitingRefCallback) -> PacketHandler WaitingRef +newWaitingRef dgst act = do + peer@Peer {..} <- gets phPeer + wref <- WaitingRef peerStorage_ (partialRefFromDigest peerInStorage dgst) act <$> liftSTM (newTVar (Left [])) + modifyTMVarP peerWaitingRefs (wref:) + liftSTM $ writeTQueue (serverDataResponse $ peerServer peer) (peer, Nothing) + return wref + + +forkServerThread :: Server -> IO () -> IO () +forkServerThread server act = modifyMVar_ (serverThreads server) $ \ts -> do + (:ts) <$> forkIO act + +startServer :: ServerOptions -> Head LocalState -> (String -> IO ()) -> [SomeService] -> IO Server +startServer opt serverOrigHead logd' serverServices = do + let serverStorage = headStorage serverOrigHead + serverIdentity_ <- newMVar $ headLocalIdentity serverOrigHead + serverThreads <- newMVar [] + serverSocket <- newEmptyMVar + (serverRawPath, protocolRawPath) <- newFlowIO + (serverControlFlow, protocolControlFlow) <- newFlowIO + serverDataResponse <- newTQueueIO + serverIOActions <- newTQueueIO + serverServiceStates <- newTMVarIO M.empty + serverPeers <- newMVar M.empty + serverChanPeer <- newTChanIO + serverErrorLog <- newTQueueIO + let server = Server {..} + + chanSvc <- newTQueueIO + + let logd = writeTQueue serverErrorLog + forkServerThread server $ forever $ do + logd' =<< atomically (readTQueue serverErrorLog) + + forkServerThread server $ dataResponseWorker server + forkServerThread server $ forever $ do + either (atomically . logd) return =<< runExceptT =<< + atomically (readTQueue serverIOActions) + + broadcastAddreses <- getBroadcastAddresses discoveryPort + + let open addr = do + sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + putMVar serverSocket sock + setSocketOption sock ReuseAddr 1 + setSocketOption sock Broadcast 1 + withFdSocket sock setCloseOnExecIfNeeded + bind sock (addrAddress addr) + return sock + + loop sock = do + when (serverLocalDiscovery opt) $ forkServerThread server $ forever $ do + atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) broadcastAddreses + threadDelay $ announceIntervalSeconds * 1000 * 1000 + + let announceUpdate identity = do + st <- derivePartialStorage serverStorage + let selfRef = partialRef st $ storedRef $ idExtData identity + updateRefs = map refDigest $ selfRef : map (partialRef st . storedRef) (idUpdates identity) + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] + hitems = map AnnounceUpdate updateRefs + packet = TransportPacket (TransportHeader $ hitems) [] + + ps <- readMVar serverPeers + forM_ ps $ \peer -> atomically $ do + ((,) <$> readTVar (peerIdentityVar peer) <*> getPeerChannel peer) >>= \case + (PeerIdentityFull _, ChannelEstablished _) -> + sendToPeerS peer ackedBy packet + _ -> return () + + void $ watchHead serverOrigHead $ \h -> do + let idt = headLocalIdentity h + changedId <- modifyMVar serverIdentity_ $ \cur -> + return (idt, cur /= idt) + when changedId $ do + writeFlowIO serverControlFlow $ UpdateSelfIdentity idt + announceUpdate idt + + forM_ serverServices $ \(SomeService service _) -> do + forM_ (serviceStorageWatchers service) $ \(SomeStorageWatcher sel act) -> do + watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do + withMVar serverPeers $ mapM_ $ \peer -> atomically $ do + readTVar (peerIdentityVar peer) >>= \case + PeerIdentityFull _ -> writeTQueue serverIOActions $ do + runPeerService peer $ act x + _ -> return () + + forkServerThread server $ forever $ do + (msg, saddr) <- S.recvFrom sock 4096 + writeFlowIO serverRawPath (DatagramAddress saddr, msg) + + forkServerThread server $ forever $ do + (paddr, msg) <- readFlowIO serverRawPath + case paddr of + DatagramAddress addr -> void $ S.sendTo sock msg addr +#ifdef ENABLE_ICE_SUPPORT + PeerIceSession ice -> iceSend ice msg +#endif + + forkServerThread server $ forever $ do + readFlowIO serverControlFlow >>= \case + NewConnection conn mbpid -> do + let paddr = connAddress conn + peer <- modifyMVar serverPeers $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, peer) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, peer) + + forkServerThread server $ do + atomically $ do + readTVar (peerConnection peer) >>= \case + Left packets -> writeFlowBulk (connData conn) $ reverse packets + Right _ -> return () + writeTVar (peerConnection peer) (Right conn) + + case mbpid of + Just dgst -> do + identity <- readMVar serverIdentity_ + atomically $ runPacketHandler False peer $ do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan serverChanPeer peer + _ -> return () + Nothing -> return () + + forever $ do + (secure, TransportPacket header objs) <- readFlowIO $ connData conn + prefs <- forM objs $ storeObject $ peerInStorage peer + identity <- readMVar serverIdentity_ + let svcs = map someServiceID serverServices + handlePacket identity secure peer chanSvc svcs header prefs + + ReceivedAnnounce addr _ -> do + void $ serverPeer' server addr + + erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd protocolRawPath protocolControlFlow + + forkServerThread server $ withSocketsDo $ do + let hints = defaultHints + { addrFlags = [AI_PASSIVE] + , addrFamily = AF_INET6 + , addrSocketType = Datagram + } + addr:_ <- getAddrInfo (Just hints) Nothing (Just $ show $ serverPort opt) + bracket (open addr) close loop + + forkServerThread server $ forever $ do + (peer, svc, ref) <- atomically $ readTQueue chanSvc + case find ((svc ==) . someServiceID) serverServices of + Just service@(SomeService (_ :: Proxy s) attr) -> runPeerServiceOn (Just (service, attr)) peer (serviceHandler $ wrappedLoad @s ref) + _ -> atomically $ logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + + return server + +stopServer :: Server -> IO () +stopServer Server {..} = do + mapM_ killThread =<< takeMVar serverThreads + +dataResponseWorker :: Server -> IO () +dataResponseWorker server = forever $ do + (peer, npref) <- atomically (readTQueue $ serverDataResponse server) + + wait <- atomically $ takeTMVar (peerWaitingRefs peer) + list <- forM wait $ \wr@WaitingRef { wrefStatus = tvar } -> + atomically (readTVar tvar) >>= \case + Left ds -> case maybe id (filter . (/=) . refDigest) npref $ ds of + [] -> copyRef (wrefStorage wr) (wrefPartial wr) >>= \case + Right ref -> do + atomically (writeTVar tvar $ Right ref) + forkServerThread server $ runExceptT (wrefAction wr ref) >>= \case + Left err -> atomically $ writeTQueue (serverErrorLog server) err + Right () -> return () + + return (Nothing, []) + Left dgst -> do + atomically (writeTVar tvar $ Left [dgst]) + return (Just wr, [dgst]) + ds' -> do + atomically (writeTVar tvar $ Left ds') + return (Just wr, []) + Right _ -> return (Nothing, []) + atomically $ putTMVar (peerWaitingRefs peer) $ catMaybes $ map fst list + + let reqs = concat $ map snd list + when (not $ null reqs) $ do + let packet = TransportPacket (TransportHeader $ map DataRequest reqs) [] + ackedBy = concat [[ Rejected r, DataResponse r ] | r <- reqs ] + atomically $ sendToPeerPlain peer ackedBy packet + + +newtype PacketHandler a = PacketHandler { unPacketHandler :: StateT PacketHandlerState (ExceptT String STM) a } + deriving (Functor, Applicative, Monad, MonadState PacketHandlerState, MonadError String) + +instance MonadFail PacketHandler where + fail = throwError + +runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM () +runPacketHandler secure peer@Peer {..} act = do + let logd = writeTQueue $ serverErrorLog peerServer_ + runExceptT (flip execStateT (PacketHandlerState peer [] [] []) $ unPacketHandler act) >>= \case + Left err -> do + logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err + Right ph -> do + when (not $ null $ phHead ph) $ do + let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph) + sendToPeerS' secure peer (phAckedBy ph) packet + +liftSTM :: STM a -> PacketHandler a +liftSTM = PacketHandler . lift . lift + +readTVarP :: TVar a -> PacketHandler a +readTVarP = liftSTM . readTVar + +writeTVarP :: TVar a -> a -> PacketHandler () +writeTVarP v = liftSTM . writeTVar v + +modifyTMVarP :: TMVar a -> (a -> a) -> PacketHandler () +modifyTMVarP v f = liftSTM $ putTMVar v . f =<< takeTMVar v + +data PacketHandlerState = PacketHandlerState + { phPeer :: Peer + , phHead :: [TransportHeaderItem] + , phAckedBy :: [TransportHeaderItem] + , phBody :: [Ref] + } + +addHeader :: TransportHeaderItem -> PacketHandler () +addHeader h = modify $ \ph -> ph { phHead = h `appendDistinct` phHead ph } + +addAckedBy :: [TransportHeaderItem] -> PacketHandler () +addAckedBy hs = modify $ \ph -> ph { phAckedBy = foldr appendDistinct (phAckedBy ph) hs } + +addBody :: Ref -> PacketHandler () +addBody r = modify $ \ph -> ph { phBody = r `appendDistinct` phBody ph } + +appendDistinct :: Eq a => a -> [a] -> [a] +appendDistinct x (y:ys) | x == y = y : ys + | otherwise = y : appendDistinct x ys +appendDistinct x [] = [x] + +handlePacket :: UnifiedIdentity -> Bool + -> Peer -> TQueue (Peer, ServiceID, Ref) -> [ServiceID] + -> TransportHeader -> [PartialRef] -> IO () +handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = atomically $ do + let server = peerServer peer + ochannel <- getPeerChannel peer + let sidentity = idData identity + plaintextRefs = map (refDigest . storedRef) $ concatMap (collectStoredObjects . wrappedLoad) $ concat + [ [ storedRef sidentity ] + , map storedRef $ idUpdates identity + , case ochannel of + ChannelOurRequest _ req -> [ storedRef req ] + ChannelOurAccept _ acc _ -> [ storedRef acc ] + _ -> [] + ] + + runPacketHandler secure peer $ do + let logd = liftSTM . writeTQueue (serverErrorLog server) + forM_ headers $ \case + Acknowledged dgst -> do + liftSTM (getPeerChannel peer) >>= \case + ChannelOurAccept _ acc ch | refDigest (storedRef acc) == dgst -> do + liftSTM $ finalizedChannel peer ch identity + _ -> return () + + Rejected dgst -> do + logd $ "rejected by peer: " ++ show dgst + + DataRequest dgst + | secure || dgst `elem` plaintextRefs -> do + Right mref <- liftSTM $ unsafeIOToSTM $ + copyRef (peerStorage peer) $ + partialRefFromDigest (peerInStorage peer) dgst + addHeader $ DataResponse dgst + addAckedBy [ Acknowledged dgst, Rejected dgst ] + addBody $ mref + | otherwise -> do + logd $ "unauthorized data request for " ++ show dgst + addHeader $ Rejected dgst + + DataResponse dgst -> if + | Just pref <- find ((==dgst) . refDigest) prefs -> do + when (not secure) $ do + addHeader $ Acknowledged dgst + liftSTM $ writeTQueue (serverDataResponse server) (peer, Just pref) + | otherwise -> throwError $ "mismatched data response " ++ show dgst + + AnnounceSelf dgst + | dgst == refDigest (storedRef sidentity) -> return () + | otherwise -> do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan (serverChanPeer $ peerServer peer) peer + _ -> return () + + AnnounceUpdate dgst -> do + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityFull _ -> do + void $ newWaitingRef dgst $ handleIdentityUpdate peer + _ -> return () + + TrChannelRequest dgst -> do + let process cookie = do + addHeader $ Acknowledged dgst + wref <- newWaitingRef dgst $ handleChannelRequest peer identity + liftSTM $ setPeerChannel peer $ ChannelPeerRequest cookie wref + reject = addHeader $ Rejected dgst + + liftSTM (getPeerChannel peer) >>= \case + ChannelNone {} -> return () + ChannelCookieWait {} -> return () + ChannelCookieReceived cookie -> process $ Just cookie + ChannelCookieConfirmed cookie -> process $ Just cookie + ChannelOurRequest mbcookie our | dgst < refDigest (storedRef our) -> process mbcookie + | otherwise -> reject + ChannelPeerRequest mbcookie _ -> process mbcookie + ChannelOurAccept {} -> reject + ChannelEstablished {} -> process Nothing + + TrChannelAccept dgst -> do + let process = do + handleChannelAccept identity $ partialRefFromDigest (peerInStorage peer) dgst + reject = addHeader $ Rejected dgst + liftSTM (getPeerChannel peer) >>= \case + ChannelNone {} -> reject + ChannelCookieWait {} -> reject + ChannelCookieReceived {} -> reject + ChannelCookieConfirmed {} -> reject + ChannelOurRequest {} -> process + ChannelPeerRequest {} -> process + ChannelOurAccept _ our _ | dgst < refDigest (storedRef our) -> process + | otherwise -> addHeader $ Rejected dgst + ChannelEstablished {} -> process + + ServiceType _ -> return () + ServiceRef dgst + | not secure -> throwError $ "service packet without secure channel" + | Just svc <- lookupServiceType headers -> if + | svc `elem` svcs -> do + if dgst `elem` map refDigest prefs || True {- TODO: used by Message service to confirm receive -} + then do + void $ newWaitingRef dgst $ \ref -> + liftIO $ atomically $ writeTQueue chanSvc (peer, svc, ref) + else throwError $ "missing service object " ++ show dgst + | otherwise -> addHeader $ Rejected dgst + | otherwise -> throwError $ "service ref without type" + + _ -> return () + + +withPeerIdentity :: MonadIO m => Peer -> (UnifiedIdentity -> ExceptT String IO ()) -> m () +withPeerIdentity peer act = liftIO $ atomically $ readTVar (peerIdentityVar peer) >>= \case + PeerIdentityUnknown tvar -> modifyTVar' tvar (act:) + PeerIdentityRef _ tvar -> modifyTVar' tvar (act:) + PeerIdentityFull idt -> writeTQueue (serverIOActions $ peerServer peer) (act idt) + + +setupChannel :: UnifiedIdentity -> Peer -> UnifiedIdentity -> WaitingRefCallback +setupChannel identity peer upid = do + req <- flip runReaderT (peerStorage peer) $ createChannelRequest identity upid + let reqref = refDigest $ storedRef req + let hitems = + [ TrChannelRequest reqref + , AnnounceSelf $ refDigest $ storedRef $ idData identity + ] + liftIO $ atomically $ do + getPeerChannel peer >>= \case + ChannelCookieConfirmed cookie -> do + sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ + TransportPacket (TransportHeader hitems) [storedRef req] + setPeerChannel peer $ ChannelOurRequest (Just cookie) req + _ -> return () + +handleChannelRequest :: Peer -> UnifiedIdentity -> Ref -> WaitingRefCallback +handleChannelRequest peer identity req = do + withPeerIdentity peer $ \upid -> do + (acc, ch) <- flip runReaderT (peerStorage peer) $ acceptChannelRequest identity upid (wrappedLoad req) + liftIO $ atomically $ do + getPeerChannel peer >>= \case + ChannelPeerRequest mbcookie wr | wrDigest wr == refDigest req -> do + setPeerChannel peer $ ChannelOurAccept mbcookie acc ch + let accref = refDigest $ storedRef acc + header = TrChannelAccept accref + ackedBy = [ Acknowledged accref, Rejected accref ] + sendToPeerPlain peer ackedBy $ TransportPacket (TransportHeader [header]) $ concat + [ [ storedRef $ acc ] + , [ storedRef $ signedData $ fromStored acc ] + , [ storedRef $ caKey $ fromStored $ signedData $ fromStored acc ] + , map storedRef $ signedSignature $ fromStored acc + ] + _ -> writeTQueue (serverErrorLog $ peerServer peer) $ "unexpected channel request" + +handleChannelAccept :: UnifiedIdentity -> PartialRef -> PacketHandler () +handleChannelAccept identity accref = do + peer <- gets phPeer + liftSTM $ writeTQueue (serverIOActions $ peerServer peer) $ do + withPeerIdentity peer $ \upid -> do + copyRef (peerStorage peer) accref >>= \case + Right acc -> do + ch <- acceptedChannel identity upid (wrappedLoad acc) + liftIO $ atomically $ do + sendToPeerS peer [] $ TransportPacket (TransportHeader [Acknowledged $ refDigest accref]) [] + finalizedChannel peer ch identity + + Left dgst -> throwError $ "missing accept data " ++ BC.unpack (showRefDigest dgst) + + +finalizedChannel :: Peer -> Channel -> UnifiedIdentity -> STM () +finalizedChannel peer@Peer {..} ch self = do + setPeerChannel peer $ ChannelEstablished ch + + -- Identity update + writeTQueue (serverIOActions peerServer_) $ liftIO $ atomically $ do + let selfRef = refDigest $ storedRef $ idExtData $ self + updateRefs = selfRef : map (refDigest . storedRef) (idUpdates self) + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] + sendToPeerS peer ackedBy $ flip TransportPacket [] $ TransportHeader $ map AnnounceUpdate updateRefs + + -- Notify services about new peer + readTVar peerIdentityVar >>= \case + PeerIdentityFull _ -> notifyServicesOfPeer peer + _ -> return () + + +handleIdentityAnnounce :: UnifiedIdentity -> Peer -> Ref -> WaitingRefCallback +handleIdentityAnnounce self peer ref = liftIO $ atomically $ do + let validateAndUpdate upds act = case validateIdentity $ wrappedLoad ref of + Just pid' -> do + let pid = fromMaybe pid' $ toUnifiedIdentity (updateIdentity upds pid') + writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid + writeTChan (serverChanPeer $ peerServer peer) peer + act pid + writeTQueue (serverIOActions $ peerServer peer) $ do + setupChannel self peer pid + Nothing -> return () + + readTVar (peerIdentityVar peer) >>= \case + PeerIdentityRef wref wact + | wrDigest wref == refDigest ref + -> validateAndUpdate [] $ \pid -> do + mapM_ (writeTQueue (serverIOActions $ peerServer peer) . ($ pid)) . + reverse =<< readTVar wact + + PeerIdentityFull pid + | idData pid `precedes` wrappedLoad ref + -> validateAndUpdate (idUpdates pid) $ \_ -> do + notifyServicesOfPeer peer + + _ -> return () + +handleIdentityUpdate :: Peer -> Ref -> WaitingRefCallback +handleIdentityUpdate peer ref = liftIO $ atomically $ do + pidentity <- readTVar (peerIdentityVar peer) + if | PeerIdentityFull pid <- pidentity + , Just pid' <- toUnifiedIdentity $ updateIdentity [wrappedLoad ref] pid + -> do + writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid' + writeTChan (serverChanPeer $ peerServer peer) peer + when (idData pid /= idData pid') $ notifyServicesOfPeer peer + + | otherwise -> return () + +notifyServicesOfPeer :: Peer -> STM () +notifyServicesOfPeer peer@Peer { peerServer_ = Server {..} } = do + writeTQueue serverIOActions $ do + forM_ serverServices $ \service@(SomeService _ attrs) -> + runPeerServiceOn (Just (service, attrs)) peer serviceNewPeer + + +mkPeer :: Server -> PeerAddress -> IO Peer +mkPeer peerServer_ peerAddress = do + peerConnection <- newTVarIO (Left []) + peerIdentityVar <- newTVarIO . PeerIdentityUnknown =<< newTVarIO [] + peerStorage_ <- deriveEphemeralStorage $ serverStorage peerServer_ + peerInStorage <- derivePartialStorage peerStorage_ + peerServiceState <- newTMVarIO M.empty + peerWaitingRefs <- newTMVarIO [] + return Peer {..} + +serverPeer :: Server -> SockAddr -> IO Peer +serverPeer server paddr = do + serverPeer' server (DatagramAddress paddr) + +#ifdef ENABLE_ICE_SUPPORT +serverPeerIce :: Server -> IceSession -> IO Peer +serverPeerIce server@Server {..} ice = do + let paddr = PeerIceSession ice + peer <- serverPeer' server paddr + iceSetChan ice $ mapFlow undefined (paddr,) serverRawPath + return peer +#endif + +serverPeer' :: Server -> PeerAddress -> IO Peer +serverPeer' server paddr = do + (peer, hello) <- modifyMVar (serverPeers server) $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, (peer, False)) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, (peer, True)) + when hello $ atomically $ do + writeFlow (serverControlFlow server) (RequestConnection paddr) + return peer + + +sendToPeer :: (Service s, MonadIO m) => Peer -> s -> m () +sendToPeer peer packet = sendToPeerList peer [ServiceReply (Left packet) True] + +sendToPeerStored :: (Service s, MonadIO m) => Peer -> Stored s -> m () +sendToPeerStored peer spacket = sendToPeerList peer [ServiceReply (Right spacket) True] + +sendToPeerList :: (Service s, MonadIO m) => Peer -> [ServiceReply s] -> m () +sendToPeerList peer parts = do + let st = peerStorage peer + srefs <- liftIO $ fmap catMaybes $ forM parts $ \case + ServiceReply (Left x) use -> Just . (,use) <$> store st x + ServiceReply (Right sx) use -> return $ Just (storedRef sx, use) + ServiceFinally act -> act >> return Nothing + let dgsts = map (refDigest . fst) srefs + let content = map fst $ filter snd srefs + header = TransportHeader (ServiceType (serviceID $ head parts) : map ServiceRef dgsts) + packet = TransportPacket header content + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- dgsts ] + liftIO $ atomically $ sendToPeerS peer ackedBy packet + +sendToPeerS' :: Bool -> Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS' secure Peer {..} ackedBy packet = do + readTVar peerConnection >>= \case + Left xs -> writeTVar peerConnection $ Left $ (secure, packet, ackedBy) : xs + Right conn -> writeFlow (connData conn) (secure, packet, ackedBy) + +sendToPeerS :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS = sendToPeerS' True + +sendToPeerPlain :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerPlain = sendToPeerS' False + +sendToPeerWith :: forall s m. (Service s, MonadIO m, MonadError String m) => Peer -> (ServiceState s -> ExceptT String IO (Maybe s, ServiceState s)) -> m () +sendToPeerWith peer fobj = do + let sproxy = Proxy @s + sid = serviceID sproxy + res <- liftIO $ do + svcs <- atomically $ takeTMVar (peerServiceState peer) + (svcs', res) <- runExceptT (fobj $ fromMaybe (emptyServiceState sproxy) $ fromServiceState sproxy =<< M.lookup sid svcs) >>= \case + Right (obj, s') -> return $ (M.insert sid (SomeServiceState sproxy s') svcs, Right obj) + Left err -> return $ (svcs, Left err) + atomically $ putTMVar (peerServiceState peer) svcs' + return res + + case res of + Right (Just obj) -> sendToPeer peer obj + Right Nothing -> return () + Left err -> throwError err + + +lookupService :: forall s. Service s => Proxy s -> [SomeService] -> Maybe (SomeService, ServiceAttributes s) +lookupService proxy (service@(SomeService (_ :: Proxy t) attr) : rest) + | Just (Refl :: s :~: t) <- eqT = Just (service, attr) + | otherwise = lookupService proxy rest +lookupService _ [] = Nothing + +runPeerService :: forall s m. (Service s, MonadIO m) => Peer -> ServiceHandler s () -> m () +runPeerService = runPeerServiceOn Nothing + +runPeerServiceOn :: forall s m. (Service s, MonadIO m) => Maybe (SomeService, ServiceAttributes s) -> Peer -> ServiceHandler s () -> m () +runPeerServiceOn mbservice peer handler = liftIO $ do + let server = peerServer peer + proxy = Proxy @s + svc = serviceID proxy + logd = writeTQueue (serverErrorLog server) + case mbservice `mplus` lookupService proxy (serverServices server) of + Just (service, attr) -> + atomically (readTVar (peerIdentityVar peer)) >>= \case + PeerIdentityFull peerId -> do + (global, svcs) <- atomically $ (,) + <$> takeTMVar (serverServiceStates server) + <*> takeTMVar (peerServiceState peer) + case (fromMaybe (someServiceEmptyState service) $ M.lookup svc svcs, + fromMaybe (someServiceEmptyGlobalState service) $ M.lookup svc global) of + ((SomeServiceState (_ :: Proxy ps) ps), + (SomeServiceGlobalState (_ :: Proxy gs) gs)) -> do + Just (Refl :: s :~: ps) <- return $ eqT + Just (Refl :: s :~: gs) <- return $ eqT + + let inp = ServiceInput + { svcAttributes = attr + , svcPeer = peer + , svcPeerIdentity = peerId + , svcServer = server + , svcPrintOp = atomically . logd + } + reloadHead (serverOrigHead server) >>= \case + Nothing -> atomically $ do + logd $ "current head deleted" + putTMVar (peerServiceState peer) svcs + putTMVar (serverServiceStates server) global + Just h -> do + (rsp, (s', gs')) <- runServiceHandler h inp ps gs handler + moveKeys (peerStorage peer) (serverStorage server) + when (not (null rsp)) $ do + sendToPeerList peer rsp + atomically $ do + putTMVar (peerServiceState peer) $ M.insert svc (SomeServiceState proxy s') svcs + putTMVar (serverServiceStates server) $ M.insert svc (SomeServiceGlobalState proxy gs') global + _ -> do + atomically $ logd $ "can't run service handler on peer with incomplete identity " ++ show (peerAddress peer) + + _ -> atomically $ do + logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + + +foreign import ccall unsafe "Network/ifaddrs.h broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32) +foreign import ccall unsafe "stdlib.h free" cFree :: Ptr Word32 -> IO () + +getBroadcastAddresses :: PortNumber -> IO [SockAddr] +getBroadcastAddresses port = do + ptr <- cBroadcastAddresses + let parse i = do + w <- peekElemOff ptr i + if w == 0 then return [] + else (SockAddrInet port w:) <$> parse (i + 1) + addrs <- parse 0 + cFree ptr + return addrs diff --git a/src/Erebos/Network.hs-boot b/src/Erebos/Network.hs-boot new file mode 100644 index 0000000..849bfc1 --- /dev/null +++ b/src/Erebos/Network.hs-boot @@ -0,0 +1,8 @@ +module Erebos.Network where + +import Erebos.Storage + +data Server +data Peer + +peerStorage :: Peer -> Storage diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs new file mode 100644 index 0000000..d7253e3 --- /dev/null +++ b/src/Erebos/Network/Protocol.hs @@ -0,0 +1,589 @@ +module Erebos.Network.Protocol ( + TransportPacket(..), + transportToObject, + TransportHeader(..), + TransportHeaderItem(..), + + WaitingRef(..), + WaitingRefCallback, + wrDigest, + + ChannelState(..), + + ControlRequest(..), + ControlMessage(..), + erebosNetworkProtocol, + + Connection, + connAddress, + connData, + connGetChannel, + connSetChannel, + + module Erebos.Flow, +) where + +import Control.Applicative +import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.Except + +import Data.Bits +import Data.ByteString (ByteString) +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.List +import Data.Maybe +import Data.Text (Text) +import Data.Text qualified as T + +import System.Clock + +import Erebos.Channel +import Erebos.Flow +import Erebos.Identity +import Erebos.Service +import Erebos.Storage + + +protocolVersion :: Text +protocolVersion = T.pack "0.1" + +protocolVersions :: [Text] +protocolVersions = [protocolVersion] + + +data TransportPacket a = TransportPacket TransportHeader [a] + +data TransportHeader = TransportHeader [TransportHeaderItem] + deriving (Show) + +data TransportHeaderItem + = Acknowledged RefDigest + | AcknowledgedSingle Integer + | Rejected RefDigest + | ProtocolVersion Text + | Initiation RefDigest + | CookieSet Cookie + | CookieEcho Cookie + | DataRequest RefDigest + | DataResponse RefDigest + | AnnounceSelf RefDigest + | AnnounceUpdate RefDigest + | TrChannelRequest RefDigest + | TrChannelAccept RefDigest + | ServiceType ServiceID + | ServiceRef RefDigest + deriving (Eq, Show) + +newtype Cookie = Cookie ByteString + deriving (Eq, Show) + +isHeaderItemAcknowledged :: TransportHeaderItem -> Bool +isHeaderItemAcknowledged = \case + Acknowledged {} -> False + AcknowledgedSingle {} -> False + Rejected {} -> False + ProtocolVersion {} -> False + Initiation {} -> False + CookieSet {} -> False + CookieEcho {} -> False + _ -> True + +transportToObject :: PartialStorage -> TransportHeader -> PartialObject +transportToObject st (TransportHeader items) = Rec $ map single items + where single = \case + Acknowledged dgst -> (BC.pack "ACK", RecRef $ partialRefFromDigest st dgst) + AcknowledgedSingle num -> (BC.pack "ACK", RecInt num) + Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst) + ProtocolVersion ver -> (BC.pack "VER", RecText ver) + Initiation dgst -> (BC.pack "INI", RecRef $ partialRefFromDigest st dgst) + CookieSet (Cookie bytes) -> (BC.pack "CKS", RecBinary bytes) + CookieEcho (Cookie bytes) -> (BC.pack "CKE", RecBinary bytes) + DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst) + DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst) + AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst) + AnnounceUpdate dgst -> (BC.pack "ANU", RecRef $ partialRefFromDigest st dgst) + TrChannelRequest dgst -> (BC.pack "CRQ", RecRef $ partialRefFromDigest st dgst) + TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst) + ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype) + ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst) + +transportFromObject :: PartialObject -> Maybe TransportHeader +transportFromObject (Rec items) = case catMaybes $ map single items of + [] -> Nothing + titems -> Just $ TransportHeader titems + where single (name, content) = if + | name == BC.pack "ACK", RecRef ref <- content -> Just $ Acknowledged $ refDigest ref + | name == BC.pack "ACK", RecInt num <- content -> Just $ AcknowledgedSingle num + | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref + | name == BC.pack "VER", RecText ver <- content -> Just $ ProtocolVersion ver + | name == BC.pack "INI", RecRef ref <- content -> Just $ Initiation $ refDigest ref + | name == BC.pack "CKS", RecBinary bytes <- content -> Just $ CookieSet (Cookie bytes) + | name == BC.pack "CKE", RecBinary bytes <- content -> Just $ CookieEcho (Cookie bytes) + | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref + | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref + | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref + | name == BC.pack "ANU", RecRef ref <- content -> Just $ AnnounceUpdate $ refDigest ref + | name == BC.pack "CRQ", RecRef ref <- content -> Just $ TrChannelRequest $ refDigest ref + | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref + | name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid + | name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref + | otherwise -> Nothing +transportFromObject _ = Nothing + + +data GlobalState addr = (Eq addr, Show addr) => GlobalState + { gIdentity :: TVar (UnifiedIdentity, [UnifiedIdentity]) + , gConnections :: TVar [Connection addr] + , gDataFlow :: SymFlow (addr, ByteString) + , gControlFlow :: Flow (ControlRequest addr) (ControlMessage addr) + , gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject)) + , gLog :: String -> STM () + , gStorage :: PartialStorage + , gNowVar :: TVar TimeSpec + , gNextTimeout :: TVar TimeSpec + , gInitConfig :: Ref + } + +data Connection addr = Connection + { cAddress :: addr + , cDataUp :: Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) + , cDataInternal :: Flow (Bool, TransportPacket Ref, [TransportHeaderItem]) (Bool, TransportPacket PartialObject) + , cChannel :: TVar ChannelState + , cSecureOutQueue :: TQueue (Bool, TransportPacket Ref, [TransportHeaderItem]) + , cSentPackets :: TVar [SentPacket] + , cToAcknowledge :: TVar [Integer] + } + +connAddress :: Connection addr -> addr +connAddress = cAddress + +connData :: Connection addr -> Flow (Bool, TransportPacket PartialObject) (Bool, TransportPacket Ref, [TransportHeaderItem]) +connData = cDataUp + +connGetChannel :: Connection addr -> STM ChannelState +connGetChannel Connection {..} = readTVar cChannel + +connSetChannel :: Connection addr -> ChannelState -> STM () +connSetChannel Connection {..} ch = do + writeTVar cChannel ch + + +data WaitingRef = WaitingRef + { wrefStorage :: Storage + , wrefPartial :: PartialRef + , wrefAction :: Ref -> WaitingRefCallback + , wrefStatus :: TVar (Either [RefDigest] Ref) + } + +type WaitingRefCallback = ExceptT String IO () + +wrDigest :: WaitingRef -> RefDigest +wrDigest = refDigest . wrefPartial + + +data ChannelState = ChannelNone + | ChannelCookieWait -- sent initiation, waiting for response + | ChannelCookieReceived Cookie -- received cookie, but no cookie echo yet + | ChannelCookieConfirmed Cookie -- received cookie echo, no need to send from our side + | ChannelOurRequest (Maybe Cookie) (Stored ChannelRequest) + | ChannelPeerRequest (Maybe Cookie) WaitingRef + | ChannelOurAccept (Maybe Cookie) (Stored ChannelAccept) Channel + | ChannelEstablished Channel + + +data SentPacket = SentPacket + { spTime :: TimeSpec + , spRetryCount :: Int + , spAckedBy :: Maybe (TransportHeaderItem -> Bool) + , spData :: BC.ByteString + } + + +data ControlRequest addr = RequestConnection addr + | SendAnnounce addr + | UpdateSelfIdentity UnifiedIdentity + +data ControlMessage addr = NewConnection (Connection addr) (Maybe RefDigest) + | ReceivedAnnounce addr RefDigest + + +erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) + => UnifiedIdentity + -> (String -> STM ()) + -> SymFlow (addr, ByteString) + -> Flow (ControlRequest addr) (ControlMessage addr) + -> IO () +erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do + gIdentity <- newTVarIO (initialIdentity, []) + gConnections <- newTVarIO [] + gNextUp <- newEmptyTMVarIO + mStorage <- memoryStorage + gStorage <- derivePartialStorage mStorage + + startTime <- getTime MonotonicRaw + gNowVar <- newTVarIO startTime + gNextTimeout <- newTVarIO startTime + gInitConfig <- store mStorage $ (Rec [] :: Object) + + let gs = GlobalState {..} + + let signalTimeouts = forever $ do + now <- getTime MonotonicRaw + next <- atomically $ do + writeTVar gNowVar now + readTVar gNextTimeout + + let waitTill time + | time > now = threadDelay $ fromInteger (toNanoSecs (time - now)) `div` 1000 + | otherwise = threadDelay maxBound + waitForUpdate = atomically $ do + next' <- readTVar gNextTimeout + when (next' == next) retry + + race_ (waitTill next) waitForUpdate + + race_ signalTimeouts $ forever $ join $ atomically $ + passUpIncoming gs <|> processIncoming gs <|> processOutgoing gs + + +getConnection :: GlobalState addr -> addr -> STM (Connection addr) +getConnection gs addr = do + maybe (newConnection gs addr) return =<< findConnection gs addr + +findConnection :: GlobalState addr -> addr -> STM (Maybe (Connection addr)) +findConnection GlobalState {..} addr = do + find ((addr==) . cAddress) <$> readTVar gConnections + +newConnection :: GlobalState addr -> addr -> STM (Connection addr) +newConnection GlobalState {..} addr = do + conns <- readTVar gConnections + + let cAddress = addr + (cDataUp, cDataInternal) <- newFlow + cChannel <- newTVar ChannelNone + cSecureOutQueue <- newTQueue + cSentPackets <- newTVar [] + cToAcknowledge <- newTVar [] + let conn = Connection {..} + + writeTVar gConnections (conn : conns) + return conn + +passUpIncoming :: GlobalState addr -> STM (IO ()) +passUpIncoming GlobalState {..} = do + (Connection {..}, up) <- takeTMVar gNextUp + writeFlow cDataInternal up + return $ return () + +processIncoming :: GlobalState addr -> STM (IO ()) +processIncoming gs@GlobalState {..} = do + guard =<< isEmptyTMVar gNextUp + guard =<< canWriteFlow gControlFlow + + (addr, msg) <- readFlow gDataFlow + mbconn <- findConnection gs addr + + mbch <- case mbconn of + Nothing -> return Nothing + Just conn -> readTVar (cChannel conn) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ _ ch -> Just ch + _ -> Nothing + + return $ do + let deserialize = liftEither . runExcept . deserializeObjects gStorage . BL.fromStrict + let parse = case B.uncons msg of + Just (b, enc) + | b .&. 0xE0 == 0x80 -> do + ch <- maybe (throwError "unexpected encrypted packet") return mbch + (dec, counter) <- channelDecrypt ch enc + + case B.uncons dec of + Just (0x00, content) -> do + objs <- deserialize content + return (True, objs, Just counter) + + Just (_, _) -> do + throwError "streams not implemented" + + Nothing -> do + throwError "empty decrypted content" + + | b .&. 0xE0 == 0x60 -> do + objs <- deserialize msg + return (False, objs, Nothing) + + | otherwise -> throwError "invalid packet" + + Nothing -> throwError "empty packet" + + runExceptT parse >>= \case + Right (secure, objs, mbcounter) + | hobj:content <- objs + , Just header@(TransportHeader items) <- transportFromObject hobj + -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case + Just (conn@Connection {..}, mbup) -> atomically $ do + case mbcounter of + Just counter | any isHeaderItemAcknowledged items -> + modifyTVar' cToAcknowledge (fromIntegral counter :) + _ -> return () + processAcknowledgements gs conn items + case mbup of + Just up -> putTMVar gNextUp (conn, (secure, up)) + Nothing -> return () + Nothing -> return () + + | otherwise -> atomically $ do + gLog $ show addr ++ ": invalid objects" + gLog $ show objs + + Left err -> do + atomically $ gLog $ show addr <> ": failed to parse packet: " <> err + +processPacket :: GlobalState addr -> Either addr (Connection addr) -> Bool -> TransportPacket a -> IO (Maybe (Connection addr, Maybe (TransportPacket a))) +processPacket gs@GlobalState {..} econn secure packet@(TransportPacket (TransportHeader header) _) = if + -- Established secure communication + | Right conn <- econn, secure + -> return $ Just (conn, Just packet) + + -- Plaintext communication with cookies to prove origin + | cookie:_ <- mapMaybe (\case CookieEcho x -> Just x; _ -> Nothing) header + -> verifyCookie gs addr cookie >>= \case + True -> do + atomically $ do + conn@Connection {..} <- getConnection gs addr + channel <- readTVar cChannel + let received = listToMaybe $ mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header + case received `mplus` channelCurrentCookie channel of + Just current -> do + cookieEchoReceived gs conn mbpid current + return $ Just (conn, Just packet) + Nothing -> do + gLog $ show addr <> ": missing cookie set, dropping " <> show header + return $ Nothing + + False -> do + atomically $ gLog $ show addr <> ": cookie verification failed, dropping " <> show header + return Nothing + + -- Response to initiation packet + | cookie:_ <- mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header + , Just _ <- version + , Right conn@Connection {..} <- econn + -> do + atomically $ readTVar cChannel >>= \case + ChannelCookieWait -> do + writeTVar cChannel $ ChannelCookieReceived cookie + writeFlow gControlFlow (NewConnection conn mbpid) + return $ Just (conn, Nothing) + _ -> return Nothing + + -- Initiation packet + | _:_ <- mapMaybe (\case Initiation x -> Just x; _ -> Nothing) header + , Just ver <- version + -> do + cookie <- createCookie gs addr + atomically $ do + identity <- fst <$> readTVar gIdentity + let reply = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader + [ CookieSet cookie + , AnnounceSelf $ refDigest $ storedRef $ idData identity + , ProtocolVersion ver + ] + writeFlow gDataFlow (addr, reply) + return Nothing + + -- Announce packet outside any connection + | dgst:_ <- mapMaybe (\case AnnounceSelf x -> Just x; _ -> Nothing) header + , Just _ <- version + -> do + atomically $ do + (cur, past) <- readTVar gIdentity + when (not $ dgst `elem` map (refDigest . storedRef . idData) (cur : past)) $ do + writeFlow gControlFlow $ ReceivedAnnounce addr dgst + return Nothing + + | otherwise -> do + atomically $ gLog $ show addr <> ": dropping packet " <> show header + return Nothing + + where + addr = either id cAddress econn + mbpid = listToMaybe $ mapMaybe (\case AnnounceSelf dgst -> Just dgst; _ -> Nothing) header + version = listToMaybe $ filter (\v -> ProtocolVersion v `elem` header) protocolVersions + +channelCurrentCookie :: ChannelState -> Maybe Cookie +channelCurrentCookie = \case + ChannelCookieReceived cookie -> Just cookie + ChannelCookieConfirmed cookie -> Just cookie + ChannelOurRequest mbcookie _ -> mbcookie + ChannelPeerRequest mbcookie _ -> mbcookie + ChannelOurAccept mbcookie _ _ -> mbcookie + _ -> Nothing + +cookieEchoReceived :: GlobalState addr -> Connection addr -> Maybe RefDigest -> Cookie -> STM () +cookieEchoReceived GlobalState {..} conn@Connection {..} mbpid cookieSet = do + readTVar cChannel >>= \case + ChannelNone -> newConn + ChannelCookieWait -> newConn + ChannelCookieReceived {} -> update + _ -> return () + where + update = do + writeTVar cChannel $ ChannelCookieConfirmed cookieSet + newConn = do + update + writeFlow gControlFlow (NewConnection conn mbpid) + +generateCookieHeaders :: GlobalState addr -> addr -> ChannelState -> IO [TransportHeaderItem] +generateCookieHeaders gs addr ch = catMaybes <$> sequence [ echoHeader, setHeader ] + where + echoHeader = return $ CookieEcho <$> channelCurrentCookie ch + setHeader = case ch of + ChannelCookieWait {} -> Just . CookieSet <$> createCookie gs addr + ChannelCookieReceived {} -> Just . CookieSet <$> createCookie gs addr + _ -> return Nothing + +createCookie :: GlobalState addr -> addr -> IO Cookie +createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr) + +verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool +verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie + +resendBytes :: GlobalState addr -> Connection addr -> SentPacket -> IO () +resendBytes GlobalState {..} Connection {..} sp = do + now <- getTime MonotonicRaw + atomically $ do + when (isJust $ spAckedBy sp) $ do + modifyTVar' cSentPackets $ (:) sp + { spTime = now + , spRetryCount = spRetryCount sp + 1 + } + writeFlow gDataFlow (cAddress, spData sp) + +sendBytes :: GlobalState addr -> Connection addr -> ByteString -> Maybe (TransportHeaderItem -> Bool) -> IO () +sendBytes gs conn bs ackedBy = resendBytes gs conn + SentPacket + { spTime = undefined + , spRetryCount = -1 + , spAckedBy = ackedBy + , spData = bs + } + +processOutgoing :: forall addr. GlobalState addr -> STM (IO ()) +processOutgoing gs@GlobalState {..} = do + + let sendNextPacket :: Connection addr -> STM (IO ()) + sendNextPacket conn@Connection {..} = do + channel <- readTVar cChannel + let mbch = case channel of + ChannelEstablished ch -> Just ch + _ -> Nothing + + let checkOutstanding + | isJust mbch = readTQueue cSecureOutQueue + | otherwise = retry + + checkAcknowledgements + | isJust mbch = do + acks <- readTVar cToAcknowledge + if null acks then retry + else return (True, TransportPacket (TransportHeader []) [], []) + | otherwise = retry + + (secure, packet@(TransportPacket (TransportHeader hitems) content), plainAckedBy) <- + checkOutstanding <|> readFlow cDataInternal <|> checkAcknowledgements + + when (isNothing mbch && secure) $ do + writeTQueue cSecureOutQueue (secure, packet, plainAckedBy) + + acknowledge <- case mbch of + Nothing -> return [] + Just _ -> swapTVar cToAcknowledge [] + + return $ do + cookieHeaders <- generateCookieHeaders gs cAddress channel + let header = TransportHeader $ map AcknowledgedSingle acknowledge ++ cookieHeaders ++ hitems + + let plain = BL.concat $ + (serializeObject $ transportToObject gStorage header) + : map lazyLoadBytes content + + mbs <- case mbch of + Just ch -> do + runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case + Right (ctext, counter) -> do + let isAcked = any isHeaderItemAcknowledged hitems + return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err + return Nothing + Nothing | secure -> return Nothing + | otherwise -> return $ Just (BL.toStrict plain, plainAckedBy) + + case mbs of + Just (bs, ackedBy) -> sendBytes gs conn bs $ guard (not $ null ackedBy) >> Just (`elem` ackedBy) + Nothing -> return () + + let retransmitPacket :: Connection addr -> STM (IO ()) + retransmitPacket conn@Connection {..} = do + now <- readTVar gNowVar + (sp, rest) <- readTVar cSentPackets >>= \case + sps@(_:_) -> return (last sps, init sps) + _ -> retry + let nextTry = spTime sp + fromNanoSecs 1000000000 + if now < nextTry + then do + nextTimeout <- readTVar gNextTimeout + if nextTimeout <= now || nextTry < nextTimeout + then do writeTVar gNextTimeout nextTry + return $ return () + else retry + else do + writeTVar cSentPackets rest + return $ resendBytes gs conn sp + + let handleControlRequests = readFlow gControlFlow >>= \case + RequestConnection addr -> do + conn@Connection {..} <- getConnection gs addr + identity <- fst <$> readTVar gIdentity + readTVar cChannel >>= \case + ChannelNone -> do + let packet = BL.toStrict $ BL.concat + [ serializeObject $ transportToObject gStorage $ TransportHeader $ + [ Initiation $ refDigest gInitConfig + , AnnounceSelf $ refDigest $ storedRef $ idData identity + ] ++ map ProtocolVersion protocolVersions + , lazyLoadBytes gInitConfig + ] + writeTVar cChannel ChannelCookieWait + return $ sendBytes gs conn packet $ Just $ \case CookieSet {} -> True; _ -> False + _ -> return $ return () + + SendAnnounce addr -> do + identity <- fst <$> readTVar gIdentity + let packet = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader $ + [ AnnounceSelf $ refDigest $ storedRef $ idData identity + ] ++ map ProtocolVersion protocolVersions + writeFlow gDataFlow (addr, packet) + return $ return () + + UpdateSelfIdentity nid -> do + (cur, past) <- readTVar gIdentity + writeTVar gIdentity (nid, cur : past) + return $ return () + + conns <- readTVar gConnections + msum $ concat $ + [ map retransmitPacket conns + , map sendNextPacket conns + , [ handleControlRequests ] + ] + +processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM () +processAcknowledgements GlobalState {} Connection {..} = mapM_ $ \hitem -> do + modifyTVar' cSentPackets $ filter $ \sp -> not (fromJust (spAckedBy sp) hitem) diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c new file mode 100644 index 0000000..37c3e00 --- /dev/null +++ b/src/Erebos/Network/ifaddrs.c @@ -0,0 +1,41 @@ +#include "ifaddrs.h" + +#include <arpa/inet.h> +#include <ifaddrs.h> +#include <net/if.h> +#include <stdlib.h> +#include <sys/types.h> +#include <endian.h> + +uint32_t * broadcast_addresses(void) +{ + struct ifaddrs * addrs; + if (getifaddrs(&addrs) < 0) + return 0; + + size_t capacity = 16, count = 0; + uint32_t * ret = malloc(sizeof(uint32_t) * capacity); + + for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET && + ifa->ifa_flags & IFF_BROADCAST) { + if (count + 2 >= capacity) { + capacity *= 2; + uint32_t * nret = realloc(ret, sizeof(uint32_t) * capacity); + if (nret) { + ret = nret; + } else { + free(ret); + return 0; + } + } + + ret[count] = ((struct sockaddr_in*)ifa->ifa_broadaddr)->sin_addr.s_addr; + count++; + } + } + + freeifaddrs(addrs); + ret[count] = 0; + return ret; +} diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h new file mode 100644 index 0000000..06d26ec --- /dev/null +++ b/src/Erebos/Network/ifaddrs.h @@ -0,0 +1,3 @@ +#include <stdint.h> + +uint32_t * broadcast_addresses(void); diff --git a/src/Erebos/Pairing.hs b/src/Erebos/Pairing.hs new file mode 100644 index 0000000..4541f6e --- /dev/null +++ b/src/Erebos/Pairing.hs @@ -0,0 +1,241 @@ +module Erebos.Pairing ( + PairingService(..), + PairingState(..), + PairingAttributes(..), + PairingResult(..), + PairingFailureReason(..), + + pairingRequest, + pairingAccept, + pairingReject, +) where + +import Control.Monad.Except +import Control.Monad.Reader + +import Crypto.Random + +import Data.Bits +import Data.ByteArray (Bytes, convert) +import qualified Data.ByteArray as BA +import qualified Data.ByteString.Char8 as BC +import Data.Kind +import Data.Maybe +import Data.Typeable +import Data.Word + +import Erebos.Identity +import Erebos.Network +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage + +data PairingService a = PairingRequest (Stored (Signed IdentityData)) (Stored (Signed IdentityData)) RefDigest + | PairingResponse Bytes + | PairingRequestNonce Bytes + | PairingAccept a + | PairingReject + +data PairingState a = NoPairing + | OurRequest UnifiedIdentity UnifiedIdentity Bytes + | OurRequestConfirm (Maybe (PairingVerifiedResult a)) + | OurRequestReady + | PeerRequest UnifiedIdentity UnifiedIdentity Bytes RefDigest + | PeerRequestConfirm + | PairingDone + +data PairingFailureReason a = PairingUserRejected + | PairingUnexpectedMessage (PairingState a) (PairingService a) + | PairingFailedOther String + +data PairingAttributes a = PairingAttributes + { pairingHookRequest :: ServiceHandler (PairingService a) () + , pairingHookResponse :: String -> ServiceHandler (PairingService a) () + , pairingHookRequestNonce :: String -> ServiceHandler (PairingService a) () + , pairingHookRequestNonceFailed :: ServiceHandler (PairingService a) () + , pairingHookConfirmedResponse :: ServiceHandler (PairingService a) () + , pairingHookConfirmedRequest :: ServiceHandler (PairingService a) () + , pairingHookAcceptedResponse :: ServiceHandler (PairingService a) () + , pairingHookAcceptedRequest :: ServiceHandler (PairingService a) () + , pairingHookVerifyFailed :: ServiceHandler (PairingService a) () + , pairingHookRejected :: ServiceHandler (PairingService a) () + , pairingHookFailed :: PairingFailureReason a -> ServiceHandler (PairingService a) () + } + +class (Typeable a, Storable a) => PairingResult a where + type PairingVerifiedResult a :: Type + type PairingVerifiedResult a = a + + pairingServiceID :: proxy a -> ServiceID + pairingVerifyResult :: a -> ServiceHandler (PairingService a) (Maybe (PairingVerifiedResult a)) + pairingFinalizeRequest :: PairingVerifiedResult a -> ServiceHandler (PairingService a) () + pairingFinalizeResponse :: ServiceHandler (PairingService a) a + defaultPairingAttributes :: proxy (PairingService a) -> PairingAttributes a + + +instance Storable a => Storable (PairingService a) where + store' (PairingRequest idReq idRsp x) = storeRec $ do + storeRef "id-req" idReq + storeRef "id-rsp" idRsp + storeBinary "request" x + store' (PairingResponse x) = storeRec $ storeBinary "response" x + store' (PairingRequestNonce x) = storeRec $ storeBinary "reqnonce" x + store' (PairingAccept x) = store' x + store' (PairingReject) = storeRec $ storeEmpty "reject" + + load' = do + res <- loadRec $ do + (req :: Maybe Bytes) <- loadMbBinary "request" + idReq <- loadMbRef "id-req" + idRsp <- loadMbRef "id-rsp" + rsp <- loadMbBinary "response" + rnonce <- loadMbBinary "reqnonce" + rej <- loadMbEmpty "reject" + return $ catMaybes + [ PairingRequest <$> idReq <*> idRsp <*> (refDigestFromByteString =<< req) + , PairingResponse <$> rsp + , PairingRequestNonce <$> rnonce + , const PairingReject <$> rej + ] + case res of + x:_ -> return x + [] -> PairingAccept <$> load' + + +instance PairingResult a => Service (PairingService a) where + serviceID _ = pairingServiceID @a Proxy + + type ServiceAttributes (PairingService a) = PairingAttributes a + defaultServiceAttributes = defaultPairingAttributes + + type ServiceState (PairingService a) = PairingState a + emptyServiceState _ = NoPairing + + serviceHandler spacket = ((,fromStored spacket) <$> svcGet) >>= \case + (NoPairing, PairingRequest pdata sdata confirm) -> do + self <- maybe (throwError "failed to validate received identity") return $ validateIdentity sdata + self' <- maybe (throwError "failed to validate own identity") return . + validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal + when (not $ self `sameIdentity` self') $ do + throwError "pairing request to different identity" + + peer <- maybe (throwError "failed to validate received peer identity") return $ validateIdentity pdata + peer' <- asks $ svcPeerIdentity + when (not $ peer `sameIdentity` peer') $ do + throwError "pairing request from different identity" + + join $ asks $ pairingHookRequest . svcAttributes + nonce <- liftIO $ getRandomBytes 32 + svcSet $ PeerRequest peer self nonce confirm + replyPacket $ PairingResponse nonce + (NoPairing, _) -> return () + + (PairingDone, _) -> return () + (_, PairingReject) -> do + join $ asks $ pairingHookRejected . svcAttributes + svcSet NoPairing + + (OurRequest self peer nonce, PairingResponse pnonce) -> do + hook <- asks $ pairingHookResponse . svcAttributes + hook $ confirmationNumber $ nonceDigest self peer nonce pnonce + svcSet $ OurRequestConfirm Nothing + replyPacket $ PairingRequestNonce nonce + x@(OurRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x + + (OurRequestConfirm _, PairingAccept x) -> do + flip catchError (reject . PairingFailedOther) $ do + pairingVerifyResult x >>= \case + Just x' -> do + join $ asks $ pairingHookConfirmedRequest . svcAttributes + svcSet $ OurRequestConfirm (Just x') + Nothing -> do + join $ asks $ pairingHookVerifyFailed . svcAttributes + svcSet NoPairing + replyPacket PairingReject + + x@(OurRequestConfirm _, _) -> reject $ uncurry PairingUnexpectedMessage x + + (OurRequestReady, PairingAccept x) -> do + flip catchError (reject . PairingFailedOther) $ do + pairingVerifyResult x >>= \case + Just x' -> do + pairingFinalizeRequest x' + join $ asks $ pairingHookAcceptedResponse . svcAttributes + svcSet $ PairingDone + Nothing -> do + join $ asks $ pairingHookVerifyFailed . svcAttributes + throwError "" + x@(OurRequestReady, _) -> reject $ uncurry PairingUnexpectedMessage x + + (PeerRequest peer self nonce dgst, PairingRequestNonce pnonce) -> do + if dgst == nonceDigest peer self pnonce BA.empty + then do hook <- asks $ pairingHookRequestNonce . svcAttributes + hook $ confirmationNumber $ nonceDigest peer self pnonce nonce + svcSet PeerRequestConfirm + else do join $ asks $ pairingHookRequestNonceFailed . svcAttributes + svcSet NoPairing + replyPacket PairingReject + x@(PeerRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x + x@(PeerRequestConfirm, _) -> reject $ uncurry PairingUnexpectedMessage x + +reject :: PairingResult a => PairingFailureReason a -> ServiceHandler (PairingService a) () +reject reason = do + join $ asks $ flip pairingHookFailed reason . svcAttributes + svcSet NoPairing + replyPacket PairingReject + + +nonceDigest :: UnifiedIdentity -> UnifiedIdentity -> Bytes -> Bytes -> RefDigest +nonceDigest idReq idRsp nonceReq nonceRsp = hashToRefDigest $ serializeObject $ Rec + [ (BC.pack "id-req", RecRef $ storedRef $ idData idReq) + , (BC.pack "id-rsp", RecRef $ storedRef $ idData idRsp) + , (BC.pack "nonce-req", RecBinary $ convert nonceReq) + , (BC.pack "nonce-rsp", RecBinary $ convert nonceRsp) + ] + +confirmationNumber :: RefDigest -> String +confirmationNumber dgst = + case map fromIntegral $ BA.unpack dgst :: [Word32] of + (a:b:c:d:_) -> let str = show $ ((a `shift` 24) .|. (b `shift` 16) .|. (c `shift` 8) .|. d) `mod` (10 ^ len) + in replicate (len - length str) '0' ++ str + _ -> "" + where len = 6 + +pairingRequest :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingRequest _ peer = do + self <- liftIO $ serverIdentity $ peerServer peer + nonce <- liftIO $ getRandomBytes 32 + pid <- peerIdentity peer >>= \case + PeerIdentityFull pid -> return pid + _ -> throwError "incomplete peer identity" + sendToPeerWith @(PairingService a) peer $ \case + NoPairing -> return (Just $ PairingRequest (idData self) (idData pid) (nonceDigest self pid nonce BA.empty), OurRequest self pid nonce) + _ -> throwError "already in progress" + +pairingAccept :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingAccept _ peer = runPeerService @(PairingService a) peer $ do + svcGet >>= \case + NoPairing -> throwError $ "none in progress" + OurRequest {} -> throwError $ "waiting for peer" + OurRequestConfirm Nothing -> do + join $ asks $ pairingHookConfirmedResponse . svcAttributes + svcSet OurRequestReady + OurRequestConfirm (Just verified) -> do + join $ asks $ pairingHookAcceptedResponse . svcAttributes + pairingFinalizeRequest verified + svcSet PairingDone + OurRequestReady -> throwError $ "already accepted, waiting for peer" + PeerRequest {} -> throwError $ "waiting for peer" + PeerRequestConfirm -> do + join $ asks $ pairingHookAcceptedRequest . svcAttributes + replyPacket . PairingAccept =<< pairingFinalizeResponse + svcSet PairingDone + PairingDone -> throwError $ "already done" + +pairingReject :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingReject _ peer = runPeerService @(PairingService a) peer $ do + svcGet >>= \case + NoPairing -> throwError $ "none in progress" + PairingDone -> throwError $ "already done" + _ -> reject PairingUserRejected diff --git a/src/Erebos/PubKey.hs b/src/Erebos/PubKey.hs new file mode 100644 index 0000000..09a8e02 --- /dev/null +++ b/src/Erebos/PubKey.hs @@ -0,0 +1,156 @@ +module Erebos.PubKey ( + PublicKey, SecretKey, + KeyPair(generateKeys), loadKey, loadKeyMb, + Signature(sigKey), Signed, signedData, signedSignature, + sign, signAdd, isSignedBy, + fromSigned, + unsafeMapSigned, + + PublicKexKey, SecretKexKey, + dhSecret, +) where + +import Control.Monad +import Control.Monad.Except + +import Crypto.Error +import qualified Crypto.PubKey.Ed25519 as ED +import qualified Crypto.PubKey.Curve25519 as CX + +import Data.ByteArray +import Data.ByteString (ByteString) +import qualified Data.Text as T + +import Erebos.Storage +import Erebos.Storage.Key + +data PublicKey = PublicKey ED.PublicKey + deriving (Show) + +data SecretKey = SecretKey ED.SecretKey (Stored PublicKey) + +data Signature = Signature + { sigKey :: Stored PublicKey + , sigSignature :: ED.Signature + } + deriving (Show) + +data Signed a = Signed + { signedData_ :: Stored a + , signedSignature_ :: [Stored Signature] + } + deriving (Show) + +signedData :: Signed a -> Stored a +signedData = signedData_ + +signedSignature :: Signed a -> [Stored Signature] +signedSignature = signedSignature_ + +instance KeyPair SecretKey PublicKey where + keyGetPublic (SecretKey _ pub) = pub + keyGetData (SecretKey sec _) = convert sec + keyFromData kdata spub = do + skey <- maybeCryptoError $ ED.secretKey kdata + let PublicKey pkey = fromStored spub + guard $ ED.toPublic skey == pkey + return $ SecretKey skey spub + generateKeys st = do + secret <- ED.generateSecretKey + public <- wrappedStore st $ PublicKey $ ED.toPublic secret + let pair = SecretKey secret public + storeKey pair + return (pair, public) + +instance Storable PublicKey where + store' (PublicKey pk) = storeRec $ do + storeText "type" $ T.pack "ed25519" + storeBinary "pubkey" pk + + load' = loadRec $ do + ktype <- loadText "type" + guard $ ktype == "ed25519" + maybe (throwError "Public key decoding failed") (return . PublicKey) . + maybeCryptoError . (ED.publicKey :: ByteString -> CryptoFailable ED.PublicKey) =<< + loadBinary "pubkey" + +instance Storable Signature where + store' sig = storeRec $ do + storeRef "key" $ sigKey sig + storeBinary "sig" $ sigSignature sig + + load' = loadRec $ Signature + <$> loadRef "key" + <*> loadSignature "sig" + where loadSignature = maybe (throwError "Signature decoding failed") return . + maybeCryptoError . (ED.signature :: ByteString -> CryptoFailable ED.Signature) <=< loadBinary + +instance Storable a => Storable (Signed a) where + store' sig = storeRec $ do + storeRef "SDATA" $ signedData sig + mapM_ (storeRef "sig") $ signedSignature sig + + load' = loadRec $ do + sdata <- loadRef "SDATA" + sigs <- loadRefs "sig" + forM_ sigs $ \sig -> do + let PublicKey pubkey = fromStored $ sigKey $ fromStored sig + when (not $ ED.verify pubkey (storedRef sdata) $ sigSignature $ fromStored sig) $ + throwError "signature verification failed" + return $ Signed sdata sigs + +sign :: MonadStorage m => SecretKey -> Stored a -> m (Signed a) +sign secret val = signAdd secret $ Signed val [] + +signAdd :: MonadStorage m => SecretKey -> Signed a -> m (Signed a) +signAdd (SecretKey secret spublic) (Signed val sigs) = do + let PublicKey public = fromStored spublic + sig = ED.sign secret public $ storedRef val + ssig <- mstore $ Signature spublic sig + return $ Signed val (ssig : sigs) + +isSignedBy :: Signed a -> Stored PublicKey -> Bool +isSignedBy sig key = key `elem` map (sigKey . fromStored) (signedSignature sig) + +fromSigned :: Stored (Signed a) -> a +fromSigned = fromStored . signedData . fromStored + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapSigned :: (a -> b) -> Signed a -> Signed b +unsafeMapSigned f signed = signed { signedData_ = unsafeMapStored f (signedData_ signed) } + + +data PublicKexKey = PublicKexKey CX.PublicKey + deriving (Show) + +data SecretKexKey = SecretKexKey CX.SecretKey (Stored PublicKexKey) + +instance KeyPair SecretKexKey PublicKexKey where + keyGetPublic (SecretKexKey _ pub) = pub + keyGetData (SecretKexKey sec _) = convert sec + keyFromData kdata spub = do + skey <- maybeCryptoError $ CX.secretKey kdata + let PublicKexKey pkey = fromStored spub + guard $ CX.toPublic skey == pkey + return $ SecretKexKey skey spub + generateKeys st = do + secret <- CX.generateSecretKey + public <- wrappedStore st $ PublicKexKey $ CX.toPublic secret + let pair = SecretKexKey secret public + storeKey pair + return (pair, public) + +instance Storable PublicKexKey where + store' (PublicKexKey pk) = storeRec $ do + storeText "type" $ T.pack "x25519" + storeBinary "pubkey" pk + + load' = loadRec $ do + ktype <- loadText "type" + guard $ ktype == "x25519" + maybe (throwError "public key decoding failed") (return . PublicKexKey) . + maybeCryptoError . (CX.publicKey :: ScrubbedBytes -> CryptoFailable CX.PublicKey) =<< + loadBinary "pubkey" + +dhSecret :: SecretKexKey -> PublicKexKey -> ScrubbedBytes +dhSecret (SecretKexKey secret _) (PublicKexKey public) = convert $ CX.dh public secret diff --git a/src/Erebos/Service.hs b/src/Erebos/Service.hs new file mode 100644 index 0000000..f8428d1 --- /dev/null +++ b/src/Erebos/Service.hs @@ -0,0 +1,190 @@ +module Erebos.Service ( + Service(..), + SomeService(..), someService, someServiceAttr, someServiceID, + SomeServiceState(..), fromServiceState, someServiceEmptyState, + SomeServiceGlobalState(..), fromServiceGlobalState, someServiceEmptyGlobalState, + SomeStorageWatcher(..), + ServiceID, mkServiceID, + + ServiceHandler, + ServiceInput(..), + ServiceReply(..), + runServiceHandler, + + svcGet, svcSet, svcModify, + svcGetGlobal, svcSetGlobal, svcModifyGlobal, + svcGetLocal, svcSetLocal, + + svcSelf, + svcPrint, + + replyPacket, replyStored, replyStoredRef, + afterCommit, +) where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer + +import Data.Kind +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U + +import Erebos.Identity +import {-# SOURCE #-} Erebos.Network +import Erebos.State +import Erebos.Storage + +class (Typeable s, Storable s, Typeable (ServiceState s), Typeable (ServiceGlobalState s)) => Service s where + serviceID :: proxy s -> ServiceID + serviceHandler :: Stored s -> ServiceHandler s () + + serviceNewPeer :: ServiceHandler s () + serviceNewPeer = return () + + type ServiceAttributes s = attr | attr -> s + type ServiceAttributes s = Proxy s + defaultServiceAttributes :: proxy s -> ServiceAttributes s + default defaultServiceAttributes :: ServiceAttributes s ~ Proxy s => proxy s -> ServiceAttributes s + defaultServiceAttributes _ = Proxy + + type ServiceState s :: Type + type ServiceState s = () + emptyServiceState :: proxy s -> ServiceState s + default emptyServiceState :: ServiceState s ~ () => proxy s -> ServiceState s + emptyServiceState _ = () + + type ServiceGlobalState s :: Type + type ServiceGlobalState s = () + emptyServiceGlobalState :: proxy s -> ServiceGlobalState s + default emptyServiceGlobalState :: ServiceGlobalState s ~ () => proxy s -> ServiceGlobalState s + emptyServiceGlobalState _ = () + + serviceStorageWatchers :: proxy s -> [SomeStorageWatcher s] + serviceStorageWatchers _ = [] + + +data SomeService = forall s. Service s => SomeService (Proxy s) (ServiceAttributes s) + +someService :: forall s proxy. Service s => proxy s -> SomeService +someService _ = SomeService @s Proxy (defaultServiceAttributes @s Proxy) + +someServiceAttr :: forall s. Service s => ServiceAttributes s -> SomeService +someServiceAttr attr = SomeService @s Proxy attr + +someServiceID :: SomeService -> ServiceID +someServiceID (SomeService s _) = serviceID s + +data SomeServiceState = forall s. Service s => SomeServiceState (Proxy s) (ServiceState s) + +fromServiceState :: Service s => proxy s -> SomeServiceState -> Maybe (ServiceState s) +fromServiceState _ (SomeServiceState _ s) = cast s + +someServiceEmptyState :: SomeService -> SomeServiceState +someServiceEmptyState (SomeService p _) = SomeServiceState p (emptyServiceState p) + +data SomeServiceGlobalState = forall s. Service s => SomeServiceGlobalState (Proxy s) (ServiceGlobalState s) + +fromServiceGlobalState :: Service s => proxy s -> SomeServiceGlobalState -> Maybe (ServiceGlobalState s) +fromServiceGlobalState _ (SomeServiceGlobalState _ s) = cast s + +someServiceEmptyGlobalState :: SomeService -> SomeServiceGlobalState +someServiceEmptyGlobalState (SomeService p _) = SomeServiceGlobalState p (emptyServiceGlobalState p) + + +data SomeStorageWatcher s = forall a. Eq a => SomeStorageWatcher (Stored LocalState -> a) (a -> ServiceHandler s ()) + + +newtype ServiceID = ServiceID UUID + deriving (Eq, Ord, Show, StorableUUID) + +mkServiceID :: String -> ServiceID +mkServiceID = maybe (error "Invalid service ID") ServiceID . U.fromString + +data ServiceInput s = ServiceInput + { svcAttributes :: ServiceAttributes s + , svcPeer :: Peer + , svcPeerIdentity :: UnifiedIdentity + , svcServer :: Server + , svcPrintOp :: String -> IO () + } + +data ServiceReply s = ServiceReply (Either s (Stored s)) Bool + | ServiceFinally (IO ()) + +data ServiceHandlerState s = ServiceHandlerState + { svcValue :: ServiceState s + , svcGlobal :: ServiceGlobalState s + , svcLocal :: Stored LocalState + } + +newtype ServiceHandler s a = ServiceHandler (ReaderT (ServiceInput s) (WriterT [ServiceReply s] (StateT (ServiceHandlerState s) (ExceptT String IO))) a) + deriving (Functor, Applicative, Monad, MonadReader (ServiceInput s), MonadWriter [ServiceReply s], MonadState (ServiceHandlerState s), MonadError String, MonadIO) + +instance MonadStorage (ServiceHandler s) where + getStorage = asks $ peerStorage . svcPeer + +instance MonadHead LocalState (ServiceHandler s) where + updateLocalHead f = do + (ls, x) <- f =<< gets svcLocal + modify $ \s -> s { svcLocal = ls } + return x + +runServiceHandler :: Service s => Head LocalState -> ServiceInput s -> ServiceState s -> ServiceGlobalState s -> ServiceHandler s () -> IO ([ServiceReply s], (ServiceState s, ServiceGlobalState s)) +runServiceHandler h input svc global shandler = do + let sstate = ServiceHandlerState { svcValue = svc, svcGlobal = global, svcLocal = headStoredObject h } + ServiceHandler handler = shandler + (runExceptT $ flip runStateT sstate $ execWriterT $ flip runReaderT input $ handler) >>= \case + Left err -> do + svcPrintOp input $ "service failed: " ++ err + return ([], (svc, global)) + Right (rsp, sstate') + | svcLocal sstate' == svcLocal sstate -> return (rsp, (svcValue sstate', svcGlobal sstate')) + | otherwise -> replaceHead h (svcLocal sstate') >>= \case + Left (Just h') -> runServiceHandler h' input svc global shandler + _ -> return (rsp, (svcValue sstate', svcGlobal sstate')) + +svcGet :: ServiceHandler s (ServiceState s) +svcGet = gets svcValue + +svcSet :: ServiceState s -> ServiceHandler s () +svcSet x = modify $ \st -> st { svcValue = x } + +svcModify :: (ServiceState s -> ServiceState s) -> ServiceHandler s () +svcModify f = modify $ \st -> st { svcValue = f (svcValue st) } + +svcGetGlobal :: ServiceHandler s (ServiceGlobalState s) +svcGetGlobal = gets svcGlobal + +svcSetGlobal :: ServiceGlobalState s -> ServiceHandler s () +svcSetGlobal x = modify $ \st -> st { svcGlobal = x } + +svcModifyGlobal :: (ServiceGlobalState s -> ServiceGlobalState s) -> ServiceHandler s () +svcModifyGlobal f = modify $ \st -> st { svcGlobal = f (svcGlobal st) } + +svcGetLocal :: ServiceHandler s (Stored LocalState) +svcGetLocal = gets svcLocal + +svcSetLocal :: Stored LocalState -> ServiceHandler s () +svcSetLocal x = modify $ \st -> st { svcLocal = x } + +svcSelf :: ServiceHandler s UnifiedIdentity +svcSelf = maybe (throwError "failed to validate own identity") return . + validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal + +svcPrint :: String -> ServiceHandler s () +svcPrint str = afterCommit . ($ str) =<< asks svcPrintOp + +replyPacket :: Service s => s -> ServiceHandler s () +replyPacket x = tell [ServiceReply (Left x) True] + +replyStored :: Service s => Stored s -> ServiceHandler s () +replyStored x = tell [ServiceReply (Right x) True] + +replyStoredRef :: Service s => Stored s -> ServiceHandler s () +replyStoredRef x = tell [ServiceReply (Right x) False] + +afterCommit :: IO () -> ServiceHandler s () +afterCommit x = tell [ServiceFinally x] diff --git a/src/Erebos/Set.hs b/src/Erebos/Set.hs new file mode 100644 index 0000000..0abe02d --- /dev/null +++ b/src/Erebos/Set.hs @@ -0,0 +1,78 @@ +module Erebos.Set ( + Set, + + emptySet, + loadSet, + storeSetAdd, + + fromSetBy, +) where + +import Control.Arrow +import Control.Monad.IO.Class + +import Data.Function +import Data.List +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Ord + +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Util + +data Set a = Set [Stored (SetItem (Component a))] + +data SetItem a = SetItem + { siPrev :: [Stored (SetItem a)] + , siItem :: [Stored a] + } + +instance Storable a => Storable (SetItem a) where + store' x = storeRec $ do + mapM_ (storeRef "PREV") $ siPrev x + mapM_ (storeRef "item") $ siItem x + + load' = loadRec $ SetItem + <$> loadRefs "PREV" + <*> loadRefs "item" + +instance Mergeable a => Mergeable (Set a) where + type Component (Set a) = SetItem (Component a) + mergeSorted = Set + toComponents (Set items) = items + + +emptySet :: Set a +emptySet = Set [] + +loadSet :: Mergeable a => Ref -> Set a +loadSet = mergeSorted . (:[]) . wrappedLoad + +storeSetAdd :: (Mergeable a, MonadIO m) => Storage -> a -> Set a -> m (Set a) +storeSetAdd st x (Set prev) = Set . (:[]) <$> wrappedStore st SetItem + { siPrev = prev + , siItem = toComponents x + } + + +fromSetBy :: forall a. Mergeable a => (a -> a -> Ordering) -> Set a -> [a] +fromSetBy cmp (Set heads) = sortBy cmp $ map merge $ groupRelated items + where + -- gather all item components in the set history + items :: [Stored (Component a)] + items = walkAncestors (siItem . fromStored) heads + + -- map individual roots to full root set as joined in history of individual items + rootToRootSet :: Map RefDigest [RefDigest] + rootToRootSet = foldl' (\m rs -> foldl' (\m' r -> M.insertWith (\a b -> uniq $ sort $ a++b) r rs m') m rs) M.empty $ + map (map (refDigest . storedRef) . storedRoots) items + + -- get full root set for given item component + storedRootSet :: Stored (Component a) -> [RefDigest] + storedRootSet = fromJust . flip M.lookup rootToRootSet . refDigest . storedRef . head . storedRoots + + -- group components of single item, i.e. components sharing some root + groupRelated :: [Stored (Component a)] -> [[Stored (Component a)]] + groupRelated = map (map fst) . groupBy ((==) `on` snd) . sortBy (comparing snd) . map (id &&& storedRootSet) diff --git a/src/Erebos/State.hs b/src/Erebos/State.hs new file mode 100644 index 0000000..1f0bf7d --- /dev/null +++ b/src/Erebos/State.hs @@ -0,0 +1,199 @@ +module Erebos.State ( + LocalState(..), + SharedState, SharedType(..), + SharedTypeID, mkSharedTypeID, + + MonadHead(..), + updateLocalHead_, + + loadLocalStateHead, + + updateSharedState, updateSharedState_, + lookupSharedValue, makeSharedStateUpdate, + + localIdentity, + headLocalIdentity, + + mergeSharedIdentity, + updateSharedIdentity, + interactiveIdentityUpdate, +) where + +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Foldable +import Data.Maybe +import qualified Data.Text as T +import qualified Data.Text.IO as T +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U + +import System.IO + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Storage +import Erebos.Storage.Merge + +data LocalState = LocalState + { lsIdentity :: Stored (Signed ExtendedIdentityData) + , lsShared :: [Stored SharedState] + } + +data SharedState = SharedState + { ssPrev :: [Stored SharedState] + , ssType :: Maybe SharedTypeID + , ssValue :: [Ref] + } + +newtype SharedTypeID = SharedTypeID UUID + deriving (Eq, Ord, StorableUUID) + +mkSharedTypeID :: String -> SharedTypeID +mkSharedTypeID = maybe (error "Invalid shared type ID") SharedTypeID . U.fromString + +class Mergeable a => SharedType a where + sharedTypeID :: proxy a -> SharedTypeID + +instance Storable LocalState where + store' st = storeRec $ do + storeRef "id" $ lsIdentity st + mapM_ (storeRef "shared") $ lsShared st + + load' = loadRec $ LocalState + <$> loadRef "id" + <*> loadRefs "shared" + +instance HeadType LocalState where + headTypeID _ = mkHeadTypeID "1d7491a9-7bcb-4eaa-8f13-c8c4c4087e4e" + +instance Storable SharedState where + store' st = storeRec $ do + mapM_ (storeRef "PREV") $ ssPrev st + storeMbUUID "type" $ ssType st + mapM_ (storeRawRef "value") $ ssValue st + + load' = loadRec $ SharedState + <$> loadRefs "PREV" + <*> loadMbUUID "type" + <*> loadRawRefs "value" + +instance SharedType (Maybe ComposedIdentity) where + sharedTypeID _ = mkSharedTypeID "0c6c1fe0-f2d7-4891-926b-c332449f7871" + + +class (MonadIO m, MonadStorage m) => MonadHead a m where + updateLocalHead :: (Stored a -> m (Stored a, b)) -> m b + +updateLocalHead_ :: MonadHead a m => (Stored a -> m (Stored a)) -> m () +updateLocalHead_ f = updateLocalHead (fmap (,()) . f) + +instance (HeadType a, MonadIO m) => MonadHead a (ReaderT (Head a) m) where + updateLocalHead f = do + h <- ask + snd <$> updateHead h f + + +loadLocalStateHead :: MonadIO m => Storage -> m (Head LocalState) +loadLocalStateHead st = loadHeads st >>= \case + (h:_) -> return h + [] -> liftIO $ do + putStr "Name: " + hFlush stdout + name <- T.getLine + + putStr "Device: " + hFlush stdout + devName <- T.getLine + + owner <- if + | T.null name -> return Nothing + | otherwise -> Just <$> createIdentity st (Just name) Nothing + + identity <- createIdentity st (if T.null devName then Nothing else Just devName) owner + + shared <- wrappedStore st $ SharedState + { ssPrev = [] + , ssType = Just $ sharedTypeID @(Maybe ComposedIdentity) Proxy + , ssValue = [storedRef $ idExtData $ fromMaybe identity owner] + } + storeHead st $ LocalState + { lsIdentity = idExtData identity + , lsShared = [shared] + } + +localIdentity :: LocalState -> UnifiedIdentity +localIdentity ls = maybe (error "failed to verify local identity") + (updateOwners $ maybe [] idExtDataF $ lookupSharedValue $ lsShared ls) + (validateExtendedIdentity $ lsIdentity ls) + +headLocalIdentity :: Head LocalState -> UnifiedIdentity +headLocalIdentity = localIdentity . headObject + + +updateSharedState_ :: forall a m. (SharedType a, MonadHead LocalState m) => (a -> m a) -> Stored LocalState -> m (Stored LocalState) +updateSharedState_ f = fmap fst <$> updateSharedState (fmap (,()) . f) + +updateSharedState :: forall a b m. (SharedType a, MonadHead LocalState m) => (a -> m (a, b)) -> Stored LocalState -> m (Stored LocalState, b) +updateSharedState f = \ls -> do + let shared = lsShared $ fromStored ls + val = lookupSharedValue shared + st <- getStorage + (val', x) <- f val + (,x) <$> if toComponents val' == toComponents val + then return ls + else do shared' <- makeSharedStateUpdate st val' shared + wrappedStore st (fromStored ls) { lsShared = [shared'] } + +lookupSharedValue :: forall a. SharedType a => [Stored SharedState] -> a +lookupSharedValue = mergeSorted . filterAncestors . map wrappedLoad . concatMap (ssValue . fromStored) . filterAncestors . helper + where helper (x:xs) | Just sid <- ssType (fromStored x), sid == sharedTypeID @a Proxy = x : helper xs + | otherwise = helper $ ssPrev (fromStored x) ++ xs + helper [] = [] + +makeSharedStateUpdate :: forall a m. MonadIO m => SharedType a => Storage -> a -> [Stored SharedState] -> m (Stored SharedState) +makeSharedStateUpdate st val prev = liftIO $ wrappedStore st SharedState + { ssPrev = prev + , ssType = Just $ sharedTypeID @a Proxy + , ssValue = storedRef <$> toComponents val + } + + +mergeSharedIdentity :: (MonadHead LocalState m, MonadError String m) => m UnifiedIdentity +mergeSharedIdentity = updateLocalHead $ updateSharedState $ \case + Just cidentity -> do + identity <- mergeIdentity cidentity + return (Just $ toComposedIdentity identity, identity) + Nothing -> throwError "no existing shared identity" + +updateSharedIdentity :: (MonadHead LocalState m, MonadError String m) => m () +updateSharedIdentity = updateLocalHead_ $ updateSharedState_ $ \case + Just identity -> do + Just . toComposedIdentity <$> interactiveIdentityUpdate identity + Nothing -> throwError "no existing shared identity" + +interactiveIdentityUpdate :: (Foldable f, MonadStorage m, MonadIO m, MonadError String m) => Identity f -> m UnifiedIdentity +interactiveIdentityUpdate identity = do + let public = idKeyIdentity identity + + name <- liftIO $ do + T.putStr $ T.concat $ concat + [ [ T.pack "Name" ] + , case idName identity of + Just name -> [T.pack " [", name, T.pack "]"] + Nothing -> [] + , [ T.pack ": " ] + ] + hFlush stdout + T.getLine + + if | T.null name -> mergeIdentity identity + | otherwise -> do + secret <- loadKey public + maybe (throwError "created invalid identity") return . validateIdentity =<< + mstore =<< sign secret =<< mstore (emptyIdentityData public) + { iddPrev = toList $ idDataF identity + , iddName = Just name + } diff --git a/src/Erebos/Storage.hs b/src/Erebos/Storage.hs new file mode 100644 index 0000000..bad1b37 --- /dev/null +++ b/src/Erebos/Storage.hs @@ -0,0 +1,1014 @@ +module Erebos.Storage ( + Storage, PartialStorage, StorageCompleteness, + openStorage, memoryStorage, + deriveEphemeralStorage, derivePartialStorage, + + Ref, PartialRef, RefDigest, + refDigest, + readRef, showRef, showRefDigest, + refDigestFromByteString, hashToRefDigest, + copyRef, partialRef, partialRefFromDigest, + + Object, PartialObject, Object'(..), RecItem, RecItem'(..), + serializeObject, deserializeObject, deserializeObjects, + ioLoadObject, ioLoadBytes, + storeRawBytes, lazyLoadBytes, + storeObject, + collectObjects, collectStoredObjects, + + Head, HeadType(..), + HeadTypeID, mkHeadTypeID, + headId, headStorage, headRef, headObject, headStoredObject, + loadHeads, loadHead, reloadHead, + storeHead, replaceHead, updateHead, updateHead_, + + WatchedHead, + watchHead, watchHeadWith, unwatchHead, + + MonadStorage(..), + + Storable(..), ZeroStorable(..), + StorableText(..), StorableDate(..), StorableUUID(..), + + Store, StoreRec, + evalStore, evalStoreObject, + storeBlob, storeRec, storeZero, + storeEmpty, storeInt, storeNum, storeText, storeBinary, storeDate, storeUUID, storeRef, storeRawRef, + storeMbEmpty, storeMbInt, storeMbNum, storeMbText, storeMbBinary, storeMbDate, storeMbUUID, storeMbRef, storeMbRawRef, + storeZRef, + + Load, LoadRec, + evalLoad, + loadCurrentRef, loadCurrentObject, + loadRecCurrentRef, loadRecItems, + + loadBlob, loadRec, loadZero, + loadEmpty, loadInt, loadNum, loadText, loadBinary, loadDate, loadUUID, loadRef, loadRawRef, + loadMbEmpty, loadMbInt, loadMbNum, loadMbText, loadMbBinary, loadMbDate, loadMbUUID, loadMbRef, loadMbRawRef, + loadTexts, loadBinaries, loadRefs, loadRawRefs, + loadZRef, + + Stored, + fromStored, storedRef, + wrappedStore, wrappedLoad, + copyStored, + unsafeMapStored, + + StoreInfo(..), makeStoreInfo, + + StoredHistory, + fromHistory, fromHistoryAt, storedFromHistory, storedHistoryList, + beginHistory, modifyHistory, +) where + +import qualified Codec.MIME.Type as MIME +import qualified Codec.MIME.Parse as MIME + +import Control.Applicative +import Control.Arrow +import Control.Concurrent +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.Writer + +import Crypto.Hash + +import Data.ByteString (ByteString) +import qualified Data.ByteArray as BA +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.Char +import Data.Function +import qualified Data.HashTable.IO as HT +import Data.List +import qualified Data.Map as M +import Data.Maybe +import Data.Ratio +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding +import Data.Text.Encoding.Error +import Data.Time.Calendar +import Data.Time.Clock +import Data.Time.Format +import Data.Time.LocalTime +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U +import qualified Data.UUID.V4 as U + +import System.Directory +import System.FilePath +import System.INotify +import System.IO.Error +import System.IO.Unsafe + +import Erebos.Storage.Internal + + +type Storage = Storage' Complete +type PartialStorage = Storage' Partial + +openStorage :: FilePath -> IO Storage +openStorage path = do + createDirectoryIfMissing True $ path ++ "/objects" + createDirectoryIfMissing True $ path ++ "/heads" + watchers <- newMVar ([], WatchList 1 []) + refgen <- newMVar =<< HT.new + refroots <- newMVar =<< HT.new + return $ Storage + { stBacking = StorageDir path watchers + , stParent = Nothing + , stRefGeneration = refgen + , stRefRoots = refroots + } + +memoryStorage' :: IO (Storage' c') +memoryStorage' = do + backing <- StorageMemory <$> newMVar [] <*> newMVar M.empty <*> newMVar M.empty <*> newMVar (WatchList 1 []) + refgen <- newMVar =<< HT.new + refroots <- newMVar =<< HT.new + return $ Storage + { stBacking = backing + , stParent = Nothing + , stRefGeneration = refgen + , stRefRoots = refroots + } + +memoryStorage :: IO Storage +memoryStorage = memoryStorage' + +deriveEphemeralStorage :: Storage -> IO Storage +deriveEphemeralStorage parent = do + st <- memoryStorage + return $ st { stParent = Just parent } + +derivePartialStorage :: Storage -> IO PartialStorage +derivePartialStorage parent = do + st <- memoryStorage' + return $ st { stParent = Just parent } + +type Ref = Ref' Complete +type PartialRef = Ref' Partial + +zeroRef :: Storage' c -> Ref' c +zeroRef s = Ref s (RefDigest h) + where h = case digestFromByteString $ B.replicate (hashDigestSize $ digestAlgo h) 0 of + Nothing -> error $ "Failed to create zero hash" + Just h' -> h' + digestAlgo :: Digest a -> a + digestAlgo = undefined + +isZeroRef :: Ref' c -> Bool +isZeroRef (Ref _ h) = all (==0) $ BA.unpack h + + +refFromDigest :: Storage' c -> RefDigest -> IO (Maybe (Ref' c)) +refFromDigest st dgst = fmap (const $ Ref st dgst) <$> ioLoadBytesFromStorage st dgst + +readRef :: Storage -> ByteString -> IO (Maybe Ref) +readRef s b = + case readRefDigest b of + Nothing -> return Nothing + Just dgst -> refFromDigest s dgst + +copyRef' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Ref' c -> IO (c (Ref' c')) +copyRef' st ref'@(Ref _ dgst) = refFromDigest st dgst >>= \case Just ref -> return $ return ref + Nothing -> doCopy + where doCopy = do mbobj' <- ioLoadObject ref' + mbobj <- sequence $ copyObject' st <$> mbobj' + sequence $ unsafeStoreObject st <$> join mbobj + +copyObject' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (c (Object' c')) +copyObject' _ (Blob bs) = return $ return $ Blob bs +copyObject' st (Rec rs) = fmap Rec . sequence <$> mapM copyItem rs + where copyItem :: (ByteString, RecItem' c) -> IO (c (ByteString, RecItem' c')) + copyItem (n, item) = fmap (n,) <$> case item of + RecEmpty -> return $ return $ RecEmpty + RecInt x -> return $ return $ RecInt x + RecNum x -> return $ return $ RecNum x + RecText x -> return $ return $ RecText x + RecBinary x -> return $ return $ RecBinary x + RecDate x -> return $ return $ RecDate x + RecUUID x -> return $ return $ RecUUID x + RecRef x -> fmap RecRef <$> copyRef' st x +copyObject' _ ZeroObject = return $ return ZeroObject + +copyRef :: forall c c' m. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => Storage' c' -> Ref' c -> m (LoadResult c (Ref' c')) +copyRef st ref' = liftIO $ returnLoadResult <$> copyRef' st ref' + +copyObject :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (LoadResult c (Object' c')) +copyObject st obj' = returnLoadResult <$> copyObject' st obj' + +partialRef :: PartialStorage -> Ref -> PartialRef +partialRef st (Ref _ dgst) = Ref st dgst + +partialRefFromDigest :: PartialStorage -> RefDigest -> PartialRef +partialRefFromDigest st dgst = Ref st dgst + + +data Object' c + = Blob ByteString + | Rec [(ByteString, RecItem' c)] + | ZeroObject + deriving (Show) + +type Object = Object' Complete +type PartialObject = Object' Partial + +data RecItem' c + = RecEmpty + | RecInt Integer + | RecNum Rational + | RecText Text + | RecBinary ByteString + | RecDate ZonedTime + | RecUUID UUID + | RecRef (Ref' c) + deriving (Show) + +type RecItem = RecItem' Complete + +serializeObject :: Object' c -> BL.ByteString +serializeObject = \case + Blob cnt -> BL.fromChunks [BC.pack "blob ", BC.pack (show $ B.length cnt), BC.singleton '\n', cnt] + Rec rec -> let cnt = BL.fromChunks $ concatMap (uncurry serializeRecItem) rec + in BL.fromChunks [BC.pack "rec ", BC.pack (show $ BL.length cnt), BC.singleton '\n'] `BL.append` cnt + ZeroObject -> BL.empty + +-- |Serializes and stores object data without ony dependencies, so is safe only +-- if all the referenced objects are already stored or reference is partial. +unsafeStoreObject :: Storage' c -> Object' c -> IO (Ref' c) +unsafeStoreObject storage = \case + ZeroObject -> return $ zeroRef storage + obj -> unsafeStoreRawBytes storage $ serializeObject obj + +storeObject :: PartialStorage -> PartialObject -> IO PartialRef +storeObject = unsafeStoreObject + +storeRawBytes :: PartialStorage -> BL.ByteString -> IO PartialRef +storeRawBytes = unsafeStoreRawBytes + +serializeRecItem :: ByteString -> RecItem' c -> [ByteString] +serializeRecItem name (RecEmpty) = [name, BC.pack ":e", BC.singleton ' ', BC.singleton '\n'] +serializeRecItem name (RecInt x) = [name, BC.pack ":i", BC.singleton ' ', BC.pack (show x), BC.singleton '\n'] +serializeRecItem name (RecNum x) = [name, BC.pack ":n", BC.singleton ' ', BC.pack (showRatio x), BC.singleton '\n'] +serializeRecItem name (RecText x) = [name, BC.pack ":t", BC.singleton ' ', escaped, BC.singleton '\n'] + where escaped = BC.concatMap escape $ encodeUtf8 x + escape '\n' = BC.pack "\n\t" + escape c = BC.singleton c +serializeRecItem name (RecBinary x) = [name, BC.pack ":b ", showHex x, BC.singleton '\n'] +serializeRecItem name (RecDate x) = [name, BC.pack ":d", BC.singleton ' ', BC.pack (formatTime defaultTimeLocale "%s %z" x), BC.singleton '\n'] +serializeRecItem name (RecUUID x) = [name, BC.pack ":u", BC.singleton ' ', U.toASCIIBytes x, BC.singleton '\n'] +serializeRecItem name (RecRef x) = [name, BC.pack ":r ", showRef x, BC.singleton '\n'] + +lazyLoadObject :: forall c. StorageCompleteness c => Ref' c -> LoadResult c (Object' c) +lazyLoadObject = returnLoadResult . unsafePerformIO . ioLoadObject + +ioLoadObject :: forall c. StorageCompleteness c => Ref' c -> IO (c (Object' c)) +ioLoadObject ref | isZeroRef ref = return $ return ZeroObject +ioLoadObject ref@(Ref st rhash) = do + file' <- ioLoadBytes ref + return $ do + file <- file' + let chash = hashToRefDigest file + when (chash /= rhash) $ error $ "Hash mismatch on object " ++ BC.unpack (showRef ref) {- TODO throw -} + return $ case runExcept $ unsafeDeserializeObject st file of + Left err -> error $ err ++ ", ref " ++ BC.unpack (showRef ref) {- TODO throw -} + Right (x, rest) | BL.null rest -> x + | otherwise -> error $ "Superfluous content after " ++ BC.unpack (showRef ref) {- TODO throw -} + +lazyLoadBytes :: forall c. StorageCompleteness c => Ref' c -> LoadResult c BL.ByteString +lazyLoadBytes ref | isZeroRef ref = returnLoadResult (return BL.empty :: c BL.ByteString) +lazyLoadBytes ref = returnLoadResult $ unsafePerformIO $ ioLoadBytes ref + +unsafeDeserializeObject :: Storage' c -> BL.ByteString -> Except String (Object' c, BL.ByteString) +unsafeDeserializeObject _ bytes | BL.null bytes = return (ZeroObject, bytes) +unsafeDeserializeObject st bytes = + case BLC.break (=='\n') bytes of + (line, rest) | Just (otype, len) <- splitObjPrefix line -> do + let (content, next) = first BL.toStrict $ BL.splitAt (fromIntegral len) $ BL.drop 1 rest + guard $ B.length content == len + (,next) <$> case otype of + _ | otype == BC.pack "blob" -> return $ Blob content + | otype == BC.pack "rec" -> maybe (throwError $ "Malformed record item ") + (return . Rec) $ sequence $ map parseRecLine $ mergeCont [] $ BC.lines content + | otherwise -> throwError $ "Unknown object type" + _ -> throwError $ "Malformed object" + where splitObjPrefix line = do + [otype, tlen] <- return $ BLC.words line + (len, rest) <- BLC.readInt tlen + guard $ BL.null rest + return (BL.toStrict otype, len) + + mergeCont cs (a:b:rest) | Just ('\t', b') <- BC.uncons b = mergeCont (b':BC.pack "\n":cs) (a:rest) + mergeCont cs (a:rest) = B.concat (a : reverse cs) : mergeCont [] rest + mergeCont _ [] = [] + + parseRecLine line = do + colon <- BC.elemIndex ':' line + space <- BC.elemIndex ' ' line + guard $ colon < space + let name = B.take colon line + itype = B.take (space-colon-1) $ B.drop (colon+1) line + content = B.drop (space+1) line + + val <- case BC.unpack itype of + "e" -> do guard $ B.null content + return RecEmpty + "i" -> do (num, rest) <- BC.readInteger content + guard $ B.null rest + return $ RecInt num + "n" -> RecNum <$> parseRatio content + "t" -> return $ RecText $ decodeUtf8With lenientDecode content + "b" -> RecBinary <$> readHex content + "d" -> RecDate <$> parseTimeM False defaultTimeLocale "%s %z" (BC.unpack content) + "u" -> RecUUID <$> U.fromASCIIBytes content + "r" -> RecRef . Ref st <$> readRefDigest content + _ -> Nothing + return (name, val) + +deserializeObject :: PartialStorage -> BL.ByteString -> Except String (PartialObject, BL.ByteString) +deserializeObject = unsafeDeserializeObject + +deserializeObjects :: PartialStorage -> BL.ByteString -> Except String [PartialObject] +deserializeObjects _ bytes | BL.null bytes = return [] +deserializeObjects st bytes = do (obj, rest) <- deserializeObject st bytes + (obj:) <$> deserializeObjects st rest + + +collectObjects :: Object -> [Object] +collectObjects obj = obj : map fromStored (fst $ collectOtherStored S.empty obj) + +collectStoredObjects :: Stored Object -> [Stored Object] +collectStoredObjects obj = obj : (fst $ collectOtherStored S.empty $ fromStored obj) + +collectOtherStored :: Set RefDigest -> Object -> ([Stored Object], Set RefDigest) +collectOtherStored seen (Rec items) = foldr helper ([], seen) $ map snd items + where helper (RecRef ref) (xs, s) | r <- refDigest ref + , r `S.notMember` s + = let o = wrappedLoad ref + (xs', s') = collectOtherStored (S.insert r s) $ fromStored o + in ((o : xs') ++ xs, s') + helper _ (xs, s) = (xs, s) +collectOtherStored seen _ = ([], seen) + + +type Head = Head' Complete + +headId :: Head a -> HeadID +headId (Head uuid _) = uuid + +headStorage :: Head a -> Storage +headStorage = refStorage . headRef + +headRef :: Head a -> Ref +headRef (Head _ sx) = storedRef sx + +headObject :: Head a -> a +headObject (Head _ sx) = fromStored sx + +headStoredObject :: Head a -> Stored a +headStoredObject (Head _ sx) = sx + +deriving instance StorableUUID HeadID +deriving instance StorableUUID HeadTypeID + +mkHeadTypeID :: String -> HeadTypeID +mkHeadTypeID = maybe (error "Invalid head type ID") HeadTypeID . U.fromString + +class Storable a => HeadType a where + headTypeID :: proxy a -> HeadTypeID + + +headTypePath :: FilePath -> HeadTypeID -> FilePath +headTypePath spath (HeadTypeID tid) = spath </> "heads" </> U.toString tid + +headPath :: FilePath -> HeadTypeID -> HeadID -> FilePath +headPath spath tid (HeadID hid) = headTypePath spath tid </> U.toString hid + +loadHeads :: forall a m. MonadIO m => HeadType a => Storage -> m [Head a] +loadHeads s@(Storage { stBacking = StorageDir { dirPath = spath }}) = liftIO $ do + let hpath = headTypePath spath $ headTypeID @a Proxy + + files <- filterM (doesFileExist . (hpath </>)) =<< + handleJust (\e -> guard (isDoesNotExistError e)) (const $ return []) + (getDirectoryContents hpath) + fmap catMaybes $ forM files $ \hname -> do + case U.fromString hname of + Just hid -> do + (h:_) <- BC.lines <$> B.readFile (hpath </> hname) + Just ref <- readRef s h + return $ Just $ Head (HeadID hid) $ wrappedLoad ref + Nothing -> return Nothing +loadHeads Storage { stBacking = StorageMemory { memHeads = theads } } = liftIO $ do + let toHead ((tid, hid), ref) | tid == headTypeID @a Proxy = Just $ Head hid $ wrappedLoad ref + | otherwise = Nothing + catMaybes . map toHead <$> readMVar theads + +loadHead :: forall a m. (HeadType a, MonadIO m) => Storage -> HeadID -> m (Maybe (Head a)) +loadHead s@(Storage { stBacking = StorageDir { dirPath = spath }}) hid = liftIO $ do + handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ do + (h:_) <- BC.lines <$> B.readFile (headPath spath (headTypeID @a Proxy) hid) + Just ref <- readRef s h + return $ Just $ Head hid $ wrappedLoad ref +loadHead Storage { stBacking = StorageMemory { memHeads = theads } } hid = liftIO $ do + fmap (Head hid . wrappedLoad) . lookup (headTypeID @a Proxy, hid) <$> readMVar theads + +reloadHead :: (HeadType a, MonadIO m) => Head a -> m (Maybe (Head a)) +reloadHead (Head hid (Stored (Ref st _) _)) = loadHead st hid + +storeHead :: forall a m. MonadIO m => HeadType a => Storage -> a -> m (Head a) +storeHead st obj = liftIO $ do + let tid = headTypeID @a Proxy + hid <- HeadID <$> U.nextRandom + stored <- wrappedStore st obj + case stBacking st of + StorageDir { dirPath = spath } -> do + Right () <- writeFileChecked (headPath spath tid hid) Nothing $ + showRef (storedRef stored) `B.append` BC.singleton '\n' + return () + StorageMemory { memHeads = theads } -> do + modifyMVar_ theads $ return . (((tid, hid), storedRef stored) :) + return $ Head hid stored + +replaceHead :: forall a m. (HeadType a, MonadIO m) => Head a -> Stored a -> m (Either (Maybe (Head a)) (Head a)) +replaceHead prev@(Head hid pobj) stored' = liftIO $ do + let st = headStorage prev + tid = headTypeID @a Proxy + stored <- copyStored st stored' + case stBacking st of + StorageDir { dirPath = spath } -> do + let filename = headPath spath tid hid + showRefL r = showRef r `B.append` BC.singleton '\n' + + writeFileChecked filename (Just $ showRefL $ headRef prev) (showRefL $ storedRef stored) >>= \case + Left Nothing -> return $ Left Nothing + Left (Just bs) -> do Just oref <- readRef st $ BC.takeWhile (/='\n') bs + return $ Left $ Just $ Head hid $ wrappedLoad oref + Right () -> return $ Right $ Head hid stored + + StorageMemory { memHeads = theads, memWatchers = twatch } -> do + res <- modifyMVar theads $ \hs -> do + ws <- map wlFun . filter ((==(tid, hid)) . wlHead) . wlList <$> readMVar twatch + return $ case partition ((==(tid, hid)) . fst) hs of + ([] , _ ) -> (hs, Left Nothing) + ((_, r):_, hs') | r == storedRef pobj -> (((tid, hid), storedRef stored) : hs', + Right (Head hid stored, ws)) + | otherwise -> (hs, Left $ Just $ Head hid $ wrappedLoad r) + case res of + Right (h, ws) -> mapM_ ($ headRef h) ws >> return (Right h) + Left x -> return $ Left x + +updateHead :: (HeadType a, MonadIO m) => Head a -> (Stored a -> m (Stored a, b)) -> m (Maybe (Head a), b) +updateHead h f = do + (o, x) <- f $ headStoredObject h + replaceHead h o >>= \case + Right h' -> return (Just h', x) + Left Nothing -> return (Nothing, x) + Left (Just h') -> updateHead h' f + +updateHead_ :: (HeadType a, MonadIO m) => Head a -> (Stored a -> m (Stored a)) -> m (Maybe (Head a)) +updateHead_ h = fmap fst . updateHead h . (fmap (,()) .) + + +data WatchedHead = forall a. WatchedHead Storage WatchID (MVar a) + +watchHead :: forall a. HeadType a => Head a -> (Head a -> IO ()) -> IO WatchedHead +watchHead h = watchHeadWith h id + +watchHeadWith :: forall a b. (HeadType a, Eq b) => Head a -> (Head a -> b) -> (b -> IO ()) -> IO WatchedHead +watchHeadWith oh@(Head hid (Stored (Ref st _) _)) sel cb = do + memo <- newEmptyMVar + let tid = headTypeID @a Proxy + addWatcher wl = (wl', WatchedHead st (wlNext wl) memo) + where wl' = wl { wlNext = wlNext wl + 1 + , wlList = WatchListItem + { wlID = wlNext wl + , wlHead = (tid, hid) + , wlFun = \r -> do + let x = sel $ Head hid $ wrappedLoad r + modifyMVar_ memo $ \prev -> do + when (x /= prev) $ cb x + return x + } : wlList wl + } + + watched <- case stBacking st of + StorageDir { dirPath = spath, dirWatchers = mvar } -> modifyMVar mvar $ \(ilist, wl) -> do + ilist' <- case lookup tid ilist of + Just _ -> return ilist + Nothing -> do + inotify <- initINotify + void $ addWatch inotify [Move] (BC.pack $ headTypePath spath tid) $ \case + MovedIn { filePath = fpath } | Just ihid <- HeadID <$> U.fromASCIIBytes fpath -> do + loadHead @a st ihid >>= \case + Just h -> mapM_ ($ headRef h) . map wlFun . filter ((== (tid, ihid)) . wlHead) . wlList . snd =<< readMVar mvar + Nothing -> return () + _ -> return () + return $ (tid, inotify) : ilist + return $ first (ilist',) $ addWatcher wl + + StorageMemory { memWatchers = mvar } -> modifyMVar mvar $ return . addWatcher + + cur <- sel . maybe oh id <$> reloadHead oh + cb cur + putMVar memo cur + + return watched + +unwatchHead :: WatchedHead -> IO () +unwatchHead (WatchedHead st wid _) = do + let delWatcher wl = wl { wlList = filter ((/=wid) . wlID) $ wlList wl } + case stBacking st of + StorageDir { dirWatchers = mvar } -> modifyMVar_ mvar $ return . second delWatcher + StorageMemory { memWatchers = mvar } -> modifyMVar_ mvar $ return . delWatcher + + +class Monad m => MonadStorage m where + getStorage :: m Storage + mstore :: Storable a => a -> m (Stored a) + + default mstore :: MonadIO m => Storable a => a -> m (Stored a) + mstore x = do + st <- getStorage + wrappedStore st x + +instance MonadIO m => MonadStorage (ReaderT Storage m) where + getStorage = ask + +instance MonadIO m => MonadStorage (ReaderT (Head a) m) where + getStorage = asks $ headStorage + + +class Storable a where + store' :: a -> Store + load' :: Load a + + store :: StorageCompleteness c => Storage' c -> a -> IO (Ref' c) + store st = evalStore st . store' + load :: Ref -> a + load = evalLoad load' + +class Storable a => ZeroStorable a where + fromZero :: Storage -> a + +data Store = StoreBlob ByteString + | StoreRec (forall c. StorageCompleteness c => Storage' c -> [IO [(ByteString, RecItem' c)]]) + | StoreZero + +evalStore :: StorageCompleteness c => Storage' c -> Store -> IO (Ref' c) +evalStore st = unsafeStoreObject st <=< evalStoreObject st + +evalStoreObject :: StorageCompleteness c => Storage' c -> Store -> IO (Object' c) +evalStoreObject _ (StoreBlob x) = return $ Blob x +evalStoreObject s (StoreRec f) = Rec . concat <$> sequence (f s) +evalStoreObject _ StoreZero = return ZeroObject + +newtype StoreRecM c a = StoreRecM (ReaderT (Storage' c) (Writer [IO [(ByteString, RecItem' c)]]) a) + deriving (Functor, Applicative, Monad) + +type StoreRec c = StoreRecM c () + +newtype Load a = Load (ReaderT (Ref, Object) (Either String) a) + deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError String) + +evalLoad :: Load a -> Ref -> a +evalLoad (Load f) ref = either (error {- TODO throw -} . ((BC.unpack (showRef ref) ++ ": ")++)) id $ runReaderT f (ref, lazyLoadObject ref) + +loadCurrentRef :: Load Ref +loadCurrentRef = Load $ asks fst + +loadCurrentObject :: Load Object +loadCurrentObject = Load $ asks snd + +newtype LoadRec a = LoadRec (ReaderT (Ref, [(ByteString, RecItem)]) (Either String) a) + deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError String) + +loadRecCurrentRef :: LoadRec Ref +loadRecCurrentRef = LoadRec $ asks fst + +loadRecItems :: LoadRec [(ByteString, RecItem)] +loadRecItems = LoadRec $ asks snd + + +instance Storable Object where + store' (Blob bs) = StoreBlob bs + store' (Rec xs) = StoreRec $ \st -> return $ do + Rec xs' <- copyObject st (Rec xs) + return xs' + store' ZeroObject = StoreZero + + load' = loadCurrentObject + + store st = unsafeStoreObject st <=< copyObject st + load = lazyLoadObject + +instance Storable ByteString where + store' = storeBlob + load' = loadBlob id + +instance Storable a => Storable [a] where + store' [] = storeZero + store' (x:xs) = storeRec $ do + storeRef "i" x + storeRef "n" xs + + load' = loadCurrentObject >>= \case + ZeroObject -> return [] + _ -> loadRec $ (:) + <$> loadRef "i" + <*> loadRef "n" + +instance Storable a => ZeroStorable [a] where + fromZero _ = [] + + +storeBlob :: ByteString -> Store +storeBlob = StoreBlob + +storeRec :: (forall c. StorageCompleteness c => StoreRec c) -> Store +storeRec sr = StoreRec $ do + let StoreRecM r = sr + execWriter . runReaderT r + +storeZero :: Store +storeZero = StoreZero + + +class StorableText a where + toText :: a -> Text + fromText :: MonadError String m => Text -> m a + +instance StorableText Text where + toText = id; fromText = return + +instance StorableText [Char] where + toText = T.pack; fromText = return . T.unpack + +instance StorableText MIME.Type where + toText = MIME.showType + fromText = maybe (throwError "Malformed MIME type") return . MIME.parseMIMEType + + +class StorableDate a where + toDate :: a -> ZonedTime + fromDate :: ZonedTime -> a + +instance StorableDate ZonedTime where + toDate = id; fromDate = id + +instance StorableDate UTCTime where + toDate = utcToZonedTime utc + fromDate = zonedTimeToUTC + +instance StorableDate Day where + toDate day = toDate $ UTCTime day 0 + fromDate = utctDay . fromDate + + +class StorableUUID a where + toUUID :: a -> UUID + fromUUID :: UUID -> a + +instance StorableUUID UUID where + toUUID = id; fromUUID = id + + +storeEmpty :: String -> StoreRec c +storeEmpty name = StoreRecM $ tell [return [(BC.pack name, RecEmpty)]] + +storeMbEmpty :: String -> Maybe () -> StoreRec c +storeMbEmpty name = maybe (return ()) (const $ storeEmpty name) + +storeInt :: Integral a => String -> a -> StoreRec c +storeInt name x = StoreRecM $ tell [return [(BC.pack name, RecInt $ toInteger x)]] + +storeMbInt :: Integral a => String -> Maybe a -> StoreRec c +storeMbInt name = maybe (return ()) (storeInt name) + +storeNum :: (Real a, Fractional a) => String -> a -> StoreRec c +storeNum name x = StoreRecM $ tell [return [(BC.pack name, RecNum $ toRational x)]] + +storeMbNum :: (Real a, Fractional a) => String -> Maybe a -> StoreRec c +storeMbNum name = maybe (return ()) (storeNum name) + +storeText :: StorableText a => String -> a -> StoreRec c +storeText name x = StoreRecM $ tell [return [(BC.pack name, RecText $ toText x)]] + +storeMbText :: StorableText a => String -> Maybe a -> StoreRec c +storeMbText name = maybe (return ()) (storeText name) + +storeBinary :: BA.ByteArrayAccess a => String -> a -> StoreRec c +storeBinary name x = StoreRecM $ tell [return [(BC.pack name, RecBinary $ BA.convert x)]] + +storeMbBinary :: BA.ByteArrayAccess a => String -> Maybe a -> StoreRec c +storeMbBinary name = maybe (return ()) (storeBinary name) + +storeDate :: StorableDate a => String -> a -> StoreRec c +storeDate name x = StoreRecM $ tell [return [(BC.pack name, RecDate $ toDate x)]] + +storeMbDate :: StorableDate a => String -> Maybe a -> StoreRec c +storeMbDate name = maybe (return ()) (storeDate name) + +storeUUID :: StorableUUID a => String -> a -> StoreRec c +storeUUID name x = StoreRecM $ tell [return [(BC.pack name, RecUUID $ toUUID x)]] + +storeMbUUID :: StorableUUID a => String -> Maybe a -> StoreRec c +storeMbUUID name = maybe (return ()) (storeUUID name) + +storeRef :: Storable a => StorageCompleteness c => String -> a -> StoreRec c +storeRef name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return [(BC.pack name, RecRef ref)] + +storeMbRef :: Storable a => StorageCompleteness c => String -> Maybe a -> StoreRec c +storeMbRef name = maybe (return ()) (storeRef name) + +storeRawRef :: StorageCompleteness c => String -> Ref -> StoreRec c +storeRawRef name ref = StoreRecM $ do + st <- ask + tell $ (:[]) $ do + ref' <- copyRef st ref + return [(BC.pack name, RecRef ref')] + +storeMbRawRef :: StorageCompleteness c => String -> Maybe Ref -> StoreRec c +storeMbRawRef name = maybe (return ()) (storeRawRef name) + +storeZRef :: (ZeroStorable a, StorageCompleteness c) => String -> a -> StoreRec c +storeZRef name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return $ if isZeroRef ref then [] + else [(BC.pack name, RecRef ref)] + + +loadBlob :: (ByteString -> a) -> Load a +loadBlob f = loadCurrentObject >>= \case + Blob x -> return $ f x + _ -> throwError "Expecting blob" + +loadRec :: LoadRec a -> Load a +loadRec (LoadRec lrec) = loadCurrentObject >>= \case + Rec rs -> do + ref <- loadCurrentRef + either throwError return $ runReaderT lrec (ref, rs) + _ -> throwError "Expecting record" + +loadZero :: a -> Load a +loadZero x = loadCurrentObject >>= \case + ZeroObject -> return x + _ -> throwError "Expecting zero" + + +loadEmpty :: String -> LoadRec () +loadEmpty name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbEmpty name + +loadMbEmpty :: String -> LoadRec (Maybe ()) +loadMbEmpty name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecEmpty) -> return (Just ()) + Just _ -> throwError $ "Expecting type int of record item '"++name++"'" + +loadInt :: Num a => String -> LoadRec a +loadInt name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbInt name + +loadMbInt :: Num a => String -> LoadRec (Maybe a) +loadMbInt name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecInt x) -> return (Just $ fromInteger x) + Just _ -> throwError $ "Expecting type int of record item '"++name++"'" + +loadNum :: (Real a, Fractional a) => String -> LoadRec a +loadNum name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbNum name + +loadMbNum :: (Real a, Fractional a) => String -> LoadRec (Maybe a) +loadMbNum name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecNum x) -> return (Just $ fromRational x) + Just _ -> throwError $ "Expecting type number of record item '"++name++"'" + +loadText :: StorableText a => String -> LoadRec a +loadText name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbText name + +loadMbText :: StorableText a => String -> LoadRec (Maybe a) +loadMbText name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecText x) -> Just <$> fromText x + Just _ -> throwError $ "Expecting type text of record item '"++name++"'" + +loadTexts :: StorableText a => String -> LoadRec [a] +loadTexts name = do + items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems + forM items $ \case RecText x -> fromText x + _ -> throwError $ "Expecting type text of record item '"++name++"'" + +loadBinary :: BA.ByteArray a => String -> LoadRec a +loadBinary name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbBinary name + +loadMbBinary :: BA.ByteArray a => String -> LoadRec (Maybe a) +loadMbBinary name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecBinary x) -> return $ Just $ BA.convert x + Just _ -> throwError $ "Expecting type binary of record item '"++name++"'" + +loadBinaries :: BA.ByteArray a => String -> LoadRec [a] +loadBinaries name = do + items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems + forM items $ \case RecBinary x -> return $ BA.convert x + _ -> throwError $ "Expecting type binary of record item '"++name++"'" + +loadDate :: StorableDate a => String -> LoadRec a +loadDate name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbDate name + +loadMbDate :: StorableDate a => String -> LoadRec (Maybe a) +loadMbDate name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecDate x) -> return $ Just $ fromDate x + Just _ -> throwError $ "Expecting type date of record item '"++name++"'" + +loadUUID :: StorableUUID a => String -> LoadRec a +loadUUID name = maybe (throwError $ "Missing record iteem '"++name++"'") return =<< loadMbUUID name + +loadMbUUID :: StorableUUID a => String -> LoadRec (Maybe a) +loadMbUUID name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecUUID x) -> return $ Just $ fromUUID x + Just _ -> throwError $ "Expecting type UUID of record item '"++name++"'" + +loadRawRef :: String -> LoadRec Ref +loadRawRef name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbRawRef name + +loadMbRawRef :: String -> LoadRec (Maybe Ref) +loadMbRawRef name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecRef x) -> return (Just x) + Just _ -> throwError $ "Expecting type ref of record item '"++name++"'" + +loadRawRefs :: String -> LoadRec [Ref] +loadRawRefs name = do + items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems + forM items $ \case RecRef x -> return x + _ -> throwError $ "Expecting type ref of record item '"++name++"'" + +loadRef :: Storable a => String -> LoadRec a +loadRef name = load <$> loadRawRef name + +loadMbRef :: Storable a => String -> LoadRec (Maybe a) +loadMbRef name = fmap load <$> loadMbRawRef name + +loadRefs :: Storable a => String -> LoadRec [a] +loadRefs name = map load <$> loadRawRefs name + +loadZRef :: ZeroStorable a => String -> LoadRec a +loadZRef name = loadMbRef name >>= \case + Nothing -> do Ref st _ <- loadRecCurrentRef + return $ fromZero st + Just x -> return x + + +type Stored a = Stored' Complete a + +instance Storable a => Storable (Stored a) where + store st = copyRef st . storedRef + store' (Stored _ x) = store' x + load' = Stored <$> loadCurrentRef <*> load' + +instance ZeroStorable a => ZeroStorable (Stored a) where + fromZero st = Stored (zeroRef st) $ fromZero st + +fromStored :: Stored a -> a +fromStored (Stored _ x) = x + +storedRef :: Stored a -> Ref +storedRef (Stored ref _) = ref + +wrappedStore :: MonadIO m => Storable a => Storage -> a -> m (Stored a) +wrappedStore st x = do ref <- liftIO $ store st x + return $ Stored ref x + +wrappedLoad :: Storable a => Ref -> Stored a +wrappedLoad ref = Stored ref (load ref) + +copyStored :: forall c c' m a. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => + Storage' c' -> Stored' c a -> m (LoadResult c (Stored' c' a)) +copyStored st (Stored ref' x) = liftIO $ returnLoadResult . fmap (flip Stored x) <$> copyRef' st ref' + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapStored :: (a -> b) -> Stored a -> Stored b +unsafeMapStored f (Stored ref x) = Stored ref (f x) + + +data StoreInfo = StoreInfo + { infoDate :: ZonedTime + , infoNote :: Maybe Text + } + deriving (Show) + +makeStoreInfo :: IO StoreInfo +makeStoreInfo = StoreInfo + <$> getZonedTime + <*> pure Nothing + +storeInfoRec :: StoreInfo -> StoreRec c +storeInfoRec info = do + storeDate "date" $ infoDate info + storeMbText "note" $ infoNote info + +loadInfoRec :: LoadRec StoreInfo +loadInfoRec = StoreInfo + <$> loadDate "date" + <*> loadMbText "note" + + +data History a = History StoreInfo (Stored a) (Maybe (StoredHistory a)) + deriving (Show) + +type StoredHistory a = Stored (History a) + +instance Storable a => Storable (History a) where + store' (History si x prev) = storeRec $ do + storeInfoRec si + storeMbRef "prev" prev + storeRef "item" x + + load' = loadRec $ History + <$> loadInfoRec + <*> loadRef "item" + <*> loadMbRef "prev" + +fromHistory :: StoredHistory a -> a +fromHistory = fromStored . storedFromHistory + +fromHistoryAt :: ZonedTime -> StoredHistory a -> Maybe a +fromHistoryAt zat = fmap (fromStored . snd) . listToMaybe . dropWhile ((at<) . zonedTimeToUTC . fst) . storedHistoryTimedList + where at = zonedTimeToUTC zat + +storedFromHistory :: StoredHistory a -> Stored a +storedFromHistory sh = let History _ item _ = fromStored sh + in item + +storedHistoryList :: StoredHistory a -> [Stored a] +storedHistoryList = map snd . storedHistoryTimedList + +storedHistoryTimedList :: StoredHistory a -> [(ZonedTime, Stored a)] +storedHistoryTimedList sh = let History hinfo item prev = fromStored sh + in (infoDate hinfo, item) : maybe [] storedHistoryTimedList prev + +beginHistory :: Storable a => Storage -> StoreInfo -> a -> IO (StoredHistory a) +beginHistory st si x = do sx <- wrappedStore st x + wrappedStore st $ History si sx Nothing + +modifyHistory :: Storable a => StoreInfo -> (a -> a) -> StoredHistory a -> IO (StoredHistory a) +modifyHistory si f prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st $ f $ fromHistory prev + wrappedStore st $ History si sx (Just prev) + + +showRatio :: Rational -> String +showRatio r = case decimalRatio r of + Just (n, 1) -> show n + Just (n', d) -> let n = abs n' + in (if n' < 0 then "-" else "") ++ show (n `div` d) ++ "." ++ + (concatMap (show.(`mod` 10).snd) $ reverse $ takeWhile ((>1).fst) $ zip (iterate (`div` 10) d) (iterate (`div` 10) (n `mod` d))) + Nothing -> show (numerator r) ++ "/" ++ show (denominator r) + +decimalRatio :: Rational -> Maybe (Integer, Integer) +decimalRatio r = do + let n = numerator r + d = denominator r + (c2, d') = takeFactors 2 d + (c5, d'') = takeFactors 5 d' + guard $ d'' == 1 + let m = if c2 > c5 then 5 ^ (c2 - c5) + else 2 ^ (c5 - c2) + return (n * m, d * m) + +takeFactors :: Integer -> Integer -> (Integer, Integer) +takeFactors f n | n `mod` f == 0 = let (c, n') = takeFactors f (n `div` f) + in (c+1, n') + | otherwise = (0, n) + +parseRatio :: ByteString -> Maybe Rational +parseRatio bs = case BC.groupBy ((==) `on` isNumber) bs of + (m:xs) | m == BC.pack "-" -> negate <$> positive xs + xs -> positive xs + where positive = \case + [bx] -> fromInteger . fst <$> BC.readInteger bx + [bx, op, by] -> do + (x, _) <- BC.readInteger bx + (y, _) <- BC.readInteger by + case BC.unpack op of + "." -> return $ (x % 1) + (y % (10 ^ BC.length by)) + "/" -> return $ x % y + _ -> Nothing + _ -> Nothing diff --git a/src/Erebos/Storage/Internal.hs b/src/Erebos/Storage/Internal.hs new file mode 100644 index 0000000..47d344f --- /dev/null +++ b/src/Erebos/Storage/Internal.hs @@ -0,0 +1,273 @@ +module Erebos.Storage.Internal where + +import Codec.Compression.Zlib + +import Control.Arrow +import Control.Concurrent +import Control.DeepSeq +import Control.Exception +import Control.Monad +import Control.Monad.Identity + +import Crypto.Hash + +import Data.Bits +import Data.ByteArray (ByteArray, ByteArrayAccess, ScrubbedBytes) +import qualified Data.ByteArray as BA +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import Data.Char +import Data.Function +import Data.Hashable +import qualified Data.HashTable.IO as HT +import Data.Kind +import Data.List +import Data.Map (Map) +import qualified Data.Map as M +import Data.UUID (UUID) + +import Foreign.Storable (peek) + +import System.Directory +import System.FilePath +import System.INotify (INotify) +import System.IO +import System.IO.Error +import System.IO.Unsafe (unsafePerformIO) +import System.Posix.Files +import System.Posix.IO + + +data Storage' c = Storage + { stBacking :: StorageBacking c + , stParent :: Maybe (Storage' Identity) + , stRefGeneration :: MVar (HT.BasicHashTable RefDigest Generation) + , stRefRoots :: MVar (HT.BasicHashTable RefDigest [RefDigest]) + } + +instance Eq (Storage' c) where + (==) = (==) `on` (stBacking &&& stParent) + +instance Show (Storage' c) where + show st@(Storage { stBacking = StorageDir { dirPath = path }}) = "dir" ++ showParentStorage st ++ ":" ++ path + show st@(Storage { stBacking = StorageMemory {} }) = "mem" ++ showParentStorage st + +showParentStorage :: Storage' c -> String +showParentStorage Storage { stParent = Nothing } = "" +showParentStorage Storage { stParent = Just st } = "@" ++ show st + +data StorageBacking c + = StorageDir { dirPath :: FilePath + , dirWatchers :: MVar ([(HeadTypeID, INotify)], WatchList c) + } + | StorageMemory { memHeads :: MVar [((HeadTypeID, HeadID), Ref' c)] + , memObjs :: MVar (Map RefDigest BL.ByteString) + , memKeys :: MVar (Map RefDigest ScrubbedBytes) + , memWatchers :: MVar (WatchList c) + } + deriving (Eq) + +newtype WatchID = WatchID Int + deriving (Eq, Ord, Num) + +data WatchList c = WatchList + { wlNext :: WatchID + , wlList :: [WatchListItem c] + } + +data WatchListItem c = WatchListItem + { wlID :: WatchID + , wlHead :: (HeadTypeID, HeadID) + , wlFun :: Ref' c -> IO () + } + + +newtype RefDigest = RefDigest (Digest Blake2b_256) + deriving (Eq, Ord, NFData, ByteArrayAccess) + +instance Show RefDigest where + show = BC.unpack . showRefDigest + +data Ref' c = Ref (Storage' c) RefDigest + +instance Eq (Ref' c) where + Ref _ d1 == Ref _ d2 = d1 == d2 + +instance Show (Ref' c) where + show ref@(Ref st _) = show st ++ ":" ++ BC.unpack (showRef ref) + +instance ByteArrayAccess (Ref' c) where + length (Ref _ dgst) = BA.length dgst + withByteArray (Ref _ dgst) = BA.withByteArray dgst + +instance Hashable RefDigest where + hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +instance Hashable (Ref' c) where + hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +refStorage :: Ref' c -> Storage' c +refStorage (Ref st _) = st + +refDigest :: Ref' c -> RefDigest +refDigest (Ref _ dgst) = dgst + +showRef :: Ref' c -> ByteString +showRef = showRefDigest . refDigest + +showRefDigestParts :: RefDigest -> (ByteString, ByteString) +showRefDigestParts x = (BC.pack "blake2", showHex x) + +showRefDigest :: RefDigest -> ByteString +showRefDigest = showRefDigestParts >>> \(alg, hex) -> alg <> BC.pack "#" <> hex + +readRefDigest :: ByteString -> Maybe RefDigest +readRefDigest x = case BC.split '#' x of + [alg, dgst] | BA.convert alg == BC.pack "blake2" -> + refDigestFromByteString =<< readHex @ByteString dgst + _ -> Nothing + +refDigestFromByteString :: ByteArrayAccess ba => ba -> Maybe RefDigest +refDigestFromByteString = fmap RefDigest . digestFromByteString + +hashToRefDigest :: BL.ByteString -> RefDigest +hashToRefDigest = RefDigest . hashFinalize . hashUpdates hashInit . BL.toChunks + +showHex :: ByteArrayAccess ba => ba -> ByteString +showHex = B.concat . map showHexByte . BA.unpack + where showHexChar x | x < 10 = x + o '0' + | otherwise = x + o 'a' - 10 + showHexByte x = B.pack [ showHexChar (x `div` 16), showHexChar (x `mod` 16) ] + o = fromIntegral . ord + +readHex :: ByteArray ba => ByteString -> Maybe ba +readHex = return . BA.concat <=< readHex' + where readHex' bs | B.null bs = Just [] + readHex' bs = do (bx, bs') <- B.uncons bs + (by, bs'') <- B.uncons bs' + x <- hexDigit bx + y <- hexDigit by + (B.singleton (x * 16 + y) :) <$> readHex' bs'' + hexDigit x | x >= o '0' && x <= o '9' = Just $ x - o '0' + | x >= o 'a' && x <= o 'z' = Just $ x - o 'a' + 10 + | otherwise = Nothing + o = fromIntegral . ord + + +newtype Generation = Generation Int + deriving (Eq, Show) + +data Head' c a = Head HeadID (Stored' c a) + deriving (Eq, Show) + +newtype HeadID = HeadID UUID + deriving (Eq, Ord, Show) + +newtype HeadTypeID = HeadTypeID UUID + deriving (Eq, Ord) + +data Stored' c a = Stored (Ref' c) a + deriving (Show) + +instance Eq (Stored' c a) where + Stored r1 _ == Stored r2 _ = refDigest r1 == refDigest r2 + +instance Ord (Stored' c a) where + compare (Stored r1 _) (Stored r2 _) = compare (refDigest r1) (refDigest r2) + +storedStorage :: Stored' c a -> Storage' c +storedStorage (Stored (Ref st _) _) = st + + +type Complete = Identity +type Partial = Either RefDigest + +class (Traversable compl, Monad compl) => StorageCompleteness compl where + type LoadResult compl a :: Type + returnLoadResult :: compl a -> LoadResult compl a + ioLoadBytes :: Ref' compl -> IO (compl BL.ByteString) + +instance StorageCompleteness Complete where + type LoadResult Complete a = a + returnLoadResult = runIdentity + ioLoadBytes ref@(Ref st dgst) = maybe (error $ "Ref not found in complete storage: "++show ref) Identity + <$> ioLoadBytesFromStorage st dgst + +instance StorageCompleteness Partial where + type LoadResult Partial a = Either RefDigest a + returnLoadResult = id + ioLoadBytes (Ref st dgst) = maybe (Left dgst) Right <$> ioLoadBytesFromStorage st dgst + +unsafeStoreRawBytes :: Storage' c -> BL.ByteString -> IO (Ref' c) +unsafeStoreRawBytes st raw = do + let dgst = hashToRefDigest raw + case stBacking st of + StorageDir { dirPath = sdir } -> writeFileOnce (refPath sdir dgst) $ compress raw + StorageMemory { memObjs = tobjs } -> + dgst `deepseq` -- the TVar may be accessed when evaluating the data to be written + modifyMVar_ tobjs (return . M.insert dgst raw) + return $ Ref st dgst + +ioLoadBytesFromStorage :: Storage' c -> RefDigest -> IO (Maybe BL.ByteString) +ioLoadBytesFromStorage st dgst = loadCurrent st >>= + \case Just bytes -> return $ Just bytes + Nothing | Just parent <- stParent st -> ioLoadBytesFromStorage parent dgst + | otherwise -> return Nothing + where loadCurrent Storage { stBacking = StorageDir { dirPath = spath } } = handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ + Just . decompress . BL.fromChunks . (:[]) <$> (B.readFile $ refPath spath dgst) + loadCurrent Storage { stBacking = StorageMemory { memObjs = tobjs } } = M.lookup dgst <$> readMVar tobjs + +refPath :: FilePath -> RefDigest -> FilePath +refPath spath rdgst = intercalate "/" [spath, "objects", BC.unpack alg, pref, rest] + where (alg, dgst) = showRefDigestParts rdgst + (pref, rest) = splitAt 2 $ BC.unpack dgst + + +openLockFile :: FilePath -> IO Handle +openLockFile path = do + createDirectoryIfMissing True (takeDirectory path) + fd <- retry 10 $ + openFd path WriteOnly (Just $ unionFileModes ownerReadMode ownerWriteMode) (defaultFileFlags { exclusive = True }) + fdToHandle fd + where + retry :: Int -> IO a -> IO a + retry 0 act = act + retry n act = catchJust (\e -> if isAlreadyExistsError e then Just () else Nothing) + act (\_ -> threadDelay (100 * 1000) >> retry (n - 1) act) + +writeFileOnce :: FilePath -> BL.ByteString -> IO () +writeFileOnce file content = bracket (openLockFile locked) + hClose $ \h -> do + fileExist file >>= \case + True -> removeLink locked + False -> do BL.hPut h content + hFlush h + rename locked file + where locked = file ++ ".lock" + +writeFileChecked :: FilePath -> Maybe ByteString -> ByteString -> IO (Either (Maybe ByteString) ()) +writeFileChecked file prev content = bracket (openLockFile locked) + hClose $ \h -> do + (prev,) <$> fileExist file >>= \case + (Nothing, True) -> do + current <- B.readFile file + removeLink locked + return $ Left $ Just current + (Nothing, False) -> do B.hPut h content + hFlush h + rename locked file + return $ Right () + (Just expected, True) -> do + current <- B.readFile file + if current == expected then do B.hPut h content + hFlush h + rename locked file + return $ return () + else do removeLink locked + return $ Left $ Just current + (Just _, False) -> do + removeLink locked + return $ Left Nothing + where locked = file ++ ".lock" diff --git a/src/Erebos/Storage/Key.hs b/src/Erebos/Storage/Key.hs new file mode 100644 index 0000000..4a97976 --- /dev/null +++ b/src/Erebos/Storage/Key.hs @@ -0,0 +1,84 @@ +module Erebos.Storage.Key ( + KeyPair(..), + storeKey, loadKey, loadKeyMb, + moveKeys, +) where + +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except + +import Data.ByteArray +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.Map as M + +import System.Directory +import System.FilePath +import System.IO.Error + +import Erebos.Storage +import Erebos.Storage.Internal + +class Storable pub => KeyPair sec pub | sec -> pub, pub -> sec where + generateKeys :: Storage -> IO (sec, Stored pub) + keyGetPublic :: sec -> Stored pub + keyGetData :: sec -> ScrubbedBytes + keyFromData :: ScrubbedBytes -> Stored pub -> Maybe sec + + +keyFilePath :: KeyPair sec pub => FilePath -> Stored pub -> FilePath +keyFilePath sdir pkey = sdir </> "keys" </> (BC.unpack $ showRef $ storedRef pkey) + +storeKey :: KeyPair sec pub => sec -> IO () +storeKey key = do + let spub = keyGetPublic key + case stBacking $ storedStorage spub of + StorageDir { dirPath = dir } -> writeFileOnce (keyFilePath dir spub) (BL.fromStrict $ convert $ keyGetData key) + StorageMemory { memKeys = kstore } -> modifyMVar_ kstore $ return . M.insert (refDigest $ storedRef spub) (keyGetData key) + +loadKey :: (KeyPair sec pub, MonadIO m, MonadError String m) => Stored pub -> m sec +loadKey pub = maybe (throwError $ "secret key not found for " <> show (storedRef pub)) return =<< loadKeyMb pub + +loadKeyMb :: (KeyPair sec pub, MonadIO m) => Stored pub -> m (Maybe sec) +loadKeyMb spub = liftIO $ run $ storedStorage spub + where + run st = tryOneLevel (stBacking st) >>= \case + key@Just {} -> return key + Nothing | Just parent <- stParent st -> run parent + | otherwise -> return Nothing + tryOneLevel = \case + StorageDir { dirPath = dir } -> tryIOError (BC.readFile (keyFilePath dir spub)) >>= \case + Right kdata -> return $ keyFromData (convert kdata) spub + Left _ -> return Nothing + StorageMemory { memKeys = kstore } -> (flip keyFromData spub <=< M.lookup (refDigest $ storedRef spub)) <$> readMVar kstore + +moveKeys :: MonadIO m => Storage -> Storage -> m () +moveKeys from to = liftIO $ do + case (stBacking from, stBacking to) of + (StorageDir { dirPath = fromPath }, StorageDir { dirPath = toPath }) -> do + files <- listDirectory (fromPath </> "keys") + forM_ files $ \file -> do + renameFile (fromPath </> "keys" </> file) (toPath </> "keys" </> file) + + (StorageDir { dirPath = fromPath }, StorageMemory { memKeys = toKeys }) -> do + let move m file + | Just dgst <- readRefDigest (BC.pack file) = do + let path = fromPath </> "keys" </> file + key <- convert <$> BC.readFile path + removeFile path + return $ M.insert dgst key m + | otherwise = return m + files <- listDirectory (fromPath </> "keys") + modifyMVar_ toKeys $ \keys -> foldM move keys files + + (StorageMemory { memKeys = fromKeys }, StorageDir { dirPath = toPath }) -> do + modifyMVar_ fromKeys $ \keys -> do + forM_ (M.assocs keys) $ \(dgst, key) -> + writeFileOnce (toPath </> "keys" </> (BC.unpack $ showRefDigest dgst)) (BL.fromStrict $ convert key) + return M.empty + + (StorageMemory { memKeys = fromKeys }, StorageMemory { memKeys = toKeys }) -> do + modifyMVar_ fromKeys $ \fkeys -> do + modifyMVar_ toKeys $ return . M.union fkeys + return M.empty diff --git a/src/Erebos/Storage/List.hs b/src/Erebos/Storage/List.hs new file mode 100644 index 0000000..f0f8786 --- /dev/null +++ b/src/Erebos/Storage/List.hs @@ -0,0 +1,154 @@ +module Erebos.Storage.List ( + StoredList, + emptySList, fromSList, storedFromSList, + slistAdd, slistAddS, + -- TODO slistInsert, slistInsertS, + slistRemove, slistReplace, slistReplaceS, + -- TODO mapFromSList, updateOld, + + -- TODO StoreUpdate(..), + -- TODO withStoredListItem, withStoredListItemS, +) where + +import Data.List +import Data.Maybe +import qualified Data.Set as S + +import Erebos.Storage +import Erebos.Storage.Internal +import Erebos.Storage.Merge + +data List a = ListNil + | ListItem { listPrev :: [StoredList a] + , listItem :: Maybe (Stored a) + , listRemove :: Maybe (Stored (List a)) + } + +type StoredList a = Stored (List a) + +instance Storable a => Storable (List a) where + store' ListNil = storeZero + store' x@ListItem {} = storeRec $ do + mapM_ (storeRef "PREV") $ listPrev x + mapM_ (storeRef "item") $ listItem x + mapM_ (storeRef "remove") $ listRemove x + + load' = loadCurrentObject >>= \case + ZeroObject -> return ListNil + _ -> loadRec $ ListItem <$> loadRefs "PREV" + <*> loadMbRef "item" + <*> loadMbRef "remove" + +instance Storable a => ZeroStorable (List a) where + fromZero _ = ListNil + + +emptySList :: Storable a => Storage -> IO (StoredList a) +emptySList st = wrappedStore st ListNil + +groupsFromSLists :: forall a. Storable a => StoredList a -> [[Stored a]] +groupsFromSLists = helperSelect S.empty . (:[]) + where + helperSelect :: S.Set (StoredList a) -> [StoredList a] -> [[Stored a]] + helperSelect rs xxs | x:xs <- sort $ filterRemoved rs xxs = helper rs x xs + | otherwise = [] + + helper :: S.Set (StoredList a) -> StoredList a -> [StoredList a] -> [[Stored a]] + helper rs x xs + | ListNil <- fromStored x + = [] + + | Just rm <- listRemove (fromStored x) + , ans <- ancestors [x] + , (other, collision) <- partition (S.null . S.intersection ans . ancestors . (:[])) xs + , cont <- helperSelect (rs `S.union` ancestors [rm]) $ concatMap (listPrev . fromStored) (x : collision) ++ other + = case catMaybes $ map (listItem . fromStored) (x : collision) of + [] -> cont + xis -> xis : cont + + | otherwise = case listItem (fromStored x) of + Nothing -> helperSelect rs $ listPrev (fromStored x) ++ xs + Just xi -> [xi] : (helperSelect rs $ listPrev (fromStored x) ++ xs) + + filterRemoved :: S.Set (StoredList a) -> [StoredList a] -> [StoredList a] + filterRemoved rs = filter (S.null . S.intersection rs . ancestors . (:[])) + +fromSList :: Mergeable a => StoredList (Component a) -> [a] +fromSList = map merge . groupsFromSLists + +storedFromSList :: (Mergeable a, Storable a) => StoredList (Component a) -> IO [Stored a] +storedFromSList = mapM storeMerge . groupsFromSLists + +slistAdd :: Storable a => a -> StoredList a -> IO (StoredList a) +slistAdd x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistAddS sx prev + +slistAddS :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistAddS sx prev@(Stored (Ref st _) _) = wrappedStore st (ListItem [prev] (Just sx) Nothing) + +{- TODO +slistInsert :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistInsert after x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistInsertS after sx prev + +slistInsertS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistInsertS after sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem Nothing (findSListRef after prev) (Just sx) prev +-} + +slistRemove :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistRemove rm prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] Nothing (findSListRef rm prev) + +slistReplace :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistReplace rm x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistReplaceS rm sx prev + +slistReplaceS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistReplaceS rm sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] (Just sx) (findSListRef rm prev) + +findSListRef :: Stored a -> StoredList a -> Maybe (StoredList a) +findSListRef _ (Stored _ ListNil) = Nothing +findSListRef x cur | listItem (fromStored cur) == Just x = Just cur + | otherwise = listToMaybe $ catMaybes $ map (findSListRef x) $ listPrev $ fromStored cur + +{- TODO +mapFromSList :: Storable a => StoredList a -> Map RefDigest (Stored a) +mapFromSList list = helper list M.empty + where helper :: Storable a => StoredList a -> Map RefDigest (Stored a) -> Map RefDigest (Stored a) + helper (Stored _ ListNil) cur = cur + helper (Stored _ (ListItem (Just rref) _ (Just x) rest)) cur = + let rxref = case load rref of + ListItem _ _ (Just rx) _ -> sameType rx x $ storedRef rx + _ -> error "mapFromSList: malformed list" + in helper rest $ case M.lookup (refDigest $ storedRef x) cur of + Nothing -> M.insert (refDigest rxref) x cur + Just x' -> M.insert (refDigest rxref) x' cur + helper (Stored _ (ListItem _ _ _ rest)) cur = helper rest cur + sameType :: a -> a -> b -> b + sameType _ _ x = x + +updateOld :: Map RefDigest (Stored a) -> Stored a -> Stored a +updateOld m x = fromMaybe x $ M.lookup (refDigest $ storedRef x) m + + +data StoreUpdate a = StoreKeep + | StoreReplace a + | StoreRemove + +withStoredListItem :: (Storable a) => (a -> Bool) -> StoredList a -> (a -> IO (StoreUpdate a)) -> IO (StoredList a) +withStoredListItem p list f = withStoredListItemS (p . fromStored) list (suMap (wrappedStore $ storedStorage list) <=< f . fromStored) + where suMap :: Monad m => (a -> m b) -> StoreUpdate a -> m (StoreUpdate b) + suMap _ StoreKeep = return StoreKeep + suMap g (StoreReplace x) = return . StoreReplace =<< g x + suMap _ StoreRemove = return StoreRemove + +withStoredListItemS :: (Storable a) => (Stored a -> Bool) -> StoredList a -> (Stored a -> IO (StoreUpdate (Stored a))) -> IO (StoredList a) +withStoredListItemS p list f = do + case find p $ storedFromSList list of + Just sx -> f sx >>= \case StoreKeep -> return list + StoreReplace nx -> slistReplaceS sx nx list + StoreRemove -> slistRemove sx list + Nothing -> return list +-} diff --git a/src/Erebos/Storage/Merge.hs b/src/Erebos/Storage/Merge.hs new file mode 100644 index 0000000..7234b87 --- /dev/null +++ b/src/Erebos/Storage/Merge.hs @@ -0,0 +1,156 @@ +module Erebos.Storage.Merge ( + Mergeable(..), + merge, storeMerge, + + Generation, + showGeneration, + compareGeneration, generationMax, + storedGeneration, + + generations, + ancestors, + precedes, + filterAncestors, + storedRoots, + walkAncestors, + + findProperty, + findPropertyFirst, +) where + +import Control.Concurrent.MVar + +import Data.ByteString.Char8 qualified as BC +import Data.HashTable.IO qualified as HT +import Data.Kind +import Data.List +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S + +import System.IO.Unsafe (unsafePerformIO) + +import Erebos.Storage +import Erebos.Storage.Internal +import Erebos.Util + +class Storable (Component a) => Mergeable a where + type Component a :: Type + mergeSorted :: [Stored (Component a)] -> a + toComponents :: a -> [Stored (Component a)] + +instance Mergeable [Stored Object] where + type Component [Stored Object] = Object + mergeSorted = id + toComponents = id + +merge :: Mergeable a => [Stored (Component a)] -> a +merge [] = error "merge: empty list" +merge xs = mergeSorted $ filterAncestors xs + +storeMerge :: (Mergeable a, Storable a) => [Stored (Component a)] -> IO (Stored a) +storeMerge [] = error "merge: empty list" +storeMerge xs@(Stored ref _ : _) = wrappedStore (refStorage ref) $ mergeSorted $ filterAncestors xs + +previous :: Storable a => Stored a -> [Stored a] +previous (Stored ref _) = case load ref of + Rec items | Just (RecRef dref) <- lookup (BC.pack "SDATA") items + , Rec ditems <- load dref -> + map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ + map snd $ filter ((`elem` [ BC.pack "SPREV", BC.pack "SBASE" ]) . fst) ditems + + | otherwise -> + map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ + map snd $ filter ((`elem` [ BC.pack "PREV", BC.pack "BASE" ]) . fst) items + _ -> [] + + +nextGeneration :: [Generation] -> Generation +nextGeneration = foldl' helper (Generation 0) + where helper (Generation c) (Generation n) | c <= n = Generation (n + 1) + | otherwise = Generation c + +showGeneration :: Generation -> String +showGeneration (Generation x) = show x + +compareGeneration :: Generation -> Generation -> Maybe Ordering +compareGeneration (Generation x) (Generation y) = Just $ compare x y + +generationMax :: Storable a => [Stored a] -> Maybe (Stored a) +generationMax (x : xs) = Just $ snd $ foldl' helper (storedGeneration x, x) xs + where helper (mg, mx) y = let yg = storedGeneration y + in case compareGeneration mg yg of + Just LT -> (yg, y) + _ -> (mg, mx) +generationMax [] = Nothing + +storedGeneration :: Storable a => Stored a -> Generation +storedGeneration x = + unsafePerformIO $ withMVar (stRefGeneration $ refStorage $ storedRef x) $ \ht -> do + let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case + Just gen -> return gen + Nothing -> do + gen <- nextGeneration <$> mapM doLookup (previous y) + HT.insert ht (refDigest $ storedRef y) gen + return gen + doLookup x + + +generations :: Storable a => [Stored a] -> [Set (Stored a)] +generations = unfoldr gen . (,S.empty) + where gen (hs, cur) = case filter (`S.notMember` cur) $ previous =<< hs of + [] -> Nothing + added -> let next = foldr S.insert cur added + in Just (next, (added, next)) + +ancestors :: Storable a => [Stored a] -> Set (Stored a) +ancestors = last . (S.empty:) . generations + +precedes :: Storable a => Stored a -> Stored a -> Bool +precedes x y = not $ x `elem` filterAncestors [x, y] + +filterAncestors :: Storable a => [Stored a] -> [Stored a] +filterAncestors [x] = [x] +filterAncestors xs = let xs' = uniq $ sort xs + in helper xs' xs' + where helper remains walk = case generationMax walk of + Just x -> let px = previous x + remains' = filter (\r -> all (/=r) px) remains + in helper remains' $ uniq $ sort (px ++ filter (/=x) walk) + Nothing -> remains + +storedRoots :: Storable a => Stored a -> [Stored a] +storedRoots x = do + let st = refStorage $ storedRef x + unsafePerformIO $ withMVar (stRefRoots st) $ \ht -> do + let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case + Just roots -> return roots + Nothing -> do + roots <- case previous y of + [] -> return [refDigest $ storedRef y] + ps -> map (refDigest . storedRef) . filterAncestors . map (wrappedLoad @Object . Ref st) . concat <$> mapM doLookup ps + HT.insert ht (refDigest $ storedRef y) roots + return roots + map (wrappedLoad . Ref st) <$> doLookup x + +walkAncestors :: (Storable a, Monoid m) => (Stored a -> m) -> [Stored a] -> m +walkAncestors f = helper . sortBy cmp + where + helper (x : y : xs) | x == y = helper (x : xs) + helper (x : xs) = f x <> helper (mergeBy cmp (sortBy cmp (previous x)) xs) + helper [] = mempty + + cmp x y = case compareGeneration (storedGeneration x) (storedGeneration y) of + Just LT -> GT + Just GT -> LT + _ -> compare x y + +findProperty :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> [b] +findProperty sel = map (fromJust . sel . fromStored) . filterAncestors . (findPropHeads sel =<<) + +findPropertyFirst :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> Maybe b +findPropertyFirst sel = fmap (fromJust . sel . fromStored) . listToMaybe . filterAncestors . (findPropHeads sel =<<) + +findPropHeads :: forall a b. Storable a => (a -> Maybe b) -> Stored a -> [Stored a] +findPropHeads sel sobj | Just _ <- sel $ fromStored sobj = [sobj] + | otherwise = findPropHeads sel =<< previous sobj diff --git a/src/Erebos/Sync.hs b/src/Erebos/Sync.hs new file mode 100644 index 0000000..04b5f11 --- /dev/null +++ b/src/Erebos/Sync.hs @@ -0,0 +1,46 @@ +module Erebos.Sync ( + SyncService(..), +) where + +import Control.Monad +import Control.Monad.Reader + +import Data.List + +import Erebos.Identity +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data SyncService = SyncPacket (Stored SharedState) + +instance Service SyncService where + serviceID _ = mkServiceID "a4f538d0-4e50-4082-8e10-7e3ec2af175d" + + serviceHandler packet = do + let SyncPacket added = fromStored packet + pid <- asks svcPeerIdentity + self <- svcSelf + when (finalOwner pid `sameIdentity` finalOwner self) $ do + updateLocalHead_ $ \ls -> do + let current = sort $ lsShared $ fromStored ls + updated = filterAncestors (added : current) + if current /= updated + then mstore (fromStored ls) { lsShared = updated } + else return ls + + serviceNewPeer = notifyPeer . lsShared . fromStored =<< svcGetLocal + serviceStorageWatchers _ = (:[]) $ SomeStorageWatcher (lsShared . fromStored) notifyPeer + +instance Storable SyncService where + store' (SyncPacket smsg) = store' smsg + load' = SyncPacket <$> load' + +notifyPeer :: [Stored SharedState] -> ServiceHandler SyncService () +notifyPeer shared = do + pid <- asks svcPeerIdentity + self <- svcSelf + when (finalOwner pid `sameIdentity` finalOwner self) $ do + forM_ shared $ \sh -> + replyStoredRef =<< (mstore . SyncPacket) sh diff --git a/src/Erebos/Util.hs b/src/Erebos/Util.hs new file mode 100644 index 0000000..ffca9c7 --- /dev/null +++ b/src/Erebos/Util.hs @@ -0,0 +1,37 @@ +module Erebos.Util where + +uniq :: Eq a => [a] -> [a] +uniq (x:y:xs) | x == y = uniq (x:xs) + | otherwise = x : uniq (y:xs) +uniq xs = xs + +mergeBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeBy cmp (x : xs) (y : ys) = case cmp x y of + LT -> x : mergeBy cmp xs (y : ys) + EQ -> x : y : mergeBy cmp xs ys + GT -> y : mergeBy cmp (x : xs) ys +mergeBy _ xs [] = xs +mergeBy _ [] ys = ys + +mergeUniqBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeUniqBy cmp (x : xs) (y : ys) = case cmp x y of + LT -> x : mergeBy cmp xs (y : ys) + EQ -> x : mergeBy cmp xs ys + GT -> y : mergeBy cmp (x : xs) ys +mergeUniqBy _ xs [] = xs +mergeUniqBy _ [] ys = ys + +mergeUniq :: Ord a => [a] -> [a] -> [a] +mergeUniq = mergeUniqBy compare + +diffSorted :: Ord a => [a] -> [a] -> [a] +diffSorted (x:xs) (y:ys) | x < y = x : diffSorted xs (y:ys) + | x > y = diffSorted (x:xs) ys + | otherwise = diffSorted xs (y:ys) +diffSorted xs _ = xs + +intersectsSorted :: Ord a => [a] -> [a] -> Bool +intersectsSorted (x:xs) (y:ys) | x < y = intersectsSorted xs (y:ys) + | x > y = intersectsSorted (x:xs) ys + | otherwise = True +intersectsSorted _ _ = False diff --git a/attach.test b/test/attach.test index 33a1483..33a1483 100644 --- a/attach.test +++ b/test/attach.test diff --git a/contact.test b/test/contact.test index 438aa1f..438aa1f 100644 --- a/contact.test +++ b/test/contact.test diff --git a/discovery.test b/test/discovery.test index 2aaaf24..2aaaf24 100644 --- a/discovery.test +++ b/test/discovery.test diff --git a/message.test b/test/message.test index 307f11a..307f11a 100644 --- a/message.test +++ b/test/message.test diff --git a/storage.test b/test/storage.test index 9bf468e..9bf468e 100644 --- a/storage.test +++ b/test/storage.test diff --git a/sync.test b/test/sync.test index ea9595d..ea9595d 100644 --- a/sync.test +++ b/test/sync.test |