From 0905fe68591a3dad83f87d5ac805b674c0b88c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roman=20Smr=C5=BE?= Date: Fri, 1 May 2026 20:25:20 +0200 Subject: Arbitrary type expression as function arguments --- src/Parser/Core.hs | 11 +++++++++++ src/Parser/Expr.hs | 29 ++++++++++++++++++++++------- src/Script/Expr.hs | 26 +++++++++++++++++--------- test/asset/parser/function.et | 6 ++++++ 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/Parser/Core.hs b/src/Parser/Core.hs index e1a4035..c12afdd 100644 --- a/src/Parser/Core.hs +++ b/src/Parser/Core.hs @@ -273,6 +273,17 @@ unifySomeExpr off stype sexpr@(SomeExpr expr) _ <- unify off (ExprTypeVar tvar) (someExprType sexpr) return sexpr + | ExprTypeFunction args res <- stype + = case someExprType sexpr of + ExprTypeFunction args' res' -> do + _ <- unify off args args' + _ <- unify off res res' + return sexpr + _ -> do + _ <- unify off args (ExprTypeArguments mempty) + SomeExpr expr' <- unifySomeExpr off res sexpr + return $ SomeExpr $ FunctionAbstraction expr' + | otherwise = do parseError $ FancyError off $ S.singleton $ ErrorFail $ T.unpack $ diff --git a/src/Parser/Expr.hs b/src/Parser/Expr.hs index c12d004..16c2b45 100644 --- a/src/Parser/Expr.hs +++ b/src/Parser/Expr.hs @@ -162,7 +162,7 @@ quotedString = label "string" $ lexeme $ do regex :: TestParser (Expr Regex) regex = label "regular expression" $ lexeme $ do off <- stateOffset <$> getParserState - void $ char '/' + void $ try $ char '/' <* notFollowedBy (char '=') -- TODO: better parsing rules for regexes let inner = choice [ char '/' >> return [] , takeWhile1P Nothing (`notElem` ['/', '\\', '$']) >>= \s -> (Pure (RegexPart (TL.toStrict s)) :) <$> inner @@ -418,12 +418,27 @@ constructor = label "constructor" $ do functionCall :: TestParser SomeExpr functionCall = do sline <- getSourceLine - (variable <|> constructor) >>= \case - SomeExpr e'@(FunVariable argTypes _ _) -> do - let check = checkFunctionArguments argTypes - args <- functionArguments check (someExpr FunctionTerm) literal (\poff -> lookupVarExpr poff sline . VarName) - return $ SomeExpr $ ArgsApp args e' - e -> return e + off <- stateOffset <$> getParserState + + fun <- variable <|> constructor + FunctionArguments margs <- functionArguments (\poff _ e -> return ( poff, e )) (someExpr FunctionTerm) literal (\poff -> lookupVarExpr poff sline . VarName) + if M.null margs + then return fun + else do + dict <- newTypeVar + res <- newTypeVar + SomeExpr (expr :: Expr fa) <- unifySomeExpr off (ExprTypeFunction (ExprTypeVar dict) (ExprTypeVar res)) fun + Just (ExprTypeArguments argTypes) <- M.lookup dict <$> gets testTypeUnif + args <- fmap (FunctionArguments . M.fromAscList) $ mapM (\( kw, ( poff, e ) ) -> ( kw, ) <$> checkFunctionArguments argTypes poff kw e) $ M.toAscList margs + M.lookup res <$> gets testTypeUnif >>= \case + Just (ExprTypePrim (_ :: Proxy a)) + | Just (Refl :: FunctionType a :~: fa) <- eqT + -> return $ SomeExpr $ ArgsApp args expr + | otherwise -> error $ "type mismatch after function unification: " <> show ( typeRep (Proxy @(FunctionType a)), typeRep (Proxy @fa) ) + _ + | Just (Refl :: FunctionType DynamicType :~: fa) <- eqT + -> return $ SomeExpr $ ArgsApp args expr + | otherwise -> error $ "type mismatch after function unification: " <> show ( typeRep (Proxy @(FunctionType DynamicType)), typeRep (Proxy @fa) ) recordSelector :: SomeExpr -> TestParser SomeExpr recordSelector (SomeExpr expr) = do diff --git a/src/Script/Expr.hs b/src/Script/Expr.hs index a975ef5..aae898a 100644 --- a/src/Script/Expr.hs +++ b/src/Script/Expr.hs @@ -57,7 +57,7 @@ data Expr a where Let :: forall a b. ExprType b => SourceLine -> TypedVarName b -> Expr b -> Expr a -> Expr a Variable :: ExprType a => SourceLine -> FqVarName -> Expr a DynVariable :: SomeExprType -> SourceLine -> FqVarName -> Expr DynamicType - FunVariable :: ExprType a => FunctionArguments SomeArgumentType -> SourceLine -> FqVarName -> Expr (FunctionType a) + FunVariable :: ExprType a => SomeExprType -> SourceLine -> FqVarName -> Expr (FunctionType a) OptVariable :: ExprType a => SourceLine -> FqVarName -> Expr (Maybe a) ArgsReq :: ExprType a => FunctionArguments ( VarName, SomeArgumentType ) -> Expr (FunctionType a) -> Expr (FunctionType a) ArgsApp :: ExprType a => FunctionArguments SomeExpr -> Expr (FunctionType a) -> Expr (FunctionType a) @@ -286,7 +286,8 @@ data SomeExprType = forall a. ExprType a => ExprTypePrim (Proxy a) | forall a. ExprTypeConstr1 a => ExprTypeConstr1 (Proxy a) | ExprTypeVar TypeVar - | ExprTypeFunction (FunctionArguments SomeArgumentType) SomeExprType + | ExprTypeFunction SomeExprType SomeExprType + | ExprTypeArguments (FunctionArguments SomeArgumentType) | ExprTypeApp SomeExprType [ SomeExprType ] | ExprTypeForall TypeVar SomeExprType @@ -296,13 +297,14 @@ someExprType (SomeExpr expr) = go expr go :: forall e. ExprType e => Expr e -> SomeExprType go = \case DynVariable stype _ _ -> stype + e@(FunVariable args _ _) -> ExprTypeFunction args (ExprTypePrim (proxyOfFunctionType e)) HideType stype _ -> stype TypeLambda tvar stype _ -> ExprTypeForall tvar stype ArgsReq args inner -> exprTypeFunction (fmap snd args) (go inner) ArgsApp (FunctionArguments used) inner - | ExprTypeFunction (FunctionArguments args) x <- go inner - -> ExprTypeFunction (FunctionArguments (args `M.difference` used)) x + | ExprTypeFunction (ExprTypeArguments (FunctionArguments args)) x <- go inner + -> ExprTypeFunction (ExprTypeArguments (FunctionArguments (args `M.difference` used))) x FunctionAbstraction inner -> exprTypeFunction mempty (go inner) FunctionEval _ inner | ExprTypeFunction _ x <- go inner -> x @@ -310,8 +312,11 @@ someExprType (SomeExpr expr) = go expr (_ :: Expr a) -> ExprTypePrim (Proxy @a) exprTypeFunction :: FunctionArguments SomeArgumentType -> SomeExprType -> SomeExprType - exprTypeFunction args (ExprTypeFunction args' inner) = ExprTypeFunction (args <> args') inner - exprTypeFunction args inner = ExprTypeFunction args inner + exprTypeFunction args (ExprTypeFunction (ExprTypeArguments args') inner) = ExprTypeFunction (ExprTypeArguments (args <> args')) inner + exprTypeFunction args inner = ExprTypeFunction (ExprTypeArguments args) inner + + proxyOfFunctionType :: Expr (FunctionType a) -> Proxy a + proxyOfFunctionType _ = Proxy renameTypeVar :: TypeVar -> TypeVar -> Expr a -> Expr a @@ -353,7 +358,8 @@ renameVarInType a b = go ExprTypeConstr1 {} -> orig ExprTypeVar tvar | tvar == a -> ExprTypeVar b | otherwise -> orig - ExprTypeFunction {} -> orig + ExprTypeFunction args result -> ExprTypeFunction (go args) (go result) + ExprTypeArguments args -> ExprTypeArguments (fmap (\(SomeArgumentType atype stype) -> SomeArgumentType atype (go stype)) args) ExprTypeApp c xs -> ExprTypeApp (go c) (map go xs) ExprTypeForall tvar stype | tvar == a -> orig @@ -369,6 +375,7 @@ textSomeExprType = go [] go [] (ExprTypeConstr1 _) = "" go _ (ExprTypeVar (TypeVar name)) = name go _ (ExprTypeFunction _ r) = "function:" <> textSomeExprType r + go _ (ExprTypeArguments _) = "{…}" go _ (ExprTypeApp c xs) = go (map textSomeExprType xs) c go _ (ExprTypeForall (TypeVar name) ctype) = "∀" <> name <> "." <> go [] ctype @@ -429,7 +436,7 @@ textSomeVarValue (SomeVarValue (VarValue _ args value)) someVarValueType :: SomeVarValue -> SomeExprType someVarValueType (SomeVarValue (VarValue _ args _ :: VarValue a)) | anull args = ExprTypePrim (Proxy @a) - | otherwise = ExprTypeFunction args (ExprTypePrim (Proxy @a)) + | otherwise = ExprTypeFunction (ExprTypeArguments args) (ExprTypePrim (Proxy @a)) newtype ArgumentKeyword = ArgumentKeyword Text @@ -445,7 +452,8 @@ exprArgs :: Expr (FunctionType a) -> FunctionArguments SomeArgumentType exprArgs = \case Let _ _ _ expr -> exprArgs expr Variable {} -> mempty - FunVariable args _ _ -> args + FunVariable (ExprTypeArguments args) _ _ -> args + FunVariable _ _ _ -> error "exprArgs: type-var args" ArgsReq args expr -> fmap snd args <> exprArgs expr ArgsApp (FunctionArguments applied) expr -> let FunctionArguments args = exprArgs expr diff --git a/test/asset/parser/function.et b/test/asset/parser/function.et index 3eca414..2a096b9 100644 --- a/test/asset/parser/function.et +++ b/test/asset/parser/function.et @@ -4,6 +4,12 @@ def g (x) and y = (x + (y+1)) test Test: guard (1 == 1) + guard (1 /= 2) + let x = 2 + guard (x == x) + guard (x /= 1) + guard (x /= x + 1) + guard (f 1 and 2 == 4) guard (f 1 and 2 == g 1 and 2) guard (f 1 and (g 2 and 3) == g 1 and 2 + 4) -- cgit v1.2.3