diff options
author | Simon Peyton Jones <simonpj@microsoft.com> | 2018-08-21 09:56:39 +0100 |
---|---|---|
committer | Simon Peyton Jones <simonpj@microsoft.com> | 2018-08-21 14:00:48 +0100 |
commit | ce6ce788251b6102f5c1b878ffec53ba7ad678b5 (patch) | |
tree | abce984aa54c6c7eaae7678b822086f50481a2a2 | |
parent | 9c4e6c6b1affd410604f8f76ecf56abfcc5cccb6 (diff) | |
download | haskell-ce6ce788251b6102f5c1b878ffec53ba7ad678b5.tar.gz |
Set strictness correctly for JoinIds
We were failing to keep correct strictness info when eta-expanding
join points; Trac #15517. The situation was something like
\q v eta ->
let j x = error "blah
-- STR Lx bottoming!
in case y of
A -> j x eta
B -> blah
C -> j x eta
So we spot j as a join point and eta-expand it. But we must
also adjust the stricness info, else it vlaimes to bottom after
one arg is applied but now it has become two.
I fixed this in two places:
- In CoreOpt.joinPointBinding_maybe, adjust strictness info
- In SimplUtils.tryEtaExpandRhs, return consistent values
for arity and bottom-ness
-rw-r--r-- | compiler/basicTypes/Demand.hs | 21 | ||||
-rw-r--r-- | compiler/coreSyn/CoreOpt.hs | 28 | ||||
-rw-r--r-- | compiler/simplCore/SimplUtils.hs | 9 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T15517.hs | 10 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/T15517a.hs | 96 | ||||
-rw-r--r-- | testsuite/tests/simplCore/should_compile/all.T | 4 |
6 files changed, 160 insertions, 8 deletions
diff --git a/compiler/basicTypes/Demand.hs b/compiler/basicTypes/Demand.hs index 0b0da1349a..071945386e 100644 --- a/compiler/basicTypes/Demand.hs +++ b/compiler/basicTypes/Demand.hs @@ -39,7 +39,7 @@ module Demand ( nopSig, botSig, exnSig, cprProdSig, isTopSig, hasDemandEnvSig, splitStrictSig, strictSigDmdEnv, - increaseStrictSigArity, + increaseStrictSigArity, etaExpandStrictSig, seqDemand, seqDemandList, seqDmdType, seqStrictSig, @@ -1737,8 +1737,23 @@ splitStrictSig (StrictSig (DmdType _ dmds res)) = (dmds, res) increaseStrictSigArity :: Int -> StrictSig -> StrictSig -- Add extra arguments to a strictness signature -increaseStrictSigArity arity_increase (StrictSig (DmdType env dmds res)) - = StrictSig (DmdType env (replicate arity_increase topDmd ++ dmds) res) +increaseStrictSigArity arity_increase sig@(StrictSig dmd_ty@(DmdType env dmds res)) + | isTopDmdType dmd_ty = sig + | arity_increase <= 0 = sig + | otherwise = StrictSig (DmdType env dmds' res) + where + dmds' = replicate arity_increase topDmd ++ dmds + +etaExpandStrictSig :: Arity -> StrictSig -> StrictSig +-- We are expanding (\x y. e) to (\x y z. e z) +-- Add exta demands to the /end/ of the arg demands if necessary +etaExpandStrictSig arity sig@(StrictSig dmd_ty@(DmdType env dmds res)) + | isTopDmdType dmd_ty = sig + | arity_increase <= 0 = sig + | otherwise = StrictSig (DmdType env dmds' res) + where + arity_increase = arity - length dmds + dmds' = dmds ++ replicate arity_increase topDmd isTopSig :: StrictSig -> Bool isTopSig (StrictSig ty) = isTopDmdType ty diff --git a/compiler/coreSyn/CoreOpt.hs b/compiler/coreSyn/CoreOpt.hs index 11cbd1e2c8..ff5ed35517 100644 --- a/compiler/coreSyn/CoreOpt.hs +++ b/compiler/coreSyn/CoreOpt.hs @@ -36,6 +36,7 @@ import Var ( varType ) import VarSet import VarEnv import DataCon +import Demand( etaExpandStrictSig ) import OptCoercion ( optCoercion ) import Type hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList , isInScope, substTyVarBndr, cloneTyVarBndr ) @@ -658,7 +659,11 @@ joinPointBinding_maybe bndr rhs | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr) , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs - = Just (bndr `asJoinId` join_arity, mkLams bndrs body) + , let str_sig = idStrictness bndr + str_arity = count isId bndrs -- Strictness demands are for Ids only + join_bndr = bndr `asJoinId` join_arity + `setIdStrictness` etaExpandStrictSig str_arity str_sig + = Just (join_bndr, mkLams bndrs body) | otherwise = Nothing @@ -668,6 +673,27 @@ joinPointBindings_maybe bndrs = mapM (uncurry joinPointBinding_maybe) bndrs +{- Note [Strictness and join points] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Suppose we have + + let f = \x. if x>200 then e1 else e1 + +and we know that f is strict in x. Then if we subsequently +discover that f is an arity-2 join point, we'll eta-expand it to + + let f = \x y. if x>200 then e1 else e1 + +and now it's only strict if applied to two arguments. So we should +adjust the strictness info. + +A more common case is when + + f = \x. error ".." + +and again its arity increses (Trac #15517) +-} + {- ********************************************************************* * * exprIsConApp_maybe diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs index 83ad059171..ca1b9bd23d 100644 --- a/compiler/simplCore/SimplUtils.hs +++ b/compiler/simplCore/SimplUtils.hs @@ -1511,9 +1511,12 @@ tryEtaExpandRhs :: SimplMode -> OutId -> OutExpr -- (a) rhs' has manifest arity -- (b) if is_bot is True then rhs' applied to n args is guaranteed bottom tryEtaExpandRhs mode bndr rhs - | isJoinId bndr - = return (manifestArity rhs, False, rhs) - -- Note [Do not eta-expand join points] + | Just join_arity <- isJoinId_maybe bndr + = do { let (join_bndrs, join_body) = collectNBinders join_arity rhs + ; return (count isId join_bndrs, exprIsBottom join_body, rhs) } + -- Note [Do not eta-expand join points] + -- But do return the correct arity and bottom-ness, because + -- these are used to set the bndr's IdInfo (Trac #15517) | otherwise = do { (new_arity, is_bot, new_rhs) <- try_expand diff --git a/testsuite/tests/simplCore/should_compile/T15517.hs b/testsuite/tests/simplCore/should_compile/T15517.hs new file mode 100644 index 0000000000..954baccd48 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T15517.hs @@ -0,0 +1,10 @@ +{-# LANGUAGE PatternSynonyms #-} +module T15517 where + +data Nat = Z | S Nat + +pattern Zpat = Z + +sfrom :: Nat -> () -> Bool +sfrom Zpat = \_ -> False +sfrom (S Z) = \_ -> False diff --git a/testsuite/tests/simplCore/should_compile/T15517a.hs b/testsuite/tests/simplCore/should_compile/T15517a.hs new file mode 100644 index 0000000000..28ca664969 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T15517a.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module T15517a () where + +import Data.Proxy + +newtype Rep (ki :: kon -> *) (phi :: Nat -> *) (code :: [[Atom kon]]) + = Rep (NS (PoA ki phi) code) + +data NA :: (kon -> *) -> (Nat -> *) -> Atom kon -> * where + NA_I :: (IsNat k) => phi k -> NA ki phi (I k) + NA_K :: ki k -> NA ki phi (K k) + +data NP :: (k -> *) -> [k] -> * where + NP0 :: NP p '[] + (:*) :: p x -> NP p xs -> NP p (x : xs) + +class IsNat (n :: Nat) where + getSNat :: Proxy n -> SNat n +instance IsNat Z where + getSNat _ = SZ +instance IsNat n => IsNat (S n) where + getSNat p = SS (getSNat $ proxyUnsuc p) + +proxyUnsuc :: Proxy (S n) -> Proxy n +proxyUnsuc _ = Proxy + +type PoA (ki :: kon -> *) (phi :: Nat -> *) = NP (NA ki phi) + +data Atom kon + = K kon + | I Nat + +data Nat = S Nat | Z +data SNat :: Nat -> * where + SZ :: SNat Z + SS :: SNat n -> SNat (S n) + +data Kon = KInt +data Singl (kon :: Kon) :: * where + SInt :: Int -> Singl KInt + +type family Lkup (n :: Nat) (ks :: [k]) :: k where + Lkup Z (k : ks) = k + Lkup (S n) (k : ks) = Lkup n ks + +data El :: [*] -> Nat -> * where + El :: IsNat ix => Lkup ix fam -> El fam ix + +data NS :: (k -> *) -> [k] -> * where + There :: NS p xs -> NS p (x : xs) + Here :: p x -> NS p (x : xs) + +class Family (ki :: kon -> *) (fam :: [*]) (codes :: [[[Atom kon]]]) + | fam -> ki codes , ki codes -> fam where + sfrom' :: SNat ix -> El fam ix -> Rep ki (El fam) (Lkup ix codes) + +data Rose a = a :>: [Rose a] + | Leaf a + +type FamRoseInt = '[Rose Int, [Rose Int]] + +type CodesRoseInt = + '[ '[ '[K KInt, I (S Z)], '[K KInt]], '[ '[], '[I Z, I (S Z)]]] + +pattern IdxRoseInt = SZ +pattern IdxListRoseInt = SS SZ + +pat1 :: PoA Singl (El FamRoseInt) '[I Z, I (S Z)] + -> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]] +pat1 d = There (Here d) + +pat2 :: PoA Singl (El FamRoseInt) '[] + -> NS (PoA Singl (El FamRoseInt)) '[ '[], '[I Z, I (S Z)]] +pat2 d = Here d + +pat3 :: PoA Singl (El FamRoseInt) '[K KInt] + -> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]] +pat3 d = There (Here d) + +pat4 :: PoA Singl (El FamRoseInt) '[K KInt, I (S Z)] + -> NS (PoA Singl (El FamRoseInt)) '[ '[K KInt, I (S Z)], '[K KInt]] +pat4 d = Here d + +instance Family Singl FamRoseInt CodesRoseInt where + sfrom' = \case IdxRoseInt -> \case El (x :>: xs) -> Rep (pat4 (NA_K (SInt x) :* (NA_I (El xs) :* NP0))) + El (Leaf x) -> Rep (pat3 (NA_K (SInt x) :* NP0)) + IdxListRoseInt -> \case El [] -> Rep (pat2 NP0) + El (x:xs) -> Rep (pat1 (NA_I (El x) :* (NA_I (El xs) :* NP0))) diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index 595607b628..188f6432fa 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -318,4 +318,6 @@ test('T15005', normal, compile, ['-O']) # we omit profiling because it affects the optimiser and makes the test fail test('T15056', [extra_files(['T15056a.hs']), omit_ways(['profasm'])], multimod_compile, ['T15056', '-O -v0 -ddump-rule-firings']) test('T15186', normal, multimod_compile, ['T15186', '-v0']) -test('T15453', normal, compile, ['-dcore-lint -O1']) +test('T15453', normal, compile, ['-O1']) +test('T15517', normal, compile, ['-O0']) +test('T15517a', normal, compile, ['-O0']) |