summaryrefslogtreecommitdiff
path: root/compiler/GHC/Core/Opt/Arity.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/GHC/Core/Opt/Arity.hs')
-rw-r--r--compiler/GHC/Core/Opt/Arity.hs242
1 files changed, 136 insertions, 106 deletions
diff --git a/compiler/GHC/Core/Opt/Arity.hs b/compiler/GHC/Core/Opt/Arity.hs
index 2b2a7c20ea..7891012792 100644
--- a/compiler/GHC/Core/Opt/Arity.hs
+++ b/compiler/GHC/Core/Opt/Arity.hs
@@ -13,9 +13,12 @@
-- | Arity and eta expansion
module GHC.Core.Opt.Arity
( manifestArity, joinRhsArity, exprArity, typeArity
- , exprEtaExpandArity, findRhsArity, etaExpand
+ , exprEtaExpandArity, findRhsArity
+ , etaExpand, etaExpandAT
, etaExpandToJoinPoint, etaExpandToJoinPointRule
, exprBotStrictness_maybe
+ , ArityType(..), expandableArityType, arityTypeArity
+ , maxWithArity, isBotArityType, idArityType
)
where
@@ -42,7 +45,7 @@ import GHC.Types.Unique
import GHC.Driver.Session ( DynFlags, GeneralFlag(..), gopt )
import GHC.Utils.Outputable
import GHC.Data.FastString
-import GHC.Utils.Misc ( debugIsOn )
+import GHC.Utils.Misc ( lengthAtLeast )
{-
************************************************************************
@@ -486,8 +489,11 @@ Then f :: AT [False,False] ATop
-------------------- Main arity code ----------------------------
-}
--- See Note [ArityType]
-data ArityType = ATop [OneShotInfo] | ABot Arity
+
+data ArityType -- See Note [ArityType]
+ = ATop [OneShotInfo]
+ | ABot Arity
+ deriving( Eq )
-- There is always an explicit lambda
-- to justify the [OneShot], or the Arity
@@ -495,18 +501,45 @@ instance Outputable ArityType where
ppr (ATop os) = text "ATop" <> parens (ppr (length os))
ppr (ABot n) = text "ABot" <> parens (ppr n)
+arityTypeArity :: ArityType -> Arity
+-- The number of value args for the arity type
+arityTypeArity (ATop oss) = length oss
+arityTypeArity (ABot ar) = ar
+
+expandableArityType :: ArityType -> Bool
+-- True <=> eta-expansion will add at least one lambda
+expandableArityType (ATop oss) = not (null oss)
+expandableArityType (ABot ar) = ar /= 0
+
+isBotArityType :: ArityType -> Bool
+isBotArityType (ABot {}) = True
+isBotArityType (ATop {}) = False
+
+arityTypeOneShots :: ArityType -> [OneShotInfo]
+arityTypeOneShots (ATop oss) = oss
+arityTypeOneShots (ABot ar) = replicate ar OneShotLam
+ -- If we are diveging or throwing an exception anyway
+ -- it's fine to push redexes inside the lambdas
+
+botArityType :: ArityType
+botArityType = ABot 0 -- Unit for andArityType
+
+maxWithArity :: ArityType -> Arity -> ArityType
+maxWithArity at@(ABot {}) _ = at
+maxWithArity at@(ATop oss) ar
+ | oss `lengthAtLeast` ar = at
+ | otherwise = ATop (take ar (oss ++ repeat NoOneShotInfo))
+
vanillaArityType :: ArityType
vanillaArityType = ATop [] -- Totally uninformative
-- ^ The Arity returned is the number of value args the
-- expression can be applied to without doing much work
-exprEtaExpandArity :: DynFlags -> CoreExpr -> Arity
+exprEtaExpandArity :: DynFlags -> CoreExpr -> ArityType
-- exprEtaExpandArity is used when eta expanding
-- e ==> \xy -> e x y
exprEtaExpandArity dflags e
- = case (arityType env e) of
- ATop oss -> length oss
- ABot n -> n
+ = arityType env e
where
env = AE { ae_cheap_fn = mk_cheap_fn dflags isCheapApp
, ae_ped_bot = gopt Opt_PedanticBottoms dflags
@@ -529,7 +562,7 @@ mk_cheap_fn dflags cheap_app
----------------------
-findRhsArity :: DynFlags -> Id -> CoreExpr -> Arity -> (Arity, Bool)
+findRhsArity :: DynFlags -> Id -> CoreExpr -> Arity -> ArityType
-- This implements the fixpoint loop for arity analysis
-- See Note [Arity analysis]
-- If findRhsArity e = (n, is_bot) then
@@ -543,44 +576,34 @@ findRhsArity dflags bndr rhs old_arity
-- we stop right away (since arities should not decrease)
-- Result: the common case is that there is just one iteration
where
- is_lam = has_lam rhs
-
- has_lam (Tick _ e) = has_lam e
- has_lam (Lam b e) = isId b || has_lam e
- has_lam _ = False
-
init_cheap_app :: CheapAppFun
init_cheap_app fn n_val_args
| fn == bndr = True -- On the first pass, this binder gets infinite arity
| otherwise = isCheapApp fn n_val_args
- go :: (Arity, Bool) -> (Arity, Bool)
- go cur_info@(cur_arity, _)
- | cur_arity <= old_arity = cur_info
- | new_arity == cur_arity = cur_info
- | otherwise = ASSERT( new_arity < cur_arity )
+ go :: ArityType -> ArityType
+ go cur_atype
+ | cur_arity <= old_arity = cur_atype
+ | new_atype == cur_atype = cur_atype
+ | otherwise =
#if defined(DEBUG)
pprTrace "Exciting arity"
- (vcat [ ppr bndr <+> ppr cur_arity <+> ppr new_arity
+ (vcat [ ppr bndr <+> ppr cur_atype <+> ppr new_atype
, ppr rhs])
#endif
- go new_info
+ go new_atype
where
- new_info@(new_arity, _) = get_arity cheap_app
+ new_atype = get_arity cheap_app
+ cur_arity = arityTypeArity cur_atype
cheap_app :: CheapAppFun
cheap_app fn n_val_args
| fn == bndr = n_val_args < cur_arity
| otherwise = isCheapApp fn n_val_args
- get_arity :: CheapAppFun -> (Arity, Bool)
- get_arity cheap_app
- = case (arityType env rhs) of
- ABot n -> (n, True)
- ATop (os:oss) | isOneShotInfo os || is_lam
- -> (1 + length oss, False) -- Don't expand PAPs/thunks
- ATop _ -> (0, False) -- Note [Eta expanding thunks]
- where
+ get_arity :: CheapAppFun -> ArityType
+ get_arity cheap_app = arityType env rhs
+ where
env = AE { ae_cheap_fn = mk_cheap_fn dflags cheap_app
, ae_ped_bot = gopt Opt_PedanticBottoms dflags
, ae_joins = emptyVarSet }
@@ -613,7 +636,6 @@ write the analysis loop.
The analysis is cheap-and-cheerful because it doesn't deal with
mutual recursion. But the self-recursive case is the important one.
-
Note [Eta expanding through dictionaries]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If the experimental -fdicts-cheap flag is on, we eta-expand through
@@ -632,24 +654,6 @@ The (foo DInt) is floated out, and makes ineffective a RULE
One could go further and make exprIsCheap reply True to any
dictionary-typed expression, but that's more work.
-
-Note [Eta expanding thunks]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~
-We don't eta-expand
- * Trivial RHSs x = y
- * PAPs x = map g
- * Thunks f = case y of p -> \x -> blah
-
-When we see
- f = case y of p -> \x -> blah
-should we eta-expand it? Well, if 'x' is a one-shot state token
-then 'yes' because 'f' will only be applied once. But otherwise
-we (conservatively) say no. My main reason is to avoid expanding
-PAPSs
- f = g d ==> f = \x. g d x
-because that might in turn make g inline (if it has an inline pragma),
-which we might not want. After all, INLINE pragmas say "inline only
-when saturated" so we don't want to be too gung-ho about saturating!
-}
arityLam :: Id -> ArityType -> ArityType
@@ -673,6 +677,7 @@ arityApp (ATop []) _ = ATop []
arityApp (ATop (_:as)) cheap = floatIn cheap (ATop as)
andArityType :: ArityType -> ArityType -> ArityType -- Used for branches of a 'case'
+-- This is least upper bound in the ArityType lattice
andArityType (ABot n1) (ABot n2) = ABot (n1 `max` n2) -- Note [ABot branches: use max]
andArityType (ATop as) (ABot _) = ATop as
andArityType (ABot _) (ATop bs) = ATop bs
@@ -754,8 +759,7 @@ arityType :: ArityEnv -> CoreExpr -> ArityType
arityType env (Cast e co)
= case arityType env e of
- ATop os -> ATop (take co_arity os)
- -- See Note [Arity trimming]
+ ATop os -> ATop (take co_arity os) -- See Note [Arity trimming]
ABot n | co_arity < n -> ATop (replicate co_arity noOneShotInfo)
| otherwise -> ABot n
where
@@ -769,19 +773,9 @@ arityType env (Cast e co)
arityType env (Var v)
| v `elemVarSet` ae_joins env
- = ABot 0 -- See Note [Eta-expansion and join points]
-
- | strict_sig <- idStrictness v
- , not $ isTopSig strict_sig
- , (ds, res) <- splitStrictSig strict_sig
- , let arity = length ds
- = if isDeadEndDiv res then ABot arity
- else ATop (take arity one_shots)
+ = botArityType -- See Note [Eta-expansion and join points]
| otherwise
- = ATop (take (idArity v) one_shots)
- where
- one_shots :: [OneShotInfo] -- One-shot-ness derived from the type
- one_shots = typeArity (idType v)
+ = idArityType v
-- Lambdas; increase arity
arityType env (Lam x e)
@@ -804,13 +798,13 @@ arityType env (App fun arg )
--
arityType env (Case scrut _ _ alts)
| exprIsDeadEnd scrut || null alts
- = ABot 0 -- Do not eta expand
- -- See Note [Dealing with bottom (1)]
+ = botArityType -- Do not eta expand
+ -- See Note [Dealing with bottom (1)]
| otherwise
= case alts_type of
- ABot n | n>0 -> ATop [] -- Don't eta expand
- | otherwise -> ABot 0 -- if RHS is bottomming
- -- See Note [Dealing with bottom (2)]
+ ABot n | n>0 -> ATop [] -- Don't eta expand
+ | otherwise -> botArityType -- if RHS is bottomming
+ -- See Note [Dealing with bottom (2)]
ATop as | not (ae_ped_bot env) -- See Note [Dealing with bottom (3)]
, ae_cheap_fn env scrut Nothing -> ATop as
@@ -886,7 +880,8 @@ So we do this:
body of the let.
* Dually, when we come to a /call/ of a join point, just no-op
- by returning (ABot 0), the neutral element of ArityType.
+ by returning botArityType, the bottom element of ArityType,
+ which so that: bot `andArityType` x = x
* This works if the join point is bound in the expression we are
taking the arityType of. But if it's bound further out, it makes
@@ -905,6 +900,20 @@ An alternative (roughly equivalent) idea would be to carry an
environment mapping let-bound Ids to their ArityType.
-}
+idArityType :: Id -> ArityType
+idArityType v
+ | strict_sig <- idStrictness v
+ , not $ isTopSig strict_sig
+ , (ds, res) <- splitStrictSig strict_sig
+ , let arity = length ds
+ = if isDeadEndDiv res then ABot arity
+ else ATop (take arity one_shots)
+ | otherwise
+ = ATop (take (idArity v) one_shots)
+ where
+ one_shots :: [OneShotInfo] -- One-shot-ness derived from the type
+ one_shots = typeArity (idType v)
+
{-
%************************************************************************
%* *
@@ -1001,6 +1010,25 @@ which we want to lead to code like
This means that we need to look through type applications and be ready
to re-add floats on the top.
+Note [Eta expansion with ArityType]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+The etaExpandAT function takes an ArityType (not just an Arity) to
+guide eta-expansion. Why? Because we want to preserve one-shot info.
+Consider
+ foo = \x. case x of
+ True -> (\s{os}. blah) |> co
+ False -> wubble
+We'll get an ArityType for foo of (ATop [NoOneShot,OneShot]).
+
+Then we want to eta-expand to
+ foo = \x. (\eta{os}. (case x of ...as before...) eta) |> some_co
+
+That 'eta' binder is fresh, and we really want it to have the
+one-shot flag from the inner \s{osf}. By expanding with the
+ArityType gotten from analysing the RHS, we achieve this neatly.
+
+This makes a big difference to the one-shot monad trick;
+see Note [The one-shot state monad trick] in GHC.Core.Unify.
-}
-- | @etaExpand n e@ returns an expression with
@@ -1013,11 +1041,16 @@ to re-add floats on the top.
-- We should have that:
--
-- > ty = exprType e = exprType e'
-etaExpand :: Arity -- ^ Result should have this number of value args
- -> CoreExpr -- ^ Expression to expand
- -> CoreExpr
+etaExpand :: Arity -> CoreExpr -> CoreExpr
+etaExpandAT :: ArityType -> CoreExpr -> CoreExpr
+
+etaExpand n orig_expr = eta_expand (replicate n NoOneShotInfo) orig_expr
+etaExpandAT at orig_expr = eta_expand (arityTypeOneShots at) orig_expr
+ -- See Note [Eta expansion with ArityType]
+
-- etaExpand arity e = res
-- Then 'res' has at least 'arity' lambdas at the top
+-- See Note [Eta expansion with ArityType]
--
-- etaExpand deals with for-alls. For example:
-- etaExpand 1 E
@@ -1028,21 +1061,23 @@ etaExpand :: Arity -- ^ Result should have this number of value arg
-- It deals with coerces too, though they are now rare
-- so perhaps the extra code isn't worth it
-etaExpand n orig_expr
- = go n orig_expr
+eta_expand :: [OneShotInfo] -> CoreExpr -> CoreExpr
+eta_expand one_shots orig_expr
+ = go one_shots orig_expr
where
-- Strip off existing lambdas and casts before handing off to mkEtaWW
-- Note [Eta expansion and SCCs]
- go 0 expr = expr
- go n (Lam v body) | isTyVar v = Lam v (go n body)
- | otherwise = Lam v (go (n-1) body)
- go n (Cast expr co) = Cast (go n expr) co
- go n expr
+ go [] expr = expr
+ go oss@(_:oss1) (Lam v body) | isTyVar v = Lam v (go oss body)
+ | otherwise = Lam v (go oss1 body)
+ go oss (Cast expr co) = Cast (go oss expr) co
+
+ go oss expr
= -- pprTrace "ee" (vcat [ppr orig_expr, ppr expr, ppr etas]) $
retick $ etaInfoAbs etas (etaInfoApp subst' sexpr etas)
where
in_scope = mkInScopeSet (exprFreeVars expr)
- (in_scope', etas) = mkEtaWW n (ppr orig_expr) in_scope (exprType expr)
+ (in_scope', etas) = mkEtaWW oss (ppr orig_expr) in_scope (exprType expr)
subst' = mkEmptySubst in_scope'
-- Find ticks behind type apps.
@@ -1141,7 +1176,7 @@ etaInfoAppTy _ (EtaCo co : eis) = etaInfoAppTy (coercionRKind co) eis
-- semantically-irrelevant source annotations, so call sites must take care to
-- preserve that info. See Note [Eta expansion and SCCs].
mkEtaWW
- :: Arity
+ :: [OneShotInfo]
-- ^ How many value arguments to eta-expand
-> SDoc
-- ^ The pretty-printed original expression, for warnings.
@@ -1153,36 +1188,29 @@ mkEtaWW
-- The outgoing 'InScopeSet' extends the incoming 'InScopeSet' with the
-- fresh variables in 'EtaInfo'.
-mkEtaWW orig_n ppr_orig_expr in_scope orig_ty
- = go orig_n empty_subst orig_ty []
+mkEtaWW orig_oss ppr_orig_expr in_scope orig_ty
+ = go 0 orig_oss empty_subst orig_ty []
where
empty_subst = mkEmptyTCvSubst in_scope
- go :: Arity -- Number of value args to expand to
+ go :: Int -- For fresh names
+ -> [OneShotInfo] -- Number of value args to expand to
-> TCvSubst -> Type -- We are really looking at subst(ty)
-> [EtaInfo] -- Accumulating parameter
-> (InScopeSet, [EtaInfo])
- go n subst ty eis -- See Note [exprArity invariant]
-
+ go _ [] subst _ eis -- See Note [exprArity invariant]
----------- Done! No more expansion needed
- | n == 0
= (getTCvInScope subst, reverse eis)
+ go n oss@(one_shot:oss1) subst ty eis -- See Note [exprArity invariant]
----------- Forall types (forall a. ty)
| Just (tcv,ty') <- splitForAllTy_maybe ty
- , let (subst', tcv') = Type.substVarBndr subst tcv
- = let ((n_subst, n_tcv), n_n)
- -- We want to have at least 'n' lambdas at the top.
- -- If tcv is a tyvar, it corresponds to one Lambda (/\).
- -- And we won't reduce n.
- -- If tcv is a covar, we could eta-expand the expr with one
- -- lambda \co:ty. e co. In this case we generate a new variable
- -- of the coercion type, update the scope, and reduce n by 1.
- | isTyVar tcv = ((subst', tcv'), n)
- -- covar case:
- | otherwise = (freshEtaId n subst' (unrestricted (varType tcv')), n-1)
- -- Avoid free vars of the original expression
- in go n_n n_subst ty' (EtaVar n_tcv : eis)
+ , (subst', tcv') <- Type.substVarBndr subst tcv
+ , let oss' | isTyVar tcv = oss
+ | otherwise = oss1
+ -- A forall can bind a CoVar, in which case
+ -- we consume one of the [OneShotInfo]
+ = go n oss' subst' ty' (EtaVar tcv' : eis)
----------- Function types (t1 -> t2)
| Just (mult, arg_ty, res_ty) <- splitFunTy_maybe ty
@@ -1190,9 +1218,11 @@ mkEtaWW orig_n ppr_orig_expr in_scope orig_ty
-- See Note [Levity polymorphism invariants] in GHC.Core
-- See also test case typecheck/should_run/EtaExpandLevPoly
- , let (subst', eta_id') = freshEtaId n subst (Scaled mult arg_ty)
- -- Avoid free vars of the original expression
- = go (n-1) subst' res_ty (EtaVar eta_id' : eis)
+ , (subst', eta_id) <- freshEtaId n subst (Scaled mult arg_ty)
+ -- Avoid free vars of the original expression
+
+ , let eta_id' = eta_id `setIdOneShotInfo` one_shot
+ = go (n+1) oss1 subst' res_ty (EtaVar eta_id' : eis)
----------- Newtypes
-- Given this:
@@ -1206,12 +1236,12 @@ mkEtaWW orig_n ppr_orig_expr in_scope orig_ty
-- Remember to apply the substitution to co (#16979)
-- (or we could have applied to ty, but then
-- we'd have had to zap it for the recursive call)
- = go n subst ty' (pushCoercion co' eis)
+ = go n oss subst ty' (pushCoercion co' eis)
| otherwise -- We have an expression of arity > 0,
-- but its type isn't a function, or a binder
-- is levity-polymorphic
- = WARN( True, (ppr orig_n <+> ppr orig_ty) $$ ppr_orig_expr )
+ = WARN( True, (ppr orig_oss <+> ppr orig_ty) $$ ppr_orig_expr )
(getTCvInScope subst, reverse eis)
-- This *can* legitimately happen:
-- e.g. coerce Int (\x. x) Essentially the programmer is