diff options
Diffstat (limited to 'src/Erebos')
-rw-r--r-- | src/Erebos/Chatroom.hs | 216 | ||||
-rw-r--r-- | src/Erebos/Conversation.hs | 36 | ||||
-rw-r--r-- | src/Erebos/Identity.hs | 33 | ||||
-rw-r--r-- | src/Erebos/Network.hs | 120 | ||||
-rw-r--r-- | src/Erebos/Network/Protocol.hs | 164 | ||||
-rw-r--r-- | src/Erebos/Network/ifaddrs.c | 132 | ||||
-rw-r--r-- | src/Erebos/Network/ifaddrs.h | 2 | ||||
-rw-r--r-- | src/Erebos/Storage/Internal.hs | 6 | ||||
-rw-r--r-- | src/Erebos/Storage/Key.hs | 7 | ||||
-rw-r--r-- | src/Erebos/Storage/Merge.hs | 7 |
10 files changed, 596 insertions, 127 deletions
diff --git a/src/Erebos/Chatroom.hs b/src/Erebos/Chatroom.hs index 673c59f..c8b5805 100644 --- a/src/Erebos/Chatroom.hs +++ b/src/Erebos/Chatroom.hs @@ -11,14 +11,20 @@ module Erebos.Chatroom ( findChatroomByRoomData, findChatroomByStateData, chatroomSetSubscribe, + chatroomMembers, + joinChatroom, joinChatroomByStateData, + leaveChatroom, leaveChatroomByStateData, getMessagesSinceState, ChatroomSetChange(..), watchChatrooms, - ChatMessage, cmsgFrom, cmsgReplyTo, cmsgTime, cmsgText, cmsgLeave, + ChatMessage, + cmsgFrom, cmsgReplyTo, cmsgTime, cmsgText, cmsgLeave, + cmsgRoom, cmsgRoomData, ChatMessageData(..), - chatroomMessageByStateData, + sendChatroomMessage, + sendChatroomMessageByStateData, ChatroomService(..), ) where @@ -29,6 +35,9 @@ import Control.Monad.Except import Control.Monad.IO.Class import Data.Bool +import Data.Either +import Data.Foldable +import Data.Function import Data.IORef import Data.List import Data.Maybe @@ -111,6 +120,11 @@ data ChatMessage = ChatMessage { cmsgData :: Stored (Signed ChatMessageData) } +validateSingleMessage :: Stored (Signed ChatMessageData) -> Maybe ChatMessage +validateSingleMessage sdata = do + guard $ fromStored sdata `isSignedBy` idKeyMessage (mdFrom (fromSigned sdata)) + return $ ChatMessage sdata + cmsgFrom :: ChatMessage -> ComposedIdentity cmsgFrom = mdFrom . fromSigned . cmsgData @@ -126,6 +140,12 @@ cmsgText = mdText . fromSigned . cmsgData cmsgLeave :: ChatMessage -> Bool cmsgLeave = mdLeave . fromSigned . cmsgData +cmsgRoom :: ChatMessage -> Maybe Chatroom +cmsgRoom = either (const Nothing) Just . runExcept . validateChatroom . cmsgRoomData + +cmsgRoomData :: ChatMessage -> [ Stored (Signed ChatroomData) ] +cmsgRoomData = concat . findProperty ((\case [] -> Nothing; xs -> Just xs) . mdRoom . fromStored . signedData) . (: []) . cmsgData + instance Storable ChatMessageData where store' ChatMessageData {..} = storeRec $ do mapM_ (storeRef "SPREV") mdPrev @@ -146,37 +166,42 @@ instance Storable ChatMessageData where mdLeave <- isJust <$> loadMbEmpty "leave" return ChatMessageData {..} -threadToList :: [Stored (Signed ChatMessageData)] -> [ChatMessage] -threadToList thread = helper S.empty $ thread +threadToListSince :: [ Stored (Signed ChatMessageData) ] -> [ Stored (Signed ChatMessageData) ] -> [ ChatMessage ] +threadToListSince since thread = helper (S.fromList since) thread where helper :: S.Set (Stored (Signed ChatMessageData)) -> [Stored (Signed ChatMessageData)] -> [ChatMessage] helper seen msgs | msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs = - messageFromData msg : helper (S.insert msg seen) (msgs' ++ mdPrev (fromSigned msg)) + maybe id (:) (validateSingleMessage msg) $ + helper (S.insert msg seen) (msgs' ++ mdPrev (fromSigned msg)) | otherwise = [] cmpView msg = (zonedTimeToUTC $ mdTime $ fromSigned msg, msg) - messageFromData :: Stored (Signed ChatMessageData) -> ChatMessage - messageFromData sdata = ChatMessage { cmsgData = sdata } +sendChatroomMessage + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => ChatroomState -> Text -> m () +sendChatroomMessage rstate msg = sendChatroomMessageByStateData (head $ roomStateData rstate) msg -chatroomMessageByStateData +sendChatroomMessageByStateData :: (MonadStorage m, MonadHead LocalState m, MonadError String m) => Stored ChatroomStateData -> Text -> m () -chatroomMessageByStateData lookupData msg = void $ findAndUpdateChatroomState $ \cstate -> do +sendChatroomMessageByStateData lookupData msg = sendRawChatroomMessageByStateData lookupData Nothing (Just msg) False + +sendRawChatroomMessageByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> Maybe (Stored (Signed ChatMessageData)) -> Maybe Text -> Bool -> m () +sendRawChatroomMessageByStateData lookupData mdReplyTo mdText mdLeave = void $ findAndUpdateChatroomState $ \cstate -> do guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate Just $ do - self <- finalOwner . localIdentity . fromStored <$> getLocalHead - secret <- loadKey $ idKeyMessage self - time <- liftIO getZonedTime - mdata <- mstore =<< sign secret =<< mstore ChatMessageData - { mdPrev = roomStateMessageData cstate - , mdRoom = [] - , mdFrom = self - , mdReplyTo = Nothing - , mdTime = time - , mdText = Just msg - , mdLeave = False - } + mdFrom <- finalOwner . localIdentity . fromStored <$> getLocalHead + secret <- loadKey $ idKeyMessage mdFrom + mdTime <- liftIO getZonedTime + let mdPrev = roomStateMessageData cstate + mdRoom = if null (roomStateMessageData cstate) + then maybe [] roomData (roomStateRoom cstate) + else [] + + mdata <- mstore =<< sign secret =<< mstore ChatMessageData {..} mergeSorted . (:[]) <$> mstore ChatroomStateData { rsdPrev = roomStateData cstate , rsdRoom = [] @@ -224,7 +249,7 @@ instance Mergeable ChatroomState where ChatroomStateData {..} | null rsdMessages -> Nothing | otherwise -> Just rsdMessages roomStateSubscribe = fromMaybe False $ findPropertyFirst rsdSubscribe roomStateData - roomStateMessages = threadToList $ concatMap (rsdMessages . fromStored) roomStateData + roomStateMessages = threadToListSince [] $ concatMap (rsdMessages . fromStored) roomStateData in ChatroomState {..} toComponents = roomStateData @@ -321,11 +346,38 @@ chatroomSetSubscribe lookupData subscribe = void $ findAndUpdateChatroomState $ , rsdMessages = [] } +chatroomMembers :: ChatroomState -> [ ComposedIdentity ] +chatroomMembers ChatroomState {..} = + map (mdFrom . fromSigned . head) $ + filter (any $ not . mdLeave . fromSigned) $ -- keep only users that hasn't left + map (filterAncestors . map snd) $ -- gather message data per each identity and filter ancestors + groupBy ((==) `on` fst) $ -- group on identity root + sortBy (comparing fst) $ -- sort by first root of identity data + map (\x -> ( head . filterAncestors . concatMap storedRoots . idDataF . mdFrom . fromSigned $ x, x )) $ + toList $ ancestors $ roomStateMessageData + +joinChatroom + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => ChatroomState -> m () +joinChatroom rstate = joinChatroomByStateData (head $ roomStateData rstate) + +joinChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> m () +joinChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing False + +leaveChatroom + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => ChatroomState -> m () +leaveChatroom rstate = leaveChatroomByStateData (head $ roomStateData rstate) + +leaveChatroomByStateData + :: (MonadStorage m, MonadHead LocalState m, MonadError String m) + => Stored ChatroomStateData -> m () +leaveChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing True + getMessagesSinceState :: ChatroomState -> ChatroomState -> [ChatMessage] -getMessagesSinceState cur old = takeWhile notOld (roomStateMessages cur) - where - notOld msg = cmsgData msg `notElem` roomStateMessageData old - -- TODO: parallel message threads +getMessagesSinceState cur old = threadToListSince (roomStateMessageData old) (roomStateMessageData cur) data ChatroomSetChange = AddedChatroom ChatroomState @@ -365,13 +417,18 @@ makeChatroomDiff [] ys = map (AddedChatroom . snd) ys data ChatroomService = ChatroomService { chatRoomQuery :: Bool , chatRoomInfo :: [Stored (Signed ChatroomData)] + , chatRoomSubscribe :: [Stored (Signed ChatroomData)] + , chatRoomUnsubscribe :: [Stored (Signed ChatroomData)] , chatRoomMessage :: [Stored (Signed ChatMessageData)] } + deriving (Eq) emptyPacket :: ChatroomService emptyPacket = ChatroomService { chatRoomQuery = False , chatRoomInfo = [] + , chatRoomSubscribe = [] + , chatRoomUnsubscribe = [] , chatRoomMessage = [] } @@ -379,17 +436,22 @@ instance Storable ChatroomService where store' ChatroomService {..} = storeRec $ do when chatRoomQuery $ storeEmpty "room-query" forM_ chatRoomInfo $ storeRef "room-info" + forM_ chatRoomSubscribe $ storeRef "room-subscribe" + forM_ chatRoomUnsubscribe $ storeRef "room-unsubscribe" forM_ chatRoomMessage $ storeRef "room-message" load' = loadRec $ do chatRoomQuery <- isJust <$> loadMbEmpty "room-query" chatRoomInfo <- loadRefs "room-info" + chatRoomSubscribe <- loadRefs "room-subscribe" + chatRoomUnsubscribe <- loadRefs "room-unsubscribe" chatRoomMessage <- loadRefs "room-message" return ChatroomService {..} data PeerState = PeerState { psSendRoomUpdates :: Bool , psLastList :: [(Stored ChatroomStateData, ChatroomState)] + , psSubscribedTo :: [ Stored (Signed ChatroomData) ] -- least root for each room } instance Service ChatroomService where @@ -399,12 +461,18 @@ instance Service ChatroomService where emptyServiceState _ = PeerState { psSendRoomUpdates = False , psLastList = [] + , psSubscribedTo = [] } serviceHandler spacket = do let ChatroomService {..} = fromStored spacket + + previouslyUpdated <- psSendRoomUpdates <$> svcGet svcModify $ \s -> s { psSendRoomUpdates = True } + when (not previouslyUpdated) $ do + syncChatroomsToPeer . lookupSharedValue . lsShared . fromStored =<< getLocalHead + when chatRoomQuery $ do rooms <- listChatrooms replyPacket emptyPacket @@ -420,7 +488,7 @@ instance Service ChatroomService where maybe [] roomData . roomStateRoom let prev = concatMap roomStateData $ filter isCurrentRoom rooms - prevRoom = concatMap (rsdRoom . fromStored) prev + prevRoom = filterAncestors $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . rsdRoom) prev room = filterAncestors $ (roomInfo : ) prevRoom -- update local state only if we got roomInfo not present there @@ -436,6 +504,51 @@ instance Service ChatroomService where else return set foldM upd roomSet chatRoomInfo + forM_ chatRoomSubscribe $ \subscribeData -> do + mbRoomState <- findChatroomByRoomData subscribeData + forM_ mbRoomState $ \roomState -> + forM (roomStateRoom roomState) $ \room -> do + let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room + svcModify $ \ps -> ps { psSubscribedTo = leastRoot : psSubscribedTo ps } + replyPacket emptyPacket + { chatRoomMessage = roomStateMessageData roomState + } + + forM_ chatRoomUnsubscribe $ \unsubscribeData -> do + mbRoomState <- findChatroomByRoomData unsubscribeData + forM_ (mbRoomState >>= roomStateRoom) $ \room -> do + let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room + svcModify $ \ps -> ps { psSubscribedTo = filter (/= leastRoot) (psSubscribedTo ps) } + + when (not (null chatRoomMessage)) $ do + updateLocalHead_ $ updateSharedState_ $ \roomSet -> do + let rooms = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet + upd set (msgData :: Stored (Signed ChatMessageData)) + | Just msg <- validateSingleMessage msgData = do + let roomInfo = cmsgRoomData msg + currentRoots = filterAncestors $ concatMap storedRoots roomInfo + isCurrentRoom = any ((`intersectsSorted` currentRoots) . storedRoots) . + maybe [] roomData . roomStateRoom + + let prevData = concatMap roomStateData $ filter isCurrentRoom rooms + prev = mergeSorted prevData + prevMessages = roomStateMessageData prev + messages = filterAncestors $ msgData : prevMessages + + -- update local state only if subscribed and we got some new messages + if roomStateSubscribe prev && messages /= prevMessages + then do + sdata <- mstore ChatroomStateData + { rsdPrev = prevData + , rsdRoom = [] + , rsdSubscribe = Nothing + , rsdMessages = messages + } + storeSetAddComponent sdata set + else return set + | otherwise = return set + foldM upd roomSet chatRoomMessage + serviceNewPeer = do replyPacket emptyPacket { chatRoomQuery = True } @@ -447,11 +560,50 @@ syncChatroomsToPeer set = do ps@PeerState {..} <- svcGet when psSendRoomUpdates $ do let curList = chatroomSetToList set - updates <- fmap (concat . catMaybes) $ - forM (makeChatroomDiff psLastList curList) $ return . \case + diff = makeChatroomDiff psLastList curList + + roomUpdates <- fmap (concat . catMaybes) $ + forM diff $ return . \case AddedChatroom room -> roomData <$> roomStateRoom room RemovedChatroom {} -> Nothing - UpdatedChatroom _ room -> roomData <$> roomStateRoom room - when (not $ null updates) $ do - replyPacket $ emptyPacket { chatRoomInfo = updates } + UpdatedChatroom oldroom room + | roomStateData oldroom /= roomStateData room -> roomData <$> roomStateRoom room + | otherwise -> Nothing + + (subscribe, unsubscribe) <- fmap (partitionEithers . concat . catMaybes) $ + forM diff $ return . \case + AddedChatroom room + | roomStateSubscribe room + -> map Left . roomData <$> roomStateRoom room + RemovedChatroom oldroom + | roomStateSubscribe oldroom + -> map Right . roomData <$> roomStateRoom oldroom + UpdatedChatroom oldroom room + | roomStateSubscribe oldroom /= roomStateSubscribe room + -> map (if roomStateSubscribe room then Left else Right) . roomData <$> roomStateRoom room + _ -> Nothing + + messages <- fmap concat $ do + let leastRootFor = head . filterAncestors . concatMap storedRoots . roomData + forM diff $ return . \case + AddedChatroom rstate + | Just room <- roomStateRoom rstate + , leastRootFor room `elem` psSubscribedTo + -> roomStateMessageData rstate + UpdatedChatroom oldstate rstate + | Just room <- roomStateRoom rstate + , leastRootFor room `elem` psSubscribedTo + , roomStateMessageData oldstate /= roomStateMessageData rstate + -> roomStateMessageData rstate + _ -> [] + + let packet = emptyPacket + { chatRoomInfo = roomUpdates + , chatRoomSubscribe = subscribe + , chatRoomUnsubscribe = unsubscribe + , chatRoomMessage = messages + } + + when (packet /= emptyPacket) $ do + replyPacket packet svcSet $ ps { psLastList = curList } diff --git a/src/Erebos/Conversation.hs b/src/Erebos/Conversation.hs index 94d2399..63475bd 100644 --- a/src/Erebos/Conversation.hs +++ b/src/Erebos/Conversation.hs @@ -1,12 +1,15 @@ module Erebos.Conversation ( Message, messageFrom, + messageTime, messageText, messageUnread, formatMessage, Conversation, directMessageConversation, + chatroomConversation, + chatroomConversationByStateData, reloadConversation, lookupConversations, @@ -23,30 +26,45 @@ import Data.List import Data.Maybe import Data.Text (Text) import Data.Text qualified as T +import Data.Time.Format import Data.Time.LocalTime import Erebos.Identity +import Erebos.Chatroom import Erebos.Message hiding (formatMessage) import Erebos.State import Erebos.Storage data Message = DirectMessageMessage DirectMessage Bool + | ChatroomMessage ChatMessage Bool messageFrom :: Message -> ComposedIdentity messageFrom (DirectMessageMessage msg _) = msgFrom msg +messageFrom (ChatroomMessage msg _) = cmsgFrom msg + +messageTime :: Message -> ZonedTime +messageTime (DirectMessageMessage msg _) = msgTime msg +messageTime (ChatroomMessage msg _) = cmsgTime msg messageText :: Message -> Maybe Text messageText (DirectMessageMessage msg _) = Just $ msgText msg +messageText (ChatroomMessage msg _) = cmsgText msg messageUnread :: Message -> Bool messageUnread (DirectMessageMessage _ unread) = unread +messageUnread (ChatroomMessage _ unread) = unread formatMessage :: TimeZone -> Message -> String -formatMessage tzone (DirectMessageMessage msg _) = formatDirectMessage tzone msg +formatMessage tzone msg = concat + [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ messageTime msg + , maybe "<unnamed>" T.unpack $ idName $ messageFrom msg + , maybe "" ((": "<>) . T.unpack) $ messageText msg + ] data Conversation = DirectMessageConversation DirectMessageThread + | ChatroomConversation ChatroomState directMessageConversation :: MonadHead LocalState m => ComposedIdentity -> m Conversation directMessageConversation peer = do @@ -54,8 +72,16 @@ directMessageConversation peer = do Just thread -> return $ DirectMessageConversation thread Nothing -> return $ DirectMessageConversation $ DirectMessageThread peer [] [] [] +chatroomConversation :: MonadHead LocalState m => ChatroomState -> m (Maybe Conversation) +chatroomConversation rstate = chatroomConversationByStateData (head $ roomStateData rstate) + +chatroomConversationByStateData :: MonadHead LocalState m => Stored ChatroomStateData -> m (Maybe Conversation) +chatroomConversationByStateData sdata = fmap ChatroomConversation <$> findChatroomByStateData sdata + reloadConversation :: MonadHead LocalState m => Conversation -> m Conversation reloadConversation (DirectMessageConversation thread) = directMessageConversation (msgPeer thread) +reloadConversation cur@(ChatroomConversation rstate) = + fromMaybe cur <$> chatroomConversation rstate lookupConversations :: MonadHead LocalState m => m [Conversation] lookupConversations = map DirectMessageConversation . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead @@ -63,13 +89,17 @@ lookupConversations = map DirectMessageConversation . toThreadList . lookupShare conversationName :: Conversation -> Text conversationName (DirectMessageConversation thread) = fromMaybe (T.pack "<unnamed>") $ idName $ msgPeer thread +conversationName (ChatroomConversation rstate) = fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate conversationPeer :: Conversation -> Maybe ComposedIdentity conversationPeer (DirectMessageConversation thread) = Just $ msgPeer thread +conversationPeer (ChatroomConversation _) = Nothing conversationHistory :: Conversation -> [Message] conversationHistory (DirectMessageConversation thread) = map (\msg -> DirectMessageMessage msg False) $ threadToList thread +conversationHistory (ChatroomConversation rstate) = map (\msg -> ChatroomMessage msg False) $ roomStateMessages rstate -sendMessage :: (MonadHead LocalState m, MonadError String m) => Conversation -> Text -> m Message -sendMessage (DirectMessageConversation thread) text = DirectMessageMessage <$> (fromStored <$> sendDirectMessage (msgPeer thread) text) <*> pure False +sendMessage :: (MonadHead LocalState m, MonadError String m) => Conversation -> Text -> m (Maybe Message) +sendMessage (DirectMessageConversation thread) text = fmap Just $ DirectMessageMessage <$> (fromStored <$> sendDirectMessage (msgPeer thread) text) <*> pure False +sendMessage (ChatroomConversation rstate) text = sendChatroomMessage rstate text >> return Nothing diff --git a/src/Erebos/Identity.hs b/src/Erebos/Identity.hs index 8761fde..f2094f6 100644 --- a/src/Erebos/Identity.hs +++ b/src/Erebos/Identity.hs @@ -35,7 +35,6 @@ import Data.Foldable import Data.Function import Data.List import Data.Maybe -import Data.Ord import Data.Set (Set) import qualified Data.Set as S import Data.Text (Text) @@ -304,25 +303,18 @@ verifySignatures sidd = do throwError "signature verification failed" lookupProperty :: forall a m. Foldable m => (ExtendedIdentityData -> Maybe a) -> m (Stored (Signed ExtendedIdentityData)) -> Maybe a -lookupProperty sel topHeads = findResult filteredLayers - where findPropHeads :: Stored (Signed ExtendedIdentityData) -> [(Stored (Signed ExtendedIdentityData), a)] - findPropHeads sobj | Just x <- sel $ fromSigned sobj = [(sobj, x)] - | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj) +lookupProperty sel topHeads = findResult propHeads + where + findPropHeads :: Stored (Signed ExtendedIdentityData) -> [ Stored (Signed ExtendedIdentityData) ] + findPropHeads sobj | Just _ <- sel $ fromSigned sobj = [ sobj ] + | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj) - propHeads :: [(Stored (Signed ExtendedIdentityData), a)] - propHeads = findPropHeads =<< toList topHeads + propHeads :: [ Stored (Signed ExtendedIdentityData) ] + propHeads = filterAncestors $ findPropHeads =<< toList topHeads - historyLayers :: [Set (Stored (Signed ExtendedIdentityData))] - historyLayers = generations $ map fst propHeads - - filteredLayers :: [[(Stored (Signed ExtendedIdentityData), a)]] - filteredLayers = scanl (\cur obsolete -> filter ((`S.notMember` obsolete) . fst) cur) propHeads historyLayers - - findResult ([(_, x)] : _) = Just x - findResult ([] : _) = Nothing - findResult [] = Nothing - findResult [xs] = Just $ snd $ minimumBy (comparing fst) xs - findResult (_:rest) = findResult rest + findResult :: [ Stored (Signed ExtendedIdentityData) ] -> Maybe a + findResult [] = Nothing + findResult xs = sel $ fromSigned $ minimum xs mergeIdentity :: (MonadStorage m, MonadError String m, MonadIO m) => Identity f -> m UnifiedIdentity mergeIdentity idt | Just idt' <- toUnifiedIdentity idt = return idt' @@ -385,8 +377,9 @@ updateOwners updates orig@Identity { idOwner_ = Just owner, idUpdates_ = cupdate updateOwners _ orig@Identity { idOwner_ = Nothing } = orig sameIdentity :: (Foldable m, Foldable m') => Identity m -> Identity m' -> Bool -sameIdentity x y = not $ S.null $ S.intersection (refset x) (refset y) - where refset idt = foldr S.insert (ancestors $ toList $ idDataF idt) (idDataF idt) +sameIdentity x y = intersectsSorted (roots x) (roots y) + where + roots idt = uniq $ sort $ concatMap storedRoots $ toList $ idDataF idt unfoldOwners :: (Foldable m) => Identity m -> [ComposedIdentity] diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs index 41b6279..2064d1c 100644 --- a/src/Erebos/Network.hs +++ b/src/Erebos/Network.hs @@ -19,7 +19,9 @@ module Erebos.Network ( #endif dropPeer, isPeerDropped, - sendToPeer, sendToPeerStored, sendToPeerWith, + sendToPeer, sendManyToPeer, + sendToPeerStored, sendManyToPeerStored, + sendToPeerWith, runPeerService, discoveryPort, @@ -52,6 +54,9 @@ import GHC.Conc.Sync (unsafeIOToSTM) import Network.Socket hiding (ControlMessage) import qualified Network.Socket.ByteString as S +import Foreign.C.Types +import Foreign.Marshal.Alloc + import Erebos.Channel #ifdef ENABLE_ICE_SUPPORT import Erebos.ICE @@ -69,6 +74,9 @@ import Erebos.Storage.Merge discoveryPort :: PortNumber discoveryPort = 29665 +discoveryMulticastGroup :: HostAddress6 +discoveryMulticastGroup = tupleToHostAddress6 (0xff12, 0xb6a4, 0x6b1f, 0x0969, 0xcaee, 0xacc2, 0x5c93, 0x73e1) -- ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1 + announceIntervalSeconds :: Int announceIntervalSeconds = 60 @@ -247,8 +255,6 @@ startServer opt serverOrigHead logd' serverServices = do either (atomically . logd) return =<< runExceptT =<< atomically (readTQueue serverIOActions) - broadcastAddreses <- getBroadcastAddresses discoveryPort - let open addr = do sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) putMVar serverSocket sock @@ -259,9 +265,14 @@ startServer opt serverOrigHead logd' serverServices = do return sock loop sock = do - when (serverLocalDiscovery opt) $ forkServerThread server $ forever $ do - atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) broadcastAddreses - threadDelay $ announceIntervalSeconds * 1000 * 1000 + when (serverLocalDiscovery opt) $ forkServerThread server $ do + announceAddreses <- fmap concat $ sequence $ + [ map (SockAddrInet6 discoveryPort 0 discoveryMulticastGroup) <$> joinMulticast sock + , getBroadcastAddresses discoveryPort + ] + forever $ do + atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) announceAddreses + threadDelay $ announceIntervalSeconds * 1000 * 1000 let announceUpdate identity = do st <- derivePartialStorage serverStorage @@ -301,10 +312,11 @@ startServer opt serverOrigHead logd' serverServices = do forkServerThread server $ forever $ do (paddr, msg) <- readFlowIO serverRawPath - case paddr of - DatagramAddress addr -> void $ S.sendTo sock msg addr + handle (\(e :: IOException) -> atomically . logd $ "failed to send packet to " ++ show paddr ++ ": " ++ show e) $ do + case paddr of + DatagramAddress addr -> void $ S.sendTo sock msg addr #ifdef ENABLE_ICE_SUPPORT - PeerIceSession ice -> iceSend ice msg + PeerIceSession ice -> iceSend ice msg #endif forkServerThread server $ forever $ do @@ -421,12 +433,18 @@ instance MonadFail PacketHandler where runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM () runPacketHandler secure peer@Peer {..} act = do let logd = writeTQueue $ serverErrorLog peerServer_ - runExceptT (flip execStateT (PacketHandlerState peer [] [] [] False) $ unPacketHandler act) >>= \case + runExceptT (flip execStateT (PacketHandlerState peer [] [] [] Nothing False) $ unPacketHandler act) >>= \case Left err -> do logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err Right ph -> do when (not $ null $ phHead ph) $ do - let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph) + body <- case phBodyStream ph of + Nothing -> return $ phBody ph + Just stream -> do + writeTQueue (serverIOActions peerServer_) $ void $ liftIO $ forkIO $ do + writeByteStringToStream stream $ BL.concat $ map lazyLoadBytes $ phBody ph + return [] + let packet = TransportPacket (TransportHeader $ phHead ph) body secreq = case (secure, phPlaintextReply ph) of (True, _) -> EncryptedOnly (False, False) -> PlaintextAllowed @@ -450,6 +468,7 @@ data PacketHandlerState = PacketHandlerState , phHead :: [TransportHeaderItem] , phAckedBy :: [TransportHeaderItem] , phBody :: [Ref] + , phBodyStream :: Maybe RawStreamWriter , phPlaintextReply :: Bool } @@ -462,6 +481,14 @@ addAckedBy hs = modify $ \ph -> ph { phAckedBy = foldr appendDistinct (phAckedBy addBody :: Ref -> PacketHandler () addBody r = modify $ \ph -> ph { phBody = r `appendDistinct` phBody ph } +sendBodyAsStream :: PacketHandler () +sendBodyAsStream = do + gets phBodyStream >>= \case + Nothing -> do + stream <- openStream + modify $ \ph -> ph { phBodyStream = Just stream } + Just _ -> return () + keepPlaintextReply :: PacketHandler () keepPlaintextReply = modify $ \ph -> ph { phPlaintextReply = True } @@ -517,8 +544,12 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = liftSTM $ finalizedChannel peer ch identity _ -> return () - Rejected dgst -> do - logd $ "rejected by peer: " ++ show dgst + Rejected dgst + | peerRequest : _ <- mapMaybe (\case TrChannelRequest d -> Just d; _ -> Nothing) headers + , peerRequest < dgst + -> return () -- Our request was rejected due to lower priority + + | otherwise -> logd $ "rejected by peer: " ++ show dgst DataRequest dgst | secure || dgst `elem` plaintextRefs -> do @@ -532,15 +563,11 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = -- otherwise lost the channel, so keep the reply plaintext as well. when (not secure) keepPlaintextReply - let bytes = lazyLoadBytes mref -- TODO: MTU - if (secure && BL.length bytes > 500) - then do - stream <- openStream - liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do - writeByteStringToStream stream bytes - else do - addBody $ mref + when (secure && BL.length (lazyLoadBytes mref) > 500) + sendBodyAsStream + + addBody $ mref | otherwise -> do logd $ "unauthorized data request for " ++ show dgst addHeader $ Rejected dgst @@ -593,9 +620,15 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = ChannelCookieWait {} -> return () ChannelCookieReceived {} -> process ChannelCookieConfirmed {} -> process - ChannelOurRequest our | dgst < refDigest (storedRef our) -> process - | otherwise -> reject - ChannelPeerRequest {} -> process + ChannelOurRequest our + | dgst < refDigest (storedRef our) -> process + | otherwise -> do + -- Reject peer channel request with lower priority + addHeader $ TrChannelRequest $ refDigest $ storedRef our + reject + ChannelPeerRequest prev + | dgst == wrDigest prev -> addHeader $ Acknowledged dgst + | otherwise -> process ChannelOurAccept {} -> reject ChannelEstablished {} -> process ChannelClosed {} -> return () @@ -647,12 +680,14 @@ setupChannel identity peer upid = do [ TrChannelRequest reqref , AnnounceSelf $ refDigest $ storedRef $ idData identity ] + let sendChannelRequest = do + sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ + TransportPacket (TransportHeader hitems) [storedRef req] + setPeerChannel peer $ ChannelOurRequest req liftIO $ atomically $ do getPeerChannel peer >>= \case - ChannelCookieConfirmed -> do - sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $ - TransportPacket (TransportHeader hitems) [storedRef req] - setPeerChannel peer $ ChannelOurRequest req + ChannelCookieReceived -> sendChannelRequest + ChannelCookieConfirmed -> sendChannelRequest _ -> return () handleChannelRequest :: Peer -> UnifiedIdentity -> Ref -> WaitingRefCallback @@ -806,10 +841,16 @@ isPeerDropped peer = liftIO $ atomically $ readTVar (peerState peer) >>= \case _ -> return False sendToPeer :: (Service s, MonadIO m) => Peer -> s -> m () -sendToPeer peer packet = sendToPeerList peer [ServiceReply (Left packet) True] +sendToPeer peer = sendManyToPeer peer . (: []) + +sendManyToPeer :: (Service s, MonadIO m) => Peer -> [ s ] -> m () +sendManyToPeer peer = sendToPeerList peer . map (\part -> ServiceReply (Left part) True) sendToPeerStored :: (Service s, MonadIO m) => Peer -> Stored s -> m () -sendToPeerStored peer spacket = sendToPeerList peer [ServiceReply (Right spacket) True] +sendToPeerStored peer = sendManyToPeerStored peer . (: []) + +sendManyToPeerStored :: (Service s, MonadIO m) => Peer -> [ Stored s ] -> m () +sendManyToPeerStored peer = sendToPeerList peer . map (\part -> ServiceReply (Right part) True) sendToPeerList :: (Service s, MonadIO m) => Peer -> [ServiceReply s] -> m () sendToPeerList peer parts = do @@ -912,9 +953,19 @@ runPeerServiceOn mbservice peer handler = liftIO $ do logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" +foreign import ccall unsafe "Network/ifaddrs.h join_multicast" cJoinMulticast :: CInt -> Ptr CSize -> IO (Ptr Word32) foreign import ccall unsafe "Network/ifaddrs.h broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32) foreign import ccall unsafe "stdlib.h free" cFree :: Ptr Word32 -> IO () +joinMulticast :: Socket -> IO [ Word32 ] +joinMulticast sock = + withFdSocket sock $ \fd -> + alloca $ \pcount -> do + ptr <- cJoinMulticast fd pcount + count <- fromIntegral <$> peek pcount + forM [ 0 .. count - 1 ] $ \i -> + peekElemOff ptr i + getBroadcastAddresses :: PortNumber -> IO [SockAddr] getBroadcastAddresses port = do ptr <- cBroadcastAddresses @@ -922,6 +973,9 @@ getBroadcastAddresses port = do w <- peekElemOff ptr i if w == 0 then return [] else (SockAddrInet port w:) <$> parse (i + 1) - addrs <- parse 0 - cFree ptr - return addrs + if ptr == nullPtr + then return [] + else do + addrs <- parse 0 + cFree ptr + return addrs diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index a009ad1..2955473 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -40,7 +40,17 @@ import Control.Monad import Control.Monad.Except import Control.Monad.Trans +import Crypto.Cipher.ChaChaPoly1305 qualified as C +import Crypto.MAC.Poly1305 qualified as C (Auth(..), authTag) +import Crypto.Error +import Crypto.Random + +import Data.Binary +import Data.Binary.Get +import Data.Binary.Put import Data.Bits +import Data.ByteArray (Bytes, ScrubbedBytes) +import Data.ByteArray qualified as BA import Data.ByteString (ByteString) import Data.ByteString qualified as B import Data.ByteString.Char8 qualified as BC @@ -51,7 +61,6 @@ import Data.Maybe import Data.Text (Text) import Data.Text qualified as T import Data.Void -import Data.Word import System.Clock @@ -68,6 +77,9 @@ protocolVersion = T.pack "0.1" protocolVersions :: [Text] protocolVersions = [protocolVersion] +keepAliveInternal :: TimeSpec +keepAliveInternal = fromNanoSecs $ 30 * 10^(9 :: Int) + data TransportPacket a = TransportPacket TransportHeader [a] @@ -93,14 +105,41 @@ data TransportHeaderItem | StreamOpen Word8 deriving (Eq, Show) -newtype Cookie = Cookie ByteString - deriving (Eq, Show) - data SecurityRequirement = PlaintextOnly | PlaintextAllowed | EncryptedOnly deriving (Eq, Ord) +data Cookie = Cookie + { cookieNonce :: C.Nonce + , cookieValidity :: Word32 + , cookieContent :: ByteString + , cookieMac :: C.Auth + } + +instance Eq Cookie where + (==) = (==) `on` (\c -> ( BA.convert (cookieNonce c) :: ByteString, cookieValidity c, cookieContent c, cookieMac c )) + + +instance Show Cookie where + show Cookie {..} = show (nonce, cookieValidity, cookieContent, mac) + where C.Auth mac = cookieMac + nonce = BA.convert cookieNonce :: ByteString + +instance Binary Cookie where + put Cookie {..} = do + putByteString $ BA.convert cookieNonce + putWord32be cookieValidity + putByteString $ BA.convert cookieMac + putByteString cookieContent + + get = do + Just cookieNonce <- maybeCryptoError . C.nonce12 <$> getByteString 12 + cookieValidity <- getWord32be + Just cookieMac <- maybeCryptoError . C.authTag <$> getByteString 16 + cookieContent <- BL.toStrict <$> getRemainingLazyByteString + return Cookie {..} + isHeaderItemAcknowledged :: TransportHeaderItem -> Bool isHeaderItemAcknowledged = \case Acknowledged {} -> False @@ -120,8 +159,8 @@ transportToObject st (TransportHeader items) = Rec $ map single items Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst) ProtocolVersion ver -> (BC.pack "VER", RecText ver) Initiation dgst -> (BC.pack "INI", RecRef $ partialRefFromDigest st dgst) - CookieSet (Cookie bytes) -> (BC.pack "CKS", RecBinary bytes) - CookieEcho (Cookie bytes) -> (BC.pack "CKE", RecBinary bytes) + CookieSet cookie -> (BC.pack "CKS", RecBinary $ BL.toStrict $ encode cookie) + CookieEcho cookie -> (BC.pack "CKE", RecBinary $ BL.toStrict $ encode cookie) DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst) DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst) AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst) @@ -142,8 +181,12 @@ transportFromObject (Rec items) = case catMaybes $ map single items of | name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref | name == BC.pack "VER", RecText ver <- content -> Just $ ProtocolVersion ver | name == BC.pack "INI", RecRef ref <- content -> Just $ Initiation $ refDigest ref - | name == BC.pack "CKS", RecBinary bytes <- content -> Just $ CookieSet (Cookie bytes) - | name == BC.pack "CKE", RecBinary bytes <- content -> Just $ CookieEcho (Cookie bytes) + | name == BC.pack "CKS", RecBinary bytes <- content + , Right (_, _, cookie) <- decodeOrFail (BL.fromStrict bytes) + -> Just $ CookieSet cookie + | name == BC.pack "CKE", RecBinary bytes <- content + , Right (_, _, cookie) <- decodeOrFail (BL.fromStrict bytes) + -> Just $ CookieEcho cookie | name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref | name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref | name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref @@ -165,9 +208,12 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState , gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject)) , gLog :: String -> STM () , gStorage :: PartialStorage + , gStartTime :: TimeSpec , gNowVar :: TVar TimeSpec , gNextTimeout :: TVar TimeSpec , gInitConfig :: Ref + , gCookieKey :: ScrubbedBytes + , gCookieStartTime :: Word32 } data Connection addr = Connection @@ -186,6 +232,7 @@ data Connection addr = Connection , cReservedPackets :: TVar Int , cSentPackets :: TVar [SentPacket] , cToAcknowledge :: TVar [Integer] + , cNextKeepAlive :: TVar (Maybe TimeSpec) , cInStreams :: TVar [(Word8, Stream)] , cOutStreams :: TVar [(Word8, Stream)] } @@ -440,15 +487,18 @@ erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do mStorage <- memoryStorage gStorage <- derivePartialStorage mStorage - startTime <- getTime MonotonicRaw - gNowVar <- newTVarIO startTime - gNextTimeout <- newTVarIO startTime + gStartTime <- getTime Monotonic + gNowVar <- newTVarIO gStartTime + gNextTimeout <- newTVarIO gStartTime gInitConfig <- store mStorage $ (Rec [] :: Object) + gCookieKey <- getRandomBytes 32 + gCookieStartTime <- runGet getWord32host . BL.pack . BA.unpack @ScrubbedBytes <$> getRandomBytes 4 + let gs = GlobalState {..} let signalTimeouts = forever $ do - now <- getTime MonotonicRaw + now <- getTime Monotonic next <- atomically $ do writeTVar gNowVar now readTVar gNextTimeout @@ -487,6 +537,7 @@ newConnection cGlobalState@GlobalState {..} addr = do cReservedPackets <- newTVar 0 cSentPackets <- newTVar [] cToAcknowledge <- newTVar [] + cNextKeepAlive <- newTVar Nothing cInStreams <- newTVar [] cOutStreams <- newTVar [] let conn = Connection {..} @@ -548,6 +599,7 @@ processIncoming gs@GlobalState {..} = do Nothing -> throwError "empty packet" + now <- getTime Monotonic runExceptT parse >>= \case Right (Left (secure, objs, mbcounter)) | hobj:content <- objs @@ -562,6 +614,7 @@ processIncoming gs@GlobalState {..} = do case mbup of Just up -> putTMVar gNextUp (conn, (secure, up)) Nothing -> return () + updateKeepAlive conn now processAcknowledgements gs conn items ioAfter Nothing -> return () @@ -571,8 +624,9 @@ processIncoming gs@GlobalState {..} = do gLog $ show objs Right (Right (snum, seq8, content, counter)) - | Just Connection {..} <- mbconn + | Just conn@Connection {..} <- mbconn -> atomically $ do + updateKeepAlive conn now (lookup snum <$> readTVar cInStreams) >>= \case Nothing -> gLog $ "unexpected stream number " ++ show snum @@ -694,11 +748,36 @@ generateCookieHeaders Connection {..} ch = catMaybes <$> sequence [ echoHeader, _ -> return Nothing createCookie :: GlobalState addr -> addr -> IO Cookie -createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr) +createCookie GlobalState {..} addr = do + (nonceBytes :: Bytes) <- getRandomBytes 12 + validUntil <- (fromNanoSecs (60 * 10^(9 :: Int)) +) <$> getTime Monotonic + let validSecondsFromStart = fromIntegral $ toNanoSecs (validUntil - gStartTime) `div` (10^(9 :: Int)) + cookieValidity = validSecondsFromStart - gCookieStartTime + plainContent = BC.pack (show addr) + throwCryptoErrorIO $ do + cookieNonce <- C.nonce12 nonceBytes + st1 <- C.initialize gCookieKey cookieNonce + let st2 = C.finalizeAAD $ C.appendAAD (BL.toStrict $ runPut $ putWord32be cookieValidity) st1 + (cookieContent, st3) = C.encrypt plainContent st2 + cookieMac = C.finalize st3 + return $ Cookie {..} verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool -verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie - +verifyCookie GlobalState {..} addr Cookie {..} = do + ctime <- getTime Monotonic + return $ fromMaybe False $ maybeCryptoError $ do + st1 <- C.initialize gCookieKey cookieNonce + let st2 = C.finalizeAAD $ C.appendAAD (BL.toStrict $ runPut $ putWord32be cookieValidity) st1 + (plainContent, st3) = C.decrypt cookieContent st2 + mac = C.finalize st3 + + validSecondsFromStart = fromIntegral $ cookieValidity + gCookieStartTime + validUntil = gStartTime + fromNanoSecs (validSecondsFromStart * (10^(9 :: Int))) + return $ and + [ mac == cookieMac + , ctime <= validUntil + , show addr == BC.unpack plainContent + ] reservePacket :: Connection addr -> STM ReservedToSend reservePacket conn@Connection {..} = do @@ -713,9 +792,9 @@ reservePacket conn@Connection {..} = do return $ ReservedToSend Nothing (return ()) (atomically $ connClose conn) resendBytes :: Connection addr -> Maybe ReservedToSend -> SentPacket -> IO () -resendBytes Connection {..} reserved sp = do +resendBytes conn@Connection {..} reserved sp = do let GlobalState {..} = cGlobalState - now <- getTime MonotonicRaw + now <- getTime Monotonic atomically $ do when (isJust reserved) $ do modifyTVar' cReservedPackets (subtract 1) @@ -726,6 +805,7 @@ resendBytes Connection {..} reserved sp = do , spRetryCount = spRetryCount sp + 1 } writeFlow gDataFlow (cAddress, spData sp) + updateKeepAlive conn now sendBytes :: Connection addr -> Maybe ReservedToSend -> ByteString -> IO () sendBytes conn reserved bs = resendBytes conn reserved @@ -738,6 +818,12 @@ sendBytes conn reserved bs = resendBytes conn reserved , spData = bs } +updateKeepAlive :: Connection addr -> TimeSpec -> STM () +updateKeepAlive Connection {..} now = do + let next = now + keepAliveInternal + writeTVar cNextKeepAlive $ Just next + + processOutgoing :: forall addr. GlobalState addr -> STM (IO ()) processOutgoing gs@GlobalState {..} = do @@ -777,11 +863,12 @@ processOutgoing gs@GlobalState {..} = do let onAck = sequence_ $ map (streamAccepted conn) $ catMaybes (map (\case StreamOpen n -> Just n; _ -> Nothing) hitems) - let mkPlain extraHeaders = - let header = TransportHeader $ map AcknowledgedSingle acknowledge ++ extraHeaders ++ hitems - in BL.concat $ - (serializeObject $ transportToObject gStorage header) - : map lazyLoadBytes content + let mkPlain extraHeaders + | combinedHeaderItems@(_:_) <- map AcknowledgedSingle acknowledge ++ extraHeaders ++ hitems = + BL.concat $ + (serializeObject $ transportToObject gStorage $ TransportHeader combinedHeaderItems) + : map lazyLoadBytes content + | otherwise = BL.empty let usePlaintext = do plain <- mkPlain <$> generateCookieHeaders conn channel @@ -811,6 +898,13 @@ processOutgoing gs@GlobalState {..} = do sendBytes conn mbReserved' bs Nothing -> return () + let waitUntil :: TimeSpec -> TimeSpec -> STM () + waitUntil now till = do + nextTimeout <- readTVar gNextTimeout + if nextTimeout <= now || till < nextTimeout + then writeTVar gNextTimeout till + else retry + let retransmitPacket :: Connection addr -> STM (IO ()) retransmitPacket conn@Connection {..} = do now <- readTVar gNowVar @@ -819,11 +913,8 @@ processOutgoing gs@GlobalState {..} = do _ -> retry let nextTry = spTime sp + fromNanoSecs 1000000000 if | now < nextTry -> do - nextTimeout <- readTVar gNextTimeout - if nextTimeout <= now || nextTry < nextTimeout - then do writeTVar gNextTimeout nextTry - return $ return () - else retry + waitUntil now nextTry + return $ return () | spRetryCount sp < 2 -> do reserved <- reservePacket conn writeTVar cSentPackets rest @@ -863,11 +954,28 @@ processOutgoing gs@GlobalState {..} = do writeTVar gIdentity (nid, cur : past) return $ return () + let sendKeepAlive :: Connection addr -> STM (IO ()) + sendKeepAlive Connection {..} = do + readTVar cNextKeepAlive >>= \case + Nothing -> retry + Just next -> do + now <- readTVar gNowVar + if next <= now + then do + writeTVar cNextKeepAlive Nothing + identity <- fst <$> readTVar gIdentity + let header = TransportHeader [ AnnounceSelf $ refDigest $ storedRef $ idData identity ] + writeTQueue cSecureOutQueue (EncryptedOnly, TransportPacket header [], []) + else do + waitUntil now next + return $ return () + conns <- readTVar gConnections msum $ concat $ [ map retransmitPacket conns , map sendNextPacket conns , [ handleControlRequests ] + , map sendKeepAlive conns ] processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM (IO ()) diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c index 37c3e00..70685bc 100644 --- a/src/Erebos/Network/ifaddrs.c +++ b/src/Erebos/Network/ifaddrs.c @@ -1,11 +1,89 @@ #include "ifaddrs.h" +#include <errno.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#ifndef _WIN32 #include <arpa/inet.h> -#include <ifaddrs.h> #include <net/if.h> -#include <stdlib.h> -#include <sys/types.h> +#include <ifaddrs.h> #include <endian.h> +#include <sys/types.h> +#include <sys/socket.h> +#else +#include <winsock2.h> +#include <ws2ipdef.h> +#include <ws2tcpip.h> +#endif + +#define DISCOVERY_MULTICAST_GROUP "ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1" + +uint32_t * join_multicast(int fd, size_t * count) +{ + size_t capacity = 16; + *count = 0; + uint32_t * interfaces = malloc(sizeof(uint32_t) * capacity); + +#ifdef _WIN32 + interfaces[0] = 0; + *count = 1; +#else + struct ifaddrs * addrs; + if (getifaddrs(&addrs) < 0) + return 0; + + for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET6 && + !(ifa->ifa_flags & IFF_LOOPBACK)) { + int idx = if_nametoindex(ifa->ifa_name); + + bool seen = false; + for (size_t i = 0; i < *count; i++) { + if (interfaces[i] == idx) { + seen = true; + break; + } + } + if (seen) + continue; + + if (*count + 1 >= capacity) { + capacity *= 2; + uint32_t * nret = realloc(interfaces, sizeof(uint32_t) * capacity); + if (nret) { + interfaces = nret; + } else { + free(interfaces); + *count = 0; + return NULL; + } + } + + interfaces[*count] = idx; + (*count)++; + } + } + + freeifaddrs(addrs); +#endif + + for (size_t i = 0; i < *count; i++) { + struct ipv6_mreq group; + group.ipv6mr_interface = interfaces[i]; + inet_pton(AF_INET6, DISCOVERY_MULTICAST_GROUP, &group.ipv6mr_multiaddr); + int ret = setsockopt(fd, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, + (const void *) &group, sizeof(group)); + if (ret < 0) + fprintf(stderr, "IPV6_ADD_MEMBERSHIP failed: %s\n", strerror(errno)); + } + + return interfaces; +} + +#ifndef _WIN32 uint32_t * broadcast_addresses(void) { @@ -39,3 +117,51 @@ uint32_t * broadcast_addresses(void) ret[count] = 0; return ret; } + +#else // _WIN32 + +#include <winsock2.h> +#include <ws2tcpip.h> + +#pragma comment(lib, "ws2_32.lib") + +uint32_t * broadcast_addresses(void) +{ + uint32_t * ret = NULL; + SOCKET wsock = INVALID_SOCKET; + + struct WSAData wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) + return NULL; + + wsock = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, 0); + if (wsock == INVALID_SOCKET) + goto cleanup; + + INTERFACE_INFO InterfaceList[32]; + unsigned long nBytesReturned; + + if (WSAIoctl(wsock, SIO_GET_INTERFACE_LIST, 0, 0, + InterfaceList, sizeof(InterfaceList), + &nBytesReturned, 0, 0) == SOCKET_ERROR) + goto cleanup; + + int numInterfaces = nBytesReturned / sizeof(INTERFACE_INFO); + + size_t capacity = 16, count = 0; + ret = malloc(sizeof(uint32_t) * capacity); + + for (int i = 0; i < numInterfaces && count < capacity - 1; i++) + if (InterfaceList[i].iiFlags & IFF_BROADCAST) + ret[count++] = InterfaceList[i].iiBroadcastAddress.AddressIn.sin_addr.s_addr; + + ret[count] = 0; +cleanup: + if (wsock != INVALID_SOCKET) + closesocket(wsock); + WSACleanup(); + + return ret; +} + +#endif diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h index 06d26ec..8852ec6 100644 --- a/src/Erebos/Network/ifaddrs.h +++ b/src/Erebos/Network/ifaddrs.h @@ -1,3 +1,5 @@ +#include <stddef.h> #include <stdint.h> +uint32_t * join_multicast(int fd, size_t * count); uint32_t * broadcast_addresses(void); diff --git a/src/Erebos/Storage/Internal.hs b/src/Erebos/Storage/Internal.hs index d419a5e..8b794d8 100644 --- a/src/Erebos/Storage/Internal.hs +++ b/src/Erebos/Storage/Internal.hs @@ -241,7 +241,7 @@ writeFileOnce file content = bracket (openLockFile locked) doesFileExist file >>= \case True -> removeFile locked False -> do BL.hPut h content - hFlush h + hClose h renameFile locked file where locked = file ++ ".lock" @@ -254,13 +254,13 @@ writeFileChecked file prev content = bracket (openLockFile locked) removeFile locked return $ Left $ Just current (Nothing, False) -> do B.hPut h content - hFlush h + hClose h renameFile locked file return $ Right () (Just expected, True) -> do current <- B.readFile file if current == expected then do B.hPut h content - hFlush h + hClose h renameFile locked file return $ return () else do removeFile locked diff --git a/src/Erebos/Storage/Key.hs b/src/Erebos/Storage/Key.hs index b6afc20..5da79e3 100644 --- a/src/Erebos/Storage/Key.hs +++ b/src/Erebos/Storage/Key.hs @@ -80,6 +80,7 @@ moveKeys from to = liftIO $ do return M.empty (StorageMemory { memKeys = fromKeys }, StorageMemory { memKeys = toKeys }) -> do - modifyMVar_ fromKeys $ \fkeys -> do - modifyMVar_ toKeys $ return . M.union fkeys - return M.empty + when (fromKeys /= toKeys) $ do + modifyMVar_ fromKeys $ \fkeys -> do + modifyMVar_ toKeys $ return . M.union fkeys + return M.empty diff --git a/src/Erebos/Storage/Merge.hs b/src/Erebos/Storage/Merge.hs index 9d9db13..a3b0fd7 100644 --- a/src/Erebos/Storage/Merge.hs +++ b/src/Erebos/Storage/Merge.hs @@ -97,13 +97,16 @@ storedGeneration x = doLookup x +-- |Returns list of sets starting with the set of given objects and +-- intcrementally adding parents. generations :: Storable a => [Stored a] -> [Set (Stored a)] generations = unfoldr gen . (,S.empty) - where gen (hs, cur) = case filter (`S.notMember` cur) $ previous =<< hs of + where gen (hs, cur) = case filter (`S.notMember` cur) hs of [] -> Nothing added -> let next = foldr S.insert cur added - in Just (next, (added, next)) + in Just (next, (previous =<< added, next)) +-- |Returns set containing all given objects and their ancestors ancestors :: Storable a => [Stored a] -> Set (Stored a) ancestors = last . (S.empty:) . generations |