summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py3
-rw-r--r--lib/sqlalchemy/orm/attributes.py29
-rw-r--r--lib/sqlalchemy/orm/collections.py6
-rw-r--r--lib/sqlalchemy/orm/dependency.py12
-rw-r--r--lib/sqlalchemy/orm/mapper.py32
-rw-r--r--lib/sqlalchemy/orm/query.py2
-rw-r--r--lib/sqlalchemy/orm/strategies.py4
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py46
-rw-r--r--lib/sqlalchemy/topological.py12
-rw-r--r--lib/sqlalchemy/util.py22
10 files changed, 105 insertions, 63 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 0ee59e369..472bd1b2c 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -326,6 +326,7 @@ class _AssociationList(object):
def __contains__(self, value):
for member in self.col:
+ # testlib.pragma exempt:__eq__
if self._get(member) == value:
return True
return False
@@ -473,6 +474,7 @@ class _AssociationDict(object):
del self.col[key]
def __contains__(self, key):
+ # testlib.pragma exempt:__hash__
return key in self.col
has_key = __contains__
@@ -609,6 +611,7 @@ class _AssociationSet(object):
def __contains__(self, value):
for member in self.col:
+ # testlib.pragma exempt:__eq__
if self._get(member) == value:
return True
return False
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 189cd52ee..a340394b9 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -8,7 +8,7 @@ import weakref, threading
import UserDict
from sqlalchemy import util
from sqlalchemy.orm import interfaces, collections
-from sqlalchemy.orm.mapper import class_mapper
+from sqlalchemy.orm.mapper import class_mapper, identity_equal
from sqlalchemy import exceptions
@@ -369,6 +369,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
super(ScalarObjectAttributeImpl, self).__init__(class_, manager, key,
callable_, trackparent=trackparent, extension=extension,
compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs)
+ if compare_function is None:
+ self.is_equal = identity_equal
def delete(self, state):
old = self.get(state)
@@ -815,23 +817,26 @@ class AttributeHistory(object):
if hasattr(attr, 'get_collection'):
self._current = current
+
if original is NO_VALUE:
- s = util.Set([])
+ s = util.IdentitySet([])
else:
- s = util.Set(original)
- self._added_items = []
- self._unchanged_items = []
- self._deleted_items = []
+ s = util.IdentitySet(original)
+
+ # FIXME: the tests have an assumption on the collection's ordering
+ self._added_items = util.OrderedIdentitySet()
+ self._unchanged_items = util.OrderedIdentitySet()
+ self._deleted_items = util.OrderedIdentitySet()
if current:
collection = attr.get_collection(state, current)
for a in collection:
if a in s:
- self._unchanged_items.append(a)
+ self._unchanged_items.add(a)
else:
- self._added_items.append(a)
+ self._added_items.add(a)
for a in s:
if a not in self._unchanged_items:
- self._deleted_items.append(a)
+ self._deleted_items.add(a)
else:
self._current = [current]
if attr.is_equal(current, original) is True:
@@ -853,13 +858,13 @@ class AttributeHistory(object):
return len(self._deleted_items) > 0 or len(self._added_items) > 0
def added_items(self):
- return self._added_items
+ return list(self._added_items)
def unchanged_items(self):
- return self._unchanged_items
+ return list(self._unchanged_items)
def deleted_items(self):
- return self._deleted_items
+ return list(self._deleted_items)
class AttributeManager(object):
"""Allow the instrumentation of object attributes."""
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index bf365d267..9e6b0ce75 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -793,6 +793,7 @@ def _list_decorators():
def remove(fn):
def remove(self, value, _sa_initiator=None):
+ # testlib.pragma exempt:__eq__
fn(self, value)
__del(self, value, _sa_initiator)
_tidy(remove)
@@ -1002,22 +1003,27 @@ def _set_decorators():
def add(fn):
def add(self, value, _sa_initiator=None):
__set(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
fn(self, value)
_tidy(add)
return add
def discard(fn):
def discard(self, value, _sa_initiator=None):
+ # testlib.pragma exempt:__hash__
if value in self:
__del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
fn(self, value)
_tidy(discard)
return discard
def remove(fn):
def remove(self, value, _sa_initiator=None):
+ # testlib.pragma exempt:__hash__
if value in self:
__del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
fn(self, value)
_tidy(remove)
return remove
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index a1669e32f..f771dc5d7 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -345,29 +345,29 @@ class ManyToManyDP(DependencyProcessor):
childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes)
if childlist is not None:
for child in childlist.deleted_items() + childlist.unchanged_items():
- if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes):
+ if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes):
continue
associationrow = {}
self._synchronize(obj, child, associationrow, False, uowcommit)
secondary_delete.append(associationrow)
- uowcommit.attributes[(self, "manytomany", obj, child)] = True
+ uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True
else:
for obj in deplist:
childlist = self.get_object_dependencies(obj, uowcommit)
if childlist is None: continue
for child in childlist.added_items():
- if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes):
+ if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes):
continue
associationrow = {}
self._synchronize(obj, child, associationrow, False, uowcommit)
- uowcommit.attributes[(self, "manytomany", obj, child)] = True
+ uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True
secondary_insert.append(associationrow)
for child in childlist.deleted_items():
- if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes):
+ if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes):
continue
associationrow = {}
self._synchronize(obj, child, associationrow, False, uowcommit)
- uowcommit.attributes[(self, "manytomany", obj, child)] = True
+ uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True
secondary_delete.append(associationrow)
if secondary_delete:
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 73c8321fc..efc509725 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1099,7 +1099,8 @@ class Mapper(object):
c = connection.execute(statement.values(value_params), params)
mapper._postfetch(connection, table, obj, c, c.last_updated_params(), value_params)
- updated_objects.add((obj, connection))
+ # testlib.pragma exempt:__hash__
+ updated_objects.add((id(obj), obj, connection))
rows += c.rowcount
if c.supports_sane_rowcount() and rows != len(update):
@@ -1134,13 +1135,14 @@ class Mapper(object):
mapper._synchronizer.execute(obj, obj)
sync(mapper)
- inserted_objects.add((obj, connection))
+ # testlib.pragma exempt:__hash__
+ inserted_objects.add((id(obj), obj, connection))
if not postupdate:
- for obj, connection in inserted_objects:
+ for id_, obj, connection in inserted_objects:
for mapper in object_mapper(obj).iterate_to_root():
if 'after_insert' in mapper.extension.methods:
mapper.extension.after_insert(mapper, connection, obj)
- for obj, connection in updated_objects:
+ for id_, obj, connection in updated_objects:
for mapper in object_mapper(obj).iterate_to_root():
if 'after_update' in mapper.extension.methods:
mapper.extension.after_update(mapper, connection, obj)
@@ -1194,7 +1196,7 @@ class Mapper(object):
for mapper in object_mapper(obj).iterate_to_root():
if 'before_delete' in mapper.extension.methods:
mapper.extension.before_delete(mapper, connection, obj)
-
+
deleted_objects = util.Set()
table_to_mapper = {}
for mapper in self.base_mapper.polymorphic_iterator():
@@ -1217,7 +1219,8 @@ class Mapper(object):
params[col.key] = mapper.get_attr_by_column(obj, col)
if mapper.version_id_col is not None:
params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
- deleted_objects.add((obj, connection))
+ # testlib.pragma exempt:__hash__
+ deleted_objects.add((id(obj), obj, connection))
for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
def comparator(a, b):
@@ -1237,7 +1240,7 @@ class Mapper(object):
if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
- for obj, connection in deleted_objects:
+ for id_, obj, connection in deleted_objects:
for mapper in object_mapper(obj).iterate_to_root():
if 'after_delete' in mapper.extension.methods:
mapper.extension.after_delete(mapper, connection, obj)
@@ -1284,7 +1287,7 @@ class Mapper(object):
"""
if recursive is None:
- recursive=util.Set()
+ recursive=util.IdentitySet()
for prop in self.__props.values():
for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
yield c
@@ -1310,7 +1313,7 @@ class Mapper(object):
"""
if recursive is None:
- recursive=util.Set()
+ recursive=util.IdentitySet()
for prop in self.__props.values():
prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on)
@@ -1516,7 +1519,7 @@ class Mapper(object):
selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
if self.non_primary:
- selectcontext.attributes[('populating_mapper', instance)] = self
+ selectcontext.attributes[('populating_mapper', id(instance))] = self
def _post_instance(self, selectcontext, instance):
post_processors = selectcontext.attributes[('post_processors', self, None)]
@@ -1577,6 +1580,15 @@ def has_mapper(object):
return hasattr(object, '_entity_name')
+def identity_equal(a, b):
+ if a is b:
+ return True
+ id_a = getattr(a, '_instance_key', None)
+ id_b = getattr(b, '_instance_key', None)
+ if id_a is None or id_b is None:
+ return False
+ return id_a == id_b
+
def object_mapper(object, entity_name=None, raiseerror=True):
"""Given an object, return the primary Mapper associated with the object instance.
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index bec05a43f..09a3a0f5b 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -691,7 +691,7 @@ class Query(object):
proc[0](context, row)
for instance in context.identity_map.values():
- context.attributes.get(('populating_mapper', instance), object_mapper(instance))._post_instance(context, instance)
+ context.attributes.get(('populating_mapper', id(instance)), object_mapper(instance))._post_instance(context, instance)
# store new stuff in the identity map
for instance in context.identity_map.values():
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index f64330289..b699bfee5 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -611,8 +611,8 @@ class EagerLoader(AbstractRelationLoader):
appender = util.UniqueAppender(collection, 'append_without_event')
# store it in the "scratch" area, which is local to this load operation.
- selectcontext.attributes[(instance, self.key)] = appender
- result_list = selectcontext.attributes[(instance, self.key)]
+ selectcontext.attributes[('appender', id(instance), self.key)] = appender
+ result_list = selectcontext.attributes[('appender', id(instance), self.key)]
if self._should_log_debug:
self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 0ce354d6f..7f9a4d7d0 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -93,8 +93,8 @@ class UnitOfWork(object):
else:
self.identity_map = {}
- self.new = util.Set() #OrderedSet()
- self.deleted = util.Set()
+ self.new = util.IdentitySet() #OrderedSet()
+ self.deleted = util.IdentitySet()
self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
def _remove_deleted(self, obj):
@@ -150,7 +150,7 @@ class UnitOfWork(object):
"""
# a little bit of inlining for speed
- return util.Set([x for x in self.identity_map.values()
+ return util.IdentitySet([x for x in self.identity_map.values()
if x not in self.deleted
and (
x._state.modified
@@ -180,13 +180,13 @@ class UnitOfWork(object):
# create the set of all objects we want to operate upon
if objects is not None:
# specific list passed in
- objset = util.Set(objects)
+ objset = util.IdentitySet(objects)
else:
# or just everything
- objset = util.Set(self.identity_map.values()).union(self.new)
+ objset = util.IdentitySet(self.identity_map.values()).union(self.new)
# store objects whose fate has been decided
- processed = util.Set()
+ processed = util.IdentitySet()
# put all saves/updates into the flush context. detect top-level orphans and throw them into deleted.
for obj in self.new.union(dirty).intersection(objset).difference(self.deleted):
@@ -305,7 +305,7 @@ class UOWTransaction(object):
"""
mapper = object_mapper(obj)
task = self.get_task_by_mapper(mapper)
- taskelement = task._objects[obj]
+ taskelement = task._objects[id(obj)]
taskelement.isdelete = "rowswitch"
def unregister_object(self, obj):
@@ -315,7 +315,7 @@ class UOWTransaction(object):
no further operations occur upon the instance."""
mapper = object_mapper(obj)
task = self.get_task_by_mapper(mapper)
- if obj in task._objects:
+ if id(obj) in task._objects:
task.delete(obj)
def is_deleted(self, obj):
@@ -615,11 +615,11 @@ class UOWTask(object):
"""
try:
- rec = self._objects[obj]
+ rec = self._objects[id(obj)]
retval = False
except KeyError:
rec = UOWTaskElement(obj)
- self._objects[obj] = rec
+ self._objects[id(obj)] = rec
retval = True
if not listonly:
rec.listonly = False
@@ -646,7 +646,7 @@ class UOWTask(object):
"""remove the given object from this UOWTask, if present."""
try:
- del self._objects[obj]
+ del self._objects[id(obj)]
except KeyError:
pass
@@ -654,7 +654,7 @@ class UOWTask(object):
"""return True if the given object is contained within this UOWTask or inheriting tasks."""
for task in self.polymorphic_tasks():
- if obj in task._objects:
+ if id(obj) in task._objects:
return True
else:
return False
@@ -663,7 +663,7 @@ class UOWTask(object):
"""return True if the given object is marked as to be deleted within this UOWTask."""
try:
- return self._objects[obj].isdelete
+ return self._objects[id(obj)].isdelete
except KeyError:
return False
@@ -735,9 +735,9 @@ class UOWTask(object):
def get_dependency_task(obj, depprocessor):
try:
- dp = dependencies[obj]
+ dp = dependencies[id(obj)]
except KeyError:
- dp = dependencies.setdefault(obj, {})
+ dp = dependencies.setdefault(id(obj), {})
try:
l = dp[depprocessor]
except KeyError:
@@ -766,7 +766,7 @@ class UOWTask(object):
for subtask in task.polymorphic_tasks():
for taskelement in subtask.elements:
obj = taskelement.obj
- object_to_original_task[obj] = subtask
+ object_to_original_task[id(obj)] = subtask
for dep in deps_by_targettask.get(subtask, []):
# is this dependency involved in one of the cycles ?
if not dependency_in_cycles(dep):
@@ -795,7 +795,7 @@ class UOWTask(object):
# task
if o not in childtask:
childtask.append(o, listonly=True)
- object_to_original_task[o] = childtask
+ object_to_original_task[id(o)] = childtask
# create a tuple representing the "parent/child"
whosdep = dep.whose_dependent_on_who(obj, o)
@@ -821,17 +821,17 @@ class UOWTask(object):
used_tasks = util.Set()
def make_task_tree(node, parenttask, nexttasks):
- originating_task = object_to_original_task[node.item]
+ originating_task = object_to_original_task[id(node.item)]
used_tasks.add(originating_task)
t = nexttasks.get(originating_task, None)
if t is None:
t = UOWTask(self.uowtransaction, originating_task.mapper)
nexttasks[originating_task] = t
- parenttask.append(None, listonly=False, isdelete=originating_task._objects[node.item].isdelete, childtask=t)
- t.append(node.item, originating_task._objects[node.item].listonly, isdelete=originating_task._objects[node.item].isdelete)
+ parenttask.append(None, listonly=False, isdelete=originating_task._objects[id(node.item)].isdelete, childtask=t)
+ t.append(node.item, originating_task._objects[id(node.item)].listonly, isdelete=originating_task._objects[id(node.item)].isdelete)
- if node.item in dependencies:
- for depprocessor, deptask in dependencies[node.item].iteritems():
+ if id(node.item) in dependencies:
+ for depprocessor, deptask in dependencies[id(node.item)].iteritems():
t.cyclical_dependencies.add(depprocessor.branch(deptask))
nd = {}
for n in node.children:
@@ -861,7 +861,7 @@ class UOWTask(object):
# or "delete" members due to inheriting mappers which contain tasks
localtask = UOWTask(self.uowtransaction, t2.mapper)
for obj in t2.elements:
- localtask.append(obj, t2.listonly, isdelete=t2._objects[obj].isdelete)
+ localtask.append(obj, t2.listonly, isdelete=t2._objects[id(obj)].isdelete)
for dep in t2.dependencies:
localtask._dependencies.add(dep)
t.childtasks.insert(0, localtask)
diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py
index d38c5cf4a..a47968519 100644
--- a/lib/sqlalchemy/topological.py
+++ b/lib/sqlalchemy/topological.py
@@ -153,20 +153,20 @@ class QueueDependencySorter(object):
nodes = {}
edges = _EdgeCollection()
for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
- if item not in nodes:
+ if id(item) not in nodes:
node = _Node(item)
- nodes[item] = node
+ nodes[id(item)] = node
for t in tuples:
if t[0] is t[1]:
if allow_self_cycles:
- n = nodes[t[0]]
+ n = nodes[id(t[0])]
n.cycles = util.Set([n])
continue
else:
raise CircularDependencyError("Self-referential dependency detected " + repr(t))
- childnode = nodes[t[1]]
- parentnode = nodes[t[0]]
+ childnode = nodes[id(t[1])]
+ parentnode = nodes[id(t[0])]
edges.add((parentnode, childnode))
queue = []
@@ -202,7 +202,7 @@ class QueueDependencySorter(object):
node = queue.pop()
if not hasattr(node, '_cyclical'):
output.append(node)
- del nodes[node.item]
+ del nodes[id(node.item)]
for childnode in edges.pop_node(node):
queue.append(childnode)
return self._create_batched_tree(output)
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index a4ccaac6a..9ad7e113c 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -620,6 +620,7 @@ class IdentitySet(object):
def union(self, iterable):
result = type(self)()
+ # testlib.pragma exempt:__hash__
result._members.update(
Set(self._members.iteritems()).union(_iter_id(iterable)))
return result
@@ -641,6 +642,7 @@ class IdentitySet(object):
def difference(self, iterable):
result = type(self)()
+ # testlib.pragma exempt:__hash__
result._members.update(
Set(self._members.iteritems()).difference(_iter_id(iterable)))
return result
@@ -662,6 +664,7 @@ class IdentitySet(object):
def intersection(self, iterable):
result = type(self)()
+ # testlib.pragma exempt:__hash__
result._members.update(
Set(self._members.iteritems()).intersection(_iter_id(iterable)))
return result
@@ -683,6 +686,7 @@ class IdentitySet(object):
def symmetric_difference(self, iterable):
result = type(self)()
+ # testlib.pragma exempt:__hash__
result._members.update(
Set(self._members.iteritems()).symmetric_difference(_iter_id(iterable)))
return result
@@ -725,13 +729,25 @@ def _iter_id(iterable):
yield id(item), item
+class OrderedIdentitySet(IdentitySet):
+ def __init__(self, iterable=None):
+ IdentitySet.__init__(self)
+ self._members = OrderedDict()
+ if iterable:
+ for o in iterable:
+ self.add(o)
+
+
class UniqueAppender(object):
- """appends items to a collection such that only unique items
- are added."""
+ """Only adds items to a collection once.
+
+ Additional appends() of the same object are ignored. Membership is
+ determined by identity (``is a``) not equality (``==``).
+ """
def __init__(self, data, via=None):
self.data = data
- self._unique = Set()
+ self._unique = IdentitySet()
if via:
self._data_appender = getattr(data, via)
elif hasattr(data, 'append'):