summaryrefslogtreecommitdiff
path: root/compiler/GHC/Tc
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/GHC/Tc')
-rw-r--r--compiler/GHC/Tc/Deriv/Generate.hs141
1 files changed, 82 insertions, 59 deletions
diff --git a/compiler/GHC/Tc/Deriv/Generate.hs b/compiler/GHC/Tc/Deriv/Generate.hs
index 35375bc5a5..3f8893460b 100644
--- a/compiler/GHC/Tc/Deriv/Generate.hs
+++ b/compiler/GHC/Tc/Deriv/Generate.hs
@@ -152,6 +152,12 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and
data Foo ... = N1 | N2 ... | Nn | O1 a b | O2 Int | O3 Double b b | ...
+* We first attempt to compare the constructor tags. If tags don't
+ match - we immediately bail out. Otherwise, we then generate one
+ branch per constructor comparing only the fields as we already
+ know that the tags match. Note that it only makes sense to check
+ the tag if there is more than one data constructor.
+
* For the ordinary constructors (if any), we emit clauses to do The
Usual Thing, e.g.,:
@@ -164,23 +170,29 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and
case (a1 `eqFloat#` a2) of r -> r
for that particular test.
-* For nullary constructors, we emit a
- catch-all clause of the form:
+* For nullary constructors, we emit a catch-all clause that always
+ returns True since we already know that the tags match.
+
+* So, given this data type:
+
+ data T = A | B Int | C Char
- (==) a b = case (dataToTag# a) of { a# ->
- case (dataToTag# b) of { b# ->
- case (a# ==# b#) of {
- r -> r }}}
+ We roughly get:
+
+ (==) a b =
+ case dataToTag# a /= dataToTag# b of
+ True -> False
+ False -> case a of -- Here we already know that tags match
+ B a1 -> case b of
+ B b1 -> a1 == b1 -- Only one branch
+ C a1 -> case b of
+ C b1 -> a1 == b1 -- Only one branch
+ _ -> True -- catch-all to match all nullary ctors
An older approach preferred regular pattern matches in some cases
but with dataToTag# forcing it's argument, and work on improving
join points, this seems no longer necessary.
-* If there aren't any nullary constructors, we emit a simpler
- catch-all:
-
- (==) a b = False
-
* For the @(/=)@ method, we normally just use the default method.
If the type is an enumeration type, we could/may/should? generate
special code that calls @dataToTag#@, much like for @(==)@ shown
@@ -202,58 +214,68 @@ gen_Eq_binds loc dit@(DerivInstTys{ dit_rep_tc = tycon
return (method_binds, emptyBag)
where
all_cons = getPossibleDataCons tycon tycon_args
- (nullary_cons, non_nullary_cons) = partition isNullarySrcDataCon all_cons
-
- -- For nullary constructors, use the getTag stuff.
- (tag_match_cons, pat_match_cons) = (nullary_cons, non_nullary_cons)
- no_tag_match_cons = null tag_match_cons
-
- -- (LHS patterns, result)
- fall_through_eqn :: [([LPat (GhcPass 'Parsed)] , LHsExpr GhcPs)]
- fall_through_eqn
- | no_tag_match_cons -- All constructors have arguments
- = case pat_match_cons of
- [] -> [] -- No constructors; no fall-though case
- [_] -> [] -- One constructor; no fall-though case
- _ -> -- Two or more constructors; add fall-through of
- -- (==) _ _ = False
- [([nlWildPat, nlWildPat], false_Expr)]
-
- | otherwise -- One or more tag_match cons; add fall-through of
- -- extract tags compare for equality,
- -- The case `(C1 x) == (C1 y)` can no longer happen
- -- at this point as it's matched earlier.
- = [([a_Pat, b_Pat],
- untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
- (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))]
+ non_nullary_cons = filter (not . isNullarySrcDataCon) all_cons
+
+ -- Generate tag check. See #17240
+ eq_expr_with_tag_check = nlHsCase
+ (nlHsPar (untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
+ (nlHsOpApp (nlHsVar ah_RDR) neInt_RDR (nlHsVar bh_RDR))))
+ [ mkHsCaseAlt (nlLitPat (HsIntPrim NoSourceText 1)) false_Expr
+ , mkHsCaseAlt nlWildPat (
+ nlHsCase
+ (nlHsVar a_RDR)
+ -- Only one branch to match all nullary constructors
+ -- as we already know the tags match but do not emit
+ -- the branch if there are no nullary constructors
+ (let non_nullary_pats = map pats_etc non_nullary_cons
+ in if null non_nullary_cons
+ then non_nullary_pats
+ else non_nullary_pats ++ [mkHsCaseAlt nlWildPat true_Expr]))
+ ]
method_binds = unitBag eq_bind
- eq_bind
- = mkFunBindEC 2 loc eq_RDR (const true_Expr)
- (map pats_etc pat_match_cons
- ++ fall_through_eqn)
+ eq_bind = mkFunBindEC 2 loc eq_RDR (const true_Expr) binds
+ where
+ binds
+ | null all_cons = []
+ -- Tag checking is redundant when there is only one data constructor
+ | [data_con] <- all_cons
+ , (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con
+ , data_con_RDR <- getRdrName data_con
+ , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed
+ , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed
+ , eq_expr <- nested_eq_expr tys_needed as_needed bs_needed
+ = [([con1_pat, con2_pat], eq_expr)]
+ -- This is an enum (all constructors are nullary) - just do a simple tag check
+ | all isNullarySrcDataCon all_cons
+ = [([a_Pat, b_Pat], untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
+ (genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))]
+ | otherwise
+ = [([a_Pat, b_Pat], eq_expr_with_tag_check)]
------------------------------------------------------------------
- pats_etc data_con
- = let
- con1_pat = nlParPat $ nlConVarPat data_con_RDR as_needed
- con2_pat = nlParPat $ nlConVarPat data_con_RDR bs_needed
-
- data_con_RDR = getRdrName data_con
- con_arity = length tys_needed
- as_needed = take con_arity as_RDRs
- bs_needed = take con_arity bs_RDRs
- tys_needed = derivDataConInstArgTys data_con dit
- in
- ([con1_pat, con2_pat], nested_eq_expr tys_needed as_needed bs_needed)
+ nested_eq_expr [] [] [] = true_Expr
+ nested_eq_expr tys as bs
+ = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs)
+ -- Using 'foldr1' here ensures that the derived code is correctly
+ -- associated. See #10859.
where
- nested_eq_expr [] [] [] = true_Expr
- nested_eq_expr tys as bs
- = foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs)
- -- Using 'foldr1' here ensures that the derived code is correctly
- -- associated. See #10859.
- where
- nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b))
+ nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b))
+
+ gen_con_fields_and_tys data_con
+ | tys_needed <- derivDataConInstArgTys data_con dit
+ , con_arity <- length tys_needed
+ , as_needed <- take con_arity as_RDRs
+ , bs_needed <- take con_arity bs_RDRs
+ = (as_needed, bs_needed, tys_needed)
+
+ pats_etc data_con
+ | (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con
+ , data_con_RDR <- getRdrName data_con
+ , con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed
+ , con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed
+ , fields_eq_expr <- nested_eq_expr tys_needed as_needed bs_needed
+ = mkHsCaseAlt con1_pat (nlHsCase (nlHsVar b_RDR) [mkHsCaseAlt con2_pat fields_eq_expr])
{-
************************************************************************
@@ -1473,7 +1495,7 @@ gfoldl_RDR, gunfold_RDR, toConstr_RDR, dataTypeOf_RDR, mkConstrTag_RDR,
dataCast1_RDR, dataCast2_RDR, gcast1_RDR, gcast2_RDR,
constr_RDR, dataType_RDR,
eqChar_RDR , ltChar_RDR , geChar_RDR , gtChar_RDR , leChar_RDR ,
- eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR ,
+ eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR , neInt_RDR ,
eqInt8_RDR , ltInt8_RDR , geInt8_RDR , gtInt8_RDR , leInt8_RDR ,
eqInt16_RDR , ltInt16_RDR , geInt16_RDR , gtInt16_RDR , leInt16_RDR ,
eqInt32_RDR , ltInt32_RDR , geInt32_RDR , gtInt32_RDR , leInt32_RDR ,
@@ -1513,6 +1535,7 @@ gtChar_RDR = varQual_RDR gHC_PRIM (fsLit "gtChar#")
geChar_RDR = varQual_RDR gHC_PRIM (fsLit "geChar#")
eqInt_RDR = varQual_RDR gHC_PRIM (fsLit "==#")
+neInt_RDR = varQual_RDR gHC_PRIM (fsLit "/=#")
ltInt_RDR = varQual_RDR gHC_PRIM (fsLit "<#" )
leInt_RDR = varQual_RDR gHC_PRIM (fsLit "<=#")
gtInt_RDR = varQual_RDR gHC_PRIM (fsLit ">#" )