diff options
author | Jason Kirtland <jek@discorporate.us> | 2007-11-03 20:23:26 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2007-11-03 20:23:26 +0000 |
commit | 429e69db67baa8fc93ff2b55361ba2831cc26144 (patch) | |
tree | d513a330fff0a9ff6313e78f28d8f12a7d399ade /lib | |
parent | 4210a1ef236e0fbc65878dd1a1ddcc8e13d43c45 (diff) | |
download | sqlalchemy-429e69db67baa8fc93ff2b55361ba2831cc26144.tar.gz |
- Removed equality, truth and hash() testing of mapped instances. Mapped
classes can now implement arbitrary __eq__ and friends. [ticket:676]
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 29 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/collections.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/dependency.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 32 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 46 | ||||
-rw-r--r-- | lib/sqlalchemy/topological.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/util.py | 22 |
10 files changed, 105 insertions, 63 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 0ee59e369..472bd1b2c 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -326,6 +326,7 @@ class _AssociationList(object): def __contains__(self, value): for member in self.col: + # testlib.pragma exempt:__eq__ if self._get(member) == value: return True return False @@ -473,6 +474,7 @@ class _AssociationDict(object): del self.col[key] def __contains__(self, key): + # testlib.pragma exempt:__hash__ return key in self.col has_key = __contains__ @@ -609,6 +611,7 @@ class _AssociationSet(object): def __contains__(self, value): for member in self.col: + # testlib.pragma exempt:__eq__ if self._get(member) == value: return True return False diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 189cd52ee..a340394b9 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -8,7 +8,7 @@ import weakref, threading import UserDict from sqlalchemy import util from sqlalchemy.orm import interfaces, collections -from sqlalchemy.orm.mapper import class_mapper +from sqlalchemy.orm.mapper import class_mapper, identity_equal from sqlalchemy import exceptions @@ -369,6 +369,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): super(ScalarObjectAttributeImpl, self).__init__(class_, manager, key, callable_, trackparent=trackparent, extension=extension, compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs) + if compare_function is None: + self.is_equal = identity_equal def delete(self, state): old = self.get(state) @@ -815,23 +817,26 @@ class AttributeHistory(object): if hasattr(attr, 'get_collection'): self._current = current + if original is NO_VALUE: - s = util.Set([]) + s = util.IdentitySet([]) else: - s = util.Set(original) - self._added_items = [] - self._unchanged_items = [] - self._deleted_items = [] + s = util.IdentitySet(original) + + # FIXME: the tests have an assumption on the collection's ordering + self._added_items = util.OrderedIdentitySet() + self._unchanged_items = util.OrderedIdentitySet() + self._deleted_items = util.OrderedIdentitySet() if current: collection = attr.get_collection(state, current) for a in collection: if a in s: - self._unchanged_items.append(a) + self._unchanged_items.add(a) else: - self._added_items.append(a) + self._added_items.add(a) for a in s: if a not in self._unchanged_items: - self._deleted_items.append(a) + self._deleted_items.add(a) else: self._current = [current] if attr.is_equal(current, original) is True: @@ -853,13 +858,13 @@ class AttributeHistory(object): return len(self._deleted_items) > 0 or len(self._added_items) > 0 def added_items(self): - return self._added_items + return list(self._added_items) def unchanged_items(self): - return self._unchanged_items + return list(self._unchanged_items) def deleted_items(self): - return self._deleted_items + return list(self._deleted_items) class AttributeManager(object): """Allow the instrumentation of object attributes.""" diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index bf365d267..9e6b0ce75 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -793,6 +793,7 @@ def _list_decorators(): def remove(fn): def remove(self, value, _sa_initiator=None): + # testlib.pragma exempt:__eq__ fn(self, value) __del(self, value, _sa_initiator) _tidy(remove) @@ -1002,22 +1003,27 @@ def _set_decorators(): def add(fn): def add(self, value, _sa_initiator=None): __set(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ fn(self, value) _tidy(add) return add def discard(fn): def discard(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ if value in self: __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ fn(self, value) _tidy(discard) return discard def remove(fn): def remove(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ if value in self: __del(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ fn(self, value) _tidy(remove) return remove diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index a1669e32f..f771dc5d7 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -345,29 +345,29 @@ class ManyToManyDP(DependencyProcessor): childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes) if childlist is not None: for child in childlist.deleted_items() + childlist.unchanged_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes): + if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes): continue associationrow = {} self._synchronize(obj, child, associationrow, False, uowcommit) secondary_delete.append(associationrow) - uowcommit.attributes[(self, "manytomany", obj, child)] = True + uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True else: for obj in deplist: childlist = self.get_object_dependencies(obj, uowcommit) if childlist is None: continue for child in childlist.added_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes): + if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes): continue associationrow = {} self._synchronize(obj, child, associationrow, False, uowcommit) - uowcommit.attributes[(self, "manytomany", obj, child)] = True + uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True secondary_insert.append(associationrow) for child in childlist.deleted_items(): - if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes): + if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes): continue associationrow = {} self._synchronize(obj, child, associationrow, False, uowcommit) - uowcommit.attributes[(self, "manytomany", obj, child)] = True + uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True secondary_delete.append(associationrow) if secondary_delete: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 73c8321fc..efc509725 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1099,7 +1099,8 @@ class Mapper(object): c = connection.execute(statement.values(value_params), params) mapper._postfetch(connection, table, obj, c, c.last_updated_params(), value_params) - updated_objects.add((obj, connection)) + # testlib.pragma exempt:__hash__ + updated_objects.add((id(obj), obj, connection)) rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): @@ -1134,13 +1135,14 @@ class Mapper(object): mapper._synchronizer.execute(obj, obj) sync(mapper) - inserted_objects.add((obj, connection)) + # testlib.pragma exempt:__hash__ + inserted_objects.add((id(obj), obj, connection)) if not postupdate: - for obj, connection in inserted_objects: + for id_, obj, connection in inserted_objects: for mapper in object_mapper(obj).iterate_to_root(): if 'after_insert' in mapper.extension.methods: mapper.extension.after_insert(mapper, connection, obj) - for obj, connection in updated_objects: + for id_, obj, connection in updated_objects: for mapper in object_mapper(obj).iterate_to_root(): if 'after_update' in mapper.extension.methods: mapper.extension.after_update(mapper, connection, obj) @@ -1194,7 +1196,7 @@ class Mapper(object): for mapper in object_mapper(obj).iterate_to_root(): if 'before_delete' in mapper.extension.methods: mapper.extension.before_delete(mapper, connection, obj) - + deleted_objects = util.Set() table_to_mapper = {} for mapper in self.base_mapper.polymorphic_iterator(): @@ -1217,7 +1219,8 @@ class Mapper(object): params[col.key] = mapper.get_attr_by_column(obj, col) if mapper.version_id_col is not None: params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col) - deleted_objects.add((obj, connection)) + # testlib.pragma exempt:__hash__ + deleted_objects.add((id(obj), obj, connection)) for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): @@ -1237,7 +1240,7 @@ class Mapper(object): if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) - for obj, connection in deleted_objects: + for id_, obj, connection in deleted_objects: for mapper in object_mapper(obj).iterate_to_root(): if 'after_delete' in mapper.extension.methods: mapper.extension.after_delete(mapper, connection, obj) @@ -1284,7 +1287,7 @@ class Mapper(object): """ if recursive is None: - recursive=util.Set() + recursive=util.IdentitySet() for prop in self.__props.values(): for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on): yield c @@ -1310,7 +1313,7 @@ class Mapper(object): """ if recursive is None: - recursive=util.Set() + recursive=util.IdentitySet() for prop in self.__props.values(): prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on) @@ -1516,7 +1519,7 @@ class Mapper(object): selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags) if self.non_primary: - selectcontext.attributes[('populating_mapper', instance)] = self + selectcontext.attributes[('populating_mapper', id(instance))] = self def _post_instance(self, selectcontext, instance): post_processors = selectcontext.attributes[('post_processors', self, None)] @@ -1577,6 +1580,15 @@ def has_mapper(object): return hasattr(object, '_entity_name') +def identity_equal(a, b): + if a is b: + return True + id_a = getattr(a, '_instance_key', None) + id_b = getattr(b, '_instance_key', None) + if id_a is None or id_b is None: + return False + return id_a == id_b + def object_mapper(object, entity_name=None, raiseerror=True): """Given an object, return the primary Mapper associated with the object instance. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index bec05a43f..09a3a0f5b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -691,7 +691,7 @@ class Query(object): proc[0](context, row) for instance in context.identity_map.values(): - context.attributes.get(('populating_mapper', instance), object_mapper(instance))._post_instance(context, instance) + context.attributes.get(('populating_mapper', id(instance)), object_mapper(instance))._post_instance(context, instance) # store new stuff in the identity map for instance in context.identity_map.values(): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index f64330289..b699bfee5 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -611,8 +611,8 @@ class EagerLoader(AbstractRelationLoader): appender = util.UniqueAppender(collection, 'append_without_event') # store it in the "scratch" area, which is local to this load operation. - selectcontext.attributes[(instance, self.key)] = appender - result_list = selectcontext.attributes[(instance, self.key)] + selectcontext.attributes[('appender', id(instance), self.key)] = appender + result_list = selectcontext.attributes[('appender', id(instance), self.key)] if self._should_log_debug: self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 0ce354d6f..7f9a4d7d0 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -93,8 +93,8 @@ class UnitOfWork(object): else: self.identity_map = {} - self.new = util.Set() #OrderedSet() - self.deleted = util.Set() + self.new = util.IdentitySet() #OrderedSet() + self.deleted = util.IdentitySet() self.logger = logging.instance_logger(self, echoflag=session.echo_uow) def _remove_deleted(self, obj): @@ -150,7 +150,7 @@ class UnitOfWork(object): """ # a little bit of inlining for speed - return util.Set([x for x in self.identity_map.values() + return util.IdentitySet([x for x in self.identity_map.values() if x not in self.deleted and ( x._state.modified @@ -180,13 +180,13 @@ class UnitOfWork(object): # create the set of all objects we want to operate upon if objects is not None: # specific list passed in - objset = util.Set(objects) + objset = util.IdentitySet(objects) else: # or just everything - objset = util.Set(self.identity_map.values()).union(self.new) + objset = util.IdentitySet(self.identity_map.values()).union(self.new) # store objects whose fate has been decided - processed = util.Set() + processed = util.IdentitySet() # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted. for obj in self.new.union(dirty).intersection(objset).difference(self.deleted): @@ -305,7 +305,7 @@ class UOWTransaction(object): """ mapper = object_mapper(obj) task = self.get_task_by_mapper(mapper) - taskelement = task._objects[obj] + taskelement = task._objects[id(obj)] taskelement.isdelete = "rowswitch" def unregister_object(self, obj): @@ -315,7 +315,7 @@ class UOWTransaction(object): no further operations occur upon the instance.""" mapper = object_mapper(obj) task = self.get_task_by_mapper(mapper) - if obj in task._objects: + if id(obj) in task._objects: task.delete(obj) def is_deleted(self, obj): @@ -615,11 +615,11 @@ class UOWTask(object): """ try: - rec = self._objects[obj] + rec = self._objects[id(obj)] retval = False except KeyError: rec = UOWTaskElement(obj) - self._objects[obj] = rec + self._objects[id(obj)] = rec retval = True if not listonly: rec.listonly = False @@ -646,7 +646,7 @@ class UOWTask(object): """remove the given object from this UOWTask, if present.""" try: - del self._objects[obj] + del self._objects[id(obj)] except KeyError: pass @@ -654,7 +654,7 @@ class UOWTask(object): """return True if the given object is contained within this UOWTask or inheriting tasks.""" for task in self.polymorphic_tasks(): - if obj in task._objects: + if id(obj) in task._objects: return True else: return False @@ -663,7 +663,7 @@ class UOWTask(object): """return True if the given object is marked as to be deleted within this UOWTask.""" try: - return self._objects[obj].isdelete + return self._objects[id(obj)].isdelete except KeyError: return False @@ -735,9 +735,9 @@ class UOWTask(object): def get_dependency_task(obj, depprocessor): try: - dp = dependencies[obj] + dp = dependencies[id(obj)] except KeyError: - dp = dependencies.setdefault(obj, {}) + dp = dependencies.setdefault(id(obj), {}) try: l = dp[depprocessor] except KeyError: @@ -766,7 +766,7 @@ class UOWTask(object): for subtask in task.polymorphic_tasks(): for taskelement in subtask.elements: obj = taskelement.obj - object_to_original_task[obj] = subtask + object_to_original_task[id(obj)] = subtask for dep in deps_by_targettask.get(subtask, []): # is this dependency involved in one of the cycles ? if not dependency_in_cycles(dep): @@ -795,7 +795,7 @@ class UOWTask(object): # task if o not in childtask: childtask.append(o, listonly=True) - object_to_original_task[o] = childtask + object_to_original_task[id(o)] = childtask # create a tuple representing the "parent/child" whosdep = dep.whose_dependent_on_who(obj, o) @@ -821,17 +821,17 @@ class UOWTask(object): used_tasks = util.Set() def make_task_tree(node, parenttask, nexttasks): - originating_task = object_to_original_task[node.item] + originating_task = object_to_original_task[id(node.item)] used_tasks.add(originating_task) t = nexttasks.get(originating_task, None) if t is None: t = UOWTask(self.uowtransaction, originating_task.mapper) nexttasks[originating_task] = t - parenttask.append(None, listonly=False, isdelete=originating_task._objects[node.item].isdelete, childtask=t) - t.append(node.item, originating_task._objects[node.item].listonly, isdelete=originating_task._objects[node.item].isdelete) + parenttask.append(None, listonly=False, isdelete=originating_task._objects[id(node.item)].isdelete, childtask=t) + t.append(node.item, originating_task._objects[id(node.item)].listonly, isdelete=originating_task._objects[id(node.item)].isdelete) - if node.item in dependencies: - for depprocessor, deptask in dependencies[node.item].iteritems(): + if id(node.item) in dependencies: + for depprocessor, deptask in dependencies[id(node.item)].iteritems(): t.cyclical_dependencies.add(depprocessor.branch(deptask)) nd = {} for n in node.children: @@ -861,7 +861,7 @@ class UOWTask(object): # or "delete" members due to inheriting mappers which contain tasks localtask = UOWTask(self.uowtransaction, t2.mapper) for obj in t2.elements: - localtask.append(obj, t2.listonly, isdelete=t2._objects[obj].isdelete) + localtask.append(obj, t2.listonly, isdelete=t2._objects[id(obj)].isdelete) for dep in t2.dependencies: localtask._dependencies.add(dep) t.childtasks.insert(0, localtask) diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index d38c5cf4a..a47968519 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -153,20 +153,20 @@ class QueueDependencySorter(object): nodes = {} edges = _EdgeCollection() for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]: - if item not in nodes: + if id(item) not in nodes: node = _Node(item) - nodes[item] = node + nodes[id(item)] = node for t in tuples: if t[0] is t[1]: if allow_self_cycles: - n = nodes[t[0]] + n = nodes[id(t[0])] n.cycles = util.Set([n]) continue else: raise CircularDependencyError("Self-referential dependency detected " + repr(t)) - childnode = nodes[t[1]] - parentnode = nodes[t[0]] + childnode = nodes[id(t[1])] + parentnode = nodes[id(t[0])] edges.add((parentnode, childnode)) queue = [] @@ -202,7 +202,7 @@ class QueueDependencySorter(object): node = queue.pop() if not hasattr(node, '_cyclical'): output.append(node) - del nodes[node.item] + del nodes[id(node.item)] for childnode in edges.pop_node(node): queue.append(childnode) return self._create_batched_tree(output) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index a4ccaac6a..9ad7e113c 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -620,6 +620,7 @@ class IdentitySet(object): def union(self, iterable): result = type(self)() + # testlib.pragma exempt:__hash__ result._members.update( Set(self._members.iteritems()).union(_iter_id(iterable))) return result @@ -641,6 +642,7 @@ class IdentitySet(object): def difference(self, iterable): result = type(self)() + # testlib.pragma exempt:__hash__ result._members.update( Set(self._members.iteritems()).difference(_iter_id(iterable))) return result @@ -662,6 +664,7 @@ class IdentitySet(object): def intersection(self, iterable): result = type(self)() + # testlib.pragma exempt:__hash__ result._members.update( Set(self._members.iteritems()).intersection(_iter_id(iterable))) return result @@ -683,6 +686,7 @@ class IdentitySet(object): def symmetric_difference(self, iterable): result = type(self)() + # testlib.pragma exempt:__hash__ result._members.update( Set(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) return result @@ -725,13 +729,25 @@ def _iter_id(iterable): yield id(item), item +class OrderedIdentitySet(IdentitySet): + def __init__(self, iterable=None): + IdentitySet.__init__(self) + self._members = OrderedDict() + if iterable: + for o in iterable: + self.add(o) + + class UniqueAppender(object): - """appends items to a collection such that only unique items - are added.""" + """Only adds items to a collection once. + + Additional appends() of the same object are ignored. Membership is + determined by identity (``is a``) not equality (``==``). + """ def __init__(self, data, via=None): self.data = data - self._unique = Set() + self._unique = IdentitySet() if via: self._data_appender = getattr(data, via) elif hasattr(data, 'append'): |