diff options
| -rw-r--r-- | erebos.cabal | 5 | ||||
| -rw-r--r-- | src/Main.hs | 35 | ||||
| -rw-r--r-- | src/Message/Service.hs | 52 | ||||
| -rw-r--r-- | src/Network.hs | 35 | ||||
| -rw-r--r-- | src/Service.hs | 58 | 
5 files changed, 148 insertions, 37 deletions
| diff --git a/erebos.cabal b/erebos.cabal index 391584b..8218d91 100644 --- a/erebos.cabal +++ b/erebos.cabal @@ -21,15 +21,18 @@ executable erebos    other-modules:       Identity,                         Channel,                         Message, +                       Message.Service                         Network,                         PubKey, +                       Service                         State,                         Storage,                         Storage.Internal                         Storage.Key                         Util -  default-extensions:  FlexibleContexts, +  default-extensions:  ExistentialQuantification +                       FlexibleContexts,                         FlexibleInstances,                         FunctionalDependencies,                         GeneralizedNewtypeDeriving diff --git a/src/Main.hs b/src/Main.hs index d473f2e..9e87af5 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -17,7 +17,6 @@ import Data.Char  import Data.List  import Data.Maybe  import qualified Data.Text as T -import Data.Time.Format  import Data.Time.LocalTime  import System.Console.Haskeline @@ -25,8 +24,10 @@ import System.Environment  import Identity  import Message +import Message.Service  import Network  import PubKey +import Service  import State  import Storage @@ -67,8 +68,10 @@ interactiveLoop st bhost = runInputT defaultSettings $ do                               False -> error "Requires terminal"      extPrint <- getExternalPrint      let extPrintLn str = extPrint $ str ++ "\n"; -    (chanPeer, chanSvc) <- liftIO $ +    chanPeer <- liftIO $          startServer extPrintLn bhost self +            [ (T.pack "dmsg", SomeService (emptyServiceState :: DirectMessageService)) +            ]      peers <- liftIO $ newMVar [] @@ -83,25 +86,6 @@ interactiveLoop st bhost = runInputT defaultSettings $ do                   let shown = showPeer peer                   when (Just shown /= (showPeer <$> op)) $ extPrint shown -    tzone <- liftIO $ getCurrentTimeZone -    void $ liftIO $ forkIO $ forever $ readChan chanSvc >>= \case -        (peer, svc, ref) -            | svc == T.pack "dmsg" -> do -                let smsg = wrappedLoad ref -                    msg = fromStored smsg -                extPrintLn $ formatMessage tzone msg -                if | PeerIdentityFull powner <- peerOwner peer -                   , idData powner == msgFrom msg -                   -> updateLocalState_ st $ \erb -> do -                          slist <- case find ((== idData powner) . msgPeer . fromStored) (storedFromSList $ lsMessages $ fromStored erb) of -                                        Just thread -> do thread' <- wrappedStore st (fromStored thread) { msgHead = smsg : msgHead (fromStored thread) } -                                                          slistReplaceS thread thread' $ lsMessages $ fromStored erb -                                        Nothing -> slistAdd (emptyDirectThread powner) { msgHead = [smsg] } $ lsMessages $ fromStored erb -                          wrappedStore st (fromStored erb) { lsMessages = slist } - -                   | otherwise -> extPrint $ "Owner mismatch" -            | otherwise -> extPrint $ "Unknown service: " ++ T.unpack svc -      let getInputLines prompt = do              Just input <- lift $ getInputLine prompt              case reverse input of @@ -229,12 +213,3 @@ cmdUpdateIdentity :: Command  cmdUpdateIdentity = void $ do      st <- asks $ storedStorage . idData . ciSelf      liftIO $ updateIdentity st - - -formatMessage :: TimeZone -> DirectMessage -> String -formatMessage tzone msg = concat -    [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg -    , maybe "<unnamed>" T.unpack $ iddName $ fromStored $ signedData $ fromStored $ msgFrom msg -    , ": " -    , T.unpack $ msgText msg -    ] diff --git a/src/Message/Service.hs b/src/Message/Service.hs new file mode 100644 index 0000000..a798fb5 --- /dev/null +++ b/src/Message/Service.hs @@ -0,0 +1,52 @@ +module Message.Service ( +    DirectMessageService, +    formatMessage, +) where + +import Control.Monad.Reader +import Control.Monad.State + +import Data.List +import qualified Data.Text as T +import Data.Time.Format +import Data.Time.LocalTime + +import Identity +import Message +import PubKey +import Service +import State +import Storage + +data DirectMessageService = DirectMessageService + +instance Service DirectMessageService where +    type ServicePacket DirectMessageService = DirectMessage +    emptyServiceState = DirectMessageService +    serviceHandler smsg = do +        let msg = fromStored smsg +        powner <- asks svcPeerOwner +        tzone <- liftIO $ getCurrentTimeZone +        svcPrint $ formatMessage tzone msg +        if | idData powner == msgFrom msg +           -> do erb <- gets svcLocal +                 let st = storedStorage erb +                 erb' <- liftIO $ do +                     slist <- case find ((== idData powner) . msgPeer . fromStored) (storedFromSList $ lsMessages $ fromStored erb) of +                                   Just thread -> do thread' <- wrappedStore st (fromStored thread) { msgHead = smsg : msgHead (fromStored thread) } +                                                     slistReplaceS thread thread' $ lsMessages $ fromStored erb +                                   Nothing -> slistAdd (emptyDirectThread powner) { msgHead = [smsg] } $ lsMessages $ fromStored erb +                     wrappedStore st (fromStored erb) { lsMessages = slist } +                 modify $ \s -> s { svcLocal = erb' } +                 return Nothing + +           | otherwise -> do svcPrint "Owner mismatch" +                             return Nothing + +formatMessage :: TimeZone -> DirectMessage -> String +formatMessage tzone msg = concat +    [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg +    , maybe "<unnamed>" T.unpack $ iddName $ fromStored $ signedData $ fromStored $ msgFrom msg +    , ": " +    , T.unpack $ msgText msg +    ] diff --git a/src/Network.hs b/src/Network.hs index 5d86a24..bff793a 100644 --- a/src/Network.hs +++ b/src/Network.hs @@ -4,6 +4,7 @@ module Network (      PeerIdentity(..), peerIdentityRef,      PeerChannel(..),      WaitingRef, wrDigest, +    Service(..),      startServer,      sendToPeer,  ) where @@ -14,8 +15,6 @@ import Control.Monad  import Control.Monad.Except  import Control.Monad.State -import Crypto.Random -  import qualified Data.ByteString.Char8 as BC  import qualified Data.ByteString.Lazy as BL  import qualified Data.Map as M @@ -28,6 +27,7 @@ import Network.Socket.ByteString (recvFrom, sendTo)  import Channel  import Identity  import PubKey +import Service  import Storage @@ -43,6 +43,7 @@ data Peer = Peer      , peerSocket :: Socket      , peerStorage :: Storage      , peerInStorage :: PartialStorage +    , peerServiceState :: M.Map T.Text SomeService      , peerServiceQueue :: [(T.Text, WaitingRef)]      , peerWaitingRefs :: [WaitingRef]      } @@ -149,8 +150,8 @@ receivedWaitingRef nref wr@(WaitingRef _ _ mvar) = do      checkWaitingRef wr -startServer :: (String -> IO ()) -> String -> UnifiedIdentity -> IO (Chan Peer, Chan (Peer, T.Text, Ref)) -startServer logd bhost identity = do +startServer :: (String -> IO ()) -> String -> UnifiedIdentity -> [(T.Text, SomeService)] -> IO (Chan Peer) +startServer logd bhost identity services = do      let sidentity = idData identity      chanPeer <- newChan      chanSvc <- newChan @@ -191,6 +192,7 @@ startServer logd bhost identity = do                                    , peerSocket = sock                                    , peerStorage = pst                                    , peerInStorage = ist +                                  , peerServiceState = M.empty                                    , peerServiceQueue = []                                    , peerWaitingRefs = []                                    } @@ -220,7 +222,28 @@ startServer logd bhost identity = do          addr:_ <- getAddrInfo (Just hints) Nothing (Just discoveryPort)          bracket (open addr) close loop -    return (chanPeer, chanSvc) +    void $ forkIO $ forever $ readChan chanSvc >>= \case +        (peer, svc, ref) +            | PeerIdentityFull peerId <- peerIdentity peer +            , PeerIdentityFull peerOwnerId <- peerOwner peer +            , DatagramAddress paddr <- peerAddress peer +            -> case maybe (lookup svc services) Just $ M.lookup svc (peerServiceState peer) of +                    Nothing -> logd $ "unhandled service '" ++ T.unpack svc ++ "'" +                    Just (SomeService s) -> do +                        let inp = ServiceInput +                                { svcPeer = peerId, svcPeerOwner = peerOwnerId +                                , svcPrintOp = logd +                                } +                        (rsp, s') <- handleServicePacket (storedStorage sidentity) inp s (wrappedLoad ref) +                        modifyMVar_ peers $ return . M.adjust (\p -> p { peerServiceState = M.insert svc (SomeService s') $ peerServiceState p }) paddr +                        runExceptT (maybe (return ()) (sendToPeer identity peer svc) rsp) >>= \case +                            Left err -> logd $ "failed to send response to peer: " ++ show err +                            Right () -> return () + +            | DatagramAddress paddr <- peerAddress peer -> do +                logd $ "service packet from peer with incomplete identity " ++ show paddr + +    return chanPeer  type PacketHandler a = StateT PacketHandlerState (ExceptT String IO) a @@ -452,7 +475,7 @@ handleServices chan = gets (peerServiceQueue . phPeer) >>= \case          updatePeer $ \p -> p { peerServiceQueue = queue' } -sendToPeer :: (Storable a, MonadIO m, MonadError String m, MonadRandom m) => UnifiedIdentity -> Peer -> T.Text -> a -> m () +sendToPeer :: (Storable a, MonadIO m, MonadError String m) => UnifiedIdentity -> Peer -> T.Text -> a -> m ()  sendToPeer _ peer@Peer { peerChannel = ChannelEstablished ch } svc obj = do      let st = peerInStorage peer      ref <- liftIO $ store st obj diff --git a/src/Service.hs b/src/Service.hs new file mode 100644 index 0000000..667196d --- /dev/null +++ b/src/Service.hs @@ -0,0 +1,58 @@ +module Service ( +    Service(..), +    SomeService(..), + +    ServiceHandler, +    ServiceInput(..), ServiceState(..), +    handleServicePacket, + +    svcPrint, +) where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Identity +import State +import Storage + +class Storable (ServicePacket s) => Service s where +    type ServicePacket s :: * +    emptyServiceState :: s +    serviceHandler :: Stored (ServicePacket s) -> ServiceHandler s (Maybe (ServicePacket s)) + +data SomeService = forall s. Service s => SomeService s + +data ServiceInput = ServiceInput +    { svcPeer :: UnifiedIdentity +    , svcPeerOwner :: UnifiedIdentity +    , svcPrintOp :: String -> IO () +    } + +data ServiceState s = ServiceState +    { svcValue :: s +    , svcLocal :: Stored LocalState +    } + +newtype ServiceHandler s a = ServiceHandler (ReaderT ServiceInput (StateT (ServiceState s) (ExceptT String IO)) a) +    deriving (Functor, Applicative, Monad, MonadReader ServiceInput, MonadState (ServiceState s), MonadIO) + +handleServicePacket :: Service s => Storage -> ServiceInput -> s -> Stored (ServicePacket s) -> IO (Maybe (ServicePacket s), s) +handleServicePacket st input svc packet = do +    herb <- loadLocalState st +    let erb = wrappedLoad $ headRef herb +        sstate = ServiceState { svcValue = svc, svcLocal = erb } +        ServiceHandler handler = serviceHandler packet +    (runExceptT $ flip runStateT sstate $ flip runReaderT input $ handler) >>= \case +        Left err -> do +            svcPrintOp input $ "service failed: " ++ err +            return (Nothing, svc) +        Right (rsp, sstate') +            | svcLocal sstate' == svcLocal sstate -> return (rsp, svcValue sstate') +            | otherwise -> replaceHead (svcLocal sstate') (Right herb) >>= \case +                Left  _ -> handleServicePacket st input svc packet +                Right _ -> return (rsp, svcValue sstate') + +svcPrint :: String -> ServiceHandler s () +svcPrint str = liftIO . ($str) =<< asks svcPrintOp |