diff options
author | Joachim Breitner <mail@joachim-breitner.de> | 2018-04-06 17:26:45 -0400 |
---|---|---|
committer | Joachim Breitner <mail@joachim-breitner.de> | 2018-04-09 11:25:06 -0400 |
commit | b14c03737574895718eed786a60dfdfd42ab49ce (patch) | |
tree | cb649dc1d68784fe8a8b8e8bdc5d37a252fdfcf4 /compiler/simplCore | |
parent | 8b823f270e53627ddca1a993c05f1ab556742d96 (diff) | |
download | haskell-b14c03737574895718eed786a60dfdfd42ab49ce.tar.gz |
Some cleanup of the Exitification code
based on a thorough review by Simon in comments
https://ghc.haskell.org/trac/ghc/ticket/14152#comment:33
through 37.
The changes are:
* `isExitJoinId` is moved to `SimplUtils`, because
it is only valid when occurrence information is up-to-date.
* Abstracted variables are properly sorted using `sortQuantVars`
* Exitification does not set occ info.
And then minor quibles to notes and avoiding some unhelpful shadowing
of local names.
Differential Revision: https://phabricator.haskell.org/D4576
Diffstat (limited to 'compiler/simplCore')
-rw-r--r-- | compiler/simplCore/Exitify.hs | 45 | ||||
-rw-r--r-- | compiler/simplCore/SimplUtils.hs | 12 |
2 files changed, 36 insertions, 21 deletions
diff --git a/compiler/simplCore/Exitify.hs b/compiler/simplCore/Exitify.hs index cf6a930d3e..570186e219 100644 --- a/compiler/simplCore/Exitify.hs +++ b/compiler/simplCore/Exitify.hs @@ -48,16 +48,19 @@ import VarEnv import CoreFVs import FastString import Type +import MkCore ( sortQuantVars ) import Data.Bifunctor import Control.Monad -- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them. +-- The really interesting function is exitify exitifyProgram :: CoreProgram -> CoreProgram exitifyProgram binds = map goTopLvl binds where goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e) goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs) + -- Top-level bindings are never join points in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds @@ -91,6 +94,10 @@ exitifyProgram binds = map goTopLvl binds is_join_rec = any (isJoinId . fst) pairs in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs) + +-- | State Monad used inside `exitify` +type ExitifyM = State [(JoinId, CoreExpr)] + -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as -- join-points outside the joinrec. exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr) @@ -120,11 +127,13 @@ exitify in_scope pairs = -- checks if there are no more recursive calls, if so, abstracts over -- variables bound on the way and lifts it out as a join point. -- - -- It uses a state monad to keep track of floated binds + -- ExitifyM is a state monad to keep track of floated binds go :: [Var] -- ^ variables to abstract over -> CoreExprWithFVs -- ^ current expression in tail position - -> State [(Id, CoreExpr)] CoreExpr + -> ExitifyM CoreExpr + -- We first look at the expression (no matter what it shape is) + -- and determine if we can turn it into a exit join point go captured ann_e -- Do not touch an expression that is already a join jump where all arguments -- are captured variables. See Note [Idempotency] @@ -145,13 +154,13 @@ exitify in_scope pairs = -- We have something to float out! | is_exit = do -- Assemble the RHS of the exit join point - let rhs = mkLams args e + let rhs = mkLams abs_vars e ty = exprType rhs let avoid = in_scope `extendInScopeSetList` captured -- Remember this binding under a suitable name - v <- addExit avoid ty (length args) rhs + v <- addExit avoid ty (length abs_vars) rhs -- And jump to it from here - return $ mkVarApps (Var v) args + return $ mkVarApps (Var v) abs_vars where -- An exit expression has no recursive calls is_exit = disjointVarSet fvs recursive_calls @@ -166,14 +175,17 @@ exitify in_scope pairs = is_interesting = anyVarSet isLocalId (fvs `minusVarSet` mkVarSet captured) -- The possible arguments of this exit join point - args = filter (`elemVarSet` fvs) captured + abs_vars = sortQuantVars $ filter (`elemVarSet` fvs) captured -- We cannot abstract over join points - captures_join_points = any isJoinId args + captures_join_points = any isJoinId abs_vars e = deAnnotate ann_e fvs = dVarSetToVarSet (freeVarsOf ann_e) + -- We could not turn it into a exit joint point. So now recurse + -- into all expression where eligible exit join points might sit, + -- i.e. into all tail-call positions: -- Case right hand sides are in tail-call position go captured (_, AnnCase scrut bndr ty alts) = do @@ -211,6 +223,8 @@ exitify in_scope pairs = return $ Let bind body' where bind = deAnnBind ann_bind + -- Cannot be turned into an exit join point, but also has no + -- tail-call subexpression. Nothing to do here. go _ ann_e = return (deAnnotate ann_e) @@ -227,14 +241,6 @@ mkExitJoinId in_scope ty join_arity = do where exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty `asJoinId` join_arity - `setIdOccInfo` exit_occ_info - - -- See Note [Do not inline exit join points] - exit_occ_info = - OneOcc { occ_in_lam = True - , occ_one_br = True - , occ_int_cxt = False - , occ_tail = AlwaysTailCalled join_arity } addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId addExit in_scope ty join_arity rhs = do @@ -245,8 +251,6 @@ addExit in_scope ty join_arity rhs = do return v -type ExitifyM = State [(JoinId, CoreExpr)] - {- Note [Interesting expression] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -381,6 +385,8 @@ joinrecs are nested. Further downside of A: If the exitify function returns annotated expressions, it would have to ensure that the annotations are correct. +We therefore choose B, and calculate the free variables in `exitify`. + Note [Do not inline exit join points] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -399,7 +405,8 @@ To prevent this, we need to recognize exit join points, and then disable inlining. Exit join points, recognizeable using `isExitJoinId` are join points with an -occurence in a recursive group, and can be recognized using `isExitJoinId`. +occurence in a recursive group, and can be recognized (after the occurence +analyzer ran!) using `isExitJoinId`. This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`, because the lambdas of a non-recursive join point are not considered for `occ_in_lam`. For example, in the following code, `j1` is /not/ marked @@ -408,8 +415,6 @@ occ_in_lam, because `j2` is called only once. join j1 x = x+1 join j2 y = join j1 (y+2) -We create exit join point ids with such an `OccInfo`, see `exit_occ_info`. - To prevent inlining, we check for isExitJoinId * In `preInlineUnconditionally` directly. * In `simplLetUnfolding` we simply give exit join points no unfolding, which diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs index db26af426e..7c0689d9be 100644 --- a/compiler/simplCore/SimplUtils.hs +++ b/compiler/simplCore/SimplUtils.hs @@ -30,7 +30,10 @@ module SimplUtils ( addValArgTo, addCastTo, addTyArgTo, argInfoExpr, argInfoAppArgs, pushSimplifiedArgs, - abstractFloats + abstractFloats, + + -- Utilities + isExitJoinId ) where #include "HsVersions.h" @@ -2199,6 +2202,13 @@ in PrelRules) mkCase3 _dflags scrut bndr alts_ty alts = return (Case scrut bndr alts_ty alts) +-- See Note [Exitification] and Note [Do not inline exit join points] in Exitify.hs +-- This lives here (and not in Id) becuase occurrence info is only valid on +-- InIds, so it's crucial that isExitJoinId is only called on freshly +-- occ-analysed code. It's not a generic function you can call anywhere. +isExitJoinId :: Var -> Bool +isExitJoinId id = isJoinId id && isOneOcc (idOccInfo id) && occ_in_lam (idOccInfo id) + {- Note [Dead binders] ~~~~~~~~~~~~~~~~~~~~ |