summaryrefslogtreecommitdiff
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/dataclasses.py6
-rw-r--r--Lib/test/test_dataclasses.py26
2 files changed, 30 insertions, 2 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 3f85d859b1..b327462080 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -808,8 +808,10 @@ def _get_field(cls, a_name, a_type, default_kw_only):
raise TypeError(f'field {f.name} is a ClassVar but specifies '
'kw_only')
- # For real fields, disallow mutable defaults for known types.
- if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
+ # For real fields, disallow mutable defaults. Use unhashable as a proxy
+ # indicator for mutability. Read the __hash__ attribute from the class,
+ # not the instance.
+ if f._field_type is _FIELD and f.default.__class__.__hash__ is None:
raise ValueError(f'mutable default {type(f.default)} for field '
f'{f.name} is not allowed: use default_factory')
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index ef5009ab11..69e7685083 100644
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -501,6 +501,32 @@ class TestCase(unittest.TestCase):
self.assertNotEqual(C(3), C(4, 10))
self.assertNotEqual(C(3, 10), C(4, 10))
+ def test_no_unhashable_default(self):
+ # See bpo-44674.
+ class Unhashable:
+ __hash__ = None
+
+ unhashable_re = 'mutable default .* for field a is not allowed'
+ with self.assertRaisesRegex(ValueError, unhashable_re):
+ @dataclass
+ class A:
+ a: dict = {}
+
+ with self.assertRaisesRegex(ValueError, unhashable_re):
+ @dataclass
+ class A:
+ a: Any = Unhashable()
+
+ # Make sure that the machinery looking for hashability is using the
+ # class's __hash__, not the instance's __hash__.
+ with self.assertRaisesRegex(ValueError, unhashable_re):
+ unhashable = Unhashable()
+ # This shouldn't make the variable hashable.
+ unhashable.__hash__ = lambda: 0
+ @dataclass
+ class A:
+ a: Any = unhashable
+
def test_hash_field_rules(self):
# Test all 6 cases of:
# hash=True/False/None