diff options
Diffstat (limited to 'src/Erebos')
-rw-r--r-- | src/Erebos/Network.hs | 90 | ||||
-rw-r--r-- | src/Erebos/Network/Protocol.hs | 12 |
2 files changed, 80 insertions, 22 deletions
diff --git a/src/Erebos/Network.hs b/src/Erebos/Network.hs index a01bdd1..f234971 100644 --- a/src/Erebos/Network.hs +++ b/src/Erebos/Network.hs @@ -4,6 +4,7 @@ module Erebos.Network ( Server, startServer, stopServer, + getCurrentPeerList, getNextPeerChange, ServerOptions(..), serverIdentity, defaultServerOptions, @@ -16,6 +17,8 @@ module Erebos.Network ( #ifdef ENABLE_ICE_SUPPORT serverPeerIce, #endif + dropPeer, + isPeerDropped, sendToPeer, sendToPeerStored, sendToPeerWith, runPeerService, @@ -90,6 +93,9 @@ data Server = Server serverIdentity :: Server -> IO UnifiedIdentity serverIdentity = readMVar . serverIdentity_ +getCurrentPeerList :: Server -> IO [Peer] +getCurrentPeerList = fmap M.elems . readMVar . serverPeers + getNextPeerChange :: Server -> IO Peer getNextPeerChange = atomically . readTChan . serverChanPeer @@ -108,7 +114,7 @@ defaultServerOptions = ServerOptions data Peer = Peer { peerAddress :: PeerAddress , peerServer_ :: Server - , peerConnection :: TVar (Either [(SecurityRequirement, TransportPacket Ref, [TransportHeaderItem])] (Connection PeerAddress)) + , peerState :: TVar PeerState , peerIdentityVar :: TVar PeerIdentity , peerStorage_ :: Storage , peerInStorage :: PartialStorage @@ -123,13 +129,18 @@ peerStorage :: Peer -> Storage peerStorage = peerStorage_ getPeerChannel :: Peer -> STM ChannelState -getPeerChannel Peer {..} = either (const $ return ChannelNone) connGetChannel =<< readTVar peerConnection +getPeerChannel Peer {..} = + readTVar peerState >>= \case + PeerInit _ -> return ChannelNone + PeerConnected conn -> connGetChannel conn + PeerDropped -> return ChannelClosed setPeerChannel :: Peer -> ChannelState -> STM () setPeerChannel Peer {..} ch = do - readTVar peerConnection >>= \case - Left _ -> retry - Right conn -> connSetChannel conn ch + readTVar peerState >>= \case + PeerInit _ -> retry + PeerConnected conn -> connSetChannel conn ch + PeerDropped -> return () instance Eq Peer where (==) = (==) `on` peerIdentityVar @@ -175,6 +186,11 @@ peerIdentity :: MonadIO m => Peer -> m PeerIdentity peerIdentity = liftIO . atomically . readTVar . peerIdentityVar +data PeerState = PeerInit [(SecurityRequirement, TransportPacket Ref, [TransportHeaderItem])] + | PeerConnected (Connection PeerAddress) + | PeerDropped + + lookupServiceType :: [TransportHeaderItem] -> Maybe ServiceID lookupServiceType (ServiceType stype : _) = Just stype lookupServiceType (_ : hs) = lookupServiceType hs @@ -196,8 +212,13 @@ newWaitingRef dgst act = do forkServerThread :: Server -> IO () -> IO () -forkServerThread server act = modifyMVar_ (serverThreads server) $ \ts -> do - (:ts) <$> forkIO act +forkServerThread server act = do + modifyMVar_ (serverThreads server) $ \ts -> do + t <- forkIO $ do + t <- myThreadId + act + modifyMVar_ (serverThreads server) $ return . filter (/=t) + return (t:ts) startServer :: ServerOptions -> Head LocalState -> (String -> IO ()) -> [SomeService] -> IO Server startServer opt serverOrigHead logd' serverServices = do @@ -299,10 +320,14 @@ startServer opt serverOrigHead logd' serverServices = do forkServerThread server $ do atomically $ do - readTVar (peerConnection peer) >>= \case - Left packets -> writeFlowBulk (connData conn) $ reverse packets - Right _ -> return () - writeTVar (peerConnection peer) (Right conn) + readTVar (peerState peer) >>= \case + PeerInit packets -> do + writeFlowBulk (connData conn) $ reverse packets + writeTVar (peerState peer) (PeerConnected conn) + PeerConnected _ -> do + writeTVar (peerState peer) (PeerConnected conn) + PeerDropped -> do + connClose conn case mbpid of Just dgst -> do @@ -438,9 +463,9 @@ keepPlaintextReply = modify $ \ph -> ph { phPlaintextReply = True } openStream :: PacketHandler RawStreamWriter openStream = do Peer {..} <- gets phPeer - conn <- readTVarP peerConnection >>= \case - Right conn -> return conn - _ -> throwError "can't open stream without established connection" + conn <- readTVarP peerState >>= \case + PeerConnected conn -> return conn + _ -> throwError "can't open stream without established connection" (hdr, writer, handler) <- liftSTM (connAddWriteStream conn) >>= \case Right res -> return res Left err -> throwError err @@ -452,9 +477,9 @@ openStream = do acceptStream :: Word8 -> PacketHandler RawStreamReader acceptStream streamNumber = do Peer {..} <- gets phPeer - conn <- readTVarP peerConnection >>= \case - Right conn -> return conn - _ -> throwError "can't accept stream without established connection" + conn <- readTVarP peerState >>= \case + PeerConnected conn -> return conn + _ -> throwError "can't accept stream without established connection" liftSTM $ connAddReadStream conn streamNumber appendDistinct :: Eq a => a -> [a] -> [a] @@ -568,6 +593,7 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = ChannelPeerRequest {} -> process ChannelOurAccept {} -> reject ChannelEstablished {} -> process + ChannelClosed {} -> return () TrChannelAccept dgst -> do let process = do @@ -583,6 +609,7 @@ handlePacket identity secure peer chanSvc svcs (TransportHeader headers) prefs = ChannelOurAccept our _ | dgst < refDigest (storedRef our) -> process | otherwise -> addHeader $ Rejected dgst ChannelEstablished {} -> process + ChannelClosed {} -> return () ServiceType _ -> return () ServiceRef dgst @@ -721,7 +748,7 @@ notifyServicesOfPeer peer@Peer { peerServer_ = Server {..} } = do mkPeer :: Server -> PeerAddress -> IO Peer mkPeer peerServer_ peerAddress = do - peerConnection <- newTVarIO (Left []) + peerState <- newTVarIO (PeerInit []) peerIdentityVar <- newTVarIO . PeerIdentityUnknown =<< newTVarIO [] peerStorage_ <- deriveEphemeralStorage $ serverStorage peerServer_ peerInStorage <- derivePartialStorage peerStorage_ @@ -731,7 +758,11 @@ mkPeer peerServer_ peerAddress = do serverPeer :: Server -> SockAddr -> IO Peer serverPeer server paddr = do - serverPeer' server (DatagramAddress paddr) + let paddr' = case IP.fromSockAddr paddr of + Just (IP.IPv4 ipv4, port) + -> IP.toSockAddr (IP.IPv6 $ IP.toIPv6w (0, 0, 0xffff, IP.fromIPv4w ipv4), port) + _ -> paddr + serverPeer' server (DatagramAddress paddr') #ifdef ENABLE_ICE_SUPPORT serverPeerIce :: Server -> IceSession -> IO Peer @@ -754,6 +785,20 @@ serverPeer' server paddr = do writeFlow (serverControlFlow server) (RequestConnection paddr) return peer +dropPeer :: MonadIO m => Peer -> m () +dropPeer peer = liftIO $ do + modifyMVar_ (serverPeers $ peerServer peer) $ \pvalue -> do + atomically $ do + readTVar (peerState peer) >>= \case + PeerConnected conn -> connClose conn + _ -> return() + writeTVar (peerState peer) PeerDropped + return $ M.delete (peerAddress peer) pvalue + +isPeerDropped :: MonadIO m => Peer -> m Bool +isPeerDropped peer = liftIO $ atomically $ readTVar (peerState peer) >>= \case + PeerDropped -> return True + _ -> return False sendToPeer :: (Service s, MonadIO m) => Peer -> s -> m () sendToPeer peer packet = sendToPeerList peer [ServiceReply (Left packet) True] @@ -777,9 +822,10 @@ sendToPeerList peer parts = do sendToPeerS' :: SecurityRequirement -> Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () sendToPeerS' secure Peer {..} ackedBy packet = do - readTVar peerConnection >>= \case - Left xs -> writeTVar peerConnection $ Left $ (secure, packet, ackedBy) : xs - Right conn -> writeFlow (connData conn) (secure, packet, ackedBy) + readTVar peerState >>= \case + PeerInit xs -> writeTVar peerState $ PeerInit $ (secure, packet, ackedBy) : xs + PeerConnected conn -> writeFlow (connData conn) (secure, packet, ackedBy) + PeerDropped -> return () sendToPeerS :: Peer -> [TransportHeaderItem] -> TransportPacket Ref -> STM () sendToPeerS = sendToPeerS' EncryptedOnly diff --git a/src/Erebos/Network/Protocol.hs b/src/Erebos/Network/Protocol.hs index a669988..26bd615 100644 --- a/src/Erebos/Network/Protocol.hs +++ b/src/Erebos/Network/Protocol.hs @@ -20,6 +20,7 @@ module Erebos.Network.Protocol ( connData, connGetChannel, connSetChannel, + connClose, RawStreamReader, RawStreamWriter, connAddWriteStream, @@ -44,6 +45,7 @@ import Data.ByteString (ByteString) import Data.ByteString qualified as B import Data.ByteString.Char8 qualified as BC import Data.ByteString.Lazy qualified as BL +import Data.Function import Data.List import Data.Maybe import Data.Text (Text) @@ -184,6 +186,9 @@ data Connection addr = Connection , cOutStreams :: TVar [(Word8, Stream)] } +instance Eq (Connection addr) where + (==) = (==) `on` cChannel + connAddress :: Connection addr -> addr connAddress = cAddress @@ -197,6 +202,12 @@ connSetChannel :: Connection addr -> ChannelState -> STM () connSetChannel Connection {..} ch = do writeTVar cChannel ch +connClose :: Connection addr -> STM () +connClose conn@Connection {..} = do + let GlobalState {..} = cGlobalState + writeTVar cChannel ChannelClosed + writeTVar gConnections . filter (/=conn) =<< readTVar gConnections + connAddWriteStream :: Connection addr -> STM (Either String (TransportHeaderItem, RawStreamWriter, IO ())) connAddWriteStream conn@Connection {..} = do outStreams <- readTVar cOutStreams @@ -380,6 +391,7 @@ data ChannelState = ChannelNone | ChannelPeerRequest WaitingRef | ChannelOurAccept (Stored ChannelAccept) Channel | ChannelEstablished Channel + | ChannelClosed data ReservedToSend = ReservedToSend { rsAckedBy :: Maybe (TransportHeaderItem -> Bool) |