summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Main.hs35
-rw-r--r--src/Message/Service.hs52
-rw-r--r--src/Network.hs35
-rw-r--r--src/Service.hs58
4 files changed, 144 insertions, 36 deletions
diff --git a/src/Main.hs b/src/Main.hs
index d473f2e..9e87af5 100644
--- a/src/Main.hs
+++ b/src/Main.hs
@@ -17,7 +17,6 @@ import Data.Char
import Data.List
import Data.Maybe
import qualified Data.Text as T
-import Data.Time.Format
import Data.Time.LocalTime
import System.Console.Haskeline
@@ -25,8 +24,10 @@ import System.Environment
import Identity
import Message
+import Message.Service
import Network
import PubKey
+import Service
import State
import Storage
@@ -67,8 +68,10 @@ interactiveLoop st bhost = runInputT defaultSettings $ do
False -> error "Requires terminal"
extPrint <- getExternalPrint
let extPrintLn str = extPrint $ str ++ "\n";
- (chanPeer, chanSvc) <- liftIO $
+ chanPeer <- liftIO $
startServer extPrintLn bhost self
+ [ (T.pack "dmsg", SomeService (emptyServiceState :: DirectMessageService))
+ ]
peers <- liftIO $ newMVar []
@@ -83,25 +86,6 @@ interactiveLoop st bhost = runInputT defaultSettings $ do
let shown = showPeer peer
when (Just shown /= (showPeer <$> op)) $ extPrint shown
- tzone <- liftIO $ getCurrentTimeZone
- void $ liftIO $ forkIO $ forever $ readChan chanSvc >>= \case
- (peer, svc, ref)
- | svc == T.pack "dmsg" -> do
- let smsg = wrappedLoad ref
- msg = fromStored smsg
- extPrintLn $ formatMessage tzone msg
- if | PeerIdentityFull powner <- peerOwner peer
- , idData powner == msgFrom msg
- -> updateLocalState_ st $ \erb -> do
- slist <- case find ((== idData powner) . msgPeer . fromStored) (storedFromSList $ lsMessages $ fromStored erb) of
- Just thread -> do thread' <- wrappedStore st (fromStored thread) { msgHead = smsg : msgHead (fromStored thread) }
- slistReplaceS thread thread' $ lsMessages $ fromStored erb
- Nothing -> slistAdd (emptyDirectThread powner) { msgHead = [smsg] } $ lsMessages $ fromStored erb
- wrappedStore st (fromStored erb) { lsMessages = slist }
-
- | otherwise -> extPrint $ "Owner mismatch"
- | otherwise -> extPrint $ "Unknown service: " ++ T.unpack svc
-
let getInputLines prompt = do
Just input <- lift $ getInputLine prompt
case reverse input of
@@ -229,12 +213,3 @@ cmdUpdateIdentity :: Command
cmdUpdateIdentity = void $ do
st <- asks $ storedStorage . idData . ciSelf
liftIO $ updateIdentity st
-
-
-formatMessage :: TimeZone -> DirectMessage -> String
-formatMessage tzone msg = concat
- [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg
- , maybe "<unnamed>" T.unpack $ iddName $ fromStored $ signedData $ fromStored $ msgFrom msg
- , ": "
- , T.unpack $ msgText msg
- ]
diff --git a/src/Message/Service.hs b/src/Message/Service.hs
new file mode 100644
index 0000000..a798fb5
--- /dev/null
+++ b/src/Message/Service.hs
@@ -0,0 +1,52 @@
+module Message.Service (
+ DirectMessageService,
+ formatMessage,
+) where
+
+import Control.Monad.Reader
+import Control.Monad.State
+
+import Data.List
+import qualified Data.Text as T
+import Data.Time.Format
+import Data.Time.LocalTime
+
+import Identity
+import Message
+import PubKey
+import Service
+import State
+import Storage
+
+data DirectMessageService = DirectMessageService
+
+instance Service DirectMessageService where
+ type ServicePacket DirectMessageService = DirectMessage
+ emptyServiceState = DirectMessageService
+ serviceHandler smsg = do
+ let msg = fromStored smsg
+ powner <- asks svcPeerOwner
+ tzone <- liftIO $ getCurrentTimeZone
+ svcPrint $ formatMessage tzone msg
+ if | idData powner == msgFrom msg
+ -> do erb <- gets svcLocal
+ let st = storedStorage erb
+ erb' <- liftIO $ do
+ slist <- case find ((== idData powner) . msgPeer . fromStored) (storedFromSList $ lsMessages $ fromStored erb) of
+ Just thread -> do thread' <- wrappedStore st (fromStored thread) { msgHead = smsg : msgHead (fromStored thread) }
+ slistReplaceS thread thread' $ lsMessages $ fromStored erb
+ Nothing -> slistAdd (emptyDirectThread powner) { msgHead = [smsg] } $ lsMessages $ fromStored erb
+ wrappedStore st (fromStored erb) { lsMessages = slist }
+ modify $ \s -> s { svcLocal = erb' }
+ return Nothing
+
+ | otherwise -> do svcPrint "Owner mismatch"
+ return Nothing
+
+formatMessage :: TimeZone -> DirectMessage -> String
+formatMessage tzone msg = concat
+ [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ msgTime msg
+ , maybe "<unnamed>" T.unpack $ iddName $ fromStored $ signedData $ fromStored $ msgFrom msg
+ , ": "
+ , T.unpack $ msgText msg
+ ]
diff --git a/src/Network.hs b/src/Network.hs
index 5d86a24..bff793a 100644
--- a/src/Network.hs
+++ b/src/Network.hs
@@ -4,6 +4,7 @@ module Network (
PeerIdentity(..), peerIdentityRef,
PeerChannel(..),
WaitingRef, wrDigest,
+ Service(..),
startServer,
sendToPeer,
) where
@@ -14,8 +15,6 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.State
-import Crypto.Random
-
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M
@@ -28,6 +27,7 @@ import Network.Socket.ByteString (recvFrom, sendTo)
import Channel
import Identity
import PubKey
+import Service
import Storage
@@ -43,6 +43,7 @@ data Peer = Peer
, peerSocket :: Socket
, peerStorage :: Storage
, peerInStorage :: PartialStorage
+ , peerServiceState :: M.Map T.Text SomeService
, peerServiceQueue :: [(T.Text, WaitingRef)]
, peerWaitingRefs :: [WaitingRef]
}
@@ -149,8 +150,8 @@ receivedWaitingRef nref wr@(WaitingRef _ _ mvar) = do
checkWaitingRef wr
-startServer :: (String -> IO ()) -> String -> UnifiedIdentity -> IO (Chan Peer, Chan (Peer, T.Text, Ref))
-startServer logd bhost identity = do
+startServer :: (String -> IO ()) -> String -> UnifiedIdentity -> [(T.Text, SomeService)] -> IO (Chan Peer)
+startServer logd bhost identity services = do
let sidentity = idData identity
chanPeer <- newChan
chanSvc <- newChan
@@ -191,6 +192,7 @@ startServer logd bhost identity = do
, peerSocket = sock
, peerStorage = pst
, peerInStorage = ist
+ , peerServiceState = M.empty
, peerServiceQueue = []
, peerWaitingRefs = []
}
@@ -220,7 +222,28 @@ startServer logd bhost identity = do
addr:_ <- getAddrInfo (Just hints) Nothing (Just discoveryPort)
bracket (open addr) close loop
- return (chanPeer, chanSvc)
+ void $ forkIO $ forever $ readChan chanSvc >>= \case
+ (peer, svc, ref)
+ | PeerIdentityFull peerId <- peerIdentity peer
+ , PeerIdentityFull peerOwnerId <- peerOwner peer
+ , DatagramAddress paddr <- peerAddress peer
+ -> case maybe (lookup svc services) Just $ M.lookup svc (peerServiceState peer) of
+ Nothing -> logd $ "unhandled service '" ++ T.unpack svc ++ "'"
+ Just (SomeService s) -> do
+ let inp = ServiceInput
+ { svcPeer = peerId, svcPeerOwner = peerOwnerId
+ , svcPrintOp = logd
+ }
+ (rsp, s') <- handleServicePacket (storedStorage sidentity) inp s (wrappedLoad ref)
+ modifyMVar_ peers $ return . M.adjust (\p -> p { peerServiceState = M.insert svc (SomeService s') $ peerServiceState p }) paddr
+ runExceptT (maybe (return ()) (sendToPeer identity peer svc) rsp) >>= \case
+ Left err -> logd $ "failed to send response to peer: " ++ show err
+ Right () -> return ()
+
+ | DatagramAddress paddr <- peerAddress peer -> do
+ logd $ "service packet from peer with incomplete identity " ++ show paddr
+
+ return chanPeer
type PacketHandler a = StateT PacketHandlerState (ExceptT String IO) a
@@ -452,7 +475,7 @@ handleServices chan = gets (peerServiceQueue . phPeer) >>= \case
updatePeer $ \p -> p { peerServiceQueue = queue' }
-sendToPeer :: (Storable a, MonadIO m, MonadError String m, MonadRandom m) => UnifiedIdentity -> Peer -> T.Text -> a -> m ()
+sendToPeer :: (Storable a, MonadIO m, MonadError String m) => UnifiedIdentity -> Peer -> T.Text -> a -> m ()
sendToPeer _ peer@Peer { peerChannel = ChannelEstablished ch } svc obj = do
let st = peerInStorage peer
ref <- liftIO $ store st obj
diff --git a/src/Service.hs b/src/Service.hs
new file mode 100644
index 0000000..667196d
--- /dev/null
+++ b/src/Service.hs
@@ -0,0 +1,58 @@
+module Service (
+ Service(..),
+ SomeService(..),
+
+ ServiceHandler,
+ ServiceInput(..), ServiceState(..),
+ handleServicePacket,
+
+ svcPrint,
+) where
+
+import Control.Monad.Except
+import Control.Monad.Reader
+import Control.Monad.State
+
+import Identity
+import State
+import Storage
+
+class Storable (ServicePacket s) => Service s where
+ type ServicePacket s :: *
+ emptyServiceState :: s
+ serviceHandler :: Stored (ServicePacket s) -> ServiceHandler s (Maybe (ServicePacket s))
+
+data SomeService = forall s. Service s => SomeService s
+
+data ServiceInput = ServiceInput
+ { svcPeer :: UnifiedIdentity
+ , svcPeerOwner :: UnifiedIdentity
+ , svcPrintOp :: String -> IO ()
+ }
+
+data ServiceState s = ServiceState
+ { svcValue :: s
+ , svcLocal :: Stored LocalState
+ }
+
+newtype ServiceHandler s a = ServiceHandler (ReaderT ServiceInput (StateT (ServiceState s) (ExceptT String IO)) a)
+ deriving (Functor, Applicative, Monad, MonadReader ServiceInput, MonadState (ServiceState s), MonadIO)
+
+handleServicePacket :: Service s => Storage -> ServiceInput -> s -> Stored (ServicePacket s) -> IO (Maybe (ServicePacket s), s)
+handleServicePacket st input svc packet = do
+ herb <- loadLocalState st
+ let erb = wrappedLoad $ headRef herb
+ sstate = ServiceState { svcValue = svc, svcLocal = erb }
+ ServiceHandler handler = serviceHandler packet
+ (runExceptT $ flip runStateT sstate $ flip runReaderT input $ handler) >>= \case
+ Left err -> do
+ svcPrintOp input $ "service failed: " ++ err
+ return (Nothing, svc)
+ Right (rsp, sstate')
+ | svcLocal sstate' == svcLocal sstate -> return (rsp, svcValue sstate')
+ | otherwise -> replaceHead (svcLocal sstate') (Right herb) >>= \case
+ Left _ -> handleServicePacket st input svc packet
+ Right _ -> return (rsp, svcValue sstate')
+
+svcPrint :: String -> ServiceHandler s ()
+svcPrint str = liftIO . ($str) =<< asks svcPrintOp