summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/persistence.py103
1 files changed, 62 insertions, 41 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 17ce2e624..9d39c39b0 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -248,9 +248,10 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
has_all_pks = True
has_all_defaults = True
+ has_version_id_generator = mapper.version_id_generator is not False \
+ and mapper.version_id_col is not None
for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col and \
- mapper.version_id_generator is not False:
+ if has_version_id_generator and col is mapper.version_id_col:
val = mapper.version_id_generator(None)
params[col.key] = val
else:
@@ -305,6 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
value_params = {}
hasdata = hasnull = False
+
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = \
@@ -341,6 +343,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
state, state_dict,
+ attributes.PASSIVE_OFF if col in pks else
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
if isinstance(history.added[0],
@@ -381,8 +384,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
else:
hasdata = True
elif col in pks:
- value = state.manager[prop.key].impl.get(
- state, state_dict)
+ value = history.unchanged[0]
if value is None:
hasnull = True
params[col._label] = value
@@ -500,41 +502,63 @@ def _emit_update_statements(base_mapper, uowtransaction,
statement = base_mapper._memo(('update', table), update_stmt)
- rows = 0
- for state, state_dict, params, mapper, \
- connection, value_params in update:
-
- if value_params:
- c = connection.execute(
- statement.values(value_params),
- params)
+ for (connection, paramkeys, hasvalue), \
+ records in groupby(
+ update,
+ lambda rec: (
+ rec[4],
+ tuple(sorted(rec[2])),
+ bool(rec[5]))
+ ):
+
+ rows = 0
+ records = list(records)
+ if hasvalue:
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ c = connection.execute(
+ statement.values(value_params),
+ params)
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
+ rows += c.rowcount
else:
+ multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
- execute(statement, params)
-
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params)
- rows += c.rowcount
-
- if connection.dialect.supports_sane_rowcount:
- if rows != len(update):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(update), rows))
-
- elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description,
- stacklevel=12)
+ execute(statement, multiparams)
+
+ rows += c.rowcount
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
+
+ if connection.dialect.supports_sane_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched." %
+ (table.description, len(records), rows))
+
+ elif needs_version_id:
+ util.warn("Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
def _emit_insert_statements(base_mapper, uowtransaction,
@@ -833,15 +857,12 @@ def _connections_for_states(base_mapper, uowtransaction, states):
connection_callable = \
uowtransaction.session.connection_callable
else:
- connection = None
+ connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
for state in _sort_states(states):
if connection_callable:
connection = connection_callable(base_mapper, state.obj())
- elif not connection:
- connection = uowtransaction.transaction.connection(
- base_mapper)
mapper = _state_mapper(state)