diff options
author | Ben Gamari <ben@smart-cactus.org> | 2020-04-19 11:16:32 -0400 |
---|---|---|
committer | Ben Gamari <ben@smart-cactus.org> | 2020-05-22 15:14:51 -0400 |
commit | 5a2cdd5c2f69992a59fa57b9dc46f6ed71864ef6 (patch) | |
tree | 041d5ae04a2f37d0d9b27cd932d7fbe80ba49af4 /compiler/GHC/Core/Lint.hs | |
parent | 31f1c568e7c9562d58ae10dbcd74d67da8156021 (diff) | |
download | haskell-wip/runRW.tar.gz |
Allow simplification through runRW#wip/runRW
Because runRW# inlines so late, we were previously able to do very
little simplification across it. For instance, given even a simple
program like
case runRW# (\s -> let n = I# 42# in n) of
I# n# -> f n#
we previously had no way to avoid the allocation of the I#.
This patch allows the simplifier to push strict contexts into the
continuation of a runRW# application, as explained in
in Note [Simplification of runRW#] in GHC.CoreToStg.Prep.
Fixes #15127.
Metric Increase:
T9961
Metric Decrease:
ManyConstructors
Co-Authored-By: Simon Peyton-Jone <simonpj@microsoft.com>
Diffstat (limited to 'compiler/GHC/Core/Lint.hs')
-rw-r--r-- | compiler/GHC/Core/Lint.hs | 128 |
1 files changed, 100 insertions, 28 deletions
diff --git a/compiler/GHC/Core/Lint.hs b/compiler/GHC/Core/Lint.hs index 226e50a8bd..724ccf17ab 100644 --- a/compiler/GHC/Core/Lint.hs +++ b/compiler/GHC/Core/Lint.hs @@ -685,22 +685,9 @@ lintRhs :: Id -> CoreExpr -> LintM LintedType -- its OccInfo and join-pointer-hood lintRhs bndr rhs | Just arity <- isJoinId_maybe bndr - = lint_join_lams arity arity True rhs + = lintJoinLams arity (Just bndr) rhs | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr) - = lint_join_lams arity arity False rhs - where - lint_join_lams 0 _ _ rhs - = lintCoreExpr rhs - - lint_join_lams n tot enforce (Lam var expr) - = lintLambda var $ lint_join_lams (n-1) tot enforce expr - - lint_join_lams n tot True _other - = failWithL $ mkBadJoinArityMsg bndr tot (tot-n) rhs - lint_join_lams _ _ False rhs - = markAllJoinsBad $ lintCoreExpr rhs - -- Future join point, not yet eta-expanded - -- Body is not a tail position + = lintJoinLams arity Nothing rhs -- Allow applications of the data constructor @StaticPtr@ at the top -- but produce errors otherwise. @@ -722,6 +709,22 @@ lintRhs _bndr rhs = fmap lf_check_static_ptrs getLintFlags >>= go binders0 go _ = markAllJoinsBad $ lintCoreExpr rhs +-- | Lint the RHS of a join point with expected join arity of @n@ (see Note +-- [Join points] in GHC.Core). +lintJoinLams :: JoinArity -> Maybe Id -> CoreExpr -> LintM LintedType +lintJoinLams join_arity enforce rhs + = go join_arity rhs + where + go 0 rhs = lintCoreExpr rhs + go n (Lam var expr) = lintLambda var $ go (n-1) expr + -- N.B. join points can be cast. e.g. we consider ((\x -> ...) `cast` ...) + -- to be a join point at join arity 1. + go n _other | Just bndr <- enforce -- Join point with too few RHS lambdas + = failWithL $ mkBadJoinArityMsg bndr join_arity n rhs + | otherwise -- Future join point, not yet eta-expanded + = markAllJoinsBad $ lintCoreExpr rhs + -- Body of lambda is not a tail position + lintIdUnfolding :: Id -> Type -> Unfolding -> LintM () lintIdUnfolding bndr bndr_ty uf | isStableUnfolding uf @@ -762,6 +765,40 @@ we will check any unfolding after it has been unfolded; checking the unfolding beforehand is merely an optimization, and one that actively hurts us here. +Note [Linting of runRW#] +~~~~~~~~~~~~~~~~~~~~~~~~ +runRW# has some very peculiar behavior (see Note [runRW magic] in +GHC.CoreToStg.Prep) which CoreLint must accommodate. + +As described in Note [Casts and lambdas] in +GHC.Core.Opt.Simplify.Utils, the simplifier pushes casts out of +lambdas. Concretely, the simplifier will transform + + runRW# @r @ty (\s -> expr `cast` co) + +into + + runRW# @r @ty ((\s -> expr) `cast` co) + +Consequently we need to handle the case that the continuation is a +cast of a lambda. See Note [Casts and lambdas] in +GHC.Core.Opt.Simplify.Utils. + +In the event that the continuation is headed by a lambda (which +will bind the State# token) we can safely allow calls to join +points since CorePrep is going to apply the continuation to +RealWorld. + +In the case that the continuation is not a lambda we lint the +continuation disallowing join points, to rule out things like, + + join j = ... + in runRW# @r @ty ( + let x = jump j + in x + ) + + ************************************************************************ * * \subsection[lintCoreExpr]{lintCoreExpr} @@ -776,6 +813,18 @@ type LintedCoercion = Coercion type LintedTyCoVar = TyCoVar type LintedId = Id +-- | Lint an expression cast through the given coercion, returning the type +-- resulting from the cast. +lintCastExpr :: CoreExpr -> LintedType -> Coercion -> LintM LintedType +lintCastExpr expr expr_ty co + = do { co' <- lintCoercion co + ; let (Pair from_ty to_ty, role) = coercionKindRole co' + ; checkValueType to_ty $ + text "target of cast" <+> quotes (ppr co') + ; lintRole co' Representational role + ; ensureEqTys from_ty expr_ty (mkCastErr expr co' from_ty expr_ty) + ; return to_ty } + lintCoreExpr :: CoreExpr -> LintM LintedType -- The returned type has the substitution from the monad -- already applied to it: @@ -793,14 +842,8 @@ lintCoreExpr (Lit lit) = return (literalType lit) lintCoreExpr (Cast expr co) - = do { expr_ty <- markAllJoinsBad $ lintCoreExpr expr - ; co' <- lintCoercion co - ; let (Pair from_ty to_ty, role) = coercionKindRole co' - ; checkValueType to_ty $ - text "target of cast" <+> quotes (ppr co') - ; lintRole co' Representational role - ; ensureEqTys from_ty expr_ty (mkCastErr expr co' from_ty expr_ty) - ; return to_ty } + = do expr_ty <- markAllJoinsBad $ lintCoreExpr expr + lintCastExpr expr expr_ty co lintCoreExpr (Tick tickish expr) = do case tickish of @@ -860,6 +903,31 @@ lintCoreExpr e@(Let (Rec pairs) body) bndrs = map fst pairs lintCoreExpr e@(App _ _) + | Var fun <- fun + , fun `hasKey` runRWKey + -- N.B. we may have an over-saturated application of the form: + -- runRW (\s -> \x -> ...) y + , arg_ty1 : arg_ty2 : arg3 : rest <- args + = do { fun_ty1 <- lintCoreArg (idType fun) arg_ty1 + ; fun_ty2 <- lintCoreArg fun_ty1 arg_ty2 + -- See Note [Linting of runRW#] + ; let lintRunRWCont :: CoreArg -> LintM LintedType + lintRunRWCont (Cast expr co) = do + ty <- lintRunRWCont expr + lintCastExpr expr ty co + lintRunRWCont expr@(Lam _ _) = do + lintJoinLams 1 (Just fun) expr + lintRunRWCont other = markAllJoinsBad $ lintCoreExpr other + -- TODO: Look through ticks? + ; arg3_ty <- lintRunRWCont arg3 + ; app_ty <- lintValApp arg3 fun_ty2 arg3_ty + ; lintCoreArgs app_ty rest } + + | Var fun <- fun + , fun `hasKey` runRWKey + = failWithL (text "Invalid runRW# application") + + | otherwise = do { fun_ty <- lintCoreFun fun (length args) ; lintCoreArgs fun_ty args } where @@ -1139,11 +1207,15 @@ lintTyApp fun_ty arg_ty = failWithL (mkTyAppMsg fun_ty arg_ty) ----------------- + +-- | @lintValApp arg fun_ty arg_ty@ lints an application of @fun arg@ +-- where @fun :: fun_ty@ and @arg :: arg_ty@, returning the type of the +-- application. lintValApp :: CoreExpr -> LintedType -> LintedType -> LintM LintedType lintValApp arg fun_ty arg_ty - | Just (arg,res) <- splitFunTy_maybe fun_ty - = do { ensureEqTys arg arg_ty err1 - ; return res } + | Just (arg_ty', res_ty') <- splitFunTy_maybe fun_ty + = do { ensureEqTys arg_ty' arg_ty err1 + ; return res_ty' } | otherwise = failWithL err2 where @@ -2780,11 +2852,11 @@ mkInvalidJoinPointMsg var ty 2 (ppr var <+> dcolon <+> ppr ty) mkBadJoinArityMsg :: Var -> Int -> Int -> CoreExpr -> SDoc -mkBadJoinArityMsg var ar nlams rhs +mkBadJoinArityMsg var ar n rhs = vcat [ text "Join point has too few lambdas", text "Join var:" <+> ppr var, text "Join arity:" <+> ppr ar, - text "Number of lambdas:" <+> ppr nlams, + text "Number of lambdas:" <+> ppr (ar - n), text "Rhs = " <+> ppr rhs ] |