diff options
51 files changed, 9810 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8c573be --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.ghc.environment.* +cabal.project.local +dist-newstyle/ +.erebos +.test/ +.minici/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3d26fab --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,35 @@ +# Revision history for erebos + +## 0.1.5 -- 2024-07-16 + +* Public chatrooms for multiple participants +* Send keep-alive packets on idle connection +* Windows support + +## 0.1.4 -- 2024-06-11 + +* Added `/conversations` command to list and select conversations +* Added `/details` command for info about selected conversation +* Handle peer reconnection after its restart +* Support non-interactive mode without tty + +## 0.1.3 -- 2024-05-05 + +* Enable/disable network services by command-line parameters +* Tab-completion of command name +* Implemented streams in network protocol +* Compatibility with GHC up to 9.8 + +## 0.1.2 -- 2024-02-20 + +* Compatibility with GHC up to 9.6 +* Pruned unnecessary dependencies and fixed bounds + +## 0.1.1 -- 2024-02-18 + +* Added build flag to enable/disable ICE support with pjproject. +* Added `-V` command-line switch to show version. + +## 0.1.0 -- 2024-02-10 + +* First version. @@ -0,0 +1,30 @@ +Copyright (c) 2019, Roman Smrž + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of Roman Smrž nor the names of other + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..9535aab --- /dev/null +++ b/README.md @@ -0,0 +1,218 @@ +Erebos +====== + +The erebos binary provides simple CLI interface to the decentralized Erebos +messaging service. Local identity is created on the first run. Protocol and +services specification is being written at: + +[https://erebosprotocol.net/spec](https://erebosprotocol.net/spec) + +Erebos identity is based on locally stored cryptographic keys, all +communication is end-to-end encrypted. Multiple devices can be attached to the +same identity, after which they function interchangeably, without any one being +in any way "primary"; messages and other state data are then synchronized +automatically whenever the devices are able to connect with one another. + +Status +------ + +This is experimental implementation of yet unfinished specification, so +changes, especially in the library API, are expected. Storage format and +network protocol should generally remain backward compatible, with their +respective versions to be increased in case of incompatible changes, to allow +for interoperability even in that case. + +Usage +----- + +On the first run, local identity will be created for this device based on +interactive prompts for: + +`Name:` name of the user/owner, which will be shared among all devices +belonging to the same user; keep empty when initializing device that is going +to be attached to already existing identity on other device. + +`Device:` name describing current device, can be empty. + +After the initial setup, the erebos tool presents interactive prompt for +messages and commands. All commands start with the slash (`/`) character, +followed by command name and parameters (if any) separated by spaces. When +a conversation is selected, message to send there is entered directly on +the command prompt. + +The session can be terminated either by end-of-input (typically `Ctrl-d`) or +using the `/quit` command. + +### Example + +Start `erebos` CLI and create new identity: +``` +Name: Some Name +Device: First device +Some Name / First device +> +``` + +Add public peer: +``` +> /peer-add-public +[1] PEER NEW <unnamed> [37.221.243.57 29665] +[1] PEER UPD discovery1.erebosprotocol.net [37.221.243.57 29665] +``` + +Select the peer and send it a message, the public server just responds with +automatic echo message: +``` +> /1 +discovery1.erebosprotocol.net> hello +[18:55] Some Name: hello +[18:55] discovery1.erebosprotocol.net: Echo: hello +``` + +List chatrooms known to the peers: +``` +> /chatrooms +[1] Test chatroom +[2] Second test chatroom +``` + +Enter a chatroom and send a message there: +``` +> /1 +Test chatroom> Hi +Test chatroom [19:03] Some Name: Hi +``` + +### Messaging + +`/peers` +: List peers with direct network connection. Peers are discovered automatically + on local network or can be manually added. + +`/contacts` +: List known contacts (see below). + +`/conversations` +: List started conversations with contacts or other peers. + +`/<number>` +: Select conversation, contact or peer `<number>` based on the last + `/conversations`, `/contacts` or `/peers` output list. + +`<message>` +: Send `<message>` to selected conversation. + +`/history` +: Show message history of the selected conversation. + +`/details` +: Show information about the selected conversations, contact or peer. + +### Chatrooms + +Currently only public unmoderated chatrooms are supported, which means that any +network peer is allowed to read and post to the chatroom. Individual messages +are signed, so message author can not be forged. + +`/chatrooms` +: List known chatrooms. + +`/chatroom-create-public [<name>]` +: Create public unmoderated chatroom. Room name can be passed as command + argument or entered interactively. + +`/members` +: List members of the chatroom – usesers who sent any message or joined via the +`join` command. + +`/join` +: Join chatroom without sending text message. + +`/leave` +: Leave the chatroom. User will no longer be listed as a member and erebos tool + will no longer collect message of this chatroom. + +### Add contacts + +To ensure the identity of the contact and prevent man-in-the-middle attack, +generated verification code needs to be confirmed on both devices to add +contacts to contact list (similar to bluetooth device pairing). Before adding +new contact, list peers using `/peers` command and select one with `/<number>`. + +`/contacts` +: List already added contacts. + +`/contact-add` +: Add selected peer as contact. Six-digit verification code will be computed + based on peer keys, which will be displayed on both devices and needs to be + checked that both numbers are same. After that it needs to be confirmed using + `/contact-accept` to finish the process. + +`/contact-accept` +: Confirm that displayed verification codes are same on both devices and add + the selected peer as contact. The side, which did not initiate the contact + adding process, needs to select the corresponding peer with `/<number>` + command first. + +`/contact-reject` +: Reject contact request or verification code of selected peer. + +### Attach other devices + +Multiple devices can be attached to single identity to be used by the same +user. After the attachment process completes the roles of the devices are +equivalent, both can send and receive messages independently and those +messages, along with any other sate data, are synchronized automatically +whenever the devices can connect to each other. + +The attachment process and underlying protocol is very similar to the contact +adding described above, so also generates verification code based on peer keys +that needs to be checked and confirmed on both devices to avoid potential +man-in-the-middle attack. + +Before attaching device, list peers using `/peers` command and select the +target device with `/<number>`. + +`/attach` +: Attach current device to the selected peer. After the process completes the + owner of the selected peer will become owner of this device as well. + Six-digit verification code will be displayed on both devices and the user + needs to check that both are the same before confirmation using the + `/attach-accept` command. + +`/attach-accept` +: Confirm that displayed verification codes are same on both devices and + complete the attachment process (or wait for the confirmation on the peer + device). The side, which did not initiate the attachment process, needs to + select the corresponding peer with `/<number>` command first. + +`/attach-reject` +: Reject device attachment request or verification code of selected peer. + +### Other + +`/peer-add <host> [<port>]` +: Manually add network peer with given hostname or IP address. + +`/peer-add-public` +: Add known public network peer(s). + +`/peer-drop` +: Drop the currently selected peer. Afterwards, the connection can be + re-established by either side. + +`/update-identity` +: Interactively update current identity information + +`/quit` +: Quit the erebos tool. + + +Storage +------- + +Data are by default stored within `.erebos` subdirectory of the current working +directory. This can be overriden by `EREBOS_DIR` environment variable. + +Private keys are currently stored in plaintext under the `keys` subdirectory of +the erebos directory. diff --git a/Setup.hs b/Setup.hs new file mode 100644 index 0000000..9a994af --- /dev/null +++ b/Setup.hs @@ -0,0 +1,2 @@ +import Distribution.Simple +main = defaultMain diff --git a/erebos-tester.yaml b/erebos-tester.yaml new file mode 100644 index 0000000..a44f080 --- /dev/null +++ b/erebos-tester.yaml @@ -0,0 +1 @@ +tests: test/**/*.test diff --git a/erebos.cabal b/erebos.cabal new file mode 100644 index 0000000..347d785 --- /dev/null +++ b/erebos.cabal @@ -0,0 +1,201 @@ +Cabal-Version: 3.0 + +Name: erebos +Version: 0.1.5 +Synopsis: Decentralized messaging and synchronization +Description: + Library and simple CLI interface implementing the Erebos identity + management, decentralized messaging and synchronization protocol, along + with local storage. + . + Erebos identity is based on locally stored cryptographic keys, all + communication is end-to-end encrypted. Multiple devices can be attached to + the same identity, after which they function interchangeably, without any + one being in any way "primary"; messages and other state data are then + synchronized automatically whenever the devices are able to connect with + one another. + . + See README for usage of the CLI tool. +License: BSD-3-Clause +License-File: LICENSE +Homepage: https://erebosprotocol.net/erebos +Author: Roman Smrž <roman.smrz@seznam.cz> +Maintainer: roman.smrz@seznam.cz +Category: Network +Stability: experimental +Build-type: Simple +Extra-Doc-Files: + README.md + CHANGELOG.md +Extra-Source-Files: + src/Erebos/ICE/pjproject.h + +Flag ice + Description: Enable peer discovery with ICE support using pjproject + +Flag ci + description: Options for CI testing + default: False + manual: True + +source-repository head + type: git + location: git://erebosprotocol.net/erebos + +common common + ghc-options: + -Wall + -fdefer-typed-holes + + if flag(ci) + ghc-options: + -Werror + -- sometimes needed for backward/forward compatibility: + -Wno-error=unused-imports + + build-depends: + base ^>= { 4.15, 4.16, 4.17, 4.18, 4.19, 4.20 }, + + default-extensions: + DefaultSignatures + ExistentialQuantification + FlexibleContexts + FlexibleInstances + FunctionalDependencies + GeneralizedNewtypeDeriving + ImportQualifiedPost + LambdaCase + MultiWayIf + RankNTypes + RecordWildCards + ScopedTypeVariables + StandaloneDeriving + TypeOperators + TupleSections + TypeApplications + TypeFamilies + TypeFamilyDependencies + + other-extensions: + CPP + ForeignFunctionInterface + OverloadedStrings + RecursiveDo + TemplateHaskell + UndecidableInstances + + if flag(ice) + cpp-options: -DENABLE_ICE_SUPPORT + +library + import: common + default-language: Haskell2010 + + hs-source-dirs: src + exposed-modules: + Erebos.Attach + Erebos.Channel + Erebos.Chatroom + Erebos.Contact + Erebos.Conversation + Erebos.Identity + Erebos.Message + Erebos.Network + Erebos.Network.Protocol + Erebos.Pairing + Erebos.PubKey + Erebos.Service + Erebos.Set + Erebos.State + Erebos.Storage + Erebos.Storage.Key + Erebos.Storage.Merge + Erebos.Sync + + -- Used by test tool: + Erebos.Storage.Internal + other-modules: + Erebos.Flow + Erebos.Storage.Platform + Erebos.Util + + c-sources: + src/Erebos/Network/ifaddrs.c + include-dirs: + src + + if flag(ice) + exposed-modules: + Erebos.Discovery + Erebos.ICE + c-sources: + src/Erebos/ICE/pjproject.c + include-dirs: + src/Erebos/ICE + includes: + src/Erebos/ICE/pjproject.h + build-tool-depends: c2hs:c2hs + pkgconfig-depends: libpjproject >= 2.9 + + build-depends: + async >=2.2 && <2.3, + binary >=0.8 && <0.11, + bytestring >=0.10 && <0.13, + clock >=0.8 && < 0.9, + containers >= 0.6 && <0.8, + cryptonite >=0.25 && <0.31, + deepseq >= 1.4 && <1.6, + directory >= 1.3 && <1.4, + filepath >=1.4 && <1.6, + fsnotify ^>= { 0.4 }, + hashable >=1.3 && <1.5, + hashtables >=1.2 && <1.4, + iproute >=1.7.12 && <1.8, + memory >=0.14 && <0.19, + mtl >=2.2 && <2.4, + network >= 3.1 && <3.2, + stm >=2.5 && <2.6, + text >= 1.2 && <2.2, + time >= 1.8 && <1.14, + uuid >=1.3 && <1.4, + zlib >=0.6 && <0.8 + + if os(windows) + hs-source-dirs: src/windows + build-depends: + Win32 ^>= { 2.14 }, + else + hs-source-dirs: src/unix + build-depends: + unix ^>= { 2.7, 2.8 }, + +executable erebos + import: common + default-language: Haskell2010 + hs-source-dirs: main + ghc-options: -threaded + + main-is: Main.hs + other-modules: + Paths_erebos + Test + Test.Service + Version + Version.Git + autogen-modules: + Paths_erebos + + build-depends: + bytestring, + cryptonite, + directory, + erebos, + haskeline >=0.7 && <0.9, + mtl, + network, + process >=1.6 && <1.7, + template-haskell ^>= { 2.17, 2.18, 2.19, 2.20, 2.21, 2.22 }, + text, + time, + transformers >= 0.5 && <0.7, + uuid, diff --git a/main/Main.hs b/main/Main.hs new file mode 100644 index 0000000..94c0418 --- /dev/null +++ b/main/Main.hs @@ -0,0 +1,904 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE OverloadedStrings #-} + +module Main (main) where + +import Control.Arrow (first) +import Control.Concurrent +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Trans.Maybe + +import Crypto.Random + +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import Data.Char +import Data.List +import Data.Maybe +import Data.Ord +import Data.Text (Text) +import Data.Text qualified as T +import Data.Text.Encoding qualified as T +import Data.Text.IO qualified as T +import Data.Time.Format +import Data.Time.LocalTime +import Data.Typeable + +import Network.Socket + +import System.Console.GetOpt +import System.Console.Haskeline +import System.Environment +import System.Exit +import System.IO + +import Erebos.Attach +import Erebos.Contact +import Erebos.Chatroom +import Erebos.Conversation +#ifdef ENABLE_ICE_SUPPORT +import Erebos.Discovery +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Message hiding (formatMessage) +import Erebos.Network +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Sync + +import Test +import Version + +data Options = Options + { optServer :: ServerOptions + , optServices :: [ServiceOption] + , optStorage :: StorageOption + , optChatroomAutoSubscribe :: Maybe Int + , optDmBotEcho :: Maybe Text + , optShowHelp :: Bool + , optShowVersion :: Bool + } + +data StorageOption = DefaultStorage + | FilesystemStorage FilePath + | MemoryStorage + +data ServiceOption = ServiceOption + { soptName :: String + , soptService :: SomeService + , soptEnabled :: Bool + , soptDescription :: String + } + +defaultOptions :: Options +defaultOptions = Options + { optServer = defaultServerOptions + , optServices = availableServices + , optStorage = DefaultStorage + , optChatroomAutoSubscribe = Nothing + , optDmBotEcho = Nothing + , optShowHelp = False + , optShowVersion = False + } + +availableServices :: [ServiceOption] +availableServices = + [ ServiceOption "attach" (someService @AttachService Proxy) + True "attach (to) other devices" + , ServiceOption "sync" (someService @SyncService Proxy) + True "synchronization with attached devices" + , ServiceOption "chatroom" (someService @ChatroomService Proxy) + True "chatrooms with multiple participants" + , ServiceOption "contact" (someService @ContactService Proxy) + True "create contacts with network peers" + , ServiceOption "dm" (someService @DirectMessage Proxy) + True "direct messages" +#ifdef ENABLE_ICE_SUPPORT + , ServiceOption "discovery" (someService @DiscoveryService Proxy) + True "peer discovery" +#endif + ] + +options :: [OptDescr (Options -> Options)] +options = + [ Option ['p'] ["port"] + (ReqArg (\p -> so $ \opts -> opts { serverPort = read p }) "<port>") + "local port to bind" + , Option ['s'] ["silent"] + (NoArg (so $ \opts -> opts { serverLocalDiscovery = False })) + "do not send announce packets for local discovery" + , Option [] [ "storage" ] + (ReqArg (\path -> \opts -> opts { optStorage = FilesystemStorage path }) "<path>") + "use storage in <path>" + , Option [] [ "memory-storage" ] + (NoArg (\opts -> opts { optStorage = MemoryStorage })) + "use memory storage" + , Option [] ["chatroom-auto-subscribe"] + (ReqArg (\count -> \opts -> opts { optChatroomAutoSubscribe = Just (read count) }) "<count>") + "automatically subscribe for up to <count> chatrooms" + , Option [] ["dm-bot-echo"] + (ReqArg (\prefix -> \opts -> opts { optDmBotEcho = Just (T.pack prefix) }) "<prefix>") + "automatically reply to direct messages with the same text prefixed with <prefix>" + , Option ['h'] ["help"] + (NoArg $ \opts -> opts { optShowHelp = True }) + "show this help and exit" + , Option ['V'] ["version"] + (NoArg $ \opts -> opts { optShowVersion = True }) + "show version and exit" + ] + where so f opts = opts { optServer = f $ optServer opts } + +servicesOptions :: [OptDescr (Options -> Options)] +servicesOptions = concatMap helper $ "all" : map soptName availableServices + where + helper name = + [ Option [] ["enable-" <> name] (NoArg $ so $ change name $ \sopt -> sopt { soptEnabled = True }) "" + , Option [] ["disable-" <> name] (NoArg $ so $ change name $ \sopt -> sopt { soptEnabled = False }) "" + ] + so f opts = opts { optServices = f $ optServices opts } + change :: String -> (ServiceOption -> ServiceOption) -> [ServiceOption] -> [ServiceOption] + change name f (s : ss) + | soptName s == name || name == "all" + = f s : change name f ss + | otherwise = s : change name f ss + change _ _ [] = [] + +main :: IO () +main = do + (opts, args) <- (getOpt RequireOrder (options ++ servicesOptions) <$> getArgs) >>= \case + (o, args, []) -> do + return (foldl (flip id) defaultOptions o, args) + (_, _, errs) -> do + progName <- getProgName + hPutStrLn stderr $ concat errs <> "Try `" <> progName <> " --help' for more information." + exitFailure + + st <- liftIO $ case optStorage opts of + DefaultStorage -> openStorage . fromMaybe "./.erebos" =<< lookupEnv "EREBOS_DIR" + FilesystemStorage path -> openStorage path + MemoryStorage -> memoryStorage + + case args of + ["cat-file", sref] -> do + readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> BL.putStr $ lazyLoadBytes ref + + ("cat-file" : objtype : srefs@(_:_)) -> do + sequence <$> (mapM (readRef st . BC.pack) srefs) >>= \case + Nothing -> error "ref does not exist" + Just refs -> case objtype of + "signed" -> forM_ refs $ \ref -> do + let signed = load ref :: Signed Object + BL.putStr $ lazyLoadBytes $ storedRef $ signedData signed + forM_ (signedSignature signed) $ \sig -> do + putStr $ "SIG " + BC.putStrLn $ showRef $ storedRef $ sigKey $ fromStored sig + "identity" -> case validateExtendedIdentityF (wrappedLoad <$> refs) of + Just identity -> do + let disp :: Identity m -> IO () + disp idt = do + maybe (return ()) (T.putStrLn . (T.pack "Name: " `T.append`)) $ idName idt + BC.putStrLn . (BC.pack "KeyId: " `BC.append`) . showRefDigest . refDigest . storedRef $ idKeyIdentity idt + BC.putStrLn . (BC.pack "KeyMsg: " `BC.append`) . showRefDigest . refDigest . storedRef $ idKeyMessage idt + case idOwner idt of + Nothing -> return () + Just owner -> do + mapM_ (putStrLn . ("OWNER " ++) . BC.unpack . showRefDigest . refDigest . storedRef) $ idExtDataF owner + disp owner + disp identity + Nothing -> putStrLn $ "Identity verification failed" + _ -> error $ "unknown object type '" ++ objtype ++ "'" + + ["show-generation", sref] -> readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> print $ storedGeneration (wrappedLoad ref :: Stored Object) + + ["update-identity"] -> either fail return <=< runExceptT $ do + runReaderT updateSharedIdentity =<< loadLocalStateHead st + + ("update-identity" : srefs) -> do + sequence <$> mapM (readRef st . BC.pack) srefs >>= \case + Nothing -> error "ref does not exist" + Just refs + | Just idt <- validateIdentityF $ map wrappedLoad refs -> do + BC.putStrLn . showRefDigest . refDigest . storedRef . idData =<< + (either fail return <=< runExceptT $ runReaderT (interactiveIdentityUpdate idt) st) + | otherwise -> error "invalid identity" + + ["test"] -> runTestTool st + + [] -> do + let header = "Usage: erebos [OPTION...]" + serviceDesc ServiceOption {..} = padService (" " <> soptName) <> soptDescription + + padTo n str = str <> replicate (n - length str) ' ' + padOpt = padTo 37 + padService = padTo 16 + + if | optShowHelp opts -> putStr $ usageInfo header options <> unlines + ( + [ padOpt " --enable-<service>" <> "enable network service <service>" + , padOpt " --disable-<service>" <> "disable network service <service>" + , padOpt " --enable-all" <> "enable all network services" + , padOpt " --disable-all" <> "disable all network services" + , "" + , "Available network services:" + ] ++ map serviceDesc availableServices + ) + | optShowVersion opts -> putStrLn versionLine + | otherwise -> interactiveLoop st opts + + (cmdname : _) -> do + hPutStrLn stderr $ "Unknown command `" <> cmdname <> "'" + exitFailure + + +inputSettings :: Settings IO +inputSettings = setComplete commandCompletion $ defaultSettings + +interactiveLoop :: Storage -> Options -> IO () +interactiveLoop st opts = runInputT inputSettings $ do + erebosHead <- liftIO $ loadLocalStateHead st + outputStrLn $ T.unpack $ displayIdentity $ headLocalIdentity erebosHead + + tui <- haveTerminalUI + extPrint <- getExternalPrint + let extPrintLn str = do + let str' = case reverse str of ('\n':_) -> str + _ -> str ++ "\n"; + extPrint $! str' -- evaluate str before calling extPrint to avoid blinking + + let getInputLinesTui eprompt = do + prompt <- case eprompt of + Left cstate -> do + pname <- case csContext cstate of + NoContext -> return "" + SelectedPeer peer -> peerIdentity peer >>= return . \case + PeerIdentityFull pid -> maybe "<unnamed>" T.unpack $ idName $ finalOwner pid + PeerIdentityRef wref _ -> "<" ++ BC.unpack (showRefDigest $ wrDigest wref) ++ ">" + PeerIdentityUnknown _ -> "<unknown>" + SelectedContact contact -> return $ T.unpack $ contactName contact + SelectedChatroom rstate -> return $ T.unpack $ fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate + SelectedConversation conv -> return $ T.unpack $ conversationName conv + return $ pname ++ "> " + Right prompt -> return prompt + Just input <- lift $ getInputLine prompt + case reverse input of + _ | all isSpace input -> getInputLinesTui eprompt + '\\':rest -> (reverse ('\n':rest) ++) <$> getInputLinesTui (Right ">> ") + _ -> return input + + getInputCommandTui cstate = do + input <- getInputLinesTui cstate + let (CommandM cmd, line) = case input of + '/':rest -> let (scmd, args) = dropWhile isSpace <$> span (\c -> isAlphaNum c || c == '-') rest + in if not (null scmd) && all isDigit scmd + then (cmdSelectContext, scmd) + else (fromMaybe (cmdUnknown scmd) $ lookup scmd commands, args) + _ -> (cmdSend, input) + return (cmd, line) + + getInputLinesPipe = do + lift (getInputLine "") >>= \case + Just input -> return input + Nothing -> liftIO $ forever $ threadDelay 100000000 + + getInputCommandPipe _ = do + input <- getInputLinesPipe + let (scmd, args) = dropWhile isSpace <$> span (\c -> isAlphaNum c || c == '-') input + let (CommandM cmd, line) = (fromMaybe (cmdUnknown scmd) $ lookup scmd commands, args) + return (cmd, line) + + let getInputCommand = if tui then getInputCommandTui . Left + else getInputCommandPipe + + _ <- liftIO $ do + tzone <- getCurrentTimeZone + watchReceivedMessages erebosHead $ \smsg -> do + let msg = fromStored smsg + extPrintLn $ formatDirectMessage tzone msg + case optDmBotEcho opts of + Nothing -> return () + Just prefix -> do + res <- runExceptT $ flip runReaderT erebosHead $ sendDirectMessage (msgFrom msg) (prefix <> msgText msg) + case res of + Right reply -> extPrintLn $ formatDirectMessage tzone $ fromStored reply + Left err -> extPrintLn $ "Failed to send dm echo: " <> err + + peers <- liftIO $ newMVar [] + contextOptions <- liftIO $ newMVar [] + chatroomSetVar <- liftIO $ newEmptyMVar + + let autoSubscribe = optChatroomAutoSubscribe opts + chatroomList = fromSetBy (comparing roomStateData) . lookupSharedValue . lsShared . headObject $ erebosHead + watched <- if isJust autoSubscribe || any roomStateSubscribe chatroomList + then fmap Just $ liftIO $ watchChatroomsForCli extPrintLn erebosHead chatroomSetVar contextOptions autoSubscribe + else return Nothing + + server <- liftIO $ do + startServer (optServer opts) erebosHead extPrintLn $ + map soptService $ filter soptEnabled $ optServices opts + + void $ liftIO $ forkIO $ void $ forever $ do + peer <- getNextPeerChange server + peerIdentity peer >>= \case + pid@(PeerIdentityFull _) -> do + dropped <- isPeerDropped peer + let shown = showPeer pid $ peerAddress peer + let update [] = ([(peer, shown)], (Nothing, "NEW")) + update ((p,s):ps) + | p == peer && dropped = (ps, (Nothing, "DEL")) + | p == peer = ((peer, shown) : ps, (Just s, "UPD")) + | otherwise = first ((p,s):) $ update ps + let ctxUpdate n [] = ([SelectedPeer peer], n) + ctxUpdate n (ctx:ctxs) + | SelectedPeer p <- ctx, p == peer = (ctx:ctxs, n) + | otherwise = first (ctx:) $ ctxUpdate (n + 1) ctxs + (op, updateType) <- modifyMVar peers (return . update) + let updateType' = if dropped then "DEL" else updateType + idx <- modifyMVar contextOptions (return . ctxUpdate (1 :: Int)) + when (Just shown /= op) $ extPrintLn $ "[" <> show idx <> "] PEER " <> updateType' <> " " <> shown + _ -> return () + + let process :: CommandState -> MaybeT (InputT IO) CommandState + process cstate = do + (cmd, line) <- getInputCommand cstate + h <- liftIO (reloadHead $ csHead cstate) >>= \case + Just h -> return h + Nothing -> do lift $ lift $ extPrintLn "current head deleted" + mzero + res <- liftIO $ runExceptT $ flip execStateT cstate { csHead = h } $ runReaderT cmd CommandInput + { ciServer = server + , ciLine = line + , ciPrint = extPrintLn + , ciOptions = opts + , ciPeers = liftIO $ modifyMVar peers $ \ps -> do + ps' <- filterM (fmap not . isPeerDropped . fst) ps + return (ps', ps') + , ciContextOptions = liftIO $ readMVar contextOptions + , ciSetContextOptions = \ctxs -> liftIO $ modifyMVar_ contextOptions $ const $ return ctxs + , ciContextOptionsVar = contextOptions + , ciChatroomSetVar = chatroomSetVar + } + case res of + Right cstate' + | csQuit cstate' -> mzero + | otherwise -> return cstate' + Left err -> do + lift $ lift $ extPrintLn $ "Error: " ++ err + return cstate + + let loop (Just cstate) = runMaybeT (process cstate) >>= loop + loop Nothing = return () + loop $ Just $ CommandState + { csHead = erebosHead + , csContext = NoContext +#ifdef ENABLE_ICE_SUPPORT + , csIceSessions = [] +#endif + , csIcePeer = Nothing + , csWatchChatrooms = watched + , csQuit = False + } + + +data CommandInput = CommandInput + { ciServer :: Server + , ciLine :: String + , ciPrint :: String -> IO () + , ciOptions :: Options + , ciPeers :: CommandM [(Peer, String)] + , ciContextOptions :: CommandM [CommandContext] + , ciSetContextOptions :: [CommandContext] -> Command + , ciContextOptionsVar :: MVar [ CommandContext ] + , ciChatroomSetVar :: MVar (Set ChatroomState) + } + +data CommandState = CommandState + { csHead :: Head LocalState + , csContext :: CommandContext +#ifdef ENABLE_ICE_SUPPORT + , csIceSessions :: [IceSession] +#endif + , csIcePeer :: Maybe Peer + , csWatchChatrooms :: Maybe WatchedHead + , csQuit :: Bool + } + +data CommandContext = NoContext + | SelectedPeer Peer + | SelectedContact Contact + | SelectedChatroom ChatroomState + | SelectedConversation Conversation + +newtype CommandM a = CommandM (ReaderT CommandInput (StateT CommandState (ExceptT String IO)) a) + deriving (Functor, Applicative, Monad, MonadReader CommandInput, MonadState CommandState, MonadError String) + +instance MonadFail CommandM where + fail = throwError + +instance MonadIO CommandM where + liftIO act = CommandM (liftIO (try act)) >>= \case + Left (e :: SomeException) -> throwError (show e) + Right x -> return x + +instance MonadRandom CommandM where + getRandomBytes = liftIO . getRandomBytes + +instance MonadStorage CommandM where + getStorage = gets $ headStorage . csHead + +instance MonadHead LocalState CommandM where + updateLocalHead f = do + h <- gets csHead + (Just h', x) <- maybe (fail "failed to reload head") (flip updateHead f) =<< reloadHead h + modify $ \s -> s { csHead = h' } + return x + +type Command = CommandM () + +getSelectedPeer :: CommandM Peer +getSelectedPeer = gets csContext >>= \case + SelectedPeer peer -> return peer + _ -> throwError "no peer selected" + +getSelectedChatroom :: CommandM ChatroomState +getSelectedChatroom = gets csContext >>= \case + SelectedChatroom rstate -> return rstate + _ -> throwError "no chatroom selected" + +getSelectedConversation :: CommandM Conversation +getSelectedConversation = gets csContext >>= \case + SelectedPeer peer -> peerIdentity peer >>= \case + PeerIdentityFull pid -> directMessageConversation $ finalOwner pid + _ -> throwError "incomplete peer identity" + SelectedContact contact -> case contactIdentity contact of + Just cid -> directMessageConversation cid + Nothing -> throwError "contact without erebos identity" + SelectedChatroom rstate -> + chatroomConversation rstate >>= \case + Just conv -> return conv + Nothing -> throwError "invalid chatroom" + SelectedConversation conv -> reloadConversation conv + _ -> throwError "no contact, peer or conversation selected" + +commands :: [(String, Command)] +commands = + [ ("history", cmdHistory) + , ("peers", cmdPeers) + , ("peer-add", cmdPeerAdd) + , ("peer-add-public", cmdPeerAddPublic) + , ("peer-drop", cmdPeerDrop) + , ("send", cmdSend) + , ("update-identity", cmdUpdateIdentity) + , ("attach", cmdAttach) + , ("attach-accept", cmdAttachAccept) + , ("attach-reject", cmdAttachReject) + , ("chatrooms", cmdChatrooms) + , ("chatroom-create-public", cmdChatroomCreatePublic) + , ("contacts", cmdContacts) + , ("contact-add", cmdContactAdd) + , ("contact-accept", cmdContactAccept) + , ("contact-reject", cmdContactReject) + , ("conversations", cmdConversations) + , ("details", cmdDetails) +#ifdef ENABLE_ICE_SUPPORT + , ("discovery-init", cmdDiscoveryInit) + , ("discovery", cmdDiscovery) + , ("ice-create", cmdIceCreate) + , ("ice-destroy", cmdIceDestroy) + , ("ice-show", cmdIceShow) + , ("ice-connect", cmdIceConnect) + , ("ice-send", cmdIceSend) +#endif + , ("join", cmdJoin) + , ("leave", cmdLeave) + , ("members", cmdMembers) + , ("select", cmdSelectContext) + , ("quit", cmdQuit) + ] + +commandCompletion :: CompletionFunc IO +commandCompletion = completeWordWithPrev Nothing [ ' ', '\t', '\n', '\r' ] $ curry $ \case + ([], '/':pref) -> return . map (simpleCompletion . ('/':)) . filter (pref `isPrefixOf`) $ sortedCommandNames + _ -> return [] + where + sortedCommandNames = sort $ map fst commands + + +cmdUnknown :: String -> Command +cmdUnknown cmd = liftIO $ putStrLn $ "Unknown command: " ++ cmd + +cmdPeers :: Command +cmdPeers = do + peers <- join $ asks ciPeers + set <- asks ciSetContextOptions + set $ map (SelectedPeer . fst) peers + forM_ (zip [1..] peers) $ \(i :: Int, (_, name)) -> do + liftIO $ putStrLn $ "[" ++ show i ++ "] " ++ name + +cmdPeerAdd :: Command +cmdPeerAdd = void $ do + server <- asks ciServer + (hostname, port) <- (words <$> asks ciLine) >>= \case + hostname:p:_ -> return (hostname, p) + [hostname] -> return (hostname, show discoveryPort) + [] -> throwError "missing peer address" + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just hostname) (Just port) + liftIO $ serverPeer server (addrAddress addr) + +cmdPeerAddPublic :: Command +cmdPeerAddPublic = do + server <- asks ciServer + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just "discovery1.erebosprotocol.net") (Just (show discoveryPort)) + void $ liftIO $ serverPeer server (addrAddress addr) + +cmdPeerDrop :: Command +cmdPeerDrop = do + dropPeer =<< getSelectedPeer + modify $ \s -> s { csContext = NoContext } + +showPeer :: PeerIdentity -> PeerAddress -> String +showPeer pidentity paddr = + let name = case pidentity of + PeerIdentityUnknown _ -> "<noid>" + PeerIdentityRef wref _ -> "<" ++ BC.unpack (showRefDigest $ wrDigest wref) ++ ">" + PeerIdentityFull pid -> T.unpack $ displayIdentity pid + in name ++ " [" ++ show paddr ++ "]" + +cmdJoin :: Command +cmdJoin = joinChatroom =<< getSelectedChatroom + +cmdLeave :: Command +cmdLeave = leaveChatroom =<< getSelectedChatroom + +cmdMembers :: Command +cmdMembers = do + Just room <- findChatroomByStateData . head . roomStateData =<< getSelectedChatroom + forM_ (chatroomMembers room) $ \x -> do + liftIO $ putStrLn $ maybe "<unnamed>" T.unpack $ idName x + + +cmdSelectContext :: Command +cmdSelectContext = do + n <- read <$> asks ciLine + join (asks ciContextOptions) >>= \ctxs -> if + | n > 0, (ctx : _) <- drop (n - 1) ctxs -> do + modify $ \s -> s { csContext = ctx } + case ctx of + SelectedChatroom rstate -> do + when (not (roomStateSubscribe rstate)) $ do + chatroomSetSubscribe (head $ roomStateData rstate) True + _ -> return () + | otherwise -> throwError "invalid index" + +cmdSend :: Command +cmdSend = void $ do + text <- asks ciLine + conv <- getSelectedConversation + sendMessage conv (T.pack text) >>= \case + Just msg -> do + tzone <- liftIO $ getCurrentTimeZone + liftIO $ putStrLn $ formatMessage tzone msg + Nothing -> return () + +cmdHistory :: Command +cmdHistory = void $ do + conv <- getSelectedConversation + case conversationHistory conv of + thread@(_:_) -> do + tzone <- liftIO $ getCurrentTimeZone + liftIO $ mapM_ (putStrLn . formatMessage tzone) $ reverse $ take 50 thread + [] -> do + liftIO $ putStrLn $ "<empty history>" + +cmdUpdateIdentity :: Command +cmdUpdateIdentity = void $ do + runReaderT updateSharedIdentity =<< gets csHead + +cmdAttach :: Command +cmdAttach = attachToOwner =<< getSelectedPeer + +cmdAttachAccept :: Command +cmdAttachAccept = attachAccept =<< getSelectedPeer + +cmdAttachReject :: Command +cmdAttachReject = attachReject =<< getSelectedPeer + +watchChatroomsForCli :: (String -> IO ()) -> Head LocalState -> MVar (Set ChatroomState) -> MVar [ CommandContext ] -> Maybe Int -> IO WatchedHead +watchChatroomsForCli eprint h chatroomSetVar contextVar autoSubscribe = do + subscribedNumVar <- newEmptyMVar + + let ctxUpdate updateType (idx :: Int) rstate = \case + SelectedChatroom rstate' : rest + | currentRoots <- filterAncestors (concatMap storedRoots $ roomStateData rstate) + , any ((`intersectsSorted` currentRoots) . storedRoots) $ roomStateData rstate' + -> do + eprint $ "[" <> show idx <> "] CHATROOM " <> updateType <> " " <> name + return (SelectedChatroom rstate : rest) + selected : rest + -> do + (selected : ) <$> ctxUpdate updateType (idx + 1) rstate rest + [] + -> do + eprint $ "[" <> show idx <> "] CHATROOM " <> updateType <> " " <> name + return [ SelectedChatroom rstate ] + where + name = maybe "<unnamed>" T.unpack $ roomName =<< roomStateRoom rstate + + watchChatrooms h $ \set -> \case + Nothing -> do + let chatroomList = fromSetBy (comparing roomStateData) set + (subscribed, notSubscribed) = partition roomStateSubscribe chatroomList + subscribedNum = length subscribed + + putMVar chatroomSetVar set + putMVar subscribedNumVar subscribedNum + + case autoSubscribe of + Nothing -> return () + Just num -> do + forM_ (take (num - subscribedNum) notSubscribed) $ \rstate -> do + (runExceptT $ flip runReaderT h $ chatroomSetSubscribe (head $ roomStateData rstate) True) >>= \case + Right () -> return () + Left err -> eprint err + + Just diff -> do + modifyMVar_ chatroomSetVar $ return . const set + forM_ diff $ \case + AddedChatroom rstate -> do + modifyMVar_ contextVar $ ctxUpdate "NEW" 1 rstate + modifyMVar_ subscribedNumVar $ return . if roomStateSubscribe rstate then (+ 1) else id + + RemovedChatroom rstate -> do + modifyMVar_ contextVar $ ctxUpdate "DEL" 1 rstate + modifyMVar_ subscribedNumVar $ return . if roomStateSubscribe rstate then subtract 1 else id + + UpdatedChatroom oldroom rstate -> do + when (any ((\rsd -> not (null (rsdRoom rsd))) . fromStored) (roomStateData rstate)) $ do + modifyMVar_ contextVar $ ctxUpdate "UPD" 1 rstate + when (any (not . null . rsdMessages . fromStored) (roomStateData rstate)) $ do + tzone <- getCurrentTimeZone + forM_ (reverse $ getMessagesSinceState rstate oldroom) $ \msg -> do + eprint $ concat $ + [ maybe "<unnamed>" T.unpack $ roomName =<< cmsgRoom msg + , formatTime defaultTimeLocale " [%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ cmsgTime msg + , maybe "<unnamed>" T.unpack $ idName $ cmsgFrom msg + , if cmsgLeave msg then " left" else "" + , maybe (if cmsgLeave msg then "" else " joined") ((": " ++) . T.unpack) $ cmsgText msg + ] + modifyMVar_ subscribedNumVar $ return + . (if roomStateSubscribe rstate then (+ 1) else id) + . (if roomStateSubscribe oldroom then subtract 1 else id) + +ensureWatchedChatrooms :: Command +ensureWatchedChatrooms = do + gets csWatchChatrooms >>= \case + Nothing -> do + eprint <- asks ciPrint + h <- gets csHead + chatroomSetVar <- asks ciChatroomSetVar + contextVar <- asks ciContextOptionsVar + autoSubscribe <- asks $ optChatroomAutoSubscribe . ciOptions + watched <- liftIO $ watchChatroomsForCli eprint h chatroomSetVar contextVar autoSubscribe + modify $ \s -> s { csWatchChatrooms = Just watched } + Just _ -> return () + +cmdChatrooms :: Command +cmdChatrooms = do + ensureWatchedChatrooms + chatroomSetVar <- asks ciChatroomSetVar + chatroomList <- fromSetBy (comparing roomStateData) <$> liftIO (readMVar chatroomSetVar) + set <- asks ciSetContextOptions + set $ map SelectedChatroom chatroomList + forM_ (zip [1..] chatroomList) $ \(i :: Int, rstate) -> do + liftIO $ putStrLn $ "[" ++ show i ++ "] " ++ maybe "<unnamed>" T.unpack (roomName =<< roomStateRoom rstate) + +cmdChatroomCreatePublic :: Command +cmdChatroomCreatePublic = do + name <- asks ciLine >>= \case + line | not (null line) -> return $ T.pack line + _ -> liftIO $ do + T.putStr $ T.pack "Name: " + hFlush stdout + T.getLine + + ensureWatchedChatrooms + void $ createChatroom + (if T.null name then Nothing else Just name) + Nothing + + +cmdContacts :: Command +cmdContacts = do + args <- words <$> asks ciLine + ehead <- gets csHead + let contacts = fromSetBy (comparing contactName) $ lookupSharedValue $ lsShared $ headObject ehead + verbose = "-v" `elem` args + set <- asks ciSetContextOptions + set $ map SelectedContact contacts + forM_ (zip [1..] contacts) $ \(i :: Int, c) -> liftIO $ do + T.putStrLn $ T.concat + [ "[", T.pack (show i), "] ", contactName c + , case contactIdentity c of + Just idt | cname <- displayIdentity idt + , cname /= contactName c + -> " (" <> cname <> ")" + _ -> "" + , if verbose then " " <> (T.unwords $ map (T.decodeUtf8 . showRef . storedRef) $ maybe [] idDataF $ contactIdentity c) + else "" + ] + +cmdContactAdd :: Command +cmdContactAdd = contactRequest =<< getSelectedPeer + +cmdContactAccept :: Command +cmdContactAccept = contactAccept =<< getSelectedPeer + +cmdContactReject :: Command +cmdContactReject = contactReject =<< getSelectedPeer + +cmdConversations :: Command +cmdConversations = do + conversations <- lookupConversations + set <- asks ciSetContextOptions + set $ map SelectedConversation conversations + forM_ (zip [1..] conversations) $ \(i :: Int, conv) -> do + liftIO $ putStrLn $ "[" ++ show i ++ "] " ++ T.unpack (conversationName conv) + +cmdDetails :: Command +cmdDetails = do + gets csContext >>= \case + SelectedPeer peer -> do + liftIO $ putStr $ unlines + [ "Network peer:" + , " " <> show (peerAddress peer) + ] + peerIdentity peer >>= \case + PeerIdentityUnknown _ -> liftIO $ do + putStrLn $ "unknown identity" + PeerIdentityRef wref _ -> liftIO $ do + putStrLn $ "Identity ref:" + putStrLn $ " " <> BC.unpack (showRefDigest $ wrDigest wref) + PeerIdentityFull pid -> printContactOrIdentityDetails pid + + SelectedContact contact -> do + printContactDetails contact + + SelectedChatroom rstate -> do + liftIO $ putStrLn $ "Chatroom: " <> (T.unpack $ fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate) + + SelectedConversation conv -> do + case conversationPeer conv of + Just pid -> printContactOrIdentityDetails pid + Nothing -> liftIO $ putStrLn $ "(conversation without peer)" + + NoContext -> liftIO $ putStrLn "nothing selected" + where + printContactOrIdentityDetails cid = do + contacts <- fromSetBy (comparing contactName) . lookupSharedValue . lsShared . fromStored <$> getLocalHead + case find (maybe False (sameIdentity cid) . contactIdentity) contacts of + Just contact -> printContactDetails contact + Nothing -> printIdentityDetails cid + + printContactDetails contact = liftIO $ do + putStrLn $ "Contact:" + prefix <- case contactCustomName contact of + Just name -> do + putStrLn $ " " <> T.unpack name + return $ Just "alias of" + Nothing -> do + return $ Nothing + + case contactIdentity contact of + Just cid -> do + printIdentityDetailsBody prefix cid + Nothing -> do + putStrLn $ " (without erebos identity)" + + printIdentityDetails identity = liftIO $ do + putStrLn $ "Identity:" + printIdentityDetailsBody Nothing identity + + printIdentityDetailsBody prefix identity = do + forM_ (zip (False : repeat True) $ unfoldOwners identity) $ \(owned, cpid) -> do + putStrLn $ unwords $ concat + [ [ " " ] + , if owned then [ "owned by" ] else maybeToList prefix + , [ maybe "<unnamed>" T.unpack (idName cpid) ] + , map (BC.unpack . showRefDigest . refDigest . storedRef) $ idExtDataF cpid + ] + +#ifdef ENABLE_ICE_SUPPORT + +cmdDiscoveryInit :: Command +cmdDiscoveryInit = void $ do + server <- asks ciServer + + (hostname, port) <- (words <$> asks ciLine) >>= return . \case + hostname:p:_ -> (hostname, p) + [hostname] -> (hostname, show discoveryPort) + [] -> ("discovery.erebosprotocol.net", show discoveryPort) + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just hostname) (Just port) + peer <- liftIO $ serverPeer server (addrAddress addr) + sendToPeer peer $ DiscoverySelf (T.pack "ICE") 0 + modify $ \s -> s { csIcePeer = Just peer } + +cmdDiscovery :: Command +cmdDiscovery = void $ do + Just peer <- gets csIcePeer + st <- getStorage + sref <- asks ciLine + eprint <- asks ciPrint + liftIO $ readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> do + res <- runExceptT $ sendToPeer peer $ DiscoverySearch ref + case res of + Right _ -> return () + Left err -> eprint err + +cmdIceCreate :: Command +cmdIceCreate = do + role <- asks ciLine >>= return . \case + 'm':_ -> PjIceSessRoleControlling + 's':_ -> PjIceSessRoleControlled + _ -> PjIceSessRoleUnknown + eprint <- asks ciPrint + sess <- liftIO $ iceCreate role $ eprint <=< iceShow + modify $ \s -> s { csIceSessions = sess : csIceSessions s } + +cmdIceDestroy :: Command +cmdIceDestroy = do + s:ss <- gets csIceSessions + modify $ \st -> st { csIceSessions = ss } + liftIO $ iceDestroy s + +cmdIceShow :: Command +cmdIceShow = do + sess <- gets csIceSessions + eprint <- asks ciPrint + liftIO $ forM_ (zip [1::Int ..] sess) $ \(i, s) -> do + eprint $ "[" ++ show i ++ "]" + eprint =<< iceShow s + +cmdIceConnect :: Command +cmdIceConnect = do + s:_ <- gets csIceSessions + server <- asks ciServer + let loadInfo = BC.getLine >>= \case line | BC.null line -> return [] + | otherwise -> (line:) <$> loadInfo + Right remote <- liftIO $ do + st <- memoryStorage + pst <- derivePartialStorage st + rbytes <- (BL.fromStrict . BC.unlines) <$> loadInfo + copyRef st =<< storeRawBytes pst (BL.fromChunks [ BC.pack "rec ", BC.pack (show (BL.length rbytes)), BC.singleton '\n' ] `BL.append` rbytes) + liftIO $ iceConnect s (load remote) $ void $ serverPeerIce server s + +cmdIceSend :: Command +cmdIceSend = void $ do + s:_ <- gets csIceSessions + server <- asks ciServer + liftIO $ serverPeerIce server s + +#endif + +cmdQuit :: Command +cmdQuit = modify $ \s -> s { csQuit = True } + + +intersectsSorted :: Ord a => [a] -> [a] -> Bool +intersectsSorted (x:xs) (y:ys) | x < y = intersectsSorted xs (y:ys) + | x > y = intersectsSorted (x:xs) ys + | otherwise = True +intersectsSorted _ _ = False diff --git a/main/Test.hs b/main/Test.hs new file mode 100644 index 0000000..c6448b8 --- /dev/null +++ b/main/Test.hs @@ -0,0 +1,785 @@ +module Test ( + runTestTool, +) where + +import Control.Arrow +import Control.Concurrent +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Crypto.Random + +import Data.Bool +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Foldable +import Data.Ord +import Data.Text (Text) +import Data.Text qualified as T +import Data.Text.Encoding +import Data.Text.IO qualified as T +import Data.Typeable +import Data.UUID qualified as U + +import Network.Socket + +import System.IO +import System.IO.Error + +import Erebos.Attach +import Erebos.Chatroom +import Erebos.Contact +import Erebos.Identity +import Erebos.Message +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Internal (unsafeStoreRawBytes) +import Erebos.Storage.Merge +import Erebos.Sync + +import Test.Service + + +data TestState = TestState + { tsHead :: Maybe (Head LocalState) + , tsServer :: Maybe RunningServer + , tsWatchedHeads :: [ ( Int, WatchedHead ) ] + , tsWatchedHeadNext :: Int + , tsWatchedLocalIdentity :: Maybe WatchedHead + , tsWatchedSharedIdentity :: Maybe WatchedHead + } + +data RunningServer = RunningServer + { rsServer :: Server + , rsPeers :: MVar (Int, [(Int, Peer)]) + , rsPeerThread :: ThreadId + } + +initTestState :: TestState +initTestState = TestState + { tsHead = Nothing + , tsServer = Nothing + , tsWatchedHeads = [] + , tsWatchedHeadNext = 1 + , tsWatchedLocalIdentity = Nothing + , tsWatchedSharedIdentity = Nothing + } + +data TestInput = TestInput + { tiOutput :: Output + , tiStorage :: Storage + , tiParams :: [Text] + } + + +runTestTool :: Storage -> IO () +runTestTool st = do + out <- newMVar () + let testLoop = getLineMb >>= \case + Just line -> do + case T.words line of + (cname:params) + | Just (CommandM cmd) <- lookup cname commands -> do + runReaderT cmd $ TestInput out st params + | otherwise -> fail $ "Unknown command '" ++ T.unpack cname ++ "'" + [] -> return () + testLoop + + Nothing -> return () + + runExceptT (evalStateT testLoop initTestState) >>= \case + Left x -> B.hPutStr stderr $ (`BC.snoc` '\n') $ BC.pack x + Right () -> return () + +getLineMb :: MonadIO m => m (Maybe Text) +getLineMb = liftIO $ catchIOError (Just <$> T.getLine) (\e -> if isEOFError e then return Nothing else ioError e) + +getLines :: MonadIO m => m [Text] +getLines = getLineMb >>= \case + Just line | not (T.null line) -> (line:) <$> getLines + _ -> return [] + +getHead :: CommandM (Head LocalState) +getHead = do + h <- maybe (fail "failed to reload head") return =<< maybe (fail "no current head") reloadHead =<< gets tsHead + modify $ \s -> s { tsHead = Just h } + return h + + +type Output = MVar () + +outLine :: Output -> String -> IO () +outLine mvar line = do + evaluate $ foldl' (flip seq) () line + withMVar mvar $ \() -> do + B.putStr $ (`BC.snoc` '\n') $ BC.pack line + hFlush stdout + +cmdOut :: String -> Command +cmdOut line = do + out <- asks tiOutput + liftIO $ outLine out line + + +getPeer :: Text -> CommandM Peer +getPeer spidx = do + Just RunningServer {..} <- gets tsServer + Just peer <- lookup (read $ T.unpack spidx) . snd <$> liftIO (readMVar rsPeers) + return peer + +getPeerIndex :: MVar (Int, [(Int, Peer)]) -> ServiceHandler (PairingService a) Int +getPeerIndex pmvar = do + peer <- asks svcPeer + maybe 0 fst . find ((==peer) . snd) . snd <$> liftIO (readMVar pmvar) + +pairingAttributes :: PairingResult a => proxy (PairingService a) -> Output -> MVar (Int, [(Int, Peer)]) -> String -> PairingAttributes a +pairingAttributes _ out peers prefix = PairingAttributes + { pairingHookRequest = return () + + , pairingHookResponse = \confirm -> do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-response", index, confirm] + + , pairingHookRequestNonce = \confirm -> do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-request", index, confirm] + + , pairingHookRequestNonceFailed = failed "nonce" + + , pairingHookConfirmedResponse = return () + , pairingHookConfirmedRequest = return () + + , pairingHookAcceptedResponse = do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-response-done", index] + + , pairingHookAcceptedRequest = do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-request-done", index] + + , pairingHookFailed = \case + PairingUserRejected -> failed "user" + PairingUnexpectedMessage pstate packet -> failed $ "unexpected " ++ strState pstate ++ " " ++ strPacket packet + PairingFailedOther str -> failed $ "other " ++ str + , pairingHookVerifyFailed = failed "verify" + , pairingHookRejected = failed "rejected" + } + where + failed :: PairingResult a => String -> ServiceHandler (PairingService a) () + failed detail = do + ptype <- svcGet >>= return . \case + OurRequest {} -> "response" + OurRequestConfirm {} -> "response" + OurRequestReady -> "response" + PeerRequest {} -> "request" + PeerRequestConfirm -> "request" + _ -> fail "unexpected pairing state" + + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ prefix ++ "-" ++ ptype ++ "-failed " ++ index ++ " " ++ detail + + strState :: PairingState a -> String + strState = \case + NoPairing -> "none" + OurRequest {} -> "our-request" + OurRequestConfirm {} -> "our-request-confirm" + OurRequestReady -> "our-request-ready" + PeerRequest {} -> "peer-request" + PeerRequestConfirm -> "peer-request-confirm" + PairingDone -> "done" + + strPacket :: PairingService a -> String + strPacket = \case + PairingRequest {} -> "request" + PairingResponse {} -> "response" + PairingRequestNonce {} -> "nonce" + PairingAccept {} -> "accept" + PairingReject -> "reject" + +directMessageAttributes :: Output -> DirectMessageAttributes +directMessageAttributes out = DirectMessageAttributes + { dmOwnerMismatch = afterCommit $ outLine out "dm-owner-mismatch" + } + +dmReceivedWatcher :: Output -> Stored DirectMessage -> IO () +dmReceivedWatcher out smsg = do + let msg = fromStored smsg + outLine out $ unwords + [ "dm-received" + , "from", maybe "<unnamed>" T.unpack $ idName $ msgFrom msg + , "text", T.unpack $ msgText msg + ] + + +newtype CommandM a = CommandM (ReaderT TestInput (StateT TestState (ExceptT String IO)) a) + deriving (Functor, Applicative, Monad, MonadIO, MonadReader TestInput, MonadState TestState, MonadError String) + +instance MonadFail CommandM where + fail = throwError + +instance MonadRandom CommandM where + getRandomBytes = liftIO . getRandomBytes + +instance MonadStorage CommandM where + getStorage = asks tiStorage + +instance MonadHead LocalState CommandM where + updateLocalHead f = do + Just h <- gets tsHead + (Just h', x) <- maybe (fail "failed to reload head") (flip updateHead f) =<< reloadHead h + modify $ \s -> s { tsHead = Just h' } + return x + +type Command = CommandM () + +commands :: [(Text, Command)] +commands = map (T.pack *** id) + [ ("store", cmdStore) + , ("stored-generation", cmdStoredGeneration) + , ("stored-roots", cmdStoredRoots) + , ("stored-set-add", cmdStoredSetAdd) + , ("stored-set-list", cmdStoredSetList) + , ("head-create", cmdHeadCreate) + , ("head-replace", cmdHeadReplace) + , ("head-watch", cmdHeadWatch) + , ("head-unwatch", cmdHeadUnwatch) + , ("create-identity", cmdCreateIdentity) + , ("start-server", cmdStartServer) + , ("stop-server", cmdStopServer) + , ("peer-add", cmdPeerAdd) + , ("peer-drop", cmdPeerDrop) + , ("peer-list", cmdPeerList) + , ("test-message-send", cmdTestMessageSend) + , ("shared-state-get", cmdSharedStateGet) + , ("shared-state-wait", cmdSharedStateWait) + , ("watch-local-identity", cmdWatchLocalIdentity) + , ("watch-shared-identity", cmdWatchSharedIdentity) + , ("update-local-identity", cmdUpdateLocalIdentity) + , ("update-shared-identity", cmdUpdateSharedIdentity) + , ("attach-to", cmdAttachTo) + , ("attach-accept", cmdAttachAccept) + , ("attach-reject", cmdAttachReject) + , ("contact-request", cmdContactRequest) + , ("contact-accept", cmdContactAccept) + , ("contact-reject", cmdContactReject) + , ("contact-list", cmdContactList) + , ("contact-set-name", cmdContactSetName) + , ("dm-send-peer", cmdDmSendPeer) + , ("dm-send-contact", cmdDmSendContact) + , ("dm-list-peer", cmdDmListPeer) + , ("dm-list-contact", cmdDmListContact) + , ("chatroom-create", cmdChatroomCreate) + , ("chatroom-list-local", cmdChatroomListLocal) + , ("chatroom-watch-local", cmdChatroomWatchLocal) + , ("chatroom-set-name", cmdChatroomSetName) + , ("chatroom-subscribe", cmdChatroomSubscribe) + , ("chatroom-unsubscribe", cmdChatroomUnsubscribe) + , ("chatroom-members", cmdChatroomMembers) + , ("chatroom-join", cmdChatroomJoin) + , ("chatroom-leave", cmdChatroomLeave) + , ("chatroom-message-send", cmdChatroomMessageSend) + ] + +cmdStore :: Command +cmdStore = do + st <- asks tiStorage + [otype] <- asks tiParams + ls <- getLines + + let cnt = encodeUtf8 $ T.unlines ls + ref <- liftIO $ unsafeStoreRawBytes st $ BL.fromChunks [encodeUtf8 otype, BC.singleton ' ', BC.pack (show $ B.length cnt), BC.singleton '\n', cnt] + cmdOut $ "store-done " ++ show (refDigest ref) + +cmdStoredGeneration :: Command +cmdStoredGeneration = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + cmdOut $ "stored-generation " ++ T.unpack tref ++ " " ++ showGeneration (storedGeneration $ wrappedLoad @Object ref) + +cmdStoredRoots :: Command +cmdStoredRoots = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + cmdOut $ "stored-roots " ++ T.unpack tref ++ concatMap ((' ':) . show . refDigest . storedRef) (storedRoots $ wrappedLoad @Object ref) + +cmdStoredSetAdd :: Command +cmdStoredSetAdd = do + st <- asks tiStorage + (item, set) <- asks tiParams >>= liftIO . mapM (readRef st . encodeUtf8) >>= \case + [Just iref, Just sref] -> return (wrappedLoad iref, loadSet @[Stored Object] sref) + [Just iref] -> return (wrappedLoad iref, emptySet) + _ -> fail "unexpected parameters" + set' <- storeSetAdd st [item] set + cmdOut $ "stored-set-add" ++ concatMap ((' ':) . show . refDigest . storedRef) (toComponents set') + +cmdStoredSetList :: Command +cmdStoredSetList = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + let items = fromSetBy compare $ loadSet @[Stored Object] ref + forM_ items $ \item -> do + cmdOut $ "stored-set-item" ++ concatMap ((' ':) . show . refDigest . storedRef) item + cmdOut $ "stored-set-done" + +cmdHeadCreate :: Command +cmdHeadCreate = do + [ ttid, tref ] <- asks tiParams + st <- asks tiStorage + Just tid <- return $ fromUUID <$> U.fromText ttid + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + + h <- storeHeadRaw st tid ref + cmdOut $ unwords $ [ "head-create-done", show (toUUID tid), show (toUUID h) ] + +cmdHeadReplace :: Command +cmdHeadReplace = do + [ ttid, thid, told, tnew ] <- asks tiParams + st <- asks tiStorage + Just tid <- return $ fmap fromUUID $ U.fromText ttid + Just hid <- return $ fmap fromUUID $ U.fromText thid + Just old <- liftIO $ readRef st (encodeUtf8 told) + Just new <- liftIO $ readRef st (encodeUtf8 tnew) + + replaceHeadRaw st tid hid old new >>= cmdOut . unwords . \case + Left Nothing -> [ "head-replace-fail", T.unpack ttid, T.unpack thid, T.unpack told, T.unpack tnew ] + Left (Just r) -> [ "head-replace-fail", T.unpack ttid, T.unpack thid, T.unpack told, T.unpack tnew, show (refDigest r) ] + Right _ -> [ "head-replace-done", T.unpack ttid, T.unpack thid, T.unpack told, T.unpack tnew ] + +cmdHeadWatch :: Command +cmdHeadWatch = do + [ ttid, thid ] <- asks tiParams + st <- asks tiStorage + Just tid <- return $ fmap fromUUID $ U.fromText ttid + Just hid <- return $ fmap fromUUID $ U.fromText thid + + out <- asks tiOutput + wid <- gets tsWatchedHeadNext + + watched <- liftIO $ watchHeadRaw st tid hid id $ \r -> do + outLine out $ unwords [ "head-watch-cb", show wid, show $ refDigest r ] + + modify $ \s -> s + { tsWatchedHeads = ( wid, watched ) : tsWatchedHeads s + , tsWatchedHeadNext = wid + 1 + } + + cmdOut $ unwords $ [ "head-watch-done", T.unpack ttid, T.unpack thid, show wid ] + +cmdHeadUnwatch :: Command +cmdHeadUnwatch = do + [ twid ] <- asks tiParams + let wid = read (T.unpack twid) + Just watched <- lookup wid <$> gets tsWatchedHeads + liftIO $ unwatchHead watched + cmdOut $ unwords [ "head-unwatch-done", show wid ] + +initTestHead :: Head LocalState -> Command +initTestHead h = do + _ <- liftIO . watchReceivedMessages h . dmReceivedWatcher =<< asks tiOutput + modify $ \s -> s { tsHead = Just h } + +loadTestHead :: CommandM (Head LocalState) +loadTestHead = do + st <- asks tiStorage + h <- loadHeads st >>= \case + h : _ -> return h + [] -> fail "no local head found" + initTestHead h + return h + +getOrLoadHead :: CommandM (Head LocalState) +getOrLoadHead = do + gets tsHead >>= \case + Just h -> return h + Nothing -> loadTestHead + +cmdCreateIdentity :: Command +cmdCreateIdentity = do + st <- asks tiStorage + names <- asks tiParams + + h <- liftIO $ do + Just identity <- if null names + then Just <$> createIdentity st Nothing Nothing + else foldrM (\n o -> Just <$> createIdentity st (Just n) o) Nothing names + + shared <- case names of + _:_:_ -> (:[]) <$> makeSharedStateUpdate st (Just $ finalOwner identity) [] + _ -> return [] + + storeHead st $ LocalState + { lsIdentity = idExtData identity + , lsShared = shared + } + initTestHead h + +cmdStartServer :: Command +cmdStartServer = do + out <- asks tiOutput + + h <- getOrLoadHead + rsPeers <- liftIO $ newMVar (1, []) + rsServer <- liftIO $ startServer defaultServerOptions h (B.hPutStr stderr . (`BC.snoc` '\n') . BC.pack) + [ someServiceAttr $ pairingAttributes (Proxy @AttachService) out rsPeers "attach" + , someServiceAttr $ pairingAttributes (Proxy @ContactService) out rsPeers "contact" + , someServiceAttr $ directMessageAttributes out + , someService @SyncService Proxy + , someService @ChatroomService Proxy + , someServiceAttr $ (defaultServiceAttributes Proxy) + { testMessageReceived = \otype len sref -> + liftIO $ outLine out $ unwords ["test-message-received", otype, len, sref] + } + ] + + rsPeerThread <- liftIO $ forkIO $ void $ forever $ do + peer <- getNextPeerChange rsServer + + let printPeer (idx, p) = do + params <- peerIdentity p >>= return . \case + PeerIdentityFull pid -> ("id":) $ map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners pid) + _ -> [ "addr", show (peerAddress p) ] + outLine out $ unwords $ [ "peer", show idx ] ++ params + + update (nid, []) = printPeer (nid, peer) >> return (nid + 1, [(nid, peer)]) + update cur@(nid, p:ps) | snd p == peer = printPeer p >> return cur + | otherwise = fmap (p:) <$> update (nid, ps) + + modifyMVar_ rsPeers update + + modify $ \s -> s { tsServer = Just RunningServer {..} } + +cmdStopServer :: Command +cmdStopServer = do + Just RunningServer {..} <- gets tsServer + liftIO $ do + killThread rsPeerThread + stopServer rsServer + modify $ \s -> s { tsServer = Nothing } + cmdOut "stop-server-done" + +cmdPeerAdd :: Command +cmdPeerAdd = do + Just RunningServer {..} <- gets tsServer + host:rest <- map T.unpack <$> asks tiParams + + let port = case rest of [] -> show discoveryPort + (p:_) -> p + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just host) (Just port) + void $ liftIO $ serverPeer rsServer (addrAddress addr) + +cmdPeerDrop :: Command +cmdPeerDrop = do + [spidx] <- asks tiParams + peer <- getPeer spidx + liftIO $ dropPeer peer + +cmdPeerList :: Command +cmdPeerList = do + Just RunningServer {..} <- gets tsServer + peers <- liftIO $ getCurrentPeerList rsServer + tpeers <- liftIO $ readMVar rsPeers + forM_ peers $ \peer -> do + Just (n, _) <- return $ find ((peer==).snd) . snd $ tpeers + mbpid <- peerIdentity peer + cmdOut $ unwords $ concat + [ [ "peer-list-item", show n ] + , [ "addr", show (peerAddress peer) ] + , case mbpid of PeerIdentityFull pid -> ("id":) $ map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners pid) + _ -> [] + ] + cmdOut "peer-list-done" + + +cmdTestMessageSend :: Command +cmdTestMessageSend = do + spidx : trefs <- asks tiParams + st <- asks tiStorage + Just refs <- liftIO $ fmap sequence $ mapM (readRef st . encodeUtf8) trefs + peer <- getPeer spidx + sendManyToPeer peer $ map (TestMessage . wrappedLoad) refs + cmdOut "test-message-send done" + +cmdSharedStateGet :: Command +cmdSharedStateGet = do + h <- getHead + cmdOut $ unwords $ "shared-state-get" : map (show . refDigest . storedRef) (lsShared $ headObject h) + +cmdSharedStateWait :: Command +cmdSharedStateWait = do + st <- asks tiStorage + out <- asks tiOutput + h <- getOrLoadHead + trefs <- asks tiParams + + liftIO $ do + mvar <- newEmptyMVar + w <- watchHeadWith h (lsShared . headObject) $ \cur -> do + mbobjs <- mapM (readRef st . encodeUtf8) trefs + case map wrappedLoad <$> sequence mbobjs of + Just objs | filterAncestors (cur ++ objs) == cur -> do + outLine out $ unwords $ "shared-state-wait" : map T.unpack trefs + void $ forkIO $ unwatchHead =<< takeMVar mvar + _ -> return () + putMVar mvar w + +cmdWatchLocalIdentity :: Command +cmdWatchLocalIdentity = do + h <- getOrLoadHead + Nothing <- gets tsWatchedLocalIdentity + + out <- asks tiOutput + w <- liftIO $ watchHeadWith h headLocalIdentity $ \idt -> do + outLine out $ unwords $ "local-identity" : map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners idt) + modify $ \s -> s { tsWatchedLocalIdentity = Just w } + +cmdWatchSharedIdentity :: Command +cmdWatchSharedIdentity = do + h <- getOrLoadHead + Nothing <- gets tsWatchedSharedIdentity + + out <- asks tiOutput + w <- liftIO $ watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \case + Just (idt :: ComposedIdentity) -> do + outLine out $ unwords $ "shared-identity" : map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners idt) + Nothing -> do + outLine out $ "shared-identity-failed" + modify $ \s -> s { tsWatchedSharedIdentity = Just w } + +cmdUpdateLocalIdentity :: Command +cmdUpdateLocalIdentity = do + [name] <- asks tiParams + updateLocalHead_ $ \ls -> do + Just identity <- return $ validateExtendedIdentity $ lsIdentity $ fromStored ls + let public = idKeyIdentity identity + + secret <- loadKey public + nidata <- maybe (error "created invalid identity") (return . idExtData) . validateExtendedIdentity =<< + mstore =<< sign secret =<< mstore . ExtendedIdentityData =<< return (emptyIdentityExtension $ idData identity) + { idePrev = toList $ idExtDataF identity + , ideName = Just name + } + mstore (fromStored ls) { lsIdentity = nidata } + +cmdUpdateSharedIdentity :: Command +cmdUpdateSharedIdentity = do + [name] <- asks tiParams + updateLocalHead_ $ updateSharedState_ $ \case + Nothing -> throwError "no existing shared identity" + Just identity -> do + let public = idKeyIdentity identity + secret <- loadKey public + uidentity <- mergeIdentity identity + maybe (error "created invalid identity") (return . Just . toComposedIdentity) . validateExtendedIdentity =<< + mstore =<< sign secret =<< mstore . ExtendedIdentityData =<< return (emptyIdentityExtension $ idData uidentity) + { idePrev = toList $ idExtDataF identity + , ideName = Just name + } + +cmdAttachTo :: Command +cmdAttachTo = do + [spidx] <- asks tiParams + attachToOwner =<< getPeer spidx + +cmdAttachAccept :: Command +cmdAttachAccept = do + [spidx] <- asks tiParams + attachAccept =<< getPeer spidx + +cmdAttachReject :: Command +cmdAttachReject = do + [spidx] <- asks tiParams + attachReject =<< getPeer spidx + +cmdContactRequest :: Command +cmdContactRequest = do + [spidx] <- asks tiParams + contactRequest =<< getPeer spidx + +cmdContactAccept :: Command +cmdContactAccept = do + [spidx] <- asks tiParams + contactAccept =<< getPeer spidx + +cmdContactReject :: Command +cmdContactReject = do + [spidx] <- asks tiParams + contactReject =<< getPeer spidx + +cmdContactList :: Command +cmdContactList = do + h <- getHead + let contacts = fromSetBy (comparing contactName) . lookupSharedValue . lsShared . headObject $ h + forM_ contacts $ \c -> do + r:_ <- return $ filterAncestors $ concatMap storedRoots $ toComponents c + cmdOut $ concat + [ "contact-list-item " + , show $ refDigest $ storedRef r + , " " + , T.unpack $ contactName c + , case contactIdentity c of Nothing -> ""; Just idt -> " " ++ T.unpack (displayIdentity idt) + ] + cmdOut "contact-list-done" + +getContact :: Text -> CommandM Contact +getContact cid = do + h <- getHead + let contacts = fromSetBy (comparing contactName) . lookupSharedValue . lsShared . headObject $ h + [contact] <- flip filterM contacts $ \c -> do + r:_ <- return $ filterAncestors $ concatMap storedRoots $ toComponents c + return $ T.pack (show $ refDigest $ storedRef r) == cid + return contact + +cmdContactSetName :: Command +cmdContactSetName = do + [cid, name] <- asks tiParams + contact <- getContact cid + updateLocalHead_ $ updateSharedState_ $ contactSetName contact name + cmdOut "contact-set-name-done" + +cmdDmSendPeer :: Command +cmdDmSendPeer = do + [spidx, msg] <- asks tiParams + PeerIdentityFull to <- peerIdentity =<< getPeer spidx + void $ sendDirectMessage to msg + +cmdDmSendContact :: Command +cmdDmSendContact = do + [cid, msg] <- asks tiParams + Just to <- contactIdentity <$> getContact cid + void $ sendDirectMessage to msg + +dmList :: Foldable f => Identity f -> Command +dmList peer = do + threads <- toThreadList . lookupSharedValue . lsShared . headObject <$> getHead + case find (sameIdentity peer . msgPeer) threads of + Just thread -> do + forM_ (reverse $ threadToList thread) $ \DirectMessage {..} -> cmdOut $ "dm-list-item" + <> " from " <> (maybe "<unnamed>" T.unpack $ idName msgFrom) + <> " text " <> (T.unpack msgText) + Nothing -> return () + cmdOut "dm-list-done" + +cmdDmListPeer :: Command +cmdDmListPeer = do + [spidx] <- asks tiParams + PeerIdentityFull to <- peerIdentity =<< getPeer spidx + dmList to + +cmdDmListContact :: Command +cmdDmListContact = do + [cid] <- asks tiParams + Just to <- contactIdentity <$> getContact cid + dmList to + +cmdChatroomCreate :: Command +cmdChatroomCreate = do + [name] <- asks tiParams + room <- createChatroom (Just name) Nothing + cmdOut $ unwords $ "chatroom-create-done" : chatroomInfo room + +getChatroomStateData :: Text -> CommandM (Stored ChatroomStateData) +getChatroomStateData tref = do + st <- asks tiStorage + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + return $ wrappedLoad ref + +cmdChatroomSetName :: Command +cmdChatroomSetName = do + [cid, name] <- asks tiParams + sdata <- getChatroomStateData cid + updateChatroomByStateData sdata (Just name) Nothing >>= \case + Just room -> cmdOut $ unwords $ "chatroom-set-name-done" : chatroomInfo room + Nothing -> cmdOut "chatroom-set-name-failed" + +cmdChatroomListLocal :: Command +cmdChatroomListLocal = do + [] <- asks tiParams + rooms <- listChatrooms + forM_ rooms $ \room -> do + cmdOut $ unwords $ "chatroom-list-item" : chatroomInfo room + cmdOut "chatroom-list-done" + +cmdChatroomWatchLocal :: Command +cmdChatroomWatchLocal = do + [] <- asks tiParams + h <- getHead + out <- asks tiOutput + void $ watchChatrooms h $ \_ -> \case + Nothing -> return () + Just diff -> forM_ diff $ \case + AddedChatroom room -> outLine out $ unwords $ "chatroom-watched-added" : chatroomInfo room + RemovedChatroom room -> outLine out $ unwords $ "chatroom-watched-removed" : chatroomInfo room + UpdatedChatroom oldroom room -> do + when (any ((\rsd -> not (null (rsdRoom rsd)) || not (null (rsdSubscribe rsd))) . fromStored) (roomStateData room)) $ do + outLine out $ unwords $ concat + [ [ "chatroom-watched-updated" ], chatroomInfo room + , [ "old" ], map (show . refDigest . storedRef) (roomStateData oldroom) + , [ "new" ], map (show . refDigest . storedRef) (roomStateData room) + ] + when (any (not . null . rsdMessages . fromStored) (roomStateData room)) $ do + forM_ (reverse $ getMessagesSinceState room oldroom) $ \msg -> do + outLine out $ unwords $ concat + [ [ "chatroom-message-new" ] + , [ show . refDigest . storedRef . head . filterAncestors . concatMap storedRoots . toComponents $ room ] + , [ "room", maybe "<unnamed>" T.unpack $ roomName =<< cmsgRoom msg ] + , [ "from", maybe "<unnamed>" T.unpack $ idName $ cmsgFrom msg ] + , if cmsgLeave msg then [ "leave" ] else [] + , maybe [] (("text":) . (:[]) . T.unpack) $ cmsgText msg + ] + +chatroomInfo :: ChatroomState -> [String] +chatroomInfo room = + [ show . refDigest . storedRef . head . filterAncestors . concatMap storedRoots . toComponents $ room + , maybe "<unnamed>" T.unpack $ roomName =<< roomStateRoom room + , "sub " <> bool "false" "true" (roomStateSubscribe room) + ] + +cmdChatroomSubscribe :: Command +cmdChatroomSubscribe = do + [ cid ] <- asks tiParams + to <- getChatroomStateData cid + void $ chatroomSetSubscribe to True + +cmdChatroomUnsubscribe :: Command +cmdChatroomUnsubscribe = do + [ cid ] <- asks tiParams + to <- getChatroomStateData cid + void $ chatroomSetSubscribe to False + +cmdChatroomMembers :: Command +cmdChatroomMembers = do + [ cid ] <- asks tiParams + Just chatroom <- findChatroomByStateData =<< getChatroomStateData cid + forM_ (chatroomMembers chatroom) $ \user -> do + cmdOut $ unwords [ "chatroom-members-item", maybe "<unnamed>" T.unpack $ idName user ] + cmdOut "chatroom-members-done" + +cmdChatroomJoin :: Command +cmdChatroomJoin = do + [ cid ] <- asks tiParams + joinChatroomByStateData =<< getChatroomStateData cid + cmdOut "chatroom-join-done" + +cmdChatroomLeave :: Command +cmdChatroomLeave = do + [ cid ] <- asks tiParams + leaveChatroomByStateData =<< getChatroomStateData cid + cmdOut "chatroom-leave-done" + +cmdChatroomMessageSend :: Command +cmdChatroomMessageSend = do + [cid, msg] <- asks tiParams + to <- getChatroomStateData cid + void $ sendChatroomMessageByStateData to msg diff --git a/main/Test/Service.hs b/main/Test/Service.hs new file mode 100644 index 0000000..1018e0d --- /dev/null +++ b/main/Test/Service.hs @@ -0,0 +1,36 @@ +module Test.Service ( + TestMessage(..), + TestMessageAttributes(..), +) where + +import Control.Monad.Reader + +import Data.ByteString.Lazy.Char8 qualified as BL + +import Erebos.Network +import Erebos.Service +import Erebos.Storage + +data TestMessage = TestMessage (Stored Object) + +data TestMessageAttributes = TestMessageAttributes + { testMessageReceived :: String -> String -> String -> ServiceHandler TestMessage () + } + +instance Storable TestMessage where + store' (TestMessage msg) = store' msg + load' = TestMessage <$> load' + +instance Service TestMessage where + serviceID _ = mkServiceID "cb46b92c-9203-4694-8370-8742d8ac9dc8" + + type ServiceAttributes TestMessage = TestMessageAttributes + defaultServiceAttributes _ = TestMessageAttributes (\_ _ _ -> return ()) + + serviceHandler smsg = do + let TestMessage sobj = fromStored smsg + case map BL.unpack $ BL.words $ BL.takeWhile (/='\n') $ serializeObject $ fromStored sobj of + [otype, len] -> do + cb <- asks $ testMessageReceived . svcAttributes + cb otype len (show $ refDigest $ storedRef sobj) + _ -> return () diff --git a/main/Version.hs b/main/Version.hs new file mode 100644 index 0000000..71af694 --- /dev/null +++ b/main/Version.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE TemplateHaskell #-} + +-- "Pattern match is redundant" warning can be generated based on template +-- haskell $$tGitVersion value +{-# OPTIONS_GHC -Wno-error=overlapping-patterns #-} + +module Version ( + versionLine, +) where + +import Paths_erebos (version) +import Data.Version (showVersion) +import Version.Git + +{-# NOINLINE versionLine #-} +versionLine :: String +versionLine = do + let ver = case $$tGitVersion of + Just gver + | 'v':v <- gver, not $ all (`elem` ('.': ['0'..'9'])) v + -> "git " <> gver + _ -> "version " <> showVersion version + in "Erebos CLI " <> ver diff --git a/main/Version/Git.hs b/main/Version/Git.hs new file mode 100644 index 0000000..2aae6e3 --- /dev/null +++ b/main/Version/Git.hs @@ -0,0 +1,31 @@ +module Version.Git ( + tGitVersion, +) where + +import Language.Haskell.TH +import Language.Haskell.TH.Syntax + +import System.Directory +import System.Exit +import System.Process + +tGitVersion :: Code Q (Maybe String) +tGitVersion = unsafeCodeCoerce $ do + let git args = do + (ExitSuccess, out, _) <- readCreateProcessWithExitCode + (proc "git" $ [ "--git-dir=./.git", "--work-tree=." ] ++ args) "" + return $ lines out + + mbver <- runIO $ do + doesPathExist "./.git" >>= \case + False -> return Nothing + True -> do + desc:_ <- git [ "describe", "--always", "--dirty= (dirty)" ] + files <- git [ "ls-files" ] + return $ Just (desc, files) + + case mbver of + Just (_, files) -> mapM_ addDependentFile files + Nothing -> return () + + lift (fst <$> mbver :: Maybe String) diff --git a/minici.yaml b/minici.yaml new file mode 100644 index 0000000..333878c --- /dev/null +++ b/minici.yaml @@ -0,0 +1,13 @@ +job build: + shell: + - cabal build -fci + - mkdir build + - cp $(cabal list-bin erebos) build/erebos + artifact erebos: + path: build/erebos + +job test: + uses: + - build.erebos + shell: + - EREBOS_TEST_TOOL='build/erebos test' erebos-tester -v diff --git a/src/Erebos/Attach.hs b/src/Erebos/Attach.hs new file mode 100644 index 0000000..bd2f521 --- /dev/null +++ b/src/Erebos/Attach.hs @@ -0,0 +1,123 @@ +module Erebos.Attach ( + AttachService, + attachToOwner, + attachAccept, + attachReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.ByteArray (ScrubbedBytes) +import Data.Maybe +import Data.Proxy +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Key + +type AttachService = PairingService AttachIdentity + +data AttachIdentity = AttachIdentity (Stored (Signed IdentityData)) [ScrubbedBytes] + +instance Storable AttachIdentity where + store' (AttachIdentity x keys) = storeRec $ do + storeRef "identity" x + mapM_ (storeBinary "skey") keys + + load' = loadRec $ AttachIdentity + <$> loadRef "identity" + <*> loadBinaries "skey" + +instance PairingResult AttachIdentity where + pairingServiceID _ = mkServiceID "4995a5f9-2d4d-48e9-ad3b-0bf1c2a1be7f" + + type PairingVerifiedResult AttachIdentity = (UnifiedIdentity, [ScrubbedBytes]) + + pairingVerifyResult (AttachIdentity sdata keys) = do + curid <- lsIdentity . fromStored <$> svcGetLocal + secret <- loadKey $ eiddKeyIdentity $ fromSigned curid + sdata' <- mstore =<< signAdd secret (fromStored sdata) + return $ do + guard $ iddKeyIdentity (fromSigned sdata) == + eiddKeyIdentity (fromSigned curid) + identity <- validateIdentity sdata' + guard $ iddPrev (fromSigned $ idData identity) == [eiddStoredBase curid] + return (identity, keys) + + pairingFinalizeRequest (identity, keys) = updateLocalHead_ $ \slocal -> do + let owner = finalOwner identity + st <- getStorage + pkeys <- mapM (copyStored st) [ idKeyIdentity owner, idKeyMessage owner ] + liftIO $ mapM_ storeKey $ catMaybes [ keyFromData sec pub | sec <- keys, pub <- pkeys ] + + identity' <- mergeIdentity $ updateIdentity [ lsIdentity $ fromStored slocal ] identity + shared <- makeSharedStateUpdate st (Just owner) (lsShared $ fromStored slocal) + mstore (fromStored slocal) + { lsIdentity = idExtData identity' + , lsShared = [ shared ] + } + + pairingFinalizeResponse = do + owner <- mergeSharedIdentity + pid <- asks svcPeerIdentity + secret <- loadKey $ idKeyIdentity owner + identity <- mstore =<< sign secret =<< mstore (emptyIdentityData $ idKeyIdentity pid) + { iddPrev = [idData pid], iddOwner = Just (idData owner) } + skeys <- map keyGetData . catMaybes <$> mapM loadKeyMb [ idKeyIdentity owner, idKeyMessage owner ] + return $ AttachIdentity identity skeys + + defaultPairingAttributes _ = PairingAttributes + { pairingHookRequest = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ " initiated" + + , pairingHookResponse = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach to " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + + , pairingHookRequestNonce = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + + , pairingHookRequestNonceFailed = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Failed attach from " ++ T.unpack (displayIdentity peer) + + , pairingHookConfirmedResponse = do + svcPrint $ "Confirmed peer, waiting for updated identity" + + , pairingHookConfirmedRequest = do + svcPrint $ "Attachment confirmed by peer" + + , pairingHookAcceptedResponse = do + svcPrint $ "Accepted updated identity" + + , pairingHookAcceptedRequest = do + svcPrint $ "Accepted new attached device, seding updated identity" + + , pairingHookVerifyFailed = do + svcPrint $ "Failed to verify new identity" + + , pairingHookRejected = do + svcPrint $ "Attachment rejected by peer" + + , pairingHookFailed = \_ -> do + svcPrint $ "Attachement failed" + } + +attachToOwner :: (MonadIO m, MonadError String m) => Peer -> m () +attachToOwner = pairingRequest @AttachIdentity Proxy + +attachAccept :: (MonadIO m, MonadError String m) => Peer -> m () +attachAccept = pairingAccept @AttachIdentity Proxy + +attachReject :: (MonadIO m, MonadError String m) => Peer -> m () +attachReject = pairingReject @AttachIdentity Proxy diff --git a/src/Erebos/Channel.hs b/src/Erebos/Channel.hs new file mode 100644 index 0000000..5f66637 --- /dev/null +++ b/src/Erebos/Channel.hs @@ -0,0 +1,175 @@ +module Erebos.Channel ( + Channel, + ChannelRequest, ChannelRequestData(..), + ChannelAccept, ChannelAcceptData(..), + + createChannelRequest, + acceptChannelRequest, + acceptedChannel, + + channelEncrypt, + channelDecrypt, +) where + +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Crypto.Cipher.ChaChaPoly1305 +import Crypto.Error + +import Data.Binary +import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, convert) +import Data.ByteArray qualified as BA +import Data.ByteString.Lazy qualified as BL +import Data.List + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Storage + +data Channel = Channel + { chPeers :: [Stored (Signed IdentityData)] + , chKey :: ScrubbedBytes + , chNonceFixedOur :: Bytes + , chNonceFixedPeer :: Bytes + , chCounterNextOut :: MVar Word64 + , chCounterNextIn :: MVar Word64 + } + +type ChannelRequest = Signed ChannelRequestData + +data ChannelRequestData = ChannelRequest + { crPeers :: [Stored (Signed IdentityData)] + , crKey :: Stored PublicKexKey + } + deriving (Show) + +type ChannelAccept = Signed ChannelAcceptData + +data ChannelAcceptData = ChannelAccept + { caRequest :: Stored ChannelRequest + , caKey :: Stored PublicKexKey + } + + +instance Storable ChannelRequestData where + store' cr = storeRec $ do + mapM_ (storeRef "peer") $ crPeers cr + storeRef "key" $ crKey cr + + load' = loadRec $ do + ChannelRequest + <$> loadRefs "peer" + <*> loadRef "key" + +instance Storable ChannelAcceptData where + store' ca = storeRec $ do + storeRef "req" $ caRequest ca + storeRef "key" $ caKey ca + + load' = loadRec $ do + ChannelAccept + <$> loadRef "req" + <*> loadRef "key" + + +keySize :: Int +keySize = 32 + +createChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest) +createChannelRequest self peer = do + (_, xpublic) <- liftIO . generateKeys =<< getStorage + skey <- loadKey $ idKeyMessage self + mstore =<< sign skey =<< mstore ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic } + +acceptChannelRequest :: (MonadStorage m, MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Channel) +acceptChannelRequest self peer req = do + case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of + Nothing -> throwError $ "invalid peers in channel request" + Just peers -> do + when (not $ any (self `sameIdentity`) peers) $ + throwError $ "self identity missing in channel request peers" + when (not $ any (peer `sameIdentity`) peers) $ + throwError $ "peer identity missing in channel request peers" + when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ + throwError $ "channel requent not signed by peer" + + (xsecret, xpublic) <- liftIO . generateKeys =<< getStorage + skey <- loadKey $ idKeyMessage self + acc <- mstore =<< sign skey =<< mstore ChannelAccept { caRequest = req, caKey = xpublic } + liftIO $ do + let chPeers = crPeers $ fromStored $ signedData $ fromStored req + chKey = BA.take keySize $ dhSecret xsecret $ + fromStored $ crKey $ fromStored $ signedData $ fromStored req + chNonceFixedOur = BA.pack [ 2, 0, 0, 0 ] + chNonceFixedPeer = BA.pack [ 1, 0, 0, 0 ] + chCounterNextOut <- newMVar 0 + chCounterNextIn <- newMVar 0 + + return (acc, Channel {..}) + +acceptedChannel :: (MonadIO m, MonadError String m) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel +acceptedChannel self peer acc = do + let req = caRequest $ fromStored $ signedData $ fromStored acc + case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of + Nothing -> throwError $ "invalid peers in channel accept" + Just peers -> do + when (not $ any (self `sameIdentity`) peers) $ + throwError $ "self identity missing in channel accept peers" + when (not $ any (peer `sameIdentity`) peers) $ + throwError $ "peer identity missing in channel accept peers" + when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $ + throwError $ "channel accept not signed by peer" + when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ + throwError $ "original channel request not signed by us" + + xsecret <- loadKey $ crKey $ fromStored $ signedData $ fromStored req + let chPeers = crPeers $ fromStored $ signedData $ fromStored req + chKey = BA.take keySize $ dhSecret xsecret $ + fromStored $ caKey $ fromStored $ signedData $ fromStored acc + chNonceFixedOur = BA.pack [ 1, 0, 0, 0 ] + chNonceFixedPeer = BA.pack [ 2, 0, 0, 0 ] + chCounterNextOut <- liftIO $ newMVar 0 + chCounterNextIn <- liftIO $ newMVar 0 + + return Channel {..} + + +channelEncrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelEncrypt Channel {..} plain = do + count <- liftIO $ modifyMVar chCounterNextOut $ \c -> return (c + 1, c) + let cbytes = convert $ BL.toStrict $ encode count + nonce = nonce8 chNonceFixedOur cbytes + state <- case initialize chKey =<< nonce of + CryptoPassed state -> return state + CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + + let (ctext, state') = encrypt plain state + tag = finalize state' + return (BA.concat [ convert $ BA.drop 7 cbytes, ctext, convert tag ], count) + +channelDecrypt :: (ByteArray ba, MonadIO m, MonadError String m) => Channel -> ba -> m (ba, Word64) +channelDecrypt Channel {..} body = do + when (BA.length body < 17) $ do + throwError $ "invalid encrypted data length" + + expectedCount <- liftIO $ readMVar chCounterNextIn + let countByte = body `BA.index` 0 + body' = BA.dropView body 1 + guessedCount = expectedCount - 128 + fromIntegral (countByte - fromIntegral expectedCount + 128 :: Word8) + nonce = nonce8 chNonceFixedPeer $ convert $ BL.toStrict $ encode guessedCount + blen = BA.length body' - 16 + ctext = BA.takeView body' blen + tag = BA.dropView body' blen + state <- case initialize chKey =<< nonce of + CryptoPassed state -> return state + CryptoFailed err -> throwError $ "failed to init chacha-poly1305 cipher: " <> show err + + let (plain, state') = decrypt (convert ctext) state + when (not $ tag `BA.constEq` finalize state') $ do + throwError $ "tag validation falied" + + liftIO $ modifyMVar_ chCounterNextIn $ return . max (guessedCount + 1) + return (plain, guessedCount) diff --git a/src/Erebos/Chatroom.hs b/src/Erebos/Chatroom.hs new file mode 100644 index 0000000..c8b5805 --- /dev/null +++ b/src/Erebos/Chatroom.hs @@ -0,0 +1,609 @@ +module Erebos.Chatroom ( + Chatroom(..), + ChatroomData(..), + validateChatroom, + + ChatroomState(..), + ChatroomStateData(..), + createChatroom, + updateChatroomByStateData, + listChatrooms, + findChatroomByRoomData, + findChatroomByStateData, + chatroomSetSubscribe, + chatroomMembers, + joinChatroom, joinChatroomByStateData, + leaveChatroom, leaveChatroomByStateData, + getMessagesSinceState, + + ChatroomSetChange(..), + watchChatrooms, + + ChatMessage, + cmsgFrom, cmsgReplyTo, cmsgTime, cmsgText, cmsgLeave, + cmsgRoom, cmsgRoomData, + ChatMessageData(..), + sendChatroomMessage, + sendChatroomMessageByStateData, + + ChatroomService(..), +) where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Data.Bool +import Data.Either +import Data.Foldable +import Data.Function +import Data.IORef +import Data.List +import Data.Maybe +import Data.Monoid +import Data.Ord +import Data.Set qualified as S +import Data.Text (Text) +import Data.Time + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Util + + +data ChatroomData = ChatroomData + { rdPrev :: [Stored (Signed ChatroomData)] + , rdName :: Maybe Text + , rdDescription :: Maybe Text + , rdKey :: Stored PublicKey + } + +data Chatroom = Chatroom + { roomData :: [Stored (Signed ChatroomData)] + , roomName :: Maybe Text + , roomDescription :: Maybe Text + , roomKey :: Stored PublicKey + } + +instance Storable ChatroomData where + store' ChatroomData {..} = storeRec $ do + mapM_ (storeRef "SPREV") rdPrev + storeMbText "name" rdName + storeMbText "description" rdDescription + storeRef "key" rdKey + + load' = loadRec $ do + rdPrev <- loadRefs "SPREV" + rdName <- loadMbText "name" + rdDescription <- loadMbText "description" + rdKey <- loadRef "key" + return ChatroomData {..} + +validateChatroom :: [Stored (Signed ChatroomData)] -> Except String Chatroom +validateChatroom roomData = do + when (null roomData) $ throwError "null data" + when (not $ getAll $ walkAncestors verifySignatures roomData) $ do + throwError "signature verification failed" + + let roomName = findPropertyFirst (rdName . fromStored . signedData) roomData + roomDescription = findPropertyFirst (rdDescription . fromStored . signedData) roomData + roomKey <- maybe (throwError "missing key") return $ + findPropertyFirst (Just . rdKey . fromStored . signedData) roomData + return Chatroom {..} + where + verifySignatures sdata = + let rdata = fromSigned sdata + required = concat + [ [ rdKey rdata ] + , map (rdKey . fromSigned) $ rdPrev rdata + ] + in All $ all (fromStored sdata `isSignedBy`) required + + +data ChatMessageData = ChatMessageData + { mdPrev :: [Stored (Signed ChatMessageData)] + , mdRoom :: [Stored (Signed ChatroomData)] + , mdFrom :: ComposedIdentity + , mdReplyTo :: Maybe (Stored (Signed ChatMessageData)) + , mdTime :: ZonedTime + , mdText :: Maybe Text + , mdLeave :: Bool + } + +data ChatMessage = ChatMessage + { cmsgData :: Stored (Signed ChatMessageData) + } + +validateSingleMessage :: Stored (Signed ChatMessageData) -> Maybe ChatMessage +validateSingleMessage sdata = do + guard $ fromStored sdata `isSignedBy` idKeyMessage (mdFrom (fromSigned sdata)) + return $ ChatMessage sdata + +cmsgFrom :: ChatMessage -> ComposedIdentity +cmsgFrom = mdFrom . fromSigned . cmsgData + +cmsgReplyTo :: ChatMessage -> Maybe ChatMessage +cmsgReplyTo = fmap ChatMessage . mdReplyTo . fromSigned . cmsgData + +cmsgTime :: ChatMessage -> ZonedTime +cmsgTime = mdTime . fromSigned . cmsgData + +cmsgText :: ChatMessage -> Maybe Text +cmsgText = mdText . fromSigned . cmsgData + +cmsgLeave :: ChatMessage -> Bool +cmsgLeave = mdLeave . fromSigned . cmsgData + +cmsgRoom :: ChatMessage -> Maybe Chatroom +cmsgRoom = either (const Nothing) Just . runExcept . validateChatroom . cmsgRoomData + +cmsgRoomData :: ChatMessage -> [ Stored (Signed ChatroomData) ] +cmsgRoomData = concat . findProperty ((\case [] -> Nothing; xs -> Just xs) . mdRoom . fromStored . signedData) . (: []) . cmsgData + +instance Storable ChatMessageData where + store' ChatMessageData {..} = storeRec $ do + mapM_ (storeRef "SPREV") mdPrev + mapM_ (storeRef "room") mdRoom + mapM_ (storeRef "from") $ idExtDataF mdFrom + storeMbRef "reply-to" mdReplyTo + storeDate "time" mdTime + storeMbText "text" mdText + when mdLeave $ storeEmpty "leave" + + load' = loadRec $ do + mdPrev <- loadRefs "SPREV" + mdRoom <- loadRefs "room" + mdFrom <- loadIdentity "from" + mdReplyTo <- loadMbRef "reply-to" + mdTime <- loadDate "time" + mdText <- loadMbText "text" + mdLeave <- isJust <$> loadMbEmpty "leave" + return ChatMessageData {..} + +threadToListSince :: [ Stored (Signed ChatMessageData) ] -> [ Stored (Signed ChatMessageData) ] -> [ ChatMessage ] +threadToListSince since thread = helper (S.fromList since) thread + where + helper :: S.Set (Stored (Signed ChatMessageData)) -> [Stored (Signed ChatMessageData)] -> [ChatMessage] + helper seen msgs + | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = + maybe id (:) (validateSingleMessage msg) $ + helper (S.insert msg seen) (msgs' ++ mdPrev (fromSigned msg)) + | otherwise = [] + cmpView msg = (zonedTimeToUTC $ mdTime $ fromSigned msg, msg) + +sendChatroomMessage + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => ChatroomState -> Text -> m () +sendChatroomMessage rstate msg = sendChatroomMessageByStateData (head $ roomStateData rstate) msg + +sendChatroomMessageByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> Text -> m () +sendChatroomMessageByStateData lookupData msg = sendRawChatroomMessageByStateData lookupData Nothing (Just msg) False + +sendRawChatroomMessageByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> Maybe (Stored (Signed ChatMessageData)) -> Maybe Text -> Bool -> m () +sendRawChatroomMessageByStateData lookupData mdReplyTo mdText mdLeave = void $ findAndUpdateChatroomState $ \cstate -> do + guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate + Just $ do + mdFrom <- finalOwner . localIdentity . fromStored <$> getLocalHead + secret <- loadKey $ idKeyMessage mdFrom + mdTime <- liftIO getZonedTime + let mdPrev = roomStateMessageData cstate + mdRoom = if null (roomStateMessageData cstate) + then maybe [] roomData (roomStateRoom cstate) + else [] + + mdata <- mstore =<< sign secret =<< mstore ChatMessageData {..} + mergeSorted . (:[]) <$> mstore ChatroomStateData + { rsdPrev = roomStateData cstate + , rsdRoom = [] + , rsdSubscribe = Just True + , rsdMessages = [ mdata ] + } + + +data ChatroomStateData = ChatroomStateData + { rsdPrev :: [Stored ChatroomStateData] + , rsdRoom :: [Stored (Signed ChatroomData)] + , rsdSubscribe :: Maybe Bool + , rsdMessages :: [Stored (Signed ChatMessageData)] + } + +data ChatroomState = ChatroomState + { roomStateData :: [Stored ChatroomStateData] + , roomStateRoom :: Maybe Chatroom + , roomStateMessageData :: [Stored (Signed ChatMessageData)] + , roomStateSubscribe :: Bool + , roomStateMessages :: [ChatMessage] + } + +instance Storable ChatroomStateData where + store' ChatroomStateData {..} = storeRec $ do + forM_ rsdPrev $ storeRef "PREV" + forM_ rsdRoom $ storeRef "room" + forM_ rsdSubscribe $ storeInt "subscribe" . bool @Int 0 1 + forM_ rsdMessages $ storeRef "msg" + + load' = loadRec $ do + rsdPrev <- loadRefs "PREV" + rsdRoom <- loadRefs "room" + rsdSubscribe <- fmap ((/=) @Int 0) <$> loadMbInt "subscribe" + rsdMessages <- loadRefs "msg" + return ChatroomStateData {..} + +instance Mergeable ChatroomState where + type Component ChatroomState = ChatroomStateData + + mergeSorted roomStateData = + let roomStateRoom = either (const Nothing) Just $ runExcept $ + validateChatroom $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . rsdRoom) roomStateData + roomStateMessageData = filterAncestors $ concat $ flip findProperty roomStateData $ \case + ChatroomStateData {..} | null rsdMessages -> Nothing + | otherwise -> Just rsdMessages + roomStateSubscribe = fromMaybe False $ findPropertyFirst rsdSubscribe roomStateData + roomStateMessages = threadToListSince [] $ concatMap (rsdMessages . fromStored) roomStateData + in ChatroomState {..} + + toComponents = roomStateData + +instance SharedType (Set ChatroomState) where + sharedTypeID _ = mkSharedTypeID "7bc71cbf-bc43-42b1-b413-d3a2c9a2aae0" + +createChatroom :: (MonadStorage m, MonadHead LocalState m, MonadIO m, MonadError String m) => Maybe Text -> Maybe Text -> m ChatroomState +createChatroom rdName rdDescription = do + (secret, rdKey) <- liftIO . generateKeys =<< getStorage + let rdPrev = [] + rdata <- mstore =<< sign secret =<< mstore ChatroomData {..} + cstate <- mergeSorted . (:[]) <$> mstore ChatroomStateData + { rsdPrev = [] + , rsdRoom = [ rdata ] + , rsdSubscribe = Just True + , rsdMessages = [] + } + + updateLocalHead $ updateSharedState $ \rooms -> do + st <- getStorage + (, cstate) <$> storeSetAdd st cstate rooms + +findAndUpdateChatroomState + :: (MonadStorage m, MonadHead LocalState m) + => (ChatroomState -> Maybe (m ChatroomState)) + -> m (Maybe ChatroomState) +findAndUpdateChatroomState f = do + updateLocalHead $ updateSharedState $ \roomSet -> do + let roomList = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet + case catMaybes $ map (\x -> (x,) <$> f x) roomList of + ((orig, act) : _) -> do + upd <- act + if roomStateData orig /= roomStateData upd + then do + st <- getStorage + roomSet' <- storeSetAdd st upd roomSet + return (roomSet', Just upd) + else do + return (roomSet, Just upd) + [] -> return (roomSet, Nothing) + +updateChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData + -> Maybe Text + -> Maybe Text + -> m (Maybe ChatroomState) +updateChatroomByStateData lookupData newName newDesc = findAndUpdateChatroomState $ \cstate -> do + guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate + room <- roomStateRoom cstate + Just $ do + secret <- loadKey $ roomKey room + rdata <- mstore =<< sign secret =<< mstore ChatroomData + { rdPrev = roomData room + , rdName = newName + , rdDescription = newDesc + , rdKey = roomKey room + } + mergeSorted . (:[]) <$> mstore ChatroomStateData + { rsdPrev = roomStateData cstate + , rsdRoom = [ rdata ] + , rsdSubscribe = Just True + , rsdMessages = [] + } + + +listChatrooms :: MonadHead LocalState m => m [ChatroomState] +listChatrooms = fromSetBy (comparing $ roomName <=< roomStateRoom) . + lookupSharedValue . lsShared . fromStored <$> getLocalHead + +findChatroom :: MonadHead LocalState m => (ChatroomState -> Bool) -> m (Maybe ChatroomState) +findChatroom p = do + list <- map snd . chatroomSetToList . lookupSharedValue . lsShared . fromStored <$> getLocalHead + return $ find p list + +findChatroomByRoomData :: MonadHead LocalState m => Stored (Signed ChatroomData) -> m (Maybe ChatroomState) +findChatroomByRoomData cdata = findChatroom $ + maybe False (any (cdata `precedesOrEquals`) . roomData) . roomStateRoom + +findChatroomByStateData :: MonadHead LocalState m => Stored ChatroomStateData -> m (Maybe ChatroomState) +findChatroomByStateData cdata = findChatroom $ any (cdata `precedesOrEquals`) . roomStateData + +chatroomSetSubscribe + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> Bool -> m () +chatroomSetSubscribe lookupData subscribe = void $ findAndUpdateChatroomState $ \cstate -> do + guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate + Just $ do + mergeSorted . (:[]) <$> mstore ChatroomStateData + { rsdPrev = roomStateData cstate + , rsdRoom = [] + , rsdSubscribe = Just subscribe + , rsdMessages = [] + } + +chatroomMembers :: ChatroomState -> [ ComposedIdentity ] +chatroomMembers ChatroomState {..} = + map (mdFrom . fromSigned . head) $ + filter (any $ not . mdLeave . fromSigned) $ -- keep only users that hasn't left + map (filterAncestors . map snd) $ -- gather message data per each identity and filter ancestors + groupBy ((==) `on` fst) $ -- group on identity root + sortBy (comparing fst) $ -- sort by first root of identity data + map (\x -> ( head . filterAncestors . concatMap storedRoots . idDataF . mdFrom . fromSigned $ x, x )) $ + toList $ ancestors $ roomStateMessageData + +joinChatroom + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => ChatroomState -> m () +joinChatroom rstate = joinChatroomByStateData (head $ roomStateData rstate) + +joinChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> m () +joinChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing False + +leaveChatroom + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => ChatroomState -> m () +leaveChatroom rstate = leaveChatroomByStateData (head $ roomStateData rstate) + +leaveChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> m () +leaveChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing True + +getMessagesSinceState :: ChatroomState -> ChatroomState -> [ChatMessage] +getMessagesSinceState cur old = threadToListSince (roomStateMessageData old) (roomStateMessageData cur) + + +data ChatroomSetChange = AddedChatroom ChatroomState + | RemovedChatroom ChatroomState + | UpdatedChatroom ChatroomState ChatroomState + +watchChatrooms :: MonadIO m => Head LocalState -> (Set ChatroomState -> Maybe [ChatroomSetChange] -> IO ()) -> m WatchedHead +watchChatrooms h f = liftIO $ do + lastVar <- newIORef Nothing + watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \cur -> do + let curList = chatroomSetToList cur + mbLast <- readIORef lastVar + writeIORef lastVar $ Just curList + f cur $ do + lastList <- mbLast + return $ makeChatroomDiff lastList curList + +chatroomSetToList :: Set ChatroomState -> [(Stored ChatroomStateData, ChatroomState)] +chatroomSetToList = map (cmp &&& id) . fromSetBy (comparing cmp) + where + cmp :: ChatroomState -> Stored ChatroomStateData + cmp = head . filterAncestors . concatMap storedRoots . toComponents + +makeChatroomDiff + :: [(Stored ChatroomStateData, ChatroomState)] + -> [(Stored ChatroomStateData, ChatroomState)] + -> [ChatroomSetChange] +makeChatroomDiff (x@(cx, vx) : xs) (y@(cy, vy) : ys) + | cx < cy = RemovedChatroom vx : makeChatroomDiff xs (y : ys) + | cx > cy = AddedChatroom vy : makeChatroomDiff (x : xs) ys + | roomStateData vx /= roomStateData vy = UpdatedChatroom vx vy : makeChatroomDiff xs ys + | otherwise = makeChatroomDiff xs ys +makeChatroomDiff xs [] = map (RemovedChatroom . snd) xs +makeChatroomDiff [] ys = map (AddedChatroom . snd) ys + + +data ChatroomService = ChatroomService + { chatRoomQuery :: Bool + , chatRoomInfo :: [Stored (Signed ChatroomData)] + , chatRoomSubscribe :: [Stored (Signed ChatroomData)] + , chatRoomUnsubscribe :: [Stored (Signed ChatroomData)] + , chatRoomMessage :: [Stored (Signed ChatMessageData)] + } + deriving (Eq) + +emptyPacket :: ChatroomService +emptyPacket = ChatroomService + { chatRoomQuery = False + , chatRoomInfo = [] + , chatRoomSubscribe = [] + , chatRoomUnsubscribe = [] + , chatRoomMessage = [] + } + +instance Storable ChatroomService where + store' ChatroomService {..} = storeRec $ do + when chatRoomQuery $ storeEmpty "room-query" + forM_ chatRoomInfo $ storeRef "room-info" + forM_ chatRoomSubscribe $ storeRef "room-subscribe" + forM_ chatRoomUnsubscribe $ storeRef "room-unsubscribe" + forM_ chatRoomMessage $ storeRef "room-message" + + load' = loadRec $ do + chatRoomQuery <- isJust <$> loadMbEmpty "room-query" + chatRoomInfo <- loadRefs "room-info" + chatRoomSubscribe <- loadRefs "room-subscribe" + chatRoomUnsubscribe <- loadRefs "room-unsubscribe" + chatRoomMessage <- loadRefs "room-message" + return ChatroomService {..} + +data PeerState = PeerState + { psSendRoomUpdates :: Bool + , psLastList :: [(Stored ChatroomStateData, ChatroomState)] + , psSubscribedTo :: [ Stored (Signed ChatroomData) ] -- least root for each room + } + +instance Service ChatroomService where + serviceID _ = mkServiceID "627657ae-3e39-468a-8381-353395ef4386" + + type ServiceState ChatroomService = PeerState + emptyServiceState _ = PeerState + { psSendRoomUpdates = False + , psLastList = [] + , psSubscribedTo = [] + } + + serviceHandler spacket = do + let ChatroomService {..} = fromStored spacket + + previouslyUpdated <- psSendRoomUpdates <$> svcGet + svcModify $ \s -> s { psSendRoomUpdates = True } + + when (not previouslyUpdated) $ do + syncChatroomsToPeer . lookupSharedValue . lsShared . fromStored =<< getLocalHead + + when chatRoomQuery $ do + rooms <- listChatrooms + replyPacket emptyPacket + { chatRoomInfo = concatMap roomData $ catMaybes $ map roomStateRoom rooms + } + + when (not $ null chatRoomInfo) $ do + updateLocalHead_ $ updateSharedState_ $ \roomSet -> do + let rooms = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet + upd set (roomInfo :: Stored (Signed ChatroomData)) = do + let currentRoots = storedRoots roomInfo + isCurrentRoom = any ((`intersectsSorted` currentRoots) . storedRoots) . + maybe [] roomData . roomStateRoom + + let prev = concatMap roomStateData $ filter isCurrentRoom rooms + prevRoom = filterAncestors $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . rsdRoom) prev + room = filterAncestors $ (roomInfo : ) prevRoom + + -- update local state only if we got roomInfo not present there + if roomInfo `notElem` prevRoom && roomInfo `elem` room + then do + sdata <- mstore ChatroomStateData + { rsdPrev = prev + , rsdRoom = room + , rsdSubscribe = Nothing + , rsdMessages = [] + } + storeSetAddComponent sdata set + else return set + foldM upd roomSet chatRoomInfo + + forM_ chatRoomSubscribe $ \subscribeData -> do + mbRoomState <- findChatroomByRoomData subscribeData + forM_ mbRoomState $ \roomState -> + forM (roomStateRoom roomState) $ \room -> do + let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room + svcModify $ \ps -> ps { psSubscribedTo = leastRoot : psSubscribedTo ps } + replyPacket emptyPacket + { chatRoomMessage = roomStateMessageData roomState + } + + forM_ chatRoomUnsubscribe $ \unsubscribeData -> do + mbRoomState <- findChatroomByRoomData unsubscribeData + forM_ (mbRoomState >>= roomStateRoom) $ \room -> do + let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room + svcModify $ \ps -> ps { psSubscribedTo = filter (/= leastRoot) (psSubscribedTo ps) } + + when (not (null chatRoomMessage)) $ do + updateLocalHead_ $ updateSharedState_ $ \roomSet -> do + let rooms = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet + upd set (msgData :: Stored (Signed ChatMessageData)) + | Just msg <- validateSingleMessage msgData = do + let roomInfo = cmsgRoomData msg + currentRoots = filterAncestors $ concatMap storedRoots roomInfo + isCurrentRoom = any ((`intersectsSorted` currentRoots) . storedRoots) . + maybe [] roomData . roomStateRoom + + let prevData = concatMap roomStateData $ filter isCurrentRoom rooms + prev = mergeSorted prevData + prevMessages = roomStateMessageData prev + messages = filterAncestors $ msgData : prevMessages + + -- update local state only if subscribed and we got some new messages + if roomStateSubscribe prev && messages /= prevMessages + then do + sdata <- mstore ChatroomStateData + { rsdPrev = prevData + , rsdRoom = [] + , rsdSubscribe = Nothing + , rsdMessages = messages + } + storeSetAddComponent sdata set + else return set + | otherwise = return set + foldM upd roomSet chatRoomMessage + + serviceNewPeer = do + replyPacket emptyPacket { chatRoomQuery = True } + + serviceStorageWatchers _ = (:[]) $ + SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncChatroomsToPeer + +syncChatroomsToPeer :: Set ChatroomState -> ServiceHandler ChatroomService () +syncChatroomsToPeer set = do + ps@PeerState {..} <- svcGet + when psSendRoomUpdates $ do + let curList = chatroomSetToList set + diff = makeChatroomDiff psLastList curList + + roomUpdates <- fmap (concat . catMaybes) $ + forM diff $ return . \case + AddedChatroom room -> roomData <$> roomStateRoom room + RemovedChatroom {} -> Nothing + UpdatedChatroom oldroom room + | roomStateData oldroom /= roomStateData room -> roomData <$> roomStateRoom room + | otherwise -> Nothing + + (subscribe, unsubscribe) <- fmap (partitionEithers . concat . catMaybes) $ + forM diff $ return . \case + AddedChatroom room + | roomStateSubscribe room + -> map Left . roomData <$> roomStateRoom room + RemovedChatroom oldroom + | roomStateSubscribe oldroom + -> map Right . roomData <$> roomStateRoom oldroom + UpdatedChatroom oldroom room + | roomStateSubscribe oldroom /= roomStateSubscribe room + -> map (if roomStateSubscribe room then Left else Right) . roomData <$> roomStateRoom room + _ -> Nothing + + messages <- fmap concat $ do + let leastRootFor = head . filterAncestors . concatMap storedRoots . roomData + forM diff $ return . \case + AddedChatroom rstate + | Just room <- roomStateRoom rstate + , leastRootFor room `elem` psSubscribedTo + -> roomStateMessageData rstate + UpdatedChatroom oldstate rstate + | Just room <- roomStateRoom rstate + , leastRootFor room `elem` psSubscribedTo + , roomStateMessageData oldstate /= roomStateMessageData rstate + -> roomStateMessageData rstate + _ -> [] + + let packet = emptyPacket + { chatRoomInfo = roomUpdates + , chatRoomSubscribe = subscribe + , chatRoomUnsubscribe = unsubscribe + , chatRoomMessage = messages + } + + when (packet /= emptyPacket) $ do + replyPacket packet + svcSet $ ps { psLastList = curList } diff --git a/src/Erebos/Contact.hs b/src/Erebos/Contact.hs new file mode 100644 index 0000000..d90aa50 --- /dev/null +++ b/src/Erebos/Contact.hs @@ -0,0 +1,175 @@ +module Erebos.Contact ( + Contact, + contactIdentity, + contactCustomName, + contactName, + + contactSetName, + + ContactService, + contactRequest, + contactAccept, + contactReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Maybe +import Data.Proxy +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data Contact = Contact + { contactData :: [Stored ContactData] + , contactIdentity_ :: Maybe ComposedIdentity + , contactCustomName_ :: Maybe Text + } + +data ContactData = ContactData + { cdPrev :: [Stored ContactData] + , cdIdentity :: [Stored (Signed ExtendedIdentityData)] + , cdName :: Maybe Text + } + +instance Storable ContactData where + store' x = storeRec $ do + mapM_ (storeRef "PREV") $ cdPrev x + mapM_ (storeRef "identity") $ cdIdentity x + storeMbText "name" $ cdName x + + load' = loadRec $ ContactData + <$> loadRefs "PREV" + <*> loadRefs "identity" + <*> loadMbText "name" + +instance Mergeable Contact where + type Component Contact = ContactData + + mergeSorted cdata = Contact + { contactData = cdata + , contactIdentity_ = validateExtendedIdentityF $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . cdIdentity) cdata + , contactCustomName_ = findPropertyFirst cdName cdata + } + + toComponents = contactData + +instance SharedType (Set Contact) where + sharedTypeID _ = mkSharedTypeID "34fbb61e-6022-405f-b1b3-a5a1abecd25e" + +contactIdentity :: Contact -> Maybe ComposedIdentity +contactIdentity = contactIdentity_ + +contactCustomName :: Contact -> Maybe Text +contactCustomName = contactCustomName_ + +contactName :: Contact -> Text +contactName c = fromJust $ msum + [ contactCustomName c + , idName =<< contactIdentity c + , Just T.empty + ] + +contactSetName :: MonadHead LocalState m => Contact -> Text -> Set Contact -> m (Set Contact) +contactSetName contact name set = do + st <- getStorage + cdata <- wrappedStore st ContactData + { cdPrev = toComponents contact + , cdIdentity = [] + , cdName = Just name + } + storeSetAdd st (mergeSorted @Contact [cdata]) set + + +type ContactService = PairingService ContactAccepted + +data ContactAccepted = ContactAccepted + +instance Storable ContactAccepted where + store' ContactAccepted = storeRec $ do + storeText "accept" "" + load' = loadRec $ do + (_ :: T.Text) <- loadText "accept" + return ContactAccepted + +instance PairingResult ContactAccepted where + pairingServiceID _ = mkServiceID "d9c37368-0da1-4280-93e9-d9bd9a198084" + + pairingVerifyResult = return . Just + + pairingFinalizeRequest ContactAccepted = do + pid <- asks svcPeerIdentity + finalizeContact pid + + pairingFinalizeResponse = do + pid <- asks svcPeerIdentity + finalizeContact pid + return ContactAccepted + + defaultPairingAttributes _ = PairingAttributes + { pairingHookRequest = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Contact pairing from " ++ T.unpack (displayIdentity peer) ++ " initiated" + + , pairingHookResponse = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Confirm contact " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + + , pairingHookRequestNonce = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Contact request from " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + + , pairingHookRequestNonceFailed = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Failed contact request from " ++ T.unpack (displayIdentity peer) + + , pairingHookConfirmedResponse = do + svcPrint $ "Contact accepted, waiting for peer confirmation" + + , pairingHookConfirmedRequest = do + svcPrint $ "Contact confirmed by peer" + + , pairingHookAcceptedResponse = do + svcPrint $ "Contact accepted" + + , pairingHookAcceptedRequest = do + svcPrint $ "Contact accepted" + + , pairingHookVerifyFailed = return () + + , pairingHookRejected = do + svcPrint $ "Contact rejected by peer" + + , pairingHookFailed = \_ -> do + svcPrint $ "Contact failed" + } + +contactRequest :: (MonadIO m, MonadError String m) => Peer -> m () +contactRequest = pairingRequest @ContactAccepted Proxy + +contactAccept :: (MonadIO m, MonadError String m) => Peer -> m () +contactAccept = pairingAccept @ContactAccepted Proxy + +contactReject :: (MonadIO m, MonadError String m) => Peer -> m () +contactReject = pairingReject @ContactAccepted Proxy + +finalizeContact :: MonadHead LocalState m => UnifiedIdentity -> m () +finalizeContact identity = updateLocalHead_ $ updateSharedState_ $ \contacts -> do + st <- getStorage + cdata <- wrappedStore st ContactData + { cdPrev = [] + , cdIdentity = idExtDataF $ finalOwner identity + , cdName = Nothing + } + storeSetAdd st (mergeSorted @Contact [cdata]) contacts diff --git a/src/Erebos/Conversation.hs b/src/Erebos/Conversation.hs new file mode 100644 index 0000000..63475bd --- /dev/null +++ b/src/Erebos/Conversation.hs @@ -0,0 +1,105 @@ +module Erebos.Conversation ( + Message, + messageFrom, + messageTime, + messageText, + messageUnread, + formatMessage, + + Conversation, + directMessageConversation, + chatroomConversation, + chatroomConversationByStateData, + reloadConversation, + lookupConversations, + + conversationName, + conversationPeer, + conversationHistory, + + sendMessage, +) where + +import Control.Monad.Except + +import Data.List +import Data.Maybe +import Data.Text (Text) +import Data.Text qualified as T +import Data.Time.Format +import Data.Time.LocalTime + +import Erebos.Identity +import Erebos.Chatroom +import Erebos.Message hiding (formatMessage) +import Erebos.State +import Erebos.Storage + + +data Message = DirectMessageMessage DirectMessage Bool + | ChatroomMessage ChatMessage Bool + +messageFrom :: Message -> ComposedIdentity +messageFrom (DirectMessageMessage msg _) = msgFrom msg +messageFrom (ChatroomMessage msg _) = cmsgFrom msg + +messageTime :: Message -> ZonedTime +messageTime (DirectMessageMessage msg _) = msgTime msg +messageTime (ChatroomMessage msg _) = cmsgTime msg + +messageText :: Message -> Maybe Text +messageText (DirectMessageMessage msg _) = Just $ msgText msg +messageText (ChatroomMessage msg _) = cmsgText msg + +messageUnread :: Message -> Bool +messageUnread (DirectMessageMessage _ unread) = unread +messageUnread (ChatroomMessage _ unread) = unread + +formatMessage :: TimeZone -> Message -> String +formatMessage tzone msg = concat + [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ messageTime msg + , maybe "<unnamed>" T.unpack $ idName $ messageFrom msg + , maybe "" ((": "<>) . T.unpack) $ messageText msg + ] + + +data Conversation = DirectMessageConversation DirectMessageThread + | ChatroomConversation ChatroomState + +directMessageConversation :: MonadHead LocalState m => ComposedIdentity -> m Conversation +directMessageConversation peer = do + (find (sameIdentity peer . msgPeer) . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead) >>= \case + Just thread -> return $ DirectMessageConversation thread + Nothing -> return $ DirectMessageConversation $ DirectMessageThread peer [] [] [] + +chatroomConversation :: MonadHead LocalState m => ChatroomState -> m (Maybe Conversation) +chatroomConversation rstate = chatroomConversationByStateData (head $ roomStateData rstate) + +chatroomConversationByStateData :: MonadHead LocalState m => Stored ChatroomStateData -> m (Maybe Conversation) +chatroomConversationByStateData sdata = fmap ChatroomConversation <$> findChatroomByStateData sdata + +reloadConversation :: MonadHead LocalState m => Conversation -> m Conversation +reloadConversation (DirectMessageConversation thread) = directMessageConversation (msgPeer thread) +reloadConversation cur@(ChatroomConversation rstate) = + fromMaybe cur <$> chatroomConversation rstate + +lookupConversations :: MonadHead LocalState m => m [Conversation] +lookupConversations = map DirectMessageConversation . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead + + +conversationName :: Conversation -> Text +conversationName (DirectMessageConversation thread) = fromMaybe (T.pack "<unnamed>") $ idName $ msgPeer thread +conversationName (ChatroomConversation rstate) = fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate + +conversationPeer :: Conversation -> Maybe ComposedIdentity +conversationPeer (DirectMessageConversation thread) = Just $ msgPeer thread +conversationPeer (ChatroomConversation _) = Nothing + +conversationHistory :: Conversation -> [Message] +conversationHistory (DirectMessageConversation thread) = map (\msg -> DirectMessageMessage msg False) $ threadToList thread +conversationHistory (ChatroomConversation rstate) = map (\msg -> ChatroomMessage msg False) $ roomStateMessages rstate + + +sendMessage :: (MonadHead LocalState m, MonadError String m) => Conversation -> Text -> m (Maybe Message) +sendMessage (DirectMessageConversation thread) text = fmap Just $ DirectMessageMessage <$> (fromStored <$> sendDirectMessage (msgPeer thread) text) <*> pure False +sendMessage (ChatroomConversation rstate) text = sendChatroomMessage rstate text >> return Nothing diff --git a/src/Erebos/Discovery.hs b/src/Erebos/Discovery.hs new file mode 100644 index 0000000..48df9c3 --- /dev/null +++ b/src/Erebos/Discovery.hs @@ -0,0 +1,223 @@ +module Erebos.Discovery ( + DiscoveryService(..), + DiscoveryConnection(..) +) where + +import Control.Concurrent +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as M +import Data.Maybe +import Data.Text (Text) +import qualified Data.Text as T + +import Network.Socket + +import Erebos.ICE +import Erebos.Identity +import Erebos.Network +import Erebos.Service +import Erebos.Storage + + +keepaliveSeconds :: Int +keepaliveSeconds = 20 + + +data DiscoveryService = DiscoverySelf Text Int + | DiscoveryAcknowledged Text + | DiscoverySearch Ref + | DiscoveryResult Ref (Maybe Text) + | DiscoveryConnectionRequest DiscoveryConnection + | DiscoveryConnectionResponse DiscoveryConnection + +data DiscoveryConnection = DiscoveryConnection + { dconnSource :: Ref + , dconnTarget :: Ref + , dconnAddress :: Maybe Text + , dconnIceSession :: Maybe IceRemoteInfo + } + +emptyConnection :: Ref -> Ref -> DiscoveryConnection +emptyConnection source target = DiscoveryConnection source target Nothing Nothing + +instance Storable DiscoveryService where + store' x = storeRec $ do + case x of + DiscoverySelf addr priority -> do + storeText "self" addr + storeInt "priority" priority + DiscoveryAcknowledged addr -> do + storeText "ack" addr + DiscoverySearch ref -> storeRawRef "search" ref + DiscoveryResult ref addr -> do + storeRawRef "result" ref + storeMbText "address" addr + DiscoveryConnectionRequest conn -> storeConnection "request" conn + DiscoveryConnectionResponse conn -> storeConnection "response" conn + + where storeConnection ctype conn = do + storeText "connection" $ ctype + storeRawRef "source" $ dconnSource conn + storeRawRef "target" $ dconnTarget conn + storeMbText "address" $ dconnAddress conn + storeMbRef "ice-session" $ dconnIceSession conn + + load' = loadRec $ msum + [ DiscoverySelf + <$> loadText "self" + <*> loadInt "priority" + , DiscoveryAcknowledged + <$> loadText "ack" + , DiscoverySearch <$> loadRawRef "search" + , DiscoveryResult + <$> loadRawRef "result" + <*> loadMbText "address" + , loadConnection "request" DiscoveryConnectionRequest + , loadConnection "response" DiscoveryConnectionResponse + ] + where loadConnection ctype ctor = do + ctype' <- loadText "connection" + guard $ ctype == ctype' + return . ctor =<< DiscoveryConnection + <$> loadRawRef "source" + <*> loadRawRef "target" + <*> loadMbText "address" + <*> loadMbRef "ice-session" + +data DiscoveryPeer = DiscoveryPeer + { dpPriority :: Int + , dpPeer :: Maybe Peer + , dpAddress :: Maybe Text + , dpIceSession :: Maybe IceSession + } + +instance Service DiscoveryService where + serviceID _ = mkServiceID "dd59c89c-69cc-4703-b75b-4ddcd4b3c23b" + + type ServiceGlobalState DiscoveryService = Map RefDigest DiscoveryPeer + emptyServiceGlobalState _ = M.empty + + serviceHandler msg = case fromStored msg of + DiscoverySelf addr priority -> do + pid <- asks svcPeerIdentity + peer <- asks svcPeer + let insertHelper new old | dpPriority new > dpPriority old = new + | otherwise = old + mbaddr <- case words (T.unpack addr) of + [ipaddr, port] | DatagramAddress paddr <- peerAddress peer -> do + saddr <- liftIO $ head <$> getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + return $ if paddr == addrAddress saddr + then Just addr + else Nothing + _ -> return Nothing + forM_ (idDataF =<< unfoldOwners pid) $ \s -> + svcModifyGlobal $ M.insertWith insertHelper (refDigest $ storedRef s) $ + DiscoveryPeer priority (Just peer) mbaddr Nothing + replyPacket $ DiscoveryAcknowledged $ fromMaybe (T.pack "ICE") mbaddr + + DiscoveryAcknowledged addr -> do + when (addr == T.pack "ICE") $ do + -- keep-alive packet from behind NAT + peer <- asks svcPeer + liftIO $ void $ forkIO $ do + threadDelay (keepaliveSeconds * 1000 * 1000) + res <- runExceptT $ sendToPeer peer $ DiscoverySelf addr 0 + case res of + Right _ -> return () + Left err -> putStrLn $ "Discovery: failed to send keep-alive: " ++ err + + DiscoverySearch ref -> do + addr <- M.lookup (refDigest ref) <$> svcGetGlobal + replyPacket $ DiscoveryResult ref $ fromMaybe (T.pack "ICE") . dpAddress <$> addr + + DiscoveryResult ref Nothing -> do + svcPrint $ "Discovery: " ++ show (refDigest ref) ++ " not found" + + DiscoveryResult ref (Just addr) -> do + -- TODO: check if we really requested that + server <- asks svcServer + if addr == T.pack "ICE" + then do + self <- svcSelf + peer <- asks svcPeer + ice <- liftIO $ iceCreate PjIceSessRoleControlling $ \ice -> do + rinfo <- iceRemoteInfo ice + res <- runExceptT $ sendToPeer peer $ + DiscoveryConnectionRequest (emptyConnection (storedRef $ idData self) ref) { dconnIceSession = Just rinfo } + case res of + Right _ -> return () + Left err -> putStrLn $ "Discovery: failed to send connection request: " ++ err + + svcModifyGlobal $ M.insert (refDigest ref) $ + DiscoveryPeer 0 Nothing Nothing (Just ice) + else do + case words (T.unpack addr) of + [ipaddr, port] -> do + saddr <- liftIO $ head <$> + getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + peer <- liftIO $ serverPeer server (addrAddress saddr) + svcModifyGlobal $ M.insert (refDigest ref) $ + DiscoveryPeer 0 (Just peer) Nothing Nothing + + _ -> svcPrint $ "Discovery: invalid address in result: " ++ T.unpack addr + + DiscoveryConnectionRequest conn -> do + self <- svcSelf + let rconn = emptyConnection (dconnSource conn) (dconnTarget conn) + if refDigest (dconnTarget conn) `elem` (map (refDigest . storedRef) $ idDataF =<< unfoldOwners self) + then do + -- request for us, create ICE sesssion + server <- asks svcServer + peer <- asks svcPeer + liftIO $ void $ iceCreate PjIceSessRoleControlled $ \ice -> do + rinfo <- iceRemoteInfo ice + res <- runExceptT $ sendToPeer peer $ DiscoveryConnectionResponse rconn { dconnIceSession = Just rinfo } + case res of + Right _ -> do + case dconnIceSession conn of + Just prinfo -> iceConnect ice prinfo $ void $ serverPeerIce server ice + Nothing -> putStrLn $ "Discovery: connection request without ICE remote info" + Left err -> putStrLn $ "Discovery: failed to send connection response: " ++ err + + else do + -- request to some of our peers, relay + mbdp <- M.lookup (refDigest $ dconnTarget conn) <$> svcGetGlobal + case mbdp of + Nothing -> replyPacket $ DiscoveryConnectionResponse rconn + Just dp | Just addr <- dpAddress dp -> do + replyPacket $ DiscoveryConnectionResponse rconn { dconnAddress = Just addr } + | Just dpeer <- dpPeer dp -> do + sendToPeer dpeer $ DiscoveryConnectionRequest conn + | otherwise -> svcPrint $ "Discovery: failed to relay connection request" + + DiscoveryConnectionResponse conn -> do + self <- svcSelf + dpeers <- svcGetGlobal + if refDigest (dconnSource conn) `elem` (map (refDigest . storedRef) $ idDataF =<< unfoldOwners self) + then do + -- response to our request, try to connect to the peer + server <- asks svcServer + if | Just addr <- dconnAddress conn + , [ipaddr, port] <- words (T.unpack addr) -> do + saddr <- liftIO $ head <$> + getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + peer <- liftIO $ serverPeer server (addrAddress saddr) + svcModifyGlobal $ M.insert (refDigest $ dconnTarget conn) $ + DiscoveryPeer 0 (Just peer) Nothing Nothing + + | Just dp <- M.lookup (refDigest $ dconnTarget conn) dpeers + , Just ice <- dpIceSession dp + , Just rinfo <- dconnIceSession conn -> do + liftIO $ iceConnect ice rinfo $ void $ serverPeerIce server ice + + | otherwise -> svcPrint $ "Discovery: connection request failed" + else do + -- response to relayed request + case M.lookup (refDigest $ dconnSource conn) dpeers of + Just dp | Just dpeer <- dpPeer dp -> do + sendToPeer dpeer $ DiscoveryConnectionResponse conn + _ -> svcPrint $ "Discovery: failed to relay connection response" diff --git a/src/Erebos/Flow.hs b/src/Erebos/Flow.hs new file mode 100644 index 0000000..ba2607a --- /dev/null +++ b/src/Erebos/Flow.hs @@ -0,0 +1,73 @@ +module Erebos.Flow ( + Flow, SymFlow, + newFlow, newFlowIO, + readFlow, tryReadFlow, canReadFlow, + writeFlow, writeFlowBulk, tryWriteFlow, canWriteFlow, + readFlowIO, writeFlowIO, + + mapFlow, +) where + +import Control.Concurrent.STM + + +data Flow r w = Flow (TMVar [r]) (TMVar [w]) + | forall r' w'. MappedFlow (r' -> r) (w -> w') (Flow r' w') + +type SymFlow a = Flow a a + +newFlow :: STM (Flow a b, Flow b a) +newFlow = do + x <- newEmptyTMVar + y <- newEmptyTMVar + return (Flow x y, Flow y x) + +newFlowIO :: IO (Flow a b, Flow b a) +newFlowIO = atomically newFlow + +readFlow :: Flow r w -> STM r +readFlow (Flow rvar _) = takeTMVar rvar >>= \case + (x:[]) -> return x + (x:xs) -> putTMVar rvar xs >> return x + [] -> error "Flow: empty list" +readFlow (MappedFlow f _ up) = f <$> readFlow up + +tryReadFlow :: Flow r w -> STM (Maybe r) +tryReadFlow (Flow rvar _) = tryTakeTMVar rvar >>= \case + Just (x:[]) -> return (Just x) + Just (x:xs) -> putTMVar rvar xs >> return (Just x) + Just [] -> error "Flow: empty list" + Nothing -> return Nothing +tryReadFlow (MappedFlow f _ up) = fmap f <$> tryReadFlow up + +canReadFlow :: Flow r w -> STM Bool +canReadFlow (Flow rvar _) = not <$> isEmptyTMVar rvar +canReadFlow (MappedFlow _ _ up) = canReadFlow up + +writeFlow :: Flow r w -> w -> STM () +writeFlow (Flow _ wvar) = putTMVar wvar . (:[]) +writeFlow (MappedFlow _ f up) = writeFlow up . f + +writeFlowBulk :: Flow r w -> [w] -> STM () +writeFlowBulk _ [] = return () +writeFlowBulk (Flow _ wvar) xs = putTMVar wvar xs +writeFlowBulk (MappedFlow _ f up) xs = writeFlowBulk up $ map f xs + +tryWriteFlow :: Flow r w -> w -> STM Bool +tryWriteFlow (Flow _ wvar) = tryPutTMVar wvar . (:[]) +tryWriteFlow (MappedFlow _ f up) = tryWriteFlow up . f + +canWriteFlow :: Flow r w -> STM Bool +canWriteFlow (Flow _ wvar) = isEmptyTMVar wvar +canWriteFlow (MappedFlow _ _ up) = canWriteFlow up + +readFlowIO :: Flow r w -> IO r +readFlowIO path = atomically $ readFlow path + +writeFlowIO :: Flow r w -> w -> IO () +writeFlowIO path = atomically . writeFlow path + + +mapFlow :: (r -> r') -> (w' -> w) -> Flow r w -> Flow r' w' +mapFlow rf wf (MappedFlow rf' wf' up) = MappedFlow (rf . rf') (wf' . wf) up +mapFlow rf wf up = MappedFlow rf wf up diff --git a/src/Erebos/ICE.chs b/src/Erebos/ICE.chs new file mode 100644 index 0000000..096ee0d --- /dev/null +++ b/src/Erebos/ICE.chs @@ -0,0 +1,205 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE RecursiveDo #-} + +module Erebos.ICE ( + IceSession, + IceSessionRole(..), + IceRemoteInfo, + + iceCreate, + iceDestroy, + iceRemoteInfo, + iceShow, + iceConnect, + iceSend, + + iceSetChan, +) where + +import Control.Arrow +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity + +import Data.ByteString (ByteString, packCStringLen, useAsCString) +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.ByteString.Unsafe +import Data.Function +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import qualified Data.Text.Read as T +import Data.Void + +import Foreign.C.String +import Foreign.C.Types +import Foreign.Marshal.Alloc +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.StablePtr + +import Erebos.Flow +import Erebos.Storage + +#include "pjproject.h" + +data IceSession = IceSession + { isStrans :: PjIceStrans + , isChan :: MVar (Either [ByteString] (Flow Void ByteString)) + } + +instance Eq IceSession where + (==) = (==) `on` isStrans + +instance Ord IceSession where + compare = compare `on` isStrans + +instance Show IceSession where + show _ = "<ICE>" + + +data IceRemoteInfo = IceRemoteInfo + { iriUsernameFrament :: Text + , iriPassword :: Text + , iriDefaultCandidate :: Text + , iriCandidates :: [Text] + } + +data IceCandidate = IceCandidate + { icandFoundation :: Text + , icandPriority :: Int + , icandAddr :: Text + , icandPort :: Int + , icandType :: Text + } + +instance Storable IceRemoteInfo where + store' x = storeRec $ do + storeText "ice-ufrag" $ iriUsernameFrament x + storeText "ice-pass" $ iriPassword x + storeText "ice-default" $ iriDefaultCandidate x + mapM_ (storeText "ice-candidate") $ iriCandidates x + + load' = loadRec $ IceRemoteInfo + <$> loadText "ice-ufrag" + <*> loadText "ice-pass" + <*> loadText "ice-default" + <*> loadTexts "ice-candidate" + +instance StorableText IceCandidate where + toText x = T.concat $ + [ icandFoundation x + , T.singleton ' ' + , T.pack $ show $ icandPriority x + , T.singleton ' ' + , icandAddr x + , T.singleton ' ' + , T.pack $ show $ icandPort x + , T.singleton ' ' + , icandType x + ] + + fromText t = case T.words t of + [found, tprio, addr, tport, ctype] + | Right (prio, _) <- T.decimal tprio + , Right (port, _) <- T.decimal tport + -> return $ IceCandidate + { icandFoundation = found + , icandPriority = prio + , icandAddr = addr + , icandPort = port + , icandType = ctype + } + _ -> throwError "failed to parse candidate" + + +{#enum pj_ice_sess_role as IceSessionRole {underscoreToCase} deriving (Show, Eq) #} + +{#pointer *pj_ice_strans as ^ #} + +iceCreate :: IceSessionRole -> (IceSession -> IO ()) -> IO IceSession +iceCreate role cb = do + rec sptr <- newStablePtr sess + cbptr <- newStablePtr $ cb sess + sess <- IceSession + <$> {#call ice_create #} (fromIntegral $ fromEnum role) (castStablePtrToPtr sptr) (castStablePtrToPtr cbptr) + <*> (newMVar $ Left []) + return $ sess + +{#fun ice_destroy as ^ { isStrans `IceSession' } -> `()' #} + +iceRemoteInfo :: IceSession -> IO IceRemoteInfo +iceRemoteInfo sess = do + let maxlen = 128 + maxcand = 29 + + allocaBytes maxlen $ \ufrag -> + allocaBytes maxlen $ \pass -> + allocaBytes maxlen $ \def -> + allocaBytes (maxcand*maxlen) $ \bytes -> + allocaArray maxcand $ \carr -> do + let cptrs = take maxcand $ iterate (`plusPtr` maxlen) bytes + pokeArray carr $ take maxcand cptrs + + ncand <- {#call ice_encode_session #} (isStrans sess) ufrag pass def carr (fromIntegral maxlen) (fromIntegral maxcand) + if ncand < 0 then fail "failed to generate ICE remote info" + else IceRemoteInfo + <$> (T.pack <$> peekCString ufrag) + <*> (T.pack <$> peekCString pass) + <*> (T.pack <$> peekCString def) + <*> (mapM (return . T.pack <=< peekCString) $ take (fromIntegral ncand) cptrs) + +iceShow :: IceSession -> IO String +iceShow sess = do + st <- memoryStorage + return . drop 1 . dropWhile (/='\n') . BLC.unpack . runIdentity =<< + ioLoadBytes =<< store st =<< iceRemoteInfo sess + +iceConnect :: IceSession -> IceRemoteInfo -> (IO ()) -> IO () +iceConnect sess remote cb = do + cbptr <- newStablePtr $ cb + ice_connect sess cbptr + (iriUsernameFrament remote) + (iriPassword remote) + (iriDefaultCandidate remote) + (iriCandidates remote) + +{#fun ice_connect { isStrans `IceSession', castStablePtrToPtr `StablePtr (IO ())', + withText* `Text', withText* `Text', withText* `Text', withTextArray* `[Text]'& } -> `()' #} + +withText :: Text -> (Ptr CChar -> IO a) -> IO a +withText t f = useAsCString (T.encodeUtf8 t) f + +withTextArray :: Num n => [Text] -> ((Ptr (Ptr CChar), n) -> IO ()) -> IO () +withTextArray tsAll f = helper tsAll [] + where helper (t:ts) bs = withText t $ \b -> helper ts (b:bs) + helper [] bs = allocaArray (length bs) $ \ptr -> do + pokeArray ptr $ reverse bs + f (ptr, fromIntegral $ length bs) + +withByteStringLen :: Num n => ByteString -> ((Ptr CChar, n) -> IO a) -> IO a +withByteStringLen t f = unsafeUseAsCStringLen t (f . (id *** fromIntegral)) + +{#fun ice_send as ^ { isStrans `IceSession', withByteStringLen* `ByteString'& } -> `()' #} + +foreign export ccall ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb = join . deRefStablePtr + +iceSetChan :: IceSession -> Flow Void ByteString -> IO () +iceSetChan sess chan = do + modifyMVar_ (isChan sess) $ \orig -> do + case orig of + Left buf -> mapM_ (writeFlowIO chan) $ reverse buf + Right _ -> return () + return $ Right chan + +foreign export ccall ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data sptr buf len = do + sess <- deRefStablePtr sptr + bs <- packCStringLen (buf, len) + modifyMVar_ (isChan sess) $ \case + mc@(Right chan) -> writeFlowIO chan bs >> return mc + Left bss -> return $ Left (bs:bss) diff --git a/src/Erebos/ICE/pjproject.c b/src/Erebos/ICE/pjproject.c new file mode 100644 index 0000000..bb06b1f --- /dev/null +++ b/src/Erebos/ICE/pjproject.c @@ -0,0 +1,363 @@ +#include "pjproject.h" +#include "Erebos/ICE_stub.h" + +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <pthread.h> +#include <pjlib.h> +#include <pjlib-util.h> + +static struct +{ + pj_caching_pool cp; + pj_pool_t * pool; + pj_ice_strans_cfg cfg; + pj_sockaddr def_addr; +} ice; + +struct user_data +{ + pj_ice_sess_role role; + HsStablePtr sptr; + HsStablePtr cb_init; + HsStablePtr cb_connect; +}; + +static void ice_perror(const char * msg, pj_status_t status) +{ + char err[PJ_ERR_MSG_SIZE]; + pj_strerror(status, err, sizeof(err)); + fprintf(stderr, "ICE: %s: %s\n", msg, err); +} + +static int ice_worker_thread(void * unused) +{ + PJ_UNUSED_ARG(unused); + + while (true) { + pj_time_val max_timeout = { 0, 0 }; + pj_time_val timeout = { 0, 0 }; + + max_timeout.msec = 500; + + pj_timer_heap_poll(ice.cfg.stun_cfg.timer_heap, &timeout); + + pj_assert(timeout.sec >= 0 && timeout.msec >= 0); + if (timeout.msec >= 1000) + timeout.msec = 999; + + if (PJ_TIME_VAL_GT(timeout, max_timeout)) + timeout = max_timeout; + + int c = pj_ioqueue_poll(ice.cfg.stun_cfg.ioqueue, &timeout); + if (c < 0) + pj_thread_sleep(PJ_TIME_VAL_MSEC(timeout)); + } + + return 0; +} + +static void cb_on_rx_data(pj_ice_strans * strans, unsigned comp_id, + void * pkt, pj_size_t size, + const pj_sockaddr_t * src_addr, unsigned src_addr_len) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + ice_rx_data(udata->sptr, pkt, size); +} + +static void cb_on_ice_complete(pj_ice_strans * strans, + pj_ice_strans_op op, pj_status_t status) +{ + if (status != PJ_SUCCESS) { + ice_perror("cb_on_ice_complete", status); + ice_destroy(strans); + return; + } + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (op == PJ_ICE_STRANS_OP_INIT) { + pj_status_t istatus = pj_ice_strans_init_ice(strans, udata->role, NULL, NULL); + if (istatus != PJ_SUCCESS) + ice_perror("error creating session", istatus); + + if (udata->cb_init) { + ice_call_cb(udata->cb_init); + hs_free_stable_ptr(udata->cb_init); + udata->cb_init = NULL; + } + } + + if (op == PJ_ICE_STRANS_OP_NEGOTIATION) { + if (udata->cb_connect) { + ice_call_cb(udata->cb_connect); + hs_free_stable_ptr(udata->cb_connect); + udata->cb_connect = NULL; + } + } +} + +static void ice_init(void) +{ + static bool done = false; + static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; + pthread_mutex_lock(&mutex); + + if (done) { + pthread_mutex_unlock(&mutex); + goto exit; + } + + pj_log_set_level(1); + + if (pj_init() != PJ_SUCCESS) { + fprintf(stderr, "pj_init failed\n"); + goto exit; + } + if (pjlib_util_init() != PJ_SUCCESS) { + fprintf(stderr, "pjlib_util_init failed\n"); + goto exit; + } + if (pjnath_init() != PJ_SUCCESS) { + fprintf(stderr, "pjnath_init failed\n"); + goto exit; + } + + pj_caching_pool_init(&ice.cp, NULL, 0); + + pj_ice_strans_cfg_default(&ice.cfg); + ice.cfg.stun_cfg.pf = &ice.cp.factory; + + ice.pool = pj_pool_create(&ice.cp.factory, "ice", 512, 512, NULL); + + if (pj_timer_heap_create(ice.pool, 100, + &ice.cfg.stun_cfg.timer_heap) != PJ_SUCCESS) { + fprintf(stderr, "pj_timer_heap_create failed\n"); + goto exit; + } + + if (pj_ioqueue_create(ice.pool, 16, &ice.cfg.stun_cfg.ioqueue) != PJ_SUCCESS) { + fprintf(stderr, "pj_ioqueue_create failed\n"); + goto exit; + } + + pj_thread_t * thread; + if (pj_thread_create(ice.pool, "ice", &ice_worker_thread, + NULL, 0, 0, &thread) != PJ_SUCCESS) { + fprintf(stderr, "pj_thread_create failed\n"); + goto exit; + } + + ice.cfg.af = pj_AF_INET(); + ice.cfg.opt.aggressive = PJ_TRUE; + + ice.cfg.stun.server.ptr = "discovery1.erebosprotocol.net"; + ice.cfg.stun.server.slen = strlen(ice.cfg.stun.server.ptr); + ice.cfg.stun.port = 29670; + + ice.cfg.turn.server = ice.cfg.stun.server; + ice.cfg.turn.port = ice.cfg.stun.port; + ice.cfg.turn.auth_cred.type = PJ_STUN_AUTH_CRED_STATIC; + ice.cfg.turn.auth_cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN; + ice.cfg.turn.conn_type = PJ_TURN_TP_UDP; + +exit: + done = true; + pthread_mutex_unlock(&mutex); +} + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb) +{ + ice_init(); + + pj_ice_strans * res; + + struct user_data * udata = malloc(sizeof(struct user_data)); + udata->role = role; + udata->sptr = sptr; + udata->cb_init = cb; + + pj_ice_strans_cb icecb = { + .on_rx_data = cb_on_rx_data, + .on_ice_complete = cb_on_ice_complete, + }; + + pj_status_t status = pj_ice_strans_create(NULL, &ice.cfg, 1, + udata, &icecb, &res); + + if (status != PJ_SUCCESS) + ice_perror("error creating ice", status); + + return res; +} + +void ice_destroy(pj_ice_strans * strans) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (udata->sptr) + hs_free_stable_ptr(udata->sptr); + if (udata->cb_init) + hs_free_stable_ptr(udata->cb_init); + if (udata->cb_connect) + hs_free_stable_ptr(udata->cb_connect); + free(udata); + + pj_ice_strans_stop_ice(strans); + pj_ice_strans_destroy(strans); +} + +ssize_t ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand) +{ + int n; + pj_str_t local_ufrag, local_pwd; + pj_status_t status; + + pj_ice_strans_get_ufrag_pwd(strans, &local_ufrag, &local_pwd, NULL, NULL); + + n = snprintf(ufrag, maxlen, "%.*s", (int) local_ufrag.slen, local_ufrag.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + n = snprintf(pass, maxlen, "%.*s", (int) local_pwd.slen, local_pwd.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + pj_ice_sess_cand cand[PJ_ICE_ST_MAX_CAND]; + char ipaddr[PJ_INET6_ADDRSTRLEN]; + + status = pj_ice_strans_get_def_cand(strans, 1, &cand[0]); + if (status != PJ_SUCCESS) + return -status; + + n = snprintf(def, maxlen, "%s %d", + pj_sockaddr_print(&cand[0].addr, ipaddr, sizeof(ipaddr), 0), + (int) pj_sockaddr_get_port(&cand[0].addr)); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + unsigned cand_cnt = PJ_ARRAY_SIZE(cand); + status = pj_ice_strans_enum_cands(strans, 1, &cand_cnt, cand); + if (status != PJ_SUCCESS) + return -status; + + for (unsigned i = 0; i < cand_cnt && i < maxcand; i++) { + char ipaddr[PJ_INET6_ADDRSTRLEN]; + n = snprintf(candidates[i], maxlen, + "%.*s %u %s %u %s", + (int) cand[i].foundation.slen, cand[i].foundation.ptr, + cand[i].prio, + pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0), + (unsigned) pj_sockaddr_get_port(&cand[i].addr), + pj_ice_get_cand_type_name(cand[i].type)); + + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + } + + return cand_cnt; +} + +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * tcandidates[], size_t ncand) +{ + unsigned def_port = 0; + char def_addr[80]; + pj_bool_t done = PJ_FALSE; + char line[256]; + pj_ice_sess_cand candidates[PJ_ICE_ST_MAX_CAND]; + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + udata->cb_connect = cb; + + def_addr[0] = '\0'; + + if (ncand == 0) { + fprintf(stderr, "ICE: no candidates\n"); + return; + } + + int cnt = sscanf(defcand, "%s %u", def_addr, &def_port); + if (cnt != 2) { + fprintf(stderr, "ICE: error parsing default candidate\n"); + return; + } + + int okcand = 0; + for (int i = 0; i < ncand; i++) { + char foundation[32], ipaddr[80], type[32]; + int prio, port; + + int cnt = sscanf(tcandidates[i], "%s %d %s %d %s", + foundation, &prio, + ipaddr, &port, + type); + if (cnt != 5) + continue; + + pj_ice_sess_cand * cand = &candidates[okcand]; + pj_bzero(cand, sizeof(*cand)); + + if (strcmp(type, "host") == 0) + cand->type = PJ_ICE_CAND_TYPE_HOST; + else if (strcmp(type, "srflx") == 0) + cand->type = PJ_ICE_CAND_TYPE_SRFLX; + else if (strcmp(type, "relay") == 0) + cand->type = PJ_ICE_CAND_TYPE_RELAYED; + else + continue; + + cand->comp_id = 1; + pj_strdup2(ice.pool, &cand->foundation, foundation); + cand->prio = prio; + + int af = strchr(ipaddr, ':') ? pj_AF_INET6() : pj_AF_INET(); + pj_str_t tmpaddr = pj_str(ipaddr); + pj_sockaddr_init(af, &cand->addr, NULL, 0); + pj_status_t status = pj_sockaddr_set_str_addr(af, &cand->addr, &tmpaddr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid IP address \"%s\"\n", ipaddr); + continue; + } + + pj_sockaddr_set_port(&cand->addr, (pj_uint16_t)port); + okcand++; + } + + pj_str_t tmp_addr; + pj_status_t status; + + int af = strchr(def_addr, ':') ? pj_AF_INET6() : pj_AF_INET(); + + pj_sockaddr_init(af, &ice.def_addr, NULL, 0); + tmp_addr = pj_str(def_addr); + status = pj_sockaddr_set_str_addr(af, &ice.def_addr, &tmp_addr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid default IP address \"%s\"\n", def_addr); + return; + } + pj_sockaddr_set_port(&ice.def_addr, (pj_uint16_t) def_port); + + pj_str_t rufrag, rpwd; + status = pj_ice_strans_start_ice(strans, + pj_cstr(&rufrag, ufrag), pj_cstr(&rpwd, pass), + okcand, candidates); + if (status != PJ_SUCCESS) { + ice_perror("error starting ICE", status); + return; + } +} + +void ice_send(pj_ice_strans * strans, const char * data, size_t len) +{ + if (!pj_ice_strans_sess_is_complete(strans)) { + fprintf(stderr, "ICE: negotiation has not been started or is in progress\n"); + return; + } + + pj_status_t status = pj_ice_strans_sendto(strans, 1, data, len, + &ice.def_addr, pj_sockaddr_get_len(&ice.def_addr)); + if (status != PJ_SUCCESS && status != PJ_EPENDING) + ice_perror("error sending data", status); +} diff --git a/src/Erebos/ICE/pjproject.h b/src/Erebos/ICE/pjproject.h new file mode 100644 index 0000000..e230e75 --- /dev/null +++ b/src/Erebos/ICE/pjproject.h @@ -0,0 +1,14 @@ +#pragma once + +#include <pjnath.h> +#include <HsFFI.h> + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb); +void ice_destroy(pj_ice_strans * strans); + +ssize_t ice_encode_session(pj_ice_strans *, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand); +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * candidates[], size_t ncand); +void ice_send(pj_ice_strans *, const char * data, size_t len); diff --git a/src/Erebos/Identity.hs b/src/Erebos/Identity.hs new file mode 100644 index 0000000..f2094f6 --- /dev/null +++ b/src/Erebos/Identity.hs @@ -0,0 +1,395 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Erebos.Identity ( + Identity, ComposedIdentity, UnifiedIdentity, + IdentityData(..), ExtendedIdentityData(..), IdentityExtension(..), + idData, idDataF, idExtData, idExtDataF, + idName, idOwner, idUpdates, idKeyIdentity, idKeyMessage, + eiddBase, eiddStoredBase, + eiddName, eiddOwner, eiddKeyIdentity, eiddKeyMessage, + + emptyIdentityData, + emptyIdentityExtension, + createIdentity, + validateIdentity, validateIdentityF, validateIdentityFE, + validateExtendedIdentity, validateExtendedIdentityF, validateExtendedIdentityFE, + loadIdentity, loadUnifiedIdentity, + + mergeIdentity, toUnifiedIdentity, toComposedIdentity, + updateIdentity, updateOwners, + sameIdentity, + + unfoldOwners, + finalOwner, + displayIdentity, +) where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity qualified as I +import Control.Monad.Reader + +import Data.Either +import Data.Foldable +import Data.Function +import Data.List +import Data.Maybe +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.PubKey +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Util + +data Identity m = IdentityKind m => Identity + { idData_ :: m (Stored (Signed ExtendedIdentityData)) + , idName_ :: Maybe Text + , idOwner_ :: Maybe ComposedIdentity + , idUpdates_ :: [Stored (Signed ExtendedIdentityData)] + , idKeyIdentity_ :: Stored PublicKey + , idKeyMessage_ :: Stored PublicKey + } + +deriving instance Show (m (Stored (Signed ExtendedIdentityData))) => Show (Identity m) + +class (Functor f, Foldable f) => IdentityKind f where + ikFilterAncestors :: Storable a => f (Stored a) -> f (Stored a) + +instance IdentityKind I.Identity where + ikFilterAncestors = id + +instance IdentityKind [] where + ikFilterAncestors = filterAncestors + +type ComposedIdentity = Identity [] +type UnifiedIdentity = Identity I.Identity + +instance Eq (m (Stored (Signed ExtendedIdentityData))) => Eq (Identity m) where + (==) = (==) `on` (idData_ &&& idUpdates_) + +data IdentityData = IdentityData + { iddPrev :: [Stored (Signed IdentityData)] + , iddName :: Maybe Text + , iddOwner :: Maybe (Stored (Signed IdentityData)) + , iddKeyIdentity :: Stored PublicKey + , iddKeyMessage :: Maybe (Stored PublicKey) + } + deriving (Show) + +data IdentityExtension = IdentityExtension + { idePrev :: [Stored (Signed ExtendedIdentityData)] + , ideBase :: Stored (Signed IdentityData) + , ideName :: Maybe Text + , ideOwner :: Maybe (Stored (Signed ExtendedIdentityData)) + } + deriving (Show) + +data ExtendedIdentityData = BaseIdentityData IdentityData + | ExtendedIdentityData IdentityExtension + deriving (Show) + +baseToExtended :: Stored (Signed IdentityData) -> Stored (Signed ExtendedIdentityData) +baseToExtended = unsafeMapStored (unsafeMapSigned BaseIdentityData) + +instance Storable IdentityData where + store' idt = storeRec $ do + mapM_ (storeRef "SPREV") $ iddPrev idt + storeMbText "name" $ iddName idt + storeMbRef "owner" $ iddOwner idt + storeRef "key-id" $ iddKeyIdentity idt + storeMbRef "key-msg" $ iddKeyMessage idt + + load' = loadRec $ IdentityData + <$> loadRefs "SPREV" + <*> loadMbText "name" + <*> loadMbRef "owner" + <*> loadRef "key-id" + <*> loadMbRef "key-msg" + +instance Storable IdentityExtension where + store' IdentityExtension {..} = storeRec $ do + mapM_ (storeRef "SPREV") idePrev + storeRef "SBASE" ideBase + storeMbText "name" ideName + storeMbRef "owner" ideOwner + + load' = loadRec $ IdentityExtension + <$> loadRefs "SPREV" + <*> loadRef "SBASE" + <*> loadMbText "name" + <*> loadMbRef "owner" + +instance Storable ExtendedIdentityData where + store' (BaseIdentityData idata) = store' idata + store' (ExtendedIdentityData idata) = store' idata + + load' = msum + [ BaseIdentityData <$> load' + , ExtendedIdentityData <$> load' + ] + +instance Mergeable (Maybe ComposedIdentity) where + type Component (Maybe ComposedIdentity) = Signed ExtendedIdentityData + mergeSorted = validateExtendedIdentityF + toComponents = maybe [] idExtDataF + +idData :: UnifiedIdentity -> Stored (Signed IdentityData) +idData = I.runIdentity . idDataF + +idDataF :: Identity m -> m (Stored (Signed IdentityData)) +idDataF idt@Identity {} = ikFilterAncestors . fmap eiddStoredBase . idData_ $ idt + +idExtData :: UnifiedIdentity -> Stored (Signed ExtendedIdentityData) +idExtData = I.runIdentity . idExtDataF + +idExtDataF :: Identity m -> m (Stored (Signed ExtendedIdentityData)) +idExtDataF = idData_ + +idName :: Identity m -> Maybe Text +idName = idName_ + +idOwner :: Identity m -> Maybe ComposedIdentity +idOwner = idOwner_ + +idUpdates :: Identity m -> [Stored (Signed ExtendedIdentityData)] +idUpdates = idUpdates_ + +idKeyIdentity :: Identity m -> Stored PublicKey +idKeyIdentity = idKeyIdentity_ + +idKeyMessage :: Identity m -> Stored PublicKey +idKeyMessage = idKeyMessage_ + +eiddPrev :: ExtendedIdentityData -> [Stored (Signed ExtendedIdentityData)] +eiddPrev (BaseIdentityData idata) = baseToExtended <$> iddPrev idata +eiddPrev (ExtendedIdentityData IdentityExtension {..}) = baseToExtended ideBase : idePrev + +eiddBase :: ExtendedIdentityData -> IdentityData +eiddBase (BaseIdentityData idata) = idata +eiddBase (ExtendedIdentityData IdentityExtension {..}) = fromSigned ideBase + +eiddStoredBase :: Stored (Signed ExtendedIdentityData) -> Stored (Signed IdentityData) +eiddStoredBase ext = case fromSigned ext of + (BaseIdentityData idata) -> unsafeMapStored (unsafeMapSigned (const idata)) ext + (ExtendedIdentityData IdentityExtension {..}) -> ideBase + +eiddName :: ExtendedIdentityData -> Maybe Text +eiddName (BaseIdentityData idata) = iddName idata +eiddName (ExtendedIdentityData IdentityExtension {..}) = ideName + +eiddOwner :: ExtendedIdentityData -> Maybe (Stored (Signed ExtendedIdentityData)) +eiddOwner (BaseIdentityData idata) = baseToExtended <$> iddOwner idata +eiddOwner (ExtendedIdentityData IdentityExtension {..}) = ideOwner + +eiddKeyIdentity :: ExtendedIdentityData -> Stored PublicKey +eiddKeyIdentity = iddKeyIdentity . eiddBase + +eiddKeyMessage :: ExtendedIdentityData -> Maybe (Stored PublicKey) +eiddKeyMessage = iddKeyMessage . eiddBase + + +emptyIdentityData :: Stored PublicKey -> IdentityData +emptyIdentityData key = IdentityData + { iddName = Nothing + , iddPrev = [] + , iddOwner = Nothing + , iddKeyIdentity = key + , iddKeyMessage = Nothing + } + +emptyIdentityExtension :: Stored (Signed IdentityData) -> IdentityExtension +emptyIdentityExtension base = IdentityExtension + { idePrev = [] + , ideBase = base + , ideName = Nothing + , ideOwner = Nothing + } + +isExtension :: Stored (Signed ExtendedIdentityData) -> Bool +isExtension x = case fromSigned x of BaseIdentityData {} -> False + _ -> True + + +createIdentity :: Storage -> Maybe Text -> Maybe UnifiedIdentity -> IO UnifiedIdentity +createIdentity st name owner = do + (secret, public) <- generateKeys st + (_secretMsg, publicMsg) <- generateKeys st + + let signOwner :: Signed a -> ReaderT Storage IO (Signed a) + signOwner idd + | Just o <- owner = do + Just ownerSecret <- loadKeyMb (iddKeyIdentity $ fromSigned $ idData o) + signAdd ownerSecret idd + | otherwise = return idd + + Just identity <- flip runReaderT st $ do + baseData <- mstore =<< signOwner =<< sign secret =<< + mstore (emptyIdentityData public) + { iddOwner = idData <$> owner + , iddKeyMessage = Just publicMsg + } + let extOwner = do + odata <- idExtData <$> owner + guard $ isExtension odata + return odata + + validateExtendedIdentityF . I.Identity <$> + if isJust name || isJust extOwner + then mstore =<< signOwner =<< sign secret =<< + mstore . ExtendedIdentityData =<< return (emptyIdentityExtension baseData) + { ideName = name + , ideOwner = extOwner + } + else return $ baseToExtended baseData + return identity + +validateIdentity :: Stored (Signed IdentityData) -> Maybe UnifiedIdentity +validateIdentity = validateIdentityF . I.Identity + +validateIdentityF :: IdentityKind m => m (Stored (Signed IdentityData)) -> Maybe (Identity m) +validateIdentityF = either (const Nothing) Just . runExcept . validateIdentityFE + +validateIdentityFE :: IdentityKind m => m (Stored (Signed IdentityData)) -> Except String (Identity m) +validateIdentityFE = validateExtendedIdentityFE . fmap baseToExtended + +validateExtendedIdentity :: Stored (Signed ExtendedIdentityData) -> Maybe UnifiedIdentity +validateExtendedIdentity = validateExtendedIdentityF . I.Identity + +validateExtendedIdentityF :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Maybe (Identity m) +validateExtendedIdentityF = either (const Nothing) Just . runExcept . validateExtendedIdentityFE + +validateExtendedIdentityFE :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Except String (Identity m) +validateExtendedIdentityFE mdata = do + let idata = ikFilterAncestors mdata + when (null idata) $ throwError "null data" + mapM_ verifySignatures $ gatherPrevious S.empty $ toList idata + Identity + <$> pure idata + <*> pure (lookupProperty eiddName idata) + <*> case lookupProperty eiddOwner idata of + Nothing -> return Nothing + Just owner -> return <$> validateExtendedIdentityFE [owner] + <*> pure [] + <*> pure (eiddKeyIdentity $ fromSigned $ minimum idata) + <*> case lookupProperty eiddKeyMessage idata of + Nothing -> throwError "no message key" + Just mk -> return mk + +loadIdentity :: String -> LoadRec ComposedIdentity +loadIdentity name = maybe (throwError "identity validation failed") return . validateExtendedIdentityF =<< loadRefs name + +loadUnifiedIdentity :: String -> LoadRec UnifiedIdentity +loadUnifiedIdentity name = maybe (throwError "identity validation failed") return . validateExtendedIdentity =<< loadRef name + + +gatherPrevious :: Set (Stored (Signed ExtendedIdentityData)) -> [Stored (Signed ExtendedIdentityData)] -> Set (Stored (Signed ExtendedIdentityData)) +gatherPrevious res (n:ns) | n `S.member` res = gatherPrevious res ns + | otherwise = gatherPrevious (S.insert n res) $ (eiddPrev $ fromSigned n) ++ ns +gatherPrevious res [] = res + +verifySignatures :: Stored (Signed ExtendedIdentityData) -> Except String () +verifySignatures sidd = do + let idd = fromSigned sidd + required = concat + [ [ eiddKeyIdentity idd ] + , map (eiddKeyIdentity . fromSigned) $ eiddPrev idd + , map (eiddKeyIdentity . fromSigned) $ toList $ eiddOwner idd + ] + unless (all (fromStored sidd `isSignedBy`) required) $ do + throwError "signature verification failed" + +lookupProperty :: forall a m. Foldable m => (ExtendedIdentityData -> Maybe a) -> m (Stored (Signed ExtendedIdentityData)) -> Maybe a +lookupProperty sel topHeads = findResult propHeads + where + findPropHeads :: Stored (Signed ExtendedIdentityData) -> [ Stored (Signed ExtendedIdentityData) ] + findPropHeads sobj | Just _ <- sel $ fromSigned sobj = [ sobj ] + | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj) + + propHeads :: [ Stored (Signed ExtendedIdentityData) ] + propHeads = filterAncestors $ findPropHeads =<< toList topHeads + + findResult :: [ Stored (Signed ExtendedIdentityData) ] -> Maybe a + findResult [] = Nothing + findResult xs = sel $ fromSigned $ minimum xs + +mergeIdentity :: (MonadStorage m, MonadError String m, MonadIO m) => Identity f -> m UnifiedIdentity +mergeIdentity idt | Just idt' <- toUnifiedIdentity idt = return idt' +mergeIdentity idt@Identity {..} = do + (owner, ownerData) <- case idOwner_ of + Nothing -> return (Nothing, Nothing) + Just cowner | Just owner <- toUnifiedIdentity cowner -> return (Just owner, Nothing) + | otherwise -> do owner <- mergeIdentity cowner + return (Just owner, Just $ idData owner) + + let public = idKeyIdentity idt + secret <- loadKey public + + unifiedBaseData <- + case toList $ idDataF idt of + [idata] -> return idata + idatas -> mstore =<< sign secret =<< mstore (emptyIdentityData public) + { iddPrev = idatas, iddOwner = ownerData } + + case filter isExtension $ toList $ idExtDataF idt of + [] -> return Identity { idData_ = I.Identity (baseToExtended unifiedBaseData), idOwner_ = toComposedIdentity <$> owner, .. } + extdata -> do + unifiedExtendedData <- mstore =<< sign secret =<< + (mstore . ExtendedIdentityData) (emptyIdentityExtension unifiedBaseData) + { idePrev = extdata } + return Identity { idData_ = I.Identity unifiedExtendedData, idOwner_ = toComposedIdentity <$> owner, .. } + + +toUnifiedIdentity :: Identity m -> Maybe UnifiedIdentity +toUnifiedIdentity Identity {..} + | [sdata] <- toList idData_ = Just Identity { idData_ = I.Identity sdata, .. } + | otherwise = Nothing + +toComposedIdentity :: Identity m -> ComposedIdentity +toComposedIdentity Identity {..} = Identity { idData_ = toList idData_ + , idOwner_ = toComposedIdentity <$> idOwner_ + , .. + } + +updateIdentity :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> ComposedIdentity +updateIdentity [] orig = toComposedIdentity orig +updateIdentity updates orig@Identity {} = + case validateExtendedIdentityF $ ourUpdates ++ idata of + Just updated -> updated + { idOwner_ = updateIdentity ownerUpdates <$> idOwner_ updated + , idUpdates_ = ownerUpdates + } + Nothing -> toComposedIdentity orig + where idata = toList $ idData_ orig + idataRoots = foldl' mergeUniq [] $ map storedRoots idata + (ourUpdates, ownerUpdates) = partitionEithers $ flip map (filterAncestors $ updates ++ idUpdates_ orig) $ + -- if an update is related to anything in idData_, use it here, otherwise push to owners + \u -> if storedRoots u `intersectsSorted` idataRoots + then Left u + else Right u + +updateOwners :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> Identity m +updateOwners updates orig@Identity { idOwner_ = Just owner, idUpdates_ = cupdates } = + orig { idOwner_ = Just $ updateIdentity updates owner, idUpdates_ = filterAncestors (updates ++ cupdates) } +updateOwners _ orig@Identity { idOwner_ = Nothing } = orig + +sameIdentity :: (Foldable m, Foldable m') => Identity m -> Identity m' -> Bool +sameIdentity x y = intersectsSorted (roots x) (roots y) + where + roots idt = uniq $ sort $ concatMap storedRoots $ toList $ idDataF idt + + +unfoldOwners :: (Foldable m) => Identity m -> [ComposedIdentity] +unfoldOwners = unfoldr (fmap (\i -> (i, idOwner i))) . Just . toComposedIdentity + +finalOwner :: (Foldable m, Applicative m) => Identity m -> ComposedIdentity +finalOwner = last . unfoldOwners + +displayIdentity :: (Foldable m, Applicative m) => Identity m -> Text +displayIdentity identity = T.concat + [ T.intercalate (T.pack " / ") $ map (fromMaybe (T.pack "<unnamed>") . idName) owners + ] + where owners = reverse $ unfoldOwners identity diff --git a/src/Erebos/Message.hs b/src/Erebos/Message.hs new file mode 100644 index 0000000..5ef27f3 --- /dev/null +++ b/src/Erebos/Message.hs @@ -0,0 +1,272 @@ +module Erebos.Message ( + DirectMessage(..), + sendDirectMessage, + + DirectMessageAttributes(..), + defaultDirectMessageAttributes, + + DirectMessageThreads, + toThreadList, + + DirectMessageThread(..), + threadToList, + messageThreadView, + + watchReceivedMessages, + formatMessage, + formatDirectMessage, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.List +import Data.Ord +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Time.Format +import Data.Time.LocalTime + +import Erebos.Identity +import Erebos.Network +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data DirectMessage = DirectMessage + { msgFrom :: ComposedIdentity + , msgPrev :: [Stored DirectMessage] + , msgTime :: ZonedTime + , msgText :: Text + } + +instance Storable DirectMessage where + store' msg = storeRec $ do + mapM_ (storeRef "from") $ idExtDataF $ msgFrom msg + mapM_ (storeRef "PREV") $ msgPrev msg + storeDate "time" $ msgTime msg + storeText "text" $ msgText msg + + load' = loadRec $ DirectMessage + <$> loadIdentity "from" + <*> loadRefs "PREV" + <*> loadDate "time" + <*> loadText "text" + +data DirectMessageAttributes = DirectMessageAttributes + { dmOwnerMismatch :: ServiceHandler DirectMessage () + } + +defaultDirectMessageAttributes :: DirectMessageAttributes +defaultDirectMessageAttributes = DirectMessageAttributes + { dmOwnerMismatch = svcPrint "Owner mismatch" + } + +instance Service DirectMessage where + serviceID _ = mkServiceID "c702076c-4928-4415-8b6b-3e839eafcb0d" + + type ServiceAttributes DirectMessage = DirectMessageAttributes + defaultServiceAttributes _ = defaultDirectMessageAttributes + + serviceHandler smsg = do + let msg = fromStored smsg + powner <- asks $ finalOwner . svcPeerIdentity + erb <- svcGetLocal + st <- getStorage + let DirectMessageThreads prev _ = lookupSharedValue $ lsShared $ fromStored erb + sent = findMsgProperty powner msSent prev + received = findMsgProperty powner msReceived prev + received' = filterAncestors $ smsg : received + if powner `sameIdentity` msgFrom msg || + filterAncestors sent == filterAncestors (smsg : sent) + then do + when (received' /= received) $ do + next <- wrappedStore st $ MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = [] + , msReceived = received' + , msSeen = [] + } + let threads = DirectMessageThreads [next] (messageThreadView [next]) + shared <- makeSharedStateUpdate st threads (lsShared $ fromStored erb) + svcSetLocal =<< wrappedStore st (fromStored erb) { lsShared = [shared] } + + when (powner `sameIdentity` msgFrom msg) $ do + replyStoredRef smsg + + else join $ asks $ dmOwnerMismatch . svcAttributes + + serviceNewPeer = syncDirectMessageToPeer . lookupSharedValue . lsShared . fromStored =<< svcGetLocal + + serviceStorageWatchers _ = (:[]) $ + SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncDirectMessageToPeer + + +data MessageState = MessageState + { msPrev :: [Stored MessageState] + , msPeer :: ComposedIdentity + , msReady :: [Stored DirectMessage] + , msSent :: [Stored DirectMessage] + , msReceived :: [Stored DirectMessage] + , msSeen :: [Stored DirectMessage] + } + +data DirectMessageThreads = DirectMessageThreads [Stored MessageState] [DirectMessageThread] + +instance Eq DirectMessageThreads where + DirectMessageThreads mss _ == DirectMessageThreads mss' _ = mss == mss' + +toThreadList :: DirectMessageThreads -> [DirectMessageThread] +toThreadList (DirectMessageThreads _ threads) = threads + +instance Storable MessageState where + store' MessageState {..} = storeRec $ do + mapM_ (storeRef "PREV") msPrev + mapM_ (storeRef "peer") $ idExtDataF msPeer + mapM_ (storeRef "ready") msReady + mapM_ (storeRef "sent") msSent + mapM_ (storeRef "received") msReceived + mapM_ (storeRef "seen") msSeen + + load' = loadRec $ do + msPrev <- loadRefs "PREV" + msPeer <- loadIdentity "peer" + msReady <- loadRefs "ready" + msSent <- loadRefs "sent" + msReceived <- loadRefs "received" + msSeen <- loadRefs "seen" + return MessageState {..} + +instance Mergeable DirectMessageThreads where + type Component DirectMessageThreads = MessageState + mergeSorted mss = DirectMessageThreads mss (messageThreadView mss) + toComponents (DirectMessageThreads mss _) = mss + +instance SharedType DirectMessageThreads where + sharedTypeID _ = mkSharedTypeID "ee793681-5976-466a-b0f0-4e1907d3fade" + +findMsgProperty :: Foldable m => Identity m -> (MessageState -> [a]) -> [Stored MessageState] -> [a] +findMsgProperty pid sel mss = concat $ flip findProperty mss $ \x -> do + guard $ msPeer x `sameIdentity` pid + guard $ not $ null $ sel x + return $ sel x + + +sendDirectMessage :: (Foldable f, Applicative f, MonadHead LocalState m, MonadError String m) + => Identity f -> Text -> m (Stored DirectMessage) +sendDirectMessage pid text = updateLocalHead $ \ls -> do + let self = localIdentity $ fromStored ls + powner = finalOwner pid + flip updateSharedState ls $ \(DirectMessageThreads prev _) -> do + let ready = findMsgProperty powner msReady prev + received = findMsgProperty powner msReceived prev + + time <- liftIO getZonedTime + smsg <- mstore DirectMessage + { msgFrom = toComposedIdentity $ finalOwner self + , msgPrev = filterAncestors $ ready ++ received + , msgTime = time + , msgText = text + } + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [smsg] + , msSent = [] + , msReceived = [] + , msSeen = [] + } + return (DirectMessageThreads [next] (messageThreadView [next]), smsg) + +syncDirectMessageToPeer :: DirectMessageThreads -> ServiceHandler DirectMessage () +syncDirectMessageToPeer (DirectMessageThreads mss _) = do + pid <- finalOwner <$> asks svcPeerIdentity + peer <- asks svcPeer + let thread = messageThreadFor pid mss + mapM_ (sendToPeerStored peer) $ msgHead thread + updateLocalHead_ $ \ls -> do + let powner = finalOwner pid + flip updateSharedState_ ls $ \unchanged@(DirectMessageThreads prev _) -> do + let ready = findMsgProperty powner msReady prev + sent = findMsgProperty powner msSent prev + sent' = filterAncestors (ready ++ sent) + + if sent' /= sent + then do + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = sent' + , msReceived = [] + , msSeen = [] + } + return $ DirectMessageThreads [next] (messageThreadView [next]) + else do + return unchanged + + +data DirectMessageThread = DirectMessageThread + { msgPeer :: ComposedIdentity + , msgHead :: [Stored DirectMessage] + , msgSent :: [Stored DirectMessage] + , msgSeen :: [Stored DirectMessage] + } + +threadToList :: DirectMessageThread -> [DirectMessage] +threadToList thread = helper S.empty $ msgHead thread + where helper seen msgs + | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = + fromStored msg : helper (S.insert msg seen) (msgs' ++ msgPrev (fromStored msg)) + | otherwise = [] + cmpView msg = (zonedTimeToUTC $ msgTime $ fromStored msg, msg) + +messageThreadView :: [Stored MessageState] -> [DirectMessageThread] +messageThreadView = helper [] + where helper used ms' = case filterAncestors ms' of + mss@(sms : rest) + | any (sameIdentity $ msPeer $ fromStored sms) used -> + helper used $ msPrev (fromStored sms) ++ rest + | otherwise -> + let peer = msPeer $ fromStored sms + in messageThreadFor peer mss : helper (peer : used) (msPrev (fromStored sms) ++ rest) + _ -> [] + +messageThreadFor :: ComposedIdentity -> [Stored MessageState] -> DirectMessageThread +messageThreadFor peer mss = + let ready = findMsgProperty peer msReady mss + sent = findMsgProperty peer msSent mss + received = findMsgProperty peer msReceived mss + seen = findMsgProperty peer msSeen mss + + in DirectMessageThread + { msgPeer = peer + , msgHead = filterAncestors $ ready ++ received + , msgSent = filterAncestors $ sent ++ received + , msgSeen = filterAncestors $ ready ++ seen + } + + +watchReceivedMessages :: Head LocalState -> (Stored DirectMessage -> IO ()) -> IO WatchedHead +watchReceivedMessages h f = do + let self = finalOwner $ localIdentity $ headObject h + watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \(DirectMessageThreads sms _) -> do + forM_ (map fromStored sms) $ \ms -> do + mapM_ f $ filter (not . sameIdentity self . msgFrom . fromStored) $ msReceived ms + +{-# DEPRECATED formatMessage "use formatDirectMessage instead" #-} +formatMessage :: TimeZone -> DirectMessage -> String +formatMessage = formatDirectMessage + +formatDirectMessage :: TimeZone -> DirectMessage -> String +formatDirectMessage tzone msg = concat + [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg + , maybe "<unnamed>" T.unpack $ idName $ msgFrom msg + , ": " + , T.unpack $ msgText msg + ] diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs new file mode 100644 index 0000000..2064d1c --- /dev/null +++ b/src/Erebos/Network.hs @@ -0,0 +1,981 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Network ( + Server, + startServer, + stopServer, + getCurrentPeerList, + getNextPeerChange, + ServerOptions(..), serverIdentity, defaultServerOptions, + + Peer, peerServer, peerStorage, + PeerAddress(..), peerAddress, + PeerIdentity(..), peerIdentity, + WaitingRef, wrDigest, + Service(..), + serverPeer, +#ifdef ENABLE_ICE_SUPPORT + serverPeerIce, +#endif + dropPeer, + isPeerDropped, + sendToPeer, sendManyToPeer, + sendToPeerStored, sendManyToPeerStored, + sendToPeerWith, + runPeerService, + + discoveryPort, +) where + +import Control.Concurrent +import Control.Concurrent.STM +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Function +import Data.IP qualified as IP +import Data.List +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe +import Data.Typeable +import Data.Word + +import Foreign.Ptr +import Foreign.Storable + +import GHC.Conc.Sync (unsafeIOToSTM) + +import Network.Socket hiding (ControlMessage) +import qualified Network.Socket.ByteString as S + +import Foreign.C.Types +import Foreign.Marshal.Alloc + +import Erebos.Channel +#ifdef ENABLE_ICE_SUPPORT +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Network.Protocol +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Key +import Erebos.Storage.Merge + + +discoveryPort :: PortNumber +discoveryPort = 29665 + +discoveryMulticastGroup :: HostAddress6 +discoveryMulticastGroup = tupleToHostAddress6 (0xff12, 0xb6a4, 0x6b1f, 0x0969, 0xcaee, 0xacc2, 0x5c93, 0x73e1) -- ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1 + +announceIntervalSeconds :: Int +announceIntervalSeconds = 60 + + +data Server = Server + { serverStorage :: Storage + , serverOrigHead :: Head LocalState + , serverIdentity_ :: MVar UnifiedIdentity + , serverThreads :: MVar [ThreadId] + , serverSocket :: MVar Socket + , serverRawPath :: SymFlow (PeerAddress, BC.ByteString) + , serverControlFlow :: Flow (ControlMessage PeerAddress) (ControlRequest PeerAddress) + , serverDataResponse :: TQueue (Peer, Maybe PartialRef) + , serverIOActions :: TQueue (ExceptT String IO ()) + , serverServices :: [SomeService] + , serverServiceStates :: TMVar (M.Map ServiceID SomeServiceGlobalState) + , serverPeers :: MVar (Map PeerAddress Peer) + , serverChanPeer :: TChan Peer + , serverErrorLog :: TQueue String + } + +serverIdentity :: Server -> IO UnifiedIdentity +serverIdentity = readMVar . serverIdentity_ + +getCurrentPeerList :: Server -> IO [Peer] +getCurrentPeerList = fmap M.elems . readMVar . serverPeers + +getNextPeerChange :: Server -> IO Peer +getNextPeerChange = atomically . readTChan . serverChanPeer + +data ServerOptions = ServerOptions + { serverPort :: PortNumber + , serverLocalDiscovery :: Bool + } + +defaultServerOptions :: ServerOptions +defaultServerOptions = ServerOptions + { serverPort = discoveryPort + , serverLocalDiscovery = True + } + + +data Peer = Peer + { peerAddress :: PeerAddress + , peerServer_ :: Server + , peerState :: TVar PeerState + , peerIdentityVar :: TVar PeerIdentity + , peerStorage_ :: Storage + , peerInStorage :: PartialStorage + , peerServiceState :: TMVar (M.Map ServiceID SomeServiceState) + , peerWaitingRefs :: TMVar [WaitingRef] + } + +peerServer :: Peer -> Server +peerServer = peerServer_ + +peerStorage :: Peer -> Storage +peerStorage = peerStorage_ + +getPeerChannel :: Peer -> STM ChannelState +getPeerChannel Peer {..} = + readTVar peerState >>= \case + PeerInit _ -> return ChannelNone + PeerConnected conn -> connGetChannel conn + PeerDropped -> return ChannelClosed + +setPeerChannel :: Peer -> ChannelState -> STM () +setPeerChannel Peer {..} ch = do + readTVar peerState >>= \case + PeerInit _ -> retry + PeerConnected conn -> connSetChannel conn ch + PeerDropped -> return () + +instance Eq Peer where + (==) = (==) `on` peerIdentityVar + +data PeerAddress = DatagramAddress SockAddr +#ifdef ENABLE_ICE_SUPPORT + | PeerIceSession IceSession +#endif + +instance Show PeerAddress where + show (DatagramAddress saddr) = unwords $ case IP.fromSockAddr saddr of + Just (IP.IPv6 ipv6, port) + | (0, 0, 0xffff, ipv4) <- IP.fromIPv6w ipv6 + -> [show (IP.toIPv4w ipv4), show port] + Just (addr, port) + -> [show addr, show port] + _ -> [show saddr] +#ifdef ENABLE_ICE_SUPPORT + show (PeerIceSession ice) = show ice +#endif + +instance Eq PeerAddress where + DatagramAddress addr == DatagramAddress addr' = addr == addr' +#ifdef ENABLE_ICE_SUPPORT + PeerIceSession ice == PeerIceSession ice' = ice == ice' + _ == _ = False +#endif + +instance Ord PeerAddress where + compare (DatagramAddress addr) (DatagramAddress addr') = compare addr addr' +#ifdef ENABLE_ICE_SUPPORT + compare (DatagramAddress _ ) _ = LT + compare _ (DatagramAddress _ ) = GT + compare (PeerIceSession ice ) (PeerIceSession ice') = compare ice ice' +#endif + + +data PeerIdentity = PeerIdentityUnknown (TVar [UnifiedIdentity -> ExceptT String IO ()]) + | PeerIdentityRef WaitingRef (TVar [UnifiedIdentity -> ExceptT String IO ()]) + | PeerIdentityFull UnifiedIdentity + +peerIdentity :: MonadIO m => Peer -> m PeerIdentity +peerIdentity = liftIO . atomically . readTVar . peerIdentityVar + + +data PeerState = PeerInit [(SecurityRequirement, TransportPacket Ref, [TransportHeaderItem])] + | PeerConnected (Connection PeerAddress) + | PeerDropped + + +lookupServiceType :: [TransportHeaderItem] -> Maybe ServiceID +lookupServiceType (ServiceType stype : _) = Just stype +lookupServiceType (_ : hs) = lookupServiceType hs +lookupServiceType [] = Nothing + +lookupNewStreams :: [TransportHeaderItem] -> [Word8] +lookupNewStreams (StreamOpen num : rest) = num : lookupNewStreams rest +lookupNewStreams (_ : rest) = lookupNewStreams rest +lookupNewStreams [] = [] + + +newWaitingRef :: RefDigest -> (Ref -> WaitingRefCallback) -> PacketHandler WaitingRef +newWaitingRef dgst act = do + peer@Peer {..} <- gets phPeer + wref <- WaitingRef peerStorage_ (partialRefFromDigest peerInStorage dgst) act <$> liftSTM (newTVar (Left [])) + modifyTMVarP peerWaitingRefs (wref:) + liftSTM $ writeTQueue (serverDataResponse $ peerServer peer) (peer, Nothing) + return wref + + +forkServerThread :: Server -> IO () -> IO () +forkServerThread server act = do + modifyMVar_ (serverThreads server) $ \ts -> do + t <- forkIO $ do + t <- myThreadId + act + modifyMVar_ (serverThreads server) $ return . filter (/=t) + return (t:ts) + +startServer :: ServerOptions -> Head LocalState -> (String -> IO ()) -> [SomeService] -> IO Server +startServer opt serverOrigHead logd' serverServices = do + let serverStorage = headStorage serverOrigHead + serverIdentity_ <- newMVar $ headLocalIdentity serverOrigHead + serverThreads <- newMVar [] + serverSocket <- newEmptyMVar + (serverRawPath, protocolRawPath) <- newFlowIO + (serverControlFlow, protocolControlFlow) <- newFlowIO + serverDataResponse <- newTQueueIO + serverIOActions <- newTQueueIO + serverServiceStates <- newTMVarIO M.empty + serverPeers <- newMVar M.empty + serverChanPeer <- newTChanIO + serverErrorLog <- newTQueueIO + let server = Server {..} + + chanSvc <- newTQueueIO + + let logd = writeTQueue serverErrorLog + forkServerThread server $ forever $ do + logd' =<< atomically (readTQueue serverErrorLog) + + forkServerThread server $ dataResponseWorker server + forkServerThread server $ forever $ do + either (atomically . logd) return =<< runExceptT =<< + atomically (readTQueue serverIOActions) + + let open addr = do + sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + putMVar serverSocket sock + setSocketOption sock ReuseAddr 1 + setSocketOption sock Broadcast 1 + withFdSocket sock setCloseOnExecIfNeeded + bind sock (addrAddress addr) + return sock + + loop sock = do + when (serverLocalDiscovery opt) $ forkServerThread server $ do + announceAddreses <- fmap concat $ sequence $ + [ map (SockAddrInet6 discoveryPort 0 discoveryMulticastGroup) <$> joinMulticast sock + , getBroadcastAddresses discoveryPort + ] + forever $ do + atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) announceAddreses + threadDelay $ announceIntervalSeconds * 1000 * 1000 + + let announceUpdate identity = do + st <- derivePartialStorage serverStorage + let selfRef = partialRef st $ storedRef $ idExtData identity + updateRefs = map refDigest $ selfRef : map (partialRef st . storedRef) (idUpdates identity) + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] + hitems = map AnnounceUpdate updateRefs + packet = TransportPacket (TransportHeader $ hitems) [] + + ps <- readMVar serverPeers + forM_ ps $ \peer -> atomically $ do + ((,) <$> readTVar (peerIdentityVar peer) <*> getPeerChannel peer) >>= \case + (PeerIdentityFull _, ChannelEstablished _) -> + sendToPeerS peer ackedBy packet + _ -> return () + + void $ watchHead serverOrigHead $ \h -> do + let idt = headLocalIdentity h + changedId <- modifyMVar serverIdentity_ $ \cur -> + return (idt, cur /= idt) + when changedId $ do + writeFlowIO serverControlFlow $ UpdateSelfIdentity idt + announceUpdate idt + + forM_ serverServices $ \(SomeService service _) -> do + forM_ (serviceStorageWatchers service) $ \(SomeStorageWatcher sel act) -> do + watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do + withMVar serverPeers $ mapM_ $ \peer -> atomically $ do + readTVar (peerIdentityVar peer) >>= \case + PeerIdentityFull _ -> writeTQueue serverIOActions $ do + runPeerService peer $ act x + _ -> return () + + forkServerThread server $ forever $ do + (msg, saddr) <- S.recvFrom sock 4096 + writeFlowIO serverRawPath (DatagramAddress saddr, msg) + + forkServerThread server $ forever $ do + (paddr, msg) <- readFlowIO serverRawPath + handle (\(e :: IOException) -> atomically . logd $ "failed to send packet to " ++ show paddr ++ ": " ++ show e) $ do + case paddr of + DatagramAddress addr -> void $ S.sendTo sock msg addr +#ifdef ENABLE_ICE_SUPPORT + PeerIceSession ice -> iceSend ice msg +#endif + + forkServerThread server $ forever $ do + readFlowIO serverControlFlow >>= \case + NewConnection conn mbpid -> do + let paddr = connAddress conn + peer <- modifyMVar serverPeers $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, peer) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, peer) + + forkServerThread server $ do + atomically $ do + readTVar (peerState peer) >>= \case + PeerInit packets -> do + writeFlowBulk (connData conn) $ reverse packets + writeTVar (peerState peer) (PeerConnected conn) + PeerConnected _ -> do + writeTVar (peerState peer) (PeerConnected conn) + PeerDropped -> do + connClose conn + + case mbpid of + Just dgst -> do + identity <- readMVar serverIdentity_ + atomically $ runPacketHandler False peer $ do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan serverChanPeer peer + _ -> return () + Nothing -> return () + + let peerLoop = readFlowIO (connData conn) >>= \case + Just (secure, TransportPacket header objs) -> do + prefs <- forM objs $ storeObject $ peerInStorage peer + identity <- readMVar serverIdentity_ + let svcs = map someServiceID serverServices + handlePacket identity secure peer chanSvc svcs header prefs + peerLoop + Nothing -> do + dropPeer peer + atomically $ writeTChan serverChanPeer peer + peerLoop + + ReceivedAnnounce addr _ -> do + void $ serverPeer' server addr + + erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd protocolRawPath protocolControlFlow + + forkServerThread server $ withSocketsDo $ do + let hints = defaultHints + { addrFlags = [AI_PASSIVE] + , addrFamily = AF_INET6 + , addrSocketType = Datagram + } + addr:_ <- getAddrInfo (Just hints) Nothing (Just $ show $ serverPort opt) + bracket (open addr) close loop + + forkServerThread server $ forever $ do + (peer, svc, ref) <- atomically $ readTQueue chanSvc + case find ((svc ==) . someServiceID) serverServices of + Just service@(SomeService (_ :: Proxy s) attr) -> runPeerServiceOn (Just (service, attr)) peer (serviceHandler $ wrappedLoad @s ref) + _ -> atomically $ logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + + return server + +stopServer :: Server -> IO () +stopServer Server {..} = do + mapM_ killThread =<< takeMVar serverThreads + +dataResponseWorker :: Server -> IO () +dataResponseWorker server = forever $ do + (peer, npref) <- atomically (readTQueue $ serverDataResponse server) + + wait <- atomically $ takeTMVar (peerWaitingRefs peer) + list <- forM wait $ \wr@WaitingRef { wrefStatus = tvar } -> + atomically (readTVar tvar) >>= \case + Left ds -> case maybe id (filter . (/=) . refDigest) npref $ ds of + [] -> copyRef (wrefStorage wr) (wrefPartial wr) >>= \case + Right ref -> do + atomically (writeTVar tvar $ Right ref) + forkServerThread server $ runExceptT (wrefAction wr ref) >>= \case + Left err -> atomically $ writeTQueue (serverErrorLog server) err + Right () -> return () + + return (Nothing, []) + Left dgst -> do + atomically (writeTVar tvar $ Left [dgst]) + return (Just wr, [dgst]) + ds' -> do + atomically (writeTVar tvar $ Left ds') + return (Just wr, []) + Right _ -> return (Nothing, []) + atomically $ putTMVar (peerWaitingRefs peer) $ catMaybes $ map fst list + + let reqs = concat $ map snd list + when (not $ null reqs) $ do + let packet = TransportPacket (TransportHeader $ map DataRequest reqs) [] + ackedBy = concat [[ Rejected r, DataResponse r ] | r <- reqs ] + atomically $ sendToPeerPlain peer ackedBy packet + + +newtype PacketHandler a = PacketHandler { unPacketHandler :: StateT PacketHandlerState (ExceptT String STM) a } + deriving (Functor, Applicative, Monad, MonadState PacketHandlerState, MonadError String) + +instance MonadFail PacketHandler where + fail = throwError + +runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM () +runPacketHandler secure peer@Peer {..} act = do + let logd = writeTQueue $ serverErrorLog peerServer_ + runExceptT (flip execStateT (PacketHandlerState peer [] [] [] Nothing False) $ unPacketHandler act) >>= \case + Left err -> do + logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err + Right ph -> do + when (not $ null $ phHead ph) $ do + body <- case phBodyStream ph of + Nothing -> return $ phBody ph + Just stream -> do + writeTQueue (serverIOActions peerServer_) $ void $ liftIO $ forkIO $ do + writeByteStringToStream stream $ BL.concat $ map lazyLoadBytes $ phBody ph + return [] + let packet = TransportPacket (TransportHeader $ phHead ph) body + secreq = case (secure, phPlaintextReply ph) of + (True, _) -> EncryptedOnly + (False, False) -> PlaintextAllowed + (False, True) -> PlaintextOnly + sendToPeerS' secreq peer (phAckedBy ph) packet + +liftSTM :: STM a -> PacketHandler a +liftSTM = PacketHandler . lift . lift + +readTVarP :: TVar a -> PacketHandler a +readTVarP = liftSTM . readTVar + +writeTVarP :: TVar a -> a -> PacketHandler () +writeTVarP v = liftSTM . writeTVar v + +modifyTMVarP :: TMVar a -> (a -> a) -> PacketHandler () +modifyTMVarP v f = liftSTM $ putTMVar v . f =<< takeTMVar v + +data PacketHandlerState = PacketHandlerState + { phPeer :: Peer + , phHead :: [TransportHeaderItem] + , phAckedBy :: [TransportHeaderItem] + , phBody :: [Ref] + , phBodyStream :: Maybe RawStreamWriter + , phPlaintextReply :: Bool + } + +addHeader :: TransportHeaderItem -> PacketHandler () +addHeader h = modify $ \ph -> ph { phHead = h `appendDistinct` phHead ph } + +addAckedBy :: [TransportHeaderItem] -> PacketHandler () +addAckedBy hs = modify $ \ph -> ph { phAckedBy = foldr appendDistinct (phAckedBy ph) hs } + +addBody :: Ref -> PacketHandler () +addBody r = modify $ \ph -> ph { phBody = r `appendDistinct` phBody ph } + +sendBodyAsStream :: PacketHandler () +sendBodyAsStream = do + gets phBodyStream >>= \case + Nothing -> do + stream <- openStream + modify $ \ph -> ph { phBodyStream = Just stream } + Just _ -> return () + +keepPlaintextReply :: PacketHandler () +keepPlaintextReply = modify $ \ph -> ph { phPlaintextReply = True } + +openStream :: PacketHandler RawStreamWriter +openStream = do + Peer {..} <- gets phPeer + conn <- readTVarP peerState >>= \case + PeerConnected conn -> return conn + _ -> throwError "can't open stream without established connection" + (hdr, writer, handler) <- liftSTM (connAddWriteStream conn) >>= \case + Right res -> return res + Left err -> throwError err + + liftSTM $ writeTQueue (serverIOActions peerServer_) (liftIO $ forkServerThread peerServer_ handler) + addHeader hdr + return writer + +acceptStream :: Word8 -> PacketHandler RawStreamReader +acceptStream streamNumber = do + Peer {..} <- gets phPeer + conn <- readTVarP peerState >>= \case + PeerConnected conn -> return conn + _ -> throwError "can't accept stream without established connection" + liftSTM $ connAddReadStream conn streamNumber + +appendDistinct :: Eq a => a -> [a] -> [a] +appendDistinct x (y:ys) | x == y = y : ys + | otherwise = y : appendDistinct x ys +appendDistinct x [] = [x] + +handlePacket :: UnifiedIdentity -> Bool + -> Peer -> TQueue (Peer, ServiceID, Ref) -> [ServiceID] + -> TransportHeader -> [PartialRef] -> IO () +handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = atomically $ do + let server = peerServer peer + ochannel <- getPeerChannel peer + let sidentity = idData identity + plaintextRefs = map (refDigest . storedRef) $ concatMap (collectStoredObjects . wrappedLoad) $ concat + [ [ storedRef sidentity ] + , map storedRef $ idUpdates identity + , case ochannel of + ChannelOurRequest req -> [ storedRef req ] + ChannelOurAccept acc _ -> [ storedRef acc ] + _ -> [] + ] + + runPacketHandler secure peer $ do + let logd = liftSTM . writeTQueue (serverErrorLog server) + forM_ headers $ \case + Acknowledged dgst -> do + liftSTM (getPeerChannel peer) >>= \case + ChannelOurAccept acc ch | refDigest (storedRef acc) == dgst -> do + liftSTM $ finalizedChannel peer ch identity + _ -> return () + + Rejected dgst + | peerRequest : _ <- mapMaybe (\case TrChannelRequest d -> Just d; _ -> Nothing) headers + , peerRequest < dgst + -> return () -- Our request was rejected due to lower priority + + | otherwise -> logd $ "rejected by peer: " ++ show dgst + + DataRequest dgst + | secure || dgst `elem` plaintextRefs -> do + Right mref <- liftSTM $ unsafeIOToSTM $ + copyRef (peerStorage peer) $ + partialRefFromDigest (peerInStorage peer) dgst + addHeader $ DataResponse dgst + addAckedBy [ Acknowledged dgst, Rejected dgst ] + + -- Plaintext request may indicate the peer has restarted/changed or + -- otherwise lost the channel, so keep the reply plaintext as well. + when (not secure) keepPlaintextReply + + -- TODO: MTU + when (secure && BL.length (lazyLoadBytes mref) > 500) + sendBodyAsStream + + addBody $ mref + | otherwise -> do + logd $ "unauthorized data request for " ++ show dgst + addHeader $ Rejected dgst + + DataResponse dgst -> if + | Just pref <- find ((==dgst) . refDigest) prefs -> do + when (not secure) $ do + addHeader $ Acknowledged dgst + liftSTM $ writeTQueue (serverDataResponse server) (peer, Just pref) + + | streamNumber : _ <- lookupNewStreams headers -> do + streamReader <- acceptStream streamNumber + liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do + (runExcept <$> readObjectsFromStream (peerInStorage peer) streamReader) >>= \case + Left err -> atomically $ writeTQueue (serverErrorLog server) $ + "failed to receive object from stream: " <> err + Right objs -> do + forM_ objs $ \obj -> do + pref <- storeObject (peerInStorage peer) obj + atomically $ writeTQueue (serverDataResponse server) (peer, Just pref) + + | otherwise -> throwError $ "mismatched data response " ++ show dgst + + AnnounceSelf dgst + | dgst == refDigest (storedRef sidentity) -> return () + | otherwise -> do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan (serverChanPeer $ peerServer peer) peer + _ -> return () + + AnnounceUpdate dgst -> do + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityFull _ -> do + void $ newWaitingRef dgst $ handleIdentityUpdate peer + _ -> return () + + TrChannelRequest dgst -> do + let process = do + addHeader $ Acknowledged dgst + wref <- newWaitingRef dgst $ handleChannelRequest peer identity + liftSTM $ setPeerChannel peer $ ChannelPeerRequest wref + reject = addHeader $ Rejected dgst + + liftSTM (getPeerChannel peer) >>= \case + ChannelNone {} -> return () + ChannelCookieWait {} -> return () + ChannelCookieReceived {} -> process + ChannelCookieConfirmed {} -> process + ChannelOurRequest our + | dgst < refDigest (storedRef our) -> process + | otherwise -> do + -- Reject peer channel request with lower priority + addHeader $ TrChannelRequest $ refDigest $ storedRef our + reject + ChannelPeerRequest prev + | dgst == wrDigest prev -> addHeader $ Acknowledged dgst + | otherwise -> process + ChannelOurAccept {} -> reject + ChannelEstablished {} -> process + ChannelClosed {} -> return () + + TrChannelAccept dgst -> do + let process = do + handleChannelAccept identity $ partialRefFromDigest (peerInStorage peer) dgst + reject = addHeader $ Rejected dgst + liftSTM (getPeerChannel peer) >>= \case + ChannelNone {} -> reject + ChannelCookieWait {} -> reject + ChannelCookieReceived {} -> reject + ChannelCookieConfirmed {} -> reject + ChannelOurRequest {} -> process + ChannelPeerRequest {} -> process + ChannelOurAccept our _ | dgst < refDigest (storedRef our) -> process + | otherwise -> addHeader $ Rejected dgst + ChannelEstablished {} -> process + ChannelClosed {} -> return () + + ServiceType _ -> return () + ServiceRef dgst + | not secure -> throwError $ "service packet without secure channel" + | Just svc <- lookupServiceType headers -> if + | svc `elem` svcs -> do + if dgst `elem` map refDigest prefs || True {- TODO: used by Message service to confirm receive -} + then do + void $ newWaitingRef dgst $ \ref -> + liftIO $ atomically $ writeTQueue chanSvc (peer, svc, ref) + else throwError $ "missing service object " ++ show dgst + | otherwise -> addHeader $ Rejected dgst + | otherwise -> throwError $ "service ref without type" + + _ -> return () + + +withPeerIdentity :: MonadIO m => Peer -> (UnifiedIdentity -> ExceptT String IO ()) -> m () +withPeerIdentity peer act = liftIO $ atomically $ readTVar (peerIdentityVar peer) >>= \case + PeerIdentityUnknown tvar -> modifyTVar' tvar (act:) + PeerIdentityRef _ tvar -> modifyTVar' tvar (act:) + PeerIdentityFull idt -> writeTQueue (serverIOActions $ peerServer peer) (act idt) + + +setupChannel :: UnifiedIdentity -> Peer -> UnifiedIdentity -> WaitingRefCallback +setupChannel identity peer upid = do + req <- flip runReaderT (peerStorage peer) $ createChannelRequest identity upid + let reqref = refDigest $ storedRef req + let hitems = + [ TrChannelRequest reqref + , AnnounceSelf $ refDigest $ storedRef $ idData identity + ] + let sendChannelRequest = do + sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ + TransportPacket (TransportHeader hitems) [storedRef req] + setPeerChannel peer $ ChannelOurRequest req + liftIO $ atomically $ do + getPeerChannel peer >>= \case + ChannelCookieReceived -> sendChannelRequest + ChannelCookieConfirmed -> sendChannelRequest + _ -> return () + +handleChannelRequest :: Peer -> UnifiedIdentity -> Ref -> WaitingRefCallback +handleChannelRequest peer identity req = do + withPeerIdentity peer $ \upid -> do + (acc, ch) <- flip runReaderT (peerStorage peer) $ acceptChannelRequest identity upid (wrappedLoad req) + liftIO $ atomically $ do + getPeerChannel peer >>= \case + ChannelPeerRequest wr | wrDigest wr == refDigest req -> do + setPeerChannel peer $ ChannelOurAccept acc ch + let accref = refDigest $ storedRef acc + header = TrChannelAccept accref + ackedBy = [ Acknowledged accref, Rejected accref ] + sendToPeerPlain peer ackedBy $ TransportPacket (TransportHeader [header]) $ concat + [ [ storedRef $ acc ] + , [ storedRef $ signedData $ fromStored acc ] + , [ storedRef $ caKey $ fromStored $ signedData $ fromStored acc ] + , map storedRef $ signedSignature $ fromStored acc + ] + _ -> writeTQueue (serverErrorLog $ peerServer peer) $ "unexpected channel request" + +handleChannelAccept :: UnifiedIdentity -> PartialRef -> PacketHandler () +handleChannelAccept identity accref = do + peer <- gets phPeer + liftSTM $ writeTQueue (serverIOActions $ peerServer peer) $ do + withPeerIdentity peer $ \upid -> do + copyRef (peerStorage peer) accref >>= \case + Right acc -> do + ch <- acceptedChannel identity upid (wrappedLoad acc) + liftIO $ atomically $ do + sendToPeerS peer [] $ TransportPacket (TransportHeader [Acknowledged $ refDigest accref]) [] + finalizedChannel peer ch identity + + Left dgst -> throwError $ "missing accept data " ++ BC.unpack (showRefDigest dgst) + + +finalizedChannel :: Peer -> Channel -> UnifiedIdentity -> STM () +finalizedChannel peer@Peer {..} ch self = do + setPeerChannel peer $ ChannelEstablished ch + + -- Identity update + writeTQueue (serverIOActions peerServer_) $ liftIO $ atomically $ do + let selfRef = refDigest $ storedRef $ idExtData $ self + updateRefs = selfRef : map (refDigest . storedRef) (idUpdates self) + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] + sendToPeerS peer ackedBy $ flip TransportPacket [] $ TransportHeader $ map AnnounceUpdate updateRefs + + -- Notify services about new peer + readTVar peerIdentityVar >>= \case + PeerIdentityFull _ -> notifyServicesOfPeer peer + _ -> return () + + +handleIdentityAnnounce :: UnifiedIdentity -> Peer -> Ref -> WaitingRefCallback +handleIdentityAnnounce self peer ref = liftIO $ atomically $ do + let validateAndUpdate upds act = case validateIdentity $ wrappedLoad ref of + Just pid' -> do + let pid = fromMaybe pid' $ toUnifiedIdentity (updateIdentity upds pid') + writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid + writeTChan (serverChanPeer $ peerServer peer) peer + act pid + writeTQueue (serverIOActions $ peerServer peer) $ do + setupChannel self peer pid + Nothing -> return () + + readTVar (peerIdentityVar peer) >>= \case + PeerIdentityRef wref wact + | wrDigest wref == refDigest ref + -> validateAndUpdate [] $ \pid -> do + mapM_ (writeTQueue (serverIOActions $ peerServer peer) . ($ pid)) . + reverse =<< readTVar wact + + PeerIdentityFull pid + | idData pid `precedes` wrappedLoad ref + -> validateAndUpdate (idUpdates pid) $ \_ -> do + notifyServicesOfPeer peer + + _ -> return () + +handleIdentityUpdate :: Peer -> Ref -> WaitingRefCallback +handleIdentityUpdate peer ref = liftIO $ atomically $ do + pidentity <- readTVar (peerIdentityVar peer) + if | PeerIdentityFull pid <- pidentity + , Just pid' <- toUnifiedIdentity $ updateIdentity [wrappedLoad ref] pid + -> do + writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid' + writeTChan (serverChanPeer $ peerServer peer) peer + when (idData pid /= idData pid') $ notifyServicesOfPeer peer + + | otherwise -> return () + +notifyServicesOfPeer :: Peer -> STM () +notifyServicesOfPeer peer@Peer { peerServer_ = Server {..} } = do + writeTQueue serverIOActions $ do + forM_ serverServices $ \service@(SomeService _ attrs) -> + runPeerServiceOn (Just (service, attrs)) peer serviceNewPeer + + +mkPeer :: Server -> PeerAddress -> IO Peer +mkPeer peerServer_ peerAddress = do + peerState <- newTVarIO (PeerInit []) + peerIdentityVar <- newTVarIO . PeerIdentityUnknown =<< newTVarIO [] + peerStorage_ <- deriveEphemeralStorage $ serverStorage peerServer_ + peerInStorage <- derivePartialStorage peerStorage_ + peerServiceState <- newTMVarIO M.empty + peerWaitingRefs <- newTMVarIO [] + return Peer {..} + +serverPeer :: Server -> SockAddr -> IO Peer +serverPeer server paddr = do + let paddr' = case IP.fromSockAddr paddr of + Just (IP.IPv4 ipv4, port) + -> IP.toSockAddr (IP.IPv6 $ IP.toIPv6w (0, 0, 0xffff, IP.fromIPv4w ipv4), port) + _ -> paddr + serverPeer' server (DatagramAddress paddr') + +#ifdef ENABLE_ICE_SUPPORT +serverPeerIce :: Server -> IceSession -> IO Peer +serverPeerIce server@Server {..} ice = do + let paddr = PeerIceSession ice + peer <- serverPeer' server paddr + iceSetChan ice $ mapFlow undefined (paddr,) serverRawPath + return peer +#endif + +serverPeer' :: Server -> PeerAddress -> IO Peer +serverPeer' server paddr = do + (peer, hello) <- modifyMVar (serverPeers server) $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, (peer, False)) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, (peer, True)) + when hello $ atomically $ do + writeFlow (serverControlFlow server) (RequestConnection paddr) + return peer + +dropPeer :: MonadIO m => Peer -> m () +dropPeer peer = liftIO $ do + modifyMVar_ (serverPeers $ peerServer peer) $ \pvalue -> do + atomically $ do + readTVar (peerState peer) >>= \case + PeerConnected conn -> connClose conn + _ -> return() + writeTVar (peerState peer) PeerDropped + return $ M.delete (peerAddress peer) pvalue + +isPeerDropped :: MonadIO m => Peer -> m Bool +isPeerDropped peer = liftIO $ atomically $ readTVar (peerState peer) >>= \case + PeerDropped -> return True + _ -> return False + +sendToPeer :: (Service s, MonadIO m) => Peer -> s -> m () +sendToPeer peer = sendManyToPeer peer . (: []) + +sendManyToPeer :: (Service s, MonadIO m) => Peer -> [ s ] -> m () +sendManyToPeer peer = sendToPeerList peer . map (\part -> ServiceReply (Left part) True) + +sendToPeerStored :: (Service s, MonadIO m) => Peer -> Stored s -> m () +sendToPeerStored peer = sendManyToPeerStored peer . (: []) + +sendManyToPeerStored :: (Service s, MonadIO m) => Peer -> [ Stored s ] -> m () +sendManyToPeerStored peer = sendToPeerList peer . map (\part -> ServiceReply (Right part) True) + +sendToPeerList :: (Service s, MonadIO m) => Peer -> [ServiceReply s] -> m () +sendToPeerList peer parts = do + let st = peerStorage peer + srefs <- liftIO $ fmap catMaybes $ forM parts $ \case + ServiceReply (Left x) use -> Just . (,use) <$> store st x + ServiceReply (Right sx) use -> return $ Just (storedRef sx, use) + ServiceFinally act -> act >> return Nothing + let dgsts = map (refDigest . fst) srefs + let content = map fst $ filter (\(ref, use) -> use && BL.length (lazyLoadBytes ref) < 500) srefs -- TODO: MTU + header = TransportHeader (ServiceType (serviceID $ head parts) : map ServiceRef dgsts) + packet = TransportPacket header content + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- dgsts ] + liftIO $ atomically $ sendToPeerS peer ackedBy packet + +sendToPeerS' :: SecurityRequirement -> Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS' secure Peer {..} ackedBy packet = do + readTVar peerState >>= \case + PeerInit xs -> writeTVar peerState $ PeerInit $ (secure, packet, ackedBy) : xs + PeerConnected conn -> writeFlow (connData conn) (secure, packet, ackedBy) + PeerDropped -> return () + +sendToPeerS :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS = sendToPeerS' EncryptedOnly + +sendToPeerPlain :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerPlain = sendToPeerS' PlaintextAllowed + +sendToPeerWith :: forall s m. (Service s, MonadIO m, MonadError String m) => Peer -> (ServiceState s -> ExceptT String IO (Maybe s, ServiceState s)) -> m () +sendToPeerWith peer fobj = do + let sproxy = Proxy @s + sid = serviceID sproxy + res <- liftIO $ do + svcs <- atomically $ takeTMVar (peerServiceState peer) + (svcs', res) <- runExceptT (fobj $ fromMaybe (emptyServiceState sproxy) $ fromServiceState sproxy =<< M.lookup sid svcs) >>= \case + Right (obj, s') -> return $ (M.insert sid (SomeServiceState sproxy s') svcs, Right obj) + Left err -> return $ (svcs, Left err) + atomically $ putTMVar (peerServiceState peer) svcs' + return res + + case res of + Right (Just obj) -> sendToPeer peer obj + Right Nothing -> return () + Left err -> throwError err + + +lookupService :: forall s. Service s => Proxy s -> [SomeService] -> Maybe (SomeService, ServiceAttributes s) +lookupService proxy (service@(SomeService (_ :: Proxy t) attr) : rest) + | Just (Refl :: s :~: t) <- eqT = Just (service, attr) + | otherwise = lookupService proxy rest +lookupService _ [] = Nothing + +runPeerService :: forall s m. (Service s, MonadIO m) => Peer -> ServiceHandler s () -> m () +runPeerService = runPeerServiceOn Nothing + +runPeerServiceOn :: forall s m. (Service s, MonadIO m) => Maybe (SomeService, ServiceAttributes s) -> Peer -> ServiceHandler s () -> m () +runPeerServiceOn mbservice peer handler = liftIO $ do + let server = peerServer peer + proxy = Proxy @s + svc = serviceID proxy + logd = writeTQueue (serverErrorLog server) + case mbservice `mplus` lookupService proxy (serverServices server) of + Just (service, attr) -> + atomically (readTVar (peerIdentityVar peer)) >>= \case + PeerIdentityFull peerId -> do + (global, svcs) <- atomically $ (,) + <$> takeTMVar (serverServiceStates server) + <*> takeTMVar (peerServiceState peer) + case (fromMaybe (someServiceEmptyState service) $ M.lookup svc svcs, + fromMaybe (someServiceEmptyGlobalState service) $ M.lookup svc global) of + ((SomeServiceState (_ :: Proxy ps) ps), + (SomeServiceGlobalState (_ :: Proxy gs) gs)) -> do + Just (Refl :: s :~: ps) <- return $ eqT + Just (Refl :: s :~: gs) <- return $ eqT + + let inp = ServiceInput + { svcAttributes = attr + , svcPeer = peer + , svcPeerIdentity = peerId + , svcServer = server + , svcPrintOp = atomically . logd + } + reloadHead (serverOrigHead server) >>= \case + Nothing -> atomically $ do + logd $ "current head deleted" + putTMVar (peerServiceState peer) svcs + putTMVar (serverServiceStates server) global + Just h -> do + (rsp, (s', gs')) <- runServiceHandler h inp ps gs handler + moveKeys (peerStorage peer) (serverStorage server) + when (not (null rsp)) $ do + sendToPeerList peer rsp + atomically $ do + putTMVar (peerServiceState peer) $ M.insert svc (SomeServiceState proxy s') svcs + putTMVar (serverServiceStates server) $ M.insert svc (SomeServiceGlobalState proxy gs') global + _ -> do + atomically $ logd $ "can't run service handler on peer with incomplete identity " ++ show (peerAddress peer) + + _ -> atomically $ do + logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + + +foreign import ccall unsafe "Network/ifaddrs.h join_multicast" cJoinMulticast :: CInt -> Ptr CSize -> IO (Ptr Word32) +foreign import ccall unsafe "Network/ifaddrs.h broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32) +foreign import ccall unsafe "stdlib.h free" cFree :: Ptr Word32 -> IO () + +joinMulticast :: Socket -> IO [ Word32 ] +joinMulticast sock = + withFdSocket sock $ \fd -> + alloca $ \pcount -> do + ptr <- cJoinMulticast fd pcount + count <- fromIntegral <$> peek pcount + forM [ 0 .. count - 1 ] $ \i -> + peekElemOff ptr i + +getBroadcastAddresses :: PortNumber -> IO [SockAddr] +getBroadcastAddresses port = do + ptr <- cBroadcastAddresses + let parse i = do + w <- peekElemOff ptr i + if w == 0 then return [] + else (SockAddrInet port w:) <$> parse (i + 1) + if ptr == nullPtr + then return [] + else do + addrs <- parse 0 + cFree ptr + return addrs diff --git a/src/Erebos/Network.hs-boot b/src/Erebos/Network.hs-boot new file mode 100644 index 0000000..849bfc1 --- /dev/null +++ b/src/Erebos/Network.hs-boot @@ -0,0 +1,8 @@ +module Erebos.Network where + +import Erebos.Storage + +data Server +data Peer + +peerStorage :: Peer -> Storage diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs new file mode 100644 index 0000000..ded0b05 --- /dev/null +++ b/src/Erebos/Network/Protocol.hs @@ -0,0 +1,914 @@ +module Erebos.Network.Protocol ( + TransportPacket(..), + transportToObject, + TransportHeader(..), + TransportHeaderItem(..), + SecurityRequirement(..), + + WaitingRef(..), + WaitingRefCallback, + wrDigest, + + ChannelState(..), + + ControlRequest(..), + ControlMessage(..), + erebosNetworkProtocol, + + Connection, + connAddress, + connData, + connGetChannel, + connSetChannel, + connClose, + + RawStreamReader, RawStreamWriter, + connAddWriteStream, + connAddReadStream, + readStreamToList, + readObjectsFromStream, + writeByteStringToStream, + + module Erebos.Flow, +) where + +import Control.Applicative +import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.Except +import Control.Monad.Trans + +import Data.Bits +import Data.ByteString (ByteString) +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Function +import Data.List +import Data.Maybe +import Data.Text (Text) +import Data.Text qualified as T +import Data.Void +import Data.Word + +import System.Clock + +import Erebos.Channel +import Erebos.Flow +import Erebos.Identity +import Erebos.Service +import Erebos.Storage + + +protocolVersion :: Text +protocolVersion = T.pack "0.1" + +protocolVersions :: [Text] +protocolVersions = [protocolVersion] + +keepAliveInternal :: TimeSpec +keepAliveInternal = fromNanoSecs $ 30 * 10^(9 :: Int) + + +data TransportPacket a = TransportPacket TransportHeader [a] + +data TransportHeader = TransportHeader [TransportHeaderItem] + deriving (Show) + +data TransportHeaderItem + = Acknowledged RefDigest + | AcknowledgedSingle Integer + | Rejected RefDigest + | ProtocolVersion Text + | Initiation RefDigest + | CookieSet Cookie + | CookieEcho Cookie + | DataRequest RefDigest + | DataResponse RefDigest + | AnnounceSelf RefDigest + | AnnounceUpdate RefDigest + | TrChannelRequest RefDigest + | TrChannelAccept RefDigest + | ServiceType ServiceID + | ServiceRef RefDigest + | StreamOpen Word8 + deriving (Eq, Show) + +newtype Cookie = Cookie ByteString + deriving (Eq, Show) + +data SecurityRequirement = PlaintextOnly + | PlaintextAllowed + | EncryptedOnly + deriving (Eq, Ord) + +isHeaderItemAcknowledged :: TransportHeaderItem -> Bool +isHeaderItemAcknowledged = \case + Acknowledged {} -> False + AcknowledgedSingle {} -> False + Rejected {} -> False + ProtocolVersion {} -> False + Initiation {} -> False + CookieSet {} -> False + CookieEcho {} -> False + _ -> True + +transportToObject :: PartialStorage -> TransportHeader -> PartialObject +transportToObject st (TransportHeader items) = Rec $ map single items + where single = \case + Acknowledged dgst -> (BC.pack "ACK", RecRef $ partialRefFromDigest st dgst) + AcknowledgedSingle num -> (BC.pack "ACK", RecInt num) + Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst) + ProtocolVersion ver -> (BC.pack "VER", RecText ver) + Initiation dgst -> (BC.pack "INI", RecRef $ partialRefFromDigest st dgst) + CookieSet (Cookie bytes) -> (BC.pack "CKS", RecBinary bytes) + CookieEcho (Cookie bytes) -> (BC.pack "CKE", RecBinary bytes) + DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst) + DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst) + AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst) + AnnounceUpdate dgst -> (BC.pack "ANU", RecRef $ partialRefFromDigest st dgst) + TrChannelRequest dgst -> (BC.pack "CRQ", RecRef $ partialRefFromDigest st dgst) + TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst) + ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype) + ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst) + StreamOpen num -> (BC.pack "STO", RecInt $ fromIntegral num) + +transportFromObject :: PartialObject -> Maybe TransportHeader +transportFromObject (Rec items) = case catMaybes $ map single items of + [] -> Nothing + titems -> Just $ TransportHeader titems + where single (name, content) = if + | name == BC.pack "ACK", RecRef ref <- content -> Just $ Acknowledged $ refDigest ref + | name == BC.pack "ACK", RecInt num <- content -> Just $ AcknowledgedSingle num + | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref + | name == BC.pack "VER", RecText ver <- content -> Just $ ProtocolVersion ver + | name == BC.pack "INI", RecRef ref <- content -> Just $ Initiation $ refDigest ref + | name == BC.pack "CKS", RecBinary bytes <- content -> Just $ CookieSet (Cookie bytes) + | name == BC.pack "CKE", RecBinary bytes <- content -> Just $ CookieEcho (Cookie bytes) + | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref + | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref + | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref + | name == BC.pack "ANU", RecRef ref <- content -> Just $ AnnounceUpdate $ refDigest ref + | name == BC.pack "CRQ", RecRef ref <- content -> Just $ TrChannelRequest $ refDigest ref + | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref + | name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid + | name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref + | name == BC.pack "STO", RecInt num <- content -> Just $ StreamOpen $ fromIntegral num + | otherwise -> Nothing +transportFromObject _ = Nothing + + +data GlobalState addr = (Eq addr, Show addr) => GlobalState + { gIdentity :: TVar (UnifiedIdentity, [UnifiedIdentity]) + , gConnections :: TVar [Connection addr] + , gDataFlow :: SymFlow (addr, ByteString) + , gControlFlow :: Flow (ControlRequest addr) (ControlMessage addr) + , gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject)) + , gLog :: String -> STM () + , gStorage :: PartialStorage + , gNowVar :: TVar TimeSpec + , gNextTimeout :: TVar TimeSpec + , gInitConfig :: Ref + } + +data Connection addr = Connection + { cGlobalState :: GlobalState addr + , cAddress :: addr + , cDataUp :: Flow + (Maybe (Bool, TransportPacket PartialObject)) + (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) + , cDataInternal :: Flow + (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) + (Maybe (Bool, TransportPacket PartialObject)) + , cChannel :: TVar ChannelState + , cCookie :: TVar (Maybe Cookie) + , cSecureOutQueue :: TQueue (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) + , cMaxInFlightPackets :: TVar Int + , cReservedPackets :: TVar Int + , cSentPackets :: TVar [SentPacket] + , cToAcknowledge :: TVar [Integer] + , cNextKeepAlive :: TVar (Maybe TimeSpec) + , cInStreams :: TVar [(Word8, Stream)] + , cOutStreams :: TVar [(Word8, Stream)] + } + +instance Eq (Connection addr) where + (==) = (==) `on` cChannel + +connAddress :: Connection addr -> addr +connAddress = cAddress + +connData :: Connection addr -> Flow + (Maybe (Bool, TransportPacket PartialObject)) + (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) +connData = cDataUp + +connGetChannel :: Connection addr -> STM ChannelState +connGetChannel Connection {..} = readTVar cChannel + +connSetChannel :: Connection addr -> ChannelState -> STM () +connSetChannel Connection {..} ch = do + writeTVar cChannel ch + +connClose :: Connection addr -> STM () +connClose conn@Connection {..} = do + let GlobalState {..} = cGlobalState + readTVar cChannel >>= \case + ChannelClosed -> return () + _ -> do + writeTVar cChannel ChannelClosed + writeTVar gConnections . filter (/=conn) =<< readTVar gConnections + writeFlow cDataInternal Nothing + +connAddWriteStream :: Connection addr -> STM (Either String (TransportHeaderItem, RawStreamWriter, IO ())) +connAddWriteStream conn@Connection {..} = do + outStreams <- readTVar cOutStreams + let doInsert :: Word8 -> [(Word8, Stream)] -> ExceptT String STM ((Word8, Stream), [(Word8, Stream)]) + doInsert n (s@(n', _) : rest) | n == n' = + fmap (s:) <$> doInsert (n + 1) rest + doInsert n streams | n < 63 = lift $ do + sState <- newTVar StreamOpening + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + sWaitingForAck <- newTVar 0 + let info = (n, Stream {..}) + return (info, info : streams) + doInsert _ _ = throwError "all outbound streams in use" + + runExceptT $ do + ((streamNumber, stream), outStreams') <- doInsert 1 outStreams + lift $ writeTVar cOutStreams outStreams' + return (StreamOpen streamNumber, sFlowIn stream, go cGlobalState streamNumber stream) + + where + go gs@GlobalState {..} streamNumber stream = do + (reserved, msg) <- atomically $ do + readTVar (sState stream) >>= \case + StreamRunning -> return () + _ -> retry + (,) <$> reservePacket conn + <*> readFlow (sFlowOut stream) + + (plain, cont, onAck) <- case msg of + StreamData {..} -> do + return (stpData, True, return ()) + StreamClosed {} -> do + atomically $ do + -- wait for ack on all sent stream data + waits <- readTVar (sWaitingForAck stream) + when (waits > 0) retry + return (BC.empty, False, streamClosed conn streamNumber) + + let secure = True + plainAckedBy = [] + mbReserved = Just reserved + + mbch <- atomically (readTVar cChannel) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ ch -> Just ch + _ -> Nothing + + mbs <- case mbch of + Just ch -> do + runExceptT (channelEncrypt ch $ B.concat + [ B.singleton streamNumber + , B.singleton (fromIntegral (stpSequence msg) :: Word8) + , plain + ] ) >>= \case + Right (ctext, counter) -> do + let isAcked = True + return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err + return Nothing + Nothing | secure -> return Nothing + | otherwise -> return $ Just (plain, plainAckedBy) + + case mbs of + Just (bs, ackedBy) -> do + atomically $ do + modifyTVar' (sWaitingForAck stream) (+ 1) + let mbReserved' = (\rs -> rs + { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) + , rsOnAck = do + rsOnAck rs + onAck + atomically $ modifyTVar' (sWaitingForAck stream) (subtract 1) + }) <$> mbReserved + sendBytes conn mbReserved' bs + Nothing -> return () + + when cont $ go gs streamNumber stream + +connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader +connAddReadStream Connection {..} streamNumber = do + inStreams <- readTVar cInStreams + let doInsert (s@(n, _) : rest) + | streamNumber < n = fmap (s:) <$> doInsert rest + | streamNumber == n = doInsert rest + doInsert streams = do + sState <- newTVar StreamRunning + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + sWaitingForAck <- newTVar 0 + let stream = Stream {..} + return (stream, (streamNumber, stream) : streams) + (stream, inStreams') <- doInsert inStreams + writeTVar cInStreams inStreams' + return $ sFlowOut stream + + +type RawStreamReader = Flow StreamPacket Void +type RawStreamWriter = Flow Void StreamPacket + +data Stream = Stream + { sState :: TVar StreamState + , sFlowIn :: Flow Void StreamPacket + , sFlowOut :: Flow StreamPacket Void + , sNextSequence :: TVar Word64 + , sWaitingForAck :: TVar Word64 + } + +data StreamState = StreamOpening | StreamRunning + +data StreamPacket + = StreamData + { stpSequence :: Word64 + , stpData :: BC.ByteString + } + | StreamClosed + { stpSequence :: Word64 + } + +streamAccepted :: Connection addr -> Word8 -> IO () +streamAccepted Connection {..} snum = atomically $ do + (lookup snum <$> readTVar cOutStreams) >>= \case + Just Stream {..} -> do + modifyTVar' sState $ \case + StreamOpening -> StreamRunning + x -> x + Nothing -> return () + +streamClosed :: Connection addr -> Word8 -> IO () +streamClosed Connection {..} snum = atomically $ do + modifyTVar' cOutStreams $ filter ((snum /=) . fst) + +readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) +readStreamToList stream = readFlowIO stream >>= \case + StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream + StreamClosed sqEnd -> return (sqEnd, []) + +readObjectsFromStream :: PartialStorage -> RawStreamReader -> IO (Except String [PartialObject]) +readObjectsFromStream st stream = do + (seqEnd, list) <- readStreamToList stream + let validate s ((s', bytes) : rest) + | s == s' = (bytes : ) <$> validate (s + 1) rest + | s > s' = validate s rest + | otherwise = throwError "missing object chunk" + validate s [] + | s == seqEnd = return [] + | otherwise = throwError "content length mismatch" + return $ do + content <- BL.fromChunks <$> validate 0 list + deserializeObjects st content + +writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO () +writeByteStringToStream stream = go 0 + where + go seqNum bstr + | BL.null bstr = writeFlowIO stream $ StreamClosed seqNum + | otherwise = do + let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU + writeFlowIO stream $ StreamData seqNum (BL.toStrict cur) + go (seqNum + 1) rest + + +data WaitingRef = WaitingRef + { wrefStorage :: Storage + , wrefPartial :: PartialRef + , wrefAction :: Ref -> WaitingRefCallback + , wrefStatus :: TVar (Either [RefDigest] Ref) + } + +type WaitingRefCallback = ExceptT String IO () + +wrDigest :: WaitingRef -> RefDigest +wrDigest = refDigest . wrefPartial + + +data ChannelState = ChannelNone + | ChannelCookieWait -- sent initiation, waiting for response + | ChannelCookieReceived -- received cookie, but no cookie echo yet + | ChannelCookieConfirmed -- received cookie echo, no need to send from our side + | ChannelOurRequest (Stored ChannelRequest) + | ChannelPeerRequest WaitingRef + | ChannelOurAccept (Stored ChannelAccept) Channel + | ChannelEstablished Channel + | ChannelClosed + +data ReservedToSend = ReservedToSend + { rsAckedBy :: Maybe (TransportHeaderItem -> Bool) + , rsOnAck :: IO () + , rsOnFail :: IO () + } + +data SentPacket = SentPacket + { spTime :: TimeSpec + , spRetryCount :: Int + , spAckedBy :: Maybe (TransportHeaderItem -> Bool) + , spOnAck :: IO () + , spOnFail :: IO () + , spData :: BC.ByteString + } + + +data ControlRequest addr = RequestConnection addr + | SendAnnounce addr + | UpdateSelfIdentity UnifiedIdentity + +data ControlMessage addr = NewConnection (Connection addr) (Maybe RefDigest) + | ReceivedAnnounce addr RefDigest + + +erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) + => UnifiedIdentity + -> (String -> STM ()) + -> SymFlow (addr, ByteString) + -> Flow (ControlRequest addr) (ControlMessage addr) + -> IO () +erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do + gIdentity <- newTVarIO (initialIdentity, []) + gConnections <- newTVarIO [] + gNextUp <- newEmptyTMVarIO + mStorage <- memoryStorage + gStorage <- derivePartialStorage mStorage + + startTime <- getTime Monotonic + gNowVar <- newTVarIO startTime + gNextTimeout <- newTVarIO startTime + gInitConfig <- store mStorage $ (Rec [] :: Object) + + let gs = GlobalState {..} + + let signalTimeouts = forever $ do + now <- getTime Monotonic + next <- atomically $ do + writeTVar gNowVar now + readTVar gNextTimeout + + let waitTill time + | time > now = threadDelay $ fromInteger (toNanoSecs (time - now)) `div` 1000 + | otherwise = threadDelay maxBound + waitForUpdate = atomically $ do + next' <- readTVar gNextTimeout + when (next' == next) retry + + race_ (waitTill next) waitForUpdate + + race_ signalTimeouts $ forever $ join $ atomically $ + passUpIncoming gs <|> processIncoming gs <|> processOutgoing gs + + +getConnection :: GlobalState addr -> addr -> STM (Connection addr) +getConnection gs addr = do + maybe (newConnection gs addr) return =<< findConnection gs addr + +findConnection :: GlobalState addr -> addr -> STM (Maybe (Connection addr)) +findConnection GlobalState {..} addr = do + find ((addr==) . cAddress) <$> readTVar gConnections + +newConnection :: GlobalState addr -> addr -> STM (Connection addr) +newConnection cGlobalState@GlobalState {..} addr = do + conns <- readTVar gConnections + + let cAddress = addr + (cDataUp, cDataInternal) <- newFlow + cChannel <- newTVar ChannelNone + cCookie <- newTVar Nothing + cSecureOutQueue <- newTQueue + cMaxInFlightPackets <- newTVar 4 + cReservedPackets <- newTVar 0 + cSentPackets <- newTVar [] + cToAcknowledge <- newTVar [] + cNextKeepAlive <- newTVar Nothing + cInStreams <- newTVar [] + cOutStreams <- newTVar [] + let conn = Connection {..} + + writeTVar gConnections (conn : conns) + return conn + +passUpIncoming :: GlobalState addr -> STM (IO ()) +passUpIncoming GlobalState {..} = do + (Connection {..}, up) <- takeTMVar gNextUp + writeFlow cDataInternal (Just up) + return $ return () + +processIncoming :: GlobalState addr -> STM (IO ()) +processIncoming gs@GlobalState {..} = do + guard =<< isEmptyTMVar gNextUp + guard =<< canWriteFlow gControlFlow + + (addr, msg) <- readFlow gDataFlow + mbconn <- findConnection gs addr + + mbch <- case mbconn of + Nothing -> return Nothing + Just conn -> readTVar (cChannel conn) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ ch -> Just ch + _ -> Nothing + + return $ do + let deserialize = liftEither . runExcept . deserializeObjects gStorage . BL.fromStrict + let parse = case B.uncons msg of + Just (b, enc) + | b .&. 0xE0 == 0x80 -> do + ch <- maybe (throwError "unexpected encrypted packet") return mbch + (dec, counter) <- channelDecrypt ch enc + + case B.uncons dec of + Just (0x00, content) -> do + objs <- deserialize content + return $ Left (True, objs, Just counter) + + Just (snum, dec') + | snum < 64 + , Just (seq8, content) <- B.uncons dec' + -> do + return $ Right (snum, seq8, content, counter) + + Just (_, _) -> do + throwError "unexpected stream header" + + Nothing -> do + throwError "empty decrypted content" + + | b .&. 0xE0 == 0x60 -> do + objs <- deserialize msg + return $ Left (False, objs, Nothing) + + | otherwise -> throwError "invalid packet" + + Nothing -> throwError "empty packet" + + now <- getTime Monotonic + runExceptT parse >>= \case + Right (Left (secure, objs, mbcounter)) + | hobj:content <- objs + , Just header@(TransportHeader items) <- transportFromObject hobj + -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case + Just (conn@Connection {..}, mbup) -> do + ioAfter <- atomically $ do + case mbcounter of + Just counter | any isHeaderItemAcknowledged items -> + modifyTVar' cToAcknowledge (fromIntegral counter :) + _ -> return () + case mbup of + Just up -> putTMVar gNextUp (conn, (secure, up)) + Nothing -> return () + updateKeepAlive conn now + processAcknowledgements gs conn items + ioAfter + Nothing -> return () + + | otherwise -> atomically $ do + gLog $ show addr ++ ": invalid objects" + gLog $ show objs + + Right (Right (snum, seq8, content, counter)) + | Just conn@Connection {..} <- mbconn + -> atomically $ do + updateKeepAlive conn now + (lookup snum <$> readTVar cInStreams) >>= \case + Nothing -> + gLog $ "unexpected stream number " ++ show snum + + Just Stream {..} -> do + expectedSequence <- readTVar sNextSequence + let seqFull = expectedSequence - 0x80 + fromIntegral (seq8 - fromIntegral expectedSequence + 0x80 :: Word8) + sdata <- if + | B.null content -> do + modifyTVar' cInStreams $ filter ((/=snum) . fst) + return $ StreamClosed seqFull + | otherwise -> do + writeTVar sNextSequence $ max expectedSequence (seqFull + 1) + return $ StreamData seqFull content + writeFlow sFlowIn sdata + modifyTVar' cToAcknowledge (fromIntegral counter :) + + | otherwise -> do + atomically $ gLog $ show addr <> ": stream packet without connection" + + Left err -> do + atomically $ gLog $ show addr <> ": failed to parse packet: " <> err + +processPacket :: GlobalState addr -> Either addr (Connection addr) -> Bool -> TransportPacket a -> IO (Maybe (Connection addr, Maybe (TransportPacket a))) +processPacket gs@GlobalState {..} econn secure packet@(TransportPacket (TransportHeader header) _) = if + -- Established secure communication + | Right conn <- econn, secure + -> return $ Just (conn, Just packet) + + -- Plaintext communication with cookies to prove origin + | cookie:_ <- mapMaybe (\case CookieEcho x -> Just x; _ -> Nothing) header + -> verifyCookie gs addr cookie >>= \case + True -> do + atomically $ do + conn@Connection {..} <- getConnection gs addr + oldCookie <- readTVar cCookie + let received = listToMaybe $ mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header + case received `mplus` oldCookie of + Just current -> do + writeTVar cCookie (Just current) + cookieEchoReceived gs conn mbpid + return $ Just (conn, Just packet) + Nothing -> do + gLog $ show addr <> ": missing cookie set, dropping " <> show header + return $ Nothing + + False -> do + atomically $ gLog $ show addr <> ": cookie verification failed, dropping " <> show header + return Nothing + + -- Response to initiation packet + | cookie:_ <- mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header + , Just _ <- version + , Right conn@Connection {..} <- econn + -> do + atomically $ readTVar cChannel >>= \case + ChannelCookieWait -> do + writeTVar cChannel $ ChannelCookieReceived + writeTVar cCookie $ Just cookie + writeFlow gControlFlow (NewConnection conn mbpid) + return $ Just (conn, Nothing) + _ -> return Nothing + + -- Initiation packet + | _:_ <- mapMaybe (\case Initiation x -> Just x; _ -> Nothing) header + , Just ver <- version + -> do + cookie <- createCookie gs addr + atomically $ do + identity <- fst <$> readTVar gIdentity + let reply = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader + [ CookieSet cookie + , AnnounceSelf $ refDigest $ storedRef $ idData identity + , ProtocolVersion ver + ] + writeFlow gDataFlow (addr, reply) + return Nothing + + -- Announce packet outside any connection + | dgst:_ <- mapMaybe (\case AnnounceSelf x -> Just x; _ -> Nothing) header + , Just _ <- version + -> do + atomically $ do + (cur, past) <- readTVar gIdentity + when (not $ dgst `elem` map (refDigest . storedRef . idData) (cur : past)) $ do + writeFlow gControlFlow $ ReceivedAnnounce addr dgst + return Nothing + + | otherwise -> do + atomically $ gLog $ show addr <> ": dropping packet " <> show header + return Nothing + + where + addr = either id cAddress econn + mbpid = listToMaybe $ mapMaybe (\case AnnounceSelf dgst -> Just dgst; _ -> Nothing) header + version = listToMaybe $ filter (\v -> ProtocolVersion v `elem` header) protocolVersions + +cookieEchoReceived :: GlobalState addr -> Connection addr -> Maybe RefDigest -> STM () +cookieEchoReceived GlobalState {..} conn@Connection {..} mbpid = do + readTVar cChannel >>= \case + ChannelNone -> newConn + ChannelCookieWait -> newConn + ChannelCookieReceived {} -> update + _ -> return () + where + update = do + writeTVar cChannel ChannelCookieConfirmed + newConn = do + update + writeFlow gControlFlow (NewConnection conn mbpid) + +generateCookieHeaders :: Connection addr -> ChannelState -> IO [TransportHeaderItem] +generateCookieHeaders Connection {..} ch = catMaybes <$> sequence [ echoHeader, setHeader ] + where + echoHeader = fmap CookieEcho <$> atomically (readTVar cCookie) + setHeader = case ch of + ChannelCookieWait {} -> Just . CookieSet <$> createCookie cGlobalState cAddress + ChannelCookieReceived {} -> Just . CookieSet <$> createCookie cGlobalState cAddress + _ -> return Nothing + +createCookie :: GlobalState addr -> addr -> IO Cookie +createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr) + +verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool +verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie + + +reservePacket :: Connection addr -> STM ReservedToSend +reservePacket conn@Connection {..} = do + maxPackets <- readTVar cMaxInFlightPackets + reserved <- readTVar cReservedPackets + sent <- length <$> readTVar cSentPackets + + when (sent + reserved >= maxPackets) $ do + retry + + writeTVar cReservedPackets $ reserved + 1 + return $ ReservedToSend Nothing (return ()) (atomically $ connClose conn) + +resendBytes :: Connection addr -> Maybe ReservedToSend -> SentPacket -> IO () +resendBytes conn@Connection {..} reserved sp = do + let GlobalState {..} = cGlobalState + now <- getTime Monotonic + atomically $ do + when (isJust reserved) $ do + modifyTVar' cReservedPackets (subtract 1) + + when (isJust $ spAckedBy sp) $ do + modifyTVar' cSentPackets $ (:) sp + { spTime = now + , spRetryCount = spRetryCount sp + 1 + } + writeFlow gDataFlow (cAddress, spData sp) + updateKeepAlive conn now + +sendBytes :: Connection addr -> Maybe ReservedToSend -> ByteString -> IO () +sendBytes conn reserved bs = resendBytes conn reserved + SentPacket + { spTime = undefined + , spRetryCount = -1 + , spAckedBy = rsAckedBy =<< reserved + , spOnAck = maybe (return ()) rsOnAck reserved + , spOnFail = maybe (return ()) rsOnFail reserved + , spData = bs + } + +updateKeepAlive :: Connection addr -> TimeSpec -> STM () +updateKeepAlive Connection {..} now = do + let next = now + keepAliveInternal + writeTVar cNextKeepAlive $ Just next + + +processOutgoing :: forall addr. GlobalState addr -> STM (IO ()) +processOutgoing gs@GlobalState {..} = do + + let sendNextPacket :: Connection addr -> STM (IO ()) + sendNextPacket conn@Connection {..} = do + channel <- readTVar cChannel + let mbch = case channel of + ChannelEstablished ch -> Just ch + _ -> Nothing + + let checkOutstanding + | isJust mbch = do + (,) <$> readTQueue cSecureOutQueue <*> (Just <$> reservePacket conn) + | otherwise = retry + + checkDataInternal = do + (,) <$> readFlow cDataInternal <*> (Just <$> reservePacket conn) + + checkAcknowledgements + | isJust mbch = do + acks <- readTVar cToAcknowledge + if null acks then retry + else return ((EncryptedOnly, TransportPacket (TransportHeader []) [], []), Nothing) + | otherwise = retry + + ((secure, packet@(TransportPacket (TransportHeader hitems) content), plainAckedBy), mbReserved) <- + checkOutstanding <|> checkDataInternal <|> checkAcknowledgements + + when (isNothing mbch && secure >= EncryptedOnly) $ do + writeTQueue cSecureOutQueue (secure, packet, plainAckedBy) + + acknowledge <- case mbch of + Nothing -> return [] + Just _ -> swapTVar cToAcknowledge [] + + return $ do + let onAck = sequence_ $ map (streamAccepted conn) $ + catMaybes (map (\case StreamOpen n -> Just n; _ -> Nothing) hitems) + + let mkPlain extraHeaders + | combinedHeaderItems@(_:_) <- map AcknowledgedSingle acknowledge ++ extraHeaders ++ hitems = + BL.concat $ + (serializeObject $ transportToObject gStorage $ TransportHeader combinedHeaderItems) + : map lazyLoadBytes content + | otherwise = BL.empty + + let usePlaintext = do + plain <- mkPlain <$> generateCookieHeaders conn channel + return $ Just (BL.toStrict plain, plainAckedBy) + + let useEncryption ch = do + plain <- mkPlain <$> return [] + runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case + Right (ctext, counter) -> do + let isAcked = any isHeaderItemAcknowledged hitems + return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ err + return Nothing + + mbs <- case (secure, mbch) of + (PlaintextOnly, _) -> usePlaintext + (PlaintextAllowed, Nothing) -> usePlaintext + (_, Just ch) -> useEncryption ch + (EncryptedOnly, Nothing) -> return Nothing + + case mbs of + Just (bs, ackedBy) -> do + let mbReserved' = (\rs -> rs + { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) + , rsOnAck = rsOnAck rs >> onAck + }) <$> mbReserved + sendBytes conn mbReserved' bs + Nothing -> return () + + let waitUntil :: TimeSpec -> TimeSpec -> STM () + waitUntil now till = do + nextTimeout <- readTVar gNextTimeout + if nextTimeout <= now || till < nextTimeout + then writeTVar gNextTimeout till + else retry + + let retransmitPacket :: Connection addr -> STM (IO ()) + retransmitPacket conn@Connection {..} = do + now <- readTVar gNowVar + (sp, rest) <- readTVar cSentPackets >>= \case + sps@(_:_) -> return (last sps, init sps) + _ -> retry + let nextTry = spTime sp + fromNanoSecs 1000000000 + if | now < nextTry -> do + waitUntil now nextTry + return $ return () + | spRetryCount sp < 2 -> do + reserved <- reservePacket conn + writeTVar cSentPackets rest + return $ resendBytes conn (Just reserved) sp + | otherwise -> do + return $ spOnFail sp + + let handleControlRequests = readFlow gControlFlow >>= \case + RequestConnection addr -> do + conn@Connection {..} <- getConnection gs addr + identity <- fst <$> readTVar gIdentity + readTVar cChannel >>= \case + ChannelNone -> do + reserved <- reservePacket conn + let packet = BL.toStrict $ BL.concat + [ serializeObject $ transportToObject gStorage $ TransportHeader $ + [ Initiation $ refDigest gInitConfig + , AnnounceSelf $ refDigest $ storedRef $ idData identity + ] ++ map ProtocolVersion protocolVersions + , lazyLoadBytes gInitConfig + ] + writeTVar cChannel ChannelCookieWait + let reserved' = reserved { rsAckedBy = Just $ \case CookieSet {} -> True; _ -> False } + return $ sendBytes conn (Just reserved') packet + _ -> return $ return () + + SendAnnounce addr -> do + identity <- fst <$> readTVar gIdentity + let packet = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader $ + [ AnnounceSelf $ refDigest $ storedRef $ idData identity + ] ++ map ProtocolVersion protocolVersions + writeFlow gDataFlow (addr, packet) + return $ return () + + UpdateSelfIdentity nid -> do + (cur, past) <- readTVar gIdentity + writeTVar gIdentity (nid, cur : past) + return $ return () + + let sendKeepAlive :: Connection addr -> STM (IO ()) + sendKeepAlive Connection {..} = do + readTVar cNextKeepAlive >>= \case + Nothing -> retry + Just next -> do + now <- readTVar gNowVar + if next <= now + then do + writeTVar cNextKeepAlive Nothing + identity <- fst <$> readTVar gIdentity + let header = TransportHeader [ AnnounceSelf $ refDigest $ storedRef $ idData identity ] + writeTQueue cSecureOutQueue (EncryptedOnly, TransportPacket header [], []) + else do + waitUntil now next + return $ return () + + conns <- readTVar gConnections + msum $ concat $ + [ map retransmitPacket conns + , map sendNextPacket conns + , [ handleControlRequests ] + , map sendKeepAlive conns + ] + +processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM (IO ()) +processAcknowledgements GlobalState {} Connection {..} header = do + (acked, notAcked) <- partition (\sp -> any (fromJust (spAckedBy sp)) header) <$> readTVar cSentPackets + writeTVar cSentPackets notAcked + return $ sequence_ $ map spOnAck acked diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c new file mode 100644 index 0000000..70685bc --- /dev/null +++ b/src/Erebos/Network/ifaddrs.c @@ -0,0 +1,167 @@ +#include "ifaddrs.h" + +#include <errno.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#ifndef _WIN32 +#include <arpa/inet.h> +#include <net/if.h> +#include <ifaddrs.h> +#include <endian.h> +#include <sys/types.h> +#include <sys/socket.h> +#else +#include <winsock2.h> +#include <ws2ipdef.h> +#include <ws2tcpip.h> +#endif + +#define DISCOVERY_MULTICAST_GROUP "ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1" + +uint32_t * join_multicast(int fd, size_t * count) +{ + size_t capacity = 16; + *count = 0; + uint32_t * interfaces = malloc(sizeof(uint32_t) * capacity); + +#ifdef _WIN32 + interfaces[0] = 0; + *count = 1; +#else + struct ifaddrs * addrs; + if (getifaddrs(&addrs) < 0) + return 0; + + for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET6 && + !(ifa->ifa_flags & IFF_LOOPBACK)) { + int idx = if_nametoindex(ifa->ifa_name); + + bool seen = false; + for (size_t i = 0; i < *count; i++) { + if (interfaces[i] == idx) { + seen = true; + break; + } + } + if (seen) + continue; + + if (*count + 1 >= capacity) { + capacity *= 2; + uint32_t * nret = realloc(interfaces, sizeof(uint32_t) * capacity); + if (nret) { + interfaces = nret; + } else { + free(interfaces); + *count = 0; + return NULL; + } + } + + interfaces[*count] = idx; + (*count)++; + } + } + + freeifaddrs(addrs); +#endif + + for (size_t i = 0; i < *count; i++) { + struct ipv6_mreq group; + group.ipv6mr_interface = interfaces[i]; + inet_pton(AF_INET6, DISCOVERY_MULTICAST_GROUP, &group.ipv6mr_multiaddr); + int ret = setsockopt(fd, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, + (const void *) &group, sizeof(group)); + if (ret < 0) + fprintf(stderr, "IPV6_ADD_MEMBERSHIP failed: %s\n", strerror(errno)); + } + + return interfaces; +} + +#ifndef _WIN32 + +uint32_t * broadcast_addresses(void) +{ + struct ifaddrs * addrs; + if (getifaddrs(&addrs) < 0) + return 0; + + size_t capacity = 16, count = 0; + uint32_t * ret = malloc(sizeof(uint32_t) * capacity); + + for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET && + ifa->ifa_flags & IFF_BROADCAST) { + if (count + 2 >= capacity) { + capacity *= 2; + uint32_t * nret = realloc(ret, sizeof(uint32_t) * capacity); + if (nret) { + ret = nret; + } else { + free(ret); + return 0; + } + } + + ret[count] = ((struct sockaddr_in*)ifa->ifa_broadaddr)->sin_addr.s_addr; + count++; + } + } + + freeifaddrs(addrs); + ret[count] = 0; + return ret; +} + +#else // _WIN32 + +#include <winsock2.h> +#include <ws2tcpip.h> + +#pragma comment(lib, "ws2_32.lib") + +uint32_t * broadcast_addresses(void) +{ + uint32_t * ret = NULL; + SOCKET wsock = INVALID_SOCKET; + + struct WSAData wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) + return NULL; + + wsock = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, 0); + if (wsock == INVALID_SOCKET) + goto cleanup; + + INTERFACE_INFO InterfaceList[32]; + unsigned long nBytesReturned; + + if (WSAIoctl(wsock, SIO_GET_INTERFACE_LIST, 0, 0, + InterfaceList, sizeof(InterfaceList), + &nBytesReturned, 0, 0) == SOCKET_ERROR) + goto cleanup; + + int numInterfaces = nBytesReturned / sizeof(INTERFACE_INFO); + + size_t capacity = 16, count = 0; + ret = malloc(sizeof(uint32_t) * capacity); + + for (int i = 0; i < numInterfaces && count < capacity - 1; i++) + if (InterfaceList[i].iiFlags & IFF_BROADCAST) + ret[count++] = InterfaceList[i].iiBroadcastAddress.AddressIn.sin_addr.s_addr; + + ret[count] = 0; +cleanup: + if (wsock != INVALID_SOCKET) + closesocket(wsock); + WSACleanup(); + + return ret; +} + +#endif diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h new file mode 100644 index 0000000..8852ec6 --- /dev/null +++ b/src/Erebos/Network/ifaddrs.h @@ -0,0 +1,5 @@ +#include <stddef.h> +#include <stdint.h> + +uint32_t * join_multicast(int fd, size_t * count); +uint32_t * broadcast_addresses(void); diff --git a/src/Erebos/Pairing.hs b/src/Erebos/Pairing.hs new file mode 100644 index 0000000..2166e71 --- /dev/null +++ b/src/Erebos/Pairing.hs @@ -0,0 +1,242 @@ +module Erebos.Pairing ( + PairingService(..), + PairingState(..), + PairingAttributes(..), + PairingResult(..), + PairingFailureReason(..), + + pairingRequest, + pairingAccept, + pairingReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Crypto.Random + +import Data.Bits +import Data.ByteArray (Bytes, convert) +import qualified Data.ByteArray as BA +import qualified Data.ByteString.Char8 as BC +import Data.Kind +import Data.Maybe +import Data.Typeable +import Data.Word + +import Erebos.Identity +import Erebos.Network +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage + +data PairingService a = PairingRequest (Stored (Signed IdentityData)) (Stored (Signed IdentityData)) RefDigest + | PairingResponse Bytes + | PairingRequestNonce Bytes + | PairingAccept a + | PairingReject + +data PairingState a = NoPairing + | OurRequest UnifiedIdentity UnifiedIdentity Bytes + | OurRequestConfirm (Maybe (PairingVerifiedResult a)) + | OurRequestReady + | PeerRequest UnifiedIdentity UnifiedIdentity Bytes RefDigest + | PeerRequestConfirm + | PairingDone + +data PairingFailureReason a = PairingUserRejected + | PairingUnexpectedMessage (PairingState a) (PairingService a) + | PairingFailedOther String + +data PairingAttributes a = PairingAttributes + { pairingHookRequest :: ServiceHandler (PairingService a) () + , pairingHookResponse :: String -> ServiceHandler (PairingService a) () + , pairingHookRequestNonce :: String -> ServiceHandler (PairingService a) () + , pairingHookRequestNonceFailed :: ServiceHandler (PairingService a) () + , pairingHookConfirmedResponse :: ServiceHandler (PairingService a) () + , pairingHookConfirmedRequest :: ServiceHandler (PairingService a) () + , pairingHookAcceptedResponse :: ServiceHandler (PairingService a) () + , pairingHookAcceptedRequest :: ServiceHandler (PairingService a) () + , pairingHookVerifyFailed :: ServiceHandler (PairingService a) () + , pairingHookRejected :: ServiceHandler (PairingService a) () + , pairingHookFailed :: PairingFailureReason a -> ServiceHandler (PairingService a) () + } + +class (Typeable a, Storable a) => PairingResult a where + type PairingVerifiedResult a :: Type + type PairingVerifiedResult a = a + + pairingServiceID :: proxy a -> ServiceID + pairingVerifyResult :: a -> ServiceHandler (PairingService a) (Maybe (PairingVerifiedResult a)) + pairingFinalizeRequest :: PairingVerifiedResult a -> ServiceHandler (PairingService a) () + pairingFinalizeResponse :: ServiceHandler (PairingService a) a + defaultPairingAttributes :: proxy (PairingService a) -> PairingAttributes a + + +instance Storable a => Storable (PairingService a) where + store' (PairingRequest idReq idRsp x) = storeRec $ do + storeRef "id-req" idReq + storeRef "id-rsp" idRsp + storeBinary "request" x + store' (PairingResponse x) = storeRec $ storeBinary "response" x + store' (PairingRequestNonce x) = storeRec $ storeBinary "reqnonce" x + store' (PairingAccept x) = store' x + store' (PairingReject) = storeRec $ storeEmpty "reject" + + load' = do + res <- loadRec $ do + (req :: Maybe Bytes) <- loadMbBinary "request" + idReq <- loadMbRef "id-req" + idRsp <- loadMbRef "id-rsp" + rsp <- loadMbBinary "response" + rnonce <- loadMbBinary "reqnonce" + rej <- loadMbEmpty "reject" + return $ catMaybes + [ PairingRequest <$> idReq <*> idRsp <*> (refDigestFromByteString =<< req) + , PairingResponse <$> rsp + , PairingRequestNonce <$> rnonce + , const PairingReject <$> rej + ] + case res of + x:_ -> return x + [] -> PairingAccept <$> load' + + +instance PairingResult a => Service (PairingService a) where + serviceID _ = pairingServiceID @a Proxy + + type ServiceAttributes (PairingService a) = PairingAttributes a + defaultServiceAttributes = defaultPairingAttributes + + type ServiceState (PairingService a) = PairingState a + emptyServiceState _ = NoPairing + + serviceHandler spacket = ((,fromStored spacket) <$> svcGet) >>= \case + (NoPairing, PairingRequest pdata sdata confirm) -> do + self <- maybe (throwError "failed to validate received identity") return $ validateIdentity sdata + self' <- maybe (throwError "failed to validate own identity") return . + validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal + when (not $ self `sameIdentity` self') $ do + throwError "pairing request to different identity" + + peer <- maybe (throwError "failed to validate received peer identity") return $ validateIdentity pdata + peer' <- asks $ svcPeerIdentity + when (not $ peer `sameIdentity` peer') $ do + throwError "pairing request from different identity" + + join $ asks $ pairingHookRequest . svcAttributes + nonce <- liftIO $ getRandomBytes 32 + svcSet $ PeerRequest peer self nonce confirm + replyPacket $ PairingResponse nonce + (NoPairing, _) -> return () + + (PairingDone, _) -> return () + (_, PairingReject) -> do + join $ asks $ pairingHookRejected . svcAttributes + svcSet NoPairing + + (OurRequest self peer nonce, PairingResponse pnonce) -> do + hook <- asks $ pairingHookResponse . svcAttributes + hook $ confirmationNumber $ nonceDigest self peer nonce pnonce + svcSet $ OurRequestConfirm Nothing + replyPacket $ PairingRequestNonce nonce + x@(OurRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x + + (OurRequestConfirm _, PairingAccept x) -> do + flip catchError (reject . PairingFailedOther) $ do + pairingVerifyResult x >>= \case + Just x' -> do + join $ asks $ pairingHookConfirmedRequest . svcAttributes + svcSet $ OurRequestConfirm (Just x') + Nothing -> do + join $ asks $ pairingHookVerifyFailed . svcAttributes + svcSet NoPairing + replyPacket PairingReject + + x@(OurRequestConfirm _, _) -> reject $ uncurry PairingUnexpectedMessage x + + (OurRequestReady, PairingAccept x) -> do + flip catchError (reject . PairingFailedOther) $ do + pairingVerifyResult x >>= \case + Just x' -> do + pairingFinalizeRequest x' + join $ asks $ pairingHookAcceptedResponse . svcAttributes + svcSet $ PairingDone + Nothing -> do + join $ asks $ pairingHookVerifyFailed . svcAttributes + throwError "" + x@(OurRequestReady, _) -> reject $ uncurry PairingUnexpectedMessage x + + (PeerRequest peer self nonce dgst, PairingRequestNonce pnonce) -> do + if dgst == nonceDigest peer self pnonce BA.empty + then do hook <- asks $ pairingHookRequestNonce . svcAttributes + hook $ confirmationNumber $ nonceDigest peer self pnonce nonce + svcSet PeerRequestConfirm + else do join $ asks $ pairingHookRequestNonceFailed . svcAttributes + svcSet NoPairing + replyPacket PairingReject + x@(PeerRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x + x@(PeerRequestConfirm, _) -> reject $ uncurry PairingUnexpectedMessage x + +reject :: PairingResult a => PairingFailureReason a -> ServiceHandler (PairingService a) () +reject reason = do + join $ asks $ flip pairingHookFailed reason . svcAttributes + svcSet NoPairing + replyPacket PairingReject + + +nonceDigest :: UnifiedIdentity -> UnifiedIdentity -> Bytes -> Bytes -> RefDigest +nonceDigest idReq idRsp nonceReq nonceRsp = hashToRefDigest $ serializeObject $ Rec + [ (BC.pack "id-req", RecRef $ storedRef $ idData idReq) + , (BC.pack "id-rsp", RecRef $ storedRef $ idData idRsp) + , (BC.pack "nonce-req", RecBinary $ convert nonceReq) + , (BC.pack "nonce-rsp", RecBinary $ convert nonceRsp) + ] + +confirmationNumber :: RefDigest -> String +confirmationNumber dgst = + case map fromIntegral $ BA.unpack dgst :: [Word32] of + (a:b:c:d:_) -> let str = show $ ((a `shift` 24) .|. (b `shift` 16) .|. (c `shift` 8) .|. d) `mod` (10 ^ len) + in replicate (len - length str) '0' ++ str + _ -> "" + where len = 6 + +pairingRequest :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingRequest _ peer = do + self <- liftIO $ serverIdentity $ peerServer peer + nonce <- liftIO $ getRandomBytes 32 + pid <- peerIdentity peer >>= \case + PeerIdentityFull pid -> return pid + _ -> throwError "incomplete peer identity" + sendToPeerWith @(PairingService a) peer $ \case + NoPairing -> return (Just $ PairingRequest (idData self) (idData pid) (nonceDigest self pid nonce BA.empty), OurRequest self pid nonce) + _ -> throwError "already in progress" + +pairingAccept :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingAccept _ peer = runPeerService @(PairingService a) peer $ do + svcGet >>= \case + NoPairing -> throwError $ "none in progress" + OurRequest {} -> throwError $ "waiting for peer" + OurRequestConfirm Nothing -> do + join $ asks $ pairingHookConfirmedResponse . svcAttributes + svcSet OurRequestReady + OurRequestConfirm (Just verified) -> do + join $ asks $ pairingHookAcceptedResponse . svcAttributes + pairingFinalizeRequest verified + svcSet PairingDone + OurRequestReady -> throwError $ "already accepted, waiting for peer" + PeerRequest {} -> throwError $ "waiting for peer" + PeerRequestConfirm -> do + join $ asks $ pairingHookAcceptedRequest . svcAttributes + replyPacket . PairingAccept =<< pairingFinalizeResponse + svcSet PairingDone + PairingDone -> throwError $ "already done" + +pairingReject :: forall a m proxy. (PairingResult a, MonadIO m, MonadError String m) => proxy a -> Peer -> m () +pairingReject _ peer = runPeerService @(PairingService a) peer $ do + svcGet >>= \case + NoPairing -> throwError $ "none in progress" + PairingDone -> throwError $ "already done" + _ -> reject PairingUserRejected diff --git a/src/Erebos/PubKey.hs b/src/Erebos/PubKey.hs new file mode 100644 index 0000000..09a8e02 --- /dev/null +++ b/src/Erebos/PubKey.hs @@ -0,0 +1,156 @@ +module Erebos.PubKey ( + PublicKey, SecretKey, + KeyPair(generateKeys), loadKey, loadKeyMb, + Signature(sigKey), Signed, signedData, signedSignature, + sign, signAdd, isSignedBy, + fromSigned, + unsafeMapSigned, + + PublicKexKey, SecretKexKey, + dhSecret, +) where + +import Control.Monad +import Control.Monad.Except + +import Crypto.Error +import qualified Crypto.PubKey.Ed25519 as ED +import qualified Crypto.PubKey.Curve25519 as CX + +import Data.ByteArray +import Data.ByteString (ByteString) +import qualified Data.Text as T + +import Erebos.Storage +import Erebos.Storage.Key + +data PublicKey = PublicKey ED.PublicKey + deriving (Show) + +data SecretKey = SecretKey ED.SecretKey (Stored PublicKey) + +data Signature = Signature + { sigKey :: Stored PublicKey + , sigSignature :: ED.Signature + } + deriving (Show) + +data Signed a = Signed + { signedData_ :: Stored a + , signedSignature_ :: [Stored Signature] + } + deriving (Show) + +signedData :: Signed a -> Stored a +signedData = signedData_ + +signedSignature :: Signed a -> [Stored Signature] +signedSignature = signedSignature_ + +instance KeyPair SecretKey PublicKey where + keyGetPublic (SecretKey _ pub) = pub + keyGetData (SecretKey sec _) = convert sec + keyFromData kdata spub = do + skey <- maybeCryptoError $ ED.secretKey kdata + let PublicKey pkey = fromStored spub + guard $ ED.toPublic skey == pkey + return $ SecretKey skey spub + generateKeys st = do + secret <- ED.generateSecretKey + public <- wrappedStore st $ PublicKey $ ED.toPublic secret + let pair = SecretKey secret public + storeKey pair + return (pair, public) + +instance Storable PublicKey where + store' (PublicKey pk) = storeRec $ do + storeText "type" $ T.pack "ed25519" + storeBinary "pubkey" pk + + load' = loadRec $ do + ktype <- loadText "type" + guard $ ktype == "ed25519" + maybe (throwError "Public key decoding failed") (return . PublicKey) . + maybeCryptoError . (ED.publicKey :: ByteString -> CryptoFailable ED.PublicKey) =<< + loadBinary "pubkey" + +instance Storable Signature where + store' sig = storeRec $ do + storeRef "key" $ sigKey sig + storeBinary "sig" $ sigSignature sig + + load' = loadRec $ Signature + <$> loadRef "key" + <*> loadSignature "sig" + where loadSignature = maybe (throwError "Signature decoding failed") return . + maybeCryptoError . (ED.signature :: ByteString -> CryptoFailable ED.Signature) <=< loadBinary + +instance Storable a => Storable (Signed a) where + store' sig = storeRec $ do + storeRef "SDATA" $ signedData sig + mapM_ (storeRef "sig") $ signedSignature sig + + load' = loadRec $ do + sdata <- loadRef "SDATA" + sigs <- loadRefs "sig" + forM_ sigs $ \sig -> do + let PublicKey pubkey = fromStored $ sigKey $ fromStored sig + when (not $ ED.verify pubkey (storedRef sdata) $ sigSignature $ fromStored sig) $ + throwError "signature verification failed" + return $ Signed sdata sigs + +sign :: MonadStorage m => SecretKey -> Stored a -> m (Signed a) +sign secret val = signAdd secret $ Signed val [] + +signAdd :: MonadStorage m => SecretKey -> Signed a -> m (Signed a) +signAdd (SecretKey secret spublic) (Signed val sigs) = do + let PublicKey public = fromStored spublic + sig = ED.sign secret public $ storedRef val + ssig <- mstore $ Signature spublic sig + return $ Signed val (ssig : sigs) + +isSignedBy :: Signed a -> Stored PublicKey -> Bool +isSignedBy sig key = key `elem` map (sigKey . fromStored) (signedSignature sig) + +fromSigned :: Stored (Signed a) -> a +fromSigned = fromStored . signedData . fromStored + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapSigned :: (a -> b) -> Signed a -> Signed b +unsafeMapSigned f signed = signed { signedData_ = unsafeMapStored f (signedData_ signed) } + + +data PublicKexKey = PublicKexKey CX.PublicKey + deriving (Show) + +data SecretKexKey = SecretKexKey CX.SecretKey (Stored PublicKexKey) + +instance KeyPair SecretKexKey PublicKexKey where + keyGetPublic (SecretKexKey _ pub) = pub + keyGetData (SecretKexKey sec _) = convert sec + keyFromData kdata spub = do + skey <- maybeCryptoError $ CX.secretKey kdata + let PublicKexKey pkey = fromStored spub + guard $ CX.toPublic skey == pkey + return $ SecretKexKey skey spub + generateKeys st = do + secret <- CX.generateSecretKey + public <- wrappedStore st $ PublicKexKey $ CX.toPublic secret + let pair = SecretKexKey secret public + storeKey pair + return (pair, public) + +instance Storable PublicKexKey where + store' (PublicKexKey pk) = storeRec $ do + storeText "type" $ T.pack "x25519" + storeBinary "pubkey" pk + + load' = loadRec $ do + ktype <- loadText "type" + guard $ ktype == "x25519" + maybe (throwError "public key decoding failed") (return . PublicKexKey) . + maybeCryptoError . (CX.publicKey :: ScrubbedBytes -> CryptoFailable CX.PublicKey) =<< + loadBinary "pubkey" + +dhSecret :: SecretKexKey -> PublicKexKey -> ScrubbedBytes +dhSecret (SecretKexKey secret _) (PublicKexKey public) = convert $ CX.dh public secret diff --git a/src/Erebos/Service.hs b/src/Erebos/Service.hs new file mode 100644 index 0000000..f8428d1 --- /dev/null +++ b/src/Erebos/Service.hs @@ -0,0 +1,190 @@ +module Erebos.Service ( + Service(..), + SomeService(..), someService, someServiceAttr, someServiceID, + SomeServiceState(..), fromServiceState, someServiceEmptyState, + SomeServiceGlobalState(..), fromServiceGlobalState, someServiceEmptyGlobalState, + SomeStorageWatcher(..), + ServiceID, mkServiceID, + + ServiceHandler, + ServiceInput(..), + ServiceReply(..), + runServiceHandler, + + svcGet, svcSet, svcModify, + svcGetGlobal, svcSetGlobal, svcModifyGlobal, + svcGetLocal, svcSetLocal, + + svcSelf, + svcPrint, + + replyPacket, replyStored, replyStoredRef, + afterCommit, +) where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer + +import Data.Kind +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U + +import Erebos.Identity +import {-# SOURCE #-} Erebos.Network +import Erebos.State +import Erebos.Storage + +class (Typeable s, Storable s, Typeable (ServiceState s), Typeable (ServiceGlobalState s)) => Service s where + serviceID :: proxy s -> ServiceID + serviceHandler :: Stored s -> ServiceHandler s () + + serviceNewPeer :: ServiceHandler s () + serviceNewPeer = return () + + type ServiceAttributes s = attr | attr -> s + type ServiceAttributes s = Proxy s + defaultServiceAttributes :: proxy s -> ServiceAttributes s + default defaultServiceAttributes :: ServiceAttributes s ~ Proxy s => proxy s -> ServiceAttributes s + defaultServiceAttributes _ = Proxy + + type ServiceState s :: Type + type ServiceState s = () + emptyServiceState :: proxy s -> ServiceState s + default emptyServiceState :: ServiceState s ~ () => proxy s -> ServiceState s + emptyServiceState _ = () + + type ServiceGlobalState s :: Type + type ServiceGlobalState s = () + emptyServiceGlobalState :: proxy s -> ServiceGlobalState s + default emptyServiceGlobalState :: ServiceGlobalState s ~ () => proxy s -> ServiceGlobalState s + emptyServiceGlobalState _ = () + + serviceStorageWatchers :: proxy s -> [SomeStorageWatcher s] + serviceStorageWatchers _ = [] + + +data SomeService = forall s. Service s => SomeService (Proxy s) (ServiceAttributes s) + +someService :: forall s proxy. Service s => proxy s -> SomeService +someService _ = SomeService @s Proxy (defaultServiceAttributes @s Proxy) + +someServiceAttr :: forall s. Service s => ServiceAttributes s -> SomeService +someServiceAttr attr = SomeService @s Proxy attr + +someServiceID :: SomeService -> ServiceID +someServiceID (SomeService s _) = serviceID s + +data SomeServiceState = forall s. Service s => SomeServiceState (Proxy s) (ServiceState s) + +fromServiceState :: Service s => proxy s -> SomeServiceState -> Maybe (ServiceState s) +fromServiceState _ (SomeServiceState _ s) = cast s + +someServiceEmptyState :: SomeService -> SomeServiceState +someServiceEmptyState (SomeService p _) = SomeServiceState p (emptyServiceState p) + +data SomeServiceGlobalState = forall s. Service s => SomeServiceGlobalState (Proxy s) (ServiceGlobalState s) + +fromServiceGlobalState :: Service s => proxy s -> SomeServiceGlobalState -> Maybe (ServiceGlobalState s) +fromServiceGlobalState _ (SomeServiceGlobalState _ s) = cast s + +someServiceEmptyGlobalState :: SomeService -> SomeServiceGlobalState +someServiceEmptyGlobalState (SomeService p _) = SomeServiceGlobalState p (emptyServiceGlobalState p) + + +data SomeStorageWatcher s = forall a. Eq a => SomeStorageWatcher (Stored LocalState -> a) (a -> ServiceHandler s ()) + + +newtype ServiceID = ServiceID UUID + deriving (Eq, Ord, Show, StorableUUID) + +mkServiceID :: String -> ServiceID +mkServiceID = maybe (error "Invalid service ID") ServiceID . U.fromString + +data ServiceInput s = ServiceInput + { svcAttributes :: ServiceAttributes s + , svcPeer :: Peer + , svcPeerIdentity :: UnifiedIdentity + , svcServer :: Server + , svcPrintOp :: String -> IO () + } + +data ServiceReply s = ServiceReply (Either s (Stored s)) Bool + | ServiceFinally (IO ()) + +data ServiceHandlerState s = ServiceHandlerState + { svcValue :: ServiceState s + , svcGlobal :: ServiceGlobalState s + , svcLocal :: Stored LocalState + } + +newtype ServiceHandler s a = ServiceHandler (ReaderT (ServiceInput s) (WriterT [ServiceReply s] (StateT (ServiceHandlerState s) (ExceptT String IO))) a) + deriving (Functor, Applicative, Monad, MonadReader (ServiceInput s), MonadWriter [ServiceReply s], MonadState (ServiceHandlerState s), MonadError String, MonadIO) + +instance MonadStorage (ServiceHandler s) where + getStorage = asks $ peerStorage . svcPeer + +instance MonadHead LocalState (ServiceHandler s) where + updateLocalHead f = do + (ls, x) <- f =<< gets svcLocal + modify $ \s -> s { svcLocal = ls } + return x + +runServiceHandler :: Service s => Head LocalState -> ServiceInput s -> ServiceState s -> ServiceGlobalState s -> ServiceHandler s () -> IO ([ServiceReply s], (ServiceState s, ServiceGlobalState s)) +runServiceHandler h input svc global shandler = do + let sstate = ServiceHandlerState { svcValue = svc, svcGlobal = global, svcLocal = headStoredObject h } + ServiceHandler handler = shandler + (runExceptT $ flip runStateT sstate $ execWriterT $ flip runReaderT input $ handler) >>= \case + Left err -> do + svcPrintOp input $ "service failed: " ++ err + return ([], (svc, global)) + Right (rsp, sstate') + | svcLocal sstate' == svcLocal sstate -> return (rsp, (svcValue sstate', svcGlobal sstate')) + | otherwise -> replaceHead h (svcLocal sstate') >>= \case + Left (Just h') -> runServiceHandler h' input svc global shandler + _ -> return (rsp, (svcValue sstate', svcGlobal sstate')) + +svcGet :: ServiceHandler s (ServiceState s) +svcGet = gets svcValue + +svcSet :: ServiceState s -> ServiceHandler s () +svcSet x = modify $ \st -> st { svcValue = x } + +svcModify :: (ServiceState s -> ServiceState s) -> ServiceHandler s () +svcModify f = modify $ \st -> st { svcValue = f (svcValue st) } + +svcGetGlobal :: ServiceHandler s (ServiceGlobalState s) +svcGetGlobal = gets svcGlobal + +svcSetGlobal :: ServiceGlobalState s -> ServiceHandler s () +svcSetGlobal x = modify $ \st -> st { svcGlobal = x } + +svcModifyGlobal :: (ServiceGlobalState s -> ServiceGlobalState s) -> ServiceHandler s () +svcModifyGlobal f = modify $ \st -> st { svcGlobal = f (svcGlobal st) } + +svcGetLocal :: ServiceHandler s (Stored LocalState) +svcGetLocal = gets svcLocal + +svcSetLocal :: Stored LocalState -> ServiceHandler s () +svcSetLocal x = modify $ \st -> st { svcLocal = x } + +svcSelf :: ServiceHandler s UnifiedIdentity +svcSelf = maybe (throwError "failed to validate own identity") return . + validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal + +svcPrint :: String -> ServiceHandler s () +svcPrint str = afterCommit . ($ str) =<< asks svcPrintOp + +replyPacket :: Service s => s -> ServiceHandler s () +replyPacket x = tell [ServiceReply (Left x) True] + +replyStored :: Service s => Stored s -> ServiceHandler s () +replyStored x = tell [ServiceReply (Right x) True] + +replyStoredRef :: Service s => Stored s -> ServiceHandler s () +replyStoredRef x = tell [ServiceReply (Right x) False] + +afterCommit :: IO () -> ServiceHandler s () +afterCommit x = tell [ServiceFinally x] diff --git a/src/Erebos/Set.hs b/src/Erebos/Set.hs new file mode 100644 index 0000000..c5edd56 --- /dev/null +++ b/src/Erebos/Set.hs @@ -0,0 +1,86 @@ +module Erebos.Set ( + Set, + + emptySet, + loadSet, + storeSetAdd, + storeSetAddComponent, + + fromSetBy, +) where + +import Control.Arrow +import Control.Monad.IO.Class + +import Data.Function +import Data.List +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Ord + +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Util + +data Set a = Set [Stored (SetItem (Component a))] + deriving (Eq) + +data SetItem a = SetItem + { siPrev :: [Stored (SetItem a)] + , siItem :: [Stored a] + } + +instance Storable a => Storable (SetItem a) where + store' x = storeRec $ do + mapM_ (storeRef "PREV") $ siPrev x + mapM_ (storeRef "item") $ siItem x + + load' = loadRec $ SetItem + <$> loadRefs "PREV" + <*> loadRefs "item" + +instance Mergeable a => Mergeable (Set a) where + type Component (Set a) = SetItem (Component a) + mergeSorted = Set + toComponents (Set items) = items + + +emptySet :: Set a +emptySet = Set [] + +loadSet :: Mergeable a => Ref -> Set a +loadSet = mergeSorted . (:[]) . wrappedLoad + +storeSetAdd :: (Mergeable a, MonadIO m) => Storage -> a -> Set a -> m (Set a) +storeSetAdd st x (Set prev) = Set . (:[]) <$> wrappedStore st SetItem + { siPrev = prev + , siItem = toComponents x + } + +storeSetAddComponent :: (Mergeable a, MonadStorage m, MonadIO m) => Stored (Component a) -> Set a -> m (Set a) +storeSetAddComponent component (Set prev) = Set . (:[]) <$> mstore SetItem + { siPrev = prev + , siItem = [ component ] + } + + +fromSetBy :: forall a. Mergeable a => (a -> a -> Ordering) -> Set a -> [a] +fromSetBy cmp (Set heads) = sortBy cmp $ map merge $ groupRelated items + where + -- gather all item components in the set history + items :: [Stored (Component a)] + items = walkAncestors (siItem . fromStored) heads + + -- map individual roots to full root set as joined in history of individual items + rootToRootSet :: Map RefDigest [RefDigest] + rootToRootSet = foldl' (\m rs -> foldl' (\m' r -> M.insertWith (\a b -> uniq $ sort $ a++b) r rs m') m rs) M.empty $ + map (map (refDigest . storedRef) . storedRoots) items + + -- get full root set for given item component + storedRootSet :: Stored (Component a) -> [RefDigest] + storedRootSet = fromJust . flip M.lookup rootToRootSet . refDigest . storedRef . head . storedRoots + + -- group components of single item, i.e. components sharing some root + groupRelated :: [Stored (Component a)] -> [[Stored (Component a)]] + groupRelated = map (map fst) . groupBy ((==) `on` snd) . sortBy (comparing snd) . map (id &&& storedRootSet) diff --git a/src/Erebos/State.hs b/src/Erebos/State.hs new file mode 100644 index 0000000..324127a --- /dev/null +++ b/src/Erebos/State.hs @@ -0,0 +1,201 @@ +module Erebos.State ( + LocalState(..), + SharedState, SharedType(..), + SharedTypeID, mkSharedTypeID, + + MonadHead(..), + updateLocalHead_, + + loadLocalStateHead, + + updateSharedState, updateSharedState_, + lookupSharedValue, makeSharedStateUpdate, + + localIdentity, + headLocalIdentity, + + mergeSharedIdentity, + updateSharedIdentity, + interactiveIdentityUpdate, +) where + +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Foldable +import Data.Maybe +import qualified Data.Text as T +import qualified Data.Text.IO as T +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U + +import System.IO + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Storage +import Erebos.Storage.Merge + +data LocalState = LocalState + { lsIdentity :: Stored (Signed ExtendedIdentityData) + , lsShared :: [Stored SharedState] + } + +data SharedState = SharedState + { ssPrev :: [Stored SharedState] + , ssType :: Maybe SharedTypeID + , ssValue :: [Ref] + } + +newtype SharedTypeID = SharedTypeID UUID + deriving (Eq, Ord, StorableUUID) + +mkSharedTypeID :: String -> SharedTypeID +mkSharedTypeID = maybe (error "Invalid shared type ID") SharedTypeID . U.fromString + +class Mergeable a => SharedType a where + sharedTypeID :: proxy a -> SharedTypeID + +instance Storable LocalState where + store' st = storeRec $ do + storeRef "id" $ lsIdentity st + mapM_ (storeRef "shared") $ lsShared st + + load' = loadRec $ LocalState + <$> loadRef "id" + <*> loadRefs "shared" + +instance HeadType LocalState where + headTypeID _ = mkHeadTypeID "1d7491a9-7bcb-4eaa-8f13-c8c4c4087e4e" + +instance Storable SharedState where + store' st = storeRec $ do + mapM_ (storeRef "PREV") $ ssPrev st + storeMbUUID "type" $ ssType st + mapM_ (storeRawRef "value") $ ssValue st + + load' = loadRec $ SharedState + <$> loadRefs "PREV" + <*> loadMbUUID "type" + <*> loadRawRefs "value" + +instance SharedType (Maybe ComposedIdentity) where + sharedTypeID _ = mkSharedTypeID "0c6c1fe0-f2d7-4891-926b-c332449f7871" + + +class (MonadIO m, MonadStorage m) => MonadHead a m where + updateLocalHead :: (Stored a -> m (Stored a, b)) -> m b + getLocalHead :: m (Stored a) + getLocalHead = updateLocalHead $ \x -> return (x, x) + +updateLocalHead_ :: MonadHead a m => (Stored a -> m (Stored a)) -> m () +updateLocalHead_ f = updateLocalHead (fmap (,()) . f) + +instance (HeadType a, MonadIO m) => MonadHead a (ReaderT (Head a) m) where + updateLocalHead f = do + h <- ask + snd <$> updateHead h f + + +loadLocalStateHead :: MonadIO m => Storage -> m (Head LocalState) +loadLocalStateHead st = loadHeads st >>= \case + (h:_) -> return h + [] -> liftIO $ do + putStr "Name: " + hFlush stdout + name <- T.getLine + + putStr "Device: " + hFlush stdout + devName <- T.getLine + + owner <- if + | T.null name -> return Nothing + | otherwise -> Just <$> createIdentity st (Just name) Nothing + + identity <- createIdentity st (if T.null devName then Nothing else Just devName) owner + + shared <- wrappedStore st $ SharedState + { ssPrev = [] + , ssType = Just $ sharedTypeID @(Maybe ComposedIdentity) Proxy + , ssValue = [storedRef $ idExtData $ fromMaybe identity owner] + } + storeHead st $ LocalState + { lsIdentity = idExtData identity + , lsShared = [shared] + } + +localIdentity :: LocalState -> UnifiedIdentity +localIdentity ls = maybe (error "failed to verify local identity") + (updateOwners $ maybe [] idExtDataF $ lookupSharedValue $ lsShared ls) + (validateExtendedIdentity $ lsIdentity ls) + +headLocalIdentity :: Head LocalState -> UnifiedIdentity +headLocalIdentity = localIdentity . headObject + + +updateSharedState_ :: forall a m. (SharedType a, MonadHead LocalState m) => (a -> m a) -> Stored LocalState -> m (Stored LocalState) +updateSharedState_ f = fmap fst <$> updateSharedState (fmap (,()) . f) + +updateSharedState :: forall a b m. (SharedType a, MonadHead LocalState m) => (a -> m (a, b)) -> Stored LocalState -> m (Stored LocalState, b) +updateSharedState f = \ls -> do + let shared = lsShared $ fromStored ls + val = lookupSharedValue shared + st <- getStorage + (val', x) <- f val + (,x) <$> if toComponents val' == toComponents val + then return ls + else do shared' <- makeSharedStateUpdate st val' shared + wrappedStore st (fromStored ls) { lsShared = [shared'] } + +lookupSharedValue :: forall a. SharedType a => [Stored SharedState] -> a +lookupSharedValue = mergeSorted . filterAncestors . map wrappedLoad . concatMap (ssValue . fromStored) . filterAncestors . helper + where helper (x:xs) | Just sid <- ssType (fromStored x), sid == sharedTypeID @a Proxy = x : helper xs + | otherwise = helper $ ssPrev (fromStored x) ++ xs + helper [] = [] + +makeSharedStateUpdate :: forall a m. MonadIO m => SharedType a => Storage -> a -> [Stored SharedState] -> m (Stored SharedState) +makeSharedStateUpdate st val prev = liftIO $ wrappedStore st SharedState + { ssPrev = prev + , ssType = Just $ sharedTypeID @a Proxy + , ssValue = storedRef <$> toComponents val + } + + +mergeSharedIdentity :: (MonadHead LocalState m, MonadError String m) => m UnifiedIdentity +mergeSharedIdentity = updateLocalHead $ updateSharedState $ \case + Just cidentity -> do + identity <- mergeIdentity cidentity + return (Just $ toComposedIdentity identity, identity) + Nothing -> throwError "no existing shared identity" + +updateSharedIdentity :: (MonadHead LocalState m, MonadError String m) => m () +updateSharedIdentity = updateLocalHead_ $ updateSharedState_ $ \case + Just identity -> do + Just . toComposedIdentity <$> interactiveIdentityUpdate identity + Nothing -> throwError "no existing shared identity" + +interactiveIdentityUpdate :: (Foldable f, MonadStorage m, MonadIO m, MonadError String m) => Identity f -> m UnifiedIdentity +interactiveIdentityUpdate identity = do + let public = idKeyIdentity identity + + name <- liftIO $ do + T.putStr $ T.concat $ concat + [ [ T.pack "Name" ] + , case idName identity of + Just name -> [T.pack " [", name, T.pack "]"] + Nothing -> [] + , [ T.pack ": " ] + ] + hFlush stdout + T.getLine + + if | T.null name -> mergeIdentity identity + | otherwise -> do + secret <- loadKey public + maybe (throwError "created invalid identity") return . validateIdentity =<< + mstore =<< sign secret =<< mstore (emptyIdentityData public) + { iddPrev = toList $ idDataF identity + , iddName = Just name + } diff --git a/src/Erebos/Storage.hs b/src/Erebos/Storage.hs new file mode 100644 index 0000000..2e6653a --- /dev/null +++ b/src/Erebos/Storage.hs @@ -0,0 +1,1053 @@ +module Erebos.Storage ( + Storage, PartialStorage, StorageCompleteness, + openStorage, memoryStorage, + deriveEphemeralStorage, derivePartialStorage, + + Ref, PartialRef, RefDigest, + refDigest, + readRef, showRef, showRefDigest, + refDigestFromByteString, hashToRefDigest, + copyRef, partialRef, partialRefFromDigest, + + Object, PartialObject, Object'(..), RecItem, RecItem'(..), + serializeObject, deserializeObject, deserializeObjects, + ioLoadObject, ioLoadBytes, + storeRawBytes, lazyLoadBytes, + storeObject, + collectObjects, collectStoredObjects, + + Head, HeadType(..), + HeadTypeID, mkHeadTypeID, + headId, headStorage, headRef, headObject, headStoredObject, + loadHeads, loadHead, reloadHead, + storeHead, replaceHead, updateHead, updateHead_, + loadHeadRaw, storeHeadRaw, replaceHeadRaw, + + WatchedHead, + watchHead, watchHeadWith, unwatchHead, + watchHeadRaw, + + MonadStorage(..), + + Storable(..), ZeroStorable(..), + StorableText(..), StorableDate(..), StorableUUID(..), + + Store, StoreRec, + evalStore, evalStoreObject, + storeBlob, storeRec, storeZero, + storeEmpty, storeInt, storeNum, storeText, storeBinary, storeDate, storeUUID, storeRef, storeRawRef, + storeMbEmpty, storeMbInt, storeMbNum, storeMbText, storeMbBinary, storeMbDate, storeMbUUID, storeMbRef, storeMbRawRef, + storeZRef, + + Load, LoadRec, + evalLoad, + loadCurrentRef, loadCurrentObject, + loadRecCurrentRef, loadRecItems, + + loadBlob, loadRec, loadZero, + loadEmpty, loadInt, loadNum, loadText, loadBinary, loadDate, loadUUID, loadRef, loadRawRef, + loadMbEmpty, loadMbInt, loadMbNum, loadMbText, loadMbBinary, loadMbDate, loadMbUUID, loadMbRef, loadMbRawRef, + loadTexts, loadBinaries, loadRefs, loadRawRefs, + loadZRef, + + Stored, + fromStored, storedRef, + wrappedStore, wrappedLoad, + copyStored, + unsafeMapStored, + + StoreInfo(..), makeStoreInfo, + + StoredHistory, + fromHistory, fromHistoryAt, storedFromHistory, storedHistoryList, + beginHistory, modifyHistory, +) where + +import Control.Applicative +import Control.Concurrent +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.Writer + +import Crypto.Hash + +import Data.Bifunctor +import Data.ByteString (ByteString) +import qualified Data.ByteArray as BA +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.Char +import Data.Function +import qualified Data.HashTable.IO as HT +import Data.List +import qualified Data.Map as M +import Data.Maybe +import Data.Ratio +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding +import Data.Text.Encoding.Error +import Data.Time.Calendar +import Data.Time.Clock +import Data.Time.Format +import Data.Time.LocalTime +import Data.Typeable +import Data.UUID (UUID) +import qualified Data.UUID as U +import qualified Data.UUID.V4 as U + +import System.Directory +import System.FSNotify +import System.FilePath +import System.IO.Error +import System.IO.Unsafe + +import Erebos.Storage.Internal + + +type Storage = Storage' Complete +type PartialStorage = Storage' Partial + +storageVersion :: String +storageVersion = "0.1" + +openStorage :: FilePath -> IO Storage +openStorage path = modifyIOError annotate $ do + let versionFileName = "erebos-storage" + let versionPath = path </> versionFileName + let writeVersionFile = writeFile versionPath $ storageVersion <> "\n" + + doesDirectoryExist path >>= \case + True -> do + listDirectory path >>= \case + files@(_:_) + | versionFileName `elem` files -> do + readFile versionPath >>= \case + content | (ver:_) <- lines content, ver == storageVersion -> return () + | otherwise -> fail "unsupported storage version" + + | "objects" `notElem` files || "heads" `notElem` files -> do + fail "directory is neither empty, nor an existing erebos storage" + + _ -> writeVersionFile + False -> do + createDirectoryIfMissing True $ path + writeVersionFile + + createDirectoryIfMissing True $ path </> "objects" + createDirectoryIfMissing True $ path </> "heads" + watchers <- newMVar (Nothing, [], WatchList 1 []) + refgen <- newMVar =<< HT.new + refroots <- newMVar =<< HT.new + return $ Storage + { stBacking = StorageDir path watchers + , stParent = Nothing + , stRefGeneration = refgen + , stRefRoots = refroots + } + where + annotate e = annotateIOError e "failed to open storage" Nothing (Just path) + +memoryStorage' :: IO (Storage' c') +memoryStorage' = do + backing <- StorageMemory <$> newMVar [] <*> newMVar M.empty <*> newMVar M.empty <*> newMVar (WatchList 1 []) + refgen <- newMVar =<< HT.new + refroots <- newMVar =<< HT.new + return $ Storage + { stBacking = backing + , stParent = Nothing + , stRefGeneration = refgen + , stRefRoots = refroots + } + +memoryStorage :: IO Storage +memoryStorage = memoryStorage' + +deriveEphemeralStorage :: Storage -> IO Storage +deriveEphemeralStorage parent = do + st <- memoryStorage + return $ st { stParent = Just parent } + +derivePartialStorage :: Storage -> IO PartialStorage +derivePartialStorage parent = do + st <- memoryStorage' + return $ st { stParent = Just parent } + +type Ref = Ref' Complete +type PartialRef = Ref' Partial + +zeroRef :: Storage' c -> Ref' c +zeroRef s = Ref s (RefDigest h) + where h = case digestFromByteString $ B.replicate (hashDigestSize $ digestAlgo h) 0 of + Nothing -> error $ "Failed to create zero hash" + Just h' -> h' + digestAlgo :: Digest a -> a + digestAlgo = undefined + +isZeroRef :: Ref' c -> Bool +isZeroRef (Ref _ h) = all (==0) $ BA.unpack h + + +refFromDigest :: Storage' c -> RefDigest -> IO (Maybe (Ref' c)) +refFromDigest st dgst = fmap (const $ Ref st dgst) <$> ioLoadBytesFromStorage st dgst + +readRef :: Storage -> ByteString -> IO (Maybe Ref) +readRef s b = + case readRefDigest b of + Nothing -> return Nothing + Just dgst -> refFromDigest s dgst + +copyRef' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Ref' c -> IO (c (Ref' c')) +copyRef' st ref'@(Ref _ dgst) = refFromDigest st dgst >>= \case Just ref -> return $ return ref + Nothing -> doCopy + where doCopy = do mbobj' <- ioLoadObject ref' + mbobj <- sequence $ copyObject' st <$> mbobj' + sequence $ unsafeStoreObject st <$> join mbobj + +copyObject' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (c (Object' c')) +copyObject' _ (Blob bs) = return $ return $ Blob bs +copyObject' st (Rec rs) = fmap Rec . sequence <$> mapM copyItem rs + where copyItem :: (ByteString, RecItem' c) -> IO (c (ByteString, RecItem' c')) + copyItem (n, item) = fmap (n,) <$> case item of + RecEmpty -> return $ return $ RecEmpty + RecInt x -> return $ return $ RecInt x + RecNum x -> return $ return $ RecNum x + RecText x -> return $ return $ RecText x + RecBinary x -> return $ return $ RecBinary x + RecDate x -> return $ return $ RecDate x + RecUUID x -> return $ return $ RecUUID x + RecRef x -> fmap RecRef <$> copyRef' st x +copyObject' _ ZeroObject = return $ return ZeroObject + +copyRef :: forall c c' m. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => Storage' c' -> Ref' c -> m (LoadResult c (Ref' c')) +copyRef st ref' = liftIO $ returnLoadResult <$> copyRef' st ref' + +copyObject :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (LoadResult c (Object' c')) +copyObject st obj' = returnLoadResult <$> copyObject' st obj' + +partialRef :: PartialStorage -> Ref -> PartialRef +partialRef st (Ref _ dgst) = Ref st dgst + +partialRefFromDigest :: PartialStorage -> RefDigest -> PartialRef +partialRefFromDigest st dgst = Ref st dgst + + +data Object' c + = Blob ByteString + | Rec [(ByteString, RecItem' c)] + | ZeroObject + deriving (Show) + +type Object = Object' Complete +type PartialObject = Object' Partial + +data RecItem' c + = RecEmpty + | RecInt Integer + | RecNum Rational + | RecText Text + | RecBinary ByteString + | RecDate ZonedTime + | RecUUID UUID + | RecRef (Ref' c) + deriving (Show) + +type RecItem = RecItem' Complete + +serializeObject :: Object' c -> BL.ByteString +serializeObject = \case + Blob cnt -> BL.fromChunks [BC.pack "blob ", BC.pack (show $ B.length cnt), BC.singleton '\n', cnt] + Rec rec -> let cnt = BL.fromChunks $ concatMap (uncurry serializeRecItem) rec + in BL.fromChunks [BC.pack "rec ", BC.pack (show $ BL.length cnt), BC.singleton '\n'] `BL.append` cnt + ZeroObject -> BL.empty + +-- |Serializes and stores object data without ony dependencies, so is safe only +-- if all the referenced objects are already stored or reference is partial. +unsafeStoreObject :: Storage' c -> Object' c -> IO (Ref' c) +unsafeStoreObject storage = \case + ZeroObject -> return $ zeroRef storage + obj -> unsafeStoreRawBytes storage $ serializeObject obj + +storeObject :: PartialStorage -> PartialObject -> IO PartialRef +storeObject = unsafeStoreObject + +storeRawBytes :: PartialStorage -> BL.ByteString -> IO PartialRef +storeRawBytes = unsafeStoreRawBytes + +serializeRecItem :: ByteString -> RecItem' c -> [ByteString] +serializeRecItem name (RecEmpty) = [name, BC.pack ":e", BC.singleton ' ', BC.singleton '\n'] +serializeRecItem name (RecInt x) = [name, BC.pack ":i", BC.singleton ' ', BC.pack (show x), BC.singleton '\n'] +serializeRecItem name (RecNum x) = [name, BC.pack ":n", BC.singleton ' ', BC.pack (showRatio x), BC.singleton '\n'] +serializeRecItem name (RecText x) = [name, BC.pack ":t", BC.singleton ' ', escaped, BC.singleton '\n'] + where escaped = BC.concatMap escape $ encodeUtf8 x + escape '\n' = BC.pack "\n\t" + escape c = BC.singleton c +serializeRecItem name (RecBinary x) = [name, BC.pack ":b ", showHex x, BC.singleton '\n'] +serializeRecItem name (RecDate x) = [name, BC.pack ":d", BC.singleton ' ', BC.pack (formatTime defaultTimeLocale "%s %z" x), BC.singleton '\n'] +serializeRecItem name (RecUUID x) = [name, BC.pack ":u", BC.singleton ' ', U.toASCIIBytes x, BC.singleton '\n'] +serializeRecItem name (RecRef x) = [name, BC.pack ":r ", showRef x, BC.singleton '\n'] + +lazyLoadObject :: forall c. StorageCompleteness c => Ref' c -> LoadResult c (Object' c) +lazyLoadObject = returnLoadResult . unsafePerformIO . ioLoadObject + +ioLoadObject :: forall c. StorageCompleteness c => Ref' c -> IO (c (Object' c)) +ioLoadObject ref | isZeroRef ref = return $ return ZeroObject +ioLoadObject ref@(Ref st rhash) = do + file' <- ioLoadBytes ref + return $ do + file <- file' + let chash = hashToRefDigest file + when (chash /= rhash) $ error $ "Hash mismatch on object " ++ BC.unpack (showRef ref) {- TODO throw -} + return $ case runExcept $ unsafeDeserializeObject st file of + Left err -> error $ err ++ ", ref " ++ BC.unpack (showRef ref) {- TODO throw -} + Right (x, rest) | BL.null rest -> x + | otherwise -> error $ "Superfluous content after " ++ BC.unpack (showRef ref) {- TODO throw -} + +lazyLoadBytes :: forall c. StorageCompleteness c => Ref' c -> LoadResult c BL.ByteString +lazyLoadBytes ref | isZeroRef ref = returnLoadResult (return BL.empty :: c BL.ByteString) +lazyLoadBytes ref = returnLoadResult $ unsafePerformIO $ ioLoadBytes ref + +unsafeDeserializeObject :: Storage' c -> BL.ByteString -> Except String (Object' c, BL.ByteString) +unsafeDeserializeObject _ bytes | BL.null bytes = return (ZeroObject, bytes) +unsafeDeserializeObject st bytes = + case BLC.break (=='\n') bytes of + (line, rest) | Just (otype, len) <- splitObjPrefix line -> do + let (content, next) = first BL.toStrict $ BL.splitAt (fromIntegral len) $ BL.drop 1 rest + guard $ B.length content == len + (,next) <$> case otype of + _ | otype == BC.pack "blob" -> return $ Blob content + | otype == BC.pack "rec" -> maybe (throwError $ "Malformed record item ") + (return . Rec) $ sequence $ map parseRecLine $ mergeCont [] $ BC.lines content + | otherwise -> throwError $ "Unknown object type" + _ -> throwError $ "Malformed object" + where splitObjPrefix line = do + [otype, tlen] <- return $ BLC.words line + (len, rest) <- BLC.readInt tlen + guard $ BL.null rest + return (BL.toStrict otype, len) + + mergeCont cs (a:b:rest) | Just ('\t', b') <- BC.uncons b = mergeCont (b':BC.pack "\n":cs) (a:rest) + mergeCont cs (a:rest) = B.concat (a : reverse cs) : mergeCont [] rest + mergeCont _ [] = [] + + parseRecLine line = do + colon <- BC.elemIndex ':' line + space <- BC.elemIndex ' ' line + guard $ colon < space + let name = B.take colon line + itype = B.take (space-colon-1) $ B.drop (colon+1) line + content = B.drop (space+1) line + + val <- case BC.unpack itype of + "e" -> do guard $ B.null content + return RecEmpty + "i" -> do (num, rest) <- BC.readInteger content + guard $ B.null rest + return $ RecInt num + "n" -> RecNum <$> parseRatio content + "t" -> return $ RecText $ decodeUtf8With lenientDecode content + "b" -> RecBinary <$> readHex content + "d" -> RecDate <$> parseTimeM False defaultTimeLocale "%s %z" (BC.unpack content) + "u" -> RecUUID <$> U.fromASCIIBytes content + "r" -> RecRef . Ref st <$> readRefDigest content + _ -> Nothing + return (name, val) + +deserializeObject :: PartialStorage -> BL.ByteString -> Except String (PartialObject, BL.ByteString) +deserializeObject = unsafeDeserializeObject + +deserializeObjects :: PartialStorage -> BL.ByteString -> Except String [PartialObject] +deserializeObjects _ bytes | BL.null bytes = return [] +deserializeObjects st bytes = do (obj, rest) <- deserializeObject st bytes + (obj:) <$> deserializeObjects st rest + + +collectObjects :: Object -> [Object] +collectObjects obj = obj : map fromStored (fst $ collectOtherStored S.empty obj) + +collectStoredObjects :: Stored Object -> [Stored Object] +collectStoredObjects obj = obj : (fst $ collectOtherStored S.empty $ fromStored obj) + +collectOtherStored :: Set RefDigest -> Object -> ([Stored Object], Set RefDigest) +collectOtherStored seen (Rec items) = foldr helper ([], seen) $ map snd items + where helper (RecRef ref) (xs, s) | r <- refDigest ref + , r `S.notMember` s + = let o = wrappedLoad ref + (xs', s') = collectOtherStored (S.insert r s) $ fromStored o + in ((o : xs') ++ xs, s') + helper _ (xs, s) = (xs, s) +collectOtherStored seen _ = ([], seen) + + +type Head = Head' Complete + +headId :: Head a -> HeadID +headId (Head uuid _) = uuid + +headStorage :: Head a -> Storage +headStorage = refStorage . headRef + +headRef :: Head a -> Ref +headRef (Head _ sx) = storedRef sx + +headObject :: Head a -> a +headObject (Head _ sx) = fromStored sx + +headStoredObject :: Head a -> Stored a +headStoredObject (Head _ sx) = sx + +deriving instance StorableUUID HeadID +deriving instance StorableUUID HeadTypeID + +mkHeadTypeID :: String -> HeadTypeID +mkHeadTypeID = maybe (error "Invalid head type ID") HeadTypeID . U.fromString + +class Storable a => HeadType a where + headTypeID :: proxy a -> HeadTypeID + + +headTypePath :: FilePath -> HeadTypeID -> FilePath +headTypePath spath (HeadTypeID tid) = spath </> "heads" </> U.toString tid + +headPath :: FilePath -> HeadTypeID -> HeadID -> FilePath +headPath spath tid (HeadID hid) = headTypePath spath tid </> U.toString hid + +loadHeads :: forall a m. MonadIO m => HeadType a => Storage -> m [Head a] +loadHeads s@(Storage { stBacking = StorageDir { dirPath = spath }}) = liftIO $ do + let hpath = headTypePath spath $ headTypeID @a Proxy + + files <- filterM (doesFileExist . (hpath </>)) =<< + handleJust (\e -> guard (isDoesNotExistError e)) (const $ return []) + (getDirectoryContents hpath) + fmap catMaybes $ forM files $ \hname -> do + case U.fromString hname of + Just hid -> do + (h:_) <- BC.lines <$> B.readFile (hpath </> hname) + Just ref <- readRef s h + return $ Just $ Head (HeadID hid) $ wrappedLoad ref + Nothing -> return Nothing +loadHeads Storage { stBacking = StorageMemory { memHeads = theads } } = liftIO $ do + let toHead ((tid, hid), ref) | tid == headTypeID @a Proxy = Just $ Head hid $ wrappedLoad ref + | otherwise = Nothing + catMaybes . map toHead <$> readMVar theads + +loadHead :: forall a m. (HeadType a, MonadIO m) => Storage -> HeadID -> m (Maybe (Head a)) +loadHead st hid = fmap (Head hid . wrappedLoad) <$> loadHeadRaw st (headTypeID @a Proxy) hid + +loadHeadRaw :: forall m. MonadIO m => Storage -> HeadTypeID -> HeadID -> m (Maybe Ref) +loadHeadRaw s@(Storage { stBacking = StorageDir { dirPath = spath }}) tid hid = liftIO $ do + handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ do + (h:_) <- BC.lines <$> B.readFile (headPath spath tid hid) + Just ref <- readRef s h + return $ Just ref +loadHeadRaw Storage { stBacking = StorageMemory { memHeads = theads } } tid hid = liftIO $ do + lookup (tid, hid) <$> readMVar theads + +reloadHead :: (HeadType a, MonadIO m) => Head a -> m (Maybe (Head a)) +reloadHead (Head hid (Stored (Ref st _) _)) = loadHead st hid + +storeHead :: forall a m. MonadIO m => HeadType a => Storage -> a -> m (Head a) +storeHead st obj = do + let tid = headTypeID @a Proxy + stored <- wrappedStore st obj + hid <- storeHeadRaw st tid (storedRef stored) + return $ Head hid stored + +storeHeadRaw :: forall m. MonadIO m => Storage -> HeadTypeID -> Ref -> m HeadID +storeHeadRaw st tid ref = liftIO $ do + hid <- HeadID <$> U.nextRandom + case stBacking st of + StorageDir { dirPath = spath } -> do + Right () <- writeFileChecked (headPath spath tid hid) Nothing $ + showRef ref `B.append` BC.singleton '\n' + return () + StorageMemory { memHeads = theads } -> do + modifyMVar_ theads $ return . (((tid, hid), ref) :) + return hid + +replaceHead :: forall a m. (HeadType a, MonadIO m) => Head a -> Stored a -> m (Either (Maybe (Head a)) (Head a)) +replaceHead prev@(Head hid pobj) stored' = liftIO $ do + let st = headStorage prev + tid = headTypeID @a Proxy + stored <- copyStored st stored' + bimap (fmap $ Head hid . wrappedLoad) (const $ Head hid stored) <$> + replaceHeadRaw st tid hid (storedRef pobj) (storedRef stored) + +replaceHeadRaw :: forall m. MonadIO m => Storage -> HeadTypeID -> HeadID -> Ref -> Ref -> m (Either (Maybe Ref) Ref) +replaceHeadRaw st tid hid prev new = liftIO $ do + case stBacking st of + StorageDir { dirPath = spath } -> do + let filename = headPath spath tid hid + showRefL r = showRef r `B.append` BC.singleton '\n' + + writeFileChecked filename (Just $ showRefL prev) (showRefL new) >>= \case + Left Nothing -> return $ Left Nothing + Left (Just bs) -> do Just oref <- readRef st $ BC.takeWhile (/='\n') bs + return $ Left $ Just oref + Right () -> return $ Right new + + StorageMemory { memHeads = theads, memWatchers = twatch } -> do + res <- modifyMVar theads $ \hs -> do + ws <- map wlFun . filter ((==(tid, hid)) . wlHead) . wlList <$> readMVar twatch + return $ case partition ((==(tid, hid)) . fst) hs of + ([] , _ ) -> (hs, Left Nothing) + ((_, r):_, hs') | r == prev -> (((tid, hid), new) : hs', + Right (new, ws)) + | otherwise -> (hs, Left $ Just r) + case res of + Right (r, ws) -> mapM_ ($ r) ws >> return (Right r) + Left x -> return $ Left x + +updateHead :: (HeadType a, MonadIO m) => Head a -> (Stored a -> m (Stored a, b)) -> m (Maybe (Head a), b) +updateHead h f = do + (o, x) <- f $ headStoredObject h + replaceHead h o >>= \case + Right h' -> return (Just h', x) + Left Nothing -> return (Nothing, x) + Left (Just h') -> updateHead h' f + +updateHead_ :: (HeadType a, MonadIO m) => Head a -> (Stored a -> m (Stored a)) -> m (Maybe (Head a)) +updateHead_ h = fmap fst . updateHead h . (fmap (,()) .) + + +data WatchedHead = forall a. WatchedHead Storage WatchID (MVar a) + +watchHead :: forall a. HeadType a => Head a -> (Head a -> IO ()) -> IO WatchedHead +watchHead h = watchHeadWith h id + +watchHeadWith :: forall a b. (HeadType a, Eq b) => Head a -> (Head a -> b) -> (b -> IO ()) -> IO WatchedHead +watchHeadWith (Head hid (Stored (Ref st _) _)) sel cb = do + watchHeadRaw st (headTypeID @a Proxy) hid (sel . Head hid . wrappedLoad) cb + +watchHeadRaw :: forall b. Eq b => Storage -> HeadTypeID -> HeadID -> (Ref -> b) -> (b -> IO ()) -> IO WatchedHead +watchHeadRaw st tid hid sel cb = do + memo <- newEmptyMVar + let addWatcher wl = (wl', WatchedHead st (wlNext wl) memo) + where wl' = wl { wlNext = wlNext wl + 1 + , wlList = WatchListItem + { wlID = wlNext wl + , wlHead = (tid, hid) + , wlFun = \r -> do + let x = sel r + modifyMVar_ memo $ \prev -> do + when (Just x /= prev) $ cb x + return $ Just x + } : wlList wl + } + + watched <- case stBacking st of + StorageDir { dirPath = spath, dirWatchers = mvar } -> modifyMVar mvar $ \(mbmanager, ilist, wl) -> do + manager <- maybe startManager return mbmanager + ilist' <- case tid `elem` ilist of + True -> return ilist + False -> do + void $ watchDir manager (headTypePath spath tid) (const True) $ \case + Added { eventPath = fpath } | Just ihid <- HeadID <$> U.fromString (takeFileName fpath) -> do + loadHeadRaw st tid ihid >>= \case + Just ref -> do + (_, _, iwl) <- readMVar mvar + mapM_ ($ ref) . map wlFun . filter ((== (tid, ihid)) . wlHead) . wlList $ iwl + Nothing -> return () + _ -> return () + return $ tid : ilist + return $ first ( Just manager, ilist', ) $ addWatcher wl + + StorageMemory { memWatchers = mvar } -> modifyMVar mvar $ return . addWatcher + + cur <- fmap sel <$> loadHeadRaw st tid hid + maybe (return ()) cb cur + putMVar memo cur + + return watched + +unwatchHead :: WatchedHead -> IO () +unwatchHead (WatchedHead st wid _) = do + let delWatcher wl = wl { wlList = filter ((/=wid) . wlID) $ wlList wl } + case stBacking st of + StorageDir { dirWatchers = mvar } -> modifyMVar_ mvar $ return . second delWatcher + StorageMemory { memWatchers = mvar } -> modifyMVar_ mvar $ return . delWatcher + + +class Monad m => MonadStorage m where + getStorage :: m Storage + mstore :: Storable a => a -> m (Stored a) + + default mstore :: MonadIO m => Storable a => a -> m (Stored a) + mstore x = do + st <- getStorage + wrappedStore st x + +instance MonadIO m => MonadStorage (ReaderT Storage m) where + getStorage = ask + +instance MonadIO m => MonadStorage (ReaderT (Head a) m) where + getStorage = asks $ headStorage + + +class Storable a where + store' :: a -> Store + load' :: Load a + + store :: StorageCompleteness c => Storage' c -> a -> IO (Ref' c) + store st = evalStore st . store' + load :: Ref -> a + load = evalLoad load' + +class Storable a => ZeroStorable a where + fromZero :: Storage -> a + +data Store = StoreBlob ByteString + | StoreRec (forall c. StorageCompleteness c => Storage' c -> [IO [(ByteString, RecItem' c)]]) + | StoreZero + +evalStore :: StorageCompleteness c => Storage' c -> Store -> IO (Ref' c) +evalStore st = unsafeStoreObject st <=< evalStoreObject st + +evalStoreObject :: StorageCompleteness c => Storage' c -> Store -> IO (Object' c) +evalStoreObject _ (StoreBlob x) = return $ Blob x +evalStoreObject s (StoreRec f) = Rec . concat <$> sequence (f s) +evalStoreObject _ StoreZero = return ZeroObject + +newtype StoreRecM c a = StoreRecM (ReaderT (Storage' c) (Writer [IO [(ByteString, RecItem' c)]]) a) + deriving (Functor, Applicative, Monad) + +type StoreRec c = StoreRecM c () + +newtype Load a = Load (ReaderT (Ref, Object) (Except String) a) + deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError String) + +evalLoad :: Load a -> Ref -> a +evalLoad (Load f) ref = either (error {- TODO throw -} . ((BC.unpack (showRef ref) ++ ": ")++)) id $ runExcept $ runReaderT f (ref, lazyLoadObject ref) + +loadCurrentRef :: Load Ref +loadCurrentRef = Load $ asks fst + +loadCurrentObject :: Load Object +loadCurrentObject = Load $ asks snd + +newtype LoadRec a = LoadRec (ReaderT (Ref, [(ByteString, RecItem)]) (Except String) a) + deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError String) + +loadRecCurrentRef :: LoadRec Ref +loadRecCurrentRef = LoadRec $ asks fst + +loadRecItems :: LoadRec [(ByteString, RecItem)] +loadRecItems = LoadRec $ asks snd + + +instance Storable Object where + store' (Blob bs) = StoreBlob bs + store' (Rec xs) = StoreRec $ \st -> return $ do + Rec xs' <- copyObject st (Rec xs) + return xs' + store' ZeroObject = StoreZero + + load' = loadCurrentObject + + store st = unsafeStoreObject st <=< copyObject st + load = lazyLoadObject + +instance Storable ByteString where + store' = storeBlob + load' = loadBlob id + +instance Storable a => Storable [a] where + store' [] = storeZero + store' (x:xs) = storeRec $ do + storeRef "i" x + storeRef "n" xs + + load' = loadCurrentObject >>= \case + ZeroObject -> return [] + _ -> loadRec $ (:) + <$> loadRef "i" + <*> loadRef "n" + +instance Storable a => ZeroStorable [a] where + fromZero _ = [] + + +storeBlob :: ByteString -> Store +storeBlob = StoreBlob + +storeRec :: (forall c. StorageCompleteness c => StoreRec c) -> Store +storeRec sr = StoreRec $ do + let StoreRecM r = sr + execWriter . runReaderT r + +storeZero :: Store +storeZero = StoreZero + + +class StorableText a where + toText :: a -> Text + fromText :: MonadError String m => Text -> m a + +instance StorableText Text where + toText = id; fromText = return + +instance StorableText [Char] where + toText = T.pack; fromText = return . T.unpack + + +class StorableDate a where + toDate :: a -> ZonedTime + fromDate :: ZonedTime -> a + +instance StorableDate ZonedTime where + toDate = id; fromDate = id + +instance StorableDate UTCTime where + toDate = utcToZonedTime utc + fromDate = zonedTimeToUTC + +instance StorableDate Day where + toDate day = toDate $ UTCTime day 0 + fromDate = utctDay . fromDate + + +class StorableUUID a where + toUUID :: a -> UUID + fromUUID :: UUID -> a + +instance StorableUUID UUID where + toUUID = id; fromUUID = id + + +storeEmpty :: String -> StoreRec c +storeEmpty name = StoreRecM $ tell [return [(BC.pack name, RecEmpty)]] + +storeMbEmpty :: String -> Maybe () -> StoreRec c +storeMbEmpty name = maybe (return ()) (const $ storeEmpty name) + +storeInt :: Integral a => String -> a -> StoreRec c +storeInt name x = StoreRecM $ tell [return [(BC.pack name, RecInt $ toInteger x)]] + +storeMbInt :: Integral a => String -> Maybe a -> StoreRec c +storeMbInt name = maybe (return ()) (storeInt name) + +storeNum :: (Real a, Fractional a) => String -> a -> StoreRec c +storeNum name x = StoreRecM $ tell [return [(BC.pack name, RecNum $ toRational x)]] + +storeMbNum :: (Real a, Fractional a) => String -> Maybe a -> StoreRec c +storeMbNum name = maybe (return ()) (storeNum name) + +storeText :: StorableText a => String -> a -> StoreRec c +storeText name x = StoreRecM $ tell [return [(BC.pack name, RecText $ toText x)]] + +storeMbText :: StorableText a => String -> Maybe a -> StoreRec c +storeMbText name = maybe (return ()) (storeText name) + +storeBinary :: BA.ByteArrayAccess a => String -> a -> StoreRec c +storeBinary name x = StoreRecM $ tell [return [(BC.pack name, RecBinary $ BA.convert x)]] + +storeMbBinary :: BA.ByteArrayAccess a => String -> Maybe a -> StoreRec c +storeMbBinary name = maybe (return ()) (storeBinary name) + +storeDate :: StorableDate a => String -> a -> StoreRec c +storeDate name x = StoreRecM $ tell [return [(BC.pack name, RecDate $ toDate x)]] + +storeMbDate :: StorableDate a => String -> Maybe a -> StoreRec c +storeMbDate name = maybe (return ()) (storeDate name) + +storeUUID :: StorableUUID a => String -> a -> StoreRec c +storeUUID name x = StoreRecM $ tell [return [(BC.pack name, RecUUID $ toUUID x)]] + +storeMbUUID :: StorableUUID a => String -> Maybe a -> StoreRec c +storeMbUUID name = maybe (return ()) (storeUUID name) + +storeRef :: Storable a => StorageCompleteness c => String -> a -> StoreRec c +storeRef name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return [(BC.pack name, RecRef ref)] + +storeMbRef :: Storable a => StorageCompleteness c => String -> Maybe a -> StoreRec c +storeMbRef name = maybe (return ()) (storeRef name) + +storeRawRef :: StorageCompleteness c => String -> Ref -> StoreRec c +storeRawRef name ref = StoreRecM $ do + st <- ask + tell $ (:[]) $ do + ref' <- copyRef st ref + return [(BC.pack name, RecRef ref')] + +storeMbRawRef :: StorageCompleteness c => String -> Maybe Ref -> StoreRec c +storeMbRawRef name = maybe (return ()) (storeRawRef name) + +storeZRef :: (ZeroStorable a, StorageCompleteness c) => String -> a -> StoreRec c +storeZRef name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return $ if isZeroRef ref then [] + else [(BC.pack name, RecRef ref)] + + +loadBlob :: (ByteString -> a) -> Load a +loadBlob f = loadCurrentObject >>= \case + Blob x -> return $ f x + _ -> throwError "Expecting blob" + +loadRec :: LoadRec a -> Load a +loadRec (LoadRec lrec) = loadCurrentObject >>= \case + Rec rs -> do + ref <- loadCurrentRef + either throwError return $ runExcept $ runReaderT lrec (ref, rs) + _ -> throwError "Expecting record" + +loadZero :: a -> Load a +loadZero x = loadCurrentObject >>= \case + ZeroObject -> return x + _ -> throwError "Expecting zero" + + +loadEmpty :: String -> LoadRec () +loadEmpty name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbEmpty name + +loadMbEmpty :: String -> LoadRec (Maybe ()) +loadMbEmpty name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecEmpty) -> return (Just ()) + Just _ -> throwError $ "Expecting type int of record item '"++name++"'" + +loadInt :: Num a => String -> LoadRec a +loadInt name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbInt name + +loadMbInt :: Num a => String -> LoadRec (Maybe a) +loadMbInt name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecInt x) -> return (Just $ fromInteger x) + Just _ -> throwError $ "Expecting type int of record item '"++name++"'" + +loadNum :: (Real a, Fractional a) => String -> LoadRec a +loadNum name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbNum name + +loadMbNum :: (Real a, Fractional a) => String -> LoadRec (Maybe a) +loadMbNum name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecNum x) -> return (Just $ fromRational x) + Just _ -> throwError $ "Expecting type number of record item '"++name++"'" + +loadText :: StorableText a => String -> LoadRec a +loadText name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbText name + +loadMbText :: StorableText a => String -> LoadRec (Maybe a) +loadMbText name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecText x) -> Just <$> fromText x + Just _ -> throwError $ "Expecting type text of record item '"++name++"'" + +loadTexts :: StorableText a => String -> LoadRec [a] +loadTexts name = do + items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems + forM items $ \case RecText x -> fromText x + _ -> throwError $ "Expecting type text of record item '"++name++"'" + +loadBinary :: BA.ByteArray a => String -> LoadRec a +loadBinary name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbBinary name + +loadMbBinary :: BA.ByteArray a => String -> LoadRec (Maybe a) +loadMbBinary name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecBinary x) -> return $ Just $ BA.convert x + Just _ -> throwError $ "Expecting type binary of record item '"++name++"'" + +loadBinaries :: BA.ByteArray a => String -> LoadRec [a] +loadBinaries name = do + items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems + forM items $ \case RecBinary x -> return $ BA.convert x + _ -> throwError $ "Expecting type binary of record item '"++name++"'" + +loadDate :: StorableDate a => String -> LoadRec a +loadDate name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbDate name + +loadMbDate :: StorableDate a => String -> LoadRec (Maybe a) +loadMbDate name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecDate x) -> return $ Just $ fromDate x + Just _ -> throwError $ "Expecting type date of record item '"++name++"'" + +loadUUID :: StorableUUID a => String -> LoadRec a +loadUUID name = maybe (throwError $ "Missing record iteem '"++name++"'") return =<< loadMbUUID name + +loadMbUUID :: StorableUUID a => String -> LoadRec (Maybe a) +loadMbUUID name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecUUID x) -> return $ Just $ fromUUID x + Just _ -> throwError $ "Expecting type UUID of record item '"++name++"'" + +loadRawRef :: String -> LoadRec Ref +loadRawRef name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbRawRef name + +loadMbRawRef :: String -> LoadRec (Maybe Ref) +loadMbRawRef name = (lookup (BC.pack name) <$> loadRecItems) >>= \case + Nothing -> return Nothing + Just (RecRef x) -> return (Just x) + Just _ -> throwError $ "Expecting type ref of record item '"++name++"'" + +loadRawRefs :: String -> LoadRec [Ref] +loadRawRefs name = do + items <- map snd . filter ((BC.pack name ==) . fst) <$> loadRecItems + forM items $ \case RecRef x -> return x + _ -> throwError $ "Expecting type ref of record item '"++name++"'" + +loadRef :: Storable a => String -> LoadRec a +loadRef name = load <$> loadRawRef name + +loadMbRef :: Storable a => String -> LoadRec (Maybe a) +loadMbRef name = fmap load <$> loadMbRawRef name + +loadRefs :: Storable a => String -> LoadRec [a] +loadRefs name = map load <$> loadRawRefs name + +loadZRef :: ZeroStorable a => String -> LoadRec a +loadZRef name = loadMbRef name >>= \case + Nothing -> do Ref st _ <- loadRecCurrentRef + return $ fromZero st + Just x -> return x + + +type Stored a = Stored' Complete a + +instance Storable a => Storable (Stored a) where + store st = copyRef st . storedRef + store' (Stored _ x) = store' x + load' = Stored <$> loadCurrentRef <*> load' + +instance ZeroStorable a => ZeroStorable (Stored a) where + fromZero st = Stored (zeroRef st) $ fromZero st + +fromStored :: Stored a -> a +fromStored (Stored _ x) = x + +storedRef :: Stored a -> Ref +storedRef (Stored ref _) = ref + +wrappedStore :: MonadIO m => Storable a => Storage -> a -> m (Stored a) +wrappedStore st x = do ref <- liftIO $ store st x + return $ Stored ref x + +wrappedLoad :: Storable a => Ref -> Stored a +wrappedLoad ref = Stored ref (load ref) + +copyStored :: forall c c' m a. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => + Storage' c' -> Stored' c a -> m (LoadResult c (Stored' c' a)) +copyStored st (Stored ref' x) = liftIO $ returnLoadResult . fmap (flip Stored x) <$> copyRef' st ref' + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapStored :: (a -> b) -> Stored a -> Stored b +unsafeMapStored f (Stored ref x) = Stored ref (f x) + + +data StoreInfo = StoreInfo + { infoDate :: ZonedTime + , infoNote :: Maybe Text + } + deriving (Show) + +makeStoreInfo :: IO StoreInfo +makeStoreInfo = StoreInfo + <$> getZonedTime + <*> pure Nothing + +storeInfoRec :: StoreInfo -> StoreRec c +storeInfoRec info = do + storeDate "date" $ infoDate info + storeMbText "note" $ infoNote info + +loadInfoRec :: LoadRec StoreInfo +loadInfoRec = StoreInfo + <$> loadDate "date" + <*> loadMbText "note" + + +data History a = History StoreInfo (Stored a) (Maybe (StoredHistory a)) + deriving (Show) + +type StoredHistory a = Stored (History a) + +instance Storable a => Storable (History a) where + store' (History si x prev) = storeRec $ do + storeInfoRec si + storeMbRef "prev" prev + storeRef "item" x + + load' = loadRec $ History + <$> loadInfoRec + <*> loadRef "item" + <*> loadMbRef "prev" + +fromHistory :: StoredHistory a -> a +fromHistory = fromStored . storedFromHistory + +fromHistoryAt :: ZonedTime -> StoredHistory a -> Maybe a +fromHistoryAt zat = fmap (fromStored . snd) . listToMaybe . dropWhile ((at<) . zonedTimeToUTC . fst) . storedHistoryTimedList + where at = zonedTimeToUTC zat + +storedFromHistory :: StoredHistory a -> Stored a +storedFromHistory sh = let History _ item _ = fromStored sh + in item + +storedHistoryList :: StoredHistory a -> [Stored a] +storedHistoryList = map snd . storedHistoryTimedList + +storedHistoryTimedList :: StoredHistory a -> [(ZonedTime, Stored a)] +storedHistoryTimedList sh = let History hinfo item prev = fromStored sh + in (infoDate hinfo, item) : maybe [] storedHistoryTimedList prev + +beginHistory :: Storable a => Storage -> StoreInfo -> a -> IO (StoredHistory a) +beginHistory st si x = do sx <- wrappedStore st x + wrappedStore st $ History si sx Nothing + +modifyHistory :: Storable a => StoreInfo -> (a -> a) -> StoredHistory a -> IO (StoredHistory a) +modifyHistory si f prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st $ f $ fromHistory prev + wrappedStore st $ History si sx (Just prev) + + +showRatio :: Rational -> String +showRatio r = case decimalRatio r of + Just (n, 1) -> show n + Just (n', d) -> let n = abs n' + in (if n' < 0 then "-" else "") ++ show (n `div` d) ++ "." ++ + (concatMap (show.(`mod` 10).snd) $ reverse $ takeWhile ((>1).fst) $ zip (iterate (`div` 10) d) (iterate (`div` 10) (n `mod` d))) + Nothing -> show (numerator r) ++ "/" ++ show (denominator r) + +decimalRatio :: Rational -> Maybe (Integer, Integer) +decimalRatio r = do + let n = numerator r + d = denominator r + (c2, d') = takeFactors 2 d + (c5, d'') = takeFactors 5 d' + guard $ d'' == 1 + let m = if c2 > c5 then 5 ^ (c2 - c5) + else 2 ^ (c5 - c2) + return (n * m, d * m) + +takeFactors :: Integer -> Integer -> (Integer, Integer) +takeFactors f n | n `mod` f == 0 = let (c, n') = takeFactors f (n `div` f) + in (c+1, n') + | otherwise = (0, n) + +parseRatio :: ByteString -> Maybe Rational +parseRatio bs = case BC.groupBy ((==) `on` isNumber) bs of + (m:xs) | m == BC.pack "-" -> negate <$> positive xs + xs -> positive xs + where positive = \case + [bx] -> fromInteger . fst <$> BC.readInteger bx + [bx, op, by] -> do + (x, _) <- BC.readInteger bx + (y, _) <- BC.readInteger by + case BC.unpack op of + "." -> return $ (x % 1) + (y % (10 ^ BC.length by)) + "/" -> return $ x % y + _ -> Nothing + _ -> Nothing diff --git a/src/Erebos/Storage/Internal.hs b/src/Erebos/Storage/Internal.hs new file mode 100644 index 0000000..8b794d8 --- /dev/null +++ b/src/Erebos/Storage/Internal.hs @@ -0,0 +1,271 @@ +module Erebos.Storage.Internal where + +import Codec.Compression.Zlib + +import Control.Arrow +import Control.Concurrent +import Control.DeepSeq +import Control.Exception +import Control.Monad +import Control.Monad.Identity + +import Crypto.Hash + +import Data.Bits +import Data.ByteArray (ByteArray, ByteArrayAccess, ScrubbedBytes) +import qualified Data.ByteArray as BA +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import Data.Char +import Data.Function +import Data.Hashable +import qualified Data.HashTable.IO as HT +import Data.Kind +import Data.List +import Data.Map (Map) +import qualified Data.Map as M +import Data.UUID (UUID) + +import Foreign.Storable (peek) + +import System.Directory +import System.FSNotify (WatchManager) +import System.FilePath +import System.IO +import System.IO.Error +import System.IO.Unsafe (unsafePerformIO) + +import Erebos.Storage.Platform + + +data Storage' c = Storage + { stBacking :: StorageBacking c + , stParent :: Maybe (Storage' Identity) + , stRefGeneration :: MVar (HT.BasicHashTable RefDigest Generation) + , stRefRoots :: MVar (HT.BasicHashTable RefDigest [RefDigest]) + } + +instance Eq (Storage' c) where + (==) = (==) `on` (stBacking &&& stParent) + +instance Show (Storage' c) where + show st@(Storage { stBacking = StorageDir { dirPath = path }}) = "dir" ++ showParentStorage st ++ ":" ++ path + show st@(Storage { stBacking = StorageMemory {} }) = "mem" ++ showParentStorage st + +showParentStorage :: Storage' c -> String +showParentStorage Storage { stParent = Nothing } = "" +showParentStorage Storage { stParent = Just st } = "@" ++ show st + +data StorageBacking c + = StorageDir { dirPath :: FilePath + , dirWatchers :: MVar ( Maybe WatchManager, [ HeadTypeID ], WatchList c ) + } + | StorageMemory { memHeads :: MVar [((HeadTypeID, HeadID), Ref' c)] + , memObjs :: MVar (Map RefDigest BL.ByteString) + , memKeys :: MVar (Map RefDigest ScrubbedBytes) + , memWatchers :: MVar (WatchList c) + } + deriving (Eq) + +newtype WatchID = WatchID Int + deriving (Eq, Ord, Num) + +data WatchList c = WatchList + { wlNext :: WatchID + , wlList :: [WatchListItem c] + } + +data WatchListItem c = WatchListItem + { wlID :: WatchID + , wlHead :: (HeadTypeID, HeadID) + , wlFun :: Ref' c -> IO () + } + + +newtype RefDigest = RefDigest (Digest Blake2b_256) + deriving (Eq, Ord, NFData, ByteArrayAccess) + +instance Show RefDigest where + show = BC.unpack . showRefDigest + +data Ref' c = Ref (Storage' c) RefDigest + +instance Eq (Ref' c) where + Ref _ d1 == Ref _ d2 = d1 == d2 + +instance Show (Ref' c) where + show ref@(Ref st _) = show st ++ ":" ++ BC.unpack (showRef ref) + +instance ByteArrayAccess (Ref' c) where + length (Ref _ dgst) = BA.length dgst + withByteArray (Ref _ dgst) = BA.withByteArray dgst + +instance Hashable RefDigest where + hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +instance Hashable (Ref' c) where + hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +refStorage :: Ref' c -> Storage' c +refStorage (Ref st _) = st + +refDigest :: Ref' c -> RefDigest +refDigest (Ref _ dgst) = dgst + +showRef :: Ref' c -> ByteString +showRef = showRefDigest . refDigest + +showRefDigestParts :: RefDigest -> (ByteString, ByteString) +showRefDigestParts x = (BC.pack "blake2", showHex x) + +showRefDigest :: RefDigest -> ByteString +showRefDigest = showRefDigestParts >>> \(alg, hex) -> alg <> BC.pack "#" <> hex + +readRefDigest :: ByteString -> Maybe RefDigest +readRefDigest x = case BC.split '#' x of + [alg, dgst] | BA.convert alg == BC.pack "blake2" -> + refDigestFromByteString =<< readHex @ByteString dgst + _ -> Nothing + +refDigestFromByteString :: ByteArrayAccess ba => ba -> Maybe RefDigest +refDigestFromByteString = fmap RefDigest . digestFromByteString + +hashToRefDigest :: BL.ByteString -> RefDigest +hashToRefDigest = RefDigest . hashFinalize . hashUpdates hashInit . BL.toChunks + +showHex :: ByteArrayAccess ba => ba -> ByteString +showHex = B.concat . map showHexByte . BA.unpack + where showHexChar x | x < 10 = x + o '0' + | otherwise = x + o 'a' - 10 + showHexByte x = B.pack [ showHexChar (x `div` 16), showHexChar (x `mod` 16) ] + o = fromIntegral . ord + +readHex :: ByteArray ba => ByteString -> Maybe ba +readHex = return . BA.concat <=< readHex' + where readHex' bs | B.null bs = Just [] + readHex' bs = do (bx, bs') <- B.uncons bs + (by, bs'') <- B.uncons bs' + x <- hexDigit bx + y <- hexDigit by + (B.singleton (x * 16 + y) :) <$> readHex' bs'' + hexDigit x | x >= o '0' && x <= o '9' = Just $ x - o '0' + | x >= o 'a' && x <= o 'z' = Just $ x - o 'a' + 10 + | otherwise = Nothing + o = fromIntegral . ord + + +newtype Generation = Generation Int + deriving (Eq, Show) + +data Head' c a = Head HeadID (Stored' c a) + deriving (Eq, Show) + +newtype HeadID = HeadID UUID + deriving (Eq, Ord, Show) + +newtype HeadTypeID = HeadTypeID UUID + deriving (Eq, Ord) + +data Stored' c a = Stored (Ref' c) a + deriving (Show) + +instance Eq (Stored' c a) where + Stored r1 _ == Stored r2 _ = refDigest r1 == refDigest r2 + +instance Ord (Stored' c a) where + compare (Stored r1 _) (Stored r2 _) = compare (refDigest r1) (refDigest r2) + +storedStorage :: Stored' c a -> Storage' c +storedStorage (Stored (Ref st _) _) = st + + +type Complete = Identity +type Partial = Either RefDigest + +class (Traversable compl, Monad compl) => StorageCompleteness compl where + type LoadResult compl a :: Type + returnLoadResult :: compl a -> LoadResult compl a + ioLoadBytes :: Ref' compl -> IO (compl BL.ByteString) + +instance StorageCompleteness Complete where + type LoadResult Complete a = a + returnLoadResult = runIdentity + ioLoadBytes ref@(Ref st dgst) = maybe (error $ "Ref not found in complete storage: "++show ref) Identity + <$> ioLoadBytesFromStorage st dgst + +instance StorageCompleteness Partial where + type LoadResult Partial a = Either RefDigest a + returnLoadResult = id + ioLoadBytes (Ref st dgst) = maybe (Left dgst) Right <$> ioLoadBytesFromStorage st dgst + +unsafeStoreRawBytes :: Storage' c -> BL.ByteString -> IO (Ref' c) +unsafeStoreRawBytes st raw = do + let dgst = hashToRefDigest raw + case stBacking st of + StorageDir { dirPath = sdir } -> writeFileOnce (refPath sdir dgst) $ compress raw + StorageMemory { memObjs = tobjs } -> + dgst `deepseq` -- the TVar may be accessed when evaluating the data to be written + modifyMVar_ tobjs (return . M.insert dgst raw) + return $ Ref st dgst + +ioLoadBytesFromStorage :: Storage' c -> RefDigest -> IO (Maybe BL.ByteString) +ioLoadBytesFromStorage st dgst = loadCurrent st >>= + \case Just bytes -> return $ Just bytes + Nothing | Just parent <- stParent st -> ioLoadBytesFromStorage parent dgst + | otherwise -> return Nothing + where loadCurrent Storage { stBacking = StorageDir { dirPath = spath } } = handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ + Just . decompress . BL.fromChunks . (:[]) <$> (B.readFile $ refPath spath dgst) + loadCurrent Storage { stBacking = StorageMemory { memObjs = tobjs } } = M.lookup dgst <$> readMVar tobjs + +refPath :: FilePath -> RefDigest -> FilePath +refPath spath rdgst = intercalate "/" [spath, "objects", BC.unpack alg, pref, rest] + where (alg, dgst) = showRefDigestParts rdgst + (pref, rest) = splitAt 2 $ BC.unpack dgst + + +openLockFile :: FilePath -> IO Handle +openLockFile path = do + createDirectoryIfMissing True (takeDirectory path) + retry 10 $ createFileExclusive path + where + retry :: Int -> IO a -> IO a + retry 0 act = act + retry n act = catchJust (\e -> if isAlreadyExistsError e then Just () else Nothing) + act (\_ -> threadDelay (100 * 1000) >> retry (n - 1) act) + +writeFileOnce :: FilePath -> BL.ByteString -> IO () +writeFileOnce file content = bracket (openLockFile locked) + hClose $ \h -> do + doesFileExist file >>= \case + True -> removeFile locked + False -> do BL.hPut h content + hClose h + renameFile locked file + where locked = file ++ ".lock" + +writeFileChecked :: FilePath -> Maybe ByteString -> ByteString -> IO (Either (Maybe ByteString) ()) +writeFileChecked file prev content = bracket (openLockFile locked) + hClose $ \h -> do + (prev,) <$> doesFileExist file >>= \case + (Nothing, True) -> do + current <- B.readFile file + removeFile locked + return $ Left $ Just current + (Nothing, False) -> do B.hPut h content + hClose h + renameFile locked file + return $ Right () + (Just expected, True) -> do + current <- B.readFile file + if current == expected then do B.hPut h content + hClose h + renameFile locked file + return $ return () + else do removeFile locked + return $ Left $ Just current + (Just _, False) -> do + removeFile locked + return $ Left Nothing + where locked = file ++ ".lock" diff --git a/src/Erebos/Storage/Key.hs b/src/Erebos/Storage/Key.hs new file mode 100644 index 0000000..5da79e3 --- /dev/null +++ b/src/Erebos/Storage/Key.hs @@ -0,0 +1,86 @@ +module Erebos.Storage.Key ( + KeyPair(..), + storeKey, loadKey, loadKeyMb, + moveKeys, +) where + +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Data.ByteArray +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.Map as M + +import System.Directory +import System.FilePath +import System.IO.Error + +import Erebos.Storage +import Erebos.Storage.Internal + +class Storable pub => KeyPair sec pub | sec -> pub, pub -> sec where + generateKeys :: Storage -> IO (sec, Stored pub) + keyGetPublic :: sec -> Stored pub + keyGetData :: sec -> ScrubbedBytes + keyFromData :: ScrubbedBytes -> Stored pub -> Maybe sec + + +keyFilePath :: KeyPair sec pub => FilePath -> Stored pub -> FilePath +keyFilePath sdir pkey = sdir </> "keys" </> (BC.unpack $ showRef $ storedRef pkey) + +storeKey :: KeyPair sec pub => sec -> IO () +storeKey key = do + let spub = keyGetPublic key + case stBacking $ storedStorage spub of + StorageDir { dirPath = dir } -> writeFileOnce (keyFilePath dir spub) (BL.fromStrict $ convert $ keyGetData key) + StorageMemory { memKeys = kstore } -> modifyMVar_ kstore $ return . M.insert (refDigest $ storedRef spub) (keyGetData key) + +loadKey :: (KeyPair sec pub, MonadIO m, MonadError String m) => Stored pub -> m sec +loadKey pub = maybe (throwError $ "secret key not found for " <> show (storedRef pub)) return =<< loadKeyMb pub + +loadKeyMb :: (KeyPair sec pub, MonadIO m) => Stored pub -> m (Maybe sec) +loadKeyMb spub = liftIO $ run $ storedStorage spub + where + run st = tryOneLevel (stBacking st) >>= \case + key@Just {} -> return key + Nothing | Just parent <- stParent st -> run parent + | otherwise -> return Nothing + tryOneLevel = \case + StorageDir { dirPath = dir } -> tryIOError (BC.readFile (keyFilePath dir spub)) >>= \case + Right kdata -> return $ keyFromData (convert kdata) spub + Left _ -> return Nothing + StorageMemory { memKeys = kstore } -> (flip keyFromData spub <=< M.lookup (refDigest $ storedRef spub)) <$> readMVar kstore + +moveKeys :: MonadIO m => Storage -> Storage -> m () +moveKeys from to = liftIO $ do + case (stBacking from, stBacking to) of + (StorageDir { dirPath = fromPath }, StorageDir { dirPath = toPath }) -> do + files <- listDirectory (fromPath </> "keys") + forM_ files $ \file -> do + renameFile (fromPath </> "keys" </> file) (toPath </> "keys" </> file) + + (StorageDir { dirPath = fromPath }, StorageMemory { memKeys = toKeys }) -> do + let move m file + | Just dgst <- readRefDigest (BC.pack file) = do + let path = fromPath </> "keys" </> file + key <- convert <$> BC.readFile path + removeFile path + return $ M.insert dgst key m + | otherwise = return m + files <- listDirectory (fromPath </> "keys") + modifyMVar_ toKeys $ \keys -> foldM move keys files + + (StorageMemory { memKeys = fromKeys }, StorageDir { dirPath = toPath }) -> do + modifyMVar_ fromKeys $ \keys -> do + forM_ (M.assocs keys) $ \(dgst, key) -> + writeFileOnce (toPath </> "keys" </> (BC.unpack $ showRefDigest dgst)) (BL.fromStrict $ convert key) + return M.empty + + (StorageMemory { memKeys = fromKeys }, StorageMemory { memKeys = toKeys }) -> do + when (fromKeys /= toKeys) $ do + modifyMVar_ fromKeys $ \fkeys -> do + modifyMVar_ toKeys $ return . M.union fkeys + return M.empty diff --git a/src/Erebos/Storage/List.hs b/src/Erebos/Storage/List.hs new file mode 100644 index 0000000..f0f8786 --- /dev/null +++ b/src/Erebos/Storage/List.hs @@ -0,0 +1,154 @@ +module Erebos.Storage.List ( + StoredList, + emptySList, fromSList, storedFromSList, + slistAdd, slistAddS, + -- TODO slistInsert, slistInsertS, + slistRemove, slistReplace, slistReplaceS, + -- TODO mapFromSList, updateOld, + + -- TODO StoreUpdate(..), + -- TODO withStoredListItem, withStoredListItemS, +) where + +import Data.List +import Data.Maybe +import qualified Data.Set as S + +import Erebos.Storage +import Erebos.Storage.Internal +import Erebos.Storage.Merge + +data List a = ListNil + | ListItem { listPrev :: [StoredList a] + , listItem :: Maybe (Stored a) + , listRemove :: Maybe (Stored (List a)) + } + +type StoredList a = Stored (List a) + +instance Storable a => Storable (List a) where + store' ListNil = storeZero + store' x@ListItem {} = storeRec $ do + mapM_ (storeRef "PREV") $ listPrev x + mapM_ (storeRef "item") $ listItem x + mapM_ (storeRef "remove") $ listRemove x + + load' = loadCurrentObject >>= \case + ZeroObject -> return ListNil + _ -> loadRec $ ListItem <$> loadRefs "PREV" + <*> loadMbRef "item" + <*> loadMbRef "remove" + +instance Storable a => ZeroStorable (List a) where + fromZero _ = ListNil + + +emptySList :: Storable a => Storage -> IO (StoredList a) +emptySList st = wrappedStore st ListNil + +groupsFromSLists :: forall a. Storable a => StoredList a -> [[Stored a]] +groupsFromSLists = helperSelect S.empty . (:[]) + where + helperSelect :: S.Set (StoredList a) -> [StoredList a] -> [[Stored a]] + helperSelect rs xxs | x:xs <- sort $ filterRemoved rs xxs = helper rs x xs + | otherwise = [] + + helper :: S.Set (StoredList a) -> StoredList a -> [StoredList a] -> [[Stored a]] + helper rs x xs + | ListNil <- fromStored x + = [] + + | Just rm <- listRemove (fromStored x) + , ans <- ancestors [x] + , (other, collision) <- partition (S.null . S.intersection ans . ancestors . (:[])) xs + , cont <- helperSelect (rs `S.union` ancestors [rm]) $ concatMap (listPrev . fromStored) (x : collision) ++ other + = case catMaybes $ map (listItem . fromStored) (x : collision) of + [] -> cont + xis -> xis : cont + + | otherwise = case listItem (fromStored x) of + Nothing -> helperSelect rs $ listPrev (fromStored x) ++ xs + Just xi -> [xi] : (helperSelect rs $ listPrev (fromStored x) ++ xs) + + filterRemoved :: S.Set (StoredList a) -> [StoredList a] -> [StoredList a] + filterRemoved rs = filter (S.null . S.intersection rs . ancestors . (:[])) + +fromSList :: Mergeable a => StoredList (Component a) -> [a] +fromSList = map merge . groupsFromSLists + +storedFromSList :: (Mergeable a, Storable a) => StoredList (Component a) -> IO [Stored a] +storedFromSList = mapM storeMerge . groupsFromSLists + +slistAdd :: Storable a => a -> StoredList a -> IO (StoredList a) +slistAdd x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistAddS sx prev + +slistAddS :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistAddS sx prev@(Stored (Ref st _) _) = wrappedStore st (ListItem [prev] (Just sx) Nothing) + +{- TODO +slistInsert :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistInsert after x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistInsertS after sx prev + +slistInsertS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistInsertS after sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem Nothing (findSListRef after prev) (Just sx) prev +-} + +slistRemove :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistRemove rm prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] Nothing (findSListRef rm prev) + +slistReplace :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistReplace rm x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistReplaceS rm sx prev + +slistReplaceS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistReplaceS rm sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] (Just sx) (findSListRef rm prev) + +findSListRef :: Stored a -> StoredList a -> Maybe (StoredList a) +findSListRef _ (Stored _ ListNil) = Nothing +findSListRef x cur | listItem (fromStored cur) == Just x = Just cur + | otherwise = listToMaybe $ catMaybes $ map (findSListRef x) $ listPrev $ fromStored cur + +{- TODO +mapFromSList :: Storable a => StoredList a -> Map RefDigest (Stored a) +mapFromSList list = helper list M.empty + where helper :: Storable a => StoredList a -> Map RefDigest (Stored a) -> Map RefDigest (Stored a) + helper (Stored _ ListNil) cur = cur + helper (Stored _ (ListItem (Just rref) _ (Just x) rest)) cur = + let rxref = case load rref of + ListItem _ _ (Just rx) _ -> sameType rx x $ storedRef rx + _ -> error "mapFromSList: malformed list" + in helper rest $ case M.lookup (refDigest $ storedRef x) cur of + Nothing -> M.insert (refDigest rxref) x cur + Just x' -> M.insert (refDigest rxref) x' cur + helper (Stored _ (ListItem _ _ _ rest)) cur = helper rest cur + sameType :: a -> a -> b -> b + sameType _ _ x = x + +updateOld :: Map RefDigest (Stored a) -> Stored a -> Stored a +updateOld m x = fromMaybe x $ M.lookup (refDigest $ storedRef x) m + + +data StoreUpdate a = StoreKeep + | StoreReplace a + | StoreRemove + +withStoredListItem :: (Storable a) => (a -> Bool) -> StoredList a -> (a -> IO (StoreUpdate a)) -> IO (StoredList a) +withStoredListItem p list f = withStoredListItemS (p . fromStored) list (suMap (wrappedStore $ storedStorage list) <=< f . fromStored) + where suMap :: Monad m => (a -> m b) -> StoreUpdate a -> m (StoreUpdate b) + suMap _ StoreKeep = return StoreKeep + suMap g (StoreReplace x) = return . StoreReplace =<< g x + suMap _ StoreRemove = return StoreRemove + +withStoredListItemS :: (Storable a) => (Stored a -> Bool) -> StoredList a -> (Stored a -> IO (StoreUpdate (Stored a))) -> IO (StoredList a) +withStoredListItemS p list f = do + case find p $ storedFromSList list of + Just sx -> f sx >>= \case StoreKeep -> return list + StoreReplace nx -> slistReplaceS sx nx list + StoreRemove -> slistRemove sx list + Nothing -> return list +-} diff --git a/src/Erebos/Storage/Merge.hs b/src/Erebos/Storage/Merge.hs new file mode 100644 index 0000000..a3b0fd7 --- /dev/null +++ b/src/Erebos/Storage/Merge.hs @@ -0,0 +1,163 @@ +module Erebos.Storage.Merge ( + Mergeable(..), + merge, storeMerge, + + Generation, + showGeneration, + compareGeneration, generationMax, + storedGeneration, + + generations, + ancestors, + precedes, + precedesOrEquals, + filterAncestors, + storedRoots, + walkAncestors, + + findProperty, + findPropertyFirst, +) where + +import Control.Concurrent.MVar + +import Data.ByteString.Char8 qualified as BC +import Data.HashTable.IO qualified as HT +import Data.Kind +import Data.List +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S + +import System.IO.Unsafe (unsafePerformIO) + +import Erebos.Storage +import Erebos.Storage.Internal +import Erebos.Util + +class Storable (Component a) => Mergeable a where + type Component a :: Type + mergeSorted :: [Stored (Component a)] -> a + toComponents :: a -> [Stored (Component a)] + +instance Mergeable [Stored Object] where + type Component [Stored Object] = Object + mergeSorted = id + toComponents = id + +merge :: Mergeable a => [Stored (Component a)] -> a +merge [] = error "merge: empty list" +merge xs = mergeSorted $ filterAncestors xs + +storeMerge :: (Mergeable a, Storable a) => [Stored (Component a)] -> IO (Stored a) +storeMerge [] = error "merge: empty list" +storeMerge xs@(Stored ref _ : _) = wrappedStore (refStorage ref) $ mergeSorted $ filterAncestors xs + +previous :: Storable a => Stored a -> [Stored a] +previous (Stored ref _) = case load ref of + Rec items | Just (RecRef dref) <- lookup (BC.pack "SDATA") items + , Rec ditems <- load dref -> + map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ + map snd $ filter ((`elem` [ BC.pack "SPREV", BC.pack "SBASE" ]) . fst) ditems + + | otherwise -> + map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ + map snd $ filter ((`elem` [ BC.pack "PREV", BC.pack "BASE" ]) . fst) items + _ -> [] + + +nextGeneration :: [Generation] -> Generation +nextGeneration = foldl' helper (Generation 0) + where helper (Generation c) (Generation n) | c <= n = Generation (n + 1) + | otherwise = Generation c + +showGeneration :: Generation -> String +showGeneration (Generation x) = show x + +compareGeneration :: Generation -> Generation -> Maybe Ordering +compareGeneration (Generation x) (Generation y) = Just $ compare x y + +generationMax :: Storable a => [Stored a] -> Maybe (Stored a) +generationMax (x : xs) = Just $ snd $ foldl' helper (storedGeneration x, x) xs + where helper (mg, mx) y = let yg = storedGeneration y + in case compareGeneration mg yg of + Just LT -> (yg, y) + _ -> (mg, mx) +generationMax [] = Nothing + +storedGeneration :: Storable a => Stored a -> Generation +storedGeneration x = + unsafePerformIO $ withMVar (stRefGeneration $ refStorage $ storedRef x) $ \ht -> do + let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case + Just gen -> return gen + Nothing -> do + gen <- nextGeneration <$> mapM doLookup (previous y) + HT.insert ht (refDigest $ storedRef y) gen + return gen + doLookup x + + +-- |Returns list of sets starting with the set of given objects and +-- intcrementally adding parents. +generations :: Storable a => [Stored a] -> [Set (Stored a)] +generations = unfoldr gen . (,S.empty) + where gen (hs, cur) = case filter (`S.notMember` cur) hs of + [] -> Nothing + added -> let next = foldr S.insert cur added + in Just (next, (previous =<< added, next)) + +-- |Returns set containing all given objects and their ancestors +ancestors :: Storable a => [Stored a] -> Set (Stored a) +ancestors = last . (S.empty:) . generations + +precedes :: Storable a => Stored a -> Stored a -> Bool +precedes x y = not $ x `elem` filterAncestors [x, y] + +precedesOrEquals :: Storable a => Stored a -> Stored a -> Bool +precedesOrEquals x y = filterAncestors [ x, y ] == [ y ] + +filterAncestors :: Storable a => [Stored a] -> [Stored a] +filterAncestors [x] = [x] +filterAncestors xs = let xs' = uniq $ sort xs + in helper xs' xs' + where helper remains walk = case generationMax walk of + Just x -> let px = previous x + remains' = filter (\r -> all (/=r) px) remains + in helper remains' $ uniq $ sort (px ++ filter (/=x) walk) + Nothing -> remains + +storedRoots :: Storable a => Stored a -> [Stored a] +storedRoots x = do + let st = refStorage $ storedRef x + unsafePerformIO $ withMVar (stRefRoots st) $ \ht -> do + let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case + Just roots -> return roots + Nothing -> do + roots <- case previous y of + [] -> return [refDigest $ storedRef y] + ps -> map (refDigest . storedRef) . filterAncestors . map (wrappedLoad @Object . Ref st) . concat <$> mapM doLookup ps + HT.insert ht (refDigest $ storedRef y) roots + return roots + map (wrappedLoad . Ref st) <$> doLookup x + +walkAncestors :: (Storable a, Monoid m) => (Stored a -> m) -> [Stored a] -> m +walkAncestors f = helper . sortBy cmp + where + helper (x : y : xs) | x == y = helper (x : xs) + helper (x : xs) = f x <> helper (mergeBy cmp (sortBy cmp (previous x)) xs) + helper [] = mempty + + cmp x y = case compareGeneration (storedGeneration x) (storedGeneration y) of + Just LT -> GT + Just GT -> LT + _ -> compare x y + +findProperty :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> [b] +findProperty sel = map (fromJust . sel . fromStored) . filterAncestors . (findPropHeads sel =<<) + +findPropertyFirst :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> Maybe b +findPropertyFirst sel = fmap (fromJust . sel . fromStored) . listToMaybe . filterAncestors . (findPropHeads sel =<<) + +findPropHeads :: forall a b. Storable a => (a -> Maybe b) -> Stored a -> [Stored a] +findPropHeads sel sobj | Just _ <- sel $ fromStored sobj = [sobj] + | otherwise = findPropHeads sel =<< previous sobj diff --git a/src/Erebos/Sync.hs b/src/Erebos/Sync.hs new file mode 100644 index 0000000..04b5f11 --- /dev/null +++ b/src/Erebos/Sync.hs @@ -0,0 +1,46 @@ +module Erebos.Sync ( + SyncService(..), +) where + +import Control.Monad +import Control.Monad.Reader + +import Data.List + +import Erebos.Identity +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Merge + +data SyncService = SyncPacket (Stored SharedState) + +instance Service SyncService where + serviceID _ = mkServiceID "a4f538d0-4e50-4082-8e10-7e3ec2af175d" + + serviceHandler packet = do + let SyncPacket added = fromStored packet + pid <- asks svcPeerIdentity + self <- svcSelf + when (finalOwner pid `sameIdentity` finalOwner self) $ do + updateLocalHead_ $ \ls -> do + let current = sort $ lsShared $ fromStored ls + updated = filterAncestors (added : current) + if current /= updated + then mstore (fromStored ls) { lsShared = updated } + else return ls + + serviceNewPeer = notifyPeer . lsShared . fromStored =<< svcGetLocal + serviceStorageWatchers _ = (:[]) $ SomeStorageWatcher (lsShared . fromStored) notifyPeer + +instance Storable SyncService where + store' (SyncPacket smsg) = store' smsg + load' = SyncPacket <$> load' + +notifyPeer :: [Stored SharedState] -> ServiceHandler SyncService () +notifyPeer shared = do + pid <- asks svcPeerIdentity + self <- svcSelf + when (finalOwner pid `sameIdentity` finalOwner self) $ do + forM_ shared $ \sh -> + replyStoredRef =<< (mstore . SyncPacket) sh diff --git a/src/Erebos/Util.hs b/src/Erebos/Util.hs new file mode 100644 index 0000000..ffca9c7 --- /dev/null +++ b/src/Erebos/Util.hs @@ -0,0 +1,37 @@ +module Erebos.Util where + +uniq :: Eq a => [a] -> [a] +uniq (x:y:xs) | x == y = uniq (x:xs) + | otherwise = x : uniq (y:xs) +uniq xs = xs + +mergeBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeBy cmp (x : xs) (y : ys) = case cmp x y of + LT -> x : mergeBy cmp xs (y : ys) + EQ -> x : y : mergeBy cmp xs ys + GT -> y : mergeBy cmp (x : xs) ys +mergeBy _ xs [] = xs +mergeBy _ [] ys = ys + +mergeUniqBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeUniqBy cmp (x : xs) (y : ys) = case cmp x y of + LT -> x : mergeBy cmp xs (y : ys) + EQ -> x : mergeBy cmp xs ys + GT -> y : mergeBy cmp (x : xs) ys +mergeUniqBy _ xs [] = xs +mergeUniqBy _ [] ys = ys + +mergeUniq :: Ord a => [a] -> [a] -> [a] +mergeUniq = mergeUniqBy compare + +diffSorted :: Ord a => [a] -> [a] -> [a] +diffSorted (x:xs) (y:ys) | x < y = x : diffSorted xs (y:ys) + | x > y = diffSorted (x:xs) ys + | otherwise = diffSorted xs (y:ys) +diffSorted xs _ = xs + +intersectsSorted :: Ord a => [a] -> [a] -> Bool +intersectsSorted (x:xs) (y:ys) | x < y = intersectsSorted xs (y:ys) + | x > y = intersectsSorted (x:xs) ys + | otherwise = True +intersectsSorted _ _ = False diff --git a/src/unix/Erebos/Storage/Platform.hs b/src/unix/Erebos/Storage/Platform.hs new file mode 100644 index 0000000..2198f61 --- /dev/null +++ b/src/unix/Erebos/Storage/Platform.hs @@ -0,0 +1,20 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Storage.Platform ( + createFileExclusive, +) where + +import System.IO +import System.Posix.Files +import System.Posix.IO + +createFileExclusive :: FilePath -> IO Handle +createFileExclusive path = fdToHandle =<< do +#if MIN_VERSION_unix(2,8,0) + openFd path WriteOnly defaultFileFlags + { creat = Just $ unionFileModes ownerReadMode ownerWriteMode + , exclusive = True + } +#else + openFd path WriteOnly (Just $ unionFileModes ownerReadMode ownerWriteMode) (defaultFileFlags { exclusive = True }) +#endif diff --git a/src/windows/Erebos/Storage/Platform.hs b/src/windows/Erebos/Storage/Platform.hs new file mode 100644 index 0000000..76c940b --- /dev/null +++ b/src/windows/Erebos/Storage/Platform.hs @@ -0,0 +1,13 @@ +module Erebos.Storage.Platform ( + createFileExclusive, +) where + +import Data.Bits + +import System.IO +import System.Win32.File +import System.Win32.Types + +createFileExclusive :: FilePath -> IO Handle +createFileExclusive path = do + hANDLEToHandle =<< createFile path gENERIC_WRITE (fILE_SHARE_READ .|. fILE_SHARE_DELETE) Nothing cREATE_NEW fILE_ATTRIBUTE_NORMAL Nothing diff --git a/attach.test b/test/attach.test index 33a1483..33a1483 100644 --- a/attach.test +++ b/test/attach.test diff --git a/chatroom.test b/test/chatroom.test index 93de1ff..93de1ff 100644 --- a/chatroom.test +++ b/test/chatroom.test diff --git a/contact.test b/test/contact.test index 438aa1f..438aa1f 100644 --- a/contact.test +++ b/test/contact.test diff --git a/message.test b/test/message.test index 307f11a..307f11a 100644 --- a/message.test +++ b/test/message.test diff --git a/network.test b/test/network.test index efd508f..efd508f 100644 --- a/network.test +++ b/test/network.test diff --git a/storage.test b/test/storage.test index 0369807..0369807 100644 --- a/storage.test +++ b/test/storage.test diff --git a/sync.test b/test/sync.test index ea9595d..ea9595d 100644 --- a/sync.test +++ b/test/sync.test |