summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2019-08-30 22:23:44 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2019-08-30 22:23:44 +0000
commitb83c41c44bad0b166ad9a2355d10641b0310e2fe (patch)
tree9870ea0f1195da751a2c08d33288f74cd3c663e8 /lib/sqlalchemy
parent520f8579d1785e6f906947ff103aaa8db8330621 (diff)
parentf6c9b20a04d183d86078252048563b14e27fb6d2 (diff)
downloadsqlalchemy-b83c41c44bad0b166ad9a2355d10641b0310e2fe.tar.gz
Merge "Annotate session-bind-lookup entity in Query-produced selectables"
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/query.py77
-rw-r--r--lib/sqlalchemy/sql/annotation.py74
-rw-r--r--lib/sqlalchemy/sql/elements.py40
-rw-r--r--lib/sqlalchemy/sql/selectable.py2
4 files changed, 138 insertions, 55 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 936929703..d4ff35d2e 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -384,6 +384,25 @@ class Query(object):
else self._query_entity_zero().entity_zero
)
+ def _deep_entity_zero(self):
+ """Return a 'deep' entity; this is any entity we can find associated
+ with the first entity / column experssion. this is used only for
+ session.get_bind().
+
+ """
+
+ if (
+ self._select_from_entity is not None
+ and not self._select_from_entity.is_clause_element
+ ):
+ return self._select_from_entity.mapper
+ for ent in self._entities:
+ ezero = ent._deep_entity_zero()
+ if ezero is not None:
+ return ezero.mapper
+ else:
+ return None
+
@property
def _mapper_entities(self):
for ent in self._entities:
@@ -394,13 +413,7 @@ class Query(object):
return self._joinpoint.get("_joinpoint_entity", self._entity_zero())
def _bind_mapper(self):
- ezero = self._entity_zero()
- if ezero is not None:
- insp = inspect(ezero)
- if not insp.is_clause_element:
- return insp.mapper
-
- return None
+ return self._deep_entity_zero()
def _only_full_mapper_zero(self, methname):
if self._entities != [self._primary_entity]:
@@ -3900,6 +3913,12 @@ class Query(object):
else:
context.statement = self._simple_statement(context)
+ if for_statement:
+ ezero = self._mapper_zero()
+ if ezero is not None:
+ context.statement = context.statement._annotate(
+ {"deepentity": ezero}
+ )
return context
def _compound_eager_statement(self, context):
@@ -4161,6 +4180,9 @@ class _MapperEntity(_QueryEntity):
def entity_zero_or_selectable(self):
return self.entity_zero
+ def _deep_entity_zero(self):
+ return self.entity_zero
+
def corresponds_to(self, entity):
return _entity_corresponds_to(self.entity_zero, entity)
@@ -4430,6 +4452,14 @@ class _BundleEntity(_QueryEntity):
else:
return None
+ def _deep_entity_zero(self):
+ for ent in self._entities:
+ ezero = ent._deep_entity_zero()
+ if ezero is not None:
+ return ezero
+ else:
+ return None
+
def adapt_to_selectable(self, query, sel):
c = _BundleEntity(query, self.bundle, setup_entities=False)
# c._label_name = self._label_name
@@ -4530,7 +4560,7 @@ class _ColumnEntity(_QueryEntity):
# of FROMs for the overall expression - this helps
# subqueries which were built from ORM constructs from
# leaking out their entities into the main select construct
- self.actual_froms = actual_froms = set(column._from_objects)
+ self.actual_froms = set(column._from_objects)
if not search_entities:
self.entity_zero = _entity
@@ -4540,7 +4570,6 @@ class _ColumnEntity(_QueryEntity):
else:
self.entities = []
self.mapper = None
- self._from_entities = set(self.entities)
else:
all_elements = [
elem
@@ -4551,21 +4580,9 @@ class _ColumnEntity(_QueryEntity):
]
self.entities = util.unique_list(
- [
- elem._annotations["parententity"]
- for elem in all_elements
- if "parententity" in elem._annotations
- ]
+ [elem._annotations["parententity"] for elem in all_elements]
)
- self._from_entities = set(
- [
- elem._annotations["parententity"]
- for elem in all_elements
- if "parententity" in elem._annotations
- and actual_froms.intersection(elem._from_objects)
- ]
- )
if self.entities:
self.entity_zero = self.entities[0]
self.mapper = self.entity_zero.mapper
@@ -4578,6 +4595,22 @@ class _ColumnEntity(_QueryEntity):
supports_single_entity = False
+ def _deep_entity_zero(self):
+ if self.mapper is not None:
+ return self.mapper
+
+ else:
+ for obj in visitors.iterate(
+ self.column,
+ {"column_tables": True, "column_collections": False},
+ ):
+ if "parententity" in obj._annotations:
+ return obj._annotations["parententity"]
+ elif "deepentity" in obj._annotations:
+ return obj._annotations["deepentity"]
+ else:
+ return None
+
@property
def entity_zero_or_selectable(self):
if self.entity_zero is not None:
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 7fc9245ab..a0264845e 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -15,8 +15,80 @@ from . import operators
from .. import util
+class SupportsCloneAnnotations(object):
+ _annotations = util.immutabledict()
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = new._annotations.union(values)
+ return new
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = util.immutabledict(values)
+ return new
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone or self._annotations:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ new = self._clone()
+ new._annotations = {}
+ return new
+ else:
+ return self
+
+
+class SupportsWrappingAnnotations(object):
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ return self._clone()
+ else:
+ # if no clone, since we have no annotations we return
+ # self
+ return self
+
+
class Annotated(object):
- """clones a ClauseElement and applies an 'annotations' dictionary.
+ """clones a SupportsAnnotated and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 42e7522ae..bc6f51b8c 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -22,6 +22,7 @@ from . import operators
from . import roles
from . import type_api
from .annotation import Annotated
+from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
@@ -161,7 +162,7 @@ def not_(clause):
@inspection._self_inspects
-class ClauseElement(roles.SQLRole, Visitable):
+class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
"""Base class for elements of a programmatically constructed SQL
expression.
@@ -276,37 +277,6 @@ class ClauseElement(roles.SQLRole, Visitable):
d.pop("_is_clone_of", None)
return d
- def _annotate(self, values):
- """return a copy of this ClauseElement with annotations
- updated by the given dictionary.
-
- """
- return Annotated(self, values)
-
- def _with_annotations(self, values):
- """return a copy of this ClauseElement with annotations
- replaced by the given dictionary.
-
- """
- return Annotated(self, values)
-
- def _deannotate(self, values=None, clone=False):
- """return a copy of this :class:`.ClauseElement` with annotations
- removed.
-
- :param values: optional tuple of individual values
- to remove.
-
- """
- if clone:
- # clone is used when we are also copying
- # the expression for a deep deannotation
- return self._clone()
- else:
- # if no clone, since we have no annotations we return
- # self
- return self
-
def _execute_on_connection(self, connection, multiparams, params):
if self.supports_execution:
return connection._execute_clauseelement(self, multiparams, params)
@@ -4230,6 +4200,12 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
self._memoized_property.expire_instance(self)
self.__dict__["table"] = table
+ def get_children(self, column_tables=False, **kw):
+ if column_tables and self.table is not None:
+ return [self.table]
+ else:
+ return []
+
table = property(_get_table, _set_table)
def _cache_key(self, **kw):
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 03dbcd449..97c49f8fc 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -19,6 +19,7 @@ from . import operators
from . import roles
from . import type_api
from .annotation import Annotated
+from .annotation import SupportsCloneAnnotations
from .base import _clone
from .base import _cloned_difference
from .base import _cloned_intersection
@@ -2068,6 +2069,7 @@ class SelectBase(
roles.InElementRole,
HasCTE,
Executable,
+ SupportsCloneAnnotations,
Selectable,
):
"""Base class for SELECT statements.