summaryrefslogtreecommitdiff
path: root/main/WebSocket.hs
blob: 79cb1415ad286fd79138f9b7f215d235e94a6e3d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
{-# LANGUAGE OverloadedStrings #-}

module WebSocket (
    WebSocketAddress(..),
    WebSocketOptions(..), defaultWebSocketOptions,
    startWebsocketServer,
) where

import Control.Concurrent
import Control.Exception
import Control.Monad

import Data.ByteString.Char8 qualified as BC
import Data.ByteString.Lazy qualified as BL
import Data.Unique

import Erebos.Network

import Network.WebSockets qualified as WS


data WebSocketAddress = WebSocketAddress Unique (Maybe String) WS.Connection

instance Eq WebSocketAddress where
    WebSocketAddress u _ _ == WebSocketAddress u' _ _ = u == u'

instance Ord WebSocketAddress where
    compare (WebSocketAddress u _ _) (WebSocketAddress u' _ _) = compare u u'

instance Show WebSocketAddress where
    show (WebSocketAddress _ Nothing _) = "websocket"
    show (WebSocketAddress _ (Just addr) _) = "websocket " <> addr

instance PeerAddressType WebSocketAddress where
    sendBytesToAddress (WebSocketAddress _ _ conn) msg = do
        WS.sendDataMessage conn $ WS.Binary $ BL.fromStrict msg
    connectionToAddressClosed (WebSocketAddress _ _ conn) = do
        WS.sendClose conn BL.empty `catch` \e -> if
            | Just WS.ConnectionClosed <- fromException e -> return ()
            | otherwise -> throwIO e


data WebSocketOptions = WebSocketOptions
    { wsAddress :: String
    , wsPort :: Int
    , wsDebugLog :: Bool
    }

defaultWebSocketOptions :: WebSocketOptions
defaultWebSocketOptions = WebSocketOptions
    { wsAddress = "::"
    , wsPort = 80
    , wsDebugLog = False
    }


startWebsocketServer :: Server -> (String -> IO ()) -> WebSocketOptions -> IO ()
startWebsocketServer server logd WebSocketOptions {..} = do
    void $ forkIO $ do
        WS.runServer wsAddress wsPort $ \pending -> do
            when wsDebugLog $ do
                logd $ "WebSocket request: " <> show (WS.pendingRequest pending)
            conn <- WS.acceptRequest pending
            u <- newUnique
            let mbaddr = fmap BC.unpack $ lookup "X-Real-IP" $ WS.requestHeaders $ WS.pendingRequest pending
            let paddr = WebSocketAddress u mbaddr conn
            void $ serverPeerCustom server paddr

            let handler e
                    | Just WS.CloseRequest {} <- fromException e = do
                        dropPeerAddress server $ CustomPeerAddress paddr
                    | Just WS.ConnectionClosed <- fromException e = do
                        dropPeerAddress server $ CustomPeerAddress paddr
                    | otherwise = do
                        logd $ "WebSocket thread exception: " ++ show e
            handle handler $ do
                WS.withPingThread conn 30 (return ()) $ do
                    forever $ do
                        WS.receiveDataMessage conn >>= \case
                            WS.Binary msg -> receivedFromCustomAddress server paddr $ BL.toStrict msg
                            WS.Text {} -> logd $ "unexpected websocket text message"