diff options
| author | Simon Peyton Jones <simonpj@microsoft.com> | 2011-06-27 08:54:29 +0100 |
|---|---|---|
| committer | Simon Peyton Jones <simonpj@microsoft.com> | 2011-06-27 09:15:50 +0100 |
| commit | 9cb20b488d4986c122b0461a54bc5c970f9d8502 (patch) | |
| tree | b73b5b873fbba9d5e11f5d565ab11156e128e081 | |
| parent | 6ed586542afab8b18bcb45ac649a89e13e258643 (diff) | |
| download | haskell-9cb20b488d4986c122b0461a54bc5c970f9d8502.tar.gz | |
Add case-floating to the float-out pass
There are two things in this patch. First, a new feature.
Given (case x of I# y -> ...)
where 'x' is known to be evaluated, the float-out pass
will float the case outwards towards x's binding. Of
course this doesn't happen if 'x' is evaluated because
of an enclosing case (becuase then the inner case would
be eliminated) but it *does* happen when x is bound by
a constructor with a strict field. This happens in DPH.
Trac #4081.
The second change is a significant refactoring of the
way the let-floater works. Now SetLevels makes a decision
about whether the let (or case) will move, and records
that decision in the FloatSpec flag. This change makes
the whole caboodle much easier to think about.
| -rw-r--r-- | compiler/simplCore/FloatOut.lhs | 295 | ||||
| -rw-r--r-- | compiler/simplCore/SetLevels.lhs | 302 |
2 files changed, 342 insertions, 255 deletions
diff --git a/compiler/simplCore/FloatOut.lhs b/compiler/simplCore/FloatOut.lhs index e5db7d93ce..cf2d7245a7 100644 --- a/compiler/simplCore/FloatOut.lhs +++ b/compiler/simplCore/FloatOut.lhs @@ -16,10 +16,10 @@ import CoreMonad ( FloatOutSwitches(..) ) import DynFlags ( DynFlags, DynFlag(..) ) import ErrUtils ( dumpIfSet_dyn ) import CostCentre ( dupifyCC, CostCentre ) -import Id ( Id, idType, idArity, isBottomingId ) -import Type ( isUnLiftedType ) -import SetLevels ( Level(..), LevelledExpr, LevelledBind, - setLevels, isTopLvl ) +import DataCon ( DataCon ) +import Id ( Id, idArity, isBottomingId ) +import Var ( Var ) +import SetLevels import UniqSupply ( UniqSupply ) import Bag import Util @@ -132,13 +132,16 @@ floatOutwards float_sws dflags us pgm int ntlets, ptext (sLit " Lets floated elsewhere; from "), int lams, ptext (sLit " Lambda groups")]); - return (concat binds_s') + return (bagToList (unionManyBags binds_s')) } -floatTopBind :: LevelledBind -> (FloatStats, [CoreBind]) +floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind) floatTopBind bind - = case (floatBind bind) of { (fs, floats) -> - (fs, bagToList (flattenFloats floats)) } + = case (floatBind bind) of { (fs, floats, bind') -> + let float_bag = flattenTopFloats floats + in case bind' of + Rec prs -> (fs, unitBag (Rec (addTopFloatPairs float_bag prs))) + NonRec {} -> (fs, float_bag `snocBag` bind') } \end{code} %************************************************************************ @@ -148,45 +151,30 @@ floatTopBind bind %************************************************************************ \begin{code} -floatBind :: LevelledBind -> (FloatStats, FloatBinds) -floatBind (NonRec (TB var level) rhs) - = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') -> +floatBind :: LevelledBind -> (FloatStats, FloatBinds, CoreBind) +floatBind (NonRec (TB var _) rhs) + = case (floatExpr rhs) of { (fs, rhs_floats, rhs') -> -- A tiresome hack: -- see Note [Bottoming floats: eta expansion] in SetLevels let rhs'' | isBottomingId var = etaExpand (idArity var) rhs' | otherwise = rhs' - in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) } + in (fs, rhs_floats, NonRec var rhs'') } floatBind (Rec pairs) = case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) -> - -- NB: the rhs floats may contain references to the - -- bound things. For example - -- f = ...(let v = ...f... in b) ... - if not (isTopLvl dest_lvl) then - -- Find which bindings float out at least one lambda beyond this one - -- These ones can't mention the binders, because they couldn't - -- be escaping a major level if so. - -- The ones that are not going further can join the letrec; - -- they may not be mutually recursive but the occurrence analyser will - -- find that out. In our example we make a Rec thus: - -- v = ...f... - -- f = ... b ... - case (partitionByMajorLevel dest_lvl rhs_floats) of { (floats', heres) -> - (fs, floats' `plusFloats` unitFloat dest_lvl - (Rec (floatsToBindPairs heres new_pairs))) } - else - -- For top level, no need to partition; just make them all recursive - -- (And the partition wouldn't work because they'd all end up in floats') - (fs, unitFloat dest_lvl - (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs))) } + (fs, rhs_floats, Rec (concat new_pairs)) } where - (((TB _ dest_lvl), _) : _) = pairs - - do_pair (TB name level, rhs) - = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') -> - (fs, rhs_floats, (name, rhs')) } + do_pair (TB name spec, rhs) + | isTopLvl dest_lvl -- See Note [floatBind for top level] + = case (floatExpr rhs) of { (fs, rhs_floats, rhs') -> + (fs, emptyFloats, addTopFloatPairs (flattenTopFloats rhs_floats) [(name, rhs')])} + | otherwise + = case (floatBody dest_lvl rhs) of { (fs, rhs_floats, rhs') -> + (fs, rhs_floats, [(name, rhs')]) } + where + dest_lvl = floatSpecLevel spec --------------- floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b]) @@ -196,6 +184,16 @@ floatList f (a:as) = case f a of { (fs_a, binds_a, b) -> (fs_a `add_stats` fs_as, binds_a `plusFloats` binds_as, b:bs) }} \end{code} +Note [floatBind for top level] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +We may have a *nested* binding whose destination level is (FloatMe tOP_LEVEL), thus + letrec { foo <0,0> = .... (let bar<0,0> = .. in ..) .... } +The binding for bar will be in the "tops" part of the floating binds, +and thus not partioned by floatBody. + +We could perhaps get rid of the 'tops' component of the floating binds, +but this case works just as well. + %************************************************************************ @@ -204,94 +202,100 @@ floatList f (a:as) = case f a of { (fs_a, binds_a, b) -> %************************************************************************ \begin{code} -floatExpr, floatRhs, floatCaseAlt - :: Level - -> LevelledExpr - -> (FloatStats, FloatBinds, CoreExpr) - -floatCaseAlt lvl arg -- Used rec rhss, and case-alternative rhss - = case (floatExpr lvl arg) of { (fsa, floats, arg') -> - case (partitionByMajorLevel lvl floats) of { (floats', heres) -> - -- Dump bindings that aren't going to escape from a lambda; - -- in particular, we must dump the ones that are bound by - -- the rec or case alternative +floatBody :: Level + -> LevelledExpr + -> (FloatStats, FloatBinds, CoreExpr) + +floatBody lvl arg -- Used rec rhss, and case-alternative rhss + = case (floatExpr arg) of { (fsa, floats, arg') -> + case (partitionByLevel lvl floats) of { (floats', heres) -> + -- Dump bindings are bound here (fsa, floats', install heres arg') }} ----------------- -floatRhs lvl arg -- Used for nested non-rec rhss, and fn args - -- See Note [Floating out of RHS] - = floatExpr lvl arg - ------------------ -floatExpr _ (Var v) = (zeroStats, emptyFloats, Var v) -floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty) -floatExpr _ (Coercion co) = (zeroStats, emptyFloats, Coercion co) -floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit) +floatExpr :: LevelledExpr + -> (FloatStats, FloatBinds, CoreExpr) +floatExpr (Var v) = (zeroStats, emptyFloats, Var v) +floatExpr (Type ty) = (zeroStats, emptyFloats, Type ty) +floatExpr (Coercion co) = (zeroStats, emptyFloats, Coercion co) +floatExpr (Lit lit) = (zeroStats, emptyFloats, Lit lit) -floatExpr lvl (App e a) - = case (floatExpr lvl e) of { (fse, floats_e, e') -> - case (floatRhs lvl a) of { (fsa, floats_a, a') -> +floatExpr (App e a) + = case (floatExpr e) of { (fse, floats_e, e') -> + case (floatExpr a) of { (fsa, floats_a, a') -> (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }} -floatExpr _ lam@(Lam (TB _ lam_lvl) _) +floatExpr lam@(Lam (TB _ lam_spec) _) = let (bndrs_w_lvls, body) = collectBinders lam bndrs = [b | TB b _ <- bndrs_w_lvls] + bndr_lvl = floatSpecLevel lam_spec -- All the binders have the same level -- See SetLevels.lvlLamBndrs in - case (floatExpr lam_lvl body) of { (fs, floats, body1) -> - - -- Dump anything that is captured by this lambda - -- Eg \x -> ...(\y -> let v = <blah> in ...)... - -- We'll have the binding (v = <blah>) in the floats, - -- but must dump it at the lambda-x - case (partitionByLevel lam_lvl floats) of { (floats1, heres) -> - (add_to_stats fs floats1, floats1, mkLams bndrs (install heres body1)) - }} - -floatExpr lvl (Note note@(SCC cc) expr) - = case (floatExpr lvl expr) of { (fs, floating_defns, expr') -> + case (floatBody bndr_lvl body) of { (fs, floats, body') -> + (add_to_stats fs floats, floats, mkLams bndrs body') } + +floatExpr (Note note@(SCC cc) expr) + = case (floatExpr expr) of { (fs, floating_defns, expr') -> let -- Annotate bindings floated outwards past an scc expression -- with the cc. We mark that cc as "duplicated", though. - annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns in (fs, annotated_defns, Note note expr') } -floatExpr lvl (Note note expr) -- Other than SCCs - = case (floatExpr lvl expr) of { (fs, floating_defns, expr') -> +floatExpr (Note note expr) -- Other than SCCs + = case (floatExpr expr) of { (fs, floating_defns, expr') -> (fs, floating_defns, Note note expr') } -floatExpr lvl (Cast expr co) - = case (floatExpr lvl expr) of { (fs, floating_defns, expr') -> +floatExpr (Cast expr co) + = case (floatExpr expr) of { (fs, floating_defns, expr') -> (fs, floating_defns, Cast expr' co) } -floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body) - | isUnLiftedType (idType bndr) -- Treat unlifted lets just like a case - -- I.e. floatExpr for rhs, floatCaseAlt for body - = case floatExpr lvl rhs of { (_, rhs_floats, rhs') -> - case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') -> - (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }} - -floatExpr lvl (Let bind body) - = case (floatBind bind) of { (fsb, bind_floats) -> - case (floatExpr lvl body) of { (fse, body_floats, body') -> - case partitionByMajorLevel lvl (bind_floats `plusFloats` body_floats) - of { (floats, heres) -> - -- See Note [Avoiding unnecessary floating] - (add_stats fsb fse, floats, install heres body') } } } - -floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts) - = case floatExpr lvl scrut of { (fse, fde, scrut') -> - case floatList float_alt alts of { (fsa, fda, alts') -> - (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts') - }} +floatExpr (Let bind body) + = case bind_spec of + FloatMe dest_lvl + -> case (floatBind bind) of { (fsb, bind_floats, bind') -> + case (floatExpr body) of { (fse, body_floats, body') -> + ( add_stats fsb fse + , bind_floats `plusFloats` unitLetFloat dest_lvl bind' + `plusFloats` body_floats + , body') }} + + StayPut bind_lvl -- See Note [Avoiding unnecessary floating] + -> case (floatBind bind) of { (fsb, bind_floats, bind') -> + case (floatBody bind_lvl body) of { (fse, body_floats, body') -> + ( add_stats fsb fse + , bind_floats `plusFloats` body_floats + , Let bind' body') }} + where + bind_spec = case bind of + NonRec (TB _ s) _ -> s + Rec ((TB _ s, _) : _) -> s + Rec [] -> panic "floatExpr:rec" + +floatExpr (Case scrut (TB case_bndr case_spec) ty alts) + = case case_spec of + FloatMe dest_lvl -- Case expression moves + | [(DataAlt con, bndrs, rhs)] <- alts + -> case floatExpr scrut of { (fse, fde, scrut') -> + case floatExpr rhs of { (fsb, fdb, rhs') -> + let + float = unitCaseFloat dest_lvl scrut' + case_bndr con [b | TB b _ <- bndrs] + in + (add_stats fse fsb, fde `plusFloats` float `plusFloats` fdb, rhs') }} + | otherwise + -> pprPanic "Floating multi-case" (ppr alts) + + StayPut bind_lvl -- Case expression stays put + -> case floatExpr scrut of { (fse, fde, scrut') -> + case floatList (float_alt bind_lvl) alts of { (fsa, fda, alts') -> + (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts') + }} where - -- Use floatCaseAlt for the alternatives, so that we - -- don't gratuitiously float bindings out of the RHSs - float_alt (con, bs, rhs) - = case (floatCaseAlt case_lvl rhs) of { (fs, rhs_floats, rhs') -> + float_alt bind_lvl (con, bs, rhs) + = case (floatBody bind_lvl rhs) of { (fs, rhs_floats, rhs') -> (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) } \end{code} @@ -391,22 +395,40 @@ partitionByMajorLevel. \begin{code} -type FloatBind = CoreBind -- INVARIANT: a FloatBind is always lifted +data FloatBind + = FloatLet FloatLet + | FloatCase CoreExpr Id DataCon [Var] -- case e of y { C ys -> ... } -data FloatBinds = FB !(Bag FloatBind) -- Destined for top level - !MajorEnv -- Levels other than top - -- See Note [Representation of FloatBinds] +type FloatLet = CoreBind -- INVARIANT: a FloatLet is always lifted +type MajorEnv = M.IntMap MinorEnv -- Keyed by major level +type MinorEnv = M.IntMap (Bag FloatBind) -- Keyed by minor level -instance Outputable FloatBinds where - ppr (FB fbs env) = ptext (sLit "FB") <+> (braces $ vcat - [ ptext (sLit "binds =") <+> ppr fbs - , ptext (sLit "env =") <+> ppr env ]) +data FloatBinds = FB !(Bag FloatLet) -- Destined for top level + !MajorEnv -- Levels other than top + -- See Note [Representation of FloatBinds] -type MajorEnv = M.IntMap MinorEnv -- Keyed by major level -type MinorEnv = M.IntMap (Bag FloatBind) -- Keyed by minor level +instance Outputable FloatBind where + ppr (FloatLet b) = ptext (sLit "LET") <+> ppr b + ppr (FloatCase e b c bs) = hang (ptext (sLit "CASE") <+> ppr e <+> ptext (sLit "of") <+> ppr b) + 2 (ppr c <+> ppr bs) -flattenFloats :: FloatBinds -> Bag FloatBind -flattenFloats (FB tops others) = tops `unionBags` flattenMajor others +instance Outputable FloatBinds where + ppr (FB fbs defs) + = ptext (sLit "FB") <+> (braces $ vcat + [ ptext (sLit "tops =") <+> ppr fbs + , ptext (sLit "non-tops =") <+> ppr defs ]) + +flattenTopFloats :: FloatBinds -> Bag CoreBind +flattenTopFloats (FB tops defs) + = ASSERT2( isEmptyBag (flattenMajor defs), ppr defs ) + tops + +addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)] +addTopFloatPairs float_bag prs + = foldrBag add prs float_bag + where + add (NonRec b r) prs = (b,r):prs + add (Rec prs1) prs2 = prs1 ++ prs2 flattenMajor :: MajorEnv -> Bag FloatBind flattenMajor = M.fold (unionBags . flattenMinor) emptyBag @@ -417,13 +439,20 @@ flattenMinor = M.fold unionBags emptyBag emptyFloats :: FloatBinds emptyFloats = FB emptyBag M.empty -unitFloat :: Level -> FloatBind -> FloatBinds -unitFloat lvl@(Level major minor) b +unitCaseFloat :: Level -> CoreExpr -> Id -> DataCon -> [Var] -> FloatBinds +unitCaseFloat (Level major minor) e b con bs + = FB emptyBag (M.singleton major (M.singleton minor (unitBag (FloatCase e b con bs)))) + +unitLetFloat :: Level -> FloatLet -> FloatBinds +unitLetFloat lvl@(Level major minor) b | isTopLvl lvl = FB (unitBag b) M.empty - | otherwise = FB emptyBag (M.singleton major (M.singleton minor (unitBag b))) + | otherwise = FB emptyBag (M.singleton major (M.singleton minor floats)) + where + floats = unitBag (FloatLet b) plusFloats :: FloatBinds -> FloatBinds -> FloatBinds -plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2) +plusFloats (FB t1 l1) (FB t2 l2) + = FB (t1 `unionBags` t2) (l1 `plusMajor` l2) plusMajor :: MajorEnv -> MajorEnv -> MajorEnv plusMajor = M.unionWith plusMinor @@ -431,26 +460,27 @@ plusMajor = M.unionWith plusMinor plusMinor :: MinorEnv -> MinorEnv -> MinorEnv plusMinor = M.unionWith unionBags -floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)] -floatsToBindPairs floats binds = foldrBag add binds floats - where - add (Rec pairs) binds = pairs ++ binds - add (NonRec binder rhs) binds = (binder,rhs) : binds - install :: Bag FloatBind -> CoreExpr -> CoreExpr install defn_groups expr = foldrBag install_group expr defn_groups where - install_group defns body = Let defns body + install_group (FloatLet defns) body + = Let defns body + install_group (FloatCase e b con bs) body + = Case e b (exprType body) [(DataAlt con, bs, body)] -partitionByMajorLevel, partitionByLevel +partitionByLevel :: Level -- Partitioning level -> FloatBinds -- Defns to be divided into 2 piles... -> (FloatBinds, -- Defns with level strictly < partition level, Bag FloatBind) -- The rest +{- -- ---- partitionByMajorLevel ---- --- Float it if we escape a value lambda, *or* if we get to the top level +-- Float it if we escape a value lambda, +-- *or* if we get to the top level +-- *or* if it's a case-float and its minor level is < current +-- -- If we can get to the top level, say "yes" anyway. This means that -- x = f e -- transforms to @@ -465,6 +495,7 @@ partitionByMajorLevel (Level major _) (FB tops defns) heres = case mb_heres of Nothing -> emptyBag Just h -> flattenMinor h +-} partitionByLevel (Level major minor) (FB tops defns) = (FB tops (outer_maj `plusMajor` M.singleton major outer_min), @@ -480,9 +511,13 @@ partitionByLevel (Level major minor) (FB tops defns) wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds wrapCostCentre cc (FB tops defns) - = FB (wrap_defns tops) (M.map (M.map wrap_defns) defns) + = FB (mapBag wrap_bind tops) (M.map (M.map wrap_defns) defns) where wrap_defns = mapBag wrap_one - wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs) - wrap_one (Rec pairs) = Rec (mapSnd (mkSCC cc) pairs) + + wrap_bind (NonRec binder rhs) = NonRec binder (mkSCC cc rhs) + wrap_bind (Rec pairs) = Rec (mapSnd (mkSCC cc) pairs) + + wrap_one (FloatLet bind) = FloatLet (wrap_bind bind) + wrap_one (FloatCase e b c bs) = FloatCase (mkSCC cc e) b c bs \end{code} diff --git a/compiler/simplCore/SetLevels.lhs b/compiler/simplCore/SetLevels.lhs index 21dca615c3..87c8b3d2d8 100644 --- a/compiler/simplCore/SetLevels.lhs +++ b/compiler/simplCore/SetLevels.lhs @@ -46,7 +46,8 @@ module SetLevels ( setLevels, Level(..), tOP_LEVEL, - LevelledBind, LevelledExpr, + LevelledBind, LevelledExpr, LevelledBndr, + FloatSpec(..), floatSpecLevel, incMinorLvl, ltMajLvl, ltLvl, isTopLvl ) where @@ -55,10 +56,10 @@ module SetLevels ( import CoreSyn import CoreMonad ( FloatOutSwitches(..) ) -import CoreUtils ( exprType, mkPiTypes ) +import CoreUtils ( exprType, exprOkForSpeculation, mkPiTypes ) import CoreArity ( exprBotStrictness_maybe ) import CoreFVs -- all of it -import CoreSubst ( Subst, emptySubst, extendInScope, extendInScopeList, +import CoreSubst ( Subst, emptySubst, extendInScope, substBndr, substRecBndrs, extendIdSubst, cloneIdBndr, cloneRecIdBndrs ) import Id import IdInfo @@ -69,7 +70,7 @@ import Demand ( StrictSig, increaseStrictSigArity ) import Name ( getOccName, mkSystemVarName ) import OccName ( occNameString ) import Type ( isUnLiftedType, Type ) -import BasicTypes ( TopLevelFlag(..), Arity ) +import BasicTypes ( Arity ) import UniqSupply import Util import Outputable @@ -83,9 +84,23 @@ import FastString %************************************************************************ \begin{code} +type LevelledExpr = TaggedExpr FloatSpec +type LevelledBind = TaggedBind FloatSpec +type LevelledBndr = TaggedBndr FloatSpec + data Level = Level Int -- Level number of enclosing lambdas Int -- Number of big-lambda and/or case expressions between -- here and the nearest enclosing lambda + +data FloatSpec + = FloatMe Level -- Float to just inside the binding + -- tagged with this level + | StayPut Level -- Stay where it is; binding is + -- tagged with tihs level + +floatSpecLevel :: FloatSpec -> Level +floatSpecLevel (FloatMe l) = l +floatSpecLevel (StayPut l) = l \end{code} The {\em level number} on a (type-)lambda-bound variable is the @@ -143,8 +158,9 @@ inlined into the floated expression, and an importing module won't see the worker at all. \begin{code} -type LevelledExpr = TaggedExpr Level -type LevelledBind = TaggedBind Level +instance Outputable FloatSpec where + ppr (FloatMe l) = char 'F' <> ppr l + ppr (StayPut l) = ppr l tOP_LEVEL :: Level tOP_LEVEL = Level 0 0 @@ -205,12 +221,18 @@ setLevels float_lams binds us ; return (lvld_bind : lvld_binds) } lvlTopBind :: LevelEnv -> Bind Id -> LvlM (LevelledBind, LevelEnv) -lvlTopBind env (NonRec binder rhs) - = lvlBind TopLevel tOP_LEVEL env (AnnNonRec binder (freeVars rhs)) - -- Rhs can have no free vars! +lvlTopBind env (NonRec bndr rhs) + = do rhs' <- lvlExpr tOP_LEVEL env (freeVars rhs) + let bndr' = TB bndr (StayPut tOP_LEVEL) + env' = extendLvlEnv env [bndr'] + return (NonRec bndr' rhs', env') lvlTopBind env (Rec pairs) - = lvlBind TopLevel tOP_LEVEL env (AnnRec [(b,freeVars rhs) | (b,rhs) <- pairs]) + = do let (bndrs,rhss) = unzip pairs + bndrs' = [TB b (StayPut tOP_LEVEL) | b <- bndrs] + env' = extendLvlEnv env bndrs' + rhss' <- mapM (lvlExpr tOP_LEVEL env' . freeVars) rhss + return (Rec (bndrs' `zip` rhss'), env') \end{code} %************************************************************************ @@ -313,41 +335,37 @@ lvlExpr ctxt_lvl env expr@(_, AnnLam {}) = do -- but not nearly so much now non-recursive newtypes are transparent. -- [See SetLevels rev 1.50 for a version with this approach.] -lvlExpr ctxt_lvl env (_, AnnLet (AnnNonRec bndr rhs) body) - | isUnLiftedType (idType bndr) = do - -- Treat unlifted let-bindings (let x = b in e) just like (case b of x -> e) - -- That is, leave it exactly where it is - -- We used to float unlifted bindings too (e.g. to get a cheap primop - -- outside a lambda (to see how, look at lvlBind in rev 1.58) - -- but an unrelated change meant that these unlifed bindings - -- could get to the top level which is bad. And there's not much point; - -- unlifted bindings are always cheap, and so hardly worth floating. - rhs' <- lvlExpr ctxt_lvl env rhs - body' <- lvlExpr incd_lvl env' body - return (Let (NonRec bndr' rhs') body') - where - incd_lvl = incMinorLvl ctxt_lvl - bndr' = TB bndr incd_lvl - env' = extendLvlEnv env [bndr'] - lvlExpr ctxt_lvl env (_, AnnLet bind body) = do - (bind', new_env) <- lvlBind NotTopLevel ctxt_lvl env bind - body' <- lvlExpr ctxt_lvl new_env body + (bind', new_lvl, new_env) <- lvlBind ctxt_lvl env bind + body' <- lvlExpr new_lvl new_env body return (Let bind' body') -lvlExpr ctxt_lvl env (_, AnnCase expr case_bndr ty alts) = do - expr' <- lvlMFE True ctxt_lvl env expr - let alts_env = extendCaseBndrLvlEnv env expr' case_bndr incd_lvl - alts' <- mapM (lvl_alt alts_env) alts - return (Case expr' (TB case_bndr incd_lvl) ty alts') +lvlExpr ctxt_lvl env (_, AnnCase scrut@(scrut_fvs,_) case_bndr ty alts) + = do { scrut' <- lvlMFE True ctxt_lvl env scrut + ; let case_bndr' = TB case_bndr bndr_spec + alts_env = extendCaseBndrLvlEnv env scrut' case_bndr' + ; alts' <- mapM (lvl_alt alts_env) alts + ; return (Case scrut' case_bndr' ty alts') } where - incd_lvl = incMinorLvl ctxt_lvl - - lvl_alt alts_env (con, bs, rhs) = do - rhs' <- lvlMFE True incd_lvl new_env rhs - return (con, bs', rhs') + incd_lvl = incMinorLvl ctxt_lvl + dest_lvl = maxFvLevel (const True) env scrut_fvs + + alt_ctxt_lvl :: Level + bndr_spec :: FloatSpec + (alt_ctxt_lvl, bndr_spec) + | [(DataAlt _, _, _)] <- alts + , exprOkForSpeculation (deAnnotate scrut) + , not (isTopLvl dest_lvl) -- Can't have top-level cases + = (ctxt_lvl, FloatMe dest_lvl) + -- Don't abstact over type variables, hence const True + | otherwise + = (incd_lvl, StayPut incd_lvl) + + lvl_alt alts_env (con, bs, rhs) + = do { rhs' <- lvlMFE True alt_ctxt_lvl new_env rhs + ; return (con, bs', rhs') } where - bs' = [ TB b incd_lvl | b <- bs ] + bs' = [ TB b bndr_spec | b <- bs ] new_env = extendLvlEnv alts_env bs' \end{code} @@ -428,14 +446,14 @@ lvlMFE strict_ctxt ctxt_lvl env ann_expr@(fvs, _) -- This includes coercions, which we don't -- want to float anyway || notWorthFloating ann_expr abs_vars - || not good_destination + || not float_me = -- Don't float it out lvlExpr ctxt_lvl env ann_expr | otherwise -- Float it out! = do expr' <- lvlFloatRhs abs_vars dest_lvl env ann_expr var <- newLvlVar abs_vars ty mb_bot - return (Let (NonRec (TB var dest_lvl) expr') + return (Let (NonRec (TB var (FloatMe dest_lvl)) expr') (mkVarApps (Var var) abs_vars)) where expr = deAnnotate ann_expr @@ -446,16 +464,13 @@ lvlMFE strict_ctxt ctxt_lvl env ann_expr@(fvs, _) -- A decision to float entails let-binding this thing, and we only do -- that if we'll escape a value lambda, or will go to the top level. - good_destination - | dest_lvl `ltMajLvl` ctxt_lvl -- Escapes a value lambda - = True - -- OLD CODE: not (exprIsCheap expr) || isTopLvl dest_lvl - -- see Note [Escaping a value lambda] - - | otherwise -- Does not escape a value lambda - = isTopLvl dest_lvl -- Only float if we are going to the top level - && floatConsts env -- and the floatConsts flag is on - && not strict_ctxt -- Don't float from a strict context + float_me = dest_lvl `ltMajLvl` ctxt_lvl -- Escapes a value lambda + -- OLD CODE: not (exprIsCheap expr) || isTopLvl dest_lvl + -- see Note [Escaping a value lambda] + + || (isTopLvl dest_lvl -- Only float if we are going to the top level + && floatConsts env -- and the floatConsts flag is on + && not strict_ctxt) -- Don't float from a strict context -- We are keen to float something to the top level, even if it does not -- escape a lambda, because then it needs no allocation. But it's controlled -- by a flag, because doing this too early loses opportunities for RULES @@ -465,9 +480,12 @@ lvlMFE strict_ctxt ctxt_lvl env ann_expr@(fvs, _) -- Beware: -- concat = /\ a -> foldr ..a.. (++) [] -- was getting turned into - -- concat = /\ a -> lvl a -- lvl = /\ a -> foldr ..a.. (++) [] + -- concat = /\ a -> lvl a -- which is pretty stupid. Hence the strict_ctxt test + -- + -- Also a strict contxt includes uboxed values, and they + -- can't be bound at top level annotateBotStr :: Id -> Maybe (Arity, StrictSig) -> Id annotateBotStr id Nothing = id @@ -560,30 +578,39 @@ OLD comment was: The binding stuff works for top level too. \begin{code} -lvlBind :: TopLevelFlag -- Used solely to decide whether to clone - -> Level -- Context level; might be Top even for bindings nested in the RHS - -- of a top level binding +lvlBind :: Level -- Context level; might be Top even for bindings + -- nested in the RHS of a top level binding -> LevelEnv -> CoreBindWithFVs - -> LvlM (LevelledBind, LevelEnv) - -lvlBind top_lvl ctxt_lvl env (AnnNonRec bndr rhs@(rhs_fvs,_)) - | isTyVar bndr -- Don't do anything for TyVar binders - -- (simplifier gets rid of them pronto) - = do rhs' <- lvlExpr ctxt_lvl env rhs - return (NonRec (TB bndr ctxt_lvl) rhs', env) - + -> LvlM (LevelledBind, Level, LevelEnv) + +lvlBind ctxt_lvl env (AnnNonRec bndr rhs@(rhs_fvs,_)) + | isTyVar bndr -- Don't do anything for TyVar binders + -- (simplifier gets rid of them pronto) + || not (profitableFloat ctxt_lvl dest_lvl) + || (isTopLvl dest_lvl && isUnLiftedType (idType bndr)) + -- We can't float an unlifted binding to top level, so we don't + -- float it at all. It's a bit brutal, but unlifted bindings + -- aren't expensive either + = -- No float + do rhs' <- lvlExpr ctxt_lvl env rhs + let (env', bndr') = substLetBndrNonRec env bndr bind_lvl + bind_lvl = incMinorLvl ctxt_lvl + tagged_bndr = TB bndr' (StayPut bind_lvl) + return (NonRec tagged_bndr rhs', bind_lvl, env') + + -- Otherwise we are going to float | null abs_vars = do -- No type abstraction; clone existing binder rhs' <- lvlExpr dest_lvl env rhs - (env', bndr') <- cloneVar top_lvl env bndr ctxt_lvl dest_lvl - return (NonRec (TB bndr' dest_lvl) rhs', env') + (env', bndr') <- cloneVar env bndr ctxt_lvl dest_lvl + return (NonRec (TB bndr' (FloatMe dest_lvl)) rhs', ctxt_lvl, env') | otherwise = do -- Yes, type abstraction; create a new binder, extend substitution, etc rhs' <- lvlFloatRhs abs_vars dest_lvl env rhs (env', [bndr']) <- newPolyBndrs dest_lvl env abs_vars [bndr_w_str] - return (NonRec (TB bndr' dest_lvl) rhs', env') + return (NonRec (TB bndr' (FloatMe dest_lvl)) rhs', ctxt_lvl, env') where bind_fvs = rhs_fvs `unionVarSet` idFreeVars bndr @@ -591,15 +618,21 @@ lvlBind top_lvl ctxt_lvl env (AnnNonRec bndr rhs@(rhs_fvs,_)) dest_lvl = destLevel env bind_fvs (isFunction rhs) mb_bot mb_bot = exprBotStrictness_maybe (deAnnotate rhs) bndr_w_str = annotateBotStr bndr mb_bot -\end{code} +lvlBind ctxt_lvl env (AnnRec pairs) + | not (profitableFloat ctxt_lvl dest_lvl) + = do let bind_lvl = incMinorLvl ctxt_lvl + (env', bndrs') = substLetBndrsRec env bndrs bind_lvl + tagged_bndrs = [ TB bndr' (StayPut bind_lvl) + | bndr' <- bndrs' ] + rhss' <- mapM (lvlExpr bind_lvl env') rhss + return (Rec (tagged_bndrs `zip` rhss'), bind_lvl, env') -\begin{code} -lvlBind top_lvl ctxt_lvl env (AnnRec pairs) | null abs_vars - = do (new_env, new_bndrs) <- cloneRecVars top_lvl env bndrs ctxt_lvl dest_lvl + = do (new_env, new_bndrs) <- cloneRecVars env bndrs ctxt_lvl dest_lvl new_rhss <- mapM (lvlExpr ctxt_lvl new_env) rhss - return (Rec ([TB b dest_lvl | b <- new_bndrs] `zip` new_rhss), new_env) + return ( Rec ([TB b (FloatMe dest_lvl) | b <- new_bndrs] `zip` new_rhss) + , ctxt_lvl, new_env) -- ToDo: when enabling the floatLambda stuff, -- I think we want to stop doing this @@ -618,42 +651,50 @@ lvlBind top_lvl ctxt_lvl env (AnnRec pairs) (bndr,rhs) = head pairs (rhs_lvl, abs_vars_w_lvls) = lvlLamBndrs dest_lvl abs_vars rhs_env = extendLvlEnv env abs_vars_w_lvls - (rhs_env', new_bndr) <- cloneVar NotTopLevel rhs_env bndr rhs_lvl rhs_lvl + (rhs_env', new_bndr) <- cloneVar rhs_env bndr rhs_lvl rhs_lvl let (lam_bndrs, rhs_body) = collectAnnBndrs rhs (body_lvl, new_lam_bndrs) = lvlLamBndrs rhs_lvl lam_bndrs body_env = extendLvlEnv rhs_env' new_lam_bndrs new_rhs_body <- lvlExpr body_lvl body_env rhs_body (poly_env, [poly_bndr]) <- newPolyBndrs dest_lvl env abs_vars [bndr] - return (Rec [(TB poly_bndr dest_lvl, - mkLams abs_vars_w_lvls $ - mkLams new_lam_bndrs $ - Let (Rec [(TB new_bndr rhs_lvl, mkLams new_lam_bndrs new_rhs_body)]) - (mkVarApps (Var new_bndr) lam_bndrs))], - poly_env) + return (Rec [(TB poly_bndr (FloatMe dest_lvl) + , mkLams abs_vars_w_lvls $ + mkLams new_lam_bndrs $ + Let (Rec [( TB new_bndr (StayPut rhs_lvl) + , mkLams new_lam_bndrs new_rhs_body)]) + (mkVarApps (Var new_bndr) lam_bndrs))] + , ctxt_lvl + , poly_env) | otherwise = do -- Non-null abs_vars (new_env, new_bndrs) <- newPolyBndrs dest_lvl env abs_vars bndrs new_rhss <- mapM (lvlFloatRhs abs_vars dest_lvl new_env) rhss - return (Rec ([TB b dest_lvl | b <- new_bndrs] `zip` new_rhss), new_env) + return ( Rec ([TB b (FloatMe dest_lvl) | b <- new_bndrs] `zip` new_rhss) + , ctxt_lvl, new_env) where (bndrs,rhss) = unzip pairs -- Finding the free vars of the binding group is annoying - bind_fvs = (unionVarSets [ idFreeVars bndr `unionVarSet` rhs_fvs - | (bndr, (rhs_fvs,_)) <- pairs]) - `minusVarSet` - mkVarSet bndrs + bind_fvs = (unionVarSets [ idFreeVars bndr `unionVarSet` rhs_fvs + | (bndr, (rhs_fvs,_)) <- pairs]) + `minusVarSet` + mkVarSet bndrs dest_lvl = destLevel env bind_fvs (all isFunction rhss) Nothing abs_vars = abstractVars dest_lvl env bind_fvs +profitableFloat :: Level -> Level -> Bool +profitableFloat ctxt_lvl dest_lvl + = (dest_lvl `ltMajLvl` ctxt_lvl) -- Escapes a value lambda + || isTopLvl dest_lvl -- Going all the way to top level + ---------------------------------------------------- -- Three help functions for the type-abstraction case lvlFloatRhs :: [CoreBndr] -> Level -> LevelEnv -> CoreExprWithFVs - -> UniqSM (Expr (TaggedBndr Level)) + -> UniqSM (Expr LevelledBndr) lvlFloatRhs abs_vars dest_lvl env rhs = do rhs' <- lvlExpr rhs_lvl rhs_env rhs return (mkLams abs_vars_w_lvls rhs') @@ -670,7 +711,7 @@ lvlFloatRhs abs_vars dest_lvl env rhs = do %************************************************************************ \begin{code} -lvlLamBndrs :: Level -> [CoreBndr] -> (Level, [TaggedBndr Level]) +lvlLamBndrs :: Level -> [CoreBndr] -> (Level, [LevelledBndr]) -- Compute the levels for the binders of a lambda group -- The binders returned are exactly the same as the ones passed, -- but they are now paired with a level @@ -678,7 +719,7 @@ lvlLamBndrs lvl [] = (lvl, []) lvlLamBndrs lvl bndrs - = (new_lvl, [TB bndr new_lvl | bndr <- bndrs]) + = (new_lvl, [TB bndr (StayPut new_lvl) | bndr <- bndrs]) -- All the new binders get the same level, because -- any floating binding is either going to float past -- all or none. We never separate binders @@ -701,8 +742,9 @@ destLevel env fvs is_function mb_bot , is_function , countFreeIds fvs <= n_args = tOP_LEVEL -- Send functions to top level; see - -- the comments with isFunction - | otherwise = maxIdLevel env fvs + -- the comments with isFunction + | otherwise = maxFvLevel isId env fvs -- Max over Ids only; the tyvars + -- will be abstracted isFunction :: CoreExprWithFVs -> Bool -- The idea here is that we want to float *functions* to @@ -782,7 +824,7 @@ floatConsts le = floatOutConstants (le_switches le) floatPAPs :: LevelEnv -> Bool floatPAPs le = floatOutPartialApplications (le_switches le) -extendLvlEnv :: LevelEnv -> [TaggedBndr Level] -> LevelEnv +extendLvlEnv :: LevelEnv -> [LevelledBndr] -> LevelEnv -- Used when *not* cloning extendLvlEnv le@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) prs @@ -790,7 +832,7 @@ extendLvlEnv le@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) , le_subst = foldl del_subst subst prs , le_env = foldl del_id id_env prs } where - add_lvl env (TB v l) = extendVarEnv env v l + add_lvl env (TB v s) = extendVarEnv env v (floatSpecLevel s) del_subst env (TB v _) = extendInScope env v del_id env (TB v _) = delVarEnv env v -- We must remove any clone for this variable name in case of @@ -807,26 +849,17 @@ extendLvlEnv le@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) -- incorrectly, because the SubstEnv was still lying around. Ouch! -- KSW 2000-07. -extendInScopeEnv :: LevelEnv -> Var -> LevelEnv -extendInScopeEnv le@(LE { le_subst = subst }) v - = le { le_subst = extendInScope subst v } - -extendInScopeEnvList :: LevelEnv -> [Var] -> LevelEnv -extendInScopeEnvList le@(LE { le_subst = subst }) vs - = le { le_subst = extendInScopeList subst vs } - -- extendCaseBndrLvlEnv adds the mapping case-bndr->scrut-var if it can -- (see point 4 of the module overview comment) -extendCaseBndrLvlEnv :: LevelEnv -> Expr (TaggedBndr Level) -> Var -> Level - -> LevelEnv -extendCaseBndrLvlEnv le@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) - (Var scrut_var) case_bndr lvl - = le { le_lvl_env = extendVarEnv lvl_env case_bndr lvl - , le_subst = extendIdSubst subst case_bndr (Var scrut_var) +extendCaseBndrLvlEnv :: LevelEnv -> Expr LevelledBndr + -> LevelledBndr -> LevelEnv +extendCaseBndrLvlEnv le@(LE { le_subst = subst, le_env = id_env }) + (Var scrut_var) (TB case_bndr _) + = le { le_subst = extendIdSubst subst case_bndr (Var scrut_var) , le_env = extendVarEnv id_env case_bndr ([scrut_var], Var scrut_var) } -extendCaseBndrLvlEnv env _scrut case_bndr lvl - = extendLvlEnv env [TB case_bndr lvl] +extendCaseBndrLvlEnv env _scrut case_bndr + = extendLvlEnv env [case_bndr] extendPolyLvlEnv :: Level -> LevelEnv -> [Var] -> [(Var, Var)] -> LevelEnv extendPolyLvlEnv dest_lvl @@ -843,26 +876,27 @@ extendPolyLvlEnv dest_lvl extendCloneLvlEnv :: Level -> LevelEnv -> Subst -> [(Var, Var)] -> LevelEnv extendCloneLvlEnv lvl le@(LE { le_lvl_env = lvl_env, le_env = id_env }) new_subst bndr_pairs - = le { le_lvl_env = foldl add_lvl lvl_env bndr_pairs + = le { le_lvl_env = foldl add_lvl lvl_env bndr_pairs , le_subst = new_subst - , le_env = foldl add_id id_env bndr_pairs } + , le_env = foldl add_id id_env bndr_pairs } where add_lvl env (_, v') = extendVarEnv env v' lvl add_id env (v, v') = extendVarEnv env v ([v'], Var v') -maxIdLevel :: LevelEnv -> VarSet -> Level -maxIdLevel (LE { le_lvl_env = lvl_env, le_env = id_env }) var_set +maxFvLevel :: (Var -> Bool) -> LevelEnv -> VarSet -> Level +maxFvLevel max_me (LE { le_lvl_env = lvl_env, le_env = id_env }) var_set = foldVarSet max_in tOP_LEVEL var_set where - max_in in_var lvl = foldr max_out lvl (case lookupVarEnv id_env in_var of - Just (abs_vars, _) -> abs_vars - Nothing -> [in_var]) + max_in in_var lvl + = foldr max_out lvl (case lookupVarEnv id_env in_var of + Just (abs_vars, _) -> abs_vars + Nothing -> [in_var]) max_out out_var lvl - | isId out_var = case lookupVarEnv lvl_env out_var of + | max_me out_var = case lookupVarEnv lvl_env out_var of Just lvl' -> maxLvl lvl' lvl Nothing -> lvl - | otherwise = lvl -- Ignore tyvars in *maxIdLevel* + | otherwise = lvl -- Ignore some vars depending on max_me lookupVar :: LevelEnv -> Id -> LevelledExpr lookupVar le v = case lookupVarEnv (le_env le) v of @@ -967,12 +1001,32 @@ newLvlVar vars body_ty mb_bot -- The deeply tiresome thing is that we have to apply the substitution -- to the rules inside each Id. Grr. But it matters. -cloneVar :: TopLevelFlag -> LevelEnv -> Id -> Level -> Level -> LvlM (LevelEnv, Id) -cloneVar TopLevel env v _ _ - = return (extendInScopeEnv env v, v) -- Don't clone top level things - -- But do extend the in-scope env, to satisfy the in-scope invariant +substLetBndrNonRec :: LevelEnv -> Id -> Level -> (LevelEnv, Id) +substLetBndrNonRec + le@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) + bndr bind_lvl + = ASSERT( isId bndr ) + (env', bndr' ) + where + (subst', bndr') = substBndr subst bndr + env' = le { le_lvl_env = extendVarEnv lvl_env bndr bind_lvl + , le_subst = subst' + , le_env = delVarEnv id_env bndr } + +substLetBndrsRec :: LevelEnv -> [Id] -> Level -> (LevelEnv, [Id]) +substLetBndrsRec + le@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env }) + bndrs bind_lvl + = ASSERT( all isId bndrs ) + (env', bndrs') + where + (subst', bndrs') = substRecBndrs subst bndrs + env' = le { le_lvl_env = extendVarEnvList lvl_env [(b,bind_lvl) | b <- bndrs] + , le_subst = subst' + , le_env = delVarEnvList id_env bndrs } -cloneVar NotTopLevel env v ctxt_lvl dest_lvl +cloneVar :: LevelEnv -> Id -> Level -> Level -> LvlM (LevelEnv, Id) +cloneVar env v ctxt_lvl dest_lvl = ASSERT( isId v ) do us <- getUniqueSupplyM let @@ -981,10 +1035,8 @@ cloneVar NotTopLevel env v ctxt_lvl dest_lvl env' = extendCloneLvlEnv dest_lvl env subst' [(v,v2)] return (env', v2) -cloneRecVars :: TopLevelFlag -> LevelEnv -> [Id] -> Level -> Level -> LvlM (LevelEnv, [Id]) -cloneRecVars TopLevel env vs _ _ - = return (extendInScopeEnvList env vs, vs) -- Don't clone top level things -cloneRecVars NotTopLevel env vs ctxt_lvl dest_lvl +cloneRecVars :: LevelEnv -> [Id] -> Level -> Level -> LvlM (LevelEnv, [Id]) +cloneRecVars env vs ctxt_lvl dest_lvl = ASSERT( all isId vs ) do us <- getUniqueSupplyM let |
