summaryrefslogtreecommitdiff
path: root/Lib/dataclasses.py
diff options
context:
space:
mode:
authorEric V. Smith <ericvsmith@users.noreply.github.com>2018-01-27 19:07:40 -0500
committerGitHub <noreply@github.com>2018-01-27 19:07:40 -0500
commitea8fc52e75363276db23c6a8d7a689f79efce4f9 (patch)
treeca662ba631df1f6e6e32b5b0d95a6b5458d5699c /Lib/dataclasses.py
parent2a2247ce5e1984eb2f2c41b269b38dbb795a60cf (diff)
downloadcpython-git-ea8fc52e75363276db23c6a8d7a689f79efce4f9.tar.gz
bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)
Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
Diffstat (limited to 'Lib/dataclasses.py')
-rw-r--r--Lib/dataclasses.py306
1 files changed, 224 insertions, 82 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 7d30da1aac..fb279cd305 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -18,6 +18,142 @@ __all__ = ['dataclass',
'is_dataclass',
]
+# Conditions for adding methods. The boxes indicate what action the
+# dataclass decorator takes. For all of these tables, when I talk
+# about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring
+# to the arguments to the @dataclass decorator. When checking if a
+# dunder method already exists, I mean check for an entry in the
+# class's __dict__. I never check to see if an attribute is defined
+# in a base class.
+
+# Key:
+# +=========+=========================================+
+# + Value | Meaning |
+# +=========+=========================================+
+# | <blank> | No action: no method is added. |
+# +---------+-----------------------------------------+
+# | add | Generated method is added. |
+# +---------+-----------------------------------------+
+# | add* | Generated method is added only if the |
+# | | existing attribute is None and if the |
+# | | user supplied a __eq__ method in the |
+# | | class definition. |
+# +---------+-----------------------------------------+
+# | raise | TypeError is raised. |
+# +---------+-----------------------------------------+
+# | None | Attribute is set to None. |
+# +=========+=========================================+
+
+# __init__
+#
+# +--- init= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __init__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+# __repr__
+#
+# +--- repr= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __repr__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+
+# __setattr__
+# __delattr__
+#
+# +--- frozen= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__?
+# +=======+=======+=======+
+# | False | | | <- the default
+# +-------+-------+-------+
+# | True | add | raise |
+# +=======+=======+=======+
+# Raise because not adding these methods would break the "frozen-ness"
+# of the class.
+
+# __eq__
+#
+# +--- eq= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __eq__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+# __lt__
+# __le__
+# __gt__
+# __ge__
+#
+# +--- order= parameter
+# |
+# v | | |
+# | no | yes | <--- class has any comparison method in __dict__?
+# +=======+=======+=======+
+# | False | | | <- the default
+# +-------+-------+-------+
+# | True | add | raise |
+# +=======+=======+=======+
+# Raise because to allow this case would interfere with using
+# functools.total_ordering.
+
+# __hash__
+
+# +------------------- hash= parameter
+# | +----------- eq= parameter
+# | | +--- frozen= parameter
+# | | |
+# v v v | | |
+# | no | yes | <--- class has __hash__ in __dict__?
+# +=========+=======+=======+========+========+
+# | 1 None | False | False | | | No __eq__, use the base class __hash__
+# +---------+-------+-------+--------+--------+
+# | 2 None | False | True | | | No __eq__, use the base class __hash__
+# +---------+-------+-------+--------+--------+
+# | 3 None | True | False | None | | <-- the default, not hashable
+# +---------+-------+-------+--------+--------+
+# | 4 None | True | True | add | add* | Frozen, so hashable
+# +---------+-------+-------+--------+--------+
+# | 5 False | False | False | | |
+# +---------+-------+-------+--------+--------+
+# | 6 False | False | True | | |
+# +---------+-------+-------+--------+--------+
+# | 7 False | True | False | | |
+# +---------+-------+-------+--------+--------+
+# | 8 False | True | True | | |
+# +---------+-------+-------+--------+--------+
+# | 9 True | False | False | add | add* | Has no __eq__, but hashable
+# +---------+-------+-------+--------+--------+
+# |10 True | False | True | add | add* | Has no __eq__, but hashable
+# +---------+-------+-------+--------+--------+
+# |11 True | True | False | add | add* | Not frozen, but hashable
+# +---------+-------+-------+--------+--------+
+# |12 True | True | True | add | add* | Frozen, so hashable
+# +=========+=======+=======+========+========+
+# For boxes that are blank, __hash__ is untouched and therefore
+# inherited from the base class. If the base is object, then
+# id-based hashing is used.
+# Note that a class may have already __hash__=None if it specified an
+# __eq__ method in the class body (not one that was created by
+# @dataclass).
+
+
# Raised when an attempt is made to modify a frozen class.
class FrozenInstanceError(AttributeError): pass
@@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields):
# return "(self.x,self.y)".
# Special case for the 0-tuple.
- if len(fields) == 0:
+ if not fields:
return '()'
# Note the trailing comma, needed if this turns out to be a 1-tuple.
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
-def _create_fn(name, args, body, globals=None, locals=None,
+def _create_fn(name, args, body, *, globals=None, locals=None,
return_type=MISSING):
# Note that we mutate locals when exec() is called. Caller beware!
if locals is None:
@@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name):
body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
# If no body lines, use 'pass'.
- if len(body_lines) == 0:
+ if not body_lines:
body_lines = ['pass']
locals = {f'_type_{f.name}': f.type for f in fields}
@@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
'return NotImplemented'])
-def _set_eq_fns(cls, fields):
- # Create and set the equality comparison methods on cls.
- # Pre-compute self_tuple and other_tuple, then re-use them for
- # each function.
- self_tuple = _tuple_str('self', fields)
- other_tuple = _tuple_str('other', fields)
- for name, op in [('__eq__', '=='),
- ('__ne__', '!='),
- ]:
- _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
-
-
-def _set_order_fns(cls, fields):
- # Create and set the ordering methods on cls.
- # Pre-compute self_tuple and other_tuple, then re-use them for
- # each function.
- self_tuple = _tuple_str('self', fields)
- other_tuple = _tuple_str('other', fields)
- for name, op in [('__lt__', '<'),
- ('__le__', '<='),
- ('__gt__', '>'),
- ('__ge__', '>='),
- ]:
- _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
-
-
def _hash_fn(fields):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
@@ -431,20 +541,20 @@ def _find_fields(cls):
# a Field(), then it contains additional info beyond (and
# possibly including) the actual default value. Pseudo-fields
# ClassVars and InitVars are included, despite the fact that
- # they're not real fields. That's deal with later.
+ # they're not real fields. That's dealt with later.
annotations = getattr(cls, '__annotations__', {})
-
return [_get_field(cls, a_name, a_type)
for a_name, a_type in annotations.items()]
-def _set_attribute(cls, name, value):
- # Raise TypeError if an attribute by this name already exists.
+def _set_new_attribute(cls, name, value):
+ # Never overwrites an existing attribute. Returns True if the
+ # attribute already exists.
if name in cls.__dict__:
- raise TypeError(f'Cannot overwrite attribute {name} '
- f'in {cls.__name__}')
+ return True
setattr(cls, name, value)
+ return False
def _process_class(cls, repr, eq, order, hash, init, frozen):
@@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
# be inherited down.
is_frozen = frozen or cls.__setattr__ is _frozen_setattr
+ # Was this class defined with an __eq__? Used in __hash__ logic.
+ auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None
+
# If we're generating ordering methods, we must be generating
# the eq methods.
if order and not eq:
@@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
has_post_init = hasattr(cls, _POST_INIT_NAME)
# Include InitVars and regular fields (so, not ClassVars).
- _set_attribute(cls, '__init__',
- _init_fn(list(filter(lambda f: f._field_type
- in (_FIELD, _FIELD_INITVAR),
- fields.values())),
- is_frozen,
- has_post_init,
- # The name to use for the "self" param
- # in __init__. Use "self" if possible.
- '__dataclass_self__' if 'self' in fields
- else 'self',
- ))
+ flds = [f for f in fields.values()
+ if f._field_type in (_FIELD, _FIELD_INITVAR)]
+ _set_new_attribute(cls, '__init__',
+ _init_fn(flds,
+ is_frozen,
+ has_post_init,
+ # The name to use for the "self" param
+ # in __init__. Use "self" if possible.
+ '__dataclass_self__' if 'self' in fields
+ else 'self',
+ ))
# Get the fields as a list, and include only real fields. This is
# used in all of the following methods.
- field_list = list(filter(lambda f: f._field_type is _FIELD,
- fields.values()))
+ field_list = [f for f in fields.values() if f._field_type is _FIELD]
if repr:
- _set_attribute(cls, '__repr__',
- _repr_fn(list(filter(lambda f: f.repr, field_list))))
-
- if is_frozen:
- _set_attribute(cls, '__setattr__', _frozen_setattr)
- _set_attribute(cls, '__delattr__', _frozen_delattr)
-
- generate_hash = False
- if hash is None:
- if eq and frozen:
- # Generate a hash function.
- generate_hash = True
- elif eq and not frozen:
- # Not hashable.
- _set_attribute(cls, '__hash__', None)
- elif not eq:
- # Otherwise, use the base class definition of hash(). That is,
- # don't set anything on this class.
- pass
- else:
- assert "can't get here"
- else:
- generate_hash = hash
- if generate_hash:
- _set_attribute(cls, '__hash__',
- _hash_fn(list(filter(lambda f: f.compare
- if f.hash is None
- else f.hash,
- field_list))))
+ flds = [f for f in field_list if f.repr]
+ _set_new_attribute(cls, '__repr__', _repr_fn(flds))
if eq:
- # Create and __eq__ and __ne__ methods.
- _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
+ # Create _eq__ method. There's no need for a __ne__ method,
+ # since python will call __eq__ and negate it.
+ flds = [f for f in field_list if f.compare]
+ self_tuple = _tuple_str('self', flds)
+ other_tuple = _tuple_str('other', flds)
+ _set_new_attribute(cls, '__eq__',
+ _cmp_fn('__eq__', '==',
+ self_tuple, other_tuple))
if order:
- # Create and __lt__, __le__, __gt__, and __ge__ methods.
- # Create and set the comparison functions.
- _set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
+ # Create and set the ordering methods.
+ flds = [f for f in field_list if f.compare]
+ self_tuple = _tuple_str('self', flds)
+ other_tuple = _tuple_str('other', flds)
+ for name, op in [('__lt__', '<'),
+ ('__le__', '<='),
+ ('__gt__', '>'),
+ ('__ge__', '>='),
+ ]:
+ if _set_new_attribute(cls, name,
+ _cmp_fn(name, op, self_tuple, other_tuple)):
+ raise TypeError(f'Cannot overwrite attribute {name} '
+ f'in {cls.__name__}. Consider using '
+ 'functools.total_ordering')
+
+ if is_frozen:
+ for name, fn in [('__setattr__', _frozen_setattr),
+ ('__delattr__', _frozen_delattr)]:
+ if _set_new_attribute(cls, name, fn):
+ raise TypeError(f'Cannot overwrite attribute {name} '
+ f'in {cls.__name__}')
+
+ # Decide if/how we're going to create a hash function.
+ # TODO: Move this table to module scope, so it's not recreated
+ # all the time.
+ generate_hash = {(None, False, False): ('', ''),
+ (None, False, True): ('', ''),
+ (None, True, False): ('none', ''),
+ (None, True, True): ('fn', 'fn-x'),
+ (False, False, False): ('', ''),
+ (False, False, True): ('', ''),
+ (False, True, False): ('', ''),
+ (False, True, True): ('', ''),
+ (True, False, False): ('fn', 'fn-x'),
+ (True, False, True): ('fn', 'fn-x'),
+ (True, True, False): ('fn', 'fn-x'),
+ (True, True, True): ('fn', 'fn-x'),
+ }[None if hash is None else bool(hash), # Force bool() if not None.
+ bool(eq),
+ bool(frozen)]['__hash__' in cls.__dict__]
+ # No need to call _set_new_attribute here, since we already know if
+ # we're overwriting a __hash__ or not.
+ if generate_hash == '':
+ # Do nothing.
+ pass
+ elif generate_hash == 'none':
+ cls.__hash__ = None
+ elif generate_hash in ('fn', 'fn-x'):
+ if generate_hash == 'fn' or auto_hash_test:
+ flds = [f for f in field_list
+ if (f.compare if f.hash is None else f.hash)]
+ cls.__hash__ = _hash_fn(flds)
+ else:
+ assert False, f"can't get here: {generate_hash}"
if not getattr(cls, '__doc__'):
# Create a class doc-string.