summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/context.py3
-rw-r--r--lib/sqlalchemy/orm/util.py17
-rw-r--r--lib/sqlalchemy/sql/annotation.py12
3 files changed, 26 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 61b957280..4e2586203 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -2137,7 +2137,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
for ae in self.global_attributes[
("additional_entity_criteria", ext_info.mapper)
]
- if ae.include_aliases or ae.entity is ext_info
+ if (ae.include_aliases or ae.entity is ext_info)
+ and ae._should_include(self)
)
else:
return ()
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 140464b87..fef65f73c 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1149,11 +1149,24 @@ class LoaderCriteriaOption(CriteriaOption):
else:
stack.extend(subclass.__subclasses__())
+ def _should_include(self, compile_state):
+ if (
+ compile_state.select_statement._annotations.get(
+ "for_loader_criteria", None
+ )
+ is self
+ ):
+ return False
+ return True
+
def _resolve_where_criteria(self, ext_info):
if self.deferred_where_criteria:
- return self.where_criteria._resolve_with_args(ext_info.entity)
+ crit = self.where_criteria._resolve_with_args(ext_info.entity)
else:
- return self.where_criteria
+ crit = self.where_criteria
+ return sql_util._deep_annotate(
+ crit, {"for_loader_criteria": self}, detect_subquery_cols=True
+ )
def process_compile_state_replaced_entities(
self, compile_state, mapper_entities
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 519a3103b..1706da44e 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -243,7 +243,9 @@ class Annotated:
annotated_classes = {}
-def _deep_annotate(element, annotations, exclude=None):
+def _deep_annotate(
+ element, annotations, exclude=None, detect_subquery_cols=False
+):
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
@@ -257,6 +259,7 @@ def _deep_annotate(element, annotations, exclude=None):
cloned_ids = {}
def clone(elem, **kw):
+ kw["detect_subquery_cols"] = detect_subquery_cols
id_ = id(elem)
if id_ in cloned_ids:
@@ -267,9 +270,12 @@ def _deep_annotate(element, annotations, exclude=None):
and hasattr(elem, "proxy_set")
and elem.proxy_set.intersection(exclude)
):
- newelem = elem._clone(**kw)
+ newelem = elem._clone(clone=clone, **kw)
elif annotations != elem._annotations:
- newelem = elem._annotate(annotations)
+ if detect_subquery_cols and elem._is_immutable:
+ newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+ else:
+ newelem = elem._annotate(annotations)
else:
newelem = elem
newelem._copy_internals(clone=clone)