summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/basicTypes/Id.hs6
-rw-r--r--compiler/basicTypes/Unique.hs4
-rw-r--r--compiler/coreSyn/CoreLint.hs1
-rw-r--r--compiler/coreSyn/CoreSyn.hs10
-rw-r--r--compiler/ghc.cabal.in1
-rw-r--r--compiler/main/DynFlags.hs6
-rw-r--r--compiler/simplCore/CoreMonad.hs2
-rw-r--r--compiler/simplCore/Exitify.hs442
-rw-r--r--compiler/simplCore/SimplCore.hs8
-rw-r--r--compiler/simplCore/SimplUtils.hs1
-rw-r--r--compiler/simplCore/Simplify.hs3
11 files changed, 478 insertions, 6 deletions
diff --git a/compiler/basicTypes/Id.hs b/compiler/basicTypes/Id.hs
index e1902ff853..63ca38cb61 100644
--- a/compiler/basicTypes/Id.hs
+++ b/compiler/basicTypes/Id.hs
@@ -74,7 +74,7 @@ module Id (
DictId, isDictId, isEvVar,
-- ** Join variables
- JoinId, isJoinId, isJoinId_maybe, idJoinArity,
+ JoinId, isJoinId, isJoinId_maybe, idJoinArity, isExitJoinId,
asJoinId, asJoinId_maybe, zapJoinId,
-- ** Inline pragma stuff
@@ -497,6 +497,10 @@ isJoinId_maybe id
_ -> Nothing
| otherwise = Nothing
+-- see Note [Exitification] and see Note [Do not inline exit join points]
+isExitJoinId :: Var -> Bool
+isExitJoinId id = isJoinId id && isOneOcc (idOccInfo id) && occ_in_lam (idOccInfo id)
+
idDataCon :: Id -> DataCon
-- ^ Get from either the worker or the wrapper 'Id' to the 'DataCon'. Currently used only in the desugarer.
--
diff --git a/compiler/basicTypes/Unique.hs b/compiler/basicTypes/Unique.hs
index a2792e196a..30de08ebd4 100644
--- a/compiler/basicTypes/Unique.hs
+++ b/compiler/basicTypes/Unique.hs
@@ -36,6 +36,7 @@ module Unique (
deriveUnique, -- Ditto
newTagUnique, -- Used in CgCase
initTyVarUnique,
+ initExitJoinUnique,
nonDetCmpUnique,
isValidKnownKeyUnique, -- Used in PrelInfo.knownKeyNamesOkay
@@ -436,3 +437,6 @@ mkVarOccUnique fs = mkUnique 'i' (uniqueOfFS fs)
mkDataOccUnique fs = mkUnique 'd' (uniqueOfFS fs)
mkTvOccUnique fs = mkUnique 'v' (uniqueOfFS fs)
mkTcOccUnique fs = mkUnique 'c' (uniqueOfFS fs)
+
+initExitJoinUnique :: Unique
+initExitJoinUnique = mkUnique 's' 0
diff --git a/compiler/coreSyn/CoreLint.hs b/compiler/coreSyn/CoreLint.hs
index 6b6d8d9d1e..96c34852ba 100644
--- a/compiler/coreSyn/CoreLint.hs
+++ b/compiler/coreSyn/CoreLint.hs
@@ -268,6 +268,7 @@ coreDumpFlag (CoreDoFloatOutwards {}) = Just Opt_D_verbose_core2core
coreDumpFlag CoreLiberateCase = Just Opt_D_verbose_core2core
coreDumpFlag CoreDoStaticArgs = Just Opt_D_verbose_core2core
coreDumpFlag CoreDoCallArity = Just Opt_D_dump_call_arity
+coreDumpFlag CoreDoExitify = Just Opt_D_dump_exitify
coreDumpFlag CoreDoStrictness = Just Opt_D_dump_stranal
coreDumpFlag CoreDoWorkerWrapper = Just Opt_D_dump_worker_wrapper
coreDumpFlag CoreDoSpecialising = Just Opt_D_dump_spec
diff --git a/compiler/coreSyn/CoreSyn.hs b/compiler/coreSyn/CoreSyn.hs
index 9333d0fcf8..c931bf187d 100644
--- a/compiler/coreSyn/CoreSyn.hs
+++ b/compiler/coreSyn/CoreSyn.hs
@@ -77,7 +77,7 @@ module CoreSyn (
collectAnnArgs, collectAnnArgsTicks,
-- ** Operations on annotations
- deAnnotate, deAnnotate', deAnnAlt,
+ deAnnotate, deAnnotate', deAnnAlt, deAnnBind,
collectAnnBndrs, collectNAnnBndrs,
-- * Orphanhood
@@ -2160,16 +2160,16 @@ deAnnotate' (AnnTick tick body) = Tick tick (deAnnotate body)
deAnnotate' (AnnLet bind body)
= Let (deAnnBind bind) (deAnnotate body)
- where
- deAnnBind (AnnNonRec var rhs) = NonRec var (deAnnotate rhs)
- deAnnBind (AnnRec pairs) = Rec [(v,deAnnotate rhs) | (v,rhs) <- pairs]
-
deAnnotate' (AnnCase scrut v t alts)
= Case (deAnnotate scrut) v t (map deAnnAlt alts)
deAnnAlt :: AnnAlt bndr annot -> Alt bndr
deAnnAlt (con,args,rhs) = (con,args,deAnnotate rhs)
+deAnnBind :: AnnBind b annot -> Bind b
+deAnnBind (AnnNonRec var rhs) = NonRec var (deAnnotate rhs)
+deAnnBind (AnnRec pairs) = Rec [(v,deAnnotate rhs) | (v,rhs) <- pairs]
+
-- | As 'collectBinders' but for 'AnnExpr' rather than 'Expr'
collectAnnBndrs :: AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectAnnBndrs e
diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in
index d3cbe9563b..acfaba9b73 100644
--- a/compiler/ghc.cabal.in
+++ b/compiler/ghc.cabal.in
@@ -429,6 +429,7 @@ Library
StgSyn
CallArity
DmdAnal
+ Exitify
WorkWrap
WwLib
FamInst
diff --git a/compiler/main/DynFlags.hs b/compiler/main/DynFlags.hs
index 7602b719cc..1b1837fdc3 100644
--- a/compiler/main/DynFlags.hs
+++ b/compiler/main/DynFlags.hs
@@ -360,6 +360,7 @@ data DumpFlag
| Opt_D_dump_prep
| Opt_D_dump_stg
| Opt_D_dump_call_arity
+ | Opt_D_dump_exitify
| Opt_D_dump_stranal
| Opt_D_dump_str_signatures
| Opt_D_dump_tc
@@ -428,6 +429,7 @@ data GeneralFlag
-- optimisation opts
| Opt_CallArity
+ | Opt_Exitification
| Opt_Strictness
| Opt_LateDmdAnal
| Opt_KillAbsence
@@ -3005,6 +3007,8 @@ dynamic_flags_deps = [
(setDumpFlag Opt_D_dump_stg)
, make_ord_flag defGhcFlag "ddump-call-arity"
(setDumpFlag Opt_D_dump_call_arity)
+ , make_ord_flag defGhcFlag "ddump-exitify"
+ (setDumpFlag Opt_D_dump_exitify)
, make_ord_flag defGhcFlag "ddump-stranal"
(setDumpFlag Opt_D_dump_stranal)
, make_ord_flag defGhcFlag "ddump-str-signatures"
@@ -3706,6 +3710,7 @@ fFlagsDeps = [
flagGhciSpec "break-on-exception" Opt_BreakOnException,
flagSpec "building-cabal-package" Opt_BuildingCabalPackage,
flagSpec "call-arity" Opt_CallArity,
+ flagSpec "exitification" Opt_Exitification,
flagSpec "case-merge" Opt_CaseMerge,
flagSpec "case-folding" Opt_CaseFolding,
flagSpec "cmm-elim-common-blocks" Opt_CmmElimCommonBlocks,
@@ -4159,6 +4164,7 @@ optLevelFlags -- see Note [Documenting optimisation flags]
, ([0], Opt_OmitInterfacePragmas)
, ([1,2], Opt_CallArity)
+ , ([1,2], Opt_Exitification)
, ([1,2], Opt_CaseMerge)
, ([1,2], Opt_CaseFolding)
, ([1,2], Opt_CmmElimCommonBlocks)
diff --git a/compiler/simplCore/CoreMonad.hs b/compiler/simplCore/CoreMonad.hs
index 35790ab925..107440a768 100644
--- a/compiler/simplCore/CoreMonad.hs
+++ b/compiler/simplCore/CoreMonad.hs
@@ -114,6 +114,7 @@ data CoreToDo -- These are diff core-to-core passes,
| CoreDoPrintCore
| CoreDoStaticArgs
| CoreDoCallArity
+ | CoreDoExitify
| CoreDoStrictness
| CoreDoWorkerWrapper
| CoreDoSpecialising
@@ -141,6 +142,7 @@ instance Outputable CoreToDo where
ppr CoreLiberateCase = text "Liberate case"
ppr CoreDoStaticArgs = text "Static argument"
ppr CoreDoCallArity = text "Called arity analysis"
+ ppr CoreDoExitify = text "Exitification transformation"
ppr CoreDoStrictness = text "Demand analysis"
ppr CoreDoWorkerWrapper = text "Worker Wrapper binds"
ppr CoreDoSpecialising = text "Specialise"
diff --git a/compiler/simplCore/Exitify.hs b/compiler/simplCore/Exitify.hs
new file mode 100644
index 0000000000..2d3b5aff55
--- /dev/null
+++ b/compiler/simplCore/Exitify.hs
@@ -0,0 +1,442 @@
+module Exitify ( exitifyProgram ) where
+
+{-
+Note [Exitification]
+~~~~~~~~~~~~~~~~~~~~
+
+This module implements Exitification. The goal is to pull as much code out of
+recursive functions as possible, as the simplifier is better at inlining into
+call-sites that are not in recursive functions.
+
+Example:
+
+ let t = foo bar
+ joinrec go 0 x y = t (x*x)
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+We’d like to inline `t`, but that does not happen: Because t is a thunk and is
+used in a recursive function, doing so might lose sharing in general. In
+this case, however, `t` is on the _exit path_ of `go`, so called at most once.
+How do we make this clearly visible to the simplifier?
+
+A code path (i.e., an expression in a tail-recursive position) in a recursive
+function is an exit path if it does not contain a recursive call. We can bind
+this expression outside the recursive function, as a join-point.
+
+Example result:
+
+ let t = foo bar
+ join exit x = t (x*x)
+ joinrec go 0 x y = jump exit x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+Now `t` is no longer in a recursive function, and good things happen!
+-}
+
+import GhcPrelude
+import Var
+import Id
+import IdInfo
+import CoreSyn
+import CoreUtils
+import State
+import Unique
+import VarSet
+import VarEnv
+import CoreFVs
+import FastString
+import Type
+
+import Data.Bifunctor
+import Control.Monad
+
+-- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
+exitifyProgram :: CoreProgram -> CoreProgram
+exitifyProgram binds = map goTopLvl binds
+ where
+ goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
+ goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
+
+ in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
+
+ go :: InScopeSet -> CoreExpr -> CoreExpr
+ go _ e@(Var{}) = e
+ go _ e@(Lit {}) = e
+ go _ e@(Type {}) = e
+ go _ e@(Coercion {}) = e
+
+ go in_scope (Lam v e') = Lam v (go in_scope' e')
+ where in_scope' = in_scope `extendInScopeSet` v
+ go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2)
+ go in_scope (Case scrut bndr ty alts)
+ = Case (go in_scope scrut) bndr ty (map (goAlt in_scope') alts)
+ where in_scope' = in_scope `extendInScopeSet` bndr
+ go in_scope (Cast e' c) = Cast (go in_scope e') c
+ go in_scope (Tick t e') = Tick t (go in_scope e')
+ go in_scope (Let bind body) = goBind in_scope bind (go in_scope' body)
+ where in_scope' = in_scope `extendInScopeSetList` bindersOf bind
+
+ goAlt :: InScopeSet -> CoreAlt -> CoreAlt
+ goAlt in_scope (dc, pats, rhs) = (dc, pats, go in_scope' rhs)
+ where in_scope' = in_scope `extendInScopeSetList` pats
+
+ goBind :: InScopeSet -> CoreBind -> (CoreExpr -> CoreExpr)
+ goBind in_scope (NonRec v rhs) = Let (NonRec v (go in_scope rhs))
+ goBind in_scope (Rec pairs)
+ | is_join_rec = exitify in_scope' pairs'
+ | otherwise = Let (Rec pairs')
+ where pairs' = map (second (go in_scope')) pairs
+ is_join_rec = any (isJoinId . fst) pairs
+ in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
+
+-- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
+-- join-points outside the joinrec.
+exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr)
+exitify in_scope pairs =
+ \body ->mkExitLets exits (mkLetRec pairs' body)
+ where
+ mkExitLets ((exitId, exitRhs):exits') = mkLetNonRec exitId exitRhs . mkExitLets exits'
+ mkExitLets [] = id
+
+ -- We need the set of free variables of many subexpressions here, so
+ -- annotate the AST with them
+ -- see Note [Calculating free variables]
+ ann_pairs = map (second freeVars) pairs
+
+ -- Which are the recursive calls?
+ recursive_calls = mkVarSet $ map fst pairs
+
+ (pairs',exits) = (`runState` []) $ do
+ forM ann_pairs $ \(x,rhs) -> do
+ -- go past the lambdas of the join point
+ let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
+ body' <- go args body
+ let rhs' = mkLams args body'
+ return (x, rhs')
+
+ -- main working function. Goes through the RHS (tail-call positions only),
+ -- checks if there are no more recursive calls, if so, abstracts over
+ -- variables bound on the way and lifts it out as a join point.
+ --
+ -- It uses a state monad to keep track of floated binds
+ go :: [Var] -- ^ variables to abstract over
+ -> CoreExprWithFVs -- ^ current expression in tail position
+ -> State [(Id, CoreExpr)] CoreExpr
+
+ go captured ann_e
+ -- Do not touch an expression that is already a join jump where all arguments
+ -- are captured variables. See Note [Idempotency]
+ -- But _do_ float join jumps with interesting arguments.
+ -- See Note [Jumps can be interesting]
+ | (Var f, args) <- collectArgs e
+ , isJoinId f
+ , all isCapturedVarArg args
+ = return e
+
+ -- Do not touch a boring expression (see Note [Interesting expression])
+ | is_exit, not is_interesting = return e
+
+ -- Cannot float out if local join points are used, as
+ -- we cannot abstract over them
+ | is_exit, captures_join_points = return e
+
+ -- We have something to float out!
+ | is_exit = do
+ -- Assemble the RHS of the exit join point
+ let rhs = mkLams args e
+ ty = exprType rhs
+ let avoid = in_scope `extendInScopeSetList` captured
+ -- Remember this binding under a suitable name
+ v <- addExit avoid ty (length args) rhs
+ -- And jump to it from here
+ return $ mkVarApps (Var v) args
+ where
+ -- An exit expression has no recursive calls
+ is_exit = disjointVarSet fvs recursive_calls
+
+ -- Used to detect exit expressoins that are already proper exit jumps
+ isCapturedVarArg (Var v) = v `elem` captured
+ isCapturedVarArg _ = False
+
+ -- An interesting exit expression has free, non-imported
+ -- variables from outside the recursive group
+ -- See Note [Interesting expression]
+ is_interesting = anyVarSet isLocalId (fvs `minusVarSet` mkVarSet captured)
+
+ -- The possible arguments of this exit join point
+ args = filter (`elemVarSet` fvs) captured
+
+ -- We cannot abstract over join points
+ captures_join_points = any isJoinId args
+
+ e = deAnnotate ann_e
+ fvs = dVarSetToVarSet (freeVarsOf ann_e)
+
+
+ -- Case right hand sides are in tail-call position
+ go captured (_, AnnCase scrut bndr ty alts) = do
+ alts' <- forM alts $ \(dc, pats, rhs) -> do
+ rhs' <- go (captured ++ [bndr] ++ pats) rhs
+ return (dc, pats, rhs')
+ return $ Case (deAnnotate scrut) bndr ty alts'
+
+ go captured (_, AnnLet ann_bind body)
+ -- join point, RHS and body are in tail-call position
+ | AnnNonRec j rhs <- ann_bind
+ , Just join_arity <- isJoinId_maybe j
+ = do let (params, join_body) = collectNAnnBndrs join_arity rhs
+ join_body' <- go (captured ++ params) join_body
+ let rhs' = mkLams params join_body'
+ body' <- go (captured ++ [j]) body
+ return $ Let (NonRec j rhs') body'
+
+ -- rec join point, RHSs and body are in tail-call position
+ | AnnRec pairs <- ann_bind
+ , isJoinId (fst (head pairs))
+ = do let js = map fst pairs
+ pairs' <- forM pairs $ \(j,rhs) -> do
+ let join_arity = idJoinArity j
+ (params, join_body) = collectNAnnBndrs join_arity rhs
+ join_body' <- go (captured ++ js ++ params) join_body
+ let rhs' = mkLams params join_body'
+ return (j, rhs')
+ body' <- go (captured ++ js) body
+ return $ Let (Rec pairs') body'
+
+ -- normal Let, only the body is in tail-call position
+ | otherwise
+ = do body' <- go (captured ++ bindersOf bind ) body
+ return $ Let bind body'
+ where bind = deAnnBind ann_bind
+
+ go _ ann_e = return (deAnnotate ann_e)
+
+
+-- Picks a new unique, which is disjoint from
+-- * the free variables of the whole joinrec
+-- * any bound variables (captured)
+-- * any exit join points created so far.
+mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
+mkExitJoinId in_scope ty join_arity = do
+ fs <- get
+ let avoid = in_scope `extendInScopeSetList` (map fst fs)
+ `extendInScopeSet` exit_id_tmpl -- just cosmetics
+ return (uniqAway avoid exit_id_tmpl)
+ where
+ exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
+ `asJoinId` join_arity
+ `setIdOccInfo` exit_occ_info
+
+ -- See Note [Do not inline exit join points]
+ exit_occ_info =
+ OneOcc { occ_in_lam = True
+ , occ_one_br = True
+ , occ_int_cxt = False
+ , occ_tail = AlwaysTailCalled join_arity }
+
+addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId
+addExit in_scope ty join_arity rhs = do
+ -- Pick a suitable name
+ v <- mkExitJoinId in_scope ty join_arity
+ fs <- get
+ put ((v,rhs):fs)
+ return v
+
+
+type ExitifyM = State [(JoinId, CoreExpr)]
+
+{-
+Note [Interesting expression]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We do not want this to happen:
+
+ joinrec go 0 x y = x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+==>
+ join exit x = x
+ joinrec go 0 x y = jump exit x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+because the floated exit path (`x`) is simply a parameter of `go`; there are
+not useful interactions exposed this way.
+
+Neither do we want this to happen
+
+ joinrec go 0 x y = x+x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+==>
+ join exit x = x+x
+ joinrec go 0 x y = jump exit x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+where the floated expression `x+x` is a bit more complicated, but still not
+intersting.
+
+Expressions are interesting when they move an occurrence of a variable outside
+the recursive `go` that can benefit from being obviously called once, for example:
+ * a local thunk that can then be inlined (see example in note [Exitification])
+ * the parameter of a function, where the demand analyzer then can then
+ see that it is called at most once, and hence improve the function’s
+ strictness signature
+
+So we only hoist an exit expression out if it mentiones at least one free,
+non-imported variable.
+
+Note [Jumps can be interesting]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+A jump to a join point can be interesting, if its arguments contain free
+non-exported variables (z in the following example):
+
+ joinrec go 0 x y = jump j (x+z)
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+==>
+ join exit x y = jump j (x+z)
+ joinrec go 0 x y = jump exit x
+ go (n-1) x y = jump go (n-1) (x+y)
+
+
+The join point itself can be interesting, even if none if
+its arguments are (assume `g` to be an imported function that, on its own, does
+not make this interesting):
+
+ join j y = map f y
+ joinrec go 0 x y = jump j (map g x)
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+Here, `j` would not be inlined because we do not inline something that looks
+like an exit join point (see Note [Do not inline exit join points]).
+
+But after exitification we have
+
+ join j y = map f y
+ join exit x = jump j (map g x)
+ joinrec go 0 x y = jump j (map g x)
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+and now we can inline `j` and this will allow `map/map` to fire.
+
+
+Note [Idempotency]
+~~~~~~~~~~~~~~~~~~
+
+We do not want this to happen, where we replace the floated expression with
+essentially the same expression:
+
+ join exit x = t (x*x)
+ joinrec go 0 x y = jump exit x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+==>
+ join exit x = t (x*x)
+ join exit' x = jump exit x
+ joinrec go 0 x y = jump exit' x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+So when the RHS is a join jump, and all of its arguments are captured variables,
+then we leave it in place.
+
+Note that `jump exit x` in this example looks interesting, as `exit` is a free
+variable. Therefore, idempotency does not simply follow from floating only
+interesting expressions.
+
+Note [Calculating free variables]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We have two options where to annotate the tree with free variables:
+
+ A) The whole tree.
+ B) Each individual joinrec as we come across it.
+
+Downside of A: We pay the price on the whole module, even outside any joinrecs.
+Downside of B: We pay the price per joinrec, possibly multiple times when
+joinrecs are nested.
+
+Further downside of A: If the exitify function returns annotated expressions,
+it would have to ensure that the annotations are correct.
+
+
+Note [Do not inline exit join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+When we have
+
+ let t = foo bar
+ join exit x = t (x*x)
+ joinrec go 0 x y = jump exit x
+ go (n-1) x y = jump go (n-1) (x+y)
+ in …
+
+we do not want the simplifier to simply inline `exit` back in (which it happily
+would).
+
+To prevent this, we need to recognize exit join points, and then disable
+inlining.
+
+Exit join points, recognizeable using `isExitJoinId` are join points with an
+occurence in a recursive group, and can be recognized using `isExitJoinId`.
+This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
+because the lambdas of a non-recursive join point are not considered for
+`occ_in_lam`. For example, in the following code, `j1` is /not/ marked
+occ_in_lam, because `j2` is called only once.
+
+ join j1 x = x+1
+ join j2 y = join j1 (y+2)
+
+We create exit join point ids with such an `OccInfo`, see `exit_occ_info`.
+
+To prevent inlining, we check for that in `preInlineUnconditionally` directly.
+For `postInlineUnconditionally` and unfolding-based inlining, the function
+`simplLetUnfolding` simply gives exit join points no unfolding, which prevents
+this kind of inlining.
+
+Note [Placement of the exitification pass]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+I (Joachim) experimented with multiple positions for the Exitification pass in
+the Core2Core pipeline:
+
+ A) Before the `simpl_phases`
+ B) Between the `simpl_phases` and the "main" simplifier pass
+ C) After demand_analyser
+ D) Before the final simplification phase
+
+Here is the table (this is without inlining join exit points in the final
+simplifier run):
+
+ Program | Allocs | Instrs
+ | ABCD.log A.log B.log C.log D.log | ABCD.log A.log B.log C.log D.log
+----------------|---------------------------------------------------|-------------------------------------------------
+ fannkuch-redux | -99.9% +0.0% -99.9% -99.9% -99.9% | -3.9% +0.5% -3.0% -3.9% -3.9%
+ fasta | -0.0% +0.0% +0.0% -0.0% -0.0% | -8.5% +0.0% +0.0% -0.0% -8.5%
+ fem | 0.0% 0.0% 0.0% 0.0% +0.0% | -2.2% -0.1% -0.1% -2.1% -2.1%
+ fish | 0.0% 0.0% 0.0% 0.0% +0.0% | -3.1% +0.0% -1.1% -1.1% -0.0%
+ k-nucleotide | -91.3% -91.0% -91.0% -91.3% -91.3% | -6.3% +11.4% +11.4% -6.3% -6.2%
+ scs | -0.0% -0.0% -0.0% -0.0% -0.0% | -3.4% -3.0% -3.1% -3.3% -3.3%
+ simple | -6.0% 0.0% -6.0% -6.0% +0.0% | -3.4% +0.0% -5.2% -3.4% -0.1%
+ spectral-norm | -0.0% 0.0% 0.0% -0.0% +0.0% | -2.7% +0.0% -2.7% -5.4% -5.4%
+----------------|---------------------------------------------------|-------------------------------------------------
+ Min | -95.0% -91.0% -95.0% -95.0% -95.0% | -8.5% -3.0% -5.2% -6.3% -8.5%
+ Max | +0.2% +0.2% +0.2% +0.2% +1.5% | +0.4% +11.4% +11.4% +0.4% +1.5%
+ Geometric Mean | -4.7% -2.1% -4.7% -4.7% -4.6% | -0.4% +0.1% -0.1% -0.3% -0.2%
+
+Position A is disqualified, as it does not get rid of the allocations in
+fannkuch-redux.
+Position A and B are disqualified because it increases instructions in k-nucleotide.
+Positions C and D have their advantages: C decreases allocations in simpl, but D instructions in fasta.
+
+Assuming we have a budget of _one_ run of Exitification, then C wins (but we
+could get more from running it multiple times, as seen in fish).
+
+-}
diff --git a/compiler/simplCore/SimplCore.hs b/compiler/simplCore/SimplCore.hs
index afb0804aa8..be44ca86a9 100644
--- a/compiler/simplCore/SimplCore.hs
+++ b/compiler/simplCore/SimplCore.hs
@@ -45,6 +45,7 @@ import Specialise ( specProgram)
import SpecConstr ( specConstrProgram)
import DmdAnal ( dmdAnalProgram )
import CallArity ( callArityAnalProgram )
+import Exitify ( exitifyProgram )
import WorkWrap ( wwTopBinds )
import Vectorise ( vectorise )
import SrcLoc
@@ -122,6 +123,7 @@ getCoreToDo dflags
max_iter = maxSimplIterations dflags
rule_check = ruleCheck dflags
call_arity = gopt Opt_CallArity dflags
+ exitification = gopt Opt_Exitification dflags
strictness = gopt Opt_Strictness dflags
full_laziness = gopt Opt_FullLaziness dflags
do_specialise = gopt Opt_Specialise dflags
@@ -308,6 +310,9 @@ getCoreToDo dflags
runWhen strictness demand_analyser,
+ runWhen exitification CoreDoExitify,
+ -- See note [Placement of the exitification pass]
+
runWhen full_laziness $
CoreDoFloatOutwards FloatOutSwitches {
floatOutLambdas = floatLamArgs dflags,
@@ -476,6 +481,9 @@ doCorePass CoreDoStaticArgs = {-# SCC "StaticArgs" #-}
doCorePass CoreDoCallArity = {-# SCC "CallArity" #-}
doPassD callArityAnalProgram
+doCorePass CoreDoExitify = {-# SCC "Exitify" #-}
+ doPass exitifyProgram
+
doCorePass CoreDoStrictness = {-# SCC "NewStranal" #-}
doPassDFM dmdAnalProgram
diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs
index ebdda8f62a..9420081d84 100644
--- a/compiler/simplCore/SimplUtils.hs
+++ b/compiler/simplCore/SimplUtils.hs
@@ -1090,6 +1090,7 @@ preInlineUnconditionally env top_lvl bndr rhs
| isStableUnfolding (idUnfolding bndr) = False -- Note [Stable unfoldings and preInlineUnconditionally]
| isTopLevel top_lvl && isBottomingId bndr = False -- Note [Top-level bottoming Ids]
| isCoVar bndr = False -- Note [Do not inline CoVars unconditionally]
+ | isExitJoinId bndr = False
| otherwise = case idOccInfo bndr of
IAmDead -> True -- Happens in ((\x.1) v)
occ@OneOcc { occ_one_br = True }
diff --git a/compiler/simplCore/Simplify.hs b/compiler/simplCore/Simplify.hs
index adcd017454..1e1b6ee27e 100644
--- a/compiler/simplCore/Simplify.hs
+++ b/compiler/simplCore/Simplify.hs
@@ -51,6 +51,7 @@ import Util
import ErrUtils
import Module ( moduleName, pprModuleName )
+
{-
The guts of the simplifier is in this module, but the driver loop for
the simplifier is in SimplCore.hs.
@@ -3235,6 +3236,8 @@ simplLetUnfolding :: SimplEnv-> TopLevelFlag
simplLetUnfolding env top_lvl cont_mb id new_rhs unf
| isStableUnfolding unf
= simplStableUnfolding env top_lvl cont_mb id unf
+ | isExitJoinId id
+ = return noUnfolding -- see Note [Do not inline exit join points]
| otherwise
= mkLetUnfolding (seDynFlags env) top_lvl InlineRhs id new_rhs