summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/stranal/DmdAnal.hs82
-rw-r--r--testsuite/tests/simplCore/should_compile/T13543.hs17
-rw-r--r--testsuite/tests/simplCore/should_compile/T13543.stderr1
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T1
4 files changed, 81 insertions, 20 deletions
diff --git a/compiler/stranal/DmdAnal.hs b/compiler/stranal/DmdAnal.hs
index 304a2becb3..78eefe39a1 100644
--- a/compiler/stranal/DmdAnal.hs
+++ b/compiler/stranal/DmdAnal.hs
@@ -64,20 +64,20 @@ dmdAnalProgram dflags fam_envs binds
dmdAnalTopBind :: AnalEnv
-> CoreBind
-> (AnalEnv, CoreBind)
-dmdAnalTopBind sigs (NonRec id rhs)
- = (extendAnalEnv TopLevel sigs id2 (idStrictness id2), NonRec id2 rhs2)
+dmdAnalTopBind env (NonRec id rhs)
+ = (extendAnalEnv TopLevel env id2 (idStrictness id2), NonRec id2 rhs2)
where
- ( _, _, rhs1) = dmdAnalRhsLetDown TopLevel Nothing sigs id rhs
- ( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin sigs) id rhs1
+ ( _, _, rhs1) = dmdAnalRhsLetDown TopLevel Nothing env cleanEvalDmd id rhs
+ ( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin env) cleanEvalDmd id rhs1
-- Do two passes to improve CPR information
-- See Note [CPR for thunks]
-- See Note [Optimistic CPR in the "virgin" case]
-- See Note [Initial CPR for strict binders]
-dmdAnalTopBind sigs (Rec pairs)
- = (sigs', Rec pairs')
+dmdAnalTopBind env (Rec pairs)
+ = (env', Rec pairs')
where
- (sigs', _, pairs') = dmdFix TopLevel sigs pairs
+ (env', _, pairs') = dmdFix TopLevel env cleanEvalDmd pairs
-- We get two iterations automatically
-- c.f. the NonRec case above
@@ -308,7 +308,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body)
dmdAnal' env dmd (Let (NonRec id rhs) body)
= (body_ty2, Let (NonRec id2 rhs') body')
where
- (lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env id rhs
+ (lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env dmd id rhs
env1 = extendAnalEnv NotTopLevel env id1 (idStrictness id1)
(body_ty, body') = dmdAnal env1 dmd body
(body_ty1, id2) = annotateBndr env body_ty id1
@@ -329,7 +329,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body)
dmdAnal' env dmd (Let (Rec pairs) body)
= let
- (env', lazy_fv, pairs') = dmdFix NotTopLevel env pairs
+ (env', lazy_fv, pairs') = dmdFix NotTopLevel env dmd pairs
(body_ty, body') = dmdAnal env' dmd body
body_ty1 = deleteFVs body_ty (map fst pairs)
body_ty2 = addLazyFVs body_ty1 lazy_fv -- see Note [Lazy and unleasheable free variables]
@@ -509,17 +509,17 @@ dmdTransform env var dmd
-- Recursive bindings
dmdFix :: TopLevelFlag
-> AnalEnv -- Does not include bindings for this binding
+ -> CleanDemand
-> [(Id,CoreExpr)]
-> (AnalEnv, DmdEnv, [(Id,CoreExpr)]) -- Binders annotated with stricness info
-dmdFix top_lvl env orig_pairs
+dmdFix top_lvl env let_dmd orig_pairs
= loop 1 initial_pairs
where
bndrs = map fst orig_pairs
-- See Note [Initialising strictness]
initial_pairs | ae_virgin env = [(setIdStrictness id botSig, rhs) | (id, rhs) <- orig_pairs ]
-
| otherwise = orig_pairs
-- If fixed-point iteration does not yield a result we use this instead
@@ -562,7 +562,7 @@ dmdFix top_lvl env orig_pairs
my_downRhs (env, lazy_fv) (id,rhs)
= ((env', lazy_fv'), (id', rhs'))
where
- (lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env id rhs
+ (lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env let_dmd id rhs
lazy_fv' = plusVarEnv_C bothDmd lazy_fv lazy_fv1
env' = extendAnalEnv top_lvl env id (idStrictness id')
@@ -621,18 +621,27 @@ dmdAnalTrivialRhs env id rhs fn
-- This is the LetDown rule in the paper “Higher-Order Cardinality Analysis”.
dmdAnalRhsLetDown :: TopLevelFlag
-> Maybe [Id] -- Just bs <=> recursive, Nothing <=> non-recursive
- -> AnalEnv -> Id -> CoreExpr
+ -> AnalEnv -> CleanDemand
+ -> Id -> CoreExpr
-> (DmdEnv, Id, CoreExpr)
-- Process the RHS of the binding, add the strictness signature
-- to the Id, and augment the environment with the signature as well.
-dmdAnalRhsLetDown top_lvl rec_flag env id rhs
+dmdAnalRhsLetDown top_lvl rec_flag env let_dmd id rhs
| Just fn <- unpackTrivial rhs -- See Note [Demand analysis for trivial right-hand sides]
= dmdAnalTrivialRhs env id rhs fn
| otherwise
= (lazy_fv, id', mkLams bndrs' body')
where
- (bndrs, body) = collectBinders rhs
+ (bndrs, body, body_dmd)
+ = case isJoinId_maybe id of
+ Just join_arity -- See Note [Demand analysis for join points]
+ | (bndrs, body) <- collectNBinders join_arity rhs
+ -> (bndrs, body, let_dmd)
+
+ Nothing | (bndrs, body) <- collectBinders rhs
+ -> (bndrs, body, mkBodyDmd env body)
+
env_body = foldl extendSigsWithLam env bndrs
(body_ty, body') = dmdAnal env_body body_dmd body
body_ty' = removeDmdTyArgs body_ty -- zap possible deep CPR info
@@ -642,10 +651,6 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs
id' = set_idStrictness env id sig_ty
-- See Note [NOINLINE and strictness]
- -- See Note [Product demands for function body]
- body_dmd = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of
- Nothing -> cleanEvalDmd
- Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc)
-- See Note [Aggregated demand for cardinality]
rhs_fv1 = case rec_flag of
@@ -667,6 +672,13 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs
|| not (isStrictDmd (idDemandInfo id) || ae_virgin env)
-- See Note [Optimistic CPR in the "virgin" case]
+mkBodyDmd :: AnalEnv -> CoreExpr -> CleanDemand
+-- See Note [Product demands for function body]
+mkBodyDmd env body
+ = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of
+ Nothing -> cleanEvalDmd
+ Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc)
+
unpackTrivial :: CoreExpr -> Maybe Id
-- Returns (Just v) if the arg is really equal to v, modulo
-- casts, type applications etc
@@ -691,7 +703,37 @@ useLetUp _ (Lam _ _) = False
useLetUp _ _ = True
-{-
+{- Note [Demand analysis for join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+ g :: (Int,Int) -> Int
+ g (p,q) = p+q
+
+ f :: T -> Int -> Int
+ f x p = g (join j y = (p,y)
+ in case x of
+ A -> j 3
+ B -> j 4
+ C -> (p,7))
+
+If j was a vanilla function definition, we'd analyse its body with
+evalDmd, and think that it was lazy in p. But for join points we can
+do better! We know that j's body will (if called at all) be evaluated
+with the demand that consumes the entire join-binding, in this case
+the argument demand from g. Whizzo! g evaluates both components of
+its arugment pair, so p will certainly be evaluated if j is called.
+
+For f to be strict in p, we need /all/ paths to evaluate p; in this
+case the C branch does so too, so we are fine. So, as usual, we need
+to transport demands on free variables to the call site(s). Compare
+Note [Lazy and unleasheable free variables].
+
+The implementation is easy. Wwhen analysing a join point, we can
+analyse its body with the demand from the entire join-binding (written
+let_dmd here).
+
+Another win for join points! Trac #13543.
+
Note [Demand analysis for trivial right-hand sides]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
diff --git a/testsuite/tests/simplCore/should_compile/T13543.hs b/testsuite/tests/simplCore/should_compile/T13543.hs
new file mode 100644
index 0000000000..88a0b142b0
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T13543.hs
@@ -0,0 +1,17 @@
+{-# LANGUAGE RankNTypes, GADTs #-}
+
+module Foo where
+
+g :: (Int, Int) -> Int
+{-# NOINLINE g #-}
+g (p,q) = p+q
+
+f :: Int -> Int -> Int -> Int
+f x p q
+ = g (let j y = (p,q)
+ {-# NOINLINE j #-}
+ in
+ case x of
+ 2 -> j 3
+ _ -> j 4)
+
diff --git a/testsuite/tests/simplCore/should_compile/T13543.stderr b/testsuite/tests/simplCore/should_compile/T13543.stderr
new file mode 100644
index 0000000000..0519ecba6e
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T13543.stderr
@@ -0,0 +1 @@
+ \ No newline at end of file
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index 7a079c793f..1b45930130 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -259,3 +259,4 @@ test('T13468',
normal,
run_command,
['$MAKE -s --no-print-directory T13468'])
+test('T13543', normal, compile, ['-ddump-str-signatures'])