diff options
Diffstat (limited to 'src')
31 files changed, 1234 insertions, 455 deletions
diff --git a/src/Erebos/Attach.hs b/src/Erebos/Attach.hs index fad6197..b7c642f 100644 --- a/src/Erebos/Attach.hs +++ b/src/Erebos/Attach.hs @@ -59,7 +59,7 @@ instance PairingResult AttachIdentity where liftIO $ mapM_ storeKey $ catMaybes [ keyFromData sec pub | sec <- keys, pub <- pkeys ] identity' <- mergeIdentity $ updateIdentity [ lsIdentity $ fromStored slocal ] identity - shared <- makeSharedStateUpdate st (Just owner) (lsShared $ fromStored slocal) + shared <- makeSharedStateUpdate (Just owner) (lsShared $ fromStored slocal) mstore (fromStored slocal) { lsIdentity = idExtData identity' , lsShared = [ shared ] diff --git a/src/Erebos/Chatroom.hs b/src/Erebos/Chatroom.hs index 74456ff..5a86b23 100644 --- a/src/Erebos/Chatroom.hs +++ b/src/Erebos/Chatroom.hs @@ -17,6 +17,7 @@ module Erebos.Chatroom ( joinChatroomAs, joinChatroomAsByStateData, leaveChatroom, leaveChatroomByStateData, getMessagesSinceState, + isSameChatroom, ChatroomSetChange(..), watchChatrooms, @@ -49,6 +50,7 @@ import Data.Set qualified as S import Data.Text (Text) import Data.Time +import Erebos.Conversation.Class import Erebos.Identity import Erebos.PubKey import Erebos.Service @@ -60,6 +62,15 @@ import Erebos.Storage.Merge import Erebos.Util +instance ConversationType ChatroomState ChatMessage where + convMessageFrom = cmsgFrom + convMessageTime = cmsgTime + convMessageText = cmsgText + + convMessageListSince mbSince cstate = + map (, False) $ threadToListSince (maybe [] roomStateMessageData mbSince) (roomStateMessageData cstate) + + data ChatroomData = ChatroomData { rdPrev :: [Stored (Signed ChatroomData)] , rdName :: Maybe Text @@ -294,8 +305,7 @@ createChatroom rdName rdDescription = do } updateLocalState $ updateSharedState $ \rooms -> do - st <- getStorage - (, cstate) <$> storeSetAdd st cstate rooms + (, cstate) <$> storeSetAdd cstate rooms findAndUpdateChatroomState :: (MonadStorage m, MonadHead LocalState m) @@ -309,8 +319,7 @@ findAndUpdateChatroomState f = do upd <- act if roomStateData orig /= roomStateData upd then do - st <- getStorage - roomSet' <- storeSetAdd st upd roomSet + roomSet' <- storeSetAdd upd roomSet return (roomSet', Just upd) else do return (roomSet, Just upd) @@ -422,6 +431,11 @@ leaveChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupDa getMessagesSinceState :: ChatroomState -> ChatroomState -> [ChatMessage] getMessagesSinceState cur old = threadToListSince (roomStateMessageData old) (roomStateMessageData cur) +isSameChatroom :: ChatroomState -> ChatroomState -> Bool +isSameChatroom rstate rstate' = + let roots = filterAncestors . concatMap storedRoots . roomStateData + in intersectsSorted (roots rstate) (roots rstate') + data ChatroomSetChange = AddedChatroom ChatroomState | RemovedChatroom ChatroomState diff --git a/src/Erebos/Contact.hs b/src/Erebos/Contact.hs index 88e6c44..78a504a 100644 --- a/src/Erebos/Contact.hs +++ b/src/Erebos/Contact.hs @@ -4,6 +4,8 @@ module Erebos.Contact ( contactCustomName, contactName, + ContactData(..), + contactSetName, ContactService, @@ -83,13 +85,12 @@ contactName c = fromJust $ msum contactSetName :: MonadHead LocalState m => Contact -> Text -> Set Contact -> m (Set Contact) contactSetName contact name set = do - st <- getStorage - cdata <- wrappedStore st ContactData + cdata <- mstore ContactData { cdPrev = toComponents contact , cdIdentity = [] , cdName = Just name } - storeSetAdd st (mergeSorted @Contact [cdata]) set + storeSetAdd (mergeSorted @Contact [cdata]) set type ContactService = PairingService ContactAccepted @@ -166,10 +167,9 @@ contactReject = pairingReject @ContactAccepted Proxy finalizeContact :: MonadHead LocalState m => UnifiedIdentity -> m () finalizeContact identity = updateLocalState_ $ updateSharedState_ $ \contacts -> do - st <- getStorage - cdata <- wrappedStore st ContactData + cdata <- mstore ContactData { cdPrev = [] , cdIdentity = idExtDataF $ finalOwner identity , cdName = Nothing } - storeSetAdd st (mergeSorted @Contact [cdata]) contacts + storeSetAdd (mergeSorted @Contact [cdata]) contacts diff --git a/src/Erebos/Conversation.hs b/src/Erebos/Conversation.hs index dee6faa..2c6f967 100644 --- a/src/Erebos/Conversation.hs +++ b/src/Erebos/Conversation.hs @@ -7,9 +7,11 @@ module Erebos.Conversation ( formatMessage, Conversation, + isSameConversation, directMessageConversation, chatroomConversation, chatroomConversationByStateData, + isChatroomStateConversation, reloadConversation, lookupConversations, @@ -31,47 +33,60 @@ import Data.Time.Format import Data.Time.LocalTime import Erebos.Chatroom +import Erebos.Conversation.Class import Erebos.DirectMessage import Erebos.Identity import Erebos.State import Erebos.Storable -data Message = DirectMessageMessage DirectMessage Bool - | ChatroomMessage ChatMessage Bool +data Message = forall conv msg. ConversationType conv msg => Message msg Bool + +withMessage :: (forall conv msg. ConversationType conv msg => msg -> a) -> Message -> a +withMessage f (Message msg _) = f msg messageFrom :: Message -> ComposedIdentity -messageFrom (DirectMessageMessage msg _) = msgFrom msg -messageFrom (ChatroomMessage msg _) = cmsgFrom msg +messageFrom = withMessage convMessageFrom messageTime :: Message -> ZonedTime -messageTime (DirectMessageMessage msg _) = msgTime msg -messageTime (ChatroomMessage msg _) = cmsgTime msg +messageTime = withMessage convMessageTime messageText :: Message -> Maybe Text -messageText (DirectMessageMessage msg _) = Just $ msgText msg -messageText (ChatroomMessage msg _) = cmsgText msg +messageText = withMessage convMessageText messageUnread :: Message -> Bool -messageUnread (DirectMessageMessage _ unread) = unread -messageUnread (ChatroomMessage _ unread) = unread +messageUnread (Message _ unread) = unread formatMessage :: TimeZone -> Message -> String formatMessage tzone msg = concat - [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ messageTime msg + [ if messageUnread msg then "\ESC[93m" else "" + , formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ messageTime msg , maybe "<unnamed>" T.unpack $ idName $ messageFrom msg , maybe "" ((": "<>) . T.unpack) $ messageText msg + , if messageUnread msg then "\ESC[0m" else "" ] -data Conversation = DirectMessageConversation DirectMessageThread - | ChatroomConversation ChatroomState +data Conversation + = DirectMessageConversation DirectMessageThread + | ChatroomConversation ChatroomState + +withConversation :: (forall conv msg. ConversationType conv msg => conv -> a) -> Conversation -> a +withConversation f (DirectMessageConversation conv) = f conv +withConversation f (ChatroomConversation conv) = f conv + +isSameConversation :: Conversation -> Conversation -> Bool +isSameConversation (DirectMessageConversation t) (DirectMessageConversation t') + = sameIdentity (msgPeer t) (msgPeer t') +isSameConversation (ChatroomConversation rstate) (ChatroomConversation rstate') = isSameChatroom rstate rstate' +isSameConversation _ _ = False directMessageConversation :: MonadHead LocalState m => ComposedIdentity -> m Conversation directMessageConversation peer = do - (find (sameIdentity peer . msgPeer) . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead) >>= \case + createOrUpdateDirectMessagePeer peer + (find (sameIdentity peer . msgPeer) . dmThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead) >>= \case Just thread -> return $ DirectMessageConversation thread - Nothing -> return $ DirectMessageConversation $ DirectMessageThread peer [] [] [] + Nothing -> return $ DirectMessageConversation $ DirectMessageThread peer [] [] [] [] chatroomConversation :: MonadHead LocalState m => ChatroomState -> m (Maybe Conversation) chatroomConversation rstate = chatroomConversationByStateData (head $ roomStateData rstate) @@ -79,13 +94,17 @@ chatroomConversation rstate = chatroomConversationByStateData (head $ roomStateD chatroomConversationByStateData :: MonadHead LocalState m => Stored ChatroomStateData -> m (Maybe Conversation) chatroomConversationByStateData sdata = fmap ChatroomConversation <$> findChatroomByStateData sdata +isChatroomStateConversation :: ChatroomState -> Conversation -> Bool +isChatroomStateConversation rstate (ChatroomConversation rstate') = isSameChatroom rstate rstate' +isChatroomStateConversation _ _ = False + reloadConversation :: MonadHead LocalState m => Conversation -> m Conversation reloadConversation (DirectMessageConversation thread) = directMessageConversation (msgPeer thread) reloadConversation cur@(ChatroomConversation rstate) = fromMaybe cur <$> chatroomConversation rstate -lookupConversations :: MonadHead LocalState m => m [Conversation] -lookupConversations = map DirectMessageConversation . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead +lookupConversations :: MonadHead LocalState m => m [ Conversation ] +lookupConversations = map DirectMessageConversation . dmThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead conversationName :: Conversation -> Text @@ -96,14 +115,13 @@ conversationPeer :: Conversation -> Maybe ComposedIdentity conversationPeer (DirectMessageConversation thread) = Just $ msgPeer thread conversationPeer (ChatroomConversation _) = Nothing -conversationHistory :: Conversation -> [Message] -conversationHistory (DirectMessageConversation thread) = map (\msg -> DirectMessageMessage msg False) $ threadToList thread -conversationHistory (ChatroomConversation rstate) = map (\msg -> ChatroomMessage msg False) $ roomStateMessages rstate +conversationHistory :: Conversation -> [ Message ] +conversationHistory = withConversation $ map (uncurry Message) . convMessageListSince Nothing -sendMessage :: (MonadHead LocalState m, MonadError e m, FromErebosError e) => Conversation -> Text -> m (Maybe Message) -sendMessage (DirectMessageConversation thread) text = fmap Just $ DirectMessageMessage <$> (fromStored <$> sendDirectMessage (msgPeer thread) text) <*> pure False -sendMessage (ChatroomConversation rstate) text = sendChatroomMessage rstate text >> return Nothing +sendMessage :: (MonadHead LocalState m, MonadError e m, FromErebosError e) => Conversation -> Text -> m () +sendMessage (DirectMessageConversation thread) text = sendDirectMessage (msgPeer thread) text +sendMessage (ChatroomConversation rstate) text = sendChatroomMessage rstate text deleteConversation :: (MonadHead LocalState m, MonadError e m, FromErebosError e) => Conversation -> m () deleteConversation (DirectMessageConversation _) = throwOtherError "deleting direct message conversation is not supported" diff --git a/src/Erebos/Conversation/Class.hs b/src/Erebos/Conversation/Class.hs new file mode 100644 index 0000000..6a28651 --- /dev/null +++ b/src/Erebos/Conversation/Class.hs @@ -0,0 +1,16 @@ +module Erebos.Conversation.Class ( + ConversationType(..), +) where + +import Data.Text (Text) +import Data.Time.LocalTime +import Data.Typeable + +import Erebos.Identity + + +class (Typeable conv, Typeable msg) => ConversationType conv msg | conv -> msg, msg -> conv where + convMessageFrom :: msg -> ComposedIdentity + convMessageTime :: msg -> ZonedTime + convMessageText :: msg -> Maybe Text + convMessageListSince :: Maybe conv -> conv -> [ ( msg, Bool ) ] diff --git a/src/Erebos/DirectMessage.hs b/src/Erebos/DirectMessage.hs index 05da865..dd10d35 100644 --- a/src/Erebos/DirectMessage.hs +++ b/src/Erebos/DirectMessage.hs @@ -1,43 +1,62 @@ module Erebos.DirectMessage ( DirectMessage(..), sendDirectMessage, + dmMarkAsSeen, + updateDirectMessagePeer, + createOrUpdateDirectMessagePeer, DirectMessageAttributes(..), defaultDirectMessageAttributes, DirectMessageThreads, - toThreadList, + dmThreadList, DirectMessageThread(..), - threadToList, - messageThreadView, + dmThreadToList, dmThreadToListSince, dmThreadToListUnread, dmThreadToListSinceUnread, + dmThreadView, - watchReceivedMessages, + watchDirectMessageThreads, formatDirectMessage, ) where +import Control.Concurrent.MVar import Control.Monad +import Control.Monad.Except import Control.Monad.Reader import Data.List import Data.Ord -import qualified Data.Set as S +import Data.Set (Set) +import Data.Set qualified as S import Data.Text (Text) -import qualified Data.Text as T +import Data.Text qualified as T import Data.Time.Format import Data.Time.LocalTime +import Erebos.Conversation.Class +import Erebos.Discovery import Erebos.Identity import Erebos.Network +import Erebos.Object import Erebos.Service import Erebos.State import Erebos.Storable import Erebos.Storage.Head import Erebos.Storage.Merge + +instance ConversationType DirectMessageThread DirectMessage where + convMessageFrom = msgFrom + convMessageTime = msgTime + convMessageText = Just . msgText + + convMessageListSince mbSince thread = + threadToListHelper (msgSeen thread) (maybe S.empty (S.fromAscList . msgHead) mbSince) (msgHead thread) + + data DirectMessage = DirectMessage { msgFrom :: ComposedIdentity - , msgPrev :: [Stored DirectMessage] + , msgPrev :: [ Stored DirectMessage ] , msgTime :: ZonedTime , msgText :: Text } @@ -74,7 +93,6 @@ instance Service DirectMessage where let msg = fromStored smsg powner <- asks $ finalOwner . svcPeerIdentity erb <- svcGetLocal - st <- getStorage let DirectMessageThreads prev _ = lookupSharedValue $ lsShared $ fromStored erb sent = findMsgProperty powner msSent prev received = findMsgProperty powner msReceived prev @@ -83,7 +101,7 @@ instance Service DirectMessage where filterAncestors sent == filterAncestors (smsg : sent) then do when (received' /= received) $ do - next <- wrappedStore st $ MessageState + next <- mstore MessageState { msPrev = prev , msPeer = powner , msReady = [] @@ -91,37 +109,43 @@ instance Service DirectMessage where , msReceived = received' , msSeen = [] } - let threads = DirectMessageThreads [next] (messageThreadView [next]) - shared <- makeSharedStateUpdate st threads (lsShared $ fromStored erb) - svcSetLocal =<< wrappedStore st (fromStored erb) { lsShared = [shared] } + let threads = DirectMessageThreads [ next ] (dmThreadView [ next ]) + shared <- makeSharedStateUpdate threads (lsShared $ fromStored erb) + svcSetLocal =<< mstore (fromStored erb) { lsShared = [ shared ] } when (powner `sameIdentity` msgFrom msg) $ do replyStoredRef smsg else join $ asks $ dmOwnerMismatch . svcAttributes - serviceNewPeer = syncDirectMessageToPeer . lookupSharedValue . lsShared . fromStored =<< svcGetLocal + serviceNewPeer = do + syncDirectMessageToPeer . lookupSharedValue . lsShared . fromStored =<< svcGetLocal - serviceStorageWatchers _ = (:[]) $ - SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncDirectMessageToPeer + serviceUpdatedPeer = do + updateDirectMessagePeer . finalOwner =<< asks svcPeerIdentity + + serviceStorageWatchers _ = + [ SomeStorageWatcher (lookupSharedValue . lsShared . fromStored) syncDirectMessageToPeer + , GlobalStorageWatcher (lookupSharedValue . lsShared . fromStored) findMissingPeers + ] data MessageState = MessageState - { msPrev :: [Stored MessageState] + { msPrev :: [ Stored MessageState ] , msPeer :: ComposedIdentity - , msReady :: [Stored DirectMessage] - , msSent :: [Stored DirectMessage] - , msReceived :: [Stored DirectMessage] - , msSeen :: [Stored DirectMessage] + , msReady :: [ Stored DirectMessage ] + , msSent :: [ Stored DirectMessage ] + , msReceived :: [ Stored DirectMessage ] + , msSeen :: [ Stored DirectMessage ] } -data DirectMessageThreads = DirectMessageThreads [Stored MessageState] [DirectMessageThread] +data DirectMessageThreads = DirectMessageThreads [ Stored MessageState ] [ DirectMessageThread ] instance Eq DirectMessageThreads where DirectMessageThreads mss _ == DirectMessageThreads mss' _ = mss == mss' -toThreadList :: DirectMessageThreads -> [DirectMessageThread] -toThreadList (DirectMessageThreads _ threads) = threads +dmThreadList :: DirectMessageThreads -> [ DirectMessageThread ] +dmThreadList (DirectMessageThreads _ threads) = threads instance Storable MessageState where store' MessageState {..} = storeRec $ do @@ -143,13 +167,13 @@ instance Storable MessageState where instance Mergeable DirectMessageThreads where type Component DirectMessageThreads = MessageState - mergeSorted mss = DirectMessageThreads mss (messageThreadView mss) + mergeSorted mss = DirectMessageThreads mss (dmThreadView mss) toComponents (DirectMessageThreads mss _) = mss instance SharedType DirectMessageThreads where sharedTypeID _ = mkSharedTypeID "ee793681-5976-466a-b0f0-4e1907d3fade" -findMsgProperty :: Foldable m => Identity m -> (MessageState -> [a]) -> [Stored MessageState] -> [a] +findMsgProperty :: Foldable m => Identity m -> (MessageState -> [ a ]) -> [ Stored MessageState ] -> [ a ] findMsgProperty pid sel mss = concat $ flip findProperty mss $ \x -> do guard $ msPeer x `sameIdentity` pid guard $ not $ null $ sel x @@ -157,11 +181,11 @@ findMsgProperty pid sel mss = concat $ flip findProperty mss $ \x -> do sendDirectMessage :: (Foldable f, Applicative f, MonadHead LocalState m) - => Identity f -> Text -> m (Stored DirectMessage) -sendDirectMessage pid text = updateLocalState $ \ls -> do + => Identity f -> Text -> m () +sendDirectMessage pid text = updateLocalState_ $ \ls -> do let self = localIdentity $ fromStored ls powner = finalOwner pid - flip updateSharedState ls $ \(DirectMessageThreads prev _) -> do + flip updateSharedState_ ls $ \(DirectMessageThreads prev _) -> do let ready = findMsgProperty powner msReady prev received = findMsgProperty powner msReceived prev @@ -175,12 +199,69 @@ sendDirectMessage pid text = updateLocalState $ \ls -> do next <- mstore MessageState { msPrev = prev , msPeer = powner - , msReady = [smsg] + , msReady = [ smsg ] , msSent = [] , msReceived = [] , msSeen = [] } - return (DirectMessageThreads [next] (messageThreadView [next]), smsg) + return $ DirectMessageThreads [ next ] (dmThreadView [ next ]) + +dmMarkAsSeen + :: (Foldable f, Applicative f, MonadHead LocalState m) + => Identity f -> m () +dmMarkAsSeen pid = do + updateLocalState_ $ updateSharedState_ $ \(DirectMessageThreads prev _) -> do + let powner = finalOwner pid + received = findMsgProperty powner msReceived prev + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = [] + , msReceived = [] + , msSeen = received + } + return $ DirectMessageThreads [ next ] (dmThreadView [ next ]) + +updateDirectMessagePeer + :: (Foldable f, Applicative f, MonadHead LocalState m) + => Identity f -> m () +updateDirectMessagePeer = createOrUpdateDirectMessagePeer' False + +createOrUpdateDirectMessagePeer + :: (Foldable f, Applicative f, MonadHead LocalState m) + => Identity f -> m () +createOrUpdateDirectMessagePeer = createOrUpdateDirectMessagePeer' True + +createOrUpdateDirectMessagePeer' + :: (Foldable f, Applicative f, MonadHead LocalState m) + => Bool -> Identity f -> m () +createOrUpdateDirectMessagePeer' create pid = do + let powner = finalOwner pid + updateLocalState_ $ updateSharedState_ $ \old@(DirectMessageThreads prev threads) -> do + let updatePeerThread = do + next <- mstore MessageState + { msPrev = prev + , msPeer = powner + , msReady = [] + , msSent = [] + , msReceived = [] + , msSeen = [] + } + return $ DirectMessageThreads [ next ] (dmThreadView [ next ]) + case find (sameIdentity powner . msgPeer) threads of + Nothing + | create + -> updatePeerThread + + Just thread + | oldPeer <- msgPeer thread + , newPeer <- updateIdentity (idExtDataF powner) oldPeer + , oldPeer /= newPeer + -> updatePeerThread + + _ -> return old + syncDirectMessageToPeer :: DirectMessageThreads -> ServiceHandler DirectMessage () syncDirectMessageToPeer (DirectMessageThreads mss _) = do @@ -205,28 +286,47 @@ syncDirectMessageToPeer (DirectMessageThreads mss _) = do , msReceived = [] , msSeen = [] } - return $ DirectMessageThreads [next] (messageThreadView [next]) + return $ DirectMessageThreads [ next ] (dmThreadView [ next ]) else do return unchanged +findMissingPeers :: Server -> DirectMessageThreads -> ExceptT ErebosError IO () +findMissingPeers server threads = do + forM_ (dmThreadList threads) $ \thread -> do + when (msgHead thread /= msgReceived thread) $ do + mapM_ (discoverySearch server) $ map (refDigest . storedRef) $ idDataF $ msgPeer thread + data DirectMessageThread = DirectMessageThread { msgPeer :: ComposedIdentity - , msgHead :: [Stored DirectMessage] - , msgSent :: [Stored DirectMessage] - , msgSeen :: [Stored DirectMessage] + , msgHead :: [ Stored DirectMessage ] + , msgSent :: [ Stored DirectMessage ] + , msgSeen :: [ Stored DirectMessage ] + , msgReceived :: [ Stored DirectMessage ] } -threadToList :: DirectMessageThread -> [DirectMessage] -threadToList thread = helper S.empty $ msgHead thread - where helper seen msgs - | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = - fromStored msg : helper (S.insert msg seen) (msgs' ++ msgPrev (fromStored msg)) - | otherwise = [] - cmpView msg = (zonedTimeToUTC $ msgTime $ fromStored msg, msg) +dmThreadToList :: DirectMessageThread -> [ DirectMessage ] +dmThreadToList thread = map fst $ threadToListHelper (msgSeen thread) S.empty $ msgHead thread + +dmThreadToListSince :: DirectMessageThread -> DirectMessageThread -> [ DirectMessage ] +dmThreadToListSince since thread = map fst $ threadToListHelper (msgSeen thread) (S.fromAscList $ msgHead since) (msgHead thread) + +dmThreadToListUnread :: DirectMessageThread -> [ ( DirectMessage, Bool ) ] +dmThreadToListUnread thread = threadToListHelper (msgSeen thread) S.empty $ msgHead thread + +dmThreadToListSinceUnread :: DirectMessageThread -> DirectMessageThread -> [ ( DirectMessage, Bool ) ] +dmThreadToListSinceUnread since thread = threadToListHelper (msgSeen thread) (S.fromAscList $ msgHead since) (msgHead thread) + +threadToListHelper :: [ Stored DirectMessage ] -> Set (Stored DirectMessage) -> [ Stored DirectMessage ] -> [ ( DirectMessage, Bool ) ] +threadToListHelper seen used msgs + | msg : msgs' <- filter (`S.notMember` used) $ reverse $ sortBy (comparing cmpView) msgs = + ( fromStored msg, not $ any (msg `precedesOrEquals`) seen ) : threadToListHelper seen (S.insert msg used) (msgs' ++ msgPrev (fromStored msg)) + | otherwise = [] + where + cmpView msg = (zonedTimeToUTC $ msgTime $ fromStored msg, msg) -messageThreadView :: [Stored MessageState] -> [DirectMessageThread] -messageThreadView = helper [] +dmThreadView :: [ Stored MessageState ] -> [ DirectMessageThread ] +dmThreadView = helper [] where helper used ms' = case filterAncestors ms' of mss@(sms : rest) | any (sameIdentity $ msPeer $ fromStored sms) used -> @@ -236,7 +336,7 @@ messageThreadView = helper [] in messageThreadFor peer mss : helper (peer : used) (msPrev (fromStored sms) ++ rest) _ -> [] -messageThreadFor :: ComposedIdentity -> [Stored MessageState] -> DirectMessageThread +messageThreadFor :: ComposedIdentity -> [ Stored MessageState ] -> DirectMessageThread messageThreadFor peer mss = let ready = findMsgProperty peer msReady mss sent = findMsgProperty peer msSent mss @@ -248,15 +348,28 @@ messageThreadFor peer mss = , msgHead = filterAncestors $ ready ++ received , msgSent = filterAncestors $ sent ++ received , msgSeen = filterAncestors $ ready ++ seen + , msgReceived = filterAncestors $ received } -watchReceivedMessages :: Head LocalState -> (Stored DirectMessage -> IO ()) -> IO WatchedHead -watchReceivedMessages h f = do - let self = finalOwner $ localIdentity $ headObject h +watchDirectMessageThreads :: Head LocalState -> (DirectMessageThread -> DirectMessageThread -> IO ()) -> IO WatchedHead +watchDirectMessageThreads h f = do + prevVar <- newMVar Nothing watchHeadWith h (lookupSharedValue . lsShared . headObject) $ \(DirectMessageThreads sms _) -> do - forM_ (map fromStored sms) $ \ms -> do - mapM_ f $ filter (not . sameIdentity self . msgFrom . fromStored) $ msReceived ms + modifyMVar_ prevVar $ \case + Just prev -> do + let addPeer (p : ps) p' + | p `sameIdentity` p' = p : ps + | otherwise = p : addPeer ps p' + addPeer [] p' = [ p' ] + + let peers = foldl' addPeer [] $ map (msPeer . fromStored) $ storedDifference prev sms + forM_ peers $ \peer -> do + f (messageThreadFor peer prev) (messageThreadFor peer sms) + return (Just sms) + + Nothing -> do + return (Just sms) formatDirectMessage :: TimeZone -> DirectMessage -> String formatDirectMessage tzone msg = concat diff --git a/src/Erebos/Discovery.hs b/src/Erebos/Discovery.hs index 2fb0ffe..5590e4c 100644 --- a/src/Erebos/Discovery.hs +++ b/src/Erebos/Discovery.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE OverloadedStrings #-} module Erebos.Discovery ( DiscoveryService(..), @@ -6,6 +7,7 @@ module Erebos.Discovery ( DiscoveryConnection(..), discoverySearch, + discoverySetupTunnel, ) where import Control.Concurrent @@ -13,7 +15,6 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Reader -import Data.IP qualified as IP import Data.List import Data.Map.Strict (Map) import Data.Map.Strict qualified as M @@ -25,15 +26,17 @@ import Data.Text (Text) import Data.Text qualified as T import Data.Word -import Network.Socket +import Text.Read #ifdef ENABLE_ICE_SUPPORT import Erebos.ICE #endif import Erebos.Identity import Erebos.Network +import Erebos.Network.Address import Erebos.Object import Erebos.Service +import Erebos.Service.Stream import Erebos.Storable @@ -45,18 +48,25 @@ type IceRemoteInfo = Stored Object data DiscoveryService - = DiscoverySelf [ Text ] (Maybe Int) - | DiscoveryAcknowledged [ Text ] (Maybe Text) (Maybe Word16) (Maybe Text) (Maybe Word16) + = DiscoverySelf [ DiscoveryAddress ] (Maybe Int) + | DiscoveryAcknowledged [ DiscoveryAddress ] (Maybe Text) (Maybe Word16) (Maybe Text) (Maybe Word16) | DiscoverySearch (Either Ref RefDigest) - | DiscoveryResult (Either Ref RefDigest) [ Text ] + | DiscoveryResult (Either Ref RefDigest) [ DiscoveryAddress ] | DiscoveryConnectionRequest DiscoveryConnection | DiscoveryConnectionResponse DiscoveryConnection +data DiscoveryAddress + = DiscoveryIP InetAddress PortNumber + | DiscoveryICE + | DiscoveryTunnel + | DiscoveryOther Text + data DiscoveryAttributes = DiscoveryAttributes { discoveryStunPort :: Maybe Word16 , discoveryStunServer :: Maybe Text , discoveryTurnPort :: Maybe Word16 , discoveryTurnServer :: Maybe Text + , discoveryProvideTunnel :: Peer -> PeerAddress -> Bool } defaultDiscoveryAttributes :: DiscoveryAttributes @@ -65,12 +75,14 @@ defaultDiscoveryAttributes = DiscoveryAttributes , discoveryStunServer = Nothing , discoveryTurnPort = Nothing , discoveryTurnServer = Nothing + , discoveryProvideTunnel = \_ _ -> False } data DiscoveryConnection = DiscoveryConnection { dconnSource :: Either Ref RefDigest , dconnTarget :: Either Ref RefDigest , dconnAddress :: Maybe Text + , dconnTunnel :: Bool , dconnIceInfo :: Maybe IceRemoteInfo } @@ -78,6 +90,7 @@ emptyConnection :: Either Ref RefDigest -> Either Ref RefDigest -> DiscoveryConn emptyConnection dconnSource dconnTarget = DiscoveryConnection {..} where dconnAddress = Nothing + dconnTunnel = False dconnIceInfo = Nothing instance Storable DiscoveryService where @@ -101,11 +114,12 @@ instance Storable DiscoveryService where DiscoveryConnectionResponse conn -> storeConnection "response" conn where - storeConnection ctype DiscoveryConnection {..} = do + storeConnection (ctype :: Text) DiscoveryConnection {..} = do storeText "connection" $ ctype either (storeRawRef "source") (storeRawWeak "source") dconnSource either (storeRawRef "target") (storeRawWeak "target") dconnTarget storeMbText "address" dconnAddress + when dconnTunnel $ storeEmpty "tunnel" storeMbRef "ice-info" dconnIceInfo load' = loadRec $ msum @@ -138,7 +152,7 @@ instance Storable DiscoveryService where , loadConnection "response" DiscoveryConnectionResponse ] where - loadConnection ctype ctor = do + loadConnection (ctype :: Text) ctor = do ctype' <- loadText "connection" guard $ ctype == ctype' dconnSource <- msum @@ -150,13 +164,37 @@ instance Storable DiscoveryService where , Right <$> loadRawWeak "target" ] dconnAddress <- loadMbText "address" + dconnTunnel <- isJust <$> loadMbEmpty "tunnel" dconnIceInfo <- loadMbRef "ice-info" return $ ctor DiscoveryConnection {..} +instance StorableText DiscoveryAddress where + toText = \case + DiscoveryIP addr port -> T.unwords [ T.pack $ show addr, T.pack $ show port ] + DiscoveryICE -> "ICE" + DiscoveryTunnel -> "tunnel" + DiscoveryOther str -> str + + fromText str = return $ if + | [ addrStr, portStr ] <- T.words str + , Just addr <- readMaybe $ T.unpack addrStr + , Just port <- readMaybe $ T.unpack portStr + -> DiscoveryIP addr port + + | "ice" <- T.toLower str + -> DiscoveryICE + + | "tunnel" <- str + -> DiscoveryTunnel + + | otherwise + -> DiscoveryOther str + + data DiscoveryPeer = DiscoveryPeer { dpPriority :: Int , dpPeer :: Maybe Peer - , dpAddress :: [ Text ] + , dpAddress :: [ DiscoveryAddress ] , dpIceSession :: Maybe IceSession } @@ -169,7 +207,11 @@ emptyPeer = DiscoveryPeer } data DiscoveryPeerState = DiscoveryPeerState - { dpsStunServer :: Maybe ( Text, Word16 ) + { dpsOurTunnelRequests :: [ ( RefDigest, StreamWriter ) ] + -- ( original target, our write stream ) + , dpsRelayedTunnelRequests :: [ ( RefDigest, ( StreamReader, StreamWriter )) ] + -- ( original source, ( from source, to target )) + , dpsStunServer :: Maybe ( Text, Word16 ) , dpsTurnServer :: Maybe ( Text, Word16 ) , dpsIceConfig :: Maybe IceConfig } @@ -187,7 +229,9 @@ instance Service DiscoveryService where type ServiceState DiscoveryService = DiscoveryPeerState emptyServiceState _ = DiscoveryPeerState - { dpsStunServer = Nothing + { dpsOurTunnelRequests = [] + , dpsRelayedTunnelRequests = [] + , dpsStunServer = Nothing , dpsTurnServer = Nothing , dpsIceConfig = Nothing } @@ -202,26 +246,22 @@ instance Service DiscoveryService where DiscoverySelf addrs priority -> do pid <- asks svcPeerIdentity peer <- asks svcPeer + paddrs <- getPeerAddresses peer + let insertHelper new old | dpPriority new > dpPriority old = new | otherwise = old - matchedAddrs <- fmap catMaybes $ forM addrs $ \addr -> if - | addr == T.pack "ICE" -> do - return $ Just addr - | [ ipaddr, port ] <- words (T.unpack addr) - , DatagramAddress paddr <- peerAddress peer -> do - saddr <- liftIO $ head <$> getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) - return $ if paddr == addrAddress saddr - then Just addr - else Nothing - - | otherwise -> return Nothing + let matchedAddrs = flip filter addrs $ \case + DiscoveryICE -> True + DiscoveryIP ipaddr port -> + DatagramAddress (inetToSockAddr ( ipaddr, port )) `elem` paddrs + _ -> False forM_ (idDataF =<< unfoldOwners pid) $ \sdata -> do let dp = DiscoveryPeer { dpPriority = fromMaybe 0 priority , dpPeer = Just peer - , dpAddress = addrs + , dpAddress = matchedAddrs , dpIceSession = Nothing } svcModifyGlobal $ \s -> s { dgsPeers = M.insertWith insertHelper (refDigest $ storedRef sdata) dp $ dgsPeers s } @@ -233,14 +273,8 @@ instance Service DiscoveryService where (discoveryTurnPort attrs) DiscoveryAcknowledged _ stunServer stunPort turnServer turnPort -> do - paddr <- asks (peerAddress . svcPeer) >>= return . \case - (DatagramAddress saddr) -> case IP.fromSockAddr saddr of - Just (IP.IPv6 ipv6, _) - | (0, 0, 0xffff, ipv4) <- IP.fromIPv6w ipv6 - -> Just $ T.pack $ show (IP.toIPv4w ipv4) - Just (addr, _) - -> Just $ T.pack $ show addr - _ -> Nothing + paddr <- asks svcPeerAddress >>= return . \case + (DatagramAddress saddr) -> T.pack . show . fst <$> inetFromSockAddr saddr _ -> Nothing let toIceServer Nothing Nothing = Nothing @@ -255,10 +289,17 @@ instance Service DiscoveryService where DiscoverySearch edgst -> do dpeer <- M.lookup (either refDigest id edgst) . dgsPeers <$> svcGetGlobal - replyPacket $ DiscoveryResult edgst $ maybe [] dpAddress dpeer + peer <- asks svcPeer + paddr <- asks svcPeerAddress + attrs <- asks svcAttributes + let offerTunnel + | discoveryProvideTunnel attrs peer paddr = (++ [ DiscoveryTunnel ]) + | otherwise = id + replyPacket $ DiscoveryResult edgst $ maybe [] (offerTunnel . dpAddress) dpeer - DiscoveryResult edgst [] -> do - svcPrint $ "Discovery: " ++ show (either refDigest id edgst) ++ " not found" + DiscoveryResult _ [] -> do + -- not found + return () DiscoveryResult edgst addrs -> do let dgst = either refDigest id edgst @@ -269,56 +310,82 @@ instance Service DiscoveryService where discoveryPeer <- asks svcPeer let runAsService = runPeerService @DiscoveryService discoveryPeer - forM_ addrs $ \addr -> if - | addr == T.pack "ICE" - -> do -#ifdef ENABLE_ICE_SUPPORT - getIceConfig >>= \case - Just config -> void $ liftIO $ forkIO $ do - ice <- iceCreateSession config PjIceSessRoleControlling $ \ice -> do - rinfo <- iceRemoteInfo ice - - -- Try to promote weak ref to normal one for older peers: - edgst' <- case edgst of - Left r -> return (Left r) - Right d -> refFromDigest st d >>= \case - Just r -> return (Left r) - Nothing -> return (Right d) - - res <- runExceptT $ sendToPeer discoveryPeer $ - DiscoveryConnectionRequest (emptyConnection (Left $ storedRef $ idData self) edgst') { dconnIceInfo = Just rinfo } - case res of - Right _ -> return () - Left err -> putStrLn $ "Discovery: failed to send connection request: " ++ err - + let tryAddresses = \case + DiscoveryIP ipaddr port : _ -> do + void $ liftIO $ forkIO $ do + let saddr = inetToSockAddr ( ipaddr, port ) + peer <- serverPeer server saddr runAsService $ do - let upd dp = dp { dpIceSession = Just ice } + let upd dp = dp { dpPeer = Just peer } svcModifyGlobal $ \s -> s { dgsPeers = M.alter (Just . upd . fromMaybe emptyPeer) dgst $ dgsPeers s } - Nothing -> do - return () + DiscoveryICE : rest -> do +#ifdef ENABLE_ICE_SUPPORT + getIceConfig >>= \case + Just config -> do + void $ liftIO $ forkIO $ do + ice <- iceCreateSession config PjIceSessRoleControlling $ \ice -> do + rinfo <- iceRemoteInfo ice + + -- Try to promote weak ref to normal one for older peers: + edgst' <- case edgst of + Left r -> return (Left r) + Right d -> refFromDigest st d >>= \case + Just r -> return (Left r) + Nothing -> return (Right d) + + res <- runExceptT $ sendToPeer discoveryPeer $ + DiscoveryConnectionRequest (emptyConnection (Left $ storedRef $ idData self) edgst') { dconnIceInfo = Just rinfo } + case res of + Right _ -> return () + Left err -> putStrLn $ "Discovery: failed to send connection request: " ++ err + + runAsService $ do + let upd dp = dp { dpIceSession = Just ice } + svcModifyGlobal $ \s -> s { dgsPeers = M.alter (Just . upd . fromMaybe emptyPeer) dgst $ dgsPeers s } + + Nothing -> do #endif - return () - - | [ ipaddr, port ] <- words (T.unpack addr) -> do - void $ liftIO $ forkIO $ do - saddr <- head <$> - getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) - peer <- serverPeer server (addrAddress saddr) - runAsService $ do - let upd dp = dp { dpPeer = Just peer } - svcModifyGlobal $ \s -> s { dgsPeers = M.alter (Just . upd . fromMaybe emptyPeer) dgst $ dgsPeers s } + tryAddresses rest - | otherwise -> do - svcPrint $ "Discovery: invalid address in result: " ++ T.unpack addr + DiscoveryTunnel : _ -> do + discoverySetupTunnelResponse dgst + + addr : rest -> do + svcPrint $ "Discovery: unsupported address in result: " ++ T.unpack (toText addr) + tryAddresses rest + + [] -> svcPrint $ "Discovery: no (supported) address received for " <> show dgst + + tryAddresses addrs DiscoveryConnectionRequest conn -> do self <- svcSelf + attrs <- asks svcAttributes let rconn = emptyConnection (dconnSource conn) (dconnTarget conn) if either refDigest id (dconnTarget conn) `elem` identityDigests self then if + -- request for us, create ICE sesssion or tunnel + | dconnTunnel conn -> do + receivedStreams >>= \case + (tunnelReader : _) -> do + tunnelWriter <- openStream + replyPacket $ DiscoveryConnectionResponse rconn + { dconnTunnel = True + } + tunnelVia <- asks svcPeer + tunnelIdentity <- asks svcPeerIdentity + server <- asks svcServer + void $ liftIO $ forkIO $ do + tunnelStreamNumber <- getStreamWriterNumber tunnelWriter + let addr = TunnelAddress {..} + void $ serverPeerCustom server addr + receiveFromTunnel server addr + + [] -> do + svcPrint $ "Discovery: missing stream on tunnel request (endpoint)" + #ifdef ENABLE_ICE_SUPPORT - -- request for us, create ICE sesssion | Just prinfo <- dconnIceInfo conn -> do server <- asks svcServer peer <- asks svcPeer @@ -338,31 +405,72 @@ instance Service DiscoveryService where svcPrint $ "Discovery: unsupported connection request" else do - -- request to some of our peers, relay - mbdp <- M.lookup (either refDigest id $ dconnTarget conn) . dgsPeers <$> svcGetGlobal - case mbdp of + -- request to some of our peers, relay + peer <- asks svcPeer + paddr <- asks svcPeerAddress + mbdp <- M.lookup (either refDigest id $ dconnTarget conn) . dgsPeers <$> svcGetGlobal + streams <- receivedStreams + case mbdp of Nothing -> replyPacket $ DiscoveryConnectionResponse rconn Just dp - | Just dpeer <- dpPeer dp -> do - sendToPeer dpeer $ DiscoveryConnectionRequest conn + | Just dpeer <- dpPeer dp -> if + | dconnTunnel conn -> if + | not (discoveryProvideTunnel attrs peer paddr) -> do + replyPacket $ DiscoveryConnectionResponse rconn + | fromSource : _ <- streams -> do + void $ liftIO $ forkIO $ runPeerService @DiscoveryService dpeer $ do + toTarget <- openStream + svcModify $ \s -> s { dpsRelayedTunnelRequests = + ( either refDigest id $ dconnSource conn, ( fromSource, toTarget )) : dpsRelayedTunnelRequests s } + replyPacket $ DiscoveryConnectionRequest conn + | otherwise -> do + svcPrint $ "Discovery: missing stream on tunnel request (relay)" + | otherwise -> do + sendToPeer dpeer $ DiscoveryConnectionRequest conn | otherwise -> svcPrint $ "Discovery: failed to relay connection request" DiscoveryConnectionResponse conn -> do self <- svcSelf + dps <- svcGet dpeers <- dgsPeers <$> svcGetGlobal + if either refDigest id (dconnSource conn) `elem` identityDigests self - then do + then do -- response to our request, try to connect to the peer server <- asks svcServer - if | Just addr <- dconnAddress conn - , [ipaddr, port] <- words (T.unpack addr) -> do - saddr <- liftIO $ head <$> - getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just ipaddr) (Just port) - peer <- liftIO $ serverPeer server (addrAddress saddr) + if + | Just addr <- dconnAddress conn + , [ addrStr, portStr ] <- words (T.unpack addr) + , Just ipaddr <- readMaybe addrStr + , Just port <- readMaybe portStr + -> do + let saddr = inetToSockAddr ( ipaddr, port ) + peer <- liftIO $ serverPeer server saddr let upd dp = dp { dpPeer = Just peer } svcModifyGlobal $ \s -> s { dgsPeers = M.alter (Just . upd . fromMaybe emptyPeer) (either refDigest id $ dconnTarget conn) $ dgsPeers s } + | dconnTunnel conn + , Just tunnelWriter <- lookup (either refDigest id (dconnTarget conn)) (dpsOurTunnelRequests dps) + -> do + receivedStreams >>= \case + tunnelReader : _ -> do + tunnelVia <- asks svcPeer + tunnelIdentity <- asks svcPeerIdentity + void $ liftIO $ forkIO $ do + tunnelStreamNumber <- getStreamWriterNumber tunnelWriter + let addr = TunnelAddress {..} + void $ serverPeerCustom server addr + receiveFromTunnel server addr + [] -> do + svcPrint $ "Discovery: missing stream in tunnel response" + liftIO $ closeStream tunnelWriter + + | Just tunnelWriter <- lookup (either refDigest id (dconnTarget conn)) (dpsOurTunnelRequests dps) + -> do + svcPrint $ "Discovery: tunnel request failed" + liftIO $ closeStream tunnelWriter + #ifdef ENABLE_ICE_SUPPORT | Just dp <- M.lookup (either refDigest id $ dconnTarget conn) dpeers , Just ice <- dpIceSession dp @@ -371,24 +479,49 @@ instance Service DiscoveryService where #endif | otherwise -> svcPrint $ "Discovery: connection request failed" - else do - -- response to relayed request - case M.lookup (either refDigest id $ dconnSource conn) dpeers of - Just dp | Just dpeer <- dpPeer dp -> do + else do + -- response to relayed request + streams <- receivedStreams + svcModify $ \s -> s { dpsRelayedTunnelRequests = + filter ((either refDigest id (dconnSource conn) /=) . fst) (dpsRelayedTunnelRequests s) } + + case M.lookup (either refDigest id $ dconnSource conn) dpeers of + Just dp | Just dpeer <- dpPeer dp -> if + -- successful tunnel request + | dconnTunnel conn + , Just ( fromSource, toTarget ) <- lookup (either refDigest id (dconnSource conn)) (dpsRelayedTunnelRequests dps) + , fromTarget : _ <- streams + -> liftIO $ do + toSourceVar <- newEmptyMVar + void $ forkIO $ runPeerService @DiscoveryService dpeer $ do + liftIO . putMVar toSourceVar =<< openStream + svcModify $ \s -> s { dpsRelayedTunnelRequests = + ( either refDigest id $ dconnSource conn, ( fromSource, toTarget )) : dpsRelayedTunnelRequests s } + replyPacket $ DiscoveryConnectionResponse conn + void $ forkIO $ do + relayStream fromSource toTarget + void $ forkIO $ do + toSource <- readMVar toSourceVar + relayStream fromTarget toSource + + -- failed tunnel request + | Just ( _, toTarget ) <- lookup (either refDigest id (dconnSource conn)) (dpsRelayedTunnelRequests dps) + -> do + liftIO $ closeStream toTarget sendToPeer dpeer $ DiscoveryConnectionResponse conn - _ -> svcPrint $ "Discovery: failed to relay connection response" + + | otherwise -> do + sendToPeer dpeer $ DiscoveryConnectionResponse conn + _ -> svcPrint $ "Discovery: failed to relay connection response" serviceNewPeer = do server <- asks svcServer peer <- asks svcPeer - let addrToText saddr = do - ( addr, port ) <- IP.fromSockAddr saddr - Just $ T.pack $ show addr <> " " <> show port addrs <- concat <$> sequence - [ catMaybes . map addrToText <$> liftIO (getServerAddresses server) + [ catMaybes . map (fmap (uncurry DiscoveryIP) . inetFromSockAddr) <$> liftIO (getServerAddresses server) #ifdef ENABLE_ICE_SUPPORT - , return [ T.pack "ICE" ] + , return [ DiscoveryICE ] #endif ] @@ -437,7 +570,7 @@ discoverySearch :: (MonadIO m, MonadError e m, FromErebosError e) => Server -> R discoverySearch server dgst = do peers <- liftIO $ getCurrentPeerList server match <- forM peers $ \peer -> do - peerIdentity peer >>= \case + getPeerIdentity peer >>= \case PeerIdentityFull pid -> do return $ dgst `elem` identityDigests pid _ -> return False @@ -447,3 +580,70 @@ discoverySearch server dgst = do } forM_ peers $ \peer -> do sendToPeer peer $ DiscoverySearch $ Right dgst + + +data TunnelAddress = TunnelAddress + { tunnelVia :: Peer + , tunnelIdentity :: UnifiedIdentity + , tunnelStreamNumber :: Int + , tunnelReader :: StreamReader + , tunnelWriter :: StreamWriter + } + +instance Eq TunnelAddress where + x == y = (==) + (idData (tunnelIdentity x), tunnelStreamNumber x) + (idData (tunnelIdentity y), tunnelStreamNumber y) + +instance Ord TunnelAddress where + compare x y = compare + (idData (tunnelIdentity x), tunnelStreamNumber x) + (idData (tunnelIdentity y), tunnelStreamNumber y) + +instance Show TunnelAddress where + show tunnel = concat + [ "tunnel@" + , show $ refDigest $ storedRef $ idData $ tunnelIdentity tunnel + , "/" <> show (tunnelStreamNumber tunnel) + ] + +instance PeerAddressType TunnelAddress where + sendBytesToAddress TunnelAddress {..} bytes = do + writeStream tunnelWriter bytes + + connectionToAddressClosed TunnelAddress {..} = do + closeStream tunnelWriter + +relayStream :: StreamReader -> StreamWriter -> IO () +relayStream r w = do + p <- readStreamPacket r + writeStreamPacket w p + case p of + StreamClosed {} -> return () + _ -> relayStream r w + +receiveFromTunnel :: Server -> TunnelAddress -> IO () +receiveFromTunnel server taddr = do + p <- readStreamPacket (tunnelReader taddr) + case p of + StreamData {..} -> do + receivedFromCustomAddress server taddr stpData + receiveFromTunnel server taddr + StreamClosed {} -> do + return () + + +discoverySetupTunnel :: Peer -> RefDigest -> IO () +discoverySetupTunnel via target = do + runPeerService via $ do + discoverySetupTunnelResponse target + +discoverySetupTunnelResponse :: RefDigest -> ServiceHandler DiscoveryService () +discoverySetupTunnelResponse target = do + self <- refDigest . storedRef . idData <$> svcSelf + stream <- openStream + svcModify $ \s -> s { dpsOurTunnelRequests = ( target, stream ) : dpsOurTunnelRequests s } + replyPacket $ DiscoveryConnectionRequest + (emptyConnection (Right self) (Right target)) + { dconnTunnel = True + } diff --git a/src/Erebos/ICE.chs b/src/Erebos/ICE.chs index dceeb2c..a3dd9bc 100644 --- a/src/Erebos/ICE.chs +++ b/src/Erebos/ICE.chs @@ -16,7 +16,7 @@ module Erebos.ICE ( iceConnect, iceSend, - iceSetChan, + serverPeerIce, ) where import Control.Arrow @@ -32,7 +32,6 @@ import Data.Text (Text) import Data.Text qualified as T import Data.Text.Encoding qualified as T import Data.Text.Read qualified as T -import Data.Void import Data.Word import Foreign.C.String @@ -43,7 +42,7 @@ import Foreign.Marshal.Array import Foreign.Ptr import Foreign.StablePtr -import Erebos.Flow +import Erebos.Network import Erebos.Object import Erebos.Storable import Erebos.Storage @@ -53,7 +52,7 @@ import Erebos.Storage data IceSession = IceSession { isStrans :: PjIceStrans , _isConfig :: IceConfig - , isChan :: MVar (Either [ByteString] (Flow Void ByteString)) + , isChan :: MVar (Either [ ByteString ] (ByteString -> IO ())) } instance Eq IceSession where @@ -65,6 +64,10 @@ instance Ord IceSession where instance Show IceSession where show _ = "<ICE>" +instance PeerAddressType IceSession where + sendBytesToAddress = iceSend + connectionToAddressClosed = iceDestroy + data IceRemoteInfo = IceRemoteInfo { iriUsernameFrament :: Text @@ -126,9 +129,9 @@ instance StorableText IceCandidate where data PjIceStransCfg newtype IceConfig = IceConfig (ForeignPtr PjIceStransCfg) -foreign import ccall unsafe "pjproject.h &ice_cfg_free" +foreign import ccall unsafe "pjproject.h &erebos_ice_cfg_free" ice_cfg_free :: FunPtr (Ptr PjIceStransCfg -> IO ()) -foreign import ccall unsafe "pjproject.h ice_cfg_create" +foreign import ccall unsafe "pjproject.h erebos_ice_cfg_create" ice_cfg_create :: CString -> Word16 -> CString -> Word16 -> IO (Ptr PjIceStransCfg) iceCreateConfig :: Maybe ( Text, Word16 ) -> Maybe ( Text, Word16 ) -> IO (Maybe IceConfig) @@ -140,7 +143,7 @@ iceCreateConfig stun turn = then return Nothing else Just . IceConfig <$> newForeignPtr ice_cfg_free cfg -foreign import ccall unsafe "pjproject.h ice_cfg_stop_thread" +foreign import ccall unsafe "pjproject.h erebos_ice_cfg_stop_thread" ice_cfg_stop_thread :: Ptr PjIceStransCfg -> IO () iceStopThread :: IceConfig -> IO () @@ -158,13 +161,13 @@ iceCreateSession icfg@(IceConfig fcfg) role cb = do forkIO $ cb sess sess <- IceSession <$> (withForeignPtr fcfg $ \cfg -> - {#call ice_create #} (castPtr cfg) (fromIntegral $ fromEnum role) (castStablePtrToPtr sptr) (castStablePtrToPtr cbptr) + {#call erebos_ice_create #} (castPtr cfg) (fromIntegral $ fromEnum role) (castStablePtrToPtr sptr) (castStablePtrToPtr cbptr) ) <*> pure icfg <*> (newMVar $ Left []) return $ sess -{#fun ice_destroy as ^ { isStrans `IceSession' } -> `()' #} +{#fun erebos_ice_destroy as iceDestroy { isStrans `IceSession' } -> `()' #} iceRemoteInfo :: IceSession -> IO IceRemoteInfo iceRemoteInfo sess = do @@ -179,7 +182,7 @@ iceRemoteInfo sess = do let cptrs = take maxcand $ iterate (`plusPtr` maxlen) bytes pokeArray carr $ take maxcand cptrs - ncand <- {#call ice_encode_session #} (isStrans sess) ufrag pass def carr (fromIntegral maxlen) (fromIntegral maxcand) + ncand <- {#call erebos_ice_encode_session #} (isStrans sess) ufrag pass def carr (fromIntegral maxlen) (fromIntegral maxcand) if ncand < 0 then fail "failed to generate ICE remote info" else IceRemoteInfo <$> (T.pack <$> peekCString ufrag) @@ -196,13 +199,13 @@ iceShow sess = do iceConnect :: IceSession -> IceRemoteInfo -> (IO ()) -> IO () iceConnect sess remote cb = do cbptr <- newStablePtr $ cb - ice_connect sess cbptr + erebos_ice_connect sess cbptr (iriUsernameFrament remote) (iriPassword remote) (iriDefaultCandidate remote) (iriCandidates remote) -{#fun ice_connect { isStrans `IceSession', castStablePtrToPtr `StablePtr (IO ())', +{#fun erebos_ice_connect { isStrans `IceSession', castStablePtrToPtr `StablePtr (IO ())', withText* `Text', withText* `Text', withText* `Text', withTextArray* `[Text]'& } -> `()' #} withText :: Text -> (Ptr CChar -> IO a) -> IO a @@ -218,19 +221,19 @@ withTextArray tsAll f = helper tsAll [] withByteStringLen :: Num n => ByteString -> ((Ptr CChar, n) -> IO a) -> IO a withByteStringLen t f = unsafeUseAsCStringLen t (f . (id *** fromIntegral)) -{#fun ice_send as ^ { isStrans `IceSession', withByteStringLen* `ByteString'& } -> `()' #} +{#fun erebos_ice_send as iceSend { isStrans `IceSession', withByteStringLen* `ByteString'& } -> `()' #} foreign export ccall ice_call_cb :: StablePtr (IO ()) -> IO () ice_call_cb :: StablePtr (IO ()) -> IO () ice_call_cb = join . deRefStablePtr -iceSetChan :: IceSession -> Flow Void ByteString -> IO () -iceSetChan sess chan = do +iceSetServer :: IceSession -> Server -> IO () +iceSetServer sess server = do modifyMVar_ (isChan sess) $ \orig -> do case orig of - Left buf -> mapM_ (writeFlowIO chan) $ reverse buf + Left buf -> mapM_ (receivedFromCustomAddress server sess) $ reverse buf Right _ -> return () - return $ Right chan + return $ Right $ receivedFromCustomAddress server sess foreign export ccall ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () @@ -238,5 +241,12 @@ ice_rx_data sptr buf len = do sess <- deRefStablePtr sptr bs <- packCStringLen (buf, len) modifyMVar_ (isChan sess) $ \case - mc@(Right chan) -> writeFlowIO chan bs >> return mc - Left bss -> return $ Left (bs:bss) + mc@(Right sendToServer) -> sendToServer bs >> return mc + Left bss -> return $ Left (bs : bss) + + +serverPeerIce :: Server -> IceSession -> IO Peer +serverPeerIce server ice = do + peer <- serverPeerCustom server ice + iceSetServer ice server + return peer diff --git a/src/Erebos/ICE/pjproject.c b/src/Erebos/ICE/pjproject.c index e9446fe..8d91eac 100644 --- a/src/Erebos/ICE/pjproject.c +++ b/src/Erebos/ICE/pjproject.c @@ -78,7 +78,7 @@ static void cb_on_ice_complete(pj_ice_strans * strans, { if (status != PJ_SUCCESS) { ice_perror("cb_on_ice_complete", status); - ice_destroy(strans); + erebos_ice_destroy(strans); return; } @@ -139,7 +139,7 @@ exit: pthread_mutex_unlock(&mutex); } -struct erebos_ice_cfg * ice_cfg_create( const char * stun_server, uint16_t stun_port, +struct erebos_ice_cfg * erebos_ice_cfg_create( const char * stun_server, uint16_t stun_port, const char * turn_server, uint16_t turn_port ) { ice_init(); @@ -189,11 +189,11 @@ struct erebos_ice_cfg * ice_cfg_create( const char * stun_server, uint16_t stun_ return ecfg; fail: - ice_cfg_free( ecfg ); + erebos_ice_cfg_free( ecfg ); return NULL; } -void ice_cfg_free( struct erebos_ice_cfg * ecfg ) +void erebos_ice_cfg_free( struct erebos_ice_cfg * ecfg ) { if( ! ecfg ) return; @@ -216,14 +216,14 @@ void ice_cfg_free( struct erebos_ice_cfg * ecfg ) free( ecfg ); } -void ice_cfg_stop_thread( struct erebos_ice_cfg * ecfg ) +void erebos_ice_cfg_stop_thread( struct erebos_ice_cfg * ecfg ) { if( ! ecfg ) return; ecfg->exit = true; } -pj_ice_strans * ice_create( const struct erebos_ice_cfg * ecfg, pj_ice_sess_role role, +pj_ice_strans * erebos_ice_create( const struct erebos_ice_cfg * ecfg, pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb ) { ice_init(); @@ -249,7 +249,7 @@ pj_ice_strans * ice_create( const struct erebos_ice_cfg * ecfg, pj_ice_sess_role return res; } -void ice_destroy(pj_ice_strans * strans) +void erebos_ice_destroy(pj_ice_strans * strans) { struct user_data * udata = pj_ice_strans_get_user_data(strans); if (udata->sptr) @@ -264,7 +264,7 @@ void ice_destroy(pj_ice_strans * strans) pj_ice_strans_destroy(strans); } -ssize_t ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, +ssize_t erebos_ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, char * def, char * candidates[], size_t maxlen, size_t maxcand) { int n; @@ -318,7 +318,7 @@ ssize_t ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, return cand_cnt; } -void ice_connect(pj_ice_strans * strans, HsStablePtr cb, +void erebos_ice_connect(pj_ice_strans * strans, HsStablePtr cb, const char * ufrag, const char * pass, const char * defcand, const char * tcandidates[], size_t ncand) { @@ -409,7 +409,7 @@ void ice_connect(pj_ice_strans * strans, HsStablePtr cb, } } -void ice_send(pj_ice_strans * strans, const char * data, size_t len) +void erebos_ice_send(pj_ice_strans * strans, const char * data, size_t len) { if (!pj_ice_strans_sess_is_complete(strans)) { fprintf(stderr, "ICE: negotiation has not been started or is in progress\n"); diff --git a/src/Erebos/ICE/pjproject.h b/src/Erebos/ICE/pjproject.h index c31e227..7a1b96d 100644 --- a/src/Erebos/ICE/pjproject.h +++ b/src/Erebos/ICE/pjproject.h @@ -3,18 +3,18 @@ #include <pjnath.h> #include <HsFFI.h> -struct erebos_ice_cfg * ice_cfg_create( const char * stun_server, uint16_t stun_port, +struct erebos_ice_cfg * erebos_ice_cfg_create( const char * stun_server, uint16_t stun_port, const char * turn_server, uint16_t turn_port ); -void ice_cfg_free( struct erebos_ice_cfg * cfg ); -void ice_cfg_stop_thread( struct erebos_ice_cfg * cfg ); +void erebos_ice_cfg_free( struct erebos_ice_cfg * cfg ); +void erebos_ice_cfg_stop_thread( struct erebos_ice_cfg * cfg ); -pj_ice_strans * ice_create( const struct erebos_ice_cfg *, pj_ice_sess_role role, +pj_ice_strans * erebos_ice_create( const struct erebos_ice_cfg *, pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb ); -void ice_destroy(pj_ice_strans * strans); +void erebos_ice_destroy(pj_ice_strans * strans); -ssize_t ice_encode_session(pj_ice_strans *, char * ufrag, char * pass, +ssize_t erebos_ice_encode_session(pj_ice_strans *, char * ufrag, char * pass, char * def, char * candidates[], size_t maxlen, size_t maxcand); -void ice_connect(pj_ice_strans * strans, HsStablePtr cb, +void erebos_ice_connect(pj_ice_strans * strans, HsStablePtr cb, const char * ufrag, const char * pass, const char * defcand, const char * candidates[], size_t ncand); -void ice_send(pj_ice_strans *, const char * data, size_t len); +void erebos_ice_send(pj_ice_strans *, const char * data, size_t len); diff --git a/src/Erebos/Identity.hs b/src/Erebos/Identity.hs index a3f17b5..491df6e 100644 --- a/src/Erebos/Identity.hs +++ b/src/Erebos/Identity.hs @@ -214,29 +214,33 @@ isExtension x = case fromSigned x of BaseIdentityData {} -> False _ -> True -createIdentity :: Storage -> Maybe Text -> Maybe UnifiedIdentity -> IO UnifiedIdentity -createIdentity st name owner = do - (secret, public) <- generateKeys st - (_secretMsg, publicMsg) <- generateKeys st - - let signOwner :: Signed a -> ReaderT Storage IO (Signed a) +createIdentity + :: forall m e. (MonadStorage m, MonadError e m, FromErebosError e, MonadIO m) + => Maybe Text -> Maybe UnifiedIdentity -> m UnifiedIdentity +createIdentity name owner = do + st <- getStorage + ( secret, public ) <- liftIO $ generateKeys st + ( _secretMsg, publicMsg ) <- liftIO $ generateKeys st + + let signOwner :: Signed a -> m (Signed a) signOwner idd | Just o <- owner = do - Just ownerSecret <- loadKeyMb (iddKeyIdentity $ fromSigned $ idData o) + ownerSecret <- maybe (throwOtherError "failed to load private key") return =<< + loadKeyMb (iddKeyIdentity $ fromSigned $ idData o) signAdd ownerSecret idd | otherwise = return idd - Just identity <- flip runReaderT st $ do - baseData <- mstore =<< signOwner =<< sign secret =<< - mstore (emptyIdentityData public) - { iddOwner = idData <$> owner - , iddKeyMessage = Just publicMsg - } - let extOwner = do - odata <- idExtData <$> owner - guard $ isExtension odata - return odata - + baseData <- mstore =<< signOwner =<< sign secret =<< + mstore (emptyIdentityData public) + { iddOwner = idData <$> owner + , iddKeyMessage = Just publicMsg + } + let extOwner = do + odata <- idExtData <$> owner + guard $ isExtension odata + return odata + + maybe (throwOtherError "created invalid identity") return =<< do validateExtendedIdentityF . I.Identity <$> if isJust name || isJust extOwner then mstore =<< signOwner =<< sign secret =<< @@ -245,7 +249,6 @@ createIdentity st name owner = do , ideOwner = extOwner } else return $ baseToExtended baseData - return identity validateIdentity :: Stored (Signed IdentityData) -> Maybe UnifiedIdentity validateIdentity = validateIdentityF . I.Identity @@ -388,13 +391,13 @@ sameIdentity x y = intersectsSorted (roots x) (roots y) roots idt = uniq $ sort $ concatMap storedRoots $ toList $ idDataF idt -unfoldOwners :: (Foldable m) => Identity m -> [ComposedIdentity] +unfoldOwners :: Foldable m => Identity m -> [ComposedIdentity] unfoldOwners = unfoldr (fmap (\i -> (i, idOwner i))) . Just . toComposedIdentity -finalOwner :: (Foldable m, Applicative m) => Identity m -> ComposedIdentity +finalOwner :: Foldable m => Identity m -> ComposedIdentity finalOwner = last . unfoldOwners -displayIdentity :: (Foldable m, Applicative m) => Identity m -> Text +displayIdentity :: Foldable m => Identity m -> Text displayIdentity identity = T.concat [ T.intercalate (T.pack " / ") $ map (fromMaybe (T.pack "<unnamed>") . idName) owners ] diff --git a/src/Erebos/Invite.hs b/src/Erebos/Invite.hs new file mode 100644 index 0000000..f860fbc --- /dev/null +++ b/src/Erebos/Invite.hs @@ -0,0 +1,213 @@ +module Erebos.Invite ( + Invite(..), + InviteData(..), + InviteService, + InviteServiceAttributes(..), + + createSingleContactInvite, + acceptInvite, +) where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class +import Control.Monad.Reader + +import Crypto.Random + +import Data.ByteString (ByteString) +import Data.ByteString.Char8 qualified as BC +import Data.Foldable +import Data.Ord +import Data.Text (Text) + +import Erebos.Contact +import Erebos.Identity +import Erebos.Network +import Erebos.Object +import Erebos.PubKey +import Erebos.Service +import Erebos.Set +import Erebos.State +import Erebos.Storable +import Erebos.Storage.Merge +import Erebos.Util + + +data Invite = Invite + { inviteData :: [ Stored InviteData ] + , inviteToken :: Maybe ByteString + , inviteAccepted :: [ Stored (Signed ExtendedIdentityData) ] + , inviteContact :: Maybe Text + } + +data InviteData = InviteData + { invdPrev :: [ Stored InviteData ] + , invdToken :: Maybe ByteString + , invdAccepted :: Maybe (Stored (Signed ExtendedIdentityData)) + , invdContact :: Maybe Text + } + +instance Storable InviteData where + store' x = storeRec $ do + mapM_ (storeRef "PREV") $ invdPrev x + mapM_ (storeBinary "token") $ invdToken x + mapM_ (storeRef "accepted") $ invdAccepted x + mapM_ (storeText "contact") $ invdContact x + + load' = loadRec $ InviteData + <$> loadRefs "PREV" + <*> loadMbBinary "token" + <*> loadMbRef "accepted" + <*> loadMbText "contact" + + +instance Mergeable Invite where + type Component Invite = InviteData + + mergeSorted invdata = Invite + { inviteData = invdata + , inviteToken = findPropertyFirst invdToken invdata + , inviteAccepted = findProperty invdAccepted invdata + , inviteContact = findPropertyFirst invdContact invdata + } + + toComponents = inviteData + +instance SharedType (Set Invite) where + sharedTypeID _ = mkSharedTypeID "78da787a-9380-432e-a51d-532a30d27b3d" + + +createSingleContactInvite :: MonadHead LocalState m => Text -> m Invite +createSingleContactInvite name = do + token <- liftIO $ getRandomBytes 32 + invite <- mergeSorted @Invite . (: []) <$> mstore InviteData + { invdPrev = [] + , invdToken = Just token + , invdAccepted = Nothing + , invdContact = Just name + } + updateLocalState_ $ updateSharedState_ $ \invites -> do + storeSetAdd invite invites + return invite + +identityOwnerDigests :: Foldable f => Identity f -> [ RefDigest ] +identityOwnerDigests pid = map (refDigest . storedRef) $ concatMap toList $ toList $ generations $ idExtDataF $ finalOwner pid + +acceptInvite :: (MonadIO m, MonadError e m, FromErebosError e) => Server -> RefDigest -> ByteString -> m () +acceptInvite server from token = do + let matchPeer peer = do + getPeerIdentity peer >>= \case + PeerIdentityFull pid -> do + return $ from `elem` identityOwnerDigests pid + _ -> return False + liftIO (findPeer server matchPeer) >>= \case + Just peer -> runPeerService @InviteService peer $ do + svcModify (token :) + replyPacket $ AcceptInvite token + Nothing -> do + throwOtherError "peer not found" + + +data InviteService + = AcceptInvite ByteString + | InvalidInvite ByteString + | ContactInvite ByteString (Maybe Text) + | UnknownInvitePacket + +data InviteServiceAttributes = InviteServiceAttributes + { inviteHookAccepted :: ByteString -> ServiceHandler InviteService () + , inviteHookReplyContact :: ByteString -> Maybe Text -> ServiceHandler InviteService () + , inviteHookReplyInvalid :: ByteString -> ServiceHandler InviteService () + } + +defaultInviteServiceAttributes :: InviteServiceAttributes +defaultInviteServiceAttributes = InviteServiceAttributes + { inviteHookAccepted = \_ -> return () + , inviteHookReplyContact = \_ _ -> return () + , inviteHookReplyInvalid = \_ -> return () + } + +instance Storable InviteService where + store' x = storeRec $ case x of + AcceptInvite token -> storeBinary "accept" token + InvalidInvite token -> storeBinary "invalid" token + ContactInvite token mbName -> do + storeBinary "valid" token + maybe (storeEmpty "contact") (storeText "contact") mbName + UnknownInvitePacket -> return () + + load' = loadRec $ msum + [ AcceptInvite <$> loadBinary "accept" + , InvalidInvite <$> loadBinary "invalid" + , ContactInvite <$> loadBinary "valid" <*> msum + [ return Nothing <* loadEmpty "contact" + , Just <$> loadText "contact" + ] + , return UnknownInvitePacket + ] + +instance Service InviteService where + serviceID _ = mkServiceID "70bff715-6856-43a0-8c58-007a06a26eb1" + + type ServiceState InviteService = [ ByteString ] -- accepted invites, waiting for reply + emptyServiceState _ = [] + + type ServiceAttributes InviteService = InviteServiceAttributes + defaultServiceAttributes _ = defaultInviteServiceAttributes + + serviceHandler = fromStored >>> \case + AcceptInvite token -> do + asks (inviteHookAccepted . svcAttributes) >>= ($ token) + invites <- fromSetBy (comparing inviteToken) . lookupSharedValue . lsShared . fromStored <$> getLocalHead + case find ((Just token ==) . inviteToken) invites of + Just invite + | Just name <- inviteContact invite + , [] <- inviteAccepted invite + -> do + identity <- asks svcPeerIdentity + cdata <- mstore ContactData + { cdPrev = [] + , cdIdentity = idExtDataF $ finalOwner identity + , cdName = Just name + } + invdata <- mstore InviteData + { invdPrev = inviteData invite + , invdToken = Nothing + , invdAccepted = Just (idExtData identity) + , invdContact = Nothing + } + updateLocalState_ $ updateSharedState_ $ storeSetAdd (mergeSorted @Contact [ cdata ]) + updateLocalState_ $ updateSharedState_ $ storeSetAdd (mergeSorted @Invite [ invdata ]) + replyPacket $ ContactInvite token Nothing + + | otherwise -> do + replyPacket $ InvalidInvite token + + Nothing -> do + replyPacket $ InvalidInvite token + + InvalidInvite token -> do + asks (inviteHookReplyInvalid . svcAttributes) >>= ($ token) + svcModify $ filter (/= token) + svcPrint $ "Invite " <> BC.unpack (showHex token) <> " rejected as invalid" + + ContactInvite token mbName -> do + asks (inviteHookReplyContact . svcAttributes) >>= ($ mbName) . ($ token) + waitingTokens <- svcGet + if token `elem` waitingTokens + then do + svcSet $ filter (/= token) waitingTokens + identity <- asks svcPeerIdentity + cdata <- mstore ContactData + { cdPrev = [] + , cdIdentity = idExtDataF $ finalOwner identity + , cdName = Nothing + } + updateLocalState_ $ updateSharedState_ $ storeSetAdd (mergeSorted @Contact [ cdata ]) + else do + svcPrint $ "Received unexpected invite response for " <> BC.unpack (showHex token) + + UnknownInvitePacket -> do + svcPrint $ "Received unknown invite packet" diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs index b341974..6265bbf 100644 --- a/src/Erebos/Network.hs +++ b/src/Erebos/Network.hs @@ -1,5 +1,3 @@ -{-# LANGUAGE CPP #-} - module Erebos.Network ( Server, startServer, @@ -10,8 +8,8 @@ module Erebos.Network ( ServerOptions(..), serverIdentity, defaultServerOptions, Peer, peerServer, peerStorage, - PeerAddress(..), peerAddress, - PeerIdentity(..), peerIdentity, + PeerAddress(..), getPeerAddress, getPeerAddresses, + PeerIdentity(..), getPeerIdentity, WaitingRef, wrDigest, Service(..), @@ -20,9 +18,7 @@ module Erebos.Network ( serverPeer, serverPeerCustom, -#ifdef ENABLE_ICE_SUPPORT - serverPeerIce, -#endif + findPeer, dropPeer, isPeerDropped, sendToPeer, sendManyToPeer, @@ -66,10 +62,8 @@ import Network.Socket hiding (ControlMessage) import Network.Socket.ByteString qualified as S import Erebos.Error -#ifdef ENABLE_ICE_SUPPORT -import Erebos.ICE -#endif import Erebos.Identity +import Erebos.Network.Address import Erebos.Network.Channel import Erebos.Network.Protocol import Erebos.Object.Internal @@ -121,12 +115,16 @@ getNextPeerChange = atomically . readTChan . serverChanPeer data ServerOptions = ServerOptions { serverPort :: PortNumber , serverLocalDiscovery :: Bool + , serverErrorPrefix :: String + , serverTestLog :: Bool } defaultServerOptions :: ServerOptions defaultServerOptions = ServerOptions { serverPort = discoveryPort , serverLocalDiscovery = True + , serverErrorPrefix = "" + , serverTestLog = False } @@ -141,6 +139,14 @@ data Peer = Peer , peerWaitingRefs :: TMVar [WaitingRef] } +-- | Get current main address of the peer (used to send new packets). +getPeerAddress :: MonadIO m => Peer -> m PeerAddress +getPeerAddress = liftIO . return . peerAddress + +-- | Get all known addresses of given peer. +getPeerAddresses :: MonadIO m => Peer -> m [ PeerAddress ] +getPeerAddresses = fmap (: []) . getPeerAddress + peerServer :: Peer -> Server peerServer = peerServer_ @@ -166,36 +172,24 @@ instance Eq Peer where class (Eq addr, Ord addr, Show addr, Typeable addr) => PeerAddressType addr where sendBytesToAddress :: addr -> ByteString -> IO () + connectionToAddressClosed :: addr -> IO () data PeerAddress = forall addr. PeerAddressType addr => CustomPeerAddress addr | DatagramAddress SockAddr -#ifdef ENABLE_ICE_SUPPORT - | PeerIceSession IceSession -#endif instance Show PeerAddress where show (CustomPeerAddress addr) = show addr - show (DatagramAddress saddr) = unwords $ case IP.fromSockAddr saddr of - Just (IP.IPv6 ipv6, port) - | (0, 0, 0xffff, ipv4) <- IP.fromIPv6w ipv6 - -> [show (IP.toIPv4w ipv4), show port] - Just (addr, port) - -> [show addr, show port] - _ -> [show saddr] - -#ifdef ENABLE_ICE_SUPPORT - show (PeerIceSession ice) = show ice -#endif + show (DatagramAddress saddr) = + case inetFromSockAddr saddr of + Just ( addr, port ) -> unwords [ show addr, show port ] + _ -> show saddr instance Eq PeerAddress where CustomPeerAddress addr == CustomPeerAddress addr' | Just addr'' <- cast addr' = addr == addr'' DatagramAddress addr == DatagramAddress addr' = addr == addr' -#ifdef ENABLE_ICE_SUPPORT - PeerIceSession ice == PeerIceSession ice' = ice == ice' -#endif _ == _ = False instance Ord PeerAddress where @@ -206,20 +200,16 @@ instance Ord PeerAddress where compare _ (CustomPeerAddress _ ) = GT compare (DatagramAddress addr) (DatagramAddress addr') = compare addr addr' -#ifdef ENABLE_ICE_SUPPORT - compare (DatagramAddress _ ) _ = LT - compare _ (DatagramAddress _ ) = GT - compare (PeerIceSession ice ) (PeerIceSession ice') = compare ice ice' -#endif +data PeerIdentity + = PeerIdentityUnknown (TVar [ UnifiedIdentity -> ExceptT ErebosError IO () ]) + | PeerIdentityRef WaitingRef (TVar [ UnifiedIdentity -> ExceptT ErebosError IO () ]) + | PeerIdentityFull UnifiedIdentity -data PeerIdentity = PeerIdentityUnknown (TVar [UnifiedIdentity -> ExceptT ErebosError IO ()]) - | PeerIdentityRef WaitingRef (TVar [UnifiedIdentity -> ExceptT ErebosError IO ()]) - | PeerIdentityFull UnifiedIdentity - -peerIdentity :: MonadIO m => Peer -> m PeerIdentity -peerIdentity = liftIO . atomically . readTVar . peerIdentityVar +-- | Get currently known identity of the given peer +getPeerIdentity :: MonadIO m => Peer -> m PeerIdentity +getPeerIdentity = liftIO . atomically . readTVar . peerIdentityVar data PeerState @@ -277,7 +267,16 @@ startServer serverOptions serverOrigHead logd' serverServices = do let logd = writeTQueue serverErrorLog forkServerThread server $ forever $ do - logd' =<< atomically (readTQueue serverErrorLog) + logd' . (serverErrorPrefix serverOptions <>) =<< atomically (readTQueue serverErrorLog) + + logt <- if + | serverTestLog serverOptions -> do + serverTestLog <- newTQueueIO + forkServerThread server $ forever $ do + logd' =<< atomically (readTQueue serverTestLog) + return $ writeTQueue serverTestLog + | otherwise -> do + return $ \_ -> return () forkServerThread server $ dataResponseWorker server forkServerThread server $ forever $ do @@ -327,13 +326,18 @@ startServer serverOptions serverOrigHead logd' serverServices = do announceUpdate idt forM_ serverServices $ \(SomeService service _) -> do - forM_ (serviceStorageWatchers service) $ \(SomeStorageWatcher sel act) -> do - watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do - withMVar serverPeers $ mapM_ $ \peer -> atomically $ do - readTVar (peerIdentityVar peer) >>= \case - PeerIdentityFull _ -> writeTQueue serverIOActions $ do - runPeerService peer $ act x - _ -> return () + forM_ (serviceStorageWatchers service) $ \case + SomeStorageWatcher sel act -> do + watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do + withMVar serverPeers $ mapM_ $ \peer -> atomically $ do + readTVar (peerIdentityVar peer) >>= \case + PeerIdentityFull _ -> writeTQueue serverIOActions $ do + runPeerService peer $ act x + _ -> return () + GlobalStorageWatcher sel act -> do + watchHeadWith serverOrigHead (sel . headStoredObject) $ \x -> do + atomically $ writeTQueue serverIOActions $ do + act server x forkServerThread server $ forever $ do (msg, saddr) <- S.recvFrom sock 4096 @@ -345,9 +349,6 @@ startServer serverOptions serverOrigHead logd' serverServices = do case paddr of CustomPeerAddress addr -> sendBytesToAddress addr msg DatagramAddress addr -> void $ S.sendTo sock msg addr -#ifdef ENABLE_ICE_SUPPORT - PeerIceSession ice -> iceSend ice msg -#endif forkServerThread server $ forever $ do readFlowIO serverControlFlow >>= \case @@ -389,9 +390,13 @@ startServer serverOptions serverOrigHead logd' serverServices = do prefs <- forM objs $ storeObject $ peerInStorage peer identity <- readMVar serverIdentity_ let svcs = map someServiceID serverServices - handlePacket identity secure peer chanSvc svcs header prefs + handlePacket paddr identity secure peer chanSvc svcs header prefs peerLoop Nothing -> do + case paddr of + DatagramAddress _ -> return () + CustomPeerAddress caddr -> connectionToAddressClosed caddr + dropPeer peer atomically $ writeTChan serverChanPeer peer peerLoop @@ -399,7 +404,7 @@ startServer serverOptions serverOrigHead logd' serverServices = do ReceivedAnnounce addr _ -> do void $ serverPeer' server addr - erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd protocolRawPath protocolControlFlow + erebosNetworkProtocol (headLocalIdentity serverOrigHead) logd logt protocolRawPath protocolControlFlow forkServerThread server $ withSocketsDo $ do let hints = defaultHints @@ -411,9 +416,9 @@ startServer serverOptions serverOrigHead logd' serverServices = do bracket (open addr) close loop forkServerThread server $ forever $ do - ( peer, svc, ref, streams ) <- atomically $ readTQueue chanSvc + ( peer, paddr, svc, ref, streams ) <- atomically $ readTQueue chanSvc case find ((svc ==) . someServiceID) serverServices of - Just service@(SomeService (_ :: Proxy s) attr) -> runPeerServiceOn (Just ( service, attr )) streams peer (serviceHandler $ wrappedLoad @s ref) + Just service@(SomeService (_ :: Proxy s) attr) -> runPeerServiceOn (Just ( service, attr )) streams paddr peer (serviceHandler $ wrappedLoad @s ref) _ -> atomically $ logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" return server @@ -560,10 +565,10 @@ appendDistinct x (y:ys) | x == y = y : ys | otherwise = y : appendDistinct x ys appendDistinct x [] = [x] -handlePacket :: UnifiedIdentity -> Bool - -> Peer -> TQueue ( Peer, ServiceID, Ref, [ RawStreamReader ]) -> [ ServiceID ] +handlePacket :: PeerAddress -> UnifiedIdentity -> Bool + -> Peer -> TQueue ( Peer, PeerAddress, ServiceID, Ref, [ RawStreamReader ] ) -> [ ServiceID ] -> TransportHeader -> [ PartialRef ] -> IO () -handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = atomically $ do +handlePacket paddr identity secure peer chanSvc svcs (TransportHeader headers) prefs = atomically $ do let server = peerServer peer ochannel <- getPeerChannel peer let sidentity = idData identity @@ -699,7 +704,7 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = then do streamReaders <- mapM acceptStream $ lookupNewStreams headers void $ newWaitingRef dgst $ \ref -> - liftIO $ atomically $ writeTQueue chanSvc ( peer, svc, ref, streamReaders ) + liftIO $ atomically $ writeTQueue chanSvc ( peer, paddr, svc, ref, streamReaders ) else throwError $ "missing service object " ++ show dgst | otherwise -> addHeader $ Rejected dgst | otherwise -> throwError $ "service ref without type" @@ -779,7 +784,7 @@ finalizedChannel peer@Peer {..} ch self = do -- Notify services about new peer readTVar peerIdentityVar >>= \case - PeerIdentityFull _ -> notifyServicesOfPeer peer + PeerIdentityFull _ -> notifyServicesOfPeer True peer _ -> return () @@ -805,7 +810,7 @@ handleIdentityAnnounce self peer ref = liftIO $ atomically $ do PeerIdentityFull pid | idData pid `precedes` wrappedLoad ref -> validateAndUpdate (idUpdates pid) $ \_ -> do - notifyServicesOfPeer peer + notifyServicesOfPeer False peer _ -> return () @@ -817,15 +822,18 @@ handleIdentityUpdate peer ref = liftIO $ atomically $ do -> do writeTVar (peerIdentityVar peer) $ PeerIdentityFull pid' writeTChan (serverChanPeer $ peerServer peer) peer - when (idData pid /= idData pid') $ notifyServicesOfPeer peer + when (pid /= pid') $ do + notifyServicesOfPeer False peer | otherwise -> return () -notifyServicesOfPeer :: Peer -> STM () -notifyServicesOfPeer peer@Peer { peerServer_ = Server {..} } = do +notifyServicesOfPeer :: Bool -> Peer -> STM () +notifyServicesOfPeer new peer@Peer { peerServer_ = Server {..} } = do writeTQueue serverIOActions $ do + paddr <- getPeerAddress peer forM_ serverServices $ \service@(SomeService _ attrs) -> - runPeerServiceOn (Just ( service, attrs )) [] peer serviceNewPeer + runPeerServiceOn (Just ( service, attrs )) [] paddr peer $ + if new then serviceNewPeer else serviceUpdatedPeer receivedFromCustomAddress :: PeerAddressType addr => Server -> addr -> ByteString -> IO () @@ -853,15 +861,6 @@ serverPeer server paddr = do serverPeerCustom :: PeerAddressType addr => Server -> addr -> IO Peer serverPeerCustom server addr = serverPeer' server (CustomPeerAddress addr) -#ifdef ENABLE_ICE_SUPPORT -serverPeerIce :: Server -> IceSession -> IO Peer -serverPeerIce server@Server {..} ice = do - let paddr = PeerIceSession ice - peer <- serverPeer' server paddr - iceSetChan ice $ mapFlow undefined (paddr,) serverRawPath - return peer -#endif - serverPeer' :: Server -> PeerAddress -> IO Peer serverPeer' server paddr = do (peer, hello) <- modifyMVar (serverPeers server) $ \pvalue -> do @@ -874,6 +873,13 @@ serverPeer' server paddr = do writeFlow (serverControlFlow server) (RequestConnection paddr) return peer +findPeer :: Server -> (Peer -> IO Bool) -> IO (Maybe Peer) +findPeer server test = withMVar (serverPeers server) (helper . M.elems) + where + helper (p : ps) = test p >>= \case True -> return (Just p) + False -> helper ps + helper [] = return Nothing + dropPeer :: MonadIO m => Peer -> m () dropPeer peer = liftIO $ do modifyMVar_ (serverPeers $ peerServer peer) $ \pvalue -> do @@ -983,10 +989,12 @@ lookupService proxy (service@(SomeService (_ :: Proxy t) attr) : rest) lookupService _ [] = Nothing runPeerService :: forall s m. (Service s, MonadIO m) => Peer -> ServiceHandler s () -> m () -runPeerService = runPeerServiceOn Nothing [] +runPeerService peer handler = do + paddr <- getPeerAddress peer + runPeerServiceOn Nothing [] paddr peer handler -runPeerServiceOn :: forall s m. (Service s, MonadIO m) => Maybe ( SomeService, ServiceAttributes s ) -> [ RawStreamReader ] -> Peer -> ServiceHandler s () -> m () -runPeerServiceOn mbservice newStreams peer handler = liftIO $ do +runPeerServiceOn :: forall s m. (Service s, MonadIO m) => Maybe ( SomeService, ServiceAttributes s ) -> [ RawStreamReader ] -> PeerAddress -> Peer -> ServiceHandler s () -> m () +runPeerServiceOn mbservice newStreams paddr peer handler = liftIO $ do let server = peerServer peer proxy = Proxy @s svc = serviceID proxy @@ -1008,6 +1016,7 @@ runPeerServiceOn mbservice newStreams peer handler = liftIO $ do let inp = ServiceInput { svcAttributes = attr , svcPeer = peer + , svcPeerAddress = paddr , svcPeerIdentity = peerId , svcServer = server , svcPrintOp = atomically . logd @@ -1027,7 +1036,7 @@ runPeerServiceOn mbservice newStreams peer handler = liftIO $ do putTMVar (peerServiceState peer) $ M.insert svc (SomeServiceState proxy s') svcs putTMVar (serverServiceStates server) $ M.insert svc (SomeServiceGlobalState proxy gs') global _ -> do - atomically $ logd $ "can't run service handler on peer with incomplete identity " ++ show (peerAddress peer) + atomically $ logd $ "can't run service handler on peer with incomplete identity " ++ show paddr _ -> atomically $ do logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" @@ -1054,30 +1063,11 @@ modifyServiceGlobalState server proxy f = do throwOtherError $ "unhandled service '" ++ show (toUUID svc) ++ "'" -foreign import ccall unsafe "Network/ifaddrs.h join_multicast" cJoinMulticast :: CInt -> Ptr CSize -> IO (Ptr Word32) -foreign import ccall unsafe "Network/ifaddrs.h local_addresses" cLocalAddresses :: Ptr CSize -> IO (Ptr InetAddress) -foreign import ccall unsafe "Network/ifaddrs.h broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32) +foreign import ccall unsafe "Network/ifaddrs.h erebos_join_multicast" cJoinMulticast :: CInt -> Ptr CSize -> IO (Ptr Word32) +foreign import ccall unsafe "Network/ifaddrs.h erebos_local_addresses" cLocalAddresses :: Ptr CSize -> IO (Ptr InetAddress) +foreign import ccall unsafe "Network/ifaddrs.h erebos_broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32) foreign import ccall unsafe "stdlib.h free" cFree :: Ptr a -> IO () -data InetAddress = InetAddress { fromInetAddress :: IP.IP } - -instance F.Storable InetAddress where - sizeOf _ = sizeOf (undefined :: CInt) + 16 - alignment _ = 8 - - peek ptr = (unpackFamily <$> peekByteOff ptr 0) >>= \case - AF_INET -> InetAddress . IP.IPv4 . IP.fromHostAddress <$> peekByteOff ptr (sizeOf (undefined :: CInt)) - AF_INET6 -> InetAddress . IP.IPv6 . IP.toIPv6b . map fromIntegral <$> peekArray 16 (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) - _ -> fail "InetAddress: unknown family" - - poke ptr (InetAddress addr) = case addr of - IP.IPv4 ip -> do - pokeByteOff ptr 0 (packFamily AF_INET) - pokeByteOff ptr (sizeOf (undefined :: CInt)) (IP.toHostAddress ip) - IP.IPv6 ip -> do - pokeByteOff ptr 0 (packFamily AF_INET6) - pokeArray (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) (map fromIntegral $ IP.fromIPv6b ip) - joinMulticast :: Socket -> IO [ Word32 ] joinMulticast sock = withFdSocket sock $ \fd -> @@ -1104,7 +1094,7 @@ getServerAddresses Server {..} = do count <- fromIntegral <$> peek pcount res <- peekArray count ptr cFree ptr - return $ map (IP.toSockAddr . (, serverPort serverOptions ) . fromInetAddress) res + return $ map (inetToSockAddr . (, serverPort serverOptions )) res getBroadcastAddresses :: PortNumber -> IO [SockAddr] getBroadcastAddresses port = do diff --git a/src/Erebos/Network.hs-boot b/src/Erebos/Network.hs-boot index af77581..17a5275 100644 --- a/src/Erebos/Network.hs-boot +++ b/src/Erebos/Network.hs-boot @@ -4,5 +4,6 @@ import Erebos.Object.Internal data Server data Peer +data PeerAddress peerStorage :: Peer -> Storage diff --git a/src/Erebos/Network/Address.hs b/src/Erebos/Network/Address.hs new file mode 100644 index 0000000..63f6af1 --- /dev/null +++ b/src/Erebos/Network/Address.hs @@ -0,0 +1,65 @@ +module Erebos.Network.Address ( + InetAddress(..), + inetFromSockAddr, + inetToSockAddr, + + SockAddr, PortNumber, +) where + +import Data.Bifunctor +import Data.IP qualified as IP +import Data.Word + +import Foreign.C.Types +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.Storable as F + +import Network.Socket + +import Text.Read + + +newtype InetAddress = InetAddress { fromInetAddress :: IP.IP } + deriving (Eq, Ord) + +instance Show InetAddress where + show (InetAddress ipaddr) + | IP.IPv6 ipv6 <- ipaddr + , ( 0, 0, 0xffff, ipv4 ) <- IP.fromIPv6w ipv6 + = show (IP.toIPv4w ipv4) + + | otherwise + = show ipaddr + +instance Read InetAddress where + readPrec = do + readPrec >>= return . InetAddress . \case + IP.IPv4 ipv4 -> IP.IPv6 $ IP.toIPv6w ( 0, 0, 0xffff, IP.fromIPv4w ipv4 ) + ipaddr -> ipaddr + + readListPrec = readListPrecDefault + +instance F.Storable InetAddress where + sizeOf _ = sizeOf (undefined :: CInt) + 16 + alignment _ = 8 + + peek ptr = (unpackFamily <$> peekByteOff ptr 0) >>= \case + AF_INET -> InetAddress . IP.IPv4 . IP.fromHostAddress <$> peekByteOff ptr (sizeOf (undefined :: CInt)) + AF_INET6 -> InetAddress . IP.IPv6 . IP.toIPv6b . map fromIntegral <$> peekArray 16 (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) + _ -> fail "InetAddress: unknown family" + + poke ptr (InetAddress addr) = case addr of + IP.IPv4 ip -> do + pokeByteOff ptr 0 (packFamily AF_INET) + pokeByteOff ptr (sizeOf (undefined :: CInt)) (IP.toHostAddress ip) + IP.IPv6 ip -> do + pokeByteOff ptr 0 (packFamily AF_INET6) + pokeArray (ptr `plusPtr` sizeOf (undefined :: CInt) :: Ptr Word8) (map fromIntegral $ IP.fromIPv6b ip) + + +inetFromSockAddr :: SockAddr -> Maybe ( InetAddress, PortNumber ) +inetFromSockAddr saddr = first InetAddress <$> IP.fromSockAddr saddr + +inetToSockAddr :: ( InetAddress, PortNumber ) -> SockAddr +inetToSockAddr = IP.toSockAddr . first fromInetAddress diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index 025f52c..f67e296 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -213,6 +213,7 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState , gControlFlow :: Flow (ControlRequest addr) (ControlMessage addr) , gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject)) , gLog :: String -> STM () + , gTestLog :: String -> STM () , gStorage :: PartialStorage , gStartTime :: TimeSpec , gNowVar :: TVar TimeSpec @@ -249,6 +250,12 @@ instance Eq (Connection addr) where connAddress :: Connection addr -> addr connAddress = cAddress +showConnAddress :: forall addr. Connection addr -> String +showConnAddress Connection {..} = helper cGlobalState cAddress + where + helper :: GlobalState addr -> addr -> String + helper GlobalState {} = show + connData :: Connection addr -> Flow (Maybe (Bool, TransportPacket PartialObject)) (SecurityRequirement, TransportPacket Ref, [TransportHeaderItem]) @@ -273,6 +280,7 @@ connClose conn@Connection {..} = do connAddWriteStream :: Connection addr -> STM (Either String (TransportHeaderItem, RawStreamWriter, IO ())) connAddWriteStream conn@Connection {..} = do + let GlobalState {..} = cGlobalState outStreams <- readTVar cOutStreams let doInsert :: Word8 -> [(Word8, Stream)] -> ExceptT String STM ((Word8, Stream), [(Word8, Stream)]) doInsert n (s@(n', _) : rest) | n == n' = @@ -289,14 +297,16 @@ connAddWriteStream conn@Connection {..} = do runExceptT $ do ((streamNumber, stream), outStreams') <- doInsert 1 outStreams lift $ writeTVar cOutStreams outStreams' + lift $ gTestLog $ "net-ostream-open " <> showConnAddress conn <> " " <> show streamNumber <> " " <> show (length outStreams') return ( StreamOpen streamNumber , RawStreamWriter (fromIntegral streamNumber) (sFlowIn stream) - , go cGlobalState streamNumber stream + , go streamNumber stream ) where - go gs@GlobalState {..} streamNumber stream = do + go streamNumber stream = do + let GlobalState {..} = cGlobalState (reserved, msg) <- atomically $ do readTVar (sState stream) >>= \case StreamRunning -> return () @@ -309,6 +319,8 @@ connAddWriteStream conn@Connection {..} = do return (stpData, True, return ()) StreamClosed {} -> do atomically $ do + gTestLog $ "net-ostream-close-send " <> showConnAddress conn <> " " <> show streamNumber + atomically $ do -- wait for ack on all sent stream data waits <- readTVar (sWaitingForAck stream) when (waits > 0) retry @@ -352,7 +364,7 @@ connAddWriteStream conn@Connection {..} = do sendBytes conn mbReserved' bs Nothing -> return () - when cont $ go gs streamNumber stream + when cont $ go streamNumber stream connAddReadStream :: Connection addr -> Word8 -> STM RawStreamReader connAddReadStream Connection {..} streamNumber = do @@ -411,8 +423,10 @@ streamAccepted Connection {..} snum = atomically $ do Nothing -> return () streamClosed :: Connection addr -> Word8 -> IO () -streamClosed Connection {..} snum = atomically $ do - modifyTVar' cOutStreams $ filter ((snum /=) . fst) +streamClosed conn@Connection {..} snum = atomically $ do + streams <- filter ((snum /=) . fst) <$> readTVar cOutStreams + writeTVar cOutStreams streams + gTestLog cGlobalState $ "net-ostream-close-ack " <> showConnAddress conn <> " " <> show snum <> " " <> show (length streams) readStreamToList :: RawStreamReader -> IO (Word64, [(Word64, BC.ByteString)]) readStreamToList stream = readFlowIO (rsrFlow stream) >>= \case @@ -494,10 +508,11 @@ data ControlMessage addr = NewConnection (Connection addr) (Maybe RefDigest) erebosNetworkProtocol :: (Eq addr, Ord addr, Show addr) => UnifiedIdentity -> (String -> STM ()) + -> (String -> STM ()) -> SymFlow (addr, ByteString) -> Flow (ControlRequest addr) (ControlMessage addr) -> IO () -erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do +erebosNetworkProtocol initialIdentity gLog gTestLog gDataFlow gControlFlow = do gIdentity <- newTVarIO (initialIdentity, []) gConnections <- newTVarIO [] gNextUp <- newEmptyTMVarIO @@ -561,6 +576,7 @@ newConnection cGlobalState@GlobalState {..} addr = do cOutStreams <- newTVar [] let conn = Connection {..} + gTestLog $ "net-conn-new " <> show cAddress writeTVar gConnections (conn : conns) return conn @@ -917,7 +933,10 @@ processOutgoing gs@GlobalState {..} = do , rsOnAck = rsOnAck rs >> onAck }) <$> mbReserved sendBytes conn mbReserved' bs - Nothing -> return () + Nothing -> do + when (isJust mbReserved) $ do + atomically $ do + modifyTVar' cReservedPackets (subtract 1) let waitUntil :: TimeSpec -> TimeSpec -> STM () waitUntil now till = do diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c index ff4382a..8139b5e 100644 --- a/src/Erebos/Network/ifaddrs.c +++ b/src/Erebos/Network/ifaddrs.c @@ -22,7 +22,7 @@ #define DISCOVERY_MULTICAST_GROUP "ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1" -uint32_t * join_multicast(int fd, size_t * count) +uint32_t * erebos_join_multicast(int fd, size_t * count) { size_t capacity = 16; *count = 0; @@ -117,7 +117,7 @@ static bool copy_local_address( struct InetAddress * dst, const struct sockaddr #ifndef _WIN32 -struct InetAddress * local_addresses( size_t * count ) +struct InetAddress * erebos_local_addresses( size_t * count ) { struct ifaddrs * addrs; if( getifaddrs( &addrs ) < 0 ) @@ -153,7 +153,7 @@ struct InetAddress * local_addresses( size_t * count ) return ret; } -uint32_t * broadcast_addresses(void) +uint32_t * erebos_broadcast_addresses(void) { struct ifaddrs * addrs; if (getifaddrs(&addrs) < 0) @@ -196,7 +196,7 @@ uint32_t * broadcast_addresses(void) #pragma comment(lib, "ws2_32.lib") -struct InetAddress * local_addresses( size_t * count ) +struct InetAddress * erebos_local_addresses( size_t * count ) { * count = 0; struct InetAddress * ret = NULL; @@ -237,7 +237,7 @@ cleanup: return ret; } -uint32_t * broadcast_addresses(void) +uint32_t * erebos_broadcast_addresses(void) { uint32_t * ret = NULL; SOCKET wsock = INVALID_SOCKET; diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h index 2ee45a7..2b3c014 100644 --- a/src/Erebos/Network/ifaddrs.h +++ b/src/Erebos/Network/ifaddrs.h @@ -13,6 +13,6 @@ struct InetAddress uint8_t addr[16]; } __attribute__((packed)); -uint32_t * join_multicast(int fd, size_t * count); -struct InetAddress * local_addresses( size_t * count ); -uint32_t * broadcast_addresses(void); +uint32_t * erebos_join_multicast(int fd, size_t * count); +struct InetAddress * erebos_local_addresses( size_t * count ); +uint32_t * erebos_broadcast_addresses(void); diff --git a/src/Erebos/Object/Internal.hs b/src/Erebos/Object/Internal.hs index 97ca7a3..fe00579 100644 --- a/src/Erebos/Object/Internal.hs +++ b/src/Erebos/Object/Internal.hs @@ -55,26 +55,27 @@ import Control.Monad.Writer import Crypto.Hash import Data.Bifunctor +import Data.ByteArray qualified as BA import Data.ByteString (ByteString) -import qualified Data.ByteArray as BA -import qualified Data.ByteString as B -import qualified Data.ByteString.Char8 as BC -import qualified Data.ByteString.Lazy as BL -import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.ByteString qualified as B +import Data.ByteString.Char8 qualified as BC +import Data.ByteString.Lazy qualified as BL +import Data.ByteString.Lazy.Char8 qualified as BLC import Data.Char import Data.Function import Data.Maybe import Data.Ratio import Data.Set (Set) -import qualified Data.Set as S +import Data.Set qualified as S import Data.Text (Text) -import qualified Data.Text as T +import Data.Text qualified as T import Data.Text.Encoding import Data.Text.Encoding.Error import Data.Time.Calendar import Data.Time.Clock import Data.Time.Format import Data.Time.LocalTime +import Data.Word import System.IO.Unsafe @@ -129,6 +130,7 @@ copyRecItem' st = \case copyObject' :: forall c c'. (StorageCompleteness c, StorageCompleteness c') => Storage' c' -> Object' c -> IO (c (Object' c')) copyObject' _ (Blob bs) = return $ return $ Blob bs copyObject' st (Rec rs) = fmap Rec . sequence <$> mapM (\( n, item ) -> fmap ( n, ) <$> copyRecItem' st item) rs +copyObject' _ (OnDemand size dgst) = return $ return $ OnDemand size dgst copyObject' _ ZeroObject = return $ return ZeroObject copyObject' _ (UnknownObject otype content) = return $ return $ UnknownObject otype content @@ -150,7 +152,8 @@ partialRefFromDigest st dgst = Ref st dgst data Object' c = Blob ByteString - | Rec [(ByteString, RecItem' c)] + | Rec [ ( ByteString, RecItem' c ) ] + | OnDemand Word64 RefDigest | ZeroObject | UnknownObject ByteString ByteString deriving (Show) @@ -176,8 +179,12 @@ type RecItem = RecItem' Complete serializeObject :: Object' c -> BL.ByteString serializeObject = \case Blob cnt -> BL.fromChunks [BC.pack "blob ", BC.pack (show $ B.length cnt), BC.singleton '\n', cnt] - Rec rec -> let cnt = BL.fromChunks $ concatMap (uncurry serializeRecItem) rec - in BL.fromChunks [BC.pack "rec ", BC.pack (show $ BL.length cnt), BC.singleton '\n'] `BL.append` cnt + Rec rec -> + let cnt = BL.fromChunks $ concatMap (uncurry serializeRecItem) rec + in BL.fromChunks [ BC.pack "rec ", BC.pack (show $ BL.length cnt), BC.singleton '\n' ] `BL.append` cnt + OnDemand size dgst -> + let cnt = BC.unlines [ BC.pack (show size), showRefDigest dgst ] + in BL.fromChunks [ BC.pack "ondemand ", BC.pack (show $ B.length cnt), BC.singleton '\n', cnt ] ZeroObject -> BL.empty UnknownObject otype cnt -> BL.fromChunks [ otype, BC.singleton ' ', BC.pack (show $ B.length cnt), BC.singleton '\n', cnt ] @@ -236,46 +243,72 @@ unsafeDeserializeObject st bytes = (line, rest) | Just (otype, len) <- splitObjPrefix line -> do let (content, next) = first BL.toStrict $ BL.splitAt (fromIntegral len) $ BL.drop 1 rest guard $ B.length content == len - (,next) <$> case otype of - _ | otype == BC.pack "blob" -> return $ Blob content - | otype == BC.pack "rec" -> maybe (throwOtherError $ "malformed record item ") - (return . Rec) $ sequence $ map parseRecLine $ mergeCont [] $ BC.lines content - | otherwise -> return $ UnknownObject otype content + (, next) <$> if + | otype == BC.pack "blob" + -> return $ Blob content + | otype == BC.pack "rec" + , Just ritems <- parseRecordBody st content + -> return $ Rec ritems + | otype == BC.pack "ondemand" + , Just ondemand <- parseOnDemand st content + -> return ondemand + | otherwise + -> return $ UnknownObject otype content _ -> throwOtherError $ "malformed object" - where splitObjPrefix line = do - [otype, tlen] <- return $ BLC.words line - (len, rest) <- BLC.readInt tlen - guard $ BL.null rest - return (BL.toStrict otype, len) - - mergeCont cs (a:b:rest) | Just ('\t', b') <- BC.uncons b = mergeCont (b':BC.pack "\n":cs) (a:rest) - mergeCont cs (a:rest) = B.concat (a : reverse cs) : mergeCont [] rest - mergeCont _ [] = [] - - parseRecLine line = do - colon <- BC.elemIndex ':' line - space <- BC.elemIndex ' ' line - guard $ colon < space - let name = B.take colon line - itype = B.take (space-colon-1) $ B.drop (colon+1) line - content = B.drop (space+1) line - - let val = fromMaybe (RecUnknown itype content) $ - case BC.unpack itype of - "e" -> do guard $ B.null content - return RecEmpty - "i" -> do (num, rest) <- BC.readInteger content - guard $ B.null rest - return $ RecInt num - "n" -> RecNum <$> parseRatio content - "t" -> return $ RecText $ decodeUtf8With lenientDecode content - "b" -> RecBinary <$> readHex content - "d" -> RecDate <$> parseTimeM False defaultTimeLocale "%s %z" (BC.unpack content) - "u" -> RecUUID <$> U.fromASCIIBytes content - "r" -> RecRef . Ref st <$> readRefDigest content - "w" -> RecWeak <$> readRefDigest content - _ -> Nothing - return (name, val) + where + splitObjPrefix line = do + [ otype, tlen ] <- return $ BLC.words line + ( len, rest ) <- BLC.readInt tlen + guard $ BL.null rest + return ( BL.toStrict otype, len ) + +parseRecordBody :: Storage' c -> ByteString -> Maybe [ ( ByteString, RecItem' c ) ] +parseRecordBody _ body | B.null body = Just [] +parseRecordBody st body = do + colon <- BC.elemIndex ':' body + space <- BC.elemIndex ' ' $ B.drop (colon + 1) body + let name = B.take colon body + itype = B.take space $ B.drop (colon + 1) body + ( content, remainingBody ) <- parseTabEscapedLines $ B.drop (space + colon + 2) body + + let val = fromMaybe (RecUnknown itype content) $ + case BC.unpack itype of + "e" -> do guard $ B.null content + return RecEmpty + "i" -> do ( num, rest ) <- BC.readInteger content + guard $ B.null rest + return $ RecInt num + "n" -> RecNum <$> parseRatio content + "t" -> return $ RecText $ decodeUtf8With lenientDecode content + "b" -> RecBinary <$> readHex content + "d" -> RecDate <$> parseTimeM False defaultTimeLocale "%s %z" (BC.unpack content) + "u" -> RecUUID <$> U.fromASCIIBytes content + "r" -> RecRef . Ref st <$> readRefDigest content + "w" -> RecWeak <$> readRefDigest content + _ -> Nothing + (( name, val ) :) <$> parseRecordBody st remainingBody + +-- Split given ByteString on the first newline not preceded by tab; replace +-- "\t\n" in the first part with "\n". +parseTabEscapedLines :: ByteString -> Maybe ( ByteString, ByteString ) +parseTabEscapedLines = parseLines [] + where + parseLines linesReversed cur = do + newline <- BC.elemIndex '\n' cur + case BC.indexMaybe cur (newline + 1) of + Just '\t' -> parseLines (B.take (newline + 1) cur : linesReversed) (B.drop (newline + 2) cur) + _ -> Just ( BC.concat $ reverse $ B.take newline cur : linesReversed, B.drop (newline + 1) cur ) + +parseOnDemand :: Storage' c -> ByteString -> Maybe (Object' c) +parseOnDemand _ body = do + newline1 <- BC.elemIndex '\n' body + newline2 <- BC.elemIndex '\n' $ B.drop (newline1 + 1) body + guard (newline1 + newline2 + 2 == B.length body) + ( size, sizeRest ) <- BC.readWord64 (B.take newline1 body) + guard (B.null sizeRest) + dgst <- readRefDigest $ B.take newline2 $ B.drop (newline1 + 1) body + return $ OnDemand size dgst + deserializeObject :: PartialStorage -> BL.ByteString -> Except ErebosError (PartialObject, BL.ByteString) deserializeObject = unsafeDeserializeObject @@ -332,10 +365,12 @@ class Storable a where class Storable a => ZeroStorable a where fromZero :: Storage -> a -data Store = StoreBlob ByteString - | StoreRec (forall c. StorageCompleteness c => Storage' c -> [IO [(ByteString, RecItem' c)]]) - | StoreZero - | StoreUnknown ByteString ByteString +data Store + = StoreBlob ByteString + | StoreRec (forall c. StorageCompleteness c => Storage' c -> [IO [(ByteString, RecItem' c)]]) + | StoreOnDemand Word64 RefDigest + | StoreZero + | StoreUnknown ByteString ByteString evalStore :: StorageCompleteness c => Storage' c -> Store -> IO (Ref' c) evalStore st = unsafeStoreObject st <=< evalStoreObject st @@ -343,6 +378,7 @@ evalStore st = unsafeStoreObject st <=< evalStoreObject st evalStoreObject :: StorageCompleteness c => Storage' c -> Store -> IO (Object' c) evalStoreObject _ (StoreBlob x) = return $ Blob x evalStoreObject s (StoreRec f) = Rec . concat <$> sequence (f s) +evalStoreObject _ (StoreOnDemand size dgst) = return $ OnDemand size dgst evalStoreObject _ StoreZero = return ZeroObject evalStoreObject _ (StoreUnknown otype content) = return $ UnknownObject otype content @@ -379,6 +415,7 @@ instance Storable Object where store' (Rec xs) = StoreRec $ \st -> return $ do Rec xs' <- copyObject st (Rec xs) return xs' + store' (OnDemand size dgst) = StoreOnDemand size dgst store' ZeroObject = StoreZero store' (UnknownObject otype content) = StoreUnknown otype content @@ -703,8 +740,6 @@ loadRawWeaks name = mapMaybe p <$> loadRecItems -type Stored a = Stored' Complete a - instance Storable a => Storable (Stored a) where store st = copyRef st . storedRef store' (Stored _ x) = store' x @@ -714,10 +749,10 @@ instance ZeroStorable a => ZeroStorable (Stored a) where fromZero st = Stored (zeroRef st) $ fromZero st fromStored :: Stored a -> a -fromStored (Stored _ x) = x +fromStored = storedObject' storedRef :: Stored a -> Ref -storedRef (Stored ref _) = ref +storedRef = storedRef' wrappedStore :: MonadIO m => Storable a => Storage -> a -> m (Stored a) wrappedStore st x = do ref <- liftIO $ store st x @@ -726,9 +761,8 @@ wrappedStore st x = do ref <- liftIO $ store st x wrappedLoad :: Storable a => Ref -> Stored a wrappedLoad ref = Stored ref (load ref) -copyStored :: forall c c' m a. (StorageCompleteness c, StorageCompleteness c', MonadIO m) => - Storage' c' -> Stored' c a -> m (LoadResult c (Stored' c' a)) -copyStored st (Stored ref' x) = liftIO $ returnLoadResult . fmap (flip Stored x) <$> copyRef' st ref' +copyStored :: forall m a. MonadIO m => Storage -> Stored a -> m (Stored a) +copyStored st (Stored ref' x) = liftIO $ returnLoadResult . fmap (\r -> Stored r x) <$> copyRef' st ref' -- |Passed function needs to preserve the object representation to be safe unsafeMapStored :: (a -> b) -> Stored a -> Stored b diff --git a/src/Erebos/Pairing.hs b/src/Erebos/Pairing.hs index e3ebf2b..d1fdc79 100644 --- a/src/Erebos/Pairing.hs +++ b/src/Erebos/Pairing.hs @@ -209,7 +209,7 @@ pairingRequest :: forall a m e proxy. (PairingResult a, MonadIO m, MonadError e pairingRequest _ peer = do self <- liftIO $ serverIdentity $ peerServer peer nonce <- liftIO $ getRandomBytes 32 - pid <- peerIdentity peer >>= \case + pid <- getPeerIdentity peer >>= \case PeerIdentityFull pid -> return pid _ -> throwOtherError "incomplete peer identity" sendToPeerWith @(PairingService a) peer $ \case diff --git a/src/Erebos/Service.hs b/src/Erebos/Service.hs index fefc503..303f9db 100644 --- a/src/Erebos/Service.hs +++ b/src/Erebos/Service.hs @@ -51,6 +51,9 @@ class ( serviceNewPeer :: ServiceHandler s () serviceNewPeer = return () + serviceUpdatedPeer :: ServiceHandler s () + serviceUpdatedPeer = return () + type ServiceAttributes s = attr | attr -> s type ServiceAttributes s = Proxy s defaultServiceAttributes :: proxy s -> ServiceAttributes s @@ -104,7 +107,9 @@ someServiceEmptyGlobalState :: SomeService -> SomeServiceGlobalState someServiceEmptyGlobalState (SomeService p _) = SomeServiceGlobalState p (emptyServiceGlobalState p) -data SomeStorageWatcher s = forall a. Eq a => SomeStorageWatcher (Stored LocalState -> a) (a -> ServiceHandler s ()) +data SomeStorageWatcher s + = forall a. Eq a => SomeStorageWatcher (Stored LocalState -> a) (a -> ServiceHandler s ()) + | forall a. Eq a => GlobalStorageWatcher (Stored LocalState -> a) (Server -> a -> ExceptT ErebosError IO ()) mkServiceID :: String -> ServiceID @@ -113,6 +118,7 @@ mkServiceID = maybe (error "Invalid service ID") ServiceID . U.fromString data ServiceInput s = ServiceInput { svcAttributes :: ServiceAttributes s , svcPeer :: Peer + , svcPeerAddress :: PeerAddress , svcPeerIdentity :: UnifiedIdentity , svcServer :: Server , svcPrintOp :: String -> IO () diff --git a/src/Erebos/Set.hs b/src/Erebos/Set.hs index 270c0ba..7453be4 100644 --- a/src/Erebos/Set.hs +++ b/src/Erebos/Set.hs @@ -10,7 +10,6 @@ module Erebos.Set ( ) where import Control.Arrow -import Control.Monad.IO.Class import Data.Function import Data.List @@ -53,14 +52,14 @@ emptySet = Set [] loadSet :: Mergeable a => Ref -> Set a loadSet = mergeSorted . (:[]) . wrappedLoad -storeSetAdd :: (Mergeable a, MonadIO m) => Storage -> a -> Set a -> m (Set a) -storeSetAdd st x (Set prev) = Set . (:[]) <$> wrappedStore st SetItem +storeSetAdd :: (Mergeable a, MonadStorage m) => a -> Set a -> m (Set a) +storeSetAdd x (Set prev) = Set . (: []) <$> mstore SetItem { siPrev = prev , siItem = toComponents x } -storeSetAddComponent :: (Mergeable a, MonadStorage m, MonadIO m) => Stored (Component a) -> Set a -> m (Set a) -storeSetAddComponent component (Set prev) = Set . (:[]) <$> mstore SetItem +storeSetAddComponent :: (Mergeable a, MonadStorage m) => Stored (Component a) -> Set a -> m (Set a) +storeSetAddComponent component (Set prev) = Set . (: []) <$> mstore SetItem { siPrev = prev , siItem = [ component ] } diff --git a/src/Erebos/State.hs b/src/Erebos/State.hs index 076a8c0..06e5c54 100644 --- a/src/Erebos/State.hs +++ b/src/Erebos/State.hs @@ -6,6 +6,7 @@ module Erebos.State ( MonadStorage(..), MonadHead(..), updateLocalHead_, + LocalHeadT(..), updateLocalState, updateLocalState_, updateSharedState, updateSharedState_, @@ -17,9 +18,11 @@ module Erebos.State ( mergeSharedIdentity, ) where +import Control.Monad import Control.Monad.Except import Control.Monad.Reader +import Data.Bifunctor import Data.ByteString (ByteString) import Data.ByteString.Char8 qualified as BC import Data.Typeable @@ -66,7 +69,7 @@ instance Storable LocalState where lsPrev <- loadMbRawWeak "PREV" lsIdentity <- loadRef "id" lsShared <- loadRefs "shared" - lsOther <- filter ((`notElem` [ BC.pack "id", BC.pack "shared" ]) . fst) <$> loadRecItems + lsOther <- filter ((`notElem` [ BC.pack "PREV", BC.pack "id", BC.pack "shared" ]) . fst) <$> loadRecItems return LocalState {..} instance HeadType LocalState where @@ -101,6 +104,35 @@ instance (HeadType a, MonadIO m) => MonadHead a (ReaderT (Head a) m) where snd <$> updateHead h f +newtype LocalHeadT h m a = LocalHeadT { runLocalHeadT :: Storage -> Stored h -> m ( a, Stored h ) } + +instance Functor m => Functor (LocalHeadT h m) where + fmap f (LocalHeadT act) = LocalHeadT $ \st h -> first f <$> act st h + +instance Monad m => Applicative (LocalHeadT h m) where + pure x = LocalHeadT $ \_ h -> pure ( x, h ) + (<*>) = ap + +instance Monad m => Monad (LocalHeadT h m) where + return = pure + LocalHeadT act >>= f = LocalHeadT $ \st h -> do + ( x, h' ) <- act st h + let (LocalHeadT act') = f x + act' st h' + +instance MonadIO m => MonadIO (LocalHeadT h m) where + liftIO act = LocalHeadT $ \_ h -> ( , h ) <$> liftIO act + +instance MonadIO m => MonadStorage (LocalHeadT h m) where + getStorage = LocalHeadT $ \st h -> return ( st, h ) + +instance (HeadType h, MonadIO m) => MonadHead h (LocalHeadT h m) where + updateLocalHead f = LocalHeadT $ \st h -> do + let LocalHeadT act = f h + ( ( h', x ), _ ) <- act st h + return ( x, h' ) + + localIdentity :: LocalState -> UnifiedIdentity localIdentity ls = maybe (error "failed to verify local identity") (updateOwners $ maybe [] idExtDataF $ lookupSharedValue $ lsShared ls) @@ -128,12 +160,11 @@ updateSharedState :: forall a b m. (SharedType a, MonadHead LocalState m) => (a updateSharedState f = \ls -> do let shared = lsShared $ fromStored ls val = lookupSharedValue shared - st <- getStorage (val', x) <- f val (,x) <$> if toComponents val' == toComponents val then return ls - else do shared' <- makeSharedStateUpdate st val' shared - wrappedStore st (fromStored ls) { lsShared = [shared'] } + else do shared' <- makeSharedStateUpdate val' shared + mstore (fromStored ls) { lsShared = [shared'] } lookupSharedValue :: forall a. SharedType a => [Stored SharedState] -> a lookupSharedValue = mergeSorted . filterAncestors . map wrappedLoad . concatMap (ssValue . fromStored) . filterAncestors . helper @@ -141,8 +172,8 @@ lookupSharedValue = mergeSorted . filterAncestors . map wrappedLoad . concatMap | otherwise = helper $ ssPrev (fromStored x) ++ xs helper [] = [] -makeSharedStateUpdate :: forall a m. MonadIO m => SharedType a => Storage -> a -> [Stored SharedState] -> m (Stored SharedState) -makeSharedStateUpdate st val prev = liftIO $ wrappedStore st SharedState +makeSharedStateUpdate :: forall a m. (SharedType a, MonadStorage m) => a -> [ Stored SharedState ] -> m (Stored SharedState) +makeSharedStateUpdate val prev = mstore SharedState { ssPrev = prev , ssType = Just $ sharedTypeID @a Proxy , ssValue = storedRef <$> toComponents val diff --git a/src/Erebos/Storable.hs b/src/Erebos/Storable.hs index caaf525..5ccb180 100644 --- a/src/Erebos/Storable.hs +++ b/src/Erebos/Storable.hs @@ -11,6 +11,7 @@ defined here as well. module Erebos.Storable ( Storable(..), ZeroStorable(..), StorableText(..), StorableDate(..), StorableUUID(..), + StorageCompleteness(..), Store, StoreRec, storeBlob, storeRec, storeZero, diff --git a/src/Erebos/Storage/Backend.hs b/src/Erebos/Storage/Backend.hs index 620d423..59097b6 100644 --- a/src/Erebos/Storage/Backend.hs +++ b/src/Erebos/Storage/Backend.hs @@ -9,12 +9,15 @@ module Erebos.Storage.Backend ( Complete, Partial, Storage, PartialStorage, newStorage, + refDigestBytes, WatchID, startWatchID, nextWatchID, ) where import Control.Concurrent.MVar +import Data.ByteArray qualified as BA +import Data.ByteString (ByteString) import Data.HashTable.IO qualified as HT import Erebos.Object.Internal @@ -26,3 +29,7 @@ newStorage stBackend = do stRefGeneration <- newMVar =<< HT.new stRefRoots <- newMVar =<< HT.new return Storage {..} + + +refDigestBytes :: RefDigest -> ByteString +refDigestBytes = BA.convert diff --git a/src/Erebos/Storage/Head.hs b/src/Erebos/Storage/Head.hs index 3239fe0..285902d 100644 --- a/src/Erebos/Storage/Head.hs +++ b/src/Erebos/Storage/Head.hs @@ -113,7 +113,7 @@ loadHeadRaw st@Storage {..} tid hid = do -- | Reload the given head from storage, returning `Head' with updated object, -- or `Nothing' if there is no longer head with the particular ID in storage. reloadHead :: (HeadType a, MonadIO m) => Head a -> m (Maybe (Head a)) -reloadHead (Head hid (Stored (Ref st _) _)) = loadHead st hid +reloadHead (Head hid val) = loadHead (storedStorage val) hid -- | Store a new `Head' of type 'a' in the storage. storeHead :: forall a m. MonadIO m => HeadType a => Storage -> a -> m (Head a) @@ -232,8 +232,8 @@ watchHeadWith -> (Head a -> b) -- ^ Selector function -> (b -> IO ()) -- ^ Callback -> IO WatchedHead -- ^ Watched head handle -watchHeadWith (Head hid (Stored (Ref st _) _)) sel cb = do - watchHeadRaw st (headTypeID @a Proxy) hid (sel . Head hid . wrappedLoad) cb +watchHeadWith (Head hid val) sel cb = do + watchHeadRaw (storedStorage val) (headTypeID @a Proxy) hid (sel . Head hid . wrappedLoad) cb -- | Watch the given head using raw IDs and a selector from `Ref'. watchHeadRaw :: forall b. Eq b => Storage -> HeadTypeID -> HeadID -> (Ref -> b) -> (b -> IO ()) -> IO WatchedHead diff --git a/src/Erebos/Storage/Internal.hs b/src/Erebos/Storage/Internal.hs index 303beb3..db211bb 100644 --- a/src/Erebos/Storage/Internal.hs +++ b/src/Erebos/Storage/Internal.hs @@ -20,7 +20,7 @@ module Erebos.Storage.Internal ( Generation(..), HeadID(..), HeadTypeID(..), - Stored'(..), storedStorage, + Stored(..), storedStorage, ) where import Control.Arrow @@ -37,6 +37,7 @@ import Data.ByteArray qualified as BA import Data.ByteString (ByteString) import Data.ByteString.Char8 qualified as BC import Data.ByteString.Lazy qualified as BL +import Data.Function import Data.HashTable.IO qualified as HT import Data.Hashable import Data.Kind @@ -239,17 +240,20 @@ newtype HeadID = HeadID UUID newtype HeadTypeID = HeadTypeID UUID deriving (Eq, Ord) -data Stored' c a = Stored (Ref' c) a +data Stored a = Stored + { storedRef' :: Ref + , storedObject' :: a + } deriving (Show) -instance Eq (Stored' c a) where - Stored r1 _ == Stored r2 _ = refDigest r1 == refDigest r2 +instance Eq (Stored a) where + (==) = (==) `on` (refDigest . storedRef') -instance Ord (Stored' c a) where - compare (Stored r1 _) (Stored r2 _) = compare (refDigest r1) (refDigest r2) +instance Ord (Stored a) where + compare = compare `on` (refDigest . storedRef') -storedStorage :: Stored' c a -> Storage' c -storedStorage (Stored (Ref st _) _) = st +storedStorage :: Stored a -> Storage +storedStorage = refStorage . storedRef' type Complete = Identity diff --git a/src/Erebos/Storage/Memory.hs b/src/Erebos/Storage/Memory.hs index 677e8c5..26bb181 100644 --- a/src/Erebos/Storage/Memory.hs +++ b/src/Erebos/Storage/Memory.hs @@ -4,7 +4,8 @@ module Erebos.Storage.Memory ( derivePartialStorage, ) where -import Control.Concurrent.MVar +import Control.Concurrent +import Control.Monad import Data.ByteArray (ScrubbedBytes) import Data.ByteString.Lazy qualified as BL @@ -62,14 +63,19 @@ instance (StorageCompleteness c, Typeable p) => StorageBackend (MemoryStorage p backendReplaceHead StorageMemory {..} tid hid expected new = do res <- modifyMVar memHeads $ \hs -> do - ws <- map wlFun . filter ((==(tid, hid)) . wlHead) . wlList <$> readMVar memWatchers - return $ case partition ((==(tid, hid)) . fst) hs of - ( [] , _ ) -> ( hs, Left Nothing ) + case partition ((==(tid, hid)) . fst) hs of + ( [] , _ ) -> return ( hs, Left Nothing ) (( _, dgst ) : _, hs' ) - | dgst == expected -> ((( tid, hid ), new ) : hs', Right ( new, ws )) - | otherwise -> ( hs, Left $ Just dgst ) + | dgst == expected -> do + ws <- map wlFun . filter ((==(tid, hid)) . wlHead) . wlList <$> readMVar memWatchers + return ((( tid, hid ), new ) : hs', Right ( new, ws )) + | otherwise -> do + return ( hs, Left $ Just dgst ) case res of - Right ( dgst, ws ) -> mapM_ ($ dgst) ws >> return (Right dgst) + Right ( dgst, ws ) -> do + void $ forkIO $ do + mapM_ ($ dgst) ws + return (Right dgst) Left x -> return $ Left x backendWatchHead StorageMemory {..} tid hid cb = modifyMVar memWatchers $ return . watchListAdd tid hid cb diff --git a/src/Erebos/Storage/Merge.hs b/src/Erebos/Storage/Merge.hs index 41725af..8221e91 100644 --- a/src/Erebos/Storage/Merge.hs +++ b/src/Erebos/Storage/Merge.hs @@ -7,7 +7,7 @@ module Erebos.Storage.Merge ( compareGeneration, generationMax, storedGeneration, - generations, + generations, generationsBy, ancestors, precedes, precedesOrEquals, @@ -17,6 +17,8 @@ module Erebos.Storage.Merge ( findProperty, findPropertyFirst, + + storedDifference, ) where import Control.Concurrent.MVar @@ -25,6 +27,8 @@ import Data.ByteString.Char8 qualified as BC import Data.HashTable.IO qualified as HT import Data.Kind import Data.List +import Data.List.NonEmpty (NonEmpty) +import Data.List.NonEmpty qualified as NE import Data.Maybe import Data.Set (Set) import Data.Set qualified as S @@ -52,7 +56,7 @@ merge xs = mergeSorted $ filterAncestors xs storeMerge :: (Mergeable a, Storable a) => [Stored (Component a)] -> IO (Stored a) storeMerge [] = error "merge: empty list" -storeMerge xs@(Stored ref _ : _) = wrappedStore (refStorage ref) $ mergeSorted $ filterAncestors xs +storeMerge xs@(x : _) = wrappedStore (storedStorage x) $ mergeSorted $ filterAncestors xs previous :: Storable a => Stored a -> [Stored a] previous (Stored ref _) = case load ref of @@ -100,16 +104,24 @@ storedGeneration x = -- |Returns list of sets starting with the set of given objects and -- intcrementally adding parents. -generations :: Storable a => [Stored a] -> [Set (Stored a)] -generations = unfoldr gen . (,S.empty) - where gen (hs, cur) = case filter (`S.notMember` cur) hs of - [] -> Nothing - added -> let next = foldr S.insert cur added - in Just (next, (previous =<< added, next)) +generations :: Storable a => [ Stored a ] -> NonEmpty (Set (Stored a)) +generations = generationsBy previous + +-- |Returns list of sets starting with the set of given objects and +-- intcrementally adding parents, with the first parameter being +-- a function to get all the parents of given object. +generationsBy :: Ord a => (a -> [ a ]) -> [ a ] -> NonEmpty (Set a) +generationsBy parents xs = NE.unfoldr gen ( xs, S.fromList xs ) + where + gen ( hs, cur ) = ( cur, ) $ + case filter (`S.notMember` cur) (parents =<< hs) of + [] -> Nothing + added -> let next = foldr S.insert cur added + in Just ( added, next ) -- |Returns set containing all given objects and their ancestors ancestors :: Storable a => [Stored a] -> Set (Stored a) -ancestors = last . (S.empty:) . generations +ancestors = NE.last . generations precedes :: Storable a => Stored a -> Stored a -> Bool precedes x y = not $ x `elem` filterAncestors [x, y] @@ -162,3 +174,18 @@ findPropertyFirst sel = fmap (fromJust . sel . fromStored) . listToMaybe . filte findPropHeads :: forall a b. Storable a => (a -> Maybe b) -> Stored a -> [Stored a] findPropHeads sel sobj | Just _ <- sel $ fromStored sobj = [sobj] | otherwise = findPropHeads sel =<< previous sobj + + +-- | Compute symmetrict difference between two stored histories. In other +-- words, return all 'Stored a' objects reachable (via 'previous') from first +-- given set, but not from the second; and vice versa. +storedDifference :: Storable a => [ Stored a ] -> [ Stored a ] -> [ Stored a ] +storedDifference xs' ys' = + let xs = filterAncestors xs' + ys = filterAncestors ys' + + filteredPrevious blocked zs = filterAncestors (previous zs ++ blocked) `diffSorted` blocked + xg = S.toAscList $ NE.last $ generationsBy (filteredPrevious ys) $ filterAncestors (xs ++ ys) `diffSorted` ys + yg = S.toAscList $ NE.last $ generationsBy (filteredPrevious xs) $ filterAncestors (ys ++ xs) `diffSorted` xs + + in xg `mergeUniq` yg diff --git a/src/Erebos/Sync.hs b/src/Erebos/Sync.hs index d837a14..5f5fdec 100644 --- a/src/Erebos/Sync.hs +++ b/src/Erebos/Sync.hs @@ -31,6 +31,7 @@ instance Service SyncService where else return ls serviceNewPeer = notifyPeer . lsShared . fromStored =<< svcGetLocal + serviceUpdatedPeer = serviceNewPeer serviceStorageWatchers _ = (:[]) $ SomeStorageWatcher (lsShared . fromStored) notifyPeer instance Storable SyncService where diff --git a/src/Erebos/Util.hs b/src/Erebos/Util.hs index 0381c3e..0d53e98 100644 --- a/src/Erebos/Util.hs +++ b/src/Erebos/Util.hs @@ -22,15 +22,16 @@ mergeBy cmp (x : xs) (y : ys) = case cmp x y of mergeBy _ xs [] = xs mergeBy _ [] ys = ys -mergeUniqBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a] -mergeUniqBy cmp (x : xs) (y : ys) = case cmp x y of - LT -> x : mergeBy cmp xs (y : ys) - EQ -> x : mergeBy cmp xs ys - GT -> y : mergeBy cmp (x : xs) ys +mergeUniqBy :: (a -> a -> Ordering) -> [ a ] -> [ a ] -> [ a ] +mergeUniqBy cmp (x : xs) (y : ys) = + case cmp x y of + LT -> x : mergeUniqBy cmp xs (y : ys) + EQ -> x : mergeUniqBy cmp xs ys + GT -> y : mergeUniqBy cmp (x : xs) ys mergeUniqBy _ xs [] = xs mergeUniqBy _ [] ys = ys -mergeUniq :: Ord a => [a] -> [a] -> [a] +mergeUniq :: Ord a => [ a ] -> [ a ] -> [ a ] mergeUniq = mergeUniqBy compare diffSorted :: Ord a => [a] -> [a] -> [a] |