diff options
Diffstat (limited to 'compiler/GHC/Stg/Lift')
-rw-r--r-- | compiler/GHC/Stg/Lift/Analysis.hs | 565 | ||||
-rw-r--r-- | compiler/GHC/Stg/Lift/Monad.hs | 348 |
2 files changed, 913 insertions, 0 deletions
diff --git a/compiler/GHC/Stg/Lift/Analysis.hs b/compiler/GHC/Stg/Lift/Analysis.hs new file mode 100644 index 0000000000..02d439cef7 --- /dev/null +++ b/compiler/GHC/Stg/Lift/Analysis.hs @@ -0,0 +1,565 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DataKinds #-} + +-- | Provides the heuristics for when it's beneficial to lambda lift bindings. +-- Most significantly, this employs a cost model to estimate impact on heap +-- allocations, by looking at an STG expression's 'Skeleton'. +module GHC.Stg.Lift.Analysis ( + -- * #when# When to lift + -- $when + + -- * #clogro# Estimating closure growth + -- $clogro + + -- * AST annotation + Skeleton(..), BinderInfo(..), binderInfoBndr, + LlStgBinding, LlStgExpr, LlStgRhs, LlStgAlt, tagSkeletonTopBind, + -- * Lifting decision + goodToLift, + closureGrowth -- Exported just for the docs + ) where + +import GhcPrelude + +import BasicTypes +import Demand +import DynFlags +import Id +import SMRep ( WordOff ) +import GHC.Stg.Syntax +import qualified GHC.StgToCmm.ArgRep as StgToCmm.ArgRep +import qualified GHC.StgToCmm.Closure as StgToCmm.Closure +import qualified GHC.StgToCmm.Layout as StgToCmm.Layout +import Outputable +import Util +import VarSet + +import Data.Maybe ( mapMaybe ) + +-- Note [When to lift] +-- ~~~~~~~~~~~~~~~~~~~ +-- $when +-- The analysis proceeds in two steps: +-- +-- 1. It tags the syntax tree with analysis information in the form of +-- 'BinderInfo' at each binder and 'Skeleton's at each let-binding +-- by 'tagSkeletonTopBind' and friends. +-- 2. The resulting syntax tree is treated by the "GHC.Stg.Lift" +-- module, calling out to 'goodToLift' to decide if a binding is worthwhile +-- to lift. +-- 'goodToLift' consults argument occurrence information in 'BinderInfo' +-- and estimates 'closureGrowth', for which it needs the 'Skeleton'. +-- +-- So the annotations from 'tagSkeletonTopBind' ultimately fuel 'goodToLift', +-- which employs a number of heuristics to identify and exclude lambda lifting +-- opportunities deemed non-beneficial: +-- +-- [Top-level bindings] can't be lifted. +-- [Thunks] and data constructors shouldn't be lifted in order not to destroy +-- sharing. +-- [Argument occurrences] #arg_occs# of binders prohibit them to be lifted. +-- Doing the lift would re-introduce the very allocation at call sites that +-- we tried to get rid off in the first place. We capture analysis +-- information in 'BinderInfo'. Note that we also consider a nullary +-- application as argument occurrence, because it would turn into an n-ary +-- partial application created by a generic apply function. This occurs in +-- CPS-heavy code like the CS benchmark. +-- [Join points] should not be lifted, simply because there's no reduction in +-- allocation to be had. +-- [Abstracting over join points] destroys join points, because they end up as +-- arguments to the lifted function. +-- [Abstracting over known local functions] turns a known call into an unknown +-- call (e.g. some @stg_ap_*@), which is generally slower. Can be turned off +-- with @-fstg-lift-lams-known@. +-- [Calling convention] Don't lift when the resulting function would have a +-- higher arity than available argument registers for the calling convention. +-- Can be influenced with @-fstg-lift-(non)rec-args(-any)@. +-- [Closure growth] introduced when former free variables have to be available +-- at call sites may actually lead to an increase in overall allocations +-- resulting from a lift. Estimating closure growth is described in +-- "GHC.Stg.Lift.Analysis#clogro" and is what most of this module is ultimately +-- concerned with. +-- +-- There's a <https://gitlab.haskell.org/ghc/ghc/wikis/late-lam-lift wiki page> with +-- some more background and history. + +-- Note [Estimating closure growth] +-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +-- $clogro +-- We estimate closure growth by abstracting the syntax tree into a 'Skeleton', +-- capturing only syntactic details relevant to 'closureGrowth', such as +-- +-- * 'ClosureSk', representing closure allocation. +-- * 'RhsSk', representing a RHS of a binding and how many times it's called +-- by an appropriate 'DmdShell'. +-- * 'AltSk', 'BothSk' and 'NilSk' for choice, sequence and empty element. +-- +-- This abstraction is mostly so that the main analysis function 'closureGrowth' +-- can stay simple and focused. Also, skeletons tend to be much smaller than +-- the syntax tree they abstract, so it makes sense to construct them once and +-- and operate on them instead of the actual syntax tree. +-- +-- A more detailed treatment of computing closure growth, including examples, +-- can be found in the paper referenced from the +-- <https://gitlab.haskell.org/ghc/ghc/wikis/late-lam-lift wiki page>. + +llTrace :: String -> SDoc -> a -> a +llTrace _ _ c = c +-- llTrace a b c = pprTrace a b c + +type instance BinderP 'LiftLams = BinderInfo +type instance XRhsClosure 'LiftLams = DIdSet +type instance XLet 'LiftLams = Skeleton +type instance XLetNoEscape 'LiftLams = Skeleton + +freeVarsOfRhs :: (XRhsClosure pass ~ DIdSet) => GenStgRhs pass -> DIdSet +freeVarsOfRhs (StgRhsCon _ _ args) = mkDVarSet [ id | StgVarArg id <- args ] +freeVarsOfRhs (StgRhsClosure fvs _ _ _ _) = fvs + +-- | Captures details of the syntax tree relevant to the cost model, such as +-- closures, multi-shot lambdas and case expressions. +data Skeleton + = ClosureSk !Id !DIdSet {- ^ free vars -} !Skeleton + | RhsSk !DmdShell {- ^ how often the RHS was entered -} !Skeleton + | AltSk !Skeleton !Skeleton + | BothSk !Skeleton !Skeleton + | NilSk + +bothSk :: Skeleton -> Skeleton -> Skeleton +bothSk NilSk b = b +bothSk a NilSk = a +bothSk a b = BothSk a b + +altSk :: Skeleton -> Skeleton -> Skeleton +altSk NilSk b = b +altSk a NilSk = a +altSk a b = AltSk a b + +rhsSk :: DmdShell -> Skeleton -> Skeleton +rhsSk _ NilSk = NilSk +rhsSk body_dmd skel = RhsSk body_dmd skel + +-- | The type used in binder positions in 'GenStgExpr's. +data BinderInfo + = BindsClosure !Id !Bool -- ^ Let(-no-escape)-bound thing with a flag + -- indicating whether it occurs as an argument + -- or in a nullary application + -- (see "GHC.Stg.Lift.Analysis#arg_occs"). + | BoringBinder !Id -- ^ Every other kind of binder + +-- | Gets the bound 'Id' out a 'BinderInfo'. +binderInfoBndr :: BinderInfo -> Id +binderInfoBndr (BoringBinder bndr) = bndr +binderInfoBndr (BindsClosure bndr _) = bndr + +-- | Returns 'Nothing' for 'BoringBinder's and 'Just' the flag indicating +-- occurrences as argument or in a nullary applications otherwise. +binderInfoOccursAsArg :: BinderInfo -> Maybe Bool +binderInfoOccursAsArg BoringBinder{} = Nothing +binderInfoOccursAsArg (BindsClosure _ b) = Just b + +instance Outputable Skeleton where + ppr NilSk = text "" + ppr (AltSk l r) = vcat + [ text "{ " <+> ppr l + , text "ALT" + , text " " <+> ppr r + , text "}" + ] + ppr (BothSk l r) = ppr l $$ ppr r + ppr (ClosureSk f fvs body) = ppr f <+> ppr fvs $$ nest 2 (ppr body) + ppr (RhsSk body_dmd body) = hcat + [ text "λ[" + , ppr str + , text ", " + , ppr use + , text "]. " + , ppr body + ] + where + str + | isStrictDmd body_dmd = '1' + | otherwise = '0' + use + | isAbsDmd body_dmd = '0' + | isUsedOnce body_dmd = '1' + | otherwise = 'ω' + +instance Outputable BinderInfo where + ppr = ppr . binderInfoBndr + +instance OutputableBndr BinderInfo where + pprBndr b = pprBndr b . binderInfoBndr + pprPrefixOcc = pprPrefixOcc . binderInfoBndr + pprInfixOcc = pprInfixOcc . binderInfoBndr + bndrIsJoin_maybe = bndrIsJoin_maybe . binderInfoBndr + +mkArgOccs :: [StgArg] -> IdSet +mkArgOccs = mkVarSet . mapMaybe stg_arg_var + where + stg_arg_var (StgVarArg occ) = Just occ + stg_arg_var _ = Nothing + +-- | Tags every binder with its 'BinderInfo' and let bindings with their +-- 'Skeleton's. +tagSkeletonTopBind :: CgStgBinding -> LlStgBinding +-- NilSk is OK when tagging top-level bindings. Also, top-level things are never +-- lambda-lifted, so no need to track their argument occurrences. They can also +-- never be let-no-escapes (thus we pass False). +tagSkeletonTopBind bind = bind' + where + (_, _, _, bind') = tagSkeletonBinding False NilSk emptyVarSet bind + +-- | Tags binders of an 'StgExpr' with its 'BinderInfo' and let bindings with +-- their 'Skeleton's. Additionally, returns its 'Skeleton' and the set of binder +-- occurrences in argument and nullary application position +-- (cf. "GHC.Stg.Lift.Analysis#arg_occs"). +tagSkeletonExpr :: CgStgExpr -> (Skeleton, IdSet, LlStgExpr) +tagSkeletonExpr (StgLit lit) + = (NilSk, emptyVarSet, StgLit lit) +tagSkeletonExpr (StgConApp con args tys) + = (NilSk, mkArgOccs args, StgConApp con args tys) +tagSkeletonExpr (StgOpApp op args ty) + = (NilSk, mkArgOccs args, StgOpApp op args ty) +tagSkeletonExpr (StgApp f args) + = (NilSk, arg_occs, StgApp f args) + where + arg_occs + -- This checks for nullary applications, which we treat the same as + -- argument occurrences, see "GHC.Stg.Lift.Analysis#arg_occs". + | null args = unitVarSet f + | otherwise = mkArgOccs args +tagSkeletonExpr (StgLam _ _) = pprPanic "stgLiftLams" (text "StgLam") +tagSkeletonExpr (StgCase scrut bndr ty alts) + = (skel, arg_occs, StgCase scrut' bndr' ty alts') + where + (scrut_skel, scrut_arg_occs, scrut') = tagSkeletonExpr scrut + (alt_skels, alt_arg_occss, alts') = mapAndUnzip3 tagSkeletonAlt alts + skel = bothSk scrut_skel (foldr altSk NilSk alt_skels) + arg_occs = unionVarSets (scrut_arg_occs:alt_arg_occss) `delVarSet` bndr + bndr' = BoringBinder bndr +tagSkeletonExpr (StgTick t e) + = (skel, arg_occs, StgTick t e') + where + (skel, arg_occs, e') = tagSkeletonExpr e +tagSkeletonExpr (StgLet _ bind body) = tagSkeletonLet False body bind +tagSkeletonExpr (StgLetNoEscape _ bind body) = tagSkeletonLet True body bind + +mkLet :: Bool -> Skeleton -> LlStgBinding -> LlStgExpr -> LlStgExpr +mkLet True = StgLetNoEscape +mkLet _ = StgLet + +tagSkeletonLet + :: Bool + -- ^ Is the binding a let-no-escape? + -> CgStgExpr + -- ^ Let body + -> CgStgBinding + -- ^ Binding group + -> (Skeleton, IdSet, LlStgExpr) + -- ^ RHS skeletons, argument occurrences and annotated binding +tagSkeletonLet is_lne body bind + = (let_skel, arg_occs, mkLet is_lne scope bind' body') + where + (body_skel, body_arg_occs, body') = tagSkeletonExpr body + (let_skel, arg_occs, scope, bind') + = tagSkeletonBinding is_lne body_skel body_arg_occs bind + +tagSkeletonBinding + :: Bool + -- ^ Is the binding a let-no-escape? + -> Skeleton + -- ^ Let body skeleton + -> IdSet + -- ^ Argument occurrences in the body + -> CgStgBinding + -- ^ Binding group + -> (Skeleton, IdSet, Skeleton, LlStgBinding) + -- ^ Let skeleton, argument occurrences, scope skeleton of binding and + -- the annotated binding +tagSkeletonBinding is_lne body_skel body_arg_occs (StgNonRec bndr rhs) + = (let_skel, arg_occs, scope, bind') + where + (rhs_skel, rhs_arg_occs, rhs') = tagSkeletonRhs bndr rhs + arg_occs = (body_arg_occs `unionVarSet` rhs_arg_occs) `delVarSet` bndr + bind_skel + | is_lne = rhs_skel -- no closure is allocated for let-no-escapes + | otherwise = ClosureSk bndr (freeVarsOfRhs rhs) rhs_skel + let_skel = bothSk body_skel bind_skel + occurs_as_arg = bndr `elemVarSet` body_arg_occs + -- Compared to the recursive case, this exploits the fact that @bndr@ is + -- never free in @rhs@. + scope = body_skel + bind' = StgNonRec (BindsClosure bndr occurs_as_arg) rhs' +tagSkeletonBinding is_lne body_skel body_arg_occs (StgRec pairs) + = (let_skel, arg_occs, scope, StgRec pairs') + where + (bndrs, _) = unzip pairs + -- Local recursive STG bindings also regard the defined binders as free + -- vars. We want to delete those for our cost model, as these are known + -- calls anyway when we add them to the same top-level recursive group as + -- the top-level binding currently being analysed. + skel_occs_rhss' = map (uncurry tagSkeletonRhs) pairs + rhss_arg_occs = map sndOf3 skel_occs_rhss' + scope_occs = unionVarSets (body_arg_occs:rhss_arg_occs) + arg_occs = scope_occs `delVarSetList` bndrs + -- @skel_rhss@ aren't yet wrapped in closures. We'll do that in a moment, + -- but we also need the un-wrapped skeletons for calculating the @scope@ + -- of the group, as the outer closures don't contribute to closure growth + -- when we lift this specific binding. + scope = foldr (bothSk . fstOf3) body_skel skel_occs_rhss' + -- Now we can build the actual Skeleton for the expression just by + -- iterating over each bind pair. + (bind_skels, pairs') = unzip (zipWith single_bind bndrs skel_occs_rhss') + let_skel = foldr bothSk body_skel bind_skels + single_bind bndr (skel_rhs, _, rhs') = (bind_skel, (bndr', rhs')) + where + -- Here, we finally add the closure around each @skel_rhs@. + bind_skel + | is_lne = skel_rhs -- no closure is allocated for let-no-escapes + | otherwise = ClosureSk bndr fvs skel_rhs + fvs = freeVarsOfRhs rhs' `dVarSetMinusVarSet` mkVarSet bndrs + bndr' = BindsClosure bndr (bndr `elemVarSet` scope_occs) + +tagSkeletonRhs :: Id -> CgStgRhs -> (Skeleton, IdSet, LlStgRhs) +tagSkeletonRhs _ (StgRhsCon ccs dc args) + = (NilSk, mkArgOccs args, StgRhsCon ccs dc args) +tagSkeletonRhs bndr (StgRhsClosure fvs ccs upd bndrs body) + = (rhs_skel, body_arg_occs, StgRhsClosure fvs ccs upd bndrs' body') + where + bndrs' = map BoringBinder bndrs + (body_skel, body_arg_occs, body') = tagSkeletonExpr body + rhs_skel = rhsSk (rhsDmdShell bndr) body_skel + +-- | How many times will the lambda body of the RHS bound to the given +-- identifier be evaluated, relative to its defining context? This function +-- computes the answer in form of a 'DmdShell'. +rhsDmdShell :: Id -> DmdShell +rhsDmdShell bndr + | is_thunk = oneifyDmd ds + | otherwise = peelManyCalls (idArity bndr) cd + where + is_thunk = idArity bndr == 0 + -- Let's pray idDemandInfo is still OK after unarise... + (ds, cd) = toCleanDmd (idDemandInfo bndr) + +tagSkeletonAlt :: CgStgAlt -> (Skeleton, IdSet, LlStgAlt) +tagSkeletonAlt (con, bndrs, rhs) + = (alt_skel, arg_occs, (con, map BoringBinder bndrs, rhs')) + where + (alt_skel, alt_arg_occs, rhs') = tagSkeletonExpr rhs + arg_occs = alt_arg_occs `delVarSetList` bndrs + +-- | Combines several heuristics to decide whether to lambda-lift a given +-- @let@-binding to top-level. See "GHC.Stg.Lift.Analysis#when" for details. +goodToLift + :: DynFlags + -> TopLevelFlag + -> RecFlag + -> (DIdSet -> DIdSet) -- ^ An expander function, turning 'InId's into + -- 'OutId's. See 'GHC.Stg.Lift.Monad.liftedIdsExpander'. + -> [(BinderInfo, LlStgRhs)] + -> Skeleton + -> Maybe DIdSet -- ^ @Just abs_ids@ <=> This binding is beneficial to + -- lift and @abs_ids@ are the variables it would + -- abstract over +goodToLift dflags top_lvl rec_flag expander pairs scope = decide + [ ("top-level", isTopLevel top_lvl) -- keep in sync with Note [When to lift] + , ("memoized", any_memoized) + , ("argument occurrences", arg_occs) + , ("join point", is_join_point) + , ("abstracts join points", abstracts_join_ids) + , ("abstracts known local function", abstracts_known_local_fun) + , ("args spill on stack", args_spill_on_stack) + , ("increases allocation", inc_allocs) + ] where + decide deciders + | not (fancy_or deciders) + = llTrace "stgLiftLams:lifting" + (ppr bndrs <+> ppr abs_ids $$ + ppr allocs $$ + ppr scope) $ + Just abs_ids + | otherwise + = Nothing + ppr_deciders = vcat . map (text . fst) . filter snd + fancy_or deciders + = llTrace "stgLiftLams:goodToLift" (ppr bndrs $$ ppr_deciders deciders) $ + any snd deciders + + bndrs = map (binderInfoBndr . fst) pairs + bndrs_set = mkVarSet bndrs + rhss = map snd pairs + + -- First objective: Calculate @abs_ids@, e.g. the former free variables + -- the lifted binding would abstract over. We have to merge the free + -- variables of all RHS to get the set of variables that will have to be + -- passed through parameters. + fvs = unionDVarSets (map freeVarsOfRhs rhss) + -- To lift the binding to top-level, we want to delete the lifted binders + -- themselves from the free var set. Local let bindings track recursive + -- occurrences in their free variable set. We neither want to apply our + -- cost model to them (see 'tagSkeletonRhs'), nor pass them as parameters + -- when lifted, as these are known calls. We call the resulting set the + -- identifiers we abstract over, thus @abs_ids@. These are all 'OutId's. + -- We will save the set in 'LiftM.e_expansions' for each of the variables + -- if we perform the lift. + abs_ids = expander (delDVarSetList fvs bndrs) + + -- We don't lift updatable thunks or constructors + any_memoized = any is_memoized_rhs rhss + is_memoized_rhs StgRhsCon{} = True + is_memoized_rhs (StgRhsClosure _ _ upd _ _) = isUpdatable upd + + -- Don't lift binders occurring as arguments. This would result in complex + -- argument expressions which would have to be given a name, reintroducing + -- the very allocation at each call site that we wanted to get rid off in + -- the first place. + arg_occs = or (mapMaybe (binderInfoOccursAsArg . fst) pairs) + + -- These don't allocate anyway. + is_join_point = any isJoinId bndrs + + -- Abstracting over join points/let-no-escapes spoils them. + abstracts_join_ids = any isJoinId (dVarSetElems abs_ids) + + -- Abstracting over known local functions that aren't floated themselves + -- turns a known, fast call into an unknown, slow call: + -- + -- let f x = ... + -- g y = ... f x ... -- this was a known call + -- in g 4 + -- + -- After lifting @g@, but not @f@: + -- + -- l_g f y = ... f y ... -- this is now an unknown call + -- let f x = ... + -- in l_g f 4 + -- + -- We can abuse the results of arity analysis for this: + -- idArity f > 0 ==> known + known_fun id = idArity id > 0 + abstracts_known_local_fun + = not (liftLamsKnown dflags) && any known_fun (dVarSetElems abs_ids) + + -- Number of arguments of a RHS in the current binding group if we decide + -- to lift it + n_args + = length + . StgToCmm.Closure.nonVoidIds -- void parameters don't appear in Cmm + . (dVarSetElems abs_ids ++) + . rhsLambdaBndrs + max_n_args + | isRec rec_flag = liftLamsRecArgs dflags + | otherwise = liftLamsNonRecArgs dflags + -- We have 5 hardware registers on x86_64 to pass arguments in. Any excess + -- args are passed on the stack, which means slow memory accesses + args_spill_on_stack + | Just n <- max_n_args = maximum (map n_args rhss) > n + | otherwise = False + + -- We only perform the lift if allocations didn't increase. + -- Note that @clo_growth@ will be 'infinity' if there was positive growth + -- under a multi-shot lambda. + -- Also, abstracting over LNEs is unacceptable. LNEs might return + -- unlifted tuples, which idClosureFootprint can't cope with. + inc_allocs = abstracts_join_ids || allocs > 0 + allocs = clo_growth + mkIntWithInf (negate closuresSize) + -- We calculate and then add up the size of each binding's closure. + -- GHC does not currently share closure environments, and we either lift + -- the entire recursive binding group or none of it. + closuresSize = sum $ flip map rhss $ \rhs -> + closureSize dflags + . dVarSetElems + . expander + . flip dVarSetMinusVarSet bndrs_set + $ freeVarsOfRhs rhs + clo_growth = closureGrowth expander (idClosureFootprint dflags) bndrs_set abs_ids scope + +rhsLambdaBndrs :: LlStgRhs -> [Id] +rhsLambdaBndrs StgRhsCon{} = [] +rhsLambdaBndrs (StgRhsClosure _ _ _ bndrs _) = map binderInfoBndr bndrs + +-- | The size in words of a function closure closing over the given 'Id's, +-- including the header. +closureSize :: DynFlags -> [Id] -> WordOff +closureSize dflags ids = words + sTD_HDR_SIZE dflags + -- We go through sTD_HDR_SIZE rather than fixedHdrSizeW so that we don't + -- optimise differently when profiling is enabled. + where + (words, _, _) + -- Functions have a StdHeader (as opposed to ThunkHeader). + = StgToCmm.Layout.mkVirtHeapOffsets dflags StgToCmm.Layout.StdHeader + . StgToCmm.Closure.addIdReps + . StgToCmm.Closure.nonVoidIds + $ ids + +-- | The number of words a single 'Id' adds to a closure's size. +-- Note that this can't handle unboxed tuples (which may still be present in +-- let-no-escapes, even after Unarise), in which case +-- @'GHC.StgToCmm.Closure.idPrimRep'@ will crash. +idClosureFootprint:: DynFlags -> Id -> WordOff +idClosureFootprint dflags + = StgToCmm.ArgRep.argRepSizeW dflags + . StgToCmm.ArgRep.idArgRep + +-- | @closureGrowth expander sizer f fvs@ computes the closure growth in words +-- as a result of lifting @f@ to top-level. If there was any growing closure +-- under a multi-shot lambda, the result will be 'infinity'. +-- Also see "GHC.Stg.Lift.Analysis#clogro". +closureGrowth + :: (DIdSet -> DIdSet) + -- ^ Expands outer free ids that were lifted to their free vars + -> (Id -> Int) + -- ^ Computes the closure footprint of an identifier + -> IdSet + -- ^ Binding group for which lifting is to be decided + -> DIdSet + -- ^ Free vars of the whole binding group prior to lifting it. These must be + -- available at call sites if we decide to lift the binding group. + -> Skeleton + -- ^ Abstraction of the scope of the function + -> IntWithInf + -- ^ Closure growth. 'infinity' indicates there was growth under a + -- (multi-shot) lambda. +closureGrowth expander sizer group abs_ids = go + where + go NilSk = 0 + go (BothSk a b) = go a + go b + go (AltSk a b) = max (go a) (go b) + go (ClosureSk _ clo_fvs rhs) + -- If no binder of the @group@ occurs free in the closure, the lifting + -- won't have any effect on it and we can omit the recursive call. + | n_occs == 0 = 0 + -- Otherwise, we account the cost of allocating the closure and add it to + -- the closure growth of its RHS. + | otherwise = mkIntWithInf cost + go rhs + where + n_occs = sizeDVarSet (clo_fvs' `dVarSetIntersectVarSet` group) + -- What we close over considering prior lifting decisions + clo_fvs' = expander clo_fvs + -- Variables that would additionally occur free in the closure body if + -- we lift @f@ + newbies = abs_ids `minusDVarSet` clo_fvs' + -- Lifting @f@ removes @f@ from the closure but adds all @newbies@ + cost = foldDVarSet (\id size -> sizer id + size) 0 newbies - n_occs + go (RhsSk body_dmd body) + -- The conservative assumption would be that + -- 1. Every RHS with positive growth would be called multiple times, + -- modulo thunks. + -- 2. Every RHS with negative growth wouldn't be called at all. + -- + -- In the first case, we'd have to return 'infinity', while in the + -- second case, we'd have to return 0. But we can do far better + -- considering information from the demand analyser, which provides us + -- with conservative estimates on minimum and maximum evaluation + -- cardinality. The @body_dmd@ part of 'RhsSk' is the result of + -- 'rhsDmdShell' and accurately captures the cardinality of the RHSs body + -- relative to its defining context. + | isAbsDmd body_dmd = 0 + | cg <= 0 = if isStrictDmd body_dmd then cg else 0 + | isUsedOnce body_dmd = cg + | otherwise = infinity + where + cg = go body diff --git a/compiler/GHC/Stg/Lift/Monad.hs b/compiler/GHC/Stg/Lift/Monad.hs new file mode 100644 index 0000000000..7d17e53cd9 --- /dev/null +++ b/compiler/GHC/Stg/Lift/Monad.hs @@ -0,0 +1,348 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE TypeFamilies #-} + +-- | Hides away distracting bookkeeping while lambda lifting into a 'LiftM' +-- monad. +module GHC.Stg.Lift.Monad ( + decomposeStgBinding, mkStgBinding, + Env (..), + -- * #floats# Handling floats + -- $floats + FloatLang (..), collectFloats, -- Exported just for the docs + -- * Transformation monad + LiftM, runLiftM, withCaffyness, + -- ** Adding bindings + startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding, + -- ** Substitution and binders + withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs, + -- ** Occurrences + substOcc, isLifted, formerFreeVars, liftedIdsExpander + ) where + +#include "HsVersions.h" + +import GhcPrelude + +import BasicTypes +import CostCentre ( isCurrentCCS, dontCareCCS ) +import DynFlags +import FastString +import Id +import IdInfo +import Name +import Outputable +import OrdList +import GHC.Stg.Subst +import GHC.Stg.Syntax +import Type +import UniqSupply +import Util +import VarEnv +import VarSet + +import Control.Arrow ( second ) +import Control.Monad.Trans.Class +import Control.Monad.Trans.RWS.Strict ( RWST, runRWST ) +import qualified Control.Monad.Trans.RWS.Strict as RWS +import Control.Monad.Trans.Cont ( ContT (..) ) +import Data.ByteString ( ByteString ) + +-- | @uncurry 'mkStgBinding' . 'decomposeStgBinding' = id@ +decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)]) +decomposeStgBinding (StgRec pairs) = (Recursive, pairs) +decomposeStgBinding (StgNonRec bndr rhs) = (NonRecursive, [(bndr, rhs)]) + +mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass +mkStgBinding Recursive = StgRec +mkStgBinding NonRecursive = uncurry StgNonRec . head + +-- | Environment threaded around in a scoped, @Reader@-like fashion. +data Env + = Env + { e_dflags :: !DynFlags + -- ^ Read-only. + , e_subst :: !Subst + -- ^ We need to track the renamings of local 'InId's to their lifted 'OutId', + -- because shadowing might make a closure's free variables unavailable at its + -- call sites. Consider: + -- @ + -- let f y = x + y in let x = 4 in f x + -- @ + -- Here, @f@ can't be lifted to top-level, because its free variable @x@ isn't + -- available at its call site. + , e_expansions :: !(IdEnv DIdSet) + -- ^ Lifted 'Id's don't occur as free variables in any closure anymore, because + -- they are bound at the top-level. Every occurrence must supply the formerly + -- free variables of the lifted 'Id', so they in turn become free variables of + -- the call sites. This environment tracks this expansion from lifted 'Id's to + -- their free variables. + -- + -- 'InId's to 'OutId's. + -- + -- Invariant: 'Id's not present in this map won't be substituted. + , e_in_caffy_context :: !Bool + -- ^ Are we currently analysing within a caffy context (e.g. the containing + -- top-level binder's 'idCafInfo' is 'MayHaveCafRefs')? If not, we can safely + -- assume that functions we lift out aren't caffy either. + } + +emptyEnv :: DynFlags -> Env +emptyEnv dflags = Env dflags emptySubst emptyVarEnv False + + +-- Note [Handling floats] +-- ~~~~~~~~~~~~~~~~~~~~~~ +-- $floats +-- Consider the following expression: +-- +-- @ +-- f x = +-- let g y = ... f y ... +-- in g x +-- @ +-- +-- What happens when we want to lift @g@? Normally, we'd put the lifted @l_g@ +-- binding above the binding for @f@: +-- +-- @ +-- g f y = ... f y ... +-- f x = g f x +-- @ +-- +-- But this very unnecessarily turns a known call to @f@ into an unknown one, in +-- addition to complicating matters for the analysis. +-- Instead, we'd really like to put both functions in the same recursive group, +-- thereby preserving the known call: +-- +-- @ +-- Rec { +-- g y = ... f y ... +-- f x = g x +-- } +-- @ +-- +-- But we don't want this to happen for just /any/ binding. That would create +-- possibly huge recursive groups in the process, calling for an occurrence +-- analyser on STG. +-- So, we need to track when we lift a binding out of a recursive RHS and add +-- the binding to the same recursive group as the enclosing recursive binding +-- (which must have either already been at the top-level or decided to be +-- lifted itself in order to preserve the known call). +-- +-- This is done by expressing this kind of nesting structure as a 'Writer' over +-- @['FloatLang']@ and flattening this expression in 'runLiftM' by a call to +-- 'collectFloats'. +-- API-wise, the analysis will not need to know about the whole 'FloatLang' +-- business and will just manipulate it indirectly through actions in 'LiftM'. + +-- | We need to detect when we are lifting something out of the RHS of a +-- recursive binding (c.f. "GHC.Stg.Lift.Monad#floats"), in which case that +-- binding needs to be added to the same top-level recursive group. This +-- requires we detect a certain nesting structure, which is encoded by +-- 'StartBindingGroup' and 'EndBindingGroup'. +-- +-- Although 'collectFloats' will only ever care if the current binding to be +-- lifted (through 'LiftedBinding') will occur inside such a binding group or +-- not, e.g. doesn't care about the nesting level as long as its greater than 0. +data FloatLang + = StartBindingGroup + | EndBindingGroup + | PlainTopBinding OutStgTopBinding + | LiftedBinding OutStgBinding + +instance Outputable FloatLang where + ppr StartBindingGroup = char '(' + ppr EndBindingGroup = char ')' + ppr (PlainTopBinding StgTopStringLit{}) = text "<str>" + ppr (PlainTopBinding (StgTopLifted b)) = ppr (LiftedBinding b) + ppr (LiftedBinding bind) = (if isRec rec then char 'r' else char 'n') <+> ppr (map fst pairs) + where + (rec, pairs) = decomposeStgBinding bind + +-- | Flattens an expression in @['FloatLang']@ into an STG program, see #floats. +-- Important pre-conditions: The nesting of opening 'StartBindinGroup's and +-- closing 'EndBindinGroup's is balanced. Also, it is crucial that every binding +-- group has at least one recursive binding inside. Otherwise there's no point +-- in announcing the binding group in the first place and an @ASSERT@ will +-- trigger. +collectFloats :: [FloatLang] -> [OutStgTopBinding] +collectFloats = go (0 :: Int) [] + where + go 0 [] [] = [] + go _ _ [] = pprPanic "collectFloats" (text "unterminated group") + go n binds (f:rest) = case f of + StartBindingGroup -> go (n+1) binds rest + EndBindingGroup + | n == 0 -> pprPanic "collectFloats" (text "no group to end") + | n == 1 -> StgTopLifted (merge_binds binds) : go 0 [] rest + | otherwise -> go (n-1) binds rest + PlainTopBinding top_bind + | n == 0 -> top_bind : go n binds rest + | otherwise -> pprPanic "collectFloats" (text "plain top binding inside group") + LiftedBinding bind + | n == 0 -> StgTopLifted (rm_cccs bind) : go n binds rest + | otherwise -> go n (bind:binds) rest + + map_rhss f = uncurry mkStgBinding . second (map (second f)) . decomposeStgBinding + rm_cccs = map_rhss removeRhsCCCS + merge_binds binds = ASSERT( any is_rec binds ) + StgRec (concatMap (snd . decomposeStgBinding . rm_cccs) binds) + is_rec StgRec{} = True + is_rec _ = False + +-- | Omitting this makes for strange closure allocation schemes that crash the +-- GC. +removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass +removeRhsCCCS (StgRhsClosure ext ccs upd bndrs body) + | isCurrentCCS ccs + = StgRhsClosure ext dontCareCCS upd bndrs body +removeRhsCCCS (StgRhsCon ccs con args) + | isCurrentCCS ccs + = StgRhsCon dontCareCCS con args +removeRhsCCCS rhs = rhs + +-- | The analysis monad consists of the following 'RWST' components: +-- +-- * 'Env': Reader-like context. Contains a substitution, info about how +-- how lifted identifiers are to be expanded into applications and details +-- such as 'DynFlags' and a flag helping with determining if a lifted +-- binding is caffy. +-- +-- * @'OrdList' 'FloatLang'@: Writer output for the resulting STG program. +-- +-- * No pure state component +-- +-- * But wrapping around 'UniqSM' for generating fresh lifted binders. +-- (The @uniqAway@ approach could give the same name to two different +-- lifted binders, so this is necessary.) +newtype LiftM a + = LiftM { unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a } + deriving (Functor, Applicative, Monad) + +instance HasDynFlags LiftM where + getDynFlags = LiftM (RWS.asks e_dflags) + +instance MonadUnique LiftM where + getUniqueSupplyM = LiftM (lift getUniqueSupplyM) + getUniqueM = LiftM (lift getUniqueM) + getUniquesM = LiftM (lift getUniquesM) + +runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding] +runLiftM dflags us (LiftM m) = collectFloats (fromOL floats) + where + (_, _, floats) = initUs_ us (runRWST m (emptyEnv dflags) ()) + +-- | Assumes a given caffyness for the execution of the passed action, which +-- influences the 'cafInfo' of lifted bindings. +withCaffyness :: Bool -> LiftM a -> LiftM a +withCaffyness caffy action + = LiftM (RWS.local (\e -> e { e_in_caffy_context = caffy }) (unwrapLiftM action)) + +-- | Writes a plain 'StgTopStringLit' to the output. +addTopStringLit :: OutId -> ByteString -> LiftM () +addTopStringLit id = LiftM . RWS.tell . unitOL . PlainTopBinding . StgTopStringLit id + +-- | Starts a recursive binding group. See #floats# and 'collectFloats'. +startBindingGroup :: LiftM () +startBindingGroup = LiftM $ RWS.tell $ unitOL $ StartBindingGroup + +-- | Ends a recursive binding group. See #floats# and 'collectFloats'. +endBindingGroup :: LiftM () +endBindingGroup = LiftM $ RWS.tell $ unitOL $ EndBindingGroup + +-- | Lifts a binding to top-level. Depending on whether it's declared inside +-- a recursive RHS (see #floats# and 'collectFloats'), this might be added to +-- an existing recursive top-level binding group. +addLiftedBinding :: OutStgBinding -> LiftM () +addLiftedBinding = LiftM . RWS.tell . unitOL . LiftedBinding + +-- | Takes a binder and a continuation which is called with the substituted +-- binder. The continuation will be evaluated in a 'LiftM' context in which that +-- binder is deemed in scope. Think of it as a 'RWS.local' computation: After +-- the continuation finishes, the new binding won't be in scope anymore. +withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a +withSubstBndr bndr inner = LiftM $ do + subst <- RWS.asks e_subst + let (bndr', subst') = substBndr bndr subst + RWS.local (\e -> e { e_subst = subst' }) (unwrapLiftM (inner bndr')) + +-- | See 'withSubstBndr'. +withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a +withSubstBndrs = runContT . traverse (ContT . withSubstBndr) + +-- | Similarly to 'withSubstBndr', this function takes a set of variables to +-- abstract over, the binder to lift (and generate a fresh, substituted name +-- for) and a continuation in which that fresh, lifted binder is in scope. +-- +-- It takes care of all the details involved with copying and adjusting the +-- binder, fresh name generation and caffyness. +withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a +withLiftedBndr abs_ids bndr inner = do + uniq <- getUniqueM + let str = "$l" ++ occNameString (getOccName bndr) + let ty = mkLamTypes (dVarSetElems abs_ids) (idType bndr) + -- When the enclosing top-level binding is not caffy, then the lifted + -- binding will not be caffy either. If we don't recognize this, non-caffy + -- things call caffy things and then codegen screws up. + in_caffy_ctxt <- LiftM (RWS.asks e_in_caffy_context) + let caf_info = if in_caffy_ctxt then MayHaveCafRefs else NoCafRefs + let bndr' + -- See Note [transferPolyIdInfo] in Id.hs. We need to do this at least + -- for arity information. + = transferPolyIdInfo bndr (dVarSetElems abs_ids) + -- Otherwise we confuse code gen if bndr was not caffy: the new bndr is + -- assumed to be caffy and will need an SRT. Transitive call sites might + -- not be caffy themselves and subsequently will miss a static link + -- field in their closure. Chaos ensues. + . flip setIdCafInfo caf_info + . mkSysLocal (mkFastString str) uniq + $ ty + LiftM $ RWS.local + (\e -> e + { e_subst = extendSubst bndr bndr' $ extendInScope bndr' $ e_subst e + , e_expansions = extendVarEnv (e_expansions e) bndr abs_ids + }) + (unwrapLiftM (inner bndr')) + +-- | See 'withLiftedBndr'. +withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a +withLiftedBndrs abs_ids = runContT . traverse (ContT . withLiftedBndr abs_ids) + +-- | Substitutes a binder /occurrence/, which was brought in scope earlier by +-- 'withSubstBndr'\/'withLiftedBndr'. +substOcc :: Id -> LiftM Id +substOcc id = LiftM (RWS.asks (lookupIdSubst id . e_subst)) + +-- | Whether the given binding was decided to be lambda lifted. +isLifted :: InId -> LiftM Bool +isLifted bndr = LiftM (RWS.asks (elemVarEnv bndr . e_expansions)) + +-- | Returns an empty list for a binding that was not lifted and the list of all +-- local variables the binding abstracts over (so, exactly the additional +-- arguments at adjusted call sites) otherwise. +formerFreeVars :: InId -> LiftM [OutId] +formerFreeVars f = LiftM $ do + expansions <- RWS.asks e_expansions + pure $ case lookupVarEnv expansions f of + Nothing -> [] + Just fvs -> dVarSetElems fvs + +-- | Creates an /expander function/ for the current set of lifted binders. +-- This expander function will replace any 'InId' by their corresponding 'OutId' +-- and, in addition, will expand any lifted binders by the former free variables +-- it abstracts over. +liftedIdsExpander :: LiftM (DIdSet -> DIdSet) +liftedIdsExpander = LiftM $ do + expansions <- RWS.asks e_expansions + subst <- RWS.asks e_subst + -- We use @noWarnLookupIdSubst@ here in order to suppress "not in scope" + -- warnings generated by 'lookupIdSubst' due to local bindings within RHS. + -- These are not in the InScopeSet of @subst@ and extending the InScopeSet in + -- @goodToLift@/@closureGrowth@ before passing it on to @expander@ is too much + -- trouble. + let go set fv = case lookupVarEnv expansions fv of + Nothing -> extendDVarSet set (noWarnLookupIdSubst fv subst) -- Not lifted + Just fvs' -> unionDVarSet set fvs' + let expander fvs = foldl' go emptyDVarSet (dVarSetElems fvs) + pure expander |