diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-08-19 14:24:56 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-08-19 14:24:56 -0400 |
| commit | 91959122e0a12943e5ff9399024c65ad4d7489e1 (patch) | |
| tree | b1645a23ca575b4b2c87529029b44ee389bbf67d /lib/sqlalchemy/orm/persistence.py | |
| parent | a251001f24e819f1ebc525948437563f52a3a226 (diff) | |
| download | sqlalchemy-91959122e0a12943e5ff9399024c65ad4d7489e1.tar.gz | |
- refinements
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 107 |
1 files changed, 72 insertions, 35 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 145a7783a..9c0008925 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,17 +23,22 @@ from ..sql import expression from . import loading -def bulk_insert(mapper, mappings, uowtransaction): +def _bulk_insert(mapper, mappings, session_transaction, isstates): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) - if uowtransaction.session.connection_callable: + if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_insert()") - connection = uowtransaction.transaction.connection(base_mapper) + if isstates: + mappings = [state.dict for state in mappings] + else: + mappings = list(mappings) + + connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue @@ -45,61 +50,55 @@ def bulk_insert(mapper, mappings, uowtransaction): state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(table, ( - (None, mapping, super_mapper, connection) - for mapping in mappings) + (None, mapping, mapper, connection) + for mapping in mappings), + bulk=True ) ) - _emit_insert_statements(base_mapper, uowtransaction, + _emit_insert_statements(base_mapper, None, cached_connections, super_mapper, table, records, bookkeeping=False) -def bulk_update(mapper, mappings, uowtransaction): +def _bulk_update(mapper, mappings, session_transaction, isstates): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) - if uowtransaction.session.connection_callable: + def _changed_dict(mapper, state): + return dict( + (k, v) + for k, v in state.dict.items() if k in state.committed_state or k + in mapper._primary_key_propkeys + ) + + if isstates: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = list(mappings) + + if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_update()") - connection = uowtransaction.transaction.connection(base_mapper) + connection = session_transaction.connection(base_mapper) value_params = {} + for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue - label_pks = super_mapper._pks_by_table[table] - if mapper.version_id_col is not None: - label_pks = label_pks.union([mapper.version_id_col]) - - to_translate = dict( - (propkey, col._label if col in label_pks else col.key) - for propkey, col in super_mapper._propkey_to_col[table].items() + records = ( + (None, None, params, super_mapper, connection, value_params) + for + params in _collect_bulk_update_commands(mapper, table, mappings) ) - records = [] - for mapping in mappings: - params = dict( - (to_translate[k], v) for k, v in mapping.items() - ) - - if mapper.version_id_generator is not False and \ - mapper.version_id_col is not None and \ - mapper.version_id_col.key not in params: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator( - params[mapper.version_id_col._label]) - - records.append( - (None, None, params, super_mapper, connection, value_params) - ) - - _emit_update_statements(base_mapper, uowtransaction, + _emit_update_statements(base_mapper, None, cached_connections, super_mapper, table, records, bookkeeping=False) @@ -360,7 +359,7 @@ def _collect_insert_commands(table, states_to_insert, bulk=False): col = propkey_to_col[propkey] if value is None: continue - elif isinstance(value, sql.ClauseElement): + elif not bulk and isinstance(value, sql.ClauseElement): value_params[col.key] = value else: params[col.key] = value @@ -481,6 +480,44 @@ def _collect_update_commands(uowtransaction, table, states_to_update): state, state_dict, params, mapper, connection, value_params) +def _collect_bulk_update_commands(mapper, table, mappings): + label_pks = mapper._pks_by_table[table] + if mapper.version_id_col is not None: + label_pks = label_pks.union([mapper.version_id_col]) + + to_translate = dict( + (propkey, col.key if col not in label_pks else col._label) + for propkey, col in mapper._propkey_to_col[table].items() + ) + + for mapping in mappings: + params = dict( + (to_translate[k], mapping[k]) for k in to_translate + if k in mapping and k not in mapper._primary_key_propkeys + ) + + if not params: + continue + + try: + params.update( + (to_translate[k], mapping[k]) for k in + mapper._primary_key_propkeys.intersection(to_translate) + ) + except KeyError as ke: + raise orm_exc.FlushError( + "Can't update table using NULL for primary " + "key attribute: %s" % ke) + + if mapper.version_id_generator is not False and \ + mapper.version_id_col is not None and \ + mapper.version_id_col.key not in params: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator( + params[mapper.version_id_col._label]) + + yield params + def _collect_post_update_commands(base_mapper, uowtransaction, table, states_to_update, post_update_cols): |
