diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Main.hs | 35 | ||||
-rw-r--r-- | src/Message/Service.hs | 52 | ||||
-rw-r--r-- | src/Network.hs | 35 | ||||
-rw-r--r-- | src/Service.hs | 58 |
4 files changed, 144 insertions, 36 deletions
diff --git a/src/Main.hs b/src/Main.hs index d473f2e..9e87af5 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -17,7 +17,6 @@ import Data.Char import Data.List import Data.Maybe import qualified Data.Text as T -import Data.Time.Format import Data.Time.LocalTime import System.Console.Haskeline @@ -25,8 +24,10 @@ import System.Environment import Identity import Message +import Message.Service import Network import PubKey +import Service import State import Storage @@ -67,8 +68,10 @@ interactiveLoop st bhost = runInputT defaultSettings $ do False -> error "Requires terminal" extPrint <- getExternalPrint let extPrintLn str = extPrint $ str ++ "\n"; - (chanPeer, chanSvc) <- liftIO $ + chanPeer <- liftIO $ startServer extPrintLn bhost self + [ (T.pack "dmsg", SomeService (emptyServiceState :: DirectMessageService)) + ] peers <- liftIO $ newMVar [] @@ -83,25 +86,6 @@ interactiveLoop st bhost = runInputT defaultSettings $ do let shown = showPeer peer when (Just shown /= (showPeer <$> op)) $ extPrint shown - tzone <- liftIO $ getCurrentTimeZone - void $ liftIO $ forkIO $ forever $ readChan chanSvc >>= \case - (peer, svc, ref) - | svc == T.pack "dmsg" -> do - let smsg = wrappedLoad ref - msg = fromStored smsg - extPrintLn $ formatMessage tzone msg - if | PeerIdentityFull powner <- peerOwner peer - , idData powner == msgFrom msg - -> updateLocalState_ st $ \erb -> do - slist <- case find ((== idData powner) . msgPeer . fromStored) (storedFromSList $ lsMessages $ fromStored erb) of - Just thread -> do thread' <- wrappedStore st (fromStored thread) { msgHead = smsg : msgHead (fromStored thread) } - slistReplaceS thread thread' $ lsMessages $ fromStored erb - Nothing -> slistAdd (emptyDirectThread powner) { msgHead = [smsg] } $ lsMessages $ fromStored erb - wrappedStore st (fromStored erb) { lsMessages = slist } - - | otherwise -> extPrint $ "Owner mismatch" - | otherwise -> extPrint $ "Unknown service: " ++ T.unpack svc - let getInputLines prompt = do Just input <- lift $ getInputLine prompt case reverse input of @@ -229,12 +213,3 @@ cmdUpdateIdentity :: Command cmdUpdateIdentity = void $ do st <- asks $ storedStorage . idData . ciSelf liftIO $ updateIdentity st - - -formatMessage :: TimeZone -> DirectMessage -> String -formatMessage tzone msg = concat - [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg - , maybe "<unnamed>" T.unpack $ iddName $ fromStored $ signedData $ fromStored $ msgFrom msg - , ": " - , T.unpack $ msgText msg - ] diff --git a/src/Message/Service.hs b/src/Message/Service.hs new file mode 100644 index 0000000..a798fb5 --- /dev/null +++ b/src/Message/Service.hs @@ -0,0 +1,52 @@ +module Message.Service ( + DirectMessageService, + formatMessage, +) where + +import Control.Monad.Reader +import Control.Monad.State + +import Data.List +import qualified Data.Text as T +import Data.Time.Format +import Data.Time.LocalTime + +import Identity +import Message +import PubKey +import Service +import State +import Storage + +data DirectMessageService = DirectMessageService + +instance Service DirectMessageService where + type ServicePacket DirectMessageService = DirectMessage + emptyServiceState = DirectMessageService + serviceHandler smsg = do + let msg = fromStored smsg + powner <- asks svcPeerOwner + tzone <- liftIO $ getCurrentTimeZone + svcPrint $ formatMessage tzone msg + if | idData powner == msgFrom msg + -> do erb <- gets svcLocal + let st = storedStorage erb + erb' <- liftIO $ do + slist <- case find ((== idData powner) . msgPeer . fromStored) (storedFromSList $ lsMessages $ fromStored erb) of + Just thread -> do thread' <- wrappedStore st (fromStored thread) { msgHead = smsg : msgHead (fromStored thread) } + slistReplaceS thread thread' $ lsMessages $ fromStored erb + Nothing -> slistAdd (emptyDirectThread powner) { msgHead = [smsg] } $ lsMessages $ fromStored erb + wrappedStore st (fromStored erb) { lsMessages = slist } + modify $ \s -> s { svcLocal = erb' } + return Nothing + + | otherwise -> do svcPrint "Owner mismatch" + return Nothing + +formatMessage :: TimeZone -> DirectMessage -> String +formatMessage tzone msg = concat + [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg + , maybe "<unnamed>" T.unpack $ iddName $ fromStored $ signedData $ fromStored $ msgFrom msg + , ": " + , T.unpack $ msgText msg + ] diff --git a/src/Network.hs b/src/Network.hs index 5d86a24..bff793a 100644 --- a/src/Network.hs +++ b/src/Network.hs @@ -4,6 +4,7 @@ module Network ( PeerIdentity(..), peerIdentityRef, PeerChannel(..), WaitingRef, wrDigest, + Service(..), startServer, sendToPeer, ) where @@ -14,8 +15,6 @@ import Control.Monad import Control.Monad.Except import Control.Monad.State -import Crypto.Random - import qualified Data.ByteString.Char8 as BC import qualified Data.ByteString.Lazy as BL import qualified Data.Map as M @@ -28,6 +27,7 @@ import Network.Socket.ByteString (recvFrom, sendTo) import Channel import Identity import PubKey +import Service import Storage @@ -43,6 +43,7 @@ data Peer = Peer , peerSocket :: Socket , peerStorage :: Storage , peerInStorage :: PartialStorage + , peerServiceState :: M.Map T.Text SomeService , peerServiceQueue :: [(T.Text, WaitingRef)] , peerWaitingRefs :: [WaitingRef] } @@ -149,8 +150,8 @@ receivedWaitingRef nref wr@(WaitingRef _ _ mvar) = do checkWaitingRef wr -startServer :: (String -> IO ()) -> String -> UnifiedIdentity -> IO (Chan Peer, Chan (Peer, T.Text, Ref)) -startServer logd bhost identity = do +startServer :: (String -> IO ()) -> String -> UnifiedIdentity -> [(T.Text, SomeService)] -> IO (Chan Peer) +startServer logd bhost identity services = do let sidentity = idData identity chanPeer <- newChan chanSvc <- newChan @@ -191,6 +192,7 @@ startServer logd bhost identity = do , peerSocket = sock , peerStorage = pst , peerInStorage = ist + , peerServiceState = M.empty , peerServiceQueue = [] , peerWaitingRefs = [] } @@ -220,7 +222,28 @@ startServer logd bhost identity = do addr:_ <- getAddrInfo (Just hints) Nothing (Just discoveryPort) bracket (open addr) close loop - return (chanPeer, chanSvc) + void $ forkIO $ forever $ readChan chanSvc >>= \case + (peer, svc, ref) + | PeerIdentityFull peerId <- peerIdentity peer + , PeerIdentityFull peerOwnerId <- peerOwner peer + , DatagramAddress paddr <- peerAddress peer + -> case maybe (lookup svc services) Just $ M.lookup svc (peerServiceState peer) of + Nothing -> logd $ "unhandled service '" ++ T.unpack svc ++ "'" + Just (SomeService s) -> do + let inp = ServiceInput + { svcPeer = peerId, svcPeerOwner = peerOwnerId + , svcPrintOp = logd + } + (rsp, s') <- handleServicePacket (storedStorage sidentity) inp s (wrappedLoad ref) + modifyMVar_ peers $ return . M.adjust (\p -> p { peerServiceState = M.insert svc (SomeService s') $ peerServiceState p }) paddr + runExceptT (maybe (return ()) (sendToPeer identity peer svc) rsp) >>= \case + Left err -> logd $ "failed to send response to peer: " ++ show err + Right () -> return () + + | DatagramAddress paddr <- peerAddress peer -> do + logd $ "service packet from peer with incomplete identity " ++ show paddr + + return chanPeer type PacketHandler a = StateT PacketHandlerState (ExceptT String IO) a @@ -452,7 +475,7 @@ handleServices chan = gets (peerServiceQueue . phPeer) >>= \case updatePeer $ \p -> p { peerServiceQueue = queue' } -sendToPeer :: (Storable a, MonadIO m, MonadError String m, MonadRandom m) => UnifiedIdentity -> Peer -> T.Text -> a -> m () +sendToPeer :: (Storable a, MonadIO m, MonadError String m) => UnifiedIdentity -> Peer -> T.Text -> a -> m () sendToPeer _ peer@Peer { peerChannel = ChannelEstablished ch } svc obj = do let st = peerInStorage peer ref <- liftIO $ store st obj diff --git a/src/Service.hs b/src/Service.hs new file mode 100644 index 0000000..667196d --- /dev/null +++ b/src/Service.hs @@ -0,0 +1,58 @@ +module Service ( + Service(..), + SomeService(..), + + ServiceHandler, + ServiceInput(..), ServiceState(..), + handleServicePacket, + + svcPrint, +) where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State + +import Identity +import State +import Storage + +class Storable (ServicePacket s) => Service s where + type ServicePacket s :: * + emptyServiceState :: s + serviceHandler :: Stored (ServicePacket s) -> ServiceHandler s (Maybe (ServicePacket s)) + +data SomeService = forall s. Service s => SomeService s + +data ServiceInput = ServiceInput + { svcPeer :: UnifiedIdentity + , svcPeerOwner :: UnifiedIdentity + , svcPrintOp :: String -> IO () + } + +data ServiceState s = ServiceState + { svcValue :: s + , svcLocal :: Stored LocalState + } + +newtype ServiceHandler s a = ServiceHandler (ReaderT ServiceInput (StateT (ServiceState s) (ExceptT String IO)) a) + deriving (Functor, Applicative, Monad, MonadReader ServiceInput, MonadState (ServiceState s), MonadIO) + +handleServicePacket :: Service s => Storage -> ServiceInput -> s -> Stored (ServicePacket s) -> IO (Maybe (ServicePacket s), s) +handleServicePacket st input svc packet = do + herb <- loadLocalState st + let erb = wrappedLoad $ headRef herb + sstate = ServiceState { svcValue = svc, svcLocal = erb } + ServiceHandler handler = serviceHandler packet + (runExceptT $ flip runStateT sstate $ flip runReaderT input $ handler) >>= \case + Left err -> do + svcPrintOp input $ "service failed: " ++ err + return (Nothing, svc) + Right (rsp, sstate') + | svcLocal sstate' == svcLocal sstate -> return (rsp, svcValue sstate') + | otherwise -> replaceHead (svcLocal sstate') (Right herb) >>= \case + Left _ -> handleServicePacket st input svc packet + Right _ -> return (rsp, svcValue sstate') + +svcPrint :: String -> ServiceHandler s () +svcPrint str = liftIO . ($str) =<< asks svcPrintOp |