From f8ecdf47f0975b8b4e357fde2008d9aae8c50239 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 20 Sep 2016 11:33:16 -0400 Subject: Allow SQL expressions to be set on PK columns Removes an unnecessary transfer of modified PK column value to the params dictionary, so that if the modified PK column is already present in value_params, this remains in effect. Also propagate a new flag through to _emit_update_statements() that will trip "return_defaults()" across the board if a PK col w/ SQL expression change is present, and pull this PK value in _postfetch as well assuming we're an UPDATE. Change-Id: I9ae87f964df9ba8faea8e25e96b8327f968e5d1b Fixes: #3801 --- lib/sqlalchemy/orm/persistence.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) (limited to 'lib/sqlalchemy/orm') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0b029f466..56b028375 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -506,6 +506,7 @@ def _collect_update_commands( elif not (params or value_params): continue + has_all_pks = True if bulk: pk_params = dict( (propkey_to_col[propkey]._label, state_dict.get(propkey)) @@ -530,7 +531,8 @@ def _collect_update_commands( else: # else, use the old value to locate the row pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] + if col in value_params: + has_all_pks = False else: pk_params[col._label] = history.unchanged[0] if pk_params[col._label] is None: @@ -542,7 +544,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params, has_all_defaults) + connection, value_params, has_all_defaults, has_all_pks) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction, cached_stmt = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue, has_all_defaults), \ + for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys bool(rec[5]), # whether or not we have "value" parameters - rec[6] # has_all_defaults + rec[6], # has_all_defaults + rec[7] # has all pks ) ): rows = 0 @@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction, connection.dialect.supports_sane_multi_rowcount allow_multirow = has_all_defaults and not needs_version_id - if bookkeeping and not has_all_defaults and \ + if not has_all_pks: + statement = statement.return_defaults() + elif bookkeeping and not has_all_defaults and \ mapper.base_mapper.eager_defaults: statement = statement.return_defaults() elif mapper.version_id_col is not None: @@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: c = connection.execute( statement.values(value_params), params) @@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, has_all_defaults, \ + has_all_pks in records: c = cached_connections[connection].\ execute(statement, params) @@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: if bookkeeping: _postfetch( mapper, @@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table, row = result.context.returned_defaults if row is not None: for col in returning_cols: - if col.primary_key: + # pk cols returned from insert are handled + # distinctly, don't step on the values here + if col.primary_key and result.context.isinsert: continue dict_[mapper._columntoproperty[col].key] = row[col] if refresh_flush: -- cgit v1.2.1