summaryrefslogtreecommitdiff
path: root/compiler/GHC/Core/Lint.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/GHC/Core/Lint.hs')
-rw-r--r--compiler/GHC/Core/Lint.hs128
1 files changed, 100 insertions, 28 deletions
diff --git a/compiler/GHC/Core/Lint.hs b/compiler/GHC/Core/Lint.hs
index 226e50a8bd..724ccf17ab 100644
--- a/compiler/GHC/Core/Lint.hs
+++ b/compiler/GHC/Core/Lint.hs
@@ -685,22 +685,9 @@ lintRhs :: Id -> CoreExpr -> LintM LintedType
-- its OccInfo and join-pointer-hood
lintRhs bndr rhs
| Just arity <- isJoinId_maybe bndr
- = lint_join_lams arity arity True rhs
+ = lintJoinLams arity (Just bndr) rhs
| AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr)
- = lint_join_lams arity arity False rhs
- where
- lint_join_lams 0 _ _ rhs
- = lintCoreExpr rhs
-
- lint_join_lams n tot enforce (Lam var expr)
- = lintLambda var $ lint_join_lams (n-1) tot enforce expr
-
- lint_join_lams n tot True _other
- = failWithL $ mkBadJoinArityMsg bndr tot (tot-n) rhs
- lint_join_lams _ _ False rhs
- = markAllJoinsBad $ lintCoreExpr rhs
- -- Future join point, not yet eta-expanded
- -- Body is not a tail position
+ = lintJoinLams arity Nothing rhs
-- Allow applications of the data constructor @StaticPtr@ at the top
-- but produce errors otherwise.
@@ -722,6 +709,22 @@ lintRhs _bndr rhs = fmap lf_check_static_ptrs getLintFlags >>= go
binders0
go _ = markAllJoinsBad $ lintCoreExpr rhs
+-- | Lint the RHS of a join point with expected join arity of @n@ (see Note
+-- [Join points] in GHC.Core).
+lintJoinLams :: JoinArity -> Maybe Id -> CoreExpr -> LintM LintedType
+lintJoinLams join_arity enforce rhs
+ = go join_arity rhs
+ where
+ go 0 rhs = lintCoreExpr rhs
+ go n (Lam var expr) = lintLambda var $ go (n-1) expr
+ -- N.B. join points can be cast. e.g. we consider ((\x -> ...) `cast` ...)
+ -- to be a join point at join arity 1.
+ go n _other | Just bndr <- enforce -- Join point with too few RHS lambdas
+ = failWithL $ mkBadJoinArityMsg bndr join_arity n rhs
+ | otherwise -- Future join point, not yet eta-expanded
+ = markAllJoinsBad $ lintCoreExpr rhs
+ -- Body of lambda is not a tail position
+
lintIdUnfolding :: Id -> Type -> Unfolding -> LintM ()
lintIdUnfolding bndr bndr_ty uf
| isStableUnfolding uf
@@ -762,6 +765,40 @@ we will check any unfolding after it has been unfolded; checking the
unfolding beforehand is merely an optimization, and one that actively
hurts us here.
+Note [Linting of runRW#]
+~~~~~~~~~~~~~~~~~~~~~~~~
+runRW# has some very peculiar behavior (see Note [runRW magic] in
+GHC.CoreToStg.Prep) which CoreLint must accommodate.
+
+As described in Note [Casts and lambdas] in
+GHC.Core.Opt.Simplify.Utils, the simplifier pushes casts out of
+lambdas. Concretely, the simplifier will transform
+
+ runRW# @r @ty (\s -> expr `cast` co)
+
+into
+
+ runRW# @r @ty ((\s -> expr) `cast` co)
+
+Consequently we need to handle the case that the continuation is a
+cast of a lambda. See Note [Casts and lambdas] in
+GHC.Core.Opt.Simplify.Utils.
+
+In the event that the continuation is headed by a lambda (which
+will bind the State# token) we can safely allow calls to join
+points since CorePrep is going to apply the continuation to
+RealWorld.
+
+In the case that the continuation is not a lambda we lint the
+continuation disallowing join points, to rule out things like,
+
+ join j = ...
+ in runRW# @r @ty (
+ let x = jump j
+ in x
+ )
+
+
************************************************************************
* *
\subsection[lintCoreExpr]{lintCoreExpr}
@@ -776,6 +813,18 @@ type LintedCoercion = Coercion
type LintedTyCoVar = TyCoVar
type LintedId = Id
+-- | Lint an expression cast through the given coercion, returning the type
+-- resulting from the cast.
+lintCastExpr :: CoreExpr -> LintedType -> Coercion -> LintM LintedType
+lintCastExpr expr expr_ty co
+ = do { co' <- lintCoercion co
+ ; let (Pair from_ty to_ty, role) = coercionKindRole co'
+ ; checkValueType to_ty $
+ text "target of cast" <+> quotes (ppr co')
+ ; lintRole co' Representational role
+ ; ensureEqTys from_ty expr_ty (mkCastErr expr co' from_ty expr_ty)
+ ; return to_ty }
+
lintCoreExpr :: CoreExpr -> LintM LintedType
-- The returned type has the substitution from the monad
-- already applied to it:
@@ -793,14 +842,8 @@ lintCoreExpr (Lit lit)
= return (literalType lit)
lintCoreExpr (Cast expr co)
- = do { expr_ty <- markAllJoinsBad $ lintCoreExpr expr
- ; co' <- lintCoercion co
- ; let (Pair from_ty to_ty, role) = coercionKindRole co'
- ; checkValueType to_ty $
- text "target of cast" <+> quotes (ppr co')
- ; lintRole co' Representational role
- ; ensureEqTys from_ty expr_ty (mkCastErr expr co' from_ty expr_ty)
- ; return to_ty }
+ = do expr_ty <- markAllJoinsBad $ lintCoreExpr expr
+ lintCastExpr expr expr_ty co
lintCoreExpr (Tick tickish expr)
= do case tickish of
@@ -860,6 +903,31 @@ lintCoreExpr e@(Let (Rec pairs) body)
bndrs = map fst pairs
lintCoreExpr e@(App _ _)
+ | Var fun <- fun
+ , fun `hasKey` runRWKey
+ -- N.B. we may have an over-saturated application of the form:
+ -- runRW (\s -> \x -> ...) y
+ , arg_ty1 : arg_ty2 : arg3 : rest <- args
+ = do { fun_ty1 <- lintCoreArg (idType fun) arg_ty1
+ ; fun_ty2 <- lintCoreArg fun_ty1 arg_ty2
+ -- See Note [Linting of runRW#]
+ ; let lintRunRWCont :: CoreArg -> LintM LintedType
+ lintRunRWCont (Cast expr co) = do
+ ty <- lintRunRWCont expr
+ lintCastExpr expr ty co
+ lintRunRWCont expr@(Lam _ _) = do
+ lintJoinLams 1 (Just fun) expr
+ lintRunRWCont other = markAllJoinsBad $ lintCoreExpr other
+ -- TODO: Look through ticks?
+ ; arg3_ty <- lintRunRWCont arg3
+ ; app_ty <- lintValApp arg3 fun_ty2 arg3_ty
+ ; lintCoreArgs app_ty rest }
+
+ | Var fun <- fun
+ , fun `hasKey` runRWKey
+ = failWithL (text "Invalid runRW# application")
+
+ | otherwise
= do { fun_ty <- lintCoreFun fun (length args)
; lintCoreArgs fun_ty args }
where
@@ -1139,11 +1207,15 @@ lintTyApp fun_ty arg_ty
= failWithL (mkTyAppMsg fun_ty arg_ty)
-----------------
+
+-- | @lintValApp arg fun_ty arg_ty@ lints an application of @fun arg@
+-- where @fun :: fun_ty@ and @arg :: arg_ty@, returning the type of the
+-- application.
lintValApp :: CoreExpr -> LintedType -> LintedType -> LintM LintedType
lintValApp arg fun_ty arg_ty
- | Just (arg,res) <- splitFunTy_maybe fun_ty
- = do { ensureEqTys arg arg_ty err1
- ; return res }
+ | Just (arg_ty', res_ty') <- splitFunTy_maybe fun_ty
+ = do { ensureEqTys arg_ty' arg_ty err1
+ ; return res_ty' }
| otherwise
= failWithL err2
where
@@ -2780,11 +2852,11 @@ mkInvalidJoinPointMsg var ty
2 (ppr var <+> dcolon <+> ppr ty)
mkBadJoinArityMsg :: Var -> Int -> Int -> CoreExpr -> SDoc
-mkBadJoinArityMsg var ar nlams rhs
+mkBadJoinArityMsg var ar n rhs
= vcat [ text "Join point has too few lambdas",
text "Join var:" <+> ppr var,
text "Join arity:" <+> ppr ar,
- text "Number of lambdas:" <+> ppr nlams,
+ text "Number of lambdas:" <+> ppr (ar - n),
text "Rhs = " <+> ppr rhs
]