diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-04-17 10:55:08 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-06-27 21:30:37 -0400 |
| commit | 08c46eea924d23a234bf3feea1a928eb8ae8a00a (patch) | |
| tree | 3795e1d04fa0e35c1e93080320b43c8fe0ed792e /lib/sqlalchemy/orm | |
| parent | 2d9387354f11da322c516412eb5dfe937163c90b (diff) | |
| download | sqlalchemy-08c46eea924d23a234bf3feea1a928eb8ae8a00a.tar.gz | |
ORM executemany returning
Build on #5401 to allow the ORM to take advanage
of executemany INSERT + RETURNING.
Implemented the feature
updated tests
to support INSERT DEFAULT VALUES, needed to come up with
a new syntax for compiler INSERT INTO table (anycol) VALUES (DEFAULT)
which can then be iterated out for executemany.
Added graceful degrade to plain executemany for PostgreSQL <= 8.2
Renamed EXECUTEMANY_DEFAULT to EXECUTEMANY_PLAIN
Fix issue where unicode identifiers or parameter names wouldn't
work with execute_values() under Py2K, because we have to
encode the statement and therefore have to encode the
insert_single_values_expr too.
Correct issue from #5401 to support executemany + return_defaults
for a PK that is explicitly pre-generated, meaning we aren't actually
getting RETURNING but need to return it from compiled_parameters.
Fixes: #5263
Change-Id: Id68e5c158c4f9ebc33b61c06a448907921c2a657
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 156 |
1 files changed, 117 insertions, 39 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 88524dc49..cbe7bde33 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -960,6 +960,7 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults, ) rows += c.rowcount check_rowcount = assert_singlerow @@ -992,6 +993,7 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults, ) rows += c.rowcount else: @@ -1028,6 +1030,9 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, + c.returned_defaults + if not c.context.executemany + else None, ) if check_rowcount: @@ -1086,7 +1091,10 @@ def _emit_insert_statements( and has_all_pks and not hasvalue ): - + # the "we don't need newly generated values back" section. + # here we have all the PKs, all the defaults or we don't want + # to fetch them, or the dialect doesn't support RETURNING at all + # so we have to post-fetch / use lastrowid anyway. records = list(records) multiparams = [rec[2] for rec in records] @@ -1116,63 +1124,132 @@ def _emit_insert_statements( last_inserted_params, value_params, False, + c.returned_defaults + if not c.context.executemany + else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) else: + # here, we need defaults and/or pk values back. + + records = list(records) + if ( + not hasvalue + and connection.dialect.insert_executemany_returning + and len(records) > 1 + ): + do_executemany = True + else: + do_executemany = False + if not has_all_defaults and base_mapper.eager_defaults: statement = statement.return_defaults() elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) + elif do_executemany: + statement = statement.return_defaults(*table.primary_key) - for ( - state, - state_dict, - params, - mapper_rec, - connection, - value_params, - has_all_pks, - has_all_defaults, - ) in records: + if do_executemany: + multiparams = [rec[2] for rec in records] - if value_params: - result = connection.execute( - statement.values(value_params), params - ) - else: - result = cached_connections[connection].execute( - statement, params - ) + c = cached_connections[connection].execute( + statement, multiparams + ) + if bookkeeping: + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + inserted_primary_key, + returned_defaults, + ) in util.zip_longest( + records, + c.context.compiled_parameters, + c.inserted_primary_key_rows, + c.returned_defaults_rows or (), + ): + for pk, col in zip( + inserted_primary_key, mapper._pks_by_table[table], + ): + prop = mapper_rec._columntoproperty[col] + if state_dict.get(prop.key) is None: + state_dict[prop.key] = pk + + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params, + False, + returned_defaults, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + else: + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: + + if value_params: + result = connection.execute( + statement.values(value_params), params + ) + else: + result = cached_connections[connection].execute( + statement, params + ) - primary_key = result.inserted_primary_key - if primary_key is not None: - # set primary key attributes + primary_key = result.inserted_primary_key + assert primary_key for pk, col in zip( primary_key, mapper._pks_by_table[table] ): prop = mapper_rec._columntoproperty[col] - if pk is not None and ( + if ( col in value_params or state_dict.get(prop.key) is None ): state_dict[prop.key] = pk - if bookkeeping: - if state: - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - result.context.compiled_parameters[0], - value_params, - False, - ) - else: - _postfetch_bulk_save(mapper_rec, state_dict, table) + if bookkeeping: + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params, + False, + result.returned_defaults + if not result.context.executemany + else None, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) def _emit_post_update_statements( @@ -1507,6 +1584,7 @@ def _postfetch( params, value_params, isupdate, + returned_defaults, ): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that @@ -1527,7 +1605,7 @@ def _postfetch( load_evt_attrs = [] if returning_cols: - row = result.returned_defaults + row = returned_defaults if row is not None: for row_value, col in zip(row, returning_cols): # pk cols returned from insert are handled |
