diff options
author | Simon Peyton Jones <simonpj@microsoft.com> | 2018-03-26 17:16:14 +0100 |
---|---|---|
committer | Simon Peyton Jones <simonpj@microsoft.com> | 2018-03-27 09:29:13 +0100 |
commit | a7628dcd2cb570fe41de247af6aa71a89177a9b9 (patch) | |
tree | 4e8b617ea94c30680654898e2bf314a16b077bc3 /compiler/simplCore/Simplify.hs | |
parent | 9cc6a182fc6f6d331774f0818bba5755188328cd (diff) | |
download | haskell-a7628dcd2cb570fe41de247af6aa71a89177a9b9.tar.gz |
Deal with join points with RULES
Trac #13900 showed that when we have a join point that
has a RULE, we must push the continuation into the RHS
of the RULE.
See Note [Rules and unfolding for join points]
It's hard to tickle this bug, so I have not added a regression test.
Diffstat (limited to 'compiler/simplCore/Simplify.hs')
-rw-r--r-- | compiler/simplCore/Simplify.hs | 120 |
1 files changed, 85 insertions, 35 deletions
diff --git a/compiler/simplCore/Simplify.hs b/compiler/simplCore/Simplify.hs index 53e3a210de..a60df1c0ad 100644 --- a/compiler/simplCore/Simplify.hs +++ b/compiler/simplCore/Simplify.hs @@ -24,7 +24,7 @@ import Id import MkId ( seqId ) import MkCore ( mkImpossibleExpr, castBottomExpr ) import IdInfo -import Name ( Name, mkSystemVarName, isExternalName, getOccFS ) +import Name ( mkSystemVarName, isExternalName, getOccFS ) import Coercion hiding ( substCo, substCoVar ) import OptCoercion ( optCoercion ) import FamInstEnv ( topNormaliseType_maybe ) @@ -143,11 +143,11 @@ simplTopBinds env0 binds0 ; (floats, env2) <- simpl_binds env1 binds ; return (float `addFloats` floats, env2) } - simpl_bind env (Rec pairs) = simplRecBind env TopLevel Nothing pairs - simpl_bind env (NonRec b r) = do { (env', b') <- addBndrRules env b (lookupRecBndr env b) - ; simplRecOrTopPair env' TopLevel - NonRecursive Nothing - b b' r } + simpl_bind env (Rec pairs) + = simplRecBind env TopLevel Nothing pairs + simpl_bind env (NonRec b r) + = do { (env', b') <- addBndrRules env b (lookupRecBndr env b) Nothing + ; simplRecOrTopPair env' TopLevel NonRecursive Nothing b b' r } {- ************************************************************************ @@ -160,7 +160,7 @@ simplRecBind is used for * recursive bindings only -} -simplRecBind :: SimplEnv -> TopLevelFlag -> Maybe SimplCont +simplRecBind :: SimplEnv -> TopLevelFlag -> MaybeJoinCont -> [(InId, InExpr)] -> SimplM (SimplFloats, SimplEnv) simplRecBind env0 top_lvl mb_cont pairs0 @@ -171,7 +171,7 @@ simplRecBind env0 top_lvl mb_cont pairs0 add_rules :: SimplEnv -> (InBndr,InExpr) -> SimplM (SimplEnv, (InBndr, OutBndr, InExpr)) -- Add the (substituted) rules to the binder add_rules env (bndr, rhs) - = do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr) + = do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr) mb_cont ; return (env', (bndr, bndr', rhs)) } go env [] = return (emptyFloats env, env) @@ -191,7 +191,7 @@ It assumes the binder has already been simplified, but not its IdInfo. -} simplRecOrTopPair :: SimplEnv - -> TopLevelFlag -> RecFlag -> Maybe SimplCont + -> TopLevelFlag -> RecFlag -> MaybeJoinCont -> InId -> OutBndr -> InExpr -- Binder and rhs -> SimplM (SimplFloats, SimplEnv) @@ -616,7 +616,7 @@ Nor does it do the atomic-argument thing completeBind :: SimplEnv -> TopLevelFlag -- Flag stuck into unfolding - -> Maybe SimplCont -- Required only for join point + -> MaybeJoinCont -- Required only for join point -> InId -- Old binder -> OutId -> OutExpr -- New binder and RHS -> SimplM (SimplFloats, SimplEnv) @@ -645,7 +645,7 @@ completeBind env top_lvl mb_cont old_bndr new_bndr new_rhs -- Simplify the unfolding ; new_unfolding <- simplLetUnfolding env top_lvl mb_cont old_bndr - final_rhs old_unf + final_rhs (idType new_bndr) old_unf ; let final_bndr = addLetBndrInfo new_bndr new_arity is_bot new_unfolding @@ -1319,7 +1319,8 @@ simplLamBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr) simplLamBndr env bndr | isId bndr && isFragileUnfolding old_unf -- Special case = do { (env1, bndr1) <- simplBinder env bndr - ; unf' <- simplStableUnfolding env1 NotTopLevel Nothing bndr old_unf + ; unf' <- simplStableUnfolding env1 NotTopLevel Nothing bndr + old_unf (idType bndr1) ; let bndr2 = bndr1 `setIdUnfolding` unf' ; return (modifyInScope env1 bndr2, bndr2) } @@ -1378,7 +1379,7 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont | otherwise = ASSERT( not (isTyVar bndr) ) do { (env1, bndr1) <- simplNonRecBndr env bndr - ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 + ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 Nothing ; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se ; (floats2, expr') <- simplLam env3 bndrs body cont ; return (floats1 `addFloats` floats2, expr') } @@ -1450,6 +1451,33 @@ Here it'd be far better to drop the unfolding and use the actual RHS. * * ********************************************************************* -} +{- Note [Rules and unfolding for join points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Suppose we have + + simplExpr (join j x = rhs ) cont + ( {- RULE j (p:ps) = blah -} ) + ( {- StableUnfolding j = blah -} ) + (in blah ) + +Then we will push 'cont' into the rhs of 'j'. But we should *also* push +'cont' into the RHS of + * Any RULEs for j, e.g. generated by SpecConstr + * Any stable unfolding for j, e.g. the result of an INLINE pragma + +Simplifying rules and stable-unfoldings happens a bit after +simplifying the right-hand side, so we remember whether or not it +is a join point, and what 'cont' is, in a value of type MaybeJoinCont + +Trac #13900 wsa caused by forgetting to push 'cont' into the RHS +of a SpecConstr-generated RULE for a join point. +-} + +type MaybeJoinCont = Maybe SimplCont + -- Nothing => Not a join point + -- Just k => This is a join binding with continuation k + -- See Note [Rules and unfolding for join points] + simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr -> InExpr -> SimplCont -> SimplM (SimplFloats, OutExpr) @@ -1465,7 +1493,7 @@ simplNonRecJoinPoint env bndr rhs body cont -- and wrap wrap_cont around the whole thing ; let res_ty = contResultType cont ; (env1, bndr1) <- simplNonRecJoinBndr env res_ty bndr - ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 + ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (Just cont) ; (floats1, env3) <- simplJoinBind env2 cont bndr bndr2 rhs env ; (floats2, body') <- simplExprF env3 body cont ; return (floats1 `addFloats` floats2, body') } @@ -3235,13 +3263,13 @@ because we don't know its usage in each RHS separately -} simplLetUnfolding :: SimplEnv-> TopLevelFlag - -> Maybe SimplCont + -> MaybeJoinCont -> InId - -> OutExpr + -> OutExpr -> OutType -> Unfolding -> SimplM Unfolding -simplLetUnfolding env top_lvl cont_mb id new_rhs unf +simplLetUnfolding env top_lvl cont_mb id new_rhs rhs_ty unf | isStableUnfolding unf - = simplStableUnfolding env top_lvl cont_mb id unf + = simplStableUnfolding env top_lvl cont_mb id unf rhs_ty | isExitJoinId id = return noUnfolding -- see Note [Do not inline exit join points] | otherwise @@ -3265,26 +3293,26 @@ mkLetUnfolding dflags top_lvl src id new_rhs ------------------- simplStableUnfolding :: SimplEnv -> TopLevelFlag - -> Maybe SimplCont -- Just k => a join point with continuation k + -> MaybeJoinCont -- Just k => a join point with continuation k -> InId - -> Unfolding -> SimplM Unfolding + -> Unfolding -> OutType -> SimplM Unfolding -- Note [Setting the new unfolding] -simplStableUnfolding env top_lvl mb_cont id unf +simplStableUnfolding env top_lvl mb_cont id unf rhs_ty = case unf of NoUnfolding -> return unf BootUnfolding -> return unf OtherCon {} -> return unf DFunUnfolding { df_bndrs = bndrs, df_con = con, df_args = args } - -> do { (env', bndrs') <- simplBinders rule_env bndrs + -> do { (env', bndrs') <- simplBinders unf_env bndrs ; args' <- mapM (simplExpr env') args ; return (mkDFunUnfolding bndrs' con args') } CoreUnfolding { uf_tmpl = expr, uf_src = src, uf_guidance = guide } | isStableSource src - -> do { expr' <- case mb_cont of - Just cont -> simplJoinRhs rule_env id expr cont - Nothing -> simplExpr rule_env expr + -> do { expr' <- case mb_cont of -- See Note [Rules and unfolding for join points] + Just cont -> simplJoinRhs unf_env id expr cont + Nothing -> simplExprC unf_env expr (mkBoringStop rhs_ty) ; case guide of UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok } -- Happens for INLINE things -> let guide' = UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok @@ -3308,7 +3336,7 @@ simplStableUnfolding env top_lvl mb_cont id unf dflags = seDynFlags env is_top_lvl = isTopLevel top_lvl act = idInlineActivation id - rule_env = updMode (updModeForStableUnfoldings act) env + unf_env = updMode (updModeForStableUnfoldings act) env -- See Note [Simplifying inside stable unfoldings] in SimplUtils {- @@ -3350,20 +3378,24 @@ to apply in that function's own right-hand side. See Note [Forming Rec groups] in OccurAnal -} -addBndrRules :: SimplEnv -> InBndr -> OutBndr -> SimplM (SimplEnv, OutBndr) +addBndrRules :: SimplEnv -> InBndr -> OutBndr + -> MaybeJoinCont -- Just k for a join point binder + -- Nothing otherwise + -> SimplM (SimplEnv, OutBndr) -- Rules are added back into the bin -addBndrRules env in_id out_id +addBndrRules env in_id out_id mb_cont | null old_rules = return (env, out_id) | otherwise - = do { new_rules <- simplRules env (Just (idName out_id)) old_rules + = do { new_rules <- simplRules env (Just out_id) old_rules mb_cont ; let final_id = out_id `setIdSpecialisation` mkRuleInfo new_rules ; return (modifyInScope env final_id, final_id) } where old_rules = ruleInfoRules (idSpecialisation in_id) -simplRules :: SimplEnv -> Maybe Name -> [CoreRule] -> SimplM [CoreRule] -simplRules env mb_new_nm rules +simplRules :: SimplEnv -> Maybe OutId -> [CoreRule] + -> MaybeJoinCont -> SimplM [CoreRule] +simplRules env mb_new_id rules mb_cont = mapM simpl_rule rules where simpl_rule rule@(BuiltinRule {}) @@ -3373,11 +3405,29 @@ simplRules env mb_new_nm rules , ru_fn = fn_name, ru_rhs = rhs }) = do { (env', bndrs') <- simplBinders env bndrs ; let rhs_ty = substTy env' (exprType rhs) - rule_cont = mkBoringStop rhs_ty - rule_env = updMode updModeForRules env' + rhs_cont = case mb_cont of -- See Note [Rules and unfolding for join points] + Nothing -> mkBoringStop rhs_ty + Just cont -> ASSERT2( join_ok, bad_join_msg ) + cont + rule_env = updMode updModeForRules env' + fn_name' = case mb_new_id of + Just id -> idName id + Nothing -> fn_name + + -- join_ok is an assertion check that the join-arity of the + -- binder matches that of the rule, so that pushing the + -- continuation into the RHS makes sense + join_ok = case mb_new_id of + Just id | Just join_arity <- isJoinId_maybe id + -> length args == join_arity + _ -> False + bad_join_msg = vcat [ ppr mb_new_id, ppr rule + , ppr (fmap isJoinId_maybe mb_new_id) ] + ; args' <- mapM (simplExpr rule_env) args - ; rhs' <- simplExprC rule_env rhs rule_cont + ; rhs' <- simplExprC rule_env rhs rhs_cont ; return (rule { ru_bndrs = bndrs' - , ru_fn = mb_new_nm `orElse` fn_name + , ru_fn = fn_name' , ru_args = args' , ru_rhs = rhs' }) } + |