summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2022-08-29 10:34:09 +0100
committerMarge Bot <ben+marge-bot@smart-cactus.org>2022-08-31 03:53:54 -0400
commit7f490b1333c17ed27b213d6af8c7275aa9b3de63 (patch)
treefedd410442948ea663f55f7df130e4480cb8a868
parent3a00263248e6176e06f03b7390fade48a9adb373 (diff)
downloadhaskell-7f490b1333c17ed27b213d6af8c7275aa9b3de63.tar.gz
Add a missing trimArityType
This buglet was exposed by #22114, a consequence of my earlier refactoring of arity for join points.
-rw-r--r--compiler/GHC/Core/Opt/Arity.hs156
-rw-r--r--compiler/GHC/Core/Opt/Simplify/Utils.hs156
-rw-r--r--testsuite/tests/simplCore/should_compile/T22114.hs17
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T1
4 files changed, 184 insertions, 146 deletions
diff --git a/compiler/GHC/Core/Opt/Arity.hs b/compiler/GHC/Core/Opt/Arity.hs
index dc4ffbdc7d..b03fe84b14 100644
--- a/compiler/GHC/Core/Opt/Arity.hs
+++ b/compiler/GHC/Core/Opt/Arity.hs
@@ -873,24 +873,49 @@ exprEtaExpandArity opts e
* *
********************************************************************* -}
-findRhsArity :: ArityOpts -> RecFlag -> Id -> CoreExpr -> Arity -> SafeArityType
+findRhsArity :: ArityOpts -> RecFlag -> Id -> CoreExpr
+ -> (Bool, SafeArityType)
-- This implements the fixpoint loop for arity analysis
-- See Note [Arity analysis]
--- If findRhsArity e = (n, is_bot) then
--- (a) any application of e to <n arguments will not do much work,
--- so it is safe to expand e ==> (\x1..xn. e x1 .. xn)
--- (b) if is_bot=True, then e applied to n args is guaranteed bottom
--
--- Returns an ArityType that is guaranteed trimmed to typeArity of 'bndr'
+-- The Bool is True if the returned arity is greater than (exprArity rhs)
+-- so the caller should do eta-expansion
+-- That Bool is never True for join points, which are never eta-expanded
+--
+-- Returns an SafeArityType that is guaranteed trimmed to typeArity of 'bndr'
-- See Note [Arity trimming]
-findRhsArity opts is_rec bndr rhs old_arity
- = case is_rec of
- Recursive -> go 0 botArityType
- NonRecursive -> step init_env
+
+findRhsArity opts is_rec bndr rhs
+ | isJoinId bndr
+ = (False, join_arity_type)
+ -- False: see Note [Do not eta-expand join points]
+ -- But do return the correct arity and bottom-ness, because
+ -- these are used to set the bndr's IdInfo (#15517)
+ -- Note [Invariants on join points] invariant 2b, in GHC.Core
+
+ | otherwise
+ = (arity_increased, non_join_arity_type)
+ -- arity_increased: eta-expand if we'll get more lambdas
+ -- to the top of the RHS
where
+ old_arity = exprArity rhs
+
init_env :: ArityEnv
init_env = findRhsArityEnv opts (isJoinId bndr)
+ -- Non-join-points only
+ non_join_arity_type = case is_rec of
+ Recursive -> go 0 botArityType
+ NonRecursive -> step init_env
+ arity_increased = arityTypeArity non_join_arity_type > old_arity
+
+ -- Join-points only
+ -- See Note [Arity for non-recursive join bindings]
+ -- and Note [Arity for recursive join bindings]
+ join_arity_type = case is_rec of
+ Recursive -> go 0 botArityType
+ NonRecursive -> trimArityType ty_arity (cheapArityType rhs)
+
ty_arity = typeArity (idType bndr)
id_one_shots = idDemandOneShots bndr
@@ -1076,6 +1101,117 @@ But /only/ for called-once demands. Suppose we had
Now we don't want to eta-expand f1 to have 3 args; only two.
Nor, in the case of f2, do we want to push that error call under
a lambda. Hence the takeWhile in combineWithDemandDoneShots.
+
+Note [Do not eta-expand join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Similarly to CPR (see Note [Don't w/w join points for CPR] in
+GHC.Core.Opt.WorkWrap), a join point stands well to gain from its outer binding's
+eta-expansion, and eta-expanding a join point is fraught with issues like how to
+deal with a cast:
+
+ let join $j1 :: IO ()
+ $j1 = ...
+ $j2 :: Int -> IO ()
+ $j2 n = if n > 0 then $j1
+ else ...
+
+ =>
+
+ let join $j1 :: IO ()
+ $j1 = (\eta -> ...)
+ `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
+ ~ IO ()
+ $j2 :: Int -> IO ()
+ $j2 n = (\eta -> if n > 0 then $j1
+ else ...)
+ `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
+ ~ IO ()
+
+The cast here can't be pushed inside the lambda (since it's not casting to a
+function type), so the lambda has to stay, but it can't because it contains a
+reference to a join point. In fact, $j2 can't be eta-expanded at all. Rather
+than try and detect this situation (and whatever other situations crop up!), we
+don't bother; again, any surrounding eta-expansion will improve these join
+points anyway, since an outer cast can *always* be pushed inside. By the time
+CorePrep comes around, the code is very likely to look more like this:
+
+ let join $j1 :: State# RealWorld -> (# State# RealWorld, ())
+ $j1 = (...) eta
+ $j2 :: Int -> State# RealWorld -> (# State# RealWorld, ())
+ $j2 = if n > 0 then $j1
+ else (...) eta
+
+Note [Arity for recursive join bindings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+ f x = joinrec j 0 = \ a b c -> (a,x,b)
+ j n = j (n-1)
+ in j 20
+
+Obviously `f` should get arity 4. But it's a bit tricky:
+
+1. Remember, we don't eta-expand join points; see
+ Note [Do not eta-expand join points].
+
+2. But even though we aren't going to eta-expand it, we still want `j` to get
+ idArity=4, via the findRhsArity fixpoint. Then when we are doing findRhsArity
+ for `f`, we'll call arityType on f's RHS:
+ - At the letrec-binding for `j` we'll whiz up an arity-4 ArityType
+ for `j` (See Note [arityType for non-recursive let-bindings]
+ in GHC.Core.Opt.Arity)b
+ - At the occurrence (j 20) that arity-4 ArityType will leave an arity-3
+ result.
+
+3. All this, even though j's /join-arity/ (stored in the JoinId) is 1.
+ This is is the Main Reason that we want the idArity to sometimes be
+ larger than the join-arity c.f. Note [Invariants on join points] item 2b
+ in GHC.Core.
+
+4. Be very careful of things like this (#21755):
+ g x = let j 0 = \y -> (x,y)
+ j n = expensive n `seq` j (n-1)
+ in j x
+ Here we do /not/ want eta-expand `g`, lest we duplicate all those
+ (expensive n) calls.
+
+ But it's fine: the findRhsArity fixpoint calculation will compute arity-1
+ for `j` (not arity 2); and that's just what we want. But we do need that
+ fixpoint.
+
+ Historical note: an earlier version of GHC did a hack in which we gave
+ join points an ArityType of ABot, but that did not work with this #21755
+ case.
+
+5. arityType does not usually expect to encounter free join points;
+ see GHC.Core.Opt.Arity Note [No free join points in arityType].
+ But consider
+ f x = join j1 y = .... in
+ joinrec j2 z = ...j1 y... in
+ j2 v
+
+ When doing findRhsArity on `j2` we'll encounter the free `j1`.
+ But that is fine, because we aren't going to eta-expand `j2`;
+ we just want to know its arity. So we have a flag am_no_eta,
+ switched on when doing findRhsArity on a join point RHS. If
+ the flag is on, we allow free join points, but not otherwise.
+
+
+Note [Arity for non-recursive join bindings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Note [Arity for recursive join bindings] deals with recursive join
+bindings. But what about /non-recursive/ones? If we just call
+findRhsArity, it will call arityType. And that can be expensive when
+we have deeply nested join points:
+ join j1 x1 = join j2 x2 = join j3 x3 = blah3
+ in blah2
+ in blah1
+(e.g. test T18698b).
+
+So we call cheapArityType instead. It's good enough for practical
+purposes.
+
+(Side note: maybe we should use cheapArity for the RHS of let bindings
+in the main arityType function.)
-}
diff --git a/compiler/GHC/Core/Opt/Simplify/Utils.hs b/compiler/GHC/Core/Opt/Simplify/Utils.hs
index a6deec63cf..433c67b35a 100644
--- a/compiler/GHC/Core/Opt/Simplify/Utils.hs
+++ b/compiler/GHC/Core/Opt/Simplify/Utils.hs
@@ -102,6 +102,14 @@ bindContextLevel :: BindContext -> TopLevelFlag
bindContextLevel (BC_Let top_lvl _) = top_lvl
bindContextLevel (BC_Join {}) = NotTopLevel
+bindContextRec :: BindContext -> RecFlag
+bindContextRec (BC_Let _ rec_flag) = rec_flag
+bindContextRec (BC_Join rec_flag _) = rec_flag
+
+isJoinBC :: BindContext -> Bool
+isJoinBC (BC_Let {}) = False
+isJoinBC (BC_Join {}) = True
+
{- *********************************************************************
* *
@@ -1776,39 +1784,26 @@ Wrinkles
tryEtaExpandRhs :: SimplEnv -> BindContext -> OutId -> OutExpr
-> SimplM (ArityType, OutExpr)
-- See Note [Eta-expanding at let bindings]
--- If tryEtaExpandRhs rhs = (n, is_bot, rhs') then
--- (a) rhs' has manifest arity n
--- (b) if is_bot is True then rhs' applied to n args is guaranteed bottom
-tryEtaExpandRhs env (BC_Join is_rec _) bndr rhs
- = assertPpr (isJoinId bndr) (ppr bndr) $
- return (arity_type, rhs)
- -- Note [Do not eta-expand join points]
- -- But do return the correct arity and bottom-ness, because
- -- these are used to set the bndr's IdInfo (#15517)
- -- Note [Invariants on join points] invariant 2b, in GHC.Core
- where
- -- See Note [Arity for non-recursive join bindings]
- -- and Note [Arity for recursive join bindings]
- arity_type = case is_rec of
- NonRecursive -> cheapArityType rhs
- Recursive -> findRhsArity (seArityOpts env) Recursive
- bndr rhs (exprArity rhs)
-
-tryEtaExpandRhs env (BC_Let _ is_rec) bndr rhs
- | seEtaExpand env -- Provided eta-expansion is on
- , new_arity > old_arity -- And the current manifest arity isn't enough
+tryEtaExpandRhs env bind_cxt bndr rhs
+ | do_eta_expand -- If the current manifest arity isn't enough
+ -- (never true for join points)
+ , seEtaExpand env -- and eta-expansion is on
, wantEtaExpansion rhs
- = do { tick (EtaExpansion bndr)
+ = -- Do eta-expansion.
+ assertPpr( not (isJoinBC bind_cxt) ) (ppr bndr) $
+ -- assert: this never happens for join points; see GHC.Core.Opt.Arity
+ -- Note [Do not eta-expand join points]
+ do { tick (EtaExpansion bndr)
; return (arity_type, etaExpandAT in_scope arity_type rhs) }
| otherwise
= return (arity_type, rhs)
+
where
in_scope = getInScope env
- old_arity = exprArity rhs
arity_opts = seArityOpts env
- arity_type = findRhsArity arity_opts is_rec bndr rhs old_arity
- new_arity = arityTypeArity arity_type
+ is_rec = bindContextRec bind_cxt
+ (do_eta_expand, arity_type) = findRhsArity arity_opts is_rec bndr rhs
wantEtaExpansion :: CoreExpr -> Bool
-- Mostly True; but False of PAPs which will immediately eta-reduce again
@@ -1894,117 +1889,6 @@ But note that this won't eta-expand, say
Does it matter not eta-expanding such functions? I'm not sure. Perhaps
strictness analysis will have less to bite on?
-Note [Do not eta-expand join points]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Similarly to CPR (see Note [Don't w/w join points for CPR] in
-GHC.Core.Opt.WorkWrap), a join point stands well to gain from its outer binding's
-eta-expansion, and eta-expanding a join point is fraught with issues like how to
-deal with a cast:
-
- let join $j1 :: IO ()
- $j1 = ...
- $j2 :: Int -> IO ()
- $j2 n = if n > 0 then $j1
- else ...
-
- =>
-
- let join $j1 :: IO ()
- $j1 = (\eta -> ...)
- `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
- ~ IO ()
- $j2 :: Int -> IO ()
- $j2 n = (\eta -> if n > 0 then $j1
- else ...)
- `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
- ~ IO ()
-
-The cast here can't be pushed inside the lambda (since it's not casting to a
-function type), so the lambda has to stay, but it can't because it contains a
-reference to a join point. In fact, $j2 can't be eta-expanded at all. Rather
-than try and detect this situation (and whatever other situations crop up!), we
-don't bother; again, any surrounding eta-expansion will improve these join
-points anyway, since an outer cast can *always* be pushed inside. By the time
-CorePrep comes around, the code is very likely to look more like this:
-
- let join $j1 :: State# RealWorld -> (# State# RealWorld, ())
- $j1 = (...) eta
- $j2 :: Int -> State# RealWorld -> (# State# RealWorld, ())
- $j2 = if n > 0 then $j1
- else (...) eta
-
-Note [Arity for recursive join bindings]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Consider
- f x = joinrec j 0 = \ a b c -> (a,x,b)
- j n = j (n-1)
- in j 20
-
-Obviously `f` should get arity 4. But it's a bit tricky:
-
-1. Remember, we don't eta-expand join points; see
- Note [Do not eta-expand join points].
-
-2. But even though we aren't going to eta-expand it, we still want `j` to get
- idArity=4, via the findRhsArity fixpoint. Then when we are doing findRhsArity
- for `f`, we'll call arityType on f's RHS:
- - At the letrec-binding for `j` we'll whiz up an arity-4 ArityType
- for `j` (See Note [arityType for non-recursive let-bindings]
- in GHC.Core.Opt.Arity)b
- - At the occurrence (j 20) that arity-4 ArityType will leave an arity-3
- result.
-
-3. All this, even though j's /join-arity/ (stored in the JoinId) is 1.
- This is is the Main Reason that we want the idArity to sometimes be
- larger than the join-arity c.f. Note [Invariants on join points] item 2b
- in GHC.Core.
-
-4. Be very careful of things like this (#21755):
- g x = let j 0 = \y -> (x,y)
- j n = expensive n `seq` j (n-1)
- in j x
- Here we do /not/ want eta-expand `g`, lest we duplicate all those
- (expensive n) calls.
-
- But it's fine: the findRhsArity fixpoint calculation will compute arity-1
- for `j` (not arity 2); and that's just what we want. But we do need that
- fixpoint.
-
- Historical note: an earlier version of GHC did a hack in which we gave
- join points an ArityType of ABot, but that did not work with this #21755
- case.
-
-5. arityType does not usually expect to encounter free join points;
- see GHC.Core.Opt.Arity Note [No free join points in arityType].
- But consider
- f x = join j1 y = .... in
- joinrec j2 z = ...j1 y... in
- j2 v
-
- When doing findRhsArity on `j2` we'll encounter the free `j1`.
- But that is fine, because we aren't going to eta-expand `j2`;
- we just want to know its arity. So we have a flag am_no_eta,
- switched on when doing findRhsArity on a join point RHS. If
- the flag is on, we allow free join points, but not otherwise.
-
-
-Note [Arity for non-recursive join bindings]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Note [Arity for recursive join bindings] deals with recursive join
-bindings. But what about /non-recursive/ones? If we just call
-findRhsArity, it will call arityType. And that can be expensive when
-we have deeply nested join points:
- join j1 x1 = join j2 x2 = join j3 x3 = blah3
- in blah2
- in blah1
-(e.g. test T18698b).
-
-So we call cheapArityType instead. It's good enough for practical
-purposes.
-
-(Side note: maybe we should use cheapArity for the RHS of let bindings
-in the main arityType function.)
-
************************************************************************
* *
diff --git a/testsuite/tests/simplCore/should_compile/T22114.hs b/testsuite/tests/simplCore/should_compile/T22114.hs
new file mode 100644
index 0000000000..38a481d01c
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T22114.hs
@@ -0,0 +1,17 @@
+{-# LANGUAGE Haskell2010 #-}
+{-# LANGUAGE TypeFamilies #-}
+
+module T22114 where
+
+import Data.Kind (Type)
+
+value :: [Int] -> () -> Maybe Bool
+value = valu
+ where valu [0] = valuN
+ valu _ = \_ -> Nothing
+
+type family T :: Type where
+ T = () -> Maybe Bool
+
+valuN :: T
+valuN = valuN
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index c1a32a7248..1335c2c242 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -428,3 +428,4 @@ test('T21948', [grep_errmsg(r'^ Arity=5') ], compile, ['-O -ddump-simpl'])
test('T21763', only_ways(['optasm']), compile, ['-O2 -ddump-rules'])
test('T21763a', only_ways(['optasm']), compile, ['-O2 -ddump-rules'])
test('T22028', normal, compile, ['-O -ddump-rule-firings'])
+test('T22114', normal, compile, ['-O'])