diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 138 |
1 files changed, 113 insertions, 25 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index abd528986..dfb61c28a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -31,6 +31,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import cursor as _cursor from ..sql import operators from ..sql.elements import BooleanClauseList from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -398,6 +399,11 @@ def _collect_insert_commands( None ) + if bulk and mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + yield ( state, state_dict, @@ -411,7 +417,11 @@ def _collect_insert_commands( def _collect_update_commands( - uowtransaction, table, states_to_update, bulk=False + uowtransaction, + table, + states_to_update, + bulk=False, + use_orm_update_stmt=None, ): """Identify sets of values to use in UPDATE statements for a list of states. @@ -437,7 +447,11 @@ def _collect_update_commands( pks = mapper._pks_by_table[table] - value_params = {} + if use_orm_update_stmt is not None: + # TODO: ordered values, etc + value_params = use_orm_update_stmt._values + else: + value_params = {} propkey_to_col = mapper._propkey_to_col[table] @@ -697,6 +711,7 @@ def _emit_update_statements( table, update, bookkeeping=True, + use_orm_update_stmt=None, ): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -708,7 +723,7 @@ def _emit_update_statements( execution_options = {"compiled_cache": base_mapper._compiled_cache} - def update_stmt(): + def update_stmt(existing_stmt=None): clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: @@ -725,10 +740,17 @@ def _emit_update_statements( ) ) - stmt = table.update().where(clauses) + if existing_stmt is not None: + stmt = existing_stmt.where(clauses) + else: + stmt = table.update().where(clauses) return stmt - cached_stmt = base_mapper._memo(("update", table), update_stmt) + if use_orm_update_stmt is not None: + cached_stmt = update_stmt(use_orm_update_stmt) + + else: + cached_stmt = base_mapper._memo(("update", table), update_stmt) for ( (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), @@ -747,6 +769,15 @@ def _emit_update_statements( records = list(records) statement = cached_stmt + + if use_orm_update_stmt is not None: + statement = statement._annotate( + { + "_emit_update_table": table, + "_emit_update_mapper": mapper, + } + ) + return_defaults = False if not has_all_pks: @@ -904,16 +935,35 @@ def _emit_insert_statements( table, insert, bookkeeping=True, + use_orm_insert_stmt=None, + execution_options=None, ): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" - cached_stmt = base_mapper._memo(("insert", table), table.insert) + if use_orm_insert_stmt is not None: + cached_stmt = use_orm_insert_stmt + exec_opt = util.EMPTY_DICT - execution_options = {"compiled_cache": base_mapper._compiled_cache} + # if a user query with RETURNING was passed, we definitely need + # to use RETURNING. + returning_is_required_anyway = bool(use_orm_insert_stmt._returning) + else: + returning_is_required_anyway = False + cached_stmt = base_mapper._memo(("insert", table), table.insert) + exec_opt = {"compiled_cache": base_mapper._compiled_cache} + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + exec_opt, execution_options + ) + else: + execution_options = exec_opt + + return_result = None for ( - (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), + (connection, _, hasvalue, has_all_pks, has_all_defaults), records, ) in groupby( insert, @@ -928,17 +978,29 @@ def _emit_insert_statements( statement = cached_stmt + if use_orm_insert_stmt is not None: + statement = statement._annotate( + { + "_emit_insert_table": table, + "_emit_insert_mapper": mapper, + } + ) + if ( - not bookkeeping - or ( - has_all_defaults - or not base_mapper.eager_defaults - or not base_mapper.local_table.implicit_returning - or not connection.dialect.insert_returning + ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper.eager_defaults + or not base_mapper.local_table.implicit_returning + or not connection.dialect.insert_returning + ) ) + and not returning_is_required_anyway and has_all_pks and not hasvalue ): + # the "we don't need newly generated values back" section. # here we have all the PKs, all the defaults or we don't want # to fetch them, or the dialect doesn't support RETURNING at all @@ -946,7 +1008,7 @@ def _emit_insert_statements( records = list(records) multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) if bookkeeping: @@ -962,7 +1024,7 @@ def _emit_insert_statements( has_all_defaults, ), last_inserted_params, - ) in zip(records, c.context.compiled_parameters): + ) in zip(records, result.context.compiled_parameters): if state: _postfetch( mapper_rec, @@ -970,19 +1032,20 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, - c.returned_defaults - if not c.context.executemany + result.returned_defaults + if not result.context.executemany else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: - # here, we need defaults and/or pk values back. + # here, we need defaults and/or pk values back or we otherwise + # know that we are using RETURNING in any case records = list(records) if ( @@ -991,6 +1054,16 @@ def _emit_insert_statements( and len(records) > 1 ): do_executemany = True + elif returning_is_required_anyway: + if connection.dialect.insert_executemany_returning: + do_executemany = True + else: + raise sa_exc.InvalidRequestError( + f"Can't use explicit RETURNING for bulk INSERT " + f"operation with " + f"{connection.dialect.dialect_description} backend; " + f"executemany is not supported with RETURNING" + ) else: do_executemany = False @@ -998,6 +1071,7 @@ def _emit_insert_statements( statement = statement.return_defaults( *mapper._server_default_cols[table] ) + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) elif do_executemany: @@ -1006,10 +1080,16 @@ def _emit_insert_statements( if do_executemany: multiparams = [rec[2] for rec in records] - c = connection.execute( + result = connection.execute( statement, multiparams, execution_options=execution_options ) + if use_orm_insert_stmt is not None: + if return_result is None: + return_result = result + else: + return_result = return_result.splice_vertically(result) + if bookkeeping: for ( ( @@ -1027,9 +1107,9 @@ def _emit_insert_statements( returned_defaults, ) in zip_longest( records, - c.context.compiled_parameters, - c.inserted_primary_key_rows, - c.returned_defaults_rows or (), + result.context.compiled_parameters, + result.inserted_primary_key_rows, + result.returned_defaults_rows or (), ): if inserted_primary_key is None: # this is a real problem and means that we didn't @@ -1062,7 +1142,7 @@ def _emit_insert_statements( table, state, state_dict, - c, + result, last_inserted_params, value_params, False, @@ -1071,6 +1151,8 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + assert not returning_is_required_anyway + for ( state, state_dict, @@ -1132,6 +1214,12 @@ def _emit_insert_statements( else: _postfetch_bulk_save(mapper_rec, state_dict, table) + if use_orm_insert_stmt is not None: + if return_result is None: + return _cursor.null_dml_result() + else: + return return_result + def _emit_post_update_statements( base_mapper, uowtransaction, mapper, table, update |
