summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-10-08 19:30:00 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-10-08 19:30:00 +0000
commitef77cfa61b6a894202495d460b055de6fea9eed6 (patch)
tree3e7ae1e4621bfc790f882babe392cae2fcc94119 /lib/sqlalchemy
parentb0ffcbc264f6a92ba5092e5d785a2dbfe418c307 (diff)
downloadsqlalchemy-ef77cfa61b6a894202495d460b055de6fea9eed6.tar.gz
- mapper.save_obj() now functions across all mappers in its polymorphic
series, UOWTask calls mapper appropriately in this manner - polymorphic mappers (i.e. using inheritance) now produces INSERT statements in order of tables across all inherited classes [ticket:321]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/mapper.py109
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py4
-rw-r--r--lib/sqlalchemy/sql_util.py31
3 files changed, 80 insertions, 64 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index b60d9c034..f8e9bca07 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -517,6 +517,21 @@ class Mapper(object):
m = m.inherits
return m is self
+ def iterate_to_root(self):
+ m = self
+ while m is not None:
+ yield m
+ m = m.inherits
+
+ def polymorphic_iterator(self):
+ m = self.base_mapper()
+ def iterate(m):
+ yield m
+ for mapper in m._inheriting_mappers:
+ for x in iterate(mapper):
+ yield x
+ return iterate(m)
+
def accept_mapper_option(self, option):
option.process_mapper(self)
@@ -702,9 +717,11 @@ class Mapper(object):
self.columntoproperty[column][0].setattr(obj, value)
def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False):
- """called by a UnitOfWork object to save objects, which involves either an INSERT or
- an UPDATE statement for each table used by this mapper, for each element of the
- list."""
+ """save a list of objects.
+
+ this method is called within a unit of work flush() process. It saves objects that are mapped not just
+ by this mapper, but inherited mappers as well, so that insert ordering of polymorphic objects is maintained."""
+
self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched"))
# if batch=false, call save_obj separately for each object
@@ -718,37 +735,30 @@ class Mapper(object):
if not postupdate:
for obj in objects:
if not has_identity(obj):
- self.extension.before_insert(self, connection, obj)
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.before_insert(mapper, connection, obj)
else:
- self.extension.before_update(self, connection, obj)
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.before_update(mapper, connection, obj)
inserted_objects = util.Set()
updated_objects = util.Set()
- for table in self.tables.sort(reverse=False):
- #print "SAVE_OBJ table ", self.class_.__name__, table.name
- # looping through our set of tables, which are all "real" tables, as opposed
- # to our main table which might be a select statement or something non-writeable
-
- # the loop structure is tables on the outer loop, objects on the inner loop.
- # this allows us to bundle inserts/updates on the same table together...although currently
- # they are separate execs via execute(), not executemany()
+
+ table_to_mapper = {}
+ for mapper in self.polymorphic_iterator():
+ for t in mapper.tables:
+ table_to_mapper[t] = mapper
- if not self._has_pks(table):
- #print "NO PKS ?", str(table)
- # if we dont have a full set of primary keys for this table, we cant really
- # do any CRUD with it, so skip. this occurs if we are mapping against a query
- # that joins on other tables so its not really an error condition.
- continue
-
+ for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False):
# two lists to store parameters for each table/object pair located
insert = []
update = []
- # we have our own idea of the primary key columns
- # for this table, in the case that the user
- # specified custom primary key cols.
for obj in objects:
- instance_key = self.instance_key(obj)
+ mapper = object_mapper(obj)
+ if table not in mapper.tables or not mapper._has_pks(table):
+ continue
+ instance_key = mapper.instance_key(obj)
self.__log_debug("save_obj() instance %s identity %s" % (mapperutil.instance_str(obj), str(instance_key)))
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
@@ -766,31 +776,31 @@ class Mapper(object):
params = {}
hasdata = False
for col in table.columns:
- if col is self.version_id_col:
+ if col is mapper.version_id_col:
if not isinsert:
- params[col._label] = self._getattrbycolumn(obj, col)
+ params[col._label] = mapper._getattrbycolumn(obj, col)
params[col.key] = params[col._label] + 1
else:
params[col.key] = 1
- elif col in self.pks_by_table[table]:
+ elif col in mapper.pks_by_table[table]:
# column is a primary key ?
if not isinsert:
# doing an UPDATE? put primary key values as "WHERE" parameters
# matching the bindparam we are creating below, i.e. "<tablename>_<colname>"
- params[col._label] = self._getattrbycolumn(obj, col)
+ params[col._label] = mapper._getattrbycolumn(obj, col)
else:
# doing an INSERT, primary key col ?
# if the primary key values are not populated,
# leave them out of the INSERT altogether, since PostGres doesn't want
# them to be present for SERIAL to take effect. A SQLEngine that uses
# explicit sequences will put them back in if they are needed
- value = self._getattrbycolumn(obj, col)
+ value = mapper._getattrbycolumn(obj, col)
if value is not None:
params[col.key] = value
- elif self.polymorphic_on is not None and self.polymorphic_on.shares_lineage(col):
+ elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
if isinsert:
- self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (self.polymorphic_identity, col.key))
- value = self.polymorphic_identity
+ self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key))
+ value = mapper.polymorphic_identity
if col.default is None or value is not None:
params[col.key] = value
else:
@@ -805,7 +815,7 @@ class Mapper(object):
params[col.key] = self._getattrbycolumn(obj, col)
hasdata = True
continue
- prop = self._getpropbycolumn(col, False)
+ prop = mapper._getpropbycolumn(col, False)
if prop is None:
continue
history = prop.get_history(obj, passive=True)
@@ -821,7 +831,7 @@ class Mapper(object):
# default. if its None and theres no default, we still might
# not want to put it in the col list but SQLIte doesnt seem to like that
# if theres no columns at all
- value = self._getattrbycolumn(obj, col, False)
+ value = mapper._getattrbycolumn(obj, col, False)
if value is NO_ATTRIBUTE:
continue
if col.default is None or value is not None:
@@ -834,18 +844,19 @@ class Mapper(object):
update.append((obj, params))
else:
insert.append((obj, params))
-
+
+ mapper = table_to_mapper[table]
if len(update):
clause = sql.and_()
- for col in self.pks_by_table[table]:
+ for col in mapper.pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label, type=col.type))
- if self.version_id_col is not None:
- clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label, type=col.type))
+ if mapper.version_id_col is not None:
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type))
statement = table.update(clause)
rows = 0
supports_sane_rowcount = True
def comparator(a, b):
- for col in self.pks_by_table[table]:
+ for col in mapper.pks_by_table[table]:
x = cmp(a[1][col._label],b[1][col._label])
if x != 0:
return x
@@ -854,7 +865,7 @@ class Mapper(object):
for rec in update:
(obj, params) = rec
c = connection.execute(statement, params)
- self._postfetch(connection, table, obj, c, c.last_updated_params())
+ mapper._postfetch(connection, table, obj, c, c.last_updated_params())
updated_objects.add(obj)
rows += c.cursor.rowcount
@@ -873,11 +884,11 @@ class Mapper(object):
primary_key = c.last_inserted_ids()
if primary_key is not None:
i = 0
- for col in self.pks_by_table[table]:
- if self._getattrbycolumn(obj, col) is None and len(primary_key) > i:
- self._setattrbycolumn(obj, col, primary_key[i])
+ for col in mapper.pks_by_table[table]:
+ if mapper._getattrbycolumn(obj, col) is None and len(primary_key) > i:
+ mapper._setattrbycolumn(obj, col, primary_key[i])
i+=1
- self._postfetch(connection, table, obj, c, c.last_inserted_params())
+ mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
# synchronize newly inserted ids from one table to the next
def sync(mapper):
@@ -886,12 +897,16 @@ class Mapper(object):
sync(inherit)
if mapper._synchronizer is not None:
mapper._synchronizer.execute(obj, obj)
- sync(self)
+ sync(mapper)
inserted_objects.add(obj)
if not postupdate:
- [self.extension.after_insert(self, connection, obj) for obj in inserted_objects]
- [self.extension.after_update(self, connection, obj) for obj in updated_objects]
+ for obj in inserted_objects:
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.after_insert(mapper, connection, obj)
+ for obj in updated_objects:
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.after_update(mapper, connection, obj)
def _postfetch(self, connection, table, obj, resultproxy, params):
"""after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 2750b34d1..2beb6a78e 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -627,8 +627,7 @@ class UOWTask(object):
pass
def _save_objects(self, trans):
- for task in self.polymorphic_tasks():
- task.mapper.save_obj(task.tosave_objects, trans)
+ self.mapper.save_obj(self.polymorphic_tosave_objects, trans)
def _delete_objects(self, trans):
for task in self.polymorphic_tasks():
task.mapper.delete_obj(task.todelete_objects, trans)
@@ -701,6 +700,7 @@ class UOWTask(object):
todelete_elements = property(lambda self:[rec for rec in self.get_elements(polymorphic=False) if rec.isdelete])
tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
todelete_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is True])
+ polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=True) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
def _sort_circular_dependencies(self, trans, cycles):
"""for a single task, creates a hierarchical tree of "subtasks" which associate
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py
index 4015fd244..94caade68 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql_util.py
@@ -6,8 +6,18 @@ import sqlalchemy.util as util
class TableCollection(object):
- def __init__(self):
- self.tables = []
+ def __init__(self, tables=None):
+ self.tables = tables or []
+ def __len__(self):
+ return len(self.tables)
+ def __getitem__(self, i):
+ return self.tables[i]
+ def __iter__(self):
+ return iter(self.tables)
+ def __contains__(self, obj):
+ return obj in self.tables
+ def __add__(self, obj):
+ return self.tables + list(obj)
def add(self, table):
self.tables.append(table)
if hasattr(self, '_sorted'):
@@ -29,10 +39,11 @@ class TableCollection(object):
import sqlalchemy.orm.topological
tuples = []
class TVisitor(schema.SchemaVisitor):
- def visit_foreign_key(self, fkey):
+ def visit_foreign_key(_self, fkey):
parent_table = fkey.column.table
- child_table = fkey.parent.table
- tuples.append( ( parent_table, child_table ) )
+ if parent_table in self:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
vis = TVisitor()
for table in self.tables:
table.accept_schema_visitor(vis)
@@ -57,16 +68,6 @@ class TableFinder(TableCollection, sql.ClauseVisitor):
table.accept_visitor(self)
def visit_table(self, table):
self.tables.append(table)
- def __len__(self):
- return len(self.tables)
- def __getitem__(self, i):
- return self.tables[i]
- def __iter__(self):
- return iter(self.tables)
- def __contains__(self, obj):
- return obj in self.tables
- def __add__(self, obj):
- return self.tables + list(obj)
def visit_column(self, column):
if self.check_columns:
column.table.accept_visitor(self)