summaryrefslogtreecommitdiff
path: root/compiler/simplCore/SimplUtils.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/simplCore/SimplUtils.hs')
-rw-r--r--compiler/simplCore/SimplUtils.hs76
1 files changed, 74 insertions, 2 deletions
diff --git a/compiler/simplCore/SimplUtils.hs b/compiler/simplCore/SimplUtils.hs
index 48dce1d090..6c4737507a 100644
--- a/compiler/simplCore/SimplUtils.hs
+++ b/compiler/simplCore/SimplUtils.hs
@@ -60,6 +60,8 @@ import Util
import MonadUtils
import Outputable
import Pair
+import PrelRules
+import Literal
import Control.Monad ( when )
@@ -1752,9 +1754,46 @@ mkCase tries these things
False -> False
and similar friends.
+
+3. Scrutinee Constant Folding
+
+ case x op# k# of _ { ===> case x of _ {
+ a1# -> e1 (a1# inv_op# k#) -> e1
+ a2# -> e2 (a2# inv_op# k#) -> e2
+ ... ...
+ DEFAULT -> ed DEFAULT -> ed
+
+ where (x op# k#) inv_op# k# == x
+
+ And similarly for commuted arguments and for some unary operations.
+
+ The purpose of this transformation is not only to avoid an arithmetic
+ operation at runtime but to allow other transformations to apply in cascade.
+
+ Example with the "Merge Nested Cases" optimization (from #12877):
+
+ main = case t of t0
+ 0## -> ...
+ DEFAULT -> case t0 `minusWord#` 1## of t1
+ 0## -> ...
+ DEFAUT -> case t1 `minusWord#` 1## of t2
+ 0## -> ...
+ DEFAULT -> case t2 `minusWord#` 1## of _
+ 0## -> ...
+ DEFAULT -> ...
+
+ becomes:
+
+ main = case t of _
+ 0## -> ...
+ 1## -> ...
+ 2## -> ...
+ 3## -> ...
+ DEFAULT -> ...
+
-}
-mkCase, mkCase1, mkCase2
+mkCase, mkCase1, mkCase2, mkCase3
:: DynFlags
-> OutExpr -> OutId
-> OutType -> [OutAlt] -- Alternatives in standard (increasing) order
@@ -1848,9 +1887,42 @@ mkCase1 _dflags scrut case_bndr _ alts@((_,_,rhs1) : _) -- Identity case
mkCase1 dflags scrut bndr alts_ty alts = mkCase2 dflags scrut bndr alts_ty alts
--------------------------------------------------
+-- 2. Scrutinee Constant Folding
+--------------------------------------------------
+
+mkCase2 dflags scrut bndr alts_ty alts
+ | gopt Opt_CaseFolding dflags
+ , Just (scrut',f) <- caseRules scrut
+ = mkCase3 dflags scrut' bndr alts_ty (map (mapAlt f) alts)
+ | otherwise
+ = mkCase3 dflags scrut bndr alts_ty alts
+ where
+ -- We need to keep the correct association between the scrutinee and its
+ -- binder if the latter isn't dead. Hence we wrap rhs of alternatives with
+ -- "let bndr = ... in":
+ --
+ -- case v + 10 of y =====> case v of y
+ -- 20 -> e1 10 -> let y = 20 in e1
+ -- DEFAULT -> e2 DEFAULT -> let y = v + 10 in e2
+ --
+ -- Other transformations give: =====> case v of y'
+ -- 10 -> let y = 20 in e1
+ -- DEFAULT -> let y = y' + 10 in e2
+ --
+ wrap_rhs l rhs
+ | isDeadBinder bndr = rhs
+ | otherwise = Let (NonRec bndr l) rhs
+
+ mapAlt f alt@(c,bs,e) = case c of
+ DEFAULT -> (c, bs, wrap_rhs scrut e)
+ LitAlt l
+ | isLitValue l -> (LitAlt (mapLitValue f l), bs, wrap_rhs (Lit l) e)
+ _ -> pprPanic "Unexpected alternative (mkCase2)" (ppr alt)
+
+--------------------------------------------------
-- Catch-all
--------------------------------------------------
-mkCase2 _dflags scrut bndr alts_ty alts
+mkCase3 _dflags scrut bndr alts_ty alts
= return (Case scrut bndr alts_ty alts)
{-