diff options
Diffstat (limited to 'compiler/simplCore')
-rw-r--r-- | compiler/simplCore/FloatIn.hs | 139 | ||||
-rw-r--r-- | compiler/simplCore/SimplUtils.hs | 7 | ||||
-rw-r--r-- | compiler/simplCore/Simplify.hs | 28 |
3 files changed, 135 insertions, 39 deletions
diff --git a/compiler/simplCore/FloatIn.hs b/compiler/simplCore/FloatIn.hs index 2593b1d7a1..04e4d32f5e 100644 --- a/compiler/simplCore/FloatIn.hs +++ b/compiler/simplCore/FloatIn.hs @@ -26,8 +26,9 @@ import MkCore import HscTypes ( ModGuts(..) ) import CoreUtils import CoreFVs +import CoreUnfold import CoreMonad ( CoreM ) -import Id ( isOneShotBndr, idType, isJoinId, isJoinId_maybe ) +import Id import Var import Type import VarSet @@ -151,7 +152,7 @@ fiExpr dflags to_drop (_, AnnCast expr (co_ann, co)) Cast (fiExpr dflags e_drop expr) co where [drop_here, e_drop, co_drop] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla [freeVarsOf expr, freeVarsOfAnn co_ann] to_drop @@ -173,7 +174,7 @@ fiExpr dflags to_drop ann_expr@(_,AnnApp {}) arg_fvs = map freeVarsOf ann_args (drop_here : extra_drop : fun_drop : arg_drops) - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla (extra_fvs : fun_fvs : arg_fvs) to_drop -- Shortcut behaviour: if to_drop is empty, @@ -446,7 +447,7 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr _ [(con,alt_bndrs,rhs)]) scrut_fvs = freeVarsOf scrut [shared_binds, scrut_binds, rhs_binds] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla [scrut_fvs, rhs_fvs] to_drop @@ -456,16 +457,17 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts) Case (fiExpr dflags scrut_drops scrut) case_bndr ty (zipWith fi_alt alts_drops_s alts) where - -- Float into the scrut and alts-considered-together just like App + -- Float into the scrut and alts-considered-together just like App [drop_here1, scrut_drops, alts_drops] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla [scrut_fvs, all_alts_fvs] to_drop - -- Float into the alts with the is_case flag set + -- Float into the alts with the SepCase context set (drop_here2 : alts_drops_s) | [ _ ] <- alts = [] : [alts_drops] - | otherwise = sepBindsByDropPoint dflags True alts_fvs alts_drops + | otherwise = sepBindsByDropPoint dflags SepCase + alts_fvs alts_drops scrut_fvs = freeVarsOf scrut alts_fvs = map alt_fvs alts @@ -491,7 +493,7 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs = ( extra_binds ++ shared_binds -- Land these before -- See Note [extra_fvs (1,2)] , FB (unitDVarSet id) rhs_fvs' -- The new binding itself - (FloatLet (NonRec id rhs')) + (FloatLet (NonRec id rhs')) , body_binds ) -- Land these after where @@ -508,7 +510,8 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs -- But do float into join points [shared_binds, extra_binds, rhs_binds, body_binds] - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags + (if isJoinId id then SepNonRecJoin else SepVanilla) [extra_fvs, rhs_fvs, body_fvs2] to_drop @@ -533,7 +536,7 @@ fiBind dflags to_drop (AnnRec bindings) body_fvs , noFloatIntoRhs Recursive bndr rhs ] (shared_binds:extra_binds:body_binds:rhss_binds) - = sepBindsByDropPoint dflags False + = sepBindsByDropPoint dflags SepVanilla (extra_fvs:body_fvs:rhss_fvs) to_drop @@ -654,9 +657,41 @@ We have to maintain the order on these drop-point-related lists. -- pprFIB :: FloatInBinds -> SDoc -- pprFIB fibs = text "FIB:" <+> ppr [b | FB _ _ b <- fibs] +data SepCtxt + = SepCase + | SepNonRecJoin + | SepVanilla + +{- Note [Floating join points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +We push join-point bindings inwards merrily, just like let-bindings. +They may get floated out again; e.g. + join j1 x = e1 + in join j2 y = ...j1... + in ... +==> + join j2 y = join { j1 x = e1 } in ...j1... + in ... + +Here we might float j1 out again. But we must float it in in case it +allows an ordinary let-binding to go too. E.g. + let x = <thunk> + in join j1 x = e1 + in join j2 y = ...j1... + in ... +===> + join j2 y = let { x = <thunk } + in join { j1 x = e1 } + in ...j1... + in ... + +Ths is important; now the thunk for 'x' may not be allocated on the +paths that don't involve j2. +-} + sepBindsByDropPoint :: DynFlags - -> Bool -- True <=> is case expression + -> SepCtxt -> [FreeVarSet] -- One set of FVs per drop point -- Always at least two long! -> FloatInBinds -- Candidate floaters @@ -672,15 +707,15 @@ sepBindsByDropPoint type DropBox = (FreeVarSet, FloatInBinds) -sepBindsByDropPoint dflags is_case drop_pts floaters +sepBindsByDropPoint dflags sep_ctxt drop_pts floaters | null floaters -- Shortcut common case = [] : [[] | _ <- drop_pts] | otherwise - = ASSERT( drop_pts `lengthAtLeast` 2 ) + = ASSERT( n_alts >= 2 ) -- Invariant on caller go floaters (map (\fvs -> (fvs, [])) (emptyDVarSet : drop_pts)) where - n_alts = length drop_pts + n_alts = length drop_pts -- n_alts >= 2 go :: FloatInBinds -> [DropBox] -> [FloatInBinds] -- The *first* one in the argument list is the drop_here set @@ -697,18 +732,29 @@ sepBindsByDropPoint dflags is_case drop_pts floaters (used_here : used_in_flags) = [ fvs `intersectsDVarSet` bndrs | (fvs, _) <- drop_boxes] - drop_here = used_here || cant_push + drop_here = used_here || not want_push + want_push = case sep_ctxt of + SepCase -> want_case_push + SepNonRecJoin -> want_join_push + SepVanilla -> want_let_push n_used_alts = count id used_in_flags -- returns number of Trues in list. - cant_push - | is_case = n_used_alts == n_alts -- Used in all, don't push + no_duplication = n_used_alts <= 1 -- See Note [Duplicating floats] + + duplicable_float = floatIsDupable dflags bind + -- True <=> duplication does not dup much code + -- (but it might still duplicate work!) + -- See Note [Duplicating floats] + + want_case_push = -- n_used_alts < n_alts && -- Used in all alts, don't push -- Remember n_alts > 1 - || (n_used_alts > 1 && not (floatIsDupable dflags bind)) - -- floatIsDupable: see Note [Duplicating floats] + (no_duplication || duplicable_float) - | otherwise = floatIsCase bind || n_used_alts > 1 - -- floatIsCase: see Note [Floating primops] + want_let_push = not (floatIsCase bind) -- See Note [Floating primops] + && no_duplication + + want_join_push = no_duplication -- See Note [Floating join points] new_boxes | drop_here = (insert here_box : fork_boxes) | otherwise = (here_box : new_fork_boxes) @@ -727,18 +773,43 @@ sepBindsByDropPoint dflags is_case drop_pts floaters {- Note [Duplicating floats] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +no_duplication is true if the binding us used in at most one +alternative. (Zero is rare; it means the binding is dead.) + +If no_duplication is false, we may still float: -For case expressions we duplicate the binding if it is reasonably -small, and if it is not used in all the RHSs This is good for -situations like +* For /case expressions/ only (SepCase) we duplicate the binding if it + is reasonably small, and if it is not used in all the RHSs. This is + good for situations like let x = I# y in case e of C -> error x D -> error x E -> ...not mentioning x... -If the thing is used in all RHSs there is nothing gained, -so we don't duplicate then. + If the thing is used in all RHSs there is nothing gained, + so we don't duplicate then. + +* This is NOT GOOD for other float-in places, like lets (SepVanilla). + Consider + let x = <small> in + let v = ...x... + in ...x... + + We definitely don't want to duplicate x into the RHS of v and the + body! At least, it would be OK if <small> was a value; but we don't + test that. + +* For non-recursive join bindings (SepNonRecJoin) we must be equally + careful. Eg + let x = <small> in + join j = ...x... + in case f x of + A -> j + B -> something else + C -> j + Here we must not duplicate the let-x binding into the RHS of j + and the body, or we'll duplicate the redex. -} floatedBindsFVs :: FloatInBinds -> FreeVarSet @@ -754,9 +825,19 @@ wrapFloats (FB _ _ fl : bs) e = wrapFloats bs (wrapFloat fl e) floatIsDupable :: DynFlags -> FloatBind -> Bool floatIsDupable dflags (FloatCase scrut _ _ _) = exprIsDupable dflags scrut -floatIsDupable dflags (FloatLet (Rec prs)) = all (exprIsDupable dflags . snd) prs -floatIsDupable dflags (FloatLet (NonRec _ r)) = exprIsDupable dflags r +floatIsDupable dflags (FloatLet (Rec prs)) = -- all (exprIsDupable dflags . snd) prs + all (smallEnough dflags) prs +floatIsDupable dflags (FloatLet (NonRec b r)) = -- exprIsDupable dflags r + smallEnough dflags (b,r) + +smallEnough :: DynFlags -> (Id,CoreExpr) -> Bool +smallEnough dflags (_,rhs) + = couldBeSmallEnoughToInline dflags (ufUseThreshold dflags) rhs floatIsCase :: FloatBind -> Bool floatIsCase (FloatCase {}) = True floatIsCase (FloatLet {}) = False + +--floatIsJoin :: FloatBind -> Bool +--floatIsJoin (FloatCase {}) = False +--floatIsJoin (FloatLet b) = isJoinBind b diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs index ca1b9bd23d..82d20e20c5 100644 --- a/compiler/simplCore/SimplUtils.hs +++ b/compiler/simplCore/SimplUtils.hs @@ -1271,6 +1271,7 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs | exprIsTrivial rhs = True | otherwise = case occ_info of +{- -- The point of examining occ_info here is that for *non-values* -- that occur outside a lambda, the call-site inliner won't have -- a chance (because it doesn't know that the thing @@ -1285,7 +1286,8 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs -- in allocation if you miss this out OneOcc { occ_in_lam = in_lam, occ_int_cxt = int_cxt } -- OneOcc => no code-duplication issue - -> smallEnoughToInline dflags unfolding -- Small enough to dup + -> not (isJoinId bndr) -- NEW! + && smallEnoughToInline dflags unfolding -- Small enough to dup -- ToDo: consider discount on smallEnoughToInline if int_cxt is true -- -- NB: Do NOT inline arbitrarily big things, even if one_br is True @@ -1310,6 +1312,7 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs -- int_cxt to prevent us inlining inside a lambda without some -- good reason. See the notes on int_cxt in preInlineUnconditionally +-} IAmDead -> True -- This happens; for example, the case_bndr during case of -- known constructor: case (a,b) of x { (p,q) -> ... } -- Here x isn't mentioned in the RHS, so we don't want to @@ -1331,7 +1334,7 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs where unfolding = idUnfolding bndr - dflags = seDynFlags env + _dflags = seDynFlags env active = isActive (sm_phase (getMode env)) (idInlineActivation bndr) -- See Note [pre/postInlineUnconditionally in gentle mode] diff --git a/compiler/simplCore/Simplify.hs b/compiler/simplCore/Simplify.hs index 872973925f..e3237bfcee 100644 --- a/compiler/simplCore/Simplify.hs +++ b/compiler/simplCore/Simplify.hs @@ -45,7 +45,7 @@ import BasicTypes ( TopLevelFlag(..), isNotTopLevel, isTopLevel, RecFlag(..), Arity ) import MonadUtils ( mapAccumLM, liftIO ) import Var ( isTyCoVar ) -import Maybes ( orElse ) +import Maybes ( isJust, orElse ) import Control.Monad import Outputable import FastString @@ -326,6 +326,7 @@ simplNonRecX :: SimplEnv -- simplified, notably in knownCon. It uses case-binding where necessary. -- -- Precondition: rhs satisfies the let/app invariant +-- Not used for JoinIds simplNonRecX env bndr new_rhs | ASSERT2( not (isJoinId bndr), ppr bndr ) @@ -350,6 +351,7 @@ completeNonRecX :: TopLevelFlag -> SimplEnv -> SimplM (SimplFloats, SimplEnv) -- The new binding is in the floats -- Precondition: rhs satisfies the let/app invariant -- See Note [CoreSyn let/app invariant] in CoreSyn +-- Not used for JoinIds completeNonRecX top_lvl env is_strict old_bndr new_bndr new_rhs = ASSERT2( not (isJoinId new_bndr), ppr new_bndr ) @@ -549,7 +551,7 @@ makeTrivialWithInfo mode top_lvl occ_fs info expr -- Now something very like completeBind, -- but without the postInlineUnconditinoally part ; (arity, is_bot, expr2) <- tryEtaExpandRhs mode var expr1 - ; unf <- mkLetUnfolding (sm_dflags mode) top_lvl InlineRhs var expr2 + ; unf <- simplVanillaUnfolding (sm_dflags mode) top_lvl InlineRhs var expr2 ; let final_id = addLetBndrInfo var arity is_bot unf bind = NonRec final_id expr2 @@ -3390,15 +3392,25 @@ simplLetUnfolding :: SimplEnv-> TopLevelFlag simplLetUnfolding env top_lvl cont_mb id new_rhs rhs_ty unf | isStableUnfolding unf = simplStableUnfolding env top_lvl cont_mb id unf rhs_ty - | isExitJoinId id + + | isJust cont_mb -- A join point + = simplJoinUnfolding env id new_rhs + + | otherwise + = simplVanillaUnfolding (seDynFlags env) top_lvl InlineRhs id new_rhs + +------------------- +simplJoinUnfolding :: SimplEnv -> InId -> OutExpr -> SimplM Unfolding +simplJoinUnfolding env join_id new_rhs + | isExitJoinId join_id = return noUnfolding -- See Note [Do not inline exit join points] in Exitify | otherwise - = mkLetUnfolding (seDynFlags env) top_lvl InlineRhs id new_rhs + = return (mkJoinUnfolding (seDynFlags env) new_rhs) ------------------- -mkLetUnfolding :: DynFlags -> TopLevelFlag -> UnfoldingSource - -> InId -> OutExpr -> SimplM Unfolding -mkLetUnfolding dflags top_lvl src id new_rhs +simplVanillaUnfolding :: DynFlags -> TopLevelFlag -> UnfoldingSource + -> InId -> OutExpr -> SimplM Unfolding +simplVanillaUnfolding dflags top_lvl src id new_rhs = is_bottoming `seq` -- See Note [Force bottoming field] return (mkUnfolding dflags src is_top_lvl is_bottoming new_rhs) -- We make an unfolding *even for loop-breakers*. @@ -3456,7 +3468,7 @@ simplStableUnfolding env top_lvl mb_cont id unf rhs_ty -- See Note [Top-level flag on inline rules] in CoreUnfold _other -- Happens for INLINABLE things - -> mkLetUnfolding dflags top_lvl src id expr' } + -> simplVanillaUnfolding dflags top_lvl src id expr' } -- If the guidance is UnfIfGoodArgs, this is an INLINABLE -- unfolding, and we need to make sure the guidance is kept up -- to date with respect to any changes in the unfolding. |