summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-08-10 10:53:11 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-08-16 20:05:32 -0400
commit6cef8526226ab6033dfef1f793be87bff2160c04 (patch)
tree693a5b244cbb6f02a78f9d1249e6bd58f90f8bfc
parenta134ec1760df6295d537ff63df7aee83d957bf6a (diff)
downloadsqlalchemy-6cef8526226ab6033dfef1f793be87bff2160c04.tar.gz
Propagate key for collection events
Added new parameter :paramref:`_orm.AttributeEvents.include_key`, which will include the dictionary or list key for operations such as ``__setitem__()`` (e.g. ``obj[key] = value``) and ``__delitem__()`` (e.g. ``del obj[key]``), using a new keyword parameter "key" or "keys", depending on event, e.g. :paramref:`_orm.AttributeEvents.append.key`, :paramref:`_orm.AttributeEvents.bulk_replace.keys`. This allows event handlers to take into account the key that was passed to the operation and is of particular importance for dictionary operations working with :class:`_orm.MappedCollection`. Fixes: #8375 Change-Id: Icc472f7c28848f94e15c94a399cc13a88782e1e4
-rw-r--r--doc/build/changelog/unreleased_20/8375.rst14
-rw-r--r--lib/sqlalchemy/orm/__init__.py1
-rw-r--r--lib/sqlalchemy/orm/attributes.py38
-rw-r--r--lib/sqlalchemy/orm/base.py18
-rw-r--r--lib/sqlalchemy/orm/collections.py70
-rw-r--r--lib/sqlalchemy/orm/events.py71
-rw-r--r--lib/sqlalchemy/orm/interfaces.py1
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py22
-rw-r--r--test/orm/test_attributes.py238
9 files changed, 411 insertions, 62 deletions
diff --git a/doc/build/changelog/unreleased_20/8375.rst b/doc/build/changelog/unreleased_20/8375.rst
new file mode 100644
index 000000000..0fb03275b
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8375.rst
@@ -0,0 +1,14 @@
+.. change::
+ :tags: feature, orm
+ :tickets: 8375
+
+ Added new parameter :paramref:`_orm.AttributeEvents.include_key`, which
+ will include the dictionary or list key for operations such as
+ ``__setitem__()`` (e.g. ``obj[key] = value``) and ``__delitem__()`` (e.g.
+ ``del obj[key]``), using a new keyword parameter "key" or "keys", depending
+ on event, e.g. :paramref:`_orm.AttributeEvents.append.key`,
+ :paramref:`_orm.AttributeEvents.bulk_replace.keys`. This allows event
+ handlers to take into account the key that was passed to the operation and
+ is of particular importance for dictionary operations working with
+ :class:`_orm.MappedCollection`.
+
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index cda58d6a5..3a0f425fc 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -82,6 +82,7 @@ from .interfaces import InspectionAttrInfo as InspectionAttrInfo
from .interfaces import MANYTOMANY as MANYTOMANY
from .interfaces import MANYTOONE as MANYTOONE
from .interfaces import MapperProperty as MapperProperty
+from .interfaces import NO_KEY as NO_KEY
from .interfaces import ONETOMANY as ONETOMANY
from .interfaces import PropComparator as PropComparator
from .interfaces import UserDefinedOption as UserDefinedOption
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index bb7eda5ac..db86d0810 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -1717,9 +1717,10 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
dict_: _InstanceDict,
value: _T,
initiator: Optional[AttributeEventToken],
+ key: Optional[Any],
) -> _T:
for fn in self.dispatch.append:
- value = fn(state, value, initiator or self._append_token)
+ value = fn(state, value, initiator or self._append_token, key=key)
state._modified_event(dict_, self, NO_VALUE, True)
@@ -1734,9 +1735,10 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
dict_: _InstanceDict,
value: _T,
initiator: Optional[AttributeEventToken],
+ key: Optional[Any],
) -> _T:
for fn in self.dispatch.append_wo_mutation:
- value = fn(state, value, initiator or self._append_token)
+ value = fn(state, value, initiator or self._append_token, key=key)
return value
@@ -1745,6 +1747,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
state: InstanceState[Any],
dict_: _InstanceDict,
initiator: Optional[AttributeEventToken],
+ key: Optional[Any],
) -> None:
"""A special event used for pop() operations.
@@ -1762,12 +1765,13 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
dict_: _InstanceDict,
value: Any,
initiator: Optional[AttributeEventToken],
+ key: Optional[Any],
) -> None:
if self.trackparent and value is not None:
self.sethasparent(instance_state(value), state, False)
for fn in self.dispatch.remove:
- fn(state, value, initiator or self._remove_token)
+ fn(state, value, initiator or self._remove_token, key=key)
state._modified_event(dict_, self, NO_VALUE, True)
@@ -1825,7 +1829,9 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
state, dict_, user_data=None, passive=passive
)
if collection is PASSIVE_NO_RESULT:
- value = self.fire_append_event(state, dict_, value, initiator)
+ value = self.fire_append_event(
+ state, dict_, value, initiator, key=NO_KEY
+ )
assert (
self.key not in dict_
), "Collection was loaded during event handling."
@@ -1847,7 +1853,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
state, state.dict, user_data=None, passive=passive
)
if collection is PASSIVE_NO_RESULT:
- self.fire_remove_event(state, dict_, value, initiator)
+ self.fire_remove_event(state, dict_, value, initiator, key=NO_KEY)
assert (
self.key not in dict_
), "Collection was loaded during event handling."
@@ -1885,6 +1891,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
_adapt: bool = True,
) -> None:
iterable = orig_iterable = value
+ new_keys = None
# pulling a new collection first so that an adaptation exception does
# not trigger a lazy load of the old collection.
@@ -1913,14 +1920,18 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl):
if hasattr(iterable, "_sa_iterator"):
iterable = iterable._sa_iterator()
elif setting_type is dict:
+ new_keys = list(iterable)
iterable = iterable.values()
else:
iterable = iter(iterable)
+ elif util.duck_type_collection(iterable) is dict:
+ new_keys = list(value)
+
new_values = list(iterable)
evt = self._bulk_replace_token
- self.dispatch.bulk_replace(state, new_values, evt)
+ self.dispatch.bulk_replace(state, new_values, evt, keys=new_keys)
old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT)
if old is PASSIVE_NO_RESULT:
@@ -2081,7 +2092,9 @@ def backref_listeners(
)
)
- def emit_backref_from_scalar_set_event(state, child, oldchild, initiator):
+ def emit_backref_from_scalar_set_event(
+ state, child, oldchild, initiator, **kw
+ ):
if oldchild is child:
return child
if (
@@ -2146,7 +2159,9 @@ def backref_listeners(
)
return child
- def emit_backref_from_collection_append_event(state, child, initiator):
+ def emit_backref_from_collection_append_event(
+ state, child, initiator, **kw
+ ):
if child is None:
return
@@ -2180,7 +2195,9 @@ def backref_listeners(
)
return child
- def emit_backref_from_collection_remove_event(state, child, initiator):
+ def emit_backref_from_collection_remove_event(
+ state, child, initiator, **kw
+ ):
if (
child is not None
and child is not PASSIVE_NO_RESULT
@@ -2234,6 +2251,7 @@ def backref_listeners(
emit_backref_from_collection_append_event,
retval=True,
raw=True,
+ include_key=True,
)
else:
event.listen(
@@ -2242,6 +2260,7 @@ def backref_listeners(
emit_backref_from_scalar_set_event,
retval=True,
raw=True,
+ include_key=True,
)
# TODO: need coverage in test/orm/ of remove event
event.listen(
@@ -2250,6 +2269,7 @@ def backref_listeners(
emit_backref_from_collection_remove_event,
retval=True,
raw=True,
+ include_key=True,
)
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index fa653a472..66b7b8c2e 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -191,9 +191,21 @@ class PassiveFlag(FastIntFlag):
DEFAULT_MANAGER_ATTR = "_sa_class_manager"
DEFAULT_STATE_ATTR = "_sa_instance_state"
-EXT_CONTINUE = util.symbol("EXT_CONTINUE")
-EXT_STOP = util.symbol("EXT_STOP")
-EXT_SKIP = util.symbol("EXT_SKIP")
+
+class EventConstants(Enum):
+ EXT_CONTINUE = 1
+ EXT_STOP = 2
+ EXT_SKIP = 3
+ NO_KEY = 4
+ """indicates an :class:`.AttributeEvent` event that did not have any
+ key argument.
+
+ .. versionadded:: 2.0
+
+ """
+
+
+EXT_CONTINUE, EXT_STOP, EXT_SKIP, NO_KEY = tuple(EventConstants)
class RelationshipDirection(Enum):
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index f47d00634..5dbd2dc30 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -125,6 +125,7 @@ from typing import TypeVar
from typing import Union
import weakref
+from .base import NO_KEY
from .. import exc as sa_exc
from .. import util
from ..util.compat import inspect_getfullargspec
@@ -614,7 +615,7 @@ class CollectionAdapter:
def __bool__(self):
return True
- def fire_append_wo_mutation_event(self, item, initiator=None):
+ def fire_append_wo_mutation_event(self, item, initiator=None, key=NO_KEY):
"""Notify that a entity is entering the collection but is already
present.
@@ -635,12 +636,12 @@ class CollectionAdapter:
self._reset_empty()
return self.attr.fire_append_wo_mutation_event(
- self.owner_state, self.owner_state.dict, item, initiator
+ self.owner_state, self.owner_state.dict, item, initiator, key
)
else:
return item
- def fire_append_event(self, item, initiator=None):
+ def fire_append_event(self, item, initiator=None, key=NO_KEY):
"""Notify that a entity has entered the collection.
Initiator is a token owned by the InstrumentedAttribute that
@@ -657,12 +658,12 @@ class CollectionAdapter:
self._reset_empty()
return self.attr.fire_append_event(
- self.owner_state, self.owner_state.dict, item, initiator
+ self.owner_state, self.owner_state.dict, item, initiator, key
)
else:
return item
- def fire_remove_event(self, item, initiator=None):
+ def fire_remove_event(self, item, initiator=None, key=NO_KEY):
"""Notify that a entity has been removed from the collection.
Initiator is the InstrumentedAttribute that initiated the membership
@@ -678,10 +679,10 @@ class CollectionAdapter:
self._reset_empty()
self.attr.fire_remove_event(
- self.owner_state, self.owner_state.dict, item, initiator
+ self.owner_state, self.owner_state.dict, item, initiator, key
)
- def fire_pre_remove_event(self, initiator=None):
+ def fire_pre_remove_event(self, initiator=None, key=NO_KEY):
"""Notify that an entity is about to be removed from the collection.
Only called if the entity cannot be removed after calling
@@ -691,7 +692,10 @@ class CollectionAdapter:
if self.invalidated:
self._warn_invalidated()
self.attr.fire_pre_remove_event(
- self.owner_state, self.owner_state.dict, initiator=initiator
+ self.owner_state,
+ self.owner_state.dict,
+ initiator=initiator,
+ key=key,
)
def __getstate__(self):
@@ -1025,10 +1029,12 @@ def __set_wo_mutation(collection, item, _sa_initiator=None):
if _sa_initiator is not False:
executor = collection._sa_adapter
if executor:
- executor.fire_append_wo_mutation_event(item, _sa_initiator)
+ executor.fire_append_wo_mutation_event(
+ item, _sa_initiator, key=None
+ )
-def __set(collection, item, _sa_initiator=None):
+def __set(collection, item, _sa_initiator, key):
"""Run set events.
This event always occurs before the collection is actually mutated.
@@ -1038,11 +1044,11 @@ def __set(collection, item, _sa_initiator=None):
if _sa_initiator is not False:
executor = collection._sa_adapter
if executor:
- item = executor.fire_append_event(item, _sa_initiator)
+ item = executor.fire_append_event(item, _sa_initiator, key=key)
return item
-def __del(collection, item, _sa_initiator=None):
+def __del(collection, item, _sa_initiator, key):
"""Run del events.
This event occurs before the collection is actually mutated, *except*
@@ -1054,7 +1060,7 @@ def __del(collection, item, _sa_initiator=None):
if _sa_initiator is not False:
executor = collection._sa_adapter
if executor:
- executor.fire_remove_event(item, _sa_initiator)
+ executor.fire_remove_event(item, _sa_initiator, key=key)
def __before_pop(collection, _sa_initiator=None):
@@ -1073,7 +1079,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def append(fn):
def append(self, item, _sa_initiator=None):
- item = __set(self, item, _sa_initiator)
+ item = __set(self, item, _sa_initiator, NO_KEY)
fn(self, item)
_tidy(append)
@@ -1081,7 +1087,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def remove(fn):
def remove(self, value, _sa_initiator=None):
- __del(self, value, _sa_initiator)
+ __del(self, value, _sa_initiator, NO_KEY)
# testlib.pragma exempt:__eq__
fn(self, value)
@@ -1090,7 +1096,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def insert(fn):
def insert(self, index, value):
- value = __set(self, value)
+ value = __set(self, value, None, index)
fn(self, index, value)
_tidy(insert)
@@ -1101,8 +1107,8 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
if not isinstance(index, slice):
existing = self[index]
if existing is not None:
- __del(self, existing)
- value = __set(self, value)
+ __del(self, existing, None, index)
+ value = __set(self, value, None, index)
fn(self, index, value)
else:
# slice assignment requires __delitem__, insert, __len__
@@ -1144,14 +1150,14 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def __delitem__(self, index):
if not isinstance(index, slice):
item = self[index]
- __del(self, item)
+ __del(self, item, None, index)
fn(self, index)
else:
# slice deletion requires __getslice__ and a slice-groking
# __getitem__ for stepped deletion
# note: not breaking this into atomic dels
for item in self[index]:
- __del(self, item)
+ __del(self, item, None, index)
fn(self, index)
_tidy(__delitem__)
@@ -1180,7 +1186,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def pop(self, index=-1):
__before_pop(self)
item = fn(self, index)
- __del(self, item)
+ __del(self, item, None, index)
return item
_tidy(pop)
@@ -1189,7 +1195,7 @@ def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def clear(fn):
def clear(self, index=-1):
for item in self:
- __del(self, item)
+ __del(self, item, None, index)
fn(self)
_tidy(clear)
@@ -1217,8 +1223,8 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def __setitem__(fn):
def __setitem__(self, key, value, _sa_initiator=None):
if key in self:
- __del(self, self[key], _sa_initiator)
- value = __set(self, value, _sa_initiator)
+ __del(self, self[key], _sa_initiator, key)
+ value = __set(self, value, _sa_initiator, key)
fn(self, key, value)
_tidy(__setitem__)
@@ -1227,7 +1233,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def __delitem__(fn):
def __delitem__(self, key, _sa_initiator=None):
if key in self:
- __del(self, self[key], _sa_initiator)
+ __del(self, self[key], _sa_initiator, key)
fn(self, key)
_tidy(__delitem__)
@@ -1236,7 +1242,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def clear(fn):
def clear(self):
for key in self:
- __del(self, self[key])
+ __del(self, self[key], None, key)
fn(self)
_tidy(clear)
@@ -1251,7 +1257,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
else:
item = fn(self, key, default)
if _to_del:
- __del(self, item)
+ __del(self, item, None, key)
return item
_tidy(pop)
@@ -1261,7 +1267,7 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def popitem(self):
__before_pop(self)
item = fn(self)
- __del(self, item[1])
+ __del(self, item[1], None, 1)
return item
_tidy(popitem)
@@ -1341,7 +1347,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def add(fn):
def add(self, value, _sa_initiator=None):
if value not in self:
- value = __set(self, value, _sa_initiator)
+ value = __set(self, value, _sa_initiator, NO_KEY)
else:
__set_wo_mutation(self, value, _sa_initiator)
# testlib.pragma exempt:__hash__
@@ -1354,7 +1360,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def discard(self, value, _sa_initiator=None):
# testlib.pragma exempt:__hash__
if value in self:
- __del(self, value, _sa_initiator)
+ __del(self, value, _sa_initiator, NO_KEY)
# testlib.pragma exempt:__hash__
fn(self, value)
@@ -1365,7 +1371,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
def remove(self, value, _sa_initiator=None):
# testlib.pragma exempt:__hash__
if value in self:
- __del(self, value, _sa_initiator)
+ __del(self, value, _sa_initiator, NO_KEY)
# testlib.pragma exempt:__hash__
fn(self, value)
@@ -1378,7 +1384,7 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
item = fn(self)
# for set in particular, we have no way to access the item
# that will be popped before pop is called.
- __del(self, item)
+ __del(self, item, None, NO_KEY)
return item
_tidy(pop)
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index 680e49981..c17ea1abe 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -22,6 +22,7 @@ from . import interfaces
from . import mapperlib
from .attributes import QueryableAttribute
from .base import _mapper_or_none
+from .base import NO_KEY
from .query import Query
from .scoping import scoped_session
from .session import Session
@@ -2288,6 +2289,7 @@ class AttributeEvents(event.Events):
raw=False,
retval=False,
propagate=False,
+ include_key=False,
):
target, fn = event_key.dispatch_target, event_key._listen_fn
@@ -2295,9 +2297,9 @@ class AttributeEvents(event.Events):
if active_history:
target.dispatch._active_history = True
- if not raw or not retval:
+ if not raw or not retval or not include_key:
- def wrap(target, *arg):
+ def wrap(target, *arg, **kw):
if not raw:
target = target.obj()
if not retval:
@@ -2305,10 +2307,16 @@ class AttributeEvents(event.Events):
value = arg[0]
else:
value = None
- fn(target, *arg)
+ if include_key:
+ fn(target, *arg, **kw)
+ else:
+ fn(target, *arg)
return value
else:
- return fn(target, *arg)
+ if include_key:
+ return fn(target, *arg, **kw)
+ else:
+ return fn(target, *arg)
event_key = event_key.with_wrapper(wrap)
@@ -2324,7 +2332,7 @@ class AttributeEvents(event.Events):
if active_history:
mgr[target.key].dispatch._active_history = True
- def append(self, target, value, initiator):
+ def append(self, target, value, initiator, *, key=NO_KEY):
"""Receive a collection append event.
The append event is invoked for each element as it is appended
@@ -2343,6 +2351,19 @@ class AttributeEvents(event.Events):
from its original value by backref handlers in order to control
chained event propagation, as well as be inspected for information
about the source of the event.
+ :param key: When the event is established using the
+ :paramref:`.AttributeEvents.include_key` parameter set to
+ True, this will be the key used in the operation, such as
+ ``collection[some_key_or_index] = value``.
+ The parameter is not passed
+ to the event at all if the the
+ :paramref:`.AttributeEvents.include_key`
+ was not used to set up the event; this is to allow backwards
+ compatibility with existing event handlers that don't include the
+ ``key`` parameter.
+
+ .. versionadded:: 2.0
+
:return: if the event was registered with ``retval=True``,
the given value, or a new effective value, should be returned.
@@ -2355,7 +2376,7 @@ class AttributeEvents(event.Events):
"""
- def append_wo_mutation(self, target, value, initiator):
+ def append_wo_mutation(self, target, value, initiator, *, key=NO_KEY):
"""Receive a collection append event where the collection was not
actually mutated.
@@ -2378,6 +2399,18 @@ class AttributeEvents(event.Events):
from its original value by backref handlers in order to control
chained event propagation, as well as be inspected for information
about the source of the event.
+ :param key: When the event is established using the
+ :paramref:`.AttributeEvents.include_key` parameter set to
+ True, this will be the key used in the operation, such as
+ ``collection[some_key_or_index] = value``.
+ The parameter is not passed
+ to the event at all if the the
+ :paramref:`.AttributeEvents.include_key`
+ was not used to set up the event; this is to allow backwards
+ compatibility with existing event handlers that don't include the
+ ``key`` parameter.
+
+ .. versionadded:: 2.0
:return: No return value is defined for this event.
@@ -2385,7 +2418,7 @@ class AttributeEvents(event.Events):
"""
- def bulk_replace(self, target, values, initiator):
+ def bulk_replace(self, target, values, initiator, *, keys=None):
"""Receive a collection 'bulk replace' event.
This event is invoked for a sequence of values as they are incoming
@@ -2428,6 +2461,17 @@ class AttributeEvents(event.Events):
handler can modify this list in place.
:param initiator: An instance of :class:`.attributes.Event`
representing the initiation of the event.
+ :param keys: When the event is established using the
+ :paramref:`.AttributeEvents.include_key` parameter set to
+ True, this will be the sequence of keys used in the operation,
+ typically only for a dictionary update. The parameter is not passed
+ to the event at all if the the
+ :paramref:`.AttributeEvents.include_key`
+ was not used to set up the event; this is to allow backwards
+ compatibility with existing event handlers that don't include the
+ ``key`` parameter.
+
+ .. versionadded:: 2.0
.. seealso::
@@ -2437,7 +2481,7 @@ class AttributeEvents(event.Events):
"""
- def remove(self, target, value, initiator):
+ def remove(self, target, value, initiator, *, key=NO_KEY):
"""Receive a collection remove event.
:param target: the object instance receiving the event.
@@ -2453,6 +2497,17 @@ class AttributeEvents(event.Events):
passed as a :class:`.attributes.Event` object, and may be
modified by backref handlers within a chain of backref-linked
events.
+ :param key: When the event is established using the
+ :paramref:`.AttributeEvents.include_key` parameter set to
+ True, this will be the key used in the operation, such as
+ ``del collection[some_key_or_index]``. The parameter is not passed
+ to the event at all if the the
+ :paramref:`.AttributeEvents.include_key`
+ was not used to set up the event; this is to allow backwards
+ compatibility with existing event handlers that don't include the
+ ``key`` parameter.
+
+ .. versionadded:: 2.0
:return: No return value is defined for this event.
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 16062fffa..72f5c6a7b 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -49,6 +49,7 @@ from .base import InspectionAttr as InspectionAttr # noqa: F401
from .base import InspectionAttrInfo as InspectionAttrInfo
from .base import MANYTOMANY as MANYTOMANY # noqa: F401
from .base import MANYTOONE as MANYTOONE # noqa: F401
+from .base import NO_KEY as NO_KEY # noqa: F401
from .base import NotExtension as NotExtension # noqa: F401
from .base import ONETOMANY as ONETOMANY # noqa: F401
from .base import RelationshipDirection as RelationshipDirection # noqa: F401
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index c83ffdb59..5e66653a3 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -47,7 +47,7 @@ def track_cascade_events(descriptor, prop):
"""
key = prop.key
- def append(state, item, initiator):
+ def append(state, item, initiator, **kw):
# process "save_update" cascade rules for when
# an instance is appended to the list of another instance
@@ -70,7 +70,7 @@ def track_cascade_events(descriptor, prop):
sess._save_or_update_state(item_state)
return item
- def remove(state, item, initiator):
+ def remove(state, item, initiator, **kw):
if item is None:
return
@@ -104,7 +104,7 @@ def track_cascade_events(descriptor, prop):
# item
item_state._orphaned_outside_of_session = True
- def set_(state, newvalue, oldvalue, initiator):
+ def set_(state, newvalue, oldvalue, initiator, **kw):
# process "save_update" cascade rules for when an instance
# is attached to another instance
if oldvalue is newvalue:
@@ -141,10 +141,18 @@ def track_cascade_events(descriptor, prop):
sess.expunge(oldvalue)
return newvalue
- event.listen(descriptor, "append_wo_mutation", append, raw=True)
- event.listen(descriptor, "append", append, raw=True, retval=True)
- event.listen(descriptor, "remove", remove, raw=True, retval=True)
- event.listen(descriptor, "set", set_, raw=True, retval=True)
+ event.listen(
+ descriptor, "append_wo_mutation", append, raw=True, include_key=True
+ )
+ event.listen(
+ descriptor, "append", append, raw=True, retval=True, include_key=True
+ )
+ event.listen(
+ descriptor, "remove", remove, raw=True, retval=True, include_key=True
+ )
+ event.listen(
+ descriptor, "set", set_, raw=True, retval=True, include_key=True
+ )
class UOWTransaction:
diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py
index e1274a805..53b306f5b 100644
--- a/test/orm/test_attributes.py
+++ b/test/orm/test_attributes.py
@@ -8,6 +8,8 @@ from sqlalchemy import testing
from sqlalchemy.orm import attributes
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import instrumentation
+from sqlalchemy.orm import NO_KEY
+from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.orm.collections import collection
from sqlalchemy.orm.state import InstanceState
from sqlalchemy.testing import assert_raises
@@ -23,7 +25,6 @@ from sqlalchemy.testing.assertions import assert_warns
from sqlalchemy.testing.util import all_partial_orderings
from sqlalchemy.testing.util import gc_collect
-
# global for pickling tests
MyTest = None
MyTest2 = None
@@ -2576,8 +2577,6 @@ class HistoryTest(fixtures.TestBase):
class Bar(fixtures.BasicEntity):
pass
- from sqlalchemy.orm.collections import attribute_mapped_collection
-
instrumentation.register_class(Foo)
instrumentation.register_class(Bar)
_register_attribute(
@@ -3193,6 +3192,239 @@ class LazyloadHistoryTest(fixtures.TestBase):
)
+class CollectionKeyTest(fixtures.ORMTest):
+ @testing.fixture
+ def dict_collection(self):
+ class Foo(fixtures.BasicEntity):
+ pass
+
+ class Bar(fixtures.BasicEntity):
+ def __init__(self, name):
+ self.name = name
+
+ instrumentation.register_class(Foo)
+ instrumentation.register_class(Bar)
+ _register_attribute(
+ Foo,
+ "someattr",
+ uselist=True,
+ useobject=True,
+ typecallable=attribute_mapped_collection("name"),
+ )
+ _register_attribute(
+ Bar,
+ "name",
+ uselist=False,
+ useobject=False,
+ )
+
+ return Foo, Bar
+
+ @testing.fixture
+ def list_collection(self):
+ class Foo(fixtures.BasicEntity):
+ pass
+
+ class Bar(fixtures.BasicEntity):
+ pass
+
+ instrumentation.register_class(Foo)
+ instrumentation.register_class(Bar)
+ _register_attribute(
+ Foo,
+ "someattr",
+ uselist=True,
+ useobject=True,
+ )
+
+ return Foo, Bar
+
+ def test_listen_w_list_key(self, list_collection):
+ Foo, Bar = list_collection
+
+ m1 = Mock()
+
+ event.listen(Foo.someattr, "append", m1, include_key=True)
+ event.listen(Foo.someattr, "remove", m1, include_key=True)
+
+ f1 = Foo()
+ b1, b2, b3 = Bar(), Bar(), Bar()
+ f1.someattr.append(b1)
+ f1.someattr.append(b2)
+ f1.someattr[1] = b3
+ del f1.someattr[0]
+ append_token, remove_token = (
+ Foo.someattr.impl._append_token,
+ Foo.someattr.impl._remove_token,
+ )
+
+ eq_(
+ m1.mock_calls,
+ [
+ call(
+ f1,
+ b1,
+ append_token,
+ key=NO_KEY,
+ ),
+ call(
+ f1,
+ b2,
+ append_token,
+ key=NO_KEY,
+ ),
+ call(
+ f1,
+ b2,
+ remove_token,
+ key=1,
+ ),
+ call(
+ f1,
+ b3,
+ append_token,
+ key=1,
+ ),
+ call(
+ f1,
+ b1,
+ remove_token,
+ key=0,
+ ),
+ ],
+ )
+
+ def test_listen_w_dict_key(self, dict_collection):
+ Foo, Bar = dict_collection
+
+ m1 = Mock()
+
+ event.listen(Foo.someattr, "append", m1, include_key=True)
+ event.listen(Foo.someattr, "remove", m1, include_key=True)
+
+ f1 = Foo()
+ b1, b2, b3 = Bar("b1"), Bar("b2"), Bar("b3")
+ f1.someattr["k1"] = b1
+ f1.someattr.update({"k2": b2, "k3": b3})
+
+ del f1.someattr["k2"]
+
+ append_token, remove_token = (
+ Foo.someattr.impl._append_token,
+ Foo.someattr.impl._remove_token,
+ )
+
+ eq_(
+ m1.mock_calls,
+ [
+ call(
+ f1,
+ b1,
+ append_token,
+ key="k1",
+ ),
+ call(
+ f1,
+ b2,
+ append_token,
+ key="k2",
+ ),
+ call(
+ f1,
+ b3,
+ append_token,
+ key="k3",
+ ),
+ call(
+ f1,
+ b2,
+ remove_token,
+ key="k2",
+ ),
+ ],
+ )
+
+ def test_dict_bulk_replace_w_key(self, dict_collection):
+ Foo, Bar = dict_collection
+
+ m1 = Mock()
+
+ event.listen(Foo.someattr, "bulk_replace", m1, include_key=True)
+ event.listen(Foo.someattr, "append", m1, include_key=True)
+ event.listen(Foo.someattr, "remove", m1, include_key=True)
+
+ f1 = Foo()
+ b1, b2, b3, b4 = Bar("b1"), Bar("b2"), Bar("b3"), Bar("b4")
+ f1.someattr = {"b1": b1, "b3": b3}
+ f1.someattr = {"b2": b2, "b3": b3, "b4": b4}
+
+ bulk_replace_token = Foo.someattr.impl._bulk_replace_token
+
+ eq_(
+ m1.mock_calls,
+ [
+ call(f1, [b1, b3], bulk_replace_token, keys=["b1", "b3"]),
+ call(f1, b1, bulk_replace_token, key="b1"),
+ call(f1, b3, bulk_replace_token, key="b3"),
+ call(
+ f1,
+ [b2, b3, b4],
+ bulk_replace_token,
+ keys=["b2", "b3", "b4"],
+ ),
+ call(f1, b2, bulk_replace_token, key="b2"),
+ call(f1, b4, bulk_replace_token, key="b4"),
+ call(f1, b1, bulk_replace_token, key=NO_KEY),
+ ],
+ )
+
+ def test_listen_wo_dict_key(self, dict_collection):
+ Foo, Bar = dict_collection
+
+ m1 = Mock()
+
+ event.listen(Foo.someattr, "append", m1)
+ event.listen(Foo.someattr, "remove", m1)
+
+ f1 = Foo()
+ b1, b2, b3 = Bar("b1"), Bar("b2"), Bar("b3")
+ f1.someattr["k1"] = b1
+ f1.someattr.update({"k2": b2, "k3": b3})
+
+ del f1.someattr["k2"]
+
+ append_token, remove_token = (
+ Foo.someattr.impl._append_token,
+ Foo.someattr.impl._remove_token,
+ )
+
+ eq_(
+ m1.mock_calls,
+ [
+ call(
+ f1,
+ b1,
+ append_token,
+ ),
+ call(
+ f1,
+ b2,
+ append_token,
+ ),
+ call(
+ f1,
+ b3,
+ append_token,
+ ),
+ call(
+ f1,
+ b2,
+ remove_token,
+ ),
+ ],
+ )
+
+
class ListenerTest(fixtures.ORMTest):
def test_receive_changes(self):
"""test that Listeners can mutate the given value."""