summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-12-01 17:24:27 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2020-05-24 11:54:08 -0400
commitdce8c7a125cb99fad62c76cd145752d5afefae36 (patch)
tree352dfa2c38005207ca64f45170bbba2c0f8c927e /lib/sqlalchemy/sql
parent1502b5b3e4e4b93021eb927a6623f288ef006ba6 (diff)
downloadsqlalchemy-dce8c7a125cb99fad62c76cd145752d5afefae36.tar.gz
Unify Query and select() , move all processing to compile phase
Convert Query to do virtually all compile state computation in the _compile_context() phase, and organize it all such that a plain select() construct may also be used as the source of information in order to generate ORM query state. This makes it such that Query is not needed except for its additional methods like from_self() which are all to be deprecated. The construction of ORM state will occur beyond the caching boundary when the new execution model is integrated. future select() gains a working join() and filter_by() method. as we continue to rebase and merge each commit in the steps, callcounts continue to bump around. will have to look at the final result when it's all in. References: #5159 References: #4705 References: #4639 References: #4871 References: #5010 Change-Id: I19e05b3424b07114cce6c439b05198ac47f7ac10
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/annotation.py18
-rw-r--r--lib/sqlalchemy/sql/base.py128
-rw-r--r--lib/sqlalchemy/sql/coercions.py82
-rw-r--r--lib/sqlalchemy/sql/compiler.py10
-rw-r--r--lib/sqlalchemy/sql/elements.py50
-rw-r--r--lib/sqlalchemy/sql/roles.py19
-rw-r--r--lib/sqlalchemy/sql/schema.py5
-rw-r--r--lib/sqlalchemy/sql/selectable.py453
-rw-r--r--lib/sqlalchemy/sql/traversals.py216
-rw-r--r--lib/sqlalchemy/sql/util.py31
-rw-r--r--lib/sqlalchemy/sql/visitors.py8
11 files changed, 837 insertions, 183 deletions
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 891b8ae09..71d05f38f 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -31,7 +31,10 @@ class SupportsAnnotations(object):
if isinstance(value, HasCacheKey)
else value,
)
- for key, value in self._annotations.items()
+ for key, value in [
+ (key, self._annotations[key])
+ for key in sorted(self._annotations)
+ ]
),
)
@@ -51,6 +54,7 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new = self._clone()
new._annotations = new._annotations.union(values)
new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
return new
def _with_annotations(self, values):
@@ -61,6 +65,7 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new = self._clone()
new._annotations = util.immutabledict(values)
new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
return new
def _deannotate(self, values=None, clone=False):
@@ -76,7 +81,7 @@ class SupportsCloneAnnotations(SupportsAnnotations):
# clone is used when we are also copying
# the expression for a deep deannotation
new = self._clone()
- new._annotations = {}
+ new._annotations = util.immutabledict()
new.__dict__.pop("_annotations_cache_key", None)
return new
else:
@@ -156,6 +161,7 @@ class Annotated(object):
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
self.__dict__.pop("_annotations_cache_key", None)
+ self.__dict__.pop("_generate_cache_key", None)
self.__element = element
self._annotations = values
self._hash = hash(element)
@@ -169,6 +175,7 @@ class Annotated(object):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotations_cache_key", None)
+ clone.__dict__.pop("_generate_cache_key", None)
clone._annotations = values
return clone
@@ -211,6 +218,13 @@ class Annotated(object):
else:
return hash(other) == hash(self)
+ @property
+ def entity_namespace(self):
+ if "entity_namespace" in self._annotations:
+ return self._annotations["entity_namespace"].entity_namespace
+ else:
+ return self.__element.entity_namespace
+
# hard-generate Annotated subclasses. this technique
# is used instead of on-the-fly types (i.e. type.__new__())
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 2d023c6a6..04cc34480 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -15,11 +15,14 @@ import operator
import re
from .traversals import HasCacheKey # noqa
+from .traversals import MemoizedHasCacheKey # noqa
from .visitors import ClauseVisitor
+from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import exc
from .. import util
from ..util import HasMemoized
+from ..util import hybridmethod
if util.TYPE_CHECKING:
from types import ModuleType
@@ -433,22 +436,52 @@ class CompileState(object):
__slots__ = ("statement",)
+ plugins = {}
+
@classmethod
def _create(cls, statement, compiler, **kw):
# factory construction.
- # specific CompileState classes here will look for
- # "plugins" in the given statement. From there they will invoke
- # the appropriate plugin constructor if one is found and return
- # the alternate CompileState object.
+ if statement._compile_state_plugin is not None:
+ constructor = cls.plugins.get(
+ (
+ statement._compile_state_plugin,
+ statement.__visit_name__,
+ None,
+ ),
+ cls,
+ )
+ else:
+ constructor = cls
- c = cls.__new__(cls)
- c.__init__(statement, compiler, **kw)
- return c
+ return constructor(statement, compiler, **kw)
def __init__(self, statement, compiler, **kw):
self.statement = statement
+ @classmethod
+ def get_plugin_classmethod(cls, statement, name):
+ if statement._compile_state_plugin is not None:
+ fn = cls.plugins.get(
+ (
+ statement._compile_state_plugin,
+ statement.__visit_name__,
+ name,
+ ),
+ None,
+ )
+ if fn is not None:
+ return fn
+ return getattr(cls, name)
+
+ @classmethod
+ def plugin_for(cls, plugin_name, visit_name, method_name=None):
+ def decorate(fn):
+ cls.plugins[(plugin_name, visit_name, method_name)] = fn
+ return fn
+
+ return decorate
+
class Generative(HasMemoized):
"""Provide a method-chaining pattern in conjunction with the
@@ -479,6 +512,57 @@ class HasCompileState(Generative):
_compile_state_plugin = None
+ _attributes = util.immutabledict()
+
+
+class _MetaOptions(type):
+ """metaclass for the Options class."""
+
+ def __init__(cls, classname, bases, dict_):
+ cls._cache_attrs = tuple(
+ sorted(d for d in dict_ if not d.startswith("__"))
+ )
+ type.__init__(cls, classname, bases, dict_)
+
+ def __add__(self, other):
+ o1 = self()
+ o1.__dict__.update(other)
+ return o1
+
+
+class Options(util.with_metaclass(_MetaOptions)):
+ """A cacheable option dictionary with defaults.
+
+
+ """
+
+ def __init__(self, **kw):
+ self.__dict__.update(kw)
+
+ def __add__(self, other):
+ o1 = self.__class__.__new__(self.__class__)
+ o1.__dict__.update(self.__dict__)
+ o1.__dict__.update(other)
+ return o1
+
+ @hybridmethod
+ def add_to_element(self, name, value):
+ return self + {name: getattr(self, name) + value}
+
+
+class CacheableOptions(Options, HasCacheKey):
+ @hybridmethod
+ def _gen_cache_key(self, anon_map, bindparams):
+ return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
+
+ @_gen_cache_key.classlevel
+ def _gen_cache_key(cls, anon_map, bindparams):
+ return (cls, ())
+
+ @hybridmethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key_for_object(self)
+
class Executable(Generative):
"""Mark a ClauseElement as supporting execution.
@@ -492,7 +576,21 @@ class Executable(Generative):
supports_execution = True
_execution_options = util.immutabledict()
_bind = None
+ _with_options = ()
+ _with_context_options = ()
+ _cache_enable = True
+
+ _executable_traverse_internals = [
+ ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list),
+ ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj),
+ ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj),
+ ]
+
+ @_generative
+ def _disable_caching(self):
+ self._cache_enable = HasCacheKey()
+ @_generative
def options(self, *options):
"""Apply options to this statement.
@@ -522,7 +620,21 @@ class Executable(Generative):
to the usage of ORM queries
"""
- self._options += options
+ self._with_options += options
+
+ @_generative
+ def _add_context_option(self, callable_, cache_args):
+ """Add a context option to this statement.
+
+ These are callable functions that will
+ be given the CompileState object upon compilation.
+
+ A second argument cache_args is required, which will be combined
+ with the identity of the function itself in order to produce a
+ cache key.
+
+ """
+ self._with_context_options += ((callable_, cache_args),)
@_generative
def execution_options(self, **kw):
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index c7b85c415..2fc63b82f 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -57,7 +57,7 @@ def expect(role, element, **kw):
if not isinstance(
element,
- (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue),
+ (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue,),
):
resolved = impl._resolve_for_clause_element(element, **kw)
else:
@@ -106,7 +106,9 @@ class RoleImpl(object):
self.name = role_class._role_name
self._use_inspection = issubclass(role_class, roles.UsesInspection)
- def _resolve_for_clause_element(self, element, argname=None, **kw):
+ def _resolve_for_clause_element(
+ self, element, argname=None, apply_plugins=None, **kw
+ ):
original_element = element
is_clause_element = False
@@ -115,18 +117,39 @@ class RoleImpl(object):
if not getattr(element, "is_clause_element", False):
element = element.__clause_element__()
else:
- return element
+ break
+
+ should_apply_plugins = (
+ apply_plugins is not None
+ and apply_plugins._compile_state_plugin is None
+ )
if is_clause_element:
+ if (
+ should_apply_plugins
+ and "compile_state_plugin" in element._annotations
+ ):
+ apply_plugins._compile_state_plugin = element._annotations[
+ "compile_state_plugin"
+ ]
return element
if self._use_inspection:
insp = inspection.inspect(element, raiseerr=False)
if insp is not None:
+ insp._post_inspect
try:
- return insp.__clause_element__()
+ element = insp.__clause_element__()
except AttributeError:
self._raise_for_expected(original_element, argname)
+ else:
+ if (
+ should_apply_plugins
+ and "compile_state_plugin" in element._annotations
+ ):
+ plugin = element._annotations["compile_state_plugin"]
+ apply_plugins._compile_state_plugin = plugin
+ return element
return self._literal_coercion(element, argname=argname, **kw)
@@ -287,7 +310,7 @@ class _SelectIsNotFrom(object):
advice = (
"To create a "
"FROM clause from a %s object, use the .subquery() method."
- % (element.__class__)
+ % (element.__class__,)
)
code = "89ve"
else:
@@ -453,6 +476,18 @@ class OrderByImpl(ByOfImpl, RoleImpl):
return resolved
+class GroupByImpl(ByOfImpl, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(resolved, roles.StrictFromClauseRole):
+ return elements.ClauseList(*resolved.c)
+ else:
+ return resolved
+
+
class DMLColumnImpl(_ReturnsStringKey, RoleImpl):
__slots__ = ()
@@ -618,6 +653,37 @@ class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole):
pass
+class JoinTargetImpl(RoleImpl):
+ __slots__ = ()
+
+ def _literal_coercion(self, element, legacy=False, **kw):
+ if isinstance(element, str):
+ return element
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, legacy=False, **kw
+ ):
+ if isinstance(original_element, roles.JoinTargetRole):
+ return original_element
+ elif legacy and isinstance(resolved, (str, roles.WhereHavingRole)):
+ return resolved
+ elif legacy and resolved._is_select_statement:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT "
+ "constructs into FROM clauses is deprecated; please call "
+ ".subquery() on any Core select or ORM Query object in "
+ "order to produce a subquery object.",
+ version="1.4",
+ )
+ # TODO: doing _implicit_subquery here causes tests to fail,
+ # how was this working before? probably that ORM
+ # join logic treated it as a select and subquery would happen
+ # in _ORMJoin->Join
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
__slots__ = ()
@@ -647,6 +713,12 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
else:
self._raise_for_expected(original_element, argname, resolved)
+ def _post_coercion(self, element, deannotate=False, **kw):
+ if deannotate:
+ return element._deannotate()
+ else:
+ return element
+
class StrictFromClauseImpl(FromClauseImpl):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ccc1b53fe..9a7646743 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -653,6 +653,12 @@ class SQLCompiler(Compiled):
"""
+ compile_state_factories = util.immutabledict()
+ """Dictionary of alternate :class:`.CompileState` factories for given
+ classes, identified by their visit_name.
+
+ """
+
def __init__(
self,
dialect,
@@ -661,6 +667,7 @@ class SQLCompiler(Compiled):
column_keys=None,
inline=False,
linting=NO_LINTING,
+ compile_state_factories=None,
**kwargs
):
"""Construct a new :class:`.SQLCompiler` object.
@@ -727,6 +734,9 @@ class SQLCompiler(Compiled):
# dialect.label_length or dialect.max_identifier_length
self.truncated_names = {}
+ if compile_state_factories:
+ self.compile_state_factories = compile_state_factories
+
Compiled.__init__(self, dialect, statement, **kwargs)
if (
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 7310edd3f..c1bc9edbc 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -26,7 +26,6 @@ from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
-from .base import HasCacheKey
from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
@@ -35,6 +34,7 @@ from .base import SingletonConstant
from .coercions import _document_text_coercion
from .traversals import _copy_internals
from .traversals import _get_children
+from .traversals import MemoizedHasCacheKey
from .traversals import NO_CACHE
from .visitors import cloned_traverse
from .visitors import InternalTraversal
@@ -179,7 +179,10 @@ def not_(clause):
@inspection._self_inspects
class ClauseElement(
- roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible,
+ roles.SQLRole,
+ SupportsWrappingAnnotations,
+ MemoizedHasCacheKey,
+ Traversible,
):
"""Base class for elements of a programmatically constructed SQL
expression.
@@ -206,6 +209,7 @@ class ClauseElement(
_is_select_container = False
_is_select_statement = False
_is_bind_parameter = False
+ _is_clause_list = False
_order_by_label_element = None
@@ -300,7 +304,7 @@ class ClauseElement(
used.
"""
- return self._params(True, optionaldict, kwargs)
+ return self._replace_params(True, optionaldict, kwargs)
def params(self, *optionaldict, **kwargs):
"""Return a copy with :func:`bindparam()` elements replaced.
@@ -315,9 +319,9 @@ class ClauseElement(
{'foo':7}
"""
- return self._params(False, optionaldict, kwargs)
+ return self._replace_params(False, optionaldict, kwargs)
- def _params(self, unique, optionaldict, kwargs):
+ def _replace_params(self, unique, optionaldict, kwargs):
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
@@ -371,7 +375,7 @@ class ClauseElement(
continue
if obj is not None:
- result = meth(self, obj, **kw)
+ result = meth(self, attrname, obj, **kw)
if result is not None:
setattr(self, attrname, result)
@@ -2070,6 +2074,8 @@ class ClauseList(
__visit_name__ = "clauselist"
+ _is_clause_list = True
+
_traverse_internals = [
("clauses", InternalTraversal.dp_clauseelement_list),
("operator", InternalTraversal.dp_operator),
@@ -2079,6 +2085,8 @@ class ClauseList(
self.operator = kwargs.pop("operator", operators.comma_op)
self.group = kwargs.pop("group", True)
self.group_contents = kwargs.pop("group_contents", True)
+ if kwargs.pop("_flatten_sub_clauses", False):
+ clauses = util.flatten_iterator(clauses)
self._tuple_values = kwargs.pop("_tuple_values", False)
self._text_converter_role = text_converter_role = kwargs.pop(
"_literal_as_text_role", roles.WhereHavingRole
@@ -2116,7 +2124,9 @@ class ClauseList(
@property
def _select_iterable(self):
- return iter(self)
+ return itertools.chain.from_iterable(
+ [elem._select_iterable for elem in self.clauses]
+ )
def append(self, clause):
if self.group_contents:
@@ -2224,6 +2234,32 @@ class BooleanClauseList(ClauseList, ColumnElement):
return cls._construct_raw(operator)
@classmethod
+ def _construct_for_whereclause(cls, clauses):
+ operator, continue_on, skip_on = (
+ operators.and_,
+ True_._singleton,
+ False_._singleton,
+ )
+
+ lcc, convert_clauses = cls._process_clauses_for_boolean(
+ operator,
+ continue_on,
+ skip_on,
+ clauses, # these are assumed to be coerced already
+ )
+
+ if lcc > 1:
+ # multiple elements. Return regular BooleanClauseList
+ # which will link elements against the operator.
+ return cls._construct_raw(operator, convert_clauses)
+ elif lcc == 1:
+ # just one element. return it as a single boolean element,
+ # not a list and discard the operator.
+ return convert_clauses[0]
+ else:
+ return None
+
+ @classmethod
def _construct_raw(cls, operator, clauses=None):
self = cls.__new__(cls)
self.clauses = clauses if clauses else []
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 72a0bdc95..b861f721b 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -19,7 +19,7 @@ class SQLRole(object):
class UsesInspection(object):
- pass
+ _post_inspect = None
class ColumnArgumentRole(SQLRole):
@@ -54,6 +54,14 @@ class ByOfRole(ColumnListRole):
_role_name = "GROUP BY / OF / etc. expression"
+class GroupByRole(UsesInspection, ByOfRole):
+ # note there's a special case right now where you can pass a whole
+ # ORM entity to group_by() and it splits out. we may not want to keep
+ # this around
+
+ _role_name = "GROUP BY expression"
+
+
class OrderByRole(ByOfRole):
_role_name = "ORDER BY expression"
@@ -92,7 +100,14 @@ class InElementRole(SQLRole):
)
-class FromClauseRole(ColumnsClauseRole):
+class JoinTargetRole(UsesInspection, StructuralRole):
+ _role_name = (
+ "Join target, typically a FROM expression, or ORM "
+ "relationship attribute"
+ )
+
+
+class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
_role_name = "FROM expression, such as a Table or alias() object"
_is_subquery = False
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 689eda11d..65f8bd81c 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -477,7 +477,10 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
]
def _gen_cache_key(self, anon_map, bindparams):
- return (self,) + self._annotations_cache_key
+ if self._annotations:
+ return (self,) + self._annotations_cache_key
+ else:
+ return (self,)
def __new__(cls, *args, **kw):
if not args:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 85abbb5e0..6a552c18c 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -28,6 +28,7 @@ from .base import _expand_cloned
from .base import _from_objects
from .base import _generative
from .base import _select_iterables
+from .base import CacheableOptions
from .base import ColumnCollection
from .base import ColumnSet
from .base import CompileState
@@ -42,6 +43,7 @@ from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import and_
from .elements import BindParameter
+from .elements import BooleanClauseList
from .elements import ClauseElement
from .elements import ClauseList
from .elements import ColumnClause
@@ -339,6 +341,90 @@ class HasSuffixes(object):
)
+class HasHints(object):
+ _hints = util.immutabledict()
+ _statement_hints = ()
+
+ _has_hints_traverse_internals = [
+ ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+
+ def with_statement_hint(self, text, dialect_name="*"):
+ """add a statement hint to this :class:`_expression.Select` or
+ other selectable object.
+
+ This method is similar to :meth:`_expression.Select.with_hint`
+ except that
+ it does not require an individual table, and instead applies to the
+ statement as a whole.
+
+ Hints here are specific to the backend database and may include
+ directives such as isolation levels, file directives, fetch directives,
+ etc.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_hint`
+
+ :meth:.`.Select.prefix_with` - generic SELECT prefixing which also
+ can suit some database-specific HINT syntaxes such as MySQL
+ optimizer hints
+
+ """
+ return self.with_hint(None, text, dialect_name)
+
+ @_generative
+ def with_hint(self, selectable, text, dialect_name="*"):
+ r"""Add an indexing or other executional context hint for the given
+ selectable to this :class:`_expression.Select` or other selectable
+ object.
+
+ The text of the hint is rendered in the appropriate
+ location for the database backend in use, relative
+ to the given :class:`_schema.Table` or :class:`_expression.Alias`
+ passed as the
+ ``selectable`` argument. The dialect implementation
+ typically uses Python string substitution syntax
+ with the token ``%(name)s`` to render the name of
+ the table or alias. E.g. when using Oracle, the
+ following::
+
+ select([mytable]).\
+ with_hint(mytable, "index(%(name)s ix_mytable)")
+
+ Would render SQL as::
+
+ select /*+ index(mytable ix_mytable) */ ... from mytable
+
+ The ``dialect_name`` option will limit the rendering of a particular
+ hint to a particular backend. Such as, to add hints for both Oracle
+ and Sybase simultaneously::
+
+ select([mytable]).\
+ with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\
+ with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_statement_hint`
+
+ """
+ if selectable is None:
+ self._statement_hints += ((dialect_name, text),)
+ else:
+ self._hints = self._hints.union(
+ {
+ (
+ coercions.expect(roles.FromClauseRole, selectable),
+ dialect_name,
+ ): text
+ }
+ )
+
+
class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
@@ -597,6 +683,22 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._populate_column_collection()
return self._columns.as_immutable()
+ @property
+ def entity_namespace(self):
+ """Return a namespace used for name-based access in SQL expressions.
+
+ This is the namespace that is used to resolve "filter_by()" type
+ expressions, such as::
+
+ stmt.filter_by(address='some address')
+
+ It defaults to the .c collection, however internally it can
+ be overridden using the "entity_namespace" annotation to deliver
+ alternative results.
+
+ """
+ return self.columns
+
@util.memoized_property
def primary_key(self):
"""Return the collection of Column objects which comprise the
@@ -727,13 +829,21 @@ class Join(FromClause):
:class:`_expression.FromClause` object.
"""
- self.left = coercions.expect(roles.FromClauseRole, left)
- self.right = coercions.expect(roles.FromClauseRole, right).self_group()
+ self.left = coercions.expect(
+ roles.FromClauseRole, left, deannotate=True
+ )
+ self.right = coercions.expect(
+ roles.FromClauseRole, right, deannotate=True
+ ).self_group()
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
else:
- self.onclause = onclause.self_group(against=operators._asbool)
+ # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba
+ # not merged yet
+ self.onclause = coercions.expect(
+ roles.WhereHavingRole, onclause
+ ).self_group(against=operators._asbool)
self.isouter = isouter
self.full = full
@@ -1963,6 +2073,12 @@ class TableClause(Immutable, FromClause):
if kw:
raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw))
+ def __str__(self):
+ if self.schema is not None:
+ return self.schema + "." + self.name
+ else:
+ return self.name
+
def _refresh_for_new_column(self, column):
pass
@@ -2905,7 +3021,8 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
self._group_by_clauses = ()
else:
self._group_by_clauses += tuple(
- coercions.expect(roles.ByOfRole, clause) for clause in clauses
+ coercions.expect(roles.GroupByRole, clause)
+ for clause in clauses
)
@@ -3309,8 +3426,16 @@ class DeprecatedSelectGenerations(object):
class SelectState(CompileState):
+ class default_select_compile_options(CacheableOptions):
+ _cache_key_traversal = []
+
def __init__(self, statement, compiler, **kw):
self.statement = statement
+ self.from_clauses = statement._from_obj
+
+ if statement._setup_joins:
+ self._setup_joins(statement._setup_joins)
+
self.froms = self._get_froms(statement)
self.columns_plus_names = statement._generate_columns_plus_names(True)
@@ -3319,7 +3444,18 @@ class SelectState(CompileState):
froms = []
seen = set()
- for item in statement._iterate_from_elements():
+ for item in itertools.chain(
+ itertools.chain.from_iterable(
+ [element._from_objects for element in statement._raw_columns]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ self.from_clauses,
+ ):
if item._is_subquery and item.element is statement:
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
@@ -3341,6 +3477,7 @@ class SelectState(CompileState):
correlating.
"""
+
froms = self.froms
toremove = set(
@@ -3425,10 +3562,162 @@ class SelectState(CompileState):
return with_cols, only_froms, only_cols
+ @classmethod
+ def determine_last_joined_entity(cls, stmt):
+ if stmt._setup_joins:
+ return stmt._setup_joins[-1][0]
+ else:
+ return None
+
+ def _setup_joins(self, args):
+ for (right, onclause, left, flags) in args:
+ isouter = flags["isouter"]
+ full = flags["full"]
+
+ if left is None:
+ (
+ left,
+ replace_from_obj_index,
+ ) = self._join_determine_implicit_left_side(
+ left, right, onclause
+ )
+ else:
+ (replace_from_obj_index) = self._join_place_explicit_left_side(
+ left
+ )
+
+ if replace_from_obj_index is not None:
+ # splice into an existing element in the
+ # self._from_obj list
+ left_clause = self.from_clauses[replace_from_obj_index]
+
+ self.from_clauses = (
+ self.from_clauses[:replace_from_obj_index]
+ + (
+ Join(
+ left_clause,
+ right,
+ onclause,
+ isouter=isouter,
+ full=full,
+ ),
+ )
+ + self.from_clauses[replace_from_obj_index + 1 :]
+ )
+ else:
+
+ self.from_clauses = self.from_clauses + (
+ Join(left, right, onclause, isouter=isouter, full=full,),
+ )
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_determine_implicit_left_side(self, left, right, onclause):
+ """When join conditions don't express the left side explicitly,
+ determine if an existing FROM or entity in this query
+ can serve as the left hand side.
+
+ """
+
+ sql_util = util.preloaded.sql_util
+
+ replace_from_obj_index = None
+
+ from_clauses = self.statement._from_obj
+
+ if from_clauses:
+
+ indexes = sql_util.find_left_clause_to_join_from(
+ from_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ replace_from_obj_index = indexes[0]
+ left = from_clauses[replace_from_obj_index]
+ else:
+ potential = {}
+ statement = self.statement
+
+ for from_clause in itertools.chain(
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._raw_columns
+ ]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ ):
+
+ potential[from_clause] = ()
+
+ all_clauses = list(potential.keys())
+ indexes = sql_util.find_left_clause_to_join_from(
+ all_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ left = all_clauses[indexes[0]]
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explcit ON clause if not present already to "
+ "help resolve the ambiguity."
+ )
+ elif not indexes:
+ raise exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explcit ON clause if not present already to "
+ "help resolve the ambiguity." % (right,)
+ )
+ return left, replace_from_obj_index
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_place_explicit_left_side(self, left):
+ replace_from_obj_index = None
+
+ sql_util = util.preloaded.sql_util
+
+ from_clauses = list(self.statement._iterate_from_elements())
+
+ if from_clauses:
+ indexes = sql_util.find_left_clause_that_matches_given(
+ self.from_clauses, left
+ )
+ else:
+ indexes = []
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't identify which entity in which to assign the "
+ "left side of this join. Please use a more specific "
+ "ON clause."
+ )
+
+ # have an index, means the left side is already present in
+ # an existing FROM in the self._from_obj tuple
+ if indexes:
+ replace_from_obj_index = indexes[0]
+
+ # no index, means we need to add a new element to the
+ # self._from_obj tuple
+
+ return replace_from_obj_index
+
class Select(
HasPrefixes,
HasSuffixes,
+ HasHints,
HasCompileState,
DeprecatedSelectGenerations,
GenerativeSelect,
@@ -3440,9 +3729,10 @@ class Select(
__visit_name__ = "select"
_compile_state_factory = SelectState._create
+ _is_future = False
+ _setup_joins = ()
+ _legacy_setup_joins = ()
- _hints = util.immutabledict()
- _statement_hints = ()
_distinct = False
_distinct_on = ()
_correlate = ()
@@ -3452,6 +3742,8 @@ class Select(
_from_obj = ()
_auto_correlate = True
+ compile_options = SelectState.default_select_compile_options
+
_traverse_internals = (
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
@@ -3460,58 +3752,33 @@ class Select(
("_having_criteria", InternalTraversal.dp_clauseelement_list),
("_order_by_clauses", InternalTraversal.dp_clauseelement_list,),
("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
(
"_correlate_except",
InternalTraversal.dp_clauseelement_unordered_set,
),
("_for_update_arg", InternalTraversal.dp_clauseelement),
- ("_statement_hints", InternalTraversal.dp_statement_hint_list),
- ("_hints", InternalTraversal.dp_table_hint_list),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
("_label_style", InternalTraversal.dp_plain_obj),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
+ + HasHints._has_hints_traverse_internals
+ SupportsCloneAnnotations._clone_annotations_traverse_internals
+ + Executable._executable_traverse_internals
)
+ _cache_key_traversal = _traverse_internals + [
+ ("compile_options", InternalTraversal.dp_has_cache_key)
+ ]
+
@classmethod
def _create_select(cls, *entities):
- r"""Construct a new :class:`_expression.Select` using the 2.
- x style API.
-
- .. versionadded:: 2.0 - the :func:`_future.select` construct is
- the same construct as the one returned by
- :func:`_expression.select`, except that the function only
- accepts the "columns clause" entities up front; the rest of the
- state of the SELECT should be built up using generative methods.
-
- Similar functionality is also available via the
- :meth:`_expression.FromClause.select` method on any
- :class:`_expression.FromClause`.
-
- .. seealso::
-
- :ref:`coretutorial_selecting` - Core Tutorial description of
- :func:`_expression.select`.
-
- :param \*entities:
- Entities to SELECT from. For Core usage, this is typically a series
- of :class:`_expression.ColumnElement` and / or
- :class:`_expression.FromClause`
- objects which will form the columns clause of the resulting
- statement. For those objects that are instances of
- :class:`_expression.FromClause` (typically :class:`_schema.Table`
- or :class:`_expression.Alias`
- objects), the :attr:`_expression.FromClause.c`
- collection is extracted
- to form a collection of :class:`_expression.ColumnElement` objects.
-
- This parameter will also accept :class:`_expression.TextClause`
- constructs as
- given, as well as ORM-mapped classes.
+ r"""Construct an old style :class:`_expression.Select` using the
+ the 2.x style constructor.
"""
@@ -3779,7 +4046,10 @@ class Select(
if cols_present:
self._raw_columns = [
- coercions.expect(roles.ColumnsClauseRole, c,) for c in columns
+ coercions.expect(
+ roles.ColumnsClauseRole, c, apply_plugins=self
+ )
+ for c in columns
]
else:
self._raw_columns = []
@@ -3820,71 +4090,6 @@ class Select(
return self._compile_state_factory(self, None)._get_display_froms()
- def with_statement_hint(self, text, dialect_name="*"):
- """add a statement hint to this :class:`_expression.Select`.
-
- This method is similar to :meth:`_expression.Select.with_hint`
- except that
- it does not require an individual table, and instead applies to the
- statement as a whole.
-
- Hints here are specific to the backend database and may include
- directives such as isolation levels, file directives, fetch directives,
- etc.
-
- .. versionadded:: 1.0.0
-
- .. seealso::
-
- :meth:`_expression.Select.with_hint`
-
- :meth:`.Select.prefix_with` - generic SELECT prefixing which also
- can suit some database-specific HINT syntaxes such as MySQL
- optimizer hints
-
- """
- return self.with_hint(None, text, dialect_name)
-
- @_generative
- def with_hint(self, selectable, text, dialect_name="*"):
- r"""Add an indexing or other executional context hint for the given
- selectable to this :class:`_expression.Select`.
-
- The text of the hint is rendered in the appropriate
- location for the database backend in use, relative
- to the given :class:`_schema.Table` or :class:`_expression.Alias`
- passed as the
- ``selectable`` argument. The dialect implementation
- typically uses Python string substitution syntax
- with the token ``%(name)s`` to render the name of
- the table or alias. E.g. when using Oracle, the
- following::
-
- select([mytable]).\
- with_hint(mytable, "index(%(name)s ix_mytable)")
-
- Would render SQL as::
-
- select /*+ index(mytable ix_mytable) */ ... from mytable
-
- The ``dialect_name`` option will limit the rendering of a particular
- hint to a particular backend. Such as, to add hints for both Oracle
- and Sybase simultaneously::
-
- select([mytable]).\
- with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\
- with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
-
- .. seealso::
-
- :meth:`_expression.Select.with_statement_hint`
-
- """
- if selectable is None:
- self._statement_hints += ((dialect_name, text),)
- else:
- self._hints = self._hints.union({(selectable, dialect_name): text})
-
@property
def inner_columns(self):
"""an iterator of all ColumnElement expressions which would
@@ -3921,9 +4126,16 @@ class Select(
_from_objects(*self._where_criteria),
)
)
+
+ # do a clone for the froms we've gathered. what is important here
+ # is if any of the things we are selecting from, like tables,
+ # were converted into Join objects. if so, these need to be
+ # added to _from_obj explicitly, because otherwise they won't be
+ # part of the new state, as they don't associate themselves with
+ # their columns.
new_froms = {f: clone(f, **kw) for f in all_the_froms}
- # 2. copy FROM collections.
+ # 2. copy FROM collections, adding in joins that we've created.
self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple(
f for f in new_froms.values() if isinstance(f, Join)
)
@@ -3937,6 +4149,10 @@ class Select(
kw["replace"] = replace
+ # copy everything else. for table-ish things like correlate,
+ # correlate_except, setup_joins, these clone normally. For
+ # column-expression oriented things like raw_columns, where_criteria,
+ # order by, we get this from the new froms.
super(Select, self)._copy_internals(
clone=clone, omit_attrs=("_from_obj",), **kw
)
@@ -3975,10 +4191,18 @@ class Select(
self._assert_no_memoizations()
self._raw_columns = self._raw_columns + [
- coercions.expect(roles.ColumnsClauseRole, column,)
+ coercions.expect(
+ roles.ColumnsClauseRole, column, apply_plugins=self
+ )
for column in columns
]
+ def _set_entities(self, entities):
+ self._raw_columns = [
+ coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self)
+ for ent in util.to_list(entities)
+ ]
+
@util.deprecated(
"1.4",
"The :meth:`_expression.Select.column` method is deprecated and will "
@@ -4111,6 +4335,7 @@ class Select(
rc = []
for c in columns:
c = coercions.expect(roles.ColumnsClauseRole, c,)
+ # TODO: why are we doing this here?
if isinstance(c, ScalarSelect):
c = c.self_group(against=operators.comma_op)
rc.append(c)
@@ -4121,7 +4346,9 @@ class Select(
"""Legacy, return the WHERE clause as a """
""":class:`_expression.BooleanClauseList`"""
- return and_(*self._where_criteria)
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
@_generative
def where(self, whereclause):
@@ -4202,7 +4429,9 @@ class Select(
"""
self._from_obj += tuple(
- coercions.expect(roles.FromClauseRole, fromclause)
+ coercions.expect(
+ roles.FromClauseRole, fromclause, apply_plugins=self
+ )
for fromclause in froms
)
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 8c63fcba1..a308feb7c 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -29,9 +29,8 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
-class HasCacheKey(HasMemoized):
+class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
-
__slots__ = ()
def _gen_cache_key(self, anon_map, bindparams):
@@ -141,7 +140,6 @@ class HasCacheKey(HasMemoized):
return result
- @HasMemoized.memoized_instancemethod
def _generate_cache_key(self):
"""return a cache key.
@@ -183,6 +181,23 @@ class HasCacheKey(HasMemoized):
else:
return CacheKey(key, bindparams)
+ @classmethod
+ def _generate_cache_key_for_object(cls, obj):
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = obj._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+ @HasMemoized.memoized_instancemethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key(self)
+
class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
def __hash__(self):
@@ -191,6 +206,40 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
def __eq__(self, other):
return self.key == other.key
+ def _whats_different(self, other):
+
+ k1 = self.key
+ k2 = other.key
+
+ stack = []
+ pickup_index = 0
+ while True:
+ s1, s2 = k1, k2
+ for idx in stack:
+ s1 = s1[idx]
+ s2 = s2[idx]
+
+ for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)):
+ if idx < pickup_index:
+ continue
+ if e1 != e2:
+ if isinstance(e1, tuple) and isinstance(e2, tuple):
+ stack.append(idx)
+ break
+ else:
+ yield "key%s[%d]: %s != %s" % (
+ "".join("[%d]" % id_ for id_ in stack),
+ idx,
+ e1,
+ e2,
+ )
+ else:
+ pickup_index = stack.pop(-1)
+ break
+
+ def _diff(self, other):
+ return ", ".join(self._whats_different(other))
+
def __str__(self):
stack = [self.key]
@@ -241,9 +290,7 @@ class _CacheKey(ExtendedInternalTraversal):
visit_type = STATIC_CACHE_KEY
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
- return self.visit_has_cache_key(
- attrname, inspect(obj), parent, anon_map, bindparams
- )
+ return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
return tuple(obj)
@@ -361,6 +408,24 @@ class _CacheKey(ExtendedInternalTraversal):
),
)
+ def visit_setup_join_tuple(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ return tuple(
+ (
+ target._gen_cache_key(anon_map, bindparams),
+ onclause._gen_cache_key(anon_map, bindparams)
+ if onclause is not None
+ else None,
+ from_._gen_cache_key(anon_map, bindparams)
+ if from_ is not None
+ else None,
+ tuple([(key, flags[key]) for key in sorted(flags)]),
+ )
+ for (target, onclause, from_, flags) in obj
+ )
+
def visit_table_hint_list(
self, attrname, obj, parent, anon_map, bindparams
):
@@ -498,31 +563,53 @@ class _CopyInternals(InternalTraversal):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
- def visit_clauseelement(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return clone(element, **kw)
- def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement_list(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return [clone(clause, **kw) for clause in element]
def visit_clauseelement_unordered_set(
- self, parent, element, clone=_clone, **kw
+ self, attrname, parent, element, clone=_clone, **kw
):
return {clone(clause, **kw) for clause in element}
- def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement_tuples(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return [
tuple(clone(tup_elem, **kw) for tup_elem in elem)
for elem in element
]
def visit_string_clauseelement_dict(
- self, parent, element, clone=_clone, **kw
+ self, attrname, parent, element, clone=_clone, **kw
):
return dict(
(key, clone(value, **kw)) for key, value in element.items()
)
- def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw):
+ def visit_setup_join_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ return tuple(
+ (
+ clone(target, **kw) if target is not None else None,
+ clone(onclause, **kw) if onclause is not None else None,
+ clone(from_, **kw) if from_ is not None else None,
+ flags,
+ )
+ for (target, onclause, from_, flags) in element
+ )
+
+ def visit_dml_ordered_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
# sequence of 2-tuples
return [
(
@@ -534,7 +621,7 @@ class _CopyInternals(InternalTraversal):
for key, value in element
]
- def visit_dml_values(self, parent, element, clone=_clone, **kw):
+ def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
return {
(
clone(key, **kw) if hasattr(key, "__clause_element__") else key
@@ -542,7 +629,9 @@ class _CopyInternals(InternalTraversal):
for key, value in element.items()
}
- def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
+ def visit_dml_multi_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
# sequence of sequences, each sequence contains a list/dict/tuple
def copy(elem):
@@ -741,7 +830,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
continue
comparison = dispatch(
- left, left_child, right, right_child, **kw
+ left_attrname, left, left_child, right, right_child, **kw
)
if comparison is COMPARE_FAILED:
return False
@@ -753,31 +842,40 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return comparator.compare(obj1, obj2, **kw)
def visit_has_cache_key(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
self.anon_map[1], []
):
return COMPARE_FAILED
+ def visit_has_cache_key_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
def visit_clauseelement(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
self.stack.append((left, right))
def visit_fromclause_canonical_column_collection(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
self.stack.append((lcol, rcol))
def visit_fromclause_derived_column_collection(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
pass
def visit_string_clauseelement_dict(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lstr, rstr in util.zip_longest(
sorted(left), sorted(right), fillvalue=None
@@ -787,7 +885,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((left[lstr], right[rstr]))
def visit_clauseelement_tuples(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
if ltup is None or rtup is None:
@@ -797,7 +895,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((l, r))
def visit_clauseelement_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
@@ -815,48 +913,62 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return len(completed) == len(seq1) == len(seq2)
def visit_clauseelement_unordered_set(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return self._compare_unordered_sequences(left, right, **kw)
def visit_fromclause_ordered_set(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
- def visit_string(self, left_parent, left, right_parent, right, **kw):
+ def visit_string(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_string_list(self, left_parent, left, right_parent, right, **kw):
+ def visit_string_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
+ def visit_anon_name(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return _resolve_name_for_compare(
left_parent, left, self.anon_map[0], **kw
) == _resolve_name_for_compare(
right_parent, right, self.anon_map[1], **kw
)
- def visit_boolean(self, left_parent, left, right_parent, right, **kw):
+ def visit_boolean(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_operator(self, left_parent, left, right_parent, right, **kw):
+ def visit_operator(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left is right
- def visit_type(self, left_parent, left, right_parent, right, **kw):
+ def visit_type(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left._compare_type_affinity(right)
- def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
+ def visit_plain_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
def visit_dialect_options(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return left == right
def visit_annotations_key(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left and right:
return (
@@ -866,11 +978,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
else:
return left == right
- def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
+ def visit_plain_obj(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
def visit_named_ddl_element(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left is None:
if right is not None:
@@ -879,7 +993,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return left.name == right.name
def visit_prefix_sequence(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
left, right, fillvalue=(None, None)
@@ -889,8 +1003,22 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
else:
self.stack.append((l_clause, r_clause))
+ def visit_setup_join_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ for (
+ (l_target, l_onclause, l_from, l_flags),
+ (r_target, r_onclause, r_from, r_flags),
+ ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)):
+ if l_flags != r_flags:
+ return COMPARE_FAILED
+ self.stack.append((l_target, r_target))
+ self.stack.append((l_onclause, r_onclause))
+ self.stack.append((l_from, r_from))
+
def visit_table_hint_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
right_keys = sorted(
@@ -907,17 +1035,17 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((ltable, rtable))
def visit_statement_hint_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return left == right
def visit_unknown_structure(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
raise NotImplementedError()
def visit_dml_ordered_values(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
# sequence of tuple pairs
@@ -941,7 +1069,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return True
- def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
+ def visit_dml_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
if left is None or right is None or len(left) != len(right):
return COMPARE_FAILED
@@ -961,7 +1091,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return COMPARE_FAILED
def visit_dml_multi_values(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
if lseq is None or rseq is None:
@@ -970,7 +1100,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
if (
self.visit_dml_values(
- left_parent, ld, right_parent, rd, **kw
+ attrname, left_parent, ld, right_parent, rd, **kw
)
is COMPARE_FAILED
):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 0a67ff9bf..377aa4fe0 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -37,6 +37,7 @@ from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
+from .traversals import HasCacheKey # noqa
from .. import exc
from .. import util
@@ -921,6 +922,14 @@ class ColumnAdapter(ClauseAdapter):
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
+ def adapt_check_present(self, col):
+ newcol = self.columns[col]
+
+ if newcol is col and self._corresponding_column(col, True) is None:
+ return None
+
+ return newcol
+
def _locate_col(self, col):
c = ClauseAdapter.traverse(self, col)
@@ -945,3 +954,25 @@ class ColumnAdapter(ClauseAdapter):
def __setstate__(self, state):
self.__dict__.update(state)
self.columns = util.WeakPopulateDict(self._locate_col)
+
+
+def _entity_namespace_key(entity, key):
+ """Return an entry from an entity_namespace.
+
+
+ Raises :class:`_exc.InvalidRequestError` rather than attribute error
+ on not found.
+
+ """
+
+ ns = entity.entity_namespace
+ try:
+ return getattr(ns, key)
+ except AttributeError as err:
+ util.raise_(
+ exc.InvalidRequestError(
+ 'Entity namespace for "%s" has no property "%s"'
+ % (entity, key)
+ ),
+ replace_context=err,
+ )
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 574896cc7..030fd2fde 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -225,6 +225,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
dp_has_cache_key = symbol("HC")
"""Visit a :class:`.HasCacheKey` object."""
+ dp_has_cache_key_list = symbol("HL")
+ """Visit a list of :class:`.HasCacheKey` objects."""
+
dp_clauseelement = symbol("CE")
"""Visit a :class:`_expression.ClauseElement` object."""
@@ -372,6 +375,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_setup_join_tuple = symbol("SJ")
+
dp_statement_hint_list = symbol("SH")
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`
@@ -437,9 +442,6 @@ class ExtendedInternalTraversal(InternalTraversal):
"""
- dp_has_cache_key_list = symbol("HL")
- """Visit a list of :class:`.HasCacheKey` objects."""
-
dp_inspectable_list = symbol("IL")
"""Visit a list of inspectable objects which upon inspection are
HasCacheKey objects."""