summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-08-14 17:44:58 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-08-15 15:53:12 -0400
commit6bc676f56d57d5ea4dc298f63d0e3a77c0f4a4a1 (patch)
tree6cc346a727e50cfd8cbadb73df026fab533b8386 /lib/sqlalchemy
parent191fd3e27e3ef90190f8315c33ba6eb97aeaf5d2 (diff)
downloadsqlalchemy-6bc676f56d57d5ea4dc298f63d0e3a77c0f4a4a1.tar.gz
dev
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/persistence.py59
-rw-r--r--lib/sqlalchemy/orm/session.py23
-rw-r--r--lib/sqlalchemy/orm/state.py15
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py10
4 files changed, 70 insertions, 37 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 511a324be..64c8440c4 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -18,7 +18,7 @@ import operator
from itertools import groupby
from .. import sql, util, exc as sa_exc, schema
from . import attributes, sync, exc as orm_exc, evaluator
-from .base import _state_mapper, state_str, _attr_as_key
+from .base import state_str, _attr_as_key
from ..sql import expression
from . import loading
@@ -65,7 +65,8 @@ def save_obj(
if insert:
_emit_insert_statements(base_mapper, uowtransaction,
cached_connections,
- mapper, table, insert)
+ mapper, table, insert,
+ bookkeeping)
_finalize_insert_update_commands(base_mapper, uowtransaction,
states_to_insert, states_to_update,
@@ -140,13 +141,16 @@ def _organize_states_for_save(
states_to_insert = []
states_to_update = []
+ instance_key = None
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction,
states):
has_identity = bool(state.key)
- instance_key = state.key or mapper._identity_key_from_state(state)
+
+ if bookkeeping:
+ instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
@@ -188,12 +192,12 @@ def _organize_states_for_save(
if not has_identity and not row_switch:
states_to_insert.append(
(state, dict_, mapper, connection,
- has_identity, instance_key, row_switch)
+ has_identity, row_switch)
)
else:
states_to_update.append(
(state, dict_, mapper, connection,
- has_identity, instance_key, row_switch)
+ has_identity, row_switch)
)
return states_to_insert, states_to_update
@@ -242,7 +246,8 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
"""
insert = []
for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_insert:
+ row_switch in states_to_insert:
+
if table not in mapper._pks_by_table:
continue
@@ -265,13 +270,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
prop = mapper._columntoproperty[col]
value = state_dict.get(prop.key, None)
- if value is None:
- if bookkeeping and col in pks:
+ if bookkeeping and value is None:
+ if col in pks:
has_all_pks = False
elif col.default is None and \
col.server_default is None:
params[col.key] = value
- elif bookkeeping and col.server_default is not None and \
+ elif col.server_default is not None and \
mapper.base_mapper.eager_defaults:
has_all_defaults = False
@@ -301,7 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
update = []
for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_update:
+ row_switch in states_to_update:
if table not in mapper._pks_by_table:
continue
@@ -567,7 +572,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert):
+ cached_connections, mapper, table, insert,
+ bookkeeping):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
@@ -593,19 +599,20 @@ def _emit_insert_statements(base_mapper, uowtransaction,
c = cached_connections[connection].\
execute(statement, multiparams)
- for (state, state_dict, params, mapper_rec,
- conn, value_params, has_all_pks, has_all_defaults), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params)
+ if bookkeeping:
+ for (state, state_dict, params, mapper_rec,
+ conn, value_params, has_all_pks, has_all_defaults), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params)
else:
if not has_all_defaults and base_mapper.eager_defaults:
@@ -768,7 +775,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
"""
for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in states_to_insert + \
+ row_switch in states_to_insert + \
states_to_update:
if bookkeeping:
@@ -871,7 +878,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
if connection_callable:
connection = connection_callable(base_mapper, state.obj())
- mapper = _state_mapper(state)
+ mapper = state.manager.mapper
yield state, state.dict, mapper, connection
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 2455c803a..546355611 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -482,7 +482,7 @@ class Session(_SessionClassMethods):
'__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested',
'close', 'commit', 'connection', 'delete', 'execute', 'expire',
'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind',
- 'is_modified',
+ 'is_modified', 'bulk_save_objects', 'bulk_save_mappings',
'merge', 'query', 'refresh', 'rollback',
'scalar')
@@ -2033,31 +2033,42 @@ class Session(_SessionClassMethods):
with util.safe_reraise():
transaction.rollback(_capture_exception=True)
- def bulk_save(self, objects):
+ def bulk_save_objects(self, objects):
+ self._bulk_save((attributes.instance_state(obj) for obj in objects))
+
+ def bulk_save_mappings(self, mapper, mappings):
+ mapper = class_mapper(mapper)
+
+ self._bulk_save((
+ statelib.MappingState(mapper, mapping)
+ for mapping in mappings)
+ )
+
+ def _bulk_save(self, states):
self._flushing = True
flush_context = UOWTransaction(self)
if self.dispatch.before_bulk_save:
self.dispatch.before_bulk_save(
- self, flush_context, objects)
+ self, flush_context, states)
flush_context.transaction = transaction = self.begin(
subtransactions=True)
try:
self._warn_on_events = True
try:
- flush_context.bulk_save(objects)
+ flush_context.bulk_save(states)
finally:
self._warn_on_events = False
self.dispatch.after_bulk_save(
- self, flush_context, objects
+ self, flush_context, states
)
flush_context.finalize_flush_changes()
self.dispatch.after_bulk_save_postexec(
- self, flush_context, objects)
+ self, flush_context, states)
transaction.commit()
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index fe8ccd222..e941bc1a4 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -580,6 +580,21 @@ class InstanceState(interfaces.InspectionAttr):
state._strong_obj = None
+class MappingState(InstanceState):
+ committed_state = {}
+ callables = {}
+
+ def __init__(self, mapper, mapping):
+ self.class_ = mapper.class_
+ self.manager = mapper.class_manager
+ self.modified = True
+ self._dict = mapping
+
+ @property
+ def dict(self):
+ return self._dict
+
+
class AttributeState(object):
"""Provide an inspection interface corresponding
to a particular attribute on a particular mapped object.
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 8df24e95a..bc8a0f556 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -394,9 +394,9 @@ class UOWTransaction(object):
if other:
self.session._register_newly_persistent(other)
- def bulk_save(self, objects):
- for (base_mapper, in_session), states in itertools.groupby(
- (attributes.instance_state(obj) for obj in objects),
+ def bulk_save(self, states):
+ for (base_mapper, in_session), states_ in itertools.groupby(
+ states,
lambda state:
(
state.mapper.base_mapper,
@@ -404,12 +404,12 @@ class UOWTransaction(object):
)):
persistence.save_obj(
- base_mapper, list(states), self, bookkeeping=in_session)
+ base_mapper, list(states_), self, bookkeeping=in_session)
if in_session:
self.states.update(
(state, (False, False))
- for state in states
+ for state in states_
)