diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 103 |
1 files changed, 62 insertions, 41 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 17ce2e624..9d39c39b0 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -248,9 +248,10 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, has_all_pks = True has_all_defaults = True + has_version_id_generator = mapper.version_id_generator is not False \ + and mapper.version_id_col is not None for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col and \ - mapper.version_id_generator is not False: + if has_version_id_generator and col is mapper.version_id_col: val = mapper.version_id_generator(None) params[col.key] = val else: @@ -305,6 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction, value_params = {} hasdata = hasnull = False + for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: params[col._label] = \ @@ -341,6 +343,7 @@ def _collect_update_commands(base_mapper, uowtransaction, prop = mapper._columntoproperty[col] history = state.manager[prop.key].impl.get_history( state, state_dict, + attributes.PASSIVE_OFF if col in pks else attributes.PASSIVE_NO_INITIALIZE) if history.added: if isinstance(history.added[0], @@ -381,8 +384,7 @@ def _collect_update_commands(base_mapper, uowtransaction, else: hasdata = True elif col in pks: - value = state.manager[prop.key].impl.get( - state, state_dict) + value = history.unchanged[0] if value is None: hasnull = True params[col._label] = value @@ -500,41 +502,63 @@ def _emit_update_statements(base_mapper, uowtransaction, statement = base_mapper._memo(('update', table), update_stmt) - rows = 0 - for state, state_dict, params, mapper, \ - connection, value_params in update: - - if value_params: - c = connection.execute( - statement.values(value_params), - params) + for (connection, paramkeys, hasvalue), \ + records in groupby( + update, + lambda rec: ( + rec[4], + tuple(sorted(rec[2])), + bool(rec[5])) + ): + + rows = 0 + records = list(records) + if hasvalue: + for state, state_dict, params, mapper, \ + connection, value_params in records: + c = connection.execute( + statement.values(value_params), + params) + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount else: + multiparams = [rec[2] for rec in records] c = cached_connections[connection].\ - execute(statement, params) - - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) - rows += c.rowcount - - if connection.dialect.supports_sane_rowcount: - if rows != len(update): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(update), rows)) - - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) + execute(statement, multiparams) + + rows += c.rowcount + for state, state_dict, params, mapper, \ + connection, value_params in records: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) + + if connection.dialect.supports_sane_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(records), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) def _emit_insert_statements(base_mapper, uowtransaction, @@ -833,15 +857,12 @@ def _connections_for_states(base_mapper, uowtransaction, states): connection_callable = \ uowtransaction.session.connection_callable else: - connection = None + connection = uowtransaction.transaction.connection(base_mapper) connection_callable = None for state in _sort_states(states): if connection_callable: connection = connection_callable(base_mapper, state.obj()) - elif not connection: - connection = uowtransaction.transaction.connection( - base_mapper) mapper = _state_mapper(state) |
