summaryrefslogtreecommitdiff
path: root/compiler/coreSyn/CoreArity.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/coreSyn/CoreArity.hs')
-rw-r--r--compiler/coreSyn/CoreArity.hs82
1 files changed, 46 insertions, 36 deletions
diff --git a/compiler/coreSyn/CoreArity.hs b/compiler/coreSyn/CoreArity.hs
index 3f429d1ad2..d15da87aac 100644
--- a/compiler/coreSyn/CoreArity.hs
+++ b/compiler/coreSyn/CoreArity.hs
@@ -18,6 +18,8 @@ module CoreArity (
#include "HsVersions.h"
+import GhcPrelude
+
import CoreSyn
import CoreFVs
import CoreUtils
@@ -521,61 +523,60 @@ mk_cheap_fn dflags cheap_app
----------------------
-findRhsArity :: DynFlags -> Id -> CoreExpr -> Arity -> Arity
+findRhsArity :: DynFlags -> Id -> CoreExpr -> Arity -> (Arity, Bool)
-- This implements the fixpoint loop for arity analysis
-- See Note [Arity analysis]
+-- If findRhsArity e = (n, is_bot) then
+-- (a) any application of e to <n arguments will not do much work,
+-- so it is safe to expand e ==> (\x1..xn. e x1 .. xn)
+-- (b) if is_bot=True, then e applied to n args is guaranteed bottom
findRhsArity dflags bndr rhs old_arity
- = go (rhsEtaExpandArity dflags init_cheap_app rhs)
+ = go (get_arity init_cheap_app)
-- We always call exprEtaExpandArity once, but usually
-- that produces a result equal to old_arity, and then
-- we stop right away (since arities should not decrease)
-- Result: the common case is that there is just one iteration
where
+ is_lam = has_lam rhs
+
+ has_lam (Tick _ e) = has_lam e
+ has_lam (Lam b e) = isId b || has_lam e
+ has_lam _ = False
+
init_cheap_app :: CheapAppFun
init_cheap_app fn n_val_args
| fn == bndr = True -- On the first pass, this binder gets infinite arity
| otherwise = isCheapApp fn n_val_args
- go :: Arity -> Arity
- go cur_arity
- | cur_arity <= old_arity = cur_arity
- | new_arity == cur_arity = cur_arity
+ go :: (Arity, Bool) -> (Arity, Bool)
+ go cur_info@(cur_arity, _)
+ | cur_arity <= old_arity = cur_info
+ | new_arity == cur_arity = cur_info
| otherwise = ASSERT( new_arity < cur_arity )
#if defined(DEBUG)
pprTrace "Exciting arity"
(vcat [ ppr bndr <+> ppr cur_arity <+> ppr new_arity
- , ppr rhs])
+ , ppr rhs])
#endif
- go new_arity
+ go new_info
where
- new_arity = rhsEtaExpandArity dflags cheap_app rhs
+ new_info@(new_arity, _) = get_arity cheap_app
cheap_app :: CheapAppFun
cheap_app fn n_val_args
| fn == bndr = n_val_args < cur_arity
| otherwise = isCheapApp fn n_val_args
--- ^ The Arity returned is the number of value args the
--- expression can be applied to without doing much work
-rhsEtaExpandArity :: DynFlags -> CheapAppFun -> CoreExpr -> Arity
--- exprEtaExpandArity is used when eta expanding
--- e ==> \xy -> e x y
-rhsEtaExpandArity dflags cheap_app e
- = case (arityType env e) of
- ATop (os:oss)
- | isOneShotInfo os || has_lam e -> 1 + length oss
- -- Don't expand PAPs/thunks
- -- Note [Eta expanding thunks]
- | otherwise -> 0
- ATop [] -> 0
- ABot n -> n
- where
- env = AE { ae_cheap_fn = mk_cheap_fn dflags cheap_app
- , ae_ped_bot = gopt Opt_PedanticBottoms dflags }
-
- has_lam (Tick _ e) = has_lam e
- has_lam (Lam b e) = isId b || has_lam e
- has_lam _ = False
+ get_arity :: CheapAppFun -> (Arity, Bool)
+ get_arity cheap_app
+ = case (arityType env rhs) of
+ ABot n -> (n, True)
+ ATop (os:oss) | isOneShotInfo os || is_lam
+ -> (1 + length oss, False) -- Don't expand PAPs/thunks
+ ATop _ -> (0, False) -- Note [Eta expanding thunks]
+ where
+ env = AE { ae_cheap_fn = mk_cheap_fn dflags cheap_app
+ , ae_ped_bot = gopt Opt_PedanticBottoms dflags }
{-
Note [Arity analysis]
@@ -936,7 +937,7 @@ etaExpand n orig_expr
-- See Note [Eta expansion and source notes]
(expr', args) = collectArgs expr
(ticks, expr'') = stripTicksTop tickishFloatable expr'
- sexpr = foldl App expr'' args
+ sexpr = foldl' App expr'' args
retick expr = foldr mkTick expr ticks
-- Abstraction Application
@@ -1036,10 +1037,19 @@ mkEtaWW orig_n orig_expr in_scope orig_ty
| n == 0
= (getTCvInScope subst, reverse eis)
- | Just (tv,ty') <- splitForAllTy_maybe ty
- , let (subst', tv') = Type.substTyVarBndr subst tv
+ | Just (tcv,ty') <- splitForAllTy_maybe ty
+ , let (subst', tcv') = Type.substVarBndr subst tcv
+ = let ((n_subst, n_tcv), n_n)
+ -- We want to have at least 'n' lambdas at the top.
+ -- If tcv is a tyvar, it corresponds to one Lambda (/\).
+ -- And we won't reduce n.
+ -- If tcv is a covar, we could eta-expand the expr with one
+ -- lambda \co:ty. e co. In this case we generate a new variable
+ -- of the coercion type, update the scope, and reduce n by 1.
+ | isTyVar tcv = ((subst', tcv'), n)
+ | otherwise = (freshEtaId n subst' (varType tcv'), n-1)
-- Avoid free vars of the original expression
- = go n subst' ty' (EtaVar tv' : eis)
+ in go n_n n_subst ty' (EtaVar n_tcv : eis)
| Just (arg_ty, res_ty) <- splitFunTy_maybe ty
, not (isTypeLevPoly arg_ty)
@@ -1122,8 +1132,8 @@ etaBodyForJoinPoint need_args body
= (reverse rev_bs, e)
go n ty subst rev_bs e
| Just (tv, res_ty) <- splitForAllTy_maybe ty
- , let (subst', tv') = Type.substTyVarBndr subst tv
- = go (n-1) res_ty subst' (tv' : rev_bs) (e `App` Type (mkTyVarTy tv'))
+ , let (subst', tv') = Type.substVarBndr subst tv
+ = go (n-1) res_ty subst' (tv' : rev_bs) (e `App` varToCoreExpr tv')
| Just (arg_ty, res_ty) <- splitFunTy_maybe ty
, let (subst', b) = freshEtaId n subst arg_ty
= go (n-1) res_ty subst' (b : rev_bs) (e `App` Var b)