summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py385
1 files changed, 268 insertions, 117 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 6b8d5af14..c3b2d7bcb 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -15,7 +15,7 @@ in unitofwork.py.
"""
import operator
-from itertools import groupby
+from itertools import groupby, chain
from .. import sql, util, exc as sa_exc, schema
from . import attributes, sync, exc as orm_exc, evaluator
from .base import state_str, _attr_as_key, _entity_descriptor
@@ -23,7 +23,105 @@ from ..sql import expression
from . import loading
-def save_obj(base_mapper, states, uowtransaction, single=False):
+def _bulk_insert(
+ mapper, mappings, session_transaction, isstates, return_defaults):
+ base_mapper = mapper.base_mapper
+
+ cached_connections = _cached_connection_dict(base_mapper)
+
+ if session_transaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_insert()")
+
+ if isstates:
+ if return_defaults:
+ states = [(state, state.dict) for state in mappings]
+ mappings = [dict_ for (state, dict_) in states]
+ else:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ connection = session_transaction.connection(base_mapper)
+ for table, super_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(super_mapper):
+ continue
+
+ records = (
+ (None, state_dict, params, mapper,
+ connection, value_params, has_all_pks, has_all_defaults)
+ for
+ state, state_dict, params, mp,
+ conn, value_params, has_all_pks,
+ has_all_defaults in _collect_insert_commands(table, (
+ (None, mapping, mapper, connection)
+ for mapping in mappings),
+ bulk=True, return_defaults=return_defaults
+ )
+ )
+ _emit_insert_statements(base_mapper, None,
+ cached_connections,
+ super_mapper, table, records,
+ bookkeeping=return_defaults)
+
+ if return_defaults and isstates:
+ identity_cls = mapper._identity_class
+ identity_props = [p.key for p in mapper._identity_key_props]
+ for state, dict_ in states:
+ state.key = (
+ identity_cls,
+ tuple([dict_[key] for key in identity_props])
+ )
+
+
+def _bulk_update(mapper, mappings, session_transaction,
+ isstates, update_changed_only):
+ base_mapper = mapper.base_mapper
+
+ cached_connections = _cached_connection_dict(base_mapper)
+
+ def _changed_dict(mapper, state):
+ return dict(
+ (k, v)
+ for k, v in state.dict.items() if k in state.committed_state or k
+ in mapper._primary_key_propkeys
+ )
+
+ if isstates:
+ if update_changed_only:
+ mappings = [_changed_dict(mapper, state) for state in mappings]
+ else:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ if session_transaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_update()")
+
+ connection = session_transaction.connection(base_mapper)
+
+ for table, super_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(super_mapper):
+ continue
+
+ records = _collect_update_commands(None, table, (
+ (None, mapping, mapper, connection,
+ (mapping[mapper._version_id_prop.key]
+ if mapper._version_id_prop else None))
+ for mapping in mappings
+ ), bulk=True)
+
+ _emit_update_statements(base_mapper, None,
+ cached_connections,
+ super_mapper, table, records,
+ bookkeeping=False)
+
+
+def save_obj(
+ base_mapper, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
@@ -76,17 +174,16 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
_finalize_insert_update_commands(
base_mapper, uowtransaction,
- (
- (state, state_dict, mapper, connection, False)
- for state, state_dict, mapper, connection in states_to_insert
- )
- )
- _finalize_insert_update_commands(
- base_mapper, uowtransaction,
- (
- (state, state_dict, mapper, connection, True)
- for state, state_dict, mapper, connection,
- update_version_id in states_to_update
+ chain(
+ (
+ (state, state_dict, mapper, connection, False)
+ for state, state_dict, mapper, connection in states_to_insert
+ ),
+ (
+ (state, state_dict, mapper, connection, True)
+ for state, state_dict, mapper, connection,
+ update_version_id in states_to_update
+ )
)
)
@@ -261,7 +358,9 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
state, dict_, mapper, connection, update_version_id)
-def _collect_insert_commands(table, states_to_insert):
+def _collect_insert_commands(
+ table, states_to_insert,
+ bulk=False, return_defaults=False):
"""Identify sets of values to use in INSERT statements for a
list of states.
@@ -280,22 +379,26 @@ def _collect_insert_commands(table, states_to_insert):
col = propkey_to_col[propkey]
if value is None:
continue
- elif isinstance(value, sql.ClauseElement):
+ elif not bulk and isinstance(value, sql.ClauseElement):
value_params[col.key] = value
else:
params[col.key] = value
- for colkey in mapper._insert_cols_as_none[table].\
- difference(params).difference(value_params):
- params[colkey] = None
+ if not bulk:
+ for colkey in mapper._insert_cols_as_none[table].\
+ difference(params).difference(value_params):
+ params[colkey] = None
- has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
+ if not bulk or return_defaults:
+ has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
- if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].\
- issubset(params)
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_default_cols[table].\
+ issubset(params)
+ else:
+ has_all_defaults = True
else:
- has_all_defaults = True
+ has_all_defaults = has_all_pks = True
if mapper.version_id_generator is not False \
and mapper.version_id_col is not None and \
@@ -309,7 +412,9 @@ def _collect_insert_commands(table, states_to_insert):
has_all_defaults)
-def _collect_update_commands(uowtransaction, table, states_to_update):
+def _collect_update_commands(
+ uowtransaction, table, states_to_update,
+ bulk=False):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -329,23 +434,32 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
pks = mapper._pks_by_table[table]
- params = {}
value_params = {}
propkey_to_col = mapper._propkey_to_col[table]
- for propkey in set(propkey_to_col).intersection(state.committed_state):
- value = state_dict[propkey]
- col = propkey_to_col[propkey]
-
- if not state.manager[propkey].impl.is_equal(
- value, state.committed_state[propkey]):
- if isinstance(value, sql.ClauseElement):
- value_params[col] = value
- else:
- params[col.key] = value
+ if bulk:
+ params = dict(
+ (propkey_to_col[propkey].key, state_dict[propkey])
+ for propkey in
+ set(propkey_to_col).intersection(state_dict)
+ )
+ else:
+ params = {}
+ for propkey in set(propkey_to_col).intersection(
+ state.committed_state):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+
+ if not state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]):
+ if isinstance(value, sql.ClauseElement):
+ value_params[col] = value
+ else:
+ params[col.key] = value
- if update_version_id is not None:
+ if update_version_id is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
col = mapper.version_id_col
params[col._label] = update_version_id
@@ -357,28 +471,37 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
if not (params or value_params):
continue
- pk_params = {}
- for col in pks:
- propkey = mapper._columntoproperty[col].key
- history = state.manager[propkey].impl.get_history(
- state, state_dict, attributes.PASSIVE_OFF)
-
- if history.added:
- if not history.deleted or \
- ("pk_cascaded", state, col) in \
- uowtransaction.attributes:
- pk_params[col._label] = history.added[0]
- params.pop(col.key, None)
+ if bulk:
+ pk_params = dict(
+ (propkey_to_col[propkey]._label, state_dict.get(propkey))
+ for propkey in
+ set(propkey_to_col).
+ intersection(mapper._pk_keys_by_table[table])
+ )
+ else:
+ pk_params = {}
+ for col in pks:
+ propkey = mapper._columntoproperty[col].key
+
+ history = state.manager[propkey].impl.get_history(
+ state, state_dict, attributes.PASSIVE_OFF)
+
+ if history.added:
+ if not history.deleted or \
+ ("pk_cascaded", state, col) in \
+ uowtransaction.attributes:
+ pk_params[col._label] = history.added[0]
+ params.pop(col.key, None)
+ else:
+ # else, use the old value to locate the row
+ pk_params[col._label] = history.deleted[0]
+ params[col.key] = history.added[0]
else:
- # else, use the old value to locate the row
- pk_params[col._label] = history.deleted[0]
- params[col.key] = history.added[0]
- else:
- pk_params[col._label] = history.unchanged[0]
- if pk_params[col._label] is None:
- raise orm_exc.FlushError(
- "Can't update table %s using NULL for primary "
- "key value on column %s" % (table, col))
+ pk_params[col._label] = history.unchanged[0]
+ if pk_params[col._label] is None:
+ raise orm_exc.FlushError(
+ "Can't update table %s using NULL for primary "
+ "key value on column %s" % (table, col))
if params or value_params:
params.update(pk_params)
@@ -446,18 +569,19 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
"key value on column %s" % (table, col))
if update_version_id is not None and \
- table.c.contains_column(mapper.version_id_col):
+ mapper.version_id_col in mapper._cols_by_table[table]:
params[mapper.version_id_col.key] = update_version_id
yield params, connection
def _emit_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update):
+ cached_connections, mapper, table, update,
+ bookkeeping=True):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
needs_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
+ mapper.version_id_col in mapper._cols_by_table[table]
def update_stmt():
clause = sql.and_()
@@ -486,32 +610,42 @@ def _emit_update_statements(base_mapper, uowtransaction,
records in groupby(
update,
lambda rec: (
- rec[4],
- tuple(sorted(rec[2])),
- bool(rec[5]))):
+ rec[4], # connection
+ set(rec[2]), # set of parameter keys
+ bool(rec[5]))): # whether or not we have "value" parameters
rows = 0
records = list(records)
+
+ # TODO: would be super-nice to not have to determine this boolean
+ # inside the loop here, in the 99.9999% of the time there's only
+ # one connection in use
+ assert_singlerow = connection.dialect.supports_sane_rowcount
+ assert_multirow = assert_singlerow and \
+ connection.dialect.supports_sane_multi_rowcount
+ allow_multirow = not needs_version_id or assert_multirow
+
if hasvalue:
for state, state_dict, params, mapper, \
connection, value_params in records:
c = connection.execute(
statement.values(value_params),
params)
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params)
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
rows += c.rowcount
+ check_rowcount = True
else:
- if needs_version_id and \
- not connection.dialect.supports_sane_multi_rowcount and \
- connection.dialect.supports_sane_rowcount:
+ if not allow_multirow:
+ check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
connection, value_params in records:
c = cached_connections[connection].\
@@ -528,6 +662,12 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows += c.rowcount
else:
multiparams = [rec[2] for rec in records]
+
+ check_rowcount = assert_multirow or (
+ assert_singlerow and
+ len(multiparams) == 1
+ )
+
c = cached_connections[connection].\
execute(statement, multiparams)
@@ -544,7 +684,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
c.context.compiled_parameters[0],
value_params)
- if connection.dialect.supports_sane_rowcount:
+ if check_rowcount:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
@@ -558,20 +698,23 @@ def _emit_update_statements(base_mapper, uowtransaction,
def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert):
+ cached_connections, mapper, table, insert,
+ bookkeeping=True):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
statement = base_mapper._memo(('insert', table), table.insert)
for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
- records in groupby(insert,
- lambda rec: (rec[4],
- tuple(sorted(rec[2].keys())),
- bool(rec[5]),
- rec[6], rec[7])
- ):
- if \
+ records in groupby(
+ insert,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # parameter keys
+ bool(rec[5]), # whether we have "value" parameters
+ rec[6],
+ rec[7])):
+ if not bookkeeping or \
(
has_all_defaults
or not base_mapper.eager_defaults
@@ -584,19 +727,20 @@ def _emit_insert_statements(base_mapper, uowtransaction,
c = cached_connections[connection].\
execute(statement, multiparams)
- for (state, state_dict, params, mapper_rec,
- conn, value_params, has_all_pks, has_all_defaults), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params)
+ if bookkeeping:
+ for (state, state_dict, params, mapper_rec,
+ conn, value_params, has_all_pks, has_all_defaults), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params)
else:
if not has_all_defaults and base_mapper.eager_defaults:
@@ -657,7 +801,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
# also group them into common (connection, cols) sets
# to support executemany().
for key, grouper in groupby(
- update, lambda rec: (rec[1], sorted(rec[0]))
+ update, lambda rec: (
+ rec[1], # connection
+ set(rec[0]) # parameter keys
+ )
):
connection = key[0]
multiparams = [params for params, conn in grouper]
@@ -671,7 +818,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
by _collect_delete_commands()."""
need_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
+ mapper.version_id_col in mapper._cols_by_table[table]
def delete_stmt():
clause = sql.and_()
@@ -693,12 +840,9 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
statement = base_mapper._memo(('delete', table), delete_stmt)
for connection, recs in groupby(
delete,
- lambda rec: rec[1]
+ lambda rec: rec[1] # connection
):
- del_objects = [
- params
- for params, connection in recs
- ]
+ del_objects = [params for params, connection in recs]
connection = cached_connections[connection]
@@ -775,9 +919,8 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
toload_now.extend(state._unloaded_non_object)
elif mapper.version_id_col is not None and \
mapper.version_id_generator is False:
- prop = mapper._columntoproperty[mapper.version_id_col]
- if prop.key in state.unloaded:
- toload_now.extend([prop.key])
+ if mapper._version_id_prop.key in state.unloaded:
+ toload_now.extend([mapper._version_id_prop.key])
if toload_now:
state.key = base_mapper._identity_key_from_state(state)
@@ -794,7 +937,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
def _postfetch(mapper, uowtransaction, table,
- state, dict_, result, params, value_params):
+ state, dict_, result, params, value_params, bulk=False):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
@@ -803,7 +946,8 @@ def _postfetch(mapper, uowtransaction, table,
postfetch_cols = result.context.compiled.postfetch
returning_cols = result.context.compiled.returning
- if mapper.version_id_col is not None:
+ if mapper.version_id_col is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
if returning_cols:
@@ -829,10 +973,13 @@ def _postfetch(mapper, uowtransaction, table,
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates)
+ if state is None:
+ sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
+ else:
+ sync.populate(state, m, state, m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates)
def _connections_for_states(base_mapper, uowtransaction, states):
@@ -883,6 +1030,7 @@ class BulkUD(object):
def __init__(self, query):
self.query = query.enable_eagerloads(False)
+ self.mapper = self.query._bind_mapper()
@property
def session(self):
@@ -977,6 +1125,7 @@ class BulkFetch(BulkUD):
self.primary_table.primary_key)
self.matched_rows = session.execute(
select_stmt,
+ mapper=self.mapper,
params=query._params).fetchall()
@@ -987,7 +1136,6 @@ class BulkUpdate(BulkUD):
super(BulkUpdate, self).__init__(query)
self.query._no_select_modifiers("update")
self.values = values
- self.mapper = self.query._mapper_zero_or_none()
@classmethod
def factory(cls, query, synchronize_session, values):
@@ -1033,7 +1181,8 @@ class BulkUpdate(BulkUD):
self.context.whereclause, values)
self.result = self.query.session.execute(
- update_stmt, params=self.query._params)
+ update_stmt, params=self.query._params,
+ mapper=self.mapper)
self.rowcount = self.result.rowcount
def _do_post(self):
@@ -1060,8 +1209,10 @@ class BulkDelete(BulkUD):
delete_stmt = sql.delete(self.primary_table,
self.context.whereclause)
- self.result = self.query.session.execute(delete_stmt,
- params=self.query._params)
+ self.result = self.query.session.execute(
+ delete_stmt,
+ params=self.query._params,
+ mapper=self.mapper)
self.rowcount = self.result.rowcount
def _do_post(self):