diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-04-17 10:55:08 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-06-27 21:30:37 -0400 |
| commit | 08c46eea924d23a234bf3feea1a928eb8ae8a00a (patch) | |
| tree | 3795e1d04fa0e35c1e93080320b43c8fe0ed792e /lib | |
| parent | 2d9387354f11da322c516412eb5dfe937163c90b (diff) | |
| download | sqlalchemy-08c46eea924d23a234bf3feea1a928eb8ae8a00a.tar.gz | |
ORM executemany returning
Build on #5401 to allow the ORM to take advanage
of executemany INSERT + RETURNING.
Implemented the feature
updated tests
to support INSERT DEFAULT VALUES, needed to come up with
a new syntax for compiler INSERT INTO table (anycol) VALUES (DEFAULT)
which can then be iterated out for executemany.
Added graceful degrade to plain executemany for PostgreSQL <= 8.2
Renamed EXECUTEMANY_DEFAULT to EXECUTEMANY_PLAIN
Fix issue where unicode identifiers or parameter names wouldn't
work with execute_values() under Py2K, because we have to
encode the statement and therefore have to encode the
insert_single_values_expr too.
Correct issue from #5401 to support executemany + return_defaults
for a PK that is explicitly pre-generated, meaning we aren't actually
getting RETURNING but need to return it from compiled_parameters.
Fixes: #5263
Change-Id: Id68e5c158c4f9ebc33b61c06a448907921c2a657
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 156 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 8 |
6 files changed, 156 insertions, 49 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 6364838a6..850e5717c 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -643,7 +643,7 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): pass -EXECUTEMANY_DEFAULT = util.symbol("executemany_default", canonical=0) +EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0) EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1) EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2) EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol( @@ -655,6 +655,12 @@ EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol( class PGDialect_psycopg2(PGDialect): driver = "psycopg2" if util.py2k: + # turn off supports_unicode_statements for Python 2. psycopg2 supports + # unicode statements in Py2K. But! it does not support unicode *bound + # parameter names* because it uses the Python "%" operator to + # interpolate these into the string, and this fails. So for Py2K, we + # have to use full-on encoding for statements and parameters before + # passing to cursor.execute(). supports_unicode_statements = False supports_server_side_cursors = True @@ -714,7 +720,7 @@ class PGDialect_psycopg2(PGDialect): self.executemany_mode = util.symbol.parse_user_argument( executemany_mode, { - EXECUTEMANY_DEFAULT: [None], + EXECUTEMANY_PLAIN: [None], EXECUTEMANY_BATCH: ["batch"], EXECUTEMANY_VALUES: ["values_only"], EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch", "values"], @@ -747,7 +753,12 @@ class PGDialect_psycopg2(PGDialect): and self._hstore_oids(connection.connection) is not None ) - # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9 + # PGDialect.initialize() checks server version for <= 8.2 and sets + # this flag to False if so + if not self.full_returning: + self.insert_executemany_returning = False + self.executemany_mode = EXECUTEMANY_PLAIN + self.supports_sane_multi_rowcount = not ( self.executemany_mode & EXECUTEMANY_BATCH ) @@ -876,6 +887,9 @@ class PGDialect_psycopg2(PGDialect): executemany_values = ( "(%s)" % context.compiled.insert_single_values_expr ) + if not self.supports_unicode_statements: + executemany_values = executemany_values.encode(self.encoding) + # guard for statement that was altered via event hook or similar if executemany_values not in statement: executemany_values = None @@ -883,10 +897,6 @@ class PGDialect_psycopg2(PGDialect): executemany_values = None if executemany_values: - # Currently, SQLAlchemy does not pass "RETURNING" statements - # into executemany(), since no DBAPI has ever supported that - # until the introduction of psycopg2's executemany_values, so - # we are not yet using the fetch=True flag. statement = statement.replace(executemany_values, "%s") if self.executemany_values_page_size: kwargs = {"page_size": self.executemany_values_page_size} diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 790f68de7..f2f30455a 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -824,7 +824,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if self.isinsert or self.isupdate or self.isdelete: self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) - self._is_implicit_returning = ( + self._is_implicit_returning = bool( compiled.returning and not compiled.statement._returning ) @@ -1291,11 +1291,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): result.out_parameters = out_parameters def _setup_dml_or_text_result(self): - if self.isinsert and not self.executemany: + if self.isinsert: if ( not self._is_implicit_returning and not self.compiled.inline and self.dialect.postfetch_lastrowid + and not self.executemany ): self._setup_ins_pk_from_lastrowid() @@ -1375,7 +1376,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): getter = self.compiled._inserted_primary_key_from_lastrowid_getter self.inserted_primary_key_rows = [ - getter(None, self.compiled_parameters[0]) + getter(None, param) for param in self.compiled_parameters ] def _setup_ins_pk_from_implicit_returning(self, result, rows): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 88524dc49..cbe7bde33 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -960,6 +960,7 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults, ) rows += c.rowcount check_rowcount = assert_singlerow @@ -992,6 +993,7 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults, ) rows += c.rowcount else: @@ -1028,6 +1030,9 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults + if not c.context.executemany + else None, ) if check_rowcount: @@ -1086,7 +1091,10 @@ def _emit_insert_statements( and has_all_pks and not hasvalue ): - + # the "we don't need newly generated values back" section. + # here we have all the PKs, all the defaults or we don't want + # to fetch them, or the dialect doesn't support RETURNING at all + # so we have to post-fetch / use lastrowid anyway. records = list(records) multiparams = [rec[2] for rec in records] @@ -1116,63 +1124,132 @@ def _emit_insert_statements( last_inserted_params, value_params, False, + c.returned_defaults + if not c.context.executemany + else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + # here, we need defaults and/or pk values back. + + records = list(records) + if ( + not hasvalue + and connection.dialect.insert_executemany_returning + and len(records) > 1 + ): + do_executemany = True + else: + do_executemany = False + if not has_all_defaults and base_mapper.eager_defaults: statement = statement.return_defaults() elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) + elif do_executemany: + statement = statement.return_defaults(*table.primary_key) - for ( - state, - state_dict, - params, - mapper_rec, - connection, - value_params, - has_all_pks, - has_all_defaults, - ) in records: + if do_executemany: + multiparams = [rec[2] for rec in records] - if value_params: - result = connection.execute( - statement.values(value_params), params - ) - else: - result = cached_connections[connection].execute( - statement, params - ) + c = cached_connections[connection].execute( + statement, multiparams + ) + if bookkeeping: + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + inserted_primary_key, + returned_defaults, + ) in util.zip_longest( + records, + c.context.compiled_parameters, + c.inserted_primary_key_rows, + c.returned_defaults_rows or (), + ): + for pk, col in zip( + inserted_primary_key, mapper._pks_by_table[table], + ): + prop = mapper_rec._columntoproperty[col] + if state_dict.get(prop.key) is None: + state_dict[prop.key] = pk + + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params, + False, + returned_defaults, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + else: + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: + + if value_params: + result = connection.execute( + statement.values(value_params), params + ) + else: + result = cached_connections[connection].execute( + statement, params + ) - primary_key = result.inserted_primary_key - if primary_key is not None: - # set primary key attributes + primary_key = result.inserted_primary_key + assert primary_key for pk, col in zip( primary_key, mapper._pks_by_table[table] ): prop = mapper_rec._columntoproperty[col] - if pk is not None and ( + if ( col in value_params or state_dict.get(prop.key) is None ): state_dict[prop.key] = pk - if bookkeeping: - if state: - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - result.context.compiled_parameters[0], - value_params, - False, - ) - else: - _postfetch_bulk_save(mapper_rec, state_dict, table) + if bookkeeping: + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params, + False, + result.returned_defaults + if not result.context.executemany + else None, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) def _emit_post_update_statements( @@ -1507,6 +1584,7 @@ def _postfetch( params, value_params, isupdate, + returned_defaults, ): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that @@ -1527,7 +1605,7 @@ def _postfetch( load_evt_attrs = [] if returning_cols: - row = result.returned_defaults + row = returned_defaults if row is not None: for row_value, col in zip(row, returning_cols): # pk cols returned from insert are handled diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index c80d95a2c..85112f850 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -157,6 +157,12 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): values = _extend_values_for_multiparams( compiler, stmt, compile_state, values, kw ) + elif not values and compiler.for_executemany: + # convert an "INSERT DEFAULT VALUES" + # into INSERT (firstcol) VALUES (DEFAULT) which can be turned + # into an in-place multi values. This supports + # insert_executemany_returning mode :) + values = [(stmt.table.columns[0], "DEFAULT")] return values diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 1ea366dac..998dde66b 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -343,6 +343,7 @@ class AssertsCompiledSQL(object): result, params=None, checkparams=None, + for_executemany=False, check_literal_execute=None, check_post_param=None, dialect=None, @@ -391,6 +392,9 @@ class AssertsCompiledSQL(object): if render_postcompile: compile_kwargs["render_postcompile"] = True + if for_executemany: + kw["for_executemany"] = True + if render_schema_translate: kw["render_schema_translate"] = True diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index ef324635e..caf61a806 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -325,6 +325,14 @@ class EachOf(AssertRule): super(EachOf, self).no_more_statements() +class Conditional(EachOf): + def __init__(self, condition, rules, else_rules): + if condition: + super(Conditional, self).__init__(*rules) + else: + super(Conditional, self).__init__(*else_rules) + + class Or(AllOf): def process_statement(self, execute_observed): for rule in self.rules: |
