summaryrefslogtreecommitdiff
path: root/compiler/GHC/Core
diff options
context:
space:
mode:
authorSylvain Henry <sylvain@haskus.fr>2023-01-27 13:57:11 +0100
committerMarge Bot <ben+marge-bot@smart-cactus.org>2023-04-13 08:50:33 -0400
commit4dd021227559e1bc70cdaed12e45ff5459c33d27 (patch)
tree04497c322c430924c746102f1d679fed3e7396c0 /compiler/GHC/Core
parent593218794199e23cdfc1a94200cbb9f404e28720 (diff)
downloadhaskell-4dd021227559e1bc70cdaed12e45ff5459c33d27.tar.gz
Add quot folding rule (#22152)
(x / l1) / l2 l1 and l2 /= 0 l1*l2 doesn't overflow ==> x / (l1 * l2)
Diffstat (limited to 'compiler/GHC/Core')
-rw-r--r--compiler/GHC/Core/Opt/ConstantFold.hs97
1 files changed, 82 insertions, 15 deletions
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs
index fb863d65cb..42ced5a86a 100644
--- a/compiler/GHC/Core/Opt/ConstantFold.hs
+++ b/compiler/GHC/Core/Opt/ConstantFold.hs
@@ -121,7 +121,9 @@ primOpRules nm = \case
Int8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 quot)
, leftZero
, rightIdentity oneI8
- , equalArgs $> Lit oneI8 ]
+ , equalArgs $> Lit oneI8
+ , quotFoldingRules int8Ops
+ ]
Int8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI8
@@ -150,7 +152,9 @@ primOpRules nm = \case
, mulFoldingRules Word8MulOp word8Ops
]
Word8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 quot)
- , rightIdentity oneW8 ]
+ , rightIdentity oneW8
+ , quotFoldingRules word8Ops
+ ]
Word8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW8
@@ -195,7 +199,9 @@ primOpRules nm = \case
Int16QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 quot)
, leftZero
, rightIdentity oneI16
- , equalArgs $> Lit oneI16 ]
+ , equalArgs $> Lit oneI16
+ , quotFoldingRules int16Ops
+ ]
Int16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI16
@@ -224,7 +230,9 @@ primOpRules nm = \case
, mulFoldingRules Word16MulOp word16Ops
]
Word16QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 quot)
- , rightIdentity oneW16 ]
+ , rightIdentity oneW16
+ , quotFoldingRules word16Ops
+ ]
Word16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW16
@@ -269,7 +277,9 @@ primOpRules nm = \case
Int32QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 quot)
, leftZero
, rightIdentity oneI32
- , equalArgs $> Lit oneI32 ]
+ , equalArgs $> Lit oneI32
+ , quotFoldingRules int32Ops
+ ]
Int32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI32
@@ -298,7 +308,9 @@ primOpRules nm = \case
, mulFoldingRules Word32MulOp word32Ops
]
Word32QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 quot)
- , rightIdentity oneW32 ]
+ , rightIdentity oneW32
+ , quotFoldingRules word32Ops
+ ]
Word32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW32
@@ -342,7 +354,9 @@ primOpRules nm = \case
Int64QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 quot)
, leftZero
, rightIdentity oneI64
- , equalArgs $> Lit oneI64 ]
+ , equalArgs $> Lit oneI64
+ , quotFoldingRules int64Ops
+ ]
Int64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroI64
@@ -371,7 +385,9 @@ primOpRules nm = \case
, mulFoldingRules Word64MulOp word64Ops
]
Word64QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 quot)
- , rightIdentity oneW64 ]
+ , rightIdentity oneW64
+ , quotFoldingRules word64Ops
+ ]
Word64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 rem)
, leftZero
, oneLit 1 $> Lit zeroW64
@@ -452,7 +468,9 @@ primOpRules nm = \case
IntQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot)
, leftZero
, rightIdentityPlatform onei
- , equalArgs >> retLit onei ]
+ , equalArgs >> retLit onei
+ , quotFoldingRules intOps
+ ]
IntRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 rem)
, leftZero
, oneLit 1 >> retLit zeroi
@@ -504,7 +522,9 @@ primOpRules nm = \case
, mulFoldingRules WordMulOp wordOps
]
WordQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot)
- , rightIdentityPlatform onew ]
+ , rightIdentityPlatform onew
+ , quotFoldingRules wordOps
+ ]
WordRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 rem)
, leftZero
, oneLit 1 >> retLit zerow
@@ -2653,6 +2673,14 @@ orFoldingRules num_ops = do
(orFoldingRules' platform arg1 arg2 num_ops
<|> orFoldingRules' platform arg2 arg1 num_ops)
+quotFoldingRules :: NumOps -> RuleM CoreExpr
+quotFoldingRules num_ops = do
+ env <- getRuleOpts
+ guard (roNumConstantFolding env)
+ [arg1,arg2] <- getArgs
+ platform <- getPlatform
+ liftMaybe (quotFoldingRules' platform arg1 arg2 num_ops)
+
addFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
addFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
@@ -2943,6 +2971,29 @@ orFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
mkL = Lit . mkNumLiteral platform num_ops
or x y = BinOpApp x (fromJust (numOr num_ops)) y
+quotFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr
+quotFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of
+
+ -- (x / l1) / l2
+ -- l1 and l2 /= 0
+ -- l1*l2 doesn't overflow
+ -- ==> x / (l1 * l2)
+ (is_div num_ops -> Just (x, L l1), L l2)
+ | l1 /= 0
+ , l2 /= 0
+ -- check that the result of the multiplication is in range
+ , Just l <- mkNumLiteralMaybe platform num_ops (l1 * l2)
+ -> Just (div x (Lit l))
+ -- NB: we could directly return 0 or (-1) in case of overflow,
+ -- but we would need to know
+ -- (1) if we're dealing with a quot or a div operation
+ -- (2) if it's an underflow or an overflow.
+ -- Left as future work for now.
+
+ _ -> Nothing
+ where
+ div x y = BinOpApp x (fromJust (numDiv num_ops)) y
+
is_binop :: PrimOp -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
is_binop op e = case e of
BinOpApp x op' y | op == op' -> Just (x,y)
@@ -2953,12 +3004,13 @@ is_op op e = case e of
App (OpVal op') x | op == op' -> Just x
_ -> Nothing
-is_add, is_sub, is_mul, is_and, is_or :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
+is_add, is_sub, is_mul, is_and, is_or, is_div :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr)
is_add num_ops e = is_binop (numAdd num_ops) e
is_sub num_ops e = is_binop (numSub num_ops) e
is_mul num_ops e = is_binop (numMul num_ops) e
is_and num_ops e = numAnd num_ops >>= \op -> is_binop op e
is_or num_ops e = numOr num_ops >>= \op -> is_binop op e
+is_div num_ops e = numDiv num_ops >>= \op -> is_binop op e
is_neg :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr)
is_neg num_ops e = numNeg num_ops >>= \op -> is_op op e
@@ -3007,6 +3059,7 @@ data NumOps = NumOps
{ numAdd :: !PrimOp -- ^ Add two numbers
, numSub :: !PrimOp -- ^ Sub two numbers
, numMul :: !PrimOp -- ^ Multiply two numbers
+ , numDiv :: !(Maybe PrimOp) -- ^ Divide two numbers
, numAnd :: !(Maybe PrimOp) -- ^ And two numbers
, numOr :: !(Maybe PrimOp) -- ^ Or two numbers
, numNeg :: !(Maybe PrimOp) -- ^ Negate a number
@@ -3017,15 +3070,20 @@ data NumOps = NumOps
mkNumLiteral :: Platform -> NumOps -> Integer -> Literal
mkNumLiteral platform ops i = mkLitNumberWrap platform (numLitType ops) i
+-- | Create a numeric literal if it is in range
+mkNumLiteralMaybe :: Platform -> NumOps -> Integer -> Maybe Literal
+mkNumLiteralMaybe platform ops i = mkLitNumberMaybe platform (numLitType ops) i
+
int8Ops :: NumOps
int8Ops = NumOps
{ numAdd = Int8AddOp
, numSub = Int8SubOp
, numMul = Int8MulOp
- , numLitType = LitNumInt8
+ , numDiv = Just Int8QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int8NegOp
+ , numLitType = LitNumInt8
}
word8Ops :: NumOps
@@ -3033,6 +3091,7 @@ word8Ops = NumOps
{ numAdd = Word8AddOp
, numSub = Word8SubOp
, numMul = Word8MulOp
+ , numDiv = Just Word8QuotOp
, numAnd = Just Word8AndOp
, numOr = Just Word8OrOp
, numNeg = Nothing
@@ -3044,10 +3103,11 @@ int16Ops = NumOps
{ numAdd = Int16AddOp
, numSub = Int16SubOp
, numMul = Int16MulOp
- , numLitType = LitNumInt16
+ , numDiv = Just Int16QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int16NegOp
+ , numLitType = LitNumInt16
}
word16Ops :: NumOps
@@ -3055,6 +3115,7 @@ word16Ops = NumOps
{ numAdd = Word16AddOp
, numSub = Word16SubOp
, numMul = Word16MulOp
+ , numDiv = Just Word16QuotOp
, numAnd = Just Word16AndOp
, numOr = Just Word16OrOp
, numNeg = Nothing
@@ -3066,10 +3127,11 @@ int32Ops = NumOps
{ numAdd = Int32AddOp
, numSub = Int32SubOp
, numMul = Int32MulOp
- , numLitType = LitNumInt32
+ , numDiv = Just Int32QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int32NegOp
+ , numLitType = LitNumInt32
}
word32Ops :: NumOps
@@ -3077,6 +3139,7 @@ word32Ops = NumOps
{ numAdd = Word32AddOp
, numSub = Word32SubOp
, numMul = Word32MulOp
+ , numDiv = Just Word32QuotOp
, numAnd = Just Word32AndOp
, numOr = Just Word32OrOp
, numNeg = Nothing
@@ -3088,10 +3151,11 @@ int64Ops = NumOps
{ numAdd = Int64AddOp
, numSub = Int64SubOp
, numMul = Int64MulOp
- , numLitType = LitNumInt64
+ , numDiv = Just Int64QuotOp
, numAnd = Nothing
, numOr = Nothing
, numNeg = Just Int64NegOp
+ , numLitType = LitNumInt64
}
word64Ops :: NumOps
@@ -3099,6 +3163,7 @@ word64Ops = NumOps
{ numAdd = Word64AddOp
, numSub = Word64SubOp
, numMul = Word64MulOp
+ , numDiv = Just Word64QuotOp
, numAnd = Just Word64AndOp
, numOr = Just Word64OrOp
, numNeg = Nothing
@@ -3110,6 +3175,7 @@ intOps = NumOps
{ numAdd = IntAddOp
, numSub = IntSubOp
, numMul = IntMulOp
+ , numDiv = Just IntQuotOp
, numAnd = Just IntAndOp
, numOr = Just IntOrOp
, numNeg = Just IntNegOp
@@ -3121,6 +3187,7 @@ wordOps = NumOps
{ numAdd = WordAddOp
, numSub = WordSubOp
, numMul = WordMulOp
+ , numDiv = Just WordQuotOp
, numAnd = Just WordAndOp
, numOr = Just WordOrOp
, numNeg = Nothing