diff options
66 files changed, 12174 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8c573be --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.ghc.environment.* +cabal.project.local +dist-newstyle/ +.erebos +.test/ +.minici/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..2beacb6 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,68 @@ +# Revision history for erebos + +## 0.1.9 -- 2025-07-08 + +* Option to show details or delete a conversation by giving index parameter without first selecting it +* Improved handling of ICE connections +* Automatic discovery of peers for pending direct messages + +## 0.1.8.1 -- 2025-03-29 + +* Fix build from sdist (add missing include) + +## 0.1.8 -- 2025-03-28 + +* Discovery service without requiring ICE support +* Added `/delete` command to delete chatrooms for current user +* Ignore record items with unexpected type +* Support GHC 9.12 + +## 0.1.7 -- 2024-10-30 + +* Chatroom-specific identity +* Secure cookie for connection initialization +* Support multiple public peers +* Handle unknown object and record item types +* Keep unknown items in local state + +## 0.1.6 -- 2024-08-12 + +* Chatroom members list and join/leave commands +* Fix sending multiple data responses in a stream +* Added `--storage`/`--memory-storage` command-line options +* Compatibility with GHC up to 9.10 +* Local discovery with IPv6 + +## 0.1.5 -- 2024-07-16 + +* Public chatrooms for multiple participants +* Send keep-alive packets on idle connection +* Windows support + +## 0.1.4 -- 2024-06-11 + +* Added `/conversations` command to list and select conversations +* Added `/details` command for info about selected conversation +* Handle peer reconnection after its restart +* Support non-interactive mode without tty + +## 0.1.3 -- 2024-05-05 + +* Enable/disable network services by command-line parameters +* Tab-completion of command name +* Implemented streams in network protocol +* Compatibility with GHC up to 9.8 + +## 0.1.2 -- 2024-02-20 + +* Compatibility with GHC up to 9.6 +* Pruned unnecessary dependencies and fixed bounds + +## 0.1.1 -- 2024-02-18 + +* Added build flag to enable/disable ICE support with pjproject. +* Added `-V` command-line switch to show version. + +## 0.1.0 -- 2024-02-10 + +* First version. @@ -0,0 +1,30 @@ +Copyright (c) 2019, Roman Smrž + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of Roman Smrž nor the names of other + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8957dd7 --- /dev/null +++ b/README.md @@ -0,0 +1,236 @@ +Erebos +====== + +The erebos binary provides simple CLI interface to the decentralized Erebos +messaging service. Local identity is created on the first run. Protocol and +services specification is being written at: + +[https://erebosprotocol.net/spec](https://erebosprotocol.net/spec) + +Erebos identity is based on locally stored cryptographic keys, all +communication is end-to-end encrypted. Multiple devices can be attached to the +same identity, after which they function interchangeably, without any one being +in any way "primary"; messages and other state data are then synchronized +automatically whenever the devices are able to connect with one another. + +Status +------ + +This is experimental implementation of yet unfinished specification, so +changes, especially in the library API, are expected. Storage format and +network protocol should generally remain backward compatible, with their +respective versions to be increased in case of incompatible changes, to allow +for interoperability even in that case. + +Usage +----- + +On the first run, local identity will be created for this device based on +interactive prompts for: + +`Name:` name of the user/owner, which will be shared among all devices +belonging to the same user; keep empty when initializing device that is going +to be attached to already existing identity on other device. + +`Device:` name describing current device, can be empty. + +After the initial setup, the erebos tool presents interactive prompt for +messages and commands. All commands start with the slash (`/`) character, +followed by command name and parameters (if any) separated by spaces. When +a conversation is selected, message to send there is entered directly on +the command prompt. + +The session can be terminated either by end-of-input (typically `Ctrl-d`) or +using the `/quit` command. + +### Example + +Start `erebos` CLI and create new identity: +``` +Name: Some Name +Device: First device +Some Name / First device +> +``` + +Add public peer: +``` +> /peer-add-public +[1] PEER NEW <unnamed> [37.221.243.57 29665] +[1] PEER UPD discovery1.erebosprotocol.net [37.221.243.57 29665] +``` + +Select the peer and send it a message, the public server just responds with +automatic echo message: +``` +> /1 +discovery1.erebosprotocol.net> hello +[18:55] Some Name: hello +[18:55] discovery1.erebosprotocol.net: Echo: hello +``` + +List chatrooms known to the peers: +``` +> /chatrooms +[1] Test chatroom +[2] Second test chatroom +``` + +Enter a chatroom and send a message there: +``` +> /1 +Test chatroom> Hi +Test chatroom [19:03] Some Name: Hi +``` + +### Messaging + +`/peers` +: List peers with direct network connection. Peers are discovered automatically + on local network or can be manually added. + +`/contacts` +: List known contacts (see below). + +`/conversations` +: List started conversations with contacts or other peers. + +`/<number>` +: Select conversation, contact or peer `<number>` based on the last + `/conversations`, `/contacts` or `/peers` output list. + +`<message>` +: Send `<message>` to selected conversation. + +`/history [<number>]` +: Show message history of the selected conversation, or the one identified by + `<number>` if given. + +`/details [<number>]` +: Show information about the selected conversations, contact or peer; or the + one identified by `<number>` if given. + +### Chatrooms + +Currently only public unmoderated chatrooms are supported, which means that any +network peer is allowed to read and post to the chatroom. Individual messages +are signed, so message author can not be forged. + +`/chatrooms` +: List known chatrooms. + +`/chatroom-create-public [<name>]` +: Create public unmoderated chatroom. Room name can be passed as command + argument or entered interactively. + +`/members` +: List members of the chatroom – usesers who sent any message or joined via the +`join` command. + +`/join` +: Join chatroom without sending text message. + +`/join-as <name>` +: Join chatroom using a new identity with a name `<name>`. This new identity is + unrelated to the main one, and will be used for any future messages sent to + this chatroom. + +`/leave` +: Leave the chatroom. User will no longer be listed as a member and erebos tool + will no longer collect message of this chatroom. + +`/delete [<number>]` +: Delete the chatroom (currently selected one, or the one identified by + `<number>`); this action is only synchronized with devices belonging to the + current user and does not affect the chatroom state for others. Due to the + storage design, the chatroom data will not be purged from the local state + history, but the chatroom will no longer be listed as available and no futher + updates for this chatroom will be collected or shared with other peers. + +### Add contacts + +To ensure the identity of the contact and prevent man-in-the-middle attack, +generated verification code needs to be confirmed on both devices to add +contacts to contact list (similar to bluetooth device pairing). Before adding +new contact, list peers using `/peers` command and select one with `/<number>`. + +`/contacts` +: List already added contacts. + +`/contact-add` +: Add selected peer as contact. Six-digit verification code will be computed + based on peer keys, which will be displayed on both devices and needs to be + checked that both numbers are same. After that it needs to be confirmed using + `/contact-accept` to finish the process. + +`/contact-accept` +: Confirm that displayed verification codes are same on both devices and add + the selected peer as contact. The side, which did not initiate the contact + adding process, needs to select the corresponding peer with `/<number>` + command first. + +`/contact-reject` +: Reject contact request or verification code of selected peer. + +### Attach other devices + +Multiple devices can be attached to single identity to be used by the same +user. After the attachment process completes the roles of the devices are +equivalent, both can send and receive messages independently and those +messages, along with any other sate data, are synchronized automatically +whenever the devices can connect to each other. + +The attachment process and underlying protocol is very similar to the contact +adding described above, so also generates verification code based on peer keys +that needs to be checked and confirmed on both devices to avoid potential +man-in-the-middle attack. + +Before attaching device, list peers using `/peers` command and select the +target device with `/<number>`. + +`/attach` +: Attach current device to the selected peer. After the process completes the + owner of the selected peer will become owner of this device as well. + Six-digit verification code will be displayed on both devices and the user + needs to check that both are the same before confirmation using the + `/attach-accept` command. + +`/attach-accept` +: Confirm that displayed verification codes are same on both devices and + complete the attachment process (or wait for the confirmation on the peer + device). The side, which did not initiate the attachment process, needs to + select the corresponding peer with `/<number>` command first. + +`/attach-reject` +: Reject device attachment request or verification code of selected peer. + +### Other + +`/peer-add <host> [<port>]` +: Manually add network peer with given hostname or IP address. + +`/peer-add-public` +: Add known public network peer(s). + +`/peer-drop` +: Drop the currently selected peer. Afterwards, the connection can be + re-established by either side. + +`/update-identity` +: Interactively update current identity information + +`/quit` +: Quit the erebos tool. + + +Storage +------- + +Data are by default stored under `XDG_DATA_HOME`, typically +`$HOME/.local/share/erebos`, unless there is an erebos storage already +in `.erebos` subdirectory of the current working directory, in which case the +latter one in used instead. This can be overriden by `EREBOS_DIR` environment +variable. + +Private keys are currently stored in plaintext under the `keys` subdirectory of +the erebos directory. diff --git a/Setup.hs b/Setup.hs new file mode 100644 index 0000000..9a994af --- /dev/null +++ b/Setup.hs @@ -0,0 +1,2 @@ +import Distribution.Simple +main = defaultMain diff --git a/erebos-tester.yaml b/erebos-tester.yaml new file mode 100644 index 0000000..a44f080 --- /dev/null +++ b/erebos-tester.yaml @@ -0,0 +1 @@ +tests: test/**/*.test diff --git a/erebos.cabal b/erebos.cabal new file mode 100644 index 0000000..9a49273 --- /dev/null +++ b/erebos.cabal @@ -0,0 +1,235 @@ +Cabal-Version: 3.0 + +Name: erebos +Version: 0.1.9 +Synopsis: Decentralized messaging and synchronization +Description: + Library and simple CLI interface implementing the Erebos identity + management, decentralized messaging and synchronization protocol, along + with local storage. + + Erebos identity is based on locally stored cryptographic keys, all + communication is end-to-end encrypted. Multiple devices can be attached to + the same identity, after which they function interchangeably, without any + one being in any way "primary"; messages and other state data are then + synchronized automatically whenever the devices are able to connect with + one another. + + See README for usage of the CLI tool. +License: BSD-3-Clause +License-File: LICENSE +Homepage: https://erebosprotocol.net/erebos +Author: Roman Smrž <roman.smrz@seznam.cz> +Maintainer: roman.smrz@seznam.cz +Category: Network +Stability: experimental +Build-type: Simple +Extra-Doc-Files: + README.md + CHANGELOG.md +Extra-Source-Files: + src/Erebos/ICE/pjproject.h + src/Erebos/Network/ifaddrs.h + +Flag ice + Description: Enable peer discovery with ICE support using pjproject + +Flag ci + description: Options for CI testing + default: False + manual: True + +Flag cryptonite + description: Use deprecated 'cryptonite' package + default: False + +source-repository head + type: git + location: https://code.erebosprotocol.net/erebos + +common common + ghc-options: + -Wall + -Wno-x-partial + -fdefer-typed-holes + + if flag(ci) + ghc-options: + -Werror + -- sometimes needed for backward/forward compatibility: + -Wno-error=unused-imports + + build-depends: + base ^>= { 4.15, 4.16, 4.17, 4.18, 4.19, 4.20, 4.21 }, + + default-extensions: + DefaultSignatures + ExistentialQuantification + FlexibleContexts + FlexibleInstances + FunctionalDependencies + GeneralizedNewtypeDeriving + ImportQualifiedPost + LambdaCase + MultiWayIf + RankNTypes + RecordWildCards + ScopedTypeVariables + StandaloneDeriving + TypeOperators + TupleSections + TypeApplications + TypeFamilies + TypeFamilyDependencies + + other-extensions: + CPP + ForeignFunctionInterface + OverloadedStrings + RecursiveDo + TemplateHaskell + UndecidableInstances + + if flag(ice) + cpp-options: -DENABLE_ICE_SUPPORT + +library + import: common + default-language: Haskell2010 + + hs-source-dirs: src + exposed-modules: + Erebos.Attach + Erebos.Chatroom + Erebos.Contact + Erebos.Conversation + Erebos.DirectMessage + Erebos.Discovery + Erebos.Error + Erebos.Identity + Erebos.Network + Erebos.Object + Erebos.Pairing + Erebos.PubKey + Erebos.Service + Erebos.Service.Stream + Erebos.Set + Erebos.State + Erebos.Storable + Erebos.Storage + Erebos.Storage.Backend + Erebos.Storage.Head + Erebos.Storage.Key + Erebos.Storage.Merge + Erebos.Sync + + other-modules: + Erebos.Flow + Erebos.Network.Channel + Erebos.Network.Protocol + Erebos.Object.Internal + Erebos.Storage.Disk + Erebos.Storage.Internal + Erebos.Storage.Memory + Erebos.Storage.Platform + Erebos.UUID + Erebos.Util + + c-sources: + src/Erebos/Network/ifaddrs.c + include-dirs: + src + includes: + src/Erebos/Network/ifaddrs.h + + if flag(ice) + exposed-modules: + Erebos.ICE + c-sources: + src/Erebos/ICE/pjproject.c + include-dirs: + src/Erebos/ICE + includes: + src/Erebos/ICE/pjproject.h + build-tool-depends: c2hs:c2hs + pkgconfig-depends: libpjproject >= 2.9 + + build-depends: + async >=2.2 && <2.3, + binary >=0.8 && <0.11, + bytestring >=0.10 && <0.13, + clock >=0.8 && < 0.9, + containers ^>= { 0.6, 0.7, 0.8 }, + deepseq >= 1.4 && <1.6, + directory >= 1.3 && <1.4, + filepath >=1.4 && <1.6, + fsnotify ^>= { 0.3, 0.4 }, + hashable ^>= { 1.3, 1.4, 1.5 }, + hashtables ^>= { 1.2, 1.3, 1.4 }, + iproute >=1.7.12 && <1.8, + memory >=0.14 && <0.19, + mtl >=2.2 && <2.4, + network ^>= { 3.1, 3.2 }, + stm >=2.5 && <2.6, + text >= 1.2 && <2.2, + time ^>= { 1.8, 1.9, 1.10, 1.11, 1.12, 1.13, 1.14 }, + uuid-types ^>= { 1.0.4 }, + zlib >=0.6 && <0.8 + + if !flag(cryptonite) + build-depends: + crypton ^>= { 0.34, 1.0 }, + else + build-depends: + cryptonite >=0.25 && <0.31, + + if os(windows) + hs-source-dirs: src/windows + build-depends: + Win32 ^>= { 2.14 }, + else + hs-source-dirs: src/unix + build-depends: + unix ^>= { 2.7, 2.8 }, + +executable erebos + import: common + default-language: Haskell2010 + hs-source-dirs: main + ghc-options: -threaded + + main-is: Main.hs + other-modules: + Paths_erebos + State + Terminal + Test + Test.Service + Version + Version.Git + WebSocket + autogen-modules: + Paths_erebos + + build-depends: + ansi-terminal ^>= { 0.11, 1.0, 1.1 }, + bytestring, + directory, + erebos, + mtl, + network, + process >=1.6 && <1.7, + stm, + template-haskell ^>= { 2.17, 2.18, 2.19, 2.20, 2.21, 2.22, 2.23 }, + text, + time, + transformers >= 0.5 && <0.7, + uuid-types, + websockets ^>= { 0.12.7, 0.13 }, + + if !flag(cryptonite) + build-depends: + crypton, + else + build-depends: + cryptonite, diff --git a/main/Main.hs b/main/Main.hs new file mode 100644 index 0000000..403e5e9 --- /dev/null +++ b/main/Main.hs @@ -0,0 +1,1039 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE OverloadedStrings #-} + +module Main (main) where + +import Control.Arrow (first) +import Control.Concurrent +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Trans.Maybe +import Control.Monad.Writer + +import Crypto.Random + +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import Data.Char +import Data.List +import Data.Maybe +import Data.Ord +import Data.Text (Text) +import Data.Text qualified as T +import Data.Text.Encoding qualified as T +import Data.Text.IO qualified as T +import Data.Time.Format +import Data.Time.LocalTime +import Data.Typeable + +import Network.Socket + +import System.Console.GetOpt +import System.Directory +import System.Environment +import System.Exit +import System.IO + +import Erebos.Attach +import Erebos.Contact +import Erebos.Chatroom +import Erebos.Conversation +import Erebos.DirectMessage +import Erebos.Discovery +#ifdef ENABLE_ICE_SUPPORT +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Network +import Erebos.Object +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storable +import Erebos.Storage +import Erebos.Storage.Merge +import Erebos.Sync + +import State +import Terminal +import Test +import Version +import WebSocket + +data Options = Options + { optServer :: ServerOptions + , optServices :: [ServiceOption] + , optStorage :: StorageOption + , optChatroomAutoSubscribe :: Maybe Int + , optDmBotEcho :: Maybe Text + , optWebSocketServer :: Maybe Int + , optShowHelp :: Bool + , optShowVersion :: Bool + } + +data StorageOption = DefaultStorage + | FilesystemStorage FilePath + | MemoryStorage + +data ServiceOption = ServiceOption + { soptName :: String + , soptService :: SomeService + , soptEnabled :: Bool + , soptDescription :: String + } + +defaultOptions :: Options +defaultOptions = Options + { optServer = defaultServerOptions + , optServices = availableServices + , optStorage = DefaultStorage + , optChatroomAutoSubscribe = Nothing + , optDmBotEcho = Nothing + , optWebSocketServer = Nothing + , optShowHelp = False + , optShowVersion = False + } + +availableServices :: [ServiceOption] +availableServices = + [ ServiceOption "attach" (someService @AttachService Proxy) + True "attach (to) other devices" + , ServiceOption "sync" (someService @SyncService Proxy) + True "synchronization with attached devices" + , ServiceOption "chatroom" (someService @ChatroomService Proxy) + True "chatrooms with multiple participants" + , ServiceOption "contact" (someService @ContactService Proxy) + True "create contacts with network peers" + , ServiceOption "dm" (someService @DirectMessage Proxy) + True "direct messages" + , ServiceOption "discovery" (someService @DiscoveryService Proxy) + True "peer discovery" + ] + +options :: [ OptDescr (Options -> Writer [ String ] Options) ] +options = + [ Option ['p'] ["port"] + (ReqArg (\p -> so $ \opts -> opts { serverPort = read p }) "<port>") + "local port to bind" + , Option ['s'] ["silent"] + (NoArg (so $ \opts -> opts { serverLocalDiscovery = False })) + "do not send announce packets for local discovery" + , Option [] [ "storage" ] + (ReqArg (\path -> \opts -> return opts { optStorage = FilesystemStorage path }) "<path>") + "use storage in <path>" + , Option [] [ "memory-storage" ] + (NoArg (\opts -> return opts { optStorage = MemoryStorage })) + "use memory storage" + , Option [] ["chatroom-auto-subscribe"] + (ReqArg (\count -> \opts -> return opts { optChatroomAutoSubscribe = Just (read count) }) "<count>") + "automatically subscribe for up to <count> chatrooms" +#ifdef ENABLE_ICE_SUPPORT + , Option [] [ "discovery-stun-port" ] + (ReqArg (\value -> serviceAttr $ \attrs -> return attrs { discoveryStunPort = Just (read value) }) "<port>") + "offer specified <port> to discovery peers for STUN protocol" + , Option [] [ "discovery-stun-server" ] + (ReqArg (\value -> serviceAttr $ \attrs -> return attrs { discoveryStunServer = Just (read value) }) "<server>") + "offer <server> (domain name or IP address) to discovery peers for STUN protocol" + , Option [] [ "discovery-turn-port" ] + (ReqArg (\value -> serviceAttr $ \attrs -> return attrs { discoveryTurnPort = Just (read value) }) "<port>") + "offer specified <port> to discovery peers for TURN protocol" + , Option [] [ "discovery-turn-server" ] + (ReqArg (\value -> serviceAttr $ \attrs -> return attrs { discoveryTurnServer = Just (read value) }) "<server>") + "offer <server> (domain name or IP address) to discovery peers for TURN protocol" +#endif + , Option [] [ "discovery-tunnel" ] + (OptArg (\value -> \opts -> do + fun <- provideTunnelFun value + serviceAttr (\attrs -> return attrs { discoveryProvideTunnel = fun }) opts) "<peer-type>") + "offer to provide tunnel for peers of given <peer-type>, possible values: all, none, websocket" + , Option [] ["dm-bot-echo"] + (ReqArg (\prefix -> \opts -> return opts { optDmBotEcho = Just (T.pack prefix) }) "<prefix>") + "automatically reply to direct messages with the same text prefixed with <prefix>" + , Option [] [ "websocket-server" ] + (ReqArg (\value -> \opts -> return opts { optWebSocketServer = Just (read value) }) "<port>") + "start WebSocket server on given port" + , Option ['h'] ["help"] + (NoArg $ \opts -> return opts { optShowHelp = True }) + "show this help and exit" + , Option ['V'] ["version"] + (NoArg $ \opts -> return opts { optShowVersion = True }) + "show version and exit" + ] + where + so f opts = return opts { optServer = f $ optServer opts } + + updateService :: (Service s, Monad m, Typeable m) => (ServiceAttributes s -> m (ServiceAttributes s)) -> SomeService -> m SomeService + updateService f some@(SomeService proxy attrs) + | Just f' <- cast f = SomeService proxy <$> f' attrs + | otherwise = return some + + serviceAttr :: (Service s, Monad m, Typeable m) => (ServiceAttributes s -> m (ServiceAttributes s)) -> Options -> m Options + serviceAttr f opts = do + services' <- forM (optServices opts) $ \sopt -> do + service <- updateService f (soptService sopt) + return sopt { soptService = service } + return opts { optServices = services' } + + provideTunnelFun :: Maybe String -> Writer [ String ] (Peer -> Bool) + provideTunnelFun Nothing = return $ const True + provideTunnelFun (Just "all") = return $ const True + provideTunnelFun (Just "none") = return $ const False + provideTunnelFun (Just "websocket") = return $ \peer -> + case peerAddress peer of + CustomPeerAddress addr | Just WebSocketAddress {} <- cast addr -> True + _ -> False + provideTunnelFun (Just name) = do + tell [ "Invalid value of --discovery-tunnel: ‘" <> name <> "’\n" ] + return $ const False + +servicesOptions :: [ OptDescr (Options -> Writer [ String ] Options) ] +servicesOptions = concatMap helper $ "all" : map soptName availableServices + where + helper name = + [ Option [] [ "enable-" <> name ] (NoArg $ so $ change name $ \sopt -> sopt { soptEnabled = True }) "" + , Option [] [ "disable-" <> name ] (NoArg $ so $ change name $ \sopt -> sopt { soptEnabled = False }) "" + ] + so f opts = return opts { optServices = f $ optServices opts } + change :: String -> (ServiceOption -> ServiceOption) -> [ServiceOption] -> [ServiceOption] + change name f (s : ss) + | soptName s == name || name == "all" + = f s : change name f ss + | otherwise = s : change name f ss + change _ _ [] = [] + +getDefaultStorageDir :: IO FilePath +getDefaultStorageDir = do + lookupEnv "EREBOS_DIR" >>= \case + Just dir -> return dir + Nothing -> doesFileExist "./.erebos/erebos-storage" >>= \case + True -> return "./.erebos" + False -> getXdgDirectory XdgData "erebos" + +main :: IO () +main = do + let printErrors errs = do + progName <- getProgName + hPutStrLn stderr $ concat errs <> "Try `" <> progName <> " --help' for more information." + exitFailure + (opts, args) <- (getOpt RequireOrder (options ++ servicesOptions) <$> getArgs) >>= \case + (wo, args, []) -> + case runWriter (foldM (flip ($)) defaultOptions wo) of + ( o, [] ) -> return ( o, args ) + ( _, errs ) -> printErrors errs + (_, _, errs) -> printErrors errs + + st <- liftIO $ case optStorage opts of + DefaultStorage -> openStorage =<< getDefaultStorageDir + FilesystemStorage path -> openStorage path + MemoryStorage -> memoryStorage + + case args of + ["cat-file", sref] -> do + readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> BL.putStr $ lazyLoadBytes ref + + ("cat-file" : objtype : srefs@(_:_)) -> do + sequence <$> (mapM (readRef st . BC.pack) srefs) >>= \case + Nothing -> error "ref does not exist" + Just refs -> case objtype of + "signed" -> forM_ refs $ \ref -> do + let signed = load ref :: Signed Object + BL.putStr $ lazyLoadBytes $ storedRef $ signedData signed + forM_ (signedSignature signed) $ \sig -> do + putStr $ "SIG " + BC.putStrLn $ showRef $ storedRef $ sigKey $ fromStored sig + "identity" -> case validateExtendedIdentityF (wrappedLoad <$> refs) of + Just identity -> do + let disp :: Identity m -> IO () + disp idt = do + maybe (return ()) (T.putStrLn . (T.pack "Name: " `T.append`)) $ idName idt + BC.putStrLn . (BC.pack "KeyId: " `BC.append`) . showRefDigest . refDigest . storedRef $ idKeyIdentity idt + BC.putStrLn . (BC.pack "KeyMsg: " `BC.append`) . showRefDigest . refDigest . storedRef $ idKeyMessage idt + case idOwner idt of + Nothing -> return () + Just owner -> do + mapM_ (putStrLn . ("OWNER " ++) . BC.unpack . showRefDigest . refDigest . storedRef) $ idExtDataF owner + disp owner + disp identity + Nothing -> putStrLn $ "Identity verification failed" + _ -> error $ "unknown object type '" ++ objtype ++ "'" + + ["show-generation", sref] -> readRef st (BC.pack sref) >>= \case + Nothing -> error "ref does not exist" + Just ref -> print $ storedGeneration (wrappedLoad ref :: Stored Object) + + ["update-identity"] -> do + withTerminal noCompletion $ \term -> do + either (fail . showErebosError) return <=< runExceptT $ do + runReaderT (updateSharedIdentity term) =<< loadLocalStateHead term st + + ("update-identity" : srefs) -> do + withTerminal noCompletion $ \term -> do + sequence <$> mapM (readRef st . BC.pack) srefs >>= \case + Nothing -> error "ref does not exist" + Just refs + | Just idt <- validateIdentityF $ map wrappedLoad refs -> do + BC.putStrLn . showRefDigest . refDigest . storedRef . idData =<< + (either (fail . showErebosError) return <=< runExceptT $ runReaderT (interactiveIdentityUpdate term idt) st) + | otherwise -> error "invalid identity" + + ["test"] -> runTestTool st + + [] -> do + let header = "Usage: erebos [OPTION...]" + serviceDesc ServiceOption {..} = padService (" " <> soptName) <> soptDescription + + padTo n str = str <> replicate (n - length str) ' ' + padOpt = padTo 37 + padService = padTo 16 + + if | optShowHelp opts -> putStr $ usageInfo header options <> unlines + ( + [ padOpt " --enable-<service>" <> "enable network service <service>" + , padOpt " --disable-<service>" <> "disable network service <service>" + , padOpt " --enable-all" <> "enable all network services" + , padOpt " --disable-all" <> "disable all network services" + , "" + , "Available network services:" + ] ++ map serviceDesc availableServices + ) + | optShowVersion opts -> putStrLn versionLine + | otherwise -> interactiveLoop st opts + + (cmdname : _) -> do + hPutStrLn stderr $ "Unknown command `" <> cmdname <> "'" + exitFailure + + +interactiveLoop :: Storage -> Options -> IO () +interactiveLoop st opts = withTerminal commandCompletion $ \term -> do + erebosHead <- liftIO $ loadLocalStateHead term st + void $ printLine term $ T.unpack $ displayIdentity $ headLocalIdentity erebosHead + + let tui = hasTerminalUI term + let extPrintLn = void . printLine term + + let getInputLinesTui :: Either CommandState String -> MaybeT IO String + getInputLinesTui eprompt = do + prompt <- case eprompt of + Left cstate -> do + pname <- case csContext cstate of + NoContext -> return "" + SelectedPeer peer -> peerIdentity peer >>= return . \case + PeerIdentityFull pid -> maybe "<unnamed>" T.unpack $ idName $ finalOwner pid + PeerIdentityRef wref _ -> "<" ++ BC.unpack (showRefDigest $ wrDigest wref) ++ ">" + PeerIdentityUnknown _ -> "<unknown>" + SelectedContact contact -> return $ T.unpack $ contactName contact + SelectedChatroom rstate -> return $ T.unpack $ fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate + SelectedConversation conv -> return $ T.unpack $ conversationName conv + return $ pname ++ "> " + Right prompt -> return prompt + lift $ setPrompt term prompt + join $ lift $ getInputLine term $ \case + Just input@('/' : _) -> KeepPrompt $ return input + Just input -> ErasePrompt $ case reverse input of + _ | all isSpace input -> getInputLinesTui eprompt + '\\':rest -> (reverse ('\n':rest) ++) <$> getInputLinesTui (Right ">> ") + _ -> return input + Nothing -> KeepPrompt mzero + + getInputCommandTui cstate = do + input <- getInputLinesTui cstate + let (CommandM cmd, line) = case input of + '/':rest -> let (scmd, args) = dropWhile isSpace <$> span (\c -> isAlphaNum c || c == '-') rest + in if not (null scmd) && all isDigit scmd + then (cmdSelectContext, scmd) + else (fromMaybe (cmdUnknown scmd) $ lookup scmd commands, args) + _ -> (cmdSend, input) + return (cmd, line) + + getInputLinesPipe = do + join $ lift $ getInputLine term $ KeepPrompt . \case + Just input -> return input + Nothing -> liftIO $ forever $ threadDelay 100000000 + + getInputCommandPipe _ = do + input <- getInputLinesPipe + let (scmd, args) = dropWhile isSpace <$> span (\c -> isAlphaNum c || c == '-') input + let (CommandM cmd, line) = (fromMaybe (cmdUnknown scmd) $ lookup scmd commands, args) + return (cmd, line) + + let getInputCommand = if tui then getInputCommandTui . Left + else getInputCommandPipe + + _ <- liftIO $ do + tzone <- getCurrentTimeZone + watchReceivedMessages erebosHead $ \smsg -> do + let msg = fromStored smsg + extPrintLn $ formatDirectMessage tzone msg + case optDmBotEcho opts of + Nothing -> return () + Just prefix -> do + res <- runExceptT $ flip runReaderT erebosHead $ sendDirectMessage (msgFrom msg) (prefix <> msgText msg) + case res of + Right reply -> extPrintLn $ formatDirectMessage tzone $ fromStored reply + Left err -> extPrintLn $ "Failed to send dm echo: " <> err + + peers <- liftIO $ newMVar [] + contextOptions <- liftIO $ newMVar [] + chatroomSetVar <- liftIO $ newEmptyMVar + + let autoSubscribe = optChatroomAutoSubscribe opts + chatroomList = fromSetBy (comparing roomStateData) . lookupSharedValue . lsShared . headObject $ erebosHead + watched <- if isJust autoSubscribe || any roomStateSubscribe chatroomList + then fmap Just $ liftIO $ watchChatroomsForCli extPrintLn erebosHead chatroomSetVar contextOptions autoSubscribe + else return Nothing + + server <- liftIO $ do + startServer (optServer opts) erebosHead extPrintLn $ + map soptService $ filter soptEnabled $ optServices opts + + case optWebSocketServer opts of + Just port -> startWebsocketServer server "::" port extPrintLn + Nothing -> return () + + void $ liftIO $ forkIO $ void $ forever $ do + peer <- getNextPeerChange server + peerIdentity peer >>= \case + pid@(PeerIdentityFull _) -> do + dropped <- isPeerDropped peer + let shown = showPeer pid $ peerAddress peer + let update [] = ([(peer, shown)], (Nothing, "NEW")) + update ((p,s):ps) + | p == peer && dropped = (ps, (Nothing, "DEL")) + | p == peer = ((peer, shown) : ps, (Just s, "UPD")) + | otherwise = first ((p,s):) $ update ps + let ctxUpdate n [] = ([SelectedPeer peer], n) + ctxUpdate n (ctx:ctxs) + | SelectedPeer p <- ctx, p == peer = (ctx:ctxs, n) + | otherwise = first (ctx:) $ ctxUpdate (n + 1) ctxs + (op, updateType) <- modifyMVar peers (return . update) + let updateType' = if dropped then "DEL" else updateType + idx <- modifyMVar contextOptions (return . ctxUpdate (1 :: Int)) + when (Just shown /= op) $ extPrintLn $ "[" <> show idx <> "] PEER " <> updateType' <> " " <> shown + _ -> return () + + let process :: CommandState -> MaybeT IO CommandState + process cstate = do + (cmd, line) <- getInputCommand cstate + h <- liftIO (reloadHead $ csHead cstate) >>= \case + Just h -> return h + Nothing -> do lift $ extPrintLn "current head deleted" + mzero + res <- liftIO $ runExceptT $ flip execStateT cstate { csHead = h } $ runReaderT cmd CommandInput + { ciServer = server + , ciTerminal = term + , ciLine = line + , ciPrint = extPrintLn + , ciOptions = opts + , ciPeers = liftIO $ modifyMVar peers $ \ps -> do + ps' <- filterM (fmap not . isPeerDropped . fst) ps + return (ps', ps') + , ciContextOptions = liftIO $ readMVar contextOptions + , ciSetContextOptions = \ctxs -> liftIO $ modifyMVar_ contextOptions $ const $ return ctxs + , ciContextOptionsVar = contextOptions + , ciChatroomSetVar = chatroomSetVar + } + case res of + Right cstate' + | csQuit cstate' -> mzero + | otherwise -> return cstate' + Left err -> do + lift $ extPrintLn $ "Error: " ++ showErebosError err + return cstate + + let loop (Just cstate) = runMaybeT (process cstate) >>= loop + loop Nothing = return () + loop $ Just $ CommandState + { csHead = erebosHead + , csContext = NoContext +#ifdef ENABLE_ICE_SUPPORT + , csIceSessions = [] +#endif + , csIcePeer = Nothing + , csWatchChatrooms = watched + , csQuit = False + } + + +data CommandInput = CommandInput + { ciServer :: Server + , ciTerminal :: Terminal + , ciLine :: String + , ciPrint :: String -> IO () + , ciOptions :: Options + , ciPeers :: CommandM [(Peer, String)] + , ciContextOptions :: CommandM [CommandContext] + , ciSetContextOptions :: [CommandContext] -> Command + , ciContextOptionsVar :: MVar [ CommandContext ] + , ciChatroomSetVar :: MVar (Set ChatroomState) + } + +data CommandState = CommandState + { csHead :: Head LocalState + , csContext :: CommandContext +#ifdef ENABLE_ICE_SUPPORT + , csIceSessions :: [IceSession] +#endif + , csIcePeer :: Maybe Peer + , csWatchChatrooms :: Maybe WatchedHead + , csQuit :: Bool + } + +data CommandContext = NoContext + | SelectedPeer Peer + | SelectedContact Contact + | SelectedChatroom ChatroomState + | SelectedConversation Conversation + +newtype CommandM a = CommandM (ReaderT CommandInput (StateT CommandState (ExceptT ErebosError IO)) a) + deriving (Functor, Applicative, Monad, MonadReader CommandInput, MonadState CommandState, MonadError ErebosError) + +instance MonadFail CommandM where + fail = throwOtherError + +instance MonadIO CommandM where + liftIO act = CommandM (liftIO (try act)) >>= \case + Left (e :: SomeException) -> throwOtherError (show e) + Right x -> return x + +instance MonadRandom CommandM where + getRandomBytes = liftIO . getRandomBytes + +instance MonadStorage CommandM where + getStorage = gets $ headStorage . csHead + +instance MonadHead LocalState CommandM where + updateLocalHead f = do + h <- gets csHead + (Just h', x) <- maybe (fail "failed to reload head") (flip updateHead f) =<< reloadHead h + modify $ \s -> s { csHead = h' } + return x + +type Command = CommandM () + +getSelectedPeer :: CommandM Peer +getSelectedPeer = gets csContext >>= \case + SelectedPeer peer -> return peer + _ -> throwOtherError "no peer selected" + +getSelectedChatroom :: CommandM ChatroomState +getSelectedChatroom = gets csContext >>= \case + SelectedChatroom rstate -> return rstate + _ -> throwOtherError "no chatroom selected" + +getSelectedConversation :: CommandM Conversation +getSelectedConversation = gets csContext >>= getConversationFromContext + +getConversationFromContext :: CommandContext -> CommandM Conversation +getConversationFromContext = \case + SelectedPeer peer -> peerIdentity peer >>= \case + PeerIdentityFull pid -> directMessageConversation $ finalOwner pid + _ -> throwOtherError "incomplete peer identity" + SelectedContact contact -> case contactIdentity contact of + Just cid -> directMessageConversation cid + Nothing -> throwOtherError "contact without erebos identity" + SelectedChatroom rstate -> + chatroomConversation rstate >>= \case + Just conv -> return conv + Nothing -> throwOtherError "invalid chatroom" + SelectedConversation conv -> reloadConversation conv + _ -> throwOtherError "no contact, peer or conversation selected" + +getSelectedOrManualContext :: CommandM CommandContext +getSelectedOrManualContext = do + asks ciLine >>= \case + "" -> gets csContext + str | all isDigit str -> getContextByIndex (read str) + _ -> throwOtherError "invalid index" + +commands :: [(String, Command)] +commands = + [ ("history", cmdHistory) + , ("peers", cmdPeers) + , ("peer-add", cmdPeerAdd) + , ("peer-add-public", cmdPeerAddPublic) + , ("peer-drop", cmdPeerDrop) + , ("send", cmdSend) + , ("delete", cmdDelete) + , ("update-identity", cmdUpdateIdentity) + , ("attach", cmdAttach) + , ("attach-accept", cmdAttachAccept) + , ("attach-reject", cmdAttachReject) + , ("chatrooms", cmdChatrooms) + , ("chatroom-create-public", cmdChatroomCreatePublic) + , ("contacts", cmdContacts) + , ("contact-add", cmdContactAdd) + , ("contact-accept", cmdContactAccept) + , ("contact-reject", cmdContactReject) + , ("conversations", cmdConversations) + , ("details", cmdDetails) + , ("discovery-init", cmdDiscoveryInit) + , ("discovery", cmdDiscovery) +#ifdef ENABLE_ICE_SUPPORT + , ("ice-create", cmdIceCreate) + , ("ice-destroy", cmdIceDestroy) + , ("ice-show", cmdIceShow) + , ("ice-connect", cmdIceConnect) + , ("ice-send", cmdIceSend) +#endif + , ("join", cmdJoin) + , ("join-as", cmdJoinAs) + , ("leave", cmdLeave) + , ("members", cmdMembers) + , ("select", cmdSelectContext) + , ("quit", cmdQuit) + ] + +commandCompletion :: CompletionFunc IO +commandCompletion = completeWordWithPrev Nothing [ ' ', '\t', '\n', '\r' ] $ curry $ \case + ([], '/':pref) -> return . map (simpleCompletion . ('/':)) . filter (pref `isPrefixOf`) $ sortedCommandNames + _ -> return [] + where + sortedCommandNames = sort $ map fst commands + + +cmdPutStrLn :: String -> Command +cmdPutStrLn str = do + term <- asks ciTerminal + void $ liftIO $ printLine term str + +cmdUnknown :: String -> Command +cmdUnknown cmd = cmdPutStrLn $ "Unknown command: " ++ cmd + +cmdPeers :: Command +cmdPeers = do + peers <- join $ asks ciPeers + set <- asks ciSetContextOptions + set $ map (SelectedPeer . fst) peers + forM_ (zip [1..] peers) $ \(i :: Int, (_, name)) -> do + cmdPutStrLn $ "[" ++ show i ++ "] " ++ name + +cmdPeerAdd :: Command +cmdPeerAdd = void $ do + server <- asks ciServer + (hostname, port) <- (words <$> asks ciLine) >>= \case + hostname:p:_ -> return (hostname, p) + [hostname] -> return (hostname, show discoveryPort) + [] -> throwOtherError "missing peer address" + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just hostname) (Just port) + liftIO $ serverPeer server (addrAddress addr) + +cmdPeerAddPublic :: Command +cmdPeerAddPublic = do + server <- asks ciServer + liftIO $ mapM_ (serverPeer server . addrAddress) =<< gather 'a' + where + gather c + | c <= 'z' = do + let hints = Just $ defaultHints { addrSocketType = Datagram } + hostname = Just $ c : ".discovery.erebosprotocol.net" + service = Just $ show discoveryPort + handle (\(_ :: IOException) -> return []) (getAddrInfo hints hostname service) >>= \case + addr : _ -> (addr :) <$> gather (succ c) + [] -> return [] + + | otherwise = do + return [] + +cmdPeerDrop :: Command +cmdPeerDrop = do + dropPeer =<< getSelectedPeer + modify $ \s -> s { csContext = NoContext } + +showPeer :: PeerIdentity -> PeerAddress -> String +showPeer pidentity paddr = + let name = case pidentity of + PeerIdentityUnknown _ -> "<noid>" + PeerIdentityRef wref _ -> "<" ++ BC.unpack (showRefDigest $ wrDigest wref) ++ ">" + PeerIdentityFull pid -> T.unpack $ displayIdentity pid + in name ++ " [" ++ show paddr ++ "]" + +cmdJoin :: Command +cmdJoin = joinChatroom =<< getSelectedChatroom + +cmdJoinAs :: Command +cmdJoinAs = do + name <- asks ciLine + st <- getStorage + identity <- liftIO $ createIdentity st (Just $ T.pack name) Nothing + joinChatroomAs identity =<< getSelectedChatroom + +cmdLeave :: Command +cmdLeave = leaveChatroom =<< getSelectedChatroom + +cmdMembers :: Command +cmdMembers = do + Just room <- findChatroomByStateData . head . roomStateData =<< getSelectedChatroom + forM_ (chatroomMembers room) $ \x -> do + cmdPutStrLn $ maybe "<unnamed>" T.unpack $ idName x + +getContextByIndex :: Int -> CommandM CommandContext +getContextByIndex n = do + join (asks ciContextOptions) >>= \ctxs -> if + | n > 0, (ctx : _) <- drop (n - 1) ctxs -> return ctx + | otherwise -> throwOtherError "invalid index" + +cmdSelectContext :: Command +cmdSelectContext = do + n <- read <$> asks ciLine + ctx <- getContextByIndex n + modify $ \s -> s { csContext = ctx } + case ctx of + SelectedChatroom rstate -> do + when (not (roomStateSubscribe rstate)) $ do + chatroomSetSubscribe (head $ roomStateData rstate) True + _ -> return () + +cmdSend :: Command +cmdSend = void $ do + text <- asks ciLine + conv <- getSelectedConversation + sendMessage conv (T.pack text) >>= \case + Just msg -> do + tzone <- liftIO $ getCurrentTimeZone + cmdPutStrLn $ formatMessage tzone msg + Nothing -> return () + +cmdDelete :: Command +cmdDelete = void $ do + deleteConversation =<< getConversationFromContext =<< getSelectedOrManualContext + modify $ \s -> s { csContext = NoContext } + +cmdHistory :: Command +cmdHistory = void $ do + conv <- getConversationFromContext =<< getSelectedOrManualContext + case conversationHistory conv of + thread@(_:_) -> do + tzone <- liftIO $ getCurrentTimeZone + mapM_ (cmdPutStrLn . formatMessage tzone) $ reverse $ take 50 thread + [] -> do + cmdPutStrLn $ "<empty history>" + +cmdUpdateIdentity :: Command +cmdUpdateIdentity = void $ do + term <- asks ciTerminal + runReaderT (updateSharedIdentity term) =<< gets csHead + +cmdAttach :: Command +cmdAttach = attachToOwner =<< getSelectedPeer + +cmdAttachAccept :: Command +cmdAttachAccept = attachAccept =<< getSelectedPeer + +cmdAttachReject :: Command +cmdAttachReject = attachReject =<< getSelectedPeer + +watchChatroomsForCli :: (String -> IO ()) -> Head LocalState -> MVar (Set ChatroomState) -> MVar [ CommandContext ] -> Maybe Int -> IO WatchedHead +watchChatroomsForCli eprint h chatroomSetVar contextVar autoSubscribe = do + subscribedNumVar <- newEmptyMVar + + let ctxUpdate updateType (idx :: Int) rstate = \case + SelectedChatroom rstate' : rest + | currentRoots <- filterAncestors (concatMap storedRoots $ roomStateData rstate) + , any ((`intersectsSorted` currentRoots) . storedRoots) $ roomStateData rstate' + -> do + eprint $ "[" <> show idx <> "] CHATROOM " <> updateType <> " " <> name + return (SelectedChatroom rstate : rest) + selected : rest + -> do + (selected : ) <$> ctxUpdate updateType (idx + 1) rstate rest + [] + -> do + eprint $ "[" <> show idx <> "] CHATROOM " <> updateType <> " " <> name + return [ SelectedChatroom rstate ] + where + name = maybe "<unnamed>" T.unpack $ roomName =<< roomStateRoom rstate + + watchChatrooms h $ \set -> \case + Nothing -> do + let chatroomList = filter (not . roomStateDeleted) $ fromSetBy (comparing roomStateData) set + (subscribed, notSubscribed) = partition roomStateSubscribe chatroomList + subscribedNum = length subscribed + + putMVar chatroomSetVar set + putMVar subscribedNumVar subscribedNum + + case autoSubscribe of + Nothing -> return () + Just num -> do + forM_ (take (num - subscribedNum) notSubscribed) $ \rstate -> do + (runExceptT $ flip runReaderT h $ chatroomSetSubscribe (head $ roomStateData rstate) True) >>= \case + Right () -> return () + Left err -> eprint (showErebosError err) + + Just diff -> do + modifyMVar_ chatroomSetVar $ return . const set + forM_ diff $ \case + AddedChatroom rstate -> do + modifyMVar_ contextVar $ ctxUpdate "NEW" 1 rstate + modifyMVar_ subscribedNumVar $ return . if roomStateSubscribe rstate then (+ 1) else id + + RemovedChatroom rstate -> do + modifyMVar_ contextVar $ ctxUpdate "DEL" 1 rstate + modifyMVar_ subscribedNumVar $ return . if roomStateSubscribe rstate then subtract 1 else id + + UpdatedChatroom oldroom rstate -> do + when (any ((\rsd -> not (null (rsdRoom rsd))) . fromStored) (roomStateData rstate)) $ do + modifyMVar_ contextVar $ ctxUpdate "UPD" 1 rstate + when (any (not . null . rsdMessages . fromStored) (roomStateData rstate)) $ do + tzone <- getCurrentTimeZone + forM_ (reverse $ getMessagesSinceState rstate oldroom) $ \msg -> do + eprint $ concat $ + [ maybe "<unnamed>" T.unpack $ roomName =<< cmsgRoom msg + , formatTime defaultTimeLocale " [%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ cmsgTime msg + , maybe "<unnamed>" T.unpack $ idName $ cmsgFrom msg + , if cmsgLeave msg then " left" else "" + , maybe (if cmsgLeave msg then "" else " joined") ((": " ++) . T.unpack) $ cmsgText msg + ] + modifyMVar_ subscribedNumVar $ return + . (if roomStateSubscribe rstate then (+ 1) else id) + . (if roomStateSubscribe oldroom then subtract 1 else id) + +ensureWatchedChatrooms :: Command +ensureWatchedChatrooms = do + gets csWatchChatrooms >>= \case + Nothing -> do + eprint <- asks ciPrint + h <- gets csHead + chatroomSetVar <- asks ciChatroomSetVar + contextVar <- asks ciContextOptionsVar + autoSubscribe <- asks $ optChatroomAutoSubscribe . ciOptions + watched <- liftIO $ watchChatroomsForCli eprint h chatroomSetVar contextVar autoSubscribe + modify $ \s -> s { csWatchChatrooms = Just watched } + Just _ -> return () + +cmdChatrooms :: Command +cmdChatrooms = do + ensureWatchedChatrooms + chatroomSetVar <- asks ciChatroomSetVar + chatroomList <- filter (not . roomStateDeleted) . fromSetBy (comparing roomStateData) <$> liftIO (readMVar chatroomSetVar) + set <- asks ciSetContextOptions + set $ map SelectedChatroom chatroomList + forM_ (zip [1..] chatroomList) $ \(i :: Int, rstate) -> do + cmdPutStrLn $ "[" ++ show i ++ "] " ++ maybe "<unnamed>" T.unpack (roomName =<< roomStateRoom rstate) + +cmdChatroomCreatePublic :: Command +cmdChatroomCreatePublic = do + term <- asks ciTerminal + name <- asks ciLine >>= \case + line | not (null line) -> return $ T.pack line + _ -> liftIO $ do + setPrompt term "Name: " + getInputLine term $ KeepPrompt . maybe T.empty T.pack + + ensureWatchedChatrooms + void $ createChatroom + (if T.null name then Nothing else Just name) + Nothing + + +cmdContacts :: Command +cmdContacts = do + args <- words <$> asks ciLine + ehead <- gets csHead + let contacts = fromSetBy (comparing contactName) $ lookupSharedValue $ lsShared $ headObject ehead + verbose = "-v" `elem` args + set <- asks ciSetContextOptions + set $ map SelectedContact contacts + forM_ (zip [1..] contacts) $ \(i :: Int, c) -> do + cmdPutStrLn $ T.unpack $ T.concat + [ "[", T.pack (show i), "] ", contactName c + , case contactIdentity c of + Just idt | cname <- displayIdentity idt + , cname /= contactName c + -> " (" <> cname <> ")" + _ -> "" + , if verbose then " " <> (T.unwords $ map (T.decodeUtf8 . showRef . storedRef) $ maybe [] idDataF $ contactIdentity c) + else "" + ] + +cmdContactAdd :: Command +cmdContactAdd = contactRequest =<< getSelectedPeer + +cmdContactAccept :: Command +cmdContactAccept = contactAccept =<< getSelectedPeer + +cmdContactReject :: Command +cmdContactReject = contactReject =<< getSelectedPeer + +cmdConversations :: Command +cmdConversations = do + conversations <- lookupConversations + set <- asks ciSetContextOptions + set $ map SelectedConversation conversations + forM_ (zip [1..] conversations) $ \(i :: Int, conv) -> do + cmdPutStrLn $ "[" ++ show i ++ "] " ++ T.unpack (conversationName conv) + +cmdDetails :: Command +cmdDetails = do + getSelectedOrManualContext >>= \case + SelectedPeer peer -> do + cmdPutStrLn $ unlines + [ "Network peer:" + , " " <> show (peerAddress peer) + ] + peerIdentity peer >>= \case + PeerIdentityUnknown _ -> do + cmdPutStrLn $ "unknown identity" + PeerIdentityRef wref _ -> do + cmdPutStrLn $ "Identity ref:" + cmdPutStrLn $ " " <> BC.unpack (showRefDigest $ wrDigest wref) + PeerIdentityFull pid -> printContactOrIdentityDetails pid + + SelectedContact contact -> do + printContactDetails contact + + SelectedChatroom rstate -> do + cmdPutStrLn $ "Chatroom: " <> (T.unpack $ fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate) + + SelectedConversation conv -> do + case conversationPeer conv of + Just pid -> printContactOrIdentityDetails pid + Nothing -> cmdPutStrLn $ "(conversation without peer)" + + NoContext -> cmdPutStrLn "nothing selected" + where + printContactOrIdentityDetails cid = do + contacts <- fromSetBy (comparing contactName) . lookupSharedValue . lsShared . fromStored <$> getLocalHead + case find (maybe False (sameIdentity cid) . contactIdentity) contacts of + Just contact -> printContactDetails contact + Nothing -> printIdentityDetails cid + + printContactDetails contact = do + cmdPutStrLn $ "Contact:" + prefix <- case contactCustomName contact of + Just name -> do + cmdPutStrLn $ " " <> T.unpack name + return $ Just "alias of" + Nothing -> do + return $ Nothing + + case contactIdentity contact of + Just cid -> do + printIdentityDetailsBody prefix cid + Nothing -> do + cmdPutStrLn $ " (without erebos identity)" + + printIdentityDetails identity = do + cmdPutStrLn $ "Identity:" + printIdentityDetailsBody Nothing identity + + printIdentityDetailsBody prefix identity = do + forM_ (zip (False : repeat True) $ unfoldOwners identity) $ \(owned, cpid) -> do + cmdPutStrLn $ unwords $ concat + [ [ " " ] + , if owned then [ "owned by" ] else maybeToList prefix + , [ maybe "<unnamed>" T.unpack (idName cpid) ] + , map (BC.unpack . showRefDigest . refDigest . storedRef) $ idExtDataF cpid + ] + +cmdDiscoveryInit :: Command +cmdDiscoveryInit = void $ do + server <- asks ciServer + + (hostname, port) <- (words <$> asks ciLine) >>= return . \case + hostname:p:_ -> (hostname, p) + [hostname] -> (hostname, show discoveryPort) + [] -> ("discovery.erebosprotocol.net", show discoveryPort) + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just hostname) (Just port) + peer <- liftIO $ serverPeer server (addrAddress addr) + sendToPeer peer $ DiscoverySelf [ T.pack "ICE" ] Nothing + modify $ \s -> s { csIcePeer = Just peer } + +cmdDiscovery :: Command +cmdDiscovery = void $ do + server <- asks ciServer + sref <- asks ciLine + case readRefDigest (BC.pack sref) of + Nothing -> throwOtherError "failed to parse ref" + Just dgst -> discoverySearch server dgst + +#ifdef ENABLE_ICE_SUPPORT + +cmdIceCreate :: Command +cmdIceCreate = do + let getRole = \case + 'm':_ -> PjIceSessRoleControlling + 's':_ -> PjIceSessRoleControlled + _ -> PjIceSessRoleUnknown + + ( role, stun, turn ) <- asks (words . ciLine) >>= \case + [] -> return ( PjIceSessRoleControlling, Nothing, Nothing ) + [ role ] -> return + ( getRole role, Nothing, Nothing ) + [ role, server ] -> return + ( getRole role + , Just ( T.pack server, 0 ) + , Just ( T.pack server, 0 ) + ) + [ role, server, port ] -> return + ( getRole role + , Just ( T.pack server, read port ) + , Just ( T.pack server, read port ) + ) + [ role, stunServer, stunPort, turnServer, turnPort ] -> return + ( getRole role + , Just ( T.pack stunServer, read stunPort ) + , Just ( T.pack turnServer, read turnPort ) + ) + _ -> throwOtherError "invalid parameters" + + eprint <- asks ciPrint + Just cfg <- liftIO $ iceCreateConfig stun turn + sess <- liftIO $ iceCreateSession cfg role $ eprint <=< iceShow + modify $ \s -> s { csIceSessions = sess : csIceSessions s } + +cmdIceDestroy :: Command +cmdIceDestroy = do + s:ss <- gets csIceSessions + modify $ \st -> st { csIceSessions = ss } + liftIO $ iceDestroy s + +cmdIceShow :: Command +cmdIceShow = do + sess <- gets csIceSessions + eprint <- asks ciPrint + liftIO $ forM_ (zip [1::Int ..] sess) $ \(i, s) -> do + eprint $ "[" ++ show i ++ "]" + eprint =<< iceShow s + +cmdIceConnect :: Command +cmdIceConnect = do + s:_ <- gets csIceSessions + server <- asks ciServer + term <- asks ciTerminal + let loadInfo = + getInputLine term (KeepPrompt . maybe BC.empty BC.pack) >>= \case + line | BC.null line -> return [] + | otherwise -> (line :) <$> loadInfo + Right remote <- liftIO $ do + st <- memoryStorage + pst <- derivePartialStorage st + setPrompt term "" + rbytes <- (BL.fromStrict . BC.unlines) <$> loadInfo + copyRef st =<< storeRawBytes pst (BL.fromChunks [ BC.pack "rec ", BC.pack (show (BL.length rbytes)), BC.singleton '\n' ] `BL.append` rbytes) + liftIO $ iceConnect s (load remote) $ void $ serverPeerIce server s + +cmdIceSend :: Command +cmdIceSend = void $ do + s:_ <- gets csIceSessions + server <- asks ciServer + liftIO $ serverPeerIce server s + +#endif + +cmdQuit :: Command +cmdQuit = modify $ \s -> s { csQuit = True } + + +intersectsSorted :: Ord a => [a] -> [a] -> Bool +intersectsSorted (x:xs) (y:ys) | x < y = intersectsSorted xs (y:ys) + | x > y = intersectsSorted (x:xs) ys + | otherwise = True +intersectsSorted _ _ = False diff --git a/main/State.hs b/main/State.hs new file mode 100644 index 0000000..150178e --- /dev/null +++ b/main/State.hs @@ -0,0 +1,80 @@ +module State ( + loadLocalStateHead, + updateSharedIdentity, + interactiveIdentityUpdate, +) where + +import Control.Monad.Except +import Control.Monad.IO.Class + +import Data.Foldable +import Data.Maybe +import Data.Proxy +import Data.Text qualified as T + +import Erebos.Error +import Erebos.Identity +import Erebos.PubKey +import Erebos.State +import Erebos.Storable +import Erebos.Storage + +import Terminal + + +loadLocalStateHead :: MonadIO m => Terminal -> Storage -> m (Head LocalState) +loadLocalStateHead term st = loadHeads st >>= \case + (h:_) -> return h + [] -> liftIO $ do + setPrompt term "Name: " + name <- getInputLine term $ KeepPrompt . maybe T.empty T.pack + + setPrompt term "Device: " + devName <- getInputLine term $ KeepPrompt . maybe T.empty T.pack + + owner <- if + | T.null name -> return Nothing + | otherwise -> Just <$> createIdentity st (Just name) Nothing + + identity <- createIdentity st (if T.null devName then Nothing else Just devName) owner + + shared <- wrappedStore st $ SharedState + { ssPrev = [] + , ssType = Just $ sharedTypeID @(Maybe ComposedIdentity) Proxy + , ssValue = [ storedRef $ idExtData $ fromMaybe identity owner ] + } + storeHead st $ LocalState + { lsPrev = Nothing + , lsIdentity = idExtData identity + , lsShared = [ shared ] + , lsOther = [] + } + + +updateSharedIdentity :: (MonadHead LocalState m, MonadError e m, FromErebosError e) => Terminal -> m () +updateSharedIdentity term = updateLocalState_ $ updateSharedState_ $ \case + Just identity -> do + Just . toComposedIdentity <$> interactiveIdentityUpdate term identity + Nothing -> throwOtherError "no existing shared identity" + +interactiveIdentityUpdate :: (Foldable f, MonadStorage m, MonadIO m, MonadError e m, FromErebosError e) => Terminal -> Identity f -> m UnifiedIdentity +interactiveIdentityUpdate term fidentity = do + identity <- mergeIdentity fidentity + name <- liftIO $ do + setPrompt term $ T.unpack $ T.concat $ concat + [ [ T.pack "Name" ] + , case idName identity of + Just name -> [T.pack " [", name, T.pack "]"] + Nothing -> [] + , [ T.pack ": " ] + ] + getInputLine term $ KeepPrompt . maybe T.empty T.pack + + if | T.null name -> return identity + | otherwise -> do + secret <- loadKey $ idKeyIdentity identity + maybe (throwOtherError "created invalid identity") return . validateExtendedIdentity =<< + mstore =<< sign secret =<< mstore . ExtendedIdentityData =<< return (emptyIdentityExtension $ idData identity) + { idePrev = toList $ idExtDataF identity + , ideName = Just name + } diff --git a/main/Terminal.hs b/main/Terminal.hs new file mode 100644 index 0000000..150bd8c --- /dev/null +++ b/main/Terminal.hs @@ -0,0 +1,347 @@ +{-# LANGUAGE CPP #-} + +module Terminal ( + Terminal, + hasTerminalUI, + withTerminal, + setPrompt, + getInputLine, + InputHandling(..), + + TerminalLine, + printLine, + + printBottomLines, + clearBottomLines, + + CompletionFunc, Completion, + noCompletion, + simpleCompletion, + completeWordWithPrev, +) where + +import Control.Arrow +import Control.Concurrent +import Control.Concurrent.STM +import Control.Exception +import Control.Monad + +import Data.Char +import Data.List +import Data.Text (Text) +import Data.Text qualified as T + +import System.Console.ANSI +import System.IO +import System.IO.Error + + +data Terminal = Terminal + { termLock :: MVar () + , termAnsi :: Bool + , termCompletionFunc :: CompletionFunc IO + , termPrompt :: TVar String + , termShowPrompt :: TVar Bool + , termInput :: TVar ( String, String ) + , termBottomLines :: TVar [ String ] + } + +data TerminalLine = TerminalLine + { tlTerminal :: Terminal + } + +data Input + = InputChar Char + | InputMoveRight + | InputMoveLeft + | InputMoveEnd + | InputMoveStart + | InputBackspace + | InputClear + | InputBackWord + | InputEnd + | InputEscape String + deriving (Eq, Ord, Show) + + +data InputHandling a + = KeepPrompt a + | ErasePrompt a + + +hasTerminalUI :: Terminal -> Bool +hasTerminalUI = termAnsi + +initTerminal :: CompletionFunc IO -> IO Terminal +initTerminal termCompletionFunc = do + termLock <- newMVar () +#if MIN_VERSION_ansi_terminal(1, 0, 1) + termAnsi <- hNowSupportsANSI stdout +#else + termAnsi <- hSupportsANSI stdout +#endif + termPrompt <- newTVarIO "" + termShowPrompt <- newTVarIO False + termInput <- newTVarIO ( "", "" ) + termBottomLines <- newTVarIO [] + return Terminal {..} + +bracketSet :: IO a -> (a -> IO b) -> a -> IO c -> IO c +bracketSet get set val = bracket (get <* set val) set . const + +withTerminal :: CompletionFunc IO -> (Terminal -> IO a) -> IO a +withTerminal compl act = do + term <- initTerminal compl + + bracketSet (hGetEcho stdin) (hSetEcho stdin) False $ + bracketSet (hGetBuffering stdin) (hSetBuffering stdin) NoBuffering $ + bracketSet (hGetBuffering stdout) (hSetBuffering stdout) (BlockBuffering Nothing) $ + act term + + +termPutStr :: Terminal -> String -> IO () +termPutStr Terminal {..} str = do + withMVar termLock $ \_ -> do + putStr str + hFlush stdout + + +getInput :: IO Input +getInput = do + handleJust (guard . isEOFError) (\() -> return InputEnd) $ getChar >>= \case + '\ESC' -> do + esc <- readEsc + case parseEsc esc of + Just ( 'C' , [] ) -> return InputMoveRight + Just ( 'D' , [] ) -> return InputMoveLeft + _ -> return (InputEscape esc) + '\b' -> return InputBackspace + '\DEL' -> return InputBackspace + '\NAK' -> return InputClear + '\ETB' -> return InputBackWord + '\SOH' -> return InputMoveStart + '\ENQ' -> return InputMoveEnd + '\EOT' -> return InputEnd + c -> return (InputChar c) + where + readEsc = getChar >>= \case + c | c == '\ESC' || isAlpha c -> return [ c ] + | otherwise -> (c :) <$> readEsc + + parseEsc = \case + '[' : c : [] -> do + Just ( c, [] ) + _ -> Nothing + + +getInputLine :: Terminal -> (Maybe String -> InputHandling a) -> IO a +getInputLine term@Terminal {..} handleResult = do + withMVar termLock $ \_ -> do + prompt <- atomically $ do + writeTVar termShowPrompt True + readTVar termPrompt + putStr $ prompt <> "\ESC[K" + drawBottomLines term + hFlush stdout + (handleResult <$> go) >>= \case + KeepPrompt x -> do + termPutStr term "\n\ESC[J" + return x + ErasePrompt x -> do + termPutStr term "\r\ESC[J" + return x + where + go = getInput >>= \case + InputChar '\n' -> do + atomically $ do + ( pre, post ) <- readTVar termInput + writeTVar termInput ( "", "" ) + writeTVar termShowPrompt False + writeTVar termBottomLines [] + return $ Just $ pre ++ post + + InputChar '\t' -> do + options <- withMVar termLock $ const $ do + ( pre, post ) <- atomically $ readTVar termInput + let updatePrompt pre' = do + prompt <- atomically $ do + writeTVar termInput ( pre', post ) + getCurrentPromptLine term + putStr $ "\r" <> prompt + hFlush stdout + + termCompletionFunc ( T.pack pre, T.pack post ) >>= \case + + ( unused, [ compl ] ) -> do + updatePrompt $ T.unpack unused ++ T.unpack (replacement compl) ++ if isFinished compl then " " else "" + return [] + + ( unused, completions@(c : cs) ) -> do + let commonPrefixes' x y = fmap (\( common, _, _ ) -> common) $ T.commonPrefixes x y + case foldl' (\mbcommon cur -> commonPrefixes' cur =<< mbcommon) (Just $ replacement c) (fmap replacement cs) of + Just common -> updatePrompt $ T.unpack unused ++ T.unpack common + Nothing -> return () + return $ map replacement completions + + ( _, [] ) -> do + return [] + + printBottomLines term $ T.unpack $ T.unlines options + go + + InputChar c | isPrint c -> withInput $ \case + ( _, post ) -> do + writeTVar termInput . first (++ [ c ]) =<< readTVar termInput + return $ c : (if null post then "" else "\ESC[s" <> post <> "\ESC[u") + + InputChar _ -> go + + InputMoveRight -> withInput $ \case + ( pre, c : post ) -> do + writeTVar termInput ( pre ++ [ c ], post ) + return $ "\ESC[C" + _ -> return "" + + InputMoveLeft -> withInput $ \case + ( pre@(_ : _), post ) -> do + writeTVar termInput ( init pre, last pre : post ) + return $ "\ESC[D" + _ -> return "" + + InputBackspace -> withInput $ \case + ( pre@(_ : _), post ) -> do + writeTVar termInput ( init pre, post ) + return $ "\b\ESC[K" <> (if null post then "" else "\ESC[s" <> post <> "\ESC[u") + _ -> return "" + + InputClear -> withInput $ \_ -> do + writeTVar termInput ( "", "" ) + ("\r\ESC[K" <>) <$> getCurrentPromptLine term + + InputBackWord -> withInput $ \( pre, post ) -> do + let pre' = reverse $ dropWhile (not . isSpace) $ dropWhile isSpace $ reverse pre + writeTVar termInput ( pre', post ) + ("\r\ESC[K" <>) <$> getCurrentPromptLine term + + InputMoveStart -> withInput $ \( pre, post ) -> do + writeTVar termInput ( "", pre <> post ) + return $ "\ESC[" <> show (length pre) <> "D" + + InputMoveEnd -> withInput $ \( pre, post ) -> do + writeTVar termInput ( pre <> post, "" ) + return $ "\ESC[" <> show (length post) <> "C" + + InputEnd -> do + atomically (readTVar termInput) >>= \case + ( "", "" ) -> return Nothing + _ -> go + + InputEscape _ -> go + + withInput f = do + withMVar termLock $ const $ do + str <- atomically $ f =<< readTVar termInput + when (not $ null str) $ do + putStr str + hFlush stdout + go + + +getCurrentPromptLine :: Terminal -> STM String +getCurrentPromptLine Terminal {..} = do + prompt <- readTVar termPrompt + ( pre, post ) <- readTVar termInput + return $ prompt <> pre <> "\ESC[s" <> post <> "\ESC[u" + +setPrompt :: Terminal -> String -> IO () +setPrompt term@Terminal {..} prompt = do + withMVar termLock $ \_ -> do + join $ atomically $ do + writeTVar termPrompt prompt + readTVar termShowPrompt >>= \case + True -> do + promptLine <- getCurrentPromptLine term + return $ do + putStr $ "\r\ESC[K" <> promptLine + hFlush stdout + False -> return $ return () + +printLine :: Terminal -> String -> IO TerminalLine +printLine tlTerminal@Terminal {..} str = do + withMVar termLock $ \_ -> do + promptLine <- atomically $ do + readTVar termShowPrompt >>= \case + True -> getCurrentPromptLine tlTerminal + False -> return "" + putStr $ "\r\ESC[K" <> str <> "\n\ESC[K" <> promptLine + drawBottomLines tlTerminal + hFlush stdout + return TerminalLine {..} + + +printBottomLines :: Terminal -> String -> IO () +printBottomLines term@Terminal {..} str = do + case lines str of + [] -> clearBottomLines term + blines -> do + withMVar termLock $ \_ -> do + atomically $ writeTVar termBottomLines blines + drawBottomLines term + hFlush stdout + +clearBottomLines :: Terminal -> IO () +clearBottomLines Terminal {..} = do + withMVar termLock $ \_ -> do + atomically (readTVar termBottomLines) >>= \case + [] -> return () + _:_ -> do + atomically $ writeTVar termBottomLines [] + putStr $ "\ESC[s\n\ESC[J\ESC[u" + hFlush stdout + +drawBottomLines :: Terminal -> IO () +drawBottomLines Terminal {..} = do + atomically (readTVar termBottomLines) >>= \case + blines@( firstLine : otherLines ) -> do + ( shift ) <- atomically $ do + readTVar termShowPrompt >>= \case + True -> do + prompt <- readTVar termPrompt + ( pre, _ ) <- readTVar termInput + return (displayWidth (prompt <> pre) + 1) + False -> do + return 0 + putStr $ concat + [ "\n\ESC[J", firstLine, concat (map ('\n' :) otherLines) + , "\ESC[", show (length blines), "F" + , "\ESC[", show shift, "G" + ] + [] -> return () + + +displayWidth :: String -> Int +displayWidth = \case + ('\ESC' : '[' : rest) -> displayWidth $ drop 1 $ dropWhile (not . isAlpha) rest + ('\ESC' : _ : rest) -> displayWidth rest + (_ : rest) -> 1 + displayWidth rest + [] -> 0 + + +type CompletionFunc m = ( Text, Text ) -> m ( Text, [ Completion ] ) + +data Completion = Completion + { replacement :: Text + , isFinished :: Bool + } + +noCompletion :: Monad m => CompletionFunc m +noCompletion ( l, _ ) = return ( l, [] ) + +completeWordWithPrev :: Monad m => Maybe Char -> [ Char ] -> (String -> String -> m [ Completion ]) -> CompletionFunc m +completeWordWithPrev _ spaceChars fun ( l, _ ) = do + let lastSpaceIndex = snd $ T.foldl' (\( i, found ) c -> if c `elem` spaceChars then ( i + 1, i ) else ( i + 1, found )) ( 1, 0 ) l + let ( pre, word ) = T.splitAt lastSpaceIndex l + ( pre, ) <$> fun (T.unpack pre) (T.unpack word) + +simpleCompletion :: String -> Completion +simpleCompletion str = Completion (T.pack str) True diff --git a/main/Test.hs b/main/Test.hs new file mode 100644 index 0000000..b07bd87 --- /dev/null +++ b/main/Test.hs @@ -0,0 +1,985 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Test ( + runTestTool, +) where + +import Control.Arrow +import Control.Concurrent +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Crypto.Random + +import Data.Bool +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.ByteString.Lazy.Char8 qualified as BL +import Data.Foldable +import Data.Ord +import Data.Text (Text) +import Data.Text qualified as T +import Data.Text.Encoding +import Data.Text.IO qualified as T +import Data.Typeable +import Data.UUID.Types qualified as U + +import Network.Socket + +import System.IO +import System.IO.Error + +import Erebos.Attach +import Erebos.Chatroom +import Erebos.Contact +import Erebos.DirectMessage +import Erebos.Discovery +import Erebos.Identity +import Erebos.Network +import Erebos.Object +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.Service.Stream +import Erebos.Set +import Erebos.State +import Erebos.Storable +import Erebos.Storage +import Erebos.Storage.Head +import Erebos.Storage.Merge +import Erebos.Sync + +import Test.Service + + +data TestState = TestState + { tsHead :: Maybe (Head LocalState) + , tsServer :: Maybe RunningServer + , tsWatchedHeads :: [ ( Int, WatchedHead ) ] + , tsWatchedHeadNext :: Int + , tsWatchedLocalIdentity :: Maybe WatchedHead + , tsWatchedSharedIdentity :: Maybe WatchedHead + } + +data RunningServer = RunningServer + { rsServer :: Server + , rsPeers :: MVar ( Int, [ TestPeer ] ) + , rsPeerThread :: ThreadId + } + +data TestPeer = TestPeer + { tpIndex :: Int + , tpPeer :: Peer + , tpStreamReaders :: MVar [ (Int, StreamReader ) ] + , tpStreamWriters :: MVar [ (Int, StreamWriter ) ] + } + +initTestState :: TestState +initTestState = TestState + { tsHead = Nothing + , tsServer = Nothing + , tsWatchedHeads = [] + , tsWatchedHeadNext = 1 + , tsWatchedLocalIdentity = Nothing + , tsWatchedSharedIdentity = Nothing + } + +data TestInput = TestInput + { tiOutput :: Output + , tiStorage :: Storage + , tiParams :: [Text] + } + + +runTestTool :: Storage -> IO () +runTestTool st = do + out <- newMVar () + let testLoop = getLineMb >>= \case + Just line -> do + case T.words line of + (cname:params) + | Just (CommandM cmd) <- lookup cname commands -> do + runReaderT cmd $ TestInput out st params + | otherwise -> fail $ "Unknown command '" ++ T.unpack cname ++ "'" + [] -> return () + testLoop + + Nothing -> return () + + runExceptT (evalStateT testLoop initTestState) >>= \case + Left x -> B.hPutStr stderr $ (`BC.snoc` '\n') $ BC.pack (showErebosError x) + Right () -> return () + +getLineMb :: MonadIO m => m (Maybe Text) +getLineMb = liftIO $ catchIOError (Just <$> T.getLine) (\e -> if isEOFError e then return Nothing else ioError e) + +getLines :: MonadIO m => m [Text] +getLines = getLineMb >>= \case + Just line | not (T.null line) -> (line:) <$> getLines + _ -> return [] + +getHead :: CommandM (Head LocalState) +getHead = do + h <- maybe (fail "failed to reload head") return =<< maybe (fail "no current head") reloadHead =<< gets tsHead + modify $ \s -> s { tsHead = Just h } + return h + + +type Output = MVar () + +outLine :: Output -> String -> IO () +outLine mvar line = do + evaluate $ foldl' (flip seq) () line + withMVar mvar $ \() -> do + B.putStr $ (`BC.snoc` '\n') $ BC.pack line + hFlush stdout + +cmdOut :: String -> Command +cmdOut line = do + out <- asks tiOutput + liftIO $ outLine out line + + +getPeer :: Text -> CommandM Peer +getPeer spidx = tpPeer <$> getTestPeer spidx + +getTestPeer :: Text -> CommandM TestPeer +getTestPeer spidx = do + Just RunningServer {..} <- gets tsServer + Just peer <- find (((read $ T.unpack spidx) ==) . tpIndex) . snd <$> liftIO (readMVar rsPeers) + return peer + +getPeerIndex :: MVar ( Int, [ TestPeer ] ) -> ServiceHandler s Int +getPeerIndex pmvar = do + peer <- asks svcPeer + maybe 0 tpIndex . find ((peer ==) . tpPeer) . snd <$> liftIO (readMVar pmvar) + +pairingAttributes :: PairingResult a => proxy (PairingService a) -> Output -> MVar ( Int, [ TestPeer ] ) -> String -> PairingAttributes a +pairingAttributes _ out peers prefix = PairingAttributes + { pairingHookRequest = return () + + , pairingHookResponse = \confirm -> do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-response", index, confirm] + + , pairingHookRequestNonce = \confirm -> do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-request", index, confirm] + + , pairingHookRequestNonceFailed = failed "nonce" + + , pairingHookConfirmedResponse = return () + , pairingHookConfirmedRequest = return () + + , pairingHookAcceptedResponse = do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-response-done", index] + + , pairingHookAcceptedRequest = do + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ unwords [prefix ++ "-request-done", index] + + , pairingHookFailed = \case + PairingUserRejected -> failed "user" + PairingUnexpectedMessage pstate packet -> failed $ "unexpected " ++ strState pstate ++ " " ++ strPacket packet + PairingFailedOther err -> failed $ "other " ++ showErebosError err + , pairingHookVerifyFailed = failed "verify" + , pairingHookRejected = failed "rejected" + } + where + failed :: PairingResult a => String -> ServiceHandler (PairingService a) () + failed detail = do + ptype <- svcGet >>= return . \case + OurRequest {} -> "response" + OurRequestConfirm {} -> "response" + OurRequestReady -> "response" + PeerRequest {} -> "request" + PeerRequestConfirm -> "request" + _ -> fail "unexpected pairing state" + + index <- show <$> getPeerIndex peers + afterCommit $ outLine out $ prefix ++ "-" ++ ptype ++ "-failed " ++ index ++ " " ++ detail + + strState :: PairingState a -> String + strState = \case + NoPairing -> "none" + OurRequest {} -> "our-request" + OurRequestConfirm {} -> "our-request-confirm" + OurRequestReady -> "our-request-ready" + PeerRequest {} -> "peer-request" + PeerRequestConfirm -> "peer-request-confirm" + PairingDone -> "done" + + strPacket :: PairingService a -> String + strPacket = \case + PairingRequest {} -> "request" + PairingResponse {} -> "response" + PairingRequestNonce {} -> "nonce" + PairingAccept {} -> "accept" + PairingReject -> "reject" + +directMessageAttributes :: Output -> DirectMessageAttributes +directMessageAttributes out = DirectMessageAttributes + { dmOwnerMismatch = afterCommit $ outLine out "dm-owner-mismatch" + } + +discoveryAttributes :: DiscoveryAttributes +discoveryAttributes = (defaultServiceAttributes Proxy) + { discoveryProvideTunnel = const True + } + +dmReceivedWatcher :: Output -> Stored DirectMessage -> IO () +dmReceivedWatcher out smsg = do + let msg = fromStored smsg + outLine out $ unwords + [ "dm-received" + , "from", maybe "<unnamed>" T.unpack $ idName $ msgFrom msg + , "text", T.unpack $ msgText msg + ] + + +newtype CommandM a = CommandM (ReaderT TestInput (StateT TestState (ExceptT ErebosError IO)) a) + deriving (Functor, Applicative, Monad, MonadIO, MonadReader TestInput, MonadState TestState, MonadError ErebosError) + +instance MonadFail CommandM where + fail = throwOtherError + +instance MonadRandom CommandM where + getRandomBytes = liftIO . getRandomBytes + +instance MonadStorage CommandM where + getStorage = asks tiStorage + +instance MonadHead LocalState CommandM where + updateLocalHead f = do + Just h <- gets tsHead + (Just h', x) <- maybe (fail "failed to reload head") (flip updateHead f) =<< reloadHead h + modify $ \s -> s { tsHead = Just h' } + return x + +type Command = CommandM () + +commands :: [(Text, Command)] +commands = map (T.pack *** id) + [ ("store", cmdStore) + , ("load", cmdLoad) + , ("stored-generation", cmdStoredGeneration) + , ("stored-roots", cmdStoredRoots) + , ("stored-set-add", cmdStoredSetAdd) + , ("stored-set-list", cmdStoredSetList) + , ("head-create", cmdHeadCreate) + , ("head-replace", cmdHeadReplace) + , ("head-watch", cmdHeadWatch) + , ("head-unwatch", cmdHeadUnwatch) + , ("create-identity", cmdCreateIdentity) + , ("identity-info", cmdIdentityInfo) + , ("start-server", cmdStartServer) + , ("stop-server", cmdStopServer) + , ("peer-add", cmdPeerAdd) + , ("peer-drop", cmdPeerDrop) + , ("peer-list", cmdPeerList) + , ("test-message-send", cmdTestMessageSend) + , ("test-stream-open", cmdTestStreamOpen) + , ("test-stream-close", cmdTestStreamClose) + , ("test-stream-send", cmdTestStreamSend) + , ("local-state-get", cmdLocalStateGet) + , ("local-state-replace", cmdLocalStateReplace) + , ("local-state-wait", cmdLocalStateWait) + , ("shared-state-get", cmdSharedStateGet) + , ("shared-state-wait", cmdSharedStateWait) + , ("watch-local-identity", cmdWatchLocalIdentity) + , ("watch-shared-identity", cmdWatchSharedIdentity) + , ("update-local-identity", cmdUpdateLocalIdentity) + , ("update-shared-identity", cmdUpdateSharedIdentity) + , ("attach-to", cmdAttachTo) + , ("attach-accept", cmdAttachAccept) + , ("attach-reject", cmdAttachReject) + , ("contact-request", cmdContactRequest) + , ("contact-accept", cmdContactAccept) + , ("contact-reject", cmdContactReject) + , ("contact-list", cmdContactList) + , ("contact-set-name", cmdContactSetName) + , ("dm-send-peer", cmdDmSendPeer) + , ("dm-send-contact", cmdDmSendContact) + , ("dm-send-identity", cmdDmSendIdentity) + , ("dm-list-peer", cmdDmListPeer) + , ("dm-list-contact", cmdDmListContact) + , ("chatroom-create", cmdChatroomCreate) + , ("chatroom-delete", cmdChatroomDelete) + , ("chatroom-list-local", cmdChatroomListLocal) + , ("chatroom-watch-local", cmdChatroomWatchLocal) + , ("chatroom-set-name", cmdChatroomSetName) + , ("chatroom-subscribe", cmdChatroomSubscribe) + , ("chatroom-unsubscribe", cmdChatroomUnsubscribe) + , ("chatroom-members", cmdChatroomMembers) + , ("chatroom-join", cmdChatroomJoin) + , ("chatroom-join-as", cmdChatroomJoinAs) + , ("chatroom-leave", cmdChatroomLeave) + , ("chatroom-message-send", cmdChatroomMessageSend) + , ("discovery-connect", cmdDiscoveryConnect) + , ("discovery-tunnel", cmdDiscoveryTunnel) + ] + +cmdStore :: Command +cmdStore = do + st <- asks tiStorage + pst <- liftIO $ derivePartialStorage st + [otype] <- asks tiParams + ls <- getLines + + let cnt = encodeUtf8 $ T.unlines ls + full = BL.fromChunks + [ encodeUtf8 otype + , BC.singleton ' ' + , BC.pack (show $ B.length cnt) + , BC.singleton '\n', cnt + ] + liftIO (copyRef st =<< storeRawBytes pst full) >>= \case + Right ref -> cmdOut $ "store-done " ++ show (refDigest ref) + Left _ -> cmdOut $ "store-failed" + +cmdLoad :: Command +cmdLoad = do + st <- asks tiStorage + [ tref ] <- asks tiParams + Just ref <- liftIO $ readRef st $ encodeUtf8 tref + let obj = load @Object ref + header : content <- return $ BL.lines $ serializeObject obj + cmdOut $ "load-type " <> T.unpack (decodeUtf8 $ BL.toStrict header) + forM_ content $ \line -> do + cmdOut $ "load-line " <> T.unpack (decodeUtf8 $ BL.toStrict line) + cmdOut "load-done" + +cmdStoredGeneration :: Command +cmdStoredGeneration = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + cmdOut $ "stored-generation " ++ T.unpack tref ++ " " ++ showGeneration (storedGeneration $ wrappedLoad @Object ref) + +cmdStoredRoots :: Command +cmdStoredRoots = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + cmdOut $ "stored-roots " ++ T.unpack tref ++ concatMap ((' ':) . show . refDigest . storedRef) (storedRoots $ wrappedLoad @Object ref) + +cmdStoredSetAdd :: Command +cmdStoredSetAdd = do + st <- asks tiStorage + (item, set) <- asks tiParams >>= liftIO . mapM (readRef st . encodeUtf8) >>= \case + [Just iref, Just sref] -> return (wrappedLoad iref, loadSet @[Stored Object] sref) + [Just iref] -> return (wrappedLoad iref, emptySet) + _ -> fail "unexpected parameters" + set' <- storeSetAdd st [item] set + cmdOut $ "stored-set-add" ++ concatMap ((' ':) . show . refDigest . storedRef) (toComponents set') + +cmdStoredSetList :: Command +cmdStoredSetList = do + st <- asks tiStorage + [tref] <- asks tiParams + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + let items = fromSetBy compare $ loadSet @[Stored Object] ref + forM_ items $ \item -> do + cmdOut $ "stored-set-item" ++ concatMap ((' ':) . show . refDigest . storedRef) item + cmdOut $ "stored-set-done" + +cmdHeadCreate :: Command +cmdHeadCreate = do + [ ttid, tref ] <- asks tiParams + st <- asks tiStorage + Just tid <- return $ fromUUID <$> U.fromText ttid + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + + h <- storeHeadRaw st tid ref + cmdOut $ unwords $ [ "head-create-done", show (toUUID tid), show (toUUID h) ] + +cmdHeadReplace :: Command +cmdHeadReplace = do + [ ttid, thid, told, tnew ] <- asks tiParams + st <- asks tiStorage + Just tid <- return $ fmap fromUUID $ U.fromText ttid + Just hid <- return $ fmap fromUUID $ U.fromText thid + Just old <- liftIO $ readRef st (encodeUtf8 told) + Just new <- liftIO $ readRef st (encodeUtf8 tnew) + + replaceHeadRaw st tid hid old new >>= cmdOut . unwords . \case + Left Nothing -> [ "head-replace-fail", T.unpack ttid, T.unpack thid, T.unpack told, T.unpack tnew ] + Left (Just r) -> [ "head-replace-fail", T.unpack ttid, T.unpack thid, T.unpack told, T.unpack tnew, show (refDigest r) ] + Right _ -> [ "head-replace-done", T.unpack ttid, T.unpack thid, T.unpack told, T.unpack tnew ] + +cmdHeadWatch :: Command +cmdHeadWatch = do + [ ttid, thid ] <- asks tiParams + st <- asks tiStorage + Just tid <- return $ fmap fromUUID $ U.fromText ttid + Just hid <- return $ fmap fromUUID $ U.fromText thid + + out <- asks tiOutput + wid <- gets tsWatchedHeadNext + + watched <- liftIO $ watchHeadRaw st tid hid id $ \r -> do + outLine out $ unwords [ "head-watch-cb", show wid, show $ refDigest r ] + + modify $ \s -> s + { tsWatchedHeads = ( wid, watched ) : tsWatchedHeads s + , tsWatchedHeadNext = wid + 1 + } + + cmdOut $ unwords $ [ "head-watch-done", T.unpack ttid, T.unpack thid, show wid ] + +cmdHeadUnwatch :: Command +cmdHeadUnwatch = do + [ twid ] <- asks tiParams + let wid = read (T.unpack twid) + Just watched <- lookup wid <$> gets tsWatchedHeads + liftIO $ unwatchHead watched + cmdOut $ unwords [ "head-unwatch-done", show wid ] + +initTestHead :: Head LocalState -> Command +initTestHead h = do + _ <- liftIO . watchReceivedMessages h . dmReceivedWatcher =<< asks tiOutput + modify $ \s -> s { tsHead = Just h } + +loadTestHead :: CommandM (Head LocalState) +loadTestHead = do + st <- asks tiStorage + h <- loadHeads st >>= \case + h : _ -> return h + [] -> fail "no local head found" + initTestHead h + return h + +getOrLoadHead :: CommandM (Head LocalState) +getOrLoadHead = do + gets tsHead >>= \case + Just h -> return h + Nothing -> loadTestHead + +cmdCreateIdentity :: Command +cmdCreateIdentity = do + st <- asks tiStorage + names <- asks tiParams + + h <- liftIO $ do + Just identity <- if null names + then Just <$> createIdentity st Nothing Nothing + else foldrM (\n o -> Just <$> createIdentity st (Just n) o) Nothing names + + shared <- case names of + _:_:_ -> (:[]) <$> makeSharedStateUpdate st (Just $ finalOwner identity) [] + _ -> return [] + + storeHead st $ LocalState + { lsPrev = Nothing + , lsIdentity = idExtData identity + , lsShared = shared + , lsOther = [] + } + initTestHead h + cmdOut $ unwords [ "create-identity-done", "ref", show $ refDigest $ storedRef $ lsIdentity $ headObject h ] + +cmdIdentityInfo :: Command +cmdIdentityInfo = do + st <- asks tiStorage + [ tref ] <- asks tiParams + Just ref <- liftIO $ readRef st $ encodeUtf8 tref + let sidata = wrappedLoad ref + idata = fromSigned sidata + cmdOut $ unwords $ concat + [ [ "identity-info" ] + , [ "ref", T.unpack tref ] + , [ "base", show $ refDigest $ storedRef $ eiddStoredBase sidata ] + , maybe [] (\owner -> [ "owner", show $ refDigest $ storedRef owner ]) $ eiddOwner idata + , maybe [] (\name -> [ "name", T.unpack name ]) $ eiddName idata + ] + +cmdStartServer :: Command +cmdStartServer = do + out <- asks tiOutput + + let parseParams = \case + (name : value : rest) + | name == "services" -> T.splitOn "," value + | otherwise -> parseParams rest + _ -> [] + serviceNames <- parseParams <$> asks tiParams + + h <- getOrLoadHead + rsPeers <- liftIO $ newMVar (1, []) + services <- forM serviceNames $ \case + "attach" -> return $ someServiceAttr $ pairingAttributes (Proxy @AttachService) out rsPeers "attach" + "chatroom" -> return $ someService @ChatroomService Proxy + "contact" -> return $ someServiceAttr $ pairingAttributes (Proxy @ContactService) out rsPeers "contact" + "discovery" -> return $ someServiceAttr $ discoveryAttributes + "dm" -> return $ someServiceAttr $ directMessageAttributes out + "sync" -> return $ someService @SyncService Proxy + "test" -> return $ someServiceAttr $ (defaultServiceAttributes Proxy) + { testMessageReceived = \obj otype len sref -> do + liftIO $ do + void $ store (headStorage h) obj + outLine out $ unwords [ "test-message-received", otype, len, sref ] + , testStreamsReceived = \streams -> do + pidx <- getPeerIndex rsPeers + liftIO $ do + nums <- mapM getStreamReaderNumber streams + outLine out $ unwords $ "test-stream-open-from" : show pidx : map show nums + forM_ (zip nums streams) $ \( num, stream ) -> void $ forkIO $ do + let go = readStreamPacket stream >>= \case + StreamData seqNum bytes -> do + outLine out $ unwords [ "test-stream-received", show pidx, show num, show seqNum, BC.unpack bytes ] + go + StreamClosed seqNum -> do + outLine out $ unwords [ "test-stream-closed-from", show pidx, show num, show seqNum ] + go + } + sname -> throwOtherError $ "unknown service `" <> T.unpack sname <> "'" + + rsServer <- liftIO $ startServer defaultServerOptions h (B.hPutStr stderr . (`BC.snoc` '\n') . BC.pack) services + + rsPeerThread <- liftIO $ forkIO $ void $ forever $ do + peer <- getNextPeerChange rsServer + + let printPeer TestPeer {..} = do + params <- peerIdentity tpPeer >>= return . \case + PeerIdentityFull pid -> ("id":) $ map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners pid) + _ -> [ "addr", show (peerAddress tpPeer) ] + outLine out $ unwords $ [ "peer", show tpIndex ] ++ params + + update ( tpIndex, [] ) = do + tpPeer <- return peer + tpStreamReaders <- newMVar [] + tpStreamWriters <- newMVar [] + let tp = TestPeer {..} + printPeer tp + return ( tpIndex + 1, [ tp ] ) + + update cur@( nid, p : ps ) + | tpPeer p == peer = printPeer p >> return cur + | otherwise = fmap (p :) <$> update ( nid, ps ) + + modifyMVar_ rsPeers update + + modify $ \s -> s { tsServer = Just RunningServer {..} } + +cmdStopServer :: Command +cmdStopServer = do + Just RunningServer {..} <- gets tsServer + liftIO $ do + killThread rsPeerThread + stopServer rsServer + modify $ \s -> s { tsServer = Nothing } + cmdOut "stop-server-done" + +cmdPeerAdd :: Command +cmdPeerAdd = do + Just RunningServer {..} <- gets tsServer + host:rest <- map T.unpack <$> asks tiParams + + let port = case rest of [] -> show discoveryPort + (p:_) -> p + addr:_ <- liftIO $ getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just host) (Just port) + void $ liftIO $ serverPeer rsServer (addrAddress addr) + +cmdPeerDrop :: Command +cmdPeerDrop = do + [spidx] <- asks tiParams + peer <- getPeer spidx + liftIO $ dropPeer peer + +cmdPeerList :: Command +cmdPeerList = do + Just RunningServer {..} <- gets tsServer + peers <- liftIO $ getCurrentPeerList rsServer + tpeers <- liftIO $ readMVar rsPeers + forM_ peers $ \peer -> do + Just tp <- return $ find ((peer ==) . tpPeer) . snd $ tpeers + mbpid <- peerIdentity peer + cmdOut $ unwords $ concat + [ [ "peer-list-item", show (tpIndex tp) ] + , [ "addr", show (peerAddress peer) ] + , case mbpid of PeerIdentityFull pid -> ("id":) $ map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners pid) + _ -> [] + ] + cmdOut "peer-list-done" + + +cmdTestMessageSend :: Command +cmdTestMessageSend = do + spidx : trefs <- asks tiParams + st <- asks tiStorage + Just refs <- liftIO $ fmap sequence $ mapM (readRef st . encodeUtf8) trefs + peer <- getPeer spidx + sendManyToPeer peer $ map (TestMessage . wrappedLoad) refs + cmdOut "test-message-send done" + +cmdTestStreamOpen :: Command +cmdTestStreamOpen = do + spidx : rest <- asks tiParams + tp <- getTestPeer spidx + count <- case rest of + [] -> return 1 + tcount : _ -> return $ read $ T.unpack tcount + + out <- asks tiOutput + runPeerService (tpPeer tp) $ do + streams <- openTestStreams count + afterCommit $ do + nums <- mapM getStreamWriterNumber streams + modifyMVar_ (tpStreamWriters tp) $ return . (++ zip nums streams) + outLine out $ unwords $ "test-stream-open-done" + : T.unpack spidx + : map show nums + +cmdTestStreamClose :: Command +cmdTestStreamClose = do + [ spidx, sid ] <- asks tiParams + tp <- getTestPeer spidx + Just stream <- lookup (read $ T.unpack sid) <$> liftIO (readMVar (tpStreamWriters tp)) + liftIO $ closeStream stream + cmdOut $ unwords [ "test-stream-close-done", T.unpack spidx, T.unpack sid ] + +cmdTestStreamSend :: Command +cmdTestStreamSend = do + [ spidx, sid, content ] <- asks tiParams + tp <- getTestPeer spidx + Just stream <- lookup (read $ T.unpack sid) <$> liftIO (readMVar (tpStreamWriters tp)) + liftIO $ writeStream stream $ encodeUtf8 content + cmdOut $ unwords [ "test-stream-send-done", T.unpack spidx, T.unpack sid ] + +cmdLocalStateGet :: Command +cmdLocalStateGet = do + h <- getHead + cmdOut $ unwords $ "local-state-get" : map (show . refDigest . storedRef) [ headStoredObject h ] + +cmdLocalStateReplace :: Command +cmdLocalStateReplace = do + st <- asks tiStorage + [ told, tnew ] <- asks tiParams + Just rold <- liftIO $ readRef st $ encodeUtf8 told + Just rnew <- liftIO $ readRef st $ encodeUtf8 tnew + ok <- updateLocalHead @LocalState $ \ls -> do + if storedRef ls == rold + then return ( wrappedLoad rnew, True ) + else return ( ls, False ) + cmdOut $ if ok then "local-state-replace-done" else "local-state-replace-failed" + +localStateWaitHelper :: Storable a => String -> (Head LocalState -> [ Stored a ]) -> Command +localStateWaitHelper label sel = do + st <- asks tiStorage + out <- asks tiOutput + h <- getOrLoadHead + trefs <- asks tiParams + + liftIO $ do + mvar <- newEmptyMVar + w <- watchHeadWith h sel $ \cur -> do + mbobjs <- mapM (readRef st . encodeUtf8) trefs + case map wrappedLoad <$> sequence mbobjs of + Just objs | filterAncestors (cur ++ objs) == cur -> do + outLine out $ unwords $ label : map T.unpack trefs + void $ forkIO $ unwatchHead =<< takeMVar mvar + _ -> return () + putMVar mvar w + +cmdLocalStateWait :: Command +cmdLocalStateWait = localStateWaitHelper "local-state-wait" ((: []) . headStoredObject) + +cmdSharedStateGet :: Command +cmdSharedStateGet = do + h <- getHead + cmdOut $ unwords $ "shared-state-get" : map (show . refDigest . storedRef) (lsShared $ headObject h) + +cmdSharedStateWait :: Command +cmdSharedStateWait = localStateWaitHelper "shared-state-wait" (lsShared . headObject) + +cmdWatchLocalIdentity :: Command +cmdWatchLocalIdentity = do + h <- getOrLoadHead + Nothing <- gets tsWatchedLocalIdentity + + out <- asks tiOutput + w <- liftIO $ watchHeadWith h headLocalIdentity $ \idt -> do + outLine out $ unwords $ "local-identity" : map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners idt) + modify $ \s -> s { tsWatchedLocalIdentity = Just w } + +cmdWatchSharedIdentity :: Command +cmdWatchSharedIdentity = do + h <- getOrLoadHead + Nothing <- gets tsWatchedSharedIdentity + + out <- asks tiOutput + w <- liftIO $ watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \case + Just (idt :: ComposedIdentity) -> do + outLine out $ unwords $ "shared-identity" : map (maybe "<unnamed>" T.unpack . idName) (unfoldOwners idt) + Nothing -> do + outLine out $ "shared-identity-failed" + modify $ \s -> s { tsWatchedSharedIdentity = Just w } + +cmdUpdateLocalIdentity :: Command +cmdUpdateLocalIdentity = do + [name] <- asks tiParams + updateLocalState_ $ \ls -> do + Just identity <- return $ validateExtendedIdentity $ lsIdentity $ fromStored ls + let public = idKeyIdentity identity + + secret <- loadKey public + nidata <- maybe (error "created invalid identity") (return . idExtData) . validateExtendedIdentity =<< + mstore =<< sign secret =<< mstore . ExtendedIdentityData =<< return (emptyIdentityExtension $ idData identity) + { idePrev = toList $ idExtDataF identity + , ideName = Just name + } + mstore (fromStored ls) { lsIdentity = nidata } + +cmdUpdateSharedIdentity :: Command +cmdUpdateSharedIdentity = do + [name] <- asks tiParams + updateLocalState_ $ updateSharedState_ $ \case + Nothing -> throwOtherError "no existing shared identity" + Just identity -> do + let public = idKeyIdentity identity + secret <- loadKey public + uidentity <- mergeIdentity identity + maybe (error "created invalid identity") (return . Just . toComposedIdentity) . validateExtendedIdentity =<< + mstore =<< sign secret =<< mstore . ExtendedIdentityData =<< return (emptyIdentityExtension $ idData uidentity) + { idePrev = toList $ idExtDataF identity + , ideName = Just name + } + +cmdAttachTo :: Command +cmdAttachTo = do + [spidx] <- asks tiParams + attachToOwner =<< getPeer spidx + +cmdAttachAccept :: Command +cmdAttachAccept = do + [spidx] <- asks tiParams + attachAccept =<< getPeer spidx + +cmdAttachReject :: Command +cmdAttachReject = do + [spidx] <- asks tiParams + attachReject =<< getPeer spidx + +cmdContactRequest :: Command +cmdContactRequest = do + [spidx] <- asks tiParams + contactRequest =<< getPeer spidx + +cmdContactAccept :: Command +cmdContactAccept = do + [spidx] <- asks tiParams + contactAccept =<< getPeer spidx + +cmdContactReject :: Command +cmdContactReject = do + [spidx] <- asks tiParams + contactReject =<< getPeer spidx + +cmdContactList :: Command +cmdContactList = do + h <- getHead + let contacts = fromSetBy (comparing contactName) . lookupSharedValue . lsShared . headObject $ h + forM_ contacts $ \c -> do + r:_ <- return $ filterAncestors $ concatMap storedRoots $ toComponents c + cmdOut $ concat + [ "contact-list-item " + , show $ refDigest $ storedRef r + , " " + , T.unpack $ contactName c + , case contactIdentity c of Nothing -> ""; Just idt -> " " ++ T.unpack (displayIdentity idt) + ] + cmdOut "contact-list-done" + +getContact :: Text -> CommandM Contact +getContact cid = do + h <- getHead + let contacts = fromSetBy (comparing contactName) . lookupSharedValue . lsShared . headObject $ h + [contact] <- flip filterM contacts $ \c -> do + r:_ <- return $ filterAncestors $ concatMap storedRoots $ toComponents c + return $ T.pack (show $ refDigest $ storedRef r) == cid + return contact + +cmdContactSetName :: Command +cmdContactSetName = do + [cid, name] <- asks tiParams + contact <- getContact cid + updateLocalState_ $ updateSharedState_ $ contactSetName contact name + cmdOut "contact-set-name-done" + +cmdDmSendPeer :: Command +cmdDmSendPeer = do + [spidx, msg] <- asks tiParams + PeerIdentityFull to <- peerIdentity =<< getPeer spidx + void $ sendDirectMessage to msg + +cmdDmSendContact :: Command +cmdDmSendContact = do + [cid, msg] <- asks tiParams + Just to <- contactIdentity <$> getContact cid + void $ sendDirectMessage to msg + +cmdDmSendIdentity :: Command +cmdDmSendIdentity = do + st <- asks tiStorage + [ tid, msg ] <- asks tiParams + Just ref <- liftIO $ readRef st $ encodeUtf8 tid + Just to <- return $ validateExtendedIdentity $ wrappedLoad ref + void $ sendDirectMessage to msg + +dmList :: Foldable f => Identity f -> Command +dmList peer = do + threads <- toThreadList . lookupSharedValue . lsShared . headObject <$> getHead + case find (sameIdentity peer . msgPeer) threads of + Just thread -> do + forM_ (reverse $ threadToList thread) $ \DirectMessage {..} -> cmdOut $ "dm-list-item" + <> " from " <> (maybe "<unnamed>" T.unpack $ idName msgFrom) + <> " text " <> (T.unpack msgText) + Nothing -> return () + cmdOut "dm-list-done" + +cmdDmListPeer :: Command +cmdDmListPeer = do + [spidx] <- asks tiParams + PeerIdentityFull to <- peerIdentity =<< getPeer spidx + dmList to + +cmdDmListContact :: Command +cmdDmListContact = do + [cid] <- asks tiParams + Just to <- contactIdentity <$> getContact cid + dmList to + +cmdChatroomCreate :: Command +cmdChatroomCreate = do + [name] <- asks tiParams + room <- createChatroom (Just name) Nothing + cmdOut $ unwords $ "chatroom-create-done" : chatroomInfo room + +cmdChatroomDelete :: Command +cmdChatroomDelete = do + [ cid ] <- asks tiParams + sdata <- getChatroomStateData cid + deleteChatroomByStateData sdata + cmdOut $ unwords [ "chatroom-delete-done", T.unpack cid ] + +getChatroomStateData :: Text -> CommandM (Stored ChatroomStateData) +getChatroomStateData tref = do + st <- asks tiStorage + Just ref <- liftIO $ readRef st (encodeUtf8 tref) + return $ wrappedLoad ref + +cmdChatroomSetName :: Command +cmdChatroomSetName = do + [cid, name] <- asks tiParams + sdata <- getChatroomStateData cid + updateChatroomByStateData sdata (Just name) Nothing >>= \case + Just room -> cmdOut $ unwords $ "chatroom-set-name-done" : chatroomInfo room + Nothing -> cmdOut "chatroom-set-name-failed" + +cmdChatroomListLocal :: Command +cmdChatroomListLocal = do + [] <- asks tiParams + rooms <- listChatrooms + forM_ rooms $ \room -> do + cmdOut $ unwords $ "chatroom-list-item" : chatroomInfo room + cmdOut "chatroom-list-done" + +cmdChatroomWatchLocal :: Command +cmdChatroomWatchLocal = do + [] <- asks tiParams + h <- getOrLoadHead + out <- asks tiOutput + void $ watchChatrooms h $ \_ -> \case + Nothing -> return () + Just diff -> forM_ diff $ \case + AddedChatroom room -> outLine out $ unwords $ "chatroom-watched-added" : chatroomInfo room + RemovedChatroom room -> outLine out $ unwords $ "chatroom-watched-removed" : chatroomInfo room + UpdatedChatroom oldroom room -> do + when (any ((\rsd -> not (null (rsdRoom rsd)) || not (null (rsdSubscribe rsd))) . fromStored) (roomStateData room)) $ do + outLine out $ unwords $ concat + [ [ "chatroom-watched-updated" ], chatroomInfo room + , [ "old" ], map (show . refDigest . storedRef) (roomStateData oldroom) + , [ "new" ], map (show . refDigest . storedRef) (roomStateData room) + ] + when (any (not . null . rsdMessages . fromStored) (roomStateData room)) $ do + forM_ (reverse $ getMessagesSinceState room oldroom) $ \msg -> do + outLine out $ unwords $ concat + [ [ "chatroom-message-new" ] + , [ show . refDigest . storedRef . head . filterAncestors . concatMap storedRoots . toComponents $ room ] + , [ "room", maybe "<unnamed>" T.unpack $ roomName =<< cmsgRoom msg ] + , [ "from", maybe "<unnamed>" T.unpack $ idName $ cmsgFrom msg ] + , if cmsgLeave msg then [ "leave" ] else [] + , maybe [] (("text":) . (:[]) . T.unpack) $ cmsgText msg + ] + +chatroomInfo :: ChatroomState -> [String] +chatroomInfo room = + [ show . refDigest . storedRef . head . filterAncestors . concatMap storedRoots . toComponents $ room + , maybe "<unnamed>" T.unpack $ roomName =<< roomStateRoom room + , "sub " <> bool "false" "true" (roomStateSubscribe room) + ] + +cmdChatroomSubscribe :: Command +cmdChatroomSubscribe = do + [ cid ] <- asks tiParams + to <- getChatroomStateData cid + void $ chatroomSetSubscribe to True + +cmdChatroomUnsubscribe :: Command +cmdChatroomUnsubscribe = do + [ cid ] <- asks tiParams + to <- getChatroomStateData cid + void $ chatroomSetSubscribe to False + +cmdChatroomMembers :: Command +cmdChatroomMembers = do + [ cid ] <- asks tiParams + Just chatroom <- findChatroomByStateData =<< getChatroomStateData cid + forM_ (chatroomMembers chatroom) $ \user -> do + cmdOut $ unwords [ "chatroom-members-item", maybe "<unnamed>" T.unpack $ idName user ] + cmdOut "chatroom-members-done" + +cmdChatroomJoin :: Command +cmdChatroomJoin = do + [ cid ] <- asks tiParams + joinChatroomByStateData =<< getChatroomStateData cid + cmdOut "chatroom-join-done" + +cmdChatroomJoinAs :: Command +cmdChatroomJoinAs = do + [ cid, name ] <- asks tiParams + st <- asks tiStorage + identity <- liftIO $ createIdentity st (Just name) Nothing + joinChatroomAsByStateData identity =<< getChatroomStateData cid + cmdOut $ unwords [ "chatroom-join-as-done", T.unpack cid ] + +cmdChatroomLeave :: Command +cmdChatroomLeave = do + [ cid ] <- asks tiParams + leaveChatroomByStateData =<< getChatroomStateData cid + cmdOut "chatroom-leave-done" + +cmdChatroomMessageSend :: Command +cmdChatroomMessageSend = do + [cid, msg] <- asks tiParams + to <- getChatroomStateData cid + void $ sendChatroomMessageByStateData to msg + +cmdDiscoveryConnect :: Command +cmdDiscoveryConnect = do + [ tref ] <- asks tiParams + Just dgst <- return $ readRefDigest $ encodeUtf8 tref + Just RunningServer {..} <- gets tsServer + discoverySearch rsServer dgst + +cmdDiscoveryTunnel :: Command +cmdDiscoveryTunnel = do + [ tvia, ttarget ] <- asks tiParams + via <- getPeer tvia + Just target <- return $ readRefDigest $ encodeUtf8 ttarget + liftIO $ discoverySetupTunnel via target diff --git a/main/Test/Service.hs b/main/Test/Service.hs new file mode 100644 index 0000000..c0be07d --- /dev/null +++ b/main/Test/Service.hs @@ -0,0 +1,57 @@ +module Test.Service ( + TestMessage(..), + TestMessageAttributes(..), + + openTestStreams, +) where + +import Control.Monad +import Control.Monad.Reader + +import Data.ByteString.Lazy.Char8 qualified as BL + +import Erebos.Network +import Erebos.Object +import Erebos.Service +import Erebos.Service.Stream +import Erebos.Storable + +data TestMessage = TestMessage (Stored Object) + +data TestMessageAttributes = TestMessageAttributes + { testMessageReceived :: Object -> String -> String -> String -> ServiceHandler TestMessage () + , testStreamsReceived :: [ StreamReader ] -> ServiceHandler TestMessage () + } + +instance Storable TestMessage where + store' (TestMessage msg) = store' msg + load' = TestMessage <$> load' + +instance Service TestMessage where + serviceID _ = mkServiceID "cb46b92c-9203-4694-8370-8742d8ac9dc8" + + type ServiceAttributes TestMessage = TestMessageAttributes + defaultServiceAttributes _ = TestMessageAttributes + { testMessageReceived = \_ _ _ _ -> return () + , testStreamsReceived = \_ -> return () + } + + serviceHandler smsg = do + let TestMessage sobj = fromStored smsg + obj = fromStored sobj + case map BL.unpack $ BL.words $ BL.takeWhile (/='\n') $ serializeObject obj of + [otype, len] -> do + cb <- asks $ testMessageReceived . svcAttributes + cb obj otype len (show $ refDigest $ storedRef sobj) + _ -> return () + + streams <- receivedStreams + when (not $ null streams) $ do + cb <- asks $ testStreamsReceived . svcAttributes + cb streams + + +openTestStreams :: Int -> ServiceHandler TestMessage [ StreamWriter ] +openTestStreams count = do + replyPacket . TestMessage =<< mstore (Rec []) + replicateM count openStream diff --git a/main/Version.hs b/main/Version.hs new file mode 100644 index 0000000..71af694 --- /dev/null +++ b/main/Version.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE TemplateHaskell #-} + +-- "Pattern match is redundant" warning can be generated based on template +-- haskell $$tGitVersion value +{-# OPTIONS_GHC -Wno-error=overlapping-patterns #-} + +module Version ( + versionLine, +) where + +import Paths_erebos (version) +import Data.Version (showVersion) +import Version.Git + +{-# NOINLINE versionLine #-} +versionLine :: String +versionLine = do + let ver = case $$tGitVersion of + Just gver + | 'v':v <- gver, not $ all (`elem` ('.': ['0'..'9'])) v + -> "git " <> gver + _ -> "version " <> showVersion version + in "Erebos CLI " <> ver diff --git a/main/Version/Git.hs b/main/Version/Git.hs new file mode 100644 index 0000000..2aae6e3 --- /dev/null +++ b/main/Version/Git.hs @@ -0,0 +1,31 @@ +module Version.Git ( + tGitVersion, +) where + +import Language.Haskell.TH +import Language.Haskell.TH.Syntax + +import System.Directory +import System.Exit +import System.Process + +tGitVersion :: Code Q (Maybe String) +tGitVersion = unsafeCodeCoerce $ do + let git args = do + (ExitSuccess, out, _) <- readCreateProcessWithExitCode + (proc "git" $ [ "--git-dir=./.git", "--work-tree=." ] ++ args) "" + return $ lines out + + mbver <- runIO $ do + doesPathExist "./.git" >>= \case + False -> return Nothing + True -> do + desc:_ <- git [ "describe", "--always", "--dirty= (dirty)" ] + files <- git [ "ls-files" ] + return $ Just (desc, files) + + case mbver of + Just (_, files) -> mapM_ addDependentFile files + Nothing -> return () + + lift (fst <$> mbver :: Maybe String) diff --git a/main/WebSocket.hs b/main/WebSocket.hs new file mode 100644 index 0000000..5c49aa1 --- /dev/null +++ b/main/WebSocket.hs @@ -0,0 +1,46 @@ +module WebSocket ( + WebSocketAddress(..), + startWebsocketServer, +) where + +import Control.Concurrent +import Control.Exception +import Control.Monad + +import Data.ByteString.Lazy qualified as BL +import Data.Unique + +import Erebos.Network + +import Network.WebSockets qualified as WS + + +data WebSocketAddress = WebSocketAddress Unique WS.Connection + +instance Eq WebSocketAddress where + WebSocketAddress u _ == WebSocketAddress u' _ = u == u' + +instance Ord WebSocketAddress where + compare (WebSocketAddress u _) (WebSocketAddress u' _) = compare u u' + +instance Show WebSocketAddress where + show (WebSocketAddress _ _) = "websocket" + +instance PeerAddressType WebSocketAddress where + sendBytesToAddress (WebSocketAddress _ conn) msg = do + WS.sendDataMessage conn $ WS.Binary $ BL.fromStrict msg + +startWebsocketServer :: Server -> String -> Int -> (String -> IO ()) -> IO () +startWebsocketServer server addr port logd = do + void $ forkIO $ do + WS.runServer addr port $ \pending -> do + conn <- WS.acceptRequest pending + u <- newUnique + let paddr = WebSocketAddress u conn + void $ serverPeerCustom server paddr + handle (\(e :: SomeException) -> logd $ "WebSocket thread exception: " ++ show e) $ do + WS.withPingThread conn 30 (return ()) $ do + forever $ do + WS.receiveDataMessage conn >>= \case + WS.Binary msg -> receivedFromCustomAddress server paddr $ BL.toStrict msg + WS.Text {} -> logd $ "unexpected websocket text message" diff --git a/minici.yaml b/minici.yaml new file mode 100644 index 0000000..333878c --- /dev/null +++ b/minici.yaml @@ -0,0 +1,13 @@ +job build: + shell: + - cabal build -fci + - mkdir build + - cp $(cabal list-bin erebos) build/erebos + artifact erebos: + path: build/erebos + +job test: + uses: + - build.erebos + shell: + - EREBOS_TEST_TOOL='build/erebos test' erebos-tester -v diff --git a/src/Erebos/Attach.hs b/src/Erebos/Attach.hs new file mode 100644 index 0000000..fad6197 --- /dev/null +++ b/src/Erebos/Attach.hs @@ -0,0 +1,123 @@ +module Erebos.Attach ( + AttachService, + attachToOwner, + attachAccept, + attachReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.ByteArray (ScrubbedBytes) +import Data.Maybe +import Data.Proxy +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storable +import Erebos.Storage.Key + +type AttachService = PairingService AttachIdentity + +data AttachIdentity = AttachIdentity (Stored (Signed IdentityData)) [ScrubbedBytes] + +instance Storable AttachIdentity where + store' (AttachIdentity x keys) = storeRec $ do + storeRef "identity" x + mapM_ (storeBinary "skey") keys + + load' = loadRec $ AttachIdentity + <$> loadRef "identity" + <*> loadBinaries "skey" + +instance PairingResult AttachIdentity where + pairingServiceID _ = mkServiceID "4995a5f9-2d4d-48e9-ad3b-0bf1c2a1be7f" + + type PairingVerifiedResult AttachIdentity = (UnifiedIdentity, [ScrubbedBytes]) + + pairingVerifyResult (AttachIdentity sdata keys) = do + curid <- lsIdentity . fromStored <$> svcGetLocal + secret <- loadKey $ eiddKeyIdentity $ fromSigned curid + sdata' <- mstore =<< signAdd secret (fromStored sdata) + return $ do + guard $ iddKeyIdentity (fromSigned sdata) == + eiddKeyIdentity (fromSigned curid) + identity <- validateIdentity sdata' + guard $ iddPrev (fromSigned $ idData identity) == [eiddStoredBase curid] + return (identity, keys) + + pairingFinalizeRequest (identity, keys) = updateLocalState_ $ \slocal -> do + let owner = finalOwner identity + st <- getStorage + pkeys <- mapM (copyStored st) [ idKeyIdentity owner, idKeyMessage owner ] + liftIO $ mapM_ storeKey $ catMaybes [ keyFromData sec pub | sec <- keys, pub <- pkeys ] + + identity' <- mergeIdentity $ updateIdentity [ lsIdentity $ fromStored slocal ] identity + shared <- makeSharedStateUpdate st (Just owner) (lsShared $ fromStored slocal) + mstore (fromStored slocal) + { lsIdentity = idExtData identity' + , lsShared = [ shared ] + } + + pairingFinalizeResponse = do + owner <- mergeSharedIdentity + pid <- asks svcPeerIdentity + secret <- loadKey $ idKeyIdentity owner + identity <- mstore =<< sign secret =<< mstore (emptyIdentityData $ idKeyIdentity pid) + { iddPrev = [idData pid], iddOwner = Just (idData owner) } + skeys <- map keyGetData . catMaybes <$> mapM loadKeyMb [ idKeyIdentity owner, idKeyMessage owner ] + return $ AttachIdentity identity skeys + + defaultPairingAttributes _ = PairingAttributes + { pairingHookRequest = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ " initiated" + + , pairingHookResponse = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach to " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + + , pairingHookRequestNonce = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Attach from " ++ T.unpack (displayIdentity peer) ++ ": " ++ confirm + + , pairingHookRequestNonceFailed = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Failed attach from " ++ T.unpack (displayIdentity peer) + + , pairingHookConfirmedResponse = do + svcPrint $ "Confirmed peer, waiting for updated identity" + + , pairingHookConfirmedRequest = do + svcPrint $ "Attachment confirmed by peer" + + , pairingHookAcceptedResponse = do + svcPrint $ "Accepted updated identity" + + , pairingHookAcceptedRequest = do + svcPrint $ "Accepted new attached device, seding updated identity" + + , pairingHookVerifyFailed = do + svcPrint $ "Failed to verify new identity" + + , pairingHookRejected = do + svcPrint $ "Attachment rejected by peer" + + , pairingHookFailed = \_ -> do + svcPrint $ "Attachement failed" + } + +attachToOwner :: (MonadIO m, MonadError e m, FromErebosError e) => Peer -> m () +attachToOwner = pairingRequest @AttachIdentity Proxy + +attachAccept :: (MonadIO m, MonadError e m, FromErebosError e) => Peer -> m () +attachAccept = pairingAccept @AttachIdentity Proxy + +attachReject :: (MonadIO m, MonadError e m, FromErebosError e) => Peer -> m () +attachReject = pairingReject @AttachIdentity Proxy diff --git a/src/Erebos/Chatroom.hs b/src/Erebos/Chatroom.hs new file mode 100644 index 0000000..74456ff --- /dev/null +++ b/src/Erebos/Chatroom.hs @@ -0,0 +1,648 @@ +module Erebos.Chatroom ( + Chatroom(..), + ChatroomData(..), + validateChatroom, + + ChatroomState(..), + ChatroomStateData(..), + createChatroom, + deleteChatroomByStateData, + updateChatroomByStateData, + listChatrooms, + findChatroomByRoomData, + findChatroomByStateData, + chatroomSetSubscribe, + chatroomMembers, + joinChatroom, joinChatroomByStateData, + joinChatroomAs, joinChatroomAsByStateData, + leaveChatroom, leaveChatroomByStateData, + getMessagesSinceState, + + ChatroomSetChange(..), + watchChatrooms, + + ChatMessage, + cmsgFrom, cmsgReplyTo, cmsgTime, cmsgText, cmsgLeave, + cmsgRoom, cmsgRoomData, + ChatMessageData(..), + sendChatroomMessage, + sendChatroomMessageByStateData, + + ChatroomService(..), +) where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Data.Bool +import Data.Either +import Data.Foldable +import Data.Function +import Data.IORef +import Data.List +import Data.Maybe +import Data.Monoid +import Data.Ord +import Data.Set qualified as S +import Data.Text (Text) +import Data.Time + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storable +import Erebos.Storage.Head +import Erebos.Storage.Merge +import Erebos.Util + + +data ChatroomData = ChatroomData + { rdPrev :: [Stored (Signed ChatroomData)] + , rdName :: Maybe Text + , rdDescription :: Maybe Text + , rdKey :: Stored PublicKey + } + +data Chatroom = Chatroom + { roomData :: [Stored (Signed ChatroomData)] + , roomName :: Maybe Text + , roomDescription :: Maybe Text + , roomKey :: Stored PublicKey + } + +instance Storable ChatroomData where + store' ChatroomData {..} = storeRec $ do + mapM_ (storeRef "SPREV") rdPrev + storeMbText "name" rdName + storeMbText "description" rdDescription + storeRef "key" rdKey + + load' = loadRec $ do + rdPrev <- loadRefs "SPREV" + rdName <- loadMbText "name" + rdDescription <- loadMbText "description" + rdKey <- loadRef "key" + return ChatroomData {..} + +validateChatroom :: [Stored (Signed ChatroomData)] -> Except String Chatroom +validateChatroom roomData = do + when (null roomData) $ throwError "null data" + when (not $ getAll $ walkAncestors verifySignatures roomData) $ do + throwError "signature verification failed" + + let roomName = findPropertyFirst (rdName . fromStored . signedData) roomData + roomDescription = findPropertyFirst (rdDescription . fromStored . signedData) roomData + roomKey <- maybe (throwError "missing key") return $ + findPropertyFirst (Just . rdKey . fromStored . signedData) roomData + return Chatroom {..} + where + verifySignatures sdata = + let rdata = fromSigned sdata + required = concat + [ [ rdKey rdata ] + , map (rdKey . fromSigned) $ rdPrev rdata + ] + in All $ all (fromStored sdata `isSignedBy`) required + + +data ChatMessageData = ChatMessageData + { mdPrev :: [Stored (Signed ChatMessageData)] + , mdRoom :: [Stored (Signed ChatroomData)] + , mdFrom :: ComposedIdentity + , mdReplyTo :: Maybe (Stored (Signed ChatMessageData)) + , mdTime :: ZonedTime + , mdText :: Maybe Text + , mdLeave :: Bool + } + +data ChatMessage = ChatMessage + { cmsgData :: Stored (Signed ChatMessageData) + } + +validateSingleMessage :: Stored (Signed ChatMessageData) -> Maybe ChatMessage +validateSingleMessage sdata = do + guard $ fromStored sdata `isSignedBy` idKeyMessage (mdFrom (fromSigned sdata)) + return $ ChatMessage sdata + +cmsgFrom :: ChatMessage -> ComposedIdentity +cmsgFrom = mdFrom . fromSigned . cmsgData + +cmsgReplyTo :: ChatMessage -> Maybe ChatMessage +cmsgReplyTo = fmap ChatMessage . mdReplyTo . fromSigned . cmsgData + +cmsgTime :: ChatMessage -> ZonedTime +cmsgTime = mdTime . fromSigned . cmsgData + +cmsgText :: ChatMessage -> Maybe Text +cmsgText = mdText . fromSigned . cmsgData + +cmsgLeave :: ChatMessage -> Bool +cmsgLeave = mdLeave . fromSigned . cmsgData + +cmsgRoom :: ChatMessage -> Maybe Chatroom +cmsgRoom = either (const Nothing) Just . runExcept . validateChatroom . cmsgRoomData + +cmsgRoomData :: ChatMessage -> [ Stored (Signed ChatroomData) ] +cmsgRoomData = concat . findProperty ((\case [] -> Nothing; xs -> Just xs) . mdRoom . fromStored . signedData) . (: []) . cmsgData + +instance Storable ChatMessageData where + store' ChatMessageData {..} = storeRec $ do + mapM_ (storeRef "SPREV") mdPrev + mapM_ (storeRef "room") mdRoom + mapM_ (storeRef "from") $ idExtDataF mdFrom + storeMbRef "reply-to" mdReplyTo + storeDate "time" mdTime + storeMbText "text" mdText + when mdLeave $ storeEmpty "leave" + + load' = loadRec $ do + mdPrev <- loadRefs "SPREV" + mdRoom <- loadRefs "room" + mdFrom <- loadIdentity "from" + mdReplyTo <- loadMbRef "reply-to" + mdTime <- loadDate "time" + mdText <- loadMbText "text" + mdLeave <- isJust <$> loadMbEmpty "leave" + return ChatMessageData {..} + +threadToListSince :: [ Stored (Signed ChatMessageData) ] -> [ Stored (Signed ChatMessageData) ] -> [ ChatMessage ] +threadToListSince since thread = helper (S.fromList since) thread + where + helper :: S.Set (Stored (Signed ChatMessageData)) -> [Stored (Signed ChatMessageData)] -> [ChatMessage] + helper seen msgs + | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = + maybe id (:) (validateSingleMessage msg) $ + helper (S.insert msg seen) (msgs' ++ mdPrev (fromSigned msg)) + | otherwise = [] + cmpView msg = (zonedTimeToUTC $ mdTime $ fromSigned msg, msg) + +sendChatroomMessage + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => ChatroomState -> Text -> m () +sendChatroomMessage rstate msg = sendChatroomMessageByStateData (head $ roomStateData rstate) msg + +sendChatroomMessageByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => Stored ChatroomStateData -> Text -> m () +sendChatroomMessageByStateData lookupData msg = sendRawChatroomMessageByStateData lookupData Nothing Nothing (Just msg) False + +sendRawChatroomMessageByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => Stored ChatroomStateData -> Maybe UnifiedIdentity -> Maybe (Stored (Signed ChatMessageData)) -> Maybe Text -> Bool -> m () +sendRawChatroomMessageByStateData lookupData mbIdentity mdReplyTo mdText mdLeave = void $ findAndUpdateChatroomState $ \cstate -> do + guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate + Just $ do + mdFrom <- finalOwner <$> if + | Just identity <- mbIdentity -> return identity + | Just identity <- roomStateIdentity cstate -> return identity + | otherwise -> localIdentity . fromStored <$> getLocalHead + secret <- loadKey $ idKeyMessage mdFrom + mdTime <- liftIO getZonedTime + let mdPrev = roomStateMessageData cstate + mdRoom = if null (roomStateMessageData cstate) + then maybe [] roomData (roomStateRoom cstate) + else [] + + mdata <- mstore =<< sign secret =<< mstore ChatMessageData {..} + mergeSorted . (:[]) <$> mstore emptyChatroomStateData + { rsdPrev = roomStateData cstate + , rsdSubscribe = Just (not mdLeave) + , rsdIdentity = mbIdentity + , rsdMessages = [ mdata ] + } + + +data ChatroomStateData = ChatroomStateData + { rsdPrev :: [Stored ChatroomStateData] + , rsdRoom :: [Stored (Signed ChatroomData)] + , rsdDelete :: Bool + , rsdSubscribe :: Maybe Bool + , rsdIdentity :: Maybe UnifiedIdentity + , rsdMessages :: [Stored (Signed ChatMessageData)] + } + +emptyChatroomStateData :: ChatroomStateData +emptyChatroomStateData = ChatroomStateData + { rsdPrev = [] + , rsdRoom = [] + , rsdDelete = False + , rsdSubscribe = Nothing + , rsdIdentity = Nothing + , rsdMessages = [] + } + +data ChatroomState = ChatroomState + { roomStateData :: [Stored ChatroomStateData] + , roomStateRoom :: Maybe Chatroom + , roomStateMessageData :: [Stored (Signed ChatMessageData)] + , roomStateDeleted :: Bool + , roomStateSubscribe :: Bool + , roomStateIdentity :: Maybe UnifiedIdentity + , roomStateMessages :: [ChatMessage] + } + +instance Storable ChatroomStateData where + store' ChatroomStateData {..} = storeRec $ do + forM_ rsdPrev $ storeRef "PREV" + forM_ rsdRoom $ storeRef "room" + when rsdDelete $ storeEmpty "delete" + forM_ rsdSubscribe $ storeInt "subscribe" . bool @Int 0 1 + forM_ rsdIdentity $ storeRef "id" . idExtData + forM_ rsdMessages $ storeRef "msg" + + load' = loadRec $ do + rsdPrev <- loadRefs "PREV" + rsdRoom <- loadRefs "room" + rsdDelete <- isJust <$> loadMbEmpty "delete" + rsdSubscribe <- fmap ((/=) @Int 0) <$> loadMbInt "subscribe" + rsdIdentity <- loadMbUnifiedIdentity "id" + rsdMessages <- loadRefs "msg" + return ChatroomStateData {..} + +instance Mergeable ChatroomState where + type Component ChatroomState = ChatroomStateData + + mergeSorted roomStateData = + let roomStateRoom = either (const Nothing) Just $ runExcept $ + validateChatroom $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . rsdRoom) roomStateData + roomStateMessageData = filterAncestors $ concat $ flip findProperty roomStateData $ \case + ChatroomStateData {..} | null rsdMessages -> Nothing + | otherwise -> Just rsdMessages + roomStateDeleted = any (rsdDelete . fromStored) roomStateData + roomStateSubscribe = not roomStateDeleted && (fromMaybe False $ findPropertyFirst rsdSubscribe roomStateData) + roomStateIdentity = findPropertyFirst rsdIdentity roomStateData + roomStateMessages = threadToListSince [] $ concatMap (rsdMessages . fromStored) roomStateData + in ChatroomState {..} + + toComponents = roomStateData + +instance SharedType (Set ChatroomState) where + sharedTypeID _ = mkSharedTypeID "7bc71cbf-bc43-42b1-b413-d3a2c9a2aae0" + +createChatroom :: (MonadStorage m, MonadHead LocalState m, MonadIO m, MonadError e m, FromErebosError e) => Maybe Text -> Maybe Text -> m ChatroomState +createChatroom rdName rdDescription = do + (secret, rdKey) <- liftIO . generateKeys =<< getStorage + let rdPrev = [] + rdata <- mstore =<< sign secret =<< mstore ChatroomData {..} + cstate <- mergeSorted . (:[]) <$> mstore emptyChatroomStateData + { rsdRoom = [ rdata ] + , rsdSubscribe = Just True + } + + updateLocalState $ updateSharedState $ \rooms -> do + st <- getStorage + (, cstate) <$> storeSetAdd st cstate rooms + +findAndUpdateChatroomState + :: (MonadStorage m, MonadHead LocalState m) + => (ChatroomState -> Maybe (m ChatroomState)) + -> m (Maybe ChatroomState) +findAndUpdateChatroomState f = do + updateLocalState $ updateSharedState $ \roomSet -> do + let roomList = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet + case catMaybes $ map (\x -> (x,) <$> f x) roomList of + ((orig, act) : _) -> do + upd <- act + if roomStateData orig /= roomStateData upd + then do + st <- getStorage + roomSet' <- storeSetAdd st upd roomSet + return (roomSet', Just upd) + else do + return (roomSet, Just upd) + [] -> return (roomSet, Nothing) + +deleteChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => Stored ChatroomStateData -> m () +deleteChatroomByStateData lookupData = void $ findAndUpdateChatroomState $ \cstate -> do + guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate + Just $ do + mergeSorted . (:[]) <$> mstore emptyChatroomStateData + { rsdPrev = roomStateData cstate + , rsdDelete = True + } + +updateChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => Stored ChatroomStateData + -> Maybe Text + -> Maybe Text + -> m (Maybe ChatroomState) +updateChatroomByStateData lookupData newName newDesc = findAndUpdateChatroomState $ \cstate -> do + guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate + room <- roomStateRoom cstate + Just $ do + secret <- loadKey $ roomKey room + rdata <- mstore =<< sign secret =<< mstore ChatroomData + { rdPrev = roomData room + , rdName = newName + , rdDescription = newDesc + , rdKey = roomKey room + } + mergeSorted . (:[]) <$> mstore emptyChatroomStateData + { rsdPrev = roomStateData cstate + , rsdRoom = [ rdata ] + , rsdSubscribe = Just True + } + + +listChatrooms :: MonadHead LocalState m => m [ChatroomState] +listChatrooms = filter (not . roomStateDeleted) . + fromSetBy (comparing $ roomName <=< roomStateRoom) . + lookupSharedValue . lsShared . fromStored <$> getLocalHead + +findChatroom :: MonadHead LocalState m => (ChatroomState -> Bool) -> m (Maybe ChatroomState) +findChatroom p = do + list <- map snd . chatroomSetToList . lookupSharedValue . lsShared . fromStored <$> getLocalHead + return $ find p list + +findChatroomByRoomData :: MonadHead LocalState m => Stored (Signed ChatroomData) -> m (Maybe ChatroomState) +findChatroomByRoomData cdata = findChatroom $ + maybe False (any (cdata `precedesOrEquals`) . roomData) . roomStateRoom + +findChatroomByStateData :: MonadHead LocalState m => Stored ChatroomStateData -> m (Maybe ChatroomState) +findChatroomByStateData cdata = findChatroom $ any (cdata `precedesOrEquals`) . roomStateData + +chatroomSetSubscribe + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => Stored ChatroomStateData -> Bool -> m () +chatroomSetSubscribe lookupData subscribe = void $ findAndUpdateChatroomState $ \cstate -> do + guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate + Just $ do + mergeSorted . (:[]) <$> mstore emptyChatroomStateData + { rsdPrev = roomStateData cstate + , rsdSubscribe = Just subscribe + } + +chatroomMembers :: ChatroomState -> [ ComposedIdentity ] +chatroomMembers ChatroomState {..} = + map (mdFrom . fromSigned . head) $ + filter (any $ not . mdLeave . fromSigned) $ -- keep only users that hasn't left + map (filterAncestors . map snd) $ -- gather message data per each identity and filter ancestors + groupBy ((==) `on` fst) $ -- group on identity root + sortBy (comparing fst) $ -- sort by first root of identity data + map (\x -> ( head . filterAncestors . concatMap storedRoots . idDataF . mdFrom . fromSigned $ x, x )) $ + toList $ ancestors $ roomStateMessageData + +joinChatroom + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => ChatroomState -> m () +joinChatroom rstate = joinChatroomByStateData (head $ roomStateData rstate) + +joinChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => Stored ChatroomStateData -> m () +joinChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing Nothing False + +joinChatroomAs + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => UnifiedIdentity -> ChatroomState -> m () +joinChatroomAs identity rstate = joinChatroomAsByStateData identity (head $ roomStateData rstate) + +joinChatroomAsByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => UnifiedIdentity -> Stored ChatroomStateData -> m () +joinChatroomAsByStateData identity lookupData = sendRawChatroomMessageByStateData lookupData (Just identity) Nothing Nothing False + +leaveChatroom + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => ChatroomState -> m () +leaveChatroom rstate = leaveChatroomByStateData (head $ roomStateData rstate) + +leaveChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError e m, FromErebosError e) + => Stored ChatroomStateData -> m () +leaveChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing Nothing True + +getMessagesSinceState :: ChatroomState -> ChatroomState -> [ChatMessage] +getMessagesSinceState cur old = threadToListSince (roomStateMessageData old) (roomStateMessageData cur) + + +data ChatroomSetChange = AddedChatroom ChatroomState + | RemovedChatroom ChatroomState + | UpdatedChatroom ChatroomState ChatroomState + +watchChatrooms :: MonadIO m => Head LocalState -> (Set ChatroomState -> Maybe [ChatroomSetChange] -> IO ()) -> m WatchedHead +watchChatrooms h f = liftIO $ do + lastVar <- newIORef Nothing + watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \cur -> do + let curList = chatroomSetToList cur + mbLast <- readIORef lastVar + writeIORef lastVar $ Just curList + f cur $ do + lastList <- mbLast + return $ makeChatroomDiff lastList curList + +chatroomSetToList :: Set ChatroomState -> [(Stored ChatroomStateData, ChatroomState)] +chatroomSetToList = map (cmp &&& id) . filter (not . roomStateDeleted) . fromSetBy (comparing cmp) + where + cmp :: ChatroomState -> Stored ChatroomStateData + cmp = head . filterAncestors . concatMap storedRoots . toComponents + +makeChatroomDiff + :: [(Stored ChatroomStateData, ChatroomState)] + -> [(Stored ChatroomStateData, ChatroomState)] + -> [ChatroomSetChange] +makeChatroomDiff (x@(cx, vx) : xs) (y@(cy, vy) : ys) + | cx < cy = RemovedChatroom vx : makeChatroomDiff xs (y : ys) + | cx > cy = AddedChatroom vy : makeChatroomDiff (x : xs) ys + | roomStateData vx /= roomStateData vy = UpdatedChatroom vx vy : makeChatroomDiff xs ys + | otherwise = makeChatroomDiff xs ys +makeChatroomDiff xs [] = map (RemovedChatroom . snd) xs +makeChatroomDiff [] ys = map (AddedChatroom . snd) ys + + +data ChatroomService = ChatroomService + { chatRoomQuery :: Bool + , chatRoomInfo :: [Stored (Signed ChatroomData)] + , chatRoomSubscribe :: [Stored (Signed ChatroomData)] + , chatRoomUnsubscribe :: [Stored (Signed ChatroomData)] + , chatRoomMessage :: [Stored (Signed ChatMessageData)] + } + deriving (Eq) + +emptyPacket :: ChatroomService +emptyPacket = ChatroomService + { chatRoomQuery = False + , chatRoomInfo = [] + , chatRoomSubscribe = [] + , chatRoomUnsubscribe = [] + , chatRoomMessage = [] + } + +instance Storable ChatroomService where + store' ChatroomService {..} = storeRec $ do + when chatRoomQuery $ storeEmpty "room-query" + forM_ chatRoomInfo $ storeRef "room-info" + forM_ chatRoomSubscribe $ storeRef "room-subscribe" + forM_ chatRoomUnsubscribe $ storeRef "room-unsubscribe" + forM_ chatRoomMessage $ storeRef "room-message" + + load' = loadRec $ do + chatRoomQuery <- isJust <$> loadMbEmpty "room-query" + chatRoomInfo <- loadRefs "room-info" + chatRoomSubscribe <- loadRefs "room-subscribe" + chatRoomUnsubscribe <- loadRefs "room-unsubscribe" + chatRoomMessage <- loadRefs "room-message" + return ChatroomService {..} + +data PeerState = PeerState + { psSendRoomUpdates :: Bool + , psLastList :: [(Stored ChatroomStateData, ChatroomState)] + , psSubscribedTo :: [ Stored (Signed ChatroomData) ] -- least root for each room + } + +instance Service ChatroomService where + serviceID _ = mkServiceID "627657ae-3e39-468a-8381-353395ef4386" + + type ServiceState ChatroomService = PeerState + emptyServiceState _ = PeerState + { psSendRoomUpdates = False + , psLastList = [] + , psSubscribedTo = [] + } + + serviceHandler spacket = do + let ChatroomService {..} = fromStored spacket + + previouslyUpdated <- psSendRoomUpdates <$> svcGet + svcModify $ \s -> s { psSendRoomUpdates = True } + + when (not previouslyUpdated) $ do + syncChatroomsToPeer . lookupSharedValue . lsShared . fromStored =<< getLocalHead + + when chatRoomQuery $ do + rooms <- listChatrooms + replyPacket emptyPacket + { chatRoomInfo = concatMap roomData $ catMaybes $ map roomStateRoom rooms + } + + when (not $ null chatRoomInfo) $ do + updateLocalState_ $ updateSharedState_ $ \roomSet -> do + let rooms = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet + upd set (roomInfo :: Stored (Signed ChatroomData)) = do + let currentRoots = storedRoots roomInfo + isCurrentRoom = any ((`intersectsSorted` currentRoots) . storedRoots) . + maybe [] roomData . roomStateRoom + + let prev = concatMap roomStateData $ filter isCurrentRoom rooms + prevRoom = filterAncestors $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . rsdRoom) prev + room = filterAncestors $ (roomInfo : ) prevRoom + + -- update local state only if we got roomInfo not present there + if roomInfo `notElem` prevRoom && roomInfo `elem` room + then do + sdata <- mstore emptyChatroomStateData + { rsdPrev = prev + , rsdRoom = room + } + storeSetAddComponent sdata set + else return set + foldM upd roomSet chatRoomInfo + + forM_ chatRoomSubscribe $ \subscribeData -> do + mbRoomState <- findChatroomByRoomData subscribeData + forM_ mbRoomState $ \roomState -> + forM (roomStateRoom roomState) $ \room -> do + let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room + svcModify $ \ps -> ps { psSubscribedTo = leastRoot : psSubscribedTo ps } + replyPacket emptyPacket + { chatRoomMessage = roomStateMessageData roomState + } + + forM_ chatRoomUnsubscribe $ \unsubscribeData -> do + mbRoomState <- findChatroomByRoomData unsubscribeData + forM_ (mbRoomState >>= roomStateRoom) $ \room -> do + let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room + svcModify $ \ps -> ps { psSubscribedTo = filter (/= leastRoot) (psSubscribedTo ps) } + + when (not (null chatRoomMessage)) $ do + updateLocalState_ $ updateSharedState_ $ \roomSet -> do + let rooms = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet + upd set (msgData :: Stored (Signed ChatMessageData)) + | Just msg <- validateSingleMessage msgData = do + let roomInfo = cmsgRoomData msg + currentRoots = filterAncestors $ concatMap storedRoots roomInfo + isCurrentRoom = any ((`intersectsSorted` currentRoots) . storedRoots) . + maybe [] roomData . roomStateRoom + + let prevData = concatMap roomStateData $ filter isCurrentRoom rooms + prev = mergeSorted prevData + prevMessages = roomStateMessageData prev + messages = filterAncestors $ msgData : prevMessages + + -- update local state only if subscribed and we got some new messages + if roomStateSubscribe prev && messages /= prevMessages + then do + sdata <- mstore emptyChatroomStateData + { rsdPrev = prevData + , rsdMessages = messages + } + storeSetAddComponent sdata set + else return set + | otherwise = return set + foldM upd roomSet chatRoomMessage + + serviceNewPeer = do + replyPacket emptyPacket { chatRoomQuery = True } + + serviceStorageWatchers _ = (:[]) $ + SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncChatroomsToPeer + +syncChatroomsToPeer :: Set ChatroomState -> ServiceHandler ChatroomService () +syncChatroomsToPeer set = do + ps@PeerState {..} <- svcGet + when psSendRoomUpdates $ do + let curList = chatroomSetToList set + diff = makeChatroomDiff psLastList curList + + roomUpdates <- fmap (concat . catMaybes) $ + forM diff $ return . \case + AddedChatroom room -> roomData <$> roomStateRoom room + RemovedChatroom {} -> Nothing + UpdatedChatroom oldroom room + | roomStateData oldroom /= roomStateData room -> roomData <$> roomStateRoom room + | otherwise -> Nothing + + (subscribe, unsubscribe) <- fmap (partitionEithers . concat . catMaybes) $ + forM diff $ return . \case + AddedChatroom room + | roomStateSubscribe room + -> map Left . roomData <$> roomStateRoom room + RemovedChatroom oldroom + | roomStateSubscribe oldroom + -> map Right . roomData <$> roomStateRoom oldroom + UpdatedChatroom oldroom room + | roomStateSubscribe oldroom /= roomStateSubscribe room + -> map (if roomStateSubscribe room then Left else Right) . roomData <$> roomStateRoom room + _ -> Nothing + + messages <- fmap concat $ do + let leastRootFor = head . filterAncestors . concatMap storedRoots . roomData + forM diff $ return . \case + AddedChatroom rstate + | Just room <- roomStateRoom rstate + , leastRootFor room `elem` psSubscribedTo + -> roomStateMessageData rstate + UpdatedChatroom oldstate rstate + | Just room <- roomStateRoom rstate + , leastRootFor room `elem` psSubscribedTo + , roomStateMessageData oldstate /= roomStateMessageData rstate + -> roomStateMessageData rstate + _ -> [] + + let packet = emptyPacket + { chatRoomInfo = roomUpdates + , chatRoomSubscribe = subscribe + , chatRoomUnsubscribe = unsubscribe + , chatRoomMessage = messages + } + + when (packet /= emptyPacket) $ do + replyPacket packet + svcSet $ ps { psLastList = curList } diff --git a/src/Erebos/Contact.hs b/src/Erebos/Contact.hs new file mode 100644 index 0000000..88e6c44 --- /dev/null +++ b/src/Erebos/Contact.hs @@ -0,0 +1,175 @@ +module Erebos.Contact ( + Contact, + contactIdentity, + contactCustomName, + contactName, + + contactSetName, + + ContactService, + contactRequest, + contactAccept, + contactReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.Maybe +import Data.Proxy +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.Identity +import Erebos.Network +import Erebos.Pairing +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storable +import Erebos.Storage.Merge + +data Contact = Contact + { contactData :: [Stored ContactData] + , contactIdentity_ :: Maybe ComposedIdentity + , contactCustomName_ :: Maybe Text + } + +data ContactData = ContactData + { cdPrev :: [Stored ContactData] + , cdIdentity :: [Stored (Signed ExtendedIdentityData)] + , cdName :: Maybe Text + } + +instance Storable ContactData where + store' x = storeRec $ do + mapM_ (storeRef "PREV") $ cdPrev x + mapM_ (storeRef "identity") $ cdIdentity x + storeMbText "name" $ cdName x + + load' = loadRec $ ContactData + <$> loadRefs "PREV" + <*> loadRefs "identity" + <*> loadMbText "name" + +instance Mergeable Contact where + type Component Contact = ContactData + + mergeSorted cdata = Contact + { contactData = cdata + , contactIdentity_ = validateExtendedIdentityF $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . cdIdentity) cdata + , contactCustomName_ = findPropertyFirst cdName cdata + } + + toComponents = contactData + +instance SharedType (Set Contact) where + sharedTypeID _ = mkSharedTypeID "34fbb61e-6022-405f-b1b3-a5a1abecd25e" + +contactIdentity :: Contact -> Maybe ComposedIdentity +contactIdentity = contactIdentity_ + +contactCustomName :: Contact -> Maybe Text +contactCustomName = contactCustomName_ + +contactName :: Contact -> Text +contactName c = fromJust $ msum + [ contactCustomName c + , idName =<< contactIdentity c + , Just T.empty + ] + +contactSetName :: MonadHead LocalState m => Contact -> Text -> Set Contact -> m (Set Contact) +contactSetName contact name set = do + st <- getStorage + cdata <- wrappedStore st ContactData + { cdPrev = toComponents contact + , cdIdentity = [] + , cdName = Just name + } + storeSetAdd st (mergeSorted @Contact [cdata]) set + + +type ContactService = PairingService ContactAccepted + +data ContactAccepted = ContactAccepted + +instance Storable ContactAccepted where + store' ContactAccepted = storeRec $ do + storeText "accept" "" + load' = loadRec $ do + (_ :: T.Text) <- loadText "accept" + return ContactAccepted + +instance PairingResult ContactAccepted where + pairingServiceID _ = mkServiceID "d9c37368-0da1-4280-93e9-d9bd9a198084" + + pairingVerifyResult = return . Just + + pairingFinalizeRequest ContactAccepted = do + pid <- asks svcPeerIdentity + finalizeContact pid + + pairingFinalizeResponse = do + pid <- asks svcPeerIdentity + finalizeContact pid + return ContactAccepted + + defaultPairingAttributes _ = PairingAttributes + { pairingHookRequest = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Contact pairing from " ++ T.unpack (displayIdentity peer) ++ " initiated" + + , pairingHookResponse = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Confirm contact " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + + , pairingHookRequestNonce = \confirm -> do + peer <- asks $ svcPeerIdentity + svcPrint $ "Contact request from " ++ T.unpack (displayIdentity $ finalOwner peer) ++ ": " ++ confirm + + , pairingHookRequestNonceFailed = do + peer <- asks $ svcPeerIdentity + svcPrint $ "Failed contact request from " ++ T.unpack (displayIdentity peer) + + , pairingHookConfirmedResponse = do + svcPrint $ "Contact accepted, waiting for peer confirmation" + + , pairingHookConfirmedRequest = do + svcPrint $ "Contact confirmed by peer" + + , pairingHookAcceptedResponse = do + svcPrint $ "Contact accepted" + + , pairingHookAcceptedRequest = do + svcPrint $ "Contact accepted" + + , pairingHookVerifyFailed = return () + + , pairingHookRejected = do + svcPrint $ "Contact rejected by peer" + + , pairingHookFailed = \_ -> do + svcPrint $ "Contact failed" + } + +contactRequest :: (MonadIO m, MonadError e m, FromErebosError e) => Peer -> m () +contactRequest = pairingRequest @ContactAccepted Proxy + +contactAccept :: (MonadIO m, MonadError e m, FromErebosError e) => Peer -> m () +contactAccept = pairingAccept @ContactAccepted Proxy + +contactReject :: (MonadIO m, MonadError e m, FromErebosError e) => Peer -> m () +contactReject = pairingReject @ContactAccepted Proxy + +finalizeContact :: MonadHead LocalState m => UnifiedIdentity -> m () +finalizeContact identity = updateLocalState_ $ updateSharedState_ $ \contacts -> do + st <- getStorage + cdata <- wrappedStore st ContactData + { cdPrev = [] + , cdIdentity = idExtDataF $ finalOwner identity + , cdName = Nothing + } + storeSetAdd st (mergeSorted @Contact [cdata]) contacts diff --git a/src/Erebos/Conversation.hs b/src/Erebos/Conversation.hs new file mode 100644 index 0000000..187fddd --- /dev/null +++ b/src/Erebos/Conversation.hs @@ -0,0 +1,110 @@ +module Erebos.Conversation ( + Message, + messageFrom, + messageTime, + messageText, + messageUnread, + formatMessage, + + Conversation, + directMessageConversation, + chatroomConversation, + chatroomConversationByStateData, + reloadConversation, + lookupConversations, + + conversationName, + conversationPeer, + conversationHistory, + + sendMessage, + deleteConversation, +) where + +import Control.Monad.Except + +import Data.List +import Data.Maybe +import Data.Text (Text) +import Data.Text qualified as T +import Data.Time.Format +import Data.Time.LocalTime + +import Erebos.Chatroom +import Erebos.DirectMessage +import Erebos.Identity +import Erebos.State +import Erebos.Storable + + +data Message = DirectMessageMessage DirectMessage Bool + | ChatroomMessage ChatMessage Bool + +messageFrom :: Message -> ComposedIdentity +messageFrom (DirectMessageMessage msg _) = msgFrom msg +messageFrom (ChatroomMessage msg _) = cmsgFrom msg + +messageTime :: Message -> ZonedTime +messageTime (DirectMessageMessage msg _) = msgTime msg +messageTime (ChatroomMessage msg _) = cmsgTime msg + +messageText :: Message -> Maybe Text +messageText (DirectMessageMessage msg _) = Just $ msgText msg +messageText (ChatroomMessage msg _) = cmsgText msg + +messageUnread :: Message -> Bool +messageUnread (DirectMessageMessage _ unread) = unread +messageUnread (ChatroomMessage _ unread) = unread + +formatMessage :: TimeZone -> Message -> String +formatMessage tzone msg = concat + [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ messageTime msg + , maybe "<unnamed>" T.unpack $ idName $ messageFrom msg + , maybe "" ((": "<>) . T.unpack) $ messageText msg + ] + + +data Conversation = DirectMessageConversation DirectMessageThread + | ChatroomConversation ChatroomState + +directMessageConversation :: MonadHead LocalState m => ComposedIdentity -> m Conversation +directMessageConversation peer = do + (find (sameIdentity peer . msgPeer) . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead) >>= \case + Just thread -> return $ DirectMessageConversation thread + Nothing -> return $ DirectMessageConversation $ DirectMessageThread peer [] [] [] [] + +chatroomConversation :: MonadHead LocalState m => ChatroomState -> m (Maybe Conversation) +chatroomConversation rstate = chatroomConversationByStateData (head $ roomStateData rstate) + +chatroomConversationByStateData :: MonadHead LocalState m => Stored ChatroomStateData -> m (Maybe Conversation) +chatroomConversationByStateData sdata = fmap ChatroomConversation <$> findChatroomByStateData sdata + +reloadConversation :: MonadHead LocalState m => Conversation -> m Conversation +reloadConversation (DirectMessageConversation thread) = directMessageConversation (msgPeer thread) +reloadConversation cur@(ChatroomConversation rstate) = + fromMaybe cur <$> chatroomConversation rstate + +lookupConversations :: MonadHead LocalState m => m [Conversation] +lookupConversations = map DirectMessageConversation . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead + + +conversationName :: Conversation -> Text +conversationName (DirectMessageConversation thread) = fromMaybe (T.pack "<unnamed>") $ idName $ msgPeer thread +conversationName (ChatroomConversation rstate) = fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate + +conversationPeer :: Conversation -> Maybe ComposedIdentity +conversationPeer (DirectMessageConversation thread) = Just $ msgPeer thread +conversationPeer (ChatroomConversation _) = Nothing + +conversationHistory :: Conversation -> [Message] +conversationHistory (DirectMessageConversation thread) = map (\msg -> DirectMessageMessage msg False) $ threadToList thread +conversationHistory (ChatroomConversation rstate) = map (\msg -> ChatroomMessage msg False) $ roomStateMessages rstate + + +sendMessage :: (MonadHead LocalState m, MonadError e m, FromErebosError e) => Conversation -> Text -> m (Maybe Message) +sendMessage (DirectMessageConversation thread) text = fmap Just $ DirectMessageMessage <$> (fromStored <$> sendDirectMessage (msgPeer thread) text) <*> pure False +sendMessage (ChatroomConversation rstate) text = sendChatroomMessage rstate text >> return Nothing + +deleteConversation :: (MonadHead LocalState m, MonadError e m, FromErebosError e) => Conversation -> m () +deleteConversation (DirectMessageConversation _) = throwOtherError "deleting direct message conversation is not supported" +deleteConversation (ChatroomConversation rstate) = deleteChatroomByStateData (head $ roomStateData rstate) diff --git a/src/Erebos/DirectMessage.hs b/src/Erebos/DirectMessage.hs new file mode 100644 index 0000000..dc6724c --- /dev/null +++ b/src/Erebos/DirectMessage.hs @@ -0,0 +1,280 @@ +module Erebos.DirectMessage ( + DirectMessage(..), + sendDirectMessage, + + DirectMessageAttributes(..), + defaultDirectMessageAttributes, + + DirectMessageThreads, + toThreadList, + + DirectMessageThread(..), + threadToList, + messageThreadView, + + watchReceivedMessages, + formatDirectMessage, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.List +import Data.Ord +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Time.Format +import Data.Time.LocalTime + +import Erebos.Discovery +import Erebos.Identity +import Erebos.Network +import Erebos.Object +import Erebos.Service +import Erebos.State +import Erebos.Storable +import Erebos.Storage.Head +import Erebos.Storage.Merge + +data DirectMessage = DirectMessage + { msgFrom :: ComposedIdentity + , msgPrev :: [Stored DirectMessage] + , msgTime :: ZonedTime + , msgText :: Text + } + +instance Storable DirectMessage where + store' msg = storeRec $ do + mapM_ (storeRef "from") $ idExtDataF $ msgFrom msg + mapM_ (storeRef "PREV") $ msgPrev msg + storeDate "time" $ msgTime msg + storeText "text" $ msgText msg + + load' = loadRec $ DirectMessage + <$> loadIdentity "from" + <*> loadRefs "PREV" + <*> loadDate "time" + <*> loadText "text" + +data DirectMessageAttributes = DirectMessageAttributes + { dmOwnerMismatch :: ServiceHandler DirectMessage () + } + +defaultDirectMessageAttributes :: DirectMessageAttributes +defaultDirectMessageAttributes = DirectMessageAttributes + { dmOwnerMismatch = svcPrint "Owner mismatch" + } + +instance Service DirectMessage where + serviceID _ = mkServiceID "c702076c-4928-4415-8b6b-3e839eafcb0d" + + type ServiceAttributes DirectMessage = DirectMessageAttributes + defaultServiceAttributes _ = defaultDirectMessageAttributes + + serviceHandler smsg = do + let msg = fromStored smsg + powner <- asks $ finalOwner . svcPeerIdentity + erb <- svcGetLocal + st <- getStorage + let DirectMessageThreads prev _ = lookupSharedValue $ lsShared $ fromStored erb + sent = findMsgProperty powner msSent prev + received = findMsgProperty powner msReceived prev + received' = filterAncestors $ smsg : received + if powner `sameIdentity` msgFrom msg || + filterAncestors sent == filterAncestors (smsg : sent) + then do + when (received' /= received) $ do + next <- wrappedStore st $ MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = [] + , msReceived = received' + , msSeen = [] + } + let threads = DirectMessageThreads [next] (messageThreadView [next]) + shared <- makeSharedStateUpdate st threads (lsShared $ fromStored erb) + svcSetLocal =<< wrappedStore st (fromStored erb) { lsShared = [shared] } + + when (powner `sameIdentity` msgFrom msg) $ do + replyStoredRef smsg + + else join $ asks $ dmOwnerMismatch . svcAttributes + + serviceNewPeer = syncDirectMessageToPeer . lookupSharedValue . lsShared . fromStored =<< svcGetLocal + + serviceStorageWatchers _ = + [ SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncDirectMessageToPeer + , GlobalStorageWatcher (lookupSharedValue . lsShared . fromStored) findMissingPeers + ] + + +data MessageState = MessageState + { msPrev :: [Stored MessageState] + , msPeer :: ComposedIdentity + , msReady :: [Stored DirectMessage] + , msSent :: [Stored DirectMessage] + , msReceived :: [Stored DirectMessage] + , msSeen :: [Stored DirectMessage] + } + +data DirectMessageThreads = DirectMessageThreads [Stored MessageState] [DirectMessageThread] + +instance Eq DirectMessageThreads where + DirectMessageThreads mss _ == DirectMessageThreads mss' _ = mss == mss' + +toThreadList :: DirectMessageThreads -> [DirectMessageThread] +toThreadList (DirectMessageThreads _ threads) = threads + +instance Storable MessageState where + store' MessageState {..} = storeRec $ do + mapM_ (storeRef "PREV") msPrev + mapM_ (storeRef "peer") $ idExtDataF msPeer + mapM_ (storeRef "ready") msReady + mapM_ (storeRef "sent") msSent + mapM_ (storeRef "received") msReceived + mapM_ (storeRef "seen") msSeen + + load' = loadRec $ do + msPrev <- loadRefs "PREV" + msPeer <- loadIdentity "peer" + msReady <- loadRefs "ready" + msSent <- loadRefs "sent" + msReceived <- loadRefs "received" + msSeen <- loadRefs "seen" + return MessageState {..} + +instance Mergeable DirectMessageThreads where + type Component DirectMessageThreads = MessageState + mergeSorted mss = DirectMessageThreads mss (messageThreadView mss) + toComponents (DirectMessageThreads mss _) = mss + +instance SharedType DirectMessageThreads where + sharedTypeID _ = mkSharedTypeID "ee793681-5976-466a-b0f0-4e1907d3fade" + +findMsgProperty :: Foldable m => Identity m -> (MessageState -> [a]) -> [Stored MessageState] -> [a] +findMsgProperty pid sel mss = concat $ flip findProperty mss $ \x -> do + guard $ msPeer x `sameIdentity` pid + guard $ not $ null $ sel x + return $ sel x + + +sendDirectMessage :: (Foldable f, Applicative f, MonadHead LocalState m) + => Identity f -> Text -> m (Stored DirectMessage) +sendDirectMessage pid text = updateLocalState $ \ls -> do + let self = localIdentity $ fromStored ls + powner = finalOwner pid + flip updateSharedState ls $ \(DirectMessageThreads prev _) -> do + let ready = findMsgProperty powner msReady prev + received = findMsgProperty powner msReceived prev + + time <- liftIO getZonedTime + smsg <- mstore DirectMessage + { msgFrom = toComposedIdentity $ finalOwner self + , msgPrev = filterAncestors $ ready ++ received + , msgTime = time + , msgText = text + } + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [smsg] + , msSent = [] + , msReceived = [] + , msSeen = [] + } + return (DirectMessageThreads [next] (messageThreadView [next]), smsg) + +syncDirectMessageToPeer :: DirectMessageThreads -> ServiceHandler DirectMessage () +syncDirectMessageToPeer (DirectMessageThreads mss _) = do + pid <- finalOwner <$> asks svcPeerIdentity + peer <- asks svcPeer + let thread = messageThreadFor pid mss + mapM_ (sendToPeerStored peer) $ msgHead thread + updateLocalState_ $ \ls -> do + let powner = finalOwner pid + flip updateSharedState_ ls $ \unchanged@(DirectMessageThreads prev _) -> do + let ready = findMsgProperty powner msReady prev + sent = findMsgProperty powner msSent prev + sent' = filterAncestors (ready ++ sent) + + if sent' /= sent + then do + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = sent' + , msReceived = [] + , msSeen = [] + } + return $ DirectMessageThreads [next] (messageThreadView [next]) + else do + return unchanged + +findMissingPeers :: Server -> DirectMessageThreads -> ExceptT ErebosError IO () +findMissingPeers server threads = do + forM_ (toThreadList threads) $ \thread -> do + when (msgHead thread /= msgReceived thread) $ do + mapM_ (discoverySearch server) $ map (refDigest . storedRef) $ idDataF $ msgPeer thread + + +data DirectMessageThread = DirectMessageThread + { msgPeer :: ComposedIdentity + , msgHead :: [ Stored DirectMessage ] + , msgSent :: [ Stored DirectMessage ] + , msgSeen :: [ Stored DirectMessage ] + , msgReceived :: [ Stored DirectMessage ] + } + +threadToList :: DirectMessageThread -> [DirectMessage] +threadToList thread = helper S.empty $ msgHead thread + where helper seen msgs + | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = + fromStored msg : helper (S.insert msg seen) (msgs' ++ msgPrev (fromStored msg)) + | otherwise = [] + cmpView msg = (zonedTimeToUTC $ msgTime $ fromStored msg, msg) + +messageThreadView :: [Stored MessageState] -> [DirectMessageThread] +messageThreadView = helper [] + where helper used ms' = case filterAncestors ms' of + mss@(sms : rest) + | any (sameIdentity $ msPeer $ fromStored sms) used -> + helper used $ msPrev (fromStored sms) ++ rest + | otherwise -> + let peer = msPeer $ fromStored sms + in messageThreadFor peer mss : helper (peer : used) (msPrev (fromStored sms) ++ rest) + _ -> [] + +messageThreadFor :: ComposedIdentity -> [Stored MessageState] -> DirectMessageThread +messageThreadFor peer mss = + let ready = findMsgProperty peer msReady mss + sent = findMsgProperty peer msSent mss + received = findMsgProperty peer msReceived mss + seen = findMsgProperty peer msSeen mss + + in DirectMessageThread + { msgPeer = peer + , msgHead = filterAncestors $ ready ++ received + , msgSent = filterAncestors $ sent ++ received + , msgSeen = filterAncestors $ ready ++ seen + , msgReceived = filterAncestors $ received + } + + +watchReceivedMessages :: Head LocalState -> (Stored DirectMessage -> IO ()) -> IO WatchedHead +watchReceivedMessages h f = do + let self = finalOwner $ localIdentity $ headObject h + watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \(DirectMessageThreads sms _) -> do + forM_ (map fromStored sms) $ \ms -> do + mapM_ f $ filter (not . sameIdentity self . msgFrom . fromStored) $ msReceived ms + +formatDirectMessage :: TimeZone -> DirectMessage -> String +formatDirectMessage tzone msg = concat + [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg + , maybe "<unnamed>" T.unpack $ idName $ msgFrom msg + , ": " + , T.unpack $ msgText msg + ] diff --git a/src/Erebos/Discovery.hs b/src/Erebos/Discovery.hs new file mode 100644 index 0000000..ede9cc9 --- /dev/null +++ b/src/Erebos/Discovery.hs @@ -0,0 +1,595 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Discovery ( + DiscoveryService(..), + DiscoveryAttributes(..), + DiscoveryConnection(..), + + discoverySearch, + discoverySetupTunnel, +) where + +import Control.Concurrent +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Data.IP qualified as IP +import Data.List +import Data.Map.Strict (Map) +import Data.Map.Strict qualified as M +import Data.Maybe +import Data.Proxy +import Data.Set (Set) +import Data.Set qualified as S +import Data.Text (Text) +import Data.Text qualified as T +import Data.Word + +import Network.Socket + +#ifdef ENABLE_ICE_SUPPORT +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Network +import Erebos.Object +import Erebos.Service +import Erebos.Service.Stream +import Erebos.Storable + + +#ifndef ENABLE_ICE_SUPPORT +type IceConfig = () +type IceSession = () +type IceRemoteInfo = Stored Object +#endif + + +data DiscoveryService + = DiscoverySelf [ Text ] (Maybe Int) + | DiscoveryAcknowledged [ Text ] (Maybe Text) (Maybe Word16) (Maybe Text) (Maybe Word16) + | DiscoverySearch (Either Ref RefDigest) + | DiscoveryResult (Either Ref RefDigest) [ Text ] + | DiscoveryConnectionRequest DiscoveryConnection + | DiscoveryConnectionResponse DiscoveryConnection + +data DiscoveryAttributes = DiscoveryAttributes + { discoveryStunPort :: Maybe Word16 + , discoveryStunServer :: Maybe Text + , discoveryTurnPort :: Maybe Word16 + , discoveryTurnServer :: Maybe Text + , discoveryProvideTunnel :: Peer -> Bool + } + +defaultDiscoveryAttributes :: DiscoveryAttributes +defaultDiscoveryAttributes = DiscoveryAttributes + { discoveryStunPort = Nothing + , discoveryStunServer = Nothing + , discoveryTurnPort = Nothing + , discoveryTurnServer = Nothing + , discoveryProvideTunnel = const False + } + +data DiscoveryConnection = DiscoveryConnection + { dconnSource :: Either Ref RefDigest + , dconnTarget :: Either Ref RefDigest + , dconnAddress :: Maybe Text + , dconnTunnel :: Bool + , dconnIceInfo :: Maybe IceRemoteInfo + } + +emptyConnection :: Either Ref RefDigest -> Either Ref RefDigest -> DiscoveryConnection +emptyConnection dconnSource dconnTarget = DiscoveryConnection {..} + where + dconnAddress = Nothing + dconnTunnel = False + dconnIceInfo = Nothing + +instance Storable DiscoveryService where + store' x = storeRec $ do + case x of + DiscoverySelf addrs priority -> do + mapM_ (storeText "self") addrs + mapM_ (storeInt "priority") priority + DiscoveryAcknowledged addrs stunServer stunPort turnServer turnPort -> do + if null addrs then storeEmpty "ack" + else mapM_ (storeText "ack") addrs + storeMbText "stun-server" stunServer + storeMbInt "stun-port" stunPort + storeMbText "turn-server" turnServer + storeMbInt "turn-port" turnPort + DiscoverySearch edgst -> either (storeRawRef "search") (storeRawWeak "search") edgst + DiscoveryResult edgst addr -> do + either (storeRawRef "result") (storeRawWeak "result") edgst + mapM_ (storeText "address") addr + DiscoveryConnectionRequest conn -> storeConnection "request" conn + DiscoveryConnectionResponse conn -> storeConnection "response" conn + + where + storeConnection ctype DiscoveryConnection {..} = do + storeText "connection" $ ctype + either (storeRawRef "source") (storeRawWeak "source") dconnSource + either (storeRawRef "target") (storeRawWeak "target") dconnTarget + storeMbText "address" dconnAddress + when dconnTunnel $ storeEmpty "tunnel" + storeMbRef "ice-info" dconnIceInfo + + load' = loadRec $ msum + [ do + addrs <- loadTexts "self" + guard (not $ null addrs) + DiscoverySelf addrs + <$> loadMbInt "priority" + , do + addrs <- loadTexts "ack" + mbEmpty <- loadMbEmpty "ack" + guard (not (null addrs) || isJust mbEmpty) + DiscoveryAcknowledged + <$> pure addrs + <*> loadMbText "stun-server" + <*> loadMbInt "stun-port" + <*> loadMbText "turn-server" + <*> loadMbInt "turn-port" + , DiscoverySearch <$> msum + [ Left <$> loadRawRef "search" + , Right <$> loadRawWeak "search" + ] + , DiscoveryResult + <$> msum + [ Left <$> loadRawRef "result" + , Right <$> loadRawWeak "result" + ] + <*> loadTexts "address" + , loadConnection "request" DiscoveryConnectionRequest + , loadConnection "response" DiscoveryConnectionResponse + ] + where + loadConnection ctype ctor = do + ctype' <- loadText "connection" + guard $ ctype == ctype' + dconnSource <- msum + [ Left <$> loadRawRef "source" + , Right <$> loadRawWeak "source" + ] + dconnTarget <- msum + [ Left <$> loadRawRef "target" + , Right <$> loadRawWeak "target" + ] + dconnAddress <- loadMbText "address" + dconnTunnel <- isJust <$> loadMbEmpty "tunnel" + dconnIceInfo <- loadMbRef "ice-info" + return $ ctor DiscoveryConnection {..} + +data DiscoveryPeer = DiscoveryPeer + { dpPriority :: Int + , dpPeer :: Maybe Peer + , dpAddress :: [ Text ] + , dpIceSession :: Maybe IceSession + } + +emptyPeer :: DiscoveryPeer +emptyPeer = DiscoveryPeer + { dpPriority = 0 + , dpPeer = Nothing + , dpAddress = [] + , dpIceSession = Nothing + } + +data DiscoveryPeerState = DiscoveryPeerState + { dpsOurTunnelRequests :: [ ( RefDigest, StreamWriter ) ] + -- ( original target, our write stream ) + , dpsRelayedTunnelRequests :: [ ( RefDigest, ( StreamReader, StreamWriter )) ] + -- ( original source, ( from source, to target )) + , dpsStunServer :: Maybe ( Text, Word16 ) + , dpsTurnServer :: Maybe ( Text, Word16 ) + , dpsIceConfig :: Maybe IceConfig + } + +data DiscoveryGlobalState = DiscoveryGlobalState + { dgsPeers :: Map RefDigest DiscoveryPeer + , dgsSearchingFor :: Set RefDigest + } + +instance Service DiscoveryService where + serviceID _ = mkServiceID "dd59c89c-69cc-4703-b75b-4ddcd4b3c23c" + + type ServiceAttributes DiscoveryService = DiscoveryAttributes + defaultServiceAttributes _ = defaultDiscoveryAttributes + + type ServiceState DiscoveryService = DiscoveryPeerState + emptyServiceState _ = DiscoveryPeerState + { dpsOurTunnelRequests = [] + , dpsRelayedTunnelRequests = [] + , dpsStunServer = Nothing + , dpsTurnServer = Nothing + , dpsIceConfig = Nothing + } + + type ServiceGlobalState DiscoveryService = DiscoveryGlobalState + emptyServiceGlobalState _ = DiscoveryGlobalState + { dgsPeers = M.empty + , dgsSearchingFor = S.empty + } + + serviceHandler msg = case fromStored msg of + DiscoverySelf addrs priority -> do + pid <- asks svcPeerIdentity + peer <- asks svcPeer + let insertHelper new old | dpPriority new > dpPriority old = new + | otherwise = old + matchedAddrs <- fmap catMaybes $ forM addrs $ \addr -> if + | addr == T.pack "ICE" -> do + return $ Just addr + + | [ ipaddr, port ] <- words (T.unpack addr) + , DatagramAddress paddr <- peerAddress peer -> do + saddr <- liftIO $ head <$> getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + return $ if paddr == addrAddress saddr + then Just addr + else Nothing + + | otherwise -> return Nothing + + forM_ (idDataF =<< unfoldOwners pid) $ \sdata -> do + let dp = DiscoveryPeer + { dpPriority = fromMaybe 0 priority + , dpPeer = Just peer + , dpAddress = addrs + , dpIceSession = Nothing + } + svcModifyGlobal $ \s -> s { dgsPeers = M.insertWith insertHelper (refDigest $ storedRef sdata) dp $ dgsPeers s } + attrs <- asks svcAttributes + replyPacket $ DiscoveryAcknowledged matchedAddrs + (discoveryStunServer attrs) + (discoveryStunPort attrs) + (discoveryTurnServer attrs) + (discoveryTurnPort attrs) + + DiscoveryAcknowledged _ stunServer stunPort turnServer turnPort -> do + paddr <- asks (peerAddress . svcPeer) >>= return . \case + (DatagramAddress saddr) -> case IP.fromSockAddr saddr of + Just (IP.IPv6 ipv6, _) + | (0, 0, 0xffff, ipv4) <- IP.fromIPv6w ipv6 + -> Just $ T.pack $ show (IP.toIPv4w ipv4) + Just (addr, _) + -> Just $ T.pack $ show addr + _ -> Nothing + _ -> Nothing + + let toIceServer Nothing Nothing = Nothing + toIceServer Nothing (Just port) = ( , port) <$> paddr + toIceServer (Just server) Nothing = Just ( server, 0 ) + toIceServer (Just server) (Just port) = Just ( server, port ) + + svcModify $ \s -> s + { dpsStunServer = toIceServer stunServer stunPort + , dpsTurnServer = toIceServer turnServer turnPort + } + + DiscoverySearch edgst -> do + dpeer <- M.lookup (either refDigest id edgst) . dgsPeers <$> svcGetGlobal + replyPacket $ DiscoveryResult edgst $ maybe [] dpAddress dpeer + + DiscoveryResult _ [] -> do + -- not found + return () + + DiscoveryResult edgst addrs -> do + let dgst = either refDigest id edgst + -- TODO: check if we really requested that + server <- asks svcServer + st <- getStorage + self <- svcSelf + discoveryPeer <- asks svcPeer + let runAsService = runPeerService @DiscoveryService discoveryPeer + + forM_ addrs $ \addr -> if + | addr == T.pack "ICE" + -> do +#ifdef ENABLE_ICE_SUPPORT + getIceConfig >>= \case + Just config -> void $ liftIO $ forkIO $ do + ice <- iceCreateSession config PjIceSessRoleControlling $ \ice -> do + rinfo <- iceRemoteInfo ice + + -- Try to promote weak ref to normal one for older peers: + edgst' <- case edgst of + Left r -> return (Left r) + Right d -> refFromDigest st d >>= \case + Just r -> return (Left r) + Nothing -> return (Right d) + + res <- runExceptT $ sendToPeer discoveryPeer $ + DiscoveryConnectionRequest (emptyConnection (Left $ storedRef $ idData self) edgst') { dconnIceInfo = Just rinfo } + case res of + Right _ -> return () + Left err -> putStrLn $ "Discovery: failed to send connection request: " ++ err + + runAsService $ do + let upd dp = dp { dpIceSession = Just ice } + svcModifyGlobal $ \s -> s { dgsPeers = M.alter (Just . upd . fromMaybe emptyPeer) dgst $ dgsPeers s } + + Nothing -> do + return () +#endif + return () + + | [ ipaddr, port ] <- words (T.unpack addr) -> do + void $ liftIO $ forkIO $ do + saddr <- head <$> + getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + peer <- serverPeer server (addrAddress saddr) + runAsService $ do + let upd dp = dp { dpPeer = Just peer } + svcModifyGlobal $ \s -> s { dgsPeers = M.alter (Just . upd . fromMaybe emptyPeer) dgst $ dgsPeers s } + + | otherwise -> do + svcPrint $ "Discovery: invalid address in result: " ++ T.unpack addr + + DiscoveryConnectionRequest conn -> do + self <- svcSelf + attrs <- asks svcAttributes + let rconn = emptyConnection (dconnSource conn) (dconnTarget conn) + if either refDigest id (dconnTarget conn) `elem` identityDigests self + then if + -- request for us, create ICE sesssion or tunnel + | dconnTunnel conn -> do + receivedStreams >>= \case + (tunnelReader : _) -> do + tunnelWriter <- openStream + replyPacket $ DiscoveryConnectionResponse rconn + { dconnTunnel = True + } + tunnelVia <- asks svcPeer + tunnelIdentity <- asks svcPeerIdentity + server <- asks svcServer + void $ liftIO $ forkIO $ do + tunnelStreamNumber <- getStreamWriterNumber tunnelWriter + let addr = TunnelAddress {..} + void $ serverPeerCustom server addr + receiveFromTunnel server addr + + [] -> do + svcPrint $ "Discovery: missing stream on tunnel request (endpoint)" + +#ifdef ENABLE_ICE_SUPPORT + | Just prinfo <- dconnIceInfo conn -> do + server <- asks svcServer + peer <- asks svcPeer + getIceConfig >>= \case + Just config -> do + liftIO $ void $ iceCreateSession config PjIceSessRoleControlled $ \ice -> do + rinfo <- iceRemoteInfo ice + res <- runExceptT $ sendToPeer peer $ DiscoveryConnectionResponse rconn { dconnIceInfo = Just rinfo } + case res of + Right _ -> iceConnect ice prinfo $ void $ serverPeerIce server ice + Left err -> putStrLn $ "Discovery: failed to send connection response: " ++ err + Nothing -> do + return () +#endif + + | otherwise -> do + svcPrint $ "Discovery: unsupported connection request" + + else do + -- request to some of our peers, relay + peer <- asks svcPeer + mbdp <- M.lookup (either refDigest id $ dconnTarget conn) . dgsPeers <$> svcGetGlobal + streams <- receivedStreams + case mbdp of + Nothing -> replyPacket $ DiscoveryConnectionResponse rconn + Just dp + | Just dpeer <- dpPeer dp -> if + | dconnTunnel conn -> if + | not (discoveryProvideTunnel attrs peer) -> do + replyPacket $ DiscoveryConnectionResponse rconn + | fromSource : _ <- streams -> do + void $ liftIO $ forkIO $ runPeerService @DiscoveryService dpeer $ do + toTarget <- openStream + svcModify $ \s -> s { dpsRelayedTunnelRequests = + ( either refDigest id $ dconnSource conn, ( fromSource, toTarget )) : dpsRelayedTunnelRequests s } + replyPacket $ DiscoveryConnectionRequest conn + | otherwise -> do + svcPrint $ "Discovery: missing stream on tunnel request (relay)" + | otherwise -> do + sendToPeer dpeer $ DiscoveryConnectionRequest conn + | otherwise -> svcPrint $ "Discovery: failed to relay connection request" + + DiscoveryConnectionResponse conn -> do + self <- svcSelf + dps <- svcGet + dpeers <- dgsPeers <$> svcGetGlobal + + if either refDigest id (dconnSource conn) `elem` identityDigests self + then do + -- response to our request, try to connect to the peer + server <- asks svcServer + if | Just addr <- dconnAddress conn + , [ipaddr, port] <- words (T.unpack addr) -> do + saddr <- liftIO $ head <$> + getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) + peer <- liftIO $ serverPeer server (addrAddress saddr) + let upd dp = dp { dpPeer = Just peer } + svcModifyGlobal $ \s -> s + { dgsPeers = M.alter (Just . upd . fromMaybe emptyPeer) (either refDigest id $ dconnTarget conn) $ dgsPeers s } + + | dconnTunnel conn + , Just tunnelWriter <- lookup (either refDigest id (dconnTarget conn)) (dpsOurTunnelRequests dps) + -> do + receivedStreams >>= \case + tunnelReader : _ -> do + tunnelVia <- asks svcPeer + tunnelIdentity <- asks svcPeerIdentity + void $ liftIO $ forkIO $ do + tunnelStreamNumber <- getStreamWriterNumber tunnelWriter + let addr = TunnelAddress {..} + void $ serverPeerCustom server addr + receiveFromTunnel server addr + [] -> svcPrint $ "Discovery: missing stream in tunnel response" + +#ifdef ENABLE_ICE_SUPPORT + | Just dp <- M.lookup (either refDigest id $ dconnTarget conn) dpeers + , Just ice <- dpIceSession dp + , Just rinfo <- dconnIceInfo conn -> do + liftIO $ iceConnect ice rinfo $ void $ serverPeerIce server ice +#endif + + | otherwise -> svcPrint $ "Discovery: connection request failed" + else do + -- response to relayed request + streams <- receivedStreams + svcModify $ \s -> s { dpsRelayedTunnelRequests = + filter ((either refDigest id (dconnSource conn) /=) . fst) (dpsRelayedTunnelRequests s) } + + case M.lookup (either refDigest id $ dconnSource conn) dpeers of + Just dp | Just dpeer <- dpPeer dp -> if + | dconnTunnel conn + , Just ( fromSource, toTarget ) <- lookup (either refDigest id (dconnSource conn)) (dpsRelayedTunnelRequests dps) + , fromTarget : _ <- streams + -> liftIO $ do + toSourceVar <- newEmptyMVar + void $ forkIO $ runPeerService @DiscoveryService dpeer $ do + liftIO . putMVar toSourceVar =<< openStream + svcModify $ \s -> s { dpsRelayedTunnelRequests = + ( either refDigest id $ dconnSource conn, ( fromSource, toTarget )) : dpsRelayedTunnelRequests s } + replyPacket $ DiscoveryConnectionResponse conn + void $ forkIO $ do + relayStream fromSource toTarget + void $ forkIO $ do + toSource <- readMVar toSourceVar + relayStream fromTarget toSource + + | otherwise -> do + sendToPeer dpeer $ DiscoveryConnectionResponse conn + _ -> svcPrint $ "Discovery: failed to relay connection response" + + serviceNewPeer = do + server <- asks svcServer + peer <- asks svcPeer + + let addrToText saddr = do + ( addr, port ) <- IP.fromSockAddr saddr + Just $ T.pack $ show addr <> " " <> show port + addrs <- concat <$> sequence + [ catMaybes . map addrToText <$> liftIO (getServerAddresses server) +#ifdef ENABLE_ICE_SUPPORT + , return [ T.pack "ICE" ] +#endif + ] + + pid <- asks svcPeerIdentity + gs <- svcGetGlobal + let searchingFor = foldl' (flip S.delete) (dgsSearchingFor gs) (identityDigests pid) + svcModifyGlobal $ \s -> s { dgsSearchingFor = searchingFor } + + when (not $ null addrs) $ do + sendToPeer peer $ DiscoverySelf addrs Nothing + forM_ searchingFor $ \dgst -> do + sendToPeer peer $ DiscoverySearch (Right dgst) + +#ifdef ENABLE_ICE_SUPPORT + serviceStopServer _ _ _ pstates = do + forM_ pstates $ \( _, DiscoveryPeerState {..} ) -> do + mapM_ iceStopThread dpsIceConfig +#endif + + +identityDigests :: Foldable f => Identity f -> [ RefDigest ] +identityDigests pid = map (refDigest . storedRef) $ idDataF =<< unfoldOwners pid + + +getIceConfig :: ServiceHandler DiscoveryService (Maybe IceConfig) +getIceConfig = do + dpsIceConfig <$> svcGet >>= \case + Just cfg -> return $ Just cfg + Nothing -> do +#ifdef ENABLE_ICE_SUPPORT + stun <- dpsStunServer <$> svcGet + turn <- dpsTurnServer <$> svcGet + liftIO (iceCreateConfig stun turn) >>= \case + Just cfg -> do + svcModify $ \s -> s { dpsIceConfig = Just cfg } + return $ Just cfg + Nothing -> do + svcPrint $ "Discovery: failed to create ICE config" + return Nothing +#else + return Nothing +#endif + + +discoverySearch :: (MonadIO m, MonadError e m, FromErebosError e) => Server -> RefDigest -> m () +discoverySearch server dgst = do + peers <- liftIO $ getCurrentPeerList server + match <- forM peers $ \peer -> do + peerIdentity peer >>= \case + PeerIdentityFull pid -> do + return $ dgst `elem` identityDigests pid + _ -> return False + when (not $ or match) $ do + modifyServiceGlobalState server (Proxy @DiscoveryService) $ \s -> (, ()) s + { dgsSearchingFor = S.insert dgst $ dgsSearchingFor s + } + forM_ peers $ \peer -> do + sendToPeer peer $ DiscoverySearch $ Right dgst + + +data TunnelAddress = TunnelAddress + { tunnelVia :: Peer + , tunnelIdentity :: UnifiedIdentity + , tunnelStreamNumber :: Int + , tunnelReader :: StreamReader + , tunnelWriter :: StreamWriter + } + +instance Eq TunnelAddress where + x == y = (==) + (idData (tunnelIdentity x), tunnelStreamNumber x) + (idData (tunnelIdentity y), tunnelStreamNumber y) + +instance Ord TunnelAddress where + compare x y = compare + (idData (tunnelIdentity x), tunnelStreamNumber x) + (idData (tunnelIdentity y), tunnelStreamNumber y) + +instance Show TunnelAddress where + show tunnel = concat + [ "tunnel@" + , show $ refDigest $ storedRef $ idData $ tunnelIdentity tunnel + , "/" <> show (tunnelStreamNumber tunnel) + ] + +instance PeerAddressType TunnelAddress where + sendBytesToAddress TunnelAddress {..} bytes = do + writeStream tunnelWriter bytes + +relayStream :: StreamReader -> StreamWriter -> IO () +relayStream r w = do + p <- readStreamPacket r + writeStreamPacket w p + case p of + StreamClosed {} -> return () + _ -> relayStream r w + +receiveFromTunnel :: Server -> TunnelAddress -> IO () +receiveFromTunnel server taddr = do + p <- readStreamPacket (tunnelReader taddr) + case p of + StreamData {..} -> do + receivedFromCustomAddress server taddr stpData + receiveFromTunnel server taddr + StreamClosed {} -> do + return () + + +discoverySetupTunnel :: Peer -> RefDigest -> IO () +discoverySetupTunnel via target = do + runPeerService via $ do + self <- refDigest . storedRef . idData <$> svcSelf + stream <- openStream + svcModify $ \s -> s { dpsOurTunnelRequests = ( target, stream ) : dpsOurTunnelRequests s } + replyPacket $ DiscoveryConnectionRequest + (emptyConnection (Right self) (Right target)) + { dconnTunnel = True + } diff --git a/src/Erebos/Error.hs b/src/Erebos/Error.hs new file mode 100644 index 0000000..3bb8736 --- /dev/null +++ b/src/Erebos/Error.hs @@ -0,0 +1,39 @@ +module Erebos.Error ( + ErebosError(..), + showErebosError, + + FromErebosError(..), + throwOtherError, +) where + +import Control.Monad.Except + + +data ErebosError + = ManyErrors [ ErebosError ] + | OtherError String + +showErebosError :: ErebosError -> String +showErebosError (ManyErrors errs) = unlines $ map showErebosError errs +showErebosError (OtherError str) = str + +instance Semigroup ErebosError where + ManyErrors [] <> b = b + a <> ManyErrors [] = a + ManyErrors a <> ManyErrors b = ManyErrors (a ++ b) + ManyErrors a <> b = ManyErrors (a ++ [ b ]) + a <> ManyErrors b = ManyErrors (a : b) + a@OtherError {} <> b@OtherError {} = ManyErrors [ a, b ] + +instance Monoid ErebosError where + mempty = ManyErrors [] + + +class FromErebosError e where + fromErebosError :: ErebosError -> e + +instance FromErebosError ErebosError where + fromErebosError = id + +throwOtherError :: (MonadError e m, FromErebosError e) => String -> m a +throwOtherError = throwError . fromErebosError . OtherError diff --git a/src/Erebos/Flow.hs b/src/Erebos/Flow.hs new file mode 100644 index 0000000..1e1a521 --- /dev/null +++ b/src/Erebos/Flow.hs @@ -0,0 +1,72 @@ +module Erebos.Flow ( + Flow, SymFlow, + newFlow, newFlowIO, + readFlow, tryReadFlow, canReadFlow, + writeFlow, writeFlowBulk, tryWriteFlow, canWriteFlow, + readFlowIO, writeFlowIO, + + mapFlow, +) where + +import Control.Concurrent.STM + + +data Flow r w + = Flow (TBQueue r) (TBQueue w) + | forall r' w'. MappedFlow (r' -> r) (w -> w') (Flow r' w') + +type SymFlow a = Flow a a + +newFlow :: STM (Flow a b, Flow b a) +newFlow = do + x <- newTBQueue 16 + y <- newTBQueue 16 + return (Flow x y, Flow y x) + +newFlowIO :: IO (Flow a b, Flow b a) +newFlowIO = atomically newFlow + +readFlow :: Flow r w -> STM r +readFlow (Flow rvar _) = readTBQueue rvar +readFlow (MappedFlow f _ up) = f <$> readFlow up + +tryReadFlow :: Flow r w -> STM (Maybe r) +tryReadFlow (Flow rvar _) = tryReadTBQueue rvar +tryReadFlow (MappedFlow f _ up) = fmap f <$> tryReadFlow up + +canReadFlow :: Flow r w -> STM Bool +canReadFlow (Flow rvar _) = not <$> isEmptyTBQueue rvar +canReadFlow (MappedFlow _ _ up) = canReadFlow up + +writeFlow :: Flow r w -> w -> STM () +writeFlow (Flow _ wvar) = writeTBQueue wvar +writeFlow (MappedFlow _ f up) = writeFlow up . f + +writeFlowBulk :: Flow r w -> [w] -> STM () +writeFlowBulk _ [] = return () +writeFlowBulk (Flow _ wvar) xs = mapM_ (writeTBQueue wvar) xs +writeFlowBulk (MappedFlow _ f up) xs = writeFlowBulk up $ map f xs + +tryWriteFlow :: Flow r w -> w -> STM Bool +tryWriteFlow (Flow _ wvar) x = do + isFullTBQueue wvar >>= \case + True -> return False + False -> do + writeTBQueue wvar x + return True +tryWriteFlow (MappedFlow _ f up) x = tryWriteFlow up $ f x + +canWriteFlow :: Flow r w -> STM Bool +canWriteFlow (Flow _ wvar) = not <$> isFullTBQueue wvar +canWriteFlow (MappedFlow _ _ up) = canWriteFlow up + +readFlowIO :: Flow r w -> IO r +readFlowIO path = atomically $ readFlow path + +writeFlowIO :: Flow r w -> w -> IO () +writeFlowIO path = atomically . writeFlow path + + +mapFlow :: (r -> r') -> (w' -> w) -> Flow r w -> Flow r' w' +mapFlow rf wf (MappedFlow rf' wf' up) = MappedFlow (rf . rf') (wf' . wf) up +mapFlow rf wf up = MappedFlow rf wf up diff --git a/src/Erebos/ICE.chs b/src/Erebos/ICE.chs new file mode 100644 index 0000000..dceeb2c --- /dev/null +++ b/src/Erebos/ICE.chs @@ -0,0 +1,242 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE RecursiveDo #-} + +module Erebos.ICE ( + IceSession, + IceSessionRole(..), + IceConfig, + IceRemoteInfo, + + iceCreateConfig, + iceStopThread, + iceCreateSession, + iceDestroy, + iceRemoteInfo, + iceShow, + iceConnect, + iceSend, + + iceSetChan, +) where + +import Control.Arrow +import Control.Concurrent +import Control.Monad +import Control.Monad.Identity + +import Data.ByteString (ByteString, packCStringLen, useAsCString) +import Data.ByteString.Lazy.Char8 qualified as BLC +import Data.ByteString.Unsafe +import Data.Function +import Data.Text (Text) +import Data.Text qualified as T +import Data.Text.Encoding qualified as T +import Data.Text.Read qualified as T +import Data.Void +import Data.Word + +import Foreign.C.String +import Foreign.C.Types +import Foreign.ForeignPtr +import Foreign.Marshal.Alloc +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.StablePtr + +import Erebos.Flow +import Erebos.Object +import Erebos.Storable +import Erebos.Storage + +#include "pjproject.h" + +data IceSession = IceSession + { isStrans :: PjIceStrans + , _isConfig :: IceConfig + , isChan :: MVar (Either [ByteString] (Flow Void ByteString)) + } + +instance Eq IceSession where + (==) = (==) `on` isStrans + +instance Ord IceSession where + compare = compare `on` isStrans + +instance Show IceSession where + show _ = "<ICE>" + + +data IceRemoteInfo = IceRemoteInfo + { iriUsernameFrament :: Text + , iriPassword :: Text + , iriDefaultCandidate :: Text + , iriCandidates :: [Text] + } + +data IceCandidate = IceCandidate + { icandFoundation :: Text + , icandPriority :: Int + , icandAddr :: Text + , icandPort :: Int + , icandType :: Text + } + +instance Storable IceRemoteInfo where + store' x = storeRec $ do + storeText "ice-ufrag" $ iriUsernameFrament x + storeText "ice-pass" $ iriPassword x + storeText "ice-default" $ iriDefaultCandidate x + mapM_ (storeText "ice-candidate") $ iriCandidates x + + load' = loadRec $ IceRemoteInfo + <$> loadText "ice-ufrag" + <*> loadText "ice-pass" + <*> loadText "ice-default" + <*> loadTexts "ice-candidate" + +instance StorableText IceCandidate where + toText x = T.concat $ + [ icandFoundation x + , T.singleton ' ' + , T.pack $ show $ icandPriority x + , T.singleton ' ' + , icandAddr x + , T.singleton ' ' + , T.pack $ show $ icandPort x + , T.singleton ' ' + , icandType x + ] + + fromText t = case T.words t of + [found, tprio, addr, tport, ctype] + | Right (prio, _) <- T.decimal tprio + , Right (port, _) <- T.decimal tport + -> return $ IceCandidate + { icandFoundation = found + , icandPriority = prio + , icandAddr = addr + , icandPort = port + , icandType = ctype + } + _ -> throwOtherError "failed to parse candidate" + + +{#enum pj_ice_sess_role as IceSessionRole {underscoreToCase} deriving (Show, Eq) #} + +data PjIceStransCfg +newtype IceConfig = IceConfig (ForeignPtr PjIceStransCfg) + +foreign import ccall unsafe "pjproject.h &ice_cfg_free" + ice_cfg_free :: FunPtr (Ptr PjIceStransCfg -> IO ()) +foreign import ccall unsafe "pjproject.h ice_cfg_create" + ice_cfg_create :: CString -> Word16 -> CString -> Word16 -> IO (Ptr PjIceStransCfg) + +iceCreateConfig :: Maybe ( Text, Word16 ) -> Maybe ( Text, Word16 ) -> IO (Maybe IceConfig) +iceCreateConfig stun turn = + maybe ($ nullPtr) (withText . fst) stun $ \cstun -> + maybe ($ nullPtr) (withText . fst) turn $ \cturn -> do + cfg <- ice_cfg_create cstun (maybe 0 snd stun) cturn (maybe 0 snd turn) + if cfg == nullPtr + then return Nothing + else Just . IceConfig <$> newForeignPtr ice_cfg_free cfg + +foreign import ccall unsafe "pjproject.h ice_cfg_stop_thread" + ice_cfg_stop_thread :: Ptr PjIceStransCfg -> IO () + +iceStopThread :: IceConfig -> IO () +iceStopThread (IceConfig fcfg) = withForeignPtr fcfg ice_cfg_stop_thread + +{#pointer *pj_ice_strans as ^ #} + +iceCreateSession :: IceConfig -> IceSessionRole -> (IceSession -> IO ()) -> IO IceSession +iceCreateSession icfg@(IceConfig fcfg) role cb = do + rec sptr <- newStablePtr sess + cbptr <- newStablePtr $ do + -- The callback may be called directly from pj_ice_strans_create or later + -- from a different thread; make sure we use a different thread here + -- to avoid deadlock on accessing 'sess'. + forkIO $ cb sess + sess <- IceSession + <$> (withForeignPtr fcfg $ \cfg -> + {#call ice_create #} (castPtr cfg) (fromIntegral $ fromEnum role) (castStablePtrToPtr sptr) (castStablePtrToPtr cbptr) + ) + <*> pure icfg + <*> (newMVar $ Left []) + return $ sess + +{#fun ice_destroy as ^ { isStrans `IceSession' } -> `()' #} + +iceRemoteInfo :: IceSession -> IO IceRemoteInfo +iceRemoteInfo sess = do + let maxlen = 128 + maxcand = 29 + + allocaBytes maxlen $ \ufrag -> + allocaBytes maxlen $ \pass -> + allocaBytes maxlen $ \def -> + allocaBytes (maxcand*maxlen) $ \bytes -> + allocaArray maxcand $ \carr -> do + let cptrs = take maxcand $ iterate (`plusPtr` maxlen) bytes + pokeArray carr $ take maxcand cptrs + + ncand <- {#call ice_encode_session #} (isStrans sess) ufrag pass def carr (fromIntegral maxlen) (fromIntegral maxcand) + if ncand < 0 then fail "failed to generate ICE remote info" + else IceRemoteInfo + <$> (T.pack <$> peekCString ufrag) + <*> (T.pack <$> peekCString pass) + <*> (T.pack <$> peekCString def) + <*> (mapM (return . T.pack <=< peekCString) $ take (fromIntegral ncand) cptrs) + +iceShow :: IceSession -> IO String +iceShow sess = do + st <- memoryStorage + return . drop 1 . dropWhile (/='\n') . BLC.unpack . runIdentity =<< + ioLoadBytes =<< store st =<< iceRemoteInfo sess + +iceConnect :: IceSession -> IceRemoteInfo -> (IO ()) -> IO () +iceConnect sess remote cb = do + cbptr <- newStablePtr $ cb + ice_connect sess cbptr + (iriUsernameFrament remote) + (iriPassword remote) + (iriDefaultCandidate remote) + (iriCandidates remote) + +{#fun ice_connect { isStrans `IceSession', castStablePtrToPtr `StablePtr (IO ())', + withText* `Text', withText* `Text', withText* `Text', withTextArray* `[Text]'& } -> `()' #} + +withText :: Text -> (Ptr CChar -> IO a) -> IO a +withText t f = useAsCString (T.encodeUtf8 t) f + +withTextArray :: Num n => [Text] -> ((Ptr (Ptr CChar), n) -> IO ()) -> IO () +withTextArray tsAll f = helper tsAll [] + where helper (t:ts) bs = withText t $ \b -> helper ts (b:bs) + helper [] bs = allocaArray (length bs) $ \ptr -> do + pokeArray ptr $ reverse bs + f (ptr, fromIntegral $ length bs) + +withByteStringLen :: Num n => ByteString -> ((Ptr CChar, n) -> IO a) -> IO a +withByteStringLen t f = unsafeUseAsCStringLen t (f . (id *** fromIntegral)) + +{#fun ice_send as ^ { isStrans `IceSession', withByteStringLen* `ByteString'& } -> `()' #} + +foreign export ccall ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb = join . deRefStablePtr + +iceSetChan :: IceSession -> Flow Void ByteString -> IO () +iceSetChan sess chan = do + modifyMVar_ (isChan sess) $ \orig -> do + case orig of + Left buf -> mapM_ (writeFlowIO chan) $ reverse buf + Right _ -> return () + return $ Right chan + +foreign export ccall ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data sptr buf len = do + sess <- deRefStablePtr sptr + bs <- packCStringLen (buf, len) + modifyMVar_ (isChan sess) $ \case + mc@(Right chan) -> writeFlowIO chan bs >> return mc + Left bss -> return $ Left (bs:bss) diff --git a/src/Erebos/ICE/pjproject.c b/src/Erebos/ICE/pjproject.c new file mode 100644 index 0000000..e9446fe --- /dev/null +++ b/src/Erebos/ICE/pjproject.c @@ -0,0 +1,423 @@ +#include "pjproject.h" +#include "Erebos/ICE_stub.h" + +#include <stdatomic.h> +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <pthread.h> +#include <pjlib.h> +#include <pjlib-util.h> + +static struct +{ + pj_caching_pool cp; + pj_pool_t * pool; + pj_sockaddr def_addr; +} ice; + +struct erebos_ice_cfg +{ + pj_ice_strans_cfg cfg; + pj_thread_t * thread; + atomic_bool exit; +}; + +struct user_data +{ + pj_ice_sess_role role; + HsStablePtr sptr; + HsStablePtr cb_init; + HsStablePtr cb_connect; +}; + +static void ice_perror(const char * msg, pj_status_t status) +{ + char err[PJ_ERR_MSG_SIZE]; + pj_strerror(status, err, sizeof(err)); + fprintf(stderr, "ICE: %s: %s\n", msg, err); +} + +static int ice_worker_thread( void * vcfg ) +{ + struct erebos_ice_cfg * ecfg = (struct erebos_ice_cfg *)( vcfg ); + + while( ! ecfg->exit ){ + pj_time_val max_timeout = { 0, 0 }; + pj_time_val timeout = { 0, 0 }; + + max_timeout.msec = 500; + + pj_timer_heap_poll( ecfg->cfg.stun_cfg.timer_heap, &timeout ); + + pj_assert(timeout.sec >= 0 && timeout.msec >= 0); + if (timeout.msec >= 1000) + timeout.msec = 999; + + if (PJ_TIME_VAL_GT(timeout, max_timeout)) + timeout = max_timeout; + + int c = pj_ioqueue_poll( ecfg->cfg.stun_cfg.ioqueue, &timeout ); + if (c < 0) + pj_thread_sleep(PJ_TIME_VAL_MSEC(timeout)); + } + + return 0; +} + +static void cb_on_rx_data(pj_ice_strans * strans, unsigned comp_id, + void * pkt, pj_size_t size, + const pj_sockaddr_t * src_addr, unsigned src_addr_len) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + ice_rx_data(udata->sptr, pkt, size); +} + +static void cb_on_ice_complete(pj_ice_strans * strans, + pj_ice_strans_op op, pj_status_t status) +{ + if (status != PJ_SUCCESS) { + ice_perror("cb_on_ice_complete", status); + ice_destroy(strans); + return; + } + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (op == PJ_ICE_STRANS_OP_INIT) { + pj_status_t istatus = pj_ice_strans_init_ice(strans, udata->role, NULL, NULL); + if (istatus != PJ_SUCCESS) + ice_perror("error creating session", istatus); + + if (udata->cb_init) { + ice_call_cb(udata->cb_init); + hs_free_stable_ptr(udata->cb_init); + udata->cb_init = NULL; + } + } + + if (op == PJ_ICE_STRANS_OP_NEGOTIATION) { + if (udata->cb_connect) { + ice_call_cb(udata->cb_connect); + hs_free_stable_ptr(udata->cb_connect); + udata->cb_connect = NULL; + } + } +} + +static void ice_init(void) +{ + static bool done = false; + static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; + pthread_mutex_lock(&mutex); + + if (done) { + pthread_mutex_unlock(&mutex); + return; + } + + pj_log_set_level(1); + + if (pj_init() != PJ_SUCCESS) { + fprintf(stderr, "pj_init failed\n"); + goto exit; + } + if (pjlib_util_init() != PJ_SUCCESS) { + fprintf(stderr, "pjlib_util_init failed\n"); + goto exit; + } + if (pjnath_init() != PJ_SUCCESS) { + fprintf(stderr, "pjnath_init failed\n"); + goto exit; + } + + pj_caching_pool_init(&ice.cp, NULL, 0); + + ice.pool = pj_pool_create(&ice.cp.factory, "ice", 512, 512, NULL); + +exit: + done = true; + pthread_mutex_unlock(&mutex); +} + +struct erebos_ice_cfg * ice_cfg_create( const char * stun_server, uint16_t stun_port, + const char * turn_server, uint16_t turn_port ) +{ + ice_init(); + + struct erebos_ice_cfg * ecfg = malloc( sizeof(struct erebos_ice_cfg) ); + pj_ice_strans_cfg_default( &ecfg->cfg ); + ecfg->exit = false; + ecfg->thread = NULL; + + ecfg->cfg.stun_cfg.pf = &ice.cp.factory; + if( pj_timer_heap_create( ice.pool, 100, + &ecfg->cfg.stun_cfg.timer_heap ) != PJ_SUCCESS ){ + fprintf( stderr, "pj_timer_heap_create failed\n" ); + goto fail; + } + + if( pj_ioqueue_create( ice.pool, 16, &ecfg->cfg.stun_cfg.ioqueue ) != PJ_SUCCESS ){ + fprintf( stderr, "pj_ioqueue_create failed\n" ); + goto fail; + } + + if( pj_thread_create( ice.pool, NULL, &ice_worker_thread, + ecfg, 0, 0, &ecfg->thread ) != PJ_SUCCESS ){ + fprintf( stderr, "pj_thread_create failed\n" ); + goto fail; + } + + ecfg->cfg.af = pj_AF_INET(); + ecfg->cfg.opt.aggressive = PJ_TRUE; + + if( stun_server ){ + ecfg->cfg.stun.server.ptr = malloc( strlen( stun_server )); + pj_strcpy2( &ecfg->cfg.stun.server, stun_server ); + if( stun_port ) + ecfg->cfg.stun.port = stun_port; + } + + if( turn_server ){ + ecfg->cfg.turn.server.ptr = malloc( strlen( turn_server )); + pj_strcpy2( &ecfg->cfg.turn.server, turn_server ); + if( turn_port ) + ecfg->cfg.turn.port = turn_port; + ecfg->cfg.turn.auth_cred.type = PJ_STUN_AUTH_CRED_STATIC; + ecfg->cfg.turn.auth_cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN; + ecfg->cfg.turn.conn_type = PJ_TURN_TP_UDP; + } + + return ecfg; +fail: + ice_cfg_free( ecfg ); + return NULL; +} + +void ice_cfg_free( struct erebos_ice_cfg * ecfg ) +{ + if( ! ecfg ) + return; + + ecfg->exit = true; + pj_thread_join( ecfg->thread ); + + if( ecfg->cfg.turn.server.ptr ) + free( ecfg->cfg.turn.server.ptr ); + + if( ecfg->cfg.stun.server.ptr ) + free( ecfg->cfg.stun.server.ptr ); + + if( ecfg->cfg.stun_cfg.ioqueue ) + pj_ioqueue_destroy( ecfg->cfg.stun_cfg.ioqueue ); + + if( ecfg->cfg.stun_cfg.timer_heap ) + pj_timer_heap_destroy( ecfg->cfg.stun_cfg.timer_heap ); + + free( ecfg ); +} + +void ice_cfg_stop_thread( struct erebos_ice_cfg * ecfg ) +{ + if( ! ecfg ) + return; + ecfg->exit = true; +} + +pj_ice_strans * ice_create( const struct erebos_ice_cfg * ecfg, pj_ice_sess_role role, + HsStablePtr sptr, HsStablePtr cb ) +{ + ice_init(); + + pj_ice_strans * res; + + struct user_data * udata = calloc( 1, sizeof( struct user_data )); + udata->role = role; + udata->sptr = sptr; + udata->cb_init = cb; + + pj_ice_strans_cb icecb = { + .on_rx_data = cb_on_rx_data, + .on_ice_complete = cb_on_ice_complete, + }; + + pj_status_t status = pj_ice_strans_create( NULL, &ecfg->cfg, 1, + udata, &icecb, &res ); + + if (status != PJ_SUCCESS) + ice_perror("error creating ice", status); + + return res; +} + +void ice_destroy(pj_ice_strans * strans) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (udata->sptr) + hs_free_stable_ptr(udata->sptr); + if (udata->cb_init) + hs_free_stable_ptr(udata->cb_init); + if (udata->cb_connect) + hs_free_stable_ptr(udata->cb_connect); + free(udata); + + pj_ice_strans_stop_ice(strans); + pj_ice_strans_destroy(strans); +} + +ssize_t ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand) +{ + int n; + pj_str_t local_ufrag, local_pwd; + pj_status_t status; + + status = pj_ice_strans_get_ufrag_pwd( strans, &local_ufrag, &local_pwd, NULL, NULL ); + if( status != PJ_SUCCESS ) + return -status; + + n = snprintf(ufrag, maxlen, "%.*s", (int) local_ufrag.slen, local_ufrag.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + n = snprintf(pass, maxlen, "%.*s", (int) local_pwd.slen, local_pwd.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + pj_ice_sess_cand cand[PJ_ICE_ST_MAX_CAND]; + char ipaddr[PJ_INET6_ADDRSTRLEN]; + + status = pj_ice_strans_get_def_cand(strans, 1, &cand[0]); + if (status != PJ_SUCCESS) + return -status; + + n = snprintf(def, maxlen, "%s %d", + pj_sockaddr_print(&cand[0].addr, ipaddr, sizeof(ipaddr), 0), + (int) pj_sockaddr_get_port(&cand[0].addr)); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + unsigned cand_cnt = PJ_ARRAY_SIZE(cand); + status = pj_ice_strans_enum_cands(strans, 1, &cand_cnt, cand); + if (status != PJ_SUCCESS) + return -status; + + for (unsigned i = 0; i < cand_cnt && i < maxcand; i++) { + char ipaddr[PJ_INET6_ADDRSTRLEN]; + n = snprintf(candidates[i], maxlen, + "%.*s %u %s %u %s", + (int) cand[i].foundation.slen, cand[i].foundation.ptr, + cand[i].prio, + pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0), + (unsigned) pj_sockaddr_get_port(&cand[i].addr), + pj_ice_get_cand_type_name(cand[i].type)); + + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + } + + return cand_cnt; +} + +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * tcandidates[], size_t ncand) +{ + unsigned def_port = 0; + char def_addr[80]; + pj_bool_t done = PJ_FALSE; + char line[256]; + pj_ice_sess_cand candidates[PJ_ICE_ST_MAX_CAND]; + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + udata->cb_connect = cb; + + def_addr[0] = '\0'; + + if (ncand == 0) { + fprintf(stderr, "ICE: no candidates\n"); + return; + } + + int cnt = sscanf(defcand, "%s %u", def_addr, &def_port); + if (cnt != 2) { + fprintf(stderr, "ICE: error parsing default candidate\n"); + return; + } + + int okcand = 0; + for (int i = 0; i < ncand; i++) { + char foundation[32], ipaddr[80], type[32]; + int prio, port; + + int cnt = sscanf(tcandidates[i], "%s %d %s %d %s", + foundation, &prio, + ipaddr, &port, + type); + if (cnt != 5) + continue; + + pj_ice_sess_cand * cand = &candidates[okcand]; + pj_bzero(cand, sizeof(*cand)); + + if (strcmp(type, "host") == 0) + cand->type = PJ_ICE_CAND_TYPE_HOST; + else if (strcmp(type, "srflx") == 0) + cand->type = PJ_ICE_CAND_TYPE_SRFLX; + else if (strcmp(type, "relay") == 0) + cand->type = PJ_ICE_CAND_TYPE_RELAYED; + else + continue; + + cand->comp_id = 1; + pj_strdup2(ice.pool, &cand->foundation, foundation); + cand->prio = prio; + + int af = strchr(ipaddr, ':') ? pj_AF_INET6() : pj_AF_INET(); + pj_str_t tmpaddr = pj_str(ipaddr); + pj_sockaddr_init(af, &cand->addr, NULL, 0); + pj_status_t status = pj_sockaddr_set_str_addr(af, &cand->addr, &tmpaddr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid IP address \"%s\"\n", ipaddr); + continue; + } + + pj_sockaddr_set_port(&cand->addr, (pj_uint16_t)port); + okcand++; + } + + pj_str_t tmp_addr; + pj_status_t status; + + int af = strchr(def_addr, ':') ? pj_AF_INET6() : pj_AF_INET(); + + pj_sockaddr_init(af, &ice.def_addr, NULL, 0); + tmp_addr = pj_str(def_addr); + status = pj_sockaddr_set_str_addr(af, &ice.def_addr, &tmp_addr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid default IP address \"%s\"\n", def_addr); + return; + } + pj_sockaddr_set_port(&ice.def_addr, (pj_uint16_t) def_port); + + pj_str_t rufrag, rpwd; + status = pj_ice_strans_start_ice(strans, + pj_cstr(&rufrag, ufrag), pj_cstr(&rpwd, pass), + okcand, candidates); + if (status != PJ_SUCCESS) { + ice_perror("error starting ICE", status); + return; + } +} + +void ice_send(pj_ice_strans * strans, const char * data, size_t len) +{ + if (!pj_ice_strans_sess_is_complete(strans)) { + fprintf(stderr, "ICE: negotiation has not been started or is in progress\n"); + return; + } + + pj_status_t status = pj_ice_strans_sendto2(strans, 1, data, len, + &ice.def_addr, pj_sockaddr_get_len(&ice.def_addr)); + if (status != PJ_SUCCESS && status != PJ_EPENDING) + ice_perror("error sending data", status); +} diff --git a/src/Erebos/ICE/pjproject.h b/src/Erebos/ICE/pjproject.h new file mode 100644 index 0000000..c31e227 --- /dev/null +++ b/src/Erebos/ICE/pjproject.h @@ -0,0 +1,20 @@ +#pragma once + +#include <pjnath.h> +#include <HsFFI.h> + +struct erebos_ice_cfg * ice_cfg_create( const char * stun_server, uint16_t stun_port, + const char * turn_server, uint16_t turn_port ); +void ice_cfg_free( struct erebos_ice_cfg * cfg ); +void ice_cfg_stop_thread( struct erebos_ice_cfg * cfg ); + +pj_ice_strans * ice_create( const struct erebos_ice_cfg *, pj_ice_sess_role role, + HsStablePtr sptr, HsStablePtr cb ); +void ice_destroy(pj_ice_strans * strans); + +ssize_t ice_encode_session(pj_ice_strans *, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand); +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * candidates[], size_t ncand); +void ice_send(pj_ice_strans *, const char * data, size_t len); diff --git a/src/Erebos/Identity.hs b/src/Erebos/Identity.hs new file mode 100644 index 0000000..a3f17b5 --- /dev/null +++ b/src/Erebos/Identity.hs @@ -0,0 +1,401 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Erebos.Identity ( + Identity, ComposedIdentity, UnifiedIdentity, + IdentityData(..), ExtendedIdentityData(..), IdentityExtension(..), + idData, idDataF, idExtData, idExtDataF, + idName, idOwner, idUpdates, idKeyIdentity, idKeyMessage, + eiddBase, eiddStoredBase, + eiddName, eiddOwner, eiddKeyIdentity, eiddKeyMessage, + + emptyIdentityData, + emptyIdentityExtension, + createIdentity, + validateIdentity, validateIdentityF, validateIdentityFE, + validateExtendedIdentity, validateExtendedIdentityF, validateExtendedIdentityFE, + loadIdentity, loadMbIdentity, loadUnifiedIdentity, loadMbUnifiedIdentity, + + mergeIdentity, toUnifiedIdentity, toComposedIdentity, + updateIdentity, updateOwners, + sameIdentity, + + unfoldOwners, + finalOwner, + displayIdentity, +) where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity qualified as I +import Control.Monad.Reader + +import Data.Either +import Data.Foldable +import Data.Function +import Data.List +import Data.Maybe +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T + +import Erebos.PubKey +import Erebos.Storable +import Erebos.Storage.Merge +import Erebos.Util + +data Identity m = IdentityKind m => Identity + { idData_ :: m (Stored (Signed ExtendedIdentityData)) + , idName_ :: Maybe Text + , idOwner_ :: Maybe ComposedIdentity + , idUpdates_ :: [Stored (Signed ExtendedIdentityData)] + , idKeyIdentity_ :: Stored PublicKey + , idKeyMessage_ :: Stored PublicKey + } + +deriving instance Show (m (Stored (Signed ExtendedIdentityData))) => Show (Identity m) + +class (Functor f, Foldable f) => IdentityKind f where + ikFilterAncestors :: Storable a => f (Stored a) -> f (Stored a) + +instance IdentityKind I.Identity where + ikFilterAncestors = id + +instance IdentityKind [] where + ikFilterAncestors = filterAncestors + +type ComposedIdentity = Identity [] +type UnifiedIdentity = Identity I.Identity + +instance Eq (m (Stored (Signed ExtendedIdentityData))) => Eq (Identity m) where + (==) = (==) `on` (idData_ &&& idUpdates_) + +data IdentityData = IdentityData + { iddPrev :: [Stored (Signed IdentityData)] + , iddName :: Maybe Text + , iddOwner :: Maybe (Stored (Signed IdentityData)) + , iddKeyIdentity :: Stored PublicKey + , iddKeyMessage :: Maybe (Stored PublicKey) + } + deriving (Show) + +data IdentityExtension = IdentityExtension + { idePrev :: [Stored (Signed ExtendedIdentityData)] + , ideBase :: Stored (Signed IdentityData) + , ideName :: Maybe Text + , ideOwner :: Maybe (Stored (Signed ExtendedIdentityData)) + } + deriving (Show) + +data ExtendedIdentityData = BaseIdentityData IdentityData + | ExtendedIdentityData IdentityExtension + deriving (Show) + +baseToExtended :: Stored (Signed IdentityData) -> Stored (Signed ExtendedIdentityData) +baseToExtended = unsafeMapStored (unsafeMapSigned BaseIdentityData) + +instance Storable IdentityData where + store' idt = storeRec $ do + mapM_ (storeRef "SPREV") $ iddPrev idt + storeMbText "name" $ iddName idt + storeMbRef "owner" $ iddOwner idt + storeRef "key-id" $ iddKeyIdentity idt + storeMbRef "key-msg" $ iddKeyMessage idt + + load' = loadRec $ IdentityData + <$> loadRefs "SPREV" + <*> loadMbText "name" + <*> loadMbRef "owner" + <*> loadRef "key-id" + <*> loadMbRef "key-msg" + +instance Storable IdentityExtension where + store' IdentityExtension {..} = storeRec $ do + mapM_ (storeRef "SPREV") idePrev + storeRef "SBASE" ideBase + storeMbText "name" ideName + storeMbRef "owner" ideOwner + + load' = loadRec $ IdentityExtension + <$> loadRefs "SPREV" + <*> loadRef "SBASE" + <*> loadMbText "name" + <*> loadMbRef "owner" + +instance Storable ExtendedIdentityData where + store' (BaseIdentityData idata) = store' idata + store' (ExtendedIdentityData idata) = store' idata + + load' = msum + [ BaseIdentityData <$> load' + , ExtendedIdentityData <$> load' + ] + +instance Mergeable (Maybe ComposedIdentity) where + type Component (Maybe ComposedIdentity) = Signed ExtendedIdentityData + mergeSorted = validateExtendedIdentityF + toComponents = maybe [] idExtDataF + +idData :: UnifiedIdentity -> Stored (Signed IdentityData) +idData = I.runIdentity . idDataF + +idDataF :: Identity m -> m (Stored (Signed IdentityData)) +idDataF idt@Identity {} = ikFilterAncestors . fmap eiddStoredBase . idData_ $ idt + +idExtData :: UnifiedIdentity -> Stored (Signed ExtendedIdentityData) +idExtData = I.runIdentity . idExtDataF + +idExtDataF :: Identity m -> m (Stored (Signed ExtendedIdentityData)) +idExtDataF = idData_ + +idName :: Identity m -> Maybe Text +idName = idName_ + +idOwner :: Identity m -> Maybe ComposedIdentity +idOwner = idOwner_ + +idUpdates :: Identity m -> [Stored (Signed ExtendedIdentityData)] +idUpdates = idUpdates_ + +idKeyIdentity :: Identity m -> Stored PublicKey +idKeyIdentity = idKeyIdentity_ + +idKeyMessage :: Identity m -> Stored PublicKey +idKeyMessage = idKeyMessage_ + +eiddPrev :: ExtendedIdentityData -> [Stored (Signed ExtendedIdentityData)] +eiddPrev (BaseIdentityData idata) = baseToExtended <$> iddPrev idata +eiddPrev (ExtendedIdentityData IdentityExtension {..}) = baseToExtended ideBase : idePrev + +eiddBase :: ExtendedIdentityData -> IdentityData +eiddBase (BaseIdentityData idata) = idata +eiddBase (ExtendedIdentityData IdentityExtension {..}) = fromSigned ideBase + +eiddStoredBase :: Stored (Signed ExtendedIdentityData) -> Stored (Signed IdentityData) +eiddStoredBase ext = case fromSigned ext of + (BaseIdentityData idata) -> unsafeMapStored (unsafeMapSigned (const idata)) ext + (ExtendedIdentityData IdentityExtension {..}) -> ideBase + +eiddName :: ExtendedIdentityData -> Maybe Text +eiddName (BaseIdentityData idata) = iddName idata +eiddName (ExtendedIdentityData IdentityExtension {..}) = ideName + +eiddOwner :: ExtendedIdentityData -> Maybe (Stored (Signed ExtendedIdentityData)) +eiddOwner (BaseIdentityData idata) = baseToExtended <$> iddOwner idata +eiddOwner (ExtendedIdentityData IdentityExtension {..}) = ideOwner + +eiddKeyIdentity :: ExtendedIdentityData -> Stored PublicKey +eiddKeyIdentity = iddKeyIdentity . eiddBase + +eiddKeyMessage :: ExtendedIdentityData -> Maybe (Stored PublicKey) +eiddKeyMessage = iddKeyMessage . eiddBase + + +emptyIdentityData :: Stored PublicKey -> IdentityData +emptyIdentityData key = IdentityData + { iddName = Nothing + , iddPrev = [] + , iddOwner = Nothing + , iddKeyIdentity = key + , iddKeyMessage = Nothing + } + +emptyIdentityExtension :: Stored (Signed IdentityData) -> IdentityExtension +emptyIdentityExtension base = IdentityExtension + { idePrev = [] + , ideBase = base + , ideName = Nothing + , ideOwner = Nothing + } + +isExtension :: Stored (Signed ExtendedIdentityData) -> Bool +isExtension x = case fromSigned x of BaseIdentityData {} -> False + _ -> True + + +createIdentity :: Storage -> Maybe Text -> Maybe UnifiedIdentity -> IO UnifiedIdentity +createIdentity st name owner = do + (secret, public) <- generateKeys st + (_secretMsg, publicMsg) <- generateKeys st + + let signOwner :: Signed a -> ReaderT Storage IO (Signed a) + signOwner idd + | Just o <- owner = do + Just ownerSecret <- loadKeyMb (iddKeyIdentity $ fromSigned $ idData o) + signAdd ownerSecret idd + | otherwise = return idd + + Just identity <- flip runReaderT st $ do + baseData <- mstore =<< signOwner =<< sign secret =<< + mstore (emptyIdentityData public) + { iddOwner = idData <$> owner + , iddKeyMessage = Just publicMsg + } + let extOwner = do + odata <- idExtData <$> owner + guard $ isExtension odata + return odata + + validateExtendedIdentityF . I.Identity <$> + if isJust name || isJust extOwner + then mstore =<< signOwner =<< sign secret =<< + mstore . ExtendedIdentityData =<< return (emptyIdentityExtension baseData) + { ideName = name + , ideOwner = extOwner + } + else return $ baseToExtended baseData + return identity + +validateIdentity :: Stored (Signed IdentityData) -> Maybe UnifiedIdentity +validateIdentity = validateIdentityF . I.Identity + +validateIdentityF :: IdentityKind m => m (Stored (Signed IdentityData)) -> Maybe (Identity m) +validateIdentityF = either (const Nothing) Just . runExcept . validateIdentityFE + +validateIdentityFE :: IdentityKind m => m (Stored (Signed IdentityData)) -> Except String (Identity m) +validateIdentityFE = validateExtendedIdentityFE . fmap baseToExtended + +validateExtendedIdentity :: Stored (Signed ExtendedIdentityData) -> Maybe UnifiedIdentity +validateExtendedIdentity = validateExtendedIdentityF . I.Identity + +validateExtendedIdentityF :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Maybe (Identity m) +validateExtendedIdentityF = either (const Nothing) Just . runExcept . validateExtendedIdentityFE + +validateExtendedIdentityFE :: IdentityKind m => m (Stored (Signed ExtendedIdentityData)) -> Except String (Identity m) +validateExtendedIdentityFE mdata = do + let idata = ikFilterAncestors mdata + when (null idata) $ throwError "null data" + mapM_ verifySignatures $ gatherPrevious S.empty $ toList idata + Identity + <$> pure idata + <*> pure (lookupProperty eiddName idata) + <*> case lookupProperty eiddOwner idata of + Nothing -> return Nothing + Just owner -> return <$> validateExtendedIdentityFE [owner] + <*> pure [] + <*> pure (eiddKeyIdentity $ fromSigned $ minimum idata) + <*> case lookupProperty eiddKeyMessage idata of + Nothing -> throwError "no message key" + Just mk -> return mk + +loadIdentity :: String -> LoadRec ComposedIdentity +loadIdentity name = maybe (throwOtherError "identity validation failed") return . validateExtendedIdentityF =<< loadRefs name + +loadMbIdentity :: String -> LoadRec (Maybe ComposedIdentity) +loadMbIdentity name = return . validateExtendedIdentityF =<< loadRefs name + +loadUnifiedIdentity :: String -> LoadRec UnifiedIdentity +loadUnifiedIdentity name = maybe (throwOtherError "identity validation failed") return . validateExtendedIdentity =<< loadRef name + +loadMbUnifiedIdentity :: String -> LoadRec (Maybe UnifiedIdentity) +loadMbUnifiedIdentity name = return . (validateExtendedIdentity =<<) =<< loadMbRef name + + +gatherPrevious :: Set (Stored (Signed ExtendedIdentityData)) -> [Stored (Signed ExtendedIdentityData)] -> Set (Stored (Signed ExtendedIdentityData)) +gatherPrevious res (n:ns) | n `S.member` res = gatherPrevious res ns + | otherwise = gatherPrevious (S.insert n res) $ (eiddPrev $ fromSigned n) ++ ns +gatherPrevious res [] = res + +verifySignatures :: Stored (Signed ExtendedIdentityData) -> Except String () +verifySignatures sidd = do + let idd = fromSigned sidd + required = concat + [ [ eiddKeyIdentity idd ] + , map (eiddKeyIdentity . fromSigned) $ eiddPrev idd + , map (eiddKeyIdentity . fromSigned) $ toList $ eiddOwner idd + ] + unless (all (fromStored sidd `isSignedBy`) required) $ do + throwError "signature verification failed" + +lookupProperty :: forall a m. Foldable m => (ExtendedIdentityData -> Maybe a) -> m (Stored (Signed ExtendedIdentityData)) -> Maybe a +lookupProperty sel topHeads = findResult propHeads + where + findPropHeads :: Stored (Signed ExtendedIdentityData) -> [ Stored (Signed ExtendedIdentityData) ] + findPropHeads sobj | Just _ <- sel $ fromSigned sobj = [ sobj ] + | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj) + + propHeads :: [ Stored (Signed ExtendedIdentityData) ] + propHeads = filterAncestors $ findPropHeads =<< toList topHeads + + findResult :: [ Stored (Signed ExtendedIdentityData) ] -> Maybe a + findResult [] = Nothing + findResult xs = sel $ fromSigned $ minimum xs + +mergeIdentity :: (MonadStorage m, MonadError e m, FromErebosError e, MonadIO m) => Identity f -> m UnifiedIdentity +mergeIdentity idt | Just idt' <- toUnifiedIdentity idt = return idt' +mergeIdentity idt@Identity {..} = do + (owner, ownerData) <- case idOwner_ of + Nothing -> return (Nothing, Nothing) + Just cowner | Just owner <- toUnifiedIdentity cowner -> return (Just owner, Nothing) + | otherwise -> do owner <- mergeIdentity cowner + return (Just owner, Just $ idData owner) + + let public = idKeyIdentity idt + secret <- loadKey public + + unifiedBaseData <- + case toList $ idDataF idt of + [idata] -> return idata + idatas -> mstore =<< sign secret =<< mstore (emptyIdentityData public) + { iddPrev = idatas, iddOwner = ownerData } + + case filter isExtension $ toList $ idExtDataF idt of + [] -> return Identity { idData_ = I.Identity (baseToExtended unifiedBaseData), idOwner_ = toComposedIdentity <$> owner, .. } + extdata -> do + unifiedExtendedData <- mstore =<< sign secret =<< + (mstore . ExtendedIdentityData) (emptyIdentityExtension unifiedBaseData) + { idePrev = extdata } + return Identity { idData_ = I.Identity unifiedExtendedData, idOwner_ = toComposedIdentity <$> owner, .. } + + +toUnifiedIdentity :: Identity m -> Maybe UnifiedIdentity +toUnifiedIdentity Identity {..} + | [sdata] <- toList idData_ = Just Identity { idData_ = I.Identity sdata, .. } + | otherwise = Nothing + +toComposedIdentity :: Identity m -> ComposedIdentity +toComposedIdentity Identity {..} = Identity { idData_ = toList idData_ + , idOwner_ = toComposedIdentity <$> idOwner_ + , .. + } + +updateIdentity :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> ComposedIdentity +updateIdentity [] orig = toComposedIdentity orig +updateIdentity updates orig@Identity {} = + case validateExtendedIdentityF $ ourUpdates ++ idata of + Just updated -> updated + { idOwner_ = updateIdentity ownerUpdates <$> idOwner_ updated + , idUpdates_ = ownerUpdates + } + Nothing -> toComposedIdentity orig + where idata = toList $ idData_ orig + idataRoots = foldl' mergeUniq [] $ map storedRoots idata + (ourUpdates, ownerUpdates) = partitionEithers $ flip map (filterAncestors $ updates ++ idUpdates_ orig) $ + -- if an update is related to anything in idData_, use it here, otherwise push to owners + \u -> if storedRoots u `intersectsSorted` idataRoots + then Left u + else Right u + +updateOwners :: [Stored (Signed ExtendedIdentityData)] -> Identity m -> Identity m +updateOwners updates orig@Identity { idOwner_ = Just owner, idUpdates_ = cupdates } = + orig { idOwner_ = Just $ updateIdentity updates owner, idUpdates_ = filterAncestors (updates ++ cupdates) } +updateOwners _ orig@Identity { idOwner_ = Nothing } = orig + +sameIdentity :: (Foldable m, Foldable m') => Identity m -> Identity m' -> Bool +sameIdentity x y = intersectsSorted (roots x) (roots y) + where + roots idt = uniq $ sort $ concatMap storedRoots $ toList $ idDataF idt + + +unfoldOwners :: (Foldable m) => Identity m -> [ComposedIdentity] +unfoldOwners = unfoldr (fmap (\i -> (i, idOwner i))) . Just . toComposedIdentity + +finalOwner :: (Foldable m, Applicative m) => Identity m -> ComposedIdentity +finalOwner = last . unfoldOwners + +displayIdentity :: (Foldable m, Applicative m) => Identity m -> Text +displayIdentity identity = T.concat + [ T.intercalate (T.pack " / ") $ map (fromMaybe (T.pack "<unnamed>") . idName) owners + ] + where owners = reverse $ unfoldOwners identity diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs new file mode 100644 index 0000000..8da4c8d --- /dev/null +++ b/src/Erebos/Network.hs @@ -0,0 +1,1126 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Network ( + Server, + startServer, + stopServer, + getCurrentPeerList, + getNextPeerChange, + getServerAddresses, + ServerOptions(..), serverIdentity, defaultServerOptions, + + Peer, peerServer, peerStorage, + PeerAddress(..), peerAddress, + PeerIdentity(..), peerIdentity, + WaitingRef, wrDigest, + Service(..), + + PeerAddressType(..), + receivedFromCustomAddress, + + serverPeer, + serverPeerCustom, +#ifdef ENABLE_ICE_SUPPORT + serverPeerIce, +#endif + dropPeer, + isPeerDropped, + sendToPeer, sendManyToPeer, + sendToPeerStored, sendManyToPeerStored, + sendToPeerWith, + runPeerService, + modifyServiceGlobalState, + + discoveryPort, +) where + +import Control.Concurrent +import Control.Concurrent.STM +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Data.ByteString (ByteString) +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Function +import Data.IP qualified as IP +import Data.List +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Typeable +import Data.Word + +import Foreign.C.Types +import Foreign.Marshal.Alloc +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.Storable as F + +import GHC.Conc.Sync (unsafeIOToSTM) + +import Network.Socket hiding (ControlMessage) +import Network.Socket.ByteString qualified as S + +import Erebos.Error +#ifdef ENABLE_ICE_SUPPORT +import Erebos.ICE +#endif +import Erebos.Identity +import Erebos.Network.Channel +import Erebos.Network.Protocol +import Erebos.Object.Internal +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storage +import Erebos.Storage.Key +import Erebos.Storage.Merge + + +discoveryPort :: PortNumber +discoveryPort = 29665 + +discoveryMulticastGroup :: HostAddress6 +discoveryMulticastGroup = tupleToHostAddress6 (0xff12, 0xb6a4, 0x6b1f, 0x0969, 0xcaee, 0xacc2, 0x5c93, 0x73e1) -- ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1 + +announceIntervalSeconds :: Int +announceIntervalSeconds = 60 + + +data Server = Server + { serverStorage :: Storage + , serverOptions :: ServerOptions + , serverOrigHead :: Head LocalState + , serverIdentity_ :: MVar UnifiedIdentity + , serverThreads :: MVar [ThreadId] + , serverSocket :: MVar Socket + , serverRawPath :: SymFlow (PeerAddress, BC.ByteString) + , serverControlFlow :: Flow (ControlMessage PeerAddress) (ControlRequest PeerAddress) + , serverDataResponse :: TQueue (Peer, Maybe PartialRef) + , serverIOActions :: TQueue (ExceptT ErebosError IO ()) + , serverServices :: [SomeService] + , serverServiceStates :: TMVar (M.Map ServiceID SomeServiceGlobalState) + , serverPeers :: MVar (Map PeerAddress Peer) + , serverChanPeer :: TChan Peer + , serverErrorLog :: TQueue String + } + +serverIdentity :: Server -> IO UnifiedIdentity +serverIdentity = readMVar . serverIdentity_ + +getCurrentPeerList :: Server -> IO [Peer] +getCurrentPeerList = fmap M.elems . readMVar . serverPeers + +getNextPeerChange :: Server -> IO Peer +getNextPeerChange = atomically . readTChan . serverChanPeer + +data ServerOptions = ServerOptions + { serverPort :: PortNumber + , serverLocalDiscovery :: Bool + } + +defaultServerOptions :: ServerOptions +defaultServerOptions = ServerOptions + { serverPort = discoveryPort + , serverLocalDiscovery = True + } + + +data Peer = Peer + { peerAddress :: PeerAddress + , peerServer_ :: Server + , peerState :: TVar PeerState + , peerIdentityVar :: TVar PeerIdentity + , peerStorage_ :: Storage + , peerInStorage :: PartialStorage + , peerServiceState :: TMVar (M.Map ServiceID SomeServiceState) + , peerWaitingRefs :: TMVar [WaitingRef] + } + +peerServer :: Peer -> Server +peerServer = peerServer_ + +peerStorage :: Peer -> Storage +peerStorage = peerStorage_ + +getPeerChannel :: Peer -> STM ChannelState +getPeerChannel Peer {..} = + readTVar peerState >>= \case + PeerInit _ -> return ChannelNone + PeerConnected conn -> connGetChannel conn + PeerDropped -> return ChannelClosed + +setPeerChannel :: Peer -> ChannelState -> STM () +setPeerChannel Peer {..} ch = do + readTVar peerState >>= \case + PeerInit _ -> retry + PeerConnected conn -> connSetChannel conn ch + PeerDropped -> return () + +instance Eq Peer where + (==) = (==) `on` peerIdentityVar + +class (Eq addr, Ord addr, Show addr, Typeable addr) => PeerAddressType addr where + sendBytesToAddress :: addr -> ByteString -> IO () + +data PeerAddress + = forall addr. PeerAddressType addr => CustomPeerAddress addr + | DatagramAddress SockAddr +#ifdef ENABLE_ICE_SUPPORT + | PeerIceSession IceSession +#endif + +instance Show PeerAddress where + show (CustomPeerAddress addr) = show addr + + show (DatagramAddress saddr) = unwords $ case IP.fromSockAddr saddr of + Just (IP.IPv6 ipv6, port) + | (0, 0, 0xffff, ipv4) <- IP.fromIPv6w ipv6 + -> [show (IP.toIPv4w ipv4), show port] + Just (addr, port) + -> [show addr, show port] + _ -> [show saddr] + +#ifdef ENABLE_ICE_SUPPORT + show (PeerIceSession ice) = show ice +#endif + +instance Eq PeerAddress where + CustomPeerAddress addr == CustomPeerAddress addr' + | Just addr'' <- cast addr' = addr == addr'' + DatagramAddress addr == DatagramAddress addr' = addr == addr' +#ifdef ENABLE_ICE_SUPPORT + PeerIceSession ice == PeerIceSession ice' = ice == ice' +#endif + _ == _ = False + +instance Ord PeerAddress where + compare (CustomPeerAddress addr) (CustomPeerAddress addr') + | Just addr'' <- cast addr' = compare addr addr'' + | otherwise = compare (typeOf addr) (typeOf addr') + compare (CustomPeerAddress _ ) _ = LT + compare _ (CustomPeerAddress _ ) = GT + + compare (DatagramAddress addr) (DatagramAddress addr') = compare addr addr' +#ifdef ENABLE_ICE_SUPPORT + compare (DatagramAddress _ ) _ = LT + compare _ (DatagramAddress _ ) = GT + + compare (PeerIceSession ice ) (PeerIceSession ice') = compare ice ice' +#endif + + +data PeerIdentity = PeerIdentityUnknown (TVar [UnifiedIdentity -> ExceptT ErebosError IO ()]) + | PeerIdentityRef WaitingRef (TVar [UnifiedIdentity -> ExceptT ErebosError IO ()]) + | PeerIdentityFull UnifiedIdentity + +peerIdentity :: MonadIO m => Peer -> m PeerIdentity +peerIdentity = liftIO . atomically . readTVar . peerIdentityVar + + +data PeerState + = PeerInit [ ( SecurityRequirement, TransportPacket Ref, [ TransportHeaderItem ] ) ] + | PeerConnected (Connection PeerAddress) + | PeerDropped + + +lookupServiceType :: [TransportHeaderItem] -> Maybe ServiceID +lookupServiceType (ServiceType stype : _) = Just stype +lookupServiceType (_ : hs) = lookupServiceType hs +lookupServiceType [] = Nothing + +lookupNewStreams :: [TransportHeaderItem] -> [Word8] +lookupNewStreams (StreamOpen num : rest) = num : lookupNewStreams rest +lookupNewStreams (_ : rest) = lookupNewStreams rest +lookupNewStreams [] = [] + + +newWaitingRef :: RefDigest -> (Ref -> WaitingRefCallback) -> PacketHandler WaitingRef +newWaitingRef dgst act = do + peer@Peer {..} <- gets phPeer + wref <- WaitingRef peerStorage_ (partialRefFromDigest peerInStorage dgst) act <$> liftSTM (newTVar (Left [])) + modifyTMVarP peerWaitingRefs (wref:) + liftSTM $ writeTQueue (serverDataResponse $ peerServer peer) (peer, Nothing) + return wref + + +forkServerThread :: Server -> IO () -> IO () +forkServerThread server act = do + modifyMVar_ (serverThreads server) $ \ts -> do + t <- forkIO $ do + t <- myThreadId + act + modifyMVar_ (serverThreads server) $ return . filter (/=t) + return (t:ts) + +startServer :: ServerOptions -> Head LocalState -> (String -> IO ()) -> [SomeService] -> IO Server +startServer serverOptions serverOrigHead logd' serverServices = do + let serverStorage = headStorage serverOrigHead + serverIdentity_ <- newMVar $ headLocalIdentity serverOrigHead + serverThreads <- newMVar [] + serverSocket <- newEmptyMVar + (serverRawPath, protocolRawPath) <- newFlowIO + (serverControlFlow, protocolControlFlow) <- newFlowIO + serverDataResponse <- newTQueueIO + serverIOActions <- newTQueueIO + serverServiceStates <- newTMVarIO M.empty + serverPeers <- newMVar M.empty + serverChanPeer <- newTChanIO + serverErrorLog <- newTQueueIO + let server = Server {..} + + chanSvc <- newTQueueIO + + let logd = writeTQueue serverErrorLog + forkServerThread server $ forever $ do + logd' =<< atomically (readTQueue serverErrorLog) + + forkServerThread server $ dataResponseWorker server + forkServerThread server $ forever $ do + either (atomically . logd . showErebosError) return =<< runExceptT =<< + atomically (readTQueue serverIOActions) + + let open addr = do + sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + putMVar serverSocket sock + setSocketOption sock ReuseAddr 1 + setSocketOption sock Broadcast 1 + withFdSocket sock setCloseOnExecIfNeeded + bind sock (addrAddress addr) + return sock + + loop sock = do + when (serverLocalDiscovery serverOptions) $ forkServerThread server $ do + announceAddreses <- fmap concat $ sequence $ + [ map (SockAddrInet6 discoveryPort 0 discoveryMulticastGroup) <$> joinMulticast sock + , getBroadcastAddresses discoveryPort + ] + forever $ do + atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) announceAddreses + threadDelay $ announceIntervalSeconds * 1000 * 1000 + + let announceUpdate identity = do + st <- derivePartialStorage serverStorage + let selfRef = partialRef st $ storedRef $ idExtData identity + updateRefs = map refDigest $ selfRef : map (partialRef st . storedRef) (idUpdates identity) + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] + hitems = map AnnounceUpdate updateRefs + packet = TransportPacket (TransportHeader $ hitems) [] + + ps <- readMVar serverPeers + forM_ ps $ \peer -> atomically $ do + ((,) <$> readTVar (peerIdentityVar peer) <*> getPeerChannel peer) >>= \case + (PeerIdentityFull _, ChannelEstablished _) -> + sendToPeerS peer ackedBy packet + _ -> return () + + void $ watchHead serverOrigHead $ \h -> do + let idt = headLocalIdentity h + changedId <- modifyMVar serverIdentity_ $ \cur -> + return (idt, cur /= idt) + when changedId $ do + writeFlowIO serverControlFlow $ UpdateSelfIdentity idt + announceUpdate idt + + forM_ serverServices $ \(SomeService service _) -> do + forM_ (serviceStorageWatchers service) $ \case + SomeStorageWatcher sel act -> do + watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do + withMVar serverPeers $ mapM_ $ \peer -> atomically $ do + readTVar (peerIdentityVar peer) >>= \case + PeerIdentityFull _ -> writeTQueue serverIOActions $ do + runPeerService peer $ act x + _ -> return () + GlobalStorageWatcher sel act -> do + watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do + atomically $ writeTQueue serverIOActions $ do + act server x + + forkServerThread server $ forever $ do + (msg, saddr) <- S.recvFrom sock 4096 + writeFlowIO serverRawPath (DatagramAddress saddr, msg) + + forkServerThread server $ forever $ do + (paddr, msg) <- readFlowIO serverRawPath + handle (\(e :: SomeException) -> atomically . logd $ "failed to send packet to " ++ show paddr ++ ": " ++ show e) $ do + case paddr of + CustomPeerAddress addr -> sendBytesToAddress addr msg + DatagramAddress addr -> void $ S.sendTo sock msg addr +#ifdef ENABLE_ICE_SUPPORT + PeerIceSession ice -> iceSend ice msg +#endif + + forkServerThread server $ forever $ do + readFlowIO serverControlFlow >>= \case + NewConnection conn mbpid -> do + let paddr = connAddress conn + peer <- modifyMVar serverPeers $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, peer) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, peer) + + forkServerThread server $ do + atomically $ do + readTVar (peerState peer) >>= \case + PeerInit packets -> do + writeFlowBulk (connData conn) $ reverse packets + writeTVar (peerState peer) (PeerConnected conn) + PeerConnected _ -> do + writeTVar (peerState peer) (PeerConnected conn) + PeerDropped -> do + connClose conn + + case mbpid of + Just dgst -> do + identity <- readMVar serverIdentity_ + atomically $ runPacketHandler False peer $ do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan serverChanPeer peer + _ -> return () + Nothing -> return () + + let peerLoop = readFlowIO (connData conn) >>= \case + Just (secure, TransportPacket header objs) -> do + prefs <- forM objs $ storeObject $ peerInStorage peer + identity <- readMVar serverIdentity_ + let svcs = map someServiceID serverServices + handlePacket identity secure peer chanSvc svcs header prefs + peerLoop + Nothing -> do + dropPeer peer + atomically $ writeTChan serverChanPeer peer + peerLoop + + ReceivedAnnounce addr _ -> do + void $ serverPeer' server addr + + erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd protocolRawPath protocolControlFlow + + forkServerThread server $ withSocketsDo $ do + let hints = defaultHints + { addrFlags = [AI_PASSIVE] + , addrFamily = AF_INET6 + , addrSocketType = Datagram + } + addr:_ <- getAddrInfo (Just hints) Nothing (Just $ show $ serverPort serverOptions) + bracket (open addr) close loop + + forkServerThread server $ forever $ do + ( peer, svc, ref, streams ) <- atomically $ readTQueue chanSvc + case find ((svc ==) . someServiceID) serverServices of + Just service@(SomeService (_ :: Proxy s) attr) -> runPeerServiceOn (Just ( service, attr )) streams peer (serviceHandler $ wrappedLoad @s ref) + _ -> atomically $ logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + + return server + +stopServer :: Server -> IO () +stopServer server@Server {..} = do + withMVar serverPeers $ \peers -> do + ( global, peerStates ) <- atomically $ (,) + <$> takeTMVar serverServiceStates + <*> (forM (M.elems peers) $ \p@Peer {..} -> ( p, ) <$> takeTMVar peerServiceState) + + forM_ global $ \(SomeServiceGlobalState (proxy :: Proxy s) gs) -> do + ps <- forM peerStates $ \( peer, states ) -> + return $ ( peer, ) $ case M.lookup (serviceID proxy) states of + Just (SomeServiceState (_ :: Proxy ps) pstate) + | Just (Refl :: s :~: ps) <- eqT + -> pstate + _ -> emptyServiceState proxy + serviceStopServer proxy server gs ps + mapM_ killThread =<< takeMVar serverThreads + +dataResponseWorker :: Server -> IO () +dataResponseWorker server = forever $ do + (peer, npref) <- atomically (readTQueue $ serverDataResponse server) + + wait <- atomically $ takeTMVar (peerWaitingRefs peer) + list <- forM wait $ \wr@WaitingRef { wrefStatus = tvar } -> + atomically (readTVar tvar) >>= \case + Left ds -> case maybe id (filter . (/=) . refDigest) npref $ ds of + [] -> copyRef (wrefStorage wr) (wrefPartial wr) >>= \case + Right ref -> do + atomically (writeTVar tvar $ Right ref) + forkServerThread server $ runExceptT (wrefAction wr ref) >>= \case + Left err -> atomically $ writeTQueue (serverErrorLog server) (showErebosError err) + Right () -> return () + + return (Nothing, []) + Left dgst -> do + atomically (writeTVar tvar $ Left [dgst]) + return (Just wr, [dgst]) + ds' -> do + atomically (writeTVar tvar $ Left ds') + return (Just wr, []) + Right _ -> return (Nothing, []) + atomically $ putTMVar (peerWaitingRefs peer) $ catMaybes $ map fst list + + let reqs = concat $ map snd list + when (not $ null reqs) $ do + let packet = TransportPacket (TransportHeader $ map DataRequest reqs) [] + ackedBy = concat [[ Rejected r, DataResponse r ] | r <- reqs ] + atomically $ sendToPeerPlain peer ackedBy packet + + +newtype PacketHandler a = PacketHandler { unPacketHandler :: StateT PacketHandlerState (ExceptT String STM) a } + deriving (Functor, Applicative, Monad, MonadState PacketHandlerState, MonadError String) + +instance MonadFail PacketHandler where + fail = throwError + +runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM () +runPacketHandler secure peer@Peer {..} act = do + let logd = writeTQueue $ serverErrorLog peerServer_ + runExceptT (flip execStateT (PacketHandlerState peer [] [] [] Nothing False) $ unPacketHandler act) >>= \case + Left err -> do + logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err + Right ph -> do + when (not $ null $ phHead ph) $ do + body <- case phBodyStream ph of + Nothing -> return $ phBody ph + Just stream -> do + writeTQueue (serverIOActions peerServer_) $ void $ liftIO $ forkIO $ do + writeByteStringToStream stream $ BL.concat $ map lazyLoadBytes $ phBody ph + return [] + let packet = TransportPacket (TransportHeader $ phHead ph) body + secreq = case (secure, phPlaintextReply ph) of + (True, _) -> EncryptedOnly + (False, False) -> PlaintextAllowed + (False, True) -> PlaintextOnly + sendToPeerS' secreq peer (phAckedBy ph) packet + +liftSTM :: STM a -> PacketHandler a +liftSTM = PacketHandler . lift . lift + +readTVarP :: TVar a -> PacketHandler a +readTVarP = liftSTM . readTVar + +writeTVarP :: TVar a -> a -> PacketHandler () +writeTVarP v = liftSTM . writeTVar v + +modifyTMVarP :: TMVar a -> (a -> a) -> PacketHandler () +modifyTMVarP v f = liftSTM $ putTMVar v . f =<< takeTMVar v + +data PacketHandlerState = PacketHandlerState + { phPeer :: Peer + , phHead :: [TransportHeaderItem] + , phAckedBy :: [TransportHeaderItem] + , phBody :: [Ref] + , phBodyStream :: Maybe RawStreamWriter + , phPlaintextReply :: Bool + } + +addHeader :: TransportHeaderItem -> PacketHandler () +addHeader h = modify $ \ph -> ph { phHead = h `appendDistinct` phHead ph } + +addAckedBy :: [TransportHeaderItem] -> PacketHandler () +addAckedBy hs = modify $ \ph -> ph { phAckedBy = foldr appendDistinct (phAckedBy ph) hs } + +addBody :: Ref -> PacketHandler () +addBody r = modify $ \ph -> ph { phBody = r `appendDistinct` phBody ph } + +sendBodyAsStream :: PacketHandler () +sendBodyAsStream = do + gets phBodyStream >>= \case + Nothing -> do + stream <- openStream + modify $ \ph -> ph { phBodyStream = Just stream } + Just _ -> return () + +keepPlaintextReply :: PacketHandler () +keepPlaintextReply = modify $ \ph -> ph { phPlaintextReply = True } + +openStream :: PacketHandler RawStreamWriter +openStream = do + Peer {..} <- gets phPeer + conn <- readTVarP peerState >>= \case + PeerConnected conn -> return conn + _ -> throwError "can't open stream without established connection" + (hdr, writer, handler) <- liftEither =<< liftSTM (connAddWriteStream conn) + + liftSTM $ writeTQueue (serverIOActions peerServer_) (liftIO $ forkServerThread peerServer_ handler) + addHeader hdr + return writer + +acceptStream :: Word8 -> PacketHandler RawStreamReader +acceptStream streamNumber = do + Peer {..} <- gets phPeer + conn <- readTVarP peerState >>= \case + PeerConnected conn -> return conn + _ -> throwError "can't accept stream without established connection" + liftSTM $ connAddReadStream conn streamNumber + +appendDistinct :: Eq a => a -> [a] -> [a] +appendDistinct x (y:ys) | x == y = y : ys + | otherwise = y : appendDistinct x ys +appendDistinct x [] = [x] + +handlePacket :: UnifiedIdentity -> Bool + -> Peer -> TQueue ( Peer, ServiceID, Ref, [ RawStreamReader ]) -> [ ServiceID ] + -> TransportHeader -> [ PartialRef ] -> IO () +handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = atomically $ do + let server = peerServer peer + ochannel <- getPeerChannel peer + let sidentity = idData identity + plaintextRefs = map (refDigest . storedRef) $ concatMap (collectStoredObjects . wrappedLoad) $ concat + [ [ storedRef sidentity ] + , map storedRef $ idUpdates identity + , case ochannel of + ChannelOurRequest req -> [ storedRef req ] + ChannelOurAccept acc _ -> [ storedRef acc ] + _ -> [] + ] + + runPacketHandler secure peer $ do + let logd = liftSTM . writeTQueue (serverErrorLog server) + forM_ headers $ \case + Acknowledged dgst -> do + liftSTM (getPeerChannel peer) >>= \case + ChannelOurAccept acc ch | refDigest (storedRef acc) == dgst -> do + liftSTM $ finalizedChannel peer ch identity + _ -> return () + + Rejected dgst + | peerRequest : _ <- mapMaybe (\case TrChannelRequest d -> Just d; _ -> Nothing) headers + , peerRequest < dgst + -> return () -- Our request was rejected due to lower priority + + | otherwise -> logd $ "rejected by peer: " ++ show dgst + + DataRequest dgst + | secure || dgst `elem` plaintextRefs -> do + Right mref <- liftSTM $ unsafeIOToSTM $ + copyRef (peerStorage peer) $ + partialRefFromDigest (peerInStorage peer) dgst + addHeader $ DataResponse dgst + addAckedBy [ Acknowledged dgst, Rejected dgst ] + + -- Plaintext request may indicate the peer has restarted/changed or + -- otherwise lost the channel, so keep the reply plaintext as well. + when (not secure) keepPlaintextReply + + -- TODO: MTU + when (secure && BL.length (lazyLoadBytes mref) > 500) + sendBodyAsStream + + addBody $ mref + | otherwise -> do + logd $ "unauthorized data request for " ++ show dgst + addHeader $ Rejected dgst + + DataResponse dgst -> if + | Just pref <- find ((==dgst) . refDigest) prefs -> do + when (not secure) $ do + addHeader $ Acknowledged dgst + liftSTM $ writeTQueue (serverDataResponse server) (peer, Just pref) + + | streamNumber : _ <- lookupNewStreams headers -> do + streamReader <- acceptStream streamNumber + liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do + (runExcept <$> readObjectsFromStream (peerInStorage peer) streamReader) >>= \case + Left err -> atomically $ writeTQueue (serverErrorLog server) $ + "failed to receive object from stream: " <> showErebosError err + Right objs -> do + forM_ objs $ \obj -> do + pref <- storeObject (peerInStorage peer) obj + atomically $ writeTQueue (serverDataResponse server) (peer, Just pref) + + | otherwise -> throwError $ "mismatched data response " ++ show dgst + + AnnounceSelf dgst + | dgst == refDigest (storedRef sidentity) -> return () + | otherwise -> do + wref <- newWaitingRef dgst $ handleIdentityAnnounce identity peer + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityUnknown idwait -> do + addHeader $ AnnounceSelf $ refDigest $ storedRef $ idData identity + writeTVarP (peerIdentityVar peer) $ PeerIdentityRef wref idwait + liftSTM $ writeTChan (serverChanPeer $ peerServer peer) peer + _ -> return () + + AnnounceUpdate dgst -> do + readTVarP (peerIdentityVar peer) >>= \case + PeerIdentityFull _ -> do + void $ newWaitingRef dgst $ handleIdentityUpdate peer + _ -> return () + + TrChannelRequest dgst -> do + let process = do + addHeader $ Acknowledged dgst + wref <- newWaitingRef dgst $ handleChannelRequest peer identity + liftSTM $ setPeerChannel peer $ ChannelPeerRequest wref + reject = addHeader $ Rejected dgst + + liftSTM (getPeerChannel peer) >>= \case + ChannelNone {} -> return () + ChannelCookieWait {} -> return () + ChannelCookieReceived {} -> process + ChannelCookieConfirmed {} -> process + ChannelOurRequest our + | dgst < refDigest (storedRef our) -> process + | otherwise -> do + -- Reject peer channel request with lower priority + addHeader $ TrChannelRequest $ refDigest $ storedRef our + reject + ChannelPeerRequest prev + | dgst == wrDigest prev -> addHeader $ Acknowledged dgst + | otherwise -> process + ChannelOurAccept {} -> reject + ChannelEstablished {} -> process + ChannelClosed {} -> return () + + TrChannelAccept dgst -> do + let process = do + handleChannelAccept identity $ partialRefFromDigest (peerInStorage peer) dgst + reject = addHeader $ Rejected dgst + liftSTM (getPeerChannel peer) >>= \case + ChannelNone {} -> reject + ChannelCookieWait {} -> reject + ChannelCookieReceived {} -> reject + ChannelCookieConfirmed {} -> reject + ChannelOurRequest {} -> process + ChannelPeerRequest {} -> process + ChannelOurAccept our _ | dgst < refDigest (storedRef our) -> process + | otherwise -> addHeader $ Rejected dgst + ChannelEstablished {} -> process + ChannelClosed {} -> return () + + ServiceType _ -> return () + ServiceRef dgst + | not secure -> throwError $ "service packet without secure channel" + | Just svc <- lookupServiceType headers -> if + | svc `elem` svcs -> do + if dgst `elem` map refDigest prefs || True {- TODO: used by Message service to confirm receive -} + then do + streamReaders <- mapM acceptStream $ lookupNewStreams headers + void $ newWaitingRef dgst $ \ref -> + liftIO $ atomically $ writeTQueue chanSvc ( peer, svc, ref, streamReaders ) + else throwError $ "missing service object " ++ show dgst + | otherwise -> addHeader $ Rejected dgst + | otherwise -> throwError $ "service ref without type" + + _ -> return () + + +withPeerIdentity :: MonadIO m => Peer -> (UnifiedIdentity -> ExceptT ErebosError IO ()) -> m () +withPeerIdentity peer act = liftIO $ atomically $ readTVar (peerIdentityVar peer) >>= \case + PeerIdentityUnknown tvar -> modifyTVar' tvar (act:) + PeerIdentityRef _ tvar -> modifyTVar' tvar (act:) + PeerIdentityFull idt -> writeTQueue (serverIOActions $ peerServer peer) (act idt) + + +setupChannel :: UnifiedIdentity -> Peer -> UnifiedIdentity -> WaitingRefCallback +setupChannel identity peer upid = do + req <- flip runReaderT (peerStorage peer) $ createChannelRequest identity upid + let reqref = refDigest $ storedRef req + let hitems = + [ TrChannelRequest reqref + , AnnounceSelf $ refDigest $ storedRef $ idData identity + ] + let sendChannelRequest = do + sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ + TransportPacket (TransportHeader hitems) [storedRef req] + setPeerChannel peer $ ChannelOurRequest req + liftIO $ atomically $ do + getPeerChannel peer >>= \case + ChannelCookieReceived -> sendChannelRequest + ChannelCookieConfirmed -> sendChannelRequest + _ -> return () + +handleChannelRequest :: Peer -> UnifiedIdentity -> Ref -> WaitingRefCallback +handleChannelRequest peer identity req = do + withPeerIdentity peer $ \upid -> do + (acc, ch) <- flip runReaderT (peerStorage peer) $ acceptChannelRequest identity upid (wrappedLoad req) + liftIO $ atomically $ do + getPeerChannel peer >>= \case + ChannelPeerRequest wr | wrDigest wr == refDigest req -> do + setPeerChannel peer $ ChannelOurAccept acc ch + let accref = refDigest $ storedRef acc + header = TrChannelAccept accref + ackedBy = [ Acknowledged accref, Rejected accref ] + sendToPeerPlain peer ackedBy $ TransportPacket (TransportHeader [header]) $ concat + [ [ storedRef $ acc ] + , [ storedRef $ signedData $ fromStored acc ] + , [ storedRef $ caKey $ fromStored $ signedData $ fromStored acc ] + , map storedRef $ signedSignature $ fromStored acc + ] + _ -> writeTQueue (serverErrorLog $ peerServer peer) $ "unexpected channel request" + +handleChannelAccept :: UnifiedIdentity -> PartialRef -> PacketHandler () +handleChannelAccept identity accref = do + peer <- gets phPeer + liftSTM $ writeTQueue (serverIOActions $ peerServer peer) $ do + withPeerIdentity peer $ \upid -> do + copyRef (peerStorage peer) accref >>= \case + Right acc -> do + ch <- acceptedChannel identity upid (wrappedLoad acc) + liftIO $ atomically $ do + sendToPeerS peer [] $ TransportPacket (TransportHeader [Acknowledged $ refDigest accref]) [] + finalizedChannel peer ch identity + + Left dgst -> throwOtherError $ "missing accept data " ++ BC.unpack (showRefDigest dgst) + + +finalizedChannel :: Peer -> Channel -> UnifiedIdentity -> STM () +finalizedChannel peer@Peer {..} ch self = do + setPeerChannel peer $ ChannelEstablished ch + + -- Identity update + writeTQueue (serverIOActions peerServer_) $ liftIO $ atomically $ do + let selfRef = refDigest $ storedRef $ idExtData $ self + updateRefs = selfRef : map (refDigest . storedRef) (idUpdates self) + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- updateRefs ] + sendToPeerS peer ackedBy $ flip TransportPacket [] $ TransportHeader $ map AnnounceUpdate updateRefs + + -- Notify services about new peer + readTVar peerIdentityVar >>= \case + PeerIdentityFull _ -> notifyServicesOfPeer peer + _ -> return () + + +handleIdentityAnnounce :: UnifiedIdentity -> Peer -> Ref -> WaitingRefCallback +handleIdentityAnnounce self peer ref = liftIO $ atomically $ do + let validateAndUpdate upds act = case validateIdentity $ wrappedLoad ref of + Just pid' -> do + let pid = fromMaybe pid' $ toUnifiedIdentity (updateIdentity upds pid') + writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid + writeTChan (serverChanPeer $ peerServer peer) peer + act pid + writeTQueue (serverIOActions $ peerServer peer) $ do + setupChannel self peer pid + Nothing -> return () + + readTVar (peerIdentityVar peer) >>= \case + PeerIdentityRef wref wact + | wrDigest wref == refDigest ref + -> validateAndUpdate [] $ \pid -> do + mapM_ (writeTQueue (serverIOActions $ peerServer peer) . ($ pid)) . + reverse =<< readTVar wact + + PeerIdentityFull pid + | idData pid `precedes` wrappedLoad ref + -> validateAndUpdate (idUpdates pid) $ \_ -> do + notifyServicesOfPeer peer + + _ -> return () + +handleIdentityUpdate :: Peer -> Ref -> WaitingRefCallback +handleIdentityUpdate peer ref = liftIO $ atomically $ do + pidentity <- readTVar (peerIdentityVar peer) + if | PeerIdentityFull pid <- pidentity + , Just pid' <- toUnifiedIdentity $ updateIdentity [wrappedLoad ref] pid + -> do + writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid' + writeTChan (serverChanPeer $ peerServer peer) peer + when (idData pid /= idData pid') $ notifyServicesOfPeer peer + + | otherwise -> return () + +notifyServicesOfPeer :: Peer -> STM () +notifyServicesOfPeer peer@Peer { peerServer_ = Server {..} } = do + writeTQueue serverIOActions $ do + forM_ serverServices $ \service@(SomeService _ attrs) -> + runPeerServiceOn (Just ( service, attrs )) [] peer serviceNewPeer + + +receivedFromCustomAddress :: PeerAddressType addr => Server -> addr -> ByteString -> IO () +receivedFromCustomAddress Server {..} addr msg = do + writeFlowIO serverRawPath ( CustomPeerAddress addr, msg ) + +mkPeer :: Server -> PeerAddress -> IO Peer +mkPeer peerServer_ peerAddress = do + peerState <- newTVarIO (PeerInit []) + peerIdentityVar <- newTVarIO . PeerIdentityUnknown =<< newTVarIO [] + peerStorage_ <- deriveEphemeralStorage $ serverStorage peerServer_ + peerInStorage <- derivePartialStorage peerStorage_ + peerServiceState <- newTMVarIO M.empty + peerWaitingRefs <- newTMVarIO [] + return Peer {..} + +serverPeer :: Server -> SockAddr -> IO Peer +serverPeer server paddr = do + let paddr' = case IP.fromSockAddr paddr of + Just (IP.IPv4 ipv4, port) + -> IP.toSockAddr (IP.IPv6 $ IP.toIPv6w (0, 0, 0xffff, IP.fromIPv4w ipv4), port) + _ -> paddr + serverPeer' server (DatagramAddress paddr') + +serverPeerCustom :: PeerAddressType addr => Server -> addr -> IO Peer +serverPeerCustom server addr = serverPeer' server (CustomPeerAddress addr) + +#ifdef ENABLE_ICE_SUPPORT +serverPeerIce :: Server -> IceSession -> IO Peer +serverPeerIce server@Server {..} ice = do + let paddr = PeerIceSession ice + peer <- serverPeer' server paddr + iceSetChan ice $ mapFlow undefined (paddr,) serverRawPath + return peer +#endif + +serverPeer' :: Server -> PeerAddress -> IO Peer +serverPeer' server paddr = do + (peer, hello) <- modifyMVar (serverPeers server) $ \pvalue -> do + case M.lookup paddr pvalue of + Just peer -> return (pvalue, (peer, False)) + Nothing -> do + peer <- mkPeer server paddr + return (M.insert paddr peer pvalue, (peer, True)) + when hello $ atomically $ do + writeFlow (serverControlFlow server) (RequestConnection paddr) + return peer + +dropPeer :: MonadIO m => Peer -> m () +dropPeer peer = liftIO $ do + modifyMVar_ (serverPeers $ peerServer peer) $ \pvalue -> do + atomically $ do + readTVar (peerState peer) >>= \case + PeerConnected conn -> connClose conn + _ -> return() + writeTVar (peerState peer) PeerDropped + return $ M.delete (peerAddress peer) pvalue + +isPeerDropped :: MonadIO m => Peer -> m Bool +isPeerDropped peer = liftIO $ atomically $ readTVar (peerState peer) >>= \case + PeerDropped -> return True + _ -> return False + +sendToPeer :: (Service s, MonadIO m) => Peer -> s -> m () +sendToPeer peer = sendManyToPeer peer . (: []) + +sendManyToPeer :: (Service s, MonadIO m) => Peer -> [ s ] -> m () +sendManyToPeer peer = sendToPeerList peer . map (\part -> ServiceReply (Left part) True) + +sendToPeerStored :: (Service s, MonadIO m) => Peer -> Stored s -> m () +sendToPeerStored peer = sendManyToPeerStored peer . (: []) + +sendManyToPeerStored :: (Service s, MonadIO m) => Peer -> [ Stored s ] -> m () +sendManyToPeerStored peer = sendToPeerList peer . map (\part -> ServiceReply (Right part) True) + +sendToPeerList :: (Service s, MonadIO m) => Peer -> [ ServiceReply s ] -> m () +sendToPeerList peer parts = do + let st = peerStorage peer + res <- runExceptT $ do + srefs <- liftIO $ fmap catMaybes $ forM parts $ \case + ServiceReply (Left x) use -> Just . (,use) <$> store st x + ServiceReply (Right sx) use -> return $ Just (storedRef sx, use) + _ -> return Nothing + + streamHeaders <- concat <$> do + (liftEither =<<) $ liftIO $ atomically $ runExceptT $ do + forM parts $ \case + ServiceOpenStream cb -> do + conn <- lift (readTVar (peerState peer)) >>= \case + PeerConnected conn -> return conn + _ -> throwError "can't open stream without established connection" + (hdr, writer, handler) <- liftEither =<< lift (connAddWriteStream conn) + + lift $ writeTQueue (serverIOActions (peerServer peer)) $ do + liftIO $ forkServerThread (peerServer peer) handler + return [ ( hdr, cb writer ) ] + _ -> return [] + liftIO $ sequence_ $ map snd streamHeaders + + liftIO $ forM_ parts $ \case + ServiceFinally act -> act + _ -> return () + + let dgsts = map (refDigest . fst) srefs + let content = map fst $ filter (\(ref, use) -> use && BL.length (lazyLoadBytes ref) < 500) srefs -- TODO: MTU + header = TransportHeader $ concat + [ [ ServiceType (serviceID $ head parts) ] + , map ServiceRef dgsts + , map fst streamHeaders + ] + packet = TransportPacket header content + ackedBy = concat [[ Acknowledged r, Rejected r, DataRequest r ] | r <- dgsts ] + liftIO $ atomically $ sendToPeerS peer ackedBy packet + + case res of + Right () -> return () + Left err -> liftIO $ atomically $ writeTQueue (serverErrorLog $ peerServer peer) $ + "failed to send packet to " <> show (peerAddress peer) <> ": " <> err + +sendToPeerS' :: SecurityRequirement -> Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS' secure Peer {..} ackedBy packet = do + readTVar peerState >>= \case + PeerInit xs -> writeTVar peerState $ PeerInit $ (secure, packet, ackedBy) : xs + PeerConnected conn -> writeFlow (connData conn) (secure, packet, ackedBy) + PeerDropped -> return () + +sendToPeerS :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerS = sendToPeerS' EncryptedOnly + +sendToPeerPlain :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () +sendToPeerPlain = sendToPeerS' PlaintextAllowed + +sendToPeerWith :: forall s m e. (Service s, MonadIO m, MonadError e m, FromErebosError e) => Peer -> (ServiceState s -> ExceptT ErebosError IO (Maybe s, ServiceState s)) -> m () +sendToPeerWith peer fobj = do + let sproxy = Proxy @s + sid = serviceID sproxy + res <- liftIO $ do + svcs <- atomically $ takeTMVar (peerServiceState peer) + (svcs', res) <- runExceptT (fobj $ fromMaybe (emptyServiceState sproxy) $ fromServiceState sproxy =<< M.lookup sid svcs) >>= \case + Right (obj, s') -> return $ (M.insert sid (SomeServiceState sproxy s') svcs, Right obj) + Left err -> return $ (svcs, Left err) + atomically $ putTMVar (peerServiceState peer) svcs' + return res + + case res of + Right (Just obj) -> sendToPeer peer obj + Right Nothing -> return () + Left err -> throwError $ fromErebosError err + + +lookupService :: forall s proxy. Service s => proxy s -> [SomeService] -> Maybe (SomeService, ServiceAttributes s) +lookupService proxy (service@(SomeService (_ :: Proxy t) attr) : rest) + | Just (Refl :: s :~: t) <- eqT = Just (service, attr) + | otherwise = lookupService proxy rest +lookupService _ [] = Nothing + +runPeerService :: forall s m. (Service s, MonadIO m) => Peer -> ServiceHandler s () -> m () +runPeerService = runPeerServiceOn Nothing [] + +runPeerServiceOn :: forall s m. (Service s, MonadIO m) => Maybe ( SomeService, ServiceAttributes s ) -> [ RawStreamReader ] -> Peer -> ServiceHandler s () -> m () +runPeerServiceOn mbservice newStreams peer handler = liftIO $ do + let server = peerServer peer + proxy = Proxy @s + svc = serviceID proxy + logd = writeTQueue (serverErrorLog server) + case mbservice `mplus` lookupService proxy (serverServices server) of + Just (service, attr) -> + atomically (readTVar (peerIdentityVar peer)) >>= \case + PeerIdentityFull peerId -> do + (global, svcs) <- atomically $ (,) + <$> takeTMVar (serverServiceStates server) + <*> takeTMVar (peerServiceState peer) + case (fromMaybe (someServiceEmptyState service) $ M.lookup svc svcs, + fromMaybe (someServiceEmptyGlobalState service) $ M.lookup svc global) of + ((SomeServiceState (_ :: Proxy ps) ps), + (SomeServiceGlobalState (_ :: Proxy gs) gs)) -> do + Just (Refl :: s :~: ps) <- return $ eqT + Just (Refl :: s :~: gs) <- return $ eqT + + let inp = ServiceInput + { svcAttributes = attr + , svcPeer = peer + , svcPeerIdentity = peerId + , svcServer = server + , svcPrintOp = atomically . logd + , svcNewStreams = newStreams + } + reloadHead (serverOrigHead server) >>= \case + Nothing -> atomically $ do + logd $ "current head deleted" + putTMVar (peerServiceState peer) svcs + putTMVar (serverServiceStates server) global + Just h -> do + (rsp, (s', gs')) <- runServiceHandler h inp ps gs handler + moveKeys (peerStorage peer) (serverStorage server) + when (not (null rsp)) $ do + sendToPeerList peer rsp + atomically $ do + putTMVar (peerServiceState peer) $ M.insert svc (SomeServiceState proxy s') svcs + putTMVar (serverServiceStates server) $ M.insert svc (SomeServiceGlobalState proxy gs') global + _ -> do + atomically $ logd $ "can't run service handler on peer with incomplete identity " ++ show (peerAddress peer) + + _ -> atomically $ do + logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" + +modifyServiceGlobalState + :: forall s a m e proxy. (Service s, MonadIO m, MonadError e m, FromErebosError e) + => Server -> proxy s + -> (ServiceGlobalState s -> ( ServiceGlobalState s, a )) + -> m a +modifyServiceGlobalState server proxy f = do + let svc = serviceID proxy + case lookupService proxy (serverServices server) of + Just ( service, _ ) -> do + liftIO $ atomically $ do + global <- takeTMVar (serverServiceStates server) + ( global', res ) <- case fromMaybe (someServiceEmptyGlobalState service) $ M.lookup svc global of + SomeServiceGlobalState (_ :: Proxy gs) gs -> do + (Refl :: s :~: gs) <- return $ fromMaybe (error "service ID mismatch in global map") eqT + let ( gs', res ) = f gs + return ( M.insert svc (SomeServiceGlobalState (Proxy @s) gs') global, res ) + putTMVar (serverServiceStates server) global' + return res + Nothing -> do + throwOtherError $ "unhandled service '" ++ show (toUUID svc) ++ "'" + + +foreign import ccall unsafe "Network/ifaddrs.h join_multicast" cJoinMulticast :: CInt -> Ptr CSize -> IO (Ptr Word32) +foreign import ccall unsafe "Network/ifaddrs.h local_addresses" cLocalAddresses :: Ptr CSize -> IO (Ptr InetAddress) +foreign import ccall unsafe "Network/ifaddrs.h broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32) +foreign import ccall unsafe "stdlib.h free" cFree :: Ptr a -> IO () + +data InetAddress = InetAddress { fromInetAddress :: IP.IP } + +instance F.Storable InetAddress where + sizeOf _ = sizeOf (undefined :: CInt) + 16 + alignment _ = 8 + + peek ptr = (unpackFamily <$> peekByteOff ptr 0) >>= \case + AF_INET -> InetAddress . IP.IPv4 . IP.fromHostAddress <$> peekByteOff ptr (sizeOf (undefined :: CInt)) + AF_INET6 -> InetAddress . IP.IPv6 . IP.toIPv6b . map fromIntegral <$> peekArray 16 (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) + _ -> fail "InetAddress: unknown family" + + poke ptr (InetAddress addr) = case addr of + IP.IPv4 ip -> do + pokeByteOff ptr 0 (packFamily AF_INET) + pokeByteOff ptr (sizeOf (undefined :: CInt)) (IP.toHostAddress ip) + IP.IPv6 ip -> do + pokeByteOff ptr 0 (packFamily AF_INET6) + pokeArray (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) (map fromIntegral $ IP.fromIPv6b ip) + +joinMulticast :: Socket -> IO [ Word32 ] +joinMulticast sock = + withFdSocket sock $ \fd -> + alloca $ \pcount -> do + ptr <- cJoinMulticast fd pcount + if ptr == nullPtr + then do + return [] + else do + count <- fromIntegral <$> peek pcount + res <- forM [ 0 .. count - 1 ] $ \i -> + peekElemOff ptr i + cFree ptr + return res + +getServerAddresses :: Server -> IO [ SockAddr ] +getServerAddresses Server {..} = do + alloca $ \pcount -> do + ptr <- cLocalAddresses pcount + if ptr == nullPtr + then do + return [] + else do + count <- fromIntegral <$> peek pcount + res <- peekArray count ptr + cFree ptr + return $ map (IP.toSockAddr . (, serverPort serverOptions ) . fromInetAddress) res + +getBroadcastAddresses :: PortNumber -> IO [SockAddr] +getBroadcastAddresses port = do + ptr <- cBroadcastAddresses + let parse i = do + w <- peekElemOff ptr i + if w == 0 then return [] + else (SockAddrInet port w:) <$> parse (i + 1) + if ptr == nullPtr + then return [] + else do + addrs <- parse 0 + cFree ptr + return addrs diff --git a/src/Erebos/Network.hs-boot b/src/Erebos/Network.hs-boot new file mode 100644 index 0000000..af77581 --- /dev/null +++ b/src/Erebos/Network.hs-boot @@ -0,0 +1,8 @@ +module Erebos.Network where + +import Erebos.Object.Internal + +data Server +data Peer + +peerStorage :: Peer -> Storage diff --git a/src/Erebos/Network/Channel.hs b/src/Erebos/Network/Channel.hs new file mode 100644 index 0000000..d9679bd --- /dev/null +++ b/src/Erebos/Network/Channel.hs @@ -0,0 +1,175 @@ +module Erebos.Network.Channel ( + Channel, + ChannelRequest, ChannelRequestData(..), + ChannelAccept, ChannelAcceptData(..), + + createChannelRequest, + acceptChannelRequest, + acceptedChannel, + + channelEncrypt, + channelDecrypt, +) where + +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Crypto.Cipher.ChaChaPoly1305 +import Crypto.Error + +import Data.Binary +import Data.ByteArray (ByteArray, Bytes, ScrubbedBytes, convert) +import Data.ByteArray qualified as BA +import Data.ByteString.Lazy qualified as BL +import Data.List + +import Erebos.Identity +import Erebos.PubKey +import Erebos.Storable + +data Channel = Channel + { chPeers :: [Stored (Signed IdentityData)] + , chKey :: ScrubbedBytes + , chNonceFixedOur :: Bytes + , chNonceFixedPeer :: Bytes + , chCounterNextOut :: MVar Word64 + , chCounterNextIn :: MVar Word64 + } + +type ChannelRequest = Signed ChannelRequestData + +data ChannelRequestData = ChannelRequest + { crPeers :: [Stored (Signed IdentityData)] + , crKey :: Stored PublicKexKey + } + deriving (Show) + +type ChannelAccept = Signed ChannelAcceptData + +data ChannelAcceptData = ChannelAccept + { caRequest :: Stored ChannelRequest + , caKey :: Stored PublicKexKey + } + + +instance Storable ChannelRequestData where + store' cr = storeRec $ do + mapM_ (storeRef "peer") $ crPeers cr + storeRef "key" $ crKey cr + + load' = loadRec $ do + ChannelRequest + <$> loadRefs "peer" + <*> loadRef "key" + +instance Storable ChannelAcceptData where + store' ca = storeRec $ do + storeRef "req" $ caRequest ca + storeRef "key" $ caKey ca + + load' = loadRec $ do + ChannelAccept + <$> loadRef "req" + <*> loadRef "key" + + +keySize :: Int +keySize = 32 + +createChannelRequest :: (MonadStorage m, MonadIO m, MonadError e m, FromErebosError e) => UnifiedIdentity -> UnifiedIdentity -> m (Stored ChannelRequest) +createChannelRequest self peer = do + (_, xpublic) <- liftIO . generateKeys =<< getStorage + skey <- loadKey $ idKeyMessage self + mstore =<< sign skey =<< mstore ChannelRequest { crPeers = sort [idData self, idData peer], crKey = xpublic } + +acceptChannelRequest :: (MonadStorage m, MonadIO m, MonadError e m, FromErebosError e) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelRequest -> m (Stored ChannelAccept, Channel) +acceptChannelRequest self peer req = do + case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of + Nothing -> throwOtherError $ "invalid peers in channel request" + Just peers -> do + when (not $ any (self `sameIdentity`) peers) $ + throwOtherError $ "self identity missing in channel request peers" + when (not $ any (peer `sameIdentity`) peers) $ + throwOtherError $ "peer identity missing in channel request peers" + when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ + throwOtherError $ "channel requent not signed by peer" + + (xsecret, xpublic) <- liftIO . generateKeys =<< getStorage + skey <- loadKey $ idKeyMessage self + acc <- mstore =<< sign skey =<< mstore ChannelAccept { caRequest = req, caKey = xpublic } + liftIO $ do + let chPeers = crPeers $ fromStored $ signedData $ fromStored req + chKey = BA.take keySize $ dhSecret xsecret $ + fromStored $ crKey $ fromStored $ signedData $ fromStored req + chNonceFixedOur = BA.pack [ 2, 0, 0, 0 ] + chNonceFixedPeer = BA.pack [ 1, 0, 0, 0 ] + chCounterNextOut <- newMVar 0 + chCounterNextIn <- newMVar 0 + + return (acc, Channel {..}) + +acceptedChannel :: (MonadIO m, MonadError e m, FromErebosError e) => UnifiedIdentity -> UnifiedIdentity -> Stored ChannelAccept -> m Channel +acceptedChannel self peer acc = do + let req = caRequest $ fromStored $ signedData $ fromStored acc + case sequence $ map validateIdentity $ crPeers $ fromStored $ signedData $ fromStored req of + Nothing -> throwOtherError $ "invalid peers in channel accept" + Just peers -> do + when (not $ any (self `sameIdentity`) peers) $ + throwOtherError $ "self identity missing in channel accept peers" + when (not $ any (peer `sameIdentity`) peers) $ + throwOtherError $ "peer identity missing in channel accept peers" + when (idKeyMessage peer `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored acc)) $ + throwOtherError $ "channel accept not signed by peer" + when (idKeyMessage self `notElem` (map (sigKey . fromStored) $ signedSignature $ fromStored req)) $ + throwOtherError $ "original channel request not signed by us" + + xsecret <- loadKey $ crKey $ fromStored $ signedData $ fromStored req + let chPeers = crPeers $ fromStored $ signedData $ fromStored req + chKey = BA.take keySize $ dhSecret xsecret $ + fromStored $ caKey $ fromStored $ signedData $ fromStored acc + chNonceFixedOur = BA.pack [ 1, 0, 0, 0 ] + chNonceFixedPeer = BA.pack [ 2, 0, 0, 0 ] + chCounterNextOut <- liftIO $ newMVar 0 + chCounterNextIn <- liftIO $ newMVar 0 + + return Channel {..} + + +channelEncrypt :: (ByteArray ba, MonadIO m, MonadError e m, FromErebosError e) => Channel -> ba -> m (ba, Word64) +channelEncrypt Channel {..} plain = do + count <- liftIO $ modifyMVar chCounterNextOut $ \c -> return (c + 1, c) + let cbytes = convert $ BL.toStrict $ encode count + nonce = nonce8 chNonceFixedOur cbytes + state <- case initialize chKey =<< nonce of + CryptoPassed state -> return state + CryptoFailed err -> throwOtherError $ "failed to init chacha-poly1305 cipher: " <> show err + + let (ctext, state') = encrypt plain state + tag = finalize state' + return (BA.concat [ convert $ BA.drop 7 cbytes, ctext, convert tag ], count) + +channelDecrypt :: (ByteArray ba, MonadIO m, MonadError e m, FromErebosError e) => Channel -> ba -> m (ba, Word64) +channelDecrypt Channel {..} body = do + when (BA.length body < 17) $ do + throwOtherError $ "invalid encrypted data length" + + expectedCount <- liftIO $ readMVar chCounterNextIn + let countByte = body `BA.index` 0 + body' = BA.dropView body 1 + guessedCount = expectedCount - 128 + fromIntegral (countByte - fromIntegral expectedCount + 128 :: Word8) + nonce = nonce8 chNonceFixedPeer $ convert $ BL.toStrict $ encode guessedCount + blen = BA.length body' - 16 + ctext = BA.takeView body' blen + tag = BA.dropView body' blen + state <- case initialize chKey =<< nonce of + CryptoPassed state -> return state + CryptoFailed err -> throwOtherError $ "failed to init chacha-poly1305 cipher: " <> show err + + let (plain, state') = decrypt (convert ctext) state + when (not $ tag `BA.constEq` finalize state') $ do + throwOtherError $ "tag validation falied" + + liftIO $ modifyMVar_ chCounterNextIn $ return . max (guessedCount + 1) + return (plain, guessedCount) diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs new file mode 100644 index 0000000..025f52c --- /dev/null +++ b/src/Erebos/Network/Protocol.hs @@ -0,0 +1,1006 @@ +module Erebos.Network.Protocol ( + TransportPacket(..), + transportToObject, + TransportHeader(..), + TransportHeaderItem(..), + ServiceID(..), + SecurityRequirement(..), + + WaitingRef(..), + WaitingRefCallback, + wrDigest, + + ChannelState(..), + + ControlRequest(..), + ControlMessage(..), + erebosNetworkProtocol, + + Connection, + connAddress, + connData, + connGetChannel, + connSetChannel, + connClose, + + RawStreamReader(..), RawStreamWriter(..), + StreamPacket(..), + connAddWriteStream, + connAddReadStream, + readStreamToList, + readObjectsFromStream, + writeByteStringToStream, + + module Erebos.Flow, +) where + +import Control.Applicative +import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Exception +import Control.Monad +import Control.Monad.Except +import Control.Monad.Trans + +import Crypto.Cipher.ChaChaPoly1305 qualified as C +import Crypto.MAC.Poly1305 qualified as C (Auth(..), authTag) +import Crypto.Error +import Crypto.Random + +import Data.Binary +import Data.Binary.Get +import Data.Binary.Put +import Data.Bits +import Data.ByteArray (Bytes, ScrubbedBytes) +import Data.ByteArray qualified as BA +import Data.ByteString (ByteString) +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Function +import Data.List +import Data.Maybe +import Data.Text (Text) +import Data.Text qualified as T +import Data.Void + +import System.Clock + +import Erebos.Flow +import Erebos.Identity +import Erebos.Network.Channel +import Erebos.Object +import Erebos.Storable +import Erebos.Storage +import Erebos.UUID (UUID) + + +protocolVersion :: Text +protocolVersion = T.pack "0.1" + +protocolVersions :: [Text] +protocolVersions = [protocolVersion] + +keepAliveInternal :: TimeSpec +keepAliveInternal = fromNanoSecs $ 30 * 10^(9 :: Int) + + +data TransportPacket a = TransportPacket TransportHeader [a] + +data TransportHeader = TransportHeader [TransportHeaderItem] + deriving (Show) + +data TransportHeaderItem + = Acknowledged RefDigest + | AcknowledgedSingle Integer + | Rejected RefDigest + | ProtocolVersion Text + | Initiation RefDigest + | CookieSet Cookie + | CookieEcho Cookie + | DataRequest RefDigest + | DataResponse RefDigest + | AnnounceSelf RefDigest + | AnnounceUpdate RefDigest + | TrChannelRequest RefDigest + | TrChannelAccept RefDigest + | ServiceType ServiceID + | ServiceRef RefDigest + | StreamOpen Word8 + deriving (Eq, Show) + +newtype ServiceID = ServiceID UUID + deriving (Eq, Ord, Show, StorableUUID) + +newtype Cookie = Cookie ByteString + deriving (Eq, Show) + +data SecurityRequirement = PlaintextOnly + | PlaintextAllowed + | EncryptedOnly + deriving (Eq, Ord) + +data ParsedCookie = ParsedCookie + { cookieNonce :: C.Nonce + , cookieValidity :: Word32 + , cookieContent :: ByteString + , cookieMac :: C.Auth + } + +instance Eq ParsedCookie where + (==) = (==) `on` (\c -> ( BA.convert (cookieNonce c) :: ByteString, cookieValidity c, cookieContent c, cookieMac c )) + +instance Show ParsedCookie where + show ParsedCookie {..} = show (nonce, cookieValidity, cookieContent, mac) + where C.Auth mac = cookieMac + nonce = BA.convert cookieNonce :: ByteString + +instance Binary ParsedCookie where + put ParsedCookie {..} = do + putByteString $ BA.convert cookieNonce + putWord32be cookieValidity + putByteString $ BA.convert cookieMac + putByteString cookieContent + + get = do + Just cookieNonce <- maybeCryptoError . C.nonce12 <$> getByteString 12 + cookieValidity <- getWord32be + Just cookieMac <- maybeCryptoError . C.authTag <$> getByteString 16 + cookieContent <- BL.toStrict <$> getRemainingLazyByteString + return ParsedCookie {..} + +isHeaderItemAcknowledged :: TransportHeaderItem -> Bool +isHeaderItemAcknowledged = \case + Acknowledged {} -> False + AcknowledgedSingle {} -> False + Rejected {} -> False + ProtocolVersion {} -> False + Initiation {} -> False + CookieSet {} -> False + CookieEcho {} -> False + _ -> True + +transportToObject :: PartialStorage -> TransportHeader -> PartialObject +transportToObject st (TransportHeader items) = Rec $ map single items + where single = \case + Acknowledged dgst -> (BC.pack "ACK", RecRef $ partialRefFromDigest st dgst) + AcknowledgedSingle num -> (BC.pack "ACK", RecInt num) + Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst) + ProtocolVersion ver -> (BC.pack "VER", RecText ver) + Initiation dgst -> (BC.pack "INI", RecRef $ partialRefFromDigest st dgst) + CookieSet (Cookie bytes) -> (BC.pack "CKS", RecBinary bytes) + CookieEcho (Cookie bytes) -> (BC.pack "CKE", RecBinary bytes) + DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst) + DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst) + AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst) + AnnounceUpdate dgst -> (BC.pack "ANU", RecRef $ partialRefFromDigest st dgst) + TrChannelRequest dgst -> (BC.pack "CRQ", RecRef $ partialRefFromDigest st dgst) + TrChannelAccept dgst -> (BC.pack "CAC", RecRef $ partialRefFromDigest st dgst) + ServiceType stype -> (BC.pack "SVT", RecUUID $ toUUID stype) + ServiceRef dgst -> (BC.pack "SVR", RecRef $ partialRefFromDigest st dgst) + StreamOpen num -> (BC.pack "STO", RecInt $ fromIntegral num) + +transportFromObject :: PartialObject -> Maybe TransportHeader +transportFromObject (Rec items) = case catMaybes $ map single items of + [] -> Nothing + titems -> Just $ TransportHeader titems + where single (name, content) = if + | name == BC.pack "ACK", RecRef ref <- content -> Just $ Acknowledged $ refDigest ref + | name == BC.pack "ACK", RecInt num <- content -> Just $ AcknowledgedSingle num + | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref + | name == BC.pack "VER", RecText ver <- content -> Just $ ProtocolVersion ver + | name == BC.pack "INI", RecRef ref <- content -> Just $ Initiation $ refDigest ref + | name == BC.pack "CKS", RecBinary bytes <- content -> Just $ CookieSet (Cookie bytes) + | name == BC.pack "CKE", RecBinary bytes <- content -> Just $ CookieEcho (Cookie bytes) + | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref + | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref + | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref + | name == BC.pack "ANU", RecRef ref <- content -> Just $ AnnounceUpdate $ refDigest ref + | name == BC.pack "CRQ", RecRef ref <- content -> Just $ TrChannelRequest $ refDigest ref + | name == BC.pack "CAC", RecRef ref <- content -> Just $ TrChannelAccept $ refDigest ref + | name == BC.pack "SVT", RecUUID uuid <- content -> Just $ ServiceType $ fromUUID uuid + | name == BC.pack "SVR", RecRef ref <- content -> Just $ ServiceRef $ refDigest ref + | name == BC.pack "STO", RecInt num <- content -> Just $ StreamOpen $ fromIntegral num + | otherwise -> Nothing +transportFromObject _ = Nothing + + +data GlobalState addr = (Eq addr, Show addr) => GlobalState + { gIdentity :: TVar (UnifiedIdentity, [UnifiedIdentity]) + , gConnections :: TVar [Connection addr] + , gDataFlow :: SymFlow (addr, ByteString) + , gControlFlow :: Flow (ControlRequest addr) (ControlMessage addr) + , gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject)) + , gLog :: String -> STM () + , gStorage :: PartialStorage + , gStartTime :: TimeSpec + , gNowVar :: TVar TimeSpec + , gNextTimeout :: TVar TimeSpec + , gInitConfig :: Ref + , gCookieKey :: ScrubbedBytes + , gCookieStartTime :: Word32 + } + +data Connection addr = Connection + { cGlobalState :: GlobalState addr + , cAddress :: addr + , cDataUp :: Flow + (Maybe (Bool, TransportPacket PartialObject)) + (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) + , cDataInternal :: Flow + (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) + (Maybe (Bool, TransportPacket PartialObject)) + , cChannel :: TVar ChannelState + , cCookie :: TVar (Maybe Cookie) + , cSecureOutQueue :: TQueue (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) + , cMaxInFlightPackets :: TVar Int + , cReservedPackets :: TVar Int + , cSentPackets :: TVar [SentPacket] + , cToAcknowledge :: TVar [Integer] + , cNextKeepAlive :: TVar (Maybe TimeSpec) + , cInStreams :: TVar [(Word8, Stream)] + , cOutStreams :: TVar [(Word8, Stream)] + } + +instance Eq (Connection addr) where + (==) = (==) `on` cChannel + +connAddress :: Connection addr -> addr +connAddress = cAddress + +connData :: Connection addr -> Flow + (Maybe (Bool, TransportPacket PartialObject)) + (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) +connData = cDataUp + +connGetChannel :: Connection addr -> STM ChannelState +connGetChannel Connection {..} = readTVar cChannel + +connSetChannel :: Connection addr -> ChannelState -> STM () +connSetChannel Connection {..} ch = do + writeTVar cChannel ch + +connClose :: Connection addr -> STM () +connClose conn@Connection {..} = do + let GlobalState {..} = cGlobalState + readTVar cChannel >>= \case + ChannelClosed -> return () + _ -> do + writeTVar cChannel ChannelClosed + writeTVar gConnections . filter (/=conn) =<< readTVar gConnections + writeFlow cDataInternal Nothing + +connAddWriteStream :: Connection addr -> STM (Either String (TransportHeaderItem, RawStreamWriter, IO ())) +connAddWriteStream conn@Connection {..} = do + outStreams <- readTVar cOutStreams + let doInsert :: Word8 -> [(Word8, Stream)] -> ExceptT String STM ((Word8, Stream), [(Word8, Stream)]) + doInsert n (s@(n', _) : rest) | n == n' = + fmap (s:) <$> doInsert (n + 1) rest + doInsert n streams | n < 63 = lift $ do + sState <- newTVar StreamOpening + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + sWaitingForAck <- newTVar 0 + let info = (n, Stream {..}) + return (info, info : streams) + doInsert _ _ = throwError "all outbound streams in use" + + runExceptT $ do + ((streamNumber, stream), outStreams') <- doInsert 1 outStreams + lift $ writeTVar cOutStreams outStreams' + return + ( StreamOpen streamNumber + , RawStreamWriter (fromIntegral streamNumber) (sFlowIn stream) + , go cGlobalState streamNumber stream + ) + + where + go gs@GlobalState {..} streamNumber stream = do + (reserved, msg) <- atomically $ do + readTVar (sState stream) >>= \case + StreamRunning -> return () + _ -> retry + (,) <$> reservePacket conn + <*> readFlow (sFlowOut stream) + + (plain, cont, onAck) <- case msg of + StreamData {..} -> do + return (stpData, True, return ()) + StreamClosed {} -> do + atomically $ do + -- wait for ack on all sent stream data + waits <- readTVar (sWaitingForAck stream) + when (waits > 0) retry + return (BC.empty, False, streamClosed conn streamNumber) + + let secure = True + plainAckedBy = [] + mbReserved = Just reserved + + mbch <- atomically (readTVar cChannel) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ ch -> Just ch + _ -> Nothing + + mbs <- case mbch of + Just ch -> do + runExceptT (channelEncrypt ch $ B.concat + [ B.singleton streamNumber + , B.singleton (fromIntegral (stpSequence msg) :: Word8) + , plain + ] ) >>= \case + Right (ctext, counter) -> do + let isAcked = True + return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ showErebosError err + return Nothing + Nothing | secure -> return Nothing + | otherwise -> return $ Just (plain, plainAckedBy) + + case mbs of + Just (bs, ackedBy) -> do + atomically $ do + modifyTVar' (sWaitingForAck stream) (+ 1) + let mbReserved' = (\rs -> rs + { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) + , rsOnAck = do + rsOnAck rs + onAck + atomically $ modifyTVar' (sWaitingForAck stream) (subtract 1) + }) <$> mbReserved + sendBytes conn mbReserved' bs + Nothing -> return () + + when cont $ go gs streamNumber stream + +connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader +connAddReadStream Connection {..} streamNumber = do + inStreams <- readTVar cInStreams + let doInsert (s@(n, _) : rest) + | streamNumber < n = fmap (s:) <$> doInsert rest + | streamNumber == n = doInsert rest + doInsert streams = do + sState <- newTVar StreamRunning + (sFlowIn, sFlowOut) <- newFlow + sNextSequence <- newTVar 0 + sWaitingForAck <- newTVar 0 + let stream = Stream {..} + return ( streamNumber, stream, (streamNumber, stream) : streams ) + ( num, stream, inStreams' ) <- doInsert inStreams + writeTVar cInStreams inStreams' + return $ RawStreamReader (fromIntegral num) (sFlowOut stream) + + +data RawStreamReader = RawStreamReader + { rsrNum :: Int + , rsrFlow :: Flow StreamPacket Void + } + +data RawStreamWriter = RawStreamWriter + { rswNum :: Int + , rswFlow :: Flow Void StreamPacket + } + +data Stream = Stream + { sState :: TVar StreamState + , sFlowIn :: Flow Void StreamPacket + , sFlowOut :: Flow StreamPacket Void + , sNextSequence :: TVar Word64 + , sWaitingForAck :: TVar Word64 + } + +data StreamState = StreamOpening | StreamRunning + +data StreamPacket + = StreamData + { stpSequence :: Word64 + , stpData :: BC.ByteString + } + | StreamClosed + { stpSequence :: Word64 + } + +streamAccepted :: Connection addr -> Word8 -> IO () +streamAccepted Connection {..} snum = atomically $ do + (lookup snum <$> readTVar cOutStreams) >>= \case + Just Stream {..} -> do + modifyTVar' sState $ \case + StreamOpening -> StreamRunning + x -> x + Nothing -> return () + +streamClosed :: Connection addr -> Word8 -> IO () +streamClosed Connection {..} snum = atomically $ do + modifyTVar' cOutStreams $ filter ((snum /=) . fst) + +readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) +readStreamToList stream = readFlowIO (rsrFlow stream) >>= \case + StreamData sq bytes -> fmap ((sq, bytes) :) <$> readStreamToList stream + StreamClosed sqEnd -> return (sqEnd, []) + +readObjectsFromStream :: PartialStorage -> RawStreamReader -> IO (Except ErebosError [PartialObject]) +readObjectsFromStream st stream = do + (seqEnd, list) <- readStreamToList stream + let validate s ((s', bytes) : rest) + | s == s' = (bytes : ) <$> validate (s + 1) rest + | s > s' = validate s rest + | otherwise = throwOtherError "missing object chunk" + validate s [] + | s == seqEnd = return [] + | otherwise = throwOtherError "content length mismatch" + return $ do + content <- BL.fromChunks <$> validate 0 list + deserializeObjects st content + +writeByteStringToStream :: RawStreamWriter -> BL.ByteString -> IO () +writeByteStringToStream stream = go 0 + where + go seqNum bstr + | BL.null bstr = writeFlowIO (rswFlow stream) $ StreamClosed seqNum + | otherwise = do + let (cur, rest) = BL.splitAt 500 bstr -- TODO: MTU + writeFlowIO (rswFlow stream) $ StreamData seqNum (BL.toStrict cur) + go (seqNum + 1) rest + + +data WaitingRef = WaitingRef + { wrefStorage :: Storage + , wrefPartial :: PartialRef + , wrefAction :: Ref -> WaitingRefCallback + , wrefStatus :: TVar (Either [RefDigest] Ref) + } + +type WaitingRefCallback = ExceptT ErebosError IO () + +wrDigest :: WaitingRef -> RefDigest +wrDigest = refDigest . wrefPartial + + +data ChannelState = ChannelNone + | ChannelCookieWait -- sent initiation, waiting for response + | ChannelCookieReceived -- received cookie, but no cookie echo yet + | ChannelCookieConfirmed -- received cookie echo, no need to send from our side + | ChannelOurRequest (Stored ChannelRequest) + | ChannelPeerRequest WaitingRef + | ChannelOurAccept (Stored ChannelAccept) Channel + | ChannelEstablished Channel + | ChannelClosed + +data ReservedToSend = ReservedToSend + { rsAckedBy :: Maybe (TransportHeaderItem -> Bool) + , rsOnAck :: IO () + , rsOnFail :: IO () + } + +data SentPacket = SentPacket + { spTime :: TimeSpec + , spRetryCount :: Int + , spAckedBy :: Maybe (TransportHeaderItem -> Bool) + , spOnAck :: IO () + , spOnFail :: IO () + , spData :: BC.ByteString + } + + +data ControlRequest addr = RequestConnection addr + | SendAnnounce addr + | UpdateSelfIdentity UnifiedIdentity + +data ControlMessage addr = NewConnection (Connection addr) (Maybe RefDigest) + | ReceivedAnnounce addr RefDigest + + +erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) + => UnifiedIdentity + -> (String -> STM ()) + -> SymFlow (addr, ByteString) + -> Flow (ControlRequest addr) (ControlMessage addr) + -> IO () +erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do + gIdentity <- newTVarIO (initialIdentity, []) + gConnections <- newTVarIO [] + gNextUp <- newEmptyTMVarIO + mStorage <- memoryStorage + gStorage <- derivePartialStorage mStorage + + gStartTime <- getTime Monotonic + gNowVar <- newTVarIO gStartTime + gNextTimeout <- newTVarIO gStartTime + gInitConfig <- store mStorage $ (Rec [] :: Object) + + gCookieKey <- getRandomBytes 32 + gCookieStartTime <- runGet getWord32host . BL.pack . BA.unpack @ScrubbedBytes <$> getRandomBytes 4 + + let gs = GlobalState {..} + + let signalTimeouts = forever $ do + now <- getTime Monotonic + next <- atomically $ do + writeTVar gNowVar now + readTVar gNextTimeout + + let waitTill time + | time > now = threadDelay $ fromInteger (toNanoSecs (time - now)) `div` 1000 + | otherwise = threadDelay maxBound + waitForUpdate = atomically $ do + next' <- readTVar gNextTimeout + when (next' == next) retry + + race_ (waitTill next) waitForUpdate + + race_ signalTimeouts $ forever $ do + io <- atomically $ do + passUpIncoming gs <|> processIncoming gs <|> processOutgoing gs + catch io $ \(e :: SomeException) -> atomically $ gLog $ "exception during network protocol handling: " <> show e + + +getConnection :: GlobalState addr -> addr -> STM (Connection addr) +getConnection gs addr = do + maybe (newConnection gs addr) return =<< findConnection gs addr + +findConnection :: GlobalState addr -> addr -> STM (Maybe (Connection addr)) +findConnection GlobalState {..} addr = do + find ((addr==) . cAddress) <$> readTVar gConnections + +newConnection :: GlobalState addr -> addr -> STM (Connection addr) +newConnection cGlobalState@GlobalState {..} addr = do + conns <- readTVar gConnections + + let cAddress = addr + (cDataUp, cDataInternal) <- newFlow + cChannel <- newTVar ChannelNone + cCookie <- newTVar Nothing + cSecureOutQueue <- newTQueue + cMaxInFlightPackets <- newTVar 4 + cReservedPackets <- newTVar 0 + cSentPackets <- newTVar [] + cToAcknowledge <- newTVar [] + cNextKeepAlive <- newTVar Nothing + cInStreams <- newTVar [] + cOutStreams <- newTVar [] + let conn = Connection {..} + + writeTVar gConnections (conn : conns) + return conn + +passUpIncoming :: GlobalState addr -> STM (IO ()) +passUpIncoming GlobalState {..} = do + (Connection {..}, up) <- takeTMVar gNextUp + writeFlow cDataInternal (Just up) + return $ return () + +processIncoming :: GlobalState addr -> STM (IO ()) +processIncoming gs@GlobalState {..} = do + guard =<< isEmptyTMVar gNextUp + guard =<< canWriteFlow gControlFlow + + (addr, msg) <- readFlow gDataFlow + mbconn <- findConnection gs addr + + mbch <- case mbconn of + Nothing -> return Nothing + Just conn -> readTVar (cChannel conn) >>= return . \case + ChannelEstablished ch -> Just ch + ChannelOurAccept _ ch -> Just ch + _ -> Nothing + + return $ do + let deserialize = liftEither . runExcept . deserializeObjects gStorage . BL.fromStrict + let parse = case B.uncons msg of + Just (b, enc) + | b .&. 0xE0 == 0x80 -> do + ch <- maybe (throwOtherError "unexpected encrypted packet") return mbch + (dec, counter) <- channelDecrypt ch enc + + case B.uncons dec of + Just (0x00, content) -> do + objs <- deserialize content + return $ Left (True, objs, Just counter) + + Just (snum, dec') + | snum < 64 + , Just (seq8, content) <- B.uncons dec' + -> do + return $ Right (snum, seq8, content, counter) + + Just (_, _) -> do + throwOtherError "unexpected stream header" + + Nothing -> do + throwOtherError "empty decrypted content" + + | b .&. 0xE0 == 0x60 -> do + objs <- deserialize msg + return $ Left (False, objs, Nothing) + + | otherwise -> throwOtherError "invalid packet" + + Nothing -> throwOtherError "empty packet" + + now <- getTime Monotonic + runExceptT parse >>= \case + Right (Left (secure, objs, mbcounter)) + | hobj:content <- objs + , Just header@(TransportHeader items) <- transportFromObject hobj + -> processPacket gs (maybe (Left addr) Right mbconn) secure (TransportPacket header content) >>= \case + Just (conn@Connection {..}, mbup) -> do + ioAfter <- atomically $ do + case mbcounter of + Just counter | any isHeaderItemAcknowledged items -> + modifyTVar' cToAcknowledge (fromIntegral counter :) + _ -> return () + case mbup of + Just up -> putTMVar gNextUp (conn, (secure, up)) + Nothing -> return () + updateKeepAlive conn now + processAcknowledgements gs conn items + ioAfter + Nothing -> return () + + | otherwise -> atomically $ do + gLog $ show addr ++ ": invalid objects" + gLog $ show objs + + Right (Right (snum, seq8, content, counter)) + | Just conn@Connection {..} <- mbconn + -> atomically $ do + updateKeepAlive conn now + (lookup snum <$> readTVar cInStreams) >>= \case + Nothing -> + gLog $ "unexpected stream number " ++ show snum + + Just Stream {..} -> do + expectedSequence <- readTVar sNextSequence + let seqFull = expectedSequence - 0x80 + fromIntegral (seq8 - fromIntegral expectedSequence + 0x80 :: Word8) + sdata <- if + | B.null content -> do + modifyTVar' cInStreams $ filter ((/=snum) . fst) + return $ StreamClosed seqFull + | otherwise -> do + writeTVar sNextSequence $ max expectedSequence (seqFull + 1) + return $ StreamData seqFull content + writeFlow sFlowIn sdata + modifyTVar' cToAcknowledge (fromIntegral counter :) + + | otherwise -> do + atomically $ gLog $ show addr <> ": stream packet without connection" + + Left err -> do + atomically $ gLog $ show addr <> ": failed to parse packet: " <> showErebosError err + +processPacket :: GlobalState addr -> Either addr (Connection addr) -> Bool -> TransportPacket a -> IO (Maybe (Connection addr, Maybe (TransportPacket a))) +processPacket gs@GlobalState {..} econn secure packet@(TransportPacket (TransportHeader header) _) = if + -- Established secure communication + | Right conn <- econn, secure + -> return $ Just (conn, Just packet) + + -- Plaintext communication with cookies to prove origin + | cookie:_ <- mapMaybe (\case CookieEcho x -> Just x; _ -> Nothing) header + -> verifyCookie gs addr cookie >>= \case + True -> do + atomically $ do + conn@Connection {..} <- getConnection gs addr + oldCookie <- readTVar cCookie + let received = listToMaybe $ mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header + case received `mplus` oldCookie of + Just current -> do + writeTVar cCookie (Just current) + cookieEchoReceived gs conn mbpid + return $ Just (conn, Just packet) + Nothing -> do + gLog $ show addr <> ": missing cookie set, dropping " <> show header + return $ Nothing + + False -> do + atomically $ gLog $ show addr <> ": cookie verification failed, dropping " <> show header + return Nothing + + -- Response to initiation packet + | cookie:_ <- mapMaybe (\case CookieSet x -> Just x; _ -> Nothing) header + , Just _ <- version + , Right conn@Connection {..} <- econn + -> do + atomically $ readTVar cChannel >>= \case + ChannelCookieWait -> do + writeTVar cChannel $ ChannelCookieReceived + writeTVar cCookie $ Just cookie + writeFlow gControlFlow (NewConnection conn mbpid) + return $ Just (conn, Nothing) + _ -> return Nothing + + -- Initiation packet + | _:_ <- mapMaybe (\case Initiation x -> Just x; _ -> Nothing) header + , Just ver <- version + -> do + cookie <- createCookie gs addr + atomically $ do + identity <- fst <$> readTVar gIdentity + let reply = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader + [ CookieSet cookie + , AnnounceSelf $ refDigest $ storedRef $ idData identity + , ProtocolVersion ver + ] + writeFlow gDataFlow (addr, reply) + return Nothing + + -- Announce packet outside any connection + | dgst:_ <- mapMaybe (\case AnnounceSelf x -> Just x; _ -> Nothing) header + , Just _ <- version + -> do + atomically $ do + (cur, past) <- readTVar gIdentity + when (not $ dgst `elem` map (refDigest . storedRef . idData) (cur : past)) $ do + writeFlow gControlFlow $ ReceivedAnnounce addr dgst + return Nothing + + | otherwise -> do + atomically $ gLog $ show addr <> ": dropping packet " <> show header + return Nothing + + where + addr = either id cAddress econn + mbpid = listToMaybe $ mapMaybe (\case AnnounceSelf dgst -> Just dgst; _ -> Nothing) header + version = listToMaybe $ filter (\v -> ProtocolVersion v `elem` header) protocolVersions + +cookieEchoReceived :: GlobalState addr -> Connection addr -> Maybe RefDigest -> STM () +cookieEchoReceived GlobalState {..} conn@Connection {..} mbpid = do + readTVar cChannel >>= \case + ChannelNone -> newConn + ChannelCookieWait -> newConn + ChannelCookieReceived {} -> update + _ -> return () + where + update = do + writeTVar cChannel ChannelCookieConfirmed + newConn = do + update + writeFlow gControlFlow (NewConnection conn mbpid) + +generateCookieHeaders :: Connection addr -> ChannelState -> IO [TransportHeaderItem] +generateCookieHeaders Connection {..} ch = catMaybes <$> sequence [ echoHeader, setHeader ] + where + echoHeader = fmap CookieEcho <$> atomically (readTVar cCookie) + setHeader = case ch of + ChannelCookieWait {} -> Just . CookieSet <$> createCookie cGlobalState cAddress + ChannelCookieReceived {} -> Just . CookieSet <$> createCookie cGlobalState cAddress + _ -> return Nothing + +createCookie :: GlobalState addr -> addr -> IO Cookie +createCookie GlobalState {..} addr = do + (nonceBytes :: Bytes) <- getRandomBytes 12 + validUntil <- (fromNanoSecs (60 * 10^(9 :: Int)) +) <$> getTime Monotonic + let validSecondsFromStart = fromIntegral $ toNanoSecs (validUntil - gStartTime) `div` (10^(9 :: Int)) + cookieValidity = validSecondsFromStart - gCookieStartTime + plainContent = BC.pack (show addr) + throwCryptoErrorIO $ do + cookieNonce <- C.nonce12 nonceBytes + st1 <- C.initialize gCookieKey cookieNonce + let st2 = C.finalizeAAD $ C.appendAAD (BL.toStrict $ runPut $ putWord32be cookieValidity) st1 + (cookieContent, st3) = C.encrypt plainContent st2 + cookieMac = C.finalize st3 + return $ Cookie $ BL.toStrict $ encode $ ParsedCookie {..} + +verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool +verifyCookie GlobalState {..} addr (Cookie cookie) = do + ctime <- getTime Monotonic + return $ fromMaybe False $ do + ( _, _, ParsedCookie {..} ) <- either (const Nothing) Just $ decodeOrFail $ BL.fromStrict cookie + maybeCryptoError $ do + st1 <- C.initialize gCookieKey cookieNonce + let st2 = C.finalizeAAD $ C.appendAAD (BL.toStrict $ runPut $ putWord32be cookieValidity) st1 + (plainContent, st3) = C.decrypt cookieContent st2 + mac = C.finalize st3 + + validSecondsFromStart = fromIntegral $ cookieValidity + gCookieStartTime + validUntil = gStartTime + fromNanoSecs (validSecondsFromStart * (10^(9 :: Int))) + return $ and + [ mac == cookieMac + , ctime <= validUntil + , show addr == BC.unpack plainContent + ] + +reservePacket :: Connection addr -> STM ReservedToSend +reservePacket conn@Connection {..} = do + maxPackets <- readTVar cMaxInFlightPackets + reserved <- readTVar cReservedPackets + sent <- length <$> readTVar cSentPackets + + when (sent + reserved >= maxPackets) $ do + retry + + writeTVar cReservedPackets $ reserved + 1 + return $ ReservedToSend Nothing (return ()) (atomically $ connClose conn) + +resendBytes :: Connection addr -> Maybe ReservedToSend -> SentPacket -> IO () +resendBytes conn@Connection {..} reserved sp = do + let GlobalState {..} = cGlobalState + now <- getTime Monotonic + atomically $ do + when (isJust reserved) $ do + modifyTVar' cReservedPackets (subtract 1) + + when (isJust $ spAckedBy sp) $ do + modifyTVar' cSentPackets $ (:) sp + { spTime = now + , spRetryCount = spRetryCount sp + 1 + } + writeFlow gDataFlow (cAddress, spData sp) + updateKeepAlive conn now + +sendBytes :: Connection addr -> Maybe ReservedToSend -> ByteString -> IO () +sendBytes conn reserved bs = resendBytes conn reserved + SentPacket + { spTime = undefined + , spRetryCount = -1 + , spAckedBy = rsAckedBy =<< reserved + , spOnAck = maybe (return ()) rsOnAck reserved + , spOnFail = maybe (return ()) rsOnFail reserved + , spData = bs + } + +updateKeepAlive :: Connection addr -> TimeSpec -> STM () +updateKeepAlive Connection {..} now = do + let next = now + keepAliveInternal + writeTVar cNextKeepAlive $ Just next + + +processOutgoing :: forall addr. GlobalState addr -> STM (IO ()) +processOutgoing gs@GlobalState {..} = do + + let sendNextPacket :: Connection addr -> STM (IO ()) + sendNextPacket conn@Connection {..} = do + channel <- readTVar cChannel + let mbch = case channel of + ChannelEstablished ch -> Just ch + _ -> Nothing + + let checkOutstanding + | isJust mbch = do + (,) <$> readTQueue cSecureOutQueue <*> (Just <$> reservePacket conn) + | otherwise = retry + + checkDataInternal = do + (,) <$> readFlow cDataInternal <*> (Just <$> reservePacket conn) + + checkAcknowledgements + | isJust mbch = do + acks <- readTVar cToAcknowledge + if null acks then retry + else return ((EncryptedOnly, TransportPacket (TransportHeader []) [], []), Nothing) + | otherwise = retry + + ((secure, packet@(TransportPacket (TransportHeader hitems) content), plainAckedBy), mbReserved) <- + checkOutstanding <|> checkDataInternal <|> checkAcknowledgements + + when (isNothing mbch && secure >= EncryptedOnly) $ do + writeTQueue cSecureOutQueue (secure, packet, plainAckedBy) + + acknowledge <- case mbch of + Nothing -> return [] + Just _ -> swapTVar cToAcknowledge [] + + return $ do + let onAck = sequence_ $ map (streamAccepted conn) $ + catMaybes (map (\case StreamOpen n -> Just n; _ -> Nothing) hitems) + + let mkPlain extraHeaders + | combinedHeaderItems@(_:_) <- map AcknowledgedSingle acknowledge ++ extraHeaders ++ hitems = + BL.concat $ + (serializeObject $ transportToObject gStorage $ TransportHeader combinedHeaderItems) + : map lazyLoadBytes content + | otherwise = BL.empty + + let usePlaintext = do + plain <- mkPlain <$> generateCookieHeaders conn channel + return $ Just (BL.toStrict plain, plainAckedBy) + + let useEncryption ch = do + plain <- mkPlain <$> return [] + runExceptT (channelEncrypt ch $ BL.toStrict $ 0x00 `BL.cons` plain) >>= \case + Right (ctext, counter) -> do + let isAcked = any isHeaderItemAcknowledged hitems + return $ Just (0x80 `B.cons` ctext, if isAcked then [ AcknowledgedSingle $ fromIntegral counter ] else []) + Left err -> do atomically $ gLog $ "Failed to encrypt data: " ++ showErebosError err + return Nothing + + mbs <- case (secure, mbch) of + (PlaintextOnly, _) -> usePlaintext + (PlaintextAllowed, Nothing) -> usePlaintext + (_, Just ch) -> useEncryption ch + (EncryptedOnly, Nothing) -> return Nothing + + case mbs of + Just (bs, ackedBy) -> do + let mbReserved' = (\rs -> rs + { rsAckedBy = guard (not $ null ackedBy) >> Just (`elem` ackedBy) + , rsOnAck = rsOnAck rs >> onAck + }) <$> mbReserved + sendBytes conn mbReserved' bs + Nothing -> return () + + let waitUntil :: TimeSpec -> TimeSpec -> STM () + waitUntil now till = do + nextTimeout <- readTVar gNextTimeout + if nextTimeout <= now || till < nextTimeout + then writeTVar gNextTimeout till + else retry + + let retransmitPacket :: Connection addr -> STM (IO ()) + retransmitPacket conn@Connection {..} = do + now <- readTVar gNowVar + (sp, rest) <- readTVar cSentPackets >>= \case + sps@(_:_) -> return (last sps, init sps) + _ -> retry + let nextTry = spTime sp + fromNanoSecs 1000000000 + if | now < nextTry -> do + waitUntil now nextTry + return $ return () + | spRetryCount sp < 2 -> do + reserved <- reservePacket conn + writeTVar cSentPackets rest + return $ resendBytes conn (Just reserved) sp + | otherwise -> do + return $ spOnFail sp + + let handleControlRequests = readFlow gControlFlow >>= \case + RequestConnection addr -> do + conn@Connection {..} <- getConnection gs addr + identity <- fst <$> readTVar gIdentity + readTVar cChannel >>= \case + ChannelNone -> do + reserved <- reservePacket conn + let packet = BL.toStrict $ BL.concat + [ serializeObject $ transportToObject gStorage $ TransportHeader $ + [ Initiation $ refDigest gInitConfig + , AnnounceSelf $ refDigest $ storedRef $ idData identity + ] ++ map ProtocolVersion protocolVersions + , lazyLoadBytes gInitConfig + ] + writeTVar cChannel ChannelCookieWait + let reserved' = reserved { rsAckedBy = Just $ \case CookieSet {} -> True; _ -> False } + return $ sendBytes conn (Just reserved') packet + _ -> return $ return () + + SendAnnounce addr -> do + identity <- fst <$> readTVar gIdentity + let packet = BL.toStrict $ serializeObject $ transportToObject gStorage $ TransportHeader $ + [ AnnounceSelf $ refDigest $ storedRef $ idData identity + ] ++ map ProtocolVersion protocolVersions + writeFlow gDataFlow (addr, packet) + return $ return () + + UpdateSelfIdentity nid -> do + (cur, past) <- readTVar gIdentity + writeTVar gIdentity (nid, cur : past) + return $ return () + + let sendKeepAlive :: Connection addr -> STM (IO ()) + sendKeepAlive Connection {..} = do + readTVar cNextKeepAlive >>= \case + Nothing -> retry + Just next -> do + now <- readTVar gNowVar + if next <= now + then do + writeTVar cNextKeepAlive Nothing + identity <- fst <$> readTVar gIdentity + let header = TransportHeader [ AnnounceSelf $ refDigest $ storedRef $ idData identity ] + writeTQueue cSecureOutQueue (EncryptedOnly, TransportPacket header [], []) + else do + waitUntil now next + return $ return () + + conns <- readTVar gConnections + msum $ concat $ + [ map retransmitPacket conns + , map sendNextPacket conns + , [ handleControlRequests ] + , map sendKeepAlive conns + ] + +processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM (IO ()) +processAcknowledgements GlobalState {} Connection {..} header = do + (acked, notAcked) <- partition (\sp -> any (fromJust (spAckedBy sp)) header) <$> readTVar cSentPackets + writeTVar cSentPackets notAcked + return $ sequence_ $ map spOnAck acked diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c new file mode 100644 index 0000000..ff4382a --- /dev/null +++ b/src/Erebos/Network/ifaddrs.c @@ -0,0 +1,279 @@ +#include "ifaddrs.h" + +#include <errno.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#ifndef _WIN32 +#include <arpa/inet.h> +#include <net/if.h> +#include <netinet/in.h> +#include <ifaddrs.h> +#include <endian.h> +#include <sys/types.h> +#include <sys/socket.h> +#else +#include <winsock2.h> +#include <ws2ipdef.h> +#include <ws2tcpip.h> +#endif + +#define DISCOVERY_MULTICAST_GROUP "ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1" + +uint32_t * join_multicast(int fd, size_t * count) +{ + size_t capacity = 16; + *count = 0; + uint32_t * interfaces = malloc(sizeof(uint32_t) * capacity); + +#ifdef _WIN32 + interfaces[0] = 0; + *count = 1; +#else + struct ifaddrs * addrs; + if (getifaddrs(&addrs) < 0) + return 0; + + for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) { + if( ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET6 && + ! (ifa->ifa_flags & IFF_LOOPBACK) && + (ifa->ifa_flags & IFF_MULTICAST) && + ! IN6_IS_ADDR_LINKLOCAL( & ((struct sockaddr_in6 *) ifa->ifa_addr)->sin6_addr ) ){ + int idx = if_nametoindex(ifa->ifa_name); + + bool seen = false; + for (size_t i = 0; i < *count; i++) { + if (interfaces[i] == idx) { + seen = true; + break; + } + } + if (seen) + continue; + + if (*count + 1 >= capacity) { + capacity *= 2; + uint32_t * nret = realloc(interfaces, sizeof(uint32_t) * capacity); + if (nret) { + interfaces = nret; + } else { + free(interfaces); + *count = 0; + return NULL; + } + } + + interfaces[*count] = idx; + (*count)++; + } + } + + freeifaddrs(addrs); +#endif + + for (size_t i = 0; i < *count; i++) { + struct ipv6_mreq group; + group.ipv6mr_interface = interfaces[i]; + inet_pton(AF_INET6, DISCOVERY_MULTICAST_GROUP, &group.ipv6mr_multiaddr); + int ret = setsockopt(fd, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, + (const void *) &group, sizeof(group)); + if (ret < 0) + fprintf(stderr, "IPV6_ADD_MEMBERSHIP failed: %s\n", strerror(errno)); + } + + return interfaces; +} + +static bool copy_local_address( struct InetAddress * dst, const struct sockaddr * src ) +{ + int family = src->sa_family; + + if( family == AF_INET ){ + struct in_addr * addr = & (( struct sockaddr_in * ) src)->sin_addr; + if (! ((ntohl( addr->s_addr ) & 0xff000000) == 0x7f000000) && // loopback + ! ((ntohl( addr->s_addr ) & 0xffff0000) == 0xa9fe0000) // link-local + ){ + dst->family = family; + memcpy( & dst->addr, addr, sizeof( * addr )); + return true; + } + } + + if( family == AF_INET6 ){ + struct in6_addr * addr = & (( struct sockaddr_in6 * ) src)->sin6_addr; + if (! IN6_IS_ADDR_LOOPBACK( addr ) && + ! IN6_IS_ADDR_LINKLOCAL( addr ) + ){ + dst->family = family; + memcpy( & dst->addr, addr, sizeof( * addr )); + return true; + } + } + + return false; +} + +#ifndef _WIN32 + +struct InetAddress * local_addresses( size_t * count ) +{ + struct ifaddrs * addrs; + if( getifaddrs( &addrs ) < 0 ) + return 0; + + * count = 0; + size_t capacity = 16; + struct InetAddress * ret = malloc( sizeof(* ret) * capacity ); + + for( struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next ){ + if ( ifa->ifa_addr ){ + int family = ifa->ifa_addr->sa_family; + if( family == AF_INET || family == AF_INET6 ){ + if( (* count) >= capacity ){ + capacity *= 2; + struct InetAddress * nret = realloc( ret, sizeof(* ret) * capacity ); + if (nret) { + ret = nret; + } else { + free( ret ); + freeifaddrs( addrs ); + return 0; + } + } + + if( copy_local_address( & ret[ * count ], ifa->ifa_addr )) + (* count)++; + } + } + } + + freeifaddrs(addrs); + return ret; +} + +uint32_t * broadcast_addresses(void) +{ + struct ifaddrs * addrs; + if (getifaddrs(&addrs) < 0) + return 0; + + size_t capacity = 16, count = 0; + uint32_t * ret = malloc(sizeof(uint32_t) * capacity); + + for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET && + ifa->ifa_flags & IFF_BROADCAST) { + if (count + 2 >= capacity) { + capacity *= 2; + uint32_t * nret = realloc(ret, sizeof(uint32_t) * capacity); + if (nret) { + ret = nret; + } else { + free(ret); + freeifaddrs(addrs); + return 0; + } + } + + ret[count] = ((struct sockaddr_in*)ifa->ifa_broadaddr)->sin_addr.s_addr; + count++; + } + } + + freeifaddrs(addrs); + ret[count] = 0; + return ret; +} + +#else // _WIN32 + +#include <winsock2.h> +#include <ws2tcpip.h> +#include <iptypes.h> +#include <iphlpapi.h> + +#pragma comment(lib, "ws2_32.lib") + +struct InetAddress * local_addresses( size_t * count ) +{ + * count = 0; + struct InetAddress * ret = NULL; + + ULONG bufsize = 15000; + IP_ADAPTER_ADDRESSES * buf = NULL; + + DWORD rv = 0; + + do { + buf = realloc( buf, bufsize ); + rv = GetAdaptersAddresses( AF_UNSPEC, 0, NULL, buf, & bufsize ); + + if( rv == ERROR_BUFFER_OVERFLOW ) + continue; + } while (0); + + if( rv == NO_ERROR ){ + size_t capacity = 16; + ret = malloc( sizeof( * ret ) * capacity ); + + for( IP_ADAPTER_ADDRESSES * cur = (IP_ADAPTER_ADDRESSES *) buf; + cur && (* count) < capacity; + cur = cur->Next ){ + + for( IP_ADAPTER_UNICAST_ADDRESS * curAddr = cur->FirstUnicastAddress; + curAddr && (* count) < capacity; + curAddr = curAddr->Next ){ + + if( copy_local_address( & ret[ * count ], curAddr->Address.lpSockaddr )) + (* count)++; + } + } + } + +cleanup: + free( buf ); + return ret; +} + +uint32_t * broadcast_addresses(void) +{ + uint32_t * ret = NULL; + SOCKET wsock = INVALID_SOCKET; + + struct WSAData wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) + return NULL; + + wsock = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, 0); + if (wsock == INVALID_SOCKET) + goto cleanup; + + INTERFACE_INFO InterfaceList[32]; + unsigned long nBytesReturned; + + if (WSAIoctl(wsock, SIO_GET_INTERFACE_LIST, 0, 0, + InterfaceList, sizeof(InterfaceList), + &nBytesReturned, 0, 0) == SOCKET_ERROR) + goto cleanup; + + int numInterfaces = nBytesReturned / sizeof(INTERFACE_INFO); + + size_t capacity = 16, count = 0; + ret = malloc(sizeof(uint32_t) * capacity); + + for (int i = 0; i < numInterfaces && count < capacity - 1; i++) + if (InterfaceList[i].iiFlags & IFF_BROADCAST) + ret[count++] = InterfaceList[i].iiBroadcastAddress.AddressIn.sin_addr.s_addr; + + ret[count] = 0; +cleanup: + if (wsock != INVALID_SOCKET) + closesocket(wsock); + WSACleanup(); + + return ret; +} + +#endif diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h new file mode 100644 index 0000000..2ee45a7 --- /dev/null +++ b/src/Erebos/Network/ifaddrs.h @@ -0,0 +1,18 @@ +#include <stddef.h> +#include <stdint.h> + +#ifndef _WIN32 +#include <sys/socket.h> +#else +#include <winsock2.h> +#endif + +struct InetAddress +{ + int family; + uint8_t addr[16]; +} __attribute__((packed)); + +uint32_t * join_multicast(int fd, size_t * count); +struct InetAddress * local_addresses( size_t * count ); +uint32_t * broadcast_addresses(void); diff --git a/src/Erebos/Object.hs b/src/Erebos/Object.hs new file mode 100644 index 0000000..f00b63d --- /dev/null +++ b/src/Erebos/Object.hs @@ -0,0 +1,23 @@ +{-| +Description: Core Erebos objects and references + +Data types and functions for working with "raw" Erebos objects and references. +-} + +module Erebos.Object ( + Object, PartialObject, Object'(..), + serializeObject, deserializeObject, deserializeObjects, + ioLoadObject, ioLoadBytes, + storeRawBytes, lazyLoadBytes, + + RecItem, RecItem'(..), + + Ref, PartialRef, RefDigest, + refDigest, refFromDigest, + readRef, showRef, + readRefDigest, showRefDigest, + refDigestFromByteString, hashToRefDigest, + copyRef, partialRef, partialRefFromDigest, +) where + +import Erebos.Object.Internal diff --git a/src/Erebos/Object/Internal.hs b/src/Erebos/Object/Internal.hs new file mode 100644 index 0000000..fdb587a --- /dev/null +++ b/src/Erebos/Object/Internal.hs @@ -0,0 +1,772 @@ +module Erebos.Object.Internal ( + Storage, PartialStorage, StorageCompleteness, + + Ref, PartialRef, RefDigest, + refDigest, refFromDigest, + readRef, showRef, + readRefDigest, showRefDigest, + refDigestFromByteString, hashToRefDigest, + copyRef, partialRef, partialRefFromDigest, + + Object, PartialObject, Object'(..), RecItem, RecItem'(..), + serializeObject, deserializeObject, deserializeObjects, + ioLoadObject, ioLoadBytes, + storeRawBytes, lazyLoadBytes, + storeObject, + collectObjects, collectStoredObjects, + + MonadStorage(..), + + Storable(..), ZeroStorable(..), + StorableText(..), StorableDate(..), StorableUUID(..), + + Store, StoreRec, + evalStore, evalStoreObject, + storeBlob, storeRec, storeZero, + storeEmpty, storeInt, storeNum, storeText, storeBinary, storeDate, storeUUID, storeRef, storeRawRef, storeWeak, storeRawWeak, + storeMbEmpty, storeMbInt, storeMbNum, storeMbText, storeMbBinary, storeMbDate, storeMbUUID, storeMbRef, storeMbRawRef, storeMbWeak, storeMbRawWeak, + storeZRef, storeZWeak, + storeRecItems, + + Load, LoadRec, + evalLoad, + loadCurrentRef, loadCurrentObject, + loadRecCurrentRef, loadRecItems, + + loadBlob, loadRec, loadZero, + loadEmpty, loadInt, loadNum, loadText, loadBinary, loadDate, loadUUID, loadRef, loadRawRef, loadRawWeak, + loadMbEmpty, loadMbInt, loadMbNum, loadMbText, loadMbBinary, loadMbDate, loadMbUUID, loadMbRef, loadMbRawRef, loadMbRawWeak, + loadTexts, loadBinaries, loadRefs, loadRawRefs, loadRawWeaks, + loadZRef, + + Stored, + fromStored, storedRef, + wrappedStore, wrappedLoad, + copyStored, + unsafeMapStored, +) where + +import Control.Applicative +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.Writer + +import Crypto.Hash + +import Data.Bifunctor +import Data.ByteString (ByteString) +import qualified Data.ByteArray as BA +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Lazy as BL +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.Char +import Data.Function +import Data.Maybe +import Data.Ratio +import Data.Set (Set) +import qualified Data.Set as S +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding +import Data.Text.Encoding.Error +import Data.Time.Calendar +import Data.Time.Clock +import Data.Time.Format +import Data.Time.LocalTime + +import System.IO.Unsafe + +import Erebos.Error +import Erebos.Storage.Internal +import Erebos.UUID (UUID) +import Erebos.UUID qualified as U +import Erebos.Util + + +zeroRef :: Storage' c -> Ref' c +zeroRef s = Ref s (RefDigest h) + where h = case digestFromByteString $ B.replicate (hashDigestSize $ digestAlgo h) 0 of + Nothing -> error $ "Failed to create zero hash" + Just h' -> h' + digestAlgo :: Digest a -> a + digestAlgo = undefined + +isZeroRef :: Ref' c -> Bool +isZeroRef (Ref _ h) = all (==0) $ BA.unpack h + + +refFromDigest :: Storage' c -> RefDigest -> IO (Maybe (Ref' c)) +refFromDigest st dgst = fmap (const $ Ref st dgst) <$> ioLoadBytesFromStorage st dgst + +readRef :: Storage -> ByteString -> IO (Maybe Ref) +readRef s b = + case readRefDigest b of + Nothing -> return Nothing + Just dgst -> refFromDigest s dgst + +copyRef' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Ref' c -> IO (c (Ref' c')) +copyRef' st ref'@(Ref _ dgst) = refFromDigest st dgst >>= \case Just ref -> return $ return ref + Nothing -> doCopy + where doCopy = do mbobj' <- ioLoadObject ref' + mbobj <- sequence $ copyObject' st <$> mbobj' + sequence $ unsafeStoreObject st <$> join mbobj + +copyRecItem' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> RecItem' c -> IO (c (RecItem' c')) +copyRecItem' st = \case + RecEmpty -> return $ return $ RecEmpty + RecInt x -> return $ return $ RecInt x + RecNum x -> return $ return $ RecNum x + RecText x -> return $ return $ RecText x + RecBinary x -> return $ return $ RecBinary x + RecDate x -> return $ return $ RecDate x + RecUUID x -> return $ return $ RecUUID x + RecRef x -> fmap RecRef <$> copyRef' st x + RecWeak x -> return $ return $ RecWeak x + RecUnknown t x -> return $ return $ RecUnknown t x + +copyObject' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (c (Object' c')) +copyObject' _ (Blob bs) = return $ return $ Blob bs +copyObject' st (Rec rs) = fmap Rec . sequence <$> mapM (\( n, item ) -> fmap ( n, ) <$> copyRecItem' st item) rs +copyObject' _ ZeroObject = return $ return ZeroObject +copyObject' _ (UnknownObject otype content) = return $ return $ UnknownObject otype content + +copyRef :: forall c c' m. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => Storage' c' -> Ref' c -> m (LoadResult c (Ref' c')) +copyRef st ref' = liftIO $ returnLoadResult <$> copyRef' st ref' + +copyRecItem :: forall c c' m. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => Storage' c' -> RecItem' c -> m (LoadResult c (RecItem' c')) +copyRecItem st item' = liftIO $ returnLoadResult <$> copyRecItem' st item' + +copyObject :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (LoadResult c (Object' c')) +copyObject st obj' = returnLoadResult <$> copyObject' st obj' + +partialRef :: PartialStorage -> Ref -> PartialRef +partialRef st (Ref _ dgst) = Ref st dgst + +partialRefFromDigest :: PartialStorage -> RefDigest -> PartialRef +partialRefFromDigest st dgst = Ref st dgst + + +data Object' c + = Blob ByteString + | Rec [(ByteString, RecItem' c)] + | ZeroObject + | UnknownObject ByteString ByteString + deriving (Show) + +type Object = Object' Complete +type PartialObject = Object' Partial + +data RecItem' c + = RecEmpty + | RecInt Integer + | RecNum Rational + | RecText Text + | RecBinary ByteString + | RecDate ZonedTime + | RecUUID UUID + | RecRef (Ref' c) + | RecWeak RefDigest + | RecUnknown ByteString ByteString + deriving (Show) + +type RecItem = RecItem' Complete + +serializeObject :: Object' c -> BL.ByteString +serializeObject = \case + Blob cnt -> BL.fromChunks [BC.pack "blob ", BC.pack (show $ B.length cnt), BC.singleton '\n', cnt] + Rec rec -> let cnt = BL.fromChunks $ concatMap (uncurry serializeRecItem) rec + in BL.fromChunks [BC.pack "rec ", BC.pack (show $ BL.length cnt), BC.singleton '\n'] `BL.append` cnt + ZeroObject -> BL.empty + UnknownObject otype cnt -> BL.fromChunks [ otype, BC.singleton ' ', BC.pack (show $ B.length cnt), BC.singleton '\n', cnt ] + +-- |Serializes and stores object data without ony dependencies, so is safe only +-- if all the referenced objects are already stored or reference is partial. +unsafeStoreObject :: Storage' c -> Object' c -> IO (Ref' c) +unsafeStoreObject storage = \case + ZeroObject -> return $ zeroRef storage + obj -> unsafeStoreRawBytes storage $ serializeObject obj + +storeObject :: PartialStorage -> PartialObject -> IO PartialRef +storeObject = unsafeStoreObject + +storeRawBytes :: PartialStorage -> BL.ByteString -> IO PartialRef +storeRawBytes = unsafeStoreRawBytes + +serializeRecItem :: ByteString -> RecItem' c -> [ByteString] +serializeRecItem name (RecEmpty) = [name, BC.pack ":e", BC.singleton ' ', BC.singleton '\n'] +serializeRecItem name (RecInt x) = [name, BC.pack ":i", BC.singleton ' ', BC.pack (show x), BC.singleton '\n'] +serializeRecItem name (RecNum x) = [name, BC.pack ":n", BC.singleton ' ', BC.pack (showRatio x), BC.singleton '\n'] +serializeRecItem name (RecText x) = [name, BC.pack ":t", BC.singleton ' ', escaped, BC.singleton '\n'] + where escaped = BC.concatMap escape $ encodeUtf8 x + escape '\n' = BC.pack "\n\t" + escape c = BC.singleton c +serializeRecItem name (RecBinary x) = [name, BC.pack ":b ", showHex x, BC.singleton '\n'] +serializeRecItem name (RecDate x) = [name, BC.pack ":d", BC.singleton ' ', BC.pack (formatTime defaultTimeLocale "%s %z" x), BC.singleton '\n'] +serializeRecItem name (RecUUID x) = [name, BC.pack ":u", BC.singleton ' ', U.toASCIIBytes x, BC.singleton '\n'] +serializeRecItem name (RecRef x) = [name, BC.pack ":r ", showRef x, BC.singleton '\n'] +serializeRecItem name (RecWeak x) = [name, BC.pack ":w ", showRefDigest x, BC.singleton '\n'] +serializeRecItem name (RecUnknown t x) = [ name, BC.singleton ':', t, BC.singleton ' ', x, BC.singleton '\n' ] + +lazyLoadObject :: forall c. StorageCompleteness c => Ref' c -> LoadResult c (Object' c) +lazyLoadObject = returnLoadResult . unsafePerformIO . ioLoadObject + +ioLoadObject :: forall c. StorageCompleteness c => Ref' c -> IO (c (Object' c)) +ioLoadObject ref | isZeroRef ref = return $ return ZeroObject +ioLoadObject ref@(Ref st rhash) = do + file' <- ioLoadBytes ref + return $ do + file <- file' + let chash = hashToRefDigest file + when (chash /= rhash) $ error $ "Hash mismatch on object " ++ BC.unpack (showRef ref) {- TODO throw -} + return $ case runExcept $ unsafeDeserializeObject st file of + Left err -> error $ showErebosError err ++ ", ref " ++ BC.unpack (showRef ref) {- TODO throw -} + Right (x, rest) | BL.null rest -> x + | otherwise -> error $ "Superfluous content after " ++ BC.unpack (showRef ref) {- TODO throw -} + +lazyLoadBytes :: forall c. StorageCompleteness c => Ref' c -> LoadResult c BL.ByteString +lazyLoadBytes ref | isZeroRef ref = returnLoadResult (return BL.empty :: c BL.ByteString) +lazyLoadBytes ref = returnLoadResult $ unsafePerformIO $ ioLoadBytes ref + +unsafeDeserializeObject :: Storage' c -> BL.ByteString -> Except ErebosError (Object' c, BL.ByteString) +unsafeDeserializeObject _ bytes | BL.null bytes = return (ZeroObject, bytes) +unsafeDeserializeObject st bytes = + case BLC.break (=='\n') bytes of + (line, rest) | Just (otype, len) <- splitObjPrefix line -> do + let (content, next) = first BL.toStrict $ BL.splitAt (fromIntegral len) $ BL.drop 1 rest + guard $ B.length content == len + (,next) <$> case otype of + _ | otype == BC.pack "blob" -> return $ Blob content + | otype == BC.pack "rec" -> maybe (throwOtherError $ "malformed record item ") + (return . Rec) $ sequence $ map parseRecLine $ mergeCont [] $ BC.lines content + | otherwise -> return $ UnknownObject otype content + _ -> throwOtherError $ "malformed object" + where splitObjPrefix line = do + [otype, tlen] <- return $ BLC.words line + (len, rest) <- BLC.readInt tlen + guard $ BL.null rest + return (BL.toStrict otype, len) + + mergeCont cs (a:b:rest) | Just ('\t', b') <- BC.uncons b = mergeCont (b':BC.pack "\n":cs) (a:rest) + mergeCont cs (a:rest) = B.concat (a : reverse cs) : mergeCont [] rest + mergeCont _ [] = [] + + parseRecLine line = do + colon <- BC.elemIndex ':' line + space <- BC.elemIndex ' ' line + guard $ colon < space + let name = B.take colon line + itype = B.take (space-colon-1) $ B.drop (colon+1) line + content = B.drop (space+1) line + + let val = fromMaybe (RecUnknown itype content) $ + case BC.unpack itype of + "e" -> do guard $ B.null content + return RecEmpty + "i" -> do (num, rest) <- BC.readInteger content + guard $ B.null rest + return $ RecInt num + "n" -> RecNum <$> parseRatio content + "t" -> return $ RecText $ decodeUtf8With lenientDecode content + "b" -> RecBinary <$> readHex content + "d" -> RecDate <$> parseTimeM False defaultTimeLocale "%s %z" (BC.unpack content) + "u" -> RecUUID <$> U.fromASCIIBytes content + "r" -> RecRef . Ref st <$> readRefDigest content + "w" -> RecWeak <$> readRefDigest content + _ -> Nothing + return (name, val) + +deserializeObject :: PartialStorage -> BL.ByteString -> Except ErebosError (PartialObject, BL.ByteString) +deserializeObject = unsafeDeserializeObject + +deserializeObjects :: PartialStorage -> BL.ByteString -> Except ErebosError [PartialObject] +deserializeObjects _ bytes | BL.null bytes = return [] +deserializeObjects st bytes = do (obj, rest) <- deserializeObject st bytes + (obj:) <$> deserializeObjects st rest + + +collectObjects :: Object -> [Object] +collectObjects obj = obj : map fromStored (fst $ collectOtherStored S.empty obj) + +collectStoredObjects :: Stored Object -> [Stored Object] +collectStoredObjects obj = obj : (fst $ collectOtherStored S.empty $ fromStored obj) + +collectOtherStored :: Set RefDigest -> Object -> ([Stored Object], Set RefDigest) +collectOtherStored seen (Rec items) = foldr helper ([], seen) $ map snd items + where helper (RecRef ref) (xs, s) | r <- refDigest ref + , r `S.notMember` s + = let o = wrappedLoad ref + (xs', s') = collectOtherStored (S.insert r s) $ fromStored o + in ((o : xs') ++ xs, s') + helper _ (xs, s) = (xs, s) +collectOtherStored seen _ = ([], seen) + + +deriving instance StorableUUID HeadID +deriving instance StorableUUID HeadTypeID + + +class Monad m => MonadStorage m where + getStorage :: m Storage + mstore :: Storable a => a -> m (Stored a) + + default mstore :: MonadIO m => Storable a => a -> m (Stored a) + mstore x = do + st <- getStorage + wrappedStore st x + +instance MonadIO m => MonadStorage (ReaderT Storage m) where + getStorage = ask + + +class Storable a where + store' :: a -> Store + load' :: Load a + + store :: StorageCompleteness c => Storage' c -> a -> IO (Ref' c) + store st = evalStore st . store' + load :: Ref -> a + load = evalLoad load' + +class Storable a => ZeroStorable a where + fromZero :: Storage -> a + +data Store = StoreBlob ByteString + | StoreRec (forall c. StorageCompleteness c => Storage' c -> [IO [(ByteString, RecItem' c)]]) + | StoreZero + | StoreUnknown ByteString ByteString + +evalStore :: StorageCompleteness c => Storage' c -> Store -> IO (Ref' c) +evalStore st = unsafeStoreObject st <=< evalStoreObject st + +evalStoreObject :: StorageCompleteness c => Storage' c -> Store -> IO (Object' c) +evalStoreObject _ (StoreBlob x) = return $ Blob x +evalStoreObject s (StoreRec f) = Rec . concat <$> sequence (f s) +evalStoreObject _ StoreZero = return ZeroObject +evalStoreObject _ (StoreUnknown otype content) = return $ UnknownObject otype content + +newtype StoreRecM c a = StoreRecM (ReaderT (Storage' c) (Writer [IO [(ByteString, RecItem' c)]]) a) + deriving (Functor, Applicative, Monad) + +type StoreRec c = StoreRecM c () + +newtype Load a = Load (ReaderT (Ref, Object) (Except ErebosError) a) + deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError ErebosError) + +evalLoad :: Load a -> Ref -> a +evalLoad (Load f) ref = either (error {- TODO throw -} . ((BC.unpack (showRef ref) ++ ": ") ++) . showErebosError) id $ + runExcept $ runReaderT f (ref, lazyLoadObject ref) + +loadCurrentRef :: Load Ref +loadCurrentRef = Load $ asks fst + +loadCurrentObject :: Load Object +loadCurrentObject = Load $ asks snd + +newtype LoadRec a = LoadRec (ReaderT (Ref, [(ByteString, RecItem)]) (Except ErebosError) a) + deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadError ErebosError) + +loadRecCurrentRef :: LoadRec Ref +loadRecCurrentRef = LoadRec $ asks fst + +loadRecItems :: LoadRec [(ByteString, RecItem)] +loadRecItems = LoadRec $ asks snd + + +instance Storable Object where + store' (Blob bs) = StoreBlob bs + store' (Rec xs) = StoreRec $ \st -> return $ do + Rec xs' <- copyObject st (Rec xs) + return xs' + store' ZeroObject = StoreZero + store' (UnknownObject otype content) = StoreUnknown otype content + + load' = loadCurrentObject + + store st = unsafeStoreObject st <=< copyObject st + load = lazyLoadObject + +instance Storable ByteString where + store' = storeBlob + load' = loadBlob id + +instance Storable a => Storable [a] where + store' [] = storeZero + store' (x:xs) = storeRec $ do + storeRef "i" x + storeRef "n" xs + + load' = loadCurrentObject >>= \case + ZeroObject -> return [] + _ -> loadRec $ (:) + <$> loadRef "i" + <*> loadRef "n" + +instance Storable a => ZeroStorable [a] where + fromZero _ = [] + + +storeBlob :: ByteString -> Store +storeBlob = StoreBlob + +storeRec :: (forall c. StorageCompleteness c => StoreRec c) -> Store +storeRec sr = StoreRec $ do + let StoreRecM r = sr + execWriter . runReaderT r + +storeZero :: Store +storeZero = StoreZero + + +class StorableText a where + toText :: a -> Text + fromText :: MonadError ErebosError m => Text -> m a + +instance StorableText Text where + toText = id; fromText = return + +instance StorableText [Char] where + toText = T.pack; fromText = return . T.unpack + + +class StorableDate a where + toDate :: a -> ZonedTime + fromDate :: ZonedTime -> a + +instance StorableDate ZonedTime where + toDate = id; fromDate = id + +instance StorableDate UTCTime where + toDate = utcToZonedTime utc + fromDate = zonedTimeToUTC + +instance StorableDate Day where + toDate day = toDate $ UTCTime day 0 + fromDate = utctDay . fromDate + + +class StorableUUID a where + toUUID :: a -> UUID + fromUUID :: UUID -> a + +instance StorableUUID UUID where + toUUID = id; fromUUID = id + + +storeEmpty :: String -> StoreRec c +storeEmpty name = StoreRecM $ tell [return [(BC.pack name, RecEmpty)]] + +storeMbEmpty :: String -> Maybe () -> StoreRec c +storeMbEmpty name = maybe (return ()) (const $ storeEmpty name) + +storeInt :: Integral a => String -> a -> StoreRec c +storeInt name x = StoreRecM $ tell [return [(BC.pack name, RecInt $ toInteger x)]] + +storeMbInt :: Integral a => String -> Maybe a -> StoreRec c +storeMbInt name = maybe (return ()) (storeInt name) + +storeNum :: (Real a, Fractional a) => String -> a -> StoreRec c +storeNum name x = StoreRecM $ tell [return [(BC.pack name, RecNum $ toRational x)]] + +storeMbNum :: (Real a, Fractional a) => String -> Maybe a -> StoreRec c +storeMbNum name = maybe (return ()) (storeNum name) + +storeText :: StorableText a => String -> a -> StoreRec c +storeText name x = StoreRecM $ tell [return [(BC.pack name, RecText $ toText x)]] + +storeMbText :: StorableText a => String -> Maybe a -> StoreRec c +storeMbText name = maybe (return ()) (storeText name) + +storeBinary :: BA.ByteArrayAccess a => String -> a -> StoreRec c +storeBinary name x = StoreRecM $ tell [return [(BC.pack name, RecBinary $ BA.convert x)]] + +storeMbBinary :: BA.ByteArrayAccess a => String -> Maybe a -> StoreRec c +storeMbBinary name = maybe (return ()) (storeBinary name) + +storeDate :: StorableDate a => String -> a -> StoreRec c +storeDate name x = StoreRecM $ tell [return [(BC.pack name, RecDate $ toDate x)]] + +storeMbDate :: StorableDate a => String -> Maybe a -> StoreRec c +storeMbDate name = maybe (return ()) (storeDate name) + +storeUUID :: StorableUUID a => String -> a -> StoreRec c +storeUUID name x = StoreRecM $ tell [return [(BC.pack name, RecUUID $ toUUID x)]] + +storeMbUUID :: StorableUUID a => String -> Maybe a -> StoreRec c +storeMbUUID name = maybe (return ()) (storeUUID name) + +storeRef :: Storable a => StorageCompleteness c => String -> a -> StoreRec c +storeRef name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return [(BC.pack name, RecRef ref)] + +storeMbRef :: Storable a => StorageCompleteness c => String -> Maybe a -> StoreRec c +storeMbRef name = maybe (return ()) (storeRef name) + +storeRawRef :: StorageCompleteness c => String -> Ref -> StoreRec c +storeRawRef name ref = StoreRecM $ do + st <- ask + tell $ (:[]) $ do + ref' <- copyRef st ref + return [(BC.pack name, RecRef ref')] + +storeMbRawRef :: StorageCompleteness c => String -> Maybe Ref -> StoreRec c +storeMbRawRef name = maybe (return ()) (storeRawRef name) + +storeZRef :: (ZeroStorable a, StorageCompleteness c) => String -> a -> StoreRec c +storeZRef name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return $ if isZeroRef ref then [] + else [(BC.pack name, RecRef ref)] + +storeWeak :: Storable a => StorageCompleteness c => String -> a -> StoreRec c +storeWeak name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return [ ( BC.pack name, RecWeak $ refDigest ref ) ] + +storeMbWeak :: Storable a => StorageCompleteness c => String -> Maybe a -> StoreRec c +storeMbWeak name = maybe (return ()) (storeWeak name) + +storeRawWeak :: StorageCompleteness c => String -> RefDigest -> StoreRec c +storeRawWeak name dgst = StoreRecM $ do + tell $ (:[]) $ do + return [ ( BC.pack name, RecWeak dgst ) ] + +storeMbRawWeak :: StorageCompleteness c => String -> Maybe RefDigest -> StoreRec c +storeMbRawWeak name = maybe (return ()) (storeRawWeak name) + +storeZWeak :: (ZeroStorable a, StorageCompleteness c) => String -> a -> StoreRec c +storeZWeak name x = StoreRecM $ do + s <- ask + tell $ (:[]) $ do + ref <- store s x + return $ if isZeroRef ref then [] + else [ ( BC.pack name, RecWeak $ refDigest ref ) ] + + +storeRecItems :: StorageCompleteness c => [ ( ByteString, RecItem ) ] -> StoreRec c +storeRecItems items = StoreRecM $ do + st <- ask + tell $ flip map items $ \( name, value ) -> do + value' <- copyRecItem st value + return [ ( name, value' ) ] + +loadBlob :: (ByteString -> a) -> Load a +loadBlob f = loadCurrentObject >>= \case + Blob x -> return $ f x + _ -> throwOtherError "Expecting blob" + +loadRec :: LoadRec a -> Load a +loadRec (LoadRec lrec) = loadCurrentObject >>= \case + Rec rs -> do + ref <- loadCurrentRef + either throwError return $ runExcept $ runReaderT lrec (ref, rs) + _ -> throwOtherError "Expecting record" + +loadZero :: a -> Load a +loadZero x = loadCurrentObject >>= \case + ZeroObject -> return x + _ -> throwOtherError "Expecting zero" + + +loadEmpty :: String -> LoadRec () +loadEmpty name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbEmpty name + +loadMbEmpty :: String -> LoadRec (Maybe ()) +loadMbEmpty name = listToMaybe . mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecEmpty ) | name' == bname + = Just () + p _ = Nothing + +loadInt :: Num a => String -> LoadRec a +loadInt name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbInt name + +loadMbInt :: Num a => String -> LoadRec (Maybe a) +loadMbInt name = listToMaybe . mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecInt x ) | name' == bname + = Just (fromInteger x) + p _ = Nothing + +loadNum :: (Real a, Fractional a) => String -> LoadRec a +loadNum name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbNum name + +loadMbNum :: (Real a, Fractional a) => String -> LoadRec (Maybe a) +loadMbNum name = listToMaybe . mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecNum x ) | name' == bname + = Just (fromRational x) + p _ = Nothing + +loadText :: StorableText a => String -> LoadRec a +loadText name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbText name + +loadMbText :: StorableText a => String -> LoadRec (Maybe a) +loadMbText name = listToMaybe <$> loadTexts name + +loadTexts :: StorableText a => String -> LoadRec [a] +loadTexts name = sequence . mapMaybe p =<< loadRecItems + where + bname = BC.pack name + p ( name', RecText x ) | name' == bname + = Just (fromText x) + p _ = Nothing + +loadBinary :: BA.ByteArray a => String -> LoadRec a +loadBinary name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbBinary name + +loadMbBinary :: BA.ByteArray a => String -> LoadRec (Maybe a) +loadMbBinary name = listToMaybe <$> loadBinaries name + +loadBinaries :: BA.ByteArray a => String -> LoadRec [a] +loadBinaries name = mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecBinary x ) | name' == bname + = Just (BA.convert x) + p _ = Nothing + +loadDate :: StorableDate a => String -> LoadRec a +loadDate name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbDate name + +loadMbDate :: StorableDate a => String -> LoadRec (Maybe a) +loadMbDate name = listToMaybe . mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecDate x ) | name' == bname + = Just (fromDate x) + p _ = Nothing + +loadUUID :: StorableUUID a => String -> LoadRec a +loadUUID name = maybe (throwOtherError $ "Missing record iteem '"++name++"'") return =<< loadMbUUID name + +loadMbUUID :: StorableUUID a => String -> LoadRec (Maybe a) +loadMbUUID name = listToMaybe . mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecUUID x ) | name' == bname + = Just (fromUUID x) + p _ = Nothing + +loadRawRef :: String -> LoadRec Ref +loadRawRef name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbRawRef name + +loadMbRawRef :: String -> LoadRec (Maybe Ref) +loadMbRawRef name = listToMaybe <$> loadRawRefs name + +loadRawRefs :: String -> LoadRec [Ref] +loadRawRefs name = mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecRef x ) | name' == bname = Just x + p _ = Nothing + +loadRef :: Storable a => String -> LoadRec a +loadRef name = load <$> loadRawRef name + +loadMbRef :: Storable a => String -> LoadRec (Maybe a) +loadMbRef name = fmap load <$> loadMbRawRef name + +loadRefs :: Storable a => String -> LoadRec [a] +loadRefs name = map load <$> loadRawRefs name + +loadZRef :: ZeroStorable a => String -> LoadRec a +loadZRef name = loadMbRef name >>= \case + Nothing -> do Ref st _ <- loadRecCurrentRef + return $ fromZero st + Just x -> return x + +loadRawWeak :: String -> LoadRec RefDigest +loadRawWeak name = maybe (throwOtherError $ "Missing record item '"++name++"'") return =<< loadMbRawWeak name + +loadMbRawWeak :: String -> LoadRec (Maybe RefDigest) +loadMbRawWeak name = listToMaybe <$> loadRawWeaks name + +loadRawWeaks :: String -> LoadRec [ RefDigest ] +loadRawWeaks name = mapMaybe p <$> loadRecItems + where + bname = BC.pack name + p ( name', RecRef x ) | name' == bname = Just (refDigest x) + p ( name', RecWeak x ) | name' == bname = Just x + p _ = Nothing + + + +instance Storable a => Storable (Stored a) where + store st = copyRef st . storedRef + store' (Stored _ x) = store' x + load' = Stored <$> loadCurrentRef <*> load' + +instance ZeroStorable a => ZeroStorable (Stored a) where + fromZero st = Stored (zeroRef st) $ fromZero st + +fromStored :: Stored a -> a +fromStored = storedObject' + +storedRef :: Stored a -> Ref +storedRef = storedRef' + +wrappedStore :: MonadIO m => Storable a => Storage -> a -> m (Stored a) +wrappedStore st x = do ref <- liftIO $ store st x + return $ Stored ref x + +wrappedLoad :: Storable a => Ref -> Stored a +wrappedLoad ref = Stored ref (load ref) + +copyStored :: forall m a. MonadIO m => Storage -> Stored a -> m (Stored a) +copyStored st (Stored ref' x) = liftIO $ returnLoadResult . fmap (\r -> Stored r x) <$> copyRef' st ref' + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapStored :: (a -> b) -> Stored a -> Stored b +unsafeMapStored f (Stored ref x) = Stored ref (f x) + + +showRatio :: Rational -> String +showRatio r = case decimalRatio r of + Just (n, 1) -> show n + Just (n', d) -> let n = abs n' + in (if n' < 0 then "-" else "") ++ show (n `div` d) ++ "." ++ + (concatMap (show.(`mod` 10).snd) $ reverse $ takeWhile ((>1).fst) $ zip (iterate (`div` 10) d) (iterate (`div` 10) (n `mod` d))) + Nothing -> show (numerator r) ++ "/" ++ show (denominator r) + +decimalRatio :: Rational -> Maybe (Integer, Integer) +decimalRatio r = do + let n = numerator r + d = denominator r + (c2, d') = takeFactors 2 d + (c5, d'') = takeFactors 5 d' + guard $ d'' == 1 + let m = if c2 > c5 then 5 ^ (c2 - c5) + else 2 ^ (c5 - c2) + return (n * m, d * m) + +takeFactors :: Integer -> Integer -> (Integer, Integer) +takeFactors f n | n `mod` f == 0 = let (c, n') = takeFactors f (n `div` f) + in (c+1, n') + | otherwise = (0, n) + +parseRatio :: ByteString -> Maybe Rational +parseRatio bs = case BC.groupBy ((==) `on` isNumber) bs of + (m:xs) | m == BC.pack "-" -> negate <$> positive xs + xs -> positive xs + where positive = \case + [bx] -> fromInteger . fst <$> BC.readInteger bx + [bx, op, by] -> do + (x, _) <- BC.readInteger bx + (y, _) <- BC.readInteger by + case BC.unpack op of + "." -> return $ (x % 1) + (y % (10 ^ BC.length by)) + "/" -> return $ x % y + _ -> Nothing + _ -> Nothing diff --git a/src/Erebos/Pairing.hs b/src/Erebos/Pairing.hs new file mode 100644 index 0000000..e3ebf2b --- /dev/null +++ b/src/Erebos/Pairing.hs @@ -0,0 +1,244 @@ +module Erebos.Pairing ( + PairingService(..), + PairingState(..), + PairingAttributes(..), + PairingResult(..), + PairingFailureReason(..), + + pairingRequest, + pairingAccept, + pairingReject, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.Reader + +import Crypto.Random + +import Data.Bits +import Data.ByteArray qualified as BA +import Data.ByteString (ByteString) +import Data.ByteString qualified as BS +import Data.ByteString.Char8 qualified as BC +import Data.Kind +import Data.Maybe +import Data.Typeable +import Data.Word + +import Erebos.Identity +import Erebos.Network +import Erebos.Object +import Erebos.PubKey +import Erebos.Service +import Erebos.State +import Erebos.Storable + +data PairingService a = PairingRequest (Stored (Signed IdentityData)) (Stored (Signed IdentityData)) RefDigest + | PairingResponse ByteString + | PairingRequestNonce ByteString + | PairingAccept a + | PairingReject + +data PairingState a = NoPairing + | OurRequest UnifiedIdentity UnifiedIdentity ByteString + | OurRequestConfirm (Maybe (PairingVerifiedResult a)) + | OurRequestReady + | PeerRequest UnifiedIdentity UnifiedIdentity ByteString RefDigest + | PeerRequestConfirm + | PairingDone + +data PairingFailureReason a = PairingUserRejected + | PairingUnexpectedMessage (PairingState a) (PairingService a) + | PairingFailedOther ErebosError + +data PairingAttributes a = PairingAttributes + { pairingHookRequest :: ServiceHandler (PairingService a) () + , pairingHookResponse :: String -> ServiceHandler (PairingService a) () + , pairingHookRequestNonce :: String -> ServiceHandler (PairingService a) () + , pairingHookRequestNonceFailed :: ServiceHandler (PairingService a) () + , pairingHookConfirmedResponse :: ServiceHandler (PairingService a) () + , pairingHookConfirmedRequest :: ServiceHandler (PairingService a) () + , pairingHookAcceptedResponse :: ServiceHandler (PairingService a) () + , pairingHookAcceptedRequest :: ServiceHandler (PairingService a) () + , pairingHookVerifyFailed :: ServiceHandler (PairingService a) () + , pairingHookRejected :: ServiceHandler (PairingService a) () + , pairingHookFailed :: PairingFailureReason a -> ServiceHandler (PairingService a) () + } + +class (Typeable a, Storable a) => PairingResult a where + type PairingVerifiedResult a :: Type + type PairingVerifiedResult a = a + + pairingServiceID :: proxy a -> ServiceID + pairingVerifyResult :: a -> ServiceHandler (PairingService a) (Maybe (PairingVerifiedResult a)) + pairingFinalizeRequest :: PairingVerifiedResult a -> ServiceHandler (PairingService a) () + pairingFinalizeResponse :: ServiceHandler (PairingService a) a + defaultPairingAttributes :: proxy (PairingService a) -> PairingAttributes a + + +instance Storable a => Storable (PairingService a) where + store' (PairingRequest idReq idRsp x) = storeRec $ do + storeRef "id-req" idReq + storeRef "id-rsp" idRsp + storeBinary "request" x + store' (PairingResponse x) = storeRec $ storeBinary "response" x + store' (PairingRequestNonce x) = storeRec $ storeBinary "reqnonce" x + store' (PairingAccept x) = store' x + store' (PairingReject) = storeRec $ storeEmpty "reject" + + load' = do + res <- loadRec $ do + (req :: Maybe ByteString) <- loadMbBinary "request" + idReq <- loadMbRef "id-req" + idRsp <- loadMbRef "id-rsp" + rsp <- loadMbBinary "response" + rnonce <- loadMbBinary "reqnonce" + rej <- loadMbEmpty "reject" + return $ catMaybes + [ PairingRequest <$> idReq <*> idRsp <*> (refDigestFromByteString =<< req) + , PairingResponse <$> rsp + , PairingRequestNonce <$> rnonce + , const PairingReject <$> rej + ] + case res of + x:_ -> return x + [] -> PairingAccept <$> load' + + +instance PairingResult a => Service (PairingService a) where + serviceID _ = pairingServiceID @a Proxy + + type ServiceAttributes (PairingService a) = PairingAttributes a + defaultServiceAttributes = defaultPairingAttributes + + type ServiceState (PairingService a) = PairingState a + emptyServiceState _ = NoPairing + + serviceHandler spacket = ((,fromStored spacket) <$> svcGet) >>= \case + (NoPairing, PairingRequest pdata sdata confirm) -> do + self <- maybe (throwOtherError "failed to validate received identity") return $ validateIdentity sdata + self' <- maybe (throwOtherError "failed to validate own identity") return . + validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal + when (not $ self `sameIdentity` self') $ do + throwOtherError "pairing request to different identity" + + peer <- maybe (throwOtherError "failed to validate received peer identity") return $ validateIdentity pdata + peer' <- asks $ svcPeerIdentity + when (not $ peer `sameIdentity` peer') $ do + throwOtherError "pairing request from different identity" + + join $ asks $ pairingHookRequest . svcAttributes + nonce <- liftIO $ getRandomBytes 32 + svcSet $ PeerRequest peer self nonce confirm + replyPacket $ PairingResponse nonce + (NoPairing, _) -> return () + + (PairingDone, _) -> return () + (_, PairingReject) -> do + join $ asks $ pairingHookRejected . svcAttributes + svcSet NoPairing + + (OurRequest self peer nonce, PairingResponse pnonce) -> do + hook <- asks $ pairingHookResponse . svcAttributes + hook $ confirmationNumber $ nonceDigest self peer nonce pnonce + svcSet $ OurRequestConfirm Nothing + replyPacket $ PairingRequestNonce nonce + x@(OurRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x + + (OurRequestConfirm _, PairingAccept x) -> do + flip catchError (reject . PairingFailedOther) $ do + pairingVerifyResult x >>= \case + Just x' -> do + join $ asks $ pairingHookConfirmedRequest . svcAttributes + svcSet $ OurRequestConfirm (Just x') + Nothing -> do + join $ asks $ pairingHookVerifyFailed . svcAttributes + svcSet NoPairing + replyPacket PairingReject + + x@(OurRequestConfirm _, _) -> reject $ uncurry PairingUnexpectedMessage x + + (OurRequestReady, PairingAccept x) -> do + flip catchError (reject . PairingFailedOther) $ do + pairingVerifyResult x >>= \case + Just x' -> do + pairingFinalizeRequest x' + join $ asks $ pairingHookAcceptedResponse . svcAttributes + svcSet $ PairingDone + Nothing -> do + join $ asks $ pairingHookVerifyFailed . svcAttributes + throwOtherError "" + x@(OurRequestReady, _) -> reject $ uncurry PairingUnexpectedMessage x + + (PeerRequest peer self nonce dgst, PairingRequestNonce pnonce) -> do + if dgst == nonceDigest peer self pnonce BS.empty + then do hook <- asks $ pairingHookRequestNonce . svcAttributes + hook $ confirmationNumber $ nonceDigest peer self pnonce nonce + svcSet PeerRequestConfirm + else do join $ asks $ pairingHookRequestNonceFailed . svcAttributes + svcSet NoPairing + replyPacket PairingReject + x@(PeerRequest {}, _) -> reject $ uncurry PairingUnexpectedMessage x + x@(PeerRequestConfirm, _) -> reject $ uncurry PairingUnexpectedMessage x + +reject :: PairingResult a => PairingFailureReason a -> ServiceHandler (PairingService a) () +reject reason = do + join $ asks $ flip pairingHookFailed reason . svcAttributes + svcSet NoPairing + replyPacket PairingReject + + +nonceDigest :: UnifiedIdentity -> UnifiedIdentity -> ByteString -> ByteString -> RefDigest +nonceDigest idReq idRsp nonceReq nonceRsp = hashToRefDigest $ serializeObject $ Rec + [ (BC.pack "id-req", RecRef $ storedRef $ idData idReq) + , (BC.pack "id-rsp", RecRef $ storedRef $ idData idRsp) + , (BC.pack "nonce-req", RecBinary nonceReq) + , (BC.pack "nonce-rsp", RecBinary nonceRsp) + ] + +confirmationNumber :: RefDigest -> String +confirmationNumber dgst = + case map fromIntegral $ BA.unpack dgst :: [Word32] of + (a:b:c:d:_) -> let str = show $ ((a `shift` 24) .|. (b `shift` 16) .|. (c `shift` 8) .|. d) `mod` (10 ^ len) + in replicate (len - length str) '0' ++ str + _ -> "" + where len = 6 + +pairingRequest :: forall a m e proxy. (PairingResult a, MonadIO m, MonadError e m, FromErebosError e) => proxy a -> Peer -> m () +pairingRequest _ peer = do + self <- liftIO $ serverIdentity $ peerServer peer + nonce <- liftIO $ getRandomBytes 32 + pid <- peerIdentity peer >>= \case + PeerIdentityFull pid -> return pid + _ -> throwOtherError "incomplete peer identity" + sendToPeerWith @(PairingService a) peer $ \case + NoPairing -> return (Just $ PairingRequest (idData self) (idData pid) (nonceDigest self pid nonce BS.empty), OurRequest self pid nonce) + _ -> throwOtherError "already in progress" + +pairingAccept :: forall a m e proxy. (PairingResult a, MonadIO m, MonadError e m, FromErebosError e) => proxy a -> Peer -> m () +pairingAccept _ peer = runPeerService @(PairingService a) peer $ do + svcGet >>= \case + NoPairing -> throwOtherError $ "none in progress" + OurRequest {} -> throwOtherError $ "waiting for peer" + OurRequestConfirm Nothing -> do + join $ asks $ pairingHookConfirmedResponse . svcAttributes + svcSet OurRequestReady + OurRequestConfirm (Just verified) -> do + join $ asks $ pairingHookAcceptedResponse . svcAttributes + pairingFinalizeRequest verified + svcSet PairingDone + OurRequestReady -> throwOtherError $ "already accepted, waiting for peer" + PeerRequest {} -> throwOtherError $ "waiting for peer" + PeerRequestConfirm -> do + join $ asks $ pairingHookAcceptedRequest . svcAttributes + replyPacket . PairingAccept =<< pairingFinalizeResponse + svcSet PairingDone + PairingDone -> throwOtherError $ "already done" + +pairingReject :: forall a m e proxy. (PairingResult a, MonadIO m, MonadError e m, FromErebosError e) => proxy a -> Peer -> m () +pairingReject _ peer = runPeerService @(PairingService a) peer $ do + svcGet >>= \case + NoPairing -> throwOtherError $ "none in progress" + PairingDone -> throwOtherError $ "already done" + _ -> reject PairingUserRejected diff --git a/src/Erebos/PubKey.hs b/src/Erebos/PubKey.hs new file mode 100644 index 0000000..a2ee519 --- /dev/null +++ b/src/Erebos/PubKey.hs @@ -0,0 +1,155 @@ +module Erebos.PubKey ( + PublicKey, SecretKey, + KeyPair(generateKeys), loadKey, loadKeyMb, + Signature(sigKey), Signed, signedData, signedSignature, + sign, signAdd, isSignedBy, + fromSigned, + unsafeMapSigned, + + PublicKexKey, SecretKexKey, + dhSecret, +) where + +import Control.Monad + +import Crypto.Error +import qualified Crypto.PubKey.Ed25519 as ED +import qualified Crypto.PubKey.Curve25519 as CX + +import Data.ByteArray +import Data.ByteString (ByteString) +import qualified Data.Text as T + +import Erebos.Storable +import Erebos.Storage.Key + +data PublicKey = PublicKey ED.PublicKey + deriving (Show) + +data SecretKey = SecretKey ED.SecretKey (Stored PublicKey) + +data Signature = Signature + { sigKey :: Stored PublicKey + , sigSignature :: ED.Signature + } + deriving (Show) + +data Signed a = Signed + { signedData_ :: Stored a + , signedSignature_ :: [Stored Signature] + } + deriving (Show) + +signedData :: Signed a -> Stored a +signedData = signedData_ + +signedSignature :: Signed a -> [Stored Signature] +signedSignature = signedSignature_ + +instance KeyPair SecretKey PublicKey where + keyGetPublic (SecretKey _ pub) = pub + keyGetData (SecretKey sec _) = convert sec + keyFromData kdata spub = do + skey <- maybeCryptoError $ ED.secretKey kdata + let PublicKey pkey = fromStored spub + guard $ ED.toPublic skey == pkey + return $ SecretKey skey spub + generateKeys st = do + secret <- ED.generateSecretKey + public <- wrappedStore st $ PublicKey $ ED.toPublic secret + let pair = SecretKey secret public + storeKey pair + return (pair, public) + +instance Storable PublicKey where + store' (PublicKey pk) = storeRec $ do + storeText "type" $ T.pack "ed25519" + storeBinary "pubkey" pk + + load' = loadRec $ do + ktype <- loadText "type" + guard $ ktype == "ed25519" + maybe (throwOtherError "public key decoding failed") (return . PublicKey) . + maybeCryptoError . (ED.publicKey :: ByteString -> CryptoFailable ED.PublicKey) =<< + loadBinary "pubkey" + +instance Storable Signature where + store' sig = storeRec $ do + storeRef "key" $ sigKey sig + storeBinary "sig" $ sigSignature sig + + load' = loadRec $ Signature + <$> loadRef "key" + <*> loadSignature "sig" + where loadSignature = maybe (throwOtherError "signature decoding failed") return . + maybeCryptoError . (ED.signature :: ByteString -> CryptoFailable ED.Signature) <=< loadBinary + +instance Storable a => Storable (Signed a) where + store' sig = storeRec $ do + storeRef "SDATA" $ signedData sig + mapM_ (storeRef "sig") $ signedSignature sig + + load' = loadRec $ do + sdata <- loadRef "SDATA" + sigs <- loadRefs "sig" + forM_ sigs $ \sig -> do + let PublicKey pubkey = fromStored $ sigKey $ fromStored sig + when (not $ ED.verify pubkey (storedRef sdata) $ sigSignature $ fromStored sig) $ + throwOtherError "signature verification failed" + return $ Signed sdata sigs + +sign :: MonadStorage m => SecretKey -> Stored a -> m (Signed a) +sign secret val = signAdd secret $ Signed val [] + +signAdd :: MonadStorage m => SecretKey -> Signed a -> m (Signed a) +signAdd (SecretKey secret spublic) (Signed val sigs) = do + let PublicKey public = fromStored spublic + sig = ED.sign secret public $ storedRef val + ssig <- mstore $ Signature spublic sig + return $ Signed val (ssig : sigs) + +isSignedBy :: Signed a -> Stored PublicKey -> Bool +isSignedBy sig key = key `elem` map (sigKey . fromStored) (signedSignature sig) + +fromSigned :: Stored (Signed a) -> a +fromSigned = fromStored . signedData . fromStored + +-- |Passed function needs to preserve the object representation to be safe +unsafeMapSigned :: (a -> b) -> Signed a -> Signed b +unsafeMapSigned f signed = signed { signedData_ = unsafeMapStored f (signedData_ signed) } + + +data PublicKexKey = PublicKexKey CX.PublicKey + deriving (Show) + +data SecretKexKey = SecretKexKey CX.SecretKey (Stored PublicKexKey) + +instance KeyPair SecretKexKey PublicKexKey where + keyGetPublic (SecretKexKey _ pub) = pub + keyGetData (SecretKexKey sec _) = convert sec + keyFromData kdata spub = do + skey <- maybeCryptoError $ CX.secretKey kdata + let PublicKexKey pkey = fromStored spub + guard $ CX.toPublic skey == pkey + return $ SecretKexKey skey spub + generateKeys st = do + secret <- CX.generateSecretKey + public <- wrappedStore st $ PublicKexKey $ CX.toPublic secret + let pair = SecretKexKey secret public + storeKey pair + return (pair, public) + +instance Storable PublicKexKey where + store' (PublicKexKey pk) = storeRec $ do + storeText "type" $ T.pack "x25519" + storeBinary "pubkey" pk + + load' = loadRec $ do + ktype <- loadText "type" + guard $ ktype == "x25519" + maybe (throwOtherError "public key decoding failed") (return . PublicKexKey) . + maybeCryptoError . (CX.publicKey :: ScrubbedBytes -> CryptoFailable CX.PublicKey) =<< + loadBinary "pubkey" + +dhSecret :: SecretKexKey -> PublicKexKey -> ScrubbedBytes +dhSecret (SecretKexKey secret _) (PublicKexKey public) = convert $ CX.dh public secret diff --git a/src/Erebos/Service.hs b/src/Erebos/Service.hs new file mode 100644 index 0000000..4499ef9 --- /dev/null +++ b/src/Erebos/Service.hs @@ -0,0 +1,202 @@ +module Erebos.Service ( + Service(..), + SomeService(..), someService, someServiceAttr, someServiceID, + SomeServiceState(..), fromServiceState, someServiceEmptyState, + SomeServiceGlobalState(..), fromServiceGlobalState, someServiceEmptyGlobalState, + SomeStorageWatcher(..), + ServiceID, mkServiceID, + + ServiceHandler, + ServiceInput(..), + ServiceReply(..), + runServiceHandler, + + svcGet, svcSet, svcModify, + svcGetGlobal, svcSetGlobal, svcModifyGlobal, + svcGetLocal, svcSetLocal, + + svcSelf, + svcPrint, + + replyPacket, replyStored, replyStoredRef, + afterCommit, +) where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer + +import Data.Kind +import Data.Typeable + +import Erebos.Identity +import {-# SOURCE #-} Erebos.Network +import Erebos.Network.Protocol +import Erebos.State +import Erebos.Storable +import Erebos.Storage.Head +import Erebos.UUID qualified as U + +class ( + Typeable s, Storable s, + Typeable (ServiceAttributes s), + Typeable (ServiceState s), + Typeable (ServiceGlobalState s) + ) => Service s where + + serviceID :: proxy s -> ServiceID + serviceHandler :: Stored s -> ServiceHandler s () + + serviceNewPeer :: ServiceHandler s () + serviceNewPeer = return () + + type ServiceAttributes s = attr | attr -> s + type ServiceAttributes s = Proxy s + defaultServiceAttributes :: proxy s -> ServiceAttributes s + default defaultServiceAttributes :: ServiceAttributes s ~ Proxy s => proxy s -> ServiceAttributes s + defaultServiceAttributes _ = Proxy + + type ServiceState s :: Type + type ServiceState s = () + emptyServiceState :: proxy s -> ServiceState s + default emptyServiceState :: ServiceState s ~ () => proxy s -> ServiceState s + emptyServiceState _ = () + + type ServiceGlobalState s :: Type + type ServiceGlobalState s = () + emptyServiceGlobalState :: proxy s -> ServiceGlobalState s + default emptyServiceGlobalState :: ServiceGlobalState s ~ () => proxy s -> ServiceGlobalState s + emptyServiceGlobalState _ = () + + serviceStorageWatchers :: proxy s -> [SomeStorageWatcher s] + serviceStorageWatchers _ = [] + + serviceStopServer :: proxy s -> Server -> ServiceGlobalState s -> [ ( Peer, ServiceState s ) ] -> IO () + serviceStopServer _ _ _ _ = return () + + +data SomeService = forall s. Service s => SomeService (Proxy s) (ServiceAttributes s) + +someService :: forall s proxy. Service s => proxy s -> SomeService +someService _ = SomeService @s Proxy (defaultServiceAttributes @s Proxy) + +someServiceAttr :: forall s. Service s => ServiceAttributes s -> SomeService +someServiceAttr attr = SomeService @s Proxy attr + +someServiceID :: SomeService -> ServiceID +someServiceID (SomeService s _) = serviceID s + +data SomeServiceState = forall s. Service s => SomeServiceState (Proxy s) (ServiceState s) + +fromServiceState :: Service s => proxy s -> SomeServiceState -> Maybe (ServiceState s) +fromServiceState _ (SomeServiceState _ s) = cast s + +someServiceEmptyState :: SomeService -> SomeServiceState +someServiceEmptyState (SomeService p _) = SomeServiceState p (emptyServiceState p) + +data SomeServiceGlobalState = forall s. Service s => SomeServiceGlobalState (Proxy s) (ServiceGlobalState s) + +fromServiceGlobalState :: Service s => proxy s -> SomeServiceGlobalState -> Maybe (ServiceGlobalState s) +fromServiceGlobalState _ (SomeServiceGlobalState _ s) = cast s + +someServiceEmptyGlobalState :: SomeService -> SomeServiceGlobalState +someServiceEmptyGlobalState (SomeService p _) = SomeServiceGlobalState p (emptyServiceGlobalState p) + + +data SomeStorageWatcher s + = forall a. Eq a => SomeStorageWatcher (Stored LocalState -> a) (a -> ServiceHandler s ()) + | forall a. Eq a => GlobalStorageWatcher (Stored LocalState -> a) (Server -> a -> ExceptT ErebosError IO ()) + + +mkServiceID :: String -> ServiceID +mkServiceID = maybe (error "Invalid service ID") ServiceID . U.fromString + +data ServiceInput s = ServiceInput + { svcAttributes :: ServiceAttributes s + , svcPeer :: Peer + , svcPeerIdentity :: UnifiedIdentity + , svcServer :: Server + , svcPrintOp :: String -> IO () + , svcNewStreams :: [ RawStreamReader ] + } + +data ServiceReply s + = ServiceReply (Either s (Stored s)) Bool + | ServiceOpenStream (RawStreamWriter -> IO ()) + | ServiceFinally (IO ()) + +data ServiceHandlerState s = ServiceHandlerState + { svcValue :: ServiceState s + , svcGlobal :: ServiceGlobalState s + , svcLocal :: Stored LocalState + } + +newtype ServiceHandler s a = ServiceHandler (ReaderT (ServiceInput s) (WriterT [ServiceReply s] (StateT (ServiceHandlerState s) (ExceptT ErebosError IO))) a) + deriving (Functor, Applicative, Monad, MonadReader (ServiceInput s), MonadWriter [ServiceReply s], MonadState (ServiceHandlerState s), MonadError ErebosError, MonadIO) + +instance MonadStorage (ServiceHandler s) where + getStorage = asks $ peerStorage . svcPeer + +instance MonadHead LocalState (ServiceHandler s) where + updateLocalHead f = do + (ls, x) <- f =<< gets svcLocal + modify $ \s -> s { svcLocal = ls } + return x + +runServiceHandler :: Service s => Head LocalState -> ServiceInput s -> ServiceState s -> ServiceGlobalState s -> ServiceHandler s () -> IO ([ServiceReply s], (ServiceState s, ServiceGlobalState s)) +runServiceHandler h input svc global shandler = do + let sstate = ServiceHandlerState { svcValue = svc, svcGlobal = global, svcLocal = headStoredObject h } + ServiceHandler handler = shandler + (runExceptT $ flip runStateT sstate $ execWriterT $ flip runReaderT input $ handler) >>= \case + Left err -> do + svcPrintOp input $ "service failed: " ++ showErebosError err + return ([], (svc, global)) + Right (rsp, sstate') + | svcLocal sstate' == svcLocal sstate -> return (rsp, (svcValue sstate', svcGlobal sstate')) + | otherwise -> replaceHead h (svcLocal sstate') >>= \case + Left (Just h') -> runServiceHandler h' input svc global shandler + _ -> return (rsp, (svcValue sstate', svcGlobal sstate')) + +svcGet :: ServiceHandler s (ServiceState s) +svcGet = gets svcValue + +svcSet :: ServiceState s -> ServiceHandler s () +svcSet x = modify $ \st -> st { svcValue = x } + +svcModify :: (ServiceState s -> ServiceState s) -> ServiceHandler s () +svcModify f = modify $ \st -> st { svcValue = f (svcValue st) } + +svcGetGlobal :: ServiceHandler s (ServiceGlobalState s) +svcGetGlobal = gets svcGlobal + +svcSetGlobal :: ServiceGlobalState s -> ServiceHandler s () +svcSetGlobal x = modify $ \st -> st { svcGlobal = x } + +svcModifyGlobal :: (ServiceGlobalState s -> ServiceGlobalState s) -> ServiceHandler s () +svcModifyGlobal f = modify $ \st -> st { svcGlobal = f (svcGlobal st) } + +svcGetLocal :: ServiceHandler s (Stored LocalState) +svcGetLocal = gets svcLocal + +svcSetLocal :: Stored LocalState -> ServiceHandler s () +svcSetLocal x = modify $ \st -> st { svcLocal = x } + +svcSelf :: ServiceHandler s UnifiedIdentity +svcSelf = maybe (throwOtherError "failed to validate own identity") return . + validateExtendedIdentity . lsIdentity . fromStored =<< svcGetLocal + +svcPrint :: String -> ServiceHandler s () +svcPrint str = afterCommit . ($ str) =<< asks svcPrintOp + +replyPacket :: Service s => s -> ServiceHandler s () +replyPacket x = tell [ServiceReply (Left x) True] + +replyStored :: Service s => Stored s -> ServiceHandler s () +replyStored x = tell [ServiceReply (Right x) True] + +replyStoredRef :: Service s => Stored s -> ServiceHandler s () +replyStoredRef x = tell [ServiceReply (Right x) False] + +afterCommit :: IO () -> ServiceHandler s () +afterCommit x = tell [ServiceFinally x] diff --git a/src/Erebos/Service/Stream.hs b/src/Erebos/Service/Stream.hs new file mode 100644 index 0000000..67df4d7 --- /dev/null +++ b/src/Erebos/Service/Stream.hs @@ -0,0 +1,74 @@ +module Erebos.Service.Stream ( + StreamPacket(..), + StreamReader, getStreamReaderNumber, + StreamWriter, getStreamWriterNumber, + openStream, receivedStreams, + readStreamPacket, writeStreamPacket, + writeStream, + closeStream, +) where + +import Control.Concurrent.MVar +import Control.Monad.Reader +import Control.Monad.Writer + +import Data.ByteString (ByteString) +import Data.Word + +import Erebos.Flow +import Erebos.Network +import Erebos.Network.Protocol +import Erebos.Service + + +data StreamReader = StreamReader RawStreamReader + +getStreamReaderNumber :: StreamReader -> IO Int +getStreamReaderNumber (StreamReader stream) = return $ rsrNum stream + +data StreamWriter = StreamWriter (MVar StreamWriterData) + +data StreamWriterData = StreamWriterData + { swdStream :: RawStreamWriter + , swdSequence :: Maybe Word64 + } + +getStreamWriterNumber :: StreamWriter -> IO Int +getStreamWriterNumber (StreamWriter stream) = rswNum . swdStream <$> readMVar stream + + +openStream :: Service s => ServiceHandler s StreamWriter +openStream = do + mvar <- liftIO newEmptyMVar + tell [ ServiceOpenStream $ \stream -> putMVar mvar $ StreamWriterData stream (Just 0) ] + return $ StreamWriter mvar + +receivedStreams :: Service s => ServiceHandler s [ StreamReader ] +receivedStreams = do + map StreamReader <$> asks svcNewStreams + +readStreamPacket :: StreamReader -> IO StreamPacket +readStreamPacket (StreamReader stream) = do + readFlowIO (rsrFlow stream) + +writeStreamPacket :: StreamWriter -> StreamPacket -> IO () +writeStreamPacket (StreamWriter mvar) packet = do + withMVar mvar $ \swd -> do + writeFlowIO (rswFlow $ swdStream swd) packet + +writeStream :: StreamWriter -> ByteString -> IO () +writeStream (StreamWriter mvar) bytes = do + modifyMVar_ mvar $ \swd -> do + case swdSequence swd of + Just seqNum -> do + writeFlowIO (rswFlow $ swdStream swd) $ StreamData seqNum bytes + return swd { swdSequence = Just (seqNum + 1) } + Nothing -> do + fail "writeStream: stream closed" + +closeStream :: StreamWriter -> IO () +closeStream (StreamWriter mvar) = do + withMVar mvar $ \swd -> do + case swdSequence swd of + Just seqNum -> writeFlowIO (rswFlow $ swdStream swd) $ StreamClosed seqNum + Nothing -> fail "closeStream: stream already closed" diff --git a/src/Erebos/Set.hs b/src/Erebos/Set.hs new file mode 100644 index 0000000..270c0ba --- /dev/null +++ b/src/Erebos/Set.hs @@ -0,0 +1,87 @@ +module Erebos.Set ( + Set, + + emptySet, + loadSet, + storeSetAdd, + storeSetAddComponent, + + fromSetBy, +) where + +import Control.Arrow +import Control.Monad.IO.Class + +import Data.Function +import Data.List +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Ord + +import Erebos.Object +import Erebos.Storable +import Erebos.Storage.Merge +import Erebos.Util + +data Set a = Set [Stored (SetItem (Component a))] + deriving (Eq) + +data SetItem a = SetItem + { siPrev :: [Stored (SetItem a)] + , siItem :: [Stored a] + } + +instance Storable a => Storable (SetItem a) where + store' x = storeRec $ do + mapM_ (storeRef "PREV") $ siPrev x + mapM_ (storeRef "item") $ siItem x + + load' = loadRec $ SetItem + <$> loadRefs "PREV" + <*> loadRefs "item" + +instance Mergeable a => Mergeable (Set a) where + type Component (Set a) = SetItem (Component a) + mergeSorted = Set + toComponents (Set items) = items + + +emptySet :: Set a +emptySet = Set [] + +loadSet :: Mergeable a => Ref -> Set a +loadSet = mergeSorted . (:[]) . wrappedLoad + +storeSetAdd :: (Mergeable a, MonadIO m) => Storage -> a -> Set a -> m (Set a) +storeSetAdd st x (Set prev) = Set . (:[]) <$> wrappedStore st SetItem + { siPrev = prev + , siItem = toComponents x + } + +storeSetAddComponent :: (Mergeable a, MonadStorage m, MonadIO m) => Stored (Component a) -> Set a -> m (Set a) +storeSetAddComponent component (Set prev) = Set . (:[]) <$> mstore SetItem + { siPrev = prev + , siItem = [ component ] + } + + +fromSetBy :: forall a. Mergeable a => (a -> a -> Ordering) -> Set a -> [a] +fromSetBy cmp (Set heads) = sortBy cmp $ map merge $ groupRelated items + where + -- gather all item components in the set history + items :: [Stored (Component a)] + items = walkAncestors (siItem . fromStored) heads + + -- map individual roots to full root set as joined in history of individual items + rootToRootSet :: Map RefDigest [RefDigest] + rootToRootSet = foldl' (\m rs -> foldl' (\m' r -> M.insertWith (\a b -> uniq $ sort $ a++b) r rs m') m rs) M.empty $ + map (map (refDigest . storedRef) . storedRoots) items + + -- get full root set for given item component + storedRootSet :: Stored (Component a) -> [RefDigest] + storedRootSet = fromJust . flip M.lookup rootToRootSet . refDigest . storedRef . head . storedRoots + + -- group components of single item, i.e. components sharing some root + groupRelated :: [Stored (Component a)] -> [[Stored (Component a)]] + groupRelated = map (map fst) . groupBy ((==) `on` snd) . sortBy (comparing snd) . map (id &&& storedRootSet) diff --git a/src/Erebos/State.hs b/src/Erebos/State.hs new file mode 100644 index 0000000..7cd82de --- /dev/null +++ b/src/Erebos/State.hs @@ -0,0 +1,157 @@ +module Erebos.State ( + LocalState(..), + SharedState(..), SharedType(..), + SharedTypeID, mkSharedTypeID, + + MonadStorage(..), + MonadHead(..), + updateLocalHead_, + + updateLocalState, updateLocalState_, + updateSharedState, updateSharedState_, + lookupSharedValue, makeSharedStateUpdate, + + localIdentity, + headLocalIdentity, + + mergeSharedIdentity, +) where + +import Control.Monad.Except +import Control.Monad.Reader + +import Data.ByteString (ByteString) +import Data.ByteString.Char8 qualified as BC +import Data.Typeable + +import Erebos.Identity +import Erebos.Object +import Erebos.PubKey +import Erebos.Storable +import Erebos.Storage.Head +import Erebos.Storage.Merge +import Erebos.UUID (UUID) +import Erebos.UUID qualified as U + +data LocalState = LocalState + { lsPrev :: Maybe RefDigest + , lsIdentity :: Stored (Signed ExtendedIdentityData) + , lsShared :: [Stored SharedState] + , lsOther :: [ ( ByteString, RecItem ) ] + } + +data SharedState = SharedState + { ssPrev :: [Stored SharedState] + , ssType :: Maybe SharedTypeID + , ssValue :: [Ref] + } + +newtype SharedTypeID = SharedTypeID UUID + deriving (Eq, Ord, StorableUUID) + +mkSharedTypeID :: String -> SharedTypeID +mkSharedTypeID = maybe (error "Invalid shared type ID") SharedTypeID . U.fromString + +class Mergeable a => SharedType a where + sharedTypeID :: proxy a -> SharedTypeID + +instance Storable LocalState where + store' LocalState {..} = storeRec $ do + mapM_ (storeRawWeak "PREV") lsPrev + storeRef "id" lsIdentity + mapM_ (storeRef "shared") lsShared + storeRecItems lsOther + + load' = loadRec $ do + lsPrev <- loadMbRawWeak "PREV" + lsIdentity <- loadRef "id" + lsShared <- loadRefs "shared" + lsOther <- filter ((`notElem` [ BC.pack "PREV", BC.pack "id", BC.pack "shared" ]) . fst) <$> loadRecItems + return LocalState {..} + +instance HeadType LocalState where + headTypeID _ = mkHeadTypeID "1d7491a9-7bcb-4eaa-8f13-c8c4c4087e4e" + +instance Storable SharedState where + store' st = storeRec $ do + mapM_ (storeRef "PREV") $ ssPrev st + storeMbUUID "type" $ ssType st + mapM_ (storeRawRef "value") $ ssValue st + + load' = loadRec $ SharedState + <$> loadRefs "PREV" + <*> loadMbUUID "type" + <*> loadRawRefs "value" + +instance SharedType (Maybe ComposedIdentity) where + sharedTypeID _ = mkSharedTypeID "0c6c1fe0-f2d7-4891-926b-c332449f7871" + + +class (MonadIO m, MonadStorage m) => MonadHead a m where + updateLocalHead :: (Stored a -> m (Stored a, b)) -> m b + getLocalHead :: m (Stored a) + getLocalHead = updateLocalHead $ \x -> return (x, x) + +updateLocalHead_ :: MonadHead a m => (Stored a -> m (Stored a)) -> m () +updateLocalHead_ f = updateLocalHead (fmap (,()) . f) + +instance (HeadType a, MonadIO m) => MonadHead a (ReaderT (Head a) m) where + updateLocalHead f = do + h <- ask + snd <$> updateHead h f + + +localIdentity :: LocalState -> UnifiedIdentity +localIdentity ls = maybe (error "failed to verify local identity") + (updateOwners $ maybe [] idExtDataF $ lookupSharedValue $ lsShared ls) + (validateExtendedIdentity $ lsIdentity ls) + +headLocalIdentity :: Head LocalState -> UnifiedIdentity +headLocalIdentity = localIdentity . headObject + + +updateLocalState :: forall m b. MonadHead LocalState m => (Stored LocalState -> m ( Stored LocalState, b )) -> m b +updateLocalState f = updateLocalHead $ \ls -> do + ( ls', x ) <- f ls + (, x) <$> if ls' == ls + then return ls' + else mstore (fromStored ls') { lsPrev = Just $ refDigest (storedRef ls) } + +updateLocalState_ :: forall m. MonadHead LocalState m => (Stored LocalState -> m (Stored LocalState)) -> m () +updateLocalState_ f = updateLocalState (fmap (,()) . f) + + +updateSharedState_ :: forall a m. (SharedType a, MonadHead LocalState m) => (a -> m a) -> Stored LocalState -> m (Stored LocalState) +updateSharedState_ f = fmap fst <$> updateSharedState (fmap (,()) . f) + +updateSharedState :: forall a b m. (SharedType a, MonadHead LocalState m) => (a -> m (a, b)) -> Stored LocalState -> m (Stored LocalState, b) +updateSharedState f = \ls -> do + let shared = lsShared $ fromStored ls + val = lookupSharedValue shared + st <- getStorage + (val', x) <- f val + (,x) <$> if toComponents val' == toComponents val + then return ls + else do shared' <- makeSharedStateUpdate st val' shared + wrappedStore st (fromStored ls) { lsShared = [shared'] } + +lookupSharedValue :: forall a. SharedType a => [Stored SharedState] -> a +lookupSharedValue = mergeSorted . filterAncestors . map wrappedLoad . concatMap (ssValue . fromStored) . filterAncestors . helper + where helper (x:xs) | Just sid <- ssType (fromStored x), sid == sharedTypeID @a Proxy = x : helper xs + | otherwise = helper $ ssPrev (fromStored x) ++ xs + helper [] = [] + +makeSharedStateUpdate :: forall a m. MonadIO m => SharedType a => Storage -> a -> [Stored SharedState] -> m (Stored SharedState) +makeSharedStateUpdate st val prev = liftIO $ wrappedStore st SharedState + { ssPrev = prev + , ssType = Just $ sharedTypeID @a Proxy + , ssValue = storedRef <$> toComponents val + } + + +mergeSharedIdentity :: (MonadHead LocalState m, MonadError e m, FromErebosError e) => m UnifiedIdentity +mergeSharedIdentity = updateLocalState $ updateSharedState $ \case + Just cidentity -> do + identity <- mergeIdentity cidentity + return (Just $ toComposedIdentity identity, identity) + Nothing -> throwOtherError "no existing shared identity" diff --git a/src/Erebos/Storable.hs b/src/Erebos/Storable.hs new file mode 100644 index 0000000..caaf525 --- /dev/null +++ b/src/Erebos/Storable.hs @@ -0,0 +1,44 @@ +{-| +Description: Encoding custom types into Erebos objects + +Module provides the 'Storable' class for types that can be serialized to/from +Erebos objects, along with various helpers, mostly for encoding using records. + +The 'Stored' wrapper for objects actually encoded and stored in some storage is +defined here as well. +-} + +module Erebos.Storable ( + Storable(..), ZeroStorable(..), + StorableText(..), StorableDate(..), StorableUUID(..), + + Store, StoreRec, + storeBlob, storeRec, storeZero, + storeEmpty, storeInt, storeNum, storeText, storeBinary, storeDate, storeUUID, storeRef, storeRawRef, storeWeak, storeRawWeak, + storeMbEmpty, storeMbInt, storeMbNum, storeMbText, storeMbBinary, storeMbDate, storeMbUUID, storeMbRef, storeMbRawRef, storeMbWeak, storeMbRawWeak, + storeZRef, storeZWeak, + storeRecItems, + + Load, LoadRec, + loadCurrentRef, loadCurrentObject, + loadRecCurrentRef, loadRecItems, + + loadBlob, loadRec, loadZero, + loadEmpty, loadInt, loadNum, loadText, loadBinary, loadDate, loadUUID, loadRef, loadRawRef, loadRawWeak, + loadMbEmpty, loadMbInt, loadMbNum, loadMbText, loadMbBinary, loadMbDate, loadMbUUID, loadMbRef, loadMbRawRef, loadMbRawWeak, + loadTexts, loadBinaries, loadRefs, loadRawRefs, loadRawWeaks, + loadZRef, + + Stored, + fromStored, storedRef, + wrappedStore, wrappedLoad, + copyStored, + unsafeMapStored, + + Storage, MonadStorage(..), + + module Erebos.Error, +) where + +import Erebos.Error +import Erebos.Object.Internal diff --git a/src/Erebos/Storage.hs b/src/Erebos/Storage.hs new file mode 100644 index 0000000..f1cce84 --- /dev/null +++ b/src/Erebos/Storage.hs @@ -0,0 +1,29 @@ +{-| +Description: Working with storage and heads + +Provides functions for opening 'Storage' backed either by disk or memory. For +conveniance also function for working with 'Head's are reexported here. +-} + +module Erebos.Storage ( + Storage, PartialStorage, + openStorage, memoryStorage, + deriveEphemeralStorage, derivePartialStorage, + + Head, HeadType, + HeadID, HeadTypeID, + headId, headStorage, headRef, headObject, headStoredObject, + loadHeads, loadHead, reloadHead, + storeHead, replaceHead, updateHead, updateHead_, + + WatchedHead, + watchHead, watchHeadWith, unwatchHead, + watchHeadRaw, + + MonadStorage(..), +) where + +import Erebos.Object.Internal +import Erebos.Storage.Disk +import Erebos.Storage.Head +import Erebos.Storage.Memory diff --git a/src/Erebos/Storage/Backend.hs b/src/Erebos/Storage/Backend.hs new file mode 100644 index 0000000..620d423 --- /dev/null +++ b/src/Erebos/Storage/Backend.hs @@ -0,0 +1,28 @@ +{-| +Description: Implement custom storage backend + +Exports type class, which can be used to create custom 'Storage' backend. +-} + +module Erebos.Storage.Backend ( + StorageBackend(..), + Complete, Partial, + Storage, PartialStorage, + newStorage, + + WatchID, startWatchID, nextWatchID, +) where + +import Control.Concurrent.MVar + +import Data.HashTable.IO qualified as HT + +import Erebos.Object.Internal +import Erebos.Storage.Internal + + +newStorage :: StorageBackend bck => bck -> IO (Storage' (BackendCompleteness bck)) +newStorage stBackend = do + stRefGeneration <- newMVar =<< HT.new + stRefRoots <- newMVar =<< HT.new + return Storage {..} diff --git a/src/Erebos/Storage/Disk.hs b/src/Erebos/Storage/Disk.hs new file mode 100644 index 0000000..8e35940 --- /dev/null +++ b/src/Erebos/Storage/Disk.hs @@ -0,0 +1,230 @@ +module Erebos.Storage.Disk ( + openStorage, +) where + +import Codec.Compression.Zlib + +import Control.Arrow +import Control.Concurrent +import Control.Exception +import Control.Monad + +import Data.ByteArray qualified as BA +import Data.ByteString (ByteString) +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.ByteString.Lazy.Char8 qualified as BLC +import Data.Function +import Data.List +import Data.Maybe + +import System.Directory +import System.FSNotify +import System.FilePath +import System.IO +import System.IO.Error + +import Erebos.Object +import Erebos.Storage.Backend +import Erebos.Storage.Head +import Erebos.Storage.Internal +import Erebos.Storage.Platform +import Erebos.UUID qualified as U + + +data DiskStorage = StorageDir + { dirPath :: FilePath + , dirWatchers :: MVar ( Maybe WatchManager, [ HeadTypeID ], WatchList ) + } + +instance Eq DiskStorage where + (==) = (==) `on` dirPath + +instance Show DiskStorage where + show StorageDir { dirPath = path } = "dir:" ++ path + +instance StorageBackend DiskStorage where + backendLoadBytes StorageDir {..} dgst = + handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ + Just . decompress . BL.fromChunks . (:[]) <$> (B.readFile $ refPath dirPath dgst) + backendStoreBytes StorageDir {..} dgst = writeFileOnce (refPath dirPath dgst) . compress + + + backendLoadHeads StorageDir {..} tid = do + let hpath = headTypePath dirPath tid + + files <- filterM (doesFileExist . (hpath </>)) =<< + handleJust (\e -> guard (isDoesNotExistError e)) (const $ return []) + (getDirectoryContents hpath) + fmap catMaybes $ forM files $ \hname -> do + case U.fromString hname of + Just hid -> do + content <- B.readFile (hpath </> hname) + return $ do + (h : _) <- Just (BC.lines content) + dgst <- readRefDigest h + Just $ ( HeadID hid, dgst ) + Nothing -> return Nothing + + backendLoadHead StorageDir {..} tid hid = do + handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ do + (h:_) <- BC.lines <$> B.readFile (headPath dirPath tid hid) + return $ readRefDigest h + + backendStoreHead StorageDir {..} tid hid dgst = do + Right () <- writeFileChecked (headPath dirPath tid hid) Nothing $ + showRefDigest dgst `B.append` BC.singleton '\n' + return () + + backendReplaceHead StorageDir {..} tid hid expected new = do + let filename = headPath dirPath tid hid + showDgstL r = showRefDigest r `B.append` BC.singleton '\n' + + writeFileChecked filename (Just $ showDgstL expected) (showDgstL new) >>= \case + Left Nothing -> return $ Left Nothing + Left (Just bs) -> do Just cur <- return $ readRefDigest $ BC.takeWhile (/='\n') bs + return $ Left $ Just cur + Right () -> return $ Right new + + backendWatchHead st@StorageDir {..} tid hid cb = do + modifyMVar dirWatchers $ \( mbmanager, ilist, wl ) -> do + manager <- maybe startManager return mbmanager + ilist' <- case tid `elem` ilist of + True -> return ilist + False -> do + void $ watchDir manager (headTypePath dirPath tid) (const True) $ \case + ev@Added {} | Just ihid <- HeadID <$> U.fromString (takeFileName (eventPath ev)) -> do + backendLoadHead st tid ihid >>= \case + Just dgst -> do + (_, _, iwl) <- readMVar dirWatchers + mapM_ ($ dgst) . map wlFun . filter ((== (tid, ihid)) . wlHead) . wlList $ iwl + Nothing -> return () + _ -> return () + return $ tid : ilist + return $ first ( Just manager, ilist', ) $ watchListAdd tid hid cb wl + + backendUnwatchHead StorageDir {..} wid = do + modifyMVar_ dirWatchers $ \( mbmanager, ilist, wl ) -> do + return ( mbmanager, ilist, watchListDel wid wl ) + + + backendListKeys StorageDir {..} = do + catMaybes . map (readRefDigest . BC.pack) <$> + listDirectory (keyDirPath dirPath) + + backendLoadKey StorageDir {..} dgst = do + tryIOError (BC.readFile (keyFilePath dirPath dgst)) >>= \case + Right kdata -> return $ Just $ BA.convert kdata + Left _ -> return Nothing + + backendStoreKey StorageDir {..} dgst key = do + writeFileOnce (keyFilePath dirPath dgst) (BL.fromStrict $ BA.convert key) + + backendRemoveKey StorageDir {..} dgst = do + void $ tryIOError (removeFile $ keyFilePath dirPath dgst) + + +storageVersion :: String +storageVersion = "0.1" + +openStorage :: FilePath -> IO Storage +openStorage path = modifyIOError annotate $ do + let versionFileName = "erebos-storage" + let versionPath = path </> versionFileName + let writeVersionFile = writeFileOnce versionPath $ BLC.pack $ storageVersion <> "\n" + + maybeVersion <- handleJust (guard . isDoesNotExistError) (const $ return Nothing) $ + Just <$> readFile versionPath + version <- case maybeVersion of + Just versionContent -> do + return $ takeWhile (/= '\n') versionContent + + Nothing -> do + files <- handleJust (guard . isDoesNotExistError) (const $ return []) $ + listDirectory path + when (not $ or + [ null files + , versionFileName `elem` files + , (versionFileName ++ ".lock") `elem` files + , "objects" `elem` files && "heads" `elem` files + ]) $ do + fail "directory is neither empty, nor an existing erebos storage" + + createDirectoryIfMissing True $ path + writeVersionFile + takeWhile (/= '\n') <$> readFile versionPath + + when (version /= storageVersion) $ do + fail $ "unsupported storage version " <> version + + createDirectoryIfMissing True $ path </> "objects" + createDirectoryIfMissing True $ path </> "heads" + watchers <- newMVar ( Nothing, [], WatchList startWatchID [] ) + newStorage $ StorageDir path watchers + where + annotate e = annotateIOError e "failed to open storage" Nothing (Just path) + + +refPath :: FilePath -> RefDigest -> FilePath +refPath spath rdgst = intercalate "/" [ spath, "objects", BC.unpack alg, pref, rest ] + where (alg, dgst) = showRefDigestParts rdgst + (pref, rest) = splitAt 2 $ BC.unpack dgst + +headTypePath :: FilePath -> HeadTypeID -> FilePath +headTypePath spath (HeadTypeID tid) = spath </> "heads" </> U.toString tid + +headPath :: FilePath -> HeadTypeID -> HeadID -> FilePath +headPath spath tid (HeadID hid) = headTypePath spath tid </> U.toString hid + +keyDirPath :: FilePath -> FilePath +keyDirPath sdir = sdir </> "keys" + +keyFilePath :: FilePath -> RefDigest -> FilePath +keyFilePath sdir dgst = keyDirPath sdir </> (BC.unpack $ showRefDigest dgst) + + +openLockFile :: FilePath -> IO Handle +openLockFile path = do + createDirectoryIfMissing True (takeDirectory path) + retry 10 $ createFileExclusive path + where + retry :: Int -> IO a -> IO a + retry 0 act = act + retry n act = catchJust (\e -> if isAlreadyExistsError e then Just () else Nothing) + act (\_ -> threadDelay (100 * 1000) >> retry (n - 1) act) + +writeFileOnce :: FilePath -> BL.ByteString -> IO () +writeFileOnce file content = bracket (openLockFile locked) + hClose $ \h -> do + doesFileExist file >>= \case + True -> removeFile locked + False -> do BL.hPut h content + hClose h + renameFile locked file + where locked = file ++ ".lock" + +writeFileChecked :: FilePath -> Maybe ByteString -> ByteString -> IO (Either (Maybe ByteString) ()) +writeFileChecked file prev content = bracket (openLockFile locked) + hClose $ \h -> do + (prev,) <$> doesFileExist file >>= \case + (Nothing, True) -> do + current <- B.readFile file + removeFile locked + return $ Left $ Just current + (Nothing, False) -> do B.hPut h content + hClose h + renameFile locked file + return $ Right () + (Just expected, True) -> do + current <- B.readFile file + if current == expected then do B.hPut h content + hClose h + renameFile locked file + return $ return () + else do removeFile locked + return $ Left $ Just current + (Just _, False) -> do + removeFile locked + return $ Left Nothing + where locked = file ++ ".lock" diff --git a/src/Erebos/Storage/Head.hs b/src/Erebos/Storage/Head.hs new file mode 100644 index 0000000..285902d --- /dev/null +++ b/src/Erebos/Storage/Head.hs @@ -0,0 +1,258 @@ +{-| +Description: Define, use and watch heads + +Provides data types and functions for reading, writing or watching `Head's. +Type class `HeadType' is used to define custom new `Head' types. +-} + +module Erebos.Storage.Head ( + -- * Head type and accessors + Head, HeadType(..), + HeadID, HeadTypeID, mkHeadTypeID, + headId, headStorage, headRef, headObject, headStoredObject, + + -- * Loading and storing heads + loadHeads, loadHead, reloadHead, + storeHead, replaceHead, updateHead, updateHead_, + loadHeadRaw, storeHeadRaw, replaceHeadRaw, + + -- * Watching heads + WatchedHead, + watchHead, watchHeadWith, unwatchHead, + watchHeadRaw, +) where + +import Control.Concurrent +import Control.Monad +import Control.Monad.Reader + +import Data.Bifunctor +import Data.Typeable + +import Erebos.Object +import Erebos.Storable +import Erebos.Storage.Backend +import Erebos.Storage.Internal +import Erebos.UUID qualified as U + + +-- | Represents loaded Erebos storage head, along with the object it pointed to +-- at the time it was loaded. +-- +-- Each possible head type has associated unique ID, represented as +-- `HeadTypeID'. For each type, there can be multiple individual heads in given +-- storage, each also identified by unique ID (`HeadID'). +data Head a = Head HeadID (Stored a) + deriving (Eq, Show) + +-- | Instances of this class can be used as objects pointed to by heads in +-- Erebos storage. Each such type must be `Storable' and have a unique ID. +-- +-- To create a custom head type, generate a new UUID and assign it to the type using +-- `mkHeadTypeID': +-- +-- > instance HeadType MyType where +-- > headTypeID _ = mkHeadTypeID "86e8033d-c476-4f81-9b7c-fd36b9144475" +class Storable a => HeadType a where + headTypeID :: proxy a -> HeadTypeID + -- ^ Get the ID of the given head type; must be unique for each `HeadType' instance. + +instance MonadIO m => MonadStorage (ReaderT (Head a) m) where + getStorage = asks $ headStorage + + +-- | Get `HeadID' associated with given `Head'. +headId :: Head a -> HeadID +headId (Head uuid _) = uuid + +-- | Get storage from which the `Head' was loaded. +headStorage :: Head a -> Storage +headStorage = refStorage . headRef + +-- | Get `Ref' of the `Head'\'s associated object. +headRef :: Head a -> Ref +headRef (Head _ sx) = storedRef sx + +-- | Get the object the `Head' pointed to when it was loaded. +headObject :: Head a -> a +headObject (Head _ sx) = fromStored sx + +-- | Get the object the `Head' pointed to when it was loaded as a `Stored' value. +headStoredObject :: Head a -> Stored a +headStoredObject (Head _ sx) = sx + +-- | Create `HeadTypeID' from string representation of UUID. +mkHeadTypeID :: String -> HeadTypeID +mkHeadTypeID = maybe (error "Invalid head type ID") HeadTypeID . U.fromString + + +-- | Load all `Head's of type @a@ from storage. +loadHeads :: forall a m. MonadIO m => HeadType a => Storage -> m [Head a] +loadHeads st@Storage {..} = + map (uncurry Head . fmap (wrappedLoad . Ref st)) + <$> liftIO (backendLoadHeads stBackend (headTypeID @a Proxy)) + +-- | Try to load a `Head' of type @a@ from storage. +loadHead + :: forall a m. (HeadType a, MonadIO m) + => Storage -- ^ Storage from which to load the head + -> HeadID -- ^ ID of the particular head + -> m (Maybe (Head a)) -- ^ Head object, or `Nothing' if not found +loadHead st hid = fmap (Head hid . wrappedLoad) <$> loadHeadRaw st (headTypeID @a Proxy) hid + +-- | Try to load `Head' using a raw head and type IDs, getting `Ref' if found. +loadHeadRaw + :: forall m. MonadIO m + => Storage -- ^ Storage from which to load the head + -> HeadTypeID -- ^ ID of the head type + -> HeadID -- ^ ID of the particular head + -> m (Maybe Ref) -- ^ `Ref' pointing to the head object, or `Nothing' if not found +loadHeadRaw st@Storage {..} tid hid = do + fmap (Ref st) <$> liftIO (backendLoadHead stBackend tid hid) + +-- | Reload the given head from storage, returning `Head' with updated object, +-- or `Nothing' if there is no longer head with the particular ID in storage. +reloadHead :: (HeadType a, MonadIO m) => Head a -> m (Maybe (Head a)) +reloadHead (Head hid val) = loadHead (storedStorage val) hid + +-- | Store a new `Head' of type 'a' in the storage. +storeHead :: forall a m. MonadIO m => HeadType a => Storage -> a -> m (Head a) +storeHead st obj = do + let tid = headTypeID @a Proxy + stored <- wrappedStore st obj + hid <- storeHeadRaw st tid (storedRef stored) + return $ Head hid stored + +-- | Store a new `Head' in the storage, using the raw `HeadTypeID' and `Ref', +-- the function returns the assigned `HeadID' of the new head. +storeHeadRaw :: forall m. MonadIO m => Storage -> HeadTypeID -> Ref -> m HeadID +storeHeadRaw Storage {..} tid ref = liftIO $ do + hid <- HeadID <$> U.nextRandom + backendStoreHead stBackend tid hid (refDigest ref) + return hid + +-- | Try to replace existing `Head' of type @a@ in the storage. Function fails +-- if the head value in storage changed after being loaded here; for automatic +-- retry see `updateHead'. +replaceHead + :: forall a m. (HeadType a, MonadIO m) + => Head a -- ^ Existing head, associated object is supposed to match the one in storage + -> Stored a -- ^ Intended new value + -> m (Either (Maybe (Head a)) (Head a)) + -- ^ + -- [@`Left' `Nothing'@]: + -- Nothing was stored – the head no longer exists in storage. + -- [@`Left' (`Just' h)@]: + -- Nothing was stored – the head value in storage does not match + -- the first parameter, but is @h@ instead. + -- [@`Right' h@]: + -- Head value was updated in storage, the new head is @h@ (which is + -- the same as first parameter with associated object replaced by + -- the second parameter). +replaceHead prev@(Head hid pobj) stored' = liftIO $ do + let st = headStorage prev + tid = headTypeID @a Proxy + stored <- copyStored st stored' + bimap (fmap $ Head hid . wrappedLoad) (const $ Head hid stored) <$> + replaceHeadRaw st tid hid (storedRef pobj) (storedRef stored) + +-- | Try to replace existing head using raw IDs and `Ref's. +replaceHeadRaw + :: forall m. MonadIO m + => Storage -- ^ Storage to use + -> HeadTypeID -- ^ ID of the head type + -> HeadID -- ^ ID of the particular head + -> Ref -- ^ Expected value in storage + -> Ref -- ^ Intended new value + -> m (Either (Maybe Ref) Ref) + -- ^ + -- [@`Left' `Nothing'@]: + -- Nothing was stored – the head no longer exists in storage. + -- [@`Left' (`Just' r)@]: + -- Nothing was stored – the head value in storage does not match + -- the expected value, but is @r@ instead. + -- [@`Right' r@]: + -- Head value was updated in storage, the new head value is @r@ + -- (which is the same as the indended value). +replaceHeadRaw st@Storage {..} tid hid prev new = liftIO $ do + _ <- copyRef st new + bimap (fmap $ Ref st) (Ref st) <$> backendReplaceHead stBackend tid hid (refDigest prev) (refDigest new) + +-- | Update existing existing `Head' of type @a@ in the storage, using a given +-- function. The update function may be called multiple times in case the head +-- content changes concurrently during evaluation. +updateHead + :: (HeadType a, MonadIO m) + => Head a -- ^ Existing head to be updated + -> (Stored a -> m ( Stored a, b )) + -- ^ Function that gets current value of the head and returns updated + -- value, along with a custom extra value to be returned from + -- `updateHead' call. The function may be called multiple times. + -> m ( Maybe (Head a), b ) + -- ^ First element contains either the new head as @`Just' h@, or + -- `Nothing' in case the head no longer exists in storage. Second + -- element is the value from last call to the update function. +updateHead h f = do + (o, x) <- f $ headStoredObject h + replaceHead h o >>= \case + Right h' -> return (Just h', x) + Left Nothing -> return (Nothing, x) + Left (Just h') -> updateHead h' f + +-- | Update existing existing `Head' of type @a@ in the storage, using a given +-- function. The update function may be called multiple times in case the head +-- content changes concurrently during evaluation. +updateHead_ + :: (HeadType a, MonadIO m) + => Head a -- ^ Existing head to be updated + -> (Stored a -> m (Stored a)) + -- ^ Function that gets current value of the head and returns updated + -- value; may be called multiple times. + -> m (Maybe (Head a)) + -- ^ The new head as @`Just' h@, or `Nothing' in case the head no + -- longer exists in storage. +updateHead_ h = fmap fst . updateHead h . (fmap (,()) .) + + +-- | Represents a handle of a watched head, which can be used to cancel the +-- watching. +data WatchedHead = forall a. WatchedHead Storage WatchID (MVar a) + +-- | Watch the given head. The callback will be called with the current head +-- value, and then again each time the head changes. +watchHead :: forall a. HeadType a => Head a -> (Head a -> IO ()) -> IO WatchedHead +watchHead h = watchHeadWith h id + +-- | Watch the given head using custom selector function. The callback will be +-- called with the value derived from current head state, and then again each +-- time the selected value changes according to its `Eq' instance. +watchHeadWith + :: forall a b. (HeadType a, Eq b) + => Head a -- ^ Head to watch + -> (Head a -> b) -- ^ Selector function + -> (b -> IO ()) -- ^ Callback + -> IO WatchedHead -- ^ Watched head handle +watchHeadWith (Head hid val) sel cb = do + watchHeadRaw (storedStorage val) (headTypeID @a Proxy) hid (sel . Head hid . wrappedLoad) cb + +-- | Watch the given head using raw IDs and a selector from `Ref'. +watchHeadRaw :: forall b. Eq b => Storage -> HeadTypeID -> HeadID -> (Ref -> b) -> (b -> IO ()) -> IO WatchedHead +watchHeadRaw st@Storage {..} tid hid sel cb = do + memo <- newEmptyMVar + let cb' dgst = do + let x = sel (Ref st dgst) + modifyMVar_ memo $ \prev -> do + when (Just x /= prev) $ cb x + return $ Just x + wid <- backendWatchHead stBackend tid hid cb' + + cur <- fmap sel <$> loadHeadRaw st tid hid + maybe (return ()) cb cur + putMVar memo cur + + return $ WatchedHead st wid memo + +-- | Stop watching previously watched head. +unwatchHead :: WatchedHead -> IO () +unwatchHead (WatchedHead Storage {..} wid _) = do + backendUnwatchHead stBackend wid diff --git a/src/Erebos/Storage/Internal.hs b/src/Erebos/Storage/Internal.hs new file mode 100644 index 0000000..db211bb --- /dev/null +++ b/src/Erebos/Storage/Internal.hs @@ -0,0 +1,291 @@ +module Erebos.Storage.Internal ( + Storage'(..), Storage, PartialStorage, + Ref'(..), Ref, PartialRef, + RefDigest(..), + WatchID, startWatchID, nextWatchID, + WatchList(..), WatchListItem(..), watchListAdd, watchListDel, + + refStorage, + refDigest, refDigestFromByteString, + showRef, showRefDigest, showRefDigestParts, + readRefDigest, + hashToRefDigest, + + StorageCompleteness(..), + StorageBackend(..), + Complete, Partial, + + unsafeStoreRawBytes, + ioLoadBytesFromStorage, + + Generation(..), + HeadID(..), HeadTypeID(..), + Stored(..), storedStorage, +) where + +import Control.Arrow +import Control.Concurrent +import Control.DeepSeq +import Control.Exception +import Control.Monad.Identity + +import Crypto.Hash + +import Data.Bits +import Data.ByteArray (ByteArrayAccess, ScrubbedBytes) +import Data.ByteArray qualified as BA +import Data.ByteString (ByteString) +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.Function +import Data.HashTable.IO qualified as HT +import Data.Hashable +import Data.Kind +import Data.Typeable + +import Foreign.Storable (peek) + +import System.IO.Unsafe (unsafePerformIO) + +import Erebos.UUID (UUID) +import Erebos.Util + + +data Storage' c = forall bck. (StorageBackend bck, BackendCompleteness bck ~ c) => Storage + { stBackend :: bck + , stRefGeneration :: MVar (HT.BasicHashTable RefDigest Generation) + , stRefRoots :: MVar (HT.BasicHashTable RefDigest [RefDigest]) + } + +type Storage = Storage' Complete +type PartialStorage = Storage' Partial + +instance Eq (Storage' c) where + Storage { stBackend = b } == Storage { stBackend = b' } + | Just b'' <- cast b' = b == b'' + | otherwise = False + +instance Show (Storage' c) where + show Storage { stBackend = b } = show b ++ showParentStorage b + +showParentStorage :: StorageBackend bck => bck -> String +showParentStorage bck + | Just (st :: Storage) <- cast (backendParent bck) = "@" ++ show st + | Just (st :: PartialStorage) <- cast (backendParent bck) = "@" ++ show st + | otherwise = "" + + +class (Eq bck, Show bck, Typeable bck, Typeable (BackendParent bck)) => StorageBackend bck where + type BackendCompleteness bck :: Type -> Type + type BackendCompleteness bck = Complete + + type BackendParent bck :: Type + type BackendParent bck = () + backendParent :: bck -> BackendParent bck + default backendParent :: BackendParent bck ~ () => bck -> BackendParent bck + backendParent _ = () + + + backendLoadBytes :: bck -> RefDigest -> IO (Maybe BL.ByteString) + default backendLoadBytes :: BackendParent bck ~ Storage => bck -> RefDigest -> IO (Maybe BL.ByteString) + backendLoadBytes bck = case backendParent bck of Storage { stBackend = bck' } -> backendLoadBytes bck' + + backendStoreBytes :: bck -> RefDigest -> BL.ByteString -> IO () + default backendStoreBytes :: BackendParent bck ~ Storage => bck -> RefDigest -> BL.ByteString -> IO () + backendStoreBytes bck = case backendParent bck of Storage { stBackend = bck' } -> backendStoreBytes bck' + + + backendLoadHeads :: bck -> HeadTypeID -> IO [ ( HeadID, RefDigest ) ] + default backendLoadHeads :: BackendParent bck ~ Storage => bck -> HeadTypeID -> IO [ ( HeadID, RefDigest ) ] + backendLoadHeads bck = case backendParent bck of Storage { stBackend = bck' } -> backendLoadHeads bck' + + backendLoadHead :: bck -> HeadTypeID -> HeadID -> IO (Maybe RefDigest) + default backendLoadHead :: BackendParent bck ~ Storage => bck -> HeadTypeID -> HeadID -> IO (Maybe RefDigest) + backendLoadHead bck = case backendParent bck of Storage { stBackend = bck' } -> backendLoadHead bck' + + backendStoreHead :: bck -> HeadTypeID -> HeadID -> RefDigest -> IO () + default backendStoreHead :: BackendParent bck ~ Storage => bck -> HeadTypeID -> HeadID -> RefDigest -> IO () + backendStoreHead bck = case backendParent bck of Storage { stBackend = bck' } -> backendStoreHead bck' + + backendReplaceHead :: bck -> HeadTypeID -> HeadID -> RefDigest -> RefDigest -> IO (Either (Maybe RefDigest) RefDigest) + default backendReplaceHead :: BackendParent bck ~ Storage => bck -> HeadTypeID -> HeadID -> RefDigest -> RefDigest -> IO (Either (Maybe RefDigest) RefDigest) + backendReplaceHead bck = case backendParent bck of Storage { stBackend = bck' } -> backendReplaceHead bck' + + backendWatchHead :: bck -> HeadTypeID -> HeadID -> (RefDigest -> IO ()) -> IO WatchID + default backendWatchHead :: BackendParent bck ~ Storage => bck -> HeadTypeID -> HeadID -> (RefDigest -> IO ()) -> IO WatchID + backendWatchHead bck = case backendParent bck of Storage { stBackend = bck' } -> backendWatchHead bck' + + backendUnwatchHead :: bck -> WatchID -> IO () + default backendUnwatchHead :: BackendParent bck ~ Storage => bck -> WatchID -> IO () + backendUnwatchHead bck = case backendParent bck of Storage { stBackend = bck' } -> backendUnwatchHead bck' + + + backendListKeys :: bck -> IO [ RefDigest ] + default backendListKeys :: BackendParent bck ~ Storage => bck -> IO [ RefDigest ] + backendListKeys bck = case backendParent bck of Storage { stBackend = bck' } -> backendListKeys bck' + + backendLoadKey :: bck -> RefDigest -> IO (Maybe ScrubbedBytes) + default backendLoadKey :: BackendParent bck ~ Storage => bck -> RefDigest -> IO (Maybe ScrubbedBytes) + backendLoadKey bck = case backendParent bck of Storage { stBackend = bck' } -> backendLoadKey bck' + + backendStoreKey :: bck -> RefDigest -> ScrubbedBytes -> IO () + default backendStoreKey :: BackendParent bck ~ Storage => bck -> RefDigest -> ScrubbedBytes -> IO () + backendStoreKey bck = case backendParent bck of Storage { stBackend = bck' } -> backendStoreKey bck' + + backendRemoveKey :: bck -> RefDigest -> IO () + default backendRemoveKey :: BackendParent bck ~ Storage => bck -> RefDigest -> IO () + backendRemoveKey bck = case backendParent bck of Storage { stBackend = bck' } -> backendRemoveKey bck' + + + +newtype WatchID = WatchID Int + deriving (Eq, Ord) + +startWatchID :: WatchID +startWatchID = WatchID 1 + +nextWatchID :: WatchID -> WatchID +nextWatchID (WatchID n) = WatchID (n + 1) + +data WatchList = WatchList + { wlNext :: WatchID + , wlList :: [ WatchListItem ] + } + +data WatchListItem = WatchListItem + { wlID :: WatchID + , wlHead :: ( HeadTypeID, HeadID ) + , wlFun :: RefDigest -> IO () + } + +watchListAdd :: HeadTypeID -> HeadID -> (RefDigest -> IO ()) -> WatchList -> ( WatchList, WatchID ) +watchListAdd tid hid cb wl = ( wl', wlNext wl ) + where + wl' = wl + { wlNext = nextWatchID (wlNext wl) + , wlList = WatchListItem + { wlID = wlNext wl + , wlHead = (tid, hid) + , wlFun = cb + } : wlList wl + } + +watchListDel :: WatchID -> WatchList -> WatchList +watchListDel wid wl = wl { wlList = filter ((/= wid) . wlID) $ wlList wl } + + +newtype RefDigest = RefDigest (Digest Blake2b_256) + deriving (Eq, Ord, NFData, ByteArrayAccess) + +instance Show RefDigest where + show = BC.unpack . showRefDigest + +data Ref' c = Ref (Storage' c) RefDigest + +type Ref = Ref' Complete +type PartialRef = Ref' Partial + +instance Eq (Ref' c) where + Ref _ d1 == Ref _ d2 = d1 == d2 + +instance Show (Ref' c) where + show ref@(Ref st _) = show st ++ ":" ++ BC.unpack (showRef ref) + +instance ByteArrayAccess (Ref' c) where + length (Ref _ dgst) = BA.length dgst + withByteArray (Ref _ dgst) = BA.withByteArray dgst + +instance Hashable RefDigest where + hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +instance Hashable (Ref' c) where + hashWithSalt salt ref = salt `xor` unsafePerformIO (BA.withByteArray ref peek) + +refStorage :: Ref' c -> Storage' c +refStorage (Ref st _) = st + +refDigest :: Ref' c -> RefDigest +refDigest (Ref _ dgst) = dgst + +showRef :: Ref' c -> ByteString +showRef = showRefDigest . refDigest + +showRefDigestParts :: RefDigest -> (ByteString, ByteString) +showRefDigestParts x = (BC.pack "blake2", showHex x) + +showRefDigest :: RefDigest -> ByteString +showRefDigest = showRefDigestParts >>> \(alg, hex) -> alg <> BC.pack "#" <> hex + +readRefDigest :: ByteString -> Maybe RefDigest +readRefDigest x = case BC.split '#' x of + [alg, dgst] | BA.convert alg == BC.pack "blake2" -> + refDigestFromByteString =<< readHex dgst + _ -> Nothing + +refDigestFromByteString :: ByteString -> Maybe RefDigest +refDigestFromByteString = fmap RefDigest . digestFromByteString + +hashToRefDigest :: BL.ByteString -> RefDigest +hashToRefDigest = RefDigest . hashFinalize . hashUpdates hashInit . BL.toChunks + + +newtype Generation = Generation Int + deriving (Eq, Show) + +-- | UUID of individual Erebos storage head. +newtype HeadID = HeadID UUID + deriving (Eq, Ord, Show) + +-- | UUID of Erebos storage head type. +newtype HeadTypeID = HeadTypeID UUID + deriving (Eq, Ord) + +data Stored a = Stored + { storedRef' :: Ref + , storedObject' :: a + } + deriving (Show) + +instance Eq (Stored a) where + (==) = (==) `on` (refDigest . storedRef') + +instance Ord (Stored a) where + compare = compare `on` (refDigest . storedRef') + +storedStorage :: Stored a -> Storage +storedStorage = refStorage . storedRef' + + +type Complete = Identity +type Partial = Either RefDigest + +class (Traversable compl, Monad compl, Typeable compl) => StorageCompleteness compl where + type LoadResult compl a :: Type + returnLoadResult :: compl a -> LoadResult compl a + ioLoadBytes :: Ref' compl -> IO (compl BL.ByteString) + +instance StorageCompleteness Complete where + type LoadResult Complete a = a + returnLoadResult = runIdentity + ioLoadBytes ref@(Ref st dgst) = maybe (error $ "Ref not found in complete storage: "++show ref) Identity + <$> ioLoadBytesFromStorage st dgst + +instance StorageCompleteness Partial where + type LoadResult Partial a = Either RefDigest a + returnLoadResult = id + ioLoadBytes (Ref st dgst) = maybe (Left dgst) Right <$> ioLoadBytesFromStorage st dgst + +unsafeStoreRawBytes :: Storage' c -> BL.ByteString -> IO (Ref' c) +unsafeStoreRawBytes st@Storage {..} raw = do + dgst <- evaluate $ force $ hashToRefDigest raw + backendStoreBytes stBackend dgst raw + return $ Ref st dgst + +ioLoadBytesFromStorage :: Storage' c -> RefDigest -> IO (Maybe BL.ByteString) +ioLoadBytesFromStorage Storage {..} dgst = + backendLoadBytes stBackend dgst >>= \case + Just bytes -> return $ Just bytes + Nothing + | Just (parent :: Storage) <- cast (backendParent stBackend) -> ioLoadBytesFromStorage parent dgst + | Just (parent :: PartialStorage) <- cast (backendParent stBackend) -> ioLoadBytesFromStorage parent dgst + | otherwise -> return Nothing diff --git a/src/Erebos/Storage/Key.hs b/src/Erebos/Storage/Key.hs new file mode 100644 index 0000000..b615f16 --- /dev/null +++ b/src/Erebos/Storage/Key.hs @@ -0,0 +1,52 @@ +module Erebos.Storage.Key ( + KeyPair(..), + storeKey, loadKey, loadKeyMb, + moveKeys, +) where + +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class + +import Data.ByteArray +import Data.Typeable + +import Erebos.Storable +import Erebos.Storage.Internal + +class Storable pub => KeyPair sec pub | sec -> pub, pub -> sec where + generateKeys :: Storage -> IO (sec, Stored pub) + keyGetPublic :: sec -> Stored pub + keyGetData :: sec -> ScrubbedBytes + keyFromData :: ScrubbedBytes -> Stored pub -> Maybe sec + + +storeKey :: KeyPair sec pub => sec -> IO () +storeKey key = do + let spub = keyGetPublic key + case storedStorage spub of + Storage {..} -> backendStoreKey stBackend (refDigest $ storedRef spub) (keyGetData key) + +loadKey :: (KeyPair sec pub, MonadIO m, MonadError e m, FromErebosError e) => Stored pub -> m sec +loadKey pub = maybe (throwOtherError $ "secret key not found for " <> show (storedRef pub)) return =<< loadKeyMb pub + +loadKeyMb :: forall sec pub m. (KeyPair sec pub, MonadIO m) => Stored pub -> m (Maybe sec) +loadKeyMb spub = liftIO $ run $ storedStorage spub + where + run :: Storage' c -> IO (Maybe sec) + run Storage {..} = backendLoadKey stBackend (refDigest $ storedRef spub) >>= \case + Just bytes -> return $ keyFromData bytes spub + Nothing + | Just (parent :: Storage) <- cast (backendParent stBackend) -> run parent + | Just (parent :: PartialStorage) <- cast (backendParent stBackend) -> run parent + | otherwise -> return Nothing + +moveKeys :: MonadIO m => Storage -> Storage -> m () +moveKeys Storage { stBackend = from } Storage { stBackend = to } = liftIO $ do + keys <- backendListKeys from + forM_ keys $ \key -> do + backendLoadKey from key >>= \case + Just sec -> do + backendStoreKey to key sec + backendRemoveKey from key + Nothing -> return () diff --git a/src/Erebos/Storage/List.hs b/src/Erebos/Storage/List.hs new file mode 100644 index 0000000..f0f8786 --- /dev/null +++ b/src/Erebos/Storage/List.hs @@ -0,0 +1,154 @@ +module Erebos.Storage.List ( + StoredList, + emptySList, fromSList, storedFromSList, + slistAdd, slistAddS, + -- TODO slistInsert, slistInsertS, + slistRemove, slistReplace, slistReplaceS, + -- TODO mapFromSList, updateOld, + + -- TODO StoreUpdate(..), + -- TODO withStoredListItem, withStoredListItemS, +) where + +import Data.List +import Data.Maybe +import qualified Data.Set as S + +import Erebos.Storage +import Erebos.Storage.Internal +import Erebos.Storage.Merge + +data List a = ListNil + | ListItem { listPrev :: [StoredList a] + , listItem :: Maybe (Stored a) + , listRemove :: Maybe (Stored (List a)) + } + +type StoredList a = Stored (List a) + +instance Storable a => Storable (List a) where + store' ListNil = storeZero + store' x@ListItem {} = storeRec $ do + mapM_ (storeRef "PREV") $ listPrev x + mapM_ (storeRef "item") $ listItem x + mapM_ (storeRef "remove") $ listRemove x + + load' = loadCurrentObject >>= \case + ZeroObject -> return ListNil + _ -> loadRec $ ListItem <$> loadRefs "PREV" + <*> loadMbRef "item" + <*> loadMbRef "remove" + +instance Storable a => ZeroStorable (List a) where + fromZero _ = ListNil + + +emptySList :: Storable a => Storage -> IO (StoredList a) +emptySList st = wrappedStore st ListNil + +groupsFromSLists :: forall a. Storable a => StoredList a -> [[Stored a]] +groupsFromSLists = helperSelect S.empty . (:[]) + where + helperSelect :: S.Set (StoredList a) -> [StoredList a] -> [[Stored a]] + helperSelect rs xxs | x:xs <- sort $ filterRemoved rs xxs = helper rs x xs + | otherwise = [] + + helper :: S.Set (StoredList a) -> StoredList a -> [StoredList a] -> [[Stored a]] + helper rs x xs + | ListNil <- fromStored x + = [] + + | Just rm <- listRemove (fromStored x) + , ans <- ancestors [x] + , (other, collision) <- partition (S.null . S.intersection ans . ancestors . (:[])) xs + , cont <- helperSelect (rs `S.union` ancestors [rm]) $ concatMap (listPrev . fromStored) (x : collision) ++ other + = case catMaybes $ map (listItem . fromStored) (x : collision) of + [] -> cont + xis -> xis : cont + + | otherwise = case listItem (fromStored x) of + Nothing -> helperSelect rs $ listPrev (fromStored x) ++ xs + Just xi -> [xi] : (helperSelect rs $ listPrev (fromStored x) ++ xs) + + filterRemoved :: S.Set (StoredList a) -> [StoredList a] -> [StoredList a] + filterRemoved rs = filter (S.null . S.intersection rs . ancestors . (:[])) + +fromSList :: Mergeable a => StoredList (Component a) -> [a] +fromSList = map merge . groupsFromSLists + +storedFromSList :: (Mergeable a, Storable a) => StoredList (Component a) -> IO [Stored a] +storedFromSList = mapM storeMerge . groupsFromSLists + +slistAdd :: Storable a => a -> StoredList a -> IO (StoredList a) +slistAdd x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistAddS sx prev + +slistAddS :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistAddS sx prev@(Stored (Ref st _) _) = wrappedStore st (ListItem [prev] (Just sx) Nothing) + +{- TODO +slistInsert :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistInsert after x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistInsertS after sx prev + +slistInsertS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistInsertS after sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem Nothing (findSListRef after prev) (Just sx) prev +-} + +slistRemove :: Storable a => Stored a -> StoredList a -> IO (StoredList a) +slistRemove rm prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] Nothing (findSListRef rm prev) + +slistReplace :: Storable a => Stored a -> a -> StoredList a -> IO (StoredList a) +slistReplace rm x prev@(Stored (Ref st _) _) = do + sx <- wrappedStore st x + slistReplaceS rm sx prev + +slistReplaceS :: Storable a => Stored a -> Stored a -> StoredList a -> IO (StoredList a) +slistReplaceS rm sx prev@(Stored (Ref st _) _) = wrappedStore st $ ListItem [prev] (Just sx) (findSListRef rm prev) + +findSListRef :: Stored a -> StoredList a -> Maybe (StoredList a) +findSListRef _ (Stored _ ListNil) = Nothing +findSListRef x cur | listItem (fromStored cur) == Just x = Just cur + | otherwise = listToMaybe $ catMaybes $ map (findSListRef x) $ listPrev $ fromStored cur + +{- TODO +mapFromSList :: Storable a => StoredList a -> Map RefDigest (Stored a) +mapFromSList list = helper list M.empty + where helper :: Storable a => StoredList a -> Map RefDigest (Stored a) -> Map RefDigest (Stored a) + helper (Stored _ ListNil) cur = cur + helper (Stored _ (ListItem (Just rref) _ (Just x) rest)) cur = + let rxref = case load rref of + ListItem _ _ (Just rx) _ -> sameType rx x $ storedRef rx + _ -> error "mapFromSList: malformed list" + in helper rest $ case M.lookup (refDigest $ storedRef x) cur of + Nothing -> M.insert (refDigest rxref) x cur + Just x' -> M.insert (refDigest rxref) x' cur + helper (Stored _ (ListItem _ _ _ rest)) cur = helper rest cur + sameType :: a -> a -> b -> b + sameType _ _ x = x + +updateOld :: Map RefDigest (Stored a) -> Stored a -> Stored a +updateOld m x = fromMaybe x $ M.lookup (refDigest $ storedRef x) m + + +data StoreUpdate a = StoreKeep + | StoreReplace a + | StoreRemove + +withStoredListItem :: (Storable a) => (a -> Bool) -> StoredList a -> (a -> IO (StoreUpdate a)) -> IO (StoredList a) +withStoredListItem p list f = withStoredListItemS (p . fromStored) list (suMap (wrappedStore $ storedStorage list) <=< f . fromStored) + where suMap :: Monad m => (a -> m b) -> StoreUpdate a -> m (StoreUpdate b) + suMap _ StoreKeep = return StoreKeep + suMap g (StoreReplace x) = return . StoreReplace =<< g x + suMap _ StoreRemove = return StoreRemove + +withStoredListItemS :: (Storable a) => (Stored a -> Bool) -> StoredList a -> (Stored a -> IO (StoreUpdate (Stored a))) -> IO (StoredList a) +withStoredListItemS p list f = do + case find p $ storedFromSList list of + Just sx -> f sx >>= \case StoreKeep -> return list + StoreReplace nx -> slistReplaceS sx nx list + StoreRemove -> slistRemove sx list + Nothing -> return list +-} diff --git a/src/Erebos/Storage/Memory.hs b/src/Erebos/Storage/Memory.hs new file mode 100644 index 0000000..677e8c5 --- /dev/null +++ b/src/Erebos/Storage/Memory.hs @@ -0,0 +1,101 @@ +module Erebos.Storage.Memory ( + memoryStorage, + deriveEphemeralStorage, + derivePartialStorage, +) where + +import Control.Concurrent.MVar + +import Data.ByteArray (ScrubbedBytes) +import Data.ByteString.Lazy qualified as BL +import Data.Function +import Data.Kind +import Data.List +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe +import Data.Typeable + +import Erebos.Object +import Erebos.Storage.Backend +import Erebos.Storage.Head +import Erebos.Storage.Internal + + +data MemoryStorage p (c :: Type -> Type) = StorageMemory + { memParent :: p + , memHeads :: MVar [ (( HeadTypeID, HeadID ), RefDigest ) ] + , memObjs :: MVar (Map RefDigest BL.ByteString) + , memKeys :: MVar (Map RefDigest ScrubbedBytes) + , memWatchers :: MVar WatchList + } + +instance Eq (MemoryStorage p c) where + (==) = (==) `on` memObjs + +instance Show (MemoryStorage p c) where + show StorageMemory {} = "mem" + +instance (StorageCompleteness c, Typeable p) => StorageBackend (MemoryStorage p c) where + type BackendCompleteness (MemoryStorage p c) = c + type BackendParent (MemoryStorage p c) = p + backendParent = memParent + + backendLoadBytes StorageMemory {..} dgst = + M.lookup dgst <$> readMVar memObjs + + backendStoreBytes StorageMemory {..} dgst raw = + modifyMVar_ memObjs (return . M.insert dgst raw) + + + backendLoadHeads StorageMemory {..} tid = do + let toRes ( ( tid', hid ), dgst ) + | tid' == tid = Just ( hid, dgst ) + | otherwise = Nothing + catMaybes . map toRes <$> readMVar memHeads + + backendLoadHead StorageMemory {..} tid hid = + lookup (tid, hid) <$> readMVar memHeads + + backendStoreHead StorageMemory {..} tid hid dgst = + modifyMVar_ memHeads $ return . (( ( tid, hid ), dgst ) :) + + backendReplaceHead StorageMemory {..} tid hid expected new = do + res <- modifyMVar memHeads $ \hs -> do + ws <- map wlFun . filter ((==(tid, hid)) . wlHead) . wlList <$> readMVar memWatchers + return $ case partition ((==(tid, hid)) . fst) hs of + ( [] , _ ) -> ( hs, Left Nothing ) + (( _, dgst ) : _, hs' ) + | dgst == expected -> ((( tid, hid ), new ) : hs', Right ( new, ws )) + | otherwise -> ( hs, Left $ Just dgst ) + case res of + Right ( dgst, ws ) -> mapM_ ($ dgst) ws >> return (Right dgst) + Left x -> return $ Left x + + backendWatchHead StorageMemory {..} tid hid cb = modifyMVar memWatchers $ return . watchListAdd tid hid cb + + backendUnwatchHead StorageMemory {..} wid = modifyMVar_ memWatchers $ return . watchListDel wid + + + backendListKeys StorageMemory {..} = M.keys <$> readMVar memKeys + backendLoadKey StorageMemory {..} dgst = M.lookup dgst <$> readMVar memKeys + backendStoreKey StorageMemory {..} dgst key = modifyMVar_ memKeys $ return . M.insert dgst key + backendRemoveKey StorageMemory {..} dgst = modifyMVar_ memKeys $ return . M.delete dgst + + +memoryStorage' :: (StorageCompleteness c, Typeable p) => p -> IO (Storage' c) +memoryStorage' memParent = do + memHeads <- newMVar [] + memObjs <- newMVar M.empty + memKeys <- newMVar M.empty + memWatchers <- newMVar (WatchList startWatchID []) + newStorage $ StorageMemory {..} + +memoryStorage :: IO Storage +memoryStorage = memoryStorage' () + +deriveEphemeralStorage :: Storage -> IO Storage +deriveEphemeralStorage parent = memoryStorage' parent + +derivePartialStorage :: Storage -> IO PartialStorage +derivePartialStorage parent = memoryStorage' parent diff --git a/src/Erebos/Storage/Merge.hs b/src/Erebos/Storage/Merge.hs new file mode 100644 index 0000000..a41a65f --- /dev/null +++ b/src/Erebos/Storage/Merge.hs @@ -0,0 +1,164 @@ +module Erebos.Storage.Merge ( + Mergeable(..), + merge, storeMerge, + + Generation, + showGeneration, + compareGeneration, generationMax, + storedGeneration, + + generations, + ancestors, + precedes, + precedesOrEquals, + filterAncestors, + storedRoots, + walkAncestors, + + findProperty, + findPropertyFirst, +) where + +import Control.Concurrent.MVar + +import Data.ByteString.Char8 qualified as BC +import Data.HashTable.IO qualified as HT +import Data.Kind +import Data.List +import Data.Maybe +import Data.Set (Set) +import Data.Set qualified as S + +import System.IO.Unsafe (unsafePerformIO) + +import Erebos.Object +import Erebos.Storable +import Erebos.Storage.Internal +import Erebos.Util + +class Storable (Component a) => Mergeable a where + type Component a :: Type + mergeSorted :: [Stored (Component a)] -> a + toComponents :: a -> [Stored (Component a)] + +instance Mergeable [Stored Object] where + type Component [Stored Object] = Object + mergeSorted = id + toComponents = id + +merge :: Mergeable a => [Stored (Component a)] -> a +merge [] = error "merge: empty list" +merge xs = mergeSorted $ filterAncestors xs + +storeMerge :: (Mergeable a, Storable a) => [Stored (Component a)] -> IO (Stored a) +storeMerge [] = error "merge: empty list" +storeMerge xs@(x : _) = wrappedStore (storedStorage x) $ mergeSorted $ filterAncestors xs + +previous :: Storable a => Stored a -> [Stored a] +previous (Stored ref _) = case load ref of + Rec items | Just (RecRef dref) <- lookup (BC.pack "SDATA") items + , Rec ditems <- load dref -> + map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ + map snd $ filter ((`elem` [ BC.pack "SPREV", BC.pack "SBASE" ]) . fst) ditems + + | otherwise -> + map wrappedLoad $ catMaybes $ map (\case RecRef r -> Just r; _ -> Nothing) $ + map snd $ filter ((`elem` [ BC.pack "PREV", BC.pack "BASE" ]) . fst) items + _ -> [] + + +nextGeneration :: [Generation] -> Generation +nextGeneration = foldl' helper (Generation 0) + where helper (Generation c) (Generation n) | c <= n = Generation (n + 1) + | otherwise = Generation c + +showGeneration :: Generation -> String +showGeneration (Generation x) = show x + +compareGeneration :: Generation -> Generation -> Maybe Ordering +compareGeneration (Generation x) (Generation y) = Just $ compare x y + +generationMax :: Storable a => [Stored a] -> Maybe (Stored a) +generationMax (x : xs) = Just $ snd $ foldl' helper (storedGeneration x, x) xs + where helper (mg, mx) y = let yg = storedGeneration y + in case compareGeneration mg yg of + Just LT -> (yg, y) + _ -> (mg, mx) +generationMax [] = Nothing + +storedGeneration :: Storable a => Stored a -> Generation +storedGeneration x = + unsafePerformIO $ withMVar (stRefGeneration $ refStorage $ storedRef x) $ \ht -> do + let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case + Just gen -> return gen + Nothing -> do + gen <- nextGeneration <$> mapM doLookup (previous y) + HT.insert ht (refDigest $ storedRef y) gen + return gen + doLookup x + + +-- |Returns list of sets starting with the set of given objects and +-- intcrementally adding parents. +generations :: Storable a => [Stored a] -> [Set (Stored a)] +generations = unfoldr gen . (,S.empty) + where gen (hs, cur) = case filter (`S.notMember` cur) hs of + [] -> Nothing + added -> let next = foldr S.insert cur added + in Just (next, (previous =<< added, next)) + +-- |Returns set containing all given objects and their ancestors +ancestors :: Storable a => [Stored a] -> Set (Stored a) +ancestors = last . (S.empty:) . generations + +precedes :: Storable a => Stored a -> Stored a -> Bool +precedes x y = not $ x `elem` filterAncestors [x, y] + +precedesOrEquals :: Storable a => Stored a -> Stored a -> Bool +precedesOrEquals x y = filterAncestors [ x, y ] == [ y ] + +filterAncestors :: Storable a => [Stored a] -> [Stored a] +filterAncestors [x] = [x] +filterAncestors xs = let xs' = uniq $ sort xs + in helper xs' xs' + where helper remains walk = case generationMax walk of + Just x -> let px = previous x + remains' = filter (\r -> all (/=r) px) remains + in helper remains' $ uniq $ sort (px ++ filter (/=x) walk) + Nothing -> remains + +storedRoots :: Storable a => Stored a -> [Stored a] +storedRoots x = do + let st = refStorage $ storedRef x + unsafePerformIO $ withMVar (stRefRoots st) $ \ht -> do + let doLookup y = HT.lookup ht (refDigest $ storedRef y) >>= \case + Just roots -> return roots + Nothing -> do + roots <- case previous y of + [] -> return [refDigest $ storedRef y] + ps -> map (refDigest . storedRef) . filterAncestors . map (wrappedLoad @Object . Ref st) . concat <$> mapM doLookup ps + HT.insert ht (refDigest $ storedRef y) roots + return roots + map (wrappedLoad . Ref st) <$> doLookup x + +walkAncestors :: (Storable a, Monoid m) => (Stored a -> m) -> [Stored a] -> m +walkAncestors f = helper . sortBy cmp + where + helper (x : y : xs) | x == y = helper (x : xs) + helper (x : xs) = f x <> helper (mergeBy cmp (sortBy cmp (previous x)) xs) + helper [] = mempty + + cmp x y = case compareGeneration (storedGeneration x) (storedGeneration y) of + Just LT -> GT + Just GT -> LT + _ -> compare x y + +findProperty :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> [b] +findProperty sel = map (fromJust . sel . fromStored) . filterAncestors . (findPropHeads sel =<<) + +findPropertyFirst :: forall a b. Storable a => (a -> Maybe b) -> [Stored a] -> Maybe b +findPropertyFirst sel = fmap (fromJust . sel . fromStored) . listToMaybe . filterAncestors . (findPropHeads sel =<<) + +findPropHeads :: forall a b. Storable a => (a -> Maybe b) -> Stored a -> [Stored a] +findPropHeads sel sobj | Just _ <- sel $ fromStored sobj = [sobj] + | otherwise = findPropHeads sel =<< previous sobj diff --git a/src/Erebos/Sync.hs b/src/Erebos/Sync.hs new file mode 100644 index 0000000..d837a14 --- /dev/null +++ b/src/Erebos/Sync.hs @@ -0,0 +1,46 @@ +module Erebos.Sync ( + SyncService(..), +) where + +import Control.Monad +import Control.Monad.Reader + +import Data.List + +import Erebos.Identity +import Erebos.Service +import Erebos.State +import Erebos.Storable +import Erebos.Storage.Merge + +data SyncService = SyncPacket (Stored SharedState) + +instance Service SyncService where + serviceID _ = mkServiceID "a4f538d0-4e50-4082-8e10-7e3ec2af175d" + + serviceHandler packet = do + let SyncPacket added = fromStored packet + pid <- asks svcPeerIdentity + self <- svcSelf + when (finalOwner pid `sameIdentity` finalOwner self) $ do + updateLocalState_ $ \ls -> do + let current = sort $ lsShared $ fromStored ls + updated = filterAncestors (added : current) + if current /= updated + then mstore (fromStored ls) { lsShared = updated } + else return ls + + serviceNewPeer = notifyPeer . lsShared . fromStored =<< svcGetLocal + serviceStorageWatchers _ = (:[]) $ SomeStorageWatcher (lsShared . fromStored) notifyPeer + +instance Storable SyncService where + store' (SyncPacket smsg) = store' smsg + load' = SyncPacket <$> load' + +notifyPeer :: [Stored SharedState] -> ServiceHandler SyncService () +notifyPeer shared = do + pid <- asks svcPeerIdentity + self <- svcSelf + when (finalOwner pid `sameIdentity` finalOwner self) $ do + forM_ shared $ \sh -> + replyStoredRef =<< (mstore . SyncPacket) sh diff --git a/src/Erebos/UUID.hs b/src/Erebos/UUID.hs new file mode 100644 index 0000000..128d450 --- /dev/null +++ b/src/Erebos/UUID.hs @@ -0,0 +1,24 @@ +module Erebos.UUID ( + UUID, + toString, fromString, + toText, fromText, + toASCIIBytes, fromASCIIBytes, + nextRandom, +) where + +import Crypto.Random.Entropy + +import Data.Bits +import Data.ByteString qualified as BS +import Data.ByteString.Lazy qualified as BSL +import Data.Maybe +import Data.UUID.Types + +nextRandom :: IO UUID +nextRandom = do + [ b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, ba, bb, bc, bd, be, bf ] + <- BS.unpack <$> getEntropy 16 + let version = 4 + b6' = b6 .&. 0x0f .|. (version `shiftL` 4) + b8' = b8 .&. 0x3f .|. 0x80 + return $ fromJust $ fromByteString $ BSL.pack [ b0, b1, b2, b3, b4, b5, b6', b7, b8', b9, ba, bb, bc, bd, be, bf ] diff --git a/src/Erebos/Util.hs b/src/Erebos/Util.hs new file mode 100644 index 0000000..0381c3e --- /dev/null +++ b/src/Erebos/Util.hs @@ -0,0 +1,67 @@ +module Erebos.Util where + +import Control.Monad + +import Data.ByteArray (ByteArray, ByteArrayAccess) +import Data.ByteArray qualified as BA +import Data.ByteString (ByteString) +import Data.ByteString qualified as B +import Data.Char + + +uniq :: Eq a => [a] -> [a] +uniq (x:y:xs) | x == y = uniq (x:xs) + | otherwise = x : uniq (y:xs) +uniq xs = xs + +mergeBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeBy cmp (x : xs) (y : ys) = case cmp x y of + LT -> x : mergeBy cmp xs (y : ys) + EQ -> x : y : mergeBy cmp xs ys + GT -> y : mergeBy cmp (x : xs) ys +mergeBy _ xs [] = xs +mergeBy _ [] ys = ys + +mergeUniqBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] +mergeUniqBy cmp (x : xs) (y : ys) = case cmp x y of + LT -> x : mergeBy cmp xs (y : ys) + EQ -> x : mergeBy cmp xs ys + GT -> y : mergeBy cmp (x : xs) ys +mergeUniqBy _ xs [] = xs +mergeUniqBy _ [] ys = ys + +mergeUniq :: Ord a => [a] -> [a] -> [a] +mergeUniq = mergeUniqBy compare + +diffSorted :: Ord a => [a] -> [a] -> [a] +diffSorted (x:xs) (y:ys) | x < y = x : diffSorted xs (y:ys) + | x > y = diffSorted (x:xs) ys + | otherwise = diffSorted xs (y:ys) +diffSorted xs _ = xs + +intersectsSorted :: Ord a => [a] -> [a] -> Bool +intersectsSorted (x:xs) (y:ys) | x < y = intersectsSorted xs (y:ys) + | x > y = intersectsSorted (x:xs) ys + | otherwise = True +intersectsSorted _ _ = False + + +showHex :: ByteArrayAccess ba => ba -> ByteString +showHex = B.concat . map showHexByte . BA.unpack + where showHexChar x | x < 10 = x + o '0' + | otherwise = x + o 'a' - 10 + showHexByte x = B.pack [ showHexChar (x `div` 16), showHexChar (x `mod` 16) ] + o = fromIntegral . ord + +readHex :: ByteArray ba => ByteString -> Maybe ba +readHex = return . BA.concat <=< readHex' + where readHex' bs | B.null bs = Just [] + readHex' bs = do (bx, bs') <- B.uncons bs + (by, bs'') <- B.uncons bs' + x <- hexDigit bx + y <- hexDigit by + (B.singleton (x * 16 + y) :) <$> readHex' bs'' + hexDigit x | x >= o '0' && x <= o '9' = Just $ x - o '0' + | x >= o 'a' && x <= o 'z' = Just $ x - o 'a' + 10 + | otherwise = Nothing + o = fromIntegral . ord diff --git a/src/unix/Erebos/Storage/Platform.hs b/src/unix/Erebos/Storage/Platform.hs new file mode 100644 index 0000000..2198f61 --- /dev/null +++ b/src/unix/Erebos/Storage/Platform.hs @@ -0,0 +1,20 @@ +{-# LANGUAGE CPP #-} + +module Erebos.Storage.Platform ( + createFileExclusive, +) where + +import System.IO +import System.Posix.Files +import System.Posix.IO + +createFileExclusive :: FilePath -> IO Handle +createFileExclusive path = fdToHandle =<< do +#if MIN_VERSION_unix(2,8,0) + openFd path WriteOnly defaultFileFlags + { creat = Just $ unionFileModes ownerReadMode ownerWriteMode + , exclusive = True + } +#else + openFd path WriteOnly (Just $ unionFileModes ownerReadMode ownerWriteMode) (defaultFileFlags { exclusive = True }) +#endif diff --git a/src/windows/Erebos/Storage/Platform.hs b/src/windows/Erebos/Storage/Platform.hs new file mode 100644 index 0000000..76c940b --- /dev/null +++ b/src/windows/Erebos/Storage/Platform.hs @@ -0,0 +1,13 @@ +module Erebos.Storage.Platform ( + createFileExclusive, +) where + +import Data.Bits + +import System.IO +import System.Win32.File +import System.Win32.Types + +createFileExclusive :: FilePath -> IO Handle +createFileExclusive path = do + hANDLEToHandle =<< createFile path gENERIC_WRITE (fILE_SHARE_READ .|. fILE_SHARE_DELETE) Nothing cREATE_NEW fILE_ATTRIBUTE_NORMAL Nothing diff --git a/attach.test b/test/attach.test index afbdd0e..afbdd0e 100644 --- a/attach.test +++ b/test/attach.test diff --git a/chatroom.test b/test/chatroom.test index 54f9b2a..54f9b2a 100644 --- a/chatroom.test +++ b/test/chatroom.test diff --git a/common.test b/test/common.test index 89941f0..89941f0 100644 --- a/common.test +++ b/test/common.test diff --git a/contact.test b/test/contact.test index 978f8a6..978f8a6 100644 --- a/contact.test +++ b/test/contact.test diff --git a/discovery.test b/test/discovery.test index 9453e65..9453e65 100644 --- a/discovery.test +++ b/test/discovery.test diff --git a/message.test b/test/message.test index 2990d0f..2990d0f 100644 --- a/message.test +++ b/test/message.test diff --git a/network.test b/test/network.test index 0f49a1e..0f49a1e 100644 --- a/network.test +++ b/test/network.test diff --git a/storage.test b/test/storage.test index 2230eac..2230eac 100644 --- a/storage.test +++ b/test/storage.test diff --git a/sync.test b/test/sync.test index d465b11..d465b11 100644 --- a/sync.test +++ b/test/sync.test |