summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/simplCore/SimplCore.hs2
-rw-r--r--compiler/simplCore/Simplify.hs120
2 files changed, 86 insertions, 36 deletions
diff --git a/compiler/simplCore/SimplCore.hs b/compiler/simplCore/SimplCore.hs
index a34baa8301..fe6d44625a 100644
--- a/compiler/simplCore/SimplCore.hs
+++ b/compiler/simplCore/SimplCore.hs
@@ -767,7 +767,7 @@ simplifyPgmIO pass@(CoreDoSimplify max_iterations mode)
-- for imported Ids. Eg RULE map my_f = blah
-- If we have a substitution my_f :-> other_f, we'd better
-- apply it to the rule to, or it'll never match
- ; rules1 <- simplRules env1 Nothing rules
+ ; rules1 <- simplRules env1 Nothing rules Nothing
; return (getTopFloatBinds floats, rules1) } ;
diff --git a/compiler/simplCore/Simplify.hs b/compiler/simplCore/Simplify.hs
index 53e3a210de..a60df1c0ad 100644
--- a/compiler/simplCore/Simplify.hs
+++ b/compiler/simplCore/Simplify.hs
@@ -24,7 +24,7 @@ import Id
import MkId ( seqId )
import MkCore ( mkImpossibleExpr, castBottomExpr )
import IdInfo
-import Name ( Name, mkSystemVarName, isExternalName, getOccFS )
+import Name ( mkSystemVarName, isExternalName, getOccFS )
import Coercion hiding ( substCo, substCoVar )
import OptCoercion ( optCoercion )
import FamInstEnv ( topNormaliseType_maybe )
@@ -143,11 +143,11 @@ simplTopBinds env0 binds0
; (floats, env2) <- simpl_binds env1 binds
; return (float `addFloats` floats, env2) }
- simpl_bind env (Rec pairs) = simplRecBind env TopLevel Nothing pairs
- simpl_bind env (NonRec b r) = do { (env', b') <- addBndrRules env b (lookupRecBndr env b)
- ; simplRecOrTopPair env' TopLevel
- NonRecursive Nothing
- b b' r }
+ simpl_bind env (Rec pairs)
+ = simplRecBind env TopLevel Nothing pairs
+ simpl_bind env (NonRec b r)
+ = do { (env', b') <- addBndrRules env b (lookupRecBndr env b) Nothing
+ ; simplRecOrTopPair env' TopLevel NonRecursive Nothing b b' r }
{-
************************************************************************
@@ -160,7 +160,7 @@ simplRecBind is used for
* recursive bindings only
-}
-simplRecBind :: SimplEnv -> TopLevelFlag -> Maybe SimplCont
+simplRecBind :: SimplEnv -> TopLevelFlag -> MaybeJoinCont
-> [(InId, InExpr)]
-> SimplM (SimplFloats, SimplEnv)
simplRecBind env0 top_lvl mb_cont pairs0
@@ -171,7 +171,7 @@ simplRecBind env0 top_lvl mb_cont pairs0
add_rules :: SimplEnv -> (InBndr,InExpr) -> SimplM (SimplEnv, (InBndr, OutBndr, InExpr))
-- Add the (substituted) rules to the binder
add_rules env (bndr, rhs)
- = do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr)
+ = do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr) mb_cont
; return (env', (bndr, bndr', rhs)) }
go env [] = return (emptyFloats env, env)
@@ -191,7 +191,7 @@ It assumes the binder has already been simplified, but not its IdInfo.
-}
simplRecOrTopPair :: SimplEnv
- -> TopLevelFlag -> RecFlag -> Maybe SimplCont
+ -> TopLevelFlag -> RecFlag -> MaybeJoinCont
-> InId -> OutBndr -> InExpr -- Binder and rhs
-> SimplM (SimplFloats, SimplEnv)
@@ -616,7 +616,7 @@ Nor does it do the atomic-argument thing
completeBind :: SimplEnv
-> TopLevelFlag -- Flag stuck into unfolding
- -> Maybe SimplCont -- Required only for join point
+ -> MaybeJoinCont -- Required only for join point
-> InId -- Old binder
-> OutId -> OutExpr -- New binder and RHS
-> SimplM (SimplFloats, SimplEnv)
@@ -645,7 +645,7 @@ completeBind env top_lvl mb_cont old_bndr new_bndr new_rhs
-- Simplify the unfolding
; new_unfolding <- simplLetUnfolding env top_lvl mb_cont old_bndr
- final_rhs old_unf
+ final_rhs (idType new_bndr) old_unf
; let final_bndr = addLetBndrInfo new_bndr new_arity is_bot new_unfolding
@@ -1319,7 +1319,8 @@ simplLamBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplLamBndr env bndr
| isId bndr && isFragileUnfolding old_unf -- Special case
= do { (env1, bndr1) <- simplBinder env bndr
- ; unf' <- simplStableUnfolding env1 NotTopLevel Nothing bndr old_unf
+ ; unf' <- simplStableUnfolding env1 NotTopLevel Nothing bndr
+ old_unf (idType bndr1)
; let bndr2 = bndr1 `setIdUnfolding` unf'
; return (modifyInScope env1 bndr2, bndr2) }
@@ -1378,7 +1379,7 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
| otherwise
= ASSERT( not (isTyVar bndr) )
do { (env1, bndr1) <- simplNonRecBndr env bndr
- ; (env2, bndr2) <- addBndrRules env1 bndr bndr1
+ ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 Nothing
; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
; (floats2, expr') <- simplLam env3 bndrs body cont
; return (floats1 `addFloats` floats2, expr') }
@@ -1450,6 +1451,33 @@ Here it'd be far better to drop the unfolding and use the actual RHS.
* *
********************************************************************* -}
+{- Note [Rules and unfolding for join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have
+
+ simplExpr (join j x = rhs ) cont
+ ( {- RULE j (p:ps) = blah -} )
+ ( {- StableUnfolding j = blah -} )
+ (in blah )
+
+Then we will push 'cont' into the rhs of 'j'. But we should *also* push
+'cont' into the RHS of
+ * Any RULEs for j, e.g. generated by SpecConstr
+ * Any stable unfolding for j, e.g. the result of an INLINE pragma
+
+Simplifying rules and stable-unfoldings happens a bit after
+simplifying the right-hand side, so we remember whether or not it
+is a join point, and what 'cont' is, in a value of type MaybeJoinCont
+
+Trac #13900 wsa caused by forgetting to push 'cont' into the RHS
+of a SpecConstr-generated RULE for a join point.
+-}
+
+type MaybeJoinCont = Maybe SimplCont
+ -- Nothing => Not a join point
+ -- Just k => This is a join binding with continuation k
+ -- See Note [Rules and unfolding for join points]
+
simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
-> InExpr -> SimplCont
-> SimplM (SimplFloats, OutExpr)
@@ -1465,7 +1493,7 @@ simplNonRecJoinPoint env bndr rhs body cont
-- and wrap wrap_cont around the whole thing
; let res_ty = contResultType cont
; (env1, bndr1) <- simplNonRecJoinBndr env res_ty bndr
- ; (env2, bndr2) <- addBndrRules env1 bndr bndr1
+ ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (Just cont)
; (floats1, env3) <- simplJoinBind env2 cont bndr bndr2 rhs env
; (floats2, body') <- simplExprF env3 body cont
; return (floats1 `addFloats` floats2, body') }
@@ -3235,13 +3263,13 @@ because we don't know its usage in each RHS separately
-}
simplLetUnfolding :: SimplEnv-> TopLevelFlag
- -> Maybe SimplCont
+ -> MaybeJoinCont
-> InId
- -> OutExpr
+ -> OutExpr -> OutType
-> Unfolding -> SimplM Unfolding
-simplLetUnfolding env top_lvl cont_mb id new_rhs unf
+simplLetUnfolding env top_lvl cont_mb id new_rhs rhs_ty unf
| isStableUnfolding unf
- = simplStableUnfolding env top_lvl cont_mb id unf
+ = simplStableUnfolding env top_lvl cont_mb id unf rhs_ty
| isExitJoinId id
= return noUnfolding -- see Note [Do not inline exit join points]
| otherwise
@@ -3265,26 +3293,26 @@ mkLetUnfolding dflags top_lvl src id new_rhs
-------------------
simplStableUnfolding :: SimplEnv -> TopLevelFlag
- -> Maybe SimplCont -- Just k => a join point with continuation k
+ -> MaybeJoinCont -- Just k => a join point with continuation k
-> InId
- -> Unfolding -> SimplM Unfolding
+ -> Unfolding -> OutType -> SimplM Unfolding
-- Note [Setting the new unfolding]
-simplStableUnfolding env top_lvl mb_cont id unf
+simplStableUnfolding env top_lvl mb_cont id unf rhs_ty
= case unf of
NoUnfolding -> return unf
BootUnfolding -> return unf
OtherCon {} -> return unf
DFunUnfolding { df_bndrs = bndrs, df_con = con, df_args = args }
- -> do { (env', bndrs') <- simplBinders rule_env bndrs
+ -> do { (env', bndrs') <- simplBinders unf_env bndrs
; args' <- mapM (simplExpr env') args
; return (mkDFunUnfolding bndrs' con args') }
CoreUnfolding { uf_tmpl = expr, uf_src = src, uf_guidance = guide }
| isStableSource src
- -> do { expr' <- case mb_cont of
- Just cont -> simplJoinRhs rule_env id expr cont
- Nothing -> simplExpr rule_env expr
+ -> do { expr' <- case mb_cont of -- See Note [Rules and unfolding for join points]
+ Just cont -> simplJoinRhs unf_env id expr cont
+ Nothing -> simplExprC unf_env expr (mkBoringStop rhs_ty)
; case guide of
UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok } -- Happens for INLINE things
-> let guide' = UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok
@@ -3308,7 +3336,7 @@ simplStableUnfolding env top_lvl mb_cont id unf
dflags = seDynFlags env
is_top_lvl = isTopLevel top_lvl
act = idInlineActivation id
- rule_env = updMode (updModeForStableUnfoldings act) env
+ unf_env = updMode (updModeForStableUnfoldings act) env
-- See Note [Simplifying inside stable unfoldings] in SimplUtils
{-
@@ -3350,20 +3378,24 @@ to apply in that function's own right-hand side.
See Note [Forming Rec groups] in OccurAnal
-}
-addBndrRules :: SimplEnv -> InBndr -> OutBndr -> SimplM (SimplEnv, OutBndr)
+addBndrRules :: SimplEnv -> InBndr -> OutBndr
+ -> MaybeJoinCont -- Just k for a join point binder
+ -- Nothing otherwise
+ -> SimplM (SimplEnv, OutBndr)
-- Rules are added back into the bin
-addBndrRules env in_id out_id
+addBndrRules env in_id out_id mb_cont
| null old_rules
= return (env, out_id)
| otherwise
- = do { new_rules <- simplRules env (Just (idName out_id)) old_rules
+ = do { new_rules <- simplRules env (Just out_id) old_rules mb_cont
; let final_id = out_id `setIdSpecialisation` mkRuleInfo new_rules
; return (modifyInScope env final_id, final_id) }
where
old_rules = ruleInfoRules (idSpecialisation in_id)
-simplRules :: SimplEnv -> Maybe Name -> [CoreRule] -> SimplM [CoreRule]
-simplRules env mb_new_nm rules
+simplRules :: SimplEnv -> Maybe OutId -> [CoreRule]
+ -> MaybeJoinCont -> SimplM [CoreRule]
+simplRules env mb_new_id rules mb_cont
= mapM simpl_rule rules
where
simpl_rule rule@(BuiltinRule {})
@@ -3373,11 +3405,29 @@ simplRules env mb_new_nm rules
, ru_fn = fn_name, ru_rhs = rhs })
= do { (env', bndrs') <- simplBinders env bndrs
; let rhs_ty = substTy env' (exprType rhs)
- rule_cont = mkBoringStop rhs_ty
- rule_env = updMode updModeForRules env'
+ rhs_cont = case mb_cont of -- See Note [Rules and unfolding for join points]
+ Nothing -> mkBoringStop rhs_ty
+ Just cont -> ASSERT2( join_ok, bad_join_msg )
+ cont
+ rule_env = updMode updModeForRules env'
+ fn_name' = case mb_new_id of
+ Just id -> idName id
+ Nothing -> fn_name
+
+ -- join_ok is an assertion check that the join-arity of the
+ -- binder matches that of the rule, so that pushing the
+ -- continuation into the RHS makes sense
+ join_ok = case mb_new_id of
+ Just id | Just join_arity <- isJoinId_maybe id
+ -> length args == join_arity
+ _ -> False
+ bad_join_msg = vcat [ ppr mb_new_id, ppr rule
+ , ppr (fmap isJoinId_maybe mb_new_id) ]
+
; args' <- mapM (simplExpr rule_env) args
- ; rhs' <- simplExprC rule_env rhs rule_cont
+ ; rhs' <- simplExprC rule_env rhs rhs_cont
; return (rule { ru_bndrs = bndrs'
- , ru_fn = mb_new_nm `orElse` fn_name
+ , ru_fn = fn_name'
, ru_args = args'
, ru_rhs = rhs' }) }
+