diff options
| -rw-r--r-- | compiler/coreSyn/CoreOpt.hs | 44 | ||||
| -rw-r--r-- | compiler/simplCore/OccurAnal.hs | 68 | ||||
| -rw-r--r-- | testsuite/tests/simplCore/should_compile/T14650.hs | 76 | ||||
| -rw-r--r-- | testsuite/tests/simplCore/should_compile/all.T | 1 | 
4 files changed, 136 insertions, 53 deletions
| diff --git a/compiler/coreSyn/CoreOpt.hs b/compiler/coreSyn/CoreOpt.hs index 4240647d58..0f35e8f3ac 100644 --- a/compiler/coreSyn/CoreOpt.hs +++ b/compiler/coreSyn/CoreOpt.hs @@ -22,7 +22,7 @@ module CoreOpt (  import GhcPrelude -import CoreArity( joinRhsArity, etaExpandToJoinPoint ) +import CoreArity( etaExpandToJoinPoint )  import CoreSyn  import CoreSubst @@ -646,58 +646,18 @@ joinPointBinding_maybe bndr rhs    = Just (bndr, rhs)    | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr) -  , not (bad_unfolding join_arity (idUnfolding bndr))    , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs    = Just (bndr `asJoinId` join_arity, mkLams bndrs body)    | otherwise    = Nothing -  where -    -- bad_unfolding returns True if we should /not/ convert a non-join-id -    -- into a join-id, even though it is AlwaysTailCalled -    -- See Note [Join points and INLINE pragmas] -    bad_unfolding join_arity (CoreUnfolding { uf_src = src, uf_tmpl = rhs }) -      = isStableSource src && join_arity > joinRhsArity rhs -    bad_unfolding _ (DFunUnfolding {}) -      = True -    bad_unfolding _ _ -      = False -  joinPointBindings_maybe :: [(InBndr, InExpr)] -> Maybe [(InBndr, InExpr)]  joinPointBindings_maybe bndrs    = mapM (uncurry joinPointBinding_maybe) bndrs -{- Note [Join points and INLINE pragmas] -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Consider -   f x = let g = \x. not  -- Arity 1 -             {-# INLINE g #-} -         in case x of -              A -> g True True -              B -> g True False -              C -> blah2 - -Here 'g' is always tail-called applied to 2 args, but the stable -unfolding captured by the INLINE pragma has arity 1.  If we try to -convert g to be a join point, its unfolding will still have arity 1 -(since it is stable, and we don't meddle with stable unfoldings), and -Lint will complain (see Note [Invariants on join points], (2a), in -CoreSyn.  Trac #13413. - -Moreover, since g is going to be inlined anyway, there is no benefit -from making it a join point. - -If it is recursive, and uselessly marked INLINE, this will stop us -making it a join point, which is annoying.  But occasionally -(notably in class methods; see Note [Instances and loop breakers] in -TcInstDcls) we mark recursive things as INLINE but the recursion -unravels; so ignoring INLINE pragmas on recursive things isn't good -either. - - -************************************************************************ +{- *********************************************************************  *                                                                      *           exprIsConApp_maybe  *                                                                      * diff --git a/compiler/simplCore/OccurAnal.hs b/compiler/simplCore/OccurAnal.hs index bcc84100a1..b0987d5da0 100644 --- a/compiler/simplCore/OccurAnal.hs +++ b/compiler/simplCore/OccurAnal.hs @@ -25,6 +25,7 @@ import CoreSyn  import CoreFVs  import CoreUtils        ( exprIsTrivial, isDefaultAlt, isExpandableApp,                            stripTicksTopE, mkTicks ) +import CoreArity        ( joinRhsArity )  import Id  import IdInfo  import Name( localiseName ) @@ -2664,9 +2665,8 @@ tagRecBinders lvl body_uds triples             , AlwaysTailCalled arity <- tailCallInfo occ             = Just arity             | otherwise -           = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if we're -                                       -- making join points! -             Nothing +           = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if +             Nothing                   -- we are making join points!       -- 3. Compute final usage details from adjusted RHS details       adj_uds   = body_uds +++ combineUsageDetailsList rhs_udss' @@ -2694,10 +2694,15 @@ setBinderOcc occ_info bndr  -- | Decide whether some bindings should be made into join points or not.  -- Returns `False` if they can't be join points. Note that it's an --- all-or-nothing decision, as if multiple binders are given, they're assumed to --- be mutually recursive. +-- all-or-nothing decision, as if multiple binders are given, they're +-- assumed to be mutually recursive.  -- --- See Note [Invariants for join points] in CoreSyn. +-- It must, however, be a final decision. If we say "True" for 'f', +-- and then subsequently decide /not/ make 'f' into a join point, then +-- the decision about another binding 'g' might be invalidated if (say) +-- 'f' tail-calls 'g'. +-- +-- See Note [Invariants on join points] in CoreSyn.  decideJoinPointHood :: TopLevelFlag -> UsageDetails                      -> [CoreBndr]                      -> Bool @@ -2721,6 +2726,9 @@ decideJoinPointHood NotTopLevel usage bndrs          AlwaysTailCalled arity <- tailCallInfo (lookupDetails usage bndr)        , -- Invariant 1 as applied to LHSes of rules          all (ok_rule arity) (idCoreRules bndr) +        -- Invariant 2a: stable unfoldings +        -- See Note [Join points and INLINE pragmas] +      , ok_unfolding arity (realIdUnfolding bndr)          -- Invariant 4: Satisfies polymorphism rule        , isValidJoinPointType arity (idType bndr)        = True @@ -2732,14 +2740,52 @@ decideJoinPointHood NotTopLevel usage bndrs        = args `lengthIs` join_arity          -- Invariant 1 as applied to LHSes of rules +    -- ok_unfolding returns False if we should /not/ convert a non-join-id +    -- into a join-id, even though it is AlwaysTailCalled +    ok_unfolding join_arity (CoreUnfolding { uf_src = src, uf_tmpl = rhs }) +      = not (isStableSource src && join_arity > joinRhsArity rhs) +    ok_unfolding _ (DFunUnfolding {}) +      = False +    ok_unfolding _ _ +      = True +  willBeJoinId_maybe :: CoreBndr -> Maybe JoinArity  willBeJoinId_maybe bndr -  | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr) -  = Just arity -  | otherwise -  = isJoinId_maybe bndr +  = case tailCallInfo (idOccInfo bndr) of +      AlwaysTailCalled arity -> Just arity +      _                      -> isJoinId_maybe bndr + + +{- Note [Join points and INLINE pragmas] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Consider +   f x = let g = \x. not  -- Arity 1 +             {-# INLINE g #-} +         in case x of +              A -> g True True +              B -> g True False +              C -> blah2 + +Here 'g' is always tail-called applied to 2 args, but the stable +unfolding captured by the INLINE pragma has arity 1.  If we try to +convert g to be a join point, its unfolding will still have arity 1 +(since it is stable, and we don't meddle with stable unfoldings), and +Lint will complain (see Note [Invariants on join points], (2a), in +CoreSyn.  Trac #13413. + +Moreover, since g is going to be inlined anyway, there is no benefit +from making it a join point. + +If it is recursive, and uselessly marked INLINE, this will stop us +making it a join point, which is annoying.  But occasionally +(notably in class methods; see Note [Instances and loop breakers] in +TcInstDcls) we mark recursive things as INLINE but the recursion +unravels; so ignoring INLINE pragmas on recursive things isn't good +either. + +See Invariant 2a of Note [Invariants on join points] in CoreSyn + -{-  ************************************************************************  *                                                                      *  \subsection{Operations over OccInfo} diff --git a/testsuite/tests/simplCore/should_compile/T14650.hs b/testsuite/tests/simplCore/should_compile/T14650.hs new file mode 100644 index 0000000000..b9eac20021 --- /dev/null +++ b/testsuite/tests/simplCore/should_compile/T14650.hs @@ -0,0 +1,76 @@ +module MergeSort (
 +  msortBy
 + ) where
 +
 +infixl 7 :%
 +infixr 6 :&
 +
 +data LenList a = LL {-# UNPACK #-} !Int Bool [a]
 +
 +data LenListAnd a b = {-# UNPACK #-} !(LenList a) :% b
 +
 +data Stack a
 +  = End
 +  | {-# UNPACK #-} !(LenList a) :& (Stack a)
 +
 +msortBy :: (a -> a -> Ordering) -> [a] -> [a]
 +msortBy cmp = mergeSplit End where
 +  splitAsc n _ _ _ | n `seq` False = undefined
 +  splitAsc n as _ [] = LL n True as :% []
 +  splitAsc n as a bs@(b:bs') = case cmp a b of
 +    GT -> LL n False as :% bs
 +    _  -> splitAsc (n + 1) as b bs'
 +
 +  splitDesc n _ _ _ | n `seq` False = undefined
 +  splitDesc n rs a [] = LL n True (a:rs) :% []
 +  splitDesc n rs a bs@(b:bs') = case cmp a b of
 +    GT -> splitDesc (n + 1) (a:rs) b bs'
 +    _  -> LL n True (a:rs) :% bs
 +
 +  mergeLL (LL na fa as) (LL nb fb bs) = LL (na + nb) True $ mergeLs na as nb bs where
 +    mergeLs nx  _ ny  _ | nx `seq` ny `seq` False = undefined
 +    mergeLs  0  _ ny ys = if fb then ys else take ny ys
 +    mergeLs  _ [] ny ys = if fb then ys else take ny ys
 +    mergeLs nx xs  0  _ = if fa then xs else take nx xs
 +    mergeLs nx xs  _ [] = if fa then xs else take nx xs
 +    mergeLs nx xs@(x:xs') ny ys@(y:ys') = case cmp x y of
 +      GT -> y:mergeLs nx xs (ny - 1) ys'
 +      _  -> x:mergeLs (nx - 1) xs' ny ys
 +
 +  push ssx px@(LL nx _ _) = case ssx of
 +    End -> px :% ssx
 +    py@(LL ny _ _) :& ssy -> case ssy of
 +      End
 +        | nx >= ny -> mergeLL py px :% ssy
 +      pz@(LL nz _ _) :& ssz
 +        | nx >= ny || nx + ny >= nz -> case nx > nz of
 +            False -> push ssy $ mergeLL py px
 +            _     -> case push ssz $ mergeLL pz py of
 +              pz' :% ssz' -> push (pz' :& ssz') px
 +      _ -> px :% ssx
 +
 +  mergeAll _ px | px `seq` False = undefined
 +  mergeAll ssx px@(LL nx _ xs) = case ssx of
 +    End -> xs
 +    py@(LL _ _ _) :& ssy -> case ssy of
 +      End -> case mergeLL py px of
 +        LL _ _ xys -> xys
 +      pz@(LL nz _ _) :& ssz -> case nx > nz of
 +        False -> mergeAll ssy $ mergeLL py px
 +        _     -> case push ssz $ mergeLL pz py of
 +          pz' :% ssz' -> mergeAll (pz' :& ssz') px
 +
 +  mergeSplit ss _ | ss `seq` False = undefined
 +  mergeSplit ss [] = case ss of
 +    End -> []
 +    px :& ss' -> mergeAll ss' px
 +  mergeSplit ss as@(a:as') = case as' of
 +    [] -> mergeAll ss $ LL 1 True as
 +    b:bs -> case cmp a b of
 +      GT -> case splitDesc 2 [a] b bs of
 +        px :% rs -> case push ss px of
 +          px' :% ss' -> mergeSplit (px' :& ss') rs
 +      _  -> case splitAsc 2 as b bs of
 +        px :% rs -> case push ss px of
 +          px' :% ss' -> mergeSplit (px' :& ss') rs
 +  {-# INLINABLE mergeSplit #-}
 diff --git a/testsuite/tests/simplCore/should_compile/all.T b/testsuite/tests/simplCore/should_compile/all.T index e51e8f7db4..e681ca7363 100644 --- a/testsuite/tests/simplCore/should_compile/all.T +++ b/testsuite/tests/simplCore/should_compile/all.T @@ -289,3 +289,4 @@ test('T14152a', [extra_files(['T14152.hs']), pre_cmd('cp T14152.hs T14152a.hs'),                   only_ways(['optasm']), check_errmsg(r'dead code') ],                  compile, ['-fno-exitification -ddump-simpl'])  test('T13990', normal, compile, ['-dcore-lint -O']) +test('T14650', normal, compile, ['-O2']) | 
