summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-09-19 00:04:38 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-09-19 00:04:38 +0000
commit2c2ecbae867801c66b57770d5f7501bd4c0c3474 (patch)
tree5d4e3fcf7838a935f28336a32a7a82c963159a08 /lib/sqlalchemy
parent73b591b8ff7f0b9d277706b4b43d68b1794f602b (diff)
downloadsqlalchemy-2c2ecbae867801c66b57770d5f7501bd4c0c3474.tar.gz
un-stupified insert/update/delete sorting
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/attributes.py7
-rw-r--r--lib/sqlalchemy/orm/mapper.py41
-rw-r--r--lib/sqlalchemy/orm/session.py3
-rw-r--r--lib/sqlalchemy/orm/uowdumper.py20
4 files changed, 18 insertions, 53 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 17fea7854..296c019a7 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -781,7 +781,8 @@ class InstanceState(object):
key = None
runid = None
expired_attributes = EMPTY_SET
-
+ insert_order = None
+
def __init__(self, obj, manager):
self.class_ = obj.__class__
self.manager = manager
@@ -797,6 +798,10 @@ class InstanceState(object):
def dispose(self):
del self.session_id
+ @property
+ def sort_key(self):
+ return self.key and self.key[1] or self.insert_order
+
def check_modified(self):
if self.modified:
return True
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 21cbe3f2b..52b90d22a 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1027,7 +1027,7 @@ class Mapper(object):
def _get_committed_state_attr_by_column(self, state, column, passive=False):
return self._get_col_to_prop(column).getcommitted(state, column, passive=passive)
-
+
def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
@@ -1047,9 +1047,7 @@ class Mapper(object):
# if batch=false, call _save_obj separately for each object
if not single and not self.batch:
- def comparator(a, b):
- return cmp(getattr(a, 'insert_order', 0), getattr(b, 'insert_order', 0))
- for state in sorted(states, comparator):
+ for state in _sort_states(states):
self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
return
@@ -1057,10 +1055,10 @@ class Mapper(object):
# organize individual states with the connection to use for insert/update
if 'connection_callable' in uowtransaction.mapper_flush_opts:
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
- tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in states]
+ tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)]
else:
connection = uowtransaction.transaction.connection(self)
- tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in states]
+ tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)]
if not postupdate:
# call before_XXX extensions
@@ -1185,20 +1183,11 @@ class Mapper(object):
clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type))
statement = table.update(clause)
- pks = mapper._pks_by_table[table]
- def comparator(a, b):
- for col in pks:
- x = cmp(a[1][col._label], b[1][col._label])
- if x != 0:
- return x
- return 0
- update.sort(comparator)
-
rows = 0
for rec in update:
(state, params, mapper, connection, value_params) = rec
c = connection.execute(statement.values(value_params), params)
- mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
+ mapper._postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
# testlib.pragma exempt:__hash__
updated_objects.add((state, connection))
@@ -1209,9 +1198,6 @@ class Mapper(object):
if insert:
statement = table.insert()
- def comparator(a, b):
- return cmp(a[0].insert_order, b[0].insert_order)
- insert.sort(comparator)
for rec in insert:
(state, params, mapper, connection, value_params) = rec
c = connection.execute(statement.values(value_params), params)
@@ -1222,7 +1208,7 @@ class Mapper(object):
for i, col in enumerate(mapper._pks_by_table[table]):
if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i:
mapper._set_state_attr_by_column(state, col, primary_key[i])
- mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
+ mapper._postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
# synchronize newly inserted ids from one table to the next
# TODO: this performs some unnecessary attribute transfers
@@ -1263,7 +1249,7 @@ class Mapper(object):
if 'after_update' in mapper.extension.methods:
mapper.extension.after_update(mapper, connection, state.obj())
- def __postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
+ def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
"""For a given Table that has just been inserted/updated,
mark as 'expired' those attributes which correspond to columns
that are marked as 'postfetch', and populate attributes which
@@ -1303,10 +1289,10 @@ class Mapper(object):
if 'connection_callable' in uowtransaction.mapper_flush_opts:
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
- tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in states]
+ tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)]
else:
connection = uowtransaction.transaction.connection(self)
- tups = [(state, _state_mapper(state), connection) for state in states]
+ tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
for state, mapper, connection in tups:
if 'before_delete' in mapper.extension.methods:
@@ -1335,13 +1321,6 @@ class Mapper(object):
for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
- def comparator(a, b):
- for col in mapper._pks_by_table[table]:
- x = cmp(a[col.key], b[col.key])
- if x != 0:
- return x
- return 0
- del_objects.sort(comparator)
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col.key, type_=col.type))
@@ -1694,6 +1673,8 @@ def _event_on_init_failure(state, instance, args, kwargs):
instrumenting_mapper, instrumenting_mapper.class_,
state.manager.events.original_init, instance, args, kwargs)
+def _sort_states(states):
+ return sorted(states, lambda a, b:cmp(a.sort_key, b.sort_key))
def _load_scalar_attributes(state, attribute_names):
mapper = _state_mapper(state)
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 66009be01..ad987430a 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -1039,9 +1039,6 @@ class Session(object):
self.identity_map.remove(state)
state.key = instance_key
- if hasattr(state, 'insert_order'):
- delattr(state, 'insert_order')
-
obj = state.obj()
# prevent against last minute dereferences of the object
# TODO: identify a code path where state.obj() is None
diff --git a/lib/sqlalchemy/orm/uowdumper.py b/lib/sqlalchemy/orm/uowdumper.py
index a46f90563..9ae7073b9 100644
--- a/lib/sqlalchemy/orm/uowdumper.py
+++ b/lib/sqlalchemy/orm/uowdumper.py
@@ -45,25 +45,7 @@ class UOWDumper(unitofwork.UOWExecutor):
def save_objects(self, trans, task):
- # sort elements to be inserted by insert order
- def comparator(a, b):
- if a.state is None:
- x = None
- elif not hasattr(a.state, 'insert_order'):
- x = None
- else:
- x = a.state.insert_order
- if b.state is None:
- y = None
- elif not hasattr(b.state, 'insert_order'):
- y = None
- else:
- y = b.state.insert_order
- return cmp(x, y)
-
- l = list(task.polymorphic_tosave_elements)
- l.sort(comparator)
- for rec in l:
+ for rec in sorted(task.polymorphic_tosave_elements, lambda a, b:cmp(a.state.sort_key, b.state.sort_key)):
if rec.listonly:
continue
self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n")