diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-12-01 17:24:27 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-05-24 11:54:08 -0400 |
| commit | dce8c7a125cb99fad62c76cd145752d5afefae36 (patch) | |
| tree | 352dfa2c38005207ca64f45170bbba2c0f8c927e /lib/sqlalchemy | |
| parent | 1502b5b3e4e4b93021eb927a6623f288ef006ba6 (diff) | |
| download | sqlalchemy-dce8c7a125cb99fad62c76cd145752d5afefae36.tar.gz | |
Unify Query and select() , move all processing to compile phase
Convert Query to do virtually all compile state computation
in the _compile_context() phase, and organize it all
such that a plain select() construct may also be used as the
source of information in order to generate ORM query state.
This makes it such that Query is not needed except for
its additional methods like from_self() which are all to
be deprecated.
The construction of ORM state will occur beyond the
caching boundary when the new execution model is integrated.
future select() gains a working join() and filter_by() method.
as we continue to rebase and merge each commit in the steps,
callcounts continue to bump around. will have to look at
the final result when it's all in.
References: #5159
References: #4705
References: #4639
References: #4871
References: #5010
Change-Id: I19e05b3424b07114cce6c439b05198ac47f7ac10
Diffstat (limited to 'lib/sqlalchemy')
37 files changed, 4781 insertions, 2730 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index c477c4292..ee02899f6 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1149,6 +1149,7 @@ class Connection(Connectable): # ensure we don't retain a link to the view object for keys() # which links to the values, which we don't want to cache keys = list(distilled_params[0].keys()) + else: keys = [] @@ -1184,6 +1185,9 @@ class Connection(Connectable): schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, + compile_state_factories=exec_opts.get( + "compile_state_factories", None + ), ) cache[key] = compiled_sql @@ -1195,6 +1199,9 @@ class Connection(Connectable): inline=len(distilled_params) > 1, schema_translate_map=schema_translate_map, linting=self.dialect.compiler_linting | compiler.WARN_LINTING, + compile_state_factories=exec_opts.get( + "compile_state_factories", None + ), ) ret = self._execute_context( diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index a9c79d6bd..24af454b6 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -13,19 +13,18 @@ compiled result to be fully cached. """ -import copy import logging from .. import exc as sa_exc from .. import util from ..orm import exc as orm_exc from ..orm import strategy_options +from ..orm.context import QueryContext from ..orm.query import Query from ..orm.session import Session from ..sql import func from ..sql import literal_column from ..sql import util as sql_util -from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import collections_abc @@ -209,9 +208,7 @@ class BakedQuery(object): key += cache_key self.add_criteria( - lambda q: q._with_current_path( - effective_path - )._conditional_options(*options), + lambda q: q._with_current_path(effective_path).options(*options), cache_path.path, key, ) @@ -228,14 +225,21 @@ class BakedQuery(object): def _bake(self, session): query = self._as_query(session) - context = query._compile_context() + compile_state = query._compile_state() - self._bake_subquery_loaders(session, context) - context.session = None - context.query = query = context.query.with_session(None) + self._bake_subquery_loaders(session, compile_state) + + # TODO: compile_state clearly needs to be simplified here. + # if the session remains, fails memusage test + compile_state.orm_query = ( + query + ) = ( + compile_state.select_statement + ) = compile_state.query = compile_state.orm_query.with_session(None) query._execution_options = query._execution_options.union( {"compiled_cache": self._bakery} ) + # we'll be holding onto the query for some of its state, # so delete some compilation-use-only attributes that can take up # space @@ -251,10 +255,10 @@ class BakedQuery(object): # if the query is not safe to cache, we still do everything as though # we did cache it, since the receiver of _bake() assumes subqueryload # context was set up, etc. - if context.query._bake_ok: - self._bakery[self._effective_key(session)] = context + if compile_state.compile_options._bake_ok: + self._bakery[self._effective_key(session)] = compile_state - return context + return compile_state def to_query(self, query_or_session): """Return the :class:`_query.Query` object for use as a subquery. @@ -314,9 +318,10 @@ class BakedQuery(object): for step in self.steps[1:]: query = step(query) + return query - def _bake_subquery_loaders(self, session, context): + def _bake_subquery_loaders(self, session, compile_state): """convert subquery eager loaders in the cache into baked queries. For subquery eager loading to work, all we need here is that the @@ -325,28 +330,30 @@ class BakedQuery(object): a "baked" query so that we save on performance too. """ - context.attributes["baked_queries"] = baked_queries = [] - for k, v in list(context.attributes.items()): - if isinstance(v, Query): - if "subquery" in k: - bk = BakedQuery(self._bakery, lambda *args: v) + compile_state.attributes["baked_queries"] = baked_queries = [] + for k, v in list(compile_state.attributes.items()): + if isinstance(v, dict) and "query" in v: + if "subqueryload_data" in k: + query = v["query"] + bk = BakedQuery(self._bakery, lambda *args: query) bk._cache_key = self._cache_key + k bk._bake(session) baked_queries.append((k, bk._cache_key, v)) - del context.attributes[k] + del compile_state.attributes[k] def _unbake_subquery_loaders( - self, session, context, params, post_criteria + self, session, compile_state, context, params, post_criteria ): """Retrieve subquery eager loaders stored by _bake_subquery_loaders and turn them back into Result objects that will iterate just like a Query object. """ - if "baked_queries" not in context.attributes: + if "baked_queries" not in compile_state.attributes: return - for k, cache_key, query in context.attributes["baked_queries"]: + for k, cache_key, v in compile_state.attributes["baked_queries"]: + query = v["query"] bk = BakedQuery( self._bakery, lambda sess, q=query: q.with_session(sess) ) @@ -354,7 +361,9 @@ class BakedQuery(object): q = bk.for_session(session) for fn in post_criteria: q = q.with_post_criteria(fn) - context.attributes[k] = q.params(**params) + v = dict(v) + v["query"] = q.params(**params) + context.attributes[k] = v class Result(object): @@ -432,26 +441,37 @@ class Result(object): if not self.session.enable_baked_queries or bq._spoiled: return self._as_query()._iter() - baked_context = bq._bakery.get(bq._effective_key(self.session), None) - if baked_context is None: - baked_context = bq._bake(self.session) + baked_compile_state = bq._bakery.get( + bq._effective_key(self.session), None + ) + if baked_compile_state is None: + baked_compile_state = bq._bake(self.session) - context = copy.copy(baked_context) + context = QueryContext(baked_compile_state, self.session) context.session = self.session - context.attributes = context.attributes.copy() bq._unbake_subquery_loaders( - self.session, context, self._params, self._post_criteria + self.session, + baked_compile_state, + context, + self._params, + self._post_criteria, ) - context.statement._label_style = LABEL_STYLE_TABLENAME_PLUS_COL + # asserts true + # if isinstance(baked_compile_state.statement, expression.Select): + # assert baked_compile_state.statement._label_style == \ + # LABEL_STYLE_TABLENAME_PLUS_COL + if context.autoflush and not context.populate_existing: self.session._autoflush() - q = context.query.params(self._params).with_session(self.session) + q = context.orm_query.params(self._params).with_session(self.session) for fn in self._post_criteria: q = fn(q) - return q._execute_and_instances(context) + params = q.load_options._params + + return q._execute_and_instances(context, params=params) def count(self): """return the 'count'. @@ -566,7 +586,7 @@ class Result(object): def _load_on_pk_identity(self, query, primary_key_identity): """Load the given primary key identity from the database.""" - mapper = query._mapper_zero() + mapper = query._only_full_mapper_zero("load_on_pk_identity") _get_clause, _get_params = mapper._get_clause @@ -592,8 +612,11 @@ class Result(object): _lcl_get_clause, nones ) - _lcl_get_clause = q._adapt_clause(_lcl_get_clause, True, False) - q._criterion = _lcl_get_clause + # TODO: can mapper._get_clause be pre-adapted? + q._where_criteria = ( + sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}), + ) + for fn in self._post_criteria: q = fn(q) return q diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 931f45699..919f4409a 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -43,7 +43,10 @@ class ShardedQuery(Query): q._shard_id = shard_id return q - def _execute_and_instances(self, context): + def _execute_and_instances(self, context, params=None): + if params is None: + params = self.load_options._params + def iter_for_shard(shard_id): # shallow copy, so that each context may be used by # ORM load events and similar. @@ -54,8 +57,11 @@ class ShardedQuery(Query): "shard_id" ] = copied_context.identity_token = shard_id result_ = self._connection_from_session( - mapper=self._bind_mapper(), shard_id=shard_id - ).execute(copied_context.statement, self._params) + mapper=context.compile_state._bind_mapper(), shard_id=shard_id + ).execute( + copied_context.compile_state.statement, + self.load_options._params, + ) return self.instances(result_, copied_context) if context.identity_token is not None: @@ -78,7 +84,7 @@ class ShardedQuery(Query): clause=stmt, close_with_result=True, ) - result = conn.execute(stmt, self._params) + result = conn.execute(stmt, self.load_options._params) return result if self._shard_id is not None: diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index ec5e8985c..afd44ca3d 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -59,7 +59,6 @@ from .. import Column from .. import Table from ..engine import Engine from ..orm import class_mapper -from ..orm.attributes import QueryableAttribute from ..orm.interfaces import MapperProperty from ..orm.mapper import Mapper from ..orm.session import Session @@ -78,11 +77,7 @@ def Serializer(*args, **kw): def persistent_id(obj): # print "serializing:", repr(obj) - if isinstance(obj, QueryableAttribute): - cls = obj.impl.class_ - key = obj.impl.key - id_ = "attribute:" + key + ":" + b64encode(pickle.dumps(cls)) - elif isinstance(obj, Mapper) and not obj.non_primary: + if isinstance(obj, Mapper) and not obj.non_primary: id_ = "mapper:" + b64encode(pickle.dumps(obj.class_)) elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: id_ = ( @@ -92,7 +87,12 @@ def Serializer(*args, **kw): + obj.key ) elif isinstance(obj, Table): - id_ = "table:" + text_type(obj.key) + if "parententity" in obj._annotations: + id_ = "mapper_selectable:" + b64encode( + pickle.dumps(obj._annotations["parententity"].class_) + ) + else: + id_ = "table:" + text_type(obj.key) elif isinstance(obj, Column) and isinstance(obj.table, Table): id_ = ( "column:" + text_type(obj.table.key) + ":" + text_type(obj.key) @@ -110,7 +110,8 @@ def Serializer(*args, **kw): our_ids = re.compile( - r"(mapperprop|mapper|table|column|session|attribute|engine):(.*)" + r"(mapperprop|mapper|mapper_selectable|table|column|" + r"session|attribute|engine):(.*)" ) @@ -140,6 +141,9 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): elif type_ == "mapper": cls = pickle.loads(b64decode(args)) return class_mapper(cls) + elif type_ == "mapper_selectable": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls).__clause_element__() elif type_ == "mapperprop": mapper, keyname = args.split(":") cls = pickle.loads(b64decode(mapper)) diff --git a/lib/sqlalchemy/future/__init__.py b/lib/sqlalchemy/future/__init__.py index 635afa78c..6a3581599 100644 --- a/lib/sqlalchemy/future/__init__.py +++ b/lib/sqlalchemy/future/__init__.py @@ -11,7 +11,7 @@ from .engine import Connection # noqa from .engine import create_engine # noqa from .engine import Engine # noqa -from ..sql.selectable import Select +from .selectable import Select # noqa from ..util.langhelpers import public_factory -select = public_factory(Select._create_select, ".future.select") +select = public_factory(Select._create_future_select, ".future.select") diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py new file mode 100644 index 000000000..2b76245e0 --- /dev/null +++ b/lib/sqlalchemy/future/selectable.py @@ -0,0 +1,144 @@ +from ..sql import coercions +from ..sql import roles +from ..sql.base import _generative +from ..sql.selectable import GenerativeSelect +from ..sql.selectable import Select as _LegacySelect +from ..sql.selectable import SelectState +from ..sql.util import _entity_namespace_key + + +class Select(_LegacySelect): + _is_future = True + _setup_joins = () + _legacy_setup_joins = () + + @classmethod + def _create_select(cls, *entities): + raise NotImplementedError("use _create_future_select") + + @classmethod + def _create_future_select(cls, *entities): + r"""Construct a new :class:`_expression.Select` using the 2. + x style API. + + .. versionadded:: 2.0 - the :func:`_future.select` construct is + the same construct as the one returned by + :func:`_expression.select`, except that the function only + accepts the "columns clause" entities up front; the rest of the + state of the SELECT should be built up using generative methods. + + Similar functionality is also available via the + :meth:`_expression.FromClause.select` method on any + :class:`_expression.FromClause`. + + .. seealso:: + + :ref:`coretutorial_selecting` - Core Tutorial description of + :func:`_expression.select`. + + :param \*entities: + Entities to SELECT from. For Core usage, this is typically a series + of :class:`_expression.ColumnElement` and / or + :class:`_expression.FromClause` + objects which will form the columns clause of the resulting + statement. For those objects that are instances of + :class:`_expression.FromClause` (typically :class:`_schema.Table` + or :class:`_expression.Alias` + objects), the :attr:`_expression.FromClause.c` + collection is extracted + to form a collection of :class:`_expression.ColumnElement` objects. + + This parameter will also accept :class:`_expression.TextClause` + constructs as + given, as well as ORM-mapped classes. + + """ + + self = cls.__new__(cls) + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self) + for ent in entities + ] + + GenerativeSelect.__init__(self) + + return self + + def filter(self, *criteria): + """A synonym for the :meth:`_future.Select.where` method.""" + + return self.where(*criteria) + + def _filter_by_zero(self): + if self._setup_joins: + meth = SelectState.get_plugin_classmethod( + self, "determine_last_joined_entity" + ) + _last_joined_entity = meth(self) + if _last_joined_entity is not None: + return _last_joined_entity + + if self._from_obj: + return self._from_obj[0] + + return self._raw_columns[0] + + def filter_by(self, **kwargs): + r"""apply the given filtering criterion as a WHERE clause + to this select. + + """ + from_entity = self._filter_by_zero() + + clauses = [ + _entity_namespace_key(from_entity, key) == value + for key, value in kwargs.items() + ] + return self.filter(*clauses) + + @_generative + def join(self, target, onclause=None, isouter=False, full=False): + r"""Create a SQL JOIN against this :class:`_expresson.Select` + object's criterion + and apply generatively, returning the newly resulting + :class:`_expression.Select`. + + + """ + target = coercions.expect( + roles.JoinTargetRole, target, apply_plugins=self + ) + self._setup_joins += ( + (target, onclause, None, {"isouter": isouter, "full": full}), + ) + + @_generative + def join_from( + self, from_, target, onclause=None, isouter=False, full=False + ): + r"""Create a SQL JOIN against this :class:`_expresson.Select` + object's criterion + and apply generatively, returning the newly resulting + :class:`_expression.Select`. + + + """ + + target = coercions.expect( + roles.JoinTargetRole, target, apply_plugins=self + ) + from_ = coercions.expect( + roles.FromClauseRole, from_, apply_plugins=self + ) + + self._setup_joins += ( + (target, onclause, from_, {"isouter": isouter, "full": full}), + ) + + def outerjoin(self, target, onclause=None, full=False): + """Create a left outer join. + + + + """ + return self.join(target, onclause=onclause, isouter=True, full=full,) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 24945ef52..0a353f81c 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -30,7 +30,6 @@ from .mapper import reconstructor # noqa from .mapper import validates # noqa from .properties import ColumnProperty # noqa from .query import AliasOption # noqa -from .query import Bundle # noqa from .query import Query # noqa from .relationships import foreign # noqa from .relationships import RelationshipProperty # noqa @@ -44,6 +43,7 @@ from .session import Session # noqa from .session import sessionmaker # noqa from .strategy_options import Load # noqa from .util import aliased # noqa +from .util import Bundle # noqa from .util import join # noqa from .util import object_mapper # noqa from .util import outerjoin # noqa diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index ec706d4d8..7b4415bfe 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -49,6 +49,7 @@ from .. import event from .. import inspection from .. import util from ..sql import base as sql_base +from ..sql import roles from ..sql import visitors @@ -57,7 +58,8 @@ class QueryableAttribute( interfaces._MappedAttribute, interfaces.InspectionAttr, interfaces.PropComparator, - sql_base.HasCacheKey, + roles.JoinTargetRole, + sql_base.MemoizedHasCacheKey, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` @@ -107,12 +109,24 @@ class QueryableAttribute( self.dispatch._active_history = True _cache_key_traversal = [ - # ("class_", visitors.ExtendedInternalTraversal.dp_plain_obj), ("key", visitors.ExtendedInternalTraversal.dp_string), ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), ] + def __reduce__(self): + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + return ( + _queryable_attribute_unreduce, + ( + self.key, + self._parententity.mapper.class_, + self._parententity, + self._parententity.entity, + ), + ) + @util.memoized_property def _supports_population(self): return self.impl.supports_population @@ -208,14 +222,14 @@ class QueryableAttribute( parententity=adapt_to_entity, ) - def of_type(self, cls): + def of_type(self, entity): return QueryableAttribute( self.class_, self.key, self.impl, - self.comparator.of_type(cls), + self.comparator.of_type(entity), self._parententity, - of_type=cls, + of_type=inspection.inspect(entity), ) def label(self, name): @@ -265,6 +279,15 @@ class QueryableAttribute( return self.comparator.property +def _queryable_attribute_unreduce(key, mapped_class, parententity, entity): + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + if parententity.is_aliased_class: + return entity._get_from_serialized(key, mapped_class, parententity) + else: + return getattr(entity, key) + + class InstrumentedAttribute(QueryableAttribute): """Class bound instrumented attribute which adds basic :term:`descriptor` methods. diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 07809282b..77a85425e 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -479,6 +479,9 @@ class InspectionAttr(object): is_mapper = False """True if this object is an instance of :class:`_orm.Mapper`.""" + is_bundle = False + """True if this object is an instance of :class:`.Bundle`.""" + is_property = False """True if this object is an instance of :class:`.MapperProperty`.""" diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py new file mode 100644 index 000000000..0a3701134 --- /dev/null +++ b/lib/sqlalchemy/orm/context.py @@ -0,0 +1,2349 @@ +# orm/context.py +# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +from . import attributes +from . import interfaces +from . import loading +from .base import _is_aliased_class +from .interfaces import ORMColumnsClauseRole +from .path_registry import PathRegistry +from .util import _entity_corresponds_to +from .util import aliased +from .util import Bundle +from .util import join as orm_join +from .util import ORMAdapter +from .. import exc as sa_exc +from .. import inspect +from .. import sql +from .. import util +from ..future.selectable import Select as FutureSelect +from ..sql import coercions +from ..sql import expression +from ..sql import roles +from ..sql import util as sql_util +from ..sql import visitors +from ..sql.base import CacheableOptions +from ..sql.base import Options +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import Select +from ..sql.selectable import SelectState +from ..sql.visitors import ExtendedInternalTraversal +from ..sql.visitors import InternalTraversal + +_path_registry = PathRegistry.root + + +class QueryContext(object): + __slots__ = ( + "compile_state", + "orm_query", + "query", + "load_options", + "session", + "autoflush", + "populate_existing", + "invoke_all_eagers", + "version_check", + "refresh_state", + "create_eager_joins", + "propagate_options", + "attributes", + "runid", + "partials", + "post_load_paths", + "identity_token", + "yield_per", + ) + + class default_load_options(Options): + _only_return_tuples = False + _populate_existing = False + _version_check = False + _invoke_all_eagers = True + _autoflush = True + _refresh_identity_token = None + _yield_per = None + _refresh_state = None + _lazy_loaded_from = None + _params = util.immutabledict() + + def __init__(self, compile_state, session): + query = compile_state.query + + self.compile_state = compile_state + self.orm_query = compile_state.orm_query + self.query = compile_state.query + self.session = session + self.load_options = load_options = query.load_options + + self.propagate_options = set( + o for o in query._with_options if o.propagate_to_loaders + ) + self.attributes = dict(compile_state.attributes) + + self.autoflush = load_options._autoflush + self.populate_existing = load_options._populate_existing + self.invoke_all_eagers = load_options._invoke_all_eagers + self.version_check = load_options._version_check + self.refresh_state = load_options._refresh_state + self.yield_per = load_options._yield_per + + if self.refresh_state is not None: + self.identity_token = load_options._refresh_identity_token + else: + self.identity_token = None + + if self.yield_per and compile_state._no_yield_pers: + raise sa_exc.InvalidRequestError( + "The yield_per Query option is currently not " + "compatible with %s eager loading. Please " + "specify lazyload('*') or query.enable_eagerloads(False) in " + "order to " + "proceed with query.yield_per()." + % ", ".join(compile_state._no_yield_pers) + ) + + @property + def is_single_entity(self): + # used for the check if we return a list of entities or tuples. + # this is gone in 2.0 when we no longer make this decision. + return ( + not self.load_options._only_return_tuples + and len(self.compile_state._entities) == 1 + and self.compile_state._entities[0].supports_single_entity + ) + + +class QueryCompileState(sql.base.CompileState): + _joinpath = _joinpoint = util.immutabledict() + _from_obj_alias = None + _has_mapper_entities = False + + _has_orm_entities = False + multi_row_eager_loaders = False + compound_eager_adapter = None + loaders_require_buffering = False + loaders_require_uniquing = False + + correlate = None + _where_criteria = () + _having_criteria = () + + orm_query = None + + class default_compile_options(CacheableOptions): + _cache_key_traversal = [ + ("_bake_ok", InternalTraversal.dp_boolean), + ( + "_with_polymorphic_adapt_map", + ExtendedInternalTraversal.dp_has_cache_key_tuples, + ), + ("_current_path", InternalTraversal.dp_has_cache_key), + ("_enable_single_crit", InternalTraversal.dp_boolean), + ("_statement", InternalTraversal.dp_clauseelement), + ("_enable_eagerloads", InternalTraversal.dp_boolean), + ("_orm_only_from_obj_alias", InternalTraversal.dp_boolean), + ("_only_load_props", InternalTraversal.dp_plain_obj), + ("_set_base_alias", InternalTraversal.dp_boolean), + ("_for_refresh_state", InternalTraversal.dp_boolean), + ] + + _bake_ok = True + _with_polymorphic_adapt_map = () + _current_path = _path_registry + _enable_single_crit = True + _statement = None + _enable_eagerloads = True + _orm_only_from_obj_alias = True + _only_load_props = None + _set_base_alias = False + _for_refresh_state = False + + def __init__(self, *arg, **kw): + raise NotImplementedError() + + @classmethod + def _create_for_select(cls, statement, compiler, **kw): + if not statement._is_future: + return SelectState(statement, compiler, **kw) + + self = cls.__new__(cls) + + if not isinstance( + statement.compile_options, cls.default_compile_options + ): + statement.compile_options = cls.default_compile_options + orm_state = self._create_for_legacy_query_via_either(statement) + compile_state = SelectState(orm_state.statement, compiler, **kw) + compile_state._orm_state = orm_state + return compile_state + + @classmethod + def _create_future_select_from_query(cls, query): + stmt = FutureSelect.__new__(FutureSelect) + + # the internal state of Query is now a mirror of that of + # Select which can be transferred directly. The Select + # supports compilation into its correct form taking all ORM + # features into account via the plugin and the compile options. + # however it does not export its columns or other attributes + # correctly if deprecated ORM features that adapt plain mapped + # elements are used; for this reason the Select() returned here + # can always support direct execution, but for composition in a larger + # select only works if it does not represent legacy ORM adaption + # features. + stmt.__dict__.update( + dict( + _raw_columns=query._raw_columns, + _compile_state_plugin="orm", # ;) + _where_criteria=query._where_criteria, + _from_obj=query._from_obj, + _legacy_setup_joins=query._legacy_setup_joins, + _order_by_clauses=query._order_by_clauses, + _group_by_clauses=query._group_by_clauses, + _having_criteria=query._having_criteria, + _distinct=query._distinct, + _distinct_on=query._distinct_on, + _with_options=query._with_options, + _with_context_options=query._with_context_options, + _hints=query._hints, + _statement_hints=query._statement_hints, + _correlate=query._correlate, + _auto_correlate=query._auto_correlate, + _limit_clause=query._limit_clause, + _offset_clause=query._offset_clause, + _for_update_arg=query._for_update_arg, + _prefixes=query._prefixes, + _suffixes=query._suffixes, + _label_style=query._label_style, + compile_options=query.compile_options, + # this will be moving but for now make it work like orm.Query + load_options=query.load_options, + ) + ) + + return stmt + + @classmethod + def _create_for_legacy_query( + cls, query, for_statement=False, entities_only=False + ): + # as we are seeking to use Select() with ORM state as the + # primary executable element, have all Query objects that are not + # from_statement() convert to a Select() first, then run on that. + + if query.compile_options._statement is not None: + return cls._create_for_legacy_query_via_either( + query, + for_statement=for_statement, + entities_only=entities_only, + orm_query=query, + ) + + else: + assert query.compile_options._statement is None + + stmt = cls._create_future_select_from_query(query) + + return cls._create_for_legacy_query_via_either( + stmt, + for_statement=for_statement, + entities_only=entities_only, + orm_query=query, + ) + + @classmethod + def _create_for_legacy_query_via_either( + cls, query, for_statement=False, entities_only=False, orm_query=None + ): + + self = cls.__new__(cls) + + self._primary_entity = None + + self.has_select = isinstance(query, Select) + + if orm_query: + self.orm_query = orm_query + self.query = query + self.has_orm_query = True + else: + self.query = query + if not self.has_select: + self.orm_query = query + self.has_orm_query = True + else: + self.orm_query = None + self.has_orm_query = False + + self.select_statement = select_statement = query + + self.query = query + + self._entities = [] + + self._aliased_generations = {} + self._polymorphic_adapters = {} + self._no_yield_pers = set() + + # legacy: only for query.with_polymorphic() + self._with_polymorphic_adapt_map = wpam = dict( + select_statement.compile_options._with_polymorphic_adapt_map + ) + if wpam: + self._setup_with_polymorphics() + + _QueryEntity.to_compile_state(self, select_statement._raw_columns) + + if entities_only: + return self + + self.compile_options = query.compile_options + self.for_statement = for_statement + + if self.has_orm_query and not for_statement: + self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL + else: + self.label_style = self.select_statement._label_style + + self.labels = self.label_style is LABEL_STYLE_TABLENAME_PLUS_COL + + self.current_path = select_statement.compile_options._current_path + + self.eager_order_by = () + + if select_statement._with_options: + self.attributes = {"_unbound_load_dedupes": set()} + + for opt in self.select_statement._with_options: + if not opt._is_legacy_option: + opt.process_compile_state(self) + else: + self.attributes = {} + + if select_statement._with_context_options: + for fn, key in select_statement._with_context_options: + fn(self) + + self.primary_columns = [] + self.secondary_columns = [] + self.eager_joins = {} + self.single_inh_entities = {} + self.create_eager_joins = [] + self._fallback_from_clauses = [] + + self.from_clauses = [ + info.selectable for info in select_statement._from_obj + ] + + if self.compile_options._statement is not None: + self._setup_for_statement() + else: + self._setup_for_generate() + + return self + + def _setup_with_polymorphics(self): + # legacy: only for query.with_polymorphic() + for ext_info, wp in self._with_polymorphic_adapt_map.items(): + self._mapper_loads_polymorphically_with(ext_info, wp._adapter) + + def _set_select_from_alias(self): + + query = self.select_statement # query + + assert self.compile_options._set_base_alias + assert len(query._from_obj) == 1 + + adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) + if adapter: + self.compile_options += {"_enable_single_crit": False} + self._from_obj_alias = adapter + + def _get_select_from_alias_from_obj(self, from_obj): + info = from_obj + + if "parententity" in info._annotations: + info = info._annotations["parententity"] + + if hasattr(info, "mapper"): + if not info.is_aliased_class: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set." + ) + else: + return info._adapter + + elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): + equivs = self._all_equivs() + return sql_util.ColumnAdapter(info, equivs) + else: + return None + + def _mapper_zero(self): + """return the Mapper associated with the first QueryEntity.""" + return self._entities[0].mapper + + def _entity_zero(self): + """Return the 'entity' (mapper or AliasedClass) associated + with the first QueryEntity, or alternatively the 'select from' + entity if specified.""" + + for ent in self.from_clauses: + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + for qent in self._entities: + if qent.entity_zero: + return qent.entity_zero + + return None + + def _deep_entity_zero(self): + """Return a 'deep' entity; this is any entity we can find associated + with the first entity / column experssion. this is used only for + session.get_bind(). + + it is hoped this concept can be removed in an upcoming change + to the ORM execution model. + + """ + for ent in self.from_clauses: + if "parententity" in ent._annotations: + return ent._annotations["parententity"].mapper + for ent in self._entities: + ezero = ent._deep_entity_zero() + if ezero is not None: + return ezero.mapper + else: + return None + + @property + def _mapper_entities(self): + for ent in self._entities: + if isinstance(ent, _MapperEntity): + yield ent + + def _bind_mapper(self): + return self._deep_entity_zero() + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + return self._primary_entity.entity_zero + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale + or "This operation requires a Query " + "against a single mapper." + ) + return self._entity_zero() + + def _all_equivs(self): + equivs = {} + for ent in self._mapper_entities: + equivs.update(ent.mapper._equivalent_columns) + return equivs + + def _setup_for_generate(self): + query = self.select_statement + + self.statement = None + self._join_entities = () + + if self.compile_options._set_base_alias: + self._set_select_from_alias() + + if query._setup_joins: + self._join(query._setup_joins) + + if query._legacy_setup_joins: + self._legacy_join(query._legacy_setup_joins) + + current_adapter = self._get_current_adapter() + + if query._where_criteria: + self._where_criteria = query._where_criteria + + if current_adapter: + self._where_criteria = tuple( + current_adapter(crit, True) + for crit in self._where_criteria + ) + + # TODO: some complexity with order_by here was due to mapper.order_by. + # now that this is removed we can hopefully make order_by / + # group_by act identically to how they are in Core select. + self.order_by = ( + self._adapt_col_list(query._order_by_clauses, current_adapter) + if current_adapter and query._order_by_clauses not in (None, False) + else query._order_by_clauses + ) + + if query._having_criteria is not None: + self._having_criteria = tuple( + current_adapter(crit, True, True) if current_adapter else crit + for crit in query._having_criteria + ) + + self.group_by = ( + self._adapt_col_list( + util.flatten_iterator(query._group_by_clauses), current_adapter + ) + if current_adapter and query._group_by_clauses not in (None, False) + else query._group_by_clauses or None + ) + + if self.eager_order_by: + adapter = self.from_clauses[0]._target_adapter + self.eager_order_by = adapter.copy_and_process(self.eager_order_by) + + if query._distinct_on: + self.distinct_on = self._adapt_col_list( + query._distinct_on, current_adapter + ) + else: + self.distinct_on = () + + self.distinct = query._distinct + + if query._correlate: + # ORM mapped entities that are mapped to joins can be passed + # to .correlate, so here they are broken into their component + # tables. + self.correlate = tuple( + util.flatten_iterator( + sql_util.surface_selectables(s) if s is not None else None + for s in query._correlate + ) + ) + elif self.has_select and not query._auto_correlate: + self.correlate = (None,) + + # PART II + + self.dedupe_cols = True + + self._for_update_arg = query._for_update_arg + + for entity in self._entities: + entity.setup_compile_state(self) + + for rec in self.create_eager_joins: + strategy = rec[0] + strategy(self, *rec[1:]) + + # else "load from discrete FROMs" mode, + # i.e. when each _MappedEntity has its own FROM + + if self.compile_options._enable_single_crit: + + self._adjust_for_single_inheritance() + + if not self.primary_columns: + if self.compile_options._only_load_props: + raise sa_exc.InvalidRequestError( + "No column-based properties specified for " + "refresh operation. Use session.expire() " + "to reload collections and related items." + ) + else: + raise sa_exc.InvalidRequestError( + "Query contains no columns with which to SELECT from." + ) + + if not self.from_clauses: + self.from_clauses = list(self._fallback_from_clauses) + + if self.order_by is False: + self.order_by = None + + if self.multi_row_eager_loaders and self._should_nest_selectable: + self.statement = self._compound_eager_statement() + else: + self.statement = self._simple_statement() + + if self.for_statement: + ezero = self._mapper_zero() + if ezero is not None: + # TODO: this goes away once we get rid of the deep entity + # thing + self.statement = self.statement._annotate( + {"deepentity": ezero} + ) + + def _setup_for_statement(self): + compile_options = self.compile_options + + if ( + isinstance(compile_options._statement, expression.SelectBase) + and not compile_options._statement._is_textual + and not compile_options._statement.use_labels + ): + self.statement = compile_options._statement.apply_labels() + else: + self.statement = compile_options._statement + self.order_by = None + + if isinstance(self.statement, expression.TextClause): + # setup for all entities, including contains_eager entities. + for entity in self._entities: + entity.setup_compile_state(self) + self.statement = expression.TextualSelect( + self.statement, self.primary_columns, positional=False + ) + else: + # allow TextualSelect with implicit columns as well + # as select() with ad-hoc columns, see test_query::TextTest + self._from_obj_alias = sql.util.ColumnAdapter( + self.statement, adapt_on_names=True + ) + + def _compound_eager_statement(self): + # for eager joins present and LIMIT/OFFSET/DISTINCT, + # wrap the query inside a select, + # then append eager joins onto that + + if self.order_by: + # the default coercion for ORDER BY is now the OrderByRole, + # which adds an additional post coercion to ByOfRole in that + # elements are converted into label refernences. For the + # eager load / subquery wrapping case, we need to un-coerce + # the original expressions outside of the label references + # in order to have them render. + unwrapped_order_by = [ + elem.element + if isinstance(elem, sql.elements._label_reference) + else elem + for elem in self.order_by + ] + + order_by_col_expr = sql_util.expand_column_list_from_order_by( + self.primary_columns, unwrapped_order_by + ) + else: + order_by_col_expr = [] + unwrapped_order_by = None + + # put FOR UPDATE on the inner query, where MySQL will honor it, + # as well as if it has an OF so PostgreSQL can use it. + inner = self._select_statement( + util.unique_list(self.primary_columns + order_by_col_expr) + if self.dedupe_cols + else (self.primary_columns + order_by_col_expr), + self.from_clauses, + self._where_criteria, + self._having_criteria, + self.label_style, + self.order_by, + for_update=self._for_update_arg, + hints=self.select_statement._hints, + statement_hints=self.select_statement._statement_hints, + correlate=self.correlate, + **self._select_args + ) + + inner = inner.alias() + + equivs = self._all_equivs() + + self.compound_eager_adapter = sql_util.ColumnAdapter(inner, equivs) + + statement = sql.select( + [inner] + self.secondary_columns, use_labels=self.labels + ) + + # Oracle however does not allow FOR UPDATE on the subquery, + # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL + # we expect that all elements of the row are locked, so also put it + # on the outside (except in the case of PG when OF is used) + if ( + self._for_update_arg is not None + and self._for_update_arg.of is None + ): + statement._for_update_arg = self._for_update_arg + + from_clause = inner + for eager_join in self.eager_joins.values(): + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of + # the join should be + from_clause = sql_util.splice_joins( + from_clause, eager_join, eager_join.stop_on + ) + + statement.select_from.non_generative(statement, from_clause) + + if unwrapped_order_by: + statement.order_by.non_generative( + statement, + *self.compound_eager_adapter.copy_and_process( + unwrapped_order_by + ) + ) + + statement.order_by.non_generative(statement, *self.eager_order_by) + return statement + + def _simple_statement(self): + + if (self.distinct and not self.distinct_on) and self.order_by: + to_add = sql_util.expand_column_list_from_order_by( + self.primary_columns, self.order_by + ) + if to_add: + util.warn_deprecated_20( + "ORDER BY columns added implicitly due to " + "DISTINCT is deprecated and will be removed in " + "SQLAlchemy 2.0. SELECT statements with DISTINCT " + "should be written to explicitly include the appropriate " + "columns in the columns clause" + ) + self.primary_columns += to_add + + statement = self._select_statement( + util.unique_list(self.primary_columns + self.secondary_columns) + if self.dedupe_cols + else (self.primary_columns + self.secondary_columns), + tuple(self.from_clauses) + tuple(self.eager_joins.values()), + self._where_criteria, + self._having_criteria, + self.label_style, + self.order_by, + for_update=self._for_update_arg, + hints=self.select_statement._hints, + statement_hints=self.select_statement._statement_hints, + correlate=self.correlate, + **self._select_args + ) + + if self.eager_order_by: + statement.order_by.non_generative(statement, *self.eager_order_by) + return statement + + def _select_statement( + self, + raw_columns, + from_obj, + where_criteria, + having_criteria, + label_style, + order_by, + for_update, + hints, + statement_hints, + correlate, + limit_clause, + offset_clause, + distinct, + distinct_on, + prefixes, + suffixes, + group_by, + ): + + statement = Select.__new__(Select) + statement._raw_columns = raw_columns + statement._from_obj = from_obj + statement._label_style = label_style + + if where_criteria: + statement._where_criteria = where_criteria + if having_criteria: + statement._having_criteria = having_criteria + + if order_by: + statement._order_by_clauses += tuple(order_by) + + if distinct_on: + statement.distinct.non_generative(statement, *distinct_on) + elif distinct: + statement.distinct.non_generative(statement) + + if group_by: + statement._group_by_clauses += tuple(group_by) + + statement._limit_clause = limit_clause + statement._offset_clause = offset_clause + + if prefixes: + statement._prefixes = prefixes + + if suffixes: + statement._suffixes = suffixes + + statement._for_update_arg = for_update + + if hints: + statement._hints = hints + if statement_hints: + statement._statement_hints = statement_hints + + if correlate: + statement.correlate.non_generative(statement, *correlate) + + return statement + + def _create_with_polymorphic_adapter(self, ext_info, selectable): + if ( + not ext_info.is_aliased_class + and ext_info.mapper.persist_selectable + not in self._polymorphic_adapters + ): + self._mapper_loads_polymorphically_with( + ext_info.mapper, + sql_util.ColumnAdapter( + selectable, ext_info.mapper._equivalent_columns + ), + ) + + def _mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers or [mapper]: + self._polymorphic_adapters[m2] = adapter + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.local_table] = adapter + + def _adapt_polymorphic_element(self, element): + if "parententity" in element._annotations: + search = element._annotations["parententity"] + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + if isinstance(element, expression.FromClause): + search = element + elif hasattr(element, "table"): + search = element.table + else: + return None + + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + def _adapt_aliased_generation(self, element): + # this is crazy logic that I look forward to blowing away + # when aliased=True is gone :) + if "aliased_generation" in element._annotations: + for adapter in self._aliased_generations.get( + element._annotations["aliased_generation"], () + ): + replaced_elem = adapter.replace(element) + if replaced_elem is not None: + return replaced_elem + + return None + + def _adapt_col_list(self, cols, current_adapter): + if current_adapter: + return [current_adapter(o, True) for o in cols] + else: + return cols + + def _get_current_adapter(self): + + adapters = [] + + # vvvvvvvvvvvvvvv legacy vvvvvvvvvvvvvvvvvv + if self._from_obj_alias: + # for the "from obj" alias, apply extra rule to the + # 'ORM only' check, if this query were generated from a + # subquery of itself, i.e. _from_selectable(), apply adaption + # to all SQL constructs. + adapters.append( + ( + False + if self.compile_options._orm_only_from_obj_alias + else True, + self._from_obj_alias.replace, + ) + ) + + if self._aliased_generations: + adapters.append((False, self._adapt_aliased_generation)) + # ^^^^^^^^^^^^^ legacy ^^^^^^^^^^^^^^^^^^^^^ + + # this is the only adapter we would need going forward... + if self._polymorphic_adapters: + adapters.append((False, self._adapt_polymorphic_element)) + + if not adapters: + return None + + def _adapt_clause(clause, as_filter): + # do we adapt all expression elements or only those + # tagged as 'ORM' constructs ? + + def replace(elem): + is_orm_adapt = ( + "_orm_adapt" in elem._annotations + or "parententity" in elem._annotations + ) + for always_adapt, adapter in adapters: + if is_orm_adapt or always_adapt: + e = adapter(elem) + if e is not None: + return e + + return visitors.replacement_traverse(clause, {}, replace) + + return _adapt_clause + + def _join(self, args): + for (right, onclause, from_, flags) in args: + isouter = flags["isouter"] + full = flags["full"] + # maybe? + self._reset_joinpoint() + + if onclause is None and isinstance( + right, interfaces.PropComparator + ): + # determine onclause/right_entity. still need to think + # about how to best organize this since we are getting: + # + # + # q.join(Entity, Parent.property) + # q.join(Parent.property) + # q.join(Parent.property.of_type(Entity)) + # q.join(some_table) + # q.join(some_table, some_parent.c.id==some_table.c.parent_id) + # + # is this still too many choices? how do we handle this + # when sometimes "right" is implied and sometimes not? + # + onclause = right + right = None + + if onclause is None: + r_info = inspect(right) + if not r_info.is_selectable and not hasattr(r_info, "mapper"): + raise sa_exc.ArgumentError( + "Expected mapped entity or " + "selectable/table as join target" + ) + + if isinstance(onclause, interfaces.PropComparator): + of_type = getattr(onclause, "_of_type", None) + else: + of_type = None + + if isinstance(onclause, interfaces.PropComparator): + # descriptor/property given (or determined); this tells us + # explicitly what the expected "left" side of the join is. + if right is None: + if of_type: + right = of_type + else: + right = onclause.property.entity + + left = onclause._parententity + + alias = self._polymorphic_adapters.get(left, None) + + # could be None or could be ColumnAdapter also + if isinstance(alias, ORMAdapter) and alias.mapper.isa(left): + left = alias.aliased_class + onclause = getattr(left, onclause.key) + + prop = onclause.property + if not isinstance(onclause, attributes.QueryableAttribute): + onclause = prop + + # TODO: this is where "check for path already present" + # would occur. see if this still applies? + + if from_ is not None: + if ( + from_ is not left + and from_._annotations.get("parententity", None) + is not left + ): + raise sa_exc.InvalidRequestError( + "explicit from clause %s does not match left side " + "of relationship attribute %s" + % ( + from_._annotations.get("parententity", from_), + onclause, + ) + ) + elif from_ is not None: + prop = None + left = from_ + else: + # no descriptor/property given; we will need to figure out + # what the effective "left" side is + prop = left = None + + # figure out the final "left" and "right" sides and create an + # ORMJoin to add to our _from_obj tuple + self._join_left_to_right( + left, right, onclause, prop, False, False, isouter, full, + ) + + def _legacy_join(self, args): + """consumes arguments from join() or outerjoin(), places them into a + consistent format with which to form the actual JOIN constructs. + + """ + for (right, onclause, left, flags) in args: + + outerjoin = flags["isouter"] + create_aliases = flags["aliased"] + from_joinpoint = flags["from_joinpoint"] + full = flags["full"] + aliased_generation = flags["aliased_generation"] + + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvv + if not from_joinpoint: + self._reset_joinpoint() + else: + prev_aliased_generation = self._joinpoint.get( + "aliased_generation", None + ) + if not aliased_generation: + aliased_generation = prev_aliased_generation + elif prev_aliased_generation: + self._aliased_generations[ + aliased_generation + ] = self._aliased_generations.get( + prev_aliased_generation, () + ) + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + if ( + isinstance( + right, (interfaces.PropComparator, util.string_types) + ) + and onclause is None + ): + onclause = right + right = None + elif "parententity" in right._annotations: + right = right._annotations["parententity"].entity + + if onclause is None: + r_info = inspect(right) + if not r_info.is_selectable and not hasattr(r_info, "mapper"): + raise sa_exc.ArgumentError( + "Expected mapped entity or " + "selectable/table as join target" + ) + + if isinstance(onclause, interfaces.PropComparator): + of_type = getattr(onclause, "_of_type", None) + else: + of_type = None + + if isinstance(onclause, util.string_types): + # string given, e.g. query(Foo).join("bar"). + # we look to the left entity or what we last joined + # towards + onclause = sql.util._entity_namespace_key( + inspect(self._joinpoint_zero()), onclause + ) + + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv + # check for q.join(Class.propname, from_joinpoint=True) + # and Class corresponds at the mapper level to the current + # joinpoint. this match intentionally looks for a non-aliased + # class-bound descriptor as the onclause and if it matches the + # current joinpoint at the mapper level, it's used. This + # is a very old use case that is intended to make it easier + # to work with the aliased=True flag, which is also something + # that probably shouldn't exist on join() due to its high + # complexity/usefulness ratio + elif from_joinpoint and isinstance( + onclause, interfaces.PropComparator + ): + jp0 = self._joinpoint_zero() + info = inspect(jp0) + + if getattr(info, "mapper", None) is onclause._parententity: + onclause = sql.util._entity_namespace_key( + info, onclause.key + ) + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + if isinstance(onclause, interfaces.PropComparator): + # descriptor/property given (or determined); this tells us + # explicitly what the expected "left" side of the join is. + if right is None: + if of_type: + right = of_type + else: + right = onclause.property.entity + + left = onclause._parententity + + alias = self._polymorphic_adapters.get(left, None) + + # could be None or could be ColumnAdapter also + if isinstance(alias, ORMAdapter) and alias.mapper.isa(left): + left = alias.aliased_class + onclause = getattr(left, onclause.key) + + prop = onclause.property + if not isinstance(onclause, attributes.QueryableAttribute): + onclause = prop + + if not create_aliases: + # check for this path already present. + # don't render in that case. + edge = (left, right, prop.key) + if edge in self._joinpoint: + # The child's prev reference might be stale -- + # it could point to a parent older than the + # current joinpoint. If this is the case, + # then we need to update it and then fix the + # tree's spine with _update_joinpoint. Copy + # and then mutate the child, which might be + # shared by a different query object. + jp = self._joinpoint[edge].copy() + jp["prev"] = (edge, self._joinpoint) + self._update_joinpoint(jp) + + continue + + else: + # no descriptor/property given; we will need to figure out + # what the effective "left" side is + prop = left = None + + # figure out the final "left" and "right" sides and create an + # ORMJoin to add to our _from_obj tuple + self._join_left_to_right( + left, + right, + onclause, + prop, + create_aliases, + aliased_generation, + outerjoin, + full, + ) + + def _joinpoint_zero(self): + return self._joinpoint.get("_joinpoint_entity", self._entity_zero()) + + def _join_left_to_right( + self, + left, + right, + onclause, + prop, + create_aliases, + aliased_generation, + outerjoin, + full, + ): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ + + if left is None: + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + ( + left, + replace_from_obj_index, + use_entity_index, + ) = self._join_determine_implicit_left_side(left, right, onclause) + else: + # left is given via a relationship/name, or as explicit left side. + # Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + ( + replace_from_obj_index, + use_entity_index, + ) = self._join_place_explicit_left_side(left) + + if left is right and not create_aliases: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % (left, right) + ) + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the new effective "right" side + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, create_aliases, aliased_generation + ) + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + [ + orm_join( + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + ) + ] + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + # add a new element to the self._from_obj list + if use_entity_index is not None: + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance( + self._entities[use_entity_index], _MapperEntity + ) + left_clause = self._entities[use_entity_index].selectable + else: + left_clause = left + + self.from_clauses = self.from_clauses + [ + orm_join( + left_clause, right, onclause, isouter=outerjoin, full=full + ) + ] + + def _join_determine_implicit_left_side(self, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self.from_clauses: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + indexes = sql_util.find_left_clause_to_join_from( + self.from_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self.from_clauses[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + + elif self._entities: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(self._entities): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, _MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + + all_clauses = list(potential.keys()) + indexes = sql_util.find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + else: + raise sa_exc.InvalidRequestError( + "No entities to join from; please use " + "select_from() to establish the left " + "entity/selectable of this join" + ) + + return left, replace_from_obj_index, use_entity_index + + def _join_place_explicit_left_side(self, left): + """When join conditions express a left side explicitly, determine + where in our existing list of FROM clauses we should join towards, + or if we need to make a new join, and if so is it from one of our + existing entities. + + """ + + # when we are here, it means join() was called with an indicator + # as to an exact left side, which means a path to a + # RelationshipProperty was given, e.g.: + # + # join(RightEntity, LeftEntity.right) + # + # or + # + # join(LeftEntity.right) + # + # as well as string forms: + # + # join(RightEntity, "right") + # + # etc. + # + + replace_from_obj_index = use_entity_index = None + + l_info = inspect(left) + if self.from_clauses: + indexes = sql_util.find_left_clause_that_matches_given( + self.from_clauses, l_info.selectable + ) + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't identify which entity in which to assign the " + "left side of this join. Please use a more specific " + "ON clause." + ) + + # have an index, means the left side is already present in + # an existing FROM in the self._from_obj tuple + if indexes: + replace_from_obj_index = indexes[0] + + # no index, means we need to add a new element to the + # self._from_obj tuple + + # no from element present, so we will have to add to the + # self._from_obj tuple. Determine if this left side matches up + # with existing mapper entities, in which case we want to apply the + # aliasing / adaptation rules present on that entity if any + if ( + replace_from_obj_index is None + and self._entities + and hasattr(l_info, "mapper") + ): + for idx, ent in enumerate(self._entities): + # TODO: should we be checking for multiple mapper entities + # matching? + if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): + use_entity_index = idx + break + + return replace_from_obj_index, use_entity_index + + def _join_check_and_adapt_right_side( + self, left, right, onclause, prop, create_aliases, aliased_generation + ): + """transform the "right" side of the join as well as the onclause + according to polymorphic mapping translations, aliasing on the query + or on the join, special cases where the right and left side have + overlapping tables. + + """ + + l_info = inspect(left) + r_info = inspect(right) + + overlap = False + if not create_aliases: + right_mapper = getattr(r_info, "mapper", None) + # if the target is a joined inheritance mapping, + # be more liberal about auto-aliasing. + if right_mapper and ( + right_mapper.with_polymorphic + or isinstance(right_mapper.persist_selectable, expression.Join) + ): + for from_obj in self.from_clauses or [l_info.selectable]: + if sql_util.selectables_overlap( + l_info.selectable, from_obj + ) and sql_util.selectables_overlap( + from_obj, r_info.selectable + ): + overlap = True + break + + if ( + overlap or not create_aliases + ) and l_info.selectable is r_info.selectable: + raise sa_exc.InvalidRequestError( + "Can't join table/selectable '%s' to itself" + % l_info.selectable + ) + + right_mapper, right_selectable, right_is_aliased = ( + getattr(r_info, "mapper", None), + r_info.selectable, + getattr(r_info, "is_aliased_class", False), + ) + + if ( + right_mapper + and prop + and not right_mapper.common_parent(prop.mapper) + ): + raise sa_exc.InvalidRequestError( + "Join target %s does not correspond to " + "the right side of join condition %s" % (right, onclause) + ) + + # _join_entities is used as a hint for single-table inheritance + # purposes at the moment + if hasattr(r_info, "mapper"): + self._join_entities += (r_info,) + + need_adapter = False + + # test for joining to an unmapped selectable as the target + if r_info.is_clause_element: + + if prop: + right_mapper = prop.mapper + + if right_selectable._is_lateral: + # orm_only is disabled to suit the case where we have to + # adapt an explicit correlate(Entity) - the select() loses + # the ORM-ness in this case right now, ideally it would not + current_adapter = self._get_current_adapter() + if current_adapter is not None: + # TODO: we had orm_only=False here before, removing + # it didn't break things. if we identify the rationale, + # may need to apply "_orm_only" annotation here. + right = current_adapter(right, True) + + elif prop: + # joining to selectable with a mapper property given + # as the ON clause + + if not right_selectable.is_derived_from( + right_mapper.persist_selectable + ): + raise sa_exc.InvalidRequestError( + "Selectable '%s' is not derived from '%s'" + % ( + right_selectable.description, + right_mapper.persist_selectable.description, + ) + ) + + # if the destination selectable is a plain select(), + # turn it into an alias(). + if isinstance(right_selectable, expression.SelectBase): + right_selectable = coercions.expect( + roles.FromClauseRole, right_selectable + ) + need_adapter = True + + # make the right hand side target into an ORM entity + right = aliased(right_mapper, right_selectable) + elif create_aliases: + # it *could* work, but it doesn't right now and I'd rather + # get rid of aliased=True completely + raise sa_exc.InvalidRequestError( + "The aliased=True parameter on query.join() only works " + "with an ORM entity, not a plain selectable, as the " + "target." + ) + + aliased_entity = ( + right_mapper + and not right_is_aliased + and ( + # TODO: there is a reliance here on aliasing occurring + # when we join to a polymorphic mapper that doesn't actually + # need aliasing. When this condition is present, we should + # be able to say mapper_loads_polymorphically_with() + # and render the straight polymorphic selectable. this + # does not appear to be possible at the moment as the + # adapter no longer takes place on the rest of the query + # and it's not clear where that's failing to happen. + ( + right_mapper.with_polymorphic + and isinstance( + right_mapper._with_polymorphic_selectable, + expression.AliasedReturnsRows, + ) + ) + or overlap + # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + ) + ) + + if not need_adapter and (create_aliases or aliased_entity): + # there are a few places in the ORM that automatic aliasing + # is still desirable, and can't be automatic with a Core + # only approach. For illustrations of "overlaps" see + # test/orm/inheritance/test_relationships.py. There are also + # general overlap cases with many-to-many tables where automatic + # aliasing is desirable. + right = aliased(right, flat=True) + need_adapter = True + + if need_adapter: + assert right_mapper + + adapter = ORMAdapter( + right, equivalents=right_mapper._equivalent_columns + ) + + # if an alias() on the right side was generated, + # which is intended to wrap a the right side in a subquery, + # ensure that columns retrieved from this target in the result + # set are also adapted. + if not create_aliases: + self._mapper_loads_polymorphically_with(right_mapper, adapter) + elif aliased_generation: + adapter._debug = True + self._aliased_generations[aliased_generation] = ( + adapter, + ) + self._aliased_generations.get(aliased_generation, ()) + + # if the onclause is a ClauseElement, adapt it with any + # adapters that are in place right now + if isinstance(onclause, expression.ClauseElement): + current_adapter = self._get_current_adapter() + if current_adapter: + onclause = current_adapter(onclause, True) + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if not create_aliases and prop: + self._update_joinpoint( + { + "_joinpoint_entity": right, + "prev": ((left, right, prop.key), self._joinpoint), + "aliased_generation": aliased_generation, + } + ) + else: + self._joinpoint = { + "_joinpoint_entity": right, + "aliased_generation": aliased_generation, + } + + return right, inspect(right), onclause + + def _update_joinpoint(self, jp): + self._joinpoint = jp + # copy backwards to the root of the _joinpath + # dict, so that no existing dict in the path is mutated + while "prev" in jp: + f, prev = jp["prev"] + prev = dict(prev) + prev[f] = jp.copy() + jp["prev"] = (f, prev) + jp = prev + self._joinpath = jp + + def _reset_joinpoint(self): + self._joinpoint = self._joinpath + + @property + def _select_args(self): + return { + "limit_clause": self.select_statement._limit_clause, + "offset_clause": self.select_statement._offset_clause, + "distinct": self.distinct, + "distinct_on": self.distinct_on, + "prefixes": self.query._prefixes, + "suffixes": self.query._suffixes, + "group_by": self.group_by or None, + } + + @property + def _should_nest_selectable(self): + kwargs = self._select_args + return ( + kwargs.get("limit_clause") is not None + or kwargs.get("offset_clause") is not None + or kwargs.get("distinct", False) + or kwargs.get("distinct_on", ()) + or kwargs.get("group_by", False) + ) + + def _adjust_for_single_inheritance(self): + """Apply single-table-inheritance filtering. + + For all distinct single-table-inheritance mappers represented in + the columns clause of this query, as well as the "select from entity", + add criterion to the WHERE + clause of the given QueryContext such that only the appropriate + subtypes are selected from the total results. + + """ + + for fromclause in self.from_clauses: + ext_info = fromclause._annotations.get("parententity", None) + if ( + ext_info + and ext_info.mapper._single_table_criterion is not None + and ext_info not in self.single_inh_entities + ): + + self.single_inh_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + + search = set(self.single_inh_entities.values()) + + for (ext_info, adapter) in search: + if ext_info in self._join_entities: + continue + single_crit = ext_info.mapper._single_table_criterion + if single_crit is not None: + if adapter: + single_crit = adapter.traverse(single_crit) + + current_adapter = self._get_current_adapter() + if current_adapter: + single_crit = sql_util._deep_annotate( + single_crit, {"_orm_adapt": True} + ) + single_crit = current_adapter(single_crit, False) + self._where_criteria += (single_crit,) + + +def _column_descriptions(query_or_select_stmt): + # TODO: this is a hack for now, as it is a little bit non-performant + # to build up QueryEntity for every entity right now. + ctx = QueryCompileState._create_for_legacy_query_via_either( + query_or_select_stmt, + entities_only=True, + orm_query=query_or_select_stmt + if not isinstance(query_or_select_stmt, Select) + else None, + ) + return [ + { + "name": ent._label_name, + "type": ent.type, + "aliased": getattr(insp_ent, "is_aliased_class", False), + "expr": ent.expr, + "entity": getattr(insp_ent, "entity", None) + if ent.entity_zero is not None and not insp_ent.is_clause_element + else None, + } + for ent, insp_ent in [ + ( + _ent, + ( + inspect(_ent.entity_zero) + if _ent.entity_zero is not None + else None + ), + ) + for _ent in ctx._entities + ] + ] + + +def _legacy_filter_by_entity_zero(query_or_augmented_select): + self = query_or_augmented_select + if self._legacy_setup_joins: + _last_joined_entity = self._last_joined_entity + if _last_joined_entity is not None: + return _last_joined_entity + + if self._from_obj and "parententity" in self._from_obj[0]._annotations: + return self._from_obj[0]._annotations["parententity"] + + return _entity_from_pre_ent_zero(self) + + +def _entity_from_pre_ent_zero(query_or_augmented_select): + self = query_or_augmented_select + if not self._raw_columns: + return None + + ent = self._raw_columns[0] + + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + elif isinstance(ent, ORMColumnsClauseRole): + return ent.entity + elif "bundle" in ent._annotations: + return ent._annotations["bundle"] + else: + return ent + + +@sql.base.CompileState.plugin_for( + "orm", "select", "determine_last_joined_entity" +) +def _determine_last_joined_entity(statement): + setup_joins = statement._setup_joins + + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance(target, interfaces.PropComparator): + return target.entity + else: + return target + + +def _legacy_determine_last_joined_entity(setup_joins, entity_zero): + """given the legacy_setup_joins collection at a point in time, + figure out what the "filter by entity" would be in terms + of those joins. + + in 2.0 this logic should hopefully be much simpler as there will + be far fewer ways to specify joins with the ORM + + """ + + if not setup_joins: + return entity_zero + + # CAN BE REMOVED IN 2.0: + # 1. from_joinpoint + # 2. aliased_generation + # 3. aliased + # 4. any treating of prop as str + # 5. tuple madness + # 6. won't need recursive call anymore without #4 + # 7. therefore can pass in just the last setup_joins record, + # don't need entity_zero + + (right, onclause, left_, flags) = setup_joins[-1] + + from_joinpoint = flags["from_joinpoint"] + + if onclause is None and isinstance( + right, (str, interfaces.PropComparator) + ): + onclause = right + right = None + + if right is not None and "parententity" in right._annotations: + right = right._annotations["parententity"].entity + + if onclause is not None and right is not None: + last_entity = right + insp = inspect(last_entity) + if insp.is_clause_element or insp.is_aliased_class or insp.is_mapper: + return insp + + last_entity = onclause + if isinstance(last_entity, interfaces.PropComparator): + return last_entity.entity + + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv + if isinstance(onclause, str): + if from_joinpoint: + prev = _legacy_determine_last_joined_entity( + setup_joins[0:-1], entity_zero + ) + else: + prev = entity_zero + + if prev is None: + return None + + prev = inspect(prev) + attr = getattr(prev.entity, onclause, None) + if attr is not None: + return attr.property.entity + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + return None + + +class _QueryEntity(object): + """represent an entity column returned within a Query result.""" + + __slots__ = () + + @classmethod + def to_compile_state(cls, compile_state, entities): + for entity in entities: + if entity.is_clause_element: + if entity.is_selectable: + if "parententity" in entity._annotations: + _MapperEntity(compile_state, entity) + else: + _ColumnEntity._for_columns( + compile_state, entity._select_iterable + ) + else: + if entity._annotations.get("bundle", False): + _BundleEntity(compile_state, entity) + elif entity._is_clause_list: + # this is legacy only - test_composites.py + # test_query_cols_legacy + _ColumnEntity._for_columns( + compile_state, entity._select_iterable + ) + else: + _ColumnEntity._for_columns(compile_state, [entity]) + elif entity.is_bundle: + _BundleEntity(compile_state, entity) + + +class _MapperEntity(_QueryEntity): + """mapper/class/AliasedClass entity""" + + __slots__ = ( + "expr", + "mapper", + "entity_zero", + "is_aliased_class", + "path", + "_extra_entities", + "_label_name", + "_with_polymorphic_mappers", + "selectable", + "_polymorphic_discriminator", + ) + + def __init__(self, compile_state, entity): + compile_state._entities.append(self) + if compile_state._primary_entity is None: + compile_state._primary_entity = self + compile_state._has_mapper_entities = True + compile_state._has_orm_entities = True + + entity = entity._annotations["parententity"] + entity._post_inspect + ext_info = self.entity_zero = entity + entity = ext_info.entity + + self.expr = entity + self.mapper = mapper = ext_info.mapper + + self._extra_entities = (self.expr,) + + if ext_info.is_aliased_class: + self._label_name = ext_info.name + else: + self._label_name = mapper.class_.__name__ + + self.is_aliased_class = ext_info.is_aliased_class + self.path = ext_info._path_registry + + if ext_info in compile_state._with_polymorphic_adapt_map: + # this codepath occurs only if query.with_polymorphic() were + # used + + wp = inspect(compile_state._with_polymorphic_adapt_map[ext_info]) + + if self.is_aliased_class: + # TODO: invalidrequest ? + raise NotImplementedError( + "Can't use with_polymorphic() against an Aliased object" + ) + + mappers, from_obj = mapper._with_polymorphic_args( + wp.with_polymorphic_mappers, wp.selectable + ) + + self._with_polymorphic_mappers = mappers + self.selectable = from_obj + self._polymorphic_discriminator = wp.polymorphic_on + + else: + self.selectable = ext_info.selectable + self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers + self._polymorphic_discriminator = ext_info.polymorphic_on + + if mapper.with_polymorphic or mapper._requires_row_aliasing: + compile_state._create_with_polymorphic_adapter( + ext_info, self.selectable + ) + + supports_single_entity = True + + use_id_for_hash = True + + @property + def type(self): + return self.mapper.class_ + + @property + def entity_zero_or_selectable(self): + return self.entity_zero + + def _deep_entity_zero(self): + return self.entity_zero + + def corresponds_to(self, entity): + return _entity_corresponds_to(self.entity_zero, entity) + + def _get_entity_clauses(self, compile_state): + + adapter = None + + if not self.is_aliased_class: + if compile_state._polymorphic_adapters: + adapter = compile_state._polymorphic_adapters.get( + self.mapper, None + ) + else: + adapter = self.entity_zero._adapter + + if adapter: + if compile_state._from_obj_alias: + ret = adapter.wrap(compile_state._from_obj_alias) + else: + ret = adapter + else: + ret = compile_state._from_obj_alias + + return ret + + def row_processor(self, context, result): + compile_state = context.compile_state + adapter = self._get_entity_clauses(compile_state) + + if compile_state.compound_eager_adapter and adapter: + adapter = adapter.wrap(compile_state.compound_eager_adapter) + elif not adapter: + adapter = compile_state.compound_eager_adapter + + if compile_state._primary_entity is self: + only_load_props = compile_state.compile_options._only_load_props + refresh_state = context.refresh_state + else: + only_load_props = refresh_state = None + + _instance = loading._instance_processor( + self.mapper, + context, + result, + self.path, + adapter, + only_load_props=only_load_props, + refresh_state=refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + + return _instance, self._label_name, self._extra_entities + + def setup_compile_state(self, compile_state): + + adapter = self._get_entity_clauses(compile_state) + + single_table_crit = self.mapper._single_table_criterion + if single_table_crit is not None: + ext_info = self.entity_zero + compile_state.single_inh_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + + compile_state._fallback_from_clauses.append(self.selectable) + + +class _BundleEntity(_QueryEntity): + use_id_for_hash = False + + _extra_entities = () + + __slots__ = ( + "bundle", + "expr", + "type", + "_label_name", + "_entities", + "supports_single_entity", + ) + + def __init__( + self, compile_state, expr, setup_entities=True, parent_bundle=None + ): + compile_state._has_orm_entities = True + + expr = expr._annotations["bundle"] + if parent_bundle: + parent_bundle._entities.append(self) + else: + compile_state._entities.append(self) + + if isinstance( + expr, (attributes.QueryableAttribute, interfaces.PropComparator) + ): + bundle = expr.__clause_element__() + else: + bundle = expr + + self.bundle = self.expr = bundle + self.type = type(bundle) + self._label_name = bundle.name + self._entities = [] + + if setup_entities: + for expr in bundle.exprs: + if "bundle" in expr._annotations: + _BundleEntity(compile_state, expr, parent_bundle=self) + elif isinstance(expr, Bundle): + _BundleEntity(compile_state, expr, parent_bundle=self) + else: + _ORMColumnEntity._for_columns( + compile_state, [expr], parent_bundle=self + ) + + self.supports_single_entity = self.bundle.single_entity + + @property + def mapper(self): + ezero = self.entity_zero + if ezero is not None: + return ezero.mapper + else: + return None + + @property + def entity_zero(self): + for ent in self._entities: + ezero = ent.entity_zero + if ezero is not None: + return ezero + else: + return None + + def corresponds_to(self, entity): + # TODO: we might be able to implement this but for now + # we are working around it + return False + + @property + def entity_zero_or_selectable(self): + for ent in self._entities: + ezero = ent.entity_zero_or_selectable + if ezero is not None: + return ezero + else: + return None + + def _deep_entity_zero(self): + for ent in self._entities: + ezero = ent._deep_entity_zero() + if ezero is not None: + return ezero + else: + return None + + def setup_compile_state(self, compile_state): + for ent in self._entities: + ent.setup_compile_state(compile_state) + + def row_processor(self, context, result): + procs, labels, extra = zip( + *[ent.row_processor(context, result) for ent in self._entities] + ) + + proc = self.bundle.create_row_processor(context.query, procs, labels) + + return proc, self._label_name, self._extra_entities + + +class _ColumnEntity(_QueryEntity): + __slots__ = () + + @classmethod + def _for_columns(cls, compile_state, columns, parent_bundle=None): + for column in columns: + annotations = column._annotations + if "parententity" in annotations: + _entity = annotations["parententity"] + else: + _entity = sql_util.extract_first_column_annotation( + column, "parententity" + ) + + if _entity: + _ORMColumnEntity( + compile_state, column, _entity, parent_bundle=parent_bundle + ) + else: + _RawColumnEntity( + compile_state, column, parent_bundle=parent_bundle + ) + + @property + def type(self): + return self.column.type + + @property + def use_id_for_hash(self): + return not self.column.type.hashable + + +class _RawColumnEntity(_ColumnEntity): + entity_zero = None + mapper = None + supports_single_entity = False + + __slots__ = ( + "expr", + "column", + "_label_name", + "entity_zero_or_selectable", + "_extra_entities", + ) + + def __init__(self, compile_state, column, parent_bundle=None): + self.expr = column + self._label_name = getattr(column, "key", None) + + if parent_bundle: + parent_bundle._entities.append(self) + else: + compile_state._entities.append(self) + + self.column = column + self.entity_zero_or_selectable = ( + self.column._from_objects[0] if self.column._from_objects else None + ) + self._extra_entities = (self.expr, self.column) + + def _deep_entity_zero(self): + for obj in visitors.iterate( + self.column, {"column_tables": True, "column_collections": False}, + ): + if "parententity" in obj._annotations: + return obj._annotations["parententity"] + elif "deepentity" in obj._annotations: + return obj._annotations["deepentity"] + else: + return None + + def corresponds_to(self, entity): + return False + + def row_processor(self, context, result): + if ("fetch_column", self) in context.attributes: + column = context.attributes[("fetch_column", self)] + else: + column = self.column + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + compile_state = context.compile_state + if compile_state.compound_eager_adapter: + column = compile_state.compound_eager_adapter.columns[column] + + getter = result._getter(column) + return getter, self._label_name, self._extra_entities + + def setup_compile_state(self, compile_state): + current_adapter = compile_state._get_current_adapter() + if current_adapter: + column = current_adapter(self.column, False) + else: + column = self.column + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + compile_state.primary_columns.append(column) + compile_state.attributes[("fetch_column", self)] = column + + +class _ORMColumnEntity(_ColumnEntity): + """Column/expression based entity.""" + + supports_single_entity = False + + __slots__ = ( + "expr", + "mapper", + "column", + "_label_name", + "entity_zero_or_selectable", + "entity_zero", + "_extra_entities", + ) + + def __init__( + self, compile_state, column, parententity, parent_bundle=None, + ): + + annotations = column._annotations + + _entity = parententity + + # an AliasedClass won't have orm_key in the annotations for + # a column if it was acquired using the class' adapter directly, + # such as using AliasedInsp._adapt_element(). this occurs + # within internal loaders. + self._label_name = _label_name = annotations.get("orm_key", None) + if _label_name: + self.expr = getattr(_entity.entity, _label_name) + else: + self._label_name = getattr(column, "key", None) + self.expr = column + + _entity._post_inspect + self.entity_zero = self.entity_zero_or_selectable = ezero = _entity + self.mapper = _entity.mapper + + if parent_bundle: + parent_bundle._entities.append(self) + else: + compile_state._entities.append(self) + + compile_state._has_orm_entities = True + self.column = column + + self._extra_entities = (self.expr, self.column) + + if self.mapper.with_polymorphic: + compile_state._create_with_polymorphic_adapter( + ezero, ezero.selectable + ) + + def _deep_entity_zero(self): + return self.mapper + + def corresponds_to(self, entity): + if _is_aliased_class(entity): + # TODO: polymorphic subclasses ? + return entity is self.entity_zero + else: + return not _is_aliased_class( + self.entity_zero + ) and entity.common_parent(self.entity_zero) + + def row_processor(self, context, result): + compile_state = context.compile_state + + if ("fetch_column", self) in context.attributes: + column = context.attributes[("fetch_column", self)] + else: + column = self.column + if compile_state._from_obj_alias: + column = compile_state._from_obj_alias.columns[column] + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + if compile_state.compound_eager_adapter: + column = compile_state.compound_eager_adapter.columns[column] + + getter = result._getter(column) + return getter, self._label_name, self._extra_entities + + def setup_compile_state(self, compile_state): + current_adapter = compile_state._get_current_adapter() + if current_adapter: + column = current_adapter(self.column, False) + else: + column = self.column + ezero = self.entity_zero + + single_table_crit = self.mapper._single_table_criterion + if single_table_crit is not None: + compile_state.single_inh_entities[ezero] = ( + ezero, + ezero._adapter if ezero.is_aliased_class else None, + ) + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + # use entity_zero as the from if we have it. this is necessary + # for polymorpic scenarios where our FROM is based on ORM entity, + # not the FROM of the column. but also, don't use it if our column + # doesn't actually have any FROMs that line up, such as when its + # a scalar subquery. + if set(self.column._from_objects).intersection( + ezero.selectable._from_objects + ): + compile_state._fallback_from_clauses.append(ezero.selectable) + + compile_state.primary_columns.append(column) + + compile_state.attributes[("fetch_column", self)] = column + + +sql.base.CompileState.plugin_for("orm", "select")( + QueryCompileState._create_for_select +) diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 7fff13101..6be4f0dff 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -12,7 +12,7 @@ as actively in the load/persist ORM loop. """ from . import attributes -from . import query +from . import util as orm_util from .interfaces import MapperProperty from .interfaces import PropComparator from .util import _none_set @@ -362,7 +362,7 @@ class CompositeProperty(DescriptorProperty): def _comparator_factory(self, mapper): return self.comparator_factory(self, mapper) - class CompositeBundle(query.Bundle): + class CompositeBundle(orm_util.Bundle): def __init__(self, property_, expr): self.property = property_ super(CompositeProperty.CompositeBundle, self).__init__( diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 2a3ef54dd..adc976e32 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -279,10 +279,21 @@ class AppenderMixin(object): # doesn't fail, and secondary is then in _from_obj[1]. self._from_obj = (prop.mapper.selectable, prop.secondary) - self._criterion = prop._with_parent(instance, alias_secondary=False) + self._where_criteria += ( + prop._with_parent(instance, alias_secondary=False), + ) if self.attr.order_by: - self._order_by = self.attr.order_by + + if ( + self._order_by_clauses is False + or self._order_by_clauses is None + ): + self._order_by_clauses = tuple(self.attr.order_by) + else: + self._order_by_clauses = self._order_by_clauses + tuple( + self.attr.order_by + ) def session(self): sess = object_session(self.instance) @@ -354,9 +365,9 @@ class AppenderMixin(object): else: query = sess.query(self.attr.target_mapper) - query._criterion = self._criterion + query._where_criteria = self._where_criteria query._from_obj = self._from_obj - query._order_by = self._order_by + query._order_by_clauses = self._order_by_clauses return query diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index cc5e703dc..51bc8e426 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -9,6 +9,7 @@ import operator from .. import inspect from .. import util +from ..sql import and_ from ..sql import operators @@ -55,7 +56,12 @@ class EvaluatorCompiler(object): def __init__(self, target_cls=None): self.target_cls = target_cls - def process(self, clause): + def process(self, *clauses): + if len(clauses) > 1: + clause = and_(*clauses) + elif clauses: + clause = clauses[0] + meth = getattr(self, "visit_%s" % clause.__visit_name__, None) if not meth: raise UnevaluatableError( diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 612724357..313f2fda8 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -36,6 +36,7 @@ from .. import inspect from .. import inspection from .. import util from ..sql import operators +from ..sql import roles from ..sql import visitors from ..sql.traversals import HasCacheKey @@ -56,12 +57,25 @@ __all__ = ( "NOT_EXTENSION", "LoaderStrategy", "MapperOption", + "LoaderOption", "MapperProperty", "PropComparator", "StrategizedProperty", ) +class ORMColumnsClauseRole(roles.ColumnsClauseRole): + _role_name = "ORM mapped entity, aliased entity, or Column expression" + + +class ORMEntityColumnsClauseRole(ORMColumnsClauseRole): + _role_name = "ORM mapped or aliased entity" + + +class ORMFromClauseRole(roles.StrictFromClauseRole): + _role_name = "ORM mapped entity, aliased entity, or FROM expression" + + class MapperProperty( HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots ): @@ -620,6 +634,8 @@ class StrategizedProperty(MapperProperty): @classmethod def _strategy_lookup(cls, requesting_property, *key): + requesting_property.parent._with_polymorphic_mappers + for prop_cls in cls.__mro__: if prop_cls in cls._all_strategies: strategies = cls._all_strategies[prop_cls] @@ -646,8 +662,52 @@ class StrategizedProperty(MapperProperty): ) +class LoaderOption(HasCacheKey): + """Describe a modification to an ORM statement at compilation time. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _is_legacy_option = False + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + + """ + + def process_compile_state(self, compile_state): + """Apply a modification to a given :class:`.CompileState`.""" + + def _generate_path_cache_key(self, path): + """Used by the "baked lazy loader" to see if this option can be cached. + + .. deprecated:: 2.0 this method is to suit the baked extension which + is itself not part of 2.0. + + """ + return False + + +@util.deprecated_cls( + "1.4", + "The :class:`.MapperOption class is deprecated and will be removed " + "in a future release. ORM options now run within the compilation " + "phase and are based on the :class:`.LoaderOption` class which is " + "intended for internal consumption only. For " + "modifications to queries on a per-execution basis, the " + ":meth:`.before_execute` hook will now intercept ORM :class:`.Query` " + "objects before they are invoked", + constructor=None, +) class MapperOption(object): - """Describe a modification to a Query.""" + """Describe a modification to a Query""" + + _is_legacy_option = True propagate_to_loaders = False """if True, indicate this option should be carried along @@ -663,7 +723,7 @@ class MapperOption(object): """same as process_query(), except that this option may not apply to the given query. - This is typically used during a lazy load or scalar refresh + This is typically applied during a lazy load or scalar refresh operation to propagate options stated in the original Query to the new Query being used for the load. It occurs for those options that specify propagate_to_loaders=True. @@ -770,7 +830,7 @@ class LoaderStrategy(object): pass def setup_query( - self, context, query_entity, path, loadopt, adapter, **kwargs + self, compile_state, query_entity, path, loadopt, adapter, **kwargs ): """Establish column and other state for a given QueryContext. diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 0394d999c..48641685e 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -24,7 +24,6 @@ from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED from .util import _none_set -from .util import aliased from .util import state_str from .. import exc as sa_exc from .. import util @@ -43,21 +42,23 @@ def instances(query, cursor, context): context.runid = _new_runid() context.post_load_paths = {} + compile_state = context.compile_state + filtered = compile_state._has_mapper_entities single_entity = context.is_single_entity try: (process, labels, extra) = list( zip( *[ - query_entity.row_processor(query, context, cursor) - for query_entity in query._entities + query_entity.row_processor(context, cursor) + for query_entity in context.compile_state._entities ] ) ) - if query._yield_per and ( - context.loaders_require_buffering - or context.loaders_require_uniquing + if context.yield_per and ( + context.compile_state.loaders_require_buffering + or context.compile_state.loaders_require_uniquing ): raise sa_exc.InvalidRequestError( "Can't use yield_per with eager loaders that require uniquing " @@ -74,7 +75,8 @@ def instances(query, cursor, context): labels, extra, _unique_filters=[ - id if ent.use_id_for_hash else None for ent in query._entities + id if ent.use_id_for_hash else None + for ent in context.compile_state._entities ], ) @@ -86,6 +88,7 @@ def instances(query, cursor, context): if yield_per: fetch = cursor.fetchmany(yield_per) + if not fetch: break else: @@ -110,13 +113,13 @@ def instances(query, cursor, context): result = ChunkedIteratorResult( row_metadata, chunks, source_supports_scalars=single_entity ) - if query._yield_per: - result.yield_per(query._yield_per) + if context.yield_per: + result.yield_per(context.yield_per) if single_entity: result = result.scalars() - filtered = query._has_mapper_entities + filtered = context.compile_state._has_mapper_entities if filtered: result = result.unique() @@ -124,10 +127,10 @@ def instances(query, cursor, context): return result -@util.preload_module("sqlalchemy.orm.query") +@util.preload_module("sqlalchemy.orm.context") def merge_result(query, iterator, load=True): - """Merge a result into this :class:`_query.Query` object's Session.""" - querylib = util.preloaded.orm_query + """Merge a result into this :class:`.Query` object's Session.""" + querycontext = util.preloaded.orm_context session = query.session if load: @@ -142,12 +145,17 @@ def merge_result(query, iterator, load=True): else: frozen_result = None + ctx = querycontext.QueryCompileState._create_for_legacy_query( + query, entities_only=True + ) + autoflush = session.autoflush try: session.autoflush = False - single_entity = not frozen_result and len(query._entities) == 1 + single_entity = not frozen_result and len(ctx._entities) == 1 + if single_entity: - if isinstance(query._entities[0], querylib._MapperEntity): + if isinstance(ctx._entities[0], querycontext._MapperEntity): result = [ session._merge( attributes.instance_state(instance), @@ -163,14 +171,16 @@ def merge_result(query, iterator, load=True): else: mapped_entities = [ i - for i, e in enumerate(query._entities) - if isinstance(e, querylib._MapperEntity) + for i, e in enumerate(ctx._entities) + if isinstance(e, querycontext._MapperEntity) ] result = [] - keys = [ent._label_name for ent in query._entities] + keys = [ent._label_name for ent in ctx._entities] + keyed_tuple = result_tuple( - keys, [tuple(ent.entities) for ent in query._entities] + keys, [ent._extra_entities for ent in ctx._entities] ) + for row in iterator: newrow = list(row) for i in mapped_entities: @@ -270,7 +280,7 @@ def load_on_pk_identity( q = query._clone() if primary_key_identity is not None: - mapper = query._mapper_zero() + mapper = query._only_full_mapper_zero("load_on_pk_identity") (_get_clause, _get_params) = mapper._get_clause @@ -286,6 +296,7 @@ def load_on_pk_identity( if value is None ] ) + _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones) if len(nones) == len(primary_key_identity): @@ -294,8 +305,11 @@ def load_on_pk_identity( "object. This condition may raise an error in a future " "release." ) - _get_clause = q._adapt_clause(_get_clause, True, False) - q._criterion = _get_clause + + # TODO: can mapper._get_clause be pre-adapted? + q._where_criteria = ( + sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}), + ) params = dict( [ @@ -306,7 +320,7 @@ def load_on_pk_identity( ] ) - q._params = params + q.load_options += {"_params": params} # with_for_update needs to be query.LockmodeArg() if with_for_update is not None: @@ -319,8 +333,9 @@ def load_on_pk_identity( version_check = False if refresh_state and refresh_state.load_options: + # if refresh_state.load_path.parent: q = q._with_current_path(refresh_state.load_path.parent) - q = q._conditional_options(refresh_state.load_options) + q = q.options(refresh_state.load_options) q._get_options( populate_existing=bool(refresh_state), @@ -338,7 +353,7 @@ def load_on_pk_identity( def _setup_entity_query( - context, + compile_state, mapper, query_entity, path, @@ -359,19 +374,27 @@ def _setup_entity_query( quick_populators = {} - path.set(context.attributes, "memoized_setups", quick_populators) + path.set(compile_state.attributes, "memoized_setups", quick_populators) + + # for the lead entities in the path, e.g. not eager loads, and + # assuming a user-passed aliased class, e.g. not a from_self() or any + # implicit aliasing, don't add columns to the SELECT that aren't + # in the thing that's aliased. + check_for_adapt = adapter and len(path) == 1 and path[-1].is_aliased_class for value in poly_properties: if only_load_props and value.key not in only_load_props: continue + value.setup( - context, + compile_state, query_entity, path, adapter, only_load_props=only_load_props, column_collection=column_collection, memoized_populators=quick_populators, + check_for_adapt=check_for_adapt, **kw ) @@ -448,21 +471,6 @@ def _instance_processor( populators["new"].append((prop.key, prop._raise_column_loader)) else: getter = None - # the "adapter" can be here via different paths, - # e.g. via adapter present at setup_query or adapter - # applied to the query afterwards via eager load subquery. - # If the column here - # were already a product of this adapter, sending it through - # the adapter again can return a totally new expression that - # won't be recognized in the result, and the ColumnAdapter - # currently does not accommodate for this. OTOH, if the - # column were never applied through this adapter, we may get - # None back, in which case we still won't get our "getter". - # so try both against result._getter(). See issue #4048 - if adapter: - adapted_col = adapter.columns[col] - if adapted_col is not None: - getter = result._getter(adapted_col, False) if not getter: getter = result._getter(col, False) if getter: @@ -481,8 +489,8 @@ def _instance_processor( propagate_options = context.propagate_options load_path = ( - context.query._current_path + path - if context.query._current_path.path + context.compile_state.current_path + path + if context.compile_state.current_path.path else path ) @@ -764,7 +772,7 @@ def _load_subclass_via_in(context, path, entity): cache_path=path, ) - if orig_query._populate_existing: + if context.populate_existing: q2.add_criteria(lambda q: q.populate_existing()) q2(context.session).params( @@ -1065,10 +1073,16 @@ def load_scalar_attributes(mapper, state, attribute_names, passive): # by default statement = mapper._optimized_get_statement(state, attribute_names) if statement is not None: - wp = aliased(mapper, statement) + # this was previously aliased(mapper, statement), however, + # statement is a select() and Query's coercion now raises for this + # since you can't "select" from a "SELECT" statement. only + # from_statement() allows this. + # note: using from_statement() here means there is an adaption + # with adapt_on_names set up. the other option is to make the + # aliased() against a subquery which affects the SQL. result = load_on_ident( - session.query(wp) - .options(strategy_options.Load(wp).undefer("*")) + session.query(mapper) + .options(strategy_options.Load(mapper).undefer("*")) .from_statement(statement), None, only_load_props=attribute_names, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c05705b67..a6fb1039f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -37,6 +37,8 @@ from .interfaces import _MappedAttribute from .interfaces import EXT_SKIP from .interfaces import InspectionAttr from .interfaces import MapperProperty +from .interfaces import ORMEntityColumnsClauseRole +from .interfaces import ORMFromClauseRole from .path_registry import PathRegistry from .. import event from .. import exc as sa_exc @@ -70,7 +72,12 @@ _CONFIGURE_MUTEX = util.threading.RLock() @inspection._self_inspects @log.class_logger -class Mapper(sql_base.HasCacheKey, InspectionAttr): +class Mapper( + ORMFromClauseRole, + ORMEntityColumnsClauseRole, + sql_base.MemoizedHasCacheKey, + InspectionAttr, +): """Define the correlation of class attributes to database table columns. @@ -2085,6 +2092,20 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): return self._mappers_from_spec(*self.with_polymorphic) @HasMemoized.memoized_attribute + def _post_inspect(self): + """This hook is invoked by attribute inspection. + + E.g. when Query calls: + + coercions.expect(roles.ColumnsClauseRole, ent, keep_inspect=True) + + This allows the inspection process run a configure mappers hook. + + """ + if Mapper._new_mappers: + configure_mappers() + + @HasMemoized.memoized_attribute def _with_polymorphic_selectable(self): if not self.with_polymorphic: return self.persist_selectable @@ -2207,12 +2228,16 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): for table, columns in self._cols_by_table.items() ) - # temporarily commented out until we fix an issue in the serializer - # @_memoized_configured_property.method + @HasMemoized.memoized_instancemethod def __clause_element__(self): - return self.selectable # ._annotate( - # {"parententity": self, "parentmapper": self} - # ) + return self.selectable._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "compile_state_plugin": "orm", + } + ) @property def selectable(self): @@ -2386,6 +2411,10 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): return self._filter_properties(descriptor_props.SynonymProperty) + @property + def entity_namespace(self): + return self.class_ + @HasMemoized.memoized_attribute def column_attrs(self): """Return a namespace of all :class:`.ColumnProperty` @@ -2961,18 +2990,24 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): (prop.key,), {"do_nothing": True} ) - if len(self.primary_key) > 1: - in_expr = sql.tuple_(*self.primary_key) + primary_key = [ + sql_util._deep_annotate(pk, {"_orm_adapt": True}) + for pk in self.primary_key + ] + + if len(primary_key) > 1: + in_expr = sql.tuple_(*primary_key) else: - in_expr = self.primary_key[0] + in_expr = primary_key[0] if entity.is_aliased_class: assert entity.mapper is self + q = baked.BakedQuery( self._compiled_cache, - lambda session: session.query(entity) - .select_entity_from(entity.selectable) - ._adapt_all_clauses(), + lambda session: session.query(entity).select_entity_from( + entity.selectable + ), (self,), ) q.spoil() @@ -2985,7 +3020,7 @@ class Mapper(sql_base.HasCacheKey, InspectionAttr): q += lambda q: q.filter( in_expr.in_(sql.bindparam("primary_keys", expanding=True)) - ).order_by(*self.primary_key) + ).order_by(*primary_key) return q, enable_opt, disable_opt diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 87bc8ea1d..d14f6c27b 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,7 +23,6 @@ from . import evaluator from . import exc as orm_exc from . import loading from . import sync -from .base import _entity_descriptor from .base import state_str from .. import exc as sa_exc from .. import sql @@ -1653,15 +1652,14 @@ class BulkUD(object): def __init__(self, query): self.query = query.enable_eagerloads(False) - self.mapper = self.query._bind_mapper() self._validate_query_state() def _validate_query_state(self): for attr, methname, notset, op in ( - ("_limit", "limit()", None, operator.is_), - ("_offset", "offset()", None, operator.is_), - ("_order_by", "order_by()", False, operator.is_), - ("_group_by", "group_by()", False, operator.is_), + ("_limit_clause", "limit()", None, operator.is_), + ("_offset_clause", "offset()", None, operator.is_), + ("_order_by_clauses", "order_by()", (), operator.eq), + ("_group_by_clauses", "group_by()", (), operator.eq), ("_distinct", "distinct()", False, operator.is_), ( "_from_obj", @@ -1669,6 +1667,12 @@ class BulkUD(object): (), operator.eq, ), + ( + "_legacy_setup_joins", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), ): if not op(getattr(self.query, attr), notset): raise sa_exc.InvalidRequestError( @@ -1710,18 +1714,24 @@ class BulkUD(object): def _do_before_compile(self): raise NotImplementedError() - @util.preload_module("sqlalchemy.orm.query") + @util.preload_module("sqlalchemy.orm.context") def _do_pre(self): - querylib = util.preloaded.orm_query + query_context = util.preloaded.orm_context query = self.query - self.context = querylib.QueryContext(query) + self.compile_state = ( + self.context + ) = compile_state = query._compile_state() + + self.mapper = compile_state._bind_mapper() - if isinstance(query._entities[0], querylib._ColumnEntity): + if isinstance( + compile_state._entities[0], query_context._RawColumnEntity, + ): # check for special case of query(table) tables = set() - for ent in query._entities: - if not isinstance(ent, querylib._ColumnEntity): + for ent in compile_state._entities: + if not isinstance(ent, query_context._RawColumnEntity,): tables.clear() break else: @@ -1736,14 +1746,14 @@ class BulkUD(object): self.primary_table = tables.pop() else: - self.primary_table = query._only_entity_zero( + self.primary_table = compile_state._only_entity_zero( "This operation requires only one Table or " "entity be specified as the target." ).mapper.local_table session = query.session - if query._autoflush: + if query.load_options._autoflush: session._autoflush() def _do_pre_synchronize(self): @@ -1761,12 +1771,14 @@ class BulkEvaluate(BulkUD): def _do_pre_synchronize(self): query = self.query - target_cls = query._mapper_zero().class_ + target_cls = self.compile_state._mapper_zero().class_ try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - if query.whereclause is not None: - eval_condition = evaluator_compiler.process(query.whereclause) + if query._where_criteria: + eval_condition = evaluator_compiler.process( + *query._where_criteria + ) else: def eval_condition(obj): @@ -1802,12 +1814,11 @@ class BulkFetch(BulkUD): def _do_pre_synchronize(self): query = self.query session = query.session - context = query._compile_context() - select_stmt = context.statement.with_only_columns( + select_stmt = self.compile_state.statement.with_only_columns( self.primary_table.primary_key ) self.matched_rows = session.execute( - select_stmt, mapper=self.mapper, params=query._params + select_stmt, mapper=self.mapper, params=query.load_options._params ).fetchall() @@ -1850,7 +1861,7 @@ class BulkUpdate(BulkUD): ): if self.mapper: if isinstance(k, util.string_types): - desc = _entity_descriptor(self.mapper, k) + desc = sql.util._entity_namespace_key(self.mapper, k) values.extend(desc._bulk_update_tuples(v)) elif isinstance(k, attributes.QueryableAttribute): values.extend(k._bulk_update_tuples(v)) @@ -1890,11 +1901,10 @@ class BulkUpdate(BulkUD): values = dict(values) update_stmt = sql.update( - self.primary_table, - self.context.whereclause, - values, - **self.update_kwargs - ) + self.primary_table, **self.update_kwargs + ).values(values) + + update_stmt._where_criteria = self.compile_state._where_criteria self._execute_stmt(update_stmt) @@ -1929,7 +1939,8 @@ class BulkDelete(BulkUD): self.query = new_query def _do_exec(self): - delete_stmt = sql.delete(self.primary_table, self.context.whereclause) + delete_stmt = sql.delete(self.primary_table,) + delete_stmt._where_criteria = self.compile_state._where_criteria self._execute_stmt(delete_stmt) @@ -1994,7 +2005,7 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): def _do_post_synchronize(self): session = self.query.session - target_mapper = self.query._mapper_zero() + target_mapper = self.compile_state._mapper_zero() states = set( [ @@ -2024,7 +2035,7 @@ class BulkDeleteFetch(BulkFetch, BulkDelete): def _do_post_synchronize(self): session = self.query.session - target_mapper = self.query._mapper_zero() + target_mapper = self.compile_state._mapper_zero() for primary_key in self.matched_rows: # TODO: inline this and call remove_newly_deleted # once diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 4cf316ac7..027786c19 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -343,13 +343,16 @@ class ColumnProperty(StrategizedProperty): if self.adapter: return self.adapter(self.prop.columns[0], self.prop.key) else: + pe = self._parententity # no adapter, so we aren't aliased # assert self._parententity is self._parentmapper return self.prop.columns[0]._annotate( { - "parententity": self._parententity, - "parentmapper": self._parententity, + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, "orm_key": self.prop.key, + "compile_state_plugin": "orm", } ) @@ -383,6 +386,7 @@ class ColumnProperty(StrategizedProperty): "parententity": self._parententity, "parentmapper": self._parententity, "orm_key": self.prop.key, + "compile_state_plugin": "orm", } ) for col in self.prop.columns diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 7e65f26e2..8a861c3dc 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -19,55 +19,50 @@ database to return iterable result sets. """ -from itertools import chain - from . import attributes from . import exc as orm_exc from . import interfaces from . import loading from . import persistence from .base import _assertions -from .base import _entity_descriptor -from .base import _is_aliased_class -from .base import _is_mapped_class -from .base import _orm_columns -from .base import InspectionAttr -from .path_registry import PathRegistry -from .util import _entity_corresponds_to +from .context import _column_descriptions +from .context import _legacy_determine_last_joined_entity +from .context import _legacy_filter_by_entity_zero +from .context import QueryCompileState +from .context import QueryContext +from .interfaces import ORMColumnsClauseRole from .util import aliased from .util import AliasedClass -from .util import join as orm_join from .util import object_mapper -from .util import ORMAdapter from .util import with_parent +from .util import with_polymorphic from .. import exc as sa_exc from .. import inspect from .. import inspection from .. import log from .. import sql from .. import util -from ..engine import result_tuple from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util -from ..sql import visitors from ..sql.base import _generative -from ..sql.base import ColumnCollection -from ..sql.base import Generative +from ..sql.base import Executable from ..sql.selectable import ForUpdateArg +from ..sql.selectable import HasHints +from ..sql.selectable import HasPrefixes +from ..sql.selectable import HasSuffixes +from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.util import _entity_namespace_key from ..util import collections_abc __all__ = ["Query", "QueryContext", "aliased"] -_path_registry = PathRegistry.root - - @inspection._self_inspects @log.class_logger -class Query(Generative): +class Query(HasPrefixes, HasSuffixes, HasHints, Executable): """ORM-level SQL construction object. :class:`_query.Query` @@ -90,85 +85,35 @@ class Query(Generative): """ - _only_return_tuples = False - _enable_eagerloads = True - _enable_assertions = True - _with_labels = False - _criterion = None - _yield_per = None - _order_by = False - _group_by = False - _having = None + # elements that are in Core and can be cached in the same way + _where_criteria = () + _having_criteria = () + + _order_by_clauses = () + _group_by_clauses = () + _limit_clause = None + _offset_clause = None + _distinct = False - _prefixes = None - _suffixes = None - _offset = None - _limit = None + _distinct_on = () + _for_update_arg = None - _statement = None - _correlate = frozenset() - _populate_existing = False - _invoke_all_eagers = True - _version_check = False - _autoflush = True - _only_load_props = None - _refresh_state = None - _refresh_identity_token = None + _correlate = () + _auto_correlate = True _from_obj = () - _join_entities = () - _select_from_entity = None - _filter_aliases = () - _from_obj_alias = None - _joinpath = _joinpoint = util.immutabledict() - _execution_options = util.immutabledict() - _params = util.immutabledict() - _attributes = util.immutabledict() - _with_options = () - _with_hints = () - _enable_single_crit = True - _orm_only_adapt = True - _orm_only_from_obj_alias = True - _current_path = _path_registry - _has_mapper_entities = False - _bake_ok = True - - lazy_loaded_from = None - """An :class:`.InstanceState` that is using this :class:`_query.Query` - for a - lazy load operation. - - The primary rationale for this attribute is to support the horizontal - sharding extension, where it is available within specific query - execution time hooks created by this extension. To that end, the - attribute is only intended to be meaningful at **query execution time**, - and importantly not any time prior to that, including query compilation - time. - - .. note:: - - Within the realm of regular :class:`_query.Query` usage, - this attribute is - set by the lazy loader strategy before the query is invoked. However - there is no established hook that is available to reliably intercept - this value programmatically. It is set by the lazy loading strategy - after any mapper option objects would have been applied, and now that - the lazy loading strategy in the ORM makes use of "baked" queries to - cache SQL compilation, the :meth:`.QueryEvents.before_compile` hook is - also not reliable. - - Currently, setting the :paramref:`_orm.relationship.bake_queries` to - ``False`` on the target :func:`_orm.relationship`, - and then making use of - the :meth:`.QueryEvents.before_compile` event hook, is the only - available programmatic path to intercepting this attribute. In future - releases, there will be new hooks available that allow interception of - the :class:`_query.Query` before it is executed, - rather than before it is - compiled. - - .. versionadded:: 1.2.9 + _setup_joins = () + _legacy_setup_joins = () + _label_style = LABEL_STYLE_NONE - """ + compile_options = QueryCompileState.default_compile_options + + load_options = QueryContext.default_load_options + + # local Query builder state, not needed for + # compilation or execution + _aliased_generation = None + _enable_assertions = True + _last_joined_entity = None def __init__(self, entities, session=None): """Construct a :class:`_query.Query` directly. @@ -197,243 +142,58 @@ class Query(Generative): :meth:`_query.Query.with_session` """ + self.session = session - self._polymorphic_adapters = {} self._set_entities(entities) - def _set_entities(self, entities, entity_wrapper=None): - if entity_wrapper is None: - entity_wrapper = _QueryEntity - self._entities = [] - self._primary_entity = None - self._has_mapper_entities = False - - if entities != (): - for ent in util.to_list(entities): - entity_wrapper(self, ent) - - def _setup_query_adapters(self, entity, ext_info): - if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic: - if ( - ext_info.mapper.persist_selectable - not in self._polymorphic_adapters - ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - ext_info.selectable, - ext_info.mapper._equivalent_columns, - ), - ) - - def _mapper_loads_polymorphically_with(self, mapper, adapter): - for m2 in mapper._with_polymorphic_mappers or [mapper]: - self._polymorphic_adapters[m2] = adapter - for m in m2.iterate_to_root(): - self._polymorphic_adapters[m.local_table] = adapter - - def _set_select_from(self, obj, set_base_alias): - fa = [] - select_from_alias = None - - for from_obj in obj: - info = inspect(from_obj) - if hasattr(info, "mapper") and ( - info.is_mapper or info.is_aliased_class - ): - self._select_from_entity = info - if set_base_alias and not info.is_aliased_class: - raise sa_exc.ArgumentError( - "A selectable (FromClause) instance is " - "expected when the base alias is being set." - ) - fa.append(info.selectable) - else: - from_obj = coercions.expect( - roles.StrictFromClauseRole, from_obj, allow_select=True - ) - if set_base_alias: - select_from_alias = from_obj - fa.append(from_obj) - - self._from_obj = tuple(fa) - - if ( - set_base_alias - and len(self._from_obj) == 1 - and isinstance( - select_from_alias, sql.selectable.AliasedReturnsRows - ) - ): - equivs = self.__all_equivs() - self._from_obj_alias = sql_util.ColumnAdapter( - self._from_obj[0], equivs - ) - self._enable_single_crit = False - elif ( - set_base_alias - and len(self._from_obj) == 1 - and hasattr(info, "mapper") - and info.is_aliased_class - ): - self._from_obj_alias = info._adapter - self._enable_single_crit = False - - def _reset_polymorphic_adapter(self, mapper): - for m2 in mapper._with_polymorphic_mappers: - self._polymorphic_adapters.pop(m2, None) - for m in m2.iterate_to_root(): - self._polymorphic_adapters.pop(m.local_table, None) - - def _adapt_polymorphic_element(self, element): - if "parententity" in element._annotations: - search = element._annotations["parententity"] - alias = self._polymorphic_adapters.get(search, None) - if alias: - return alias.adapt_clause(element) - - if isinstance(element, expression.FromClause): - search = element - elif hasattr(element, "table"): - search = element.table - else: - return None - - alias = self._polymorphic_adapters.get(search, None) - if alias: - return alias.adapt_clause(element) - - def _adapt_col_list(self, cols): - return [ - self._adapt_clause(coercions.expect(roles.ByOfRole, o), True, True) - for o in cols + def _set_entities(self, entities): + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, ent) + for ent in util.to_list(entities) ] - @_generative - def _set_lazyload_from(self, state): - self.lazy_loaded_from = state - - @_generative - def _adapt_all_clauses(self): - self._orm_only_adapt = False - - def _adapt_clause(self, clause, as_filter, orm_only): - """Adapt incoming clauses to transformations which - have been applied within this query.""" - - adapters = [] - # do we adapt all expression elements or only those - # tagged as 'ORM' constructs ? - if not self._orm_only_adapt: - orm_only = False - - if as_filter and self._filter_aliases: - for fa in self._filter_aliases: - adapters.append((orm_only, fa.replace)) - - if self._from_obj_alias: - # for the "from obj" alias, apply extra rule to the - # 'ORM only' check, if this query were generated from a - # subquery of itself, i.e. _from_selectable(), apply adaption - # to all SQL constructs. - adapters.append( - ( - orm_only if self._orm_only_from_obj_alias else False, - self._from_obj_alias.replace, - ) - ) - - if self._polymorphic_adapters: - adapters.append((orm_only, self._adapt_polymorphic_element)) - - if not adapters: - return clause - - def replace(elem): - is_orm_adapt = ( - "_orm_adapt" in elem._annotations - or "parententity" in elem._annotations - ) - for _orm_only, adapter in adapters: - if not _orm_only or is_orm_adapt: - e = adapter(elem) - if e is not None: - return e - - return visitors.replacement_traverse(clause, {}, replace) - - def _query_entity_zero(self): - """Return the first QueryEntity.""" - return self._entities[0] - - def _mapper_zero(self): - """return the Mapper associated with the first QueryEntity.""" - return self._entities[0].mapper - - def _entity_zero(self): - """Return the 'entity' (mapper or AliasedClass) associated - with the first QueryEntity, or alternatively the 'select from' - entity if specified.""" - - return ( - self._select_from_entity - if self._select_from_entity is not None - else self._query_entity_zero().entity_zero - ) - - def _deep_entity_zero(self): - """Return a 'deep' entity; this is any entity we can find associated - with the first entity / column experssion. this is used only for - session.get_bind(). - - """ - - if ( - self._select_from_entity is not None - and not self._select_from_entity.is_clause_element - ): - return self._select_from_entity.mapper - for ent in self._entities: - ezero = ent._deep_entity_zero() - if ezero is not None: - return ezero.mapper - else: + def _entity_from_pre_ent_zero(self): + if not self._raw_columns: return None - @property - def _mapper_entities(self): - for ent in self._entities: - if isinstance(ent, _MapperEntity): - yield ent - - def _joinpoint_zero(self): - return self._joinpoint.get("_joinpoint_entity", self._entity_zero()) + ent = self._raw_columns[0] - def _bind_mapper(self): - return self._deep_entity_zero() + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + elif isinstance(ent, ORMColumnsClauseRole): + return ent.entity + elif "bundle" in ent._annotations: + return ent._annotations["bundle"] + else: + return ent def _only_full_mapper_zero(self, methname): - if self._entities != [self._primary_entity]: + if ( + len(self._raw_columns) != 1 + or "parententity" not in self._raw_columns[0]._annotations + or not self._raw_columns[0].is_selectable + ): raise sa_exc.InvalidRequestError( "%s() can only be used against " "a single mapped class." % methname ) - return self._primary_entity.entity_zero - def _only_entity_zero(self, rationale=None): - if len(self._entities) > 1: - raise sa_exc.InvalidRequestError( - rationale - or "This operation requires a Query " - "against a single mapper." + return self._raw_columns[0]._annotations["parententity"] + + def _set_select_from(self, obj, set_base_alias): + fa = [ + coercions.expect( + roles.StrictFromClauseRole, elem, allow_select=True ) - return self._entity_zero() + for elem in obj + ] - def __all_equivs(self): - equivs = {} - for ent in self._mapper_entities: - equivs.update(ent.mapper._equivalent_columns) - return equivs + self.compile_options += {"_set_base_alias": set_base_alias} + self._from_obj = tuple(fa) + + @_generative + def _set_lazyload_from(self, state): + self.load_options += {"_lazy_loaded_from": state} def _get_condition(self): return self._no_criterion_condition( @@ -447,13 +207,14 @@ class Query(Generative): if not self._enable_assertions: return if ( - self._criterion is not None - or self._statement is not None + self._where_criteria + or self.compile_options._statement is not None or self._from_obj - or self._limit is not None - or self._offset is not None - or self._group_by - or (order_by and self._order_by) + or self._legacy_setup_joins + or self._limit_clause is not None + or self._offset_clause is not None + or self._group_by_clauses + or (order_by and self._order_by_clauses) or (distinct and self._distinct) ): raise sa_exc.InvalidRequestError( @@ -464,14 +225,18 @@ class Query(Generative): def _no_criterion_condition(self, meth, order_by=True, distinct=True): self._no_criterion_assertion(meth, order_by, distinct) - self._from_obj = () - self._statement = self._criterion = None - self._order_by = self._group_by = self._distinct = False + self._from_obj = self._legacy_setup_joins = () + if self.compile_options._statement is not None: + self.compile_options += {"_statement": None} + self._where_criteria = () + self._distinct = False + + self._order_by_clauses = self._group_by_clauses = () def _no_clauseelement_condition(self, meth): if not self._enable_assertions: return - if self._order_by: + if self._order_by_clauses: raise sa_exc.InvalidRequestError( "Query.%s() being called on a " "Query with existing criterion. " % meth @@ -481,7 +246,7 @@ class Query(Generative): def _no_statement_condition(self, meth): if not self._enable_assertions: return - if self._statement is not None: + if self.compile_options._statement is not None: raise sa_exc.InvalidRequestError( ( "Query.%s() being called on a Query with an existing full " @@ -493,7 +258,7 @@ class Query(Generative): def _no_limit_offset(self, meth): if not self._enable_assertions: return - if self._limit is not None or self._offset is not None: + if self._limit_clause is not None or self._offset_clause is not None: raise sa_exc.InvalidRequestError( "Query.%s() being called on a Query which already has LIMIT " "or OFFSET applied. To modify the row-limited results of a " @@ -510,16 +275,26 @@ class Query(Generative): refresh_state=None, identity_token=None, ): - if populate_existing: - self._populate_existing = populate_existing + load_options = {} + compile_options = {} + if version_check: - self._version_check = version_check + load_options["_version_check"] = version_check + if populate_existing: + load_options["_populate_existing"] = populate_existing if refresh_state: - self._refresh_state = refresh_state + load_options["_refresh_state"] = refresh_state + compile_options["_for_refresh_state"] = True if only_load_props: - self._only_load_props = set(only_load_props) + compile_options["_only_load_props"] = frozenset(only_load_props) if identity_token: - self._refresh_identity_token = identity_token + load_options["_refresh_identity_token"] = identity_token + + if load_options: + self.load_options += load_options + if compile_options: + self.compile_options += compile_options + return self def _clone(self): @@ -535,12 +310,48 @@ class Query(Generative): """ - stmt = self._compile_context(for_statement=True).statement - if self._params: - stmt = stmt.params(self._params) + # .statement can return the direct future.Select() construct here, as + # long as we are not using subsequent adaption features that + # are made against raw entities, e.g. from_self(), with_polymorphic(), + # select_entity_from(). If these features are being used, then + # the Select() we return will not have the correct .selected_columns + # collection and will not embed in subsequent queries correctly. + # We could find a way to make this collection "correct", however + # this would not be too different from doing the full compile as + # we are doing in any case, the Select() would still not have the + # proper state for other attributes like whereclause, order_by, + # and these features are all deprecated in any case. + # + # for these reasons, Query is not a Select, it remains an ORM + # object for which __clause_element__() must be called in order for + # it to provide a real expression object. + # + # from there, it starts to look much like Query itself won't be + # passed into the execute process and wont generate its own cache + # key; this will all occur in terms of the ORM-enabled Select. + if ( + not self.compile_options._set_base_alias + and not self.compile_options._with_polymorphic_adapt_map + and self.compile_options._statement is None + ): + # if we don't have legacy top level aliasing features in use + # then convert to a future select() directly + stmt = self._statement_20() + else: + stmt = QueryCompileState._create_for_legacy_query( + self, for_statement=True + ).statement + + if self.load_options._params: + # this is the search and replace thing. this is kind of nuts + # to be doing here. + stmt = stmt.params(self.load_options._params) return stmt + def _statement_20(self): + return QueryCompileState._create_future_select_from_query(self) + def subquery(self, name=None, with_labels=False, reduce_columns=False): """return the full SELECT statement represented by this :class:`_query.Query`, embedded within an @@ -686,7 +497,7 @@ class Query(Generative): :meth:`_query.Query.is_single_entity` """ - self._only_return_tuples = value + self.load_options += dict(_only_return_tuples=value) @property def is_single_entity(self): @@ -705,9 +516,13 @@ class Query(Generative): """ return ( - not self._only_return_tuples - and len(self._entities) == 1 - and self._entities[0].supports_single_entity + not self.load_options._only_return_tuples + and len(self._raw_columns) == 1 + and "parententity" in self._raw_columns[0]._annotations + and isinstance( + self._raw_columns[0]._annotations["parententity"], + ORMColumnsClauseRole, + ) ) @_generative @@ -726,7 +541,7 @@ class Query(Generative): selectable, or when using :meth:`_query.Query.yield_per`. """ - self._enable_eagerloads = value + self.compile_options += {"_enable_eagerloads": value} @_generative def with_labels(self): @@ -754,7 +569,13 @@ class Query(Generative): """ - self._with_labels = True + self._label_style = LABEL_STYLE_TABLENAME_PLUS_COL + + apply_labels = with_labels + + @property + def use_labels(self): + return self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL @_generative def enable_assertions(self, value): @@ -787,7 +608,9 @@ class Query(Generative): criterion has been established. """ - return self._criterion + return sql.elements.BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) @_generative def _with_current_path(self, path): @@ -799,8 +622,9 @@ class Query(Generative): query intended for the deferred load. """ - self._current_path = path + self.compile_options += {"_current_path": path} + # TODO: removed in 2.0 @_generative @_assertions(_no_clauseelement_condition) def with_polymorphic( @@ -823,24 +647,19 @@ class Query(Generative): """ - if not self._primary_entity: - raise sa_exc.InvalidRequestError( - "No primary mapper set up for this Query." - ) - entity = self._entities[0]._clone() - self._entities = [entity] + self._entities[1:] - - # NOTE: we likely should set primary_entity here, however - # this hasn't been changed for many years and we'd like to - # deprecate this method. + entity = _legacy_filter_by_entity_zero(self) - entity.set_with_polymorphic( - self, + wp = with_polymorphic( + entity, cls_or_mappers, selectable=selectable, polymorphic_on=polymorphic_on, ) + self.compile_options = self.compile_options.add_to_element( + "_with_polymorphic_adapt_map", ((entity, inspect(wp)),) + ) + @_generative def yield_per(self, count): r"""Yield only ``count`` rows at a time. @@ -906,7 +725,7 @@ class Query(Generative): :meth:`_query.Query.enable_eagerloads` """ - self._yield_per = count + self.load_options += {"_yield_per": count} self._execution_options = self._execution_options.union( {"stream_results": True, "max_row_buffer": count} ) @@ -1041,7 +860,7 @@ class Query(Generative): ) if ( - not self._populate_existing + not self.load_options._populate_existing and not mapper.always_refresh and self._for_update_arg is None ): @@ -1062,12 +881,51 @@ class Query(Generative): return db_load_fn(self, primary_key_identity) + @property + def lazy_loaded_from(self): + """An :class:`.InstanceState` that is using this :class:`_query.Query` + for a lazy load operation. + + The primary rationale for this attribute is to support the horizontal + sharding extension, where it is available within specific query + execution time hooks created by this extension. To that end, the + attribute is only intended to be meaningful at **query execution + time**, and importantly not any time prior to that, including query + compilation time. + + .. note:: + + Within the realm of regular :class:`_query.Query` usage, this + attribute is set by the lazy loader strategy before the query is + invoked. However there is no established hook that is available to + reliably intercept this value programmatically. It is set by the + lazy loading strategy after any mapper option objects would have + been applied, and now that the lazy loading strategy in the ORM + makes use of "baked" queries to cache SQL compilation, the + :meth:`.QueryEvents.before_compile` hook is also not reliable. + + Currently, setting the :paramref:`_orm.relationship.bake_queries` + to ``False`` on the target :func:`_orm.relationship`, and then + making use of the :meth:`.QueryEvents.before_compile` event hook, + is the only available programmatic path to intercepting this + attribute. In future releases, there will be new hooks available + that allow interception of the :class:`_query.Query` before it is + executed, rather than before it is compiled. + + .. versionadded:: 1.2.9 + + """ + return self.load_options._lazy_loaded_from + + @property + def _current_path(self): + return self.compile_options._current_path + @_generative - def correlate(self, *args): - """Return a :class:`_query.Query` - construct which will correlate the given - FROM clauses to that of an enclosing :class:`_query.Query` or - :func:`_expression.select`. + def correlate(self, *fromclauses): + """Return a :class:`.Query` construct which will correlate the given + FROM clauses to that of an enclosing :class:`.Query` or + :func:`~.expression.select`. The method here accepts mapped classes, :func:`.aliased` constructs, and :func:`.mapper` constructs as arguments, which are resolved into @@ -1085,15 +943,13 @@ class Query(Generative): """ - for s in args: - if s is None: - self._correlate = self._correlate.union([None]) - else: - self._correlate = self._correlate.union( - sql_util.surface_selectables( - coercions.expect(roles.FromClauseRole, s) - ) - ) + self._auto_correlate = False + if fromclauses and fromclauses[0] is None: + self._correlate = () + else: + self._correlate = set(self._correlate).union( + coercions.expect(roles.FromClauseRole, f) for f in fromclauses + ) @_generative def autoflush(self, setting): @@ -1105,7 +961,7 @@ class Query(Generative): to disable autoflush for a specific Query. """ - self._autoflush = setting + self.load_options += {"_autoflush": setting} @_generative def populate_existing(self): @@ -1120,7 +976,7 @@ class Query(Generative): This method is not intended for general use. """ - self._populate_existing = True + self.load_options += {"_populate_existing": True} @_generative def _with_invoke_all_eagers(self, value): @@ -1131,8 +987,9 @@ class Query(Generative): Default is that of :attr:`_query.Query._invoke_all_eagers`. """ - self._invoke_all_eagers = value + self.load_options += {"_invoke_all_eagers": value} + # TODO: removed in 2.0, use with_parent standalone in filter @util.preload_module("sqlalchemy.orm.relationships") def with_parent(self, instance, property=None, from_entity=None): # noqa """Add filtering criterion that relates the given instance @@ -1166,9 +1023,9 @@ class Query(Generative): if from_entity: entity_zero = inspect(from_entity) else: - entity_zero = self._entity_zero() + entity_zero = _legacy_filter_by_entity_zero(self) if property is None: - + # TODO: deprecate, property has to be supplied mapper = object_mapper(instance) for prop in mapper.iterate_properties: @@ -1196,10 +1053,14 @@ class Query(Generative): to be returned.""" if alias is not None: + # TODO: deprecate entity = aliased(entity, alias) - self._entities = list(self._entities) - _MapperEntity(self, entity) + self._raw_columns = list(self._raw_columns) + + self._raw_columns.append( + coercions.expect(roles.ColumnsClauseRole, entity) + ) @_generative def with_session(self, session): @@ -1409,6 +1270,7 @@ class Query(Generative): those being selected. """ + fromclause = ( self.with_labels() .enable_eagerloads(False) @@ -1416,62 +1278,85 @@ class Query(Generative): .subquery() ._anonymous_fromclause() ) - q = self._from_selectable(fromclause) - q._select_from_entity = self._entity_zero() + + parententity = self._raw_columns[0]._annotations.get("parententity") + if parententity: + ac = aliased(parententity, alias=fromclause) + q = self._from_selectable(ac) + else: + q = self._from_selectable(fromclause) + if entities: q._set_entities(entities) return q @_generative def _set_enable_single_crit(self, val): - self._enable_single_crit = val + self.compile_options += {"_enable_single_crit": val} @_generative - def _from_selectable(self, fromclause): + def _from_selectable(self, fromclause, set_entity_from=True): for attr in ( - "_statement", - "_criterion", - "_order_by", - "_group_by", - "_limit", - "_offset", - "_joinpath", - "_joinpoint", + "_where_criteria", + "_order_by_clauses", + "_group_by_clauses", + "_limit_clause", + "_offset_clause", + "_last_joined_entity", + "_legacy_setup_joins", "_distinct", - "_having", + "_having_criteria", "_prefixes", "_suffixes", ): self.__dict__.pop(attr, None) - self._set_select_from([fromclause], True) - self._enable_single_crit = False + self._set_select_from([fromclause], set_entity_from) + self.compile_options += { + "_enable_single_crit": False, + "_statement": None, + } # this enables clause adaptation for non-ORM # expressions. - self._orm_only_from_obj_alias = False - - old_entities = self._entities - self._entities = [] - for e in old_entities: - e.adapt_to_selectable(self, self._from_obj[0]) + # legacy. see test/orm/test_froms.py for various + # "oldstyle" tests that rely on this and the correspoinding + # "newtyle" that do not. + self.compile_options += {"_orm_only_from_obj_alias": False} + @util.deprecated( + "1.4", + ":meth:`_query.Query.values` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.with_entities`", + ) def values(self, *columns): """Return an iterator yielding result tuples corresponding - to the given list of columns""" + to the given list of columns + + """ if not columns: return iter(()) - q = self._clone() - q._set_entities(columns, entity_wrapper=_ColumnEntity) - if not q._yield_per: - q._yield_per = 10 + q = self._clone().enable_eagerloads(False) + q._set_entities(columns) + if not q.load_options._yield_per: + q.load_options += {"_yield_per": 10} return iter(q) _values = values + @util.deprecated( + "1.4", + ":meth:`_query.Query.value` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.with_entities` " + "in combination with :meth:`_query.Query.scalar`", + ) def value(self, column): """Return a scalar result corresponding to the given - column expression.""" + column expression. + + """ try: return next(self.values(column))[0] except StopIteration: @@ -1509,10 +1394,11 @@ class Query(Generative): """Add one or more column expressions to the list of result columns to be returned.""" - self._entities = list(self._entities) + self._raw_columns = list(self._raw_columns) - for c in column: - _ColumnEntity(self, c) + self._raw_columns.extend( + coercions.expect(roles.ColumnsClauseRole, c) for c in column + ) @util.deprecated( "1.4", @@ -1527,6 +1413,7 @@ class Query(Generative): """ return self.add_columns(column) + @_generative def options(self, *args): """Return a new :class:`_query.Query` object, applying the given list of @@ -1542,26 +1429,18 @@ class Query(Generative): :ref:`relationship_loader_options` """ - return self._options(False, *args) - - def _conditional_options(self, *args): - return self._options(True, *args) - @_generative - def _options(self, conditional, *args): - # most MapperOptions write to the '_attributes' dictionary, - # so copy that as well - self._attributes = dict(self._attributes) - if "_unbound_load_dedupes" not in self._attributes: - self._attributes["_unbound_load_dedupes"] = set() opts = tuple(util.flatten_iterator(args)) - self._with_options = self._with_options + opts - if conditional: + if self.compile_options._current_path: for opt in opts: - opt.process_query_conditionally(self) + if opt._is_legacy_option: + opt.process_query_conditionally(self) else: for opt in opts: - opt.process_query(self) + if opt._is_legacy_option: + opt.process_query(self) + + self._with_options += opts def with_transformation(self, fn): """Return a new :class:`_query.Query` object transformed by @@ -1582,53 +1461,6 @@ class Query(Generative): """ return fn(self) - @_generative - def with_hint(self, selectable, text, dialect_name="*"): - """Add an indexing or other executional context - hint for the given entity or selectable to - this :class:`_query.Query`. - - Functionality is passed straight through to - :meth:`~sqlalchemy.sql.expression.Select.with_hint`, - with the addition that ``selectable`` can be a - :class:`_schema.Table`, :class:`_expression.Alias`, - or ORM entity / mapped class - /etc. - - .. seealso:: - - :meth:`_query.Query.with_statement_hint` - - :meth:`.Query.prefix_with` - generic SELECT prefixing which also - can suit some database-specific HINT syntaxes such as MySQL - optimizer hints - - """ - if selectable is not None: - selectable = inspect(selectable).selectable - - self._with_hints += ((selectable, text, dialect_name),) - - def with_statement_hint(self, text, dialect_name="*"): - """add a statement hint to this :class:`_expression.Select`. - - This method is similar to :meth:`_expression.Select.with_hint` - except that - it does not require an individual table, and instead applies to the - statement as a whole. - - This feature calls down into - :meth:`_expression.Select.with_statement_hint`. - - .. versionadded:: 1.0.0 - - .. seealso:: - - :meth:`_query.Query.with_hint` - - """ - return self.with_hint(None, text, dialect_name) - def get_execution_options(self): """ Get the non-SQL options which will take effect during execution. @@ -1720,8 +1552,9 @@ class Query(Generative): "params() takes zero or one positional argument, " "which is a dictionary." ) - self._params = dict(self._params) - self._params.update(kwargs) + params = dict(self.load_options._params) + params.update(kwargs) + self.load_options += {"_params": params} @_generative @_assertions(_no_statement_condition, _no_limit_offset) @@ -1752,12 +1585,35 @@ class Query(Generative): """ for criterion in list(criterion): criterion = coercions.expect(roles.WhereHavingRole, criterion) - criterion = self._adapt_clause(criterion, True, True) - if self._criterion is not None: - self._criterion = self._criterion & criterion - else: - self._criterion = criterion + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv + if self._aliased_generation: + criterion = sql_util._deep_annotate( + criterion, {"aliased_generation": self._aliased_generation} + ) + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + self._where_criteria += (criterion,) + + @util.memoized_property + def _last_joined_entity(self): + if self._legacy_setup_joins: + return _legacy_determine_last_joined_entity( + self._legacy_setup_joins, self._entity_from_pre_ent_zero() + ) + else: + return None + + def _filter_by_zero(self): + if self._legacy_setup_joins: + _last_joined_entity = self._last_joined_entity + if _last_joined_entity is not None: + return _last_joined_entity + + if self._from_obj: + return self._from_obj[0] + + return self._raw_columns[0] def filter_by(self, **kwargs): r"""apply the given filtering criterion to a copy @@ -1783,9 +1639,9 @@ class Query(Generative): :meth:`_query.Query.filter` - filter on SQL expressions. """ + from_entity = self._filter_by_zero() - zero = self._joinpoint_zero() - if zero is None: + if from_entity is None: raise sa_exc.InvalidRequestError( "Can't use filter_by when the first entity '%s' of a query " "is not a mapped class. Please use the filter method instead, " @@ -1794,40 +1650,45 @@ class Query(Generative): ) clauses = [ - _entity_descriptor(zero, key) == value + _entity_namespace_key(from_entity, key) == value for key, value in kwargs.items() ] return self.filter(*clauses) @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def order_by(self, *criterion): + def order_by(self, *clauses): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query`` - All existing ORDER BY settings can be suppressed by + All existing ORDER BY settings candef order_by be suppressed by passing ``None``. """ - if len(criterion) == 1: - if criterion[0] is False: - if "_order_by" in self.__dict__: - self._order_by = False - return - if criterion[0] is None: - self._order_by = None - return - - criterion = self._adapt_col_list(criterion) - - if self._order_by is False or self._order_by is None: - self._order_by = criterion + if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False): + self._order_by_clauses = () else: - self._order_by = self._order_by + criterion + criterion = tuple( + coercions.expect(roles.OrderByRole, clause) + for clause in clauses + ) + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv + if self._aliased_generation: + criterion = tuple( + [ + sql_util._deep_annotate( + o, {"aliased_generation": self._aliased_generation} + ) + for o in criterion + ] + ) + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + self._order_by_clauses += criterion @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def group_by(self, *criterion): + def group_by(self, *clauses): """apply one or more GROUP BY criterion to the query and return the newly resulting :class:`_query.Query` @@ -1840,18 +1701,26 @@ class Query(Generative): """ - if len(criterion) == 1: - if criterion[0] is None: - self._group_by = False - return - - criterion = list(chain(*[_orm_columns(c) for c in criterion])) - criterion = self._adapt_col_list(criterion) - - if self._group_by is False: - self._group_by = criterion + if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False): + self._group_by_clauses = () else: - self._group_by = self._group_by + criterion + criterion = tuple( + coercions.expect(roles.GroupByRole, clause) + for clause in clauses + ) + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv + if self._aliased_generation: + criterion = tuple( + [ + sql_util._deep_annotate( + o, {"aliased_generation": self._aliased_generation} + ) + for o in criterion + ] + ) + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^ + + self._group_by_clauses += criterion @_generative @_assertions(_no_statement_condition, _no_limit_offset) @@ -1872,22 +1741,9 @@ class Query(Generative): """ - criterion = coercions.expect(roles.WhereHavingRole, criterion) - - if criterion is not None and not isinstance( - criterion, sql.ClauseElement - ): - raise sa_exc.ArgumentError( - "having() argument must be of type " - "sqlalchemy.sql.ClauseElement or string" - ) - - criterion = self._adapt_clause(criterion, True, True) - - if self._having is not None: - self._having = self._having & criterion - else: - self._having = criterion + self._having_criteria += ( + coercions.expect(roles.WhereHavingRole, criterion), + ) def _set_op(self, expr_fn, *q): return self._from_selectable(expr_fn(*([self] + list(q))).subquery()) @@ -1976,7 +1832,15 @@ class Query(Generative): """ return self._set_op(expression.except_all, *q) - def join(self, *props, **kwargs): + def _next_aliased_generation(self): + if "_aliased_generation_counter" not in self.__dict__: + self._aliased_generation_counter = 0 + self._aliased_generation_counter += 1 + return self._aliased_generation_counter + + @_generative + @_assertions(_no_statement_condition, _no_limit_offset) + def join(self, target, *props, **kwargs): r"""Create a SQL JOIN against this :class:`_query.Query` object's criterion and apply generatively, returning the newly resulting @@ -2248,649 +2112,125 @@ class Query(Generative): raise TypeError( "unknown arguments: %s" % ", ".join(sorted(kwargs)) ) - return self._join( - props, - outerjoin=isouter, - full=full, - create_aliases=aliased, - from_joinpoint=from_joinpoint, - ) - - def outerjoin(self, *props, **kwargs): - """Create a left outer join against this ``Query`` object's criterion - and apply generatively, returning the newly resulting ``Query``. - - Usage is the same as the ``join()`` method. - - """ - aliased, from_joinpoint, full = ( - kwargs.pop("aliased", False), - kwargs.pop("from_joinpoint", False), - kwargs.pop("full", False), - ) - if kwargs: - raise TypeError( - "unknown arguments: %s" % ", ".join(sorted(kwargs)) - ) - return self._join( - props, - outerjoin=True, - full=full, - create_aliases=aliased, - from_joinpoint=from_joinpoint, - ) - - def _update_joinpoint(self, jp): - self._joinpoint = jp - # copy backwards to the root of the _joinpath - # dict, so that no existing dict in the path is mutated - while "prev" in jp: - f, prev = jp["prev"] - prev = dict(prev) - prev[f] = jp.copy() - jp["prev"] = (f, prev) - jp = prev - self._joinpath = jp - - @_generative - @_assertions(_no_statement_condition, _no_limit_offset) - def _join(self, keys, outerjoin, full, create_aliases, from_joinpoint): - """consumes arguments from join() or outerjoin(), places them into a - consistent format with which to form the actual JOIN constructs. - - """ + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv if not from_joinpoint: - self._reset_joinpoint() + self._last_joined_entity = None + self._aliased_generation = None + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - if ( - len(keys) == 2 - and isinstance( - keys[0], - ( - # note this would be FromClause once - # coercion of SELECT is removed - expression.Selectable, - type, - AliasedClass, - ), - ) - and isinstance( - keys[1], - (str, expression.ClauseElement, interfaces.PropComparator), - ) - ): - # detect 2-arg form of join and - # convert to a tuple. - keys = (keys,) - - # Query.join() accepts a list of join paths all at once. - # step one is to iterate through these paths and determine the - # intent of each path individually. as we encounter a path token, - # we add a new ORMJoin construct to the self._from_obj tuple, - # either by adding a new element to it, or by replacing an existing - # element with a new ORMJoin. - keylist = util.to_list(keys) - for idx, arg1 in enumerate(keylist): - if isinstance(arg1, tuple): - # "tuple" form of join, multiple - # tuples are accepted as well. The simpler - # "2-arg" form is preferred. - arg1, arg2 = arg1 - else: - arg2 = None - - # determine onclause/right_entity. there - # is a little bit of legacy behavior still at work here - # which means they might be in either order. - if isinstance( - arg1, (interfaces.PropComparator, util.string_types) - ): - right, onclause = arg2, arg1 - else: - right, onclause = arg1, arg2 - - if onclause is None: - r_info = inspect(right) - if not r_info.is_selectable and not hasattr(r_info, "mapper"): - raise sa_exc.ArgumentError( - "Expected mapped entity or " - "selectable/table as join target" - ) - - if isinstance(onclause, interfaces.PropComparator): - of_type = getattr(onclause, "_of_type", None) - else: - of_type = None - - if isinstance(onclause, util.string_types): - # string given, e.g. query(Foo).join("bar"). - # we look to the left entity or what we last joined - # towards - onclause = _entity_descriptor(self._joinpoint_zero(), onclause) - - # check for q.join(Class.propname, from_joinpoint=True) - # and Class corresponds at the mapper level to the current - # joinpoint. this match intentionally looks for a non-aliased - # class-bound descriptor as the onclause and if it matches the - # current joinpoint at the mapper level, it's used. This - # is a very old use case that is intended to make it easier - # to work with the aliased=True flag, which is also something - # that probably shouldn't exist on join() due to its high - # complexity/usefulness ratio - elif from_joinpoint and isinstance( - onclause, interfaces.PropComparator - ): - jp0 = self._joinpoint_zero() - info = inspect(jp0) - - if getattr(info, "mapper", None) is onclause._parententity: - onclause = _entity_descriptor(jp0, onclause.key) - - if isinstance(onclause, interfaces.PropComparator): - # descriptor/property given (or determined); this tells - # us explicitly what the expected "left" side of the join is. - if right is None: - if of_type: - right = of_type - else: - right = onclause.property.entity - - left = onclause._parententity - - alias = self._polymorphic_adapters.get(left, None) - - # could be None or could be ColumnAdapter also - if isinstance(alias, ORMAdapter) and alias.mapper.isa(left): - left = alias.aliased_class - onclause = getattr(left, onclause.key) - - prop = onclause.property - if not isinstance(onclause, attributes.QueryableAttribute): - onclause = prop - - if not create_aliases: - # check for this path already present. - # don't render in that case. - edge = (left, right, prop.key) - if edge in self._joinpoint: - # The child's prev reference might be stale -- - # it could point to a parent older than the - # current joinpoint. If this is the case, - # then we need to update it and then fix the - # tree's spine with _update_joinpoint. Copy - # and then mutate the child, which might be - # shared by a different query object. - jp = self._joinpoint[edge].copy() - jp["prev"] = (edge, self._joinpoint) - self._update_joinpoint(jp) - - # warn only on the last element of the list - if idx == len(keylist) - 1: - util.warn( - "Pathed join target %s has already " - "been joined to; skipping" % prop - ) - continue - else: - # no descriptor/property given; we will need to figure out - # what the effective "left" side is - prop = left = None - - # figure out the final "left" and "right" sides and create an - # ORMJoin to add to our _from_obj tuple - self._join_left_to_right( - left, right, onclause, prop, create_aliases, outerjoin, full - ) - - def _join_left_to_right( - self, left, right, onclause, prop, create_aliases, outerjoin, full - ): - """given raw "left", "right", "onclause" parameters consumed from - a particular key within _join(), add a real ORMJoin object to - our _from_obj list (or augment an existing one) - - """ - - self._polymorphic_adapters = self._polymorphic_adapters.copy() - - if left is None: - # left not given (e.g. no relationship object/name specified) - # figure out the best "left" side based on our existing froms / - # entities - assert prop is None - ( - left, - replace_from_obj_index, - use_entity_index, - ) = self._join_determine_implicit_left_side(left, right, onclause) + if props: + onclause, legacy = props[0], props[1:] else: - # left is given via a relationship/name. Determine where in our - # "froms" list it should be spliced/appended as well as what - # existing entity it corresponds to. - assert prop is not None - ( - replace_from_obj_index, - use_entity_index, - ) = self._join_place_explicit_left_side(left) + onclause = legacy = None - if left is right and not create_aliases: - raise sa_exc.InvalidRequestError( - "Can't construct a join from %s to %s, they " - "are the same entity" % (left, right) - ) - - # the right side as given often needs to be adapted. additionally - # a lot of things can be wrong with it. handle all that and - # get back the new effective "right" side - r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop, create_aliases - ) - - if replace_from_obj_index is not None: - # splice into an existing element in the - # self._from_obj list - left_clause = self._from_obj[replace_from_obj_index] - - self._from_obj = ( - self._from_obj[:replace_from_obj_index] - + ( - orm_join( - left_clause, - right, - onclause, - isouter=outerjoin, - full=full, - ), - ) - + self._from_obj[replace_from_obj_index + 1 :] - ) + if not legacy and onclause is None and not isinstance(target, tuple): + # non legacy argument form + _props = [(target,)] + elif not legacy and isinstance( + target, (expression.Selectable, type, AliasedClass,) + ): + # non legacy argument form + _props = [(target, onclause)] else: - # add a new element to the self._from_obj list - if use_entity_index is not None: - # make use of _MapperEntity selectable, which is usually - # entity_zero.selectable, but if with_polymorphic() were used - # might be distinct - assert isinstance( - self._entities[use_entity_index], _MapperEntity - ) - left_clause = self._entities[use_entity_index].selectable - else: - left_clause = left - - self._from_obj = self._from_obj + ( - orm_join( - left_clause, right, onclause, isouter=outerjoin, full=full - ), - ) - - def _join_determine_implicit_left_side(self, left, right, onclause): - """When join conditions don't express the left side explicitly, - determine if an existing FROM or entity in this query - can serve as the left hand side. - - """ - - # when we are here, it means join() was called without an ORM- - # specific way of telling us what the "left" side is, e.g.: - # - # join(RightEntity) - # - # or - # - # join(RightEntity, RightEntity.foo == LeftEntity.bar) - # - - r_info = inspect(right) - - replace_from_obj_index = use_entity_index = None - - if self._from_obj: - # we have a list of FROMs already. So by definition this - # join has to connect to one of those FROMs. - - indexes = sql_util.find_left_clause_to_join_from( - self._from_obj, r_info.selectable, onclause - ) - - if len(indexes) == 1: - replace_from_obj_index = indexes[0] - left = self._from_obj[replace_from_obj_index] - elif len(indexes) > 1: - raise sa_exc.InvalidRequestError( - "Can't determine which FROM clause to join " - "from, there are multiple FROMS which can " - "join to this entity. Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explcit ON clause if not present already to " - "help resolve the ambiguity." - ) - else: - raise sa_exc.InvalidRequestError( - "Don't know how to join to %r. " - "Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explcit ON clause if not present already to " - "help resolve the ambiguity." % (right,) - ) - - elif self._entities: - # we have no explicit FROMs, so the implicit left has to - # come from our list of entities. - - potential = {} - for entity_index, ent in enumerate(self._entities): - entity = ent.entity_zero_or_selectable - if entity is None: - continue - ent_info = inspect(entity) - if ent_info is r_info: # left and right are the same, skip - continue - - # by using a dictionary with the selectables as keys this - # de-duplicates those selectables as occurs when the query is - # against a series of columns from the same selectable - if isinstance(ent, _MapperEntity): - potential[ent.selectable] = (entity_index, entity) + # legacy forms. more time consuming :) + _props = [] + _single = [] + for prop in (target,) + props: + if isinstance(prop, tuple): + if _single: + _props.extend((_s,) for _s in _single) + _single = [] + + # this checks for an extremely ancient calling form of + # reversed tuples. + if isinstance(prop[0], (str, interfaces.PropComparator)): + prop = (prop[1], prop[0]) + + _props.append(prop) else: - potential[ent_info.selectable] = (None, entity) - - all_clauses = list(potential.keys()) - indexes = sql_util.find_left_clause_to_join_from( - all_clauses, r_info.selectable, onclause - ) - - if len(indexes) == 1: - use_entity_index, left = potential[all_clauses[indexes[0]]] - elif len(indexes) > 1: - raise sa_exc.InvalidRequestError( - "Can't determine which FROM clause to join " - "from, there are multiple FROMS which can " - "join to this entity. Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explcit ON clause if not present already to " - "help resolve the ambiguity." - ) - else: - raise sa_exc.InvalidRequestError( - "Don't know how to join to %r. " - "Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explcit ON clause if not present already to " - "help resolve the ambiguity." % (right,) - ) - else: - raise sa_exc.InvalidRequestError( - "No entities to join from; please use " - "select_from() to establish the left " - "entity/selectable of this join" - ) - - return left, replace_from_obj_index, use_entity_index - - def _join_place_explicit_left_side(self, left): - """When join conditions express a left side explicitly, determine - where in our existing list of FROM clauses we should join towards, - or if we need to make a new join, and if so is it from one of our - existing entities. - - """ - - # when we are here, it means join() was called with an indicator - # as to an exact left side, which means a path to a - # RelationshipProperty was given, e.g.: - # - # join(RightEntity, LeftEntity.right) - # - # or - # - # join(LeftEntity.right) - # - # as well as string forms: - # - # join(RightEntity, "right") - # - # etc. - # - - replace_from_obj_index = use_entity_index = None - - l_info = inspect(left) - if self._from_obj: - indexes = sql_util.find_left_clause_that_matches_given( - self._from_obj, l_info.selectable - ) - - if len(indexes) > 1: - raise sa_exc.InvalidRequestError( - "Can't identify which entity in which to assign the " - "left side of this join. Please use a more specific " - "ON clause." - ) - - # have an index, means the left side is already present in - # an existing FROM in the self._from_obj tuple - if indexes: - replace_from_obj_index = indexes[0] - - # no index, means we need to add a new element to the - # self._from_obj tuple - - # no from element present, so we will have to add to the - # self._from_obj tuple. Determine if this left side matches up - # with existing mapper entities, in which case we want to apply the - # aliasing / adaptation rules present on that entity if any - if ( - replace_from_obj_index is None - and self._entities - and hasattr(l_info, "mapper") - ): - for idx, ent in enumerate(self._entities): - # TODO: should we be checking for multiple mapper entities - # matching? - if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): - use_entity_index = idx - break - - return replace_from_obj_index, use_entity_index + _single.append(prop) + if _single: + _props.extend((_s,) for _s in _single) - def _join_check_and_adapt_right_side( - self, left, right, onclause, prop, create_aliases - ): - """transform the "right" side of the join as well as the onclause - according to polymorphic mapping translations, aliasing on the query - or on the join, special cases where the right and left side have - overlapping tables. - - """ - - l_info = inspect(left) - r_info = inspect(right) - - overlap = False - if not create_aliases: - right_mapper = getattr(r_info, "mapper", None) - # if the target is a joined inheritance mapping, - # be more liberal about auto-aliasing. - if right_mapper and ( - right_mapper.with_polymorphic - or isinstance(right_mapper.persist_selectable, expression.Join) - ): - for from_obj in self._from_obj or [l_info.selectable]: - if sql_util.selectables_overlap( - l_info.selectable, from_obj - ) and sql_util.selectables_overlap( - from_obj, r_info.selectable - ): - overlap = True - break - - if ( - overlap or not create_aliases - ) and l_info.selectable is r_info.selectable: - raise sa_exc.InvalidRequestError( - "Can't join table/selectable '%s' to itself" - % l_info.selectable - ) + # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv + if aliased: + self._aliased_generation = self._next_aliased_generation() - right_mapper, right_selectable, right_is_aliased = ( - getattr(r_info, "mapper", None), - r_info.selectable, - getattr(r_info, "is_aliased_class", False), - ) - - if ( - right_mapper - and prop - and not right_mapper.common_parent(prop.mapper) - ): - raise sa_exc.InvalidRequestError( - "Join target %s does not correspond to " - "the right side of join condition %s" % (right, onclause) - ) - - # _join_entities is used as a hint for single-table inheritance - # purposes at the moment - if hasattr(r_info, "mapper"): - self._join_entities += (r_info,) - - need_adapter = False - - # test for joining to an unmapped selectable as the target - if r_info.is_clause_element: - - if prop: - right_mapper = prop.mapper - - if right_selectable._is_lateral: - # orm_only is disabled to suit the case where we have to - # adapt an explicit correlate(Entity) - the select() loses - # the ORM-ness in this case right now, ideally it would not - right = self._adapt_clause(right, True, False) - - elif prop: - # joining to selectable with a mapper property given - # as the ON clause - - if not right_selectable.is_derived_from( - right_mapper.persist_selectable - ): - raise sa_exc.InvalidRequestError( - "Selectable '%s' is not derived from '%s'" - % ( - right_selectable.description, - right_mapper.persist_selectable.description, - ) - ) - - # if the destination selectable is a plain select(), - # turn it into an alias(). - if isinstance(right_selectable, expression.SelectBase): - right_selectable = coercions.expect( - roles.FromClauseRole, right_selectable + if self._aliased_generation: + _props = [ + ( + prop[0], + sql_util._deep_annotate( + prop[1], + {"aliased_generation": self._aliased_generation}, ) - need_adapter = True - - # make the right hand side target into an ORM entity - right = aliased(right_mapper, right_selectable) - elif create_aliases: - # it *could* work, but it doesn't right now and I'd rather - # get rid of aliased=True completely - raise sa_exc.InvalidRequestError( - "The aliased=True parameter on query.join() only works " - "with an ORM entity, not a plain selectable, as the " - "target." + if isinstance(prop[1], expression.ClauseElement) + else prop[1], ) + if len(prop) == 2 + else prop + for prop in _props + ] - aliased_entity = ( - right_mapper - and not right_is_aliased - and ( - right_mapper.with_polymorphic - and isinstance( - right_mapper._with_polymorphic_selectable, - expression.AliasedReturnsRows, - ) - or overlap - # test for overlap: - # orm/inheritance/relationships.py - # SelfReferentialM2MTest + # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + self._legacy_setup_joins += tuple( + ( + coercions.expect(roles.JoinTargetRole, prop[0], legacy=True), + prop[1] if len(prop) == 2 else None, + None, + { + "isouter": isouter, + "aliased": aliased, + "from_joinpoint": True if i > 0 else from_joinpoint, + "full": full, + "aliased_generation": self._aliased_generation, + }, ) + for i, prop in enumerate(_props) ) - if not need_adapter and (create_aliases or aliased_entity): - right = aliased(right, flat=True) - need_adapter = True - - if need_adapter: - assert right_mapper + self.__dict__.pop("_last_joined_entity", None) - # if an alias() of the right side was generated, - # apply an adapter to all subsequent filter() calls - # until reset_joinpoint() is called. - adapter = ORMAdapter( - right, equivalents=right_mapper._equivalent_columns - ) - # current adapter takes highest precedence - self._filter_aliases = (adapter,) + self._filter_aliases - - # if an alias() on the right side was generated, - # which is intended to wrap a the right side in a subquery, - # ensure that columns retrieved from this target in the result - # set are also adapted. - if not create_aliases: - self._mapper_loads_polymorphically_with(right_mapper, adapter) - - # if the onclause is a ClauseElement, adapt it with any - # adapters that are in place right now - if isinstance(onclause, expression.ClauseElement): - onclause = self._adapt_clause(onclause, True, True) - - # if joining on a MapperProperty path, - # track the path to prevent redundant joins - if not create_aliases and prop: - self._update_joinpoint( - { - "_joinpoint_entity": right, - "prev": ((left, right, prop.key), self._joinpoint), - } - ) - else: - self._joinpoint = {"_joinpoint_entity": right} + def outerjoin(self, target, *props, **kwargs): + """Create a left outer join against this ``Query`` object's criterion + and apply generatively, returning the newly resulting ``Query``. - return right, inspect(right), onclause + Usage is the same as the ``join()`` method. - def _reset_joinpoint(self): - self._joinpoint = self._joinpath - self._filter_aliases = () + """ + kwargs["isouter"] = True + return self.join(target, *props, **kwargs) @_generative @_assertions(_no_statement_condition) def reset_joinpoint(self): - """Return a new :class:`_query.Query`, where the "join point" has + """Return a new :class:`.Query`, where the "join point" has been reset back to the base FROM entities of the query. This method is usually used in conjunction with the - ``aliased=True`` feature of the :meth:`_query.Query.join` - method. See the example in :meth:`_query.Query.join` for how + ``aliased=True`` feature of the :meth:`~.Query.join` + method. See the example in :meth:`~.Query.join` for how this is used. """ - self._reset_joinpoint() + self._last_joined_entity = None + self._aliased_generation = None @_generative @_assertions(_no_clauseelement_condition) def select_from(self, *from_obj): - r"""Set the FROM clause of this :class:`_query.Query` explicitly. + r"""Set the FROM clause of this :class:`.Query` explicitly. - :meth:`_query.Query.select_from` is often used in conjunction with - :meth:`_query.Query.join` in order to control which entity is selected + :meth:`.Query.select_from` is often used in conjunction with + :meth:`.Query.join` in order to control which entity is selected from on the "left" side of the join. The entity or selectable object here effectively replaces the - "left edge" of any calls to :meth:`_query.Query.join`, when no + "left edge" of any calls to :meth:`~.Query.join`, when no joinpoint is otherwise established - usually, the default "join - point" is the leftmost entity in the :class:`_query.Query` object's + point" is the leftmost entity in the :class:`~.Query` object's list of entities to be selected. A typical example:: @@ -2907,9 +2247,8 @@ class Query(Generative): :param \*from_obj: collection of one or more entities to apply to the FROM clause. Entities can be mapped classes, - :class:`.AliasedClass` objects, :class:`_orm.Mapper` objects - as well as core :class:`_expression.FromClause` - elements like subqueries. + :class:`.AliasedClass` objects, :class:`.Mapper` objects + as well as core :class:`.FromClause` elements like subqueries. .. versionchanged:: 0.9 This method no longer applies the given FROM object @@ -2920,9 +2259,9 @@ class Query(Generative): .. seealso:: - :meth:`_query.Query.join` + :meth:`~.Query.join` - :meth:`_query.Query.select_entity_from` + :meth:`.Query.select_entity_from` """ @@ -3043,6 +2382,7 @@ class Query(Generative): """ self._set_select_from([from_obj], True) + self.compile_options += {"_enable_single_crit": False} def __getitem__(self, item): if isinstance(item, slice): @@ -3105,20 +2445,46 @@ class Query(Generative): :meth:`_query.Query.offset` """ + # for calculated limit/offset, try to do the addition of + # values to offset in Python, howver if a SQL clause is present + # then the addition has to be on the SQL side. if start is not None and stop is not None: - self._offset = self._offset if self._offset is not None else 0 + offset_clause = self._offset_or_limit_clause_asint_if_possible( + self._offset_clause + ) + if offset_clause is None: + offset_clause = 0 + if start != 0: - self._offset += start - self._limit = stop - start + offset_clause = offset_clause + start + + if offset_clause == 0: + self._offset_clause = None + else: + self._offset_clause = self._offset_or_limit_clause( + offset_clause + ) + + self._limit_clause = self._offset_or_limit_clause(stop - start) + elif start is None and stop is not None: - self._limit = stop + self._limit_clause = self._offset_or_limit_clause(stop) elif start is not None and stop is None: - self._offset = self._offset if self._offset is not None else 0 + offset_clause = self._offset_or_limit_clause_asint_if_possible( + self._offset_clause + ) + if offset_clause is None: + offset_clause = 0 + if start != 0: - self._offset += start + offset_clause = offset_clause + start - if isinstance(self._offset, int) and self._offset == 0: - self._offset = None + if offset_clause == 0: + self._offset_clause = None + else: + self._offset_clause = self._offset_or_limit_clause( + offset_clause + ) @_generative @_assertions(_no_statement_condition) @@ -3127,7 +2493,7 @@ class Query(Generative): ``Query``. """ - self._limit = limit + self._limit_clause = self._offset_or_limit_clause(limit) @_generative @_assertions(_no_statement_condition) @@ -3136,7 +2502,31 @@ class Query(Generative): ``Query``. """ - self._offset = offset + self._offset_clause = self._offset_or_limit_clause(offset) + + def _offset_or_limit_clause(self, element, name=None, type_=None): + """Convert the given value to an "offset or limit" clause. + + This handles incoming integers and converts to an expression; if + an expression is already given, it is passed through. + + """ + return coercions.expect( + roles.LimitOffsetRole, element, name=name, type_=type_ + ) + + def _offset_or_limit_clause_asint_if_possible(self, clause): + """Return the offset or limit clause as a simple integer if possible, + else return the clause. + + """ + if clause is None: + return None + if hasattr(clause, "_limit_offset_value"): + value = clause._limit_offset_value + return util.asint(value) + else: + return clause @_generative @_assertions(_no_statement_condition) @@ -3169,67 +2559,13 @@ class Query(Generative): and will raise :class:`_exc.CompileError` in a future version. """ - if not expr: + if expr: self._distinct = True + self._distinct_on = self._distinct_on + tuple( + coercions.expect(roles.ByOfRole, e) for e in expr + ) else: - expr = self._adapt_col_list(expr) - if isinstance(self._distinct, list): - self._distinct += expr - else: - self._distinct = expr - - @_generative - def prefix_with(self, *prefixes): - r"""Apply the prefixes to the query and return the newly resulting - ``Query``. - - :param \*prefixes: optional prefixes, typically strings, - not using any commas. In particular is useful for MySQL keywords - and optimizer hints: - - e.g.:: - - query = sess.query(User.name).\ - prefix_with('HIGH_PRIORITY').\ - prefix_with('SQL_SMALL_RESULT', 'ALL').\ - prefix_with('/*+ BKA(user) */') - - Would render:: - - SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL /*+ BKA(user) */ - users.name AS users_name FROM users - - .. seealso:: - - :meth:`_expression.HasPrefixes.prefix_with` - - """ - if self._prefixes: - self._prefixes += prefixes - else: - self._prefixes = prefixes - - @_generative - def suffix_with(self, *suffixes): - r"""Apply the suffix to the query and return the newly resulting - ``Query``. - - :param \*suffixes: optional suffixes, typically strings, - not using any commas. - - .. versionadded:: 1.0.0 - - .. seealso:: - - :meth:`_query.Query.prefix_with` - - :meth:`_expression.HasSuffixes.suffix_with` - - """ - if self._suffixes: - self._suffixes += suffixes - else: - self._suffixes = suffixes + self._distinct = True def all(self): """Return the results represented by this :class:`_query.Query` @@ -3270,7 +2606,7 @@ class Query(Generative): """ statement = coercions.expect(roles.SelectStatementRole, statement) - self._statement = statement + self.compile_options += {"_statement": statement} def first(self): """Return the first result of this ``Query`` or @@ -3293,7 +2629,7 @@ class Query(Generative): """ # replicates limit(1) behavior - if self._statement is not None: + if self.compile_options._statement is not None: return self._iter().first() else: return self.limit(1)._iter().first() @@ -3392,22 +2728,22 @@ class Query(Generative): def _iter(self): context = self._compile_context() - context.statement.label_style = LABEL_STYLE_TABLENAME_PLUS_COL - if self._autoflush: + + if self.load_options._autoflush: self.session._autoflush() return self._execute_and_instances(context) def __str__(self): - context = self._compile_context() + compile_state = self._compile_state() try: bind = ( - self._get_bind_args(context, self.session.get_bind) + self._get_bind_args(compile_state, self.session.get_bind) if self.session else None ) except sa_exc.UnboundExecutionError: bind = None - return str(context.statement.compile(bind)) + return str(compile_state.statement.compile(bind)) def _connection_from_session(self, **kw): conn = self.session.connection(**kw) @@ -3415,12 +2751,21 @@ class Query(Generative): conn = conn.execution_options(**self._execution_options) return conn - def _execute_and_instances(self, querycontext): + def _execute_and_instances(self, querycontext, params=None): conn = self._get_bind_args( - querycontext, self._connection_from_session, close_with_result=True + querycontext.compile_state, + self._connection_from_session, + close_with_result=True, ) - result = conn._execute_20(querycontext.statement, self._params) + if params is None: + params = querycontext.load_options._params + + result = conn._execute_20( + querycontext.compile_state.statement, + params, + # execution_options=self.session._orm_execution_options(), + ) return loading.instances(querycontext.query, result, querycontext) def _execute_crud(self, stmt, mapper): @@ -3428,11 +2773,13 @@ class Query(Generative): mapper=mapper, clause=stmt, close_with_result=True ) - return conn.execute(stmt, self._params) + return conn.execute(stmt, self.load_options._params) - def _get_bind_args(self, querycontext, fn, **kw): + def _get_bind_args(self, compile_state, fn, **kw): return fn( - mapper=self._bind_mapper(), clause=querycontext.statement, **kw + mapper=compile_state._bind_mapper(), + clause=compile_state.statement, + **kw ) @property @@ -3475,29 +2822,7 @@ class Query(Generative): """ - return [ - { - "name": ent._label_name, - "type": ent.type, - "aliased": getattr(insp_ent, "is_aliased_class", False), - "expr": ent.expr, - "entity": getattr(insp_ent, "entity", None) - if ent.entity_zero is not None - and not insp_ent.is_clause_element - else None, - } - for ent, insp_ent in [ - ( - _ent, - ( - inspect(_ent.entity_zero) - if _ent.entity_zero is not None - else None - ), - ) - for _ent in self._entities - ] - ] + return _column_descriptions(self) def instances(self, result_proxy, context=None): """Return an ORM result given a :class:`_engine.CursorResult` and @@ -3512,7 +2837,8 @@ class Query(Generative): "for linking ORM results to arbitrary select constructs.", version="1.4", ) - context = QueryContext(self) + compile_state = QueryCompileState._create_for_legacy_query(self) + context = QueryContext(compile_state, self.session) return loading.instances(self, result_proxy, context) @@ -3544,28 +2870,6 @@ class Query(Generative): return loading.merge_result(self, iterator, load) - @property - def _select_args(self): - return { - "limit": self._limit, - "offset": self._offset, - "distinct": self._distinct, - "prefixes": self._prefixes, - "suffixes": self._suffixes, - "group_by": self._group_by or None, - "having": self._having, - } - - @property - def _should_nest_selectable(self): - kwargs = self._select_args - return ( - kwargs.get("limit") is not None - or kwargs.get("offset") is not None - or kwargs.get("distinct", False) - or kwargs.get("group_by", False) - ) - def exists(self): """A convenience method that turns a query into an EXISTS subquery of the form EXISTS (SELECT 1 FROM ... WHERE ...). @@ -3601,13 +2905,20 @@ class Query(Generative): # omitting the FROM clause from a query(X) (#2818); # .with_only_columns() after we have a core select() so that # we get just "SELECT 1" without any entities. - return sql.exists( + + inner = ( self.enable_eagerloads(False) .add_columns(sql.literal_column("1")) .with_labels() .statement.with_only_columns([1]) ) + ezero = self._entity_from_pre_ent_zero() + if ezero is not None: + inner = inner.select_from(ezero) + + return sql.exists(inner) + def count(self): r"""Return a count of rows this the SQL formed by this :class:`Query` would return. @@ -3927,966 +3238,50 @@ class Query(Generative): update_op.exec_() return update_op.rowcount - def _compile_context(self, for_statement=False): + def _compile_state(self, for_statement=False, **kw): + # TODO: this needs to become a general event for all + # Executable objects as well (all ClauseElement?) + # but then how do we clarify that this event is only for + # *top level* compile, not as an embedded element is visted? + # how does that even work because right now a Query that does things + # like from_self() will in fact invoke before_compile for each + # inner element. + # OK perhaps with 2.0 style folks will continue using before_execute() + # as they can now, as a select() with ORM elements will be delivered + # there, OK. sort of fixes the "bake_ok" problem too. if self.dispatch.before_compile: for fn in self.dispatch.before_compile: new_query = fn(self) if new_query is not None and new_query is not self: self = new_query if not fn._bake_ok: - self._bake_ok = False - - context = QueryContext(self) - - if context.statement is not None: - if isinstance(context.statement, expression.TextClause): - # setup for all entities, including contains_eager entities. - for entity in self._entities: - entity.setup_context(self, context) - context.statement = expression.TextualSelect( - context.statement, - context.primary_columns, - positional=False, - ) - else: - # allow TextualSelect with implicit columns as well - # as select() with ad-hoc columns, see test_query::TextTest - self._from_obj_alias = sql.util.ColumnAdapter( - context.statement, adapt_on_names=True - ) - - return context - - context.labels = not for_statement or self._with_labels - context.dedupe_cols = True - - context._for_update_arg = self._for_update_arg - - for entity in self._entities: - entity.setup_context(self, context) - - for rec in context.create_eager_joins: - strategy = rec[0] - strategy(context, *rec[1:]) - - if context.from_clause: - # "load from explicit FROMs" mode, - # i.e. when select_from() or join() is used - context.froms = list(context.from_clause) - # else "load from discrete FROMs" mode, - # i.e. when each _MappedEntity has its own FROM - - if self._enable_single_crit: - self._adjust_for_single_inheritance(context) - - if not context.primary_columns: - if self._only_load_props: - raise sa_exc.InvalidRequestError( - "No column-based properties specified for " - "refresh operation. Use session.expire() " - "to reload collections and related items." - ) - else: - raise sa_exc.InvalidRequestError( - "Query contains no columns with which to " "SELECT from." - ) - - if context.multi_row_eager_loaders and self._should_nest_selectable: - context.statement = self._compound_eager_statement(context) - else: - context.statement = self._simple_statement(context) - - if for_statement: - ezero = self._mapper_zero() - if ezero is not None: - context.statement = context.statement._annotate( - {"deepentity": ezero} - ) - return context - - def _compound_eager_statement(self, context): - # for eager joins present and LIMIT/OFFSET/DISTINCT, - # wrap the query inside a select, - # then append eager joins onto that - - if context.order_by: - order_by_col_expr = sql_util.expand_column_list_from_order_by( - context.primary_columns, context.order_by - ) - else: - context.order_by = None - order_by_col_expr = [] - - inner = sql.select( - util.unique_list(context.primary_columns + order_by_col_expr) - if context.dedupe_cols - else (context.primary_columns + order_by_col_expr), - context.whereclause, - from_obj=context.froms, - use_labels=context.labels, - # TODO: this order_by is only needed if - # LIMIT/OFFSET is present in self._select_args, - # else the application on the outside is enough - order_by=context.order_by, - **self._select_args - ) - # put FOR UPDATE on the inner query, where MySQL will honor it, - # as well as if it has an OF so PostgreSQL can use it. - inner._for_update_arg = context._for_update_arg - - for hint in self._with_hints: - inner = inner.with_hint(*hint) - - if self._correlate: - inner = inner.correlate(*self._correlate) - - inner = inner.alias() - - equivs = self.__all_equivs() - - context.adapter = sql_util.ColumnAdapter(inner, equivs) - - statement = sql.select( - [inner] + context.secondary_columns, use_labels=context.labels - ) - - # Oracle however does not allow FOR UPDATE on the subquery, - # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL - # we expect that all elements of the row are locked, so also put it - # on the outside (except in the case of PG when OF is used) - if ( - context._for_update_arg is not None - and context._for_update_arg.of is None - ): - statement._for_update_arg = context._for_update_arg - - from_clause = inner - for eager_join in context.eager_joins.values(): - # EagerLoader places a 'stop_on' attribute on the join, - # giving us a marker as to where the "splice point" of - # the join should be - from_clause = sql_util.splice_joins( - from_clause, eager_join, eager_join.stop_on - ) - - statement.select_from.non_generative(statement, from_clause) - - if context.order_by: - statement.order_by.non_generative( - statement, *context.adapter.copy_and_process(context.order_by) - ) - - statement.order_by.non_generative(statement, *context.eager_order_by) - return statement - - def _simple_statement(self, context): - if not context.order_by: - context.order_by = None - - if self._distinct is True and context.order_by: - to_add = sql_util.expand_column_list_from_order_by( - context.primary_columns, context.order_by - ) - if to_add: - util.warn_deprecated_20( - "ORDER BY columns added implicitly due to " - "DISTINCT is deprecated and will be removed in " - "SQLAlchemy 2.0. SELECT statements with DISTINCT " - "should be written to explicitly include the appropriate " - "columns in the columns clause" - ) - context.primary_columns += to_add - context.froms += tuple(context.eager_joins.values()) - - statement = sql.select( - util.unique_list( - context.primary_columns + context.secondary_columns - ) - if context.dedupe_cols - else (context.primary_columns + context.secondary_columns), - context.whereclause, - from_obj=context.froms, - use_labels=context.labels, - order_by=context.order_by, - **self._select_args - ) - statement._for_update_arg = context._for_update_arg - - for hint in self._with_hints: - statement = statement.with_hint(*hint) - - if self._correlate: - statement = statement.correlate(*self._correlate) - - if context.eager_order_by: - statement.order_by.non_generative( - statement, *context.eager_order_by - ) - return statement - - def _adjust_for_single_inheritance(self, context): - """Apply single-table-inheritance filtering. - - For all distinct single-table-inheritance mappers represented in - the columns clause of this query, as well as the "select from entity", - add criterion to the WHERE - clause of the given QueryContext such that only the appropriate - subtypes are selected from the total results. - - """ - - search = set(context.single_inh_entities.values()) - if ( - self._select_from_entity - and self._select_from_entity not in context.single_inh_entities - ): - insp = inspect(self._select_from_entity) - if insp.is_aliased_class: - adapter = insp._adapter - else: - adapter = None - search = search.union([(self._select_from_entity, adapter)]) - - for (ext_info, adapter) in search: - if ext_info in self._join_entities: - continue - single_crit = ext_info.mapper._single_table_criterion - if single_crit is not None: - if adapter: - single_crit = adapter.traverse(single_crit) - - single_crit = self._adapt_clause(single_crit, False, False) - context.whereclause = sql.and_( - sql.True_._ifnone(context.whereclause), single_crit - ) - - -class _QueryEntity(object): - """represent an entity column returned within a Query result.""" - - def __new__(cls, *args, **kwargs): - if cls is _QueryEntity: - entity = args[1] - if not isinstance(entity, util.string_types) and _is_mapped_class( - entity - ): - cls = _MapperEntity - elif isinstance(entity, Bundle): - cls = _BundleEntity - else: - cls = _ColumnEntity - return object.__new__(cls) - - def _clone(self): - q = self.__class__.__new__(self.__class__) - q.__dict__ = self.__dict__.copy() - return q - - -class _MapperEntity(_QueryEntity): - """mapper/class/AliasedClass entity""" - - def __init__(self, query, entity): - if not query._primary_entity: - query._primary_entity = self - query._entities.append(self) - query._has_mapper_entities = True - self.entities = [entity] - self.expr = entity - - ext_info = self.entity_zero = inspect(entity) - - self.mapper = ext_info.mapper - - if ext_info.is_aliased_class: - self._label_name = ext_info.name - else: - self._label_name = self.mapper.class_.__name__ - - self.selectable = ext_info.selectable - self.is_aliased_class = ext_info.is_aliased_class - self._with_polymorphic = ext_info.with_polymorphic_mappers - self._polymorphic_discriminator = ext_info.polymorphic_on - self.path = ext_info._path_registry - - if ext_info.mapper.with_polymorphic: - query._setup_query_adapters(entity, ext_info) - - supports_single_entity = True - - use_id_for_hash = True - - def set_with_polymorphic( - self, query, cls_or_mappers, selectable, polymorphic_on - ): - """Receive an update from a call to query.with_polymorphic(). - - Note the newer style of using a free standing with_polymporphic() - construct doesn't make use of this method. - - - """ - if self.is_aliased_class: - # TODO: invalidrequest ? - raise NotImplementedError( - "Can't use with_polymorphic() against " "an Aliased object" - ) - - if cls_or_mappers is None: - query._reset_polymorphic_adapter(self.mapper) - return - - mappers, from_obj = self.mapper._with_polymorphic_args( - cls_or_mappers, selectable - ) - self._with_polymorphic = mappers - self._polymorphic_discriminator = polymorphic_on - - self.selectable = from_obj - query._mapper_loads_polymorphically_with( - self.mapper, - sql_util.ColumnAdapter(from_obj, self.mapper._equivalent_columns), - ) - - @property - def type(self): - return self.mapper.class_ - - @property - def entity_zero_or_selectable(self): - return self.entity_zero - - def _deep_entity_zero(self): - return self.entity_zero - - def corresponds_to(self, entity): - return _entity_corresponds_to(self.entity_zero, entity) - - def adapt_to_selectable(self, query, sel): - query._entities.append(self) - - def _get_entity_clauses(self, query, context): - - adapter = None - - if not self.is_aliased_class: - if query._polymorphic_adapters: - adapter = query._polymorphic_adapters.get(self.mapper, None) - else: - adapter = self.entity_zero._adapter - - if adapter: - if query._from_obj_alias: - ret = adapter.wrap(query._from_obj_alias) - else: - ret = adapter - else: - ret = query._from_obj_alias - - return ret + self.compile_options += {"_bake_ok": False} - def row_processor(self, query, context, result): - adapter = self._get_entity_clauses(query, context) - - if context.adapter and adapter: - adapter = adapter.wrap(context.adapter) - elif not adapter: - adapter = context.adapter - - # polymorphic mappers which have concrete tables in - # their hierarchy usually - # require row aliasing unconditionally. - if not adapter and self.mapper._requires_row_aliasing: - adapter = sql_util.ColumnAdapter( - self.selectable, self.mapper._equivalent_columns - ) - - if query._primary_entity is self: - only_load_props = query._only_load_props - refresh_state = context.refresh_state - else: - only_load_props = refresh_state = None - - _instance = loading._instance_processor( - self.mapper, - context, - result, - self.path, - adapter, - only_load_props=only_load_props, - refresh_state=refresh_state, - polymorphic_discriminator=self._polymorphic_discriminator, - ) - - return _instance, self._label_name, tuple(self.entities) - - def setup_context(self, query, context): - adapter = self._get_entity_clauses(query, context) - - single_table_crit = self.mapper._single_table_criterion - if single_table_crit is not None: - ext_info = self.entity_zero - context.single_inh_entities[ext_info] = ( - ext_info, - ext_info._adapter if ext_info.is_aliased_class else None, - ) - - # if self._adapted_selectable is None: - context.froms += (self.selectable,) - - loading._setup_entity_query( - context, - self.mapper, - self, - self.path, - adapter, - context.primary_columns, - with_polymorphic=self._with_polymorphic, - only_load_props=query._only_load_props, - polymorphic_discriminator=self._polymorphic_discriminator, - ) - - def __str__(self): - return str(self.mapper) - - -@inspection._self_inspects -class Bundle(InspectionAttr): - """A grouping of SQL expressions that are returned by a - :class:`_query.Query` - under one namespace. - - The :class:`.Bundle` essentially allows nesting of the tuple-based - results returned by a column-oriented :class:`_query.Query` object. - It also - is extensible via simple subclassing, where the primary capability - to override is that of how the set of expressions should be returned, - allowing post-processing as well as custom return types, without - involving ORM identity-mapped classes. - - .. versionadded:: 0.9.0 - - .. seealso:: - - :ref:`bundles` - - """ - - single_entity = False - """If True, queries for a single Bundle will be returned as a single - entity, rather than an element within a keyed tuple.""" - - is_clause_element = False - - is_mapper = False - - is_aliased_class = False - - def __init__(self, name, *exprs, **kw): - r"""Construct a new :class:`.Bundle`. - - e.g.:: - - bn = Bundle("mybundle", MyClass.x, MyClass.y) - - for row in session.query(bn).filter( - bn.c.x == 5).filter(bn.c.y == 4): - print(row.mybundle.x, row.mybundle.y) - - :param name: name of the bundle. - :param \*exprs: columns or SQL expressions comprising the bundle. - :param single_entity=False: if True, rows for this :class:`.Bundle` - can be returned as a "single entity" outside of any enclosing tuple - in the same manner as a mapped entity. - - """ - self.name = self._label = name - self.exprs = exprs - self.c = self.columns = ColumnCollection( - (getattr(col, "key", col._label), col) for col in exprs + compile_state = QueryCompileState._create_for_legacy_query( + self, for_statement=for_statement, **kw ) - self.single_entity = kw.pop("single_entity", self.single_entity) - - columns = None - """A namespace of SQL expressions referred to by this :class:`.Bundle`. - - e.g.:: - - bn = Bundle("mybundle", MyClass.x, MyClass.y) - - q = sess.query(bn).filter(bn.c.x == 5) - - Nesting of bundles is also supported:: - - b1 = Bundle("b1", - Bundle('b2', MyClass.a, MyClass.b), - Bundle('b3', MyClass.x, MyClass.y) - ) - - q = sess.query(b1).filter( - b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) - - .. seealso:: - - :attr:`.Bundle.c` - - """ - - c = None - """An alias for :attr:`.Bundle.columns`.""" - - def _clone(self): - cloned = self.__class__.__new__(self.__class__) - cloned.__dict__.update(self.__dict__) - return cloned - - def __clause_element__(self): - return expression.ClauseList(group=False, *self.exprs)._annotate( - {"bundle": True} - ) - - @property - def clauses(self): - return self.__clause_element__().clauses - - def label(self, name): - """Provide a copy of this :class:`.Bundle` passing a new label.""" - - cloned = self._clone() - cloned.name = name - return cloned - - def create_row_processor(self, query, procs, labels): - """Produce the "row processing" function for this :class:`.Bundle`. - - May be overridden by subclasses. - - .. seealso:: - - :ref:`bundles` - includes an example of subclassing. - - """ - keyed_tuple = result_tuple(labels, [() for l in labels]) - - def proc(row): - return keyed_tuple([proc(row) for proc in procs]) - - return proc - - -class _BundleEntity(_QueryEntity): - use_id_for_hash = False - - def __init__(self, query, expr, setup_entities=True, parent_bundle=None): - if parent_bundle: - parent_bundle._entities.append(self) - else: - query._entities.append(self) - - if isinstance( - expr, (attributes.QueryableAttribute, interfaces.PropComparator) - ): - bundle = expr.__clause_element__() - else: - bundle = expr - - self.bundle = self.expr = bundle - self.type = type(bundle) - self._label_name = bundle.name - self._entities = [] - - if setup_entities: - for expr in bundle.exprs: - if isinstance(expr, Bundle): - _BundleEntity(query, expr, parent_bundle=self) - else: - _ColumnEntity(query, expr, parent_bundle=self) - - self.supports_single_entity = self.bundle.single_entity - - @property - def mapper(self): - ezero = self.entity_zero - if ezero is not None: - return ezero.mapper - else: - return None - - @property - def entities(self): - entities = [] - for ent in self._entities: - entities.extend(ent.entities) - return entities - - @property - def entity_zero(self): - for ent in self._entities: - ezero = ent.entity_zero - if ezero is not None: - return ezero - else: - return None + return compile_state - def corresponds_to(self, entity): - # TODO: we might be able to implement this but for now - # we are working around it - return False - - @property - def entity_zero_or_selectable(self): - for ent in self._entities: - ezero = ent.entity_zero_or_selectable - if ezero is not None: - return ezero - else: - return None - - def _deep_entity_zero(self): - for ent in self._entities: - ezero = ent._deep_entity_zero() - if ezero is not None: - return ezero - else: - return None - - def adapt_to_selectable(self, query, sel, parent_bundle=None): - c = _BundleEntity( - query, - self.bundle, - setup_entities=False, - parent_bundle=parent_bundle, - ) - # c._label_name = self._label_name - # c.entity_zero = self.entity_zero - # c.entities = self.entities - - for ent in self._entities: - ent.adapt_to_selectable(query, sel, parent_bundle=c) - - def setup_context(self, query, context): - for ent in self._entities: - ent.setup_context(query, context) - - def row_processor(self, query, context, result): - procs, labels, extra = zip( - *[ - ent.row_processor(query, context, result) - for ent in self._entities - ] - ) - - proc = self.bundle.create_row_processor(query, procs, labels) - - return proc, self._label_name, () - - -class _ColumnEntity(_QueryEntity): - """Column/expression based entity.""" - - froms = frozenset() - - def __init__(self, query, column, namespace=None, parent_bundle=None): - self.expr = expr = column - self.namespace = namespace - _label_name = None - - column = coercions.expect(roles.ColumnsClauseRole, column) - - annotations = column._annotations - - if annotations.get("bundle", False): - _BundleEntity(query, expr, parent_bundle=parent_bundle) - return - - orm_expr = False - - if "parententity" in annotations: - _entity = annotations["parententity"] - self._label_name = _label_name = annotations.get("orm_key", None) - orm_expr = True - - if hasattr(column, "_select_iterable"): - # break out an object like Table into - # individual columns - for c in column._select_iterable: - if c is column: - break - _ColumnEntity(query, c, namespace=column) - else: - return - - if _label_name is None: - self._label_name = getattr(column, "key", None) - - self.type = type_ = column.type - self.use_id_for_hash = not type_.hashable - - if parent_bundle: - parent_bundle._entities.append(self) - else: - query._entities.append(self) - - self.column = column - - if orm_expr: - self.entity_zero = _entity - if _entity: - self.entities = [_entity] - self.mapper = _entity.mapper - else: - self.entities = [] - self.mapper = None - else: - - entity = sql_util.extract_first_column_annotation( - column, "parententity" - ) - - if entity: - self.entities = [entity] - else: - self.entities = [] - - if self.entities: - self.entity_zero = self.entities[0] - self.mapper = self.entity_zero.mapper - - elif self.namespace is not None: - self.entity_zero = self.namespace - self.mapper = None - else: - self.entity_zero = None - self.mapper = None - - if self.entities and self.entity_zero.mapper.with_polymorphic: - query._setup_query_adapters(self.entity_zero, self.entity_zero) - - supports_single_entity = False - - def _deep_entity_zero(self): - if self.mapper is not None: - return self.mapper - - else: - for obj in visitors.iterate(self.column, {"column_tables": True},): - if "parententity" in obj._annotations: - return obj._annotations["parententity"] - elif "deepentity" in obj._annotations: - return obj._annotations["deepentity"] - else: - return None - - @property - def entity_zero_or_selectable(self): - if self.entity_zero is not None: - return self.entity_zero - elif self.column._from_objects: - return self.column._from_objects[0] - else: - return None - - def adapt_to_selectable(self, query, sel, parent_bundle=None): - c = _ColumnEntity( - query, - sel.corresponding_column(self.column), - parent_bundle=parent_bundle, - ) - c._label_name = self._label_name - c.entity_zero = self.entity_zero - c.entities = self.entities - - def corresponds_to(self, entity): - if self.entity_zero is None: - return False - elif _is_aliased_class(entity): - # TODO: polymorphic subclasses ? - return entity is self.entity_zero - else: - return not _is_aliased_class( - self.entity_zero - ) and entity.common_parent(self.entity_zero) - - def row_processor(self, query, context, result): - if ("fetch_column", self) in context.attributes: - column = context.attributes[("fetch_column", self)] - else: - column = self.column - if query._from_obj_alias: - column = query._from_obj_alias.columns[column] - - if column._annotations: - # annotated columns perform more slowly in compiler and - # result due to the __eq__() method, so use deannotated - column = column._deannotate() - - if context.adapter: - column = context.adapter.columns[column] - - getter = result._getter(column) - - return getter, self._label_name, (self.expr, self.column) - - def setup_context(self, query, context): - column = query._adapt_clause(self.column, False, True) - ezero = self.entity_zero - - if self.mapper: - single_table_crit = self.mapper._single_table_criterion - if single_table_crit is not None: - context.single_inh_entities[ezero] = ( - ezero, - ezero._adapter if ezero.is_aliased_class else None, - ) - - if column._annotations: - # annotated columns perform more slowly in compiler and - # result due to the __eq__() method, so use deannotated - column = column._deannotate() - - if ezero is not None: - # use entity_zero as the from if we have it. this is necessary - # for polymorpic scenarios where our FROM is based on ORM entity, - # not the FROM of the column. but also, don't use it if our column - # doesn't actually have any FROMs that line up, such as when its - # a scalar subquery. - if set(self.column._from_objects).intersection( - ezero.selectable._from_objects - ): - context.froms += (ezero.selectable,) + def _compile_context(self, for_statement=False): + compile_state = self._compile_state(for_statement=for_statement) + context = QueryContext(compile_state, self.session) - context.primary_columns.append(column) + return context - context.attributes[("fetch_column", self)] = column - def __str__(self): - return str(self.column) - - -class QueryContext(object): - __slots__ = ( - "multi_row_eager_loaders", - "adapter", - "froms", - "for_update", - "query", - "session", - "autoflush", - "populate_existing", - "invoke_all_eagers", - "version_check", - "refresh_state", - "primary_columns", - "secondary_columns", - "eager_order_by", - "eager_joins", - "create_eager_joins", - "propagate_options", - "attributes", - "statement", - "from_clause", - "whereclause", - "order_by", - "labels", - "dedupe_cols", - "_for_update_arg", - "runid", - "partials", - "post_load_paths", - "identity_token", - "single_inh_entities", - "is_single_entity", - "loaders_require_uniquing", - "loaders_require_buffering", +class AliasOption(interfaces.LoaderOption): + @util.deprecated( + "1.4", + "The :class:`.AliasOption` is not necessary " + "for entities to be matched up to a query that is established " + "via :meth:`.Query.from_statement` and now does nothing.", ) - - def __init__(self, query): - - if query._statement is not None: - if ( - isinstance(query._statement, expression.SelectBase) - and not query._statement._is_textual - and not query._statement.use_labels - ): - self.statement = query._statement.apply_labels() - else: - self.statement = query._statement - self.order_by = None - else: - self.statement = None - self.from_clause = query._from_obj - self.whereclause = query._criterion - self.order_by = query._order_by - - self.is_single_entity = query.is_single_entity - self.loaders_require_buffering = self.loaders_require_uniquing = False - self.multi_row_eager_loaders = False - self.adapter = None - self.froms = () - self.for_update = None - self.query = query - self.session = query.session - self.autoflush = query._autoflush - self.populate_existing = query._populate_existing - self.invoke_all_eagers = query._invoke_all_eagers - self.version_check = query._version_check - self.refresh_state = query._refresh_state - self.primary_columns = [] - self.secondary_columns = [] - self.eager_order_by = [] - self.eager_joins = {} - self.single_inh_entities = {} - self.create_eager_joins = [] - self.propagate_options = set( - o for o in query._with_options if o.propagate_to_loaders - ) - self.attributes = dict(query._attributes) - if self.refresh_state is not None: - self.identity_token = query._refresh_identity_token - else: - self.identity_token = None - - -class AliasOption(interfaces.MapperOption): def __init__(self, alias): r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. - This is a seldom-used option to suit the - very rare case that :func:`.contains_eager` - is being used in conjunction with a user-defined SELECT - statement that aliases the parent table. E.g.:: - - # define an aliased UNION called 'ulist' - ulist = users.select(users.c.user_id==7).\ - union(users.select(users.c.user_id>7)).\ - alias('ulist') - - # add on an eager load of "addresses" - statement = ulist.outerjoin(addresses).\ - select().apply_labels() - - # create query, indicating "ulist" will be an - # alias for the main table, "addresses" - # property should be eager loaded - query = session.query(User).options( - contains_alias(ulist), - contains_eager(User.addresses)) - - # then get results via the statement - results = query.from_statement(statement).all() - - :param alias: is the string name of an alias, or a - :class:`_expression.Alias` object representing - the alias. - """ - self.alias = alias - def process_query(self, query): - if isinstance(self.alias, util.string_types): - alias = query._mapper_zero().persist_selectable.alias(self.alias) - else: - alias = self.alias - query._from_obj_alias = sql_util.ColumnAdapter(alias) + def process_compile_state(self, compile_state): + pass diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 450e5d023..6cb8a0062 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1258,9 +1258,9 @@ class Session(_SessionClassMethods): if bind is None: bind = self.get_bind(mapper, clause=clause, **kw) - return self._connection_for_bind(bind, close_with_result=True).execute( - clause, params or {} - ) + return self._connection_for_bind( + bind, close_with_result=True + )._execute_20(clause, params,) def scalar(self, clause, params=None, mapper=None, bind=None, **kw): """Like :meth:`~.Session.execute` but return a scalar result.""" diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 7edac8990..c0c090b3d 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -12,18 +12,19 @@ from __future__ import absolute_import import collections import itertools +from sqlalchemy.orm import query from . import attributes from . import exc as orm_exc from . import interfaces from . import loading from . import properties -from . import query from . import relationships from . import unitofwork from . import util as orm_util from .base import _DEFER_FOR_STATE from .base import _RAISE_FOR_STATE from .base import _SET_DEFERRED_EXPIRED +from .context import _column_descriptions from .interfaces import LoaderStrategy from .interfaces import StrategizedProperty from .session import _state_session @@ -140,7 +141,7 @@ class UninstrumentedColumnLoader(LoaderStrategy): def setup_query( self, - context, + compile_state, query_entity, path, loadopt, @@ -173,18 +174,25 @@ class ColumnLoader(LoaderStrategy): def setup_query( self, - context, + compile_state, query_entity, path, loadopt, adapter, column_collection, memoized_populators, + check_for_adapt=False, **kwargs ): for c in self.columns: if adapter: - c = adapter.columns[c] + if check_for_adapt: + c = adapter.adapt_check_present(c) + if c is None: + return + else: + c = adapter.columns[c] + column_collection.append(c) fetch = self.columns[0] @@ -238,7 +246,7 @@ class ExpressionColumnLoader(ColumnLoader): def setup_query( self, - context, + compile_state, query_entity, path, loadopt, @@ -351,7 +359,7 @@ class DeferredColumnLoader(LoaderStrategy): def setup_query( self, - context, + compile_state, query_entity, path, loadopt, @@ -382,7 +390,7 @@ class DeferredColumnLoader(LoaderStrategy): self.parent_property._get_strategy( (("deferred", False), ("instrument", True)) ).setup_query( - context, + compile_state, query_entity, path, loadopt, @@ -546,6 +554,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): __slots__ = ( "_lazywhere", "_rev_lazywhere", + "_lazyload_reverse_option", + "_order_by", "use_get", "is_aliased_class", "_bind_to_col", @@ -578,6 +588,14 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): self._rev_equated_columns, ) = join_condition.create_lazy_clause(reverse_direction=True) + if self.parent_property.order_by: + self._order_by = [ + sql_util._deep_annotate(elem, {"_orm_adapt": True}) + for elem in util.to_list(self.parent_property.order_by) + ] + else: + self._order_by = None + self.logger.info("%s lazy loading clause %s", self, self._lazywhere) # determine if our "lazywhere" clause is the same as the mapper's @@ -632,7 +650,12 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) def _memoized_attr__simple_lazy_clause(self): - criterion, bind_to_col = (self._lazywhere, self._bind_to_col) + + lazywhere = sql_util._deep_annotate( + self._lazywhere, {"_orm_adapt": True} + ) + + criterion, bind_to_col = (lazywhere, self._bind_to_col) params = [] @@ -828,16 +851,16 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): # generation of a cache key that is including a throwaway object # in the key. + strategy_options = util.preloaded.orm_strategy_options + # note that "lazy='select'" and "lazy=True" make two separate # lazy loaders. Currently the LRU cache is local to the LazyLoader, # however add ourselves to the initial cache key just to future # proof in case it moves - strategy_options = util.preloaded.orm_strategy_options q = self._bakery(lambda session: session.query(self.entity), self) q.add_criteria( - lambda q: q._adapt_all_clauses()._with_invoke_all_eagers(False), - self.parent_property, + lambda q: q._with_invoke_all_eagers(False), self.parent_property, ) if not self.parent_property.bake_queries: @@ -878,29 +901,29 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) ) - if self.parent_property.order_by: - q.add_criteria( - lambda q: q.order_by( - *util.to_list(self.parent_property.order_by) - ) - ) + if self._order_by: + q.add_criteria(lambda q: q.order_by(*self._order_by)) - for rev in self.parent_property._reverse_property: - # reverse props that are MANYTOONE are loading *this* - # object from get(), so don't need to eager out to those. - if ( - rev.direction is interfaces.MANYTOONE - and rev._use_get - and not isinstance(rev.strategy, LazyLoader) - ): + def _lazyload_reverse(compile_context): + for rev in self.parent_property._reverse_property: + # reverse props that are MANYTOONE are loading *this* + # object from get(), so don't need to eager out to those. + if ( + rev.direction is interfaces.MANYTOONE + and rev._use_get + and not isinstance(rev.strategy, LazyLoader) + ): + strategy_options.Load.for_existing_path( + compile_context.compile_options._current_path[ + rev.parent + ] + ).lazyload(rev.key).process_compile_state(compile_context) - q.add_criteria( - lambda q: q.options( - strategy_options.Load.for_existing_path( - q._current_path[rev.parent] - ).lazyload(rev.key) - ) - ) + q.add_criteria( + lambda q: q._add_context_option( + _lazyload_reverse, self.parent_property + ) + ) lazy_clause, params = self._generate_lazy_clause(state, passive) if self.key in state.dict: @@ -921,8 +944,8 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): # set parameters in the query such that we don't overwrite # parameters that are already set within it def set_default_params(q): - params.update(q._params) - q._params = params + params.update(q.load_options._params) + q.load_options += {"_params": params} return q result = ( @@ -1022,7 +1045,7 @@ class ImmediateLoader(PostLoader): def setup_query( self, - context, + compile_state, entity, path, loadopt, @@ -1058,7 +1081,7 @@ class SubqueryLoader(PostLoader): def setup_query( self, - context, + compile_state, entity, path, loadopt, @@ -1068,24 +1091,27 @@ class SubqueryLoader(PostLoader): **kwargs ): - if not context.query._enable_eagerloads or context.refresh_state: + if ( + not compile_state.compile_options._enable_eagerloads + or compile_state.compile_options._for_refresh_state + ): return - context.loaders_require_buffering = True + compile_state.loaders_require_buffering = True path = path[self.parent_property] # build up a path indicating the path from the leftmost # entity to the thing we're subquery loading. with_poly_entity = path.get( - context.attributes, "path_with_polymorphic", None + compile_state.attributes, "path_with_polymorphic", None ) if with_poly_entity is not None: effective_entity = with_poly_entity else: effective_entity = self.entity - subq_path = context.attributes.get( + subq_path = compile_state.attributes.get( ("subquery_path", None), orm_util.PathRegistry.root ) @@ -1093,12 +1119,12 @@ class SubqueryLoader(PostLoader): # if not via query option, check for # a cycle - if not path.contains(context.attributes, "loader"): + if not path.contains(compile_state.attributes, "loader"): if self.join_depth: if ( ( - context.query._current_path.length - if context.query._current_path + compile_state.current_path.length + if compile_state.current_path else 0 ) + path.length @@ -1113,8 +1139,8 @@ class SubqueryLoader(PostLoader): leftmost_relationship, ) = self._get_leftmost(subq_path) - orig_query = context.attributes.get( - ("orig_query", SubqueryLoader), context.query + orig_query = compile_state.attributes.get( + ("orig_query", SubqueryLoader), compile_state.orm_query ) # generate a new Query from the original, then @@ -1132,11 +1158,18 @@ class SubqueryLoader(PostLoader): # basically doing a longhand # "from_self()". (from_self() itself not quite industrial # strength enough for all contingencies...but very close) - q = orig_query.session.query(effective_entity) - q._attributes = { - ("orig_query", SubqueryLoader): orig_query, - ("subquery_path", None): subq_path, - } + + q = query.Query(effective_entity) + + def set_state_options(compile_state): + compile_state.attributes.update( + { + ("orig_query", SubqueryLoader): orig_query, + ("subquery_path", None): subq_path, + } + ) + + q = q._add_context_option(set_state_options, None)._disable_caching() q = q._set_enable_single_crit(False) to_join, local_attr, parent_alias = self._prep_for_joins( @@ -1153,7 +1186,11 @@ class SubqueryLoader(PostLoader): # add new query to attributes to be picked up # by create_row_processor - path.set(context.attributes, "subquery", q) + # NOTE: be sure to consult baked.py for some hardcoded logic + # about this structure as well + path.set( + compile_state.attributes, "subqueryload_data", {"query": q}, + ) def _get_leftmost(self, subq_path): subq_path = subq_path.path @@ -1196,24 +1233,34 @@ class SubqueryLoader(PostLoader): # the columns in the SELECT list which may no longer include # all entities mentioned in things like WHERE, JOIN, etc. if not q._from_obj: - q._set_select_from( - list( - set( - [ - ent["entity"] - for ent in orig_query.column_descriptions - if ent["entity"] is not None - ] - ) - ), - False, + q._enable_assertions = False + q.select_from.non_generative( + q, + *{ + ent["entity"] + for ent in _column_descriptions(orig_query) + if ent["entity"] is not None + } ) + cs = q._clone() + + # using the _compile_state method so that the before_compile() + # event is hit here. keystone is testing for this. + compile_state = cs._compile_state(entities_only=True) + # select from the identity columns of the outer (specifically, these # are the 'local_cols' of the property). This will remove # other columns from the query that might suggest the right entity # which is why we do _set_select_from above. - target_cols = q._adapt_col_list(leftmost_attr) + target_cols = compile_state._adapt_col_list( + [ + sql.coercions.expect(sql.roles.ByOfRole, o) + for o in leftmost_attr + ], + compile_state._get_current_adapter(), + ) + # q.add_columns.non_generative(q, target_cols) q._set_entities(target_cols) distinct_target_key = leftmost_relationship.distinct_target_key @@ -1229,14 +1276,14 @@ class SubqueryLoader(PostLoader): break # don't need ORDER BY if no limit/offset - if q._limit is None and q._offset is None: - q._order_by = None + if q._limit_clause is None and q._offset_clause is None: + q._order_by_clauses = () - if q._distinct is True and q._order_by: + if q._distinct is True and q._order_by_clauses: # the logic to automatically add the order by columns to the query # when distinct is True is deprecated in the query to_add = sql_util.expand_column_list_from_order_by( - target_cols, q._order_by + target_cols, q._order_by_clauses ) if to_add: q._set_entities(target_cols + to_add) @@ -1244,7 +1291,7 @@ class SubqueryLoader(PostLoader): # the original query now becomes a subquery # which we'll join onto. - embed_q = q.with_labels().subquery() + embed_q = q.apply_labels().subquery() left_alias = orm_util.AliasedClass( leftmost_mapper, embed_q, use_mapper_path=True ) @@ -1346,31 +1393,32 @@ class SubqueryLoader(PostLoader): ) for attr in to_join: - q = q.join(attr, from_joinpoint=True) + q = q.join(attr) + return q def _setup_options(self, q, subq_path, orig_query, effective_entity): # propagate loader options etc. to the new query. # these will fire relative to subq_path. q = q._with_current_path(subq_path) - q = q._conditional_options(*orig_query._with_options) - if orig_query._populate_existing: - q._populate_existing = orig_query._populate_existing + q = q.options(*orig_query._with_options) + if orig_query.load_options._populate_existing: + q.load_options += {"_populate_existing": True} return q def _setup_outermost_orderby(self, q): if self.parent_property.order_by: - # if there's an ORDER BY, alias it the same - # way joinedloader does, but we have to pull out - # the "eagerjoin" from the query. - # this really only picks up the "secondary" table - # right now. - eagerjoin = q._from_obj[0] - eager_order_by = eagerjoin._target_adapter.copy_and_process( - util.to_list(self.parent_property.order_by) - ) - q = q.order_by(*eager_order_by) + + def _setup_outermost_orderby(compile_context): + compile_context.eager_order_by += tuple( + util.to_list(self.parent_property.order_by) + ) + + q = q._add_context_option( + _setup_outermost_orderby, self.parent_property + ) + return q class _SubqCollections(object): @@ -1380,10 +1428,12 @@ class SubqueryLoader(PostLoader): """ - _data = None + __slots__ = ("subq_info", "subq", "_data") - def __init__(self, subq): - self.subq = subq + def __init__(self, subq_info): + self.subq_info = subq_info + self.subq = subq_info["query"] + self._data = None def get(self, key, default): if self._data is None: @@ -1392,7 +1442,9 @@ class SubqueryLoader(PostLoader): def _load(self): self._data = collections.defaultdict(list) - for k, v in itertools.groupby(self.subq, lambda x: x[1:]): + + rows = list(self.subq) + for k, v in itertools.groupby(rows, lambda x: x[1:]): self._data[k].extend(vv[0] for vv in v) def loader(self, state, dict_, row): @@ -1415,11 +1467,15 @@ class SubqueryLoader(PostLoader): path = path[self.parent_property] - subq = path.get(context.attributes, "subquery") + subq_info = path.get(context.attributes, "subqueryload_data") - if subq is None: + if subq_info is None: return + subq = subq_info["query"] + + if subq.session is None: + subq.session = context.session assert subq.session is context.session, ( "Subquery session doesn't refer to that of " "our context. Are there broken context caching " @@ -1433,7 +1489,7 @@ class SubqueryLoader(PostLoader): # call upon create_row_processor again collections = path.get(context.attributes, "collections") if collections is None: - collections = self._SubqCollections(subq) + collections = self._SubqCollections(subq_info) path.set(context.attributes, "collections", collections) if adapter: @@ -1522,7 +1578,7 @@ class JoinedLoader(AbstractRelationshipLoader): def setup_query( self, - context, + compile_state, query_entity, path, loadopt, @@ -1534,17 +1590,20 @@ class JoinedLoader(AbstractRelationshipLoader): ): """Add a left outer join to the statement that's being constructed.""" - if not context.query._enable_eagerloads: + if not compile_state.compile_options._enable_eagerloads: return elif self.uselist: - context.loaders_require_uniquing = True + compile_state.loaders_require_uniquing = True + compile_state.multi_row_eager_loaders = True path = path[self.parent_property] with_polymorphic = None user_defined_adapter = ( - self._init_user_defined_eager_proc(loadopt, context) + self._init_user_defined_eager_proc( + loadopt, compile_state, compile_state.attributes + ) if loadopt else False ) @@ -1555,12 +1614,16 @@ class JoinedLoader(AbstractRelationshipLoader): adapter, add_to_collection, ) = self._setup_query_on_user_defined_adapter( - context, query_entity, path, adapter, user_defined_adapter + compile_state, + query_entity, + path, + adapter, + user_defined_adapter, ) else: # if not via query option, check for # a cycle - if not path.contains(context.attributes, "loader"): + if not path.contains(compile_state.attributes, "loader"): if self.join_depth: if path.length / 2 > self.join_depth: return @@ -1573,7 +1636,7 @@ class JoinedLoader(AbstractRelationshipLoader): add_to_collection, chained_from_outerjoin, ) = self._generate_row_adapter( - context, + compile_state, query_entity, path, loadopt, @@ -1584,7 +1647,7 @@ class JoinedLoader(AbstractRelationshipLoader): ) with_poly_entity = path.get( - context.attributes, "path_with_polymorphic", None + compile_state.attributes, "path_with_polymorphic", None ) if with_poly_entity is not None: with_polymorphic = inspect( @@ -1596,7 +1659,7 @@ class JoinedLoader(AbstractRelationshipLoader): path = path[self.entity] loading._setup_entity_query( - context, + compile_state, self.mapper, query_entity, path, @@ -1608,7 +1671,7 @@ class JoinedLoader(AbstractRelationshipLoader): ) if with_poly_entity is not None and None in set( - context.secondary_columns + compile_state.secondary_columns ): raise sa_exc.InvalidRequestError( "Detected unaliased columns when generating joined " @@ -1616,7 +1679,9 @@ class JoinedLoader(AbstractRelationshipLoader): "when using joined loading with with_polymorphic()." ) - def _init_user_defined_eager_proc(self, loadopt, context): + def _init_user_defined_eager_proc( + self, loadopt, compile_state, target_attributes + ): # check if the opt applies at all if "eager_from_alias" not in loadopt.local_opts: @@ -1628,7 +1693,7 @@ class JoinedLoader(AbstractRelationshipLoader): # the option applies. check if the "user_defined_eager_row_processor" # has been built up. adapter = path.get( - context.attributes, "user_defined_eager_row_processor", False + compile_state.attributes, "user_defined_eager_row_processor", False ) if adapter is not False: # just return it @@ -1645,20 +1710,22 @@ class JoinedLoader(AbstractRelationshipLoader): alias, equivalents=prop.mapper._equivalent_columns ) else: - if path.contains(context.attributes, "path_with_polymorphic"): + if path.contains( + compile_state.attributes, "path_with_polymorphic" + ): with_poly_entity = path.get( - context.attributes, "path_with_polymorphic" + compile_state.attributes, "path_with_polymorphic" ) adapter = orm_util.ORMAdapter( with_poly_entity, equivalents=prop.mapper._equivalent_columns, ) else: - adapter = context.query._polymorphic_adapters.get( + adapter = compile_state._polymorphic_adapters.get( prop.mapper, None ) path.set( - context.attributes, "user_defined_eager_row_processor", adapter + target_attributes, "user_defined_eager_row_processor", adapter, ) return adapter @@ -1669,7 +1736,7 @@ class JoinedLoader(AbstractRelationshipLoader): # apply some more wrapping to the "user defined adapter" # if we are setting up the query for SQL render. - adapter = entity._get_entity_clauses(context.query, context) + adapter = entity._get_entity_clauses(context) if adapter and user_defined_adapter: user_defined_adapter = user_defined_adapter.wrap(adapter) @@ -1725,7 +1792,7 @@ class JoinedLoader(AbstractRelationshipLoader): def _generate_row_adapter( self, - context, + compile_state, entity, path, loadopt, @@ -1735,12 +1802,12 @@ class JoinedLoader(AbstractRelationshipLoader): chained_from_outerjoin, ): with_poly_entity = path.get( - context.attributes, "path_with_polymorphic", None + compile_state.attributes, "path_with_polymorphic", None ) if with_poly_entity: to_adapt = with_poly_entity else: - to_adapt = self._gen_pooled_aliased_class(context) + to_adapt = self._gen_pooled_aliased_class(compile_state) clauses = inspect(to_adapt)._memo( ("joinedloader_ormadapter", self), @@ -1754,9 +1821,6 @@ class JoinedLoader(AbstractRelationshipLoader): assert clauses.aliased_class is not None - if self.parent_property.uselist: - context.multi_row_eager_loaders = True - innerjoin = ( loadopt.local_opts.get("innerjoin", self.parent_property.innerjoin) if loadopt is not None @@ -1768,7 +1832,7 @@ class JoinedLoader(AbstractRelationshipLoader): # this path must also be outer joins chained_from_outerjoin = True - context.create_eager_joins.append( + compile_state.create_eager_joins.append( ( self._create_eager_join, entity, @@ -1781,14 +1845,14 @@ class JoinedLoader(AbstractRelationshipLoader): ) ) - add_to_collection = context.secondary_columns - path.set(context.attributes, "eager_row_processor", clauses) + add_to_collection = compile_state.secondary_columns + path.set(compile_state.attributes, "eager_row_processor", clauses) return clauses, adapter, add_to_collection, chained_from_outerjoin def _create_eager_join( self, - context, + compile_state, query_entity, path, adapter, @@ -1797,7 +1861,6 @@ class JoinedLoader(AbstractRelationshipLoader): innerjoin, chained_from_outerjoin, ): - if parentmapper is None: localparent = query_entity.mapper else: @@ -1807,19 +1870,19 @@ class JoinedLoader(AbstractRelationshipLoader): # and then attach eager load joins to that (i.e., in the case of # LIMIT/OFFSET etc.) should_nest_selectable = ( - context.multi_row_eager_loaders - and context.query._should_nest_selectable + compile_state.multi_row_eager_loaders + and compile_state._should_nest_selectable ) query_entity_key = None if ( - query_entity not in context.eager_joins + query_entity not in compile_state.eager_joins and not should_nest_selectable - and context.from_clause + and compile_state.from_clauses ): indexes = sql_util.find_left_clause_that_matches_given( - context.from_clause, query_entity.selectable + compile_state.from_clauses, query_entity.selectable ) if len(indexes) > 1: @@ -1832,7 +1895,7 @@ class JoinedLoader(AbstractRelationshipLoader): ) if indexes: - clause = context.from_clause[indexes[0]] + clause = compile_state.from_clauses[indexes[0]] # join to an existing FROM clause on the query. # key it to its list index in the eager_joins dict. # Query._compile_context will adapt as needed and @@ -1845,7 +1908,7 @@ class JoinedLoader(AbstractRelationshipLoader): query_entity.selectable, ) - towrap = context.eager_joins.setdefault( + towrap = compile_state.eager_joins.setdefault( query_entity_key, default_towrap ) @@ -1903,7 +1966,7 @@ class JoinedLoader(AbstractRelationshipLoader): path, towrap, clauses, onclause ) - context.eager_joins[query_entity_key] = eagerjoin + compile_state.eager_joins[query_entity_key] = eagerjoin # send a hint to the Query as to where it may "splice" this join eagerjoin.stop_on = query_entity.selectable @@ -1922,12 +1985,14 @@ class JoinedLoader(AbstractRelationshipLoader): if localparent.persist_selectable.c.contains_column(col): if adapter: col = adapter.columns[col] - context.primary_columns.append(col) + compile_state.primary_columns.append(col) if self.parent_property.order_by: - context.eager_order_by += ( - eagerjoin._target_adapter.copy_and_process - )(util.to_list(self.parent_property.order_by)) + compile_state.eager_order_by += tuple( + (eagerjoin._target_adapter.copy_and_process)( + util.to_list(self.parent_property.order_by) + ) + ) def _splice_nested_inner_join( self, path, join_obj, clauses, onclause, splicing=False @@ -2000,8 +2065,12 @@ class JoinedLoader(AbstractRelationshipLoader): return eagerjoin def _create_eager_adapter(self, context, result, adapter, path, loadopt): + compile_state = context.compile_state + user_defined_adapter = ( - self._init_user_defined_eager_proc(loadopt, context) + self._init_user_defined_eager_proc( + loadopt, compile_state, context.attributes + ) if loadopt else False ) @@ -2011,12 +2080,16 @@ class JoinedLoader(AbstractRelationshipLoader): # user defined eagerloads are part of the "primary" # portion of the load. # the adapters applied to the Query should be honored. - if context.adapter and decorator: - decorator = decorator.wrap(context.adapter) - elif context.adapter: - decorator = context.adapter + if compile_state.compound_eager_adapter and decorator: + decorator = decorator.wrap( + compile_state.compound_eager_adapter + ) + elif compile_state.compound_eager_adapter: + decorator = compile_state.compound_eager_adapter else: - decorator = path.get(context.attributes, "eager_row_processor") + decorator = path.get( + compile_state.attributes, "eager_row_processor" + ) if decorator is None: return False @@ -2282,7 +2355,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): ) selectin_path = ( - context.query._current_path or orm_util.PathRegistry.root + context.compile_state.current_path or orm_util.PathRegistry.root ) + path if not orm_util._entity_isa(path[-1], self.parent): @@ -2391,7 +2464,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): q = self._bakery( lambda session: session.query( - query.Bundle("pk", *pk_cols), effective_entity + orm_util.Bundle("pk", *pk_cols), effective_entity ), self, ) @@ -2435,7 +2508,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): orig_query._with_options, path[self.parent_property] ) - if orig_query._populate_existing: + if context.populate_existing: q.add_criteria(lambda q: q.populate_existing()) if self.parent_property.order_by: @@ -2448,18 +2521,16 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): q.add_criteria(lambda q: q.order_by(*eager_order_by)) else: - def _setup_outermost_orderby(q): - # imitate the same method that subquery eager loading uses, - # looking for the adapted "secondary" table - eagerjoin = q._from_obj[0] - - return q.order_by( - *eagerjoin._target_adapter.copy_and_process( - util.to_list(self.parent_property.order_by) - ) + def _setup_outermost_orderby(compile_context): + compile_context.eager_order_by += tuple( + util.to_list(self.parent_property.order_by) ) - q.add_criteria(_setup_outermost_orderby) + q.add_criteria( + lambda q: q._add_context_option( + _setup_outermost_orderby, self.parent_property + ) + ) if query_info.load_only_child: self._load_via_child( diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 2fd628d0b..e0ba3050c 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -14,7 +14,7 @@ from .base import _class_to_mapper from .base import _is_aliased_class from .base import _is_mapped_class from .base import InspectionAttr -from .interfaces import MapperOption +from .interfaces import LoaderOption from .interfaces import PropComparator from .path_registry import _DEFAULT_TOKEN from .path_registry import _WILDCARD_TOKEN @@ -29,10 +29,9 @@ from ..sql import roles from ..sql import visitors from ..sql.base import _generative from ..sql.base import Generative -from ..sql.traversals import HasCacheKey -class Load(HasCacheKey, Generative, MapperOption): +class Load(Generative, LoaderOption): """Represents loader options which modify the state of a :class:`_query.Query` in order to affect how various mapped attributes are loaded. @@ -196,21 +195,23 @@ class Load(HasCacheKey, Generative, MapperOption): propagate_to_loaders = False _of_type = None - def process_query(self, query): - self._process(query, True) + def process_compile_state(self, compile_state): + if not compile_state.compile_options._enable_eagerloads: + return - def process_query_conditionally(self, query): - self._process(query, False) + self._process(compile_state, not bool(compile_state.current_path)) - def _process(self, query, raiseerr): - current_path = query._current_path + def _process(self, compile_state, raiseerr): + current_path = compile_state.current_path if current_path: for (token, start_path), loader in self.context.items(): chopped_start_path = self._chop_path(start_path, current_path) if chopped_start_path is not None: - query._attributes[(token, chopped_start_path)] = loader + compile_state.attributes[ + (token, chopped_start_path) + ] = loader else: - query._attributes.update(self.context) + compile_state.attributes.update(self.context) def _generate_path( self, path, attr, for_strategy, wildcard_key, raiseerr=True @@ -423,7 +424,6 @@ class Load(HasCacheKey, Generative, MapperOption): @_generative def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False): strategy = self._coerce_strat(strategy) - self.is_class_strategy = False for attr in attrs: cloned = self._clone_for_bind_strategy( @@ -434,7 +434,6 @@ class Load(HasCacheKey, Generative, MapperOption): @_generative def set_generic_strategy(self, attrs, strategy): strategy = self._coerce_strat(strategy) - for attr in attrs: cloned = self._clone_for_bind_strategy(attr, strategy, None) cloned.propagate_to_loaders = True @@ -685,15 +684,18 @@ class _UnboundLoad(Load): state["path"] = tuple(ret) self.__dict__ = state - def _process(self, query, raiseerr): - dedupes = query._attributes["_unbound_load_dedupes"] + def _process(self, compile_state, raiseerr): + dedupes = compile_state.attributes["_unbound_load_dedupes"] for val in self._to_bind: if val not in dedupes: dedupes.add(val) val._bind_loader( - [ent.entity_zero for ent in query._mapper_entities], - query._current_path, - query._attributes, + [ + ent.entity_zero + for ent in compile_state._mapper_entities + ], + compile_state.current_path, + compile_state.attributes, raiseerr, ) @@ -767,7 +769,11 @@ class _UnboundLoad(Load): ret.append((token._parentmapper.class_, token.key, None)) else: ret.append( - (token._parentmapper.class_, token.key, token._of_type) + ( + token._parentmapper.class_, + token.key, + token._of_type.entity if token._of_type else None, + ) ) elif isinstance(token, PropComparator): ret.append((token._parentmapper.class_, token.key, None)) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index b78f8824e..1e415e49c 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -24,6 +24,9 @@ from .base import state_attribute_str # noqa from .base import state_class_str # noqa from .base import state_str # noqa from .interfaces import MapperProperty # noqa +from .interfaces import ORMColumnsClauseRole +from .interfaces import ORMEntityColumnsClauseRole +from .interfaces import ORMFromClauseRole from .interfaces import PropComparator # noqa from .path_registry import PathRegistry # noqa from .. import event @@ -31,12 +34,14 @@ from .. import exc as sa_exc from .. import inspection from .. import sql from .. import util +from ..engine.result import result_tuple from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.base import ColumnCollection all_cascades = frozenset( @@ -497,6 +502,13 @@ class AliasedClass(object): self.__name__ = "AliasedClass_%s" % mapper.class_.__name__ + @classmethod + def _reconstitute_from_aliased_insp(cls, aliased_insp): + obj = cls.__new__(cls) + obj.__name__ = "AliasedClass_%s" % aliased_insp.mapper.class_.__name__ + obj._aliased_insp = aliased_insp + return obj + def __getattr__(self, key): try: _aliased_insp = self.__dict__["_aliased_insp"] @@ -526,6 +538,27 @@ class AliasedClass(object): return attr + def _get_from_serialized(self, key, mapped_class, aliased_insp): + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + attr = getattr(mapped_class, key) + if hasattr(attr, "__call__") and hasattr(attr, "__self__"): + return types.MethodType(attr.__func__, self) + + # attribute is a descriptor, that will be invoked against a + # "self"; so invoke the descriptor against this self + if hasattr(attr, "__get__"): + attr = attr.__get__(None, self) + + # attributes within the QueryableAttribute system will want this + # to be invoked so the object can be adapted + if hasattr(attr, "adapt_to_entity"): + aliased_insp._weak_entity = weakref.ref(self) + attr = attr.adapt_to_entity(aliased_insp) + setattr(self, key, attr) + + return attr + def __repr__(self): return "<AliasedClass at 0x%x; %s>" % ( id(self), @@ -536,7 +569,12 @@ class AliasedClass(object): return str(self._aliased_insp) -class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): +class AliasedInsp( + ORMEntityColumnsClauseRole, + ORMFromClauseRole, + sql_base.MemoizedHasCacheKey, + InspectionAttr, +): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -632,13 +670,35 @@ class AliasedInsp(sql_base.HasCacheKey, InspectionAttr): @property def entity(self): - return self._weak_entity() + # to eliminate reference cycles, the AliasedClass is held weakly. + # this produces some situations where the AliasedClass gets lost, + # particularly when one is created internally and only the AliasedInsp + # is passed around. + # to work around this case, we just generate a new one when we need + # it, as it is a simple class with very little initial state on it. + ent = self._weak_entity() + if ent is None: + ent = AliasedClass._reconstitute_from_aliased_insp(self) + self._weak_entity = weakref.ref(ent) + return ent is_aliased_class = True "always returns True" + @util.memoized_instancemethod def __clause_element__(self): - return self.selectable + return self.selectable._annotate( + { + "parentmapper": self.mapper, + "parententity": self, + "entity_namespace": self, + "compile_state_plugin": "orm", + } + ) + + @property + def entity_namespace(self): + return self.entity _cache_key_traversal = [ ("name", visitors.ExtendedInternalTraversal.dp_string), @@ -976,6 +1036,150 @@ def with_polymorphic( ) +@inspection._self_inspects +class Bundle(ORMColumnsClauseRole, InspectionAttr): + """A grouping of SQL expressions that are returned by a :class:`.Query` + under one namespace. + + The :class:`.Bundle` essentially allows nesting of the tuple-based + results returned by a column-oriented :class:`_query.Query` object. + It also + is extensible via simple subclassing, where the primary capability + to override is that of how the set of expressions should be returned, + allowing post-processing as well as custom return types, without + involving ORM identity-mapped classes. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`bundles` + + + """ + + single_entity = False + """If True, queries for a single Bundle will be returned as a single + entity, rather than an element within a keyed tuple.""" + + is_clause_element = False + + is_mapper = False + + is_aliased_class = False + + is_bundle = True + + def __init__(self, name, *exprs, **kw): + r"""Construct a new :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + for row in session.query(bn).filter( + bn.c.x == 5).filter(bn.c.y == 4): + print(row.mybundle.x, row.mybundle.y) + + :param name: name of the bundle. + :param \*exprs: columns or SQL expressions comprising the bundle. + :param single_entity=False: if True, rows for this :class:`.Bundle` + can be returned as a "single entity" outside of any enclosing tuple + in the same manner as a mapped entity. + + """ + self.name = self._label = name + self.exprs = exprs = [ + coercions.expect(roles.ColumnsClauseRole, expr) for expr in exprs + ] + + self.c = self.columns = ColumnCollection( + (getattr(col, "key", col._label), col) + for col in [e._annotations.get("bundle", e) for e in exprs] + ) + self.single_entity = kw.pop("single_entity", self.single_entity) + + @property + def mapper(self): + return self.exprs[0]._annotations.get("parentmapper", None) + + @property + def entity(self): + return self.exprs[0]._annotations.get("parententity", None) + + @property + def entity_namespace(self): + return self.c + + columns = None + """A namespace of SQL expressions referred to by this :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + q = sess.query(bn).filter(bn.c.x == 5) + + Nesting of bundles is also supported:: + + b1 = Bundle("b1", + Bundle('b2', MyClass.a, MyClass.b), + Bundle('b3', MyClass.x, MyClass.y) + ) + + q = sess.query(b1).filter( + b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) + + .. seealso:: + + :attr:`.Bundle.c` + + """ + + c = None + """An alias for :attr:`.Bundle.columns`.""" + + def _clone(self): + cloned = self.__class__.__new__(self.__class__) + cloned.__dict__.update(self.__dict__) + return cloned + + def __clause_element__(self): + return expression.ClauseList( + _literal_as_text_role=roles.ColumnsClauseRole, + group=False, + *[e._annotations.get("bundle", e) for e in self.exprs] + )._annotate({"bundle": self, "entity_namespace": self}) + + @property + def clauses(self): + return self.__clause_element__().clauses + + def label(self, name): + """Provide a copy of this :class:`.Bundle` passing a new label.""" + + cloned = self._clone() + cloned.name = name + return cloned + + def create_row_processor(self, query, procs, labels): + """Produce the "row processing" function for this :class:`.Bundle`. + + May be overridden by subclasses. + + .. seealso:: + + :ref:`bundles` - includes an example of subclassing. + + """ + keyed_tuple = result_tuple(labels, [() for l in labels]) + + def proc(row): + return keyed_tuple([proc(row) for proc in procs]) + + return proc + + def _orm_annotate(element, exclude=None): """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. @@ -1020,33 +1224,39 @@ class _ORMJoin(expression.Join): _right_memo=None, ): left_info = inspection.inspect(left) - left_orm_info = getattr(left, "_joined_from_info", left_info) right_info = inspection.inspect(right) adapt_to = right_info.selectable - self._joined_from_info = right_info - + # used by joined eager loader self._left_memo = _left_memo self._right_memo = _right_memo + # legacy, for string attr name ON clause. if that's removed + # then the "_joined_from_info" concept can go + left_orm_info = getattr(left, "_joined_from_info", left_info) + self._joined_from_info = right_info if isinstance(onclause, util.string_types): onclause = getattr(left_orm_info.entity, onclause) + # #### if isinstance(onclause, attributes.QueryableAttribute): on_selectable = onclause.comparator._source_selectable() prop = onclause.property elif isinstance(onclause, MapperProperty): + # used internally by joined eager loader...possibly not ideal prop = onclause on_selectable = prop.parent.selectable else: prop = None if prop: - if sql_util.clause_is_present(on_selectable, left_info.selectable): + left_selectable = left_info.selectable + + if sql_util.clause_is_present(on_selectable, left_selectable): adapt_from = on_selectable else: - adapt_from = left_info.selectable + adapt_from = left_selectable ( pj, diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 891b8ae09..71d05f38f 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -31,7 +31,10 @@ class SupportsAnnotations(object): if isinstance(value, HasCacheKey) else value, ) - for key, value in self._annotations.items() + for key, value in [ + (key, self._annotations[key]) + for key in sorted(self._annotations) + ] ), ) @@ -51,6 +54,7 @@ class SupportsCloneAnnotations(SupportsAnnotations): new = self._clone() new._annotations = new._annotations.union(values) new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) return new def _with_annotations(self, values): @@ -61,6 +65,7 @@ class SupportsCloneAnnotations(SupportsAnnotations): new = self._clone() new._annotations = util.immutabledict(values) new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) return new def _deannotate(self, values=None, clone=False): @@ -76,7 +81,7 @@ class SupportsCloneAnnotations(SupportsAnnotations): # clone is used when we are also copying # the expression for a deep deannotation new = self._clone() - new._annotations = {} + new._annotations = util.immutabledict() new.__dict__.pop("_annotations_cache_key", None) return new else: @@ -156,6 +161,7 @@ class Annotated(object): def __init__(self, element, values): self.__dict__ = element.__dict__.copy() self.__dict__.pop("_annotations_cache_key", None) + self.__dict__.pop("_generate_cache_key", None) self.__element = element self._annotations = values self._hash = hash(element) @@ -169,6 +175,7 @@ class Annotated(object): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() clone.__dict__.pop("_annotations_cache_key", None) + clone.__dict__.pop("_generate_cache_key", None) clone._annotations = values return clone @@ -211,6 +218,13 @@ class Annotated(object): else: return hash(other) == hash(self) + @property + def entity_namespace(self): + if "entity_namespace" in self._annotations: + return self._annotations["entity_namespace"].entity_namespace + else: + return self.__element.entity_namespace + # hard-generate Annotated subclasses. this technique # is used instead of on-the-fly types (i.e. type.__new__()) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 2d023c6a6..04cc34480 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -15,11 +15,14 @@ import operator import re from .traversals import HasCacheKey # noqa +from .traversals import MemoizedHasCacheKey # noqa from .visitors import ClauseVisitor +from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import exc from .. import util from ..util import HasMemoized +from ..util import hybridmethod if util.TYPE_CHECKING: from types import ModuleType @@ -433,22 +436,52 @@ class CompileState(object): __slots__ = ("statement",) + plugins = {} + @classmethod def _create(cls, statement, compiler, **kw): # factory construction. - # specific CompileState classes here will look for - # "plugins" in the given statement. From there they will invoke - # the appropriate plugin constructor if one is found and return - # the alternate CompileState object. + if statement._compile_state_plugin is not None: + constructor = cls.plugins.get( + ( + statement._compile_state_plugin, + statement.__visit_name__, + None, + ), + cls, + ) + else: + constructor = cls - c = cls.__new__(cls) - c.__init__(statement, compiler, **kw) - return c + return constructor(statement, compiler, **kw) def __init__(self, statement, compiler, **kw): self.statement = statement + @classmethod + def get_plugin_classmethod(cls, statement, name): + if statement._compile_state_plugin is not None: + fn = cls.plugins.get( + ( + statement._compile_state_plugin, + statement.__visit_name__, + name, + ), + None, + ) + if fn is not None: + return fn + return getattr(cls, name) + + @classmethod + def plugin_for(cls, plugin_name, visit_name, method_name=None): + def decorate(fn): + cls.plugins[(plugin_name, visit_name, method_name)] = fn + return fn + + return decorate + class Generative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @@ -479,6 +512,57 @@ class HasCompileState(Generative): _compile_state_plugin = None + _attributes = util.immutabledict() + + +class _MetaOptions(type): + """metaclass for the Options class.""" + + def __init__(cls, classname, bases, dict_): + cls._cache_attrs = tuple( + sorted(d for d in dict_ if not d.startswith("__")) + ) + type.__init__(cls, classname, bases, dict_) + + def __add__(self, other): + o1 = self() + o1.__dict__.update(other) + return o1 + + +class Options(util.with_metaclass(_MetaOptions)): + """A cacheable option dictionary with defaults. + + + """ + + def __init__(self, **kw): + self.__dict__.update(kw) + + def __add__(self, other): + o1 = self.__class__.__new__(self.__class__) + o1.__dict__.update(self.__dict__) + o1.__dict__.update(other) + return o1 + + @hybridmethod + def add_to_element(self, name, value): + return self + {name: getattr(self, name) + value} + + +class CacheableOptions(Options, HasCacheKey): + @hybridmethod + def _gen_cache_key(self, anon_map, bindparams): + return HasCacheKey._gen_cache_key(self, anon_map, bindparams) + + @_gen_cache_key.classlevel + def _gen_cache_key(cls, anon_map, bindparams): + return (cls, ()) + + @hybridmethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key_for_object(self) + class Executable(Generative): """Mark a ClauseElement as supporting execution. @@ -492,7 +576,21 @@ class Executable(Generative): supports_execution = True _execution_options = util.immutabledict() _bind = None + _with_options = () + _with_context_options = () + _cache_enable = True + + _executable_traverse_internals = [ + ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list), + ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj), + ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj), + ] + + @_generative + def _disable_caching(self): + self._cache_enable = HasCacheKey() + @_generative def options(self, *options): """Apply options to this statement. @@ -522,7 +620,21 @@ class Executable(Generative): to the usage of ORM queries """ - self._options += options + self._with_options += options + + @_generative + def _add_context_option(self, callable_, cache_args): + """Add a context option to this statement. + + These are callable functions that will + be given the CompileState object upon compilation. + + A second argument cache_args is required, which will be combined + with the identity of the function itself in order to produce a + cache key. + + """ + self._with_context_options += ((callable_, cache_args),) @_generative def execution_options(self, **kw): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index c7b85c415..2fc63b82f 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -57,7 +57,7 @@ def expect(role, element, **kw): if not isinstance( element, - (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), + (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue,), ): resolved = impl._resolve_for_clause_element(element, **kw) else: @@ -106,7 +106,9 @@ class RoleImpl(object): self.name = role_class._role_name self._use_inspection = issubclass(role_class, roles.UsesInspection) - def _resolve_for_clause_element(self, element, argname=None, **kw): + def _resolve_for_clause_element( + self, element, argname=None, apply_plugins=None, **kw + ): original_element = element is_clause_element = False @@ -115,18 +117,39 @@ class RoleImpl(object): if not getattr(element, "is_clause_element", False): element = element.__clause_element__() else: - return element + break + + should_apply_plugins = ( + apply_plugins is not None + and apply_plugins._compile_state_plugin is None + ) if is_clause_element: + if ( + should_apply_plugins + and "compile_state_plugin" in element._annotations + ): + apply_plugins._compile_state_plugin = element._annotations[ + "compile_state_plugin" + ] return element if self._use_inspection: insp = inspection.inspect(element, raiseerr=False) if insp is not None: + insp._post_inspect try: - return insp.__clause_element__() + element = insp.__clause_element__() except AttributeError: self._raise_for_expected(original_element, argname) + else: + if ( + should_apply_plugins + and "compile_state_plugin" in element._annotations + ): + plugin = element._annotations["compile_state_plugin"] + apply_plugins._compile_state_plugin = plugin + return element return self._literal_coercion(element, argname=argname, **kw) @@ -287,7 +310,7 @@ class _SelectIsNotFrom(object): advice = ( "To create a " "FROM clause from a %s object, use the .subquery() method." - % (element.__class__) + % (element.__class__,) ) code = "89ve" else: @@ -453,6 +476,18 @@ class OrderByImpl(ByOfImpl, RoleImpl): return resolved +class GroupByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(resolved, roles.StrictFromClauseRole): + return elements.ClauseList(*resolved.c) + else: + return resolved + + class DMLColumnImpl(_ReturnsStringKey, RoleImpl): __slots__ = () @@ -618,6 +653,37 @@ class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole): pass +class JoinTargetImpl(RoleImpl): + __slots__ = () + + def _literal_coercion(self, element, legacy=False, **kw): + if isinstance(element, str): + return element + + def _implicit_coercions( + self, original_element, resolved, argname=None, legacy=False, **kw + ): + if isinstance(original_element, roles.JoinTargetRole): + return original_element + elif legacy and isinstance(resolved, (str, roles.WhereHavingRole)): + return resolved + elif legacy and resolved._is_select_statement: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + # TODO: doing _implicit_subquery here causes tests to fail, + # how was this working before? probably that ORM + # join logic treated it as a select and subquery would happen + # in _ORMJoin->Join + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): __slots__ = () @@ -647,6 +713,12 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): else: self._raise_for_expected(original_element, argname, resolved) + def _post_coercion(self, element, deannotate=False, **kw): + if deannotate: + return element._deannotate() + else: + return element + class StrictFromClauseImpl(FromClauseImpl): __slots__ = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ccc1b53fe..9a7646743 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -653,6 +653,12 @@ class SQLCompiler(Compiled): """ + compile_state_factories = util.immutabledict() + """Dictionary of alternate :class:`.CompileState` factories for given + classes, identified by their visit_name. + + """ + def __init__( self, dialect, @@ -661,6 +667,7 @@ class SQLCompiler(Compiled): column_keys=None, inline=False, linting=NO_LINTING, + compile_state_factories=None, **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -727,6 +734,9 @@ class SQLCompiler(Compiled): # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} + if compile_state_factories: + self.compile_state_factories = compile_state_factories + Compiled.__init__(self, dialect, statement, **kwargs) if ( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 7310edd3f..c1bc9edbc 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -26,7 +26,6 @@ from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative from .base import Executable -from .base import HasCacheKey from .base import HasMemoized from .base import Immutable from .base import NO_ARG @@ -35,6 +34,7 @@ from .base import SingletonConstant from .coercions import _document_text_coercion from .traversals import _copy_internals from .traversals import _get_children +from .traversals import MemoizedHasCacheKey from .traversals import NO_CACHE from .visitors import cloned_traverse from .visitors import InternalTraversal @@ -179,7 +179,10 @@ def not_(clause): @inspection._self_inspects class ClauseElement( - roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible, + roles.SQLRole, + SupportsWrappingAnnotations, + MemoizedHasCacheKey, + Traversible, ): """Base class for elements of a programmatically constructed SQL expression. @@ -206,6 +209,7 @@ class ClauseElement( _is_select_container = False _is_select_statement = False _is_bind_parameter = False + _is_clause_list = False _order_by_label_element = None @@ -300,7 +304,7 @@ class ClauseElement( used. """ - return self._params(True, optionaldict, kwargs) + return self._replace_params(True, optionaldict, kwargs) def params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elements replaced. @@ -315,9 +319,9 @@ class ClauseElement( {'foo':7} """ - return self._params(False, optionaldict, kwargs) + return self._replace_params(False, optionaldict, kwargs) - def _params(self, unique, optionaldict, kwargs): + def _replace_params(self, unique, optionaldict, kwargs): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: @@ -371,7 +375,7 @@ class ClauseElement( continue if obj is not None: - result = meth(self, obj, **kw) + result = meth(self, attrname, obj, **kw) if result is not None: setattr(self, attrname, result) @@ -2070,6 +2074,8 @@ class ClauseList( __visit_name__ = "clauselist" + _is_clause_list = True + _traverse_internals = [ ("clauses", InternalTraversal.dp_clauseelement_list), ("operator", InternalTraversal.dp_operator), @@ -2079,6 +2085,8 @@ class ClauseList( self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) self.group_contents = kwargs.pop("group_contents", True) + if kwargs.pop("_flatten_sub_clauses", False): + clauses = util.flatten_iterator(clauses) self._tuple_values = kwargs.pop("_tuple_values", False) self._text_converter_role = text_converter_role = kwargs.pop( "_literal_as_text_role", roles.WhereHavingRole @@ -2116,7 +2124,9 @@ class ClauseList( @property def _select_iterable(self): - return iter(self) + return itertools.chain.from_iterable( + [elem._select_iterable for elem in self.clauses] + ) def append(self, clause): if self.group_contents: @@ -2224,6 +2234,32 @@ class BooleanClauseList(ClauseList, ColumnElement): return cls._construct_raw(operator) @classmethod + def _construct_for_whereclause(cls, clauses): + operator, continue_on, skip_on = ( + operators.and_, + True_._singleton, + False_._singleton, + ) + + lcc, convert_clauses = cls._process_clauses_for_boolean( + operator, + continue_on, + skip_on, + clauses, # these are assumed to be coerced already + ) + + if lcc > 1: + # multiple elements. Return regular BooleanClauseList + # which will link elements against the operator. + return cls._construct_raw(operator, convert_clauses) + elif lcc == 1: + # just one element. return it as a single boolean element, + # not a list and discard the operator. + return convert_clauses[0] + else: + return None + + @classmethod def _construct_raw(cls, operator, clauses=None): self = cls.__new__(cls) self.clauses = clauses if clauses else [] diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 72a0bdc95..b861f721b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -19,7 +19,7 @@ class SQLRole(object): class UsesInspection(object): - pass + _post_inspect = None class ColumnArgumentRole(SQLRole): @@ -54,6 +54,14 @@ class ByOfRole(ColumnListRole): _role_name = "GROUP BY / OF / etc. expression" +class GroupByRole(UsesInspection, ByOfRole): + # note there's a special case right now where you can pass a whole + # ORM entity to group_by() and it splits out. we may not want to keep + # this around + + _role_name = "GROUP BY expression" + + class OrderByRole(ByOfRole): _role_name = "ORDER BY expression" @@ -92,7 +100,14 @@ class InElementRole(SQLRole): ) -class FromClauseRole(ColumnsClauseRole): +class JoinTargetRole(UsesInspection, StructuralRole): + _role_name = ( + "Join target, typically a FROM expression, or ORM " + "relationship attribute" + ) + + +class FromClauseRole(ColumnsClauseRole, JoinTargetRole): _role_name = "FROM expression, such as a Table or alias() object" _is_subquery = False diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 689eda11d..65f8bd81c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -477,7 +477,10 @@ class Table(DialectKWArgs, SchemaItem, TableClause): ] def _gen_cache_key(self, anon_map, bindparams): - return (self,) + self._annotations_cache_key + if self._annotations: + return (self,) + self._annotations_cache_key + else: + return (self,) def __new__(cls, *args, **kw): if not args: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 85abbb5e0..6a552c18c 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -28,6 +28,7 @@ from .base import _expand_cloned from .base import _from_objects from .base import _generative from .base import _select_iterables +from .base import CacheableOptions from .base import ColumnCollection from .base import ColumnSet from .base import CompileState @@ -42,6 +43,7 @@ from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import and_ from .elements import BindParameter +from .elements import BooleanClauseList from .elements import ClauseElement from .elements import ClauseList from .elements import ColumnClause @@ -339,6 +341,90 @@ class HasSuffixes(object): ) +class HasHints(object): + _hints = util.immutabledict() + _statement_hints = () + + _has_hints_traverse_internals = [ + ("_statement_hints", InternalTraversal.dp_statement_hint_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + def with_statement_hint(self, text, dialect_name="*"): + """add a statement hint to this :class:`_expression.Select` or + other selectable object. + + This method is similar to :meth:`_expression.Select.with_hint` + except that + it does not require an individual table, and instead applies to the + statement as a whole. + + Hints here are specific to the backend database and may include + directives such as isolation levels, file directives, fetch directives, + etc. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`_expression.Select.with_hint` + + :meth:.`.Select.prefix_with` - generic SELECT prefixing which also + can suit some database-specific HINT syntaxes such as MySQL + optimizer hints + + """ + return self.with_hint(None, text, dialect_name) + + @_generative + def with_hint(self, selectable, text, dialect_name="*"): + r"""Add an indexing or other executional context hint for the given + selectable to this :class:`_expression.Select` or other selectable + object. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the given :class:`_schema.Table` or :class:`_expression.Alias` + passed as the + ``selectable`` argument. The dialect implementation + typically uses Python string substitution syntax + with the token ``%(name)s`` to render the name of + the table or alias. E.g. when using Oracle, the + following:: + + select([mytable]).\ + with_hint(mytable, "index(%(name)s ix_mytable)") + + Would render SQL as:: + + select /*+ index(mytable ix_mytable) */ ... from mytable + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add hints for both Oracle + and Sybase simultaneously:: + + select([mytable]).\ + with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\ + with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + + .. seealso:: + + :meth:`_expression.Select.with_statement_hint` + + """ + if selectable is None: + self._statement_hints += ((dialect_name, text),) + else: + self._hints = self._hints.union( + { + ( + coercions.expect(roles.FromClauseRole, selectable), + dialect_name, + ): text + } + ) + + class FromClause(roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -597,6 +683,22 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._populate_column_collection() return self._columns.as_immutable() + @property + def entity_namespace(self): + """Return a namespace used for name-based access in SQL expressions. + + This is the namespace that is used to resolve "filter_by()" type + expressions, such as:: + + stmt.filter_by(address='some address') + + It defaults to the .c collection, however internally it can + be overridden using the "entity_namespace" annotation to deliver + alternative results. + + """ + return self.columns + @util.memoized_property def primary_key(self): """Return the collection of Column objects which comprise the @@ -727,13 +829,21 @@ class Join(FromClause): :class:`_expression.FromClause` object. """ - self.left = coercions.expect(roles.FromClauseRole, left) - self.right = coercions.expect(roles.FromClauseRole, right).self_group() + self.left = coercions.expect( + roles.FromClauseRole, left, deannotate=True + ) + self.right = coercions.expect( + roles.FromClauseRole, right, deannotate=True + ).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) else: - self.onclause = onclause.self_group(against=operators._asbool) + # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba + # not merged yet + self.onclause = coercions.expect( + roles.WhereHavingRole, onclause + ).self_group(against=operators._asbool) self.isouter = isouter self.full = full @@ -1963,6 +2073,12 @@ class TableClause(Immutable, FromClause): if kw: raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw)) + def __str__(self): + if self.schema is not None: + return self.schema + "." + self.name + else: + return self.name + def _refresh_for_new_column(self, column): pass @@ -2905,7 +3021,8 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): self._group_by_clauses = () else: self._group_by_clauses += tuple( - coercions.expect(roles.ByOfRole, clause) for clause in clauses + coercions.expect(roles.GroupByRole, clause) + for clause in clauses ) @@ -3309,8 +3426,16 @@ class DeprecatedSelectGenerations(object): class SelectState(CompileState): + class default_select_compile_options(CacheableOptions): + _cache_key_traversal = [] + def __init__(self, statement, compiler, **kw): self.statement = statement + self.from_clauses = statement._from_obj + + if statement._setup_joins: + self._setup_joins(statement._setup_joins) + self.froms = self._get_froms(statement) self.columns_plus_names = statement._generate_columns_plus_names(True) @@ -3319,7 +3444,18 @@ class SelectState(CompileState): froms = [] seen = set() - for item in statement._iterate_from_elements(): + for item in itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in statement._raw_columns] + ), + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._where_criteria + ] + ), + self.from_clauses, + ): if item._is_subquery and item.element is statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" @@ -3341,6 +3477,7 @@ class SelectState(CompileState): correlating. """ + froms = self.froms toremove = set( @@ -3425,10 +3562,162 @@ class SelectState(CompileState): return with_cols, only_froms, only_cols + @classmethod + def determine_last_joined_entity(cls, stmt): + if stmt._setup_joins: + return stmt._setup_joins[-1][0] + else: + return None + + def _setup_joins(self, args): + for (right, onclause, left, flags) in args: + isouter = flags["isouter"] + full = flags["full"] + + if left is None: + ( + left, + replace_from_obj_index, + ) = self._join_determine_implicit_left_side( + left, right, onclause + ) + else: + (replace_from_obj_index) = self._join_place_explicit_left_side( + left + ) + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + ( + Join( + left_clause, + right, + onclause, + isouter=isouter, + full=full, + ), + ) + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + + self.from_clauses = self.from_clauses + ( + Join(left, right, onclause, isouter=isouter, full=full,), + ) + + @util.preload_module("sqlalchemy.sql.util") + def _join_determine_implicit_left_side(self, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + sql_util = util.preloaded.sql_util + + replace_from_obj_index = None + + from_clauses = self.statement._from_obj + + if from_clauses: + + indexes = sql_util.find_left_clause_to_join_from( + from_clauses, right, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = from_clauses[replace_from_obj_index] + else: + potential = {} + statement = self.statement + + for from_clause in itertools.chain( + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._raw_columns + ] + ), + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._where_criteria + ] + ), + ): + + potential[from_clause] = () + + all_clauses = list(potential.keys()) + indexes = sql_util.find_left_clause_to_join_from( + all_clauses, right, onclause + ) + + if len(indexes) == 1: + left = all_clauses[indexes[0]] + + if len(indexes) > 1: + raise exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." + ) + elif not indexes: + raise exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + return left, replace_from_obj_index + + @util.preload_module("sqlalchemy.sql.util") + def _join_place_explicit_left_side(self, left): + replace_from_obj_index = None + + sql_util = util.preloaded.sql_util + + from_clauses = list(self.statement._iterate_from_elements()) + + if from_clauses: + indexes = sql_util.find_left_clause_that_matches_given( + self.from_clauses, left + ) + else: + indexes = [] + + if len(indexes) > 1: + raise exc.InvalidRequestError( + "Can't identify which entity in which to assign the " + "left side of this join. Please use a more specific " + "ON clause." + ) + + # have an index, means the left side is already present in + # an existing FROM in the self._from_obj tuple + if indexes: + replace_from_obj_index = indexes[0] + + # no index, means we need to add a new element to the + # self._from_obj tuple + + return replace_from_obj_index + class Select( HasPrefixes, HasSuffixes, + HasHints, HasCompileState, DeprecatedSelectGenerations, GenerativeSelect, @@ -3440,9 +3729,10 @@ class Select( __visit_name__ = "select" _compile_state_factory = SelectState._create + _is_future = False + _setup_joins = () + _legacy_setup_joins = () - _hints = util.immutabledict() - _statement_hints = () _distinct = False _distinct_on = () _correlate = () @@ -3452,6 +3742,8 @@ class Select( _from_obj = () _auto_correlate = True + compile_options = SelectState.default_select_compile_options + _traverse_internals = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), @@ -3460,58 +3752,33 @@ class Select( ("_having_criteria", InternalTraversal.dp_clauseelement_list), ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,), ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,), + ("_setup_joins", InternalTraversal.dp_setup_join_tuple,), + ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,), ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), ( "_correlate_except", InternalTraversal.dp_clauseelement_unordered_set, ), ("_for_update_arg", InternalTraversal.dp_clauseelement), - ("_statement_hints", InternalTraversal.dp_statement_hint_list), - ("_hints", InternalTraversal.dp_table_hint_list), ("_distinct", InternalTraversal.dp_boolean), ("_distinct_on", InternalTraversal.dp_clauseelement_list), ("_label_style", InternalTraversal.dp_plain_obj), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals + + HasHints._has_hints_traverse_internals + SupportsCloneAnnotations._clone_annotations_traverse_internals + + Executable._executable_traverse_internals ) + _cache_key_traversal = _traverse_internals + [ + ("compile_options", InternalTraversal.dp_has_cache_key) + ] + @classmethod def _create_select(cls, *entities): - r"""Construct a new :class:`_expression.Select` using the 2. - x style API. - - .. versionadded:: 2.0 - the :func:`_future.select` construct is - the same construct as the one returned by - :func:`_expression.select`, except that the function only - accepts the "columns clause" entities up front; the rest of the - state of the SELECT should be built up using generative methods. - - Similar functionality is also available via the - :meth:`_expression.FromClause.select` method on any - :class:`_expression.FromClause`. - - .. seealso:: - - :ref:`coretutorial_selecting` - Core Tutorial description of - :func:`_expression.select`. - - :param \*entities: - Entities to SELECT from. For Core usage, this is typically a series - of :class:`_expression.ColumnElement` and / or - :class:`_expression.FromClause` - objects which will form the columns clause of the resulting - statement. For those objects that are instances of - :class:`_expression.FromClause` (typically :class:`_schema.Table` - or :class:`_expression.Alias` - objects), the :attr:`_expression.FromClause.c` - collection is extracted - to form a collection of :class:`_expression.ColumnElement` objects. - - This parameter will also accept :class:`_expression.TextClause` - constructs as - given, as well as ORM-mapped classes. + r"""Construct an old style :class:`_expression.Select` using the + the 2.x style constructor. """ @@ -3779,7 +4046,10 @@ class Select( if cols_present: self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, c,) for c in columns + coercions.expect( + roles.ColumnsClauseRole, c, apply_plugins=self + ) + for c in columns ] else: self._raw_columns = [] @@ -3820,71 +4090,6 @@ class Select( return self._compile_state_factory(self, None)._get_display_froms() - def with_statement_hint(self, text, dialect_name="*"): - """add a statement hint to this :class:`_expression.Select`. - - This method is similar to :meth:`_expression.Select.with_hint` - except that - it does not require an individual table, and instead applies to the - statement as a whole. - - Hints here are specific to the backend database and may include - directives such as isolation levels, file directives, fetch directives, - etc. - - .. versionadded:: 1.0.0 - - .. seealso:: - - :meth:`_expression.Select.with_hint` - - :meth:`.Select.prefix_with` - generic SELECT prefixing which also - can suit some database-specific HINT syntaxes such as MySQL - optimizer hints - - """ - return self.with_hint(None, text, dialect_name) - - @_generative - def with_hint(self, selectable, text, dialect_name="*"): - r"""Add an indexing or other executional context hint for the given - selectable to this :class:`_expression.Select`. - - The text of the hint is rendered in the appropriate - location for the database backend in use, relative - to the given :class:`_schema.Table` or :class:`_expression.Alias` - passed as the - ``selectable`` argument. The dialect implementation - typically uses Python string substitution syntax - with the token ``%(name)s`` to render the name of - the table or alias. E.g. when using Oracle, the - following:: - - select([mytable]).\ - with_hint(mytable, "index(%(name)s ix_mytable)") - - Would render SQL as:: - - select /*+ index(mytable ix_mytable) */ ... from mytable - - The ``dialect_name`` option will limit the rendering of a particular - hint to a particular backend. Such as, to add hints for both Oracle - and Sybase simultaneously:: - - select([mytable]).\ - with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\ - with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') - - .. seealso:: - - :meth:`_expression.Select.with_statement_hint` - - """ - if selectable is None: - self._statement_hints += ((dialect_name, text),) - else: - self._hints = self._hints.union({(selectable, dialect_name): text}) - @property def inner_columns(self): """an iterator of all ColumnElement expressions which would @@ -3921,9 +4126,16 @@ class Select( _from_objects(*self._where_criteria), ) ) + + # do a clone for the froms we've gathered. what is important here + # is if any of the things we are selecting from, like tables, + # were converted into Join objects. if so, these need to be + # added to _from_obj explicitly, because otherwise they won't be + # part of the new state, as they don't associate themselves with + # their columns. new_froms = {f: clone(f, **kw) for f in all_the_froms} - # 2. copy FROM collections. + # 2. copy FROM collections, adding in joins that we've created. self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple( f for f in new_froms.values() if isinstance(f, Join) ) @@ -3937,6 +4149,10 @@ class Select( kw["replace"] = replace + # copy everything else. for table-ish things like correlate, + # correlate_except, setup_joins, these clone normally. For + # column-expression oriented things like raw_columns, where_criteria, + # order by, we get this from the new froms. super(Select, self)._copy_internals( clone=clone, omit_attrs=("_from_obj",), **kw ) @@ -3975,10 +4191,18 @@ class Select( self._assert_no_memoizations() self._raw_columns = self._raw_columns + [ - coercions.expect(roles.ColumnsClauseRole, column,) + coercions.expect( + roles.ColumnsClauseRole, column, apply_plugins=self + ) for column in columns ] + def _set_entities(self, entities): + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self) + for ent in util.to_list(entities) + ] + @util.deprecated( "1.4", "The :meth:`_expression.Select.column` method is deprecated and will " @@ -4111,6 +4335,7 @@ class Select( rc = [] for c in columns: c = coercions.expect(roles.ColumnsClauseRole, c,) + # TODO: why are we doing this here? if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) rc.append(c) @@ -4121,7 +4346,9 @@ class Select( """Legacy, return the WHERE clause as a """ """:class:`_expression.BooleanClauseList`""" - return and_(*self._where_criteria) + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) @_generative def where(self, whereclause): @@ -4202,7 +4429,9 @@ class Select( """ self._from_obj += tuple( - coercions.expect(roles.FromClauseRole, fromclause) + coercions.expect( + roles.FromClauseRole, fromclause, apply_plugins=self + ) for fromclause in froms ) diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 8c63fcba1..a308feb7c 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -29,9 +29,8 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) -class HasCacheKey(HasMemoized): +class HasCacheKey(object): _cache_key_traversal = NO_CACHE - __slots__ = () def _gen_cache_key(self, anon_map, bindparams): @@ -141,7 +140,6 @@ class HasCacheKey(HasMemoized): return result - @HasMemoized.memoized_instancemethod def _generate_cache_key(self): """return a cache key. @@ -183,6 +181,23 @@ class HasCacheKey(HasMemoized): else: return CacheKey(key, bindparams) + @classmethod + def _generate_cache_key_for_object(cls, obj): + bindparams = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key(self) + class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __hash__(self): @@ -191,6 +206,40 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __eq__(self, other): return self.key == other.key + def _whats_different(self, other): + + k1 = self.key + k2 = other.key + + stack = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other): + return ", ".join(self._whats_different(other)) + def __str__(self): stack = [self.key] @@ -241,9 +290,7 @@ class _CacheKey(ExtendedInternalTraversal): visit_type = STATIC_CACHE_KEY def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): - return self.visit_has_cache_key( - attrname, inspect(obj), parent, anon_map, bindparams - ) + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): return tuple(obj) @@ -361,6 +408,24 @@ class _CacheKey(ExtendedInternalTraversal): ), ) + def visit_setup_join_tuple( + self, attrname, obj, parent, anon_map, bindparams + ): + # TODO: look at attrname for "legacy_join" and use different structure + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + def visit_table_hint_list( self, attrname, obj, parent, anon_map, bindparams ): @@ -498,31 +563,53 @@ class _CopyInternals(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" - def visit_clauseelement(self, parent, element, clone=_clone, **kw): + def visit_clauseelement( + self, attrname, parent, element, clone=_clone, **kw + ): return clone(element, **kw) - def visit_clauseelement_list(self, parent, element, clone=_clone, **kw): + def visit_clauseelement_list( + self, attrname, parent, element, clone=_clone, **kw + ): return [clone(clause, **kw) for clause in element] def visit_clauseelement_unordered_set( - self, parent, element, clone=_clone, **kw + self, attrname, parent, element, clone=_clone, **kw ): return {clone(clause, **kw) for clause in element} - def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw): + def visit_clauseelement_tuples( + self, attrname, parent, element, clone=_clone, **kw + ): return [ tuple(clone(tup_elem, **kw) for tup_elem in elem) for elem in element ] def visit_string_clauseelement_dict( - self, parent, element, clone=_clone, **kw + self, attrname, parent, element, clone=_clone, **kw ): return dict( (key, clone(value, **kw)) for key, value in element.items() ) - def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw): + def visit_setup_join_tuple( + self, attrname, parent, element, clone=_clone, **kw + ): + # TODO: look at attrname for "legacy_join" and use different structure + return tuple( + ( + clone(target, **kw) if target is not None else None, + clone(onclause, **kw) if onclause is not None else None, + clone(from_, **kw) if from_ is not None else None, + flags, + ) + for (target, onclause, from_, flags) in element + ) + + def visit_dml_ordered_values( + self, attrname, parent, element, clone=_clone, **kw + ): # sequence of 2-tuples return [ ( @@ -534,7 +621,7 @@ class _CopyInternals(InternalTraversal): for key, value in element ] - def visit_dml_values(self, parent, element, clone=_clone, **kw): + def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw): return { ( clone(key, **kw) if hasattr(key, "__clause_element__") else key @@ -542,7 +629,9 @@ class _CopyInternals(InternalTraversal): for key, value in element.items() } - def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): + def visit_dml_multi_values( + self, attrname, parent, element, clone=_clone, **kw + ): # sequence of sequences, each sequence contains a list/dict/tuple def copy(elem): @@ -741,7 +830,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): continue comparison = dispatch( - left, left_child, right, right_child, **kw + left_attrname, left, left_child, right, right_child, **kw ) if comparison is COMPARE_FAILED: return False @@ -753,31 +842,40 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return comparator.compare(obj1, obj2, **kw) def visit_has_cache_key( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key( self.anon_map[1], [] ): return COMPARE_FAILED + def visit_has_cache_key_list( + self, attrname, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key( + self.anon_map[1], [] + ): + return COMPARE_FAILED + def visit_clauseelement( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): self.stack.append((left, right)) def visit_fromclause_canonical_column_collection( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lcol, rcol in util.zip_longest(left, right, fillvalue=None): self.stack.append((lcol, rcol)) def visit_fromclause_derived_column_collection( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): pass def visit_string_clauseelement_dict( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lstr, rstr in util.zip_longest( sorted(left), sorted(right), fillvalue=None @@ -787,7 +885,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((left[lstr], right[rstr])) def visit_clauseelement_tuples( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for ltup, rtup in util.zip_longest(left, right, fillvalue=None): if ltup is None or rtup is None: @@ -797,7 +895,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((l, r)) def visit_clauseelement_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) @@ -815,48 +913,62 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return len(completed) == len(seq1) == len(seq2) def visit_clauseelement_unordered_set( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return self._compare_unordered_sequences(left, right, **kw) def visit_fromclause_ordered_set( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) - def visit_string(self, left_parent, left, right_parent, right, **kw): + def visit_string( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_string_list(self, left_parent, left, right_parent, right, **kw): + def visit_string_list( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_anon_name(self, left_parent, left, right_parent, right, **kw): + def visit_anon_name( + self, attrname, left_parent, left, right_parent, right, **kw + ): return _resolve_name_for_compare( left_parent, left, self.anon_map[0], **kw ) == _resolve_name_for_compare( right_parent, right, self.anon_map[1], **kw ) - def visit_boolean(self, left_parent, left, right_parent, right, **kw): + def visit_boolean( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_operator(self, left_parent, left, right_parent, right, **kw): + def visit_operator( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left is right - def visit_type(self, left_parent, left, right_parent, right, **kw): + def visit_type( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left._compare_type_affinity(right) - def visit_plain_dict(self, left_parent, left, right_parent, right, **kw): + def visit_plain_dict( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right def visit_dialect_options( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return left == right def visit_annotations_key( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left and right: return ( @@ -866,11 +978,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): else: return left == right - def visit_plain_obj(self, left_parent, left, right_parent, right, **kw): + def visit_plain_obj( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right def visit_named_ddl_element( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left is None: if right is not None: @@ -879,7 +993,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return left.name == right.name def visit_prefix_sequence( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for (l_clause, l_str), (r_clause, r_str) in util.zip_longest( left, right, fillvalue=(None, None) @@ -889,8 +1003,22 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): else: self.stack.append((l_clause, r_clause)) + def visit_setup_join_tuple( + self, attrname, left_parent, left, right_parent, right, **kw + ): + # TODO: look at attrname for "legacy_join" and use different structure + for ( + (l_target, l_onclause, l_from, l_flags), + (r_target, r_onclause, r_from, r_flags), + ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)): + if l_flags != r_flags: + return COMPARE_FAILED + self.stack.append((l_target, r_target)) + self.stack.append((l_onclause, r_onclause)) + self.stack.append((l_from, r_from)) + def visit_table_hint_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1])) right_keys = sorted( @@ -907,17 +1035,17 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((ltable, rtable)) def visit_statement_hint_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return left == right def visit_unknown_structure( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): raise NotImplementedError() def visit_dml_ordered_values( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): # sequence of tuple pairs @@ -941,7 +1069,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return True - def visit_dml_values(self, left_parent, left, right_parent, right, **kw): + def visit_dml_values( + self, attrname, left_parent, left, right_parent, right, **kw + ): if left is None or right is None or len(left) != len(right): return COMPARE_FAILED @@ -961,7 +1091,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return COMPARE_FAILED def visit_dml_multi_values( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lseq, rseq in util.zip_longest(left, right, fillvalue=None): if lseq is None or rseq is None: @@ -970,7 +1100,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None): if ( self.visit_dml_values( - left_parent, ld, right_parent, rd, **kw + attrname, left_parent, ld, right_parent, rd, **kw ) is COMPARE_FAILED ): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 0a67ff9bf..377aa4fe0 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -37,6 +37,7 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause +from .traversals import HasCacheKey # noqa from .. import exc from .. import util @@ -921,6 +922,14 @@ class ColumnAdapter(ClauseAdapter): adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process + def adapt_check_present(self, col): + newcol = self.columns[col] + + if newcol is col and self._corresponding_column(col, True) is None: + return None + + return newcol + def _locate_col(self, col): c = ClauseAdapter.traverse(self, col) @@ -945,3 +954,25 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) self.columns = util.WeakPopulateDict(self._locate_col) + + +def _entity_namespace_key(entity, key): + """Return an entry from an entity_namespace. + + + Raises :class:`_exc.InvalidRequestError` rather than attribute error + on not found. + + """ + + ns = entity.entity_namespace + try: + return getattr(ns, key) + except AttributeError as err: + util.raise_( + exc.InvalidRequestError( + 'Entity namespace for "%s" has no property "%s"' + % (entity, key) + ), + replace_context=err, + ) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 574896cc7..030fd2fde 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -225,6 +225,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): dp_has_cache_key = symbol("HC") """Visit a :class:`.HasCacheKey` object.""" + dp_has_cache_key_list = symbol("HL") + """Visit a list of :class:`.HasCacheKey` objects.""" + dp_clauseelement = symbol("CE") """Visit a :class:`_expression.ClauseElement` object.""" @@ -372,6 +375,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_setup_join_tuple = symbol("SJ") + dp_statement_hint_list = symbol("SH") """Visit the ``_statement_hints`` collection of a :class:`_expression.Select` @@ -437,9 +442,6 @@ class ExtendedInternalTraversal(InternalTraversal): """ - dp_has_cache_key_list = symbol("HL") - """Visit a list of :class:`.HasCacheKey` objects.""" - dp_inspectable_list = symbol("IL") """Visit a list of inspectable objects which upon inspection are HasCacheKey objects.""" diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index ba4a2de72..0ea9f067e 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -398,9 +398,11 @@ class AssertsCompiledSQL(object): from sqlalchemy import orm if isinstance(clause, orm.Query): - context = clause._compile_context() - context.statement._label_style = LABEL_STYLE_TABLENAME_PLUS_COL - clause = context.statement + compile_state = clause._compile_state() + compile_state.statement._label_style = ( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + clause = compile_state.statement elif isinstance(clause, orm.persistence.BulkUD): with mock.patch.object(clause, "_execute_stmt") as stmt_mock: clause.exec_() diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 065935c48..9a832ba1b 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -103,6 +103,12 @@ class FacadeDict(ImmutableContainer, dict): def __init__(self, *args): pass + # note that currently, "copy()" is used as a way to get a plain dict + # from an immutabledict, while also allowing the method to work if the + # dictionary is already a plain dict. + # def copy(self): + # return immutabledict.__new__(immutabledict, self) + def __reduce__(self): return FacadeDict, (dict(self),) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index bd670f2cc..f6fefc244 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -930,6 +930,8 @@ class HasMemoized(object): """ + __slots__ = () + _memoized_keys = frozenset() def _reset_memoizations(self): @@ -1273,13 +1275,18 @@ class hybridmethod(object): def __init__(self, func): self.func = func + self.clslevel = func def __get__(self, instance, owner): if instance is None: - return self.func.__get__(owner, owner.__class__) + return self.clslevel.__get__(owner, owner.__class__) else: return self.func.__get__(instance, owner) + def classlevel(self, func): + self.clslevel = func + return self + class _symbol(int): def __new__(self, name, doc=None, canonical=None): |
