diff options
| -rw-r--r-- | doc/build/changelog/changelog_10.rst | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 103 | ||||
| -rw-r--r-- | test/orm/test_unitofwork.py | 9 | ||||
| -rw-r--r-- | test/orm/test_unitofworkv2.py | 31 |
4 files changed, 91 insertions, 63 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst index fb14279ac..439d02c47 100644 --- a/doc/build/changelog/changelog_10.rst +++ b/doc/build/changelog/changelog_10.rst @@ -17,6 +17,17 @@ :version: 1.0.0 .. change:: + :tags: orm, feature + + UPDATE statements can now be batched within an ORM flush + into more performant executemany() call, similarly to how INSERT + statements can be batched; this will be invoked within flush + to the degree that subsequent UPDATE statements for the + same mapping and table involve the identical columns within the + VALUES clause, as well as that no VALUES-level SQL expressions + are embedded. + + .. change:: :tags: engine, bug :tickets: 3163 diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 17ce2e624..9d39c39b0 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -248,9 +248,10 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, has_all_pks = True has_all_defaults = True + has_version_id_generator = mapper.version_id_generator is not False \ + and mapper.version_id_col is not None for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col and \ - mapper.version_id_generator is not False: + if has_version_id_generator and col is mapper.version_id_col: val = mapper.version_id_generator(None) params[col.key] = val else: @@ -305,6 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction, value_params = {} hasdata = hasnull = False + for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: params[col._label] = \ @@ -341,6 +343,7 @@ def _collect_update_commands(base_mapper, uowtransaction, prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( state, state_dict, + attributes.PASSIVE_OFF if col in pks else attributes.PASSIVE_NO_INITIALIZE) if history.added: if isinstance(history.added[0], @@ -381,8 +384,7 @@ def _collect_update_commands(base_mapper, uowtransaction, else: hasdata = True elif col in pks: - value = state.manager[prop.key].impl.get( - state, state_dict) + value = history.unchanged[0] if value is None: hasnull = True params[col._label] = value @@ -500,41 +502,63 @@ def _emit_update_statements(base_mapper, uowtransaction, statement = base_mapper._memo(('update', table), update_stmt) - rows = 0 - for state, state_dict, params, mapper, \ - connection, value_params in update: - - if value_params: - c = connection.execute( - statement.values(value_params), - params) + for (connection, paramkeys, hasvalue), \ + records in groupby( + update, + lambda rec: ( + rec[4], + tuple(sorted(rec[2])), + bool(rec[5])) + ): + + rows = 0 + records = list(records) + if hasvalue: + for state, state_dict, params, mapper, \ + connection, value_params in records: + c = connection.execute( + statement.values(value_params), + params) + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount else: + multiparams = [rec[2] for rec in records] c = cached_connections[connection].\ - execute(statement, params) - - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) - rows += c.rowcount - - if connection.dialect.supports_sane_rowcount: - if rows != len(update): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(update), rows)) - - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) + execute(statement, multiparams) + + rows += c.rowcount + for state, state_dict, params, mapper, \ + connection, value_params in records: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + + if connection.dialect.supports_sane_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(records), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) def _emit_insert_statements(base_mapper, uowtransaction, @@ -833,15 +857,12 @@ def _connections_for_states(base_mapper, uowtransaction, states): connection_callable = \ uowtransaction.session.connection_callable else: - connection = None + connection = uowtransaction.transaction.connection(base_mapper) connection_callable = None for state in _sort_states(states): if connection_callable: connection = connection_callable(base_mapper, state.obj()) - elif not connection: - connection = uowtransaction.transaction.connection( - base_mapper) mapper = _state_mapper(state) diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 6eb763213..a54097b03 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -1126,11 +1126,12 @@ class OneToManyTest(_fixtures.FixtureTest): ("UPDATE addresses SET user_id=:user_id " "WHERE addresses.id = :addresses_id", - {'user_id': None, 'addresses_id': a1.id}), + [ + {'user_id': None, 'addresses_id': a1.id}, + {'user_id': u1.id, 'addresses_id': a3.id} + ]), - ("UPDATE addresses SET user_id=:user_id " - "WHERE addresses.id = :addresses_id", - {'user_id': u1.id, 'addresses_id': a3.id})]) + ]) def test_child_move(self): """Moving a child from one parent to another, with a delete. diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 9c9296786..c643e6a87 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -131,12 +131,10 @@ class RudimentaryFlushTest(UOWTest): CompiledSQL( "UPDATE addresses SET user_id=:user_id WHERE " "addresses.id = :addresses_id", - lambda ctx: [{'addresses_id': a1.id, 'user_id': None}] - ), - CompiledSQL( - "UPDATE addresses SET user_id=:user_id WHERE " - "addresses.id = :addresses_id", - lambda ctx: [{'addresses_id': a2.id, 'user_id': None}] + lambda ctx: [ + {'addresses_id': a1.id, 'user_id': None}, + {'addresses_id': a2.id, 'user_id': None} + ] ), CompiledSQL( "DELETE FROM users WHERE users.id = :id", @@ -240,12 +238,10 @@ class RudimentaryFlushTest(UOWTest): CompiledSQL( "UPDATE addresses SET user_id=:user_id WHERE " "addresses.id = :addresses_id", - lambda ctx: [{'addresses_id': a1.id, 'user_id': None}] - ), - CompiledSQL( - "UPDATE addresses SET user_id=:user_id WHERE " - "addresses.id = :addresses_id", - lambda ctx: [{'addresses_id': a2.id, 'user_id': None}] + lambda ctx: [ + {'addresses_id': a1.id, 'user_id': None}, + {'addresses_id': a2.id, 'user_id': None} + ] ), CompiledSQL( "DELETE FROM users WHERE users.id = :id", @@ -732,12 +728,11 @@ class SingleCycleTest(UOWTest): testing.db, sess.flush, AllOf( CompiledSQL( "UPDATE nodes SET parent_id=:parent_id " - "WHERE nodes.id = :nodes_id", lambda ctx: { - 'nodes_id': n3.id, 'parent_id': None}), - CompiledSQL( - "UPDATE nodes SET parent_id=:parent_id " - "WHERE nodes.id = :nodes_id", lambda ctx: { - 'nodes_id': n2.id, 'parent_id': None}), + "WHERE nodes.id = :nodes_id", lambda ctx: [ + {'nodes_id': n3.id, 'parent_id': None}, + {'nodes_id': n2.id, 'parent_id': None} + ] + ) ), CompiledSQL( "DELETE FROM nodes WHERE nodes.id = :id", lambda ctx: { |
