diff options
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 134 |
1 files changed, 90 insertions, 44 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 1f5507edf..b0fa620e3 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1,5 +1,5 @@ # orm/persistence.py -# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file> +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file> # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -17,7 +17,7 @@ import operator from itertools import groupby from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator -from .util import _state_mapper, state_str, _attr_as_key +from .base import _state_mapper, state_str, _attr_as_key from ..sql import expression from . import loading @@ -61,7 +61,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False): if insert: _emit_insert_statements(base_mapper, uowtransaction, cached_connections, - table, insert) + mapper, table, insert) _finalize_insert_update_commands(base_mapper, uowtransaction, states_to_insert, states_to_update) @@ -246,9 +246,12 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, value_params = {} has_all_pks = True + has_all_defaults = True for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col.key] = mapper.version_id_generator(None) + if col is mapper.version_id_col and \ + mapper.version_id_generator is not False: + val = mapper.version_id_generator(None) + params[col.key] = val else: # pull straight from the dict for # pending objects @@ -261,6 +264,9 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, elif col.default is None and \ col.server_default is None: params[col.key] = value + elif col.server_default is not None and \ + mapper.base_mapper.eager_defaults: + has_all_defaults = False elif isinstance(value, sql.ClauseElement): value_params[col] = value @@ -268,7 +274,8 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, params[col.key] = value insert.append((state, state_dict, params, mapper, - connection, value_params, has_all_pks)) + connection, value_params, has_all_pks, + has_all_defaults)) return insert @@ -315,19 +322,20 @@ def _collect_update_commands(base_mapper, uowtransaction, params[col.key] = history.added[0] hasdata = True else: - params[col.key] = mapper.version_id_generator( - params[col._label]) - - # HACK: check for history, in case the - # history is only - # in a different table than the one - # where the version_id_col is. - for prop in mapper._columntoproperty.values(): - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - hasdata = True + if mapper.version_id_generator is not False: + val = mapper.version_id_generator(params[col._label]) + params[col.key] = val + + # HACK: check for history, in case the + # history is only + # in a different table than the one + # where the version_id_col is. + for prop in mapper._columntoproperty.values(): + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + hasdata = True else: prop = mapper._columntoproperty[col] history = attributes.get_state_history( @@ -409,6 +417,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, mapper._get_state_attr_by_column( state, state_dict, col) + elif col in post_update_cols: prop = mapper._columntoproperty[col] history = attributes.get_state_history( @@ -478,7 +487,13 @@ def _emit_update_statements(base_mapper, uowtransaction, sql.bindparam(mapper.version_id_col._label, type_=mapper.version_id_col.type)) - return table.update(clause) + stmt = table.update(clause) + if mapper.base_mapper.eager_defaults: + stmt = stmt.return_defaults() + elif mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt statement = base_mapper._memo(('update', table), update_stmt) @@ -500,8 +515,7 @@ def _emit_update_statements(base_mapper, uowtransaction, table, state, state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, + c, c.context.compiled_parameters[0], value_params) rows += c.rowcount @@ -521,44 +535,55 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, table, insert): + cached_connections, mapper, table, insert): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" statement = base_mapper._memo(('insert', table), table.insert) - for (connection, pkeys, hasvalue, has_all_pks), \ + for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ records in groupby(insert, lambda rec: (rec[4], list(rec[2].keys()), bool(rec[5]), - rec[6]) + rec[6], rec[7]) ): - if has_all_pks and not hasvalue: + if \ + ( + has_all_defaults + or not base_mapper.eager_defaults + or not connection.dialect.implicit_returning + ) and has_all_pks and not hasvalue: + records = list(records) multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ execute(statement, multiparams) - for (state, state_dict, params, mapper, - conn, value_params, has_all_pks), \ + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ last_inserted_params in \ zip(records, c.context.compiled_parameters): _postfetch( - mapper, + mapper_rec, uowtransaction, table, state, state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, + c, last_inserted_params, value_params) else: - for state, state_dict, params, mapper, \ + if not has_all_defaults and base_mapper.eager_defaults: + statement = statement.return_defaults() + elif mapper.version_id_col is not None: + statement = statement.return_defaults(mapper.version_id_col) + + for state, state_dict, params, mapper_rec, \ connection, value_params, \ - has_all_pks in records: + has_all_pks, has_all_defaults in records: if value_params: result = connection.execute( @@ -574,23 +599,22 @@ def _emit_insert_statements(base_mapper, uowtransaction, # set primary key attributes for pk, col in zip(primary_key, mapper._pks_by_table[table]): - prop = mapper._columntoproperty[col] + prop = mapper_rec._columntoproperty[col] if state_dict.get(prop.key) is None: # TODO: would rather say: #state_dict[prop.key] = pk - mapper._set_state_attr_by_column( + mapper_rec._set_state_attr_by_column( state, state_dict, col, pk) _postfetch( - mapper, + mapper_rec, uowtransaction, table, state, state_dict, - result.context.prefetch_cols, - result.context.postfetch_cols, + result, result.context.compiled_parameters[0], value_params) @@ -699,14 +723,25 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, if readonly: state._expire_attributes(state.dict, readonly) - # if eager_defaults option is enabled, - # refresh whatever has been expired. - if base_mapper.eager_defaults and state.unloaded: + # if eager_defaults option is enabled, load + # all expired cols. Else if we have a version_id_col, make sure + # it isn't expired. + toload_now = [] + + if base_mapper.eager_defaults: + toload_now.extend(state._unloaded_non_object) + elif mapper.version_id_col is not None and \ + mapper.version_id_generator is False: + prop = mapper._columntoproperty[mapper.version_id_col] + if prop.key in state.unloaded: + toload_now.extend([prop.key]) + + if toload_now: state.key = base_mapper._identity_key_from_state(state) loading.load_on_ident( uowtransaction.session.query(base_mapper), state.key, refresh_state=state, - only_load_props=state.unloaded) + only_load_props=toload_now) # call after_XXX extensions if not has_identity: @@ -716,15 +751,26 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, def _postfetch(mapper, uowtransaction, table, - state, dict_, prefetch_cols, postfetch_cols, - params, value_params): + state, dict_, result, params, value_params): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" + prefetch_cols = result.context.prefetch_cols + postfetch_cols = result.context.postfetch_cols + returning_cols = result.context.returning_cols + if mapper.version_id_col is not None: prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + if returning_cols: + row = result.context.returned_defaults + if row is not None: + for col in returning_cols: + if col.primary_key: + continue + mapper._set_state_attr_by_column(state, dict_, col, row[col]) + for c in prefetch_cols: if c.key in params and c in mapper._columntoproperty: mapper._set_state_attr_by_column(state, dict_, c, params[c.key]) |