diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-06-15 15:13:34 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-06-17 09:48:52 -0400 |
| commit | 5b3e887f46afdbee312d5efd2a14f7c9b7eeac65 (patch) | |
| tree | 7c12dd2686dc3d26222383d39527b24613e49da3 /lib/sqlalchemy | |
| parent | 29fbbd9cebf5d4a4f21d01a74bcfb6dce923fe1b (diff) | |
| download | sqlalchemy-5b3e887f46afdbee312d5efd2a14f7c9b7eeac65.tar.gz | |
memoize current options and joins w with_entities/with_only_cols
Fixed further regressions in the same area as that of :ticket:`6052` where
loader options as well as invocations of methods like
:meth:`_orm.Query.join` would fail if the left side of the statement for
which the option/join depends upon were replaced by using the
:meth:`_orm.Query.with_entities` method, or when using 2.0 style queries
when using the :meth:`_sql.Select.with_only_columns` method. A new set of
state has been added to the objects which tracks the "left" entities that
the options / join were made against which is memoized when the lead
entities are changed.
Fixes: #6503
Fixes: #6253
Change-Id: I211b2af98b0b20d1263fb15dc513884dcc5de6a4
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 203 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 86 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 35 |
8 files changed, 322 insertions, 94 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index e4448f953..321eeada0 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -322,10 +322,16 @@ class ORMCompileState(CompileState): return loading.instances(result, querycontext) @property - def _mapper_entities(self): - return ( + def _lead_mapper_entities(self): + """return all _MapperEntity objects in the lead entities collection. + + Does **not** include entities that have been replaced by + with_entities(), with_only_columns() + + """ + return [ ent for ent in self._entities if isinstance(ent, _MapperEntity) - ) + ] def _create_with_polymorphic_adapter(self, ext_info, selectable): if ( @@ -405,7 +411,9 @@ class ORMFromStatementCompileState(ORMCompileState): self.use_legacy_query_style, ) - _QueryEntity.to_compile_state(self, statement_container._raw_columns) + _QueryEntity.to_compile_state( + self, statement_container._raw_columns, self._entities + ) self.current_path = statement_container._compile_options._current_path @@ -477,6 +485,8 @@ class ORMFromStatementCompileState(ORMCompileState): class ORMSelectCompileState(ORMCompileState, SelectState): _joinpath = _joinpoint = _EMPTY_DICT + _memoized_entities = _EMPTY_DICT + _from_obj_alias = None _has_mapper_entities = False @@ -572,15 +582,48 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement._label_style, self.use_legacy_query_style ) - _QueryEntity.to_compile_state(self, select_statement._raw_columns) + if select_statement._memoized_select_entities: + self._memoized_entities = { + memoized_entities: _QueryEntity.to_compile_state( + self, + memoized_entities._raw_columns, + [], + ) + for memoized_entities in ( + select_statement._memoized_select_entities + ) + } + + _QueryEntity.to_compile_state( + self, select_statement._raw_columns, self._entities + ) self.current_path = select_statement._compile_options._current_path self.eager_order_by = () - if toplevel and select_statement._with_options: + if toplevel and ( + select_statement._with_options + or select_statement._memoized_select_entities + ): self.attributes = {"_unbound_load_dedupes": set()} + for ( + memoized_entities + ) in select_statement._memoized_select_entities: + for opt in memoized_entities._with_options: + if opt._is_compile_state: + opt.process_compile_state_replaced_entities( + self, + [ + ent + for ent in self._memoized_entities[ + memoized_entities + ] + if isinstance(ent, _MapperEntity) + ], + ) + for opt in self.select_statement._with_options: if opt._is_compile_state: opt.process_compile_state(self) @@ -626,11 +669,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState): if self.compile_options._set_base_alias: self._set_select_from_alias() + for memoized_entities in query._memoized_select_entities: + if memoized_entities._setup_joins: + self._join( + memoized_entities._setup_joins, + self._memoized_entities[memoized_entities], + ) + if memoized_entities._legacy_setup_joins: + self._legacy_join( + memoized_entities._legacy_setup_joins, + self._memoized_entities[memoized_entities], + ) + if query._setup_joins: - self._join(query._setup_joins) + self._join(query._setup_joins, self._entities) if query._legacy_setup_joins: - self._legacy_join(query._legacy_setup_joins) + self._legacy_join(query._legacy_setup_joins, self._entities) current_adapter = self._get_current_adapter() @@ -782,7 +837,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # entities will also set up polymorphic adapters for mappers # that have with_polymorphic configured - _QueryEntity.to_compile_state(self, query._raw_columns) + _QueryEntity.to_compile_state(self, query._raw_columns, self._entities) return self @classmethod @@ -921,7 +976,18 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _all_equivs(self): equivs = {} - for ent in self._mapper_entities: + + for memoized_entities in self._memoized_entities.values(): + for ent in [ + ent + for ent in memoized_entities + if isinstance(ent, _MapperEntity) + ]: + equivs.update(ent.mapper._equivalent_columns) + + for ent in [ + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ]: equivs.update(ent.mapper._equivalent_columns) return equivs @@ -1211,7 +1277,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): return _adapt_clause - def _join(self, args): + def _join(self, args, entities_collection): for (right, onclause, from_, flags) in args: isouter = flags["isouter"] full = flags["full"] @@ -1316,6 +1382,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # figure out the final "left" and "right" sides and create an # ORMJoin to add to our _from_obj tuple self._join_left_to_right( + entities_collection, left, right, onclause, @@ -1326,7 +1393,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): full, ) - def _legacy_join(self, args): + def _legacy_join(self, args, entities_collection): """consumes arguments from join() or outerjoin(), places them into a consistent format with which to form the actual JOIN constructs. @@ -1474,6 +1541,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # figure out the final "left" and "right" sides and create an # ORMJoin to add to our _from_obj tuple self._join_left_to_right( + entities_collection, left, right, onclause, @@ -1489,6 +1557,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _join_left_to_right( self, + entities_collection, left, right, onclause, @@ -1513,7 +1582,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState): left, replace_from_obj_index, use_entity_index, - ) = self._join_determine_implicit_left_side(left, right, onclause) + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) else: # left is given via a relationship/name, or as explicit left side. # Determine where in our @@ -1522,7 +1593,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ( replace_from_obj_index, use_entity_index, - ) = self._join_place_explicit_left_side(left) + ) = self._join_place_explicit_left_side(entities_collection, left) if left is right and not create_aliases: raise sa_exc.InvalidRequestError( @@ -1568,9 +1639,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # entity_zero.selectable, but if with_polymorphic() were used # might be distinct assert isinstance( - self._entities[use_entity_index], _MapperEntity + entities_collection[use_entity_index], _MapperEntity ) - left_clause = self._entities[use_entity_index].selectable + left_clause = entities_collection[use_entity_index].selectable else: left_clause = left @@ -1585,7 +1656,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) ] - def _join_determine_implicit_left_side(self, left, right, onclause): + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -1635,12 +1708,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "to help resolve the ambiguity." % (right,) ) - elif self._entities: + elif entities_collection: # we have no explicit FROMs, so the implicit left has to # come from our list of entities. potential = {} - for entity_index, ent in enumerate(self._entities): + for entity_index, ent in enumerate(entities_collection): entity = ent.entity_zero_or_selectable if entity is None: continue @@ -1689,7 +1762,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): return left, replace_from_obj_index, use_entity_index - def _join_place_explicit_left_side(self, left): + def _join_place_explicit_left_side(self, entities_collection, left): """When join conditions express a left side explicitly, determine where in our existing list of FROM clauses we should join towards, or if we need to make a new join, and if so is it from one of our @@ -1743,10 +1816,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # aliasing / adaptation rules present on that entity if any if ( replace_from_obj_index is None - and self._entities + and entities_collection and hasattr(l_info, "mapper") ): - for idx, ent in enumerate(self._entities): + for idx, ent in enumerate(entities_collection): # TODO: should we be checking for multiple mapper entities # matching? if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): @@ -2194,11 +2267,14 @@ class _QueryEntity(object): __slots__ = () @classmethod - def to_compile_state(cls, compile_state, entities): + def to_compile_state(cls, compile_state, entities, entities_collection): + for idx, entity in enumerate(entities): if entity._is_lambda_element: if entity._is_sequence: - cls.to_compile_state(compile_state, entity._resolved) + cls.to_compile_state( + compile_state, entity._resolved, entities_collection + ) continue else: entity = entity._resolved @@ -2206,26 +2282,38 @@ class _QueryEntity(object): if entity.is_clause_element: if entity.is_selectable: if "parententity" in entity._annotations: - _MapperEntity(compile_state, entity) + _MapperEntity( + compile_state, entity, entities_collection + ) else: _ColumnEntity._for_columns( - compile_state, entity._select_iterable, idx + compile_state, + entity._select_iterable, + entities_collection, + idx, ) else: if entity._annotations.get("bundle", False): - _BundleEntity(compile_state, entity) + _BundleEntity( + compile_state, entity, entities_collection + ) elif entity._is_clause_list: # this is legacy only - test_composites.py # test_query_cols_legacy _ColumnEntity._for_columns( - compile_state, entity._select_iterable, idx + compile_state, + entity._select_iterable, + entities_collection, + idx, ) else: _ColumnEntity._for_columns( - compile_state, [entity], idx + compile_state, [entity], entities_collection, idx ) elif entity.is_bundle: - _BundleEntity(compile_state, entity) + _BundleEntity(compile_state, entity, entities_collection) + + return entities_collection class _MapperEntity(_QueryEntity): @@ -2244,8 +2332,8 @@ class _MapperEntity(_QueryEntity): "_polymorphic_discriminator", ) - def __init__(self, compile_state, entity): - compile_state._entities.append(self) + def __init__(self, compile_state, entity, entities_collection): + entities_collection.append(self) if compile_state._primary_entity is None: compile_state._primary_entity = self compile_state._has_mapper_entities = True @@ -2418,7 +2506,12 @@ class _BundleEntity(_QueryEntity): ) def __init__( - self, compile_state, expr, setup_entities=True, parent_bundle=None + self, + compile_state, + expr, + entities_collection, + setup_entities=True, + parent_bundle=None, ): compile_state._has_orm_entities = True @@ -2426,7 +2519,7 @@ class _BundleEntity(_QueryEntity): if parent_bundle: parent_bundle._entities.append(self) else: - compile_state._entities.append(self) + entities_collection.append(self) if isinstance( expr, (attributes.QueryableAttribute, interfaces.PropComparator) @@ -2443,12 +2536,26 @@ class _BundleEntity(_QueryEntity): if setup_entities: for expr in bundle.exprs: if "bundle" in expr._annotations: - _BundleEntity(compile_state, expr, parent_bundle=self) + _BundleEntity( + compile_state, + expr, + entities_collection, + parent_bundle=self, + ) elif isinstance(expr, Bundle): - _BundleEntity(compile_state, expr, parent_bundle=self) + _BundleEntity( + compile_state, + expr, + entities_collection, + parent_bundle=self, + ) else: _ORMColumnEntity._for_columns( - compile_state, [expr], None, parent_bundle=self + compile_state, + [expr], + entities_collection, + None, + parent_bundle=self, ) self.supports_single_entity = self.bundle.single_entity @@ -2516,7 +2623,12 @@ class _ColumnEntity(_QueryEntity): @classmethod def _for_columns( - cls, compile_state, columns, raw_column_index, parent_bundle=None + cls, + compile_state, + columns, + entities_collection, + raw_column_index, + parent_bundle=None, ): for column in columns: annotations = column._annotations @@ -2532,6 +2644,7 @@ class _ColumnEntity(_QueryEntity): _IdentityTokenEntity( compile_state, column, + entities_collection, _entity, raw_column_index, parent_bundle=parent_bundle, @@ -2540,6 +2653,7 @@ class _ColumnEntity(_QueryEntity): _ORMColumnEntity( compile_state, column, + entities_collection, _entity, raw_column_index, parent_bundle=parent_bundle, @@ -2548,6 +2662,7 @@ class _ColumnEntity(_QueryEntity): _RawColumnEntity( compile_state, column, + entities_collection, raw_column_index, parent_bundle=parent_bundle, ) @@ -2630,7 +2745,12 @@ class _RawColumnEntity(_ColumnEntity): ) def __init__( - self, compile_state, column, raw_column_index, parent_bundle=None + self, + compile_state, + column, + entities_collection, + raw_column_index, + parent_bundle=None, ): self.expr = column self.raw_column_index = raw_column_index @@ -2643,7 +2763,7 @@ class _RawColumnEntity(_ColumnEntity): if parent_bundle: parent_bundle._entities.append(self) else: - compile_state._entities.append(self) + entities_collection.append(self) self.column = column self.entity_zero_or_selectable = ( @@ -2690,6 +2810,7 @@ class _ORMColumnEntity(_ColumnEntity): self, compile_state, column, + entities_collection, parententity, raw_column_index, parent_bundle=None, @@ -2729,7 +2850,7 @@ class _ORMColumnEntity(_ColumnEntity): if parent_bundle: parent_bundle._entities.append(self) else: - compile_state._entities.append(self) + entities_collection.append(self) compile_state._has_orm_entities = True diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c9a601f99..28b4bfb2d 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -750,6 +750,18 @@ class LoaderOption(ORMOption): _is_compile_state = True + def process_compile_state_replaced_entities( + self, compile_state, mapper_entities + ): + """Apply a modification to a given :class:`.CompileState`, + given entities that were replaced by with_only_columns() or + with_entities(). + + .. versionadded:: 1.4.19 + + """ + self.process_compile_state(compile_state) + def process_compile_state(self, compile_state): """Apply a modification to a given :class:`.CompileState`.""" diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index cacfb8d84..7ba31fa7a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -57,6 +57,7 @@ from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative from ..sql.base import Executable +from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg from ..sql.selectable import GroupedElement @@ -125,6 +126,8 @@ class Query( _legacy_setup_joins = () _label_style = LABEL_STYLE_LEGACY_ORM + _memoized_select_entities = () + _compile_options = ORMCompileState.default_compile_options load_options = QueryContext.default_load_options @@ -1433,6 +1436,7 @@ class Query( limit(1) """ + _MemoizedSelectEntities._generate_for_statement(self) self._set_entities(entities) @_generative diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index e371442fd..91e627525 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -172,13 +172,32 @@ class Load(Generative, LoaderOption): _of_type = None _extra_criteria = () + def process_compile_state_replaced_entities( + self, compile_state, mapper_entities + ): + if not compile_state.compile_options._enable_eagerloads: + return + + # process is being run here so that the options given are validated + # against what the lead entities were, as well as to accommodate + # for the entities having been replaced with equivalents + self._process( + compile_state, + mapper_entities, + not bool(compile_state.current_path), + ) + def process_compile_state(self, compile_state): if not compile_state.compile_options._enable_eagerloads: return - self._process(compile_state, not bool(compile_state.current_path)) + self._process( + compile_state, + compile_state._lead_mapper_entities, + not bool(compile_state.current_path), + ) - def _process(self, compile_state, raiseerr): + def _process(self, compile_state, mapper_entities, raiseerr): is_refresh = compile_state.compile_options._for_refresh_state current_path = compile_state.current_path if current_path: @@ -700,7 +719,7 @@ class _UnboundLoad(Load): state["path"] = tuple(ret) self.__dict__ = state - def _process(self, compile_state, raiseerr): + def _process(self, compile_state, mapper_entities, raiseerr): dedupes = compile_state.attributes["_unbound_load_dedupes"] is_refresh = compile_state.compile_options._for_refresh_state for val in self._to_bind: @@ -709,10 +728,7 @@ class _UnboundLoad(Load): if is_refresh and not val.propagate_to_loaders: continue val._bind_loader( - [ - ent.entity_zero - for ent in compile_state._mapper_entities - ], + [ent.entity_zero for ent in mapper_entities], compile_state.current_path, compile_state.attributes, raiseerr, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 213f47c40..709106b6b 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -32,7 +32,6 @@ from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .base import SingletonConstant from .coercions import _document_text_coercion -from .traversals import _get_children from .traversals import HasCopyInternals from .traversals import MemoizedHasCacheKey from .traversals import NO_CACHE @@ -389,33 +388,6 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def get_children(self, omit_attrs=(), **kw): - r"""Return immediate child :class:`.visitors.Traversible` - elements of this :class:`.visitors.Traversible`. - - This is used for visit traversal. - - \**kw may contain flags that change the collection that is - returned, for example to return a subset of items in order to - cut down on larger traversals, or to return child items from a - different context (such as schema-level collections instead of - clause-level). - - """ - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return [] - - return itertools.chain.from_iterable( - meth(obj, **kw) - for attrname, obj, meth in _get_children.run_generated_dispatch( - self, traverse_internals, "_generated_get_children_traversal" - ) - if attrname not in omit_attrs and obj is not None - ) - def self_group(self, against=None): """Apply a 'grouping' to this :class:`_expression.ClauseElement`. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 1610191d1..e1dee091b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -18,7 +18,9 @@ from operator import attrgetter from . import coercions from . import operators from . import roles +from . import traversals from . import type_api +from . import visitors from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState): self.statement = statement self.from_clauses = statement._from_obj + for memoized_entities in statement._memoized_select_entities: + self._setup_joins( + memoized_entities._setup_joins, memoized_entities._raw_columns + ) + if statement._setup_joins: - self._setup_joins(statement._setup_joins) + self._setup_joins(statement._setup_joins, statement._raw_columns) self.froms = self._get_froms(statement) @@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState): def all_selected_columns(cls, statement): return [c for c in _select_iterables(statement._raw_columns)] - def _setup_joins(self, args): + def _setup_joins(self, args, raw_columns): for (right, onclause, left, flags) in args: isouter = flags["isouter"] full = flags["full"] @@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState): left, replace_from_obj_index, ) = self._join_determine_implicit_left_side( - left, right, onclause + raw_columns, left, right, onclause ) else: (replace_from_obj_index) = self._join_place_explicit_left_side( @@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState): ) @util.preload_module("sqlalchemy.sql.util") - def _join_determine_implicit_left_side(self, left, right, onclause): + def _join_determine_implicit_left_side( + self, raw_columns, left, right, onclause + ): """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState): for from_clause in itertools.chain( itertools.chain.from_iterable( - [ - element._from_objects - for element in statement._raw_columns - ] + [element._from_objects for element in raw_columns] ), itertools.chain.from_iterable( [ @@ -4531,6 +4537,47 @@ class _SelectFromElements(object): yield element +class _MemoizedSelectEntities( + traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible +): + __visit_name__ = "memoized_select_entities" + + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_with_options", InternalTraversal.dp_executable_options), + ] + + _annotations = util.EMPTY_DICT + + def _clone(self, **kw): + c = self.__class__.__new__(self.__class__) + c.__dict__ = {k: v for k, v in self.__dict__.items()} + c._is_clone_of = self + return c + + @classmethod + def _generate_for_statement(cls, select_stmt): + if ( + select_stmt._setup_joins + or select_stmt._legacy_setup_joins + or select_stmt._with_options + ): + self = _MemoizedSelectEntities() + self._raw_columns = select_stmt._raw_columns + self._setup_joins = select_stmt._setup_joins + self._legacy_setup_joins = select_stmt._legacy_setup_joins + self._with_options = select_stmt._with_options + + select_stmt._memoized_select_entities += (self,) + select_stmt._raw_columns = ( + select_stmt._setup_joins + ) = ( + select_stmt._legacy_setup_joins + ) = select_stmt._with_options = () + + class Select( HasPrefixes, HasSuffixes, @@ -4559,6 +4606,7 @@ class Select( _setup_joins = () _legacy_setup_joins = () + _memoized_select_entities = () _distinct = False _distinct_on = () @@ -4574,6 +4622,10 @@ class Select( _traverse_internals = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ( + "_memoized_select_entities", + InternalTraversal.dp_memoized_select_entities, + ), ("_from_obj", InternalTraversal.dp_clauseelement_list), ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), ("_having_criteria", InternalTraversal.dp_clauseelement_tuple), @@ -5461,16 +5513,14 @@ class Select( # is the case for now. self._assert_no_memoizations() - rc = [] - for c in coercions._expression_collection_was_a_list( - "columns", "Select.with_only_columns", columns - ): - c = coercions.expect(roles.ColumnsClauseRole, c) - # TODO: why are we doing this here? - if isinstance(c, ScalarSelect): - c = c.self_group(against=operators.comma_op) - rc.append(c) - self._raw_columns = rc + _MemoizedSelectEntities._generate_for_statement(self) + + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, c) + for c in coercions._expression_collection_was_a_list( + "columns", "Select.with_only_columns", columns + ) + ] @property def whereclause(self): diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 35f2bd62f..a86d16ef4 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -194,6 +194,8 @@ class HasCacheKey(object): elif ( meth is InternalTraversal.dp_clauseelement_list or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities ): result += ( attrname, @@ -409,6 +411,9 @@ class _CacheKey(ExtendedInternalTraversal): visit_clauseelement_list = InternalTraversal.dp_clauseelement_list visit_annotations_key = InternalTraversal.dp_annotations_key visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) visit_string = ( visit_boolean @@ -799,6 +804,9 @@ class _CopyInternals(InternalTraversal): for (target, onclause, from_, flags) in element ) + def visit_memoized_select_entities(self, attrname, parent, element, **kw): + return self.visit_clauseelement_tuple(attrname, parent, element, **kw) + def visit_dml_ordered_values( self, attrname, parent, element, clone=_clone, **kw ): @@ -919,6 +927,9 @@ class _GetChildren(InternalTraversal): if onclause is not None and not isinstance(onclause, str): yield _flatten_clauseelement(onclause) + def visit_memoized_select_entities(self, element, **kw): + return self.visit_clauseelement_tuple(element, **kw) + def visit_dml_ordered_values(self, element, **kw): for k, v in element: if hasattr(k, "__clause_element__"): @@ -1265,6 +1276,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((l_onclause, r_onclause)) self.stack.append((l_from, r_from)) + def visit_memoized_select_entities( + self, attrname, left_parent, left, right_parent, right, **kw + ): + return self.visit_clauseelement_tuple( + attrname, left_parent, left, right_parent, right, **kw + ) + def visit_table_hint_list( self, attrname, left_parent, left, right_parent, right, **kw ): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 93ee8eb1c..c750c546a 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -24,6 +24,7 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ . """ from collections import deque +import itertools import operator from .. import exc @@ -119,6 +120,38 @@ class Traversible(util.with_metaclass(TraversibleType)): """ + @util.preload_module("sqlalchemy.sql.traversals") + def get_children(self, omit_attrs=(), **kw): + r"""Return immediate child :class:`.visitors.Traversible` + elements of this :class:`.visitors.Traversible`. + + This is used for visit traversal. + + \**kw may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + + traversals = util.preloaded.sql_traversals + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return [] + + dispatch = traversals._get_children.run_generated_dispatch + return itertools.chain.from_iterable( + meth(obj, **kw) + for attrname, obj, meth in dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ) + if attrname not in omit_attrs and obj is not None + ) + class _InternalTraversalType(type): def __init__(cls, clsname, bases, clsdict): @@ -393,6 +426,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): dp_setup_join_tuple = symbol("SJ") + dp_memoized_select_entities = symbol("ME") + dp_statement_hint_list = symbol("SH") """Visit the ``_statement_hints`` collection of a :class:`_expression.Select` |
