summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/coreSyn/CoreUtils.lhs25
1 files changed, 16 insertions, 9 deletions
diff --git a/compiler/coreSyn/CoreUtils.lhs b/compiler/coreSyn/CoreUtils.lhs
index c872ac311e..7c0a2d406e 100644
--- a/compiler/coreSyn/CoreUtils.lhs
+++ b/compiler/coreSyn/CoreUtils.lhs
@@ -45,6 +45,7 @@ module CoreUtils (
import CoreSyn
import PprCore
+import CoreFVs( exprFreeVars )
import Var
import SrcLoc
import VarEnv
@@ -1529,6 +1530,11 @@ are going to avoid allocating this thing altogether.
There are some particularly delicate points here:
+* We want to eta-reduce if doing so leaves a trivial expression,
+ *including* a cast. For example
+ \x. f |> co --> f |> co
+ (provided co doesn't mention x)
+
* Eta reduction is not valid in general:
\x. bot /= bot
This matters, partly for old-fashioned correctness reasons but,
@@ -1545,7 +1551,7 @@ There are some particularly delicate points here:
Result: seg-fault because the boolean case actually gets a function value.
See Trac #1947.
- So it's important to to the right thing.
+ So it's important to do the right thing.
* Note [Arity care]: we need to be careful if we just look at f's
arity. Currently (Dec07), f's arity is visible in its own RHS (see
@@ -1616,7 +1622,11 @@ tryEtaReduce bndrs body
-- See Note [Eta reduction with casted arguments]
-- for why we have an accumulating coercion
go [] fun co
- | ok_fun fun = Just (mkCast fun co)
+ | ok_fun fun
+ , let result = mkCast fun co
+ , not (any (`elemVarSet` exprFreeVars result) bndrs)
+ = Just result -- Check for any of the binders free in the result
+ -- *including* the accumulated coercion
go (b : bs) (App fun arg) co
| Just co' <- ok_arg b arg co
@@ -1626,13 +1636,10 @@ tryEtaReduce bndrs body
---------------
-- Note [Eta reduction conditions]
- ok_fun (App fun (Type ty))
- | not (any (`elemVarSet` tyVarsOfType ty) bndrs)
- = ok_fun fun
- ok_fun (Var fun_id)
- = not (fun_id `elem` bndrs)
- && (ok_fun_id fun_id || all ok_lam bndrs)
- ok_fun _fun = False
+ ok_fun (App fun (Type {})) = ok_fun fun
+ ok_fun (Cast fun _) = ok_fun fun
+ ok_fun (Var fun_id) = ok_fun_id fun_id || all ok_lam bndrs
+ ok_fun _fun = False
---------------
ok_fun_id fun = fun_arity fun >= incoming_arity