summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorda-woods <dw-git@d-woods.co.uk>2022-09-24 11:24:38 +0100
committerGitHub <noreply@github.com>2022-09-24 11:24:38 +0100
commitab1053b2b1171664038488cb6721b9e407fe5679 (patch)
tree806497f6fb61eefee24fa74c40f320b299541dab
parent230d5083704d8d7fe32f1998dcc375b18752b8f8 (diff)
downloadcython-ab1053b2b1171664038488cb6721b9e407fe5679.tar.gz
Fix arguments like `init=False` being ignored in dataclasses (#4958)
Fixes some of https://github.com/cython/cython/issues/4956
-rw-r--r--Cython/Compiler/Dataclass.py16
-rw-r--r--Tools/make_dataclass_tests.py14
-rw-r--r--tests/run/test_dataclasses.pyx103
3 files changed, 97 insertions, 36 deletions
diff --git a/Cython/Compiler/Dataclass.py b/Cython/Compiler/Dataclass.py
index 88e147c58..609520004 100644
--- a/Cython/Compiler/Dataclass.py
+++ b/Cython/Compiler/Dataclass.py
@@ -269,7 +269,7 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
if not isinstance(v, ExprNodes.BoolNode):
error(node.pos,
"Arguments passed to cython.dataclasses.dataclass must be True or False")
- kwargs[k] = v
+ kwargs[k] = v.value
# remove everything that does not belong into _DataclassParams()
kw_only = kwargs.pop("kw_only")
@@ -329,12 +329,6 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
def generate_init_code(code, init, node, fields, kw_only):
"""
- All of these "generate_*_code" functions return a tuple of:
- - code string
- - placeholder dict (often empty)
- - stat list (often empty)
- which can then be combined later and processed once.
-
Notes on CPython generated "__init__":
* Implemented in `_init_fn`.
* The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as
@@ -346,6 +340,11 @@ def generate_init_code(code, init, node, fields, kw_only):
* seen_default and the associated error message are copied directly from Python
* Call to user-defined __post_init__ function (if it exists) is copied from
CPython.
+
+ Cython behaviour deviates a little here (to be decided if this is right...)
+ Because the class variable from the assignment does not exist Cython fields will
+ return None (or whatever their type default is) if not initialized while Python
+ dataclasses will fall back to looking up the class variable.
"""
if not init or node.scope.lookup_here("__init__"):
return
@@ -456,9 +455,6 @@ def generate_cmp_code(code, op, funcname, node, fields):
names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)]
- if not names:
- return # no comparable types
-
code.add_code_lines([
"def %s(self, other):" % funcname,
" if not isinstance(other, %s):" % node.class_name,
diff --git a/Tools/make_dataclass_tests.py b/Tools/make_dataclass_tests.py
index 0c02f3d14..6a3cee7ac 100644
--- a/Tools/make_dataclass_tests.py
+++ b/Tools/make_dataclass_tests.py
@@ -120,6 +120,12 @@ skip_tests = frozenset(
# These tests are probably fine, but the string substitution in this file doesn't get it right
("TestRepr", "test_repr"),
("TestCase", "test_not_in_repr"),
+ ('TestRepr', 'test_no_repr'),
+ # class variable doesn't exist in Cython so uninitialized variable appears differently - for now this is deliberate
+ ('TestInit', 'test_no_init'),
+ # I believe the test works but the ordering functions do appear in the class dict (and default slot wrappers which
+ # just raise NotImplementedError
+ ('TestOrdering', 'test_no_order'),
# not possible to add attributes on extension types
("TestCase", "test_post_init_classmethod"),
# Bugs
@@ -139,18 +145,14 @@ skip_tests = frozenset(
("TestReplace", "test_recursive_repr_misc_attrs"), # recursion error
("TestReplace", "test_recursive_repr_indirection"), # recursion error
("TestReplace", "test_recursive_repr_indirection_two"), # recursion error
- ("TestCase", "test_0_field_compare"), # should return False
- ("TestCase", "test_1_field_compare"), # order=False is apparently ignored
- ("TestOrdering", "test_no_order"), # probably order=False being ignored
- ("TestRepr", "test_no_repr"), # turning off repr doesn't work
(
"TestCase",
"test_intermediate_non_dataclass",
), # issue with propagating through intermediate class
- ("TestCase", "test_post_init"), # init=False being ignored
(
"TestFrozen",
), # raises AttributeError, not FrozenInstanceError (may be hard to fix)
+ ('TestCase', 'test_post_init'), # Works except for AttributeError instead of FrozenInstanceError
("TestReplace", "test_frozen"), # AttributeError not FrozenInstanceError
(
"TestCase",
@@ -158,7 +160,6 @@ skip_tests = frozenset(
), # doesn't define __setattr__ and just relies on Cython to enforce readonly properties
("TestCase", "test_compare_subclasses"), # wrong comparison
("TestCase", "test_simple_compare"), # wrong comparison
- ("TestEq", "test_no_eq"), # wrong comparison (probably eq=False being ignored)
(
"TestCase",
"test_field_named_self",
@@ -210,7 +211,6 @@ version_specific_skips = {
), # needs language support for | operator on types
}
-
class DataclassInDecorators(ast.NodeVisitor):
found = False
diff --git a/tests/run/test_dataclasses.pyx b/tests/run/test_dataclasses.pyx
index d9e8fd2b0..8321b9de0 100644
--- a/tests/run/test_dataclasses.pyx
+++ b/tests/run/test_dataclasses.pyx
@@ -36,6 +36,36 @@ class C_TestCase_test_field_named_object_frozen:
@dataclass
@cclass
+class C0_TestCase_test_0_field_compare:
+ pass
+
+@dataclass(order=False)
+@cclass
+class C1_TestCase_test_0_field_compare:
+ pass
+
+@dataclass(order=True)
+@cclass
+class C_TestCase_test_0_field_compare:
+ pass
+
+@dataclass
+@cclass
+class C0_TestCase_test_1_field_compare:
+ x: int
+
+@dataclass(order=False)
+@cclass
+class C1_TestCase_test_1_field_compare:
+ x: int
+
+@dataclass(order=True)
+@cclass
+class C_TestCase_test_1_field_compare:
+ x: int
+
+@dataclass
+@cclass
class C_TestCase_test_not_in_compare:
x: int = 0
y: int = field(compare=False, default=4)
@@ -344,19 +374,6 @@ class R_TestCase_test_dataclasses_pickleable:
x: int
y: List[int] = field(default_factory=list)
-@dataclass(init=False)
-@cclass
-class C_TestInit_test_no_init:
- i: int = 0
-
-@dataclass(init=False)
-@cclass
-class C_TestInit_test_no_init_:
- i: int = 2
-
- def __init__(self):
- self.i = 3
-
@dataclass
@cclass
class C_TestInit_test_overwriting_init:
@@ -405,6 +422,19 @@ class C_TestRepr_test_overwriting_repr__:
def __repr__(self):
return 'x'
+@dataclass(eq=False)
+@cclass
+class C_TestEq_test_no_eq:
+ x: int
+
+@dataclass(eq=False)
+@cclass
+class C_TestEq_test_no_eq_:
+ x: int
+
+ def __eq__(self, other):
+ return other == 10
+
@dataclass
@cclass
class C_TestEq_test_overwriting_eq:
@@ -517,6 +547,39 @@ class TestCase(unittest.TestCase):
c = C('foo')
self.assertEqual(c.object, 'foo')
+ def test_0_field_compare(self):
+ C0 = C0_TestCase_test_0_field_compare
+ C1 = C1_TestCase_test_0_field_compare
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(), cls())
+ for (idx, fn) in enumerate([lambda a, b: a < b, lambda a, b: a <= b, lambda a, b: a > b, lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaises(TypeError):
+ fn(cls(), cls())
+ C = C_TestCase_test_0_field_compare
+ self.assertLessEqual(C(), C())
+ self.assertGreaterEqual(C(), C())
+
+ def test_1_field_compare(self):
+ C0 = C0_TestCase_test_1_field_compare
+ C1 = C1_TestCase_test_1_field_compare
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(1), cls(1))
+ self.assertNotEqual(cls(0), cls(1))
+ for (idx, fn) in enumerate([lambda a, b: a < b, lambda a, b: a <= b, lambda a, b: a > b, lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaises(TypeError):
+ fn(cls(0), cls(0))
+ C = C_TestCase_test_1_field_compare
+ self.assertLess(C(0), C(1))
+ self.assertLessEqual(C(0), C(1))
+ self.assertLessEqual(C(1), C(1))
+ self.assertGreater(C(1), C(0))
+ self.assertGreaterEqual(C(1), C(0))
+ self.assertGreaterEqual(C(1), C(1))
+
def test_not_in_compare(self):
C = C_TestCase_test_not_in_compare
self.assertEqual(C(), C(0, 20))
@@ -840,12 +903,6 @@ class TestFieldNoAnnotation(unittest.TestCase):
class TestInit(unittest.TestCase):
- def test_no_init(self):
- C = C_TestInit_test_no_init
- self.assertEqual(C().i, 0)
- C = C_TestInit_test_no_init_
- self.assertEqual(C().i, 3)
-
def test_overwriting_init(self):
C = C_TestInit_test_overwriting_init
self.assertEqual(C(3).x, 6)
@@ -866,6 +923,14 @@ class TestRepr(unittest.TestCase):
class TestEq(unittest.TestCase):
+ def test_no_eq(self):
+ C = C_TestEq_test_no_eq
+ self.assertNotEqual(C(0), C(0))
+ c = C(3)
+ self.assertEqual(c, c)
+ C = C_TestEq_test_no_eq_
+ self.assertEqual(C(3), 10)
+
def test_overwriting_eq(self):
C = C_TestEq_test_overwriting_eq
self.assertEqual(C(1), 3)