diff options
author | Sylvain Henry <sylvain@haskus.fr> | 2023-01-27 13:57:11 +0100 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2023-04-13 08:50:33 -0400 |
commit | 4dd021227559e1bc70cdaed12e45ff5459c33d27 (patch) | |
tree | 04497c322c430924c746102f1d679fed3e7396c0 /compiler/GHC/Core | |
parent | 593218794199e23cdfc1a94200cbb9f404e28720 (diff) | |
download | haskell-4dd021227559e1bc70cdaed12e45ff5459c33d27.tar.gz |
Add quot folding rule (#22152)
(x / l1) / l2
l1 and l2 /= 0
l1*l2 doesn't overflow
==> x / (l1 * l2)
Diffstat (limited to 'compiler/GHC/Core')
-rw-r--r-- | compiler/GHC/Core/Opt/ConstantFold.hs | 97 |
1 files changed, 82 insertions, 15 deletions
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs index fb863d65cb..42ced5a86a 100644 --- a/compiler/GHC/Core/Opt/ConstantFold.hs +++ b/compiler/GHC/Core/Opt/ConstantFold.hs @@ -121,7 +121,9 @@ primOpRules nm = \case Int8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 quot) , leftZero , rightIdentity oneI8 - , equalArgs $> Lit oneI8 ] + , equalArgs $> Lit oneI8 + , quotFoldingRules int8Ops + ] Int8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int8Op2 rem) , leftZero , oneLit 1 $> Lit zeroI8 @@ -150,7 +152,9 @@ primOpRules nm = \case , mulFoldingRules Word8MulOp word8Ops ] Word8QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 quot) - , rightIdentity oneW8 ] + , rightIdentity oneW8 + , quotFoldingRules word8Ops + ] Word8RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word8Op2 rem) , leftZero , oneLit 1 $> Lit zeroW8 @@ -195,7 +199,9 @@ primOpRules nm = \case Int16QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 quot) , leftZero , rightIdentity oneI16 - , equalArgs $> Lit oneI16 ] + , equalArgs $> Lit oneI16 + , quotFoldingRules int16Ops + ] Int16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int16Op2 rem) , leftZero , oneLit 1 $> Lit zeroI16 @@ -224,7 +230,9 @@ primOpRules nm = \case , mulFoldingRules Word16MulOp word16Ops ] Word16QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 quot) - , rightIdentity oneW16 ] + , rightIdentity oneW16 + , quotFoldingRules word16Ops + ] Word16RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word16Op2 rem) , leftZero , oneLit 1 $> Lit zeroW16 @@ -269,7 +277,9 @@ primOpRules nm = \case Int32QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 quot) , leftZero , rightIdentity oneI32 - , equalArgs $> Lit oneI32 ] + , equalArgs $> Lit oneI32 + , quotFoldingRules int32Ops + ] Int32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int32Op2 rem) , leftZero , oneLit 1 $> Lit zeroI32 @@ -298,7 +308,9 @@ primOpRules nm = \case , mulFoldingRules Word32MulOp word32Ops ] Word32QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 quot) - , rightIdentity oneW32 ] + , rightIdentity oneW32 + , quotFoldingRules word32Ops + ] Word32RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word32Op2 rem) , leftZero , oneLit 1 $> Lit zeroW32 @@ -342,7 +354,9 @@ primOpRules nm = \case Int64QuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 quot) , leftZero , rightIdentity oneI64 - , equalArgs $> Lit oneI64 ] + , equalArgs $> Lit oneI64 + , quotFoldingRules int64Ops + ] Int64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (int64Op2 rem) , leftZero , oneLit 1 $> Lit zeroI64 @@ -371,7 +385,9 @@ primOpRules nm = \case , mulFoldingRules Word64MulOp word64Ops ] Word64QuotOp-> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 quot) - , rightIdentity oneW64 ] + , rightIdentity oneW64 + , quotFoldingRules word64Ops + ] Word64RemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (word64Op2 rem) , leftZero , oneLit 1 $> Lit zeroW64 @@ -452,7 +468,9 @@ primOpRules nm = \case IntQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot) , leftZero , rightIdentityPlatform onei - , equalArgs >> retLit onei ] + , equalArgs >> retLit onei + , quotFoldingRules intOps + ] IntRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 rem) , leftZero , oneLit 1 >> retLit zeroi @@ -504,7 +522,9 @@ primOpRules nm = \case , mulFoldingRules WordMulOp wordOps ] WordQuotOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot) - , rightIdentityPlatform onew ] + , rightIdentityPlatform onew + , quotFoldingRules wordOps + ] WordRemOp -> mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 rem) , leftZero , oneLit 1 >> retLit zerow @@ -2653,6 +2673,14 @@ orFoldingRules num_ops = do (orFoldingRules' platform arg1 arg2 num_ops <|> orFoldingRules' platform arg2 arg1 num_ops) +quotFoldingRules :: NumOps -> RuleM CoreExpr +quotFoldingRules num_ops = do + env <- getRuleOpts + guard (roNumConstantFolding env) + [arg1,arg2] <- getArgs + platform <- getPlatform + liftMaybe (quotFoldingRules' platform arg1 arg2 num_ops) + addFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr addFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of @@ -2943,6 +2971,29 @@ orFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of mkL = Lit . mkNumLiteral platform num_ops or x y = BinOpApp x (fromJust (numOr num_ops)) y +quotFoldingRules' :: Platform -> CoreExpr -> CoreExpr -> NumOps -> Maybe CoreExpr +quotFoldingRules' platform arg1 arg2 num_ops = case (arg1, arg2) of + + -- (x / l1) / l2 + -- l1 and l2 /= 0 + -- l1*l2 doesn't overflow + -- ==> x / (l1 * l2) + (is_div num_ops -> Just (x, L l1), L l2) + | l1 /= 0 + , l2 /= 0 + -- check that the result of the multiplication is in range + , Just l <- mkNumLiteralMaybe platform num_ops (l1 * l2) + -> Just (div x (Lit l)) + -- NB: we could directly return 0 or (-1) in case of overflow, + -- but we would need to know + -- (1) if we're dealing with a quot or a div operation + -- (2) if it's an underflow or an overflow. + -- Left as future work for now. + + _ -> Nothing + where + div x y = BinOpApp x (fromJust (numDiv num_ops)) y + is_binop :: PrimOp -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr) is_binop op e = case e of BinOpApp x op' y | op == op' -> Just (x,y) @@ -2953,12 +3004,13 @@ is_op op e = case e of App (OpVal op') x | op == op' -> Just x _ -> Nothing -is_add, is_sub, is_mul, is_and, is_or :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr) +is_add, is_sub, is_mul, is_and, is_or, is_div :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr, Arg CoreBndr) is_add num_ops e = is_binop (numAdd num_ops) e is_sub num_ops e = is_binop (numSub num_ops) e is_mul num_ops e = is_binop (numMul num_ops) e is_and num_ops e = numAnd num_ops >>= \op -> is_binop op e is_or num_ops e = numOr num_ops >>= \op -> is_binop op e +is_div num_ops e = numDiv num_ops >>= \op -> is_binop op e is_neg :: NumOps -> CoreExpr -> Maybe (Arg CoreBndr) is_neg num_ops e = numNeg num_ops >>= \op -> is_op op e @@ -3007,6 +3059,7 @@ data NumOps = NumOps { numAdd :: !PrimOp -- ^ Add two numbers , numSub :: !PrimOp -- ^ Sub two numbers , numMul :: !PrimOp -- ^ Multiply two numbers + , numDiv :: !(Maybe PrimOp) -- ^ Divide two numbers , numAnd :: !(Maybe PrimOp) -- ^ And two numbers , numOr :: !(Maybe PrimOp) -- ^ Or two numbers , numNeg :: !(Maybe PrimOp) -- ^ Negate a number @@ -3017,15 +3070,20 @@ data NumOps = NumOps mkNumLiteral :: Platform -> NumOps -> Integer -> Literal mkNumLiteral platform ops i = mkLitNumberWrap platform (numLitType ops) i +-- | Create a numeric literal if it is in range +mkNumLiteralMaybe :: Platform -> NumOps -> Integer -> Maybe Literal +mkNumLiteralMaybe platform ops i = mkLitNumberMaybe platform (numLitType ops) i + int8Ops :: NumOps int8Ops = NumOps { numAdd = Int8AddOp , numSub = Int8SubOp , numMul = Int8MulOp - , numLitType = LitNumInt8 + , numDiv = Just Int8QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int8NegOp + , numLitType = LitNumInt8 } word8Ops :: NumOps @@ -3033,6 +3091,7 @@ word8Ops = NumOps { numAdd = Word8AddOp , numSub = Word8SubOp , numMul = Word8MulOp + , numDiv = Just Word8QuotOp , numAnd = Just Word8AndOp , numOr = Just Word8OrOp , numNeg = Nothing @@ -3044,10 +3103,11 @@ int16Ops = NumOps { numAdd = Int16AddOp , numSub = Int16SubOp , numMul = Int16MulOp - , numLitType = LitNumInt16 + , numDiv = Just Int16QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int16NegOp + , numLitType = LitNumInt16 } word16Ops :: NumOps @@ -3055,6 +3115,7 @@ word16Ops = NumOps { numAdd = Word16AddOp , numSub = Word16SubOp , numMul = Word16MulOp + , numDiv = Just Word16QuotOp , numAnd = Just Word16AndOp , numOr = Just Word16OrOp , numNeg = Nothing @@ -3066,10 +3127,11 @@ int32Ops = NumOps { numAdd = Int32AddOp , numSub = Int32SubOp , numMul = Int32MulOp - , numLitType = LitNumInt32 + , numDiv = Just Int32QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int32NegOp + , numLitType = LitNumInt32 } word32Ops :: NumOps @@ -3077,6 +3139,7 @@ word32Ops = NumOps { numAdd = Word32AddOp , numSub = Word32SubOp , numMul = Word32MulOp + , numDiv = Just Word32QuotOp , numAnd = Just Word32AndOp , numOr = Just Word32OrOp , numNeg = Nothing @@ -3088,10 +3151,11 @@ int64Ops = NumOps { numAdd = Int64AddOp , numSub = Int64SubOp , numMul = Int64MulOp - , numLitType = LitNumInt64 + , numDiv = Just Int64QuotOp , numAnd = Nothing , numOr = Nothing , numNeg = Just Int64NegOp + , numLitType = LitNumInt64 } word64Ops :: NumOps @@ -3099,6 +3163,7 @@ word64Ops = NumOps { numAdd = Word64AddOp , numSub = Word64SubOp , numMul = Word64MulOp + , numDiv = Just Word64QuotOp , numAnd = Just Word64AndOp , numOr = Just Word64OrOp , numNeg = Nothing @@ -3110,6 +3175,7 @@ intOps = NumOps { numAdd = IntAddOp , numSub = IntSubOp , numMul = IntMulOp + , numDiv = Just IntQuotOp , numAnd = Just IntAndOp , numOr = Just IntOrOp , numNeg = Just IntNegOp @@ -3121,6 +3187,7 @@ wordOps = NumOps { numAdd = WordAddOp , numSub = WordSubOp , numMul = WordMulOp + , numDiv = Just WordQuotOp , numAnd = Just WordAndOp , numOr = Just WordOrOp , numNeg = Nothing |