From 93e583408af5f41f9dde324f198e47fa94e1881e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Sun, 30 Aug 2020 17:31:48 +0200 Subject: Peer connection through ICE --- src/ICE.chs | 200 +++++++++++++++++++++++++++++ src/ICE/pjproject.c | 363 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/ICE/pjproject.h | 14 ++ src/Main.hs | 61 ++++++++- src/Network.hs | 106 +++++++++------ src/Storage.hs | 8 +- 6 files changed, 710 insertions(+), 42 deletions(-) create mode 100644 src/ICE.chs create mode 100644 src/ICE/pjproject.c create mode 100644 src/ICE/pjproject.h (limited to 'src') diff --git a/src/ICE.chs b/src/ICE.chs new file mode 100644 index 0000000..06ad7aa --- /dev/null +++ b/src/ICE.chs @@ -0,0 +1,200 @@ +{-# LANGUAGE ForeignFunctionInterface #-} +{-# LANGUAGE RecursiveDo #-} + +module ICE ( + IceSession, + IceSessionRole(..), + IceRemoteInfo, + + iceCreate, + iceDestroy, + iceRemoteInfo, + iceShow, + iceConnect, + iceSend, + + iceSetChan, +) where + +import Control.Arrow +import Control.Concurrent.Chan +import Control.Concurrent.MVar +import Control.Monad +import Control.Monad.Except +import Control.Monad.Identity + +import Data.ByteString (ByteString, packCStringLen, useAsCString) +import qualified Data.ByteString.Lazy.Char8 as BLC +import Data.ByteString.Unsafe +import Data.Function +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import qualified Data.Text.Read as T + +import Foreign.C.String +import Foreign.C.Types +import Foreign.Marshal.Alloc +import Foreign.Marshal.Array +import Foreign.Ptr +import Foreign.StablePtr + +import Storage + +#include "pjproject.h" + +data IceSession = IceSession + { isStrans :: PjIceStrans + , isChan :: MVar (Either [ByteString] (MappedChan ByteString)) + } + +data MappedChan a = forall b. MappedChan (a -> b) (Chan b) + +instance Eq IceSession where + (==) = (==) `on` isStrans + +instance Ord IceSession where + compare = compare `on` isStrans + +instance Show IceSession where + show _ = "" + + +data IceRemoteInfo = IceRemoteInfo + { iriUsernameFrament :: Text + , iriPassword :: Text + , iriDefaultCandidate :: Text + , iriCandidates :: [Text] + } + +data IceCandidate = IceCandidate + { icandFoundation :: Text + , icandPriority :: Int + , icandAddr :: Text + , icandPort :: Int + , icandType :: Text + } + +instance Storable IceRemoteInfo where + store' x = storeRec $ do + storeText "ice-ufrag" $ iriUsernameFrament x + storeText "ice-pass" $ iriPassword x + storeText "ice-default" $ iriDefaultCandidate x + mapM_ (storeText "ice-candidate") $ iriCandidates x + + load' = loadRec $ IceRemoteInfo + <$> loadText "ice-ufrag" + <*> loadText "ice-pass" + <*> loadText "ice-default" + <*> loadTexts "ice-candidate" + +instance StorableText IceCandidate where + toText x = T.concat $ + [ icandFoundation x + , T.singleton ' ' + , T.pack $ show $ icandPriority x + , T.singleton ' ' + , icandAddr x + , T.singleton ' ' + , T.pack $ show $ icandPort x + , T.singleton ' ' + , icandType x + ] + + fromText t = case T.words t of + [found, tprio, addr, tport, ctype] + | Right (prio, _) <- T.decimal tprio + , Right (port, _) <- T.decimal tport + -> return $ IceCandidate + { icandFoundation = found + , icandPriority = prio + , icandAddr = addr + , icandPort = port + , icandType = ctype + } + _ -> throwError "failed to parse candidate" + + +{#enum pj_ice_sess_role as IceSessionRole {underscoreToCase} deriving (Show, Eq) #} + +{#pointer *pj_ice_strans as ^ #} + +iceCreate :: IceSessionRole -> (IceSession -> IO ()) -> IO IceSession +iceCreate role cb = do + rec sptr <- newStablePtr sess + cbptr <- newStablePtr $ cb sess + sess <- IceSession + <$> {#call ice_create #} (fromIntegral $ fromEnum role) (castStablePtrToPtr sptr) (castStablePtrToPtr cbptr) + <*> (newMVar $ Left []) + return $ sess + +{#fun ice_destroy as ^ { isStrans `IceSession' } -> `()' #} + +iceRemoteInfo :: IceSession -> IO IceRemoteInfo +iceRemoteInfo sess = + allocaBytes (32*128) $ \bytes -> + allocaArray 29 $ \carr -> do + let (ufrag : pass : def : cptrs) = take 32 $ iterate (`plusPtr` 128) bytes + pokeArray carr cptrs + + ncand <- {#call ice_encode_session #} (isStrans sess) ufrag pass def carr 128 29 + if ncand < 0 then fail "failed to generate ICE remote info" + else IceRemoteInfo + <$> (T.pack <$> peekCString ufrag) + <*> (T.pack <$> peekCString pass) + <*> (T.pack <$> peekCString def) + <*> (mapM (return . T.pack <=< peekCString) $ take (fromIntegral ncand) cptrs) + +iceShow :: IceSession -> IO String +iceShow sess = do + st <- memoryStorage + return . drop 1 . dropWhile (/='\n') . BLC.unpack . runIdentity =<< + ioLoadBytes =<< store st =<< iceRemoteInfo sess + +iceConnect :: IceSession -> IceRemoteInfo -> (IO ()) -> IO () +iceConnect sess remote cb = do + cbptr <- newStablePtr $ cb + ice_connect sess cbptr + (iriUsernameFrament remote) + (iriPassword remote) + (iriDefaultCandidate remote) + (iriCandidates remote) + +{#fun ice_connect { isStrans `IceSession', castStablePtrToPtr `StablePtr (IO ())', + withText* `Text', withText* `Text', withText* `Text', withTextArray* `[Text]'& } -> `()' #} + +withText :: Text -> (Ptr CChar -> IO a) -> IO a +withText t f = useAsCString (T.encodeUtf8 t) f + +withTextArray :: Num n => [Text] -> ((Ptr (Ptr CChar), n) -> IO ()) -> IO () +withTextArray tsAll f = helper tsAll [] + where helper (t:ts) bs = withText t $ \b -> helper ts (b:bs) + helper [] bs = allocaArray (length bs) $ \ptr -> do + pokeArray ptr $ reverse bs + f (ptr, fromIntegral $ length bs) + +withByteStringLen :: Num n => ByteString -> ((Ptr CChar, n) -> IO a) -> IO a +withByteStringLen t f = unsafeUseAsCStringLen t (f . (id *** fromIntegral)) + +{#fun ice_send as ^ { isStrans `IceSession', withByteStringLen* `ByteString'& } -> `()' #} + +foreign export ccall ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb :: StablePtr (IO ()) -> IO () +ice_call_cb = join . deRefStablePtr + +iceSetChan :: IceSession -> (ByteString -> a) -> Chan a -> IO () +iceSetChan sess f chan = do + modifyMVar_ (isChan sess) $ \orig -> do + case orig of + Left buf -> writeList2Chan chan $ map f $ reverse buf + Right _ -> return () + return $ Right $ MappedChan f chan + +foreign export ccall ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data :: StablePtr IceSession -> Ptr CChar -> Int -> IO () +ice_rx_data sptr buf len = do + sess <- deRefStablePtr sptr + bs <- packCStringLen (buf, len) + modifyMVar_ (isChan sess) $ \case + mc@(Right (MappedChan f chan)) -> writeChan chan (f bs) >> return mc + Left bss -> return $ Left (bs:bss) diff --git a/src/ICE/pjproject.c b/src/ICE/pjproject.c new file mode 100644 index 0000000..0ae69e9 --- /dev/null +++ b/src/ICE/pjproject.c @@ -0,0 +1,363 @@ +#include "pjproject.h" +#include "ICE_stub.h" + +#include +#include +#include +#include +#include +#include + +static struct +{ + pj_caching_pool cp; + pj_pool_t * pool; + pj_ice_strans_cfg cfg; + pj_sockaddr def_addr; +} ice; + +struct user_data +{ + pj_ice_sess_role role; + HsStablePtr sptr; + HsStablePtr cb_init; + HsStablePtr cb_connect; +}; + +static void ice_perror(const char * msg, pj_status_t status) +{ + char err[PJ_ERR_MSG_SIZE]; + pj_strerror(status, err, sizeof(err)); + fprintf(stderr, "ICE: %s: %s\n", msg, err); +} + +static int ice_worker_thread(void * unused) +{ + PJ_UNUSED_ARG(unused); + + while (true) { + pj_time_val max_timeout = { 0, 0 }; + pj_time_val timeout = { 0, 0 }; + + max_timeout.msec = 500; + + pj_timer_heap_poll(ice.cfg.stun_cfg.timer_heap, &timeout); + + pj_assert(timeout.sec >= 0 && timeout.msec >= 0); + if (timeout.msec >= 1000) + timeout.msec = 999; + + if (PJ_TIME_VAL_GT(timeout, max_timeout)) + timeout = max_timeout; + + int c = pj_ioqueue_poll(ice.cfg.stun_cfg.ioqueue, &timeout); + if (c < 0) + pj_thread_sleep(PJ_TIME_VAL_MSEC(timeout)); + } + + return 0; +} + +static void cb_on_rx_data(pj_ice_strans * strans, unsigned comp_id, + void * pkt, pj_size_t size, + const pj_sockaddr_t * src_addr, unsigned src_addr_len) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + ice_rx_data(udata->sptr, pkt, size); +} + +static void cb_on_ice_complete(pj_ice_strans * strans, + pj_ice_strans_op op, pj_status_t status) +{ + if (status != PJ_SUCCESS) { + ice_perror("cb_on_ice_complete", status); + ice_destroy(strans); + return; + } + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (op == PJ_ICE_STRANS_OP_INIT) { + pj_status_t istatus = pj_ice_strans_init_ice(strans, udata->role, NULL, NULL); + if (istatus != PJ_SUCCESS) + ice_perror("error creating session", istatus); + + if (udata->cb_init) { + ice_call_cb(udata->cb_init); + hs_free_stable_ptr(udata->cb_init); + udata->cb_init = NULL; + } + } + + if (op == PJ_ICE_STRANS_OP_NEGOTIATION) { + if (udata->cb_connect) { + ice_call_cb(udata->cb_connect); + hs_free_stable_ptr(udata->cb_connect); + udata->cb_connect = NULL; + } + } +} + +static void ice_init(void) +{ + static bool done = false; + static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; + pthread_mutex_lock(&mutex); + + if (done) { + pthread_mutex_unlock(&mutex); + goto exit; + } + + pj_log_set_level(1); + + if (pj_init() != PJ_SUCCESS) { + fprintf(stderr, "pj_init failed\n"); + goto exit; + } + if (pjlib_util_init() != PJ_SUCCESS) { + fprintf(stderr, "pjlib_util_init failed\n"); + goto exit; + } + if (pjnath_init() != PJ_SUCCESS) { + fprintf(stderr, "pjnath_init failed\n"); + goto exit; + } + + pj_caching_pool_init(&ice.cp, NULL, 0); + + pj_ice_strans_cfg_default(&ice.cfg); + ice.cfg.stun_cfg.pf = &ice.cp.factory; + + ice.pool = pj_pool_create(&ice.cp.factory, "ice", 512, 512, NULL); + + if (pj_timer_heap_create(ice.pool, 100, + &ice.cfg.stun_cfg.timer_heap) != PJ_SUCCESS) { + fprintf(stderr, "pj_timer_heap_create failed\n"); + goto exit; + } + + if (pj_ioqueue_create(ice.pool, 16, &ice.cfg.stun_cfg.ioqueue) != PJ_SUCCESS) { + fprintf(stderr, "pj_ioqueue_create failed\n"); + goto exit; + } + + pj_thread_t * thread; + if (pj_thread_create(ice.pool, "ice", &ice_worker_thread, + NULL, 0, 0, &thread) != PJ_SUCCESS) { + fprintf(stderr, "pj_thread_create failed\n"); + goto exit; + } + + ice.cfg.af = pj_AF_INET(); + ice.cfg.opt.aggressive = PJ_TRUE; + + ice.cfg.stun.server.ptr = "discovery1.erebosprotocol.net"; + ice.cfg.stun.server.slen = strlen(ice.cfg.stun.server.ptr); + ice.cfg.stun.port = 29670; + + ice.cfg.turn.server = ice.cfg.stun.server; + ice.cfg.turn.port = ice.cfg.stun.port; + ice.cfg.turn.auth_cred.type = PJ_STUN_AUTH_CRED_STATIC; + ice.cfg.turn.auth_cred.data.static_cred.data_type = PJ_STUN_PASSWD_PLAIN; + ice.cfg.turn.conn_type = PJ_TURN_TP_UDP; + +exit: + done = true; + pthread_mutex_unlock(&mutex); +} + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb) +{ + ice_init(); + + pj_ice_strans * res; + + struct user_data * udata = malloc(sizeof(struct user_data)); + udata->role = role; + udata->sptr = sptr; + udata->cb_init = cb; + + pj_ice_strans_cb icecb = { + .on_rx_data = cb_on_rx_data, + .on_ice_complete = cb_on_ice_complete, + }; + + pj_status_t status = pj_ice_strans_create(NULL, &ice.cfg, 1, + udata, &icecb, &res); + + if (status != PJ_SUCCESS) + ice_perror("error creating ice", status); + + return res; +} + +void ice_destroy(pj_ice_strans * strans) +{ + struct user_data * udata = pj_ice_strans_get_user_data(strans); + if (udata->sptr) + hs_free_stable_ptr(udata->sptr); + if (udata->cb_init) + hs_free_stable_ptr(udata->cb_init); + if (udata->cb_connect) + hs_free_stable_ptr(udata->cb_connect); + free(udata); + + pj_ice_strans_stop_ice(strans); + pj_ice_strans_destroy(strans); +} + +ssize_t ice_encode_session(pj_ice_strans * strans, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand) +{ + int n; + pj_str_t local_ufrag, local_pwd; + pj_status_t status; + + pj_ice_strans_get_ufrag_pwd(strans, &local_ufrag, &local_pwd, NULL, NULL); + + n = snprintf(ufrag, maxlen, "%.*s", (int) local_ufrag.slen, local_ufrag.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + n = snprintf(pass, maxlen, "%.*s", (int) local_pwd.slen, local_pwd.ptr); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + pj_ice_sess_cand cand[PJ_ICE_ST_MAX_CAND]; + char ipaddr[PJ_INET6_ADDRSTRLEN]; + + status = pj_ice_strans_get_def_cand(strans, 1, &cand[0]); + if (status != PJ_SUCCESS) + return -status; + + n = snprintf(def, maxlen, "%s %d", + pj_sockaddr_print(&cand[0].addr, ipaddr, sizeof(ipaddr), 0), + (int) pj_sockaddr_get_port(&cand[0].addr)); + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + + unsigned cand_cnt = PJ_ARRAY_SIZE(cand); + status = pj_ice_strans_enum_cands(strans, 1, &cand_cnt, cand); + if (status != PJ_SUCCESS) + return -status; + + for (unsigned i = 0; i < cand_cnt && i < maxcand; i++) { + char ipaddr[PJ_INET6_ADDRSTRLEN]; + n = snprintf(candidates[i], maxlen, + "%.*s %u %s %u %s", + (int) cand[i].foundation.slen, cand[i].foundation.ptr, + cand[i].prio, + pj_sockaddr_print(&cand[i].addr, ipaddr, sizeof(ipaddr), 0), + (unsigned) pj_sockaddr_get_port(&cand[i].addr), + pj_ice_get_cand_type_name(cand[i].type)); + + if (n < 0 || n == maxlen) + return -PJ_ETOOSMALL; + } + + return cand_cnt; +} + +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * tcandidates[], size_t ncand) +{ + unsigned def_port = 0; + char def_addr[80]; + pj_bool_t done = PJ_FALSE; + char line[256]; + pj_ice_sess_cand candidates[PJ_ICE_ST_MAX_CAND]; + + struct user_data * udata = pj_ice_strans_get_user_data(strans); + udata->cb_connect = cb; + + def_addr[0] = '\0'; + + if (ncand == 0) { + fprintf(stderr, "ICE: no candidates\n"); + return; + } + + int cnt = sscanf(defcand, "%s %u", def_addr, &def_port); + if (cnt != 2) { + fprintf(stderr, "ICE: error parsing default candidate\n"); + return; + } + + int okcand = 0; + for (int i = 0; i < ncand; i++) { + char foundation[32], ipaddr[80], type[32]; + int prio, port; + + int cnt = sscanf(tcandidates[i], "%s %d %s %d %s", + foundation, &prio, + ipaddr, &port, + type); + if (cnt != 5) + continue; + + pj_ice_sess_cand * cand = &candidates[okcand]; + pj_bzero(cand, sizeof(*cand)); + + if (strcmp(type, "host") == 0) + cand->type = PJ_ICE_CAND_TYPE_HOST; + else if (strcmp(type, "srflx") == 0) + cand->type = PJ_ICE_CAND_TYPE_SRFLX; + else if (strcmp(type, "relay") == 0) + cand->type = PJ_ICE_CAND_TYPE_RELAYED; + else + continue; + + cand->comp_id = 1; + pj_strdup2(ice.pool, &cand->foundation, foundation); + cand->prio = prio; + + int af = strchr(ipaddr, ':') ? pj_AF_INET6() : pj_AF_INET(); + pj_str_t tmpaddr = pj_str(ipaddr); + pj_sockaddr_init(af, &cand->addr, NULL, 0); + pj_status_t status = pj_sockaddr_set_str_addr(af, &cand->addr, &tmpaddr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid IP address \"%s\"\n", ipaddr); + continue; + } + + pj_sockaddr_set_port(&cand->addr, (pj_uint16_t)port); + okcand++; + } + + pj_str_t tmp_addr; + pj_status_t status; + + int af = strchr(def_addr, ':') ? pj_AF_INET6() : pj_AF_INET(); + + pj_sockaddr_init(af, &ice.def_addr, NULL, 0); + tmp_addr = pj_str(def_addr); + status = pj_sockaddr_set_str_addr(af, &ice.def_addr, &tmp_addr); + if (status != PJ_SUCCESS) { + fprintf(stderr, "ICE: invalid default IP address \"%s\"\n", def_addr); + return; + } + pj_sockaddr_set_port(&ice.def_addr, (pj_uint16_t) def_port); + + pj_str_t rufrag, rpwd; + status = pj_ice_strans_start_ice(strans, + pj_cstr(&rufrag, ufrag), pj_cstr(&rpwd, pass), + okcand, candidates); + if (status != PJ_SUCCESS) { + ice_perror("error starting ICE", status); + return; + } +} + +void ice_send(pj_ice_strans * strans, const char * data, size_t len) +{ + if (!pj_ice_strans_sess_is_complete(strans)) { + fprintf(stderr, "ICE: negotiation has not been started or is in progress\n"); + return; + } + + pj_status_t status = pj_ice_strans_sendto(strans, 1, data, len, + &ice.def_addr, pj_sockaddr_get_len(&ice.def_addr)); + if (status != PJ_SUCCESS && status != PJ_EPENDING) + ice_perror("error sending data", status); +} diff --git a/src/ICE/pjproject.h b/src/ICE/pjproject.h new file mode 100644 index 0000000..e230e75 --- /dev/null +++ b/src/ICE/pjproject.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +pj_ice_strans * ice_create(pj_ice_sess_role role, HsStablePtr sptr, HsStablePtr cb); +void ice_destroy(pj_ice_strans * strans); + +ssize_t ice_encode_session(pj_ice_strans *, char * ufrag, char * pass, + char * def, char * candidates[], size_t maxlen, size_t maxcand); +void ice_connect(pj_ice_strans * strans, HsStablePtr cb, + const char * ufrag, const char * pass, + const char * defcand, const char * candidates[], size_t ncand); +void ice_send(pj_ice_strans *, const char * data, size_t len); diff --git a/src/Main.hs b/src/Main.hs index 9404517..0e8970f 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -25,6 +25,7 @@ import System.Environment import Attach import Contact +import ICE import Identity import Message import Network @@ -95,7 +96,8 @@ interactiveLoop st bhost = runInputT defaultSettings $ do haveTerminalUI >>= \case True -> return () False -> error "Requires terminal" extPrint <- getExternalPrint - let extPrintLn str = extPrint $ str ++ "\n"; + let extPrintLn str = extPrint $ case reverse str of ('\n':_) -> str + _ -> str ++ "\n"; server <- liftIO $ do startServer erebosHead extPrintLn bhost [ SomeService @AttachService Proxy @@ -158,7 +160,10 @@ interactiveLoop st bhost = runInputT defaultSettings $ do let loop (Just cstate) = runMaybeT (process cstate) >>= loop loop Nothing = return () - loop $ Just $ CommandState { csPeer = Nothing } + loop $ Just $ CommandState + { csPeer = Nothing + , csIceSessions = [] + } data CommandInput = CommandInput @@ -171,6 +176,7 @@ data CommandInput = CommandInput data CommandState = CommandState { csPeer :: Maybe Peer + , csIceSessions :: [IceSession] } newtype CommandM a = CommandM (ReaderT CommandInput (StateT CommandState (ExceptT String IO)) a) @@ -195,6 +201,11 @@ commands = , ("contacts", cmdContacts) , ("contact-add", cmdContactAdd) , ("contact-accept", cmdContactAccept) + , ("ice-create", cmdIceCreate) + , ("ice-destroy", cmdIceDestroy) + , ("ice-show", cmdIceShow) + , ("ice-connect", cmdIceConnect) + , ("ice-send", cmdIceSend) ] cmdUnknown :: String -> Command @@ -212,8 +223,7 @@ showPeer peer = PeerIdentityUnknown -> "" PeerIdentityRef wref -> "<" ++ BC.unpack (showRefDigest $ wrDigest wref) ++ ">" PeerIdentityFull pid -> T.unpack $ displayIdentity pid - DatagramAddress addr = peerAddress peer - in name ++ " [" ++ show addr ++ "]" + in name ++ " [" ++ show (peerAddress peer) ++ "]" cmdSetPeer :: Int -> Command cmdSetPeer n | n < 1 = liftIO $ putStrLn "Invalid peer index" @@ -276,3 +286,46 @@ cmdContactAccept = join $ contactAccept <$> asks ciPrint <*> asks ciHead <*> (maybe (throwError "no peer selected") return =<< gets csPeer) + +cmdIceCreate :: Command +cmdIceCreate = do + role <- asks ciLine >>= return . \case + 'm':_ -> PjIceSessRoleControlling + 's':_ -> PjIceSessRoleControlled + _ -> PjIceSessRoleUnknown + eprint <- asks ciPrint + sess <- liftIO $ iceCreate role $ eprint <=< iceShow + modify $ \s -> s { csIceSessions = sess : csIceSessions s } + +cmdIceDestroy :: Command +cmdIceDestroy = do + s:ss <- gets csIceSessions + modify $ \st -> st { csIceSessions = ss } + liftIO $ iceDestroy s + +cmdIceShow :: Command +cmdIceShow = do + sess <- gets csIceSessions + eprint <- asks ciPrint + liftIO $ forM_ (zip [1::Int ..] sess) $ \(i, s) -> do + eprint $ "[" ++ show i ++ "]" + eprint =<< iceShow s + +cmdIceConnect :: Command +cmdIceConnect = do + s:_ <- gets csIceSessions + server <- asks ciServer + let loadInfo = BC.getLine >>= \case line | BC.null line -> return [] + | otherwise -> (line:) <$> loadInfo + Right remote <- liftIO $ do + st <- memoryStorage + pst <- derivePartialStorage st + rbytes <- (BL.fromStrict . BC.unlines) <$> loadInfo + copyRef st =<< storeRawBytes pst (BL.fromChunks [ BC.pack "rec ", BC.pack (show (BL.length rbytes)), BC.singleton '\n' ] `BL.append` rbytes) + liftIO $ iceConnect s (load remote) $ void $ serverPeerIce server s + +cmdIceSend :: Command +cmdIceSend = void $ do + s:_ <- gets csIceSessions + server <- asks ciServer + liftIO $ serverPeerIce server s diff --git a/src/Network.hs b/src/Network.hs index 5685627..cbc68b6 100644 --- a/src/Network.hs +++ b/src/Network.hs @@ -9,7 +9,7 @@ module Network ( PeerChannel(..), WaitingRef, wrDigest, Service(..), - serverPeer, + serverPeer, serverPeerIce, sendToPeer, sendToPeerStored, sendToPeerWith, discoveryPort, @@ -31,9 +31,10 @@ import Data.Maybe import Data.Typeable import Network.Socket -import Network.Socket.ByteString (recvFrom, sendTo) +import qualified Network.Socket.ByteString as S import Channel +import ICE import Identity import PubKey import Service @@ -53,7 +54,8 @@ data Server = Server { serverStorage :: Storage , serverIdentity :: MVar UnifiedIdentity , serverSocket :: MVar Socket - , serverPeers :: MVar (Map SockAddr Peer) + , serverChanPacket :: Chan (PeerAddress, BC.ByteString) + , serverPeers :: MVar (Map PeerAddress Peer) , serverChanPeer' :: Chan Peer } @@ -66,7 +68,6 @@ data Peer = Peer , peerIdentity :: PeerIdentity , peerIdentityUpdate :: [WaitingRef] , peerChannel :: PeerChannel - , peerSocket :: Socket , peerStorage :: Storage , peerInStorage :: PartialStorage , peerServiceState :: MVar (M.Map ServiceID SomeServiceState) @@ -75,8 +76,24 @@ data Peer = Peer , peerWaitingRefs :: [WaitingRef] } -data PeerAddress = DatagramAddress SockAddr - deriving (Show) +data PeerAddress = DatagramAddress Socket SockAddr + | PeerIceSession IceSession + +instance Show PeerAddress where + show (DatagramAddress _ addr) = show addr + show (PeerIceSession ice) = show ice + +instance Eq PeerAddress where + DatagramAddress _ addr == DatagramAddress _ addr' = addr == addr' + PeerIceSession ice == PeerIceSession ice' = ice == ice' + _ == _ = False + +instance Ord PeerAddress where + compare (DatagramAddress _ addr) (DatagramAddress _ addr') = compare addr addr' + compare (DatagramAddress _ _ ) _ = LT + compare _ (DatagramAddress _ _ ) = GT + compare (PeerIceSession ice ) (PeerIceSession ice') = compare ice ice' + data PeerIdentity = PeerIdentityUnknown | PeerIdentityRef WaitingRef @@ -184,6 +201,7 @@ receivedWaitingRef nref wr@(WaitingRef _ _ mvar) = do startServer :: Head LocalState -> (String -> IO ()) -> String -> [SomeService] -> IO Server startServer origHead logd bhost services = do let storage = refStorage $ headRef origHead + chanPacket <- newChan chanPeer <- newChan chanSvc <- newChan svcStates <- newMVar M.empty @@ -206,7 +224,7 @@ startServer origHead logd bhost services = do readMVar midentity >>= \identity -> do st <- derivePartialStorage storage baddr:_ <- getAddrInfo (Just $ defaultHints { addrSocketType = Datagram }) (Just bhost) (Just discoveryPort) - void $ sendTo sock (BL.toStrict $ serializeObject $ transportToObject $ TransportHeader [ AnnounceSelf $ partialRef st $ storedRef $ idData identity ]) (addrAddress baddr) + void $ S.sendTo sock (BL.toStrict $ serializeObject $ transportToObject $ TransportHeader [ AnnounceSelf $ partialRef st $ storedRef $ idData identity ]) (addrAddress baddr) threadDelay $ announceIntervalSeconds * 1000 * 1000 let announceUpdate identity = do @@ -220,9 +238,8 @@ startServer origHead logd bhost services = do peer | PeerIdentityFull _ <- peerIdentity peer , ChannelEstablished ch <- peerChannel peer - , DatagramAddress paddr <- peerAddress peer -> runExceptT (channelEncrypt ch plaintext) >>= \case - Right ctext -> void $ sendTo (peerSocket peer) ctext paddr + Right ctext -> void $ sendTo peer ctext Left err -> logd $ "Failed to encrypt data: " ++ err | otherwise -> return () @@ -246,8 +263,12 @@ startServer origHead logd bhost services = do when changedShared $ do mapM_ (shareState idt shared) =<< readMVar peers + void $ forkIO $ forever $ do + (msg, saddr) <- S.recvFrom sock 4096 + writeChan chanPacket (DatagramAddress sock saddr, msg) + forever $ do - (msg, paddr) <- recvFrom sock 4096 + (paddr, msg) <- readChan chanPacket modifyMVar_ peers $ \pvalue -> do let mbpeer = M.lookup paddr pvalue (peer, content, secure) <- if @@ -263,7 +284,7 @@ startServer origHead logd bhost services = do -> return (peer, msg, False) | otherwise - -> (, msg, False) <$> mkPeer storage sock paddr + -> (, msg, False) <$> mkPeer storage paddr case runExcept $ deserializeObjects (peerInStorage peer) $ BL.fromStrict content of Right (obj:objs) @@ -323,13 +344,14 @@ startServer origHead logd bhost services = do logd $ "unhandled service '" ++ show (toUUID svc) ++ "'" return (svcs, global) - | DatagramAddress paddr <- peerAddress peer -> do - logd $ "service packet from peer with incomplete identity " ++ show paddr + | otherwise -> do + logd $ "service packet from peer with incomplete identity " ++ show (peerAddress peer) return Server { serverStorage = storage , serverIdentity = midentity , serverSocket = ssocket + , serverChanPacket = chanPacket , serverPeers = peers , serverChanPeer' = chanPeer } @@ -362,7 +384,6 @@ handlePacket :: (String -> IO ()) -> Head LocalState -> UnifiedIdentity -> Bool -> TransportHeader -> IO (Maybe Peer) handlePacket logd origHead identity secure opeer chanSvc svcs (TransportHeader headers) = do let sidentity = idData identity - DatagramAddress paddr = peerAddress opeer plaintextRefs = map (refDigest . storedRef) $ concatMap (collectStoredObjects . wrappedLoad) $ concat [ [ storedRef sidentity ] , map storedRef $ idUpdates identity @@ -477,7 +498,7 @@ handlePacket logd origHead identity secure opeer chanSvc svcs (TransportHeader h case res of Left err -> do - logd $ "Error in handling packet from " ++ show paddr ++ ": " ++ err + logd $ "Error in handling packet from " ++ show (peerAddress opeer) ++ ": " ++ err return Nothing Right ph -> do when (not $ null $ phHead ph) $ do @@ -488,9 +509,9 @@ handlePacket logd origHead identity secure opeer chanSvc svcs (TransportHeader h case peerChannel $ phPeer ph of ChannelEstablished ch -> do x <- runExceptT (channelEncrypt ch plain) - case x of Right ctext -> void $ sendTo (peerSocket $ phPeer ph) ctext paddr + case x of Right ctext -> void $ sendTo (phPeer ph) ctext Left err -> logd $ "Failed to encrypt data: " ++ err - _ -> void $ sendTo (peerSocket $ phPeer ph) plain paddr + _ -> void $ sendTo (phPeer ph) plain return $ if phPeerChanged ph then Just $ phPeer ph else Nothing @@ -599,13 +620,12 @@ finalizedChannel oh self = do -- Outstanding service packets gets phPeer >>= \case - Peer { peerChannel = ChannelEstablished ch - , peerAddress = DatagramAddress paddr - , peerServiceOutQueue = oqueue - , peerSocket = sock - } -> do - ps <- liftIO $ modifyMVar oqueue $ return . ([],) - forM_ ps $ sendPacket sock paddr ch + peer@Peer + { peerChannel = ChannelEstablished ch + , peerServiceOutQueue = oqueue + } -> do + ps <- liftIO $ modifyMVar oqueue $ return . ([],) + forM_ ps $ sendPacket peer ch _ -> return () @@ -645,18 +665,17 @@ handleServices chan = gets (peerServiceInQueue . phPeer) >>= \case updatePeer $ \p -> p { peerServiceInQueue = queue' } -mkPeer :: Storage -> Socket -> SockAddr -> IO Peer -mkPeer st sock paddr = do +mkPeer :: Storage -> PeerAddress -> IO Peer +mkPeer st paddr = do pst <- deriveEphemeralStorage st ist <- derivePartialStorage pst svcs <- newMVar M.empty oqueue <- newMVar [] return $ Peer - { peerAddress = DatagramAddress paddr + { peerAddress = paddr , peerIdentity = PeerIdentityUnknown , peerIdentityUpdate = [] , peerChannel = ChannelWait - , peerSocket = sock , peerStorage = pst , peerInStorage = ist , peerServiceState = svcs @@ -668,21 +687,35 @@ mkPeer st sock paddr = do serverPeer :: Server -> SockAddr -> IO Peer serverPeer server paddr = do sock <- readMVar $ serverSocket server + serverPeer' server (DatagramAddress sock paddr) + +serverPeerIce :: Server -> IceSession -> IO Peer +serverPeerIce server ice = do + let paddr = PeerIceSession ice + peer <- serverPeer' server paddr + iceSetChan ice (paddr,) $ serverChanPacket server + return peer + +serverPeer' :: Server -> PeerAddress -> IO Peer +serverPeer' server paddr = do (peer, hello) <- modifyMVar (serverPeers server) $ \pvalue -> do case M.lookup paddr pvalue of Just peer -> return (pvalue, (peer, False)) Nothing -> do - peer <- mkPeer (serverStorage server) sock paddr + peer <- mkPeer (serverStorage server) paddr return (M.insert paddr peer pvalue, (peer, True)) when hello $ do identity <- readMVar (serverIdentity server) - void $ sendTo sock - (BL.toStrict $ serializeObject $ transportToObject $ TransportHeader + void $ sendTo peer $ + BL.toStrict $ serializeObject $ transportToObject $ TransportHeader [ AnnounceSelf $ partialRef (peerInStorage peer) $ storedRef $ idData identity ] - ) paddr return peer +sendTo :: Peer -> BC.ByteString -> IO () +sendTo Peer { peerAddress = DatagramAddress sock addr } msg = void $ S.sendTo sock msg addr +sendTo Peer { peerAddress = PeerIceSession ice } msg = iceSend ice msg + sendToPeer :: (Service s, MonadIO m, MonadError String m) => UnifiedIdentity -> Peer -> s -> m () sendToPeer self peer packet = sendToPeerList self peer [ServiceReply (Left packet) True] @@ -701,17 +734,16 @@ sendToPeerList _ peer parts = do packet = TransportPacket header content case peerChannel peer of ChannelEstablished ch -> do - let DatagramAddress paddr = peerAddress peer - sendPacket (peerSocket peer) paddr ch packet + sendPacket peer ch packet _ -> liftIO $ modifyMVar_ (peerServiceOutQueue peer) $ return . (packet:) -sendPacket :: (MonadIO m, MonadError String m) => Socket -> SockAddr -> Channel -> TransportPacket -> m () -sendPacket sock addr ch (TransportPacket header content) = do +sendPacket :: (MonadIO m, MonadError String m) => Peer -> Channel -> TransportPacket -> m () +sendPacket peer ch (TransportPacket header content) = do let plain = BL.toStrict $ BL.concat $ (serializeObject $ transportToObject header) : map lazyLoadBytes content ctext <- channelEncrypt ch plain - void $ liftIO $ sendTo sock ctext addr + void $ liftIO $ sendTo peer ctext sendToPeerWith :: forall s m. (Service s, MonadIO m, MonadError String m) => UnifiedIdentity -> Peer -> (ServiceState s -> ExceptT String IO (Maybe s, ServiceState s)) -> m () sendToPeerWith identity peer fobj = do diff --git a/src/Storage.hs b/src/Storage.hs index 92a1e1f..f73c420 100644 --- a/src/Storage.hs +++ b/src/Storage.hs @@ -35,7 +35,7 @@ module Storage ( loadBlob, loadRec, loadZero, loadInt, loadNum, loadText, loadBinary, loadDate, loadUUID, loadJson, loadRef, loadRawRef, loadMbInt, loadMbNum, loadMbText, loadMbBinary, loadMbDate, loadMbUUID, loadMbJson, loadMbRef, loadMbRawRef, - loadBinaries, loadRefs, loadRawRefs, + loadTexts, loadBinaries, loadRefs, loadRawRefs, loadZRef, Stored, @@ -720,6 +720,12 @@ loadMbText name = asks (lookup (BC.pack name) . snd) >>= \case Just (RecText x) -> Just <$> fromText x Just _ -> throwError $ "Expecting type text of record item '"++name++"'" +loadTexts :: StorableText a => String -> LoadRec [a] +loadTexts name = do + items <- map snd . filter ((BC.pack name ==) . fst) <$> asks snd + forM items $ \case RecText x -> fromText x + _ -> throwError $ "Expecting type text of record item '"++name++"'" + loadBinary :: BA.ByteArray a => String -> LoadRec a loadBinary name = maybe (throwError $ "Missing record item '"++name++"'") return =<< loadMbBinary name -- cgit v1.2.3