summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/changelog_10.rst11
-rw-r--r--lib/sqlalchemy/orm/persistence.py103
-rw-r--r--test/orm/test_unitofwork.py9
-rw-r--r--test/orm/test_unitofworkv2.py31
4 files changed, 91 insertions, 63 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst
index fb14279ac..439d02c47 100644
--- a/doc/build/changelog/changelog_10.rst
+++ b/doc/build/changelog/changelog_10.rst
@@ -17,6 +17,17 @@
:version: 1.0.0
.. change::
+ :tags: orm, feature
+
+ UPDATE statements can now be batched within an ORM flush
+ into more performant executemany() call, similarly to how INSERT
+ statements can be batched; this will be invoked within flush
+ to the degree that subsequent UPDATE statements for the
+ same mapping and table involve the identical columns within the
+ VALUES clause, as well as that no VALUES-level SQL expressions
+ are embedded.
+
+ .. change::
:tags: engine, bug
:tickets: 3163
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 17ce2e624..9d39c39b0 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -248,9 +248,10 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
has_all_pks = True
has_all_defaults = True
+ has_version_id_generator = mapper.version_id_generator is not False \
+ and mapper.version_id_col is not None
for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col and \
- mapper.version_id_generator is not False:
+ if has_version_id_generator and col is mapper.version_id_col:
val = mapper.version_id_generator(None)
params[col.key] = val
else:
@@ -305,6 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
value_params = {}
hasdata = hasnull = False
+
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = \
@@ -341,6 +343,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
prop = mapper._columntoproperty[col]
history = state.manager[prop.key].impl.get_history(
state, state_dict,
+ attributes.PASSIVE_OFF if col in pks else
attributes.PASSIVE_NO_INITIALIZE)
if history.added:
if isinstance(history.added[0],
@@ -381,8 +384,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
else:
hasdata = True
elif col in pks:
- value = state.manager[prop.key].impl.get(
- state, state_dict)
+ value = history.unchanged[0]
if value is None:
hasnull = True
params[col._label] = value
@@ -500,41 +502,63 @@ def _emit_update_statements(base_mapper, uowtransaction,
statement = base_mapper._memo(('update', table), update_stmt)
- rows = 0
- for state, state_dict, params, mapper, \
- connection, value_params in update:
-
- if value_params:
- c = connection.execute(
- statement.values(value_params),
- params)
+ for (connection, paramkeys, hasvalue), \
+ records in groupby(
+ update,
+ lambda rec: (
+ rec[4],
+ tuple(sorted(rec[2])),
+ bool(rec[5]))
+ ):
+
+ rows = 0
+ records = list(records)
+ if hasvalue:
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ c = connection.execute(
+ statement.values(value_params),
+ params)
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
+ rows += c.rowcount
else:
+ multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
- execute(statement, params)
-
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params)
- rows += c.rowcount
-
- if connection.dialect.supports_sane_rowcount:
- if rows != len(update):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(update), rows))
-
- elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description,
- stacklevel=12)
+ execute(statement, multiparams)
+
+ rows += c.rowcount
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
+
+ if connection.dialect.supports_sane_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched." %
+ (table.description, len(records), rows))
+
+ elif needs_version_id:
+ util.warn("Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
def _emit_insert_statements(base_mapper, uowtransaction,
@@ -833,15 +857,12 @@ def _connections_for_states(base_mapper, uowtransaction, states):
connection_callable = \
uowtransaction.session.connection_callable
else:
- connection = None
+ connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
for state in _sort_states(states):
if connection_callable:
connection = connection_callable(base_mapper, state.obj())
- elif not connection:
- connection = uowtransaction.transaction.connection(
- base_mapper)
mapper = _state_mapper(state)
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index 6eb763213..a54097b03 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -1126,11 +1126,12 @@ class OneToManyTest(_fixtures.FixtureTest):
("UPDATE addresses SET user_id=:user_id "
"WHERE addresses.id = :addresses_id",
- {'user_id': None, 'addresses_id': a1.id}),
+ [
+ {'user_id': None, 'addresses_id': a1.id},
+ {'user_id': u1.id, 'addresses_id': a3.id}
+ ]),
- ("UPDATE addresses SET user_id=:user_id "
- "WHERE addresses.id = :addresses_id",
- {'user_id': u1.id, 'addresses_id': a3.id})])
+ ])
def test_child_move(self):
"""Moving a child from one parent to another, with a delete.
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 9c9296786..c643e6a87 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -131,12 +131,10 @@ class RudimentaryFlushTest(UOWTest):
CompiledSQL(
"UPDATE addresses SET user_id=:user_id WHERE "
"addresses.id = :addresses_id",
- lambda ctx: [{'addresses_id': a1.id, 'user_id': None}]
- ),
- CompiledSQL(
- "UPDATE addresses SET user_id=:user_id WHERE "
- "addresses.id = :addresses_id",
- lambda ctx: [{'addresses_id': a2.id, 'user_id': None}]
+ lambda ctx: [
+ {'addresses_id': a1.id, 'user_id': None},
+ {'addresses_id': a2.id, 'user_id': None}
+ ]
),
CompiledSQL(
"DELETE FROM users WHERE users.id = :id",
@@ -240,12 +238,10 @@ class RudimentaryFlushTest(UOWTest):
CompiledSQL(
"UPDATE addresses SET user_id=:user_id WHERE "
"addresses.id = :addresses_id",
- lambda ctx: [{'addresses_id': a1.id, 'user_id': None}]
- ),
- CompiledSQL(
- "UPDATE addresses SET user_id=:user_id WHERE "
- "addresses.id = :addresses_id",
- lambda ctx: [{'addresses_id': a2.id, 'user_id': None}]
+ lambda ctx: [
+ {'addresses_id': a1.id, 'user_id': None},
+ {'addresses_id': a2.id, 'user_id': None}
+ ]
),
CompiledSQL(
"DELETE FROM users WHERE users.id = :id",
@@ -732,12 +728,11 @@ class SingleCycleTest(UOWTest):
testing.db, sess.flush, AllOf(
CompiledSQL(
"UPDATE nodes SET parent_id=:parent_id "
- "WHERE nodes.id = :nodes_id", lambda ctx: {
- 'nodes_id': n3.id, 'parent_id': None}),
- CompiledSQL(
- "UPDATE nodes SET parent_id=:parent_id "
- "WHERE nodes.id = :nodes_id", lambda ctx: {
- 'nodes_id': n2.id, 'parent_id': None}),
+ "WHERE nodes.id = :nodes_id", lambda ctx: [
+ {'nodes_id': n3.id, 'parent_id': None},
+ {'nodes_id': n2.id, 'parent_id': None}
+ ]
+ )
),
CompiledSQL(
"DELETE FROM nodes WHERE nodes.id = :id", lambda ctx: {