summaryrefslogtreecommitdiff
path: root/src/Erebos
diff options
context:
space:
mode:
Diffstat (limited to 'src/Erebos')
-rw-r--r--src/Erebos/Chatroom.hs216
-rw-r--r--src/Erebos/Conversation.hs36
-rw-r--r--src/Erebos/Identity.hs33
-rw-r--r--src/Erebos/Network.hs120
-rw-r--r--src/Erebos/Network/Protocol.hs164
-rw-r--r--src/Erebos/Network/ifaddrs.c132
-rw-r--r--src/Erebos/Network/ifaddrs.h2
-rw-r--r--src/Erebos/Storage/Internal.hs6
-rw-r--r--src/Erebos/Storage/Key.hs7
-rw-r--r--src/Erebos/Storage/Merge.hs7
10 files changed, 596 insertions, 127 deletions
diff --git a/src/Erebos/Chatroom.hs b/src/Erebos/Chatroom.hs
index 673c59f..c8b5805 100644
--- a/src/Erebos/Chatroom.hs
+++ b/src/Erebos/Chatroom.hs
@@ -11,14 +11,20 @@ module Erebos.Chatroom (
findChatroomByRoomData,
findChatroomByStateData,
chatroomSetSubscribe,
+ chatroomMembers,
+ joinChatroom, joinChatroomByStateData,
+ leaveChatroom, leaveChatroomByStateData,
getMessagesSinceState,
ChatroomSetChange(..),
watchChatrooms,
- ChatMessage, cmsgFrom, cmsgReplyTo, cmsgTime, cmsgText, cmsgLeave,
+ ChatMessage,
+ cmsgFrom, cmsgReplyTo, cmsgTime, cmsgText, cmsgLeave,
+ cmsgRoom, cmsgRoomData,
ChatMessageData(..),
- chatroomMessageByStateData,
+ sendChatroomMessage,
+ sendChatroomMessageByStateData,
ChatroomService(..),
) where
@@ -29,6 +35,9 @@ import Control.Monad.Except
import Control.Monad.IO.Class
import Data.Bool
+import Data.Either
+import Data.Foldable
+import Data.Function
import Data.IORef
import Data.List
import Data.Maybe
@@ -111,6 +120,11 @@ data ChatMessage = ChatMessage
{ cmsgData :: Stored (Signed ChatMessageData)
}
+validateSingleMessage :: Stored (Signed ChatMessageData) -> Maybe ChatMessage
+validateSingleMessage sdata = do
+ guard $ fromStored sdata `isSignedBy` idKeyMessage (mdFrom (fromSigned sdata))
+ return $ ChatMessage sdata
+
cmsgFrom :: ChatMessage -> ComposedIdentity
cmsgFrom = mdFrom . fromSigned . cmsgData
@@ -126,6 +140,12 @@ cmsgText = mdText . fromSigned . cmsgData
cmsgLeave :: ChatMessage -> Bool
cmsgLeave = mdLeave . fromSigned . cmsgData
+cmsgRoom :: ChatMessage -> Maybe Chatroom
+cmsgRoom = either (const Nothing) Just . runExcept . validateChatroom . cmsgRoomData
+
+cmsgRoomData :: ChatMessage -> [ Stored (Signed ChatroomData) ]
+cmsgRoomData = concat . findProperty ((\case [] -> Nothing; xs -> Just xs) . mdRoom . fromStored . signedData) . (: []) . cmsgData
+
instance Storable ChatMessageData where
store' ChatMessageData {..} = storeRec $ do
mapM_ (storeRef "SPREV") mdPrev
@@ -146,37 +166,42 @@ instance Storable ChatMessageData where
mdLeave <- isJust <$> loadMbEmpty "leave"
return ChatMessageData {..}
-threadToList :: [Stored (Signed ChatMessageData)] -> [ChatMessage]
-threadToList thread = helper S.empty $ thread
+threadToListSince :: [ Stored (Signed ChatMessageData) ] -> [ Stored (Signed ChatMessageData) ] -> [ ChatMessage ]
+threadToListSince since thread = helper (S.fromList since) thread
where
helper :: S.Set (Stored (Signed ChatMessageData)) -> [Stored (Signed ChatMessageData)] -> [ChatMessage]
helper seen msgs
| msg : msgs' <- filter (`S.notMember` seen) $ reverse $ sortBy (comparing cmpView) msgs =
- messageFromData msg : helper (S.insert msg seen) (msgs' ++ mdPrev (fromSigned msg))
+ maybe id (:) (validateSingleMessage msg) $
+ helper (S.insert msg seen) (msgs' ++ mdPrev (fromSigned msg))
| otherwise = []
cmpView msg = (zonedTimeToUTC $ mdTime $ fromSigned msg, msg)
- messageFromData :: Stored (Signed ChatMessageData) -> ChatMessage
- messageFromData sdata = ChatMessage { cmsgData = sdata }
+sendChatroomMessage
+ :: (MonadStorage m, MonadHead LocalState m, MonadError String m)
+ => ChatroomState -> Text -> m ()
+sendChatroomMessage rstate msg = sendChatroomMessageByStateData (head $ roomStateData rstate) msg
-chatroomMessageByStateData
+sendChatroomMessageByStateData
:: (MonadStorage m, MonadHead LocalState m, MonadError String m)
=> Stored ChatroomStateData -> Text -> m ()
-chatroomMessageByStateData lookupData msg = void $ findAndUpdateChatroomState $ \cstate -> do
+sendChatroomMessageByStateData lookupData msg = sendRawChatroomMessageByStateData lookupData Nothing (Just msg) False
+
+sendRawChatroomMessageByStateData
+ :: (MonadStorage m, MonadHead LocalState m, MonadError String m)
+ => Stored ChatroomStateData -> Maybe (Stored (Signed ChatMessageData)) -> Maybe Text -> Bool -> m ()
+sendRawChatroomMessageByStateData lookupData mdReplyTo mdText mdLeave = void $ findAndUpdateChatroomState $ \cstate -> do
guard $ any (lookupData `precedesOrEquals`) $ roomStateData cstate
Just $ do
- self <- finalOwner . localIdentity . fromStored <$> getLocalHead
- secret <- loadKey $ idKeyMessage self
- time <- liftIO getZonedTime
- mdata <- mstore =<< sign secret =<< mstore ChatMessageData
- { mdPrev = roomStateMessageData cstate
- , mdRoom = []
- , mdFrom = self
- , mdReplyTo = Nothing
- , mdTime = time
- , mdText = Just msg
- , mdLeave = False
- }
+ mdFrom <- finalOwner . localIdentity . fromStored <$> getLocalHead
+ secret <- loadKey $ idKeyMessage mdFrom
+ mdTime <- liftIO getZonedTime
+ let mdPrev = roomStateMessageData cstate
+ mdRoom = if null (roomStateMessageData cstate)
+ then maybe [] roomData (roomStateRoom cstate)
+ else []
+
+ mdata <- mstore =<< sign secret =<< mstore ChatMessageData {..}
mergeSorted . (:[]) <$> mstore ChatroomStateData
{ rsdPrev = roomStateData cstate
, rsdRoom = []
@@ -224,7 +249,7 @@ instance Mergeable ChatroomState where
ChatroomStateData {..} | null rsdMessages -> Nothing
| otherwise -> Just rsdMessages
roomStateSubscribe = fromMaybe False $ findPropertyFirst rsdSubscribe roomStateData
- roomStateMessages = threadToList $ concatMap (rsdMessages . fromStored) roomStateData
+ roomStateMessages = threadToListSince [] $ concatMap (rsdMessages . fromStored) roomStateData
in ChatroomState {..}
toComponents = roomStateData
@@ -321,11 +346,38 @@ chatroomSetSubscribe lookupData subscribe = void $ findAndUpdateChatroomState $
, rsdMessages = []
}
+chatroomMembers :: ChatroomState -> [ ComposedIdentity ]
+chatroomMembers ChatroomState {..} =
+ map (mdFrom . fromSigned . head) $
+ filter (any $ not . mdLeave . fromSigned) $ -- keep only users that hasn't left
+ map (filterAncestors . map snd) $ -- gather message data per each identity and filter ancestors
+ groupBy ((==) `on` fst) $ -- group on identity root
+ sortBy (comparing fst) $ -- sort by first root of identity data
+ map (\x -> ( head . filterAncestors . concatMap storedRoots . idDataF . mdFrom . fromSigned $ x, x )) $
+ toList $ ancestors $ roomStateMessageData
+
+joinChatroom
+ :: (MonadStorage m, MonadHead LocalState m, MonadError String m)
+ => ChatroomState -> m ()
+joinChatroom rstate = joinChatroomByStateData (head $ roomStateData rstate)
+
+joinChatroomByStateData
+ :: (MonadStorage m, MonadHead LocalState m, MonadError String m)
+ => Stored ChatroomStateData -> m ()
+joinChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing False
+
+leaveChatroom
+ :: (MonadStorage m, MonadHead LocalState m, MonadError String m)
+ => ChatroomState -> m ()
+leaveChatroom rstate = leaveChatroomByStateData (head $ roomStateData rstate)
+
+leaveChatroomByStateData
+ :: (MonadStorage m, MonadHead LocalState m, MonadError String m)
+ => Stored ChatroomStateData -> m ()
+leaveChatroomByStateData lookupData = sendRawChatroomMessageByStateData lookupData Nothing Nothing True
+
getMessagesSinceState :: ChatroomState -> ChatroomState -> [ChatMessage]
-getMessagesSinceState cur old = takeWhile notOld (roomStateMessages cur)
- where
- notOld msg = cmsgData msg `notElem` roomStateMessageData old
- -- TODO: parallel message threads
+getMessagesSinceState cur old = threadToListSince (roomStateMessageData old) (roomStateMessageData cur)
data ChatroomSetChange = AddedChatroom ChatroomState
@@ -365,13 +417,18 @@ makeChatroomDiff [] ys = map (AddedChatroom . snd) ys
data ChatroomService = ChatroomService
{ chatRoomQuery :: Bool
, chatRoomInfo :: [Stored (Signed ChatroomData)]
+ , chatRoomSubscribe :: [Stored (Signed ChatroomData)]
+ , chatRoomUnsubscribe :: [Stored (Signed ChatroomData)]
, chatRoomMessage :: [Stored (Signed ChatMessageData)]
}
+ deriving (Eq)
emptyPacket :: ChatroomService
emptyPacket = ChatroomService
{ chatRoomQuery = False
, chatRoomInfo = []
+ , chatRoomSubscribe = []
+ , chatRoomUnsubscribe = []
, chatRoomMessage = []
}
@@ -379,17 +436,22 @@ instance Storable ChatroomService where
store' ChatroomService {..} = storeRec $ do
when chatRoomQuery $ storeEmpty "room-query"
forM_ chatRoomInfo $ storeRef "room-info"
+ forM_ chatRoomSubscribe $ storeRef "room-subscribe"
+ forM_ chatRoomUnsubscribe $ storeRef "room-unsubscribe"
forM_ chatRoomMessage $ storeRef "room-message"
load' = loadRec $ do
chatRoomQuery <- isJust <$> loadMbEmpty "room-query"
chatRoomInfo <- loadRefs "room-info"
+ chatRoomSubscribe <- loadRefs "room-subscribe"
+ chatRoomUnsubscribe <- loadRefs "room-unsubscribe"
chatRoomMessage <- loadRefs "room-message"
return ChatroomService {..}
data PeerState = PeerState
{ psSendRoomUpdates :: Bool
, psLastList :: [(Stored ChatroomStateData, ChatroomState)]
+ , psSubscribedTo :: [ Stored (Signed ChatroomData) ] -- least root for each room
}
instance Service ChatroomService where
@@ -399,12 +461,18 @@ instance Service ChatroomService where
emptyServiceState _ = PeerState
{ psSendRoomUpdates = False
, psLastList = []
+ , psSubscribedTo = []
}
serviceHandler spacket = do
let ChatroomService {..} = fromStored spacket
+
+ previouslyUpdated <- psSendRoomUpdates <$> svcGet
svcModify $ \s -> s { psSendRoomUpdates = True }
+ when (not previouslyUpdated) $ do
+ syncChatroomsToPeer . lookupSharedValue . lsShared . fromStored =<< getLocalHead
+
when chatRoomQuery $ do
rooms <- listChatrooms
replyPacket emptyPacket
@@ -420,7 +488,7 @@ instance Service ChatroomService where
maybe [] roomData . roomStateRoom
let prev = concatMap roomStateData $ filter isCurrentRoom rooms
- prevRoom = concatMap (rsdRoom . fromStored) prev
+ prevRoom = filterAncestors $ concat $ findProperty ((\case [] -> Nothing; xs -> Just xs) . rsdRoom) prev
room = filterAncestors $ (roomInfo : ) prevRoom
-- update local state only if we got roomInfo not present there
@@ -436,6 +504,51 @@ instance Service ChatroomService where
else return set
foldM upd roomSet chatRoomInfo
+ forM_ chatRoomSubscribe $ \subscribeData -> do
+ mbRoomState <- findChatroomByRoomData subscribeData
+ forM_ mbRoomState $ \roomState ->
+ forM (roomStateRoom roomState) $ \room -> do
+ let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room
+ svcModify $ \ps -> ps { psSubscribedTo = leastRoot : psSubscribedTo ps }
+ replyPacket emptyPacket
+ { chatRoomMessage = roomStateMessageData roomState
+ }
+
+ forM_ chatRoomUnsubscribe $ \unsubscribeData -> do
+ mbRoomState <- findChatroomByRoomData unsubscribeData
+ forM_ (mbRoomState >>= roomStateRoom) $ \room -> do
+ let leastRoot = head . filterAncestors . concatMap storedRoots . roomData $ room
+ svcModify $ \ps -> ps { psSubscribedTo = filter (/= leastRoot) (psSubscribedTo ps) }
+
+ when (not (null chatRoomMessage)) $ do
+ updateLocalHead_ $ updateSharedState_ $ \roomSet -> do
+ let rooms = fromSetBy (comparing $ roomName <=< roomStateRoom) roomSet
+ upd set (msgData :: Stored (Signed ChatMessageData))
+ | Just msg <- validateSingleMessage msgData = do
+ let roomInfo = cmsgRoomData msg
+ currentRoots = filterAncestors $ concatMap storedRoots roomInfo
+ isCurrentRoom = any ((`intersectsSorted` currentRoots) . storedRoots) .
+ maybe [] roomData . roomStateRoom
+
+ let prevData = concatMap roomStateData $ filter isCurrentRoom rooms
+ prev = mergeSorted prevData
+ prevMessages = roomStateMessageData prev
+ messages = filterAncestors $ msgData : prevMessages
+
+ -- update local state only if subscribed and we got some new messages
+ if roomStateSubscribe prev && messages /= prevMessages
+ then do
+ sdata <- mstore ChatroomStateData
+ { rsdPrev = prevData
+ , rsdRoom = []
+ , rsdSubscribe = Nothing
+ , rsdMessages = messages
+ }
+ storeSetAddComponent sdata set
+ else return set
+ | otherwise = return set
+ foldM upd roomSet chatRoomMessage
+
serviceNewPeer = do
replyPacket emptyPacket { chatRoomQuery = True }
@@ -447,11 +560,50 @@ syncChatroomsToPeer set = do
ps@PeerState {..} <- svcGet
when psSendRoomUpdates $ do
let curList = chatroomSetToList set
- updates <- fmap (concat . catMaybes) $
- forM (makeChatroomDiff psLastList curList) $ return . \case
+ diff = makeChatroomDiff psLastList curList
+
+ roomUpdates <- fmap (concat . catMaybes) $
+ forM diff $ return . \case
AddedChatroom room -> roomData <$> roomStateRoom room
RemovedChatroom {} -> Nothing
- UpdatedChatroom _ room -> roomData <$> roomStateRoom room
- when (not $ null updates) $ do
- replyPacket $ emptyPacket { chatRoomInfo = updates }
+ UpdatedChatroom oldroom room
+ | roomStateData oldroom /= roomStateData room -> roomData <$> roomStateRoom room
+ | otherwise -> Nothing
+
+ (subscribe, unsubscribe) <- fmap (partitionEithers . concat . catMaybes) $
+ forM diff $ return . \case
+ AddedChatroom room
+ | roomStateSubscribe room
+ -> map Left . roomData <$> roomStateRoom room
+ RemovedChatroom oldroom
+ | roomStateSubscribe oldroom
+ -> map Right . roomData <$> roomStateRoom oldroom
+ UpdatedChatroom oldroom room
+ | roomStateSubscribe oldroom /= roomStateSubscribe room
+ -> map (if roomStateSubscribe room then Left else Right) . roomData <$> roomStateRoom room
+ _ -> Nothing
+
+ messages <- fmap concat $ do
+ let leastRootFor = head . filterAncestors . concatMap storedRoots . roomData
+ forM diff $ return . \case
+ AddedChatroom rstate
+ | Just room <- roomStateRoom rstate
+ , leastRootFor room `elem` psSubscribedTo
+ -> roomStateMessageData rstate
+ UpdatedChatroom oldstate rstate
+ | Just room <- roomStateRoom rstate
+ , leastRootFor room `elem` psSubscribedTo
+ , roomStateMessageData oldstate /= roomStateMessageData rstate
+ -> roomStateMessageData rstate
+ _ -> []
+
+ let packet = emptyPacket
+ { chatRoomInfo = roomUpdates
+ , chatRoomSubscribe = subscribe
+ , chatRoomUnsubscribe = unsubscribe
+ , chatRoomMessage = messages
+ }
+
+ when (packet /= emptyPacket) $ do
+ replyPacket packet
svcSet $ ps { psLastList = curList }
diff --git a/src/Erebos/Conversation.hs b/src/Erebos/Conversation.hs
index 94d2399..63475bd 100644
--- a/src/Erebos/Conversation.hs
+++ b/src/Erebos/Conversation.hs
@@ -1,12 +1,15 @@
module Erebos.Conversation (
Message,
messageFrom,
+ messageTime,
messageText,
messageUnread,
formatMessage,
Conversation,
directMessageConversation,
+ chatroomConversation,
+ chatroomConversationByStateData,
reloadConversation,
lookupConversations,
@@ -23,30 +26,45 @@ import Data.List
import Data.Maybe
import Data.Text (Text)
import Data.Text qualified as T
+import Data.Time.Format
import Data.Time.LocalTime
import Erebos.Identity
+import Erebos.Chatroom
import Erebos.Message hiding (formatMessage)
import Erebos.State
import Erebos.Storage
data Message = DirectMessageMessage DirectMessage Bool
+ | ChatroomMessage ChatMessage Bool
messageFrom :: Message -> ComposedIdentity
messageFrom (DirectMessageMessage msg _) = msgFrom msg
+messageFrom (ChatroomMessage msg _) = cmsgFrom msg
+
+messageTime :: Message -> ZonedTime
+messageTime (DirectMessageMessage msg _) = msgTime msg
+messageTime (ChatroomMessage msg _) = cmsgTime msg
messageText :: Message -> Maybe Text
messageText (DirectMessageMessage msg _) = Just $ msgText msg
+messageText (ChatroomMessage msg _) = cmsgText msg
messageUnread :: Message -> Bool
messageUnread (DirectMessageMessage _ unread) = unread
+messageUnread (ChatroomMessage _ unread) = unread
formatMessage :: TimeZone -> Message -> String
-formatMessage tzone (DirectMessageMessage msg _) = formatDirectMessage tzone msg
+formatMessage tzone msg = concat
+ [ formatTime defaultTimeLocale "[%H:%M] " $ utcToLocalTime tzone $ zonedTimeToUTC $ messageTime msg
+ , maybe "<unnamed>" T.unpack $ idName $ messageFrom msg
+ , maybe "" ((": "<>) . T.unpack) $ messageText msg
+ ]
data Conversation = DirectMessageConversation DirectMessageThread
+ | ChatroomConversation ChatroomState
directMessageConversation :: MonadHead LocalState m => ComposedIdentity -> m Conversation
directMessageConversation peer = do
@@ -54,8 +72,16 @@ directMessageConversation peer = do
Just thread -> return $ DirectMessageConversation thread
Nothing -> return $ DirectMessageConversation $ DirectMessageThread peer [] [] []
+chatroomConversation :: MonadHead LocalState m => ChatroomState -> m (Maybe Conversation)
+chatroomConversation rstate = chatroomConversationByStateData (head $ roomStateData rstate)
+
+chatroomConversationByStateData :: MonadHead LocalState m => Stored ChatroomStateData -> m (Maybe Conversation)
+chatroomConversationByStateData sdata = fmap ChatroomConversation <$> findChatroomByStateData sdata
+
reloadConversation :: MonadHead LocalState m => Conversation -> m Conversation
reloadConversation (DirectMessageConversation thread) = directMessageConversation (msgPeer thread)
+reloadConversation cur@(ChatroomConversation rstate) =
+ fromMaybe cur <$> chatroomConversation rstate
lookupConversations :: MonadHead LocalState m => m [Conversation]
lookupConversations = map DirectMessageConversation . toThreadList . lookupSharedValue . lsShared . fromStored <$> getLocalHead
@@ -63,13 +89,17 @@ lookupConversations = map DirectMessageConversation . toThreadList . lookupShare
conversationName :: Conversation -> Text
conversationName (DirectMessageConversation thread) = fromMaybe (T.pack "<unnamed>") $ idName $ msgPeer thread
+conversationName (ChatroomConversation rstate) = fromMaybe (T.pack "<unnamed>") $ roomName =<< roomStateRoom rstate
conversationPeer :: Conversation -> Maybe ComposedIdentity
conversationPeer (DirectMessageConversation thread) = Just $ msgPeer thread
+conversationPeer (ChatroomConversation _) = Nothing
conversationHistory :: Conversation -> [Message]
conversationHistory (DirectMessageConversation thread) = map (\msg -> DirectMessageMessage msg False) $ threadToList thread
+conversationHistory (ChatroomConversation rstate) = map (\msg -> ChatroomMessage msg False) $ roomStateMessages rstate
-sendMessage :: (MonadHead LocalState m, MonadError String m) => Conversation -> Text -> m Message
-sendMessage (DirectMessageConversation thread) text = DirectMessageMessage <$> (fromStored <$> sendDirectMessage (msgPeer thread) text) <*> pure False
+sendMessage :: (MonadHead LocalState m, MonadError String m) => Conversation -> Text -> m (Maybe Message)
+sendMessage (DirectMessageConversation thread) text = fmap Just $ DirectMessageMessage <$> (fromStored <$> sendDirectMessage (msgPeer thread) text) <*> pure False
+sendMessage (ChatroomConversation rstate) text = sendChatroomMessage rstate text >> return Nothing
diff --git a/src/Erebos/Identity.hs b/src/Erebos/Identity.hs
index 8761fde..f2094f6 100644
--- a/src/Erebos/Identity.hs
+++ b/src/Erebos/Identity.hs
@@ -35,7 +35,6 @@ import Data.Foldable
import Data.Function
import Data.List
import Data.Maybe
-import Data.Ord
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text (Text)
@@ -304,25 +303,18 @@ verifySignatures sidd = do
throwError "signature verification failed"
lookupProperty :: forall a m. Foldable m => (ExtendedIdentityData -> Maybe a) -> m (Stored (Signed ExtendedIdentityData)) -> Maybe a
-lookupProperty sel topHeads = findResult filteredLayers
- where findPropHeads :: Stored (Signed ExtendedIdentityData) -> [(Stored (Signed ExtendedIdentityData), a)]
- findPropHeads sobj | Just x <- sel $ fromSigned sobj = [(sobj, x)]
- | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj)
+lookupProperty sel topHeads = findResult propHeads
+ where
+ findPropHeads :: Stored (Signed ExtendedIdentityData) -> [ Stored (Signed ExtendedIdentityData) ]
+ findPropHeads sobj | Just _ <- sel $ fromSigned sobj = [ sobj ]
+ | otherwise = findPropHeads =<< (eiddPrev $ fromSigned sobj)
- propHeads :: [(Stored (Signed ExtendedIdentityData), a)]
- propHeads = findPropHeads =<< toList topHeads
+ propHeads :: [ Stored (Signed ExtendedIdentityData) ]
+ propHeads = filterAncestors $ findPropHeads =<< toList topHeads
- historyLayers :: [Set (Stored (Signed ExtendedIdentityData))]
- historyLayers = generations $ map fst propHeads
-
- filteredLayers :: [[(Stored (Signed ExtendedIdentityData), a)]]
- filteredLayers = scanl (\cur obsolete -> filter ((`S.notMember` obsolete) . fst) cur) propHeads historyLayers
-
- findResult ([(_, x)] : _) = Just x
- findResult ([] : _) = Nothing
- findResult [] = Nothing
- findResult [xs] = Just $ snd $ minimumBy (comparing fst) xs
- findResult (_:rest) = findResult rest
+ findResult :: [ Stored (Signed ExtendedIdentityData) ] -> Maybe a
+ findResult [] = Nothing
+ findResult xs = sel $ fromSigned $ minimum xs
mergeIdentity :: (MonadStorage m, MonadError String m, MonadIO m) => Identity f -> m UnifiedIdentity
mergeIdentity idt | Just idt' <- toUnifiedIdentity idt = return idt'
@@ -385,8 +377,9 @@ updateOwners updates orig@Identity { idOwner_ = Just owner, idUpdates_ = cupdate
updateOwners _ orig@Identity { idOwner_ = Nothing } = orig
sameIdentity :: (Foldable m, Foldable m') => Identity m -> Identity m' -> Bool
-sameIdentity x y = not $ S.null $ S.intersection (refset x) (refset y)
- where refset idt = foldr S.insert (ancestors $ toList $ idDataF idt) (idDataF idt)
+sameIdentity x y = intersectsSorted (roots x) (roots y)
+ where
+ roots idt = uniq $ sort $ concatMap storedRoots $ toList $ idDataF idt
unfoldOwners :: (Foldable m) => Identity m -> [ComposedIdentity]
diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs
index 41b6279..2064d1c 100644
--- a/src/Erebos/Network.hs
+++ b/src/Erebos/Network.hs
@@ -19,7 +19,9 @@ module Erebos.Network (
#endif
dropPeer,
isPeerDropped,
- sendToPeer, sendToPeerStored, sendToPeerWith,
+ sendToPeer, sendManyToPeer,
+ sendToPeerStored, sendManyToPeerStored,
+ sendToPeerWith,
runPeerService,
discoveryPort,
@@ -52,6 +54,9 @@ import GHC.Conc.Sync (unsafeIOToSTM)
import Network.Socket hiding (ControlMessage)
import qualified Network.Socket.ByteString as S
+import Foreign.C.Types
+import Foreign.Marshal.Alloc
+
import Erebos.Channel
#ifdef ENABLE_ICE_SUPPORT
import Erebos.ICE
@@ -69,6 +74,9 @@ import Erebos.Storage.Merge
discoveryPort :: PortNumber
discoveryPort = 29665
+discoveryMulticastGroup :: HostAddress6
+discoveryMulticastGroup = tupleToHostAddress6 (0xff12, 0xb6a4, 0x6b1f, 0x0969, 0xcaee, 0xacc2, 0x5c93, 0x73e1) -- ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1
+
announceIntervalSeconds :: Int
announceIntervalSeconds = 60
@@ -247,8 +255,6 @@ startServer opt serverOrigHead logd' serverServices = do
either (atomically . logd) return =<< runExceptT =<<
atomically (readTQueue serverIOActions)
- broadcastAddreses <- getBroadcastAddresses discoveryPort
-
let open addr = do
sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
putMVar serverSocket sock
@@ -259,9 +265,14 @@ startServer opt serverOrigHead logd' serverServices = do
return sock
loop sock = do
- when (serverLocalDiscovery opt) $ forkServerThread server $ forever $ do
- atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) broadcastAddreses
- threadDelay $ announceIntervalSeconds * 1000 * 1000
+ when (serverLocalDiscovery opt) $ forkServerThread server $ do
+ announceAddreses <- fmap concat $ sequence $
+ [ map (SockAddrInet6 discoveryPort 0 discoveryMulticastGroup) <$> joinMulticast sock
+ , getBroadcastAddresses discoveryPort
+ ]
+ forever $ do
+ atomically $ writeFlowBulk serverControlFlow $ map (SendAnnounce . DatagramAddress) announceAddreses
+ threadDelay $ announceIntervalSeconds * 1000 * 1000
let announceUpdate identity = do
st <- derivePartialStorage serverStorage
@@ -301,10 +312,11 @@ startServer opt serverOrigHead logd' serverServices = do
forkServerThread server $ forever $ do
(paddr, msg) <- readFlowIO serverRawPath
- case paddr of
- DatagramAddress addr -> void $ S.sendTo sock msg addr
+ handle (\(e :: IOException) -> atomically . logd $ "failed to send packet to " ++ show paddr ++ ": " ++ show e) $ do
+ case paddr of
+ DatagramAddress addr -> void $ S.sendTo sock msg addr
#ifdef ENABLE_ICE_SUPPORT
- PeerIceSession ice -> iceSend ice msg
+ PeerIceSession ice -> iceSend ice msg
#endif
forkServerThread server $ forever $ do
@@ -421,12 +433,18 @@ instance MonadFail PacketHandler where
runPacketHandler :: Bool -> Peer -> PacketHandler () -> STM ()
runPacketHandler secure peer@Peer {..} act = do
let logd = writeTQueue $ serverErrorLog peerServer_
- runExceptT (flip execStateT (PacketHandlerState peer [] [] [] False) $ unPacketHandler act) >>= \case
+ runExceptT (flip execStateT (PacketHandlerState peer [] [] [] Nothing False) $ unPacketHandler act) >>= \case
Left err -> do
logd $ "Error in handling packet from " ++ show peerAddress ++ ": " ++ err
Right ph -> do
when (not $ null $ phHead ph) $ do
- let packet = TransportPacket (TransportHeader $ phHead ph) (phBody ph)
+ body <- case phBodyStream ph of
+ Nothing -> return $ phBody ph
+ Just stream -> do
+ writeTQueue (serverIOActions peerServer_) $ void $ liftIO $ forkIO $ do
+ writeByteStringToStream stream $ BL.concat $ map lazyLoadBytes $ phBody ph
+ return []
+ let packet = TransportPacket (TransportHeader $ phHead ph) body
secreq = case (secure, phPlaintextReply ph) of
(True, _) -> EncryptedOnly
(False, False) -> PlaintextAllowed
@@ -450,6 +468,7 @@ data PacketHandlerState = PacketHandlerState
, phHead :: [TransportHeaderItem]
, phAckedBy :: [TransportHeaderItem]
, phBody :: [Ref]
+ , phBodyStream :: Maybe RawStreamWriter
, phPlaintextReply :: Bool
}
@@ -462,6 +481,14 @@ addAckedBy hs = modify $ \ph -> ph { phAckedBy = foldr appendDistinct (phAckedBy
addBody :: Ref -> PacketHandler ()
addBody r = modify $ \ph -> ph { phBody = r `appendDistinct` phBody ph }
+sendBodyAsStream :: PacketHandler ()
+sendBodyAsStream = do
+ gets phBodyStream >>= \case
+ Nothing -> do
+ stream <- openStream
+ modify $ \ph -> ph { phBodyStream = Just stream }
+ Just _ -> return ()
+
keepPlaintextReply :: PacketHandler ()
keepPlaintextReply = modify $ \ph -> ph { phPlaintextReply = True }
@@ -517,8 +544,12 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs =
liftSTM $ finalizedChannel peer ch identity
_ -> return ()
- Rejected dgst -> do
- logd $ "rejected by peer: " ++ show dgst
+ Rejected dgst
+ | peerRequest : _ <- mapMaybe (\case TrChannelRequest d -> Just d; _ -> Nothing) headers
+ , peerRequest < dgst
+ -> return () -- Our request was rejected due to lower priority
+
+ | otherwise -> logd $ "rejected by peer: " ++ show dgst
DataRequest dgst
| secure || dgst `elem` plaintextRefs -> do
@@ -532,15 +563,11 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs =
-- otherwise lost the channel, so keep the reply plaintext as well.
when (not secure) keepPlaintextReply
- let bytes = lazyLoadBytes mref
-- TODO: MTU
- if (secure && BL.length bytes > 500)
- then do
- stream <- openStream
- liftSTM $ writeTQueue (serverIOActions server) $ void $ liftIO $ forkIO $ do
- writeByteStringToStream stream bytes
- else do
- addBody $ mref
+ when (secure && BL.length (lazyLoadBytes mref) > 500)
+ sendBodyAsStream
+
+ addBody $ mref
| otherwise -> do
logd $ "unauthorized data request for " ++ show dgst
addHeader $ Rejected dgst
@@ -593,9 +620,15 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs =
ChannelCookieWait {} -> return ()
ChannelCookieReceived {} -> process
ChannelCookieConfirmed {} -> process
- ChannelOurRequest our | dgst < refDigest (storedRef our) -> process
- | otherwise -> reject
- ChannelPeerRequest {} -> process
+ ChannelOurRequest our
+ | dgst < refDigest (storedRef our) -> process
+ | otherwise -> do
+ -- Reject peer channel request with lower priority
+ addHeader $ TrChannelRequest $ refDigest $ storedRef our
+ reject
+ ChannelPeerRequest prev
+ | dgst == wrDigest prev -> addHeader $ Acknowledged dgst
+ | otherwise -> process
ChannelOurAccept {} -> reject
ChannelEstablished {} -> process
ChannelClosed {} -> return ()
@@ -647,12 +680,14 @@ setupChannel identity peer upid = do
[ TrChannelRequest reqref
, AnnounceSelf $ refDigest $ storedRef $ idData identity
]
+ let sendChannelRequest = do
+ sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $
+ TransportPacket (TransportHeader hitems) [storedRef req]
+ setPeerChannel peer $ ChannelOurRequest req
liftIO $ atomically $ do
getPeerChannel peer >>= \case
- ChannelCookieConfirmed -> do
- sendToPeerPlain peer [ Acknowledged reqref, Rejected reqref ] $
- TransportPacket (TransportHeader hitems) [storedRef req]
- setPeerChannel peer $ ChannelOurRequest req
+ ChannelCookieReceived -> sendChannelRequest
+ ChannelCookieConfirmed -> sendChannelRequest
_ -> return ()
handleChannelRequest :: Peer -> UnifiedIdentity -> Ref -> WaitingRefCallback
@@ -806,10 +841,16 @@ isPeerDropped peer = liftIO $ atomically $ readTVar (peerState peer) >>= \case
_ -> return False
sendToPeer :: (Service s, MonadIO m) => Peer -> s -> m ()
-sendToPeer peer packet = sendToPeerList peer [ServiceReply (Left packet) True]
+sendToPeer peer = sendManyToPeer peer . (: [])
+
+sendManyToPeer :: (Service s, MonadIO m) => Peer -> [ s ] -> m ()
+sendManyToPeer peer = sendToPeerList peer . map (\part -> ServiceReply (Left part) True)
sendToPeerStored :: (Service s, MonadIO m) => Peer -> Stored s -> m ()
-sendToPeerStored peer spacket = sendToPeerList peer [ServiceReply (Right spacket) True]
+sendToPeerStored peer = sendManyToPeerStored peer . (: [])
+
+sendManyToPeerStored :: (Service s, MonadIO m) => Peer -> [ Stored s ] -> m ()
+sendManyToPeerStored peer = sendToPeerList peer . map (\part -> ServiceReply (Right part) True)
sendToPeerList :: (Service s, MonadIO m) => Peer -> [ServiceReply s] -> m ()
sendToPeerList peer parts = do
@@ -912,9 +953,19 @@ runPeerServiceOn mbservice peer handler = liftIO $ do
logd $ "unhandled service '" ++ show (toUUID svc) ++ "'"
+foreign import ccall unsafe "Network/ifaddrs.h join_multicast" cJoinMulticast :: CInt -> Ptr CSize -> IO (Ptr Word32)
foreign import ccall unsafe "Network/ifaddrs.h broadcast_addresses" cBroadcastAddresses :: IO (Ptr Word32)
foreign import ccall unsafe "stdlib.h free" cFree :: Ptr Word32 -> IO ()
+joinMulticast :: Socket -> IO [ Word32 ]
+joinMulticast sock =
+ withFdSocket sock $ \fd ->
+ alloca $ \pcount -> do
+ ptr <- cJoinMulticast fd pcount
+ count <- fromIntegral <$> peek pcount
+ forM [ 0 .. count - 1 ] $ \i ->
+ peekElemOff ptr i
+
getBroadcastAddresses :: PortNumber -> IO [SockAddr]
getBroadcastAddresses port = do
ptr <- cBroadcastAddresses
@@ -922,6 +973,9 @@ getBroadcastAddresses port = do
w <- peekElemOff ptr i
if w == 0 then return []
else (SockAddrInet port w:) <$> parse (i + 1)
- addrs <- parse 0
- cFree ptr
- return addrs
+ if ptr == nullPtr
+ then return []
+ else do
+ addrs <- parse 0
+ cFree ptr
+ return addrs
diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs
index a009ad1..2955473 100644
--- a/src/Erebos/Network/Protocol.hs
+++ b/src/Erebos/Network/Protocol.hs
@@ -40,7 +40,17 @@ import Control.Monad
import Control.Monad.Except
import Control.Monad.Trans
+import Crypto.Cipher.ChaChaPoly1305 qualified as C
+import Crypto.MAC.Poly1305 qualified as C (Auth(..), authTag)
+import Crypto.Error
+import Crypto.Random
+
+import Data.Binary
+import Data.Binary.Get
+import Data.Binary.Put
import Data.Bits
+import Data.ByteArray (Bytes, ScrubbedBytes)
+import Data.ByteArray qualified as BA
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Char8 qualified as BC
@@ -51,7 +61,6 @@ import Data.Maybe
import Data.Text (Text)
import Data.Text qualified as T
import Data.Void
-import Data.Word
import System.Clock
@@ -68,6 +77,9 @@ protocolVersion = T.pack "0.1"
protocolVersions :: [Text]
protocolVersions = [protocolVersion]
+keepAliveInternal :: TimeSpec
+keepAliveInternal = fromNanoSecs $ 30 * 10^(9 :: Int)
+
data TransportPacket a = TransportPacket TransportHeader [a]
@@ -93,14 +105,41 @@ data TransportHeaderItem
| StreamOpen Word8
deriving (Eq, Show)
-newtype Cookie = Cookie ByteString
- deriving (Eq, Show)
-
data SecurityRequirement = PlaintextOnly
| PlaintextAllowed
| EncryptedOnly
deriving (Eq, Ord)
+data Cookie = Cookie
+ { cookieNonce :: C.Nonce
+ , cookieValidity :: Word32
+ , cookieContent :: ByteString
+ , cookieMac :: C.Auth
+ }
+
+instance Eq Cookie where
+ (==) = (==) `on` (\c -> ( BA.convert (cookieNonce c) :: ByteString, cookieValidity c, cookieContent c, cookieMac c ))
+
+
+instance Show Cookie where
+ show Cookie {..} = show (nonce, cookieValidity, cookieContent, mac)
+ where C.Auth mac = cookieMac
+ nonce = BA.convert cookieNonce :: ByteString
+
+instance Binary Cookie where
+ put Cookie {..} = do
+ putByteString $ BA.convert cookieNonce
+ putWord32be cookieValidity
+ putByteString $ BA.convert cookieMac
+ putByteString cookieContent
+
+ get = do
+ Just cookieNonce <- maybeCryptoError . C.nonce12 <$> getByteString 12
+ cookieValidity <- getWord32be
+ Just cookieMac <- maybeCryptoError . C.authTag <$> getByteString 16
+ cookieContent <- BL.toStrict <$> getRemainingLazyByteString
+ return Cookie {..}
+
isHeaderItemAcknowledged :: TransportHeaderItem -> Bool
isHeaderItemAcknowledged = \case
Acknowledged {} -> False
@@ -120,8 +159,8 @@ transportToObject st (TransportHeader items) = Rec $ map single items
Rejected dgst -> (BC.pack "REJ", RecRef $ partialRefFromDigest st dgst)
ProtocolVersion ver -> (BC.pack "VER", RecText ver)
Initiation dgst -> (BC.pack "INI", RecRef $ partialRefFromDigest st dgst)
- CookieSet (Cookie bytes) -> (BC.pack "CKS", RecBinary bytes)
- CookieEcho (Cookie bytes) -> (BC.pack "CKE", RecBinary bytes)
+ CookieSet cookie -> (BC.pack "CKS", RecBinary $ BL.toStrict $ encode cookie)
+ CookieEcho cookie -> (BC.pack "CKE", RecBinary $ BL.toStrict $ encode cookie)
DataRequest dgst -> (BC.pack "REQ", RecRef $ partialRefFromDigest st dgst)
DataResponse dgst -> (BC.pack "RSP", RecRef $ partialRefFromDigest st dgst)
AnnounceSelf dgst -> (BC.pack "ANN", RecRef $ partialRefFromDigest st dgst)
@@ -142,8 +181,12 @@ transportFromObject (Rec items) = case catMaybes $ map single items of
| name == BC.pack "REJ", RecRef ref <- content -> Just $ Rejected $ refDigest ref
| name == BC.pack "VER", RecText ver <- content -> Just $ ProtocolVersion ver
| name == BC.pack "INI", RecRef ref <- content -> Just $ Initiation $ refDigest ref
- | name == BC.pack "CKS", RecBinary bytes <- content -> Just $ CookieSet (Cookie bytes)
- | name == BC.pack "CKE", RecBinary bytes <- content -> Just $ CookieEcho (Cookie bytes)
+ | name == BC.pack "CKS", RecBinary bytes <- content
+ , Right (_, _, cookie) <- decodeOrFail (BL.fromStrict bytes)
+ -> Just $ CookieSet cookie
+ | name == BC.pack "CKE", RecBinary bytes <- content
+ , Right (_, _, cookie) <- decodeOrFail (BL.fromStrict bytes)
+ -> Just $ CookieEcho cookie
| name == BC.pack "REQ", RecRef ref <- content -> Just $ DataRequest $ refDigest ref
| name == BC.pack "RSP", RecRef ref <- content -> Just $ DataResponse $ refDigest ref
| name == BC.pack "ANN", RecRef ref <- content -> Just $ AnnounceSelf $ refDigest ref
@@ -165,9 +208,12 @@ data GlobalState addr = (Eq addr, Show addr) => GlobalState
, gNextUp :: TMVar (Connection addr, (Bool, TransportPacket PartialObject))
, gLog :: String -> STM ()
, gStorage :: PartialStorage
+ , gStartTime :: TimeSpec
, gNowVar :: TVar TimeSpec
, gNextTimeout :: TVar TimeSpec
, gInitConfig :: Ref
+ , gCookieKey :: ScrubbedBytes
+ , gCookieStartTime :: Word32
}
data Connection addr = Connection
@@ -186,6 +232,7 @@ data Connection addr = Connection
, cReservedPackets :: TVar Int
, cSentPackets :: TVar [SentPacket]
, cToAcknowledge :: TVar [Integer]
+ , cNextKeepAlive :: TVar (Maybe TimeSpec)
, cInStreams :: TVar [(Word8, Stream)]
, cOutStreams :: TVar [(Word8, Stream)]
}
@@ -440,15 +487,18 @@ erebosNetworkProtocol initialIdentity gLog gDataFlow gControlFlow = do
mStorage <- memoryStorage
gStorage <- derivePartialStorage mStorage
- startTime <- getTime MonotonicRaw
- gNowVar <- newTVarIO startTime
- gNextTimeout <- newTVarIO startTime
+ gStartTime <- getTime Monotonic
+ gNowVar <- newTVarIO gStartTime
+ gNextTimeout <- newTVarIO gStartTime
gInitConfig <- store mStorage $ (Rec [] :: Object)
+ gCookieKey <- getRandomBytes 32
+ gCookieStartTime <- runGet getWord32host . BL.pack . BA.unpack @ScrubbedBytes <$> getRandomBytes 4
+
let gs = GlobalState {..}
let signalTimeouts = forever $ do
- now <- getTime MonotonicRaw
+ now <- getTime Monotonic
next <- atomically $ do
writeTVar gNowVar now
readTVar gNextTimeout
@@ -487,6 +537,7 @@ newConnection cGlobalState@GlobalState {..} addr = do
cReservedPackets <- newTVar 0
cSentPackets <- newTVar []
cToAcknowledge <- newTVar []
+ cNextKeepAlive <- newTVar Nothing
cInStreams <- newTVar []
cOutStreams <- newTVar []
let conn = Connection {..}
@@ -548,6 +599,7 @@ processIncoming gs@GlobalState {..} = do
Nothing -> throwError "empty packet"
+ now <- getTime Monotonic
runExceptT parse >>= \case
Right (Left (secure, objs, mbcounter))
| hobj:content <- objs
@@ -562,6 +614,7 @@ processIncoming gs@GlobalState {..} = do
case mbup of
Just up -> putTMVar gNextUp (conn, (secure, up))
Nothing -> return ()
+ updateKeepAlive conn now
processAcknowledgements gs conn items
ioAfter
Nothing -> return ()
@@ -571,8 +624,9 @@ processIncoming gs@GlobalState {..} = do
gLog $ show objs
Right (Right (snum, seq8, content, counter))
- | Just Connection {..} <- mbconn
+ | Just conn@Connection {..} <- mbconn
-> atomically $ do
+ updateKeepAlive conn now
(lookup snum <$> readTVar cInStreams) >>= \case
Nothing ->
gLog $ "unexpected stream number " ++ show snum
@@ -694,11 +748,36 @@ generateCookieHeaders Connection {..} ch = catMaybes <$> sequence [ echoHeader,
_ -> return Nothing
createCookie :: GlobalState addr -> addr -> IO Cookie
-createCookie GlobalState {} addr = return (Cookie $ BC.pack $ show addr)
+createCookie GlobalState {..} addr = do
+ (nonceBytes :: Bytes) <- getRandomBytes 12
+ validUntil <- (fromNanoSecs (60 * 10^(9 :: Int)) +) <$> getTime Monotonic
+ let validSecondsFromStart = fromIntegral $ toNanoSecs (validUntil - gStartTime) `div` (10^(9 :: Int))
+ cookieValidity = validSecondsFromStart - gCookieStartTime
+ plainContent = BC.pack (show addr)
+ throwCryptoErrorIO $ do
+ cookieNonce <- C.nonce12 nonceBytes
+ st1 <- C.initialize gCookieKey cookieNonce
+ let st2 = C.finalizeAAD $ C.appendAAD (BL.toStrict $ runPut $ putWord32be cookieValidity) st1
+ (cookieContent, st3) = C.encrypt plainContent st2
+ cookieMac = C.finalize st3
+ return $ Cookie {..}
verifyCookie :: GlobalState addr -> addr -> Cookie -> IO Bool
-verifyCookie GlobalState {} addr (Cookie cookie) = return $ show addr == BC.unpack cookie
-
+verifyCookie GlobalState {..} addr Cookie {..} = do
+ ctime <- getTime Monotonic
+ return $ fromMaybe False $ maybeCryptoError $ do
+ st1 <- C.initialize gCookieKey cookieNonce
+ let st2 = C.finalizeAAD $ C.appendAAD (BL.toStrict $ runPut $ putWord32be cookieValidity) st1
+ (plainContent, st3) = C.decrypt cookieContent st2
+ mac = C.finalize st3
+
+ validSecondsFromStart = fromIntegral $ cookieValidity + gCookieStartTime
+ validUntil = gStartTime + fromNanoSecs (validSecondsFromStart * (10^(9 :: Int)))
+ return $ and
+ [ mac == cookieMac
+ , ctime <= validUntil
+ , show addr == BC.unpack plainContent
+ ]
reservePacket :: Connection addr -> STM ReservedToSend
reservePacket conn@Connection {..} = do
@@ -713,9 +792,9 @@ reservePacket conn@Connection {..} = do
return $ ReservedToSend Nothing (return ()) (atomically $ connClose conn)
resendBytes :: Connection addr -> Maybe ReservedToSend -> SentPacket -> IO ()
-resendBytes Connection {..} reserved sp = do
+resendBytes conn@Connection {..} reserved sp = do
let GlobalState {..} = cGlobalState
- now <- getTime MonotonicRaw
+ now <- getTime Monotonic
atomically $ do
when (isJust reserved) $ do
modifyTVar' cReservedPackets (subtract 1)
@@ -726,6 +805,7 @@ resendBytes Connection {..} reserved sp = do
, spRetryCount = spRetryCount sp + 1
}
writeFlow gDataFlow (cAddress, spData sp)
+ updateKeepAlive conn now
sendBytes :: Connection addr -> Maybe ReservedToSend -> ByteString -> IO ()
sendBytes conn reserved bs = resendBytes conn reserved
@@ -738,6 +818,12 @@ sendBytes conn reserved bs = resendBytes conn reserved
, spData = bs
}
+updateKeepAlive :: Connection addr -> TimeSpec -> STM ()
+updateKeepAlive Connection {..} now = do
+ let next = now + keepAliveInternal
+ writeTVar cNextKeepAlive $ Just next
+
+
processOutgoing :: forall addr. GlobalState addr -> STM (IO ())
processOutgoing gs@GlobalState {..} = do
@@ -777,11 +863,12 @@ processOutgoing gs@GlobalState {..} = do
let onAck = sequence_ $ map (streamAccepted conn) $
catMaybes (map (\case StreamOpen n -> Just n; _ -> Nothing) hitems)
- let mkPlain extraHeaders =
- let header = TransportHeader $ map AcknowledgedSingle acknowledge ++ extraHeaders ++ hitems
- in BL.concat $
- (serializeObject $ transportToObject gStorage header)
- : map lazyLoadBytes content
+ let mkPlain extraHeaders
+ | combinedHeaderItems@(_:_) <- map AcknowledgedSingle acknowledge ++ extraHeaders ++ hitems =
+ BL.concat $
+ (serializeObject $ transportToObject gStorage $ TransportHeader combinedHeaderItems)
+ : map lazyLoadBytes content
+ | otherwise = BL.empty
let usePlaintext = do
plain <- mkPlain <$> generateCookieHeaders conn channel
@@ -811,6 +898,13 @@ processOutgoing gs@GlobalState {..} = do
sendBytes conn mbReserved' bs
Nothing -> return ()
+ let waitUntil :: TimeSpec -> TimeSpec -> STM ()
+ waitUntil now till = do
+ nextTimeout <- readTVar gNextTimeout
+ if nextTimeout <= now || till < nextTimeout
+ then writeTVar gNextTimeout till
+ else retry
+
let retransmitPacket :: Connection addr -> STM (IO ())
retransmitPacket conn@Connection {..} = do
now <- readTVar gNowVar
@@ -819,11 +913,8 @@ processOutgoing gs@GlobalState {..} = do
_ -> retry
let nextTry = spTime sp + fromNanoSecs 1000000000
if | now < nextTry -> do
- nextTimeout <- readTVar gNextTimeout
- if nextTimeout <= now || nextTry < nextTimeout
- then do writeTVar gNextTimeout nextTry
- return $ return ()
- else retry
+ waitUntil now nextTry
+ return $ return ()
| spRetryCount sp < 2 -> do
reserved <- reservePacket conn
writeTVar cSentPackets rest
@@ -863,11 +954,28 @@ processOutgoing gs@GlobalState {..} = do
writeTVar gIdentity (nid, cur : past)
return $ return ()
+ let sendKeepAlive :: Connection addr -> STM (IO ())
+ sendKeepAlive Connection {..} = do
+ readTVar cNextKeepAlive >>= \case
+ Nothing -> retry
+ Just next -> do
+ now <- readTVar gNowVar
+ if next <= now
+ then do
+ writeTVar cNextKeepAlive Nothing
+ identity <- fst <$> readTVar gIdentity
+ let header = TransportHeader [ AnnounceSelf $ refDigest $ storedRef $ idData identity ]
+ writeTQueue cSecureOutQueue (EncryptedOnly, TransportPacket header [], [])
+ else do
+ waitUntil now next
+ return $ return ()
+
conns <- readTVar gConnections
msum $ concat $
[ map retransmitPacket conns
, map sendNextPacket conns
, [ handleControlRequests ]
+ , map sendKeepAlive conns
]
processAcknowledgements :: GlobalState addr -> Connection addr -> [TransportHeaderItem] -> STM (IO ())
diff --git a/src/Erebos/Network/ifaddrs.c b/src/Erebos/Network/ifaddrs.c
index 37c3e00..70685bc 100644
--- a/src/Erebos/Network/ifaddrs.c
+++ b/src/Erebos/Network/ifaddrs.c
@@ -1,11 +1,89 @@
#include "ifaddrs.h"
+#include <errno.h>
+#include <stdbool.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#ifndef _WIN32
#include <arpa/inet.h>
-#include <ifaddrs.h>
#include <net/if.h>
-#include <stdlib.h>
-#include <sys/types.h>
+#include <ifaddrs.h>
#include <endian.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#else
+#include <winsock2.h>
+#include <ws2ipdef.h>
+#include <ws2tcpip.h>
+#endif
+
+#define DISCOVERY_MULTICAST_GROUP "ff12:b6a4:6b1f:969:caee:acc2:5c93:73e1"
+
+uint32_t * join_multicast(int fd, size_t * count)
+{
+ size_t capacity = 16;
+ *count = 0;
+ uint32_t * interfaces = malloc(sizeof(uint32_t) * capacity);
+
+#ifdef _WIN32
+ interfaces[0] = 0;
+ *count = 1;
+#else
+ struct ifaddrs * addrs;
+ if (getifaddrs(&addrs) < 0)
+ return 0;
+
+ for (struct ifaddrs * ifa = addrs; ifa; ifa = ifa->ifa_next) {
+ if (ifa->ifa_addr && ifa->ifa_addr->sa_family == AF_INET6 &&
+ !(ifa->ifa_flags & IFF_LOOPBACK)) {
+ int idx = if_nametoindex(ifa->ifa_name);
+
+ bool seen = false;
+ for (size_t i = 0; i < *count; i++) {
+ if (interfaces[i] == idx) {
+ seen = true;
+ break;
+ }
+ }
+ if (seen)
+ continue;
+
+ if (*count + 1 >= capacity) {
+ capacity *= 2;
+ uint32_t * nret = realloc(interfaces, sizeof(uint32_t) * capacity);
+ if (nret) {
+ interfaces = nret;
+ } else {
+ free(interfaces);
+ *count = 0;
+ return NULL;
+ }
+ }
+
+ interfaces[*count] = idx;
+ (*count)++;
+ }
+ }
+
+ freeifaddrs(addrs);
+#endif
+
+ for (size_t i = 0; i < *count; i++) {
+ struct ipv6_mreq group;
+ group.ipv6mr_interface = interfaces[i];
+ inet_pton(AF_INET6, DISCOVERY_MULTICAST_GROUP, &group.ipv6mr_multiaddr);
+ int ret = setsockopt(fd, IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP,
+ (const void *) &group, sizeof(group));
+ if (ret < 0)
+ fprintf(stderr, "IPV6_ADD_MEMBERSHIP failed: %s\n", strerror(errno));
+ }
+
+ return interfaces;
+}
+
+#ifndef _WIN32
uint32_t * broadcast_addresses(void)
{
@@ -39,3 +117,51 @@ uint32_t * broadcast_addresses(void)
ret[count] = 0;
return ret;
}
+
+#else // _WIN32
+
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#pragma comment(lib, "ws2_32.lib")
+
+uint32_t * broadcast_addresses(void)
+{
+ uint32_t * ret = NULL;
+ SOCKET wsock = INVALID_SOCKET;
+
+ struct WSAData wsaData;
+ if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
+ return NULL;
+
+ wsock = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, 0);
+ if (wsock == INVALID_SOCKET)
+ goto cleanup;
+
+ INTERFACE_INFO InterfaceList[32];
+ unsigned long nBytesReturned;
+
+ if (WSAIoctl(wsock, SIO_GET_INTERFACE_LIST, 0, 0,
+ InterfaceList, sizeof(InterfaceList),
+ &nBytesReturned, 0, 0) == SOCKET_ERROR)
+ goto cleanup;
+
+ int numInterfaces = nBytesReturned / sizeof(INTERFACE_INFO);
+
+ size_t capacity = 16, count = 0;
+ ret = malloc(sizeof(uint32_t) * capacity);
+
+ for (int i = 0; i < numInterfaces && count < capacity - 1; i++)
+ if (InterfaceList[i].iiFlags & IFF_BROADCAST)
+ ret[count++] = InterfaceList[i].iiBroadcastAddress.AddressIn.sin_addr.s_addr;
+
+ ret[count] = 0;
+cleanup:
+ if (wsock != INVALID_SOCKET)
+ closesocket(wsock);
+ WSACleanup();
+
+ return ret;
+}
+
+#endif
diff --git a/src/Erebos/Network/ifaddrs.h b/src/Erebos/Network/ifaddrs.h
index 06d26ec..8852ec6 100644
--- a/src/Erebos/Network/ifaddrs.h
+++ b/src/Erebos/Network/ifaddrs.h
@@ -1,3 +1,5 @@
+#include <stddef.h>
#include <stdint.h>
+uint32_t * join_multicast(int fd, size_t * count);
uint32_t * broadcast_addresses(void);
diff --git a/src/Erebos/Storage/Internal.hs b/src/Erebos/Storage/Internal.hs
index d419a5e..8b794d8 100644
--- a/src/Erebos/Storage/Internal.hs
+++ b/src/Erebos/Storage/Internal.hs
@@ -241,7 +241,7 @@ writeFileOnce file content = bracket (openLockFile locked)
doesFileExist file >>= \case
True -> removeFile locked
False -> do BL.hPut h content
- hFlush h
+ hClose h
renameFile locked file
where locked = file ++ ".lock"
@@ -254,13 +254,13 @@ writeFileChecked file prev content = bracket (openLockFile locked)
removeFile locked
return $ Left $ Just current
(Nothing, False) -> do B.hPut h content
- hFlush h
+ hClose h
renameFile locked file
return $ Right ()
(Just expected, True) -> do
current <- B.readFile file
if current == expected then do B.hPut h content
- hFlush h
+ hClose h
renameFile locked file
return $ return ()
else do removeFile locked
diff --git a/src/Erebos/Storage/Key.hs b/src/Erebos/Storage/Key.hs
index b6afc20..5da79e3 100644
--- a/src/Erebos/Storage/Key.hs
+++ b/src/Erebos/Storage/Key.hs
@@ -80,6 +80,7 @@ moveKeys from to = liftIO $ do
return M.empty
(StorageMemory { memKeys = fromKeys }, StorageMemory { memKeys = toKeys }) -> do
- modifyMVar_ fromKeys $ \fkeys -> do
- modifyMVar_ toKeys $ return . M.union fkeys
- return M.empty
+ when (fromKeys /= toKeys) $ do
+ modifyMVar_ fromKeys $ \fkeys -> do
+ modifyMVar_ toKeys $ return . M.union fkeys
+ return M.empty
diff --git a/src/Erebos/Storage/Merge.hs b/src/Erebos/Storage/Merge.hs
index 9d9db13..a3b0fd7 100644
--- a/src/Erebos/Storage/Merge.hs
+++ b/src/Erebos/Storage/Merge.hs
@@ -97,13 +97,16 @@ storedGeneration x =
doLookup x
+-- |Returns list of sets starting with the set of given objects and
+-- intcrementally adding parents.
generations :: Storable a => [Stored a] -> [Set (Stored a)]
generations = unfoldr gen . (,S.empty)
- where gen (hs, cur) = case filter (`S.notMember` cur) $ previous =<< hs of
+ where gen (hs, cur) = case filter (`S.notMember` cur) hs of
[] -> Nothing
added -> let next = foldr S.insert cur added
- in Just (next, (added, next))
+ in Just (next, (previous =<< added, next))
+-- |Returns set containing all given objects and their ancestors
ancestors :: Storable a => [Stored a] -> Set (Stored a)
ancestors = last . (S.empty:) . generations