diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-12-14 17:24:47 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-12-14 17:30:21 -0500 |
| commit | 0e4c4d7efc08d04c3c0ae960428b08ada37e4a91 (patch) | |
| tree | 4421c6681b9bc6025c5baccffbe5d61b901c48da /lib/sqlalchemy/orm | |
| parent | 7d96ad4d535dc02a8ab1384df1db94dea2a045b5 (diff) | |
| download | sqlalchemy-0e4c4d7efc08d04c3c0ae960428b08ada37e4a91.tar.gz | |
- Fixed bug in :meth:`.Update.return_defaults` which would cause all
insert-default holding columns not otherwise included in the SET
clause (such as primary key cols) to get rendered into the RETURNING
even though this is an UPDATE.
- Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this
flag would not be honored correctly in the case that multiple
UPDATE statements were to be emitted, either as part of a flush
or a bulk update operation. Additionally, RETURNING
would be emitted unnecessarily within update statements.
fixes #3609
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 36 |
2 files changed, 36 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5ade4b966..95aa14a26 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1970,12 +1970,24 @@ class Mapper(InspectionAttr): ( table, frozenset([ - col for col in columns + col.key for col in columns if col.server_default is not None]) ) for table, columns in self._cols_by_table.items() ) + @_memoized_configured_property + def _server_onupdate_default_cols(self): + return dict( + ( + table, + frozenset([ + col.key for col in columns + if col.server_onupdate is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + @property def selectable(self): """The :func:`.select` construct this :class:`.Mapper` selects from diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 768c1146a..88c96e94c 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -448,6 +448,7 @@ def _collect_update_commands( set(propkey_to_col).intersection(state_dict).difference( mapper._pk_keys_by_table[table]) ) + has_all_defaults = True else: params = {} for propkey in set(propkey_to_col).intersection( @@ -463,6 +464,12 @@ def _collect_update_commands( value, state.committed_state[propkey]) is not True: params[col.key] = value + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_onupdate_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True + if update_version_id is not None and \ mapper.version_id_col in mapper._cols_by_table[table]: @@ -529,7 +536,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params) + connection, value_params, has_all_defaults) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -619,23 +626,20 @@ def _emit_update_statements(base_mapper, uowtransaction, type_=mapper.version_id_col.type)) stmt = table.update(clause) - if mapper.base_mapper.eager_defaults: - stmt = stmt.return_defaults() - elif mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - return stmt statement = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue), \ + for (connection, paramkeys, hasvalue, has_all_defaults), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys - bool(rec[5]))): # whether or not we have "value" parameters - + bool(rec[5]), # whether or not we have "value" parameters + rec[6] # has_all_defaults + ) + ): rows = 0 records = list(records) @@ -645,11 +649,16 @@ def _emit_update_statements(base_mapper, uowtransaction, assert_singlerow = connection.dialect.supports_sane_rowcount assert_multirow = assert_singlerow and \ connection.dialect.supports_sane_multi_rowcount - allow_multirow = not needs_version_id + allow_multirow = has_all_defaults and not needs_version_id + + if bookkeeping and mapper.base_mapper.eager_defaults: + statement = statement.return_defaults() + elif mapper.version_id_col is not None: + statement = statement.return_defaults(mapper.version_id_col) if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = connection.execute( statement.values(value_params), params) @@ -669,7 +678,7 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: c = cached_connections[connection].\ execute(statement, params) @@ -699,7 +708,7 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount for state, state_dict, params, mapper, \ - connection, value_params in records: + connection, value_params, has_all_defaults in records: if bookkeeping: _postfetch( mapper, @@ -741,6 +750,7 @@ def _emit_insert_statements(base_mapper, uowtransaction, bool(rec[5]), # whether we have "value" parameters rec[6], rec[7])): + if not bookkeeping or \ ( has_all_defaults |
