summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-04-17 10:55:08 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-06-27 21:30:37 -0400
commit08c46eea924d23a234bf3feea1a928eb8ae8a00a (patch)
tree3795e1d04fa0e35c1e93080320b43c8fe0ed792e /lib
parent2d9387354f11da322c516412eb5dfe937163c90b (diff)
downloadsqlalchemy-08c46eea924d23a234bf3feea1a928eb8ae8a00a.tar.gz
ORM executemany returning
Build on #5401 to allow the ORM to take advanage of executemany INSERT + RETURNING. Implemented the feature updated tests to support INSERT DEFAULT VALUES, needed to come up with a new syntax for compiler INSERT INTO table (anycol) VALUES (DEFAULT) which can then be iterated out for executemany. Added graceful degrade to plain executemany for PostgreSQL <= 8.2 Renamed EXECUTEMANY_DEFAULT to EXECUTEMANY_PLAIN Fix issue where unicode identifiers or parameter names wouldn't work with execute_values() under Py2K, because we have to encode the statement and therefore have to encode the insert_single_values_expr too. Correct issue from #5401 to support executemany + return_defaults for a PK that is explicitly pre-generated, meaning we aren't actually getting RETURNING but need to return it from compiled_parameters. Fixes: #5263 Change-Id: Id68e5c158c4f9ebc33b61c06a448907921c2a657
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py24
-rw-r--r--lib/sqlalchemy/engine/default.py7
-rw-r--r--lib/sqlalchemy/orm/persistence.py156
-rw-r--r--lib/sqlalchemy/sql/crud.py6
-rw-r--r--lib/sqlalchemy/testing/assertions.py4
-rw-r--r--lib/sqlalchemy/testing/assertsql.py8
6 files changed, 156 insertions, 49 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index 6364838a6..850e5717c 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -643,7 +643,7 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
pass
-EXECUTEMANY_DEFAULT = util.symbol("executemany_default", canonical=0)
+EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0)
EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1)
EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2)
EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol(
@@ -655,6 +655,12 @@ EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol(
class PGDialect_psycopg2(PGDialect):
driver = "psycopg2"
if util.py2k:
+ # turn off supports_unicode_statements for Python 2. psycopg2 supports
+ # unicode statements in Py2K. But! it does not support unicode *bound
+ # parameter names* because it uses the Python "%" operator to
+ # interpolate these into the string, and this fails. So for Py2K, we
+ # have to use full-on encoding for statements and parameters before
+ # passing to cursor.execute().
supports_unicode_statements = False
supports_server_side_cursors = True
@@ -714,7 +720,7 @@ class PGDialect_psycopg2(PGDialect):
self.executemany_mode = util.symbol.parse_user_argument(
executemany_mode,
{
- EXECUTEMANY_DEFAULT: [None],
+ EXECUTEMANY_PLAIN: [None],
EXECUTEMANY_BATCH: ["batch"],
EXECUTEMANY_VALUES: ["values_only"],
EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch", "values"],
@@ -747,7 +753,12 @@ class PGDialect_psycopg2(PGDialect):
and self._hstore_oids(connection.connection) is not None
)
- # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9
+ # PGDialect.initialize() checks server version for <= 8.2 and sets
+ # this flag to False if so
+ if not self.full_returning:
+ self.insert_executemany_returning = False
+ self.executemany_mode = EXECUTEMANY_PLAIN
+
self.supports_sane_multi_rowcount = not (
self.executemany_mode & EXECUTEMANY_BATCH
)
@@ -876,6 +887,9 @@ class PGDialect_psycopg2(PGDialect):
executemany_values = (
"(%s)" % context.compiled.insert_single_values_expr
)
+ if not self.supports_unicode_statements:
+ executemany_values = executemany_values.encode(self.encoding)
+
# guard for statement that was altered via event hook or similar
if executemany_values not in statement:
executemany_values = None
@@ -883,10 +897,6 @@ class PGDialect_psycopg2(PGDialect):
executemany_values = None
if executemany_values:
- # Currently, SQLAlchemy does not pass "RETURNING" statements
- # into executemany(), since no DBAPI has ever supported that
- # until the introduction of psycopg2's executemany_values, so
- # we are not yet using the fetch=True flag.
statement = statement.replace(executemany_values, "%s")
if self.executemany_values_page_size:
kwargs = {"page_size": self.executemany_values_page_size}
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 790f68de7..f2f30455a 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -824,7 +824,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if self.isinsert or self.isupdate or self.isdelete:
self.is_crud = True
self._is_explicit_returning = bool(compiled.statement._returning)
- self._is_implicit_returning = (
+ self._is_implicit_returning = bool(
compiled.returning and not compiled.statement._returning
)
@@ -1291,11 +1291,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
result.out_parameters = out_parameters
def _setup_dml_or_text_result(self):
- if self.isinsert and not self.executemany:
+ if self.isinsert:
if (
not self._is_implicit_returning
and not self.compiled.inline
and self.dialect.postfetch_lastrowid
+ and not self.executemany
):
self._setup_ins_pk_from_lastrowid()
@@ -1375,7 +1376,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
getter = self.compiled._inserted_primary_key_from_lastrowid_getter
self.inserted_primary_key_rows = [
- getter(None, self.compiled_parameters[0])
+ getter(None, param) for param in self.compiled_parameters
]
def _setup_ins_pk_from_implicit_returning(self, result, rows):
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 88524dc49..cbe7bde33 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -960,6 +960,7 @@ def _emit_update_statements(
c.context.compiled_parameters[0],
value_params,
True,
+ c.returned_defaults,
)
rows += c.rowcount
check_rowcount = assert_singlerow
@@ -992,6 +993,7 @@ def _emit_update_statements(
c.context.compiled_parameters[0],
value_params,
True,
+ c.returned_defaults,
)
rows += c.rowcount
else:
@@ -1028,6 +1030,9 @@ def _emit_update_statements(
c.context.compiled_parameters[0],
value_params,
True,
+ c.returned_defaults
+ if not c.context.executemany
+ else None,
)
if check_rowcount:
@@ -1086,7 +1091,10 @@ def _emit_insert_statements(
and has_all_pks
and not hasvalue
):
-
+ # the "we don't need newly generated values back" section.
+ # here we have all the PKs, all the defaults or we don't want
+ # to fetch them, or the dialect doesn't support RETURNING at all
+ # so we have to post-fetch / use lastrowid anyway.
records = list(records)
multiparams = [rec[2] for rec in records]
@@ -1116,63 +1124,132 @@ def _emit_insert_statements(
last_inserted_params,
value_params,
False,
+ c.returned_defaults
+ if not c.context.executemany
+ else None,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
+ # here, we need defaults and/or pk values back.
+
+ records = list(records)
+ if (
+ not hasvalue
+ and connection.dialect.insert_executemany_returning
+ and len(records) > 1
+ ):
+ do_executemany = True
+ else:
+ do_executemany = False
+
if not has_all_defaults and base_mapper.eager_defaults:
statement = statement.return_defaults()
elif mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
+ elif do_executemany:
+ statement = statement.return_defaults(*table.primary_key)
- for (
- state,
- state_dict,
- params,
- mapper_rec,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- ) in records:
+ if do_executemany:
+ multiparams = [rec[2] for rec in records]
- if value_params:
- result = connection.execute(
- statement.values(value_params), params
- )
- else:
- result = cached_connections[connection].execute(
- statement, params
- )
+ c = cached_connections[connection].execute(
+ statement, multiparams
+ )
+ if bookkeeping:
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ inserted_primary_key,
+ returned_defaults,
+ ) in util.zip_longest(
+ records,
+ c.context.compiled_parameters,
+ c.inserted_primary_key_rows,
+ c.returned_defaults_rows or (),
+ ):
+ for pk, col in zip(
+ inserted_primary_key, mapper._pks_by_table[table],
+ ):
+ prop = mapper_rec._columntoproperty[col]
+ if state_dict.get(prop.key) is None:
+ state_dict[prop.key] = pk
+
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params,
+ False,
+ returned_defaults,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
+ else:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in records:
+
+ if value_params:
+ result = connection.execute(
+ statement.values(value_params), params
+ )
+ else:
+ result = cached_connections[connection].execute(
+ statement, params
+ )
- primary_key = result.inserted_primary_key
- if primary_key is not None:
- # set primary key attributes
+ primary_key = result.inserted_primary_key
+ assert primary_key
for pk, col in zip(
primary_key, mapper._pks_by_table[table]
):
prop = mapper_rec._columntoproperty[col]
- if pk is not None and (
+ if (
col in value_params
or state_dict.get(prop.key) is None
):
state_dict[prop.key] = pk
- if bookkeeping:
- if state:
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- result,
- result.context.compiled_parameters[0],
- value_params,
- False,
- )
- else:
- _postfetch_bulk_save(mapper_rec, state_dict, table)
+ if bookkeeping:
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result,
+ result.context.compiled_parameters[0],
+ value_params,
+ False,
+ result.returned_defaults
+ if not result.context.executemany
+ else None,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
def _emit_post_update_statements(
@@ -1507,6 +1584,7 @@ def _postfetch(
params,
value_params,
isupdate,
+ returned_defaults,
):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
@@ -1527,7 +1605,7 @@ def _postfetch(
load_evt_attrs = []
if returning_cols:
- row = result.returned_defaults
+ row = returned_defaults
if row is not None:
for row_value, col in zip(row, returning_cols):
# pk cols returned from insert are handled
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index c80d95a2c..85112f850 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -157,6 +157,12 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
values = _extend_values_for_multiparams(
compiler, stmt, compile_state, values, kw
)
+ elif not values and compiler.for_executemany:
+ # convert an "INSERT DEFAULT VALUES"
+ # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
+ # into an in-place multi values. This supports
+ # insert_executemany_returning mode :)
+ values = [(stmt.table.columns[0], "DEFAULT")]
return values
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 1ea366dac..998dde66b 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -343,6 +343,7 @@ class AssertsCompiledSQL(object):
result,
params=None,
checkparams=None,
+ for_executemany=False,
check_literal_execute=None,
check_post_param=None,
dialect=None,
@@ -391,6 +392,9 @@ class AssertsCompiledSQL(object):
if render_postcompile:
compile_kwargs["render_postcompile"] = True
+ if for_executemany:
+ kw["for_executemany"] = True
+
if render_schema_translate:
kw["render_schema_translate"] = True
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index ef324635e..caf61a806 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -325,6 +325,14 @@ class EachOf(AssertRule):
super(EachOf, self).no_more_statements()
+class Conditional(EachOf):
+ def __init__(self, condition, rules, else_rules):
+ if condition:
+ super(Conditional, self).__init__(*rules)
+ else:
+ super(Conditional, self).__init__(*else_rules)
+
+
class Or(AllOf):
def process_statement(self, execute_observed):
for rule in self.rules: