diff options
author | nineonine <mail4chemik@gmail.com> | 2021-11-10 00:52:06 -0800 |
---|---|---|
committer | Marge Bot <ben+marge-bot@smart-cactus.org> | 2022-03-20 21:16:06 -0400 |
commit | c842611fc72d987519cd9fab1c351135ae93665e (patch) | |
tree | aa4365b4050a0733887a4d4f0291e6e9d52cc801 /compiler/GHC/Tc | |
parent | d45bb70178e044bc8b6e8215da7bc8ed0c95f2cb (diff) | |
download | haskell-c842611fc72d987519cd9fab1c351135ae93665e.tar.gz |
Revamp derived Eq instance code generation (#17240)
This patch improves code generation for derived Eq instances.
The idea is to use 'dataToTag' to evaluate both arguments.
This allows to 'short-circuit' when tags do not match.
Unfortunately, inner evals are still present when we branch
on tags. This is due to the way 'dataToTag#' primop
evaluates its argument in the code generator. #21207 was
created to explore further optimizations.
Metric Decrease:
LargeRecord
Diffstat (limited to 'compiler/GHC/Tc')
-rw-r--r-- | compiler/GHC/Tc/Deriv/Generate.hs | 141 |
1 files changed, 82 insertions, 59 deletions
diff --git a/compiler/GHC/Tc/Deriv/Generate.hs b/compiler/GHC/Tc/Deriv/Generate.hs index 35375bc5a5..3f8893460b 100644 --- a/compiler/GHC/Tc/Deriv/Generate.hs +++ b/compiler/GHC/Tc/Deriv/Generate.hs @@ -152,6 +152,12 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and data Foo ... = N1 | N2 ... | Nn | O1 a b | O2 Int | O3 Double b b | ... +* We first attempt to compare the constructor tags. If tags don't + match - we immediately bail out. Otherwise, we then generate one + branch per constructor comparing only the fields as we already + know that the tags match. Note that it only makes sense to check + the tag if there is more than one data constructor. + * For the ordinary constructors (if any), we emit clauses to do The Usual Thing, e.g.,: @@ -164,23 +170,29 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and case (a1 `eqFloat#` a2) of r -> r for that particular test. -* For nullary constructors, we emit a - catch-all clause of the form: +* For nullary constructors, we emit a catch-all clause that always + returns True since we already know that the tags match. + +* So, given this data type: + + data T = A | B Int | C Char - (==) a b = case (dataToTag# a) of { a# -> - case (dataToTag# b) of { b# -> - case (a# ==# b#) of { - r -> r }}} + We roughly get: + + (==) a b = + case dataToTag# a /= dataToTag# b of + True -> False + False -> case a of -- Here we already know that tags match + B a1 -> case b of + B b1 -> a1 == b1 -- Only one branch + C a1 -> case b of + C b1 -> a1 == b1 -- Only one branch + _ -> True -- catch-all to match all nullary ctors An older approach preferred regular pattern matches in some cases but with dataToTag# forcing it's argument, and work on improving join points, this seems no longer necessary. -* If there aren't any nullary constructors, we emit a simpler - catch-all: - - (==) a b = False - * For the @(/=)@ method, we normally just use the default method. If the type is an enumeration type, we could/may/should? generate special code that calls @dataToTag#@, much like for @(==)@ shown @@ -202,58 +214,68 @@ gen_Eq_binds loc dit@(DerivInstTys{ dit_rep_tc = tycon return (method_binds, emptyBag) where all_cons = getPossibleDataCons tycon tycon_args - (nullary_cons, non_nullary_cons) = partition isNullarySrcDataCon all_cons - - -- For nullary constructors, use the getTag stuff. - (tag_match_cons, pat_match_cons) = (nullary_cons, non_nullary_cons) - no_tag_match_cons = null tag_match_cons - - -- (LHS patterns, result) - fall_through_eqn :: [([LPat (GhcPass 'Parsed)] , LHsExpr GhcPs)] - fall_through_eqn - | no_tag_match_cons -- All constructors have arguments - = case pat_match_cons of - [] -> [] -- No constructors; no fall-though case - [_] -> [] -- One constructor; no fall-though case - _ -> -- Two or more constructors; add fall-through of - -- (==) _ _ = False - [([nlWildPat, nlWildPat], false_Expr)] - - | otherwise -- One or more tag_match cons; add fall-through of - -- extract tags compare for equality, - -- The case `(C1 x) == (C1 y)` can no longer happen - -- at this point as it's matched earlier. - = [([a_Pat, b_Pat], - untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)] - (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))] + non_nullary_cons = filter (not . isNullarySrcDataCon) all_cons + + -- Generate tag check. See #17240 + eq_expr_with_tag_check = nlHsCase + (nlHsPar (untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)] + (nlHsOpApp (nlHsVar ah_RDR) neInt_RDR (nlHsVar bh_RDR)))) + [ mkHsCaseAlt (nlLitPat (HsIntPrim NoSourceText 1)) false_Expr + , mkHsCaseAlt nlWildPat ( + nlHsCase + (nlHsVar a_RDR) + -- Only one branch to match all nullary constructors + -- as we already know the tags match but do not emit + -- the branch if there are no nullary constructors + (let non_nullary_pats = map pats_etc non_nullary_cons + in if null non_nullary_cons + then non_nullary_pats + else non_nullary_pats ++ [mkHsCaseAlt nlWildPat true_Expr])) + ] method_binds = unitBag eq_bind - eq_bind - = mkFunBindEC 2 loc eq_RDR (const true_Expr) - (map pats_etc pat_match_cons - ++ fall_through_eqn) + eq_bind = mkFunBindEC 2 loc eq_RDR (const true_Expr) binds + where + binds + | null all_cons = [] + -- Tag checking is redundant when there is only one data constructor + | [data_con] <- all_cons + , (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con + , data_con_RDR <- getRdrName data_con + , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed + , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed + , eq_expr <- nested_eq_expr tys_needed as_needed bs_needed + = [([con1_pat, con2_pat], eq_expr)] + -- This is an enum (all constructors are nullary) - just do a simple tag check + | all isNullarySrcDataCon all_cons + = [([a_Pat, b_Pat], untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)] + (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))] + | otherwise + = [([a_Pat, b_Pat], eq_expr_with_tag_check)] ------------------------------------------------------------------ - pats_etc data_con - = let - con1_pat = nlParPat $ nlConVarPat data_con_RDR as_needed - con2_pat = nlParPat $ nlConVarPat data_con_RDR bs_needed - - data_con_RDR = getRdrName data_con - con_arity = length tys_needed - as_needed = take con_arity as_RDRs - bs_needed = take con_arity bs_RDRs - tys_needed = derivDataConInstArgTys data_con dit - in - ([con1_pat, con2_pat], nested_eq_expr tys_needed as_needed bs_needed) + nested_eq_expr [] [] [] = true_Expr + nested_eq_expr tys as bs + = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs) + -- Using 'foldr1' here ensures that the derived code is correctly + -- associated. See #10859. where - nested_eq_expr [] [] [] = true_Expr - nested_eq_expr tys as bs - = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs) - -- Using 'foldr1' here ensures that the derived code is correctly - -- associated. See #10859. - where - nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b)) + nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b)) + + gen_con_fields_and_tys data_con + | tys_needed <- derivDataConInstArgTys data_con dit + , con_arity <- length tys_needed + , as_needed <- take con_arity as_RDRs + , bs_needed <- take con_arity bs_RDRs + = (as_needed, bs_needed, tys_needed) + + pats_etc data_con + | (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con + , data_con_RDR <- getRdrName data_con + , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed + , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed + , fields_eq_expr <- nested_eq_expr tys_needed as_needed bs_needed + = mkHsCaseAlt con1_pat (nlHsCase (nlHsVar b_RDR) [mkHsCaseAlt con2_pat fields_eq_expr]) {- ************************************************************************ @@ -1473,7 +1495,7 @@ gfoldl_RDR, gunfold_RDR, toConstr_RDR, dataTypeOf_RDR, mkConstrTag_RDR, dataCast1_RDR, dataCast2_RDR, gcast1_RDR, gcast2_RDR, constr_RDR, dataType_RDR, eqChar_RDR , ltChar_RDR , geChar_RDR , gtChar_RDR , leChar_RDR , - eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR , + eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR , neInt_RDR , eqInt8_RDR , ltInt8_RDR , geInt8_RDR , gtInt8_RDR , leInt8_RDR , eqInt16_RDR , ltInt16_RDR , geInt16_RDR , gtInt16_RDR , leInt16_RDR , eqInt32_RDR , ltInt32_RDR , geInt32_RDR , gtInt32_RDR , leInt32_RDR , @@ -1513,6 +1535,7 @@ gtChar_RDR = varQual_RDR gHC_PRIM (fsLit "gtChar#") geChar_RDR = varQual_RDR gHC_PRIM (fsLit "geChar#") eqInt_RDR = varQual_RDR gHC_PRIM (fsLit "==#") +neInt_RDR = varQual_RDR gHC_PRIM (fsLit "/=#") ltInt_RDR = varQual_RDR gHC_PRIM (fsLit "<#" ) leInt_RDR = varQual_RDR gHC_PRIM (fsLit "<=#") gtInt_RDR = varQual_RDR gHC_PRIM (fsLit ">#" ) |