summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-08-18 16:32:48 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-08-18 16:32:59 -0400
commit06dec268e53e999bd348ef2ca148def066ca30d6 (patch)
tree6c63af264205f4f0da0925e70657088c7d950632 /lib
parentd39927ec20dd0b66f4ab3aab3e4e67b3814186ce (diff)
downloadsqlalchemy-06dec268e53e999bd348ef2ca148def066ca30d6.tar.gz
- organize persistence methods in terms of generators,
narrow down argument lists and generator items for each function down to just what each function needs. This will help for them to be of more multipurpose use for bulk operations
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/persistence.py187
1 files changed, 94 insertions, 93 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 228cfef3a..c7850ac1d 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -40,32 +40,58 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
save_obj(base_mapper, [state], uowtransaction, single=True)
return
- states_to_insert, states_to_update = _organize_states_for_save(
- base_mapper,
- states,
- uowtransaction)
-
+ states_to_update = []
+ states_to_insert = []
cached_connections = _cached_connection_dict(base_mapper)
- for table, mapper in base_mapper._sorted_tables.items():
- insert = _collect_insert_commands(base_mapper, uowtransaction,
- table, states_to_insert)
-
- update = _collect_update_commands(base_mapper, uowtransaction,
- table, states_to_update)
-
- if update:
- _emit_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
-
- if insert:
- _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, insert)
+ for (state, dict_, mapper, connection,
+ has_identity, row_switch) in _organize_states_for_save(
+ base_mapper, states, uowtransaction
+ ):
+ if has_identity or row_switch:
+ states_to_update.append(
+ (state, dict_, mapper, connection,
+ has_identity, row_switch)
+ )
+ else:
+ states_to_insert.append(
+ (state, dict_, mapper, connection,
+ has_identity, row_switch)
+ )
- _finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update)
+ for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
+ insert = (
+ (state, state_dict, mapper, connection)
+ for state, state_dict, mapper, connection, has_identity,
+ row_switch in states_to_insert
+ )
+ insert = _collect_insert_commands(table, insert)
+
+ update = (
+ (state, state_dict, mapper, connection, row_switch)
+ for state, state_dict, mapper, connection, has_identity,
+ row_switch in states_to_update
+ )
+ update = _collect_update_commands(uowtransaction, table, update)
+
+ _emit_update_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, update)
+
+ _emit_insert_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, insert)
+
+ _finalize_insert_update_commands(
+ base_mapper, uowtransaction,
+ (
+ (state, state_dict, mapper, connection, has_identity)
+ for state, state_dict, mapper, connection, has_identity,
+ row_switch in states_to_insert + states_to_update
+ )
+ )
def post_update(base_mapper, states, uowtransaction, post_update_cols):
@@ -75,19 +101,20 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
"""
cached_connections = _cached_connection_dict(base_mapper)
- states_to_update = _organize_states_for_post_update(
+ states_to_update = list(_organize_states_for_post_update(
base_mapper,
- states, uowtransaction)
+ states, uowtransaction))
for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
update = _collect_post_update_commands(base_mapper, uowtransaction,
table, states_to_update,
post_update_cols)
- if update:
- _emit_post_update_statements(base_mapper, uowtransaction,
- cached_connections,
- mapper, table, update)
+ _emit_post_update_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, update)
def delete_obj(base_mapper, states, uowtransaction):
@@ -100,19 +127,21 @@ def delete_obj(base_mapper, states, uowtransaction):
cached_connections = _cached_connection_dict(base_mapper)
- states_to_delete = _organize_states_for_delete(
+ states_to_delete = list(_organize_states_for_delete(
base_mapper,
states,
- uowtransaction)
+ uowtransaction))
table_to_mapper = base_mapper._sorted_tables
for table in reversed(list(table_to_mapper.keys())):
+ mapper = table_to_mapper[table]
+ if table not in mapper._pks_by_table:
+ continue
+
delete = _collect_delete_commands(base_mapper, uowtransaction,
table, states_to_delete)
- mapper = table_to_mapper[table]
-
_emit_delete_statements(base_mapper, uowtransaction,
cached_connections, mapper, table, delete)
@@ -133,9 +162,6 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
"""
- states_to_insert = []
- states_to_update = []
-
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction,
states):
@@ -181,18 +207,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
uowtransaction.remove_state_actions(existing)
row_switch = existing
- if not has_identity and not row_switch:
- states_to_insert.append(
- (state, dict_, mapper, connection,
- has_identity, row_switch)
- )
- else:
- states_to_update.append(
- (state, dict_, mapper, connection,
- has_identity, row_switch)
- )
-
- return states_to_insert, states_to_update
+ yield (state, dict_, mapper, connection,
+ has_identity, row_switch)
def _organize_states_for_post_update(base_mapper, states,
@@ -205,8 +221,7 @@ def _organize_states_for_post_update(base_mapper, states,
the execution per state.
"""
- return list(_connections_for_states(base_mapper, uowtransaction,
- states))
+ return _connections_for_states(base_mapper, uowtransaction, states)
def _organize_states_for_delete(base_mapper, states, uowtransaction):
@@ -217,28 +232,21 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
mapper, the connection to use for the execution per state.
"""
- states_to_delete = []
-
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction,
states):
mapper.dispatch.before_delete(mapper, connection, state)
- states_to_delete.append((state, dict_, mapper,
- bool(state.key), connection))
- return states_to_delete
+ yield state, dict_, mapper, bool(state.key), connection
-def _collect_insert_commands(base_mapper, uowtransaction, table,
- states_to_insert):
+def _collect_insert_commands(table, states_to_insert):
"""Identify sets of values to use in INSERT statements for a
list of states.
"""
- insert = []
- for state, state_dict, mapper, connection, has_identity, \
- row_switch in states_to_insert:
+ for state, state_dict, mapper, connection in states_to_insert:
if table not in mapper._pks_by_table:
continue
@@ -262,7 +270,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
- if base_mapper.eager_defaults:
+ if mapper.base_mapper.eager_defaults:
has_all_defaults = mapper._server_default_cols[table].\
issubset(params)
else:
@@ -274,14 +282,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
params[mapper.version_id_col.key] = \
mapper.version_id_generator(None)
- insert.append((state, state_dict, params, mapper,
- connection, value_params, has_all_pks,
- has_all_defaults))
- return insert
+ yield (
+ state, state_dict, params, mapper,
+ connection, value_params, has_all_pks,
+ has_all_defaults)
-def _collect_update_commands(base_mapper, uowtransaction,
- table, states_to_update):
+def _collect_update_commands(uowtransaction, table, states_to_update):
"""Identify sets of values to use in UPDATE statements for a
list of states.
@@ -293,9 +300,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
"""
- update = []
- for state, state_dict, mapper, connection, has_identity, \
- row_switch in states_to_update:
+ for state, state_dict, mapper, connection, row_switch in states_to_update:
if table not in mapper._pks_by_table:
continue
@@ -368,9 +373,9 @@ def _collect_update_commands(base_mapper, uowtransaction,
"Can't update table using NULL for primary "
"key value")
params.update(pk_params)
- update.append((state, state_dict, params, mapper,
- connection, value_params))
- return update
+ yield (
+ state, state_dict, params, mapper,
+ connection, value_params)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -380,7 +385,6 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
"""
- update = []
for state, state_dict, mapper, connection in states_to_update:
if table not in mapper._pks_by_table:
continue
@@ -405,9 +409,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
params[col.key] = value
hasdata = True
if hasdata:
- update.append((state, state_dict, params, mapper,
- connection))
- return update
+ yield params, connection
def _collect_delete_commands(base_mapper, uowtransaction, table,
@@ -415,15 +417,12 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
"""Identify values to use in DELETE statements for a list of
states to be deleted."""
- delete = util.defaultdict(list)
-
for state, state_dict, mapper, has_identity, connection \
in states_to_delete:
if not has_identity or table not in mapper._pks_by_table:
continue
params = {}
- delete[connection].append(params)
for col in mapper._pks_by_table[table]:
params[col.key] = \
value = \
@@ -441,7 +440,7 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
mapper._get_committed_state_attr_by_column(
state, state_dict,
mapper.version_id_col)
- return delete
+ yield params, connection
def _emit_update_statements(base_mapper, uowtransaction,
@@ -481,8 +480,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
lambda rec: (
rec[4],
tuple(sorted(rec[2])),
- bool(rec[5]))
- ):
+ bool(rec[5]))):
rows = 0
records = list(records)
@@ -652,11 +650,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
# also group them into common (connection, cols) sets
# to support executemany().
for key, grouper in groupby(
- update, lambda rec: (rec[4], list(rec[2].keys()))
+ update, lambda rec: (rec[1], sorted(rec[0]))
):
connection = key[0]
- multiparams = [params for state, state_dict,
- params, mapper, conn in grouper]
+ multiparams = [params for params, conn in grouper]
cached_connections[connection].\
execute(statement, multiparams)
@@ -686,8 +683,15 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
return table.delete(clause)
- for connection, del_objects in delete.items():
- statement = base_mapper._memo(('delete', table), delete_stmt)
+ statement = base_mapper._memo(('delete', table), delete_stmt)
+ for connection, recs in groupby(
+ delete,
+ lambda rec: rec[1]
+ ):
+ del_objects = [
+ params
+ for params, connection in recs
+ ]
connection = cached_connections[connection]
@@ -740,15 +744,12 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
)
-def _finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update):
+def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
"""finalize state on states that have been inserted or updated,
including calling after_insert/after_update events.
"""
- for state, state_dict, mapper, connection, has_identity, \
- row_switch in states_to_insert + \
- states_to_update:
+ for state, state_dict, mapper, connection, has_identity in states:
if mapper._readonly_props:
readonly = state.unmodified_intersection(