summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-06-05 23:28:44 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-06-05 23:28:44 +0000
commita4252a12b0e1411cea7a636025ef9b97cb824f17 (patch)
tree3c3521a7df2c3102d95aae21173cddf53d9d4f89 /lib/sqlalchemy/orm
parent20563ddaddaff86235285f55a504bc8a43763776 (diff)
downloadsqlalchemy-a4252a12b0e1411cea7a636025ef9b97cb824f17.tar.gz
HashSet is gone, uses set() for most sets in py2.4 or sets.Set.
ordered set functionality supplied by a subclass of sets.Set
Diffstat (limited to 'lib/sqlalchemy/orm')
-rw-r--r--lib/sqlalchemy/orm/mapper.py26
-rw-r--r--lib/sqlalchemy/orm/properties.py6
-rw-r--r--lib/sqlalchemy/orm/session.py2
-rw-r--r--lib/sqlalchemy/orm/topological.py116
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py50
-rw-r--r--lib/sqlalchemy/orm/util.py6
6 files changed, 45 insertions, 161 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index eba220384..64889b9a6 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -56,16 +56,16 @@ class Mapper(object):
# uber-pendantic style of making mapper chain, as various testbase/
# threadlocal/assignmapper combinations keep putting dupes etc. in the list
# TODO: do something that isnt 21 lines....
- extlist = util.HashSet()
+ extlist = util.Set()
for ext_class in global_extensions:
if isinstance(ext_class, MapperExtension):
- extlist.append(ext_class)
+ extlist.add(ext_class)
else:
- extlist.append(ext_class())
+ extlist.add(ext_class())
if extension is not None:
for ext_obj in util.to_list(extension):
- extlist.append(ext_obj)
+ extlist.add(ext_obj)
self.extension = None
previous = None
@@ -87,7 +87,7 @@ class Mapper(object):
self._options = {}
self.always_refresh = always_refresh
self.version_id_col = version_id_col
- self._inheriting_mappers = sets.Set()
+ self._inheriting_mappers = util.Set()
self.polymorphic_on = polymorphic_on
if polymorphic_map is None:
self.polymorphic_map = {}
@@ -146,7 +146,7 @@ class Mapper(object):
# stricter set of tables to create "sync rules" by,based on the immediate
# inherited table, rather than all inherited tables
self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
- self._synchronizer.compile(self.mapped_table.onclause, util.HashSet([inherits.local_table]), sqlutil.TableFinder(self.local_table))
+ self._synchronizer.compile(self.mapped_table.onclause, util.Set([inherits.local_table]), sqlutil.TableFinder(self.local_table))
else:
self._synchronizer = None
self.mapped_table = self.local_table
@@ -182,19 +182,19 @@ class Mapper(object):
self.pks_by_table = {}
if primary_key is not None:
for k in primary_key:
- self.pks_by_table.setdefault(k.table, util.HashSet(ordered=True)).append(k)
+ self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
if k.table != self.mapped_table:
# associate pk cols from subtables to the "main" table
- self.pks_by_table.setdefault(self.mapped_table, util.HashSet(ordered=True)).append(k)
+ self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(k)
# TODO: need local_table properly accounted for when custom primary key is sent
else:
for t in self.tables + [self.mapped_table]:
try:
l = self.pks_by_table[t]
except KeyError:
- l = self.pks_by_table.setdefault(t, util.HashSet(ordered=True))
+ l = self.pks_by_table.setdefault(t, util.OrderedSet())
for k in t.primary_key:
- l.append(k)
+ l.add(k)
if len(self.pks_by_table[self.mapped_table]) == 0:
raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
@@ -582,7 +582,7 @@ class Mapper(object):
params[col.key] = params[col._label] + 1
else:
params[col.key] = 1
- elif self.pks_by_table[table].contains(col):
+ elif col in self.pks_by_table[table]:
# column is a primary key ?
if not isinsert:
# doing an UPDATE? put primary key values as "WHERE" parameters
@@ -756,14 +756,14 @@ class Mapper(object):
def cascade_iterator(self, type, object, callable_=None, recursive=None):
if recursive is None:
- recursive=sets.Set()
+ recursive=util.Set()
for prop in self.props.values():
for c in prop.cascade_iterator(type, object, recursive):
yield c
def cascade_callable(self, type, object, callable_, recursive=None):
if recursive is None:
- recursive=sets.Set()
+ recursive=util.Set()
for prop in self.props.values():
prop.cascade_callable(type, object, callable_, recursive)
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 34529a136..0b609e300 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -271,7 +271,7 @@ class PropertyLoader(mapper.MapperProperty):
"""searches through the primary join condition to determine which side
has the foreign key - from this we return
the "foreign key" for this property which helps determine one-to-many/many-to-one."""
- foreignkeys = sets.Set()
+ foreignkeys = util.Set()
def foo(binary):
if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
@@ -305,8 +305,8 @@ class PropertyLoader(mapper.MapperProperty):
The list of rules is used within commits by the _synchronize() method when dependent
objects are processed."""
- parent_tables = util.HashSet(self.parent.tables + [self.parent.mapped_table])
- target_tables = util.HashSet(self.mapper.tables + [self.mapper.mapped_table])
+ parent_tables = util.Set(self.parent.tables + [self.parent.mapped_table])
+ target_tables = util.Set(self.mapper.tables + [self.mapper.mapped_table])
self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
if self.direction == sync.MANYTOMANY:
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index bd1750165..1ba6d1a35 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -393,7 +393,7 @@ class Session(object):
def __contains__(self, obj):
return self._is_attached(obj) and (obj in self.uow.new or self.uow.has_key(obj._instance_key))
def __iter__(self):
- return iter(self.uow.new + self.uow.identity_map.values())
+ return iter(list(self.uow.new) + self.uow.identity_map.values())
def _get(self, key):
return self.uow._get(key)
def has_key(self, key):
diff --git a/lib/sqlalchemy/orm/topological.py b/lib/sqlalchemy/orm/topological.py
index 89e760039..d9ec5cde9 100644
--- a/lib/sqlalchemy/orm/topological.py
+++ b/lib/sqlalchemy/orm/topological.py
@@ -231,119 +231,3 @@ class QueueDependencySorter(object):
else:
return cycled_edges
-class TreeDependencySorter(object):
- """
- this is my first topological sorting algorithm. its crazy, but matched my thinking
- at the time. it also creates the kind of structure I want. but, I am not 100% sure
- it works in all cases since I always did really poorly in linear algebra. anyway,
- I got the other one above to produce a tree structure too so we should be OK.
- """
- class Node:
- """represents a node in a tree. stores an 'item' which represents the
- dependent thing we are talking about. if node 'a' is an ancestor node of
- node 'b', it means 'a's item is *not* dependent on that of 'b'."""
- def __init__(self, item):
- #print "new node on " + str(item)
- self.item = item
- self.children = HashSet()
- self.parent = None
- def append(self, node):
- """appends the given node as a child on this node. removes the node from
- its preexisting parent."""
- if node.parent is not None:
- del node.parent.children[node]
- self.children.append(node)
- node.parent = self
- def is_descendant_of(self, node):
- """returns true if this node is a descendant of the given node"""
- n = self
- while n is not None:
- if n is node:
- return True
- else:
- n = n.parent
- return False
- def get_root(self):
- """returns the highest ancestor node of this node, i.e. which has no parent"""
- n = self
- while n.parent is not None:
- n = n.parent
- return n
- def get_sibling_ancestor(self, node):
- """returns the node which is:
- - an ancestor of this node
- - is a sibling of the given node
- - not an ancestor of the given node
-
- - else returns this node's root node."""
- n = self
- while n.parent is not None and n.parent is not node.parent and not node.is_descendant_of(n.parent):
- n = n.parent
- return n
- def __str__(self):
- return self.safestr({})
- def safestr(self, hash, indent = 0):
- if hash.has_key(self):
- return (' ' * indent) + "RECURSIVE:%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or 'None')
- hash[self] = True
- return (' ' * indent) + "%s (idself=%s, idparent=%s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '')
- def describe(self):
- return "%s (idself=%s)" % (str(self.item), repr(id(self)))
-
- def __init__(self, tuples, allitems):
- self.tuples = tuples
- self.allitems = allitems
-
- def sort(self):
- (tuples, allitems) = (self.tuples, self.allitems)
-
- nodes = {}
- # make nodes for all the items and store in the hash
- for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
- if not nodes.has_key(item):
- nodes[item] = TreeDependencySorter.Node(item)
-
- # loop through tuples
- for tup in tuples:
- (parent, child) = (tup[0], tup[1])
- # get parent node
- parentnode = nodes[parent]
-
- # if parent is child, mark "circular" attribute on the node
- if parent is child:
- parentnode.circular = True
- # and just continue
- continue
-
- # get child node
- childnode = nodes[child]
-
- if parentnode.parent is childnode:
- # check for "a switch"
- t = parentnode.item
- parentnode.item = childnode.item
- childnode.item = t
- nodes[parentnode.item] = parentnode
- nodes[childnode.item] = childnode
- elif parentnode.is_descendant_of(childnode):
- # check for a line thats backwards with nodes in between, this is a
- # circular dependency (although confirmation on this would be helpful)
- raise FlushError("Circular dependency detected")
- elif not childnode.is_descendant_of(parentnode):
- # if relationship doesnt exist, connect nodes together
- root = childnode.get_sibling_ancestor(parentnode)
- parentnode.append(root)
-
-
- # now we have a collection of subtrees which represent dependencies.
- # go through the collection root nodes wire them together into one tree
- head = None
- for node in nodes.values():
- if node.parent is None:
- if head is not None:
- head.append(node)
- else:
- head = node
- #print str(head)
- return head
- \ No newline at end of file
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 9e9778cad..b8c6939f7 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -104,10 +104,10 @@ class UnitOfWork(object):
self.identity_map = weakref.WeakValueDictionary()
self.attributes = global_attributes
- self.new = util.HashSet(ordered = True)
- self.dirty = util.HashSet()
+ self.new = util.OrderedSet()
+ self.dirty = util.Set()
- self.deleted = util.HashSet()
+ self.deleted = util.Set()
def get(self, class_, *id):
"""given a class and a list of primary key values in their table-order, locates the mapper
@@ -149,15 +149,15 @@ class UnitOfWork(object):
if hasattr(obj, "_instance_key"):
del self.identity_map[obj._instance_key]
try:
- del self.deleted[obj]
+ self.deleted.remove(obj)
except KeyError:
pass
try:
- del self.dirty[obj]
+ self.dirty.remove(obj)
except KeyError:
pass
try:
- del self.new[obj]
+ self.new.remove(obj)
except KeyError:
pass
#self.attributes.commit(obj)
@@ -183,11 +183,11 @@ class UnitOfWork(object):
def register_clean(self, obj):
try:
- del self.dirty[obj]
+ self.dirty.remove(obj)
except KeyError:
pass
try:
- del self.new[obj]
+ self.new.remove(obj)
except KeyError:
pass
if not hasattr(obj, '_instance_key'):
@@ -199,26 +199,26 @@ class UnitOfWork(object):
def register_new(self, obj):
if hasattr(obj, '_instance_key'):
raise InvalidRequestError("Object '%s' already has an identity - it cant be registered as new" % repr(obj))
- if not self.new.contains(obj):
- self.new.append(obj)
+ if obj not in self.new:
+ self.new.add(obj)
self.unregister_deleted(obj)
def register_dirty(self, obj):
- if not self.dirty.contains(obj):
+ if obj not in self.dirty:
self._validate_obj(obj)
- self.dirty.append(obj)
+ self.dirty.add(obj)
self.unregister_deleted(obj)
def is_dirty(self, obj):
- if not self.dirty.contains(obj):
+ if obj not in self.dirty:
return False
else:
return True
def register_deleted(self, obj):
- if not self.deleted.contains(obj):
+ if obj not in self.deleted:
self._validate_obj(obj)
- self.deleted.append(obj)
+ self.deleted.add(obj)
def unregister_deleted(self, obj):
try:
@@ -230,14 +230,14 @@ class UnitOfWork(object):
flush_context = UOWTransaction(self, session)
if objects is not None:
- objset = sets.Set(objects)
+ objset = util.Set(objects)
else:
objset = None
for obj in [n for n in self.new] + [d for d in self.dirty]:
if objset is not None and not obj in objset:
continue
- if self.deleted.contains(obj):
+ if obj in self.deleted:
continue
flush_context.register_object(obj)
@@ -262,11 +262,11 @@ class UnitOfWork(object):
"""'rolls back' the attributes that have been changed on an object instance."""
self.attributes.rollback(obj)
try:
- del self.dirty[obj]
+ self.dirty.remove(obj)
except KeyError:
pass
try:
- del self.deleted[obj]
+ self.deleted.remove(obj)
except KeyError:
pass
@@ -277,7 +277,7 @@ class UOWTransaction(object):
self.uow = uow
self.session = session
# unique list of all the mappers we come across
- self.mappers = sets.Set()
+ self.mappers = util.Set()
self.dependencies = {}
self.tasks = {}
self.__modified = False
@@ -463,7 +463,7 @@ class UOWTransaction(object):
def _get_noninheriting_mappers(self):
"""returns a list of UOWTasks whose mappers are not inheriting from the mapper of another UOWTask.
i.e., this returns the root UOWTasks for all the inheritance hierarchies represented in this UOWTransaction."""
- mappers = sets.Set()
+ mappers = util.Set()
for task in self.tasks.values():
base = task.mapper.base_mapper()
mappers.add(base)
@@ -580,7 +580,7 @@ class UOWTask(object):
# a list of UOWDependencyProcessors which are executed after saves and
# before deletes, to synchronize data to dependent objects
- self.dependencies = sets.Set()
+ self.dependencies = util.Set()
# a list of UOWTasks that are dependent on this UOWTask, which
# are to be executed after this UOWTask performs saves and post-save
@@ -589,7 +589,7 @@ class UOWTask(object):
# a list of UOWTasks that correspond to Mappers which are inheriting
# mappers of this UOWTask's Mapper
- #self.inheriting_tasks = sets.Set()
+ #self.inheriting_tasks = util.Set()
# whether this UOWTask is circular, meaning it holds a second
# UOWTask that contains a special row-based dependency structure.
@@ -603,7 +603,7 @@ class UOWTask(object):
# set of dependencies, referencing sub-UOWTasks attached to this
# one which represent portions of the total list of objects.
# this is used for the row-based "circular sort"
- self.cyclical_dependencies = sets.Set()
+ self.cyclical_dependencies = util.Set()
def is_empty(self):
return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0
@@ -773,7 +773,7 @@ class UOWTask(object):
allobjects += [e.obj for e in task.get_elements(polymorphic=True)]
tuples = []
- cycles = sets.Set(cycles)
+ cycles = util.Set(cycles)
#print "BEGIN CIRC SORT-------"
#print "PRE-CIRC:"
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 19cb21367..86799b311 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -4,13 +4,13 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sets
+import sqlalchemy.util as util
import sqlalchemy.sql as sql
class CascadeOptions(object):
"""keeps track of the options sent to relation().cascade"""
def __init__(self, arg=""):
- values = sets.Set([c.strip() for c in arg.split(',')])
+ values = util.Set([c.strip() for c in arg.split(',')])
self.delete_orphan = "delete-orphan" in values
self.delete = "delete" in values or self.delete_orphan or "all" in values
self.save_update = "save-update" in values or "all" in values
@@ -22,7 +22,7 @@ class CascadeOptions(object):
def polymorphic_union(table_map, typecolname, aliasname='p_union'):
- colnames = sets.Set()
+ colnames = util.Set()
colnamemaps = {}
for key in table_map.keys():