From 191fd3e27e3ef90190f8315c33ba6eb97aeaf5d2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 14 Aug 2014 15:38:30 -0400 Subject: - proof of concept --- lib/sqlalchemy/orm/events.py | 9 +++++ lib/sqlalchemy/orm/persistence.py | 81 +++++++++++++++++++++------------------ lib/sqlalchemy/orm/session.py | 34 ++++++++++++++++ lib/sqlalchemy/orm/unitofwork.py | 28 +++++++++++++- 4 files changed, 113 insertions(+), 39 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index aa99673ba..097726c62 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1453,6 +1453,15 @@ class SessionEvents(event.Events): """ + def before_bulk_save(self, session, flush_context, objects): + """""" + + def after_bulk_save(self, session, flush_context, objects): + """""" + + def after_bulk_save_postexec(self, session, flush_context, objects): + """""" + def after_begin(self, session, transaction, connection): """Execute after a transaction is begun on a connection diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 9d39c39b0..511a324be 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,7 +23,9 @@ from ..sql import expression from . import loading -def save_obj(base_mapper, states, uowtransaction, single=False): +def save_obj( + base_mapper, states, uowtransaction, single=False, + bookkeeping=True): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -43,13 +45,14 @@ def save_obj(base_mapper, states, uowtransaction, single=False): states_to_insert, states_to_update = _organize_states_for_save( base_mapper, states, - uowtransaction) + uowtransaction, bookkeeping) cached_connections = _cached_connection_dict(base_mapper) for table, mapper in base_mapper._sorted_tables.items(): insert = _collect_insert_commands(base_mapper, uowtransaction, - table, states_to_insert) + table, states_to_insert, + bookkeeping) update = _collect_update_commands(base_mapper, uowtransaction, table, states_to_update) @@ -65,7 +68,8 @@ def save_obj(base_mapper, states, uowtransaction, single=False): mapper, table, insert) _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update) + states_to_insert, states_to_update, + bookkeeping) def post_update(base_mapper, states, uowtransaction, post_update_cols): @@ -121,7 +125,8 @@ def delete_obj(base_mapper, states, uowtransaction): mapper.dispatch.after_delete(mapper, connection, state) -def _organize_states_for_save(base_mapper, states, uowtransaction): +def _organize_states_for_save( + base_mapper, states, uowtransaction, bookkeeping): """Make an initial pass across a set of states for INSERT or UPDATE. @@ -158,7 +163,7 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): # no instance_key attached to it), and another instance # with the same identity key already exists as persistent. # convert to an UPDATE if so. - if not has_identity and \ + if bookkeeping and not has_identity and \ instance_key in uowtransaction.session.identity_map: instance = \ uowtransaction.session.identity_map[instance_key] @@ -230,7 +235,7 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): def _collect_insert_commands(base_mapper, uowtransaction, table, - states_to_insert): + states_to_insert, bookkeeping): """Identify sets of values to use in INSERT statements for a list of states. @@ -261,12 +266,12 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, value = state_dict.get(prop.key, None) if value is None: - if col in pks: + if bookkeeping and col in pks: has_all_pks = False elif col.default is None and \ col.server_default is None: params[col.key] = value - elif col.server_default is not None and \ + elif bookkeeping and col.server_default is not None and \ mapper.base_mapper.eager_defaults: has_all_defaults = False @@ -756,7 +761,8 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, def _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update): + states_to_insert, states_to_update, + bookkeeping): """finalize state on states that have been inserted or updated, including calling after_insert/after_update events. @@ -765,33 +771,34 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, instance_key, row_switch in states_to_insert + \ states_to_update: - if mapper._readonly_props: - readonly = state.unmodified_intersection( - [p.key for p in mapper._readonly_props - if p.expire_on_flush or p.key not in state.dict] - ) - if readonly: - state._expire_attributes(state.dict, readonly) - - # if eager_defaults option is enabled, load - # all expired cols. Else if we have a version_id_col, make sure - # it isn't expired. - toload_now = [] - - if base_mapper.eager_defaults: - toload_now.extend(state._unloaded_non_object) - elif mapper.version_id_col is not None and \ - mapper.version_id_generator is False: - prop = mapper._columntoproperty[mapper.version_id_col] - if prop.key in state.unloaded: - toload_now.extend([prop.key]) - - if toload_now: - state.key = base_mapper._identity_key_from_state(state) - loading.load_on_ident( - uowtransaction.session.query(base_mapper), - state.key, refresh_state=state, - only_load_props=toload_now) + if bookkeeping: + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state._expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, load + # all expired cols. Else if we have a version_id_col, make sure + # it isn't expired. + toload_now = [] + + if base_mapper.eager_defaults: + toload_now.extend(state._unloaded_non_object) + elif mapper.version_id_col is not None and \ + mapper.version_id_generator is False: + prop = mapper._columntoproperty[mapper.version_id_col] + if prop.key in state.unloaded: + toload_now.extend([prop.key]) + + if toload_now: + state.key = base_mapper._identity_key_from_state(state) + loading.load_on_ident( + uowtransaction.session.query(base_mapper), + state.key, refresh_state=state, + only_load_props=toload_now) # call after_XXX extensions if not has_identity: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 036045dba..2455c803a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2033,6 +2033,40 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) + def bulk_save(self, objects): + self._flushing = True + flush_context = UOWTransaction(self) + + if self.dispatch.before_bulk_save: + self.dispatch.before_bulk_save( + self, flush_context, objects) + + flush_context.transaction = transaction = self.begin( + subtransactions=True) + try: + self._warn_on_events = True + try: + flush_context.bulk_save(objects) + finally: + self._warn_on_events = False + + self.dispatch.after_bulk_save( + self, flush_context, objects + ) + + flush_context.finalize_flush_changes() + + self.dispatch.after_bulk_save_postexec( + self, flush_context, objects) + + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + finally: + self._flushing = False + def is_modified(self, instance, include_collections=True, passive=True): """Return ``True`` if the given instance has locally diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 71e61827b..8df24e95a 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -16,6 +16,7 @@ organizes them in order of dependency, and executes. from .. import util, event from ..util import topological from . import attributes, persistence, util as orm_util +import itertools def track_cascade_events(descriptor, prop): @@ -379,14 +380,37 @@ class UOWTransaction(object): execute() method has succeeded and the transaction has been committed. """ + if not self.states: + return + states = set(self.states) isdel = set( s for (s, (isdelete, listonly)) in self.states.items() if isdelete ) other = states.difference(isdel) - self.session._remove_newly_deleted(isdel) - self.session._register_newly_persistent(other) + if isdel: + self.session._remove_newly_deleted(isdel) + if other: + self.session._register_newly_persistent(other) + + def bulk_save(self, objects): + for (base_mapper, in_session), states in itertools.groupby( + (attributes.instance_state(obj) for obj in objects), + lambda state: + ( + state.mapper.base_mapper, + state.key is self.session.hash_key + )): + + persistence.save_obj( + base_mapper, list(states), self, bookkeeping=in_session) + + if in_session: + self.states.update( + (state, (False, False)) + for state in states + ) class IterateMappersMixin(object): -- cgit v1.2.1 From 6bc676f56d57d5ea4dc298f63d0e3a77c0f4a4a1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 14 Aug 2014 17:44:58 -0400 Subject: dev --- lib/sqlalchemy/orm/persistence.py | 59 ++++++++++++++++++++++----------------- lib/sqlalchemy/orm/session.py | 23 +++++++++++---- lib/sqlalchemy/orm/state.py | 15 ++++++++++ lib/sqlalchemy/orm/unitofwork.py | 10 +++---- 4 files changed, 70 insertions(+), 37 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 511a324be..64c8440c4 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -18,7 +18,7 @@ import operator from itertools import groupby from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator -from .base import _state_mapper, state_str, _attr_as_key +from .base import state_str, _attr_as_key from ..sql import expression from . import loading @@ -65,7 +65,8 @@ def save_obj( if insert: _emit_insert_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, insert) + mapper, table, insert, + bookkeeping) _finalize_insert_update_commands(base_mapper, uowtransaction, states_to_insert, states_to_update, @@ -140,13 +141,16 @@ def _organize_states_for_save( states_to_insert = [] states_to_update = [] + instance_key = None for state, dict_, mapper, connection in _connections_for_states( base_mapper, uowtransaction, states): has_identity = bool(state.key) - instance_key = state.key or mapper._identity_key_from_state(state) + + if bookkeeping: + instance_key = state.key or mapper._identity_key_from_state(state) row_switch = None @@ -188,12 +192,12 @@ def _organize_states_for_save( if not has_identity and not row_switch: states_to_insert.append( (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) + has_identity, row_switch) ) else: states_to_update.append( (state, dict_, mapper, connection, - has_identity, instance_key, row_switch) + has_identity, row_switch) ) return states_to_insert, states_to_update @@ -242,7 +246,8 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, """ insert = [] for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert: + row_switch in states_to_insert: + if table not in mapper._pks_by_table: continue @@ -265,13 +270,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, prop = mapper._columntoproperty[col] value = state_dict.get(prop.key, None) - if value is None: - if bookkeeping and col in pks: + if bookkeeping and value is None: + if col in pks: has_all_pks = False elif col.default is None and \ col.server_default is None: params[col.key] = value - elif bookkeeping and col.server_default is not None and \ + elif col.server_default is not None and \ mapper.base_mapper.eager_defaults: has_all_defaults = False @@ -301,7 +306,7 @@ def _collect_update_commands(base_mapper, uowtransaction, update = [] for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_update: + row_switch in states_to_update: if table not in mapper._pks_by_table: continue @@ -567,7 +572,8 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert): + cached_connections, mapper, table, insert, + bookkeeping): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" @@ -593,19 +599,20 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + if bookkeeping: + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -768,7 +775,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, """ for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in states_to_insert + \ + row_switch in states_to_insert + \ states_to_update: if bookkeeping: @@ -871,7 +878,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): if connection_callable: connection = connection_callable(base_mapper, state.obj()) - mapper = _state_mapper(state) + mapper = state.manager.mapper yield state, state.dict, mapper, connection diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 2455c803a..546355611 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -482,7 +482,7 @@ class Session(_SessionClassMethods): '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', 'close', 'commit', 'connection', 'delete', 'execute', 'expire', 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', - 'is_modified', + 'is_modified', 'bulk_save_objects', 'bulk_save_mappings', 'merge', 'query', 'refresh', 'rollback', 'scalar') @@ -2033,31 +2033,42 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) - def bulk_save(self, objects): + def bulk_save_objects(self, objects): + self._bulk_save((attributes.instance_state(obj) for obj in objects)) + + def bulk_save_mappings(self, mapper, mappings): + mapper = class_mapper(mapper) + + self._bulk_save(( + statelib.MappingState(mapper, mapping) + for mapping in mappings) + ) + + def _bulk_save(self, states): self._flushing = True flush_context = UOWTransaction(self) if self.dispatch.before_bulk_save: self.dispatch.before_bulk_save( - self, flush_context, objects) + self, flush_context, states) flush_context.transaction = transaction = self.begin( subtransactions=True) try: self._warn_on_events = True try: - flush_context.bulk_save(objects) + flush_context.bulk_save(states) finally: self._warn_on_events = False self.dispatch.after_bulk_save( - self, flush_context, objects + self, flush_context, states ) flush_context.finalize_flush_changes() self.dispatch.after_bulk_save_postexec( - self, flush_context, objects) + self, flush_context, states) transaction.commit() diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index fe8ccd222..e941bc1a4 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -580,6 +580,21 @@ class InstanceState(interfaces.InspectionAttr): state._strong_obj = None +class MappingState(InstanceState): + committed_state = {} + callables = {} + + def __init__(self, mapper, mapping): + self.class_ = mapper.class_ + self.manager = mapper.class_manager + self.modified = True + self._dict = mapping + + @property + def dict(self): + return self._dict + + class AttributeState(object): """Provide an inspection interface corresponding to a particular attribute on a particular mapped object. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 8df24e95a..bc8a0f556 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -394,9 +394,9 @@ class UOWTransaction(object): if other: self.session._register_newly_persistent(other) - def bulk_save(self, objects): - for (base_mapper, in_session), states in itertools.groupby( - (attributes.instance_state(obj) for obj in objects), + def bulk_save(self, states): + for (base_mapper, in_session), states_ in itertools.groupby( + states, lambda state: ( state.mapper.base_mapper, @@ -404,12 +404,12 @@ class UOWTransaction(object): )): persistence.save_obj( - base_mapper, list(states), self, bookkeeping=in_session) + base_mapper, list(states_), self, bookkeeping=in_session) if in_session: self.states.update( (state, (False, False)) - for state in states + for state in states_ ) -- cgit v1.2.1 From 591f2e4ed2d455cb2c5b9ece43d79fde4b109510 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 14 Aug 2014 19:47:23 -0400 Subject: - change to be represented as two very fast bulk_insert() and bulk_update() methods --- lib/sqlalchemy/orm/events.py | 9 +- lib/sqlalchemy/orm/persistence.py | 255 ++++++++++++++++++++++++++------------ lib/sqlalchemy/orm/session.py | 57 ++++----- lib/sqlalchemy/orm/state.py | 15 --- lib/sqlalchemy/orm/unitofwork.py | 22 +--- 5 files changed, 213 insertions(+), 145 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 097726c62..37ea3071b 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1453,13 +1453,16 @@ class SessionEvents(event.Events): """ - def before_bulk_save(self, session, flush_context, objects): + def before_bulk_insert(self, session, flush_context, mapper, mappings): """""" - def after_bulk_save(self, session, flush_context, objects): + def after_bulk_insert(self, session, flush_context, mapper, mappings): """""" - def after_bulk_save_postexec(self, session, flush_context, objects): + def before_bulk_update(self, session, flush_context, mapper, mappings): + """""" + + def after_bulk_update(self, session, flush_context, mapper, mappings): """""" def after_begin(self, session, transaction, connection): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 64c8440c4..a8d4bd695 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,9 +23,104 @@ from ..sql import expression from . import loading +def bulk_insert(mapper, mappings, uowtransaction): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + if uowtransaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_insert()") + + connection = uowtransaction.transaction.connection(base_mapper) + + for table, sub_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(sub_mapper): + continue + + to_translate = dict( + (mapper._columntoproperty[col].key, col.key) + for col in mapper._cols_by_table[table] + ) + has_version_generator = mapper.version_id_generator is not False and \ + mapper.version_id_col is not None + multiparams = [] + for mapping in mappings: + params = dict( + (k, mapping.get(v)) for k, v in to_translate.items() + ) + if has_version_generator: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator(None) + multiparams.append(params) + + statement = base_mapper._memo(('insert', table), table.insert) + cached_connections[connection].execute(statement, multiparams) + + +def bulk_update(mapper, mappings, uowtransaction): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + if uowtransaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_update()") + + connection = uowtransaction.transaction.connection(base_mapper) + + for table, sub_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(sub_mapper): + continue + + needs_version_id = sub_mapper.version_id_col is not None and \ + table.c.contains_column(sub_mapper.version_id_col) + + def update_stmt(): + return _update_stmt_for_mapper(sub_mapper, table, needs_version_id) + + statement = base_mapper._memo(('update', table), update_stmt) + + pks = mapper._pks_by_table[table] + to_translate = dict( + (mapper._columntoproperty[col].key, col._label + if col in pks else col.key) + for col in mapper._cols_by_table[table] + ) + + for colnames, sub_mappings in groupby( + mappings, + lambda mapping: sorted(tuple(mapping.keys()))): + + multiparams = [] + for mapping in sub_mappings: + params = dict( + (to_translate[k], v) for k, v in mapping.items() + ) + multiparams.append(params) + + c = cached_connections[connection].execute(statement, multiparams) + + rows = c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(multiparams): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(multiparams), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + + def save_obj( - base_mapper, states, uowtransaction, single=False, - bookkeeping=True): + base_mapper, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -45,14 +140,13 @@ def save_obj( states_to_insert, states_to_update = _organize_states_for_save( base_mapper, states, - uowtransaction, bookkeeping) + uowtransaction) cached_connections = _cached_connection_dict(base_mapper) for table, mapper in base_mapper._sorted_tables.items(): insert = _collect_insert_commands(base_mapper, uowtransaction, - table, states_to_insert, - bookkeeping) + table, states_to_insert) update = _collect_update_commands(base_mapper, uowtransaction, table, states_to_update) @@ -65,12 +159,11 @@ def save_obj( if insert: _emit_insert_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, insert, - bookkeeping) + mapper, table, insert) - _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update, - bookkeeping) + _finalize_insert_update_commands( + base_mapper, uowtransaction, + states_to_insert, states_to_update) def post_update(base_mapper, states, uowtransaction, post_update_cols): @@ -126,8 +219,7 @@ def delete_obj(base_mapper, states, uowtransaction): mapper.dispatch.after_delete(mapper, connection, state) -def _organize_states_for_save( - base_mapper, states, uowtransaction, bookkeeping): +def _organize_states_for_save(base_mapper, states, uowtransaction): """Make an initial pass across a set of states for INSERT or UPDATE. @@ -149,8 +241,7 @@ def _organize_states_for_save( has_identity = bool(state.key) - if bookkeeping: - instance_key = state.key or mapper._identity_key_from_state(state) + instance_key = state.key or mapper._identity_key_from_state(state) row_switch = None @@ -167,7 +258,7 @@ def _organize_states_for_save( # no instance_key attached to it), and another instance # with the same identity key already exists as persistent. # convert to an UPDATE if so. - if bookkeeping and not has_identity and \ + if not has_identity and \ instance_key in uowtransaction.session.identity_map: instance = \ uowtransaction.session.identity_map[instance_key] @@ -239,7 +330,7 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): def _collect_insert_commands(base_mapper, uowtransaction, table, - states_to_insert, bookkeeping): + states_to_insert): """Identify sets of values to use in INSERT statements for a list of states. @@ -270,7 +361,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, prop = mapper._columntoproperty[col] value = state_dict.get(prop.key, None) - if bookkeeping and value is None: + if value is None: if col in pks: has_all_pks = False elif col.default is None and \ @@ -481,6 +572,28 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, return delete +def _update_stmt_for_mapper(mapper, table, needs_version_id): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append( + mapper.version_id_col == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type)) + + stmt = table.update(clause) + if mapper.base_mapper.eager_defaults: + stmt = stmt.return_defaults() + elif mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt + + def _emit_update_statements(base_mapper, uowtransaction, cached_connections, mapper, table, update): """Emit UPDATE statements corresponding to value lists collected @@ -490,25 +603,7 @@ def _emit_update_statements(base_mapper, uowtransaction, table.c.contains_column(mapper.version_id_col) def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - if needs_version_id: - clause.clauses.append( - mapper.version_id_col == sql.bindparam( - mapper.version_id_col._label, - type_=mapper.version_id_col.type)) - - stmt = table.update(clause) - if mapper.base_mapper.eager_defaults: - stmt = stmt.return_defaults() - elif mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - - return stmt + return _update_stmt_for_mapper(mapper, table, needs_version_id) statement = base_mapper._memo(('update', table), update_stmt) @@ -572,8 +667,7 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert, - bookkeeping): + cached_connections, mapper, table, insert): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" @@ -599,20 +693,19 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - if bookkeeping: - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -768,8 +861,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, def _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update, - bookkeeping): + states_to_insert, states_to_update): """finalize state on states that have been inserted or updated, including calling after_insert/after_update events. @@ -778,34 +870,33 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, row_switch in states_to_insert + \ states_to_update: - if bookkeeping: - if mapper._readonly_props: - readonly = state.unmodified_intersection( - [p.key for p in mapper._readonly_props - if p.expire_on_flush or p.key not in state.dict] - ) - if readonly: - state._expire_attributes(state.dict, readonly) - - # if eager_defaults option is enabled, load - # all expired cols. Else if we have a version_id_col, make sure - # it isn't expired. - toload_now = [] - - if base_mapper.eager_defaults: - toload_now.extend(state._unloaded_non_object) - elif mapper.version_id_col is not None and \ - mapper.version_id_generator is False: - prop = mapper._columntoproperty[mapper.version_id_col] - if prop.key in state.unloaded: - toload_now.extend([prop.key]) - - if toload_now: - state.key = base_mapper._identity_key_from_state(state) - loading.load_on_ident( - uowtransaction.session.query(base_mapper), - state.key, refresh_state=state, - only_load_props=toload_now) + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state._expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, load + # all expired cols. Else if we have a version_id_col, make sure + # it isn't expired. + toload_now = [] + + if base_mapper.eager_defaults: + toload_now.extend(state._unloaded_non_object) + elif mapper.version_id_col is not None and \ + mapper.version_id_generator is False: + prop = mapper._columntoproperty[mapper.version_id_col] + if prop.key in state.unloaded: + toload_now.extend([prop.key]) + + if toload_now: + state.key = base_mapper._identity_key_from_state(state) + loading.load_on_ident( + uowtransaction.session.query(base_mapper), + state.key, refresh_state=state, + only_load_props=toload_now) # call after_XXX extensions if not has_identity: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 546355611..3199a4332 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -20,6 +20,7 @@ from .base import ( _class_to_mapper, _state_mapper, object_state, _none_set, state_str, instance_str ) +import itertools from .unitofwork import UOWTransaction from . import state as statelib import sys @@ -482,7 +483,8 @@ class Session(_SessionClassMethods): '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', 'close', 'commit', 'connection', 'delete', 'execute', 'expire', 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', - 'is_modified', 'bulk_save_objects', 'bulk_save_mappings', + 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings', + 'bulk_update_mappings', 'merge', 'query', 'refresh', 'rollback', 'scalar') @@ -2034,42 +2036,41 @@ class Session(_SessionClassMethods): transaction.rollback(_capture_exception=True) def bulk_save_objects(self, objects): - self._bulk_save((attributes.instance_state(obj) for obj in objects)) + for (mapper, isupdate), states in itertools.groupby( + (attributes.instance_state(obj) for obj in objects), + lambda state: (state.mapper, state.key is not None) + ): + if isupdate: + self.bulk_update_mappings(mapper, (s.dict for s in states)) + else: + self.bulk_insert_mappings(mapper, (s.dict for s in states)) - def bulk_save_mappings(self, mapper, mappings): - mapper = class_mapper(mapper) + def bulk_insert_mappings(self, mapper, mappings): + self._bulk_save_mappings(mapper, mappings, False) - self._bulk_save(( - statelib.MappingState(mapper, mapping) - for mapping in mappings) - ) + def bulk_update_mappings(self, mapper, mappings): + self._bulk_save_mappings(mapper, mappings, True) - def _bulk_save(self, states): + def _bulk_save_mappings(self, mapper, mappings, isupdate): + mapper = _class_to_mapper(mapper) self._flushing = True flush_context = UOWTransaction(self) - if self.dispatch.before_bulk_save: - self.dispatch.before_bulk_save( - self, flush_context, states) - flush_context.transaction = transaction = self.begin( subtransactions=True) try: - self._warn_on_events = True - try: - flush_context.bulk_save(states) - finally: - self._warn_on_events = False - - self.dispatch.after_bulk_save( - self, flush_context, states - ) - - flush_context.finalize_flush_changes() - - self.dispatch.after_bulk_save_postexec( - self, flush_context, states) - + if isupdate: + self.dispatch.before_bulk_update( + self, flush_context, mapper, mappings) + flush_context.bulk_update(mapper, mappings) + self.dispatch.after_bulk_update( + self, flush_context, mapper, mappings) + else: + self.dispatch.before_bulk_insert( + self, flush_context, mapper, mappings) + flush_context.bulk_insert(mapper, mappings) + self.dispatch.after_bulk_insert( + self, flush_context, mapper, mappings) transaction.commit() except: diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index e941bc1a4..fe8ccd222 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -580,21 +580,6 @@ class InstanceState(interfaces.InspectionAttr): state._strong_obj = None -class MappingState(InstanceState): - committed_state = {} - callables = {} - - def __init__(self, mapper, mapping): - self.class_ = mapper.class_ - self.manager = mapper.class_manager - self.modified = True - self._dict = mapping - - @property - def dict(self): - return self._dict - - class AttributeState(object): """Provide an inspection interface corresponding to a particular attribute on a particular mapped object. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index bc8a0f556..b3a1519c5 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -394,23 +394,11 @@ class UOWTransaction(object): if other: self.session._register_newly_persistent(other) - def bulk_save(self, states): - for (base_mapper, in_session), states_ in itertools.groupby( - states, - lambda state: - ( - state.mapper.base_mapper, - state.key is self.session.hash_key - )): - - persistence.save_obj( - base_mapper, list(states_), self, bookkeeping=in_session) - - if in_session: - self.states.update( - (state, (False, False)) - for state in states_ - ) + def bulk_insert(self, mapper, mappings): + persistence.bulk_insert(mapper, mappings, self) + + def bulk_update(self, mapper, mappings): + persistence.bulk_update(mapper, mappings, self) class IterateMappersMixin(object): -- cgit v1.2.1 From 8773307257550e86801217f2b77d47047718807a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 15 Aug 2014 18:22:08 -0400 Subject: - refine this enough so that _collect_insert_commands() seems to be more than twice as fast now (.039 vs. .091); bulk_insert() and bulk_update() do their own collection but now both call into _emit_insert_statements() / _emit_update_statements(); the approach seems to have no impact on insert speed, still .85 for the insert test --- lib/sqlalchemy/orm/mapper.py | 35 ++++++ lib/sqlalchemy/orm/persistence.py | 259 +++++++++++++++++++------------------- 2 files changed, 161 insertions(+), 133 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 06ec2bf14..fc15769cd 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1892,6 +1892,41 @@ class Mapper(InspectionAttr): """ + @_memoized_configured_property + def _col_to_propkey(self): + return dict( + ( + table, + [ + (col, self._columntoproperty[col].key) + for col in columns + ] + ) + for table, columns in self._cols_by_table.items() + ) + + @_memoized_configured_property + def _pk_keys_by_table(self): + return dict( + ( + table, + frozenset([col.key for col in pks]) + ) + for table, pks in self._pks_by_table.items() + ) + + @_memoized_configured_property + def _server_default_cols(self): + return dict( + ( + table, + frozenset([ + col for col in columns + if col.server_default is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + @property def selectable(self): """The :func:`.select` construct this :class:`.Mapper` selects from diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index a8d4bd695..782d94dc8 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -34,29 +34,35 @@ def bulk_insert(mapper, mappings, uowtransaction): "not supported in bulk_insert()") connection = uowtransaction.transaction.connection(base_mapper) - + value_params = {} for table, sub_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(sub_mapper): continue - to_translate = dict( - (mapper._columntoproperty[col].key, col.key) - for col in mapper._cols_by_table[table] - ) has_version_generator = mapper.version_id_generator is not False and \ mapper.version_id_col is not None - multiparams = [] + + records = [] for mapping in mappings: params = dict( - (k, mapping.get(v)) for k, v in to_translate.items() + (col.key, mapping[propkey]) + for col, propkey in mapper._col_to_propkey[table] + if propkey in mapping ) + if has_version_generator: params[mapper.version_id_col.key] = \ mapper.version_id_generator(None) - multiparams.append(params) - statement = base_mapper._memo(('insert', table), table.insert) - cached_connections[connection].execute(statement, multiparams) + records.append( + (None, None, params, sub_mapper, + connection, value_params, True, True) + ) + + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, records, + bookkeeping=False) def bulk_update(mapper, mappings, uowtransaction): @@ -71,52 +77,41 @@ def bulk_update(mapper, mappings, uowtransaction): connection = uowtransaction.transaction.connection(base_mapper) + value_params = {} for table, sub_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(sub_mapper): continue - needs_version_id = sub_mapper.version_id_col is not None and \ - table.c.contains_column(sub_mapper.version_id_col) - - def update_stmt(): - return _update_stmt_for_mapper(sub_mapper, table, needs_version_id) - - statement = base_mapper._memo(('update', table), update_stmt) + label_pks = mapper._pks_by_table[table] + if mapper.version_id_col is not None: + label_pks = label_pks.union([mapper.version_id_col]) - pks = mapper._pks_by_table[table] to_translate = dict( - (mapper._columntoproperty[col].key, col._label - if col in pks else col.key) - for col in mapper._cols_by_table[table] + (propkey, col._label if col in label_pks else col.key) + for col, propkey in mapper._col_to_propkey[table] ) - for colnames, sub_mappings in groupby( - mappings, - lambda mapping: sorted(tuple(mapping.keys()))): - - multiparams = [] - for mapping in sub_mappings: - params = dict( - (to_translate[k], v) for k, v in mapping.items() - ) - multiparams.append(params) - - c = cached_connections[connection].execute(statement, multiparams) + records = [] + for mapping in mappings: + params = dict( + (to_translate[k], v) for k, v in mapping.items() + ) - rows = c.rowcount + if mapper.version_id_generator is not False and \ + mapper.version_id_col is not None and \ + mapper.version_id_col.key not in params: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator( + params[mapper.version_id_col._label]) - if connection.dialect.supports_sane_rowcount: - if rows != len(multiparams): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(multiparams), rows)) + records.append( + (None, None, params, sub_mapper, connection, value_params) + ) - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, records, + bookkeeping=False) def save_obj( @@ -342,39 +337,36 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, if table not in mapper._pks_by_table: continue - pks = mapper._pks_by_table[table] - params = {} value_params = {} - - has_all_pks = True - has_all_defaults = True - has_version_id_generator = mapper.version_id_generator is not False \ - and mapper.version_id_col is not None - for col in mapper._cols_by_table[table]: - if has_version_id_generator and col is mapper.version_id_col: - val = mapper.version_id_generator(None) - params[col.key] = val + for col, propkey in mapper._col_to_propkey[table]: + if propkey in state_dict: + value = state_dict[propkey] + if isinstance(value, sql.ClauseElement): + value_params[col.key] = value + elif value is not None or ( + not col.primary_key and + not col.server_default and + not col.default): + params[col.key] = value else: - # pull straight from the dict for - # pending objects - prop = mapper._columntoproperty[col] - value = state_dict.get(prop.key, None) + if not col.server_default \ + and not col.default and not col.primary_key: + params[col.key] = None - if value is None: - if col in pks: - has_all_pks = False - elif col.default is None and \ - col.server_default is None: - params[col.key] = value - elif col.server_default is not None and \ - mapper.base_mapper.eager_defaults: - has_all_defaults = False + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) - elif isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value + if base_mapper.eager_defaults: + has_all_defaults = mapper._server_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True + + if mapper.version_id_generator is not False \ + and mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator(None) insert.append((state, state_dict, params, mapper, connection, value_params, has_all_pks, @@ -572,30 +564,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, return delete -def _update_stmt_for_mapper(mapper, table, needs_version_id): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - if needs_version_id: - clause.clauses.append( - mapper.version_id_col == sql.bindparam( - mapper.version_id_col._label, - type_=mapper.version_id_col.type)) - - stmt = table.update(clause) - if mapper.base_mapper.eager_defaults: - stmt = stmt.return_defaults() - elif mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - - return stmt - - def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): + cached_connections, mapper, table, update, + bookkeeping=True): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -603,7 +574,25 @@ def _emit_update_statements(base_mapper, uowtransaction, table.c.contains_column(mapper.version_id_col) def update_stmt(): - return _update_stmt_for_mapper(mapper, table, needs_version_id) + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append( + mapper.version_id_col == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type)) + + stmt = table.update(clause) + if mapper.base_mapper.eager_defaults: + stmt = stmt.return_defaults() + elif mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt statement = base_mapper._memo(('update', table), update_stmt) @@ -624,15 +613,16 @@ def _emit_update_statements(base_mapper, uowtransaction, c = connection.execute( statement.values(value_params), params) - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) rows += c.rowcount else: multiparams = [rec[2] for rec in records] @@ -640,17 +630,18 @@ def _emit_update_statements(base_mapper, uowtransaction, execute(statement, multiparams) rows += c.rowcount - for state, state_dict, params, mapper, \ - connection, value_params in records: - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + if bookkeeping: + for state, state_dict, params, mapper, \ + connection, value_params in records: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) if connection.dialect.supports_sane_rowcount: if rows != len(records): @@ -667,7 +658,8 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert): + cached_connections, mapper, table, insert, + bookkeeping=True): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" @@ -676,11 +668,11 @@ def _emit_insert_statements(base_mapper, uowtransaction, for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ records in groupby(insert, lambda rec: (rec[4], - list(rec[2].keys()), + tuple(sorted(rec[2].keys())), bool(rec[5]), rec[6], rec[7]) ): - if \ + if not bookkeeping or \ ( has_all_defaults or not base_mapper.eager_defaults @@ -693,19 +685,20 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + if bookkeeping: + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: -- cgit v1.2.1 From 84cca0e28660b5d35c35195aa57c89b094fa897d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 18 Aug 2014 18:30:14 -0400 Subject: dev --- lib/sqlalchemy/orm/persistence.py | 47 +++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 26 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 8d3e90cf4..f9e7eda28 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -34,26 +34,18 @@ def bulk_insert(mapper, mappings, uowtransaction): "not supported in bulk_insert()") connection = uowtransaction.transaction.connection(base_mapper) - value_params = {} for table, sub_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(sub_mapper): continue - has_version_generator = mapper.version_id_generator is not False and \ - mapper.version_id_col is not None - records = [] - for mapping in mappings: - params = dict( - (col.key, mapping[propkey]) - for col, propkey in mapper._col_to_propkey[table] - if propkey in mapping - ) - - if has_version_generator: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator(None) - + for ( + state, state_dict, params, mapper, + connection, value_params, has_all_pks, + has_all_defaults) in _collect_insert_commands(table, ( + (None, mapping, sub_mapper, connection) + for mapping in mappings) + ): records.append( (None, None, params, sub_mapper, connection, value_params, True, True) @@ -82,13 +74,13 @@ def bulk_update(mapper, mappings, uowtransaction): if not mapper.isa(sub_mapper): continue - label_pks = mapper._pks_by_table[table] + label_pks = sub_mapper._pks_by_table[table] if mapper.version_id_col is not None: label_pks = label_pks.union([mapper.version_id_col]) to_translate = dict( (propkey, col._label if col in label_pks else col.key) - for col, propkey in mapper._col_to_propkey[table] + for propkey, col in sub_mapper._propkey_to_col[table].items() ) records = [] @@ -350,7 +342,7 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): yield state, dict_, mapper, bool(state.key), connection -def _collect_insert_commands(table, states_to_insert): +def _collect_insert_commands(table, states_to_insert, bulk=False): """Identify sets of values to use in INSERT statements for a list of states. @@ -374,17 +366,20 @@ def _collect_insert_commands(table, states_to_insert): else: params[col.key] = value - for colkey in mapper._insert_cols_as_none[table].\ - difference(params).difference(value_params): - params[colkey] = None + if not bulk: + for colkey in mapper._insert_cols_as_none[table].\ + difference(params).difference(value_params): + params[colkey] = None - has_all_pks = mapper._pk_keys_by_table[table].issubset(params) + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) - if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].\ - issubset(params) + if mapper.base_mapper.eager_defaults: + has_all_defaults = mapper._server_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True else: - has_all_defaults = True + has_all_defaults = has_all_pks = True if mapper.version_id_generator is not False \ and mapper.version_id_col is not None and \ -- cgit v1.2.1 From a251001f24e819f1ebc525948437563f52a3a226 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 18 Aug 2014 18:52:53 -0400 Subject: dev --- lib/sqlalchemy/orm/persistence.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index f9e7eda28..145a7783a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -34,26 +34,25 @@ def bulk_insert(mapper, mappings, uowtransaction): "not supported in bulk_insert()") connection = uowtransaction.transaction.connection(base_mapper) - for table, sub_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(sub_mapper): + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper): continue - records = [] - for ( - state, state_dict, params, mapper, - connection, value_params, has_all_pks, - has_all_defaults) in _collect_insert_commands(table, ( - (None, mapping, sub_mapper, connection) + records = ( + (None, None, params, super_mapper, + connection, value_params, True, True) + for + state, state_dict, params, mp, + conn, value_params, has_all_pks, + has_all_defaults in _collect_insert_commands(table, ( + (None, mapping, super_mapper, connection) for mapping in mappings) - ): - records.append( - (None, None, params, sub_mapper, - connection, value_params, True, True) ) + ) _emit_insert_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, records, + super_mapper, table, records, bookkeeping=False) @@ -70,17 +69,17 @@ def bulk_update(mapper, mappings, uowtransaction): connection = uowtransaction.transaction.connection(base_mapper) value_params = {} - for table, sub_mapper in base_mapper._sorted_tables.items(): - if not mapper.isa(sub_mapper): + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper): continue - label_pks = sub_mapper._pks_by_table[table] + label_pks = super_mapper._pks_by_table[table] if mapper.version_id_col is not None: label_pks = label_pks.union([mapper.version_id_col]) to_translate = dict( (propkey, col._label if col in label_pks else col.key) - for propkey, col in sub_mapper._propkey_to_col[table].items() + for propkey, col in super_mapper._propkey_to_col[table].items() ) records = [] @@ -97,12 +96,12 @@ def bulk_update(mapper, mappings, uowtransaction): params[mapper.version_id_col._label]) records.append( - (None, None, params, sub_mapper, connection, value_params) + (None, None, params, super_mapper, connection, value_params) ) _emit_update_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, records, + super_mapper, table, records, bookkeeping=False) -- cgit v1.2.1 From 91959122e0a12943e5ff9399024c65ad4d7489e1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 19 Aug 2014 14:24:56 -0400 Subject: - refinements --- lib/sqlalchemy/orm/events.py | 12 ----- lib/sqlalchemy/orm/mapper.py | 4 ++ lib/sqlalchemy/orm/persistence.py | 107 +++++++++++++++++++++++++------------- lib/sqlalchemy/orm/session.py | 29 ++++------- lib/sqlalchemy/orm/unitofwork.py | 6 --- 5 files changed, 86 insertions(+), 72 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 37ea3071b..aa99673ba 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1453,18 +1453,6 @@ class SessionEvents(event.Events): """ - def before_bulk_insert(self, session, flush_context, mapper, mappings): - """""" - - def after_bulk_insert(self, session, flush_context, mapper, mappings): - """""" - - def before_bulk_update(self, session, flush_context, mapper, mappings): - """""" - - def after_bulk_update(self, session, flush_context, mapper, mappings): - """""" - def after_begin(self, session, transaction, connection): """Execute after a transaction is begun on a connection diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 89c092b58..b98fbda42 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2366,6 +2366,10 @@ class Mapper(InspectionAttr): def _primary_key_props(self): return [self._columntoproperty[col] for col in self.primary_key] + @_memoized_configured_property + def _primary_key_propkeys(self): + return set([prop.key for prop in self._primary_key_props]) + def _get_state_attr_by_column( self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 145a7783a..9c0008925 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,17 +23,22 @@ from ..sql import expression from . import loading -def bulk_insert(mapper, mappings, uowtransaction): +def _bulk_insert(mapper, mappings, session_transaction, isstates): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) - if uowtransaction.session.connection_callable: + if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_insert()") - connection = uowtransaction.transaction.connection(base_mapper) + if isstates: + mappings = [state.dict for state in mappings] + else: + mappings = list(mappings) + + connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue @@ -45,61 +50,55 @@ def bulk_insert(mapper, mappings, uowtransaction): state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(table, ( - (None, mapping, super_mapper, connection) - for mapping in mappings) + (None, mapping, mapper, connection) + for mapping in mappings), + bulk=True ) ) - _emit_insert_statements(base_mapper, uowtransaction, + _emit_insert_statements(base_mapper, None, cached_connections, super_mapper, table, records, bookkeeping=False) -def bulk_update(mapper, mappings, uowtransaction): +def _bulk_update(mapper, mappings, session_transaction, isstates): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) - if uowtransaction.session.connection_callable: + def _changed_dict(mapper, state): + return dict( + (k, v) + for k, v in state.dict.items() if k in state.committed_state or k + in mapper._primary_key_propkeys + ) + + if isstates: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = list(mappings) + + if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_update()") - connection = uowtransaction.transaction.connection(base_mapper) + connection = session_transaction.connection(base_mapper) value_params = {} + for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue - label_pks = super_mapper._pks_by_table[table] - if mapper.version_id_col is not None: - label_pks = label_pks.union([mapper.version_id_col]) - - to_translate = dict( - (propkey, col._label if col in label_pks else col.key) - for propkey, col in super_mapper._propkey_to_col[table].items() + records = ( + (None, None, params, super_mapper, connection, value_params) + for + params in _collect_bulk_update_commands(mapper, table, mappings) ) - records = [] - for mapping in mappings: - params = dict( - (to_translate[k], v) for k, v in mapping.items() - ) - - if mapper.version_id_generator is not False and \ - mapper.version_id_col is not None and \ - mapper.version_id_col.key not in params: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator( - params[mapper.version_id_col._label]) - - records.append( - (None, None, params, super_mapper, connection, value_params) - ) - - _emit_update_statements(base_mapper, uowtransaction, + _emit_update_statements(base_mapper, None, cached_connections, super_mapper, table, records, bookkeeping=False) @@ -360,7 +359,7 @@ def _collect_insert_commands(table, states_to_insert, bulk=False): col = propkey_to_col[propkey] if value is None: continue - elif isinstance(value, sql.ClauseElement): + elif not bulk and isinstance(value, sql.ClauseElement): value_params[col.key] = value else: params[col.key] = value @@ -481,6 +480,44 @@ def _collect_update_commands(uowtransaction, table, states_to_update): state, state_dict, params, mapper, connection, value_params) +def _collect_bulk_update_commands(mapper, table, mappings): + label_pks = mapper._pks_by_table[table] + if mapper.version_id_col is not None: + label_pks = label_pks.union([mapper.version_id_col]) + + to_translate = dict( + (propkey, col.key if col not in label_pks else col._label) + for propkey, col in mapper._propkey_to_col[table].items() + ) + + for mapping in mappings: + params = dict( + (to_translate[k], mapping[k]) for k in to_translate + if k in mapping and k not in mapper._primary_key_propkeys + ) + + if not params: + continue + + try: + params.update( + (to_translate[k], mapping[k]) for k in + mapper._primary_key_propkeys.intersection(to_translate) + ) + except KeyError as ke: + raise orm_exc.FlushError( + "Can't update table using NULL for primary " + "key attribute: %s" % ke) + + if mapper.version_id_generator is not False and \ + mapper.version_id_col is not None and \ + mapper.version_id_col.key not in params: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator( + params[mapper.version_id_col._label]) + + yield params + def _collect_post_update_commands(base_mapper, uowtransaction, table, states_to_update, post_update_cols): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 3199a4332..968868e84 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -21,6 +21,7 @@ from .base import ( _none_set, state_str, instance_str ) import itertools +from . import persistence from .unitofwork import UOWTransaction from . import state as statelib import sys @@ -2040,37 +2041,27 @@ class Session(_SessionClassMethods): (attributes.instance_state(obj) for obj in objects), lambda state: (state.mapper, state.key is not None) ): - if isupdate: - self.bulk_update_mappings(mapper, (s.dict for s in states)) - else: - self.bulk_insert_mappings(mapper, (s.dict for s in states)) + self._bulk_save_mappings(mapper, states, isupdate, True) def bulk_insert_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, False) + self._bulk_save_mappings(mapper, mappings, False, False) def bulk_update_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, True) + self._bulk_save_mappings(mapper, mappings, True, False) - def _bulk_save_mappings(self, mapper, mappings, isupdate): + def _bulk_save_mappings(self, mapper, mappings, isupdate, isstates): mapper = _class_to_mapper(mapper) self._flushing = True - flush_context = UOWTransaction(self) - flush_context.transaction = transaction = self.begin( + transaction = self.begin( subtransactions=True) try: if isupdate: - self.dispatch.before_bulk_update( - self, flush_context, mapper, mappings) - flush_context.bulk_update(mapper, mappings) - self.dispatch.after_bulk_update( - self, flush_context, mapper, mappings) + persistence._bulk_update( + mapper, mappings, transaction, isstates) else: - self.dispatch.before_bulk_insert( - self, flush_context, mapper, mappings) - flush_context.bulk_insert(mapper, mappings) - self.dispatch.after_bulk_insert( - self, flush_context, mapper, mappings) + persistence._bulk_insert( + mapper, mappings, transaction, isstates) transaction.commit() except: diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index b3a1519c5..05265b13f 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -394,12 +394,6 @@ class UOWTransaction(object): if other: self.session._register_newly_persistent(other) - def bulk_insert(self, mapper, mappings): - persistence.bulk_insert(mapper, mappings, self) - - def bulk_update(self, mapper, mappings): - persistence.bulk_update(mapper, mappings, self) - class IterateMappersMixin(object): def _mappers(self, uow): -- cgit v1.2.1 From fcea5c86d3a9097caa04e2e35fa6404a3ef32044 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 19 Aug 2014 18:26:11 -0400 Subject: - rename mapper._primary_key_props to mapper._identity_key_props - ensure bulk update is using all PK cols for all tables --- lib/sqlalchemy/orm/mapper.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 63d23e31d..31c17e69e 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1244,7 +1244,7 @@ class Mapper(InspectionAttr): self._readonly_props = set( self._columntoproperty[col] for col in self._columntoproperty - if self._columntoproperty[col] not in self._primary_key_props and + if self._columntoproperty[col] not in self._identity_key_props and (not hasattr(col, 'table') or col.table not in self._cols_by_table)) @@ -2359,19 +2359,23 @@ class Mapper(InspectionAttr): manager[prop.key]. impl.get(state, dict_, attributes.PASSIVE_RETURN_NEVER_SET) - for prop in self._primary_key_props + for prop in self._identity_key_props ] @_memoized_configured_property - def _primary_key_props(self): - # TODO: this should really be called "identity key props", - # as it does not necessarily include primary key columns within - # individual tables + def _identity_key_props(self): return [self._columntoproperty[col] for col in self.primary_key] + @_memoized_configured_property + def _all_pk_props(self): + collection = set() + for table in self.tables: + collection.update(self._pks_by_table[table]) + return collection + @_memoized_configured_property def _primary_key_propkeys(self): - return set([prop.key for prop in self._primary_key_props]) + return set([prop.key for prop in self._all_pk_props]) def _get_state_attr_by_column( self, state, dict_, column, -- cgit v1.2.1 From db70b6e79e263c137f4d282c9c600417636afa25 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 20 Aug 2014 17:15:20 -0400 Subject: - that's it, feature is finished, needs tests --- lib/sqlalchemy/orm/persistence.py | 195 +++++++++++++++++--------------------- 1 file changed, 89 insertions(+), 106 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index c2750eeb3..aa10da9f4 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -15,7 +15,7 @@ in unitofwork.py. """ import operator -from itertools import groupby +from itertools import groupby, chain from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator from .base import state_str, _attr_as_key @@ -86,17 +86,16 @@ def _bulk_update(mapper, mappings, session_transaction, isstates): connection = session_transaction.connection(base_mapper) - value_params = {} - for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue - records = ( - (None, None, params, super_mapper, connection, value_params) - for - params in _collect_bulk_update_commands(mapper, table, mappings) - ) + records = _collect_update_commands(None, table, ( + (None, mapping, mapper, connection, + (mapping[mapper._version_id_prop.key] + if mapper._version_id_prop else None)) + for mapping in mappings + ), bulk=True) _emit_update_statements(base_mapper, None, cached_connections, @@ -158,17 +157,16 @@ def save_obj( _finalize_insert_update_commands( base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, False) - for state, state_dict, mapper, connection in states_to_insert - ) - ) - _finalize_insert_update_commands( - base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, True) - for state, state_dict, mapper, connection, - update_version_id in states_to_update + chain( + ( + (state, state_dict, mapper, connection, False) + for state, state_dict, mapper, connection in states_to_insert + ), + ( + (state, state_dict, mapper, connection, True) + for state, state_dict, mapper, connection, + update_version_id in states_to_update + ) ) ) @@ -394,7 +392,9 @@ def _collect_insert_commands(table, states_to_insert, bulk=False): has_all_defaults) -def _collect_update_commands(uowtransaction, table, states_to_update): +def _collect_update_commands( + uowtransaction, table, states_to_update, + bulk=False): """Identify sets of values to use in UPDATE statements for a list of states. @@ -414,23 +414,32 @@ def _collect_update_commands(uowtransaction, table, states_to_update): pks = mapper._pks_by_table[table] - params = {} value_params = {} propkey_to_col = mapper._propkey_to_col[table] - for propkey in set(propkey_to_col).intersection(state.committed_state): - value = state_dict[propkey] - col = propkey_to_col[propkey] - - if not state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey]): - if isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value + if bulk: + params = dict( + (propkey_to_col[propkey].key, state_dict[propkey]) + for propkey in + set(propkey_to_col).intersection(state_dict) + ) + else: + params = {} + for propkey in set(propkey_to_col).intersection( + state.committed_state): + value = state_dict[propkey] + col = propkey_to_col[propkey] + + if not state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey]): + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value - if update_version_id is not None: + if update_version_id is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: col = mapper.version_id_col params[col._label] = update_version_id @@ -442,24 +451,33 @@ def _collect_update_commands(uowtransaction, table, states_to_update): if not (params or value_params): continue - pk_params = {} - for col in pks: - propkey = mapper._columntoproperty[col].key - history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF) - - if history.added: - if not history.deleted or \ - ("pk_cascaded", state, col) in \ - uowtransaction.attributes: - pk_params[col._label] = history.added[0] - params.pop(col.key, None) + if bulk: + pk_params = dict( + (propkey_to_col[propkey]._label, state_dict.get(propkey)) + for propkey in + set(propkey_to_col). + intersection(mapper._pk_keys_by_table[table]) + ) + else: + pk_params = {} + for col in pks: + propkey = mapper._columntoproperty[col].key + + history = state.manager[propkey].impl.get_history( + state, state_dict, attributes.PASSIVE_OFF) + + if history.added: + if not history.deleted or \ + ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + pk_params[col._label] = history.added[0] + params.pop(col.key, None) + else: + # else, use the old value to locate the row + pk_params[col._label] = history.deleted[0] + params[col.key] = history.added[0] else: - # else, use the old value to locate the row - pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] - else: - pk_params[col._label] = history.unchanged[0] + pk_params[col._label] = history.unchanged[0] if params or value_params: if None in pk_params.values(): @@ -471,44 +489,6 @@ def _collect_update_commands(uowtransaction, table, states_to_update): state, state_dict, params, mapper, connection, value_params) -def _collect_bulk_update_commands(mapper, table, mappings): - label_pks = mapper._pks_by_table[table] - if mapper.version_id_col is not None: - label_pks = label_pks.union([mapper.version_id_col]) - - to_translate = dict( - (propkey, col.key if col not in label_pks else col._label) - for propkey, col in mapper._propkey_to_col[table].items() - ) - - for mapping in mappings: - params = dict( - (to_translate[k], mapping[k]) for k in to_translate - if k in mapping and k not in mapper._primary_key_propkeys - ) - - if not params: - continue - - try: - params.update( - (to_translate[k], mapping[k]) for k in - mapper._primary_key_propkeys.intersection(to_translate) - ) - except KeyError as ke: - raise orm_exc.FlushError( - "Can't update table using NULL for primary " - "key attribute: %s" % ke) - - if mapper.version_id_generator is not False and \ - mapper.version_id_col is not None and \ - mapper.version_id_col.key not in params: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator( - params[mapper.version_id_col._label]) - - yield params - def _collect_post_update_commands(base_mapper, uowtransaction, table, states_to_update, post_update_cols): @@ -569,7 +549,7 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, "key value") if update_version_id is not None and \ - table.c.contains_column(mapper.version_id_col): + mapper.version_id_col in mapper._cols_by_table[table]: params[mapper.version_id_col.key] = update_version_id yield params, connection @@ -581,7 +561,7 @@ def _emit_update_statements(base_mapper, uowtransaction, by _collect_update_commands().""" needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def update_stmt(): clause = sql.and_() @@ -610,9 +590,9 @@ def _emit_update_statements(base_mapper, uowtransaction, records in groupby( update, lambda rec: ( - rec[4], - tuple(sorted(rec[2])), - bool(rec[5]))): + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]))): # whether or not we have "value" parameters rows = 0 records = list(records) @@ -692,12 +672,14 @@ def _emit_insert_statements(base_mapper, uowtransaction, statement = base_mapper._memo(('insert', table), table.insert) for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ - records in groupby(insert, - lambda rec: (rec[4], - tuple(sorted(rec[2].keys())), - bool(rec[5]), - rec[6], rec[7]) - ): + records in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7])): if not bookkeeping or \ ( has_all_defaults @@ -785,7 +767,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # also group them into common (connection, cols) sets # to support executemany(). for key, grouper in groupby( - update, lambda rec: (rec[1], sorted(rec[0])) + update, lambda rec: ( + rec[1], # connection + set(rec[0]) # parameter keys + ) ): connection = key[0] multiparams = [params for params, conn in grouper] @@ -799,7 +784,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, by _collect_delete_commands().""" need_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def delete_stmt(): clause = sql.and_() @@ -821,12 +806,9 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, statement = base_mapper._memo(('delete', table), delete_stmt) for connection, recs in groupby( delete, - lambda rec: rec[1] + lambda rec: rec[1] # connection ): - del_objects = [ - params - for params, connection in recs - ] + del_objects = [params for params, connection in recs] connection = cached_connections[connection] @@ -931,7 +913,8 @@ def _postfetch(mapper, uowtransaction, table, postfetch_cols = result.context.postfetch_cols returning_cols = result.context.returning_cols - if mapper.version_id_col is not None: + if mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] if returning_cols: -- cgit v1.2.1 From ccfd26d96916cc7953f1fefa8abed53d4f696c4c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 2 Sep 2014 19:23:09 -0400 Subject: - add options to get back pk defaults for inserts. times spent start getting barely different... --- lib/sqlalchemy/orm/persistence.py | 37 ++++++++++++++++++++++++++----------- lib/sqlalchemy/orm/session.py | 16 +++++++++------- 2 files changed, 35 insertions(+), 18 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 198eeb46f..2a697a6f9 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,7 +23,8 @@ from ..sql import expression from . import loading -def _bulk_insert(mapper, mappings, session_transaction, isstates): +def _bulk_insert( + mapper, mappings, session_transaction, isstates, return_defaults): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -34,7 +35,11 @@ def _bulk_insert(mapper, mappings, session_transaction, isstates): "not supported in bulk_insert()") if isstates: - mappings = [state.dict for state in mappings] + if return_defaults: + states = [(state, state.dict) for state in mappings] + mappings = [dict_ for (state, dict_) in states] + else: + mappings = [state.dict for state in mappings] else: mappings = list(mappings) @@ -44,22 +49,30 @@ def _bulk_insert(mapper, mappings, session_transaction, isstates): continue records = ( - (None, None, params, super_mapper, - connection, value_params, True, True) + (None, state_dict, params, super_mapper, + connection, value_params, has_all_pks, has_all_defaults) for state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(table, ( (None, mapping, mapper, connection) for mapping in mappings), - bulk=True + bulk=True, return_defaults=return_defaults ) ) - _emit_insert_statements(base_mapper, None, cached_connections, super_mapper, table, records, - bookkeeping=False) + bookkeeping=return_defaults) + + if return_defaults and isstates: + identity_cls = mapper._identity_class + identity_props = [p.key for p in mapper._identity_key_props] + for state, dict_ in states: + state.key = ( + identity_cls, + tuple([dict_[key] for key in identity_props]) + ) def _bulk_update(mapper, mappings, session_transaction, isstates): @@ -341,7 +354,9 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): state, dict_, mapper, connection, update_version_id) -def _collect_insert_commands(table, states_to_insert, bulk=False): +def _collect_insert_commands( + table, states_to_insert, + bulk=False, return_defaults=False): """Identify sets of values to use in INSERT statements for a list of states. @@ -370,6 +385,7 @@ def _collect_insert_commands(table, states_to_insert, bulk=False): difference(params).difference(value_params): params[colkey] = None + if not bulk or return_defaults: has_all_pks = mapper._pk_keys_by_table[table].issubset(params) if mapper.base_mapper.eager_defaults: @@ -884,9 +900,8 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): toload_now.extend(state._unloaded_non_object) elif mapper.version_id_col is not None and \ mapper.version_id_generator is False: - prop = mapper._columntoproperty[mapper.version_id_col] - if prop.key in state.unloaded: - toload_now.extend([prop.key]) + if mapper._version_id_prop.key in state.unloaded: + toload_now.extend([mapper._version_id_prop.key]) if toload_now: state.key = base_mapper._identity_key_from_state(state) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e075b9c71..1611688b0 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2036,20 +2036,22 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) - def bulk_save_objects(self, objects): + def bulk_save_objects(self, objects, return_defaults=False): for (mapper, isupdate), states in itertools.groupby( (attributes.instance_state(obj) for obj in objects), lambda state: (state.mapper, state.key is not None) ): - self._bulk_save_mappings(mapper, states, isupdate, True) + self._bulk_save_mappings( + mapper, states, isupdate, True, return_defaults) - def bulk_insert_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, False, False) + def bulk_insert_mappings(self, mapper, mappings, return_defaults=False): + self._bulk_save_mappings(mapper, mappings, False, False, return_defaults) def bulk_update_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, True, False) + self._bulk_save_mappings(mapper, mappings, True, False, False) - def _bulk_save_mappings(self, mapper, mappings, isupdate, isstates): + def _bulk_save_mappings( + self, mapper, mappings, isupdate, isstates, return_defaults): mapper = _class_to_mapper(mapper) self._flushing = True @@ -2061,7 +2063,7 @@ class Session(_SessionClassMethods): mapper, mappings, transaction, isstates) else: persistence._bulk_insert( - mapper, mappings, transaction, isstates) + mapper, mappings, transaction, isstates, return_defaults) transaction.commit() except: -- cgit v1.2.1 From 16d9d366bd80b3f9b42c89ceb3e392de15631188 Mon Sep 17 00:00:00 2001 From: jonathan vanasco Date: Fri, 3 Oct 2014 13:15:52 -0400 Subject: * adding 'isouter=False' to sqlalchemy.orm.query.Query (https://bitbucket.org/zzzeek/sqlalchemy/issue/3217/make-join-more-standard-or-improve-error) $ python setup.py develop $ pip install nose $ pip install mock $ ./sqla_nose.py test.orm.test_joins ..................................................................................................... ---------------------------------------------------------------------- Ran 101 tests in 1.222s OK $ ./sqla_nose.py test.orm ......................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................S......................................................................................................................................................................................................................................................................................................................S.......................................................................................................................................................................................................................................................................................................................................................S.......S..S.SSS.SS...............................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................S................................S..S........................S...........................................................................................SSS.S.........SSSSSSSS......SSSSSSSSS........SS...SS...............S.............................S..............................................................SS..SS..............................................................................................................S. ---------------------------------------------------------------------- Ran 3103 tests in 82.607s OK (SKIP=46) --- lib/sqlalchemy/orm/query.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 7b2ea7977..eaa3a8dcd 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1740,6 +1740,8 @@ class Query(object): anonymously aliased. Subsequent calls to :meth:`~.Query.filter` and similar will adapt the incoming criterion to the target alias, until :meth:`~.Query.reset_joinpoint` is called. + :param isouter=False: If True, the join used will be a left outer join, + just as if the ``outerjoin()`` method were called. :param from_joinpoint=False: When using ``aliased=True``, a setting of True here will cause the join to be from the most recent joined target, rather than starting back from the original @@ -1757,13 +1759,15 @@ class Query(object): SQLAlchemy versions was the primary ORM-level joining interface. """ - aliased, from_joinpoint = kwargs.pop('aliased', False),\ - kwargs.pop('from_joinpoint', False) + aliased, from_joinpoint, isouter = kwargs.pop('aliased', False),\ + kwargs.pop('from_joinpoint', False),\ + kwargs.pop('isouter', False) if kwargs: raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys)) + isouter = isouter return self._join(props, - outerjoin=False, create_aliases=aliased, + outerjoin=isouter, create_aliases=aliased, from_joinpoint=from_joinpoint) def outerjoin(self, *props, **kwargs): @@ -3385,7 +3389,6 @@ class _BundleEntity(_QueryEntity): self.supports_single_entity = self.bundle.single_entity - @property def entity_zero(self): for ent in self._entities: -- cgit v1.2.1 From 0f5a400b77862d2ae8f5d1a326fe9571da8fc0cb Mon Sep 17 00:00:00 2001 From: jonathan vanasco Date: Fri, 17 Oct 2014 19:35:29 -0400 Subject: added docs to clarify that sql statement is already in a dialect --- lib/sqlalchemy/events.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index 1ff35b8b0..86bd3653b 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -420,6 +420,12 @@ class ConnectionEvents(event.Events): context, executemany): log.info("Received statement: %s" % statement) + When the methods are called with a `statement` parameter, such as in + :meth:`.after_cursor_execute`, :meth:`.before_cursor_execute` and + :meth:`.dbapi_error`, the statement is the exact SQL string that was + prepared for transmission to the DBAPI ``cursor`` in the connection's + :class:`.Dialect`. + The :meth:`.before_execute` and :meth:`.before_cursor_execute` events can also be established with the ``retval=True`` flag, which allows modification of the statement and parameters to be sent @@ -549,9 +555,8 @@ class ConnectionEvents(event.Events): def before_cursor_execute(self, conn, cursor, statement, parameters, context, executemany): """Intercept low-level cursor execute() events before execution, - receiving the string - SQL statement and DBAPI-specific parameter list to be invoked - against a cursor. + receiving the string SQL statement and DBAPI-specific parameter list to + be invoked against a cursor. This event is a good choice for logging as well as late modifications to the SQL string. It's less ideal for parameter modifications except @@ -571,7 +576,7 @@ class ConnectionEvents(event.Events): :param conn: :class:`.Connection` object :param cursor: DBAPI cursor object - :param statement: string SQL statement + :param statement: string SQL statement, as to be passed to the DBAPI :param parameters: Dictionary, tuple, or list of parameters being passed to the ``execute()`` or ``executemany()`` method of the DBAPI ``cursor``. In some cases may be ``None``. @@ -596,7 +601,7 @@ class ConnectionEvents(event.Events): :param cursor: DBAPI cursor object. Will have results pending if the statement was a SELECT, but these should not be consumed as they will be needed by the :class:`.ResultProxy`. - :param statement: string SQL statement + :param statement: string SQL statement, as passed to the DBAPI :param parameters: Dictionary, tuple, or list of parameters being passed to the ``execute()`` or ``executemany()`` method of the DBAPI ``cursor``. In some cases may be ``None``. @@ -640,7 +645,7 @@ class ConnectionEvents(event.Events): :param conn: :class:`.Connection` object :param cursor: DBAPI cursor object - :param statement: string SQL statement + :param statement: string SQL statement, as passed to the DBAPI :param parameters: Dictionary, tuple, or list of parameters being passed to the ``execute()`` or ``executemany()`` method of the DBAPI ``cursor``. In some cases may be ``None``. -- cgit v1.2.1 From 25434e9209af9ee2c05b651bc4fe197541c0bd60 Mon Sep 17 00:00:00 2001 From: Scott Dugas Date: Wed, 22 Oct 2014 15:09:05 -0400 Subject: Support additional args/kwargs on cursor method fdbsql has an optional nested kwarg, which is supported in the actual code, but not in the testing proxy --- lib/sqlalchemy/testing/engines.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 67c13231e..75bcc58e1 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -284,10 +284,10 @@ class DBAPIProxyCursor(object): """ - def __init__(self, engine, conn): + def __init__(self, engine, conn, *args, **kwargs): self.engine = engine self.connection = conn - self.cursor = conn.cursor() + self.cursor = conn.cursor(*args, **kwargs) def execute(self, stmt, parameters=None, **kw): if parameters: @@ -315,8 +315,10 @@ class DBAPIProxyConnection(object): self.engine = engine self.cursor_cls = cursor_cls - def cursor(self): - return self.cursor_cls(self.engine, self.conn) + def cursor(self, *args, **kwargs): + print "DPA", args + print "DPK", kwargs + return self.cursor_cls(self.engine, self.conn, *args, **kwargs) def close(self): self.conn.close() -- cgit v1.2.1 From 9c0eb840788ed5971f0876958cfb9866c7af918d Mon Sep 17 00:00:00 2001 From: Scott Dugas Date: Thu, 23 Oct 2014 10:24:35 -0400 Subject: Print useful traceback on error _expect_failure was rethrowing the exception without keeping the traceback, so it was really hard to find out what was actually wrong --- lib/sqlalchemy/testing/exclusions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 283d89e36..5ce8bcd84 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -12,6 +12,7 @@ from ..util import decorator from . import config from .. import util import inspect +import sys import contextlib @@ -120,20 +121,21 @@ class compound(object): try: return_value = fn(*args, **kw) - except Exception as ex: - self._expect_failure(config, ex, name=fn.__name__) + except Exception: + exc_type, exc_value, exc_traceback = sys.exc_info() + self._expect_failure(config, exc_type, exc_value, exc_traceback, name=fn.__name__) else: self._expect_success(config, name=fn.__name__) return return_value - def _expect_failure(self, config, ex, name='block'): + def _expect_failure(self, config, exc_type, exc_value, exc_traceback, name='block'): for fail in self.fails: if fail(config): print(("%s failed as expected (%s): %s " % ( name, fail._as_string(config), str(ex)))) break else: - raise ex + raise exc_type, exc_value, exc_traceback def _expect_success(self, config, name='block'): if not self.fails: -- cgit v1.2.1 From 2ce9333a24a1f894de4bf028f51eb1de28c10a3d Mon Sep 17 00:00:00 2001 From: Scott Dugas Date: Thu, 23 Oct 2014 13:01:23 -0400 Subject: Forgot to update usage of ex to exc_value --- lib/sqlalchemy/testing/exclusions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 5ce8bcd84..c9f81c8b9 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -132,7 +132,7 @@ class compound(object): for fail in self.fails: if fail(config): print(("%s failed as expected (%s): %s " % ( - name, fail._as_string(config), str(ex)))) + name, fail._as_string(config), str(exc_value)))) break else: raise exc_type, exc_value, exc_traceback -- cgit v1.2.1 From ebb9d57cb385f49becbf54c6f78647715ddd1c29 Mon Sep 17 00:00:00 2001 From: Scott Dugas Date: Thu, 30 Oct 2014 16:40:36 -0400 Subject: Removed accidental print statements --- lib/sqlalchemy/testing/engines.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 75bcc58e1..3a3f5be10 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -316,8 +316,6 @@ class DBAPIProxyConnection(object): self.cursor_cls = cursor_cls def cursor(self, *args, **kwargs): - print "DPA", args - print "DPK", kwargs return self.cursor_cls(self.engine, self.conn, *args, **kwargs) def close(self): -- cgit v1.2.1 From 7bf5ac9c1e814c999d4930941935e1d5cfd236bf Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 31 Oct 2014 20:00:42 -0400 Subject: - ensure kwargs are passed for limit clause on a compound select as well, further fixes for #3034 --- lib/sqlalchemy/sql/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a6c30b7dc..5fa78ad0f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -813,7 +813,7 @@ class SQLCompiler(Compiled): text += self.order_by_clause(cs, **kwargs) text += (cs._limit_clause is not None or cs._offset_clause is not None) and \ - self.limit_clause(cs) or "" + self.limit_clause(cs, **kwargs) or "" if self.ctes and \ compound_index == 0 and toplevel: -- cgit v1.2.1 From 8d154f84f1a552c290a1ccd802f20940c8cab066 Mon Sep 17 00:00:00 2001 From: Scott Dugas Date: Mon, 3 Nov 2014 15:24:31 -0500 Subject: It now calls raise_from_cause master was updated to call util.raise_from_cause which is better than what I had --- lib/sqlalchemy/testing/exclusions.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index e3d91300d..f94724608 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -12,7 +12,6 @@ from ..util import decorator from . import config from .. import util import inspect -import sys import contextlib @@ -121,18 +120,17 @@ class compound(object): try: return_value = fn(*args, **kw) - except Exception: - exc_type, exc_value, exc_traceback = sys.exc_info() - self._expect_failure(config, exc_type, exc_value, exc_traceback, name=fn.__name__) + except Exception as ex: + self._expect_failure(config, ex, name=fn.__name__) else: self._expect_success(config, name=fn.__name__) return return_value - def _expect_failure(self, config, exc_type, exc_value, exc_traceback, name='block'): + def _expect_failure(self, config, ex, name='block'): for fail in self.fails: if fail(config): print(("%s failed as expected (%s): %s " % ( - name, fail._as_string(config), str(exc_value)))) + name, fail._as_string(config), str(ex)))) break else: util.raise_from_cause(ex) -- cgit v1.2.1 From edec583b459e955a30d40b5c5d8baaed0a2ec1c6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 5 Nov 2014 04:22:30 -0500 Subject: - Fixed bug regarding expression mutations which could express itself as a "Could not locate column" error when using :class:`.Query` to select from multiple, anonymous column entities when querying against SQLite, as a side effect of the "join rewriting" feature used by the SQLite dialect. fixes #3241 --- lib/sqlalchemy/sql/elements.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 4d5bb9476..fa9b66024 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -861,6 +861,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): expressions and function calls. """ + while self._is_clone_of is not None: + self = self._is_clone_of + return _anonymous_label( '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon')) ) @@ -2778,6 +2781,10 @@ class Grouping(ColumnElement): def self_group(self, against=None): return self + @property + def _key_label(self): + return self._label + @property def _label(self): return getattr(self.element, '_label', None) or self.anon_label -- cgit v1.2.1 From ea637cef2d9ec54b14fac3620b1cfd47da723f3f Mon Sep 17 00:00:00 2001 From: Paulo Bu Date: Wed, 5 Nov 2014 13:15:08 +0100 Subject: Small improvement on FlushError can't delete error message Output in the error message the table name and the column name. --- lib/sqlalchemy/orm/persistence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 114b79ea5..28254cc10 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -441,9 +441,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, state, state_dict, col) if value is None: raise orm_exc.FlushError( - "Can't delete from table " + "Can't delete from table %s " "using NULL for primary " - "key value") + "key value on column %s" % (table, col)) if update_version_id is not None and \ table.c.contains_column(mapper.version_id_col): -- cgit v1.2.1 From 4b09f1423b382336f29722490bab3a4c8c8607ea Mon Sep 17 00:00:00 2001 From: Paulo Bu Date: Thu, 6 Nov 2014 21:14:17 +0100 Subject: Small improvement on FlushError can't update error message Output in the error message the table name and the column name. --- lib/sqlalchemy/orm/persistence.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 28254cc10..6b8d5af14 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -375,12 +375,12 @@ def _collect_update_commands(uowtransaction, table, states_to_update): params[col.key] = history.added[0] else: pk_params[col._label] = history.unchanged[0] + if pk_params[col._label] is None: + raise orm_exc.FlushError( + "Can't update table %s using NULL for primary " + "key value on column %s" % (table, col)) if params or value_params: - if None in pk_params.values(): - raise orm_exc.FlushError( - "Can't update table using NULL for primary " - "key value") params.update(pk_params) yield ( state, state_dict, params, mapper, -- cgit v1.2.1 From a19b2f419cd876b561a3b3c21ebed5c223192883 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 10 Nov 2014 17:37:26 -0500 Subject: - The :attr:`.Column.key` attribute is now used as the source of anonymous bound parameter names within expressions, to match the existing use of this value as the key when rendered in an INSERT or UPDATE statement. This allows :attr:`.Column.key` to be used as a "substitute" string to work around a difficult column name that doesn't translate well into a bound parameter name. Note that the paramstyle is configurable on :func:`.create_engine` in any case, and most DBAPIs today support a named and positional style. fixes #3245 --- lib/sqlalchemy/sql/elements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index fa9b66024..734f78632 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1092,7 +1092,7 @@ class BindParameter(ColumnElement): """ if isinstance(key, ColumnClause): type_ = key.type - key = key.name + key = key.key if required is NO_ARG: required = (value is NO_ARG and callable_ is None) if value is NO_ARG: @@ -3335,7 +3335,7 @@ class ColumnClause(Immutable, ColumnElement): return name def _bind_param(self, operator, obj): - return BindParameter(self.name, obj, + return BindParameter(self.key, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) -- cgit v1.2.1 From 21022f9760e32cf54d59eaccc12cc9e2fea1d37a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 10 Nov 2014 17:58:09 -0500 Subject: - in lieu of adding a new system of translating bound parameter names for psycopg2 and others, encourage users to take advantage of positional styles by documenting "paramstyle". A section is added to psycopg2 specifically as this is a pretty common spot for named parameters that may be unusually named. fixes #3246. --- lib/sqlalchemy/dialects/postgresql/psycopg2.py | 49 ++++++++++++++++++++++++++ lib/sqlalchemy/engine/__init__.py | 11 ++++++ 2 files changed, 60 insertions(+) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 1a2a1ffe4..f67b2e3b0 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -159,6 +159,55 @@ defaults to ``utf-8``. SQLAlchemy's own unicode encode/decode functionality is steadily becoming obsolete as most DBAPIs now support unicode fully. +Bound Parameter Styles +---------------------- + +The default parameter style for the psycopg2 dialect is "pyformat", where +SQL is rendered using ``%(paramname)s`` style. This format has the limitation +that it does not accommodate the unusual case of parameter names that +actually contain percent or parenthesis symbols; as SQLAlchemy in many cases +generates bound parameter names based on the name of a column, the presence +of these characters in a column name can lead to problems. + +There are two solutions to the issue of a :class:`.schema.Column` that contains +one of these characters in its name. One is to specify the +:paramref:`.schema.Column.key` for columns that have such names:: + + measurement = Table('measurement', metadata, + Column('Size (meters)', Integer, key='size_meters') + ) + +Above, an INSERT statement such as ``measurement.insert()`` will use +``size_meters`` as the parameter name, and a SQL expression such as +``measurement.c.size_meters > 10`` will derive the bound parameter name +from the ``size_meters`` key as well. + +.. versionchanged:: 1.0.0 - SQL expressions will use :attr:`.Column.key` + as the source of naming when anonymous bound parameters are created + in SQL expressions; previously, this behavior only applied to + :meth:`.Table.insert` and :meth:`.Table.update` parameter names. + +The other solution is to use a positional format; psycopg2 allows use of the +"format" paramstyle, which can be passed to +:paramref:`.create_engine.paramstyle`:: + + engine = create_engine( + 'postgresql://scott:tiger@localhost:5432/test', paramstyle='format') + +With the above engine, instead of a statement like:: + + INSERT INTO measurement ("Size (meters)") VALUES (%(Size (meters))s) + {'Size (meters)': 1} + +we instead see:: + + INSERT INTO measurement ("Size (meters)") VALUES (%s) + (1, ) + +Where above, the dictionary style is converted into a tuple with positional +style. + + Transactions ------------ diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 68145f5cd..cf75871bf 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -292,6 +292,17 @@ def create_engine(*args, **kwargs): be used instead. Can be used for testing of DBAPIs as well as to inject "mock" DBAPI implementations into the :class:`.Engine`. + :param paramstyle=None: The `paramstyle `_ + to use when rendering bound parameters. This style defaults to the + one recommended by the DBAPI itself, which is retrieved from the + ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept + more than one paramstyle, and in particular it may be desirable + to change a "named" paramstyle into a "positional" one, or vice versa. + When this attribute is passed, it should be one of the values + ``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or + ``"pyformat"``, and should correspond to a parameter style known + to be supported by the DBAPI in use. + :param pool=None: an already-constructed instance of :class:`~sqlalchemy.pool.Pool`, such as a :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this -- cgit v1.2.1 From b013fb82f5a5d891c6fd776e0e6ed926cdf2ffe1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 11 Nov 2014 12:34:00 -0500 Subject: - Fixed issue where the columns from a SELECT embedded in an INSERT, either through the values clause or as a "from select", would pollute the column types used in the result set produced by the RETURNING clause when columns from both statements shared the same name, leading to potential errors or mis-adaptation when retrieving the returning rows. fixes #3248 --- lib/sqlalchemy/sql/compiler.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5fa78ad0f..8f3ede25f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1729,6 +1729,12 @@ class SQLCompiler(Compiled): ) def visit_insert(self, insert_stmt, **kw): + self.stack.append( + {'correlate_froms': set(), + "iswrapper": False, + "asfrom_froms": set(), + "selectable": insert_stmt}) + self.isinsert = True crud_params = crud._get_crud_params(self, insert_stmt, **kw) @@ -1812,6 +1818,8 @@ class SQLCompiler(Compiled): if self.returning and not self.returning_precedes_values: text += " " + returning_clause + self.stack.pop(-1) + return text def update_limit_clause(self, update_stmt): -- cgit v1.2.1 From 30075f9015c91d945c620af0d84c9c162627aa3c Mon Sep 17 00:00:00 2001 From: Jon Nelson Date: Tue, 11 Nov 2014 21:34:57 -0600 Subject: - don't do inline string interpolation when logging --- lib/sqlalchemy/dialects/mysql/base.py | 2 +- lib/sqlalchemy/orm/strategies.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 2fb054d0c..58eb3afa0 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2593,7 +2593,7 @@ class MySQLDialect(default.DefaultDialect): pass else: self.logger.info( - "Converting unknown KEY type %s to a plain KEY" % flavor) + "Converting unknown KEY type %s to a plain KEY", flavor) pass index_d = {} index_d['name'] = spec['name'] diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index cdb501c14..d95f17f64 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -373,7 +373,7 @@ class LazyLoader(AbstractRelationshipLoader): self._equated_columns[c] = self._equated_columns[col] self.logger.info("%s will use query.get() to " - "optimize instance loads" % self) + "optimize instance loads", self) def init_class_attribute(self, mapper): self.is_class_level = True -- cgit v1.2.1 From 026449c15ff35a9b89c2ca591d3e3cc791857272 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 13 Nov 2014 13:17:38 -0500 Subject: - Fixed a leak which would occur in the unsupported and highly non-recommended use case of replacing a relationship on a fixed mapped class many times, referring to an arbitrarily growing number of target mappers. A warning is emitted when the old relationship is replaced, however if the mapping were already used for querying, the old relationship would still be referenced within some registries. fixes #3251 --- lib/sqlalchemy/orm/mapper.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7e88ba161..863dab5cb 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1581,6 +1581,8 @@ class Mapper(InspectionAttr): self, prop, )) + oldprop = self._props[key] + self._path_registry.pop(oldprop, None) self._props[key] = prop -- cgit v1.2.1 From de9103aae22ba548323a3e469624f02d1d279103 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 14 Nov 2014 11:06:43 -0500 Subject: - correct this to rewrite a multiple profile line correctly --- lib/sqlalchemy/testing/profiling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index fcb888f86..6fc51ef32 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -115,7 +115,11 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] counts = per_platform['counts'] - counts[-1] = callcount + current_count = per_platform['current_count'] + if current_count < len(counts): + counts[current_count - 1] = callcount + else: + counts[-1] = callcount if self.write: self._write() -- cgit v1.2.1 From ba926a09b493b37c88e7b435aaccc6b399574057 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 24 Nov 2014 17:35:50 -0500 Subject: - add some logging to path_registry to help debug eager loading issues --- lib/sqlalchemy/orm/path_registry.py | 10 ++++++++++ lib/sqlalchemy/orm/strategy_options.py | 3 +++ 2 files changed, 13 insertions(+) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index f10a125a8..d4dbf29a0 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -13,6 +13,9 @@ from .. import util from .. import exc from itertools import chain from .base import class_mapper +import logging + +log = logging.getLogger(__name__) def _unreduce_path(path): @@ -54,9 +57,11 @@ class PathRegistry(object): self.path == other.path def set(self, attributes, key, value): + log.debug("set '%s' on path '%s' to '%s'", key, self, value) attributes[(key, self.path)] = value def setdefault(self, attributes, key, value): + log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value) attributes.setdefault((key, self.path), value) def get(self, attributes, key, value=None): @@ -184,6 +189,11 @@ class PropRegistry(PathRegistry): self.parent = parent self.path = parent.path + (prop,) + def __str__(self): + return " -> ".join( + str(elem) for elem in self.path + ) + @util.memoized_property def has_entity(self): return hasattr(self.prop, "mapper") diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 4f986193e..a4107202e 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -176,6 +176,9 @@ class Load(Generative, MapperOption): path = path.entity_path return path + def __str__(self): + return "Load(strategy=%r)" % self.strategy + def _coerce_strat(self, strategy): if strategy is not None: strategy = tuple(sorted(strategy.items())) -- cgit v1.2.1 From de11f9498258182cbb6668b72067ec3f43a90415 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 24 Nov 2014 18:49:32 -0500 Subject: - The :meth:`.PropComparator.of_type` modifier has been improved in conjunction with loader directives such as :func:`.joinedload` and :func:`.contains_eager` such that if two :meth:`.PropComparator.of_type` modifiers of the same base type/path are encountered, they will be joined together into a single "polymorphic" entity, rather than replacing the entity of type A with the one of type B. E.g. a joinedload of ``A.b.of_type(BSub1)->BSub1.c`` combined with joinedload of ``A.b.of_type(BSub2)->BSub2.c`` will create a single joinedload of ``A.b.of_type((BSub1, BSub2)) -> BSub1.c, BSub2.c``, without the need for the ``with_polymorphic`` to be explicit in the query. fixes #3256 --- lib/sqlalchemy/orm/strategy_options.py | 5 ++++- lib/sqlalchemy/orm/util.py | 22 +++++++++++++++++++--- lib/sqlalchemy/util/_collections.py | 9 ++++++--- 3 files changed, 29 insertions(+), 7 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index a4107202e..276da2ae0 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -161,11 +161,14 @@ class Load(Generative, MapperOption): ext_info = inspect(ac) path_element = ext_info.mapper + existing = path.entity_path[prop].get( + self.context, "path_with_polymorphic") if not ext_info.is_aliased_class: ac = orm_util.with_polymorphic( ext_info.mapper.base_mapper, ext_info.mapper, aliased=True, - _use_mapper_path=True) + _use_mapper_path=True, + _existing_alias=existing) path.entity_path[prop].set( self.context, "path_with_polymorphic", inspect(ac)) path = path[prop][path_element] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ad610a4ac..4be8d19ff 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -543,8 +543,13 @@ class AliasedInsp(InspectionAttr): mapper, self) def __repr__(self): - return '' % ( - id(self), self.class_.__name__) + if self.with_polymorphic_mappers: + with_poly = "(%s)" % ", ".join( + mp.class_.__name__ for mp in self.with_polymorphic_mappers) + else: + with_poly = "" + return '' % ( + id(self), self.class_.__name__, with_poly) inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) @@ -648,7 +653,8 @@ def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): def with_polymorphic(base, classes, selectable=False, flat=False, polymorphic_on=None, aliased=False, - innerjoin=False, _use_mapper_path=False): + innerjoin=False, _use_mapper_path=False, + _existing_alias=None): """Produce an :class:`.AliasedClass` construct which specifies columns for descendant mappers of the given base. @@ -713,6 +719,16 @@ def with_polymorphic(base, classes, selectable=False, only be specified if querying for one specific subtype only """ primary_mapper = _class_to_mapper(base) + if _existing_alias: + assert _existing_alias.mapper is primary_mapper + classes = util.to_set(classes) + new_classes = set([ + mp.class_ for mp in + _existing_alias.with_polymorphic_mappers]) + if classes == new_classes: + return _existing_alias + else: + classes = classes.union(new_classes) mappers, selectable = primary_mapper.\ _with_polymorphic_args(classes, selectable, innerjoin=innerjoin) diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index a1fbc0fa0..d36852698 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -10,9 +10,10 @@ from __future__ import absolute_import import weakref import operator -from .compat import threading, itertools_filterfalse +from .compat import threading, itertools_filterfalse, string_types from . import py2k import types +import collections EMPTY_SET = frozenset() @@ -779,10 +780,12 @@ def coerce_generator_arg(arg): def to_list(x, default=None): if x is None: return default - if not isinstance(x, (list, tuple)): + if not isinstance(x, collections.Iterable) or isinstance(x, string_types): return [x] - else: + elif isinstance(x, list): return x + else: + return list(x) def to_set(x): -- cgit v1.2.1 From 212d93366d1c5c3a8e44f8b428eeece6258ae28f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 25 Nov 2014 18:01:31 -0500 Subject: - The behavioral contract of the :attr:`.ForeignKeyConstraint.columns` collection has been made consistent; this attribute is now a :class:`.ColumnCollection` like that of all other constraints and is initialized at the point when the constraint is associated with a :class:`.Table`. fixes #3243 --- lib/sqlalchemy/dialects/sqlite/base.py | 4 +- lib/sqlalchemy/sql/compiler.py | 6 +- lib/sqlalchemy/sql/schema.py | 100 ++++++++++++++++++++------------- 3 files changed, 65 insertions(+), 45 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 335b35c94..33003297c 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -646,8 +646,8 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): def visit_foreign_key_constraint(self, constraint): - local_table = list(constraint._elements.values())[0].parent.table - remote_table = list(constraint._elements.values())[0].column.table + local_table = constraint.elements[0].parent.table + remote_table = constraint.elements[0].column.table if local_table.schema != remote_table.schema: return None diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8f3ede25f..b102f0240 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2286,14 +2286,14 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - remote_table = list(constraint._elements.values())[0].column.table + remote_table = list(constraint.elements)[0].column.table text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join(preparer.quote(f.parent.name) - for f in constraint._elements.values()), + for f in constraint.elements), self.define_constraint_remote_table( constraint, remote_table, preparer), ', '.join(preparer.quote(f.column.name) - for f in constraint._elements.values()) + for f in constraint.elements) ) text += self.define_constraint_match(constraint) text += self.define_constraint_cascades(constraint) diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 96cabbf4f..8b2eb12f0 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1804,7 +1804,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): match=self.match, **self._unvalidated_dialect_kw ) - self.constraint._elements[self.parent] = self + self.constraint._append_element(column, self) self.constraint._set_parent_with_dispatch(table) table.foreign_keys.add(self) @@ -2489,7 +2489,7 @@ class CheckConstraint(Constraint): return self._schema_item_copy(c) -class ForeignKeyConstraint(Constraint): +class ForeignKeyConstraint(ColumnCollectionConstraint): """A table-level FOREIGN KEY constraint. Defines a single column or composite FOREIGN KEY ... REFERENCES @@ -2564,9 +2564,10 @@ class ForeignKeyConstraint(Constraint): .. versionadded:: 0.9.2 """ - super(ForeignKeyConstraint, self).\ - __init__(name, deferrable, initially, info=info, **dialect_kw) + Constraint.__init__( + self, name=name, deferrable=deferrable, initially=initially, + info=info, **dialect_kw) self.onupdate = onupdate self.ondelete = ondelete self.link_to_name = link_to_name @@ -2575,14 +2576,12 @@ class ForeignKeyConstraint(Constraint): self.use_alter = use_alter self.match = match - self._elements = util.OrderedDict() - # standalone ForeignKeyConstraint - create # associated ForeignKey objects which will be applied to hosted # Column objects (in col.foreign_keys), either now or when attached # to the Table for string-specified names - for col, refcol in zip(columns, refcolumns): - self._elements[col] = ForeignKey( + self.elements = [ + ForeignKey( refcol, _constraint=self, name=self.name, @@ -2594,25 +2593,36 @@ class ForeignKeyConstraint(Constraint): deferrable=self.deferrable, initially=self.initially, **self.dialect_kwargs - ) + ) for refcol in refcolumns + ] + ColumnCollectionMixin.__init__(self, *columns) if table is not None: + if hasattr(self, "parent"): + assert table is self.parent self._set_parent_with_dispatch(table) - elif columns and \ - isinstance(columns[0], Column) and \ - columns[0].table is not None: - self._set_parent_with_dispatch(columns[0].table) + + def _append_element(self, column, fk): + self.columns.add(column) + self.elements.append(fk) + + @property + def _elements(self): + # legacy - provide a dictionary view of (column_key, fk) + return util.OrderedDict( + zip(self.column_keys, self.elements) + ) @property def _referred_schema(self): - for elem in self._elements.values(): + for elem in self.elements: return elem._referred_schema else: return None def _validate_dest_table(self, table): table_keys = set([elem._table_key() - for elem in self._elements.values()]) + for elem in self.elements]) if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( @@ -2625,38 +2635,48 @@ class ForeignKeyConstraint(Constraint): )) @property - def _col_description(self): - return ", ".join(self._elements) + def column_keys(self): + """Return a list of string keys representing the local + columns in this :class:`.ForeignKeyConstraint`. - @property - def columns(self): - return list(self._elements) + This list is either the original string arguments sent + to the constructor of the :class:`.ForeignKeyConstraint`, + or if the constraint has been initialized with :class:`.Column` + objects, is the string .key of each element. + + .. versionadded:: 1.0.0 + + """ + if hasattr(self, 'table'): + return self.columns.keys() + else: + return [ + col.key if isinstance(col, ColumnElement) + else str(col) for col in self._pending_colargs + ] @property - def elements(self): - return list(self._elements.values()) + def _col_description(self): + return ", ".join(self.column_keys) def _set_parent(self, table): - super(ForeignKeyConstraint, self)._set_parent(table) - - self._validate_dest_table(table) + Constraint._set_parent(self, table) - for col, fk in self._elements.items(): - # string-specified column names now get - # resolved to Column objects - if isinstance(col, util.string_types): - try: - col = table.c[col] - except KeyError: - raise exc.ArgumentError( - "Can't create ForeignKeyConstraint " - "on table '%s': no column " - "named '%s' is present." % (table.description, col)) + try: + ColumnCollectionConstraint._set_parent(self, table) + except KeyError as ke: + raise exc.ArgumentError( + "Can't create ForeignKeyConstraint " + "on table '%s': no column " + "named '%s' is present." % (table.description, ke.args[0])) + for col, fk in zip(self.columns, self.elements): if not hasattr(fk, 'parent') or \ fk.parent is not col: fk._set_parent_with_dispatch(col) + self._validate_dest_table(table) + if self.use_alter: def supports_alter(ddl, event, schema_item, bind, **kw): return table in set(kw['tables']) and \ @@ -2669,14 +2689,14 @@ class ForeignKeyConstraint(Constraint): def copy(self, schema=None, target_table=None, **kw): fkc = ForeignKeyConstraint( - [x.parent.key for x in self._elements.values()], + [x.parent.key for x in self.elements], [x._get_colspec( schema=schema, table_name=target_table.name if target_table is not None and x._table_key() == x.parent.table.key else None) - for x in self._elements.values()], + for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, @@ -2687,8 +2707,8 @@ class ForeignKeyConstraint(Constraint): match=self.match ) for self_fk, other_fk in zip( - self._elements.values(), - fkc._elements.values()): + self.elements, + fkc.elements): self_fk._schema_item_copy(other_fk) return self._schema_item_copy(fkc) -- cgit v1.2.1 From d69f44b78090c6795b0b73b3befef39af44b6918 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 25 Nov 2014 23:28:54 -0500 Subject: - add a new option --force-write-profiles to rewrite profiles even if they are passing --- lib/sqlalchemy/testing/plugin/plugin_base.py | 5 ++++- lib/sqlalchemy/testing/profiling.py | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 6696427dc..614a12133 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -93,7 +93,10 @@ def setup_options(make_option): help="Exclude tests with tag ") make_option("--write-profiles", action="store_true", dest="write_profiles", default=False, - help="Write/update profiling data.") + help="Write/update failing profiling data.") + make_option("--force-write-profiles", action="store_true", + dest="force_write_profiles", default=False, + help="Unconditionally write/update profiling data.") def configure_follower(follower_ident): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 6fc51ef32..671bbe32d 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -42,7 +42,11 @@ class ProfileStatsFile(object): """ def __init__(self, filename): - self.write = ( + self.force_write = ( + config.options is not None and + config.options.force_write_profiles + ) + self.write = self.force_write or ( config.options is not None and config.options.write_profiles ) @@ -239,7 +243,7 @@ def count_functions(variance=0.05): deviance = int(callcount * variance) failed = abs(callcount - expected_count) > deviance - if failed: + if failed or _profile_stats.force_write: if _profile_stats.write: _profile_stats.replace(callcount) else: -- cgit v1.2.1 From 79c0aa6b7320f94399634d02997faacbb6ced1d7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 25 Nov 2014 23:33:47 -0500 Subject: - use self.parent, not table here as there's an attributeerror trap for self.table that behaves differently in py3k --- lib/sqlalchemy/sql/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 8b2eb12f0..4093d7115 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2647,7 +2647,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): .. versionadded:: 1.0.0 """ - if hasattr(self, 'table'): + if hasattr(self, "parent"): return self.columns.keys() else: return [ -- cgit v1.2.1 From 99e51151244c7028fcc319d60e2e8ad1ba9e22bb Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 26 Nov 2014 13:50:43 -0500 Subject: - changelog, improve docstring/test for #3217. fixes #3217 --- lib/sqlalchemy/orm/query.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 884e04bbc..790686288 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1741,7 +1741,13 @@ class Query(object): and similar will adapt the incoming criterion to the target alias, until :meth:`~.Query.reset_joinpoint` is called. :param isouter=False: If True, the join used will be a left outer join, - just as if the ``outerjoin()`` method were called. + just as if the :meth:`.Query.outerjoin` method were called. This + flag is here to maintain consistency with the same flag as accepted + by :meth:`.FromClause.join` and other Core constructs. + + + .. versionadded:: 1.0.0 + :param from_joinpoint=False: When using ``aliased=True``, a setting of True here will cause the join to be from the most recent joined target, rather than starting back from the original -- cgit v1.2.1 From 98c2a679707432e6707ba70f1aebd10b28b861a3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 29 Nov 2014 14:44:26 -0500 Subject: - Fixed bug in :meth:`.Table.tometadata` method where the :class:`.CheckConstraint` associated with a :class:`.Boolean` or :class:`.Enum` type object would be doubled in the target table. The copy process now tracks the production of this constraint object as local to a type object. fixes #3260 --- lib/sqlalchemy/sql/schema.py | 16 ++++++++++------ lib/sqlalchemy/sql/sqltypes.py | 10 +++++----- 2 files changed, 15 insertions(+), 11 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4093d7115..b90f7fc53 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -824,7 +824,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause): table.append_constraint( c.copy(schema=fk_constraint_schema, target_table=table)) - else: + elif not c._type_bound: table.append_constraint( c.copy(schema=schema, target_table=table)) for index in self.indexes: @@ -1295,7 +1295,7 @@ class Column(SchemaItem, ColumnClause): # Constraint objects plus non-constraint-bound ForeignKey objects args = \ - [c.copy(**kw) for c in self.constraints] + \ + [c.copy(**kw) for c in self.constraints if not c._type_bound] + \ [c.copy(**kw) for c in self.foreign_keys if not c.constraint] type_ = self.type @@ -2254,7 +2254,7 @@ class Constraint(DialectKWArgs, SchemaItem): __visit_name__ = 'constraint' def __init__(self, name=None, deferrable=None, initially=None, - _create_rule=None, info=None, + _create_rule=None, info=None, _type_bound=False, **dialect_kw): """Create a SQL constraint. @@ -2304,6 +2304,7 @@ class Constraint(DialectKWArgs, SchemaItem): if info: self.info = info self._create_rule = _create_rule + self._type_bound = _type_bound util.set_creation_order(self) self._validate_dialect_kwargs(dialect_kw) @@ -2420,7 +2421,7 @@ class CheckConstraint(Constraint): def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None, info=None, _create_rule=None, - _autoattach=True): + _autoattach=True, _type_bound=False): """Construct a CHECK constraint. :param sqltext: @@ -2450,7 +2451,9 @@ class CheckConstraint(Constraint): """ super(CheckConstraint, self).\ - __init__(name, deferrable, initially, _create_rule, info=info) + __init__( + name, deferrable, initially, _create_rule, info=info, + _type_bound=_type_bound) self.sqltext = _literal_as_text(sqltext, warn=False) if table is not None: self._set_parent_with_dispatch(table) @@ -2485,7 +2488,8 @@ class CheckConstraint(Constraint): deferrable=self.deferrable, _create_rule=self._create_rule, table=target_table, - _autoattach=False) + _autoattach=False, + _type_bound=self._type_bound) return self._schema_item_copy(c) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 2729bc83e..7bf2f337c 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -998,13 +998,11 @@ class SchemaType(SchemaEventTarget): def adapt(self, impltype, **kw): schema = kw.pop('schema', self.schema) - # don't associate with MetaData as the hosting type + # don't associate with self.metadata as the hosting type # is already associated with it, avoid creating event # listeners - metadata = kw.pop('metadata', None) return impltype(name=self.name, schema=schema, - metadata=metadata, inherit_schema=self.inherit_schema, **kw) @@ -1165,7 +1163,8 @@ class Enum(String, SchemaType): type_coerce(column, self).in_(self.enums), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint) + self._should_create_constraint), + _type_bound=True ) assert e.table is table @@ -1303,7 +1302,8 @@ class Boolean(TypeEngine, SchemaType): type_coerce(column, self).in_([0, 1]), name=_defer_name(self.name), _create_rule=util.portable_instancemethod( - self._should_create_constraint) + self._should_create_constraint), + _type_bound=True ) assert e.table is table -- cgit v1.2.1 From 87bfcf91e9659893f17adf307090bc0a4a8a8f23 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 4 Dec 2014 12:01:19 -0500 Subject: - The :meth:`.PGDialect.has_table` method will now query against ``pg_catalog.pg_table_is_visible(c.oid)``, rather than testing for an exact schema match, when the schema name is None; this so that the method will also illustrate that temporary tables are present. Note that this is a behavioral change, as Postgresql allows a non-temporary table to silently overwrite an existing temporary table of the same name, so this changes the behavior of ``checkfirst`` in that unusual scenario. fixes #3264 --- lib/sqlalchemy/dialects/postgresql/base.py | 3 ++- lib/sqlalchemy/testing/suite/test_reflection.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index baa640eaa..034ee9076 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1942,7 +1942,8 @@ class PGDialect(default.DefaultDialect): cursor = connection.execute( sql.text( "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=current_schema() " + "n.oid=c.relnamespace where " + "pg_catalog.pg_table_is_visible(c.oid) " "and relname=:name", bindparams=[ sql.bindparam('name', util.text_type(table_name), diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 08b858b47..e58b6f068 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -128,6 +128,10 @@ class ComponentReflectionTest(fixtures.TablesTest): DDL("create temporary view user_tmp_v as " "select * from user_tmp") ) + event.listen( + user_tmp, "before_drop", + DDL("drop view user_tmp_v") + ) @classmethod def define_index(cls, metadata, users): -- cgit v1.2.1 From f5ff86983f9cc7914a89b96da1fd2638677d345b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 4 Dec 2014 18:29:56 -0500 Subject: - The :meth:`.Operators.match` operator is now handled such that the return type is not strictly assumed to be boolean; it now returns a :class:`.Boolean` subclass called :class:`.MatchType`. The type will still produce boolean behavior when used in Python expressions, however the dialect can override its behavior at result time. In the case of MySQL, while the MATCH operator is typically used in a boolean context within an expression, if one actually queries for the value of a match expression, a floating point value is returned; this value is not compatible with SQLAlchemy's C-based boolean processor, so MySQL's result-set behavior now follows that of the :class:`.Float` type. A new operator object ``notmatch_op`` is also added to better allow dialects to define the negation of a match operation. fixes #3263 --- lib/sqlalchemy/dialects/mysql/base.py | 9 +++++++++ lib/sqlalchemy/sql/compiler.py | 9 +++++++-- lib/sqlalchemy/sql/default_comparator.py | 21 ++++++++++++++++----- lib/sqlalchemy/sql/elements.py | 2 +- lib/sqlalchemy/sql/operators.py | 5 +++++ lib/sqlalchemy/sql/sqltypes.py | 17 +++++++++++++++++ lib/sqlalchemy/sql/type_api.py | 2 +- lib/sqlalchemy/types.py | 1 + 8 files changed, 57 insertions(+), 9 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 58eb3afa0..c868f58b2 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -602,6 +602,14 @@ class _StringType(sqltypes.String): to_inspect=[_StringType, sqltypes.String]) +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): + # TODO: float arguments? + sqltypes.Float.__init__(self) + sqltypes.MatchType.__init__(self) + + + class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" @@ -1544,6 +1552,7 @@ colspecs = { sqltypes.Float: FLOAT, sqltypes.Time: TIME, sqltypes.Enum: ENUM, + sqltypes.MatchType: _MatchType } # Everything 3.23 through 5.1 excepting OpenGIS types. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b102f0240..29a7401a1 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -82,6 +82,7 @@ OPERATORS = { operators.eq: ' = ', operators.concat_op: ' || ', operators.match_op: ' MATCH ', + operators.notmatch_op: ' NOT MATCH ', operators.in_op: ' IN ', operators.notin_op: ' NOT IN ', operators.comma_op: ', ', @@ -862,14 +863,18 @@ class SQLCompiler(Compiled): else: return "%s = 0" % self.process(element.element, **kw) - def visit_binary(self, binary, **kw): + def visit_notmatch_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op) + + def visit_binary(self, binary, override_operator=None, **kw): # don't allow "? = ?" to render if self.ansi_bind_rules and \ isinstance(binary.left, elements.BindParameter) and \ isinstance(binary.right, elements.BindParameter): kw['literal_binds'] = True - operator_ = binary.operator + operator_ = override_operator or binary.operator disp = getattr(self, "visit_%s_binary" % operator_.__name__, None) if disp: return disp(binary, operator_, **kw) diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 4f53e2979..d26fdc455 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -68,8 +68,12 @@ class _DefaultColumnComparator(operators.ColumnOperators): def _boolean_compare(self, expr, op, obj, negate=None, reverse=False, _python_is_types=(util.NoneType, bool), + result_type = None, **kwargs): + if result_type is None: + result_type = type_api.BOOLEANTYPE + if isinstance(obj, _python_is_types + (Null, True_, False_)): # allow x ==/!= True/False to be treated as a literal. @@ -80,7 +84,7 @@ class _DefaultColumnComparator(operators.ColumnOperators): return BinaryExpression(expr, _literal_as_text(obj), op, - type_=type_api.BOOLEANTYPE, + type_=result_type, negate=negate, modifiers=kwargs) else: # all other None/True/False uses IS, IS NOT @@ -103,13 +107,13 @@ class _DefaultColumnComparator(operators.ColumnOperators): return BinaryExpression(obj, expr, op, - type_=type_api.BOOLEANTYPE, + type_=result_type, negate=negate, modifiers=kwargs) else: return BinaryExpression(expr, obj, op, - type_=type_api.BOOLEANTYPE, + type_=result_type, negate=negate, modifiers=kwargs) def _binary_operate(self, expr, op, obj, reverse=False, result_type=None, @@ -125,7 +129,8 @@ class _DefaultColumnComparator(operators.ColumnOperators): op, result_type = left.comparator._adapt_expression( op, right.comparator) - return BinaryExpression(left, right, op, type_=result_type) + return BinaryExpression( + left, right, op, type_=result_type, modifiers=kw) def _conjunction_operate(self, expr, op, other, **kw): if op is operators.and_: @@ -216,11 +221,16 @@ class _DefaultColumnComparator(operators.ColumnOperators): def _match_impl(self, expr, op, other, **kw): """See :meth:`.ColumnOperators.match`.""" + return self._boolean_compare( expr, operators.match_op, self._check_literal( expr, operators.match_op, other), - **kw) + result_type=type_api.MATCHTYPE, + negate=operators.notmatch_op + if op is operators.match_op else operators.match_op, + **kw + ) def _distinct_impl(self, expr, op, **kw): """See :meth:`.ColumnOperators.distinct`.""" @@ -282,6 +292,7 @@ class _DefaultColumnComparator(operators.ColumnOperators): "isnot": (_boolean_compare, operators.isnot), "collate": (_collate_impl,), "match_op": (_match_impl,), + "notmatch_op": (_match_impl,), "distinct_op": (_distinct_impl,), "between_op": (_between_impl, ), "notbetween_op": (_between_impl, ), diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 734f78632..30965c801 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2763,7 +2763,7 @@ class BinaryExpression(ColumnElement): self.right, self.negate, negate=self.operator, - type_=type_api.BOOLEANTYPE, + type_=self.type, modifiers=self.modifiers) else: return super(BinaryExpression, self)._negate() diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 945356328..b08e44ab8 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -767,6 +767,10 @@ def match_op(a, b, **kw): return a.match(b, **kw) +def notmatch_op(a, b, **kw): + return a.notmatch(b, **kw) + + def comma_op(a, b): raise NotImplementedError() @@ -834,6 +838,7 @@ _PRECEDENCE = { concat_op: 6, match_op: 6, + notmatch_op: 6, ilike_op: 6, notilike_op: 6, diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 7bf2f337c..94db1d837 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1654,10 +1654,26 @@ class NullType(TypeEngine): comparator_factory = Comparator +class MatchType(Boolean): + """Refers to the return type of the MATCH operator. + + As the :meth:`.Operators.match` is probably the most open-ended + operator in generic SQLAlchemy Core, we can't assume the return type + at SQL evaluation time, as MySQL returns a floating point, not a boolean, + and other backends might do something different. So this type + acts as a placeholder, currently subclassing :class:`.Boolean`. + The type allows dialects to inject result-processing functionality + if needed, and on MySQL will return floating-point values. + + .. versionadded:: 1.0.0 + + """ + NULLTYPE = NullType() BOOLEANTYPE = Boolean() STRINGTYPE = String() INTEGERTYPE = Integer() +MATCHTYPE = MatchType() _type_map = { int: Integer(), @@ -1685,6 +1701,7 @@ type_api.BOOLEANTYPE = BOOLEANTYPE type_api.STRINGTYPE = STRINGTYPE type_api.INTEGERTYPE = INTEGERTYPE type_api.NULLTYPE = NULLTYPE +type_api.MATCHTYPE = MATCHTYPE type_api._type_map = _type_map # this one, there's all kinds of ways to play it, but at the EOD diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 77c6e1b1e..d3e0a008e 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -19,7 +19,7 @@ BOOLEANTYPE = None INTEGERTYPE = None NULLTYPE = None STRINGTYPE = None - +MATCHTYPE = None class TypeEngine(Visitable): """The ultimate base class for all SQL datatypes. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index b49e389ac..1215bd790 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -51,6 +51,7 @@ from .sql.sqltypes import ( Integer, Interval, LargeBinary, + MatchType, NCHAR, NVARCHAR, NullType, -- cgit v1.2.1 From fda589487b2cb60e8d69f520e0120eeb7c875915 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 4 Dec 2014 19:12:52 -0500 Subject: - Updated the "supports_unicode_statements" flag to True for MySQLdb and Pymysql under Python 2. This refers to the SQL statements themselves, not the parameters, and affects issues such as table and column names using non-ASCII characters. These drivers both appear to support Python 2 Unicode objects without issue in modern versions. fixes #3121 --- lib/sqlalchemy/dialects/mysql/mysqldb.py | 2 +- lib/sqlalchemy/dialects/mysql/pymysql.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 73210d67a..893c6a9e2 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -77,7 +77,7 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): class MySQLDialect_mysqldb(MySQLDialect): driver = 'mysqldb' - supports_unicode_statements = False + supports_unicode_statements = True supports_sane_rowcount = True supports_sane_multi_rowcount = True diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 31226cea0..8df2ba03f 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -31,8 +31,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): driver = 'pymysql' description_encoding = None - if py3k: - supports_unicode_statements = True + supports_unicode_statements = True @classmethod def dbapi(cls): -- cgit v1.2.1 From e46c71b4198ee9811ea851dbe037f19a74af0b08 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 4 Dec 2014 19:35:00 -0500 Subject: - Added support for CTEs under Oracle. This includes some tweaks to the aliasing syntax, as well as a new CTE feature :meth:`.CTE.suffix_with`, which is useful for adding in special Oracle-specific directives to the CTE. fixes #3220 --- lib/sqlalchemy/dialects/oracle/base.py | 21 ++---- lib/sqlalchemy/orm/query.py | 30 +++++++- lib/sqlalchemy/sql/compiler.py | 17 ++++- lib/sqlalchemy/sql/selectable.py | 131 ++++++++++++++++++++++----------- 4 files changed, 138 insertions(+), 61 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 6df38e57e..524ba8115 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -549,6 +549,9 @@ class OracleCompiler(compiler.SQLCompiler): def visit_false(self, expr, **kw): return '0' + def get_cte_preamble(self, recursive): + return "WITH" + def get_select_hint_text(self, byfroms): return " ".join( "/*+ %s */" % text for table, text in byfroms.items() @@ -619,22 +622,10 @@ class OracleCompiler(compiler.SQLCompiler): return (self.dialect.identifier_preparer.format_sequence(seq) + ".nextval") - def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs): - """Oracle doesn't like ``FROM table AS alias``. Is the AS standard - SQL?? - """ - - if asfrom or ashint: - alias_name = isinstance(alias.name, expression._truncated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name + def get_render_as_alias_suffix(self, alias_name_text): + """Oracle doesn't like ``FROM table AS alias``""" - if ashint: - return alias_name - elif asfrom: - return self.process(alias.original, asfrom=asfrom, **kwargs) + \ - " " + self.preparer.format_alias(alias, alias_name) - else: - return self.process(alias.original, **kwargs) + return " " + alias_name_text def returning_clause(self, stmt, returning_cols): columns = [] diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 790686288..9b7747e15 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -75,6 +75,7 @@ class Query(object): _having = None _distinct = False _prefixes = None + _suffixes = None _offset = None _limit = None _for_update_arg = None @@ -1003,7 +1004,7 @@ class Query(object): '_limit', '_offset', '_joinpath', '_joinpoint', '_distinct', '_having', - '_prefixes', + '_prefixes', '_suffixes' ): self.__dict__.pop(attr, None) self._set_select_from([fromclause], True) @@ -2359,12 +2360,38 @@ class Query(object): .. versionadded:: 0.7.7 + .. seealso:: + + :meth:`.HasPrefixes.prefix_with` + """ if self._prefixes: self._prefixes += prefixes else: self._prefixes = prefixes + @_generative() + def suffix_with(self, *suffixes): + """Apply the suffix to the query and return the newly resulting + ``Query``. + + :param \*suffixes: optional suffixes, typically strings, + not using any commas. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`.Query.prefix_with` + + :meth:`.HasSuffixes.suffix_with` + + """ + if self._suffixes: + self._suffixes += suffixes + else: + self._suffixes = suffixes + def all(self): """Return the results represented by this ``Query`` as a list. @@ -2601,6 +2628,7 @@ class Query(object): 'offset': self._offset, 'distinct': self._distinct, 'prefixes': self._prefixes, + 'suffixes': self._suffixes, 'group_by': self._group_by or None, 'having': self._having } diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 29a7401a1..9304bba9f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1193,12 +1193,16 @@ class SQLCompiler(Compiled): self, asfrom=True, **kwargs ) + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs) + self.ctes[cte] = text if asfrom: if cte_alias_name: text = self.preparer.format_alias(cte, cte_alias_name) - text += " AS " + cte_name + text += self.get_render_as_alias_suffix(cte_name) else: return self.preparer.format_alias(cte, cte_name) return text @@ -1217,8 +1221,8 @@ class SQLCompiler(Compiled): elif asfrom: ret = alias.original._compiler_dispatch(self, asfrom=True, **kwargs) + \ - " AS " + \ - self.preparer.format_alias(alias, alias_name) + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name)) if fromhints and alias in fromhints: ret = self.format_from_hint_text(ret, alias, @@ -1228,6 +1232,9 @@ class SQLCompiler(Compiled): else: return alias.original._compiler_dispatch(self, **kwargs) + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + def _add_to_result_map(self, keyname, name, objects, type_): if not self.dialect.case_sensitive: keyname = keyname.lower() @@ -1554,6 +1561,10 @@ class SQLCompiler(Compiled): compound_index == 0 and toplevel: text = self._render_cte_clause() + text + if select._suffixes: + text += " " + self._generate_prefixes( + select, select._suffixes, **kwargs) + self.stack.pop(-1) if asfrom and parens: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8198a6733..87029ec2b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -171,6 +171,79 @@ class Selectable(ClauseElement): return self +class HasPrefixes(object): + _prefixes = () + + @_generative + def prefix_with(self, *expr, **kw): + """Add one or more expressions following the statement keyword, i.e. + SELECT, INSERT, UPDATE, or DELETE. Generative. + + This is used to support backend-specific prefix keywords such as those + provided by MySQL. + + E.g.:: + + stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") + + Multiple prefixes can be specified by multiple calls + to :meth:`.prefix_with`. + + :param \*expr: textual or :class:`.ClauseElement` construct which + will be rendered following the INSERT, UPDATE, or DELETE + keyword. + :param \**kw: A single keyword 'dialect' is accepted. This is an + optional string dialect name which will + limit rendering of this prefix to only that dialect. + + """ + dialect = kw.pop('dialect', None) + if kw: + raise exc.ArgumentError("Unsupported argument(s): %s" % + ",".join(kw)) + self._setup_prefixes(expr, dialect) + + def _setup_prefixes(self, prefixes, dialect=None): + self._prefixes = self._prefixes + tuple( + [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) + + +class HasSuffixes(object): + _suffixes = () + + @_generative + def suffix_with(self, *expr, **kw): + """Add one or more expressions following the statement as a whole. + + This is used to support backend-specific suffix keywords on + certain constructs. + + E.g.:: + + stmt = select([col1, col2]).cte().suffix_with( + "cycle empno set y_cycle to 1 default 0", dialect="oracle") + + Multiple prefixes can be specified by multiple calls + to :meth:`.suffix_with`. + + :param \*expr: textual or :class:`.ClauseElement` construct which + will be rendered following the target clause. + :param \**kw: A single keyword 'dialect' is accepted. This is an + optional string dialect name which will + limit rendering of this suffix to only that dialect. + + """ + dialect = kw.pop('dialect', None) + if kw: + raise exc.ArgumentError("Unsupported argument(s): %s" % + ",".join(kw)) + self._setup_suffixes(expr, dialect) + + def _setup_suffixes(self, suffixes, dialect=None): + self._suffixes = self._suffixes + tuple( + [(_literal_as_text(p, warn=False), dialect) for p in suffixes]) + + class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -1088,7 +1161,7 @@ class Alias(FromClause): return self.element.bind -class CTE(Alias): +class CTE(Generative, HasSuffixes, Alias): """Represent a Common Table Expression. The :class:`.CTE` object is obtained using the @@ -1104,10 +1177,13 @@ class CTE(Alias): name=None, recursive=False, _cte_alias=None, - _restates=frozenset()): + _restates=frozenset(), + _suffixes=None): self.recursive = recursive self._cte_alias = _cte_alias self._restates = _restates + if _suffixes: + self._suffixes = _suffixes super(CTE, self).__init__(selectable, name=name) def alias(self, name=None, flat=False): @@ -1116,6 +1192,7 @@ class CTE(Alias): name=name, recursive=self.recursive, _cte_alias=self, + _suffixes=self._suffixes ) def union(self, other): @@ -1123,7 +1200,8 @@ class CTE(Alias): self.original.union(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]) + _restates=self._restates.union([self]), + _suffixes=self._suffixes ) def union_all(self, other): @@ -1131,7 +1209,8 @@ class CTE(Alias): self.original.union_all(other), name=self.name, recursive=self.recursive, - _restates=self._restates.union([self]) + _restates=self._restates.union([self]), + _suffixes=self._suffixes ) @@ -2118,44 +2197,7 @@ class CompoundSelect(GenerativeSelect): bind = property(bind, _set_bind) -class HasPrefixes(object): - _prefixes = () - - @_generative - def prefix_with(self, *expr, **kw): - """Add one or more expressions following the statement keyword, i.e. - SELECT, INSERT, UPDATE, or DELETE. Generative. - - This is used to support backend-specific prefix keywords such as those - provided by MySQL. - - E.g.:: - - stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql") - - Multiple prefixes can be specified by multiple calls - to :meth:`.prefix_with`. - - :param \*expr: textual or :class:`.ClauseElement` construct which - will be rendered following the INSERT, UPDATE, or DELETE - keyword. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will - limit rendering of this prefix to only that dialect. - - """ - dialect = kw.pop('dialect', None) - if kw: - raise exc.ArgumentError("Unsupported argument(s): %s" % - ",".join(kw)) - self._setup_prefixes(expr, dialect) - - def _setup_prefixes(self, prefixes, dialect=None): - self._prefixes = self._prefixes + tuple( - [(_literal_as_text(p, warn=False), dialect) for p in prefixes]) - - -class Select(HasPrefixes, GenerativeSelect): +class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """Represents a ``SELECT`` statement. """ @@ -2163,6 +2205,7 @@ class Select(HasPrefixes, GenerativeSelect): __visit_name__ = 'select' _prefixes = () + _suffixes = () _hints = util.immutabledict() _statement_hints = () _distinct = False @@ -2180,6 +2223,7 @@ class Select(HasPrefixes, GenerativeSelect): having=None, correlate=True, prefixes=None, + suffixes=None, **kwargs): """Construct a new :class:`.Select`. @@ -2425,6 +2469,9 @@ class Select(HasPrefixes, GenerativeSelect): if prefixes: self._setup_prefixes(prefixes) + if suffixes: + self._setup_suffixes(suffixes) + GenerativeSelect.__init__(self, **kwargs) @property -- cgit v1.2.1 From edef95379777a9c84ee7dbcbc9a3b58849aa8930 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 4 Dec 2014 20:08:07 -0500 Subject: - New Oracle DDL features for tables, indexes: COMPRESS, BITMAP. Patch courtesy Gabor Gombas. fixes #3127 --- lib/sqlalchemy/dialects/oracle/base.py | 165 +++++++++++++++++++++++++++++++-- lib/sqlalchemy/engine/reflection.py | 10 +- 2 files changed, 165 insertions(+), 10 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 524ba8115..9f375da94 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -213,6 +213,8 @@ is reflected and the type is reported as ``DATE``, the time-supporting examining the type of column for use in special Python translations or for migrating schemas to other database backends. +.. _oracle_table_options: + Oracle Table Options ------------------------- @@ -228,15 +230,63 @@ in conjunction with the :class:`.Table` construct: .. versionadded:: 1.0.0 +* ``COMPRESS``:: + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=True) + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=6) + + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. + +.. versionadded:: 1.0.0 + +.. _oracle_index_options: + +Oracle Specific Index Options +----------------------------- + +Bitmap Indexes +~~~~~~~~~~~~~~ + +You can specify the ``oracle_bitmap`` parameter to create a bitmap index +instead of a B-tree index:: + + Index('my_index', my_table.c.data, oracle_bitmap=True) + +Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not +check for such limitations, only the database will. + +.. versionadded:: 1.0.0 + +Index compression +~~~~~~~~~~~~~~~~~ + +Oracle has a more efficient storage mode for indexes containing lots of +repeated values. Use the ``oracle_compress`` parameter to turn on key c +ompression:: + + Index('my_index', my_table.c.data, oracle_compress=True) + + Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, + oracle_compress=1) + +The ``oracle_compress`` parameter accepts either an integer specifying the +number of prefix columns to compress, or ``True`` to use the default (all +columns for non-unique indexes, all but the last column for unique indexes). + +.. versionadded:: 1.0.0 + """ import re from sqlalchemy import util, sql -from sqlalchemy.engine import default, base, reflection +from sqlalchemy.engine import default, reflection from sqlalchemy.sql import compiler, visitors, expression -from sqlalchemy.sql import (operators as sql_operators, - functions as sql_functions) +from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes, schema as sa_schema from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ BLOB, CLOB, TIMESTAMP, FLOAT @@ -786,9 +836,32 @@ class OracleDDLCompiler(compiler.DDLCompiler): return text - def visit_create_index(self, create, **kw): - return super(OracleDDLCompiler, self).\ - visit_create_index(create, include_schema=True) + def visit_create_index(self, create): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.dialect_options['oracle']['bitmap']: + text += "BITMAP " + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=True), + ', '.join( + self.sql_compiler.process( + expr, + include_table=False, literal_binds=True) + for expr in index.expressions) + ) + if index.dialect_options['oracle']['compress'] is not False: + if index.dialect_options['oracle']['compress'] is True: + text += " COMPRESS" + else: + text += " COMPRESS %d" % ( + index.dialect_options['oracle']['compress'] + ) + return text def post_create_table(self, table): table_opts = [] @@ -798,6 +871,14 @@ class OracleDDLCompiler(compiler.DDLCompiler): on_commit_options = opts['on_commit'].replace("_", " ").upper() table_opts.append('\n ON COMMIT %s' % on_commit_options) + if opts['compress']: + if opts['compress'] is True: + table_opts.append("\n COMPRESS") + else: + table_opts.append("\n COMPRESS FOR %s" % ( + opts['compress'] + )) + return ''.join(table_opts) @@ -861,7 +942,12 @@ class OracleDialect(default.DefaultDialect): construct_arguments = [ (sa_schema.Table, { "resolve_synonyms": False, - "on_commit": None + "on_commit": None, + "compress": False + }), + (sa_schema.Index, { + "bitmap": False, + "compress": False }) ] @@ -892,6 +978,16 @@ class OracleDialect(default.DefaultDialect): return self.server_version_info and \ self.server_version_info < (9, ) + @property + def _supports_table_compression(self): + return self.server_version_info and \ + self.server_version_info >= (9, 2, ) + + @property + def _supports_table_compress_for(self): + return self.server_version_info and \ + self.server_version_info >= (11, ) + @property def _supports_char_length(self): return not self._is_oracle_8 @@ -1074,6 +1170,50 @@ class OracleDialect(default.DefaultDialect): cursor = connection.execute(s, owner=self.denormalize_name(schema)) return [self.normalize_name(row[0]) for row in cursor] + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + options = {} + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + params = {"table_name": table_name} + + columns = ["table_name"] + if self._supports_table_compression: + columns.append("compression") + if self._supports_table_compress_for: + columns.append("compress_for") + + text = "SELECT %(columns)s "\ + "FROM ALL_TABLES%(dblink)s "\ + "WHERE table_name = :table_name" + + if schema is not None: + params['owner'] = schema + text += " AND owner = :owner " + text = text % {'dblink': dblink, 'columns': ", ".join(columns)} + + result = connection.execute(sql.text(text), **params) + + enabled = dict(DISABLED=False, ENABLED=True) + + row = result.first() + if row: + if "compression" in row and enabled.get(row.compression, False): + if "compress_for" in row: + options['oracle_compress'] = row.compress_for + else: + options['oracle_compress'] = True + + return options + @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): """ @@ -1159,7 +1299,8 @@ class OracleDialect(default.DefaultDialect): params = {'table_name': table_name} text = \ - "SELECT a.index_name, a.column_name, b.uniqueness "\ + "SELECT a.index_name, a.column_name, "\ + "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "\ "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\ "\nALL_INDEXES%(dblink)s b "\ "\nWHERE "\ @@ -1185,6 +1326,7 @@ class OracleDialect(default.DefaultDialect): dblink=dblink, info_cache=kw.get('info_cache')) pkeys = pk_constraint['constrained_columns'] uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + enabled = dict(DISABLED=False, ENABLED=True) oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) @@ -1204,10 +1346,15 @@ class OracleDialect(default.DefaultDialect): if rset.index_name != last_index_name: remove_if_primary_key(index) index = dict(name=self.normalize_name(rset.index_name), - column_names=[]) + column_names=[], dialect_options={}) indexes.append(index) index['unique'] = uniqueness.get(rset.uniqueness, False) + if rset.index_type in ('BITMAP', 'FUNCTION-BASED BITMAP'): + index['dialect_options']['oracle_bitmap'] = True + if enabled.get(rset.compression, False): + index['dialect_options']['oracle_compress'] = rset.prefix_length + # filter out Oracle SYS_NC names. could also do an outer join # to the all_tab_columns table and check for real col names there. if not oracle_sys_col.match(rset.column_name): diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 2a1def86a..ebc96f5dd 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -394,6 +394,9 @@ class Inspector(object): unique boolean + dialect_options + dict of dialect-specific index options + :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -642,6 +645,8 @@ class Inspector(object): columns = index_d['column_names'] unique = index_d['unique'] flavor = index_d.get('type', 'index') + dialect_options = index_d.get('dialect_options', {}) + duplicates = index_d.get('duplicates_constraint') if include_columns and \ not set(columns).issubset(include_columns): @@ -667,7 +672,10 @@ class Inspector(object): else: idx_cols.append(idx_col) - sa_schema.Index(name, *idx_cols, **dict(unique=unique)) + sa_schema.Index( + name, *idx_cols, + **dict(list(dialect_options.items()) + [('unique', unique)]) + ) def _reflect_unique_constraints( self, table_name, schema, table, cols_by_orig_name, -- cgit v1.2.1 From 41e7253dee168b8c26c4993d27aac11f98c7f9e3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 5 Dec 2014 12:12:44 -0500 Subject: - The engine-level error handling and wrapping routines will now take effect in all engine connection use cases, including when user-custom connect routines are used via the :paramref:`.create_engine.creator` parameter, as well as when the :class:`.Connection` encounters a connection error on revalidation. fixes #3266 --- lib/sqlalchemy/engine/base.py | 74 +++++++++++++++++++++++++++++++++--- lib/sqlalchemy/engine/interfaces.py | 18 ++++++++- lib/sqlalchemy/engine/strategies.py | 11 +----- lib/sqlalchemy/engine/threadlocal.py | 2 +- lib/sqlalchemy/events.py | 6 +++ 5 files changed, 93 insertions(+), 18 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dd82be1d1..901ab07eb 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -276,7 +276,7 @@ class Connection(Connectable): raise exc.InvalidRequestError( "Can't reconnect until invalid " "transaction is rolled back") - self.__connection = self.engine.raw_connection() + self.__connection = self.engine.raw_connection(self) self.__invalid = False return self.__connection raise exc.ResourceClosedError("This Connection is closed") @@ -1194,7 +1194,8 @@ class Connection(Connectable): # new handle_error event ctx = ExceptionContextImpl( - e, sqlalchemy_exception, self, cursor, statement, + e, sqlalchemy_exception, self.engine, + self, cursor, statement, parameters, context, self._is_disconnect) for fn in self.dispatch.handle_error: @@ -1242,6 +1243,58 @@ class Connection(Connectable): if self.should_close_with_result: self.close() + @classmethod + def _handle_dbapi_exception_noconnection( + cls, e, dialect, engine, connection): + exc_info = sys.exc_info() + + is_disconnect = dialect.is_disconnect(e, None, None) + + should_wrap = isinstance(e, dialect.dbapi.Error) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + None, + None, + e, + dialect.dbapi.Error, + connection_invalidated=is_disconnect) + else: + sqlalchemy_exception = None + + newraise = None + + if engine._has_events: + ctx = ExceptionContextImpl( + e, sqlalchemy_exception, engine, connection, None, None, + None, None, is_disconnect) + for fn in engine.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if sqlalchemy_exception and \ + is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = \ + is_disconnect = ctx.is_disconnect + + if newraise: + util.raise_from_cause(newraise, exc_info) + elif should_wrap: + util.raise_from_cause( + sqlalchemy_exception, + exc_info + ) + else: + util.reraise(*exc_info) + def default_schema_name(self): return self.engine.dialect.get_default_schema_name(self) @@ -1320,8 +1373,9 @@ class ExceptionContextImpl(ExceptionContext): """Implement the :class:`.ExceptionContext` interface.""" def __init__(self, exception, sqlalchemy_exception, - connection, cursor, statement, parameters, + engine, connection, cursor, statement, parameters, context, is_disconnect): + self.engine = engine self.connection = connection self.sqlalchemy_exception = sqlalchemy_exception self.original_exception = exception @@ -1898,7 +1952,15 @@ class Engine(Connectable, log.Identified): """ return self.run_callable(self.dialect.has_table, table_name, schema) - def raw_connection(self): + def _wrap_pool_connect(self, fn, connection=None): + dialect = self.dialect + try: + return fn() + except dialect.dbapi.Error as e: + Connection._handle_dbapi_exception_noconnection( + e, dialect, self, connection) + + def raw_connection(self, _connection=None): """Return a "raw" DBAPI connection from the connection pool. The returned object is a proxied version of the DBAPI @@ -1914,8 +1976,8 @@ class Engine(Connectable, log.Identified): :meth:`.Engine.connect` method. """ - - return self.pool.unique_connection() + return self._wrap_pool_connect( + self.pool.unique_connection, _connection) class OptionEngine(Engine): diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 0ad2efae0..5f66e54b5 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -917,7 +917,23 @@ class ExceptionContext(object): connection = None """The :class:`.Connection` in use during the exception. - This member is always present. + This member is present, except in the case of a failure when + first connecting. + + .. seealso:: + + :attr:`.ExceptionContext.engine` + + + """ + + engine = None + """The :class:`.Engine` in use during the exception. + + This member should always be present, even in the case of a failure + when first connecting. + + .. versionadded:: 1.0.0 """ diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 398ef8df6..fd665ad03 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -86,16 +86,7 @@ class DefaultEngineStrategy(EngineStrategy): pool = pop_kwarg('pool', None) if pool is None: def connect(): - try: - return dialect.connect(*cargs, **cparams) - except dialect.dbapi.Error as e: - invalidated = dialect.is_disconnect(e, None, None) - util.raise_from_cause( - exc.DBAPIError.instance( - None, None, e, dialect.dbapi.Error, - connection_invalidated=invalidated - ) - ) + return dialect.connect(*cargs, **cparams) creator = pop_kwarg('creator', connect) diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 637523a0e..71caac626 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -59,7 +59,7 @@ class TLEngine(base.Engine): # guards against pool-level reapers, if desired. # or not connection.connection.is_valid: connection = self._tl_connection_cls( - self, self.pool.connect(), **kw) + self, self._wrap_pool_connect(self.pool.connect), **kw) self._connections.conn = weakref.ref(connection) return connection._increment_connect() diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py index c144902cd..8600c20f5 100644 --- a/lib/sqlalchemy/events.py +++ b/lib/sqlalchemy/events.py @@ -739,6 +739,12 @@ class ConnectionEvents(event.Events): .. versionadded:: 0.9.7 Added the :meth:`.ConnectionEvents.handle_error` hook. + .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is now + invoked when an :class:`.Engine` fails during the initial + call to :meth:`.Engine.connect`, as well as when a + :class:`.Connection` object encounters an error during a + reconnect operation. + .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is not fired off when a dialect makes use of the ``skip_user_error_events`` execution option. This is used -- cgit v1.2.1 From d204e61f63756f2bbd3322377a283fc995e562ec Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 5 Dec 2014 12:18:11 -0500 Subject: - document / work around that dialect_options isn't necessarily there --- lib/sqlalchemy/engine/reflection.py | 5 ++++- lib/sqlalchemy/testing/suite/test_reflection.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index ebc96f5dd..25f084c15 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -395,7 +395,10 @@ class Inspector(object): boolean dialect_options - dict of dialect-specific index options + dict of dialect-specific index options. May not be present + for all dialects. + + .. versionadded:: 1.0.0 :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index e58b6f068..3edbdeb8c 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -515,6 +515,8 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_temp_table_indexes(self): insp = inspect(self.metadata.bind) indexes = insp.get_indexes('user_tmp') + for ind in indexes: + ind.pop('dialect_options', None) eq_( # TODO: we need to add better filtering for indexes/uq constraints # that are doubled up -- cgit v1.2.1 From 0ce045bd853ec078943c14fc93b87897d2169882 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 5 Dec 2014 14:46:43 -0500 Subject: - The SQLite dialect, when using the :class:`.sqlite.DATE`, :class:`.sqlite.TIME`, or :class:`.sqlite.DATETIME` types, and given a ``storage_format`` that only renders numbers, will render the types in DDL as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``, so that despite the lack of alpha characters in the values, the column will still deliver the "text affinity". Normally this is not needed, as the textual values within the default storage formats already imply text. fixes #3257 --- lib/sqlalchemy/dialects/sqlite/base.py | 60 +++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 33003297c..ccd7f2539 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -9,6 +9,7 @@ .. dialect:: sqlite :name: SQLite +.. _sqlite_datetime: Date and Time Types ------------------- @@ -23,6 +24,20 @@ These types represent dates and times as ISO formatted strings, which also nicely support ordering. There's no reliance on typical "libc" internals for these functions so historical dates are fully supported. +Ensuring Text affinity +^^^^^^^^^^^^^^^^^^^^^^ + +The DDL rendered for these types is the standard ``DATE``, ``TIME`` +and ``DATETIME`` indicators. However, custom storage formats can also be +applied to these types. When the +storage format is detected as containing no alpha characters, the DDL for +these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``, +so that the column continues to have textual affinity. + +.. seealso:: + + `Type Affinity `_ - in the SQLite documentation + .. _sqlite_autoincrement: SQLite Auto Incrementing Behavior @@ -255,7 +270,7 @@ from ... import util from ...engine import default, reflection from ...sql import compiler -from ...types import (BLOB, BOOLEAN, CHAR, DATE, DECIMAL, FLOAT, +from ...types import (BLOB, BOOLEAN, CHAR, DECIMAL, FLOAT, INTEGER, REAL, NUMERIC, SMALLINT, TEXT, TIMESTAMP, VARCHAR) @@ -271,6 +286,25 @@ class _DateTimeMixin(object): if storage_format is not None: self._storage_format = storage_format + @property + def format_is_text_affinity(self): + """return True if the storage format will automatically imply + a TEXT affinity. + + If the storage format contains no non-numeric characters, + it will imply a NUMERIC storage format on SQLite; in this case, + the type will generate its DDL as DATE_CHAR, DATETIME_CHAR, + TIME_CHAR. + + .. versionadded:: 1.0.0 + + """ + spec = self._storage_format % { + "year": 0, "month": 0, "day": 0, "hour": 0, + "minute": 0, "second": 0, "microsecond": 0 + } + return bool(re.search(r'[^0-9]', spec)) + def adapt(self, cls, **kw): if issubclass(cls, _DateTimeMixin): if self._storage_format: @@ -526,7 +560,9 @@ ischema_names = { 'BOOLEAN': sqltypes.BOOLEAN, 'CHAR': sqltypes.CHAR, 'DATE': sqltypes.DATE, + 'DATE_CHAR': sqltypes.DATE, 'DATETIME': sqltypes.DATETIME, + 'DATETIME_CHAR': sqltypes.DATETIME, 'DOUBLE': sqltypes.FLOAT, 'DECIMAL': sqltypes.DECIMAL, 'FLOAT': sqltypes.FLOAT, @@ -537,6 +573,7 @@ ischema_names = { 'SMALLINT': sqltypes.SMALLINT, 'TEXT': sqltypes.TEXT, 'TIME': sqltypes.TIME, + 'TIME_CHAR': sqltypes.TIME, 'TIMESTAMP': sqltypes.TIMESTAMP, 'VARCHAR': sqltypes.VARCHAR, 'NVARCHAR': sqltypes.NVARCHAR, @@ -670,6 +707,27 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): def visit_large_binary(self, type_): return self.visit_BLOB(type_) + def visit_DATETIME(self, type_): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) + else: + return "DATETIME_CHAR" + + def visit_DATE(self, type_): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_DATE(type_) + else: + return "DATE_CHAR" + + def visit_TIME(self, type_): + if not isinstance(type_, _DateTimeMixin) or \ + type_.format_is_text_affinity: + return super(SQLiteTypeCompiler, self).visit_TIME(type_) + else: + return "TIME_CHAR" + class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = set([ -- cgit v1.2.1 From 0639c199a547343d62134d2f233225fd2862ec45 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 5 Dec 2014 16:34:43 -0500 Subject: - move inner calls to _revalidate_connection() outside of existing _handle_dbapi_error(); these are now handled already and the reentrant call is not needed / breaks things. Adjustment to 41e7253dee168b8c26c49 / --- lib/sqlalchemy/engine/base.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 901ab07eb..235e1bf43 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -814,11 +814,11 @@ class Connection(Connectable): fn(self, default, multiparams, params) try: - try: - conn = self.__connection - except AttributeError: - conn = self._revalidate_connection() + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection() + try: dialect = self.dialect ctx = dialect.execution_ctx_cls._init_default( dialect, self, conn) @@ -952,11 +952,11 @@ class Connection(Connectable): a :class:`.ResultProxy`.""" try: - try: - conn = self.__connection - except AttributeError: - conn = self._revalidate_connection() + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection() + try: context = constructor(dialect, self, conn, *args) except Exception as e: self._handle_dbapi_exception(e, @@ -1246,6 +1246,7 @@ class Connection(Connectable): @classmethod def _handle_dbapi_exception_noconnection( cls, e, dialect, engine, connection): + exc_info = sys.exc_info() is_disconnect = dialect.is_disconnect(e, None, None) -- cgit v1.2.1 From b8114a357684ab3232ff90ceb0da16dad080d1ac Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 5 Dec 2014 19:08:47 -0500 Subject: - adjust _revalidate_connection() again such that we pass a _wrap=False to it, so that we say we will do the wrapping just once right here in _execute_context() / _execute_default(). An adjustment is made to _handle_dbapi_error() to not assume self.__connection in case we are already in an invalidated state further adjustment to 0639c199a547343d62134d2f233225fd2862ec45, 41e7253dee168b8c26c49, #3266 --- lib/sqlalchemy/engine/base.py | 46 ++++++++++++++++++++---------------- lib/sqlalchemy/engine/threadlocal.py | 5 +++- 2 files changed, 30 insertions(+), 21 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 235e1bf43..23348469d 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -265,18 +265,18 @@ class Connection(Connectable): try: return self.__connection except AttributeError: - return self._revalidate_connection() + return self._revalidate_connection(_wrap=True) - def _revalidate_connection(self): + def _revalidate_connection(self, _wrap): if self.__branch_from: - return self.__branch_from._revalidate_connection() - + return self.__branch_from._revalidate_connection(_wrap=_wrap) if self.__can_reconnect and self.__invalid: if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " "transaction is rolled back") - self.__connection = self.engine.raw_connection(self) + self.__connection = self.engine.raw_connection( + _connection=self, _wrap=_wrap) self.__invalid = False return self.__connection raise exc.ResourceClosedError("This Connection is closed") @@ -814,11 +814,11 @@ class Connection(Connectable): fn(self, default, multiparams, params) try: - conn = self.__connection - except AttributeError: - conn = self._revalidate_connection() + try: + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection(_wrap=False) - try: dialect = self.dialect ctx = dialect.execution_ctx_cls._init_default( dialect, self, conn) @@ -952,16 +952,17 @@ class Connection(Connectable): a :class:`.ResultProxy`.""" try: - conn = self.__connection - except AttributeError: - conn = self._revalidate_connection() + try: + conn = self.__connection + except AttributeError: + conn = self._revalidate_connection(_wrap=False) - try: context = constructor(dialect, self, conn, *args) except Exception as e: - self._handle_dbapi_exception(e, - util.text_type(statement), parameters, - None, None) + self._handle_dbapi_exception( + e, + util.text_type(statement), parameters, + None, None) if context.compiled: context.pre_exec() @@ -1149,7 +1150,10 @@ class Connection(Connectable): self._is_disconnect = \ isinstance(e, self.dialect.dbapi.Error) and \ not self.closed and \ - self.dialect.is_disconnect(e, self.__connection, cursor) + self.dialect.is_disconnect( + e, + self.__connection if not self.invalidated else None, + cursor) if context: context.is_disconnect = self._is_disconnect @@ -1953,7 +1957,9 @@ class Engine(Connectable, log.Identified): """ return self.run_callable(self.dialect.has_table, table_name, schema) - def _wrap_pool_connect(self, fn, connection=None): + def _wrap_pool_connect(self, fn, connection, wrap=True): + if not wrap: + return fn() dialect = self.dialect try: return fn() @@ -1961,7 +1967,7 @@ class Engine(Connectable, log.Identified): Connection._handle_dbapi_exception_noconnection( e, dialect, self, connection) - def raw_connection(self, _connection=None): + def raw_connection(self, _connection=None, _wrap=True): """Return a "raw" DBAPI connection from the connection pool. The returned object is a proxied version of the DBAPI @@ -1978,7 +1984,7 @@ class Engine(Connectable, log.Identified): """ return self._wrap_pool_connect( - self.pool.unique_connection, _connection) + self.pool.unique_connection, _connection, _wrap) class OptionEngine(Engine): diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 71caac626..824b68fdf 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -59,7 +59,10 @@ class TLEngine(base.Engine): # guards against pool-level reapers, if desired. # or not connection.connection.is_valid: connection = self._tl_connection_cls( - self, self._wrap_pool_connect(self.pool.connect), **kw) + self, + self._wrap_pool_connect( + self.pool.connect, connection, wrap=True), + **kw) self._connections.conn = weakref.ref(connection) return connection._increment_connect() -- cgit v1.2.1 From c24423bc2e3fd227bf4a86599e28407bd190ee9e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 6 Dec 2014 13:29:32 -0500 Subject: - enhance only_on() to work with compound specs - fix "temporary_tables" requirement --- lib/sqlalchemy/testing/exclusions.py | 2 +- lib/sqlalchemy/testing/requirements.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index f94724608..0aff43ae1 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -425,7 +425,7 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) + OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)]) ) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index da3e3128a..5744431cb 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -322,6 +322,11 @@ class SuiteRequirements(Requirements): """target dialect supports listing of temporary table names""" return exclusions.closed() + @property + def temporary_tables(self): + """target database supports temporary tables""" + return exclusions.open() + @property def temporary_views(self): """target database supports temporary views""" -- cgit v1.2.1 From c8817e608788799837a91b1d2616227594698d2b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 6 Dec 2014 13:30:51 -0500 Subject: - SQL Server 2012 now recommends VARCHAR(max), NVARCHAR(max), VARBINARY(max) for large text/binary types. The MSSQL dialect will now respect this based on version detection, as well as the new ``deprecate_large_types`` flag. fixes #3039 --- lib/sqlalchemy/dialects/mssql/base.py | 105 +++++++++++++++++++++++++++++++--- lib/sqlalchemy/sql/sqltypes.py | 2 +- 2 files changed, 97 insertions(+), 10 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index dad02ee0f..5d84975c0 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -226,6 +226,53 @@ The DATE and TIME types are not available for MSSQL 2005 and previous - if a server version below 2008 is detected, DDL for these types will be issued as DATETIME. +.. _mssql_large_type_deprecation: + +Large Text/Binary Type Deprecation +---------------------------------- + +Per `SQL Server 2012/2014 Documentation `_, +the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL Server +in a future release. SQLAlchemy normally relates these types to the +:class:`.UnicodeText`, :class:`.Text` and :class:`.LargeBinary` datatypes. + +In order to accommodate this change, a new flag ``deprecate_large_types`` +is added to the dialect, which will be automatically set based on detection +of the server version in use, if not otherwise set by the user. The +behavior of this flag is as follows: + +* When this flag is ``True``, the :class:`.UnicodeText`, :class:`.Text` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``, + respectively. This is a new behavior as of the addition of this flag. + +* When this flag is ``False``, the :class:`.UnicodeText`, :class:`.Text` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NTEXT``, ``TEXT``, and ``IMAGE``, + respectively. This is the long-standing behavior of these types. + +* The flag begins with the value ``None``, before a database connection is + established. If the dialect is used to render DDL without the flag being + set, it is interpreted the same as ``False``. + +* On first connection, the dialect detects if SQL Server version 2012 or greater + is in use; if the flag is still at ``None``, it sets it to ``True`` or + ``False`` based on whether 2012 or greater is detected. + +* The flag can be set to either ``True`` or ``False`` when the dialect + is created, typically via :func:`.create_engine`:: + + eng = create_engine("mssql+pymssql://user:pass@host/db", + deprecate_large_types=True) + +* Complete control over whether the "old" or "new" types are rendered is + available in all SQLAlchemy versions by using the UPPERCASE type objects + instead: :class:`.NVARCHAR`, :class:`.VARCHAR`, :class:`.types.VARBINARY`, + :class:`.TEXT`, :class:`.mssql.NTEXT`, :class:`.mssql.IMAGE` will always remain + fixed and always output exactly that type. + +.. versionadded:: 1.0.0 + .. _mssql_indexes: Clustered Index Support @@ -367,19 +414,20 @@ import operator import re from ... import sql, schema as sa_schema, exc, util -from ...sql import compiler, expression, \ - util as sql_util, cast +from ...sql import compiler, expression, util as sql_util from ... import engine from ...engine import reflection, default from ... import types as sqltypes from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ - VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR + TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR from ...util import update_wrapper from . import information_schema as ischema +# http://sqlserverbuilds.blogspot.com/ +MS_2012_VERSION = (11,) MS_2008_VERSION = (10,) MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) @@ -545,6 +593,26 @@ class NTEXT(sqltypes.UnicodeText): __visit_name__ = 'NTEXT' +class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): + """The MSSQL VARBINARY type. + + This type extends both :class:`.types.VARBINARY` and + :class:`.types.LargeBinary`. In "deprecate_large_types" mode, + the :class:`.types.LargeBinary` type will produce ``VARBINARY(max)`` + on SQL Server. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :ref:`mssql_large_type_deprecation` + + + + """ + __visit_name__ = 'VARBINARY' + + class IMAGE(sqltypes.LargeBinary): __visit_name__ = 'IMAGE' @@ -683,8 +751,17 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): def visit_unicode(self, type_): return self.visit_NVARCHAR(type_) + def visit_text(self, type_): + if self.dialect.deprecate_large_types: + return self.visit_VARCHAR(type_) + else: + return self.visit_TEXT(type_) + def visit_unicode_text(self, type_): - return self.visit_NTEXT(type_) + if self.dialect.deprecate_large_types: + return self.visit_NVARCHAR(type_) + else: + return self.visit_NTEXT(type_) def visit_NTEXT(self, type_): return self._extend("NTEXT", type_) @@ -717,7 +794,10 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self.visit_TIME(type_) def visit_large_binary(self, type_): - return self.visit_IMAGE(type_) + if self.dialect.deprecate_large_types: + return self.visit_VARBINARY(type_) + else: + return self.visit_IMAGE(type_) def visit_IMAGE(self, type_): return "IMAGE" @@ -1370,13 +1450,15 @@ class MSDialect(default.DefaultDialect): query_timeout=None, use_scope_identity=True, max_identifier_length=None, - schema_name="dbo", **opts): + schema_name="dbo", + deprecate_large_types=None, **opts): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = use_scope_identity self.max_identifier_length = int(max_identifier_length or 0) or \ self.max_identifier_length + self.deprecate_large_types = deprecate_large_types super(MSDialect, self).__init__(**opts) def do_savepoint(self, connection, name): @@ -1390,6 +1472,9 @@ class MSDialect(default.DefaultDialect): def initialize(self, connection): super(MSDialect, self).initialize(connection) + self._setup_version_attributes() + + def _setup_version_attributes(self): if self.server_version_info[0] not in list(range(8, 17)): # FreeTDS with version 4.2 seems to report here # a number like "95.10.255". Don't know what @@ -1405,6 +1490,9 @@ class MSDialect(default.DefaultDialect): self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True + if self.deprecate_large_types is None: + self.deprecate_large_types = \ + self.server_version_info >= MS_2012_VERSION def _get_default_schema_name(self, connection): if self.server_version_info < MS_2005_VERSION: @@ -1592,12 +1680,11 @@ class MSDialect(default.DefaultDialect): if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary): + if charlen == -1: + charlen = 'max' kwargs['length'] = charlen if collation: kwargs['collation'] = collation - if coltype == MSText or \ - (coltype in (MSString, MSNVarchar) and charlen == -1): - kwargs.pop('length') if coltype is None: util.warn( diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 94db1d837..9a2de39b4 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -894,7 +894,7 @@ class LargeBinary(_Binary): :param length: optional, a length for the column for use in DDL statements, for those BLOB types that accept a length - (i.e. MySQL). It does *not* produce a small BINARY/VARBINARY + (i.e. MySQL). It does *not* produce a *lengthed* BINARY/VARBINARY type - use the BINARY/VARBINARY types specifically for those. May be safely omitted if no ``CREATE TABLE`` will be issued. Certain databases may require a -- cgit v1.2.1 From 60e6ac8856e5f7f257e1797280d1510682ae8fb7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 7 Dec 2014 18:54:52 -0500 Subject: - rework the assert_sql system so that we have a context manager to work with, use events that are local to the engine and to the run and are removed afterwards. --- lib/sqlalchemy/testing/assertions.py | 13 +++-- lib/sqlalchemy/testing/assertsql.py | 92 ++++++++++++++++++++++++++---------- lib/sqlalchemy/testing/engines.py | 3 -- 3 files changed, 75 insertions(+), 33 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bf7c27a89..66d1f3cb0 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -405,13 +405,16 @@ class AssertsExecutionResults(object): cls.__name__, repr(expected_item))) return True + def sql_execution_asserter(self, db=None): + if db is None: + from . import db as db + + return assertsql.assert_engine(db) + def assert_sql_execution(self, db, callable_, *rules): - assertsql.asserter.add_rules(rules) - try: + with self.sql_execution_asserter(db) as asserter: callable_() - assertsql.asserter.statement_complete() - finally: - assertsql.asserter.clear_rules() + asserter.assert_(*rules) def assert_sql(self, db, callable_, list_, with_sequences=None): if (with_sequences is not None and diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index bcc999fe3..2ac0605a2 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -8,6 +8,9 @@ from ..engine.default import DefaultDialect from .. import util import re +import collections +import contextlib +from .. import event class AssertRule(object): @@ -321,39 +324,78 @@ def _process_assertion_statement(query, context): return query -class SQLAssert(object): +class SQLExecuteObserved( + collections.namedtuple( + "SQLExecuteObserved", ["clauseelement", "multiparams", "params"]) +): + def process(self, rules): + if rules is not None: + if not rules: + assert False, \ + 'All rules have been exhausted, but further '\ + 'statements remain' + rule = rules[0] + rule.process_execute( + self.clauseelement, *self.multiparams, **self.params) + if rule.is_consumed(): + rules.pop(0) - rules = None - def add_rules(self, rules): - self.rules = list(rules) +class SQLCursorExecuteObserved( + collections.namedtuple( + "SQLCursorExecuteObserved", + ["statement", "parameters", "context", "executemany"]) +): + def process(self, rules): + if rules: + rule = rules[0] + rule.process_cursor_execute( + self.statement, self.parameters, + self.context, self.executemany) - def statement_complete(self): - for rule in self.rules: + +class SQLAsserter(object): + def __init__(self): + self.accumulated = [] + + def _close(self): + # safety feature in case event.remove + # goes haywire + self._final = self.accumulated + del self.accumulated + + def assert_(self, *rules): + rules = list(rules) + for observed in self._final: + observed.process(rules) + + for rule in rules: if not rule.consume_final(): assert False, \ 'All statements are complete, but pending '\ 'assertion rules remain' - def clear_rules(self): - del self.rules - def execute(self, conn, clauseelement, multiparams, params, result): - if self.rules is not None: - if not self.rules: - assert False, \ - 'All rules have been exhausted, but further '\ - 'statements remain' - rule = self.rules[0] - rule.process_execute(clauseelement, *multiparams, **params) - if rule.is_consumed(): - self.rules.pop(0) +@contextlib.contextmanager +def assert_engine(engine): + asserter = SQLAsserter() - def cursor_execute(self, conn, cursor, statement, parameters, - context, executemany): - if self.rules: - rule = self.rules[0] - rule.process_cursor_execute(statement, parameters, context, - executemany) + @event.listens_for(engine, "after_execute") + def execute(conn, clauseelement, multiparams, params, result): + asserter.accumulated.append( + SQLExecuteObserved( + clauseelement, multiparams, params)) -asserter = SQLAssert() + @event.listens_for(engine, "after_cursor_execute") + def cursor_execute(conn, cursor, statement, parameters, + context, executemany): + asserter.accumulated.append( + SQLCursorExecuteObserved( + statement, parameters, context, executemany)) + + try: + yield asserter + finally: + asserter._close() + event.remove(engine, "after_cursor_execute", cursor_execute) + event.remove(engine, "after_execute", execute) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 0f6f59401..7d73e7423 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -204,7 +204,6 @@ def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" from sqlalchemy import create_engine - from .assertsql import asserter if not options: use_reaper = True @@ -219,8 +218,6 @@ def testing_engine(url=None, options=None): if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 - event.listen(engine, 'after_execute', asserter.execute) - event.listen(engine, 'after_cursor_execute', asserter.cursor_execute) if use_reaper: event.listen(engine.pool, 'connect', testing_reaper.connect) event.listen(engine.pool, 'checkout', testing_reaper.checkout) -- cgit v1.2.1 From e257ca6c5268517ec2e9a561372d82dfc10475e8 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 7 Dec 2014 18:55:23 -0500 Subject: - initial tests for bulk --- lib/sqlalchemy/orm/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index ef911824c..7dd577230 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2056,7 +2056,8 @@ class Session(_SessionClassMethods): mapper, states, isupdate, True, return_defaults) def bulk_insert_mappings(self, mapper, mappings, return_defaults=False): - self._bulk_save_mappings(mapper, mappings, False, False, return_defaults) + self._bulk_save_mappings( + mapper, mappings, False, False, return_defaults) def bulk_update_mappings(self, mapper, mappings): self._bulk_save_mappings(mapper, mappings, True, False, False) -- cgit v1.2.1 From c42b8f8eb8f4c324e2469bf3baaa316c214abce5 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 7 Dec 2014 20:21:20 -0500 Subject: - fix inheritance persistence - start writing docs --- lib/sqlalchemy/orm/persistence.py | 15 ++-- lib/sqlalchemy/orm/session.py | 158 ++++++++++++++++++++++++++++++++++++++ lib/sqlalchemy/orm/sync.py | 17 ++++ 3 files changed, 184 insertions(+), 6 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 81024c41f..d94fbb040 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -49,7 +49,7 @@ def _bulk_insert( continue records = ( - (None, state_dict, params, super_mapper, + (None, state_dict, params, mapper, connection, value_params, has_all_pks, has_all_defaults) for state, state_dict, params, mp, @@ -918,7 +918,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): def _postfetch(mapper, uowtransaction, table, - state, dict_, result, params, value_params): + state, dict_, result, params, value_params, bulk=False): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that state.""" @@ -954,10 +954,13 @@ def _postfetch(mapper, uowtransaction, table, # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - mapper.passive_updates) + if state is None: + sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) + else: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) def _connections_for_states(base_mapper, uowtransaction, states): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 7dd577230..e07b4554e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2048,6 +2048,66 @@ class Session(_SessionClassMethods): transaction.rollback(_capture_exception=True) def bulk_save_objects(self, objects, return_defaults=False): + """Perform a bulk save of the given list of objects. + + The bulk save feature allows mapped objects to be used as the + source of simple INSERT and UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations; the extraction of data from the objects is also performed + using a lower-latency process that ignores whether or not attributes + have actually been modified in the case of UPDATEs, and also ignores + SQL expressions. + + The objects as given are not added to the session and no additional + state is established on them, unless the ``return_defaults`` flag + is also set. + + .. warning:: + + The bulk save feature allows for a lower-latency INSERT/UPDATE + of rows at the expense of a lack of features. Features such + as object management, relationship handling, and SQL clause + support are bypassed in favor of raw INSERT/UPDATES of records. + + **Please read the list of caveats at :ref:`bulk_operations` + before using this method.** + + :param objects: a list of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, **all** those attributes which are present + and are not part of the primary key are applied to the SET clause + of the UPDATE statement, regardless of whether any change in state + was logged on each attribute; there is no checking of per-attribute + history. The primary key attributes, which are required, + are applied to the WHERE clause. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + return_defaults mode greatly reduces the performance gains of the + method overall. + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` + + """ for (mapper, isupdate), states in itertools.groupby( (attributes.instance_state(obj) for obj in objects), lambda state: (state.mapper, state.key is not None) @@ -2056,10 +2116,108 @@ class Session(_SessionClassMethods): mapper, states, isupdate, True, return_defaults) def bulk_insert_mappings(self, mapper, mappings, return_defaults=False): + """Perform a bulk insert of the given list of mapping dictionaries. + + The bulk insert feature allows plain Python dictionaries to be used as + the source of simple INSERT operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when inserting + large numbers of simple rows. + + The values within the dictionaries as given are typically passed + without modification into Core :meth:`.Insert` constructs, after + organizing the values within them across the tables to which + the given mapper is mapped. + + .. warning:: + + The bulk insert feature allows for a lower-latency INSERT + of rows at the expense of a lack of features. Features such + as relationship handling and SQL clause support are bypassed + in favor of a raw INSERT of records. + + **Please read the list of caveats at :ref:`bulk_operations` + before using this method.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be inserted, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain + all keys to be populated into all tables. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary + key values ahead of time; however, return_defaults mode greatly + reduces the performance gains of the method overall. If the rows + to be inserted only refer to a single table, then there is no + reason this flag should be set as the returned default information + is not used. + + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + """ self._bulk_save_mappings( mapper, mappings, False, False, return_defaults) def bulk_update_mappings(self, mapper, mappings): + """Perform a bulk update of the given list of mapping dictionaries. + + The bulk update feature allows plain Python dictionaries to be used as + the source of simple UPDATE operations which can be more easily + grouped together into higher performing "executemany" + operations. Using dictionaries, there is no "history" or session + state management features in use, reducing latency when updating + large numbers of simple rows. + + .. warning:: + + The bulk update feature allows for a lower-latency UPDATE + of rows at the expense of a lack of features. Features such + as relationship handling and SQL clause support are bypassed + in favor of a raw UPDATE of records. + + **Please read the list of caveats at :ref:`bulk_operations` + before using this method.** + + :param mapper: a mapped class, or the actual :class:`.Mapper` object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a list of dictionaries, each one containing the state + of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary may contain + keys corresponding to all tables. All those keys which are present + and are not part of the primary key are applied to the SET clause + of the UPDATE statement; the primary key values, which are required, + are applied to the WHERE clause. + + + .. seealso:: + + :ref:`bulk_operations` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + """ self._bulk_save_mappings(mapper, mappings, True, False, False) def _bulk_save_mappings( diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index e1ef85c1d..671c7c067 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -45,6 +45,23 @@ def populate(source, source_mapper, dest, dest_mapper, uowcommit.attributes[("pk_cascaded", dest, r)] = True +def bulk_populate_inherit_keys( + source_dict, source_mapper, synchronize_pairs): + # a simplified version of populate() used by bulk insert mode + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + value = source_dict[prop.key] + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r) + + try: + prop = source_mapper._columntoproperty[r] + source_dict[prop.key] = value + except exc.UnmappedColumnError: + _raise_col_to_prop(True, source_mapper, l, source_mapper, r) + + def clear(dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: if r.primary_key and \ -- cgit v1.2.1 From 07cc9e054ae4d5bb9cfc3c1d807b2a0d58a95b69 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 7 Dec 2014 20:36:01 -0500 Subject: - add an option for bulk_save -> update to not do history --- lib/sqlalchemy/orm/persistence.py | 9 +++++++-- lib/sqlalchemy/orm/session.py | 32 +++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 13 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index d94fbb040..f477e1dd7 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -75,7 +75,8 @@ def _bulk_insert( ) -def _bulk_update(mapper, mappings, session_transaction, isstates): +def _bulk_update(mapper, mappings, session_transaction, + isstates, update_changed_only): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) @@ -88,7 +89,10 @@ def _bulk_update(mapper, mappings, session_transaction, isstates): ) if isstates: - mappings = [_changed_dict(mapper, state) for state in mappings] + if update_changed_only: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = [state.dict for state in mappings] else: mappings = list(mappings) @@ -612,6 +616,7 @@ def _emit_update_statements(base_mapper, uowtransaction, rows = 0 records = list(records) + if hasvalue: for state, state_dict, params, mapper, \ connection, value_params in records: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index e07b4554e..72d393f54 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2047,7 +2047,8 @@ class Session(_SessionClassMethods): with util.safe_reraise(): transaction.rollback(_capture_exception=True) - def bulk_save_objects(self, objects, return_defaults=False): + def bulk_save_objects( + self, objects, return_defaults=False, update_changed_only=True): """Perform a bulk save of the given list of objects. The bulk save feature allows mapped objects to be used as the @@ -2083,12 +2084,13 @@ class Session(_SessionClassMethods): attribute set, then the object is assumed to be "detached" and will result in an UPDATE. Otherwise, an INSERT is used. - In the case of an UPDATE, **all** those attributes which are present - and are not part of the primary key are applied to the SET clause - of the UPDATE statement, regardless of whether any change in state - was logged on each attribute; there is no checking of per-attribute - history. The primary key attributes, which are required, - are applied to the WHERE clause. + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. :param return_defaults: when True, rows that are missing values which generate defaults, namely integer primary key defaults and sequences, @@ -2099,6 +2101,11 @@ class Session(_SessionClassMethods): return_defaults mode greatly reduces the performance gains of the method overall. + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + .. seealso:: :ref:`bulk_operations` @@ -2113,7 +2120,8 @@ class Session(_SessionClassMethods): lambda state: (state.mapper, state.key is not None) ): self._bulk_save_mappings( - mapper, states, isupdate, True, return_defaults) + mapper, states, isupdate, True, + return_defaults, update_changed_only) def bulk_insert_mappings(self, mapper, mappings, return_defaults=False): """Perform a bulk insert of the given list of mapping dictionaries. @@ -2218,10 +2226,11 @@ class Session(_SessionClassMethods): :meth:`.Session.bulk_save_objects` """ - self._bulk_save_mappings(mapper, mappings, True, False, False) + self._bulk_save_mappings(mapper, mappings, True, False, False, False) def _bulk_save_mappings( - self, mapper, mappings, isupdate, isstates, return_defaults): + self, mapper, mappings, isupdate, isstates, + return_defaults, update_changed_only): mapper = _class_to_mapper(mapper) self._flushing = True @@ -2230,7 +2239,8 @@ class Session(_SessionClassMethods): try: if isupdate: persistence._bulk_update( - mapper, mappings, transaction, isstates) + mapper, mappings, transaction, + isstates, update_changed_only) else: persistence._bulk_insert( mapper, mappings, transaction, isstates, return_defaults) -- cgit v1.2.1 From 3f1477e2ecf3b2e95a26383490d0e8c363f4d0cc Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 8 Dec 2014 01:10:30 -0500 Subject: - A new series of :class:`.Session` methods which provide hooks directly into the unit of work's facility for emitting INSERT and UPDATE statements has been created. When used correctly, this expert-oriented system can allow ORM-mappings to be used to generate bulk insert and update statements batched into executemany groups, allowing the statements to proceed at speeds that rival direct use of the Core. fixes #3100 --- lib/sqlalchemy/orm/session.py | 59 +++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 22 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 72d393f54..d40d28154 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2061,17 +2061,22 @@ class Session(_SessionClassMethods): The objects as given are not added to the session and no additional state is established on them, unless the ``return_defaults`` flag - is also set. + is also set, in which case primary key attributes and server-side + default values will be populated. + + .. versionadded:: 1.0.0 .. warning:: The bulk save feature allows for a lower-latency INSERT/UPDATE - of rows at the expense of a lack of features. Features such - as object management, relationship handling, and SQL clause - support are bypassed in favor of raw INSERT/UPDATES of records. + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT/UPDATES of records. - **Please read the list of caveats at :ref:`bulk_operations` - before using this method.** + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** :param objects: a list of mapped object instances. The mapped objects are persisted as is, and are **not** associated with the @@ -2098,8 +2103,8 @@ class Session(_SessionClassMethods): is available. In particular this will allow joined-inheritance and other multi-table mappings to insert correctly without the need to provide primary key values ahead of time; however, - return_defaults mode greatly reduces the performance gains of the - method overall. + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. :param update_changed_only: when True, UPDATE statements are rendered based on those attributes in each state that have logged changes. @@ -2138,15 +2143,19 @@ class Session(_SessionClassMethods): organizing the values within them across the tables to which the given mapper is mapped. + .. versionadded:: 1.0.0 + .. warning:: The bulk insert feature allows for a lower-latency INSERT - of rows at the expense of a lack of features. Features such - as relationship handling and SQL clause support are bypassed - in favor of a raw INSERT of records. + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + INSERT of records. - **Please read the list of caveats at :ref:`bulk_operations` - before using this method.** + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** :param mapper: a mapped class, or the actual :class:`.Mapper` object, representing the single kind of object represented within the mapping @@ -2164,8 +2173,10 @@ class Session(_SessionClassMethods): is available. In particular this will allow joined-inheritance and other multi-table mappings to insert correctly without the need to provide primary - key values ahead of time; however, return_defaults mode greatly - reduces the performance gains of the method overall. If the rows + key values ahead of time; however, + :paramref:`.Session.bulk_insert_mappings.return_defaults` + **greatly reduces the performance gains** of the method overall. + If the rows to be inserted only refer to a single table, then there is no reason this flag should be set as the returned default information is not used. @@ -2181,7 +2192,7 @@ class Session(_SessionClassMethods): """ self._bulk_save_mappings( - mapper, mappings, False, False, return_defaults) + mapper, mappings, False, False, return_defaults, False) def bulk_update_mappings(self, mapper, mappings): """Perform a bulk update of the given list of mapping dictionaries. @@ -2193,15 +2204,19 @@ class Session(_SessionClassMethods): state management features in use, reducing latency when updating large numbers of simple rows. + .. versionadded:: 1.0.0 + .. warning:: The bulk update feature allows for a lower-latency UPDATE - of rows at the expense of a lack of features. Features such - as relationship handling and SQL clause support are bypassed - in favor of a raw UPDATE of records. - - **Please read the list of caveats at :ref:`bulk_operations` - before using this method.** + of rows at the expense of most other unit-of-work features. + Features such as object management, relationship handling, + and SQL clause support are **silently omitted** in favor of raw + UPDATES of records. + + **Please read the list of caveats at** :ref:`bulk_operations` + **before using this method, and fully test and confirm the + functionality of all code developed using these systems.** :param mapper: a mapped class, or the actual :class:`.Mapper` object, representing the single kind of object represented within the mapping -- cgit v1.2.1 From 6b9f62df10e1b1f557b9077613e5e96a08427460 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 8 Dec 2014 11:18:38 -0500 Subject: - force the _has_events flag to True on engines, so that profiling is more predictable - restore the profiling from before this change --- lib/sqlalchemy/testing/engines.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 7d73e7423..444a79b70 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -215,6 +215,9 @@ def testing_engine(url=None, options=None): options = config.db_opts engine = create_engine(url, **options) + engine._has_events = True # enable event blocks, helps with + # profiling + if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 -- cgit v1.2.1 From b7cf11b163dd7d15f56634a41dcceb880821ecf3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 8 Dec 2014 14:05:20 -0500 Subject: - simplify the "noconnection" error handling, setting _handle_dbapi_exception_noconnection() to only invoke in the case of raw_connection() in the constructor of Connection. in all other cases the Connection proceeds with _handle_dbapi_exception() including revalidate. --- lib/sqlalchemy/engine/base.py | 36 +++++++++++++++++++----------------- lib/sqlalchemy/engine/threadlocal.py | 2 +- 2 files changed, 20 insertions(+), 18 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 23348469d..dd8ea275c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -265,18 +265,20 @@ class Connection(Connectable): try: return self.__connection except AttributeError: - return self._revalidate_connection(_wrap=True) + try: + return self._revalidate_connection() + except Exception as e: + self._handle_dbapi_exception(e, None, None, None, None) - def _revalidate_connection(self, _wrap): + def _revalidate_connection(self): if self.__branch_from: - return self.__branch_from._revalidate_connection(_wrap=_wrap) + return self.__branch_from._revalidate_connection() if self.__can_reconnect and self.__invalid: if self.__transaction is not None: raise exc.InvalidRequestError( "Can't reconnect until invalid " "transaction is rolled back") - self.__connection = self.engine.raw_connection( - _connection=self, _wrap=_wrap) + self.__connection = self.engine.raw_connection(_connection=self) self.__invalid = False return self.__connection raise exc.ResourceClosedError("This Connection is closed") @@ -817,7 +819,7 @@ class Connection(Connectable): try: conn = self.__connection except AttributeError: - conn = self._revalidate_connection(_wrap=False) + conn = self._revalidate_connection() dialect = self.dialect ctx = dialect.execution_ctx_cls._init_default( @@ -955,7 +957,7 @@ class Connection(Connectable): try: conn = self.__connection except AttributeError: - conn = self._revalidate_connection(_wrap=False) + conn = self._revalidate_connection() context = constructor(dialect, self, conn, *args) except Exception as e: @@ -1248,8 +1250,7 @@ class Connection(Connectable): self.close() @classmethod - def _handle_dbapi_exception_noconnection( - cls, e, dialect, engine, connection): + def _handle_dbapi_exception_noconnection(cls, e, dialect, engine): exc_info = sys.exc_info() @@ -1271,7 +1272,7 @@ class Connection(Connectable): if engine._has_events: ctx = ExceptionContextImpl( - e, sqlalchemy_exception, engine, connection, None, None, + e, sqlalchemy_exception, engine, None, None, None, None, None, is_disconnect) for fn in engine.dispatch.handle_error: try: @@ -1957,17 +1958,18 @@ class Engine(Connectable, log.Identified): """ return self.run_callable(self.dialect.has_table, table_name, schema) - def _wrap_pool_connect(self, fn, connection, wrap=True): - if not wrap: - return fn() + def _wrap_pool_connect(self, fn, connection): dialect = self.dialect try: return fn() except dialect.dbapi.Error as e: - Connection._handle_dbapi_exception_noconnection( - e, dialect, self, connection) + if connection is None: + Connection._handle_dbapi_exception_noconnection( + e, dialect, self) + else: + util.reraise(*sys.exc_info()) - def raw_connection(self, _connection=None, _wrap=True): + def raw_connection(self, _connection=None): """Return a "raw" DBAPI connection from the connection pool. The returned object is a proxied version of the DBAPI @@ -1984,7 +1986,7 @@ class Engine(Connectable, log.Identified): """ return self._wrap_pool_connect( - self.pool.unique_connection, _connection, _wrap) + self.pool.unique_connection, _connection) class OptionEngine(Engine): diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 824b68fdf..e64ab09f4 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -61,7 +61,7 @@ class TLEngine(base.Engine): connection = self._tl_connection_cls( self, self._wrap_pool_connect( - self.pool.connect, connection, wrap=True), + self.pool.connect, connection), **kw) self._connections.conn = weakref.ref(connection) -- cgit v1.2.1 From 06738f665ea936246a3813ad7de01e98ff8d519a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 8 Dec 2014 15:15:02 -0500 Subject: - identify another spot where _handle_dbapi_error() needs to do something differently for the case where it is called in an already-invalidated state; don't call upon self.connection --- lib/sqlalchemy/engine/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dd8ea275c..9a8610344 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1243,9 +1243,10 @@ class Connection(Connectable): del self._reentrant_error if self._is_disconnect: del self._is_disconnect - dbapi_conn_wrapper = self.connection - self.engine.pool._invalidate(dbapi_conn_wrapper, e) - self.invalidate(e) + if not self.invalidated: + dbapi_conn_wrapper = self.__connection + self.engine.pool._invalidate(dbapi_conn_wrapper, e) + self.invalidate(e) if self.should_close_with_result: self.close() -- cgit v1.2.1 From 347db81aea9bfe301a9fe1fade644ad099545f3e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 10 Dec 2014 12:15:14 -0500 Subject: - keep working on fixing #3266, more cases, more tests --- lib/sqlalchemy/engine/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 9a8610344..918ee0e37 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1926,10 +1926,11 @@ class Engine(Connectable, log.Identified): """ - return self._connection_cls(self, - self.pool.connect(), - close_with_result=close_with_result, - **kwargs) + return self._connection_cls( + self, + self._wrap_pool_connect(self.pool.connect, None), + close_with_result=close_with_result, + **kwargs) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. -- cgit v1.2.1 From 3c70f609507ccc6775495cc533265aeb645528cd Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 10 Dec 2014 13:08:53 -0500 Subject: - fix up query update /delete documentation, make warnings a lot clearer, partial fixes for #3252 --- lib/sqlalchemy/orm/query.py | 179 +++++++++++++++++++++++++++----------------- 1 file changed, 110 insertions(+), 69 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 9b7747e15..1afffb90e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2725,6 +2725,18 @@ class Query(object): Deletes rows matched by this query from the database. + E.g.:: + + sess.query(User).filter(User.age == 25).\\ + delete(synchronize_session=False) + + sess.query(User).filter(User.age == 25).\\ + delete(synchronize_session='evaluate') + + .. warning:: The :meth:`.Query.delete` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + :param synchronize_session: chooses the strategy for the removal of matched objects from the session. Valid values are: @@ -2743,8 +2755,7 @@ class Query(object): ``'evaluate'`` - Evaluate the query's criteria in Python straight on the objects in the session. If evaluation of the criteria isn't - implemented, an error is raised. In that case you probably - want to use the 'fetch' strategy as a fallback. + implemented, an error is raised. The expression evaluator currently doesn't account for differing string collations between the database and Python. @@ -2752,29 +2763,42 @@ class Query(object): :return: the count of rows matched as returned by the database's "row count" feature. - This method has several key caveats: - - * The method does **not** offer in-Python cascading of relationships - - it is assumed that ON DELETE CASCADE/SET NULL/etc. is configured - for any foreign key references which require it, otherwise the - database may emit an integrity violation if foreign key references - are being enforced. - - After the DELETE, dependent objects in the :class:`.Session` which - were impacted by an ON DELETE may not contain the current - state, or may have been deleted. This issue is resolved once the - :class:`.Session` is expired, - which normally occurs upon :meth:`.Session.commit` or can be forced - by using :meth:`.Session.expire_all`. Accessing an expired object - whose row has been deleted will invoke a SELECT to locate the - row; when the row is not found, an - :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. - - * The :meth:`.MapperEvents.before_delete` and - :meth:`.MapperEvents.after_delete` - events are **not** invoked from this method. Instead, the - :meth:`.SessionEvents.after_bulk_delete` method is provided to act - upon a mass DELETE of entity rows. + .. warning:: **Additional Caveats for bulk query deletes** + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON DELETE CASCADE/SET + NULL/etc. is configured for any foreign key references + which require it, otherwise the database may emit an + integrity violation if foreign key references are being + enforced. + + After the DELETE, dependent objects in the + :class:`.Session` which were impacted by an ON DELETE + may not contain the current state, or may have been + deleted. This issue is resolved once the + :class:`.Session` is expired, which normally occurs upon + :meth:`.Session.commit` or can be forced by using + :meth:`.Session.expire_all`. Accessing an expired + object whose row has been deleted will invoke a SELECT + to locate the row; when the row is not found, an + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is + raised. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaulate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The :meth:`.MapperEvents.before_delete` and + :meth:`.MapperEvents.after_delete` + events **are not invoked** from this method. Instead, the + :meth:`.SessionEvents.after_bulk_delete` method is provided to + act upon a mass DELETE of entity rows. .. seealso:: @@ -2797,17 +2821,21 @@ class Query(object): E.g.:: - sess.query(User).filter(User.age == 25).\ - update({User.age: User.age - 10}, synchronize_session='fetch') + sess.query(User).filter(User.age == 25).\\ + update({User.age: User.age - 10}, synchronize_session=False) - - sess.query(User).filter(User.age == 25).\ + sess.query(User).filter(User.age == 25).\\ update({"age": User.age - 10}, synchronize_session='evaluate') + .. warning:: The :meth:`.Query.update` method is a "bulk" operation, + which bypasses ORM unit-of-work automation in favor of greater + performance. **Please read all caveats and warnings below.** + + :param values: a dictionary with attributes names, or alternatively - mapped attributes or SQL expressions, as keys, and literal - values or sql expressions as values. + mapped attributes or SQL expressions, as keys, and literal + values or sql expressions as values. .. versionchanged:: 1.0.0 - string names in the values dictionary are now resolved against the mapped entity; previously, these @@ -2815,7 +2843,7 @@ class Query(object): translation. :param synchronize_session: chooses the strategy to update the - attributes on objects in the session. Valid values are: + attributes on objects in the session. Valid values are: ``False`` - don't synchronize the session. This option is the most efficient and is reliable once the session is expired, which @@ -2836,43 +2864,56 @@ class Query(object): string collations between the database and Python. :return: the count of rows matched as returned by the database's - "row count" feature. - - This method has several key caveats: - - * The method does **not** offer in-Python cascading of relationships - - it is assumed that ON UPDATE CASCADE is configured for any foreign - key references which require it, otherwise the database may emit an - integrity violation if foreign key references are being enforced. - - After the UPDATE, dependent objects in the :class:`.Session` which - were impacted by an ON UPDATE CASCADE may not contain the current - state; this issue is resolved once the :class:`.Session` is expired, - which normally occurs upon :meth:`.Session.commit` or can be forced - by using :meth:`.Session.expire_all`. - - * The method supports multiple table updates, as - detailed in :ref:`multi_table_updates`, and this behavior does - extend to support updates of joined-inheritance and other multiple - table mappings. However, the **join condition of an inheritance - mapper is currently not automatically rendered**. - Care must be taken in any multiple-table update to explicitly - include the joining condition between those tables, even in mappings - where this is normally automatic. - E.g. if a class ``Engineer`` subclasses ``Employee``, an UPDATE of - the ``Engineer`` local table using criteria against the ``Employee`` - local table might look like:: - - session.query(Engineer).\\ - filter(Engineer.id == Employee.id).\\ - filter(Employee.name == 'dilbert').\\ - update({"engineer_type": "programmer"}) - - * The :meth:`.MapperEvents.before_update` and - :meth:`.MapperEvents.after_update` - events are **not** invoked from this method. Instead, the - :meth:`.SessionEvents.after_bulk_update` method is provided to act - upon a mass UPDATE of entity rows. + "row count" feature. + + .. warning:: **Additional Caveats for bulk query updates** + + * The method does **not** offer in-Python cascading of + relationships - it is assumed that ON UPDATE CASCADE is + configured for any foreign key references which require + it, otherwise the database may emit an integrity + violation if foreign key references are being enforced. + + After the UPDATE, dependent objects in the + :class:`.Session` which were impacted by an ON UPDATE + CASCADE may not contain the current state; this issue is + resolved once the :class:`.Session` is expired, which + normally occurs upon :meth:`.Session.commit` or can be + forced by using :meth:`.Session.expire_all`. + + * The ``'fetch'`` strategy results in an additional + SELECT statement emitted and will significantly reduce + performance. + + * The ``'evaulate'`` strategy performs a scan of + all matching objects within the :class:`.Session`; if the + contents of the :class:`.Session` are expired, such as + via a proceeding :meth:`.Session.commit` call, **this will + result in SELECT queries emitted for every matching object**. + + * The method supports multiple table updates, as detailed + in :ref:`multi_table_updates`, and this behavior does + extend to support updates of joined-inheritance and + other multiple table mappings. However, the **join + condition of an inheritance mapper is not + automatically rendered**. Care must be taken in any + multiple-table update to explicitly include the joining + condition between those tables, even in mappings where + this is normally automatic. E.g. if a class ``Engineer`` + subclasses ``Employee``, an UPDATE of the ``Engineer`` + local table using criteria against the ``Employee`` + local table might look like:: + + session.query(Engineer).\\ + filter(Engineer.id == Employee.id).\\ + filter(Employee.name == 'dilbert').\\ + update({"engineer_type": "programmer"}) + + * The :meth:`.MapperEvents.before_update` and + :meth:`.MapperEvents.after_update` + events **are not invoked from this method**. Instead, the + :meth:`.SessionEvents.after_bulk_update` method is provided to + act upon a mass UPDATE of entity rows. .. seealso:: -- cgit v1.2.1 From cf7981f60d485f17465f44c6ff651ae283ade377 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 12 Dec 2014 19:59:11 -0500 Subject: - Added new method :meth:`.Session.invalidate`, functions similarly to :meth:`.Session.close`, except also calls :meth:`.Connection.invalidate` on all connections, guaranteeing that they will not be returned to the connection pool. This is useful in situations e.g. dealing with gevent timeouts when it is not safe to use the connection further, even for rollbacks. references #3258 --- lib/sqlalchemy/orm/session.py | 42 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d40d28154..507e99b2e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -435,11 +435,13 @@ class SessionTransaction(object): self.session.dispatch.after_rollback(self.session) - def close(self): + def close(self, invalidate=False): self.session.transaction = self._parent if self._parent is None: for connection, transaction, autoclose in \ set(self._connections.values()): + if invalidate: + connection.invalidate() if autoclose: connection.close() else: @@ -1000,10 +1002,46 @@ class Session(_SessionClassMethods): not use any connection resources until they are first needed. """ + self._close_impl(invalidate=False) + + def invalidate(self): + """Close this Session, using connection invalidation. + + This is a variant of :meth:`.Session.close` that will additionally + ensure that the :meth:`.Connection.invalidate` method will be called + on all :class:`.Connection` objects. This can be called when + the database is known to be in a state where the connections are + no longer safe to be used. + + E.g.:: + + try: + sess = Session() + sess.add(User()) + sess.commit() + except gevent.Timeout: + sess.invalidate() + raise + except: + sess.rollback() + raise + + This clears all items and ends any transaction in progress. + + If this session were created with ``autocommit=False``, a new + transaction is immediately begun. Note that this new transaction does + not use any connection resources until they are first needed. + + .. versionadded:: 0.9.9 + + """ + self._close_impl(invalidate=True) + + def _close_impl(self, invalidate): self.expunge_all() if self.transaction is not None: for transaction in self.transaction._iterate_parents(): - transaction.close() + transaction.close(invalidate) def expunge_all(self): """Remove all object instances from this ``Session``. -- cgit v1.2.1 From 91af7337878612b2497269e600eef147a0f5bb30 Mon Sep 17 00:00:00 2001 From: Jon Nelson Date: Tue, 11 Nov 2014 22:46:07 -0600 Subject: - fix unique constraint parsing for sqlite -- may return '' for name, however --- lib/sqlalchemy/dialects/sqlite/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index ccd7f2539..30d8a6ea3 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1173,7 +1173,7 @@ class SQLiteDialect(default.DefaultDialect): return [] table_data = row[0] - UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)' + UNIQUE_PATTERN = '(?:CONSTRAINT (\w+) )?UNIQUE \(([^\)]+)\)' return [ {'name': name, 'column_names': [col.strip(' "') for col in cols.split(',')]} -- cgit v1.2.1 From 468db416dbf284f0e7dddde90ec9641dc89428c6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 13 Dec 2014 18:04:11 -0500 Subject: - rework sqlite FK and unique constraint system to combine both PRAGMA and regexp parsing of SQL in order to form a complete picture of constraints + their names. fixes #3244 fixes #3261 - factor various PRAGMA work to be centralized into one call --- lib/sqlalchemy/dialects/sqlite/base.py | 299 +++++++++++++++++++++------------ 1 file changed, 187 insertions(+), 112 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 30d8a6ea3..e79299527 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -913,22 +913,9 @@ class SQLiteDialect(default.DefaultDialect): return [row[0] for row in rs] def has_table(self, connection, table_name, schema=None): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%stable_info(%s)" % (pragma, qtable) - cursor = _pragma_cursor(connection.execute(statement)) - row = cursor.fetchone() - - # consume remaining rows, to work around - # http://www.sqlite.org/cvstrac/tktview?tn=1884 - while not cursor.closed and cursor.fetchone() is not None: - pass - - return row is not None + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema) + return bool(info) @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -970,18 +957,11 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%stable_info(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema) - rows = c.fetchall() columns = [] - for row in rows: + for row in info: (name, type_, nullable, default, primary_key) = ( row[1], row[2].upper(), not row[3], row[4], row[5]) @@ -1068,92 +1048,192 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - statement = "%sforeign_key_list(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) - fkeys = [] + # sqlite makes this *extremely difficult*. + # First, use the pragma to get the actual FKs. + pragma_fks = self._get_table_pragma( + connection, "foreign_key_list", + table_name, schema=schema + ) + fks = {} - while True: - row = c.fetchone() - if row is None: - break + + for row in pragma_fks: (numerical_id, rtbl, lcol, rcol) = ( row[0], row[2], row[3], row[4]) - self._parse_fk(fks, fkeys, numerical_id, rtbl, lcol, rcol) - return fkeys + if rcol is None: + rcol = lcol - def _parse_fk(self, fks, fkeys, numerical_id, rtbl, lcol, rcol): - # sqlite won't return rcol if the table was created with REFERENCES - # , no col - if rcol is None: - rcol = lcol + if self._broken_fk_pragma_quotes: + rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) - if self._broken_fk_pragma_quotes: - rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) + if numerical_id in fks: + fk = fks[numerical_id] + else: + fk = fks[numerical_id] = { + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': rtbl, + 'referred_columns': [], + } + fks[numerical_id] = fk - try: - fk = fks[numerical_id] - except KeyError: - fk = { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': rtbl, - 'referred_columns': [], - } - fkeys.append(fk) - fks[numerical_id] = fk - - if lcol not in fk['constrained_columns']: fk['constrained_columns'].append(lcol) - if rcol not in fk['referred_columns']: fk['referred_columns'].append(rcol) - return fk + + def fk_sig(constrained_columns, referred_table, referred_columns): + return tuple(constrained_columns) + (referred_table,) + \ + tuple(referred_columns) + + # then, parse the actual SQL and attempt to find DDL that matches + # the names as well. SQLite saves the DDL in whatever format + # it was typed in as, so need to be liberal here. + + keys_by_signature = dict( + ( + fk_sig( + fk['constrained_columns'], + fk['referred_table'], fk['referred_columns']), + fk + ) for fk in fks.values() + ) + + table_data = self._get_table_sql(connection, table_name, schema=schema) + if table_data is None: + # system tables, etc. + return [] + + def parse_fks(): + FK_PATTERN = ( + '(?:CONSTRAINT (\w+) +)?' + 'FOREIGN KEY *\( *(.+?) *\) +' + 'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\)' + ) + + for match in re.finditer(FK_PATTERN, table_data, re.I): + ( + constraint_name, constrained_columns, + referred_quoted_name, referred_name, + referred_columns) = match.group(1, 2, 3, 4, 5) + constrained_columns = list( + self._find_cols_in_sig(constrained_columns)) + if not referred_columns: + referred_columns = constrained_columns + else: + referred_columns = list( + self._find_cols_in_sig(referred_columns)) + referred_name = referred_quoted_name or referred_name + yield ( + constraint_name, constrained_columns, + referred_name, referred_columns) + fkeys = [] + + for ( + constraint_name, constrained_columns, + referred_name, referred_columns) in parse_fks(): + sig = fk_sig( + constrained_columns, referred_name, referred_columns) + if sig not in keys_by_signature: + util.warn( + "WARNING: SQL-parsed foreign key constraint " + "'%s' could not be located in PRAGMA " + "foreign_keys for table %s" % ( + sig, + table_name + )) + continue + key = keys_by_signature.pop(sig) + key['name'] = constraint_name + fkeys.append(key) + # assume the remainders are the unnamed, inline constraints, just + # use them as is as it's extremely difficult to parse inline + # constraints + fkeys.extend(keys_by_signature.values()) + return fkeys + + def _find_cols_in_sig(self, sig): + for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): + yield match.group(1) or match.group(2) + + @reflection.cache + def get_unique_constraints(self, connection, table_name, + schema=None, **kw): + + auto_index_by_sig = {} + for idx in self.get_indexes( + connection, table_name, schema=schema, + include_auto_indexes=True, **kw): + if not idx['name'].startswith("sqlite_autoindex"): + continue + sig = tuple(idx['column_names']) + auto_index_by_sig[sig] = idx + + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw) + if not table_data: + return [] + + unique_constraints = [] + + def parse_uqs(): + UNIQUE_PATTERN = '(?:CONSTRAINT (\w+) +)?UNIQUE *\((.+?)\)' + INLINE_UNIQUE_PATTERN = ( + '(?:(".+?")|([a-z0-9]+)) ' + '+[a-z0-9_ ]+? +UNIQUE') + + for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): + name, cols = match.group(1, 2) + yield name, list(self._find_cols_in_sig(cols)) + + # we need to match inlines as well, as we seek to differentiate + # a UNIQUE constraint from a UNIQUE INDEX, even though these + # are kind of the same thing :) + for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): + cols = list( + self._find_cols_in_sig(match.group(1) or match.group(2))) + yield None, cols + + for name, cols in parse_uqs(): + sig = tuple(cols) + if sig in auto_index_by_sig: + auto_index_by_sig.pop(sig) + parsed_constraint = { + 'name': name, + 'column_names': cols + } + unique_constraints.append(parsed_constraint) + # NOTE: auto_index_by_sig might not be empty here, + # the PRIMARY KEY may have an entry. + return unique_constraints @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - include_auto_indexes = kw.pop('include_auto_indexes', False) - qtable = quote(table_name) - statement = "%sindex_list(%s)" % (pragma, qtable) - c = _pragma_cursor(connection.execute(statement)) + pragma_indexes = self._get_table_pragma( + connection, "index_list", table_name, schema=schema) indexes = [] - while True: - row = c.fetchone() - if row is None: - break + + include_auto_indexes = kw.pop('include_auto_indexes', False) + for row in pragma_indexes: # ignore implicit primary key index. # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html - elif (not include_auto_indexes and - row[1].startswith('sqlite_autoindex')): + if (not include_auto_indexes and + row[1].startswith('sqlite_autoindex')): continue indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + # loop thru unique indexes to get the column names. for idx in indexes: - statement = "%sindex_info(%s)" % (pragma, quote(idx['name'])) - c = connection.execute(statement) - cols = idx['column_names'] - while True: - row = c.fetchone() - if row is None: - break - cols.append(row[2]) + pragma_index = self._get_table_pragma( + connection, "index_info", idx['name']) + + for row in pragma_index: + idx['column_names'].append(row[2]) return indexes @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def _get_table_sql(self, connection, table_name, schema=None, **kw): try: s = ("SELECT sql FROM " " (SELECT * FROM sqlite_master UNION ALL " @@ -1165,27 +1245,22 @@ class SQLiteDialect(default.DefaultDialect): s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " "AND type = 'table'") % table_name rs = connection.execute(s) - row = rs.fetchone() - if row is None: - # sqlite won't return the schema for the sqlite_master or - # sqlite_temp_master tables from this query. These tables - # don't have any unique constraints anyway. - return [] - table_data = row[0] - - UNIQUE_PATTERN = '(?:CONSTRAINT (\w+) )?UNIQUE \(([^\)]+)\)' - return [ - {'name': name, - 'column_names': [col.strip(' "') for col in cols.split(',')]} - for name, cols in re.findall(UNIQUE_PATTERN, table_data) - ] + return rs.scalar() - -def _pragma_cursor(cursor): - """work around SQLite issue whereby cursor.description - is blank when PRAGMA returns no rows.""" - - if cursor.closed: - cursor.fetchone = lambda: None - cursor.fetchall = lambda: [] - return cursor + def _get_table_pragma(self, connection, pragma, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + statement = "PRAGMA %s." % quote(schema) + else: + statement = "PRAGMA " + qtable = quote(table_name) + statement = "%s%s(%s)" % (statement, pragma, qtable) + cursor = connection.execute(statement) + if not cursor.closed: + # work around SQLite issue whereby cursor.description + # is blank when PRAGMA returns no rows: + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + result = cursor.fetchall() + else: + result = [] + return result -- cgit v1.2.1 From 8038cfa0771ff860f48967a6800477ce8a508d65 Mon Sep 17 00:00:00 2001 From: Tony Locke Date: Sun, 24 Aug 2014 16:33:29 +0100 Subject: pg8000 client_encoding in create_engine() The pg8000 dialect now supports the setting of the PostgreSQL parameter client_encoding from create_engine(). --- lib/sqlalchemy/dialects/postgresql/pg8000.py | 61 ++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 7 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 4ccc90208..a76787016 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -13,17 +13,30 @@ postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] :url: https://pythonhosted.org/pg8000/ + +.. _pg8000_unicode: + Unicode ------- -When communicating with the server, pg8000 **always uses the server-side -character set**. SQLAlchemy has no ability to modify what character set -pg8000 chooses to use, and additionally SQLAlchemy does no unicode conversion -of any kind with the pg8000 backend. The origin of the client encoding setting -is ultimately the CLIENT_ENCODING setting in postgresql.conf. +pg8000 will encode / decode string values between it and the server using the +PostgreSQL ``client_encoding`` parameter; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf-8``, as a more useful default:: + + #client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +The ``client_encoding`` can be overriden for a session by executing the SQL: -It is not necessary, though is also harmless, to pass the "encoding" parameter -to :func:`.create_engine` when using pg8000. +SET CLIENT_ENCODING TO 'utf8'; + +SQLAlchemy will execute this SQL on all new connections based on the value +passed to :func:`.create_engine` using the ``client_encoding`` parameter:: + + engine = create_engine( + "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') .. _pg8000_isolation_level: @@ -133,6 +146,10 @@ class PGDialect_pg8000(PGDialect): } ) + def __init__(self, client_encoding=None, **kwargs): + PGDialect.__init__(self, **kwargs) + self.client_encoding = client_encoding + def initialize(self, connection): if self.dbapi and hasattr(self.dbapi, '__version__'): self._dbapi_version = tuple([ @@ -181,6 +198,16 @@ class PGDialect_pg8000(PGDialect): (level, self.name, ", ".join(self._isolation_lookup)) ) + def set_client_encoding(self, connection, client_encoding): + # adjust for ConnectionFairy possibly being present + if hasattr(connection, 'connection'): + connection = connection.connection + + cursor = connection.cursor() + cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'") + cursor.execute("COMMIT") + cursor.close() + def do_begin_twophase(self, connection, xid): connection.connection.tpc_begin((0, xid, '')) @@ -198,4 +225,24 @@ class PGDialect_pg8000(PGDialect): def do_recover_twophase(self, connection): return [row[1] for row in connection.connection.tpc_recover()] + def on_connect(self): + fns = [] + if self.client_encoding is not None: + def on_connect(conn): + self.set_client_encoding(conn, self.client_encoding) + fns.append(on_connect) + + if self.isolation_level is not None: + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) + + if len(fns) > 0: + def on_connect(conn): + for fn in fns: + fn(conn) + return on_connect + else: + return None + dialect = PGDialect_pg8000 -- cgit v1.2.1 From c93706fa3319663234e3ab886b65f055bf9ed5da Mon Sep 17 00:00:00 2001 From: Tony Locke Date: Sun, 24 Aug 2014 15:15:17 +0100 Subject: Make pg8000 version detection more robust pg8000 uses Versioneer, which means that development versions have version strings that don't fit into the dotted triple number format. Released versions will always fit the triple format though. --- lib/sqlalchemy/dialects/postgresql/pg8000.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index a76787016..17d83fa61 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -71,6 +71,7 @@ from ... import types as sqltypes from .base import ( PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext, _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES) +import re class _PGNumeric(sqltypes.Numeric): @@ -151,15 +152,19 @@ class PGDialect_pg8000(PGDialect): self.client_encoding = client_encoding def initialize(self, connection): - if self.dbapi and hasattr(self.dbapi, '__version__'): - self._dbapi_version = tuple([ - int(x) for x in - self.dbapi.__version__.split(".")]) - else: - self._dbapi_version = (99, 99, 99) self.supports_sane_multi_rowcount = self._dbapi_version >= (1, 9, 14) super(PGDialect_pg8000, self).initialize(connection) + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, '__version__'): + return tuple( + [ + int(x) for x in re.findall( + r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)]) + else: + return (99, 99, 99) + @classmethod def dbapi(cls): return __import__('pg8000') -- cgit v1.2.1 From 17e03a0ea86cd92816b4002a203b2b0b2c1a538a Mon Sep 17 00:00:00 2001 From: Tony Locke Date: Sat, 3 Jan 2015 16:59:17 +0000 Subject: Changed pg8000 dialect to cope with native JSON For versions > 1.10.1 pg8000 returns de-serialized JSON objects rather than a string. SQL parameters are still strings though. --- lib/sqlalchemy/dialects/postgresql/pg8000.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy') diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 17d83fa61..4bb376a96 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -72,6 +72,7 @@ from .base import ( PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext, _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES) import re +from sqlalchemy.dialects.postgresql.json import JSON class _PGNumeric(sqltypes.Numeric): @@ -102,6 +103,15 @@ class _PGNumericNoBind(_PGNumeric): return None +class _PGJSON(JSON): + + def result_processor(self, dialect, coltype): + if dialect._dbapi_version > (1, 10, 1): + return None # Has native JSON + else: + return super(_PGJSON, self).result_processor(dialect, coltype) + + class PGExecutionContext_pg8000(PGExecutionContext): pass @@ -143,7 +153,8 @@ class PGDialect_pg8000(PGDialect): PGDialect.colspecs, { sqltypes.Numeric: _PGNumericNoBind, - sqltypes.Float: _PGNumeric + sqltypes.Float: _PGNumeric, + JSON: _PGJSON, } ) -- cgit v1.2.1