diff options
author | da-woods <dw-git@d-woods.co.uk> | 2022-09-24 11:24:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-24 11:24:38 +0100 |
commit | ab1053b2b1171664038488cb6721b9e407fe5679 (patch) | |
tree | 806497f6fb61eefee24fa74c40f320b299541dab | |
parent | 230d5083704d8d7fe32f1998dcc375b18752b8f8 (diff) | |
download | cython-ab1053b2b1171664038488cb6721b9e407fe5679.tar.gz |
Fix arguments like `init=False` being ignored in dataclasses (#4958)
Fixes some of https://github.com/cython/cython/issues/4956
-rw-r--r-- | Cython/Compiler/Dataclass.py | 16 | ||||
-rw-r--r-- | Tools/make_dataclass_tests.py | 14 | ||||
-rw-r--r-- | tests/run/test_dataclasses.pyx | 103 |
3 files changed, 97 insertions, 36 deletions
diff --git a/Cython/Compiler/Dataclass.py b/Cython/Compiler/Dataclass.py index 88e147c58..609520004 100644 --- a/Cython/Compiler/Dataclass.py +++ b/Cython/Compiler/Dataclass.py @@ -269,7 +269,7 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform): if not isinstance(v, ExprNodes.BoolNode): error(node.pos, "Arguments passed to cython.dataclasses.dataclass must be True or False") - kwargs[k] = v + kwargs[k] = v.value # remove everything that does not belong into _DataclassParams() kw_only = kwargs.pop("kw_only") @@ -329,12 +329,6 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform): def generate_init_code(code, init, node, fields, kw_only): """ - All of these "generate_*_code" functions return a tuple of: - - code string - - placeholder dict (often empty) - - stat list (often empty) - which can then be combined later and processed once. - Notes on CPython generated "__init__": * Implemented in `_init_fn`. * The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as @@ -346,6 +340,11 @@ def generate_init_code(code, init, node, fields, kw_only): * seen_default and the associated error message are copied directly from Python * Call to user-defined __post_init__ function (if it exists) is copied from CPython. + + Cython behaviour deviates a little here (to be decided if this is right...) + Because the class variable from the assignment does not exist Cython fields will + return None (or whatever their type default is) if not initialized while Python + dataclasses will fall back to looking up the class variable. """ if not init or node.scope.lookup_here("__init__"): return @@ -456,9 +455,6 @@ def generate_cmp_code(code, op, funcname, node, fields): names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)] - if not names: - return # no comparable types - code.add_code_lines([ "def %s(self, other):" % funcname, " if not isinstance(other, %s):" % node.class_name, diff --git a/Tools/make_dataclass_tests.py b/Tools/make_dataclass_tests.py index 0c02f3d14..6a3cee7ac 100644 --- a/Tools/make_dataclass_tests.py +++ b/Tools/make_dataclass_tests.py @@ -120,6 +120,12 @@ skip_tests = frozenset( # These tests are probably fine, but the string substitution in this file doesn't get it right ("TestRepr", "test_repr"), ("TestCase", "test_not_in_repr"), + ('TestRepr', 'test_no_repr'), + # class variable doesn't exist in Cython so uninitialized variable appears differently - for now this is deliberate + ('TestInit', 'test_no_init'), + # I believe the test works but the ordering functions do appear in the class dict (and default slot wrappers which + # just raise NotImplementedError + ('TestOrdering', 'test_no_order'), # not possible to add attributes on extension types ("TestCase", "test_post_init_classmethod"), # Bugs @@ -139,18 +145,14 @@ skip_tests = frozenset( ("TestReplace", "test_recursive_repr_misc_attrs"), # recursion error ("TestReplace", "test_recursive_repr_indirection"), # recursion error ("TestReplace", "test_recursive_repr_indirection_two"), # recursion error - ("TestCase", "test_0_field_compare"), # should return False - ("TestCase", "test_1_field_compare"), # order=False is apparently ignored - ("TestOrdering", "test_no_order"), # probably order=False being ignored - ("TestRepr", "test_no_repr"), # turning off repr doesn't work ( "TestCase", "test_intermediate_non_dataclass", ), # issue with propagating through intermediate class - ("TestCase", "test_post_init"), # init=False being ignored ( "TestFrozen", ), # raises AttributeError, not FrozenInstanceError (may be hard to fix) + ('TestCase', 'test_post_init'), # Works except for AttributeError instead of FrozenInstanceError ("TestReplace", "test_frozen"), # AttributeError not FrozenInstanceError ( "TestCase", @@ -158,7 +160,6 @@ skip_tests = frozenset( ), # doesn't define __setattr__ and just relies on Cython to enforce readonly properties ("TestCase", "test_compare_subclasses"), # wrong comparison ("TestCase", "test_simple_compare"), # wrong comparison - ("TestEq", "test_no_eq"), # wrong comparison (probably eq=False being ignored) ( "TestCase", "test_field_named_self", @@ -210,7 +211,6 @@ version_specific_skips = { ), # needs language support for | operator on types } - class DataclassInDecorators(ast.NodeVisitor): found = False diff --git a/tests/run/test_dataclasses.pyx b/tests/run/test_dataclasses.pyx index d9e8fd2b0..8321b9de0 100644 --- a/tests/run/test_dataclasses.pyx +++ b/tests/run/test_dataclasses.pyx @@ -36,6 +36,36 @@ class C_TestCase_test_field_named_object_frozen: @dataclass @cclass +class C0_TestCase_test_0_field_compare: + pass + +@dataclass(order=False) +@cclass +class C1_TestCase_test_0_field_compare: + pass + +@dataclass(order=True) +@cclass +class C_TestCase_test_0_field_compare: + pass + +@dataclass +@cclass +class C0_TestCase_test_1_field_compare: + x: int + +@dataclass(order=False) +@cclass +class C1_TestCase_test_1_field_compare: + x: int + +@dataclass(order=True) +@cclass +class C_TestCase_test_1_field_compare: + x: int + +@dataclass +@cclass class C_TestCase_test_not_in_compare: x: int = 0 y: int = field(compare=False, default=4) @@ -344,19 +374,6 @@ class R_TestCase_test_dataclasses_pickleable: x: int y: List[int] = field(default_factory=list) -@dataclass(init=False) -@cclass -class C_TestInit_test_no_init: - i: int = 0 - -@dataclass(init=False) -@cclass -class C_TestInit_test_no_init_: - i: int = 2 - - def __init__(self): - self.i = 3 - @dataclass @cclass class C_TestInit_test_overwriting_init: @@ -405,6 +422,19 @@ class C_TestRepr_test_overwriting_repr__: def __repr__(self): return 'x' +@dataclass(eq=False) +@cclass +class C_TestEq_test_no_eq: + x: int + +@dataclass(eq=False) +@cclass +class C_TestEq_test_no_eq_: + x: int + + def __eq__(self, other): + return other == 10 + @dataclass @cclass class C_TestEq_test_overwriting_eq: @@ -517,6 +547,39 @@ class TestCase(unittest.TestCase): c = C('foo') self.assertEqual(c.object, 'foo') + def test_0_field_compare(self): + C0 = C0_TestCase_test_0_field_compare + C1 = C1_TestCase_test_0_field_compare + for cls in [C0, C1]: + with self.subTest(cls=cls): + self.assertEqual(cls(), cls()) + for (idx, fn) in enumerate([lambda a, b: a < b, lambda a, b: a <= b, lambda a, b: a > b, lambda a, b: a >= b]): + with self.subTest(idx=idx): + with self.assertRaises(TypeError): + fn(cls(), cls()) + C = C_TestCase_test_0_field_compare + self.assertLessEqual(C(), C()) + self.assertGreaterEqual(C(), C()) + + def test_1_field_compare(self): + C0 = C0_TestCase_test_1_field_compare + C1 = C1_TestCase_test_1_field_compare + for cls in [C0, C1]: + with self.subTest(cls=cls): + self.assertEqual(cls(1), cls(1)) + self.assertNotEqual(cls(0), cls(1)) + for (idx, fn) in enumerate([lambda a, b: a < b, lambda a, b: a <= b, lambda a, b: a > b, lambda a, b: a >= b]): + with self.subTest(idx=idx): + with self.assertRaises(TypeError): + fn(cls(0), cls(0)) + C = C_TestCase_test_1_field_compare + self.assertLess(C(0), C(1)) + self.assertLessEqual(C(0), C(1)) + self.assertLessEqual(C(1), C(1)) + self.assertGreater(C(1), C(0)) + self.assertGreaterEqual(C(1), C(0)) + self.assertGreaterEqual(C(1), C(1)) + def test_not_in_compare(self): C = C_TestCase_test_not_in_compare self.assertEqual(C(), C(0, 20)) @@ -840,12 +903,6 @@ class TestFieldNoAnnotation(unittest.TestCase): class TestInit(unittest.TestCase): - def test_no_init(self): - C = C_TestInit_test_no_init - self.assertEqual(C().i, 0) - C = C_TestInit_test_no_init_ - self.assertEqual(C().i, 3) - def test_overwriting_init(self): C = C_TestInit_test_overwriting_init self.assertEqual(C(3).x, 6) @@ -866,6 +923,14 @@ class TestRepr(unittest.TestCase): class TestEq(unittest.TestCase): + def test_no_eq(self): + C = C_TestEq_test_no_eq + self.assertNotEqual(C(0), C(0)) + c = C(3) + self.assertEqual(c, c) + C = C_TestEq_test_no_eq_ + self.assertEqual(C(3), 10) + def test_overwriting_eq(self): C = C_TestEq_test_overwriting_eq self.assertEqual(C(1), 3) |