diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2019-08-30 22:23:44 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2019-08-30 22:23:44 +0000 |
| commit | b83c41c44bad0b166ad9a2355d10641b0310e2fe (patch) | |
| tree | 9870ea0f1195da751a2c08d33288f74cd3c663e8 /lib/sqlalchemy | |
| parent | 520f8579d1785e6f906947ff103aaa8db8330621 (diff) | |
| parent | f6c9b20a04d183d86078252048563b14e27fb6d2 (diff) | |
| download | sqlalchemy-b83c41c44bad0b166ad9a2355d10641b0310e2fe.tar.gz | |
Merge "Annotate session-bind-lookup entity in Query-produced selectables"
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 77 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 74 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 40 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 2 |
4 files changed, 138 insertions, 55 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 936929703..d4ff35d2e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -384,6 +384,25 @@ class Query(object): else self._query_entity_zero().entity_zero ) + def _deep_entity_zero(self): + """Return a 'deep' entity; this is any entity we can find associated + with the first entity / column experssion. this is used only for + session.get_bind(). + + """ + + if ( + self._select_from_entity is not None + and not self._select_from_entity.is_clause_element + ): + return self._select_from_entity.mapper + for ent in self._entities: + ezero = ent._deep_entity_zero() + if ezero is not None: + return ezero.mapper + else: + return None + @property def _mapper_entities(self): for ent in self._entities: @@ -394,13 +413,7 @@ class Query(object): return self._joinpoint.get("_joinpoint_entity", self._entity_zero()) def _bind_mapper(self): - ezero = self._entity_zero() - if ezero is not None: - insp = inspect(ezero) - if not insp.is_clause_element: - return insp.mapper - - return None + return self._deep_entity_zero() def _only_full_mapper_zero(self, methname): if self._entities != [self._primary_entity]: @@ -3900,6 +3913,12 @@ class Query(object): else: context.statement = self._simple_statement(context) + if for_statement: + ezero = self._mapper_zero() + if ezero is not None: + context.statement = context.statement._annotate( + {"deepentity": ezero} + ) return context def _compound_eager_statement(self, context): @@ -4161,6 +4180,9 @@ class _MapperEntity(_QueryEntity): def entity_zero_or_selectable(self): return self.entity_zero + def _deep_entity_zero(self): + return self.entity_zero + def corresponds_to(self, entity): return _entity_corresponds_to(self.entity_zero, entity) @@ -4430,6 +4452,14 @@ class _BundleEntity(_QueryEntity): else: return None + def _deep_entity_zero(self): + for ent in self._entities: + ezero = ent._deep_entity_zero() + if ezero is not None: + return ezero + else: + return None + def adapt_to_selectable(self, query, sel): c = _BundleEntity(query, self.bundle, setup_entities=False) # c._label_name = self._label_name @@ -4530,7 +4560,7 @@ class _ColumnEntity(_QueryEntity): # of FROMs for the overall expression - this helps # subqueries which were built from ORM constructs from # leaking out their entities into the main select construct - self.actual_froms = actual_froms = set(column._from_objects) + self.actual_froms = set(column._from_objects) if not search_entities: self.entity_zero = _entity @@ -4540,7 +4570,6 @@ class _ColumnEntity(_QueryEntity): else: self.entities = [] self.mapper = None - self._from_entities = set(self.entities) else: all_elements = [ elem @@ -4551,21 +4580,9 @@ class _ColumnEntity(_QueryEntity): ] self.entities = util.unique_list( - [ - elem._annotations["parententity"] - for elem in all_elements - if "parententity" in elem._annotations - ] + [elem._annotations["parententity"] for elem in all_elements] ) - self._from_entities = set( - [ - elem._annotations["parententity"] - for elem in all_elements - if "parententity" in elem._annotations - and actual_froms.intersection(elem._from_objects) - ] - ) if self.entities: self.entity_zero = self.entities[0] self.mapper = self.entity_zero.mapper @@ -4578,6 +4595,22 @@ class _ColumnEntity(_QueryEntity): supports_single_entity = False + def _deep_entity_zero(self): + if self.mapper is not None: + return self.mapper + + else: + for obj in visitors.iterate( + self.column, + {"column_tables": True, "column_collections": False}, + ): + if "parententity" in obj._annotations: + return obj._annotations["parententity"] + elif "deepentity" in obj._annotations: + return obj._annotations["deepentity"] + else: + return None + @property def entity_zero_or_selectable(self): if self.entity_zero is not None: diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 7fc9245ab..a0264845e 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -15,8 +15,80 @@ from . import operators from .. import util +class SupportsCloneAnnotations(object): + _annotations = util.immutabledict() + + def _annotate(self, values): + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + new = self._clone() + new._annotations = new._annotations.union(values) + return new + + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + new = self._clone() + new._annotations = util.immutabledict(values) + return new + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone or self._annotations: + # clone is used when we are also copying + # the expression for a deep deannotation + new = self._clone() + new._annotations = {} + return new + else: + return self + + +class SupportsWrappingAnnotations(object): + def _annotate(self, values): + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + return Annotated(self, values) + + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + return Annotated(self, values) + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + # clone is used when we are also copying + # the expression for a deep deannotation + return self._clone() + else: + # if no clone, since we have no annotations we return + # self + return self + + class Annotated(object): - """clones a ClauseElement and applies an 'annotations' dictionary. + """clones a SupportsAnnotated and applies an 'annotations' dictionary. Unlike regular clones, this clone also mimics __hash__() and __cmp__() of the original element so that it takes its place diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 42e7522ae..bc6f51b8c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -22,6 +22,7 @@ from . import operators from . import roles from . import type_api from .annotation import Annotated +from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative from .base import Executable @@ -161,7 +162,7 @@ def not_(clause): @inspection._self_inspects -class ClauseElement(roles.SQLRole, Visitable): +class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): """Base class for elements of a programmatically constructed SQL expression. @@ -276,37 +277,6 @@ class ClauseElement(roles.SQLRole, Visitable): d.pop("_is_clone_of", None) return d - def _annotate(self, values): - """return a copy of this ClauseElement with annotations - updated by the given dictionary. - - """ - return Annotated(self, values) - - def _with_annotations(self, values): - """return a copy of this ClauseElement with annotations - replaced by the given dictionary. - - """ - return Annotated(self, values) - - def _deannotate(self, values=None, clone=False): - """return a copy of this :class:`.ClauseElement` with annotations - removed. - - :param values: optional tuple of individual values - to remove. - - """ - if clone: - # clone is used when we are also copying - # the expression for a deep deannotation - return self._clone() - else: - # if no clone, since we have no annotations we return - # self - return self - def _execute_on_connection(self, connection, multiparams, params): if self.supports_execution: return connection._execute_clauseelement(self, multiparams, params) @@ -4230,6 +4200,12 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): self._memoized_property.expire_instance(self) self.__dict__["table"] = table + def get_children(self, column_tables=False, **kw): + if column_tables and self.table is not None: + return [self.table] + else: + return [] + table = property(_get_table, _set_table) def _cache_key(self, **kw): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 03dbcd449..97c49f8fc 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -19,6 +19,7 @@ from . import operators from . import roles from . import type_api from .annotation import Annotated +from .annotation import SupportsCloneAnnotations from .base import _clone from .base import _cloned_difference from .base import _cloned_intersection @@ -2068,6 +2069,7 @@ class SelectBase( roles.InElementRole, HasCTE, Executable, + SupportsCloneAnnotations, Selectable, ): """Base class for SELECT statements. |
