diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-08-14 17:44:58 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-08-15 15:53:12 -0400 |
| commit | 6bc676f56d57d5ea4dc298f63d0e3a77c0f4a4a1 (patch) | |
| tree | 6cc346a727e50cfd8cbadb73df026fab533b8386 /lib/sqlalchemy | |
| parent | 191fd3e27e3ef90190f8315c33ba6eb97aeaf5d2 (diff) | |
| download | sqlalchemy-6bc676f56d57d5ea4dc298f63d0e3a77c0f4a4a1.tar.gz | |
dev
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 59 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 23 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/state.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 10 |
4 files changed, 70 insertions, 37 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 511a324be..64c8440c4 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -18,7 +18,7 @@ import operator from itertools import groupby from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator -from .base import _state_mapper, state_str, _attr_as_key +from .base import state_str, _attr_as_key from ..sql import expression from . import loading @@ -65,7 +65,8 @@ def save_obj( if insert: _emit_insert_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, insert) + mapper, table, insert, + bookkeeping) _finalize_insert_update_commands(base_mapper, uowtransaction, states_to_insert, states_to_update, @@ -140,13 +141,16 @@ def _organize_states_for_save( states_to_insert = [] states_to_update = [] + instance_key = None for state, dict_, mapper, connection in _connections_for_states( base_mapper, uowtransaction, states): has_identity = bool(state.key) - instance_key = state.key or mapper._identity_key_from_state(state) + + if bookkeeping: + instance_key = state.key or mapper._identity_key_from_state(state) row_switch = None @@ -188,12 +192,12 @@ def _organize_states_for_save( if not has_identity and not row_switch: states_to_insert.append( (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) + has_identity, row_switch) ) else: states_to_update.append( (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) + has_identity, row_switch) ) return states_to_insert, states_to_update @@ -242,7 +246,8 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, """ insert = [] for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert: + row_switch in states_to_insert: + if table not in mapper._pks_by_table: continue @@ -265,13 +270,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, prop = mapper._columntoproperty[col] value = state_dict.get(prop.key, None) - if value is None: - if bookkeeping and col in pks: + if bookkeeping and value is None: + if col in pks: has_all_pks = False elif col.default is None and \ col.server_default is None: params[col.key] = value - elif bookkeeping and col.server_default is not None and \ + elif col.server_default is not None and \ mapper.base_mapper.eager_defaults: has_all_defaults = False @@ -301,7 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction, update = [] for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_update: + row_switch in states_to_update: if table not in mapper._pks_by_table: continue @@ -567,7 +572,8 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert): + cached_connections, mapper, table, insert, + bookkeeping): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" @@ -593,19 +599,20 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + if bookkeeping: + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -768,7 +775,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, """ for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert + \ + row_switch in states_to_insert + \ states_to_update: if bookkeeping: @@ -871,7 +878,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): if connection_callable: connection = connection_callable(base_mapper, state.obj()) - mapper = _state_mapper(state) + mapper = state.manager.mapper yield state, state.dict, mapper, connection diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 2455c803a..546355611 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -482,7 +482,7 @@ class Session(_SessionClassMethods): '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', 'close', 'commit', 'connection', 'delete', 'execute', 'expire', 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', - 'is_modified', + 'is_modified', 'bulk_save_objects', 'bulk_save_mappings', 'merge', 'query', 'refresh', 'rollback', 'scalar') @@ -2033,31 +2033,42 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) - def bulk_save(self, objects): + def bulk_save_objects(self, objects): + self._bulk_save((attributes.instance_state(obj) for obj in objects)) + + def bulk_save_mappings(self, mapper, mappings): + mapper = class_mapper(mapper) + + self._bulk_save(( + statelib.MappingState(mapper, mapping) + for mapping in mappings) + ) + + def _bulk_save(self, states): self._flushing = True flush_context = UOWTransaction(self) if self.dispatch.before_bulk_save: self.dispatch.before_bulk_save( - self, flush_context, objects) + self, flush_context, states) flush_context.transaction = transaction = self.begin( subtransactions=True) try: self._warn_on_events = True try: - flush_context.bulk_save(objects) + flush_context.bulk_save(states) finally: self._warn_on_events = False self.dispatch.after_bulk_save( - self, flush_context, objects + self, flush_context, states ) flush_context.finalize_flush_changes() self.dispatch.after_bulk_save_postexec( - self, flush_context, objects) + self, flush_context, states) transaction.commit() diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index fe8ccd222..e941bc1a4 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -580,6 +580,21 @@ class InstanceState(interfaces.InspectionAttr): state._strong_obj = None +class MappingState(InstanceState): + committed_state = {} + callables = {} + + def __init__(self, mapper, mapping): + self.class_ = mapper.class_ + self.manager = mapper.class_manager + self.modified = True + self._dict = mapping + + @property + def dict(self): + return self._dict + + class AttributeState(object): """Provide an inspection interface corresponding to a particular attribute on a particular mapped object. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 8df24e95a..bc8a0f556 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -394,9 +394,9 @@ class UOWTransaction(object): if other: self.session._register_newly_persistent(other) - def bulk_save(self, objects): - for (base_mapper, in_session), states in itertools.groupby( - (attributes.instance_state(obj) for obj in objects), + def bulk_save(self, states): + for (base_mapper, in_session), states_ in itertools.groupby( + states, lambda state: ( state.mapper.base_mapper, @@ -404,12 +404,12 @@ class UOWTransaction(object): )): persistence.save_obj( - base_mapper, list(states), self, bookkeeping=in_session) + base_mapper, list(states_), self, bookkeeping=in_session) if in_session: self.states.update( (state, (False, False)) - for state in states + for state in states_ ) |
