summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-08-19 14:24:56 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-08-19 14:24:56 -0400
commit91959122e0a12943e5ff9399024c65ad4d7489e1 (patch)
treeb1645a23ca575b4b2c87529029b44ee389bbf67d /lib/sqlalchemy/orm/persistence.py
parenta251001f24e819f1ebc525948437563f52a3a226 (diff)
downloadsqlalchemy-91959122e0a12943e5ff9399024c65ad4d7489e1.tar.gz
- refinements
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py107
1 files changed, 72 insertions, 35 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 145a7783a..9c0008925 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -23,17 +23,22 @@ from ..sql import expression
from . import loading
-def bulk_insert(mapper, mappings, uowtransaction):
+def _bulk_insert(mapper, mappings, session_transaction, isstates):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
- if uowtransaction.session.connection_callable:
+ if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
"not supported in bulk_insert()")
- connection = uowtransaction.transaction.connection(base_mapper)
+ if isstates:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ connection = session_transaction.connection(base_mapper)
for table, super_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(super_mapper):
continue
@@ -45,61 +50,55 @@ def bulk_insert(mapper, mappings, uowtransaction):
state, state_dict, params, mp,
conn, value_params, has_all_pks,
has_all_defaults in _collect_insert_commands(table, (
- (None, mapping, super_mapper, connection)
- for mapping in mappings)
+ (None, mapping, mapper, connection)
+ for mapping in mappings),
+ bulk=True
)
)
- _emit_insert_statements(base_mapper, uowtransaction,
+ _emit_insert_statements(base_mapper, None,
cached_connections,
super_mapper, table, records,
bookkeeping=False)
-def bulk_update(mapper, mappings, uowtransaction):
+def _bulk_update(mapper, mappings, session_transaction, isstates):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
- if uowtransaction.session.connection_callable:
+ def _changed_dict(mapper, state):
+ return dict(
+ (k, v)
+ for k, v in state.dict.items() if k in state.committed_state or k
+ in mapper._primary_key_propkeys
+ )
+
+ if isstates:
+ mappings = [_changed_dict(mapper, state) for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
"not supported in bulk_update()")
- connection = uowtransaction.transaction.connection(base_mapper)
+ connection = session_transaction.connection(base_mapper)
value_params = {}
+
for table, super_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(super_mapper):
continue
- label_pks = super_mapper._pks_by_table[table]
- if mapper.version_id_col is not None:
- label_pks = label_pks.union([mapper.version_id_col])
-
- to_translate = dict(
- (propkey, col._label if col in label_pks else col.key)
- for propkey, col in super_mapper._propkey_to_col[table].items()
+ records = (
+ (None, None, params, super_mapper, connection, value_params)
+ for
+ params in _collect_bulk_update_commands(mapper, table, mappings)
)
- records = []
- for mapping in mappings:
- params = dict(
- (to_translate[k], v) for k, v in mapping.items()
- )
-
- if mapper.version_id_generator is not False and \
- mapper.version_id_col is not None and \
- mapper.version_id_col.key not in params:
- params[mapper.version_id_col.key] = \
- mapper.version_id_generator(
- params[mapper.version_id_col._label])
-
- records.append(
- (None, None, params, super_mapper, connection, value_params)
- )
-
- _emit_update_statements(base_mapper, uowtransaction,
+ _emit_update_statements(base_mapper, None,
cached_connections,
super_mapper, table, records,
bookkeeping=False)
@@ -360,7 +359,7 @@ def _collect_insert_commands(table, states_to_insert, bulk=False):
col = propkey_to_col[propkey]
if value is None:
continue
- elif isinstance(value, sql.ClauseElement):
+ elif not bulk and isinstance(value, sql.ClauseElement):
value_params[col.key] = value
else:
params[col.key] = value
@@ -481,6 +480,44 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
state, state_dict, params, mapper,
connection, value_params)
+def _collect_bulk_update_commands(mapper, table, mappings):
+ label_pks = mapper._pks_by_table[table]
+ if mapper.version_id_col is not None:
+ label_pks = label_pks.union([mapper.version_id_col])
+
+ to_translate = dict(
+ (propkey, col.key if col not in label_pks else col._label)
+ for propkey, col in mapper._propkey_to_col[table].items()
+ )
+
+ for mapping in mappings:
+ params = dict(
+ (to_translate[k], mapping[k]) for k in to_translate
+ if k in mapping and k not in mapper._primary_key_propkeys
+ )
+
+ if not params:
+ continue
+
+ try:
+ params.update(
+ (to_translate[k], mapping[k]) for k in
+ mapper._primary_key_propkeys.intersection(to_translate)
+ )
+ except KeyError as ke:
+ raise orm_exc.FlushError(
+ "Can't update table using NULL for primary "
+ "key attribute: %s" % ke)
+
+ if mapper.version_id_generator is not False and \
+ mapper.version_id_col is not None and \
+ mapper.version_id_col.key not in params:
+ params[mapper.version_id_col.key] = \
+ mapper.version_id_generator(
+ params[mapper.version_id_col._label])
+
+ yield params
+
def _collect_post_update_commands(base_mapper, uowtransaction, table,
states_to_update, post_update_cols):