summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2018-08-21 09:56:39 +0100
committerSimon Peyton Jones <simonpj@microsoft.com>2018-08-21 14:00:48 +0100
commitce6ce788251b6102f5c1b878ffec53ba7ad678b5 (patch)
treeabce984aa54c6c7eaae7678b822086f50481a2a2
parent9c4e6c6b1affd410604f8f76ecf56abfcc5cccb6 (diff)
downloadhaskell-ce6ce788251b6102f5c1b878ffec53ba7ad678b5.tar.gz
Set strictness correctly for JoinIds
We were failing to keep correct strictness info when eta-expanding join points; Trac #15517. The situation was something like \q v eta -> let j x = error "blah -- STR Lx bottoming! in case y of A -> j x eta B -> blah C -> j x eta So we spot j as a join point and eta-expand it. But we must also adjust the stricness info, else it vlaimes to bottom after one arg is applied but now it has become two. I fixed this in two places: - In CoreOpt.joinPointBinding_maybe, adjust strictness info - In SimplUtils.tryEtaExpandRhs, return consistent values for arity and bottom-ness
-rw-r--r--compiler/basicTypes/Demand.hs21
-rw-r--r--compiler/coreSyn/CoreOpt.hs28
-rw-r--r--compiler/simplCore/SimplUtils.hs9
-rw-r--r--testsuite/tests/simplCore/should_compile/T15517.hs10
-rw-r--r--testsuite/tests/simplCore/should_compile/T15517a.hs96
-rw-r--r--testsuite/tests/simplCore/should_compile/all.T4
6 files changed, 160 insertions, 8 deletions
diff --git a/compiler/basicTypes/Demand.hs b/compiler/basicTypes/Demand.hs
index 0b0da1349a..071945386e 100644
--- a/compiler/basicTypes/Demand.hs
+++ b/compiler/basicTypes/Demand.hs
@@ -39,7 +39,7 @@ module Demand (
nopSig, botSig, exnSig, cprProdSig,
isTopSig, hasDemandEnvSig,
splitStrictSig, strictSigDmdEnv,
- increaseStrictSigArity,
+ increaseStrictSigArity, etaExpandStrictSig,
seqDemand, seqDemandList, seqDmdType, seqStrictSig,
@@ -1737,8 +1737,23 @@ splitStrictSig (StrictSig (DmdType _ dmds res)) = (dmds, res)
increaseStrictSigArity :: Int -> StrictSig -> StrictSig
-- Add extra arguments to a strictness signature
-increaseStrictSigArity arity_increase (StrictSig (DmdType env dmds res))
- = StrictSig (DmdType env (replicate arity_increase topDmd ++ dmds) res)
+increaseStrictSigArity arity_increase sig@(StrictSig dmd_ty@(DmdType env dmds res))
+ | isTopDmdType dmd_ty = sig
+ | arity_increase <= 0 = sig
+ | otherwise = StrictSig (DmdType env dmds' res)
+ where
+ dmds' = replicate arity_increase topDmd ++ dmds
+
+etaExpandStrictSig :: Arity -> StrictSig -> StrictSig
+-- We are expanding (\x y. e) to (\x y z. e z)
+-- Add exta demands to the /end/ of the arg demands if necessary
+etaExpandStrictSig arity sig@(StrictSig dmd_ty@(DmdType env dmds res))
+ | isTopDmdType dmd_ty = sig
+ | arity_increase <= 0 = sig
+ | otherwise = StrictSig (DmdType env dmds' res)
+ where
+ arity_increase = arity - length dmds
+ dmds' = dmds ++ replicate arity_increase topDmd
isTopSig :: StrictSig -> Bool
isTopSig (StrictSig ty) = isTopDmdType ty
diff --git a/compiler/coreSyn/CoreOpt.hs b/compiler/coreSyn/CoreOpt.hs
index 11cbd1e2c8..ff5ed35517 100644
--- a/compiler/coreSyn/CoreOpt.hs
+++ b/compiler/coreSyn/CoreOpt.hs
@@ -36,6 +36,7 @@ import Var ( varType )
import VarSet
import VarEnv
import DataCon
+import Demand( etaExpandStrictSig )
import OptCoercion ( optCoercion )
import Type hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList
, isInScope, substTyVarBndr, cloneTyVarBndr )
@@ -658,7 +659,11 @@ joinPointBinding_maybe bndr rhs
| AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
, (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
- = Just (bndr `asJoinId` join_arity, mkLams bndrs body)
+ , let str_sig = idStrictness bndr
+ str_arity = count isId bndrs -- Strictness demands are for Ids only
+ join_bndr = bndr `asJoinId` join_arity
+ `setIdStrictness` etaExpandStrictSig str_arity str_sig
+ = Just (join_bndr, mkLams bndrs body)
| otherwise
= Nothing
@@ -668,6 +673,27 @@ joinPointBindings_maybe bndrs
= mapM (uncurry joinPointBinding_maybe) bndrs
+{- Note [Strictness and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have
+
+ let f = \x. if x>200 then e1 else e1
+
+and we know that f is strict in x. Then if we subsequently
+discover that f is an arity-2 join point, we'll eta-expand it to
+
+ let f = \x y. if x>200 then e1 else e1
+
+and now it's only strict if applied to two arguments. So we should
+adjust the strictness info.
+
+A more common case is when
+
+ f = \x. error ".."
+
+and again its arity increses (Trac #15517)
+-}
+
{- *********************************************************************
* *
exprIsConApp_maybe
diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs
index 83ad059171..ca1b9bd23d 100644
--- a/compiler/simplCore/SimplUtils.hs
+++ b/compiler/simplCore/SimplUtils.hs
@@ -1511,9 +1511,12 @@ tryEtaExpandRhs :: SimplMode -> OutId -> OutExpr
-- (a) rhs' has manifest arity
-- (b) if is_bot is True then rhs' applied to n args is guaranteed bottom
tryEtaExpandRhs mode bndr rhs
- | isJoinId bndr
- = return (manifestArity rhs, False, rhs)
- -- Note [Do not eta-expand join points]
+ | Just join_arity <- isJoinId_maybe bndr
+ = do { let (join_bndrs, join_body) = collectNBinders join_arity rhs
+ ; return (count isId join_bndrs, exprIsBottom join_body, rhs) }
+ -- Note [Do not eta-expand join points]
+ -- But do return the correct arity and bottom-ness, because
+ -- these are used to set the bndr's IdInfo (Trac #15517)
| otherwise
= do { (new_arity, is_bot, new_rhs) <- try_expand
diff --git a/testsuite/tests/simplCore/should_compile/T15517.hs b/testsuite/tests/simplCore/should_compile/T15517.hs
new file mode 100644
index 0000000000..954baccd48
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T15517.hs
@@ -0,0 +1,10 @@
+{-# LANGUAGE PatternSynonyms #-}
+module T15517 where
+
+data Nat = Z | S Nat
+
+pattern Zpat = Z
+
+sfrom :: Nat -> () -> Bool
+sfrom Zpat = \_ -> False
+sfrom (S Z) = \_ -> False
diff --git a/testsuite/tests/simplCore/should_compile/T15517a.hs b/testsuite/tests/simplCore/should_compile/T15517a.hs
new file mode 100644
index 0000000000..28ca664969
--- /dev/null
+++ b/testsuite/tests/simplCore/should_compile/T15517a.hs
@@ -0,0 +1,96 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE FunctionalDependencies #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module T15517a () where
+
+import Data.Proxy
+
+newtype Rep (ki :: kon -> *) (phi :: Nat -> *) (code :: [[Atom kon]])
+ = Rep (NS (PoA ki phi) code)
+
+data NA :: (kon -> *) -> (Nat -> *) -> Atom kon -> * where
+ NA_I :: (IsNat k) => phi k -> NA ki phi (I k)
+ NA_K :: ki k -> NA ki phi (K k)
+
+data NP :: (k -> *) -> [k] -> * where
+ NP0 :: NP p '[]
+ (:*) :: p x -> NP p xs -> NP p (x : xs)
+
+class IsNat (n :: Nat) where
+ getSNat :: Proxy n -> SNat n
+instance IsNat Z where
+ getSNat _ = SZ
+instance IsNat n => IsNat (S n) where
+ getSNat p = SS (getSNat $ proxyUnsuc p)
+
+proxyUnsuc :: Proxy (S n) -> Proxy n
+proxyUnsuc _ = Proxy
+
+type PoA (ki :: kon -> *) (phi :: Nat -> *) = NP (NA ki phi)
+
+data Atom kon
+ = K kon
+ | I Nat
+
+data Nat = S Nat | Z
+data SNat :: Nat -> * where
+ SZ :: SNat Z
+ SS :: SNat n -> SNat (S n)
+
+data Kon = KInt
+data Singl (kon :: Kon) :: * where
+ SInt :: Int -> Singl KInt
+
+type family Lkup (n :: Nat) (ks :: [k]) :: k where
+ Lkup Z (k : ks) = k
+ Lkup (S n) (k : ks) = Lkup n ks
+
+data El :: [*] -> Nat -> * where
+ El :: IsNat ix => Lkup ix fam -> El fam ix
+
+data NS :: (k -> *) -> [k] -> * where
+ There :: NS p xs -> NS p (x : xs)
+ Here :: p x -> NS p (x : xs)
+
+class Family (ki :: kon -> *) (fam :: [*]) (codes :: [[[Atom kon]]])
+ | fam -> ki codes , ki codes -> fam where
+ sfrom' :: SNat ix -> El fam ix -> Rep ki (El fam) (Lkup ix codes)
+
+data Rose a = a :>: [Rose a]
+ | Leaf a
+
+type FamRoseInt = '[Rose Int, [Rose Int]]
+
+type CodesRoseInt =
+ '[ '[ '[K KInt, I (S Z)], '[K KInt]], '[ '[], '[I Z, I (S Z)]]]
+
+pattern IdxRoseInt = SZ
+pattern IdxListRoseInt = SS SZ
+
+pat1 :: PoA Singl (El FamRoseInt) '[I Z, I (S Z)]
+ -> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]]
+pat1 d = There (Here d)
+
+pat2 :: PoA Singl (El FamRoseInt) '[]
+ -> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]]
+pat2 d = Here d
+
+pat3 :: PoA Singl (El FamRoseInt) '[K KInt]
+ -> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]]
+pat3 d = There (Here d)
+
+pat4 :: PoA Singl (El FamRoseInt) '[K KInt, I (S Z)]
+ -> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]]
+pat4 d = Here d
+
+instance Family Singl FamRoseInt CodesRoseInt where
+ sfrom' = \case IdxRoseInt -> \case El (x :>: xs) -> Rep (pat4 (NA_K (SInt x) :* (NA_I (El xs) :* NP0)))
+ El (Leaf x) -> Rep (pat3 (NA_K (SInt x) :* NP0))
+ IdxListRoseInt -> \case El [] -> Rep (pat2 NP0)
+ El (x:xs) -> Rep (pat1 (NA_I (El x) :* (NA_I (El xs) :* NP0)))
diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T
index 595607b628..188f6432fa 100644
--- a/testsuite/tests/simplCore/should_compile/all.T
+++ b/testsuite/tests/simplCore/should_compile/all.T
@@ -318,4 +318,6 @@ test('T15005', normal, compile, ['-O'])
# we omit profiling because it affects the optimiser and makes the test fail
test('T15056', [extra_files(['T15056a.hs']), omit_ways(['profasm'])], multimod_compile, ['T15056', '-O -v0 -ddump-rule-firings'])
test('T15186', normal, multimod_compile, ['T15186', '-v0'])
-test('T15453', normal, compile, ['-dcore-lint -O1'])
+test('T15453', normal, compile, ['-O1'])
+test('T15517', normal, compile, ['-O0'])
+test('T15517a', normal, compile, ['-O0'])