summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py138
1 files changed, 113 insertions, 25 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index abd528986..dfb61c28a 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -31,6 +31,7 @@ from .. import exc as sa_exc
from .. import future
from .. import sql
from .. import util
+from ..engine import cursor as _cursor
from ..sql import operators
from ..sql.elements import BooleanClauseList
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -398,6 +399,11 @@ def _collect_insert_commands(
None
)
+ if bulk and mapper._set_polymorphic_identity:
+ params.setdefault(
+ mapper._polymorphic_attr_key, mapper.polymorphic_identity
+ )
+
yield (
state,
state_dict,
@@ -411,7 +417,11 @@ def _collect_insert_commands(
def _collect_update_commands(
- uowtransaction, table, states_to_update, bulk=False
+ uowtransaction,
+ table,
+ states_to_update,
+ bulk=False,
+ use_orm_update_stmt=None,
):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -437,7 +447,11 @@ def _collect_update_commands(
pks = mapper._pks_by_table[table]
- value_params = {}
+ if use_orm_update_stmt is not None:
+ # TODO: ordered values, etc
+ value_params = use_orm_update_stmt._values
+ else:
+ value_params = {}
propkey_to_col = mapper._propkey_to_col[table]
@@ -697,6 +711,7 @@ def _emit_update_statements(
table,
update,
bookkeeping=True,
+ use_orm_update_stmt=None,
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
@@ -708,7 +723,7 @@ def _emit_update_statements(
execution_options = {"compiled_cache": base_mapper._compiled_cache}
- def update_stmt():
+ def update_stmt(existing_stmt=None):
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
@@ -725,10 +740,17 @@ def _emit_update_statements(
)
)
- stmt = table.update().where(clauses)
+ if existing_stmt is not None:
+ stmt = existing_stmt.where(clauses)
+ else:
+ stmt = table.update().where(clauses)
return stmt
- cached_stmt = base_mapper._memo(("update", table), update_stmt)
+ if use_orm_update_stmt is not None:
+ cached_stmt = update_stmt(use_orm_update_stmt)
+
+ else:
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
for (
(connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
@@ -747,6 +769,15 @@ def _emit_update_statements(
records = list(records)
statement = cached_stmt
+
+ if use_orm_update_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_update_table": table,
+ "_emit_update_mapper": mapper,
+ }
+ )
+
return_defaults = False
if not has_all_pks:
@@ -904,16 +935,35 @@ def _emit_insert_statements(
table,
insert,
bookkeeping=True,
+ use_orm_insert_stmt=None,
+ execution_options=None,
):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ if use_orm_insert_stmt is not None:
+ cached_stmt = use_orm_insert_stmt
+ exec_opt = util.EMPTY_DICT
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
+ # if a user query with RETURNING was passed, we definitely need
+ # to use RETURNING.
+ returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
+ else:
+ returning_is_required_anyway = False
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ exec_opt = {"compiled_cache": base_mapper._compiled_cache}
+
+ if execution_options:
+ execution_options = util.EMPTY_DICT.merge_with(
+ exec_opt, execution_options
+ )
+ else:
+ execution_options = exec_opt
+
+ return_result = None
for (
- (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ (connection, _, hasvalue, has_all_pks, has_all_defaults),
records,
) in groupby(
insert,
@@ -928,17 +978,29 @@ def _emit_insert_statements(
statement = cached_stmt
+ if use_orm_insert_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_insert_table": table,
+ "_emit_insert_mapper": mapper,
+ }
+ )
+
if (
- not bookkeeping
- or (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not base_mapper.local_table.implicit_returning
- or not connection.dialect.insert_returning
+ (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not base_mapper.local_table.implicit_returning
+ or not connection.dialect.insert_returning
+ )
)
+ and not returning_is_required_anyway
and has_all_pks
and not hasvalue
):
+
# the "we don't need newly generated values back" section.
# here we have all the PKs, all the defaults or we don't want
# to fetch them, or the dialect doesn't support RETURNING at all
@@ -946,7 +1008,7 @@ def _emit_insert_statements(
records = list(records)
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
if bookkeeping:
@@ -962,7 +1024,7 @@ def _emit_insert_statements(
has_all_defaults,
),
last_inserted_params,
- ) in zip(records, c.context.compiled_parameters):
+ ) in zip(records, result.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
@@ -970,19 +1032,20 @@ def _emit_insert_statements(
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
- c.returned_defaults
- if not c.context.executemany
+ result.returned_defaults
+ if not result.context.executemany
else None,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
- # here, we need defaults and/or pk values back.
+ # here, we need defaults and/or pk values back or we otherwise
+ # know that we are using RETURNING in any case
records = list(records)
if (
@@ -991,6 +1054,16 @@ def _emit_insert_statements(
and len(records) > 1
):
do_executemany = True
+ elif returning_is_required_anyway:
+ if connection.dialect.insert_executemany_returning:
+ do_executemany = True
+ else:
+ raise sa_exc.InvalidRequestError(
+ f"Can't use explicit RETURNING for bulk INSERT "
+ f"operation with "
+ f"{connection.dialect.dialect_description} backend; "
+ f"executemany is not supported with RETURNING"
+ )
else:
do_executemany = False
@@ -998,6 +1071,7 @@ def _emit_insert_statements(
statement = statement.return_defaults(
*mapper._server_default_cols[table]
)
+
if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
elif do_executemany:
@@ -1006,10 +1080,16 @@ def _emit_insert_statements(
if do_executemany:
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return_result = result
+ else:
+ return_result = return_result.splice_vertically(result)
+
if bookkeeping:
for (
(
@@ -1027,9 +1107,9 @@ def _emit_insert_statements(
returned_defaults,
) in zip_longest(
records,
- c.context.compiled_parameters,
- c.inserted_primary_key_rows,
- c.returned_defaults_rows or (),
+ result.context.compiled_parameters,
+ result.inserted_primary_key_rows,
+ result.returned_defaults_rows or (),
):
if inserted_primary_key is None:
# this is a real problem and means that we didn't
@@ -1062,7 +1142,7 @@ def _emit_insert_statements(
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
@@ -1071,6 +1151,8 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
+ assert not returning_is_required_anyway
+
for (
state,
state_dict,
@@ -1132,6 +1214,12 @@ def _emit_insert_statements(
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return _cursor.null_dml_result()
+ else:
+ return return_result
+
def _emit_post_update_statements(
base_mapper, uowtransaction, mapper, table, update