summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2016-09-20 11:33:16 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2016-09-20 11:33:16 -0400
commitf8ecdf47f0975b8b4e357fde2008d9aae8c50239 (patch)
tree3efea46680fc5ca957387542f21b5e52eaf1a737 /lib/sqlalchemy/orm
parent881369b949cff44e0017fdc28d9722ef3c26171a (diff)
downloadsqlalchemy-f8ecdf47f0975b8b4e357fde2008d9aae8c50239.tar.gz
Allow SQL expressions to be set on PK columns
Removes an unnecessary transfer of modified PK column value to the params dictionary, so that if the modified PK column is already present in value_params, this remains in effect. Also propagate a new flag through to _emit_update_statements() that will trip "return_defaults()" across the board if a PK col w/ SQL expression change is present, and pull this PK value in _postfetch as well assuming we're an UPDATE. Change-Id: I9ae87f964df9ba8faea8e25e96b8327f968e5d1b Fixes: #3801
Diffstat (limited to 'lib/sqlalchemy/orm')
-rw-r--r--lib/sqlalchemy/orm/persistence.py28
1 files changed, 19 insertions, 9 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 0b029f466..56b028375 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -506,6 +506,7 @@ def _collect_update_commands(
elif not (params or value_params):
continue
+ has_all_pks = True
if bulk:
pk_params = dict(
(propkey_to_col[propkey]._label, state_dict.get(propkey))
@@ -530,7 +531,8 @@ def _collect_update_commands(
else:
# else, use the old value to locate the row
pk_params[col._label] = history.deleted[0]
- params[col.key] = history.added[0]
+ if col in value_params:
+ has_all_pks = False
else:
pk_params[col._label] = history.unchanged[0]
if pk_params[col._label] is None:
@@ -542,7 +544,7 @@ def _collect_update_commands(
params.update(pk_params)
yield (
state, state_dict, params, mapper,
- connection, value_params, has_all_defaults)
+ connection, value_params, has_all_defaults, has_all_pks)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction,
cached_stmt = base_mapper._memo(('update', table), update_stmt)
- for (connection, paramkeys, hasvalue, has_all_defaults), \
+ for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
records in groupby(
update,
lambda rec: (
rec[4], # connection
set(rec[2]), # set of parameter keys
bool(rec[5]), # whether or not we have "value" parameters
- rec[6] # has_all_defaults
+ rec[6], # has_all_defaults
+ rec[7] # has all pks
)
):
rows = 0
@@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction,
connection.dialect.supports_sane_multi_rowcount
allow_multirow = has_all_defaults and not needs_version_id
- if bookkeeping and not has_all_defaults and \
+ if not has_all_pks:
+ statement = statement.return_defaults()
+ elif bookkeeping and not has_all_defaults and \
mapper.base_mapper.eager_defaults:
statement = statement.return_defaults()
elif mapper.version_id_col is not None:
@@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
if hasvalue:
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
c = connection.execute(
statement.values(value_params),
params)
@@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, has_all_defaults, \
+ has_all_pks in records:
c = cached_connections[connection].\
execute(statement, params)
@@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows += c.rowcount
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
if bookkeeping:
_postfetch(
mapper,
@@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table,
row = result.context.returned_defaults
if row is not None:
for col in returning_cols:
- if col.primary_key:
+ # pk cols returned from insert are handled
+ # distinctly, don't step on the values here
+ if col.primary_key and result.context.isinsert:
continue
dict_[mapper._columntoproperty[col].key] = row[col]
if refresh_flush: