diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-09-20 11:33:16 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-09-20 11:33:16 -0400 |
| commit | f8ecdf47f0975b8b4e357fde2008d9aae8c50239 (patch) | |
| tree | 3efea46680fc5ca957387542f21b5e52eaf1a737 /lib/sqlalchemy/orm | |
| parent | 881369b949cff44e0017fdc28d9722ef3c26171a (diff) | |
| download | sqlalchemy-f8ecdf47f0975b8b4e357fde2008d9aae8c50239.tar.gz | |
Allow SQL expressions to be set on PK columns
Removes an unnecessary transfer of modified PK column
value to the params dictionary, so that if the modified PK column
is already present in value_params, this remains in effect. Also
propagate a new flag through to _emit_update_statements() that will
trip "return_defaults()" across the board if a PK col w/ SQL expression
change is present, and pull this PK value in _postfetch as well assuming
we're an UPDATE.
Change-Id: I9ae87f964df9ba8faea8e25e96b8327f968e5d1b
Fixes: #3801
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0b029f466..56b028375 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -506,6 +506,7 @@ def _collect_update_commands( elif not (params or value_params): continue + has_all_pks = True if bulk: pk_params = dict( (propkey_to_col[propkey]._label, state_dict.get(propkey)) @@ -530,7 +531,8 @@ def _collect_update_commands( else: # else, use the old value to locate the row pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] + if col in value_params: + has_all_pks = False else: pk_params[col._label] = history.unchanged[0] if pk_params[col._label] is None: @@ -542,7 +544,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params, has_all_defaults) + connection, value_params, has_all_defaults, has_all_pks) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction, cached_stmt = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue, has_all_defaults), \ + for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys bool(rec[5]), # whether or not we have "value" parameters - rec[6] # has_all_defaults + rec[6], # has_all_defaults + rec[7] # has all pks ) ): rows = 0 @@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction, connection.dialect.supports_sane_multi_rowcount allow_multirow = has_all_defaults and not needs_version_id - if bookkeeping and not has_all_defaults and \ + if not has_all_pks: + statement = statement.return_defaults() + elif bookkeeping and not has_all_defaults and \ mapper.base_mapper.eager_defaults: statement = statement.return_defaults() elif mapper.version_id_col is not None: @@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: c = connection.execute( statement.values(value_params), params) @@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, has_all_defaults, \ + has_all_pks in records: c = cached_connections[connection].\ execute(statement, params) @@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: if bookkeeping: _postfetch( mapper, @@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table, row = result.context.returned_defaults if row is not None: for col in returning_cols: - if col.primary_key: + # pk cols returned from insert are handled + # distinctly, don't step on the values here + if col.primary_key and result.context.isinsert: continue dict_[mapper._columntoproperty[col].key] = row[col] if refresh_flush: |
