summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-06-15 15:13:34 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-06-17 09:48:52 -0400
commit5b3e887f46afdbee312d5efd2a14f7c9b7eeac65 (patch)
tree7c12dd2686dc3d26222383d39527b24613e49da3 /lib/sqlalchemy
parent29fbbd9cebf5d4a4f21d01a74bcfb6dce923fe1b (diff)
downloadsqlalchemy-5b3e887f46afdbee312d5efd2a14f7c9b7eeac65.tar.gz
memoize current options and joins w with_entities/with_only_cols
Fixed further regressions in the same area as that of :ticket:`6052` where loader options as well as invocations of methods like :meth:`_orm.Query.join` would fail if the left side of the statement for which the option/join depends upon were replaced by using the :meth:`_orm.Query.with_entities` method, or when using 2.0 style queries when using the :meth:`_sql.Select.with_only_columns` method. A new set of state has been added to the objects which tracks the "left" entities that the options / join were made against which is memoized when the lead entities are changed. Fixes: #6503 Fixes: #6253 Change-Id: I211b2af98b0b20d1263fb15dc513884dcc5de6a4
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/context.py203
-rw-r--r--lib/sqlalchemy/orm/interfaces.py12
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py30
-rw-r--r--lib/sqlalchemy/sql/elements.py28
-rw-r--r--lib/sqlalchemy/sql/selectable.py86
-rw-r--r--lib/sqlalchemy/sql/traversals.py18
-rw-r--r--lib/sqlalchemy/sql/visitors.py35
8 files changed, 322 insertions, 94 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index e4448f953..321eeada0 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -322,10 +322,16 @@ class ORMCompileState(CompileState):
return loading.instances(result, querycontext)
@property
- def _mapper_entities(self):
- return (
+ def _lead_mapper_entities(self):
+ """return all _MapperEntity objects in the lead entities collection.
+
+ Does **not** include entities that have been replaced by
+ with_entities(), with_only_columns()
+
+ """
+ return [
ent for ent in self._entities if isinstance(ent, _MapperEntity)
- )
+ ]
def _create_with_polymorphic_adapter(self, ext_info, selectable):
if (
@@ -405,7 +411,9 @@ class ORMFromStatementCompileState(ORMCompileState):
self.use_legacy_query_style,
)
- _QueryEntity.to_compile_state(self, statement_container._raw_columns)
+ _QueryEntity.to_compile_state(
+ self, statement_container._raw_columns, self._entities
+ )
self.current_path = statement_container._compile_options._current_path
@@ -477,6 +485,8 @@ class ORMFromStatementCompileState(ORMCompileState):
class ORMSelectCompileState(ORMCompileState, SelectState):
_joinpath = _joinpoint = _EMPTY_DICT
+ _memoized_entities = _EMPTY_DICT
+
_from_obj_alias = None
_has_mapper_entities = False
@@ -572,15 +582,48 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement._label_style, self.use_legacy_query_style
)
- _QueryEntity.to_compile_state(self, select_statement._raw_columns)
+ if select_statement._memoized_select_entities:
+ self._memoized_entities = {
+ memoized_entities: _QueryEntity.to_compile_state(
+ self,
+ memoized_entities._raw_columns,
+ [],
+ )
+ for memoized_entities in (
+ select_statement._memoized_select_entities
+ )
+ }
+
+ _QueryEntity.to_compile_state(
+ self, select_statement._raw_columns, self._entities
+ )
self.current_path = select_statement._compile_options._current_path
self.eager_order_by = ()
- if toplevel and select_statement._with_options:
+ if toplevel and (
+ select_statement._with_options
+ or select_statement._memoized_select_entities
+ ):
self.attributes = {"_unbound_load_dedupes": set()}
+ for (
+ memoized_entities
+ ) in select_statement._memoized_select_entities:
+ for opt in memoized_entities._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state_replaced_entities(
+ self,
+ [
+ ent
+ for ent in self._memoized_entities[
+ memoized_entities
+ ]
+ if isinstance(ent, _MapperEntity)
+ ],
+ )
+
for opt in self.select_statement._with_options:
if opt._is_compile_state:
opt.process_compile_state(self)
@@ -626,11 +669,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if self.compile_options._set_base_alias:
self._set_select_from_alias()
+ for memoized_entities in query._memoized_select_entities:
+ if memoized_entities._setup_joins:
+ self._join(
+ memoized_entities._setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+ if memoized_entities._legacy_setup_joins:
+ self._legacy_join(
+ memoized_entities._legacy_setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+
if query._setup_joins:
- self._join(query._setup_joins)
+ self._join(query._setup_joins, self._entities)
if query._legacy_setup_joins:
- self._legacy_join(query._legacy_setup_joins)
+ self._legacy_join(query._legacy_setup_joins, self._entities)
current_adapter = self._get_current_adapter()
@@ -782,7 +837,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# entities will also set up polymorphic adapters for mappers
# that have with_polymorphic configured
- _QueryEntity.to_compile_state(self, query._raw_columns)
+ _QueryEntity.to_compile_state(self, query._raw_columns, self._entities)
return self
@classmethod
@@ -921,7 +976,18 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def _all_equivs(self):
equivs = {}
- for ent in self._mapper_entities:
+
+ for memoized_entities in self._memoized_entities.values():
+ for ent in [
+ ent
+ for ent in memoized_entities
+ if isinstance(ent, _MapperEntity)
+ ]:
+ equivs.update(ent.mapper._equivalent_columns)
+
+ for ent in [
+ ent for ent in self._entities if isinstance(ent, _MapperEntity)
+ ]:
equivs.update(ent.mapper._equivalent_columns)
return equivs
@@ -1211,7 +1277,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
return _adapt_clause
- def _join(self, args):
+ def _join(self, args, entities_collection):
for (right, onclause, from_, flags) in args:
isouter = flags["isouter"]
full = flags["full"]
@@ -1316,6 +1382,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# figure out the final "left" and "right" sides and create an
# ORMJoin to add to our _from_obj tuple
self._join_left_to_right(
+ entities_collection,
left,
right,
onclause,
@@ -1326,7 +1393,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
full,
)
- def _legacy_join(self, args):
+ def _legacy_join(self, args, entities_collection):
"""consumes arguments from join() or outerjoin(), places them into a
consistent format with which to form the actual JOIN constructs.
@@ -1474,6 +1541,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# figure out the final "left" and "right" sides and create an
# ORMJoin to add to our _from_obj tuple
self._join_left_to_right(
+ entities_collection,
left,
right,
onclause,
@@ -1489,6 +1557,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def _join_left_to_right(
self,
+ entities_collection,
left,
right,
onclause,
@@ -1513,7 +1582,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left,
replace_from_obj_index,
use_entity_index,
- ) = self._join_determine_implicit_left_side(left, right, onclause)
+ ) = self._join_determine_implicit_left_side(
+ entities_collection, left, right, onclause
+ )
else:
# left is given via a relationship/name, or as explicit left side.
# Determine where in our
@@ -1522,7 +1593,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
(
replace_from_obj_index,
use_entity_index,
- ) = self._join_place_explicit_left_side(left)
+ ) = self._join_place_explicit_left_side(entities_collection, left)
if left is right and not create_aliases:
raise sa_exc.InvalidRequestError(
@@ -1568,9 +1639,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# entity_zero.selectable, but if with_polymorphic() were used
# might be distinct
assert isinstance(
- self._entities[use_entity_index], _MapperEntity
+ entities_collection[use_entity_index], _MapperEntity
)
- left_clause = self._entities[use_entity_index].selectable
+ left_clause = entities_collection[use_entity_index].selectable
else:
left_clause = left
@@ -1585,7 +1656,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
]
- def _join_determine_implicit_left_side(self, left, right, onclause):
+ def _join_determine_implicit_left_side(
+ self, entities_collection, left, right, onclause
+ ):
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -1635,12 +1708,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
"to help resolve the ambiguity." % (right,)
)
- elif self._entities:
+ elif entities_collection:
# we have no explicit FROMs, so the implicit left has to
# come from our list of entities.
potential = {}
- for entity_index, ent in enumerate(self._entities):
+ for entity_index, ent in enumerate(entities_collection):
entity = ent.entity_zero_or_selectable
if entity is None:
continue
@@ -1689,7 +1762,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
return left, replace_from_obj_index, use_entity_index
- def _join_place_explicit_left_side(self, left):
+ def _join_place_explicit_left_side(self, entities_collection, left):
"""When join conditions express a left side explicitly, determine
where in our existing list of FROM clauses we should join towards,
or if we need to make a new join, and if so is it from one of our
@@ -1743,10 +1816,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# aliasing / adaptation rules present on that entity if any
if (
replace_from_obj_index is None
- and self._entities
+ and entities_collection
and hasattr(l_info, "mapper")
):
- for idx, ent in enumerate(self._entities):
+ for idx, ent in enumerate(entities_collection):
# TODO: should we be checking for multiple mapper entities
# matching?
if isinstance(ent, _MapperEntity) and ent.corresponds_to(left):
@@ -2194,11 +2267,14 @@ class _QueryEntity(object):
__slots__ = ()
@classmethod
- def to_compile_state(cls, compile_state, entities):
+ def to_compile_state(cls, compile_state, entities, entities_collection):
+
for idx, entity in enumerate(entities):
if entity._is_lambda_element:
if entity._is_sequence:
- cls.to_compile_state(compile_state, entity._resolved)
+ cls.to_compile_state(
+ compile_state, entity._resolved, entities_collection
+ )
continue
else:
entity = entity._resolved
@@ -2206,26 +2282,38 @@ class _QueryEntity(object):
if entity.is_clause_element:
if entity.is_selectable:
if "parententity" in entity._annotations:
- _MapperEntity(compile_state, entity)
+ _MapperEntity(
+ compile_state, entity, entities_collection
+ )
else:
_ColumnEntity._for_columns(
- compile_state, entity._select_iterable, idx
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
)
else:
if entity._annotations.get("bundle", False):
- _BundleEntity(compile_state, entity)
+ _BundleEntity(
+ compile_state, entity, entities_collection
+ )
elif entity._is_clause_list:
# this is legacy only - test_composites.py
# test_query_cols_legacy
_ColumnEntity._for_columns(
- compile_state, entity._select_iterable, idx
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
)
else:
_ColumnEntity._for_columns(
- compile_state, [entity], idx
+ compile_state, [entity], entities_collection, idx
)
elif entity.is_bundle:
- _BundleEntity(compile_state, entity)
+ _BundleEntity(compile_state, entity, entities_collection)
+
+ return entities_collection
class _MapperEntity(_QueryEntity):
@@ -2244,8 +2332,8 @@ class _MapperEntity(_QueryEntity):
"_polymorphic_discriminator",
)
- def __init__(self, compile_state, entity):
- compile_state._entities.append(self)
+ def __init__(self, compile_state, entity, entities_collection):
+ entities_collection.append(self)
if compile_state._primary_entity is None:
compile_state._primary_entity = self
compile_state._has_mapper_entities = True
@@ -2418,7 +2506,12 @@ class _BundleEntity(_QueryEntity):
)
def __init__(
- self, compile_state, expr, setup_entities=True, parent_bundle=None
+ self,
+ compile_state,
+ expr,
+ entities_collection,
+ setup_entities=True,
+ parent_bundle=None,
):
compile_state._has_orm_entities = True
@@ -2426,7 +2519,7 @@ class _BundleEntity(_QueryEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
if isinstance(
expr, (attributes.QueryableAttribute, interfaces.PropComparator)
@@ -2443,12 +2536,26 @@ class _BundleEntity(_QueryEntity):
if setup_entities:
for expr in bundle.exprs:
if "bundle" in expr._annotations:
- _BundleEntity(compile_state, expr, parent_bundle=self)
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ parent_bundle=self,
+ )
elif isinstance(expr, Bundle):
- _BundleEntity(compile_state, expr, parent_bundle=self)
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ parent_bundle=self,
+ )
else:
_ORMColumnEntity._for_columns(
- compile_state, [expr], None, parent_bundle=self
+ compile_state,
+ [expr],
+ entities_collection,
+ None,
+ parent_bundle=self,
)
self.supports_single_entity = self.bundle.single_entity
@@ -2516,7 +2623,12 @@ class _ColumnEntity(_QueryEntity):
@classmethod
def _for_columns(
- cls, compile_state, columns, raw_column_index, parent_bundle=None
+ cls,
+ compile_state,
+ columns,
+ entities_collection,
+ raw_column_index,
+ parent_bundle=None,
):
for column in columns:
annotations = column._annotations
@@ -2532,6 +2644,7 @@ class _ColumnEntity(_QueryEntity):
_IdentityTokenEntity(
compile_state,
column,
+ entities_collection,
_entity,
raw_column_index,
parent_bundle=parent_bundle,
@@ -2540,6 +2653,7 @@ class _ColumnEntity(_QueryEntity):
_ORMColumnEntity(
compile_state,
column,
+ entities_collection,
_entity,
raw_column_index,
parent_bundle=parent_bundle,
@@ -2548,6 +2662,7 @@ class _ColumnEntity(_QueryEntity):
_RawColumnEntity(
compile_state,
column,
+ entities_collection,
raw_column_index,
parent_bundle=parent_bundle,
)
@@ -2630,7 +2745,12 @@ class _RawColumnEntity(_ColumnEntity):
)
def __init__(
- self, compile_state, column, raw_column_index, parent_bundle=None
+ self,
+ compile_state,
+ column,
+ entities_collection,
+ raw_column_index,
+ parent_bundle=None,
):
self.expr = column
self.raw_column_index = raw_column_index
@@ -2643,7 +2763,7 @@ class _RawColumnEntity(_ColumnEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
self.column = column
self.entity_zero_or_selectable = (
@@ -2690,6 +2810,7 @@ class _ORMColumnEntity(_ColumnEntity):
self,
compile_state,
column,
+ entities_collection,
parententity,
raw_column_index,
parent_bundle=None,
@@ -2729,7 +2850,7 @@ class _ORMColumnEntity(_ColumnEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
compile_state._has_orm_entities = True
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index c9a601f99..28b4bfb2d 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -750,6 +750,18 @@ class LoaderOption(ORMOption):
_is_compile_state = True
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ """Apply a modification to a given :class:`.CompileState`,
+ given entities that were replaced by with_only_columns() or
+ with_entities().
+
+ .. versionadded:: 1.4.19
+
+ """
+ self.process_compile_state(compile_state)
+
def process_compile_state(self, compile_state):
"""Apply a modification to a given :class:`.CompileState`."""
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index cacfb8d84..7ba31fa7a 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -57,6 +57,7 @@ from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import _entity_namespace_key
from ..sql.base import _generative
from ..sql.base import Executable
+from ..sql.selectable import _MemoizedSelectEntities
from ..sql.selectable import _SelectFromElements
from ..sql.selectable import ForUpdateArg
from ..sql.selectable import GroupedElement
@@ -125,6 +126,8 @@ class Query(
_legacy_setup_joins = ()
_label_style = LABEL_STYLE_LEGACY_ORM
+ _memoized_select_entities = ()
+
_compile_options = ORMCompileState.default_compile_options
load_options = QueryContext.default_load_options
@@ -1433,6 +1436,7 @@ class Query(
limit(1)
"""
+ _MemoizedSelectEntities._generate_for_statement(self)
self._set_entities(entities)
@_generative
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index e371442fd..91e627525 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -172,13 +172,32 @@ class Load(Generative, LoaderOption):
_of_type = None
_extra_criteria = ()
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+
+ # process is being run here so that the options given are validated
+ # against what the lead entities were, as well as to accommodate
+ # for the entities having been replaced with equivalents
+ self._process(
+ compile_state,
+ mapper_entities,
+ not bool(compile_state.current_path),
+ )
+
def process_compile_state(self, compile_state):
if not compile_state.compile_options._enable_eagerloads:
return
- self._process(compile_state, not bool(compile_state.current_path))
+ self._process(
+ compile_state,
+ compile_state._lead_mapper_entities,
+ not bool(compile_state.current_path),
+ )
- def _process(self, compile_state, raiseerr):
+ def _process(self, compile_state, mapper_entities, raiseerr):
is_refresh = compile_state.compile_options._for_refresh_state
current_path = compile_state.current_path
if current_path:
@@ -700,7 +719,7 @@ class _UnboundLoad(Load):
state["path"] = tuple(ret)
self.__dict__ = state
- def _process(self, compile_state, raiseerr):
+ def _process(self, compile_state, mapper_entities, raiseerr):
dedupes = compile_state.attributes["_unbound_load_dedupes"]
is_refresh = compile_state.compile_options._for_refresh_state
for val in self._to_bind:
@@ -709,10 +728,7 @@ class _UnboundLoad(Load):
if is_refresh and not val.propagate_to_loaders:
continue
val._bind_loader(
- [
- ent.entity_zero
- for ent in compile_state._mapper_entities
- ],
+ [ent.entity_zero for ent in mapper_entities],
compile_state.current_path,
compile_state.attributes,
raiseerr,
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 213f47c40..709106b6b 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -32,7 +32,6 @@ from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .base import SingletonConstant
from .coercions import _document_text_coercion
-from .traversals import _get_children
from .traversals import HasCopyInternals
from .traversals import MemoizedHasCacheKey
from .traversals import NO_CACHE
@@ -389,33 +388,6 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def get_children(self, omit_attrs=(), **kw):
- r"""Return immediate child :class:`.visitors.Traversible`
- elements of this :class:`.visitors.Traversible`.
-
- This is used for visit traversal.
-
- \**kw may contain flags that change the collection that is
- returned, for example to return a subset of items in order to
- cut down on larger traversals, or to return child items from a
- different context (such as schema-level collections instead of
- clause-level).
-
- """
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return []
-
- return itertools.chain.from_iterable(
- meth(obj, **kw)
- for attrname, obj, meth in _get_children.run_generated_dispatch(
- self, traverse_internals, "_generated_get_children_traversal"
- )
- if attrname not in omit_attrs and obj is not None
- )
-
def self_group(self, against=None):
"""Apply a 'grouping' to this :class:`_expression.ClauseElement`.
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 1610191d1..e1dee091b 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -18,7 +18,9 @@ from operator import attrgetter
from . import coercions
from . import operators
from . import roles
+from . import traversals
from . import type_api
+from . import visitors
from .annotation import Annotated
from .annotation import SupportsCloneAnnotations
from .base import _clone
@@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState):
self.statement = statement
self.from_clauses = statement._from_obj
+ for memoized_entities in statement._memoized_select_entities:
+ self._setup_joins(
+ memoized_entities._setup_joins, memoized_entities._raw_columns
+ )
+
if statement._setup_joins:
- self._setup_joins(statement._setup_joins)
+ self._setup_joins(statement._setup_joins, statement._raw_columns)
self.froms = self._get_froms(statement)
@@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState):
def all_selected_columns(cls, statement):
return [c for c in _select_iterables(statement._raw_columns)]
- def _setup_joins(self, args):
+ def _setup_joins(self, args, raw_columns):
for (right, onclause, left, flags) in args:
isouter = flags["isouter"]
full = flags["full"]
@@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState):
left,
replace_from_obj_index,
) = self._join_determine_implicit_left_side(
- left, right, onclause
+ raw_columns, left, right, onclause
)
else:
(replace_from_obj_index) = self._join_place_explicit_left_side(
@@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState):
)
@util.preload_module("sqlalchemy.sql.util")
- def _join_determine_implicit_left_side(self, left, right, onclause):
+ def _join_determine_implicit_left_side(
+ self, raw_columns, left, right, onclause
+ ):
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState):
for from_clause in itertools.chain(
itertools.chain.from_iterable(
- [
- element._from_objects
- for element in statement._raw_columns
- ]
+ [element._from_objects for element in raw_columns]
),
itertools.chain.from_iterable(
[
@@ -4531,6 +4537,47 @@ class _SelectFromElements(object):
yield element
+class _MemoizedSelectEntities(
+ traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+):
+ __visit_name__ = "memoized_select_entities"
+
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_with_options", InternalTraversal.dp_executable_options),
+ ]
+
+ _annotations = util.EMPTY_DICT
+
+ def _clone(self, **kw):
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = {k: v for k, v in self.__dict__.items()}
+ c._is_clone_of = self
+ return c
+
+ @classmethod
+ def _generate_for_statement(cls, select_stmt):
+ if (
+ select_stmt._setup_joins
+ or select_stmt._legacy_setup_joins
+ or select_stmt._with_options
+ ):
+ self = _MemoizedSelectEntities()
+ self._raw_columns = select_stmt._raw_columns
+ self._setup_joins = select_stmt._setup_joins
+ self._legacy_setup_joins = select_stmt._legacy_setup_joins
+ self._with_options = select_stmt._with_options
+
+ select_stmt._memoized_select_entities += (self,)
+ select_stmt._raw_columns = (
+ select_stmt._setup_joins
+ ) = (
+ select_stmt._legacy_setup_joins
+ ) = select_stmt._with_options = ()
+
+
class Select(
HasPrefixes,
HasSuffixes,
@@ -4559,6 +4606,7 @@ class Select(
_setup_joins = ()
_legacy_setup_joins = ()
+ _memoized_select_entities = ()
_distinct = False
_distinct_on = ()
@@ -4574,6 +4622,10 @@ class Select(
_traverse_internals = (
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ (
+ "_memoized_select_entities",
+ InternalTraversal.dp_memoized_select_entities,
+ ),
("_from_obj", InternalTraversal.dp_clauseelement_list),
("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
@@ -5461,16 +5513,14 @@ class Select(
# is the case for now.
self._assert_no_memoizations()
- rc = []
- for c in coercions._expression_collection_was_a_list(
- "columns", "Select.with_only_columns", columns
- ):
- c = coercions.expect(roles.ColumnsClauseRole, c)
- # TODO: why are we doing this here?
- if isinstance(c, ScalarSelect):
- c = c.self_group(against=operators.comma_op)
- rc.append(c)
- self._raw_columns = rc
+ _MemoizedSelectEntities._generate_for_statement(self)
+
+ self._raw_columns = [
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in coercions._expression_collection_was_a_list(
+ "columns", "Select.with_only_columns", columns
+ )
+ ]
@property
def whereclause(self):
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 35f2bd62f..a86d16ef4 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -194,6 +194,8 @@ class HasCacheKey(object):
elif (
meth is InternalTraversal.dp_clauseelement_list
or meth is InternalTraversal.dp_clauseelement_tuple
+ or meth
+ is InternalTraversal.dp_memoized_select_entities
):
result += (
attrname,
@@ -409,6 +411,9 @@ class _CacheKey(ExtendedInternalTraversal):
visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
visit_annotations_key = InternalTraversal.dp_annotations_key
visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+ visit_memoized_select_entities = (
+ InternalTraversal.dp_memoized_select_entities
+ )
visit_string = (
visit_boolean
@@ -799,6 +804,9 @@ class _CopyInternals(InternalTraversal):
for (target, onclause, from_, flags) in element
)
+ def visit_memoized_select_entities(self, attrname, parent, element, **kw):
+ return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
+
def visit_dml_ordered_values(
self, attrname, parent, element, clone=_clone, **kw
):
@@ -919,6 +927,9 @@ class _GetChildren(InternalTraversal):
if onclause is not None and not isinstance(onclause, str):
yield _flatten_clauseelement(onclause)
+ def visit_memoized_select_entities(self, element, **kw):
+ return self.visit_clauseelement_tuple(element, **kw)
+
def visit_dml_ordered_values(self, element, **kw):
for k, v in element:
if hasattr(k, "__clause_element__"):
@@ -1265,6 +1276,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((l_onclause, r_onclause))
self.stack.append((l_from, r_from))
+ def visit_memoized_select_entities(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.visit_clauseelement_tuple(
+ attrname, left_parent, left, right_parent, right, **kw
+ )
+
def visit_table_hint_list(
self, attrname, left_parent, left, right_parent, right, **kw
):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 93ee8eb1c..c750c546a 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -24,6 +24,7 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
"""
from collections import deque
+import itertools
import operator
from .. import exc
@@ -119,6 +120,38 @@ class Traversible(util.with_metaclass(TraversibleType)):
"""
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(self, omit_attrs=(), **kw):
+ r"""Return immediate child :class:`.visitors.Traversible`
+ elements of this :class:`.visitors.Traversible`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
class _InternalTraversalType(type):
def __init__(cls, clsname, bases, clsdict):
@@ -393,6 +426,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
dp_setup_join_tuple = symbol("SJ")
+ dp_memoized_select_entities = symbol("ME")
+
dp_statement_hint_list = symbol("SH")
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`