summaryrefslogtreecommitdiff
path: root/compiler/simplCore
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/simplCore')
-rw-r--r--compiler/simplCore/FloatIn.hs139
-rw-r--r--compiler/simplCore/SimplUtils.hs7
-rw-r--r--compiler/simplCore/Simplify.hs28
3 files changed, 135 insertions, 39 deletions
diff --git a/compiler/simplCore/FloatIn.hs b/compiler/simplCore/FloatIn.hs
index 2593b1d7a1..04e4d32f5e 100644
--- a/compiler/simplCore/FloatIn.hs
+++ b/compiler/simplCore/FloatIn.hs
@@ -26,8 +26,9 @@ import MkCore
import HscTypes ( ModGuts(..) )
import CoreUtils
import CoreFVs
+import CoreUnfold
import CoreMonad ( CoreM )
-import Id ( isOneShotBndr, idType, isJoinId, isJoinId_maybe )
+import Id
import Var
import Type
import VarSet
@@ -151,7 +152,7 @@ fiExpr dflags to_drop (_, AnnCast expr (co_ann, co))
Cast (fiExpr dflags e_drop expr) co
where
[drop_here, e_drop, co_drop]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
[freeVarsOf expr, freeVarsOfAnn co_ann]
to_drop
@@ -173,7 +174,7 @@ fiExpr dflags to_drop ann_expr@(_,AnnApp {})
arg_fvs = map freeVarsOf ann_args
(drop_here : extra_drop : fun_drop : arg_drops)
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
(extra_fvs : fun_fvs : arg_fvs)
to_drop
-- Shortcut behaviour: if to_drop is empty,
@@ -446,7 +447,7 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr _ [(con,alt_bndrs,rhs)])
scrut_fvs = freeVarsOf scrut
[shared_binds, scrut_binds, rhs_binds]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
[scrut_fvs, rhs_fvs]
to_drop
@@ -456,16 +457,17 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts)
Case (fiExpr dflags scrut_drops scrut) case_bndr ty
(zipWith fi_alt alts_drops_s alts)
where
- -- Float into the scrut and alts-considered-together just like App
+ -- Float into the scrut and alts-considered-together just like App
[drop_here1, scrut_drops, alts_drops]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
[scrut_fvs, all_alts_fvs]
to_drop
- -- Float into the alts with the is_case flag set
+ -- Float into the alts with the SepCase context set
(drop_here2 : alts_drops_s)
| [ _ ] <- alts = [] : [alts_drops]
- | otherwise = sepBindsByDropPoint dflags True alts_fvs alts_drops
+ | otherwise = sepBindsByDropPoint dflags SepCase
+ alts_fvs alts_drops
scrut_fvs = freeVarsOf scrut
alts_fvs = map alt_fvs alts
@@ -491,7 +493,7 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
= ( extra_binds ++ shared_binds -- Land these before
-- See Note [extra_fvs (1,2)]
, FB (unitDVarSet id) rhs_fvs' -- The new binding itself
- (FloatLet (NonRec id rhs'))
+ (FloatLet (NonRec id rhs'))
, body_binds ) -- Land these after
where
@@ -508,7 +510,8 @@ fiBind dflags to_drop (AnnNonRec id ann_rhs@(rhs_fvs, rhs)) body_fvs
-- But do float into join points
[shared_binds, extra_binds, rhs_binds, body_binds]
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags
+ (if isJoinId id then SepNonRecJoin else SepVanilla)
[extra_fvs, rhs_fvs, body_fvs2]
to_drop
@@ -533,7 +536,7 @@ fiBind dflags to_drop (AnnRec bindings) body_fvs
, noFloatIntoRhs Recursive bndr rhs ]
(shared_binds:extra_binds:body_binds:rhss_binds)
- = sepBindsByDropPoint dflags False
+ = sepBindsByDropPoint dflags SepVanilla
(extra_fvs:body_fvs:rhss_fvs)
to_drop
@@ -654,9 +657,41 @@ We have to maintain the order on these drop-point-related lists.
-- pprFIB :: FloatInBinds -> SDoc
-- pprFIB fibs = text "FIB:" <+> ppr [b | FB _ _ b <- fibs]
+data SepCtxt
+ = SepCase
+ | SepNonRecJoin
+ | SepVanilla
+
+{- Note [Floating join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We push join-point bindings inwards merrily, just like let-bindings.
+They may get floated out again; e.g.
+ join j1 x = e1
+ in join j2 y = ...j1...
+ in ...
+==>
+ join j2 y = join { j1 x = e1 } in ...j1...
+ in ...
+
+Here we might float j1 out again. But we must float it in in case it
+allows an ordinary let-binding to go too. E.g.
+ let x = <thunk>
+ in join j1 x = e1
+ in join j2 y = ...j1...
+ in ...
+===>
+ join j2 y = let { x = <thunk }
+ in join { j1 x = e1 }
+ in ...j1...
+ in ...
+
+Ths is important; now the thunk for 'x' may not be allocated on the
+paths that don't involve j2.
+-}
+
sepBindsByDropPoint
:: DynFlags
- -> Bool -- True <=> is case expression
+ -> SepCtxt
-> [FreeVarSet] -- One set of FVs per drop point
-- Always at least two long!
-> FloatInBinds -- Candidate floaters
@@ -672,15 +707,15 @@ sepBindsByDropPoint
type DropBox = (FreeVarSet, FloatInBinds)
-sepBindsByDropPoint dflags is_case drop_pts floaters
+sepBindsByDropPoint dflags sep_ctxt drop_pts floaters
| null floaters -- Shortcut common case
= [] : [[] | _ <- drop_pts]
| otherwise
- = ASSERT( drop_pts `lengthAtLeast` 2 )
+ = ASSERT( n_alts >= 2 ) -- Invariant on caller
go floaters (map (\fvs -> (fvs, [])) (emptyDVarSet : drop_pts))
where
- n_alts = length drop_pts
+ n_alts = length drop_pts -- n_alts >= 2
go :: FloatInBinds -> [DropBox] -> [FloatInBinds]
-- The *first* one in the argument list is the drop_here set
@@ -697,18 +732,29 @@ sepBindsByDropPoint dflags is_case drop_pts floaters
(used_here : used_in_flags) = [ fvs `intersectsDVarSet` bndrs
| (fvs, _) <- drop_boxes]
- drop_here = used_here || cant_push
+ drop_here = used_here || not want_push
+ want_push = case sep_ctxt of
+ SepCase -> want_case_push
+ SepNonRecJoin -> want_join_push
+ SepVanilla -> want_let_push
n_used_alts = count id used_in_flags -- returns number of Trues in list.
- cant_push
- | is_case = n_used_alts == n_alts -- Used in all, don't push
+ no_duplication = n_used_alts <= 1 -- See Note [Duplicating floats]
+
+ duplicable_float = floatIsDupable dflags bind
+ -- True <=> duplication does not dup much code
+ -- (but it might still duplicate work!)
+ -- See Note [Duplicating floats]
+
+ want_case_push = -- n_used_alts < n_alts && -- Used in all alts, don't push
-- Remember n_alts > 1
- || (n_used_alts > 1 && not (floatIsDupable dflags bind))
- -- floatIsDupable: see Note [Duplicating floats]
+ (no_duplication || duplicable_float)
- | otherwise = floatIsCase bind || n_used_alts > 1
- -- floatIsCase: see Note [Floating primops]
+ want_let_push = not (floatIsCase bind) -- See Note [Floating primops]
+ && no_duplication
+
+ want_join_push = no_duplication -- See Note [Floating join points]
new_boxes | drop_here = (insert here_box : fork_boxes)
| otherwise = (here_box : new_fork_boxes)
@@ -727,18 +773,43 @@ sepBindsByDropPoint dflags is_case drop_pts floaters
{- Note [Duplicating floats]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+no_duplication is true if the binding us used in at most one
+alternative. (Zero is rare; it means the binding is dead.)
+
+If no_duplication is false, we may still float:
-For case expressions we duplicate the binding if it is reasonably
-small, and if it is not used in all the RHSs This is good for
-situations like
+* For /case expressions/ only (SepCase) we duplicate the binding if it
+ is reasonably small, and if it is not used in all the RHSs. This is
+ good for situations like
let x = I# y in
case e of
C -> error x
D -> error x
E -> ...not mentioning x...
-If the thing is used in all RHSs there is nothing gained,
-so we don't duplicate then.
+ If the thing is used in all RHSs there is nothing gained,
+ so we don't duplicate then.
+
+* This is NOT GOOD for other float-in places, like lets (SepVanilla).
+ Consider
+ let x = <small> in
+ let v = ...x...
+ in ...x...
+
+ We definitely don't want to duplicate x into the RHS of v and the
+ body! At least, it would be OK if <small> was a value; but we don't
+ test that.
+
+* For non-recursive join bindings (SepNonRecJoin) we must be equally
+ careful. Eg
+ let x = <small> in
+ join j = ...x...
+ in case f x of
+ A -> j
+ B -> something else
+ C -> j
+ Here we must not duplicate the let-x binding into the RHS of j
+ and the body, or we'll duplicate the redex.
-}
floatedBindsFVs :: FloatInBinds -> FreeVarSet
@@ -754,9 +825,19 @@ wrapFloats (FB _ _ fl : bs) e = wrapFloats bs (wrapFloat fl e)
floatIsDupable :: DynFlags -> FloatBind -> Bool
floatIsDupable dflags (FloatCase scrut _ _ _) = exprIsDupable dflags scrut
-floatIsDupable dflags (FloatLet (Rec prs)) = all (exprIsDupable dflags . snd) prs
-floatIsDupable dflags (FloatLet (NonRec _ r)) = exprIsDupable dflags r
+floatIsDupable dflags (FloatLet (Rec prs)) = -- all (exprIsDupable dflags . snd) prs
+ all (smallEnough dflags) prs
+floatIsDupable dflags (FloatLet (NonRec b r)) = -- exprIsDupable dflags r
+ smallEnough dflags (b,r)
+
+smallEnough :: DynFlags -> (Id,CoreExpr) -> Bool
+smallEnough dflags (_,rhs)
+ = couldBeSmallEnoughToInline dflags (ufUseThreshold dflags) rhs
floatIsCase :: FloatBind -> Bool
floatIsCase (FloatCase {}) = True
floatIsCase (FloatLet {}) = False
+
+--floatIsJoin :: FloatBind -> Bool
+--floatIsJoin (FloatCase {}) = False
+--floatIsJoin (FloatLet b) = isJoinBind b
diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs
index ca1b9bd23d..82d20e20c5 100644
--- a/compiler/simplCore/SimplUtils.hs
+++ b/compiler/simplCore/SimplUtils.hs
@@ -1271,6 +1271,7 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs
| exprIsTrivial rhs = True
| otherwise
= case occ_info of
+{-
-- The point of examining occ_info here is that for *non-values*
-- that occur outside a lambda, the call-site inliner won't have
-- a chance (because it doesn't know that the thing
@@ -1285,7 +1286,8 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs
-- in allocation if you miss this out
OneOcc { occ_in_lam = in_lam, occ_int_cxt = int_cxt }
-- OneOcc => no code-duplication issue
- -> smallEnoughToInline dflags unfolding -- Small enough to dup
+ -> not (isJoinId bndr) -- NEW!
+ && smallEnoughToInline dflags unfolding -- Small enough to dup
-- ToDo: consider discount on smallEnoughToInline if int_cxt is true
--
-- NB: Do NOT inline arbitrarily big things, even if one_br is True
@@ -1310,6 +1312,7 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs
-- int_cxt to prevent us inlining inside a lambda without some
-- good reason. See the notes on int_cxt in preInlineUnconditionally
+-}
IAmDead -> True -- This happens; for example, the case_bndr during case of
-- known constructor: case (a,b) of x { (p,q) -> ... }
-- Here x isn't mentioned in the RHS, so we don't want to
@@ -1331,7 +1334,7 @@ postInlineUnconditionally env top_lvl bndr occ_info rhs
where
unfolding = idUnfolding bndr
- dflags = seDynFlags env
+ _dflags = seDynFlags env
active = isActive (sm_phase (getMode env)) (idInlineActivation bndr)
-- See Note [pre/postInlineUnconditionally in gentle mode]
diff --git a/compiler/simplCore/Simplify.hs b/compiler/simplCore/Simplify.hs
index 872973925f..e3237bfcee 100644
--- a/compiler/simplCore/Simplify.hs
+++ b/compiler/simplCore/Simplify.hs
@@ -45,7 +45,7 @@ import BasicTypes ( TopLevelFlag(..), isNotTopLevel, isTopLevel,
RecFlag(..), Arity )
import MonadUtils ( mapAccumLM, liftIO )
import Var ( isTyCoVar )
-import Maybes ( orElse )
+import Maybes ( isJust, orElse )
import Control.Monad
import Outputable
import FastString
@@ -326,6 +326,7 @@ simplNonRecX :: SimplEnv
-- simplified, notably in knownCon. It uses case-binding where necessary.
--
-- Precondition: rhs satisfies the let/app invariant
+-- Not used for JoinIds
simplNonRecX env bndr new_rhs
| ASSERT2( not (isJoinId bndr), ppr bndr )
@@ -350,6 +351,7 @@ completeNonRecX :: TopLevelFlag -> SimplEnv
-> SimplM (SimplFloats, SimplEnv) -- The new binding is in the floats
-- Precondition: rhs satisfies the let/app invariant
-- See Note [CoreSyn let/app invariant] in CoreSyn
+-- Not used for JoinIds
completeNonRecX top_lvl env is_strict old_bndr new_bndr new_rhs
= ASSERT2( not (isJoinId new_bndr), ppr new_bndr )
@@ -549,7 +551,7 @@ makeTrivialWithInfo mode top_lvl occ_fs info expr
-- Now something very like completeBind,
-- but without the postInlineUnconditinoally part
; (arity, is_bot, expr2) <- tryEtaExpandRhs mode var expr1
- ; unf <- mkLetUnfolding (sm_dflags mode) top_lvl InlineRhs var expr2
+ ; unf <- simplVanillaUnfolding (sm_dflags mode) top_lvl InlineRhs var expr2
; let final_id = addLetBndrInfo var arity is_bot unf
bind = NonRec final_id expr2
@@ -3390,15 +3392,25 @@ simplLetUnfolding :: SimplEnv-> TopLevelFlag
simplLetUnfolding env top_lvl cont_mb id new_rhs rhs_ty unf
| isStableUnfolding unf
= simplStableUnfolding env top_lvl cont_mb id unf rhs_ty
- | isExitJoinId id
+
+ | isJust cont_mb -- A join point
+ = simplJoinUnfolding env id new_rhs
+
+ | otherwise
+ = simplVanillaUnfolding (seDynFlags env) top_lvl InlineRhs id new_rhs
+
+-------------------
+simplJoinUnfolding :: SimplEnv -> InId -> OutExpr -> SimplM Unfolding
+simplJoinUnfolding env join_id new_rhs
+ | isExitJoinId join_id
= return noUnfolding -- See Note [Do not inline exit join points] in Exitify
| otherwise
- = mkLetUnfolding (seDynFlags env) top_lvl InlineRhs id new_rhs
+ = return (mkJoinUnfolding (seDynFlags env) new_rhs)
-------------------
-mkLetUnfolding :: DynFlags -> TopLevelFlag -> UnfoldingSource
- -> InId -> OutExpr -> SimplM Unfolding
-mkLetUnfolding dflags top_lvl src id new_rhs
+simplVanillaUnfolding :: DynFlags -> TopLevelFlag -> UnfoldingSource
+ -> InId -> OutExpr -> SimplM Unfolding
+simplVanillaUnfolding dflags top_lvl src id new_rhs
= is_bottoming `seq` -- See Note [Force bottoming field]
return (mkUnfolding dflags src is_top_lvl is_bottoming new_rhs)
-- We make an unfolding *even for loop-breakers*.
@@ -3456,7 +3468,7 @@ simplStableUnfolding env top_lvl mb_cont id unf rhs_ty
-- See Note [Top-level flag on inline rules] in CoreUnfold
_other -- Happens for INLINABLE things
- -> mkLetUnfolding dflags top_lvl src id expr' }
+ -> simplVanillaUnfolding dflags top_lvl src id expr' }
-- If the guidance is UnfIfGoodArgs, this is an INLINABLE
-- unfolding, and we need to make sure the guidance is kept up
-- to date with respect to any changes in the unfolding.