summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_14/6503.rst13
-rw-r--r--lib/sqlalchemy/orm/context.py203
-rw-r--r--lib/sqlalchemy/orm/interfaces.py12
-rw-r--r--lib/sqlalchemy/orm/query.py4
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py30
-rw-r--r--lib/sqlalchemy/sql/elements.py28
-rw-r--r--lib/sqlalchemy/sql/selectable.py86
-rw-r--r--lib/sqlalchemy/sql/traversals.py18
-rw-r--r--lib/sqlalchemy/sql/visitors.py35
-rw-r--r--test/orm/test_cache_key.py106
-rw-r--r--test/orm/test_joins.py37
-rw-r--r--test/orm/test_options.py102
-rw-r--r--test/profiles.txt20
-rw-r--r--test/sql/test_compare.py8
-rw-r--r--test/sql/test_external_traversal.py23
-rw-r--r--test/sql/test_select.py27
16 files changed, 641 insertions, 111 deletions
diff --git a/doc/build/changelog/unreleased_14/6503.rst b/doc/build/changelog/unreleased_14/6503.rst
new file mode 100644
index 000000000..a2d50bc99
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/6503.rst
@@ -0,0 +1,13 @@
+.. change::
+ :tags: bug, orm, regression
+ :tickets: 6503, 6253
+
+ Fixed further regressions in the same area as that of :ticket:`6052` where
+ loader options as well as invocations of methods like
+ :meth:`_orm.Query.join` would fail if the left side of the statement for
+ which the option/join depends upon were replaced by using the
+ :meth:`_orm.Query.with_entities` method, or when using 2.0 style queries
+ when using the :meth:`_sql.Select.with_only_columns` method. A new set of
+ state has been added to the objects which tracks the "left" entities that
+ the options / join were made against which is memoized when the lead
+ entities are changed.
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index e4448f953..321eeada0 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -322,10 +322,16 @@ class ORMCompileState(CompileState):
return loading.instances(result, querycontext)
@property
- def _mapper_entities(self):
- return (
+ def _lead_mapper_entities(self):
+ """return all _MapperEntity objects in the lead entities collection.
+
+ Does **not** include entities that have been replaced by
+ with_entities(), with_only_columns()
+
+ """
+ return [
ent for ent in self._entities if isinstance(ent, _MapperEntity)
- )
+ ]
def _create_with_polymorphic_adapter(self, ext_info, selectable):
if (
@@ -405,7 +411,9 @@ class ORMFromStatementCompileState(ORMCompileState):
self.use_legacy_query_style,
)
- _QueryEntity.to_compile_state(self, statement_container._raw_columns)
+ _QueryEntity.to_compile_state(
+ self, statement_container._raw_columns, self._entities
+ )
self.current_path = statement_container._compile_options._current_path
@@ -477,6 +485,8 @@ class ORMFromStatementCompileState(ORMCompileState):
class ORMSelectCompileState(ORMCompileState, SelectState):
_joinpath = _joinpoint = _EMPTY_DICT
+ _memoized_entities = _EMPTY_DICT
+
_from_obj_alias = None
_has_mapper_entities = False
@@ -572,15 +582,48 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement._label_style, self.use_legacy_query_style
)
- _QueryEntity.to_compile_state(self, select_statement._raw_columns)
+ if select_statement._memoized_select_entities:
+ self._memoized_entities = {
+ memoized_entities: _QueryEntity.to_compile_state(
+ self,
+ memoized_entities._raw_columns,
+ [],
+ )
+ for memoized_entities in (
+ select_statement._memoized_select_entities
+ )
+ }
+
+ _QueryEntity.to_compile_state(
+ self, select_statement._raw_columns, self._entities
+ )
self.current_path = select_statement._compile_options._current_path
self.eager_order_by = ()
- if toplevel and select_statement._with_options:
+ if toplevel and (
+ select_statement._with_options
+ or select_statement._memoized_select_entities
+ ):
self.attributes = {"_unbound_load_dedupes": set()}
+ for (
+ memoized_entities
+ ) in select_statement._memoized_select_entities:
+ for opt in memoized_entities._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state_replaced_entities(
+ self,
+ [
+ ent
+ for ent in self._memoized_entities[
+ memoized_entities
+ ]
+ if isinstance(ent, _MapperEntity)
+ ],
+ )
+
for opt in self.select_statement._with_options:
if opt._is_compile_state:
opt.process_compile_state(self)
@@ -626,11 +669,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if self.compile_options._set_base_alias:
self._set_select_from_alias()
+ for memoized_entities in query._memoized_select_entities:
+ if memoized_entities._setup_joins:
+ self._join(
+ memoized_entities._setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+ if memoized_entities._legacy_setup_joins:
+ self._legacy_join(
+ memoized_entities._legacy_setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+
if query._setup_joins:
- self._join(query._setup_joins)
+ self._join(query._setup_joins, self._entities)
if query._legacy_setup_joins:
- self._legacy_join(query._legacy_setup_joins)
+ self._legacy_join(query._legacy_setup_joins, self._entities)
current_adapter = self._get_current_adapter()
@@ -782,7 +837,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# entities will also set up polymorphic adapters for mappers
# that have with_polymorphic configured
- _QueryEntity.to_compile_state(self, query._raw_columns)
+ _QueryEntity.to_compile_state(self, query._raw_columns, self._entities)
return self
@classmethod
@@ -921,7 +976,18 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def _all_equivs(self):
equivs = {}
- for ent in self._mapper_entities:
+
+ for memoized_entities in self._memoized_entities.values():
+ for ent in [
+ ent
+ for ent in memoized_entities
+ if isinstance(ent, _MapperEntity)
+ ]:
+ equivs.update(ent.mapper._equivalent_columns)
+
+ for ent in [
+ ent for ent in self._entities if isinstance(ent, _MapperEntity)
+ ]:
equivs.update(ent.mapper._equivalent_columns)
return equivs
@@ -1211,7 +1277,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
return _adapt_clause
- def _join(self, args):
+ def _join(self, args, entities_collection):
for (right, onclause, from_, flags) in args:
isouter = flags["isouter"]
full = flags["full"]
@@ -1316,6 +1382,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# figure out the final "left" and "right" sides and create an
# ORMJoin to add to our _from_obj tuple
self._join_left_to_right(
+ entities_collection,
left,
right,
onclause,
@@ -1326,7 +1393,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
full,
)
- def _legacy_join(self, args):
+ def _legacy_join(self, args, entities_collection):
"""consumes arguments from join() or outerjoin(), places them into a
consistent format with which to form the actual JOIN constructs.
@@ -1474,6 +1541,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# figure out the final "left" and "right" sides and create an
# ORMJoin to add to our _from_obj tuple
self._join_left_to_right(
+ entities_collection,
left,
right,
onclause,
@@ -1489,6 +1557,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def _join_left_to_right(
self,
+ entities_collection,
left,
right,
onclause,
@@ -1513,7 +1582,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left,
replace_from_obj_index,
use_entity_index,
- ) = self._join_determine_implicit_left_side(left, right, onclause)
+ ) = self._join_determine_implicit_left_side(
+ entities_collection, left, right, onclause
+ )
else:
# left is given via a relationship/name, or as explicit left side.
# Determine where in our
@@ -1522,7 +1593,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
(
replace_from_obj_index,
use_entity_index,
- ) = self._join_place_explicit_left_side(left)
+ ) = self._join_place_explicit_left_side(entities_collection, left)
if left is right and not create_aliases:
raise sa_exc.InvalidRequestError(
@@ -1568,9 +1639,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# entity_zero.selectable, but if with_polymorphic() were used
# might be distinct
assert isinstance(
- self._entities[use_entity_index], _MapperEntity
+ entities_collection[use_entity_index], _MapperEntity
)
- left_clause = self._entities[use_entity_index].selectable
+ left_clause = entities_collection[use_entity_index].selectable
else:
left_clause = left
@@ -1585,7 +1656,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
]
- def _join_determine_implicit_left_side(self, left, right, onclause):
+ def _join_determine_implicit_left_side(
+ self, entities_collection, left, right, onclause
+ ):
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -1635,12 +1708,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
"to help resolve the ambiguity." % (right,)
)
- elif self._entities:
+ elif entities_collection:
# we have no explicit FROMs, so the implicit left has to
# come from our list of entities.
potential = {}
- for entity_index, ent in enumerate(self._entities):
+ for entity_index, ent in enumerate(entities_collection):
entity = ent.entity_zero_or_selectable
if entity is None:
continue
@@ -1689,7 +1762,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
return left, replace_from_obj_index, use_entity_index
- def _join_place_explicit_left_side(self, left):
+ def _join_place_explicit_left_side(self, entities_collection, left):
"""When join conditions express a left side explicitly, determine
where in our existing list of FROM clauses we should join towards,
or if we need to make a new join, and if so is it from one of our
@@ -1743,10 +1816,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# aliasing / adaptation rules present on that entity if any
if (
replace_from_obj_index is None
- and self._entities
+ and entities_collection
and hasattr(l_info, "mapper")
):
- for idx, ent in enumerate(self._entities):
+ for idx, ent in enumerate(entities_collection):
# TODO: should we be checking for multiple mapper entities
# matching?
if isinstance(ent, _MapperEntity) and ent.corresponds_to(left):
@@ -2194,11 +2267,14 @@ class _QueryEntity(object):
__slots__ = ()
@classmethod
- def to_compile_state(cls, compile_state, entities):
+ def to_compile_state(cls, compile_state, entities, entities_collection):
+
for idx, entity in enumerate(entities):
if entity._is_lambda_element:
if entity._is_sequence:
- cls.to_compile_state(compile_state, entity._resolved)
+ cls.to_compile_state(
+ compile_state, entity._resolved, entities_collection
+ )
continue
else:
entity = entity._resolved
@@ -2206,26 +2282,38 @@ class _QueryEntity(object):
if entity.is_clause_element:
if entity.is_selectable:
if "parententity" in entity._annotations:
- _MapperEntity(compile_state, entity)
+ _MapperEntity(
+ compile_state, entity, entities_collection
+ )
else:
_ColumnEntity._for_columns(
- compile_state, entity._select_iterable, idx
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
)
else:
if entity._annotations.get("bundle", False):
- _BundleEntity(compile_state, entity)
+ _BundleEntity(
+ compile_state, entity, entities_collection
+ )
elif entity._is_clause_list:
# this is legacy only - test_composites.py
# test_query_cols_legacy
_ColumnEntity._for_columns(
- compile_state, entity._select_iterable, idx
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
)
else:
_ColumnEntity._for_columns(
- compile_state, [entity], idx
+ compile_state, [entity], entities_collection, idx
)
elif entity.is_bundle:
- _BundleEntity(compile_state, entity)
+ _BundleEntity(compile_state, entity, entities_collection)
+
+ return entities_collection
class _MapperEntity(_QueryEntity):
@@ -2244,8 +2332,8 @@ class _MapperEntity(_QueryEntity):
"_polymorphic_discriminator",
)
- def __init__(self, compile_state, entity):
- compile_state._entities.append(self)
+ def __init__(self, compile_state, entity, entities_collection):
+ entities_collection.append(self)
if compile_state._primary_entity is None:
compile_state._primary_entity = self
compile_state._has_mapper_entities = True
@@ -2418,7 +2506,12 @@ class _BundleEntity(_QueryEntity):
)
def __init__(
- self, compile_state, expr, setup_entities=True, parent_bundle=None
+ self,
+ compile_state,
+ expr,
+ entities_collection,
+ setup_entities=True,
+ parent_bundle=None,
):
compile_state._has_orm_entities = True
@@ -2426,7 +2519,7 @@ class _BundleEntity(_QueryEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
if isinstance(
expr, (attributes.QueryableAttribute, interfaces.PropComparator)
@@ -2443,12 +2536,26 @@ class _BundleEntity(_QueryEntity):
if setup_entities:
for expr in bundle.exprs:
if "bundle" in expr._annotations:
- _BundleEntity(compile_state, expr, parent_bundle=self)
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ parent_bundle=self,
+ )
elif isinstance(expr, Bundle):
- _BundleEntity(compile_state, expr, parent_bundle=self)
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ parent_bundle=self,
+ )
else:
_ORMColumnEntity._for_columns(
- compile_state, [expr], None, parent_bundle=self
+ compile_state,
+ [expr],
+ entities_collection,
+ None,
+ parent_bundle=self,
)
self.supports_single_entity = self.bundle.single_entity
@@ -2516,7 +2623,12 @@ class _ColumnEntity(_QueryEntity):
@classmethod
def _for_columns(
- cls, compile_state, columns, raw_column_index, parent_bundle=None
+ cls,
+ compile_state,
+ columns,
+ entities_collection,
+ raw_column_index,
+ parent_bundle=None,
):
for column in columns:
annotations = column._annotations
@@ -2532,6 +2644,7 @@ class _ColumnEntity(_QueryEntity):
_IdentityTokenEntity(
compile_state,
column,
+ entities_collection,
_entity,
raw_column_index,
parent_bundle=parent_bundle,
@@ -2540,6 +2653,7 @@ class _ColumnEntity(_QueryEntity):
_ORMColumnEntity(
compile_state,
column,
+ entities_collection,
_entity,
raw_column_index,
parent_bundle=parent_bundle,
@@ -2548,6 +2662,7 @@ class _ColumnEntity(_QueryEntity):
_RawColumnEntity(
compile_state,
column,
+ entities_collection,
raw_column_index,
parent_bundle=parent_bundle,
)
@@ -2630,7 +2745,12 @@ class _RawColumnEntity(_ColumnEntity):
)
def __init__(
- self, compile_state, column, raw_column_index, parent_bundle=None
+ self,
+ compile_state,
+ column,
+ entities_collection,
+ raw_column_index,
+ parent_bundle=None,
):
self.expr = column
self.raw_column_index = raw_column_index
@@ -2643,7 +2763,7 @@ class _RawColumnEntity(_ColumnEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
self.column = column
self.entity_zero_or_selectable = (
@@ -2690,6 +2810,7 @@ class _ORMColumnEntity(_ColumnEntity):
self,
compile_state,
column,
+ entities_collection,
parententity,
raw_column_index,
parent_bundle=None,
@@ -2729,7 +2850,7 @@ class _ORMColumnEntity(_ColumnEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
compile_state._has_orm_entities = True
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index c9a601f99..28b4bfb2d 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -750,6 +750,18 @@ class LoaderOption(ORMOption):
_is_compile_state = True
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ """Apply a modification to a given :class:`.CompileState`,
+ given entities that were replaced by with_only_columns() or
+ with_entities().
+
+ .. versionadded:: 1.4.19
+
+ """
+ self.process_compile_state(compile_state)
+
def process_compile_state(self, compile_state):
"""Apply a modification to a given :class:`.CompileState`."""
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index cacfb8d84..7ba31fa7a 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -57,6 +57,7 @@ from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import _entity_namespace_key
from ..sql.base import _generative
from ..sql.base import Executable
+from ..sql.selectable import _MemoizedSelectEntities
from ..sql.selectable import _SelectFromElements
from ..sql.selectable import ForUpdateArg
from ..sql.selectable import GroupedElement
@@ -125,6 +126,8 @@ class Query(
_legacy_setup_joins = ()
_label_style = LABEL_STYLE_LEGACY_ORM
+ _memoized_select_entities = ()
+
_compile_options = ORMCompileState.default_compile_options
load_options = QueryContext.default_load_options
@@ -1433,6 +1436,7 @@ class Query(
limit(1)
"""
+ _MemoizedSelectEntities._generate_for_statement(self)
self._set_entities(entities)
@_generative
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index e371442fd..91e627525 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -172,13 +172,32 @@ class Load(Generative, LoaderOption):
_of_type = None
_extra_criteria = ()
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+
+ # process is being run here so that the options given are validated
+ # against what the lead entities were, as well as to accommodate
+ # for the entities having been replaced with equivalents
+ self._process(
+ compile_state,
+ mapper_entities,
+ not bool(compile_state.current_path),
+ )
+
def process_compile_state(self, compile_state):
if not compile_state.compile_options._enable_eagerloads:
return
- self._process(compile_state, not bool(compile_state.current_path))
+ self._process(
+ compile_state,
+ compile_state._lead_mapper_entities,
+ not bool(compile_state.current_path),
+ )
- def _process(self, compile_state, raiseerr):
+ def _process(self, compile_state, mapper_entities, raiseerr):
is_refresh = compile_state.compile_options._for_refresh_state
current_path = compile_state.current_path
if current_path:
@@ -700,7 +719,7 @@ class _UnboundLoad(Load):
state["path"] = tuple(ret)
self.__dict__ = state
- def _process(self, compile_state, raiseerr):
+ def _process(self, compile_state, mapper_entities, raiseerr):
dedupes = compile_state.attributes["_unbound_load_dedupes"]
is_refresh = compile_state.compile_options._for_refresh_state
for val in self._to_bind:
@@ -709,10 +728,7 @@ class _UnboundLoad(Load):
if is_refresh and not val.propagate_to_loaders:
continue
val._bind_loader(
- [
- ent.entity_zero
- for ent in compile_state._mapper_entities
- ],
+ [ent.entity_zero for ent in mapper_entities],
compile_state.current_path,
compile_state.attributes,
raiseerr,
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 213f47c40..709106b6b 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -32,7 +32,6 @@ from .base import NO_ARG
from .base import PARSE_AUTOCOMMIT
from .base import SingletonConstant
from .coercions import _document_text_coercion
-from .traversals import _get_children
from .traversals import HasCopyInternals
from .traversals import MemoizedHasCacheKey
from .traversals import NO_CACHE
@@ -389,33 +388,6 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def get_children(self, omit_attrs=(), **kw):
- r"""Return immediate child :class:`.visitors.Traversible`
- elements of this :class:`.visitors.Traversible`.
-
- This is used for visit traversal.
-
- \**kw may contain flags that change the collection that is
- returned, for example to return a subset of items in order to
- cut down on larger traversals, or to return child items from a
- different context (such as schema-level collections instead of
- clause-level).
-
- """
- try:
- traverse_internals = self._traverse_internals
- except AttributeError:
- # user-defined classes may not have a _traverse_internals
- return []
-
- return itertools.chain.from_iterable(
- meth(obj, **kw)
- for attrname, obj, meth in _get_children.run_generated_dispatch(
- self, traverse_internals, "_generated_get_children_traversal"
- )
- if attrname not in omit_attrs and obj is not None
- )
-
def self_group(self, against=None):
"""Apply a 'grouping' to this :class:`_expression.ClauseElement`.
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 1610191d1..e1dee091b 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -18,7 +18,9 @@ from operator import attrgetter
from . import coercions
from . import operators
from . import roles
+from . import traversals
from . import type_api
+from . import visitors
from .annotation import Annotated
from .annotation import SupportsCloneAnnotations
from .base import _clone
@@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState):
self.statement = statement
self.from_clauses = statement._from_obj
+ for memoized_entities in statement._memoized_select_entities:
+ self._setup_joins(
+ memoized_entities._setup_joins, memoized_entities._raw_columns
+ )
+
if statement._setup_joins:
- self._setup_joins(statement._setup_joins)
+ self._setup_joins(statement._setup_joins, statement._raw_columns)
self.froms = self._get_froms(statement)
@@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState):
def all_selected_columns(cls, statement):
return [c for c in _select_iterables(statement._raw_columns)]
- def _setup_joins(self, args):
+ def _setup_joins(self, args, raw_columns):
for (right, onclause, left, flags) in args:
isouter = flags["isouter"]
full = flags["full"]
@@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState):
left,
replace_from_obj_index,
) = self._join_determine_implicit_left_side(
- left, right, onclause
+ raw_columns, left, right, onclause
)
else:
(replace_from_obj_index) = self._join_place_explicit_left_side(
@@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState):
)
@util.preload_module("sqlalchemy.sql.util")
- def _join_determine_implicit_left_side(self, left, right, onclause):
+ def _join_determine_implicit_left_side(
+ self, raw_columns, left, right, onclause
+ ):
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState):
for from_clause in itertools.chain(
itertools.chain.from_iterable(
- [
- element._from_objects
- for element in statement._raw_columns
- ]
+ [element._from_objects for element in raw_columns]
),
itertools.chain.from_iterable(
[
@@ -4531,6 +4537,47 @@ class _SelectFromElements(object):
yield element
+class _MemoizedSelectEntities(
+ traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+):
+ __visit_name__ = "memoized_select_entities"
+
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_with_options", InternalTraversal.dp_executable_options),
+ ]
+
+ _annotations = util.EMPTY_DICT
+
+ def _clone(self, **kw):
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = {k: v for k, v in self.__dict__.items()}
+ c._is_clone_of = self
+ return c
+
+ @classmethod
+ def _generate_for_statement(cls, select_stmt):
+ if (
+ select_stmt._setup_joins
+ or select_stmt._legacy_setup_joins
+ or select_stmt._with_options
+ ):
+ self = _MemoizedSelectEntities()
+ self._raw_columns = select_stmt._raw_columns
+ self._setup_joins = select_stmt._setup_joins
+ self._legacy_setup_joins = select_stmt._legacy_setup_joins
+ self._with_options = select_stmt._with_options
+
+ select_stmt._memoized_select_entities += (self,)
+ select_stmt._raw_columns = (
+ select_stmt._setup_joins
+ ) = (
+ select_stmt._legacy_setup_joins
+ ) = select_stmt._with_options = ()
+
+
class Select(
HasPrefixes,
HasSuffixes,
@@ -4559,6 +4606,7 @@ class Select(
_setup_joins = ()
_legacy_setup_joins = ()
+ _memoized_select_entities = ()
_distinct = False
_distinct_on = ()
@@ -4574,6 +4622,10 @@ class Select(
_traverse_internals = (
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ (
+ "_memoized_select_entities",
+ InternalTraversal.dp_memoized_select_entities,
+ ),
("_from_obj", InternalTraversal.dp_clauseelement_list),
("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
@@ -5461,16 +5513,14 @@ class Select(
# is the case for now.
self._assert_no_memoizations()
- rc = []
- for c in coercions._expression_collection_was_a_list(
- "columns", "Select.with_only_columns", columns
- ):
- c = coercions.expect(roles.ColumnsClauseRole, c)
- # TODO: why are we doing this here?
- if isinstance(c, ScalarSelect):
- c = c.self_group(against=operators.comma_op)
- rc.append(c)
- self._raw_columns = rc
+ _MemoizedSelectEntities._generate_for_statement(self)
+
+ self._raw_columns = [
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in coercions._expression_collection_was_a_list(
+ "columns", "Select.with_only_columns", columns
+ )
+ ]
@property
def whereclause(self):
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 35f2bd62f..a86d16ef4 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -194,6 +194,8 @@ class HasCacheKey(object):
elif (
meth is InternalTraversal.dp_clauseelement_list
or meth is InternalTraversal.dp_clauseelement_tuple
+ or meth
+ is InternalTraversal.dp_memoized_select_entities
):
result += (
attrname,
@@ -409,6 +411,9 @@ class _CacheKey(ExtendedInternalTraversal):
visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
visit_annotations_key = InternalTraversal.dp_annotations_key
visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+ visit_memoized_select_entities = (
+ InternalTraversal.dp_memoized_select_entities
+ )
visit_string = (
visit_boolean
@@ -799,6 +804,9 @@ class _CopyInternals(InternalTraversal):
for (target, onclause, from_, flags) in element
)
+ def visit_memoized_select_entities(self, attrname, parent, element, **kw):
+ return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
+
def visit_dml_ordered_values(
self, attrname, parent, element, clone=_clone, **kw
):
@@ -919,6 +927,9 @@ class _GetChildren(InternalTraversal):
if onclause is not None and not isinstance(onclause, str):
yield _flatten_clauseelement(onclause)
+ def visit_memoized_select_entities(self, element, **kw):
+ return self.visit_clauseelement_tuple(element, **kw)
+
def visit_dml_ordered_values(self, element, **kw):
for k, v in element:
if hasattr(k, "__clause_element__"):
@@ -1265,6 +1276,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((l_onclause, r_onclause))
self.stack.append((l_from, r_from))
+ def visit_memoized_select_entities(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.visit_clauseelement_tuple(
+ attrname, left_parent, left, right_parent, right, **kw
+ )
+
def visit_table_hint_list(
self, attrname, left_parent, left, right_parent, right, **kw
):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 93ee8eb1c..c750c546a 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -24,6 +24,7 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
"""
from collections import deque
+import itertools
import operator
from .. import exc
@@ -119,6 +120,38 @@ class Traversible(util.with_metaclass(TraversibleType)):
"""
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(self, omit_attrs=(), **kw):
+ r"""Return immediate child :class:`.visitors.Traversible`
+ elements of this :class:`.visitors.Traversible`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
class _InternalTraversalType(type):
def __init__(cls, clsname, bases, clsdict):
@@ -393,6 +426,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
dp_setup_join_tuple = symbol("SJ")
+ dp_memoized_select_entities = symbol("ME")
+
dp_statement_hint_list = symbol("SH")
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`
diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py
index 67f2d0230..7b6feb96a 100644
--- a/test/orm/test_cache_key.py
+++ b/test/orm/test_cache_key.py
@@ -30,6 +30,7 @@ from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import mock
+from sqlalchemy.testing import ne_
from sqlalchemy.testing.fixtures import fixture_session
from test.orm import _fixtures
from .inheritance import _poly_fixtures
@@ -313,6 +314,111 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
compare_values=True,
)
+ def test_orm_query_using_with_entities(self):
+ """test issue #6503"""
+ User, Address, Keyword, Order, Item = self.classes(
+ "User", "Address", "Keyword", "Order", "Item"
+ )
+
+ self._run_cache_key_fixture(
+ lambda: stmt_20(
+ fixture_session()
+ .query(User)
+ .join(User.addresses)
+ .with_entities(Address.id),
+ #
+ fixture_session().query(Address.id).join(User.addresses),
+ #
+ fixture_session()
+ .query(User)
+ .options(selectinload(User.addresses))
+ .with_entities(User.id),
+ #
+ fixture_session()
+ .query(User)
+ .options(selectinload(User.addresses)),
+ #
+ fixture_session().query(User).with_entities(User.id),
+ #
+ # here, propagate_attr->orm is Address, entity is Address.id,
+ # but the join() + with_entities() will log a
+ # _MemoizedSelectEntities to differentiate
+ fixture_session()
+ .query(Address, Order)
+ .join(Address.dingaling)
+ .with_entities(Address.id),
+ #
+ # same, propagate_attr->orm is Address, entity is Address.id,
+ # but the join() + with_entities() will log a
+ # _MemoizedSelectEntities to differentiate
+ fixture_session()
+ .query(Address, User)
+ .join(Address.dingaling)
+ .with_entities(Address.id),
+ ),
+ compare_values=True,
+ )
+
+ def test_more_with_entities_sanity_checks(self):
+ """test issue #6503"""
+ User, Address, Keyword, Order, Item = self.classes(
+ "User", "Address", "Keyword", "Order", "Item"
+ )
+
+ sess = fixture_session()
+
+ q1 = (
+ sess.query(Address, Order)
+ .with_entities(Address.id)
+ ._statement_20()
+ )
+ q2 = (
+ sess.query(Address, User).with_entities(Address.id)._statement_20()
+ )
+
+ assert not q1._memoized_select_entities
+ assert not q2._memoized_select_entities
+
+ # no joins or options, so q1 and q2 have the same cache key as Order/
+ # User are discarded. Note Address is first so propagate_attrs->orm is
+ # Address.
+ eq_(q1._generate_cache_key(), q2._generate_cache_key())
+
+ q3 = sess.query(Order).with_entities(Address.id)._statement_20()
+ q4 = sess.query(User).with_entities(Address.id)._statement_20()
+
+ # with Order/User as lead entity, this affects propagate_attrs->orm
+ # so keys are different
+ ne_(q3._generate_cache_key(), q4._generate_cache_key())
+
+ # confirm by deleting propagate attrs and memoized key and
+ # running again
+ q3._propagate_attrs = None
+ q4._propagate_attrs = None
+ del q3.__dict__["_generate_cache_key"]
+ del q4.__dict__["_generate_cache_key"]
+ eq_(q3._generate_cache_key(), q4._generate_cache_key())
+
+ # once there's a join() or options() prior to with_entities, now they
+ # are not discarded from the key; Order and User are in the
+ # _MemoizedSelectEntities
+ q5 = (
+ sess.query(Address, Order)
+ .join(Address.dingaling)
+ .with_entities(Address.id)
+ ._statement_20()
+ )
+ q6 = (
+ sess.query(Address, User)
+ .join(Address.dingaling)
+ .with_entities(Address.id)
+ ._statement_20()
+ )
+
+ assert q5._memoized_select_entities
+ assert q6._memoized_select_entities
+ ne_(q5._generate_cache_key(), q6._generate_cache_key())
+
def test_orm_query_from_statement(self):
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py
index 7f6e1b72e..25fa7e661 100644
--- a/test/orm/test_joins.py
+++ b/test/orm/test_joins.py
@@ -327,6 +327,43 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
"JOIN addresses ON users.id = addresses.user_id",
)
+ @testing.combinations((True,), (False,), argnames="legacy")
+ @testing.combinations((True,), (False,), argnames="threelevel")
+ def test_join_with_entities(self, legacy, threelevel):
+ """test issue #6503"""
+
+ User, Address, Dingaling = self.classes("User", "Address", "Dingaling")
+
+ if legacy:
+ sess = fixture_session()
+ stmt = sess.query(User).join(Address).with_entities(Address.id)
+ else:
+ stmt = select(User).join(Address).with_only_columns(Address.id)
+
+ stmt = stmt.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ if threelevel:
+ if legacy:
+ stmt = stmt.join(Address.dingaling).with_entities(Dingaling.id)
+ else:
+ stmt = stmt.join(Address.dingaling).with_only_columns(
+ Dingaling.id
+ )
+
+ if threelevel:
+ self.assert_compile(
+ stmt,
+ "SELECT dingalings.id AS dingalings_id "
+ "FROM users JOIN addresses ON users.id = addresses.user_id "
+ "JOIN dingalings ON addresses.id = dingalings.address_id",
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "SELECT addresses.id AS addresses_id FROM users "
+ "JOIN addresses ON users.id = addresses.user_id",
+ )
+
def test_invalid_kwarg_join(self):
User = self.classes.User
sess = fixture_session()
diff --git a/test/orm/test_options.py b/test/orm/test_options.py
index 4bef121d9..31ab100fa 100644
--- a/test/orm/test_options.py
+++ b/test/orm/test_options.py
@@ -3,6 +3,7 @@ from sqlalchemy import Column
from sqlalchemy import ForeignKey
from sqlalchemy import inspect
from sqlalchemy import Integer
+from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.orm import aliased
@@ -24,6 +25,7 @@ from sqlalchemy.orm import util as orm_util
from sqlalchemy.orm import with_polymorphic
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises_message
+from sqlalchemy.testing.assertions import AssertsCompiledSQL
from sqlalchemy.testing.assertions import eq_
from sqlalchemy.testing.fixtures import fixture_session
from test.orm import _fixtures
@@ -95,7 +97,7 @@ class PathTest(object):
val._bind_loader(
[
ent.entity_zero
- for ent in q._compile_state()._mapper_entities
+ for ent in q._compile_state()._lead_mapper_entities
],
q._compile_options._current_path,
attr,
@@ -104,7 +106,7 @@ class PathTest(object):
else:
compile_state = q._compile_state()
compile_state.attributes = attr = {}
- opt._process(compile_state, True)
+ opt._process(compile_state, [], True)
assert_paths = [k[1] for k in attr]
eq_(
@@ -401,6 +403,92 @@ class OfTypePathingTest(PathTest, QueryTest):
)
+class WithEntitiesTest(QueryTest, AssertsCompiledSQL):
+ def test_options_legacy_with_entities_onelevel(self):
+ """test issue #6253 (part of #6503)"""
+
+ User = self.classes.User
+ sess = fixture_session()
+
+ q = (
+ sess.query(User)
+ .options(joinedload(User.addresses))
+ .with_entities(User.id)
+ )
+ self.assert_compile(q, "SELECT users.id AS users_id FROM users")
+
+ def test_options_with_only_cols_onelevel(self):
+ """test issue #6253 (part of #6503)"""
+
+ User = self.classes.User
+
+ q = (
+ select(User)
+ .options(joinedload(User.addresses))
+ .with_only_columns(User.id)
+ )
+ self.assert_compile(q, "SELECT users.id FROM users")
+
+ def test_options_entities_replaced_with_equivs_one(self):
+ User = self.classes.User
+ Address = self.classes.Address
+
+ q = (
+ select(User, Address)
+ .options(joinedload(User.addresses))
+ .with_only_columns(User)
+ )
+ self.assert_compile(
+ q,
+ "SELECT users.id, users.name, addresses_1.id AS id_1, "
+ "addresses_1.user_id, addresses_1.email_address FROM users "
+ "LEFT OUTER JOIN addresses AS addresses_1 "
+ "ON users.id = addresses_1.user_id ORDER BY addresses_1.id",
+ )
+
+ def test_options_entities_replaced_with_equivs_two(self):
+ User = self.classes.User
+ Address = self.classes.Address
+
+ q = (
+ select(User, Address)
+ .options(joinedload(User.addresses), joinedload(Address.dingaling))
+ .with_only_columns(User)
+ )
+ self.assert_compile(
+ q,
+ "SELECT users.id, users.name, addresses_1.id AS id_1, "
+ "addresses_1.user_id, addresses_1.email_address FROM users "
+ "LEFT OUTER JOIN addresses AS addresses_1 "
+ "ON users.id = addresses_1.user_id ORDER BY addresses_1.id",
+ )
+
+ def test_options_entities_replaced_with_equivs_three(self):
+ User = self.classes.User
+ Address = self.classes.Address
+
+ q = (
+ select(User)
+ .options(joinedload(User.addresses))
+ .with_only_columns(User, Address)
+ .options(joinedload(Address.dingaling))
+ )
+ self.assert_compile(
+ q,
+ "SELECT users.id, users.name, addresses.id AS id_1, "
+ "addresses.user_id, addresses.email_address, "
+ "addresses_1.id AS id_2, addresses_1.user_id AS user_id_1, "
+ "addresses_1.email_address AS email_address_1, "
+ "dingalings_1.id AS id_3, dingalings_1.address_id, "
+ "dingalings_1.data "
+ "FROM users LEFT OUTER JOIN addresses AS addresses_1 "
+ "ON users.id = addresses_1.user_id, addresses "
+ "LEFT OUTER JOIN dingalings AS dingalings_1 "
+ "ON addresses.id = dingalings_1.address_id "
+ "ORDER BY addresses_1.id",
+ )
+
+
class OptionsTest(PathTest, QueryTest):
def _option_fixture(self, *arg):
return strategy_options._UnboundLoad._from_keys(
@@ -1479,7 +1567,7 @@ class PickleTest(PathTest, QueryTest):
load = opt._bind_loader(
[
ent.entity_zero
- for ent in query._compile_state()._mapper_entities
+ for ent in query._compile_state()._lead_mapper_entities
],
query._compile_options._current_path,
attr,
@@ -1516,7 +1604,7 @@ class PickleTest(PathTest, QueryTest):
load = opt._bind_loader(
[
ent.entity_zero
- for ent in query._compile_state()._mapper_entities
+ for ent in query._compile_state()._lead_mapper_entities
],
query._compile_options._current_path,
attr,
@@ -1560,7 +1648,7 @@ class LocalOptsTest(PathTest, QueryTest):
ctx = query._compile_state()
for tb in opt._to_bind:
tb._bind_loader(
- [ent.entity_zero for ent in ctx._mapper_entities],
+ [ent.entity_zero for ent in ctx._lead_mapper_entities],
query._compile_options._current_path,
attr,
False,
@@ -1658,7 +1746,7 @@ class SubOptionsTest(PathTest, QueryTest):
val._bind_loader(
[
ent.entity_zero
- for ent in q._compile_state()._mapper_entities
+ for ent in q._compile_state()._lead_mapper_entities
],
q._compile_options._current_path,
attr_a,
@@ -1672,7 +1760,7 @@ class SubOptionsTest(PathTest, QueryTest):
val._bind_loader(
[
ent.entity_zero
- for ent in q._compile_state()._mapper_entities
+ for ent in q._compile_state()._lead_mapper_entities
],
q._compile_options._current_path,
attr_b,
diff --git a/test/profiles.txt b/test/profiles.txt
index 3b5b1aca3..6e6f430a3 100644
--- a/test/profiles.txt
+++ b/test/profiles.txt
@@ -1,15 +1,15 @@
# /home/classic/dev/sqlalchemy/test/profiles.txt
# This file is written out on a per-environment basis.
-# For each test in aaa_profiling, the corresponding function and
+# For each test in aaa_profiling, the corresponding function and
# environment is located within this file. If it doesn't exist,
# the test is skipped.
-# If a callcount does exist, it is compared to what we received.
+# If a callcount does exist, it is compared to what we received.
# assertions are raised if the counts do not match.
-#
-# To add a new callcount test, apply the function_call_count
-# decorator and re-run the tests using the --write-profiles
+#
+# To add a new callcount test, apply the function_call_count
+# decorator and re-run the tests using the --write-profiles
# option - this file will be rewritten including the new count.
-#
+#
# TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
@@ -240,10 +240,10 @@ test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove
# TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 60
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 60
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 61
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 61
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 68
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 68
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 73
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 73
# TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 257776c50..e96a47553 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -514,6 +514,14 @@ class CoreFixtures(object):
),
),
lambda: (
+ # test issue #6503
+ # join from table_a -> table_c, select table_b.c.a
+ select(table_a).join(table_c).with_only_columns(table_b.c.a),
+ # join from table_b -> table_c, select table_b.c.a
+ select(table_b.c.a).join(table_c).with_only_columns(table_b.c.a),
+ select(table_a).with_only_columns(table_b.c.a),
+ ),
+ lambda: (
table_a.insert(),
table_a.insert().values({})._annotate({"nocache": True}),
table_b.insert(),
diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py
index 3469dcb37..c7e51c807 100644
--- a/test/sql/test_external_traversal.py
+++ b/test/sql/test_external_traversal.py
@@ -1747,6 +1747,29 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
"addresses.user_id",
)
+ def test_prev_entities_adapt(self):
+ """test #6503"""
+
+ m = MetaData()
+ users = Table("users", m, Column("id", Integer, primary_key=True))
+ addresses = Table(
+ "addresses",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("user_id", ForeignKey("users.id")),
+ )
+
+ ualias = users.alias()
+
+ s = select(users).join(addresses).with_only_columns(addresses.c.id)
+ s = sql_util.ClauseAdapter(ualias).traverse(s)
+
+ self.assert_compile(
+ s,
+ "SELECT addresses.id FROM users AS users_1 "
+ "JOIN addresses ON users_1.id = addresses.user_id",
+ )
+
@testing.combinations((True,), (False,), argnames="use_adapt_from")
def test_table_to_alias_1(self, use_adapt_from):
t1alias = t1.alias("t1alias")
diff --git a/test/sql/test_select.py b/test/sql/test_select.py
index f9f1acfa0..d1f9e381f 100644
--- a/test/sql/test_select.py
+++ b/test/sql/test_select.py
@@ -266,6 +266,33 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL):
"ON parent.id = child.parent_id",
)
+ def test_join_implicit_left_side_wo_cols_onelevel(self):
+ """test issue #6503"""
+ stmt = select(parent).join(child).with_only_columns(child.c.id)
+
+ self.assert_compile(
+ stmt,
+ "SELECT child.id FROM parent "
+ "JOIN child ON parent.id = child.parent_id",
+ )
+
+ def test_join_implicit_left_side_wo_cols_twolevel(self):
+ """test issue #6503"""
+ stmt = (
+ select(parent)
+ .join(child)
+ .with_only_columns(child.c.id)
+ .join(grandchild)
+ .with_only_columns(grandchild.c.id)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT grandchild.id FROM parent "
+ "JOIN child ON parent.id = child.parent_id "
+ "JOIN grandchild ON child.id = grandchild.child_id",
+ )
+
def test_right_nested_inner_join(self):
inner = child.join(grandchild)