summaryrefslogtreecommitdiff
path: root/compiler/specialise
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/specialise')
-rw-r--r--compiler/specialise/Rules.hs10
-rw-r--r--compiler/specialise/SpecConstr.hs100
2 files changed, 103 insertions, 7 deletions
diff --git a/compiler/specialise/Rules.hs b/compiler/specialise/Rules.hs
index 319404ef15..b6025955ac 100644
--- a/compiler/specialise/Rules.hs
+++ b/compiler/specialise/Rules.hs
@@ -1148,10 +1148,10 @@ is so important.
-- string for the purposes of error reporting
ruleCheckProgram :: CompilerPhase -- ^ Rule activation test
-> String -- ^ Rule pattern
- -> RuleEnv -- ^ Database of rules
+ -> (Id -> [CoreRule]) -- ^ Rules for an Id
-> CoreProgram -- ^ Bindings to check in
-> SDoc -- ^ Resulting check message
-ruleCheckProgram phase rule_pat rule_base binds
+ruleCheckProgram phase rule_pat rules binds
| isEmptyBag results
= text "Rule check results: no rule application sites"
| otherwise
@@ -1164,7 +1164,7 @@ ruleCheckProgram phase rule_pat rule_base binds
, rc_id_unf = idUnfolding -- Not quite right
-- Should use activeUnfolding
, rc_pattern = rule_pat
- , rc_rule_base = rule_base }
+ , rc_rules = rules }
results = unionManyBags (map (ruleCheckBind env) binds)
line = text (replicate 20 '-')
@@ -1172,7 +1172,7 @@ data RuleCheckEnv = RuleCheckEnv {
rc_is_active :: Activation -> Bool,
rc_id_unf :: IdUnfoldingFun,
rc_pattern :: String,
- rc_rule_base :: RuleEnv
+ rc_rules :: Id -> [CoreRule]
}
ruleCheckBind :: RuleCheckEnv -> CoreBind -> Bag SDoc
@@ -1206,7 +1206,7 @@ ruleCheckFun env fn args
| null name_match_rules = emptyBag
| otherwise = unitBag (ruleAppCheck_help env fn args name_match_rules)
where
- name_match_rules = filter match (getRules (rc_rule_base env) fn)
+ name_match_rules = filter match (rc_rules env fn)
match rule = (rc_pattern env) `isPrefixOf` unpackFS (ruleName rule)
ruleAppCheck_help :: RuleCheckEnv -> Id -> [CoreExpr] -> [CoreRule] -> SDoc
diff --git a/compiler/specialise/SpecConstr.hs b/compiler/specialise/SpecConstr.hs
index efd56ce77c..f62f7d0778 100644
--- a/compiler/specialise/SpecConstr.hs
+++ b/compiler/specialise/SpecConstr.hs
@@ -57,6 +57,7 @@ import UniqFM
import MonadUtils
import Control.Monad ( zipWithM )
import Data.List
+import Data.Maybe ( fromMaybe )
import PrelNames ( specTyConName )
import Module
import TyCon ( TyCon )
@@ -1509,6 +1510,7 @@ data OneSpec =
OS { os_pat :: CallPat -- Call pattern that generated this specialisation
, os_rule :: CoreRule -- Rule connecting original id with the specialisation
, os_id :: OutId -- Spec id
+ , os_orig_id :: OutId -- The original id
, os_rhs :: OutExpr } -- Spec rhs
noSpecInfo :: SpecInfo
@@ -1522,7 +1524,8 @@ specNonRec :: ScEnv
-- plus details of specialisations
specNonRec env body_usg rhs_info
- = specialise env (scu_calls body_usg) rhs_info
+ = addPatUsages env (scu_calls body_usg) <$>
+ specialise env (scu_calls body_usg) rhs_info
(noSpecInfo { si_mb_unspec = Just (ri_rhs_usg rhs_info) })
----------------------
@@ -1533,7 +1536,8 @@ specRec :: TopLevelFlag -> ScEnv
-- plus details of specialisations
specRec top_lvl env body_usg rhs_infos
- = go 1 seed_calls nullUsage init_spec_infos
+ = addPatUsagess env (scu_calls body_usg) <$>
+ go 1 seed_calls nullUsage init_spec_infos
where
(seed_calls, init_spec_infos) -- Note [Seeding top-level recursive groups]
| isTopLevel top_lvl
@@ -1754,8 +1758,64 @@ spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
-- See Note [Transfer activation]
; return (spec_usg, OS { os_pat = call_pat, os_rule = rule
, os_id = spec_id
+ , os_orig_id = fn
, os_rhs = spec_rhs }) }
+-- See Note [ArgOcc from calls to specialized functions]
+addPatUsagess :: ScEnv -> CallEnv -> (ScUsage, [SpecInfo]) -> (ScUsage, [SpecInfo])
+addPatUsagess env body_calls (usg, spec_infos) = (usg `combineUsage` extra_usages, spec_infos)
+ where extra_usages = combineUsages [ extraPatUsages env body_calls si | si <- spec_infos ]
+
+addPatUsages :: ScEnv -> CallEnv -> (ScUsage, SpecInfo) -> (ScUsage, SpecInfo)
+addPatUsages env body_calls (usg, spec_info) = (usg `combineUsage` extra_usage, spec_info)
+ where extra_usage = extraPatUsages env body_calls spec_info
+
+extraPatUsages :: ScEnv -> CallEnv -> SpecInfo -> ScUsage
+extraPatUsages env body_calls si = combineUsages
+ [ patToCallUsage env call_pat call
+ | os <- si_specs si
+ , let fn = os_orig_id os
+ call_pat = os_pat os
+ , pprTrace "add_pat_usages" (ppr fn <+> ppr call_pat) True
+ , call <- fromMaybe [] $ lookupVarEnv body_calls fn
+ ]
+
+patToCallUsage :: ScEnv -> CallPat -> Call -> ScUsage
+patToCallUsage env (_qvars, pats) (Call _ args _)
+ = pprTrace "patToCallUsage" (ppr pats <+> ppr args <+> ppr usage) $
+ usage
+ where
+ usage = combineUsages $ zipWith go pats args
+
+ go :: CoreExpr -> CoreExpr -> ScUsage
+ -- The interesting case
+ go pat (Var v)
+ | Just RecArg <- lookupHowBound env v
+ , arg_occ@ScrutOcc{} <- patToArgOcc pat -- skip if we get UnkOcc
+ = nullUsage { scu_occs = unitVarEnv v arg_occ }
+
+ -- Transparent cases
+ go (Tick _ p) e = go p e
+ go (Cast p _) e = go p e
+ go p (Tick _ e) = go p e
+ go p (Cast e _) = go p e
+
+
+ -- Traverse the tree
+ go (App pf pa) (App f a)
+ = go pf f `combineUsage` go pa a
+
+ -- Boring catch-all
+ go _ _ = nullUsage
+
+patToArgOcc :: CoreExpr -> ArgOcc
+patToArgOcc e@App{}
+ | (Var f, args) <- collectArgs e
+ , Just dc <- isDataConWorkId_maybe f
+ = let arg_occs = [ patToArgOcc arg | arg <- args, not (isTypeArg arg) ]
+ in ScrutOcc $ unitUFM dc arg_occs
+patToArgOcc _
+ = UnkOcc
-- See Note [Strictness information in worker binders]
handOutStrictnessInformation :: [Demand] -> [Var] -> [Var]
@@ -1792,6 +1852,42 @@ calcSpecStrictness fn qvars pats
go_one env _ _ = env
{-
+Note [ArgOcc from calls to specialized functions]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We collect the ArgOcc to find out which parameters are being scrutinized in the
+body function, and only generate specializations when they would lead to some
+optimization: In
+
+ foo x = … case x of (a,b) -> …
+
+We are willing to specialize foo. If we have
+
+ foo x = … bar x …
+ where bar y = …
+
+we normally don’t. But what if we specialize bar? Then we have
+
+ foo x = … bar x …
+ where $sbar a b = …
+ bar y = …
+ {-# RULE forall a b. bar (a,b) = $sbar a b #-}
+
+and now it would be beneficial to create a specialized version of foo that
+calls $sbar directly.
+
+To achieve this, after we specialize bar, we look at the calls to it (found in
+scu_calls), and all the specializations that we created. If there is a call `bar x`
+and a specialization pattern `(x,y)`, then we treat that as if we found a case
+analysis of x, and include `x ↦ ScrutOcc` in scu_occs. This unblocks specialization
+of foo, and so on.
+
+(We might want to generalize this to any call to `baz x` where `baz` has
+rewrite rules that match on constructor arguments, not only for when _we_ _just_
+created specializations.)
+
+(See #14951)
+
Note [spec_usg includes rhs_usg]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In calls to 'specialise', the returned ScUsage must include the rhs_usg in