summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Smrž <roman.smrz@seznam.cz>2024-12-05 22:14:21 +0100
committerRoman Smrž <roman.smrz@seznam.cz>2024-12-07 09:16:25 +0100
commit51d78df83fc69df8e54cb72212a91576da8bf5b0 (patch)
tree3e26ce743ad7ea72d7ca152f63fb597adb7fcd92
parent1b26af0b8da3bf9527d92978b3f23c851c749510 (diff)
Arguments for user-defined functions
-rw-r--r--src/Parser.hs36
-rw-r--r--src/Parser/Core.hs3
-rw-r--r--src/Test.hs118
3 files changed, 137 insertions, 20 deletions
diff --git a/src/Parser.hs b/src/Parser.hs
index ab44833..940bd60 100644
--- a/src/Parser.hs
+++ b/src/Parser.hs
@@ -9,6 +9,7 @@ import Control.Monad.State
import Data.Map qualified as M
import Data.Maybe
+import Data.Proxy
import Data.Set qualified as S
import Data.Text qualified as T
import Data.Text.Lazy qualified as TL
@@ -41,15 +42,46 @@ parseDefinition = label "symbol definition" $ toplevel ToplevelDefinition $ do
def <- localState $ L.indentBlock scn $ do
wsymbol "def"
name <- varName
+ argsDecl <- functionArguments (\off _ -> return . ( off, )) varName mzero (\_ -> return . VarName)
+ atypes <- forM argsDecl $ \( off, vname :: VarName ) -> do
+ tvar <- newTypeVar
+ modify $ \s -> s { testVars = ( vname, ExprTypeVar tvar ) : testVars s }
+ return ( off, vname, tvar )
choice
[ do
- symbol ":"
+ osymbol ":"
let finish steps = do
- return $ ( name, SomeExpr $ mconcat steps )
+ atypes' <- getInferredTypes atypes
+ ( name, ) . SomeExpr . ArgsReq atypes' . FunctionAbstraction <$> replaceDynArgs (mconcat steps)
return $ L.IndentSome Nothing finish testStep
]
modify $ \s -> s { testVars = fmap someExprType def : testVars s }
return def
+ where
+ getInferredTypes atypes = forM atypes $ \( off, vname, tvar@(TypeVar tvarname) ) -> do
+ let err msg = do
+ registerParseError . FancyError off . S.singleton . ErrorFail $ T.unpack msg
+ return ( vname, SomeArgumentType (OptionalArgument @DynamicType) )
+ gets (M.lookup tvar . testTypeUnif) >>= \case
+ Just (ExprTypePrim (_ :: Proxy a)) -> return ( vname, SomeArgumentType (RequiredArgument @a) )
+ Just (ExprTypeVar (TypeVar tvar')) -> err $ "ambiguous type for ‘" <> textVarName vname <> " : " <> tvar' <> "’"
+ Just (ExprTypeFunction {}) -> err $ "unsupported function type of ‘" <> textVarName vname <> "’"
+ Nothing -> err $ "ambiguous type for ‘" <> textVarName vname <> " : " <> tvarname <> "’"
+
+ replaceDynArgs :: forall a. Expr a -> TestParser (Expr a)
+ replaceDynArgs expr = do
+ unif <- gets testTypeUnif
+ return $ mapExpr (go unif) expr
+ where
+ go :: forall b. M.Map TypeVar SomeExprType -> Expr b -> Expr b
+ go unif = \case
+ ArgsApp args body -> ArgsApp (fmap replaceArgs args) body
+ where
+ replaceArgs (SomeExpr (DynVariable tvar sline vname))
+ | Just (ExprTypePrim (Proxy :: Proxy v)) <- M.lookup tvar unif
+ = SomeExpr (Variable sline vname :: Expr v)
+ replaceArgs (SomeExpr e) = SomeExpr (go unif e)
+ e -> e
parseTestModule :: FilePath -> TestParser Module
parseTestModule absPath = do
diff --git a/src/Parser/Core.hs b/src/Parser/Core.hs
index 57b2eb4..5fb4c5f 100644
--- a/src/Parser/Core.hs
+++ b/src/Parser/Core.hs
@@ -200,7 +200,8 @@ localState :: TestParser a -> TestParser a
localState inner = do
s <- get
x <- inner
- put s
+ s' <- get
+ put s { testNextTypeVar = testNextTypeVar s', testTypeUnif = testTypeUnif s' }
return x
toplevel :: (a -> Toplevel) -> TestParser a -> TestParser Toplevel
diff --git a/src/Test.hs b/src/Test.hs
index e6cc415..3db7919 100644
--- a/src/Test.hs
+++ b/src/Test.hs
@@ -19,7 +19,7 @@ module Test (
RecordSelector(..),
ExprListUnpacker(..),
ExprEnumerator(..),
- Expr(..), varExpr, eval, evalSome,
+ Expr(..), varExpr, mapExpr, eval, evalSome,
Traced(..), EvalTrace, VarNameSelectors, gatherVars,
AppAnnotation(..),
@@ -34,6 +34,7 @@ import Control.Monad
import Control.Monad.Reader
import Data.Char
+import Data.Foldable
import Data.List
import Data.Map (Map)
import Data.Map qualified as M
@@ -198,18 +199,51 @@ data SomeExprType
| forall a. ExprType a => ExprTypeFunction (FunctionArguments SomeArgumentType) (Proxy a)
someExprType :: SomeExpr -> SomeExprType
-someExprType (SomeExpr (DynVariable tvar _ _)) = ExprTypeVar tvar
-someExprType (SomeExpr fun@(FunVariable params _ _)) = ExprTypeFunction params (proxyOfFunctionType fun)
+someExprType (SomeExpr expr) = go expr
where
+ go :: forall e. ExprType e => Expr e -> SomeExprType
+ go = \case
+ DynVariable tvar _ _ -> ExprTypeVar tvar
+ (e :: Expr a)
+ | IsFunType <- asFunType e -> ExprTypeFunction (gof e) (proxyOfFunctionType e)
+ | otherwise -> ExprTypePrim (Proxy @a)
+
+ gof :: forall e. ExprType e => Expr (FunctionType e) -> FunctionArguments SomeArgumentType
+ gof = \case
+ Let _ _ _ body -> gof body
+ Variable {} -> error "someExprType: gof: variable"
+ FunVariable params _ _ -> params
+ ArgsReq args body -> fmap snd args <> gof body
+ ArgsApp (FunctionArguments used) body ->
+ let FunctionArguments args = gof body
+ in FunctionArguments $ args `M.difference` used
+ FunctionAbstraction {} -> mempty
+ FunctionEval {} -> error "someExprType: gof: function eval"
+ Pure {} -> error "someExprType: gof: pure"
+ App {} -> error "someExprType: gof: app"
+ Undefined {} -> error "someExprType: gof: undefined"
+
proxyOfFunctionType :: Expr (FunctionType a) -> Proxy a
proxyOfFunctionType _ = Proxy
-someExprType (SomeExpr (_ :: Expr a)) = ExprTypePrim (Proxy @a)
textSomeExprType :: SomeExprType -> Text
textSomeExprType (ExprTypePrim p) = textExprType p
textSomeExprType (ExprTypeVar (TypeVar name)) = name
textSomeExprType (ExprTypeFunction _ r) = "function:" <> textExprType r
+data AsFunType a
+ = forall b. (a ~ FunctionType b, ExprType b) => IsFunType
+ | NotFunType
+
+asFunType :: Expr a -> AsFunType a
+asFunType = \case
+ Let _ _ _ expr -> asFunType expr
+ FunVariable {} -> IsFunType
+ ArgsReq {} -> IsFunType
+ ArgsApp {} -> IsFunType
+ FunctionAbstraction {} -> IsFunType
+ _ -> NotFunType
+
data SomeVarValue = forall a. ExprType a => SomeVarValue (VarValue a)
@@ -269,8 +303,10 @@ data Expr a where
Variable :: ExprType a => SourceLine -> VarName -> Expr a
DynVariable :: TypeVar -> SourceLine -> VarName -> Expr DynamicType
FunVariable :: ExprType a => FunctionArguments SomeArgumentType -> SourceLine -> VarName -> Expr (FunctionType a)
- ArgsApp :: FunctionArguments SomeExpr -> Expr (FunctionType a) -> Expr (FunctionType a)
- FunctionEval :: Expr (FunctionType a) -> Expr a
+ ArgsReq :: ExprType a => FunctionArguments ( VarName, SomeArgumentType ) -> Expr (FunctionType a) -> Expr (FunctionType a)
+ ArgsApp :: ExprType a => FunctionArguments SomeExpr -> Expr (FunctionType a) -> Expr (FunctionType a)
+ FunctionAbstraction :: ExprType a => Expr a -> Expr (FunctionType a)
+ FunctionEval :: ExprType a => Expr (FunctionType a) -> Expr a
LambdaAbstraction :: ExprType a => TypedVarName a -> Expr b -> Expr (a -> b)
Pure :: a -> Expr a
App :: AppAnnotation b -> Expr (a -> b) -> Expr a -> Expr b
@@ -298,6 +334,27 @@ instance Monoid a => Monoid (Expr a) where
varExpr :: ExprType a => SourceLine -> TypedVarName a -> Expr a
varExpr sline (TypedVarName name) = Variable sline name
+mapExpr :: forall a. (forall b. Expr b -> Expr b) -> Expr a -> Expr a
+mapExpr f = go
+ where
+ go :: forall c. Expr c -> Expr c
+ go = \case
+ Let sline vname vval expr -> f $ Let sline vname (go vval) (go expr)
+ e@Variable {} -> f e
+ e@DynVariable {} -> f e
+ e@FunVariable {} -> f e
+ ArgsReq args expr -> f $ ArgsReq args (go expr)
+ ArgsApp args expr -> f $ ArgsApp (fmap (\(SomeExpr e) -> SomeExpr (go e)) args) (go expr)
+ FunctionAbstraction expr -> f $ FunctionAbstraction (go expr)
+ FunctionEval expr -> f $ FunctionEval (go expr)
+ LambdaAbstraction tvar expr -> f $ LambdaAbstraction tvar (go expr)
+ e@Pure {} -> f e
+ App ann efun earg -> f $ App ann (go efun) (go earg)
+ e@Concat {} -> f e
+ e@Regex {} -> f e
+ e@Undefined {} -> f e
+ Trace expr -> f $ Trace (go expr)
+
newtype SimpleEval a = SimpleEval (Reader VariableDictionary a)
deriving (Functor, Applicative, Monad)
@@ -319,12 +376,21 @@ eval = \case
val <- eval valExpr
withVar name val $ eval expr
Variable sline name -> fromSomeVarValue sline name =<< lookupVar name
- DynVariable _ _ _ -> fail "ambiguous type"
+ DynVariable _ _ name -> fail $ "ambiguous type of ‘" <> unpackVarName name <> "’"
FunVariable _ sline name -> funFromSomeVarValue sline name =<< lookupVar name
+ ArgsReq (FunctionArguments req) efun -> do
+ dict <- askDictionary
+ return $ FunctionType $ \(FunctionArguments args) ->
+ let used = M.intersectionWith (\value ( vname, _ ) -> ( vname, value )) args req
+ FunctionType fun = runSimpleEval (eval efun) (toList used ++ dict)
+ in fun $ FunctionArguments $ args `M.difference` req
ArgsApp eargs efun -> do
FunctionType fun <- eval efun
args <- mapM evalSome eargs
return $ FunctionType $ \args' -> fun (args <> args')
+ FunctionAbstraction expr -> do
+ val <- eval expr
+ return $ FunctionType $ const val
FunctionEval efun -> do
FunctionType fun <- eval efun
return $ fun mempty
@@ -343,10 +409,18 @@ eval = \case
Trace expr -> Traced <$> gatherVars expr <*> eval expr
evalSome :: MonadEval m => SomeExpr -> m SomeVarValue
-evalSome (SomeExpr expr) = fmap SomeVarValue $ VarValue
- <$> gatherVars expr
- <*> pure mempty
- <*> (const . const <$> eval expr)
+evalSome (SomeExpr expr)
+ | IsFunType <- asFunType expr = do
+ FunctionType fun <- eval expr
+ fmap SomeVarValue $ VarValue
+ <$> gatherVars expr
+ <*> pure (exprArgs expr)
+ <*> pure (const fun)
+ | otherwise = do
+ fmap SomeVarValue $ VarValue
+ <$> gatherVars expr
+ <*> pure mempty
+ <*> (const . const <$> eval expr)
data Traced a = Traced EvalTrace a
@@ -364,10 +438,12 @@ gatherVars = fmap (uniqOn fst . sortOn fst) . helper
| otherwise -> maybe [] (\x -> [ (( var, [] ), x ) ]) <$> tryLookupVar var
DynVariable _ _ var -> maybe [] (\x -> [ (( var, [] ), x ) ]) <$> tryLookupVar var
FunVariable _ _ var -> maybe [] (\x -> [ (( var, [] ), x ) ]) <$> tryLookupVar var
+ ArgsReq args expr -> withDictionary (filter ((`notElem` map fst (toList args)) . fst)) $ helper expr
ArgsApp (FunctionArguments args) fun -> do
v <- helper fun
vs <- mapM (\(SomeExpr e) -> helper e) $ M.elems args
return $ concat (v : vs)
+ FunctionAbstraction expr -> helper expr
FunctionEval efun -> helper efun
LambdaAbstraction (TypedVarName var) expr -> withDictionary (filter ((var /=) . fst)) $ helper expr
Pure _ -> return []
@@ -403,11 +479,19 @@ anull :: FunctionArguments a -> Bool
anull (FunctionArguments args) = M.null args
exprArgs :: Expr (FunctionType a) -> FunctionArguments SomeArgumentType
-exprArgs (FunVariable args _ _) = args
-exprArgs (ArgsApp (FunctionArguments applied) expr) =
- let FunctionArguments args = exprArgs expr
- in FunctionArguments (args `M.difference` applied)
-exprArgs _ = error "exprArgs on unexpected type"
+exprArgs = \case
+ Let _ _ _ expr -> exprArgs expr
+ Variable {} -> mempty
+ FunVariable args _ _ -> args
+ ArgsReq args expr -> fmap snd args <> exprArgs expr
+ ArgsApp (FunctionArguments applied) expr ->
+ let FunctionArguments args = exprArgs expr
+ in FunctionArguments (args `M.difference` applied)
+ FunctionAbstraction {} -> mempty
+ FunctionEval {} -> mempty
+ Pure {} -> error "exprArgs: pure"
+ App {} -> error "exprArgs: app"
+ Undefined {} -> error "exprArgs: undefined"
funFromSomeVarValue :: forall a m. (ExprType a, MonadFail m) => SourceLine -> VarName -> SomeVarValue -> m (FunctionType a)
funFromSomeVarValue sline name (SomeVarValue (VarValue _ args value :: VarValue b)) = do
@@ -416,7 +500,7 @@ funFromSomeVarValue sline name (SomeVarValue (VarValue _ args value :: VarValue
FunctionType <$> cast (value sline)
where
err = T.unpack $ T.concat [ T.pack "expected function returning ", textExprType @a Proxy, T.pack ", but variable '", textVarName name, T.pack "' has ",
- (if anull args then "type" else "function type returting ") <> textExprType @b Proxy ]
+ (if anull args then "type " else "function type returting ") <> textExprType @b Proxy ]
data SomeArgumentType = forall a. ExprType a => SomeArgumentType (ArgumentType a)