summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/annotation.py18
-rw-r--r--lib/sqlalchemy/sql/base.py128
-rw-r--r--lib/sqlalchemy/sql/coercions.py82
-rw-r--r--lib/sqlalchemy/sql/compiler.py10
-rw-r--r--lib/sqlalchemy/sql/elements.py50
-rw-r--r--lib/sqlalchemy/sql/roles.py19
-rw-r--r--lib/sqlalchemy/sql/schema.py5
-rw-r--r--lib/sqlalchemy/sql/selectable.py453
-rw-r--r--lib/sqlalchemy/sql/traversals.py216
-rw-r--r--lib/sqlalchemy/sql/util.py31
-rw-r--r--lib/sqlalchemy/sql/visitors.py8
11 files changed, 837 insertions, 183 deletions
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index 891b8ae09..71d05f38f 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -31,7 +31,10 @@ class SupportsAnnotations(object):
if isinstance(value, HasCacheKey)
else value,
)
- for key, value in self._annotations.items()
+ for key, value in [
+ (key, self._annotations[key])
+ for key in sorted(self._annotations)
+ ]
),
)
@@ -51,6 +54,7 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new = self._clone()
new._annotations = new._annotations.union(values)
new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
return new
def _with_annotations(self, values):
@@ -61,6 +65,7 @@ class SupportsCloneAnnotations(SupportsAnnotations):
new = self._clone()
new._annotations = util.immutabledict(values)
new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
return new
def _deannotate(self, values=None, clone=False):
@@ -76,7 +81,7 @@ class SupportsCloneAnnotations(SupportsAnnotations):
# clone is used when we are also copying
# the expression for a deep deannotation
new = self._clone()
- new._annotations = {}
+ new._annotations = util.immutabledict()
new.__dict__.pop("_annotations_cache_key", None)
return new
else:
@@ -156,6 +161,7 @@ class Annotated(object):
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
self.__dict__.pop("_annotations_cache_key", None)
+ self.__dict__.pop("_generate_cache_key", None)
self.__element = element
self._annotations = values
self._hash = hash(element)
@@ -169,6 +175,7 @@ class Annotated(object):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone.__dict__.pop("_annotations_cache_key", None)
+ clone.__dict__.pop("_generate_cache_key", None)
clone._annotations = values
return clone
@@ -211,6 +218,13 @@ class Annotated(object):
else:
return hash(other) == hash(self)
+ @property
+ def entity_namespace(self):
+ if "entity_namespace" in self._annotations:
+ return self._annotations["entity_namespace"].entity_namespace
+ else:
+ return self.__element.entity_namespace
+
# hard-generate Annotated subclasses. this technique
# is used instead of on-the-fly types (i.e. type.__new__())
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 2d023c6a6..04cc34480 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -15,11 +15,14 @@ import operator
import re
from .traversals import HasCacheKey # noqa
+from .traversals import MemoizedHasCacheKey # noqa
from .visitors import ClauseVisitor
+from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import exc
from .. import util
from ..util import HasMemoized
+from ..util import hybridmethod
if util.TYPE_CHECKING:
from types import ModuleType
@@ -433,22 +436,52 @@ class CompileState(object):
__slots__ = ("statement",)
+ plugins = {}
+
@classmethod
def _create(cls, statement, compiler, **kw):
# factory construction.
- # specific CompileState classes here will look for
- # "plugins" in the given statement. From there they will invoke
- # the appropriate plugin constructor if one is found and return
- # the alternate CompileState object.
+ if statement._compile_state_plugin is not None:
+ constructor = cls.plugins.get(
+ (
+ statement._compile_state_plugin,
+ statement.__visit_name__,
+ None,
+ ),
+ cls,
+ )
+ else:
+ constructor = cls
- c = cls.__new__(cls)
- c.__init__(statement, compiler, **kw)
- return c
+ return constructor(statement, compiler, **kw)
def __init__(self, statement, compiler, **kw):
self.statement = statement
+ @classmethod
+ def get_plugin_classmethod(cls, statement, name):
+ if statement._compile_state_plugin is not None:
+ fn = cls.plugins.get(
+ (
+ statement._compile_state_plugin,
+ statement.__visit_name__,
+ name,
+ ),
+ None,
+ )
+ if fn is not None:
+ return fn
+ return getattr(cls, name)
+
+ @classmethod
+ def plugin_for(cls, plugin_name, visit_name, method_name=None):
+ def decorate(fn):
+ cls.plugins[(plugin_name, visit_name, method_name)] = fn
+ return fn
+
+ return decorate
+
class Generative(HasMemoized):
"""Provide a method-chaining pattern in conjunction with the
@@ -479,6 +512,57 @@ class HasCompileState(Generative):
_compile_state_plugin = None
+ _attributes = util.immutabledict()
+
+
+class _MetaOptions(type):
+ """metaclass for the Options class."""
+
+ def __init__(cls, classname, bases, dict_):
+ cls._cache_attrs = tuple(
+ sorted(d for d in dict_ if not d.startswith("__"))
+ )
+ type.__init__(cls, classname, bases, dict_)
+
+ def __add__(self, other):
+ o1 = self()
+ o1.__dict__.update(other)
+ return o1
+
+
+class Options(util.with_metaclass(_MetaOptions)):
+ """A cacheable option dictionary with defaults.
+
+
+ """
+
+ def __init__(self, **kw):
+ self.__dict__.update(kw)
+
+ def __add__(self, other):
+ o1 = self.__class__.__new__(self.__class__)
+ o1.__dict__.update(self.__dict__)
+ o1.__dict__.update(other)
+ return o1
+
+ @hybridmethod
+ def add_to_element(self, name, value):
+ return self + {name: getattr(self, name) + value}
+
+
+class CacheableOptions(Options, HasCacheKey):
+ @hybridmethod
+ def _gen_cache_key(self, anon_map, bindparams):
+ return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
+
+ @_gen_cache_key.classlevel
+ def _gen_cache_key(cls, anon_map, bindparams):
+ return (cls, ())
+
+ @hybridmethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key_for_object(self)
+
class Executable(Generative):
"""Mark a ClauseElement as supporting execution.
@@ -492,7 +576,21 @@ class Executable(Generative):
supports_execution = True
_execution_options = util.immutabledict()
_bind = None
+ _with_options = ()
+ _with_context_options = ()
+ _cache_enable = True
+
+ _executable_traverse_internals = [
+ ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list),
+ ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj),
+ ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj),
+ ]
+
+ @_generative
+ def _disable_caching(self):
+ self._cache_enable = HasCacheKey()
+ @_generative
def options(self, *options):
"""Apply options to this statement.
@@ -522,7 +620,21 @@ class Executable(Generative):
to the usage of ORM queries
"""
- self._options += options
+ self._with_options += options
+
+ @_generative
+ def _add_context_option(self, callable_, cache_args):
+ """Add a context option to this statement.
+
+ These are callable functions that will
+ be given the CompileState object upon compilation.
+
+ A second argument cache_args is required, which will be combined
+ with the identity of the function itself in order to produce a
+ cache key.
+
+ """
+ self._with_context_options += ((callable_, cache_args),)
@_generative
def execution_options(self, **kw):
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index c7b85c415..2fc63b82f 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -57,7 +57,7 @@ def expect(role, element, **kw):
if not isinstance(
element,
- (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue),
+ (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue,),
):
resolved = impl._resolve_for_clause_element(element, **kw)
else:
@@ -106,7 +106,9 @@ class RoleImpl(object):
self.name = role_class._role_name
self._use_inspection = issubclass(role_class, roles.UsesInspection)
- def _resolve_for_clause_element(self, element, argname=None, **kw):
+ def _resolve_for_clause_element(
+ self, element, argname=None, apply_plugins=None, **kw
+ ):
original_element = element
is_clause_element = False
@@ -115,18 +117,39 @@ class RoleImpl(object):
if not getattr(element, "is_clause_element", False):
element = element.__clause_element__()
else:
- return element
+ break
+
+ should_apply_plugins = (
+ apply_plugins is not None
+ and apply_plugins._compile_state_plugin is None
+ )
if is_clause_element:
+ if (
+ should_apply_plugins
+ and "compile_state_plugin" in element._annotations
+ ):
+ apply_plugins._compile_state_plugin = element._annotations[
+ "compile_state_plugin"
+ ]
return element
if self._use_inspection:
insp = inspection.inspect(element, raiseerr=False)
if insp is not None:
+ insp._post_inspect
try:
- return insp.__clause_element__()
+ element = insp.__clause_element__()
except AttributeError:
self._raise_for_expected(original_element, argname)
+ else:
+ if (
+ should_apply_plugins
+ and "compile_state_plugin" in element._annotations
+ ):
+ plugin = element._annotations["compile_state_plugin"]
+ apply_plugins._compile_state_plugin = plugin
+ return element
return self._literal_coercion(element, argname=argname, **kw)
@@ -287,7 +310,7 @@ class _SelectIsNotFrom(object):
advice = (
"To create a "
"FROM clause from a %s object, use the .subquery() method."
- % (element.__class__)
+ % (element.__class__,)
)
code = "89ve"
else:
@@ -453,6 +476,18 @@ class OrderByImpl(ByOfImpl, RoleImpl):
return resolved
+class GroupByImpl(ByOfImpl, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(resolved, roles.StrictFromClauseRole):
+ return elements.ClauseList(*resolved.c)
+ else:
+ return resolved
+
+
class DMLColumnImpl(_ReturnsStringKey, RoleImpl):
__slots__ = ()
@@ -618,6 +653,37 @@ class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole):
pass
+class JoinTargetImpl(RoleImpl):
+ __slots__ = ()
+
+ def _literal_coercion(self, element, legacy=False, **kw):
+ if isinstance(element, str):
+ return element
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, legacy=False, **kw
+ ):
+ if isinstance(original_element, roles.JoinTargetRole):
+ return original_element
+ elif legacy and isinstance(resolved, (str, roles.WhereHavingRole)):
+ return resolved
+ elif legacy and resolved._is_select_statement:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT "
+ "constructs into FROM clauses is deprecated; please call "
+ ".subquery() on any Core select or ORM Query object in "
+ "order to produce a subquery object.",
+ version="1.4",
+ )
+ # TODO: doing _implicit_subquery here causes tests to fail,
+ # how was this working before? probably that ORM
+ # join logic treated it as a select and subquery would happen
+ # in _ORMJoin->Join
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
__slots__ = ()
@@ -647,6 +713,12 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
else:
self._raise_for_expected(original_element, argname, resolved)
+ def _post_coercion(self, element, deannotate=False, **kw):
+ if deannotate:
+ return element._deannotate()
+ else:
+ return element
+
class StrictFromClauseImpl(FromClauseImpl):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ccc1b53fe..9a7646743 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -653,6 +653,12 @@ class SQLCompiler(Compiled):
"""
+ compile_state_factories = util.immutabledict()
+ """Dictionary of alternate :class:`.CompileState` factories for given
+ classes, identified by their visit_name.
+
+ """
+
def __init__(
self,
dialect,
@@ -661,6 +667,7 @@ class SQLCompiler(Compiled):
column_keys=None,
inline=False,
linting=NO_LINTING,
+ compile_state_factories=None,
**kwargs
):
"""Construct a new :class:`.SQLCompiler` object.
@@ -727,6 +734,9 @@ class SQLCompiler(Compiled):
# dialect.label_length or dialect.max_identifier_length
self.truncated_names = {}
+ if compile_state_factories:
+ self.compile_state_factories = compile_state_factories
+
Compiled.__init__(self, dialect, statement, **kwargs)
if (
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 7310edd3f..c1bc9edbc 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -26,7 +26,6 @@ from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
-from .base import HasCacheKey
from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
@@ -35,6 +34,7 @@ from .base import SingletonConstant
from .coercions import _document_text_coercion
from .traversals import _copy_internals
from .traversals import _get_children
+from .traversals import MemoizedHasCacheKey
from .traversals import NO_CACHE
from .visitors import cloned_traverse
from .visitors import InternalTraversal
@@ -179,7 +179,10 @@ def not_(clause):
@inspection._self_inspects
class ClauseElement(
- roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible,
+ roles.SQLRole,
+ SupportsWrappingAnnotations,
+ MemoizedHasCacheKey,
+ Traversible,
):
"""Base class for elements of a programmatically constructed SQL
expression.
@@ -206,6 +209,7 @@ class ClauseElement(
_is_select_container = False
_is_select_statement = False
_is_bind_parameter = False
+ _is_clause_list = False
_order_by_label_element = None
@@ -300,7 +304,7 @@ class ClauseElement(
used.
"""
- return self._params(True, optionaldict, kwargs)
+ return self._replace_params(True, optionaldict, kwargs)
def params(self, *optionaldict, **kwargs):
"""Return a copy with :func:`bindparam()` elements replaced.
@@ -315,9 +319,9 @@ class ClauseElement(
{'foo':7}
"""
- return self._params(False, optionaldict, kwargs)
+ return self._replace_params(False, optionaldict, kwargs)
- def _params(self, unique, optionaldict, kwargs):
+ def _replace_params(self, unique, optionaldict, kwargs):
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
@@ -371,7 +375,7 @@ class ClauseElement(
continue
if obj is not None:
- result = meth(self, obj, **kw)
+ result = meth(self, attrname, obj, **kw)
if result is not None:
setattr(self, attrname, result)
@@ -2070,6 +2074,8 @@ class ClauseList(
__visit_name__ = "clauselist"
+ _is_clause_list = True
+
_traverse_internals = [
("clauses", InternalTraversal.dp_clauseelement_list),
("operator", InternalTraversal.dp_operator),
@@ -2079,6 +2085,8 @@ class ClauseList(
self.operator = kwargs.pop("operator", operators.comma_op)
self.group = kwargs.pop("group", True)
self.group_contents = kwargs.pop("group_contents", True)
+ if kwargs.pop("_flatten_sub_clauses", False):
+ clauses = util.flatten_iterator(clauses)
self._tuple_values = kwargs.pop("_tuple_values", False)
self._text_converter_role = text_converter_role = kwargs.pop(
"_literal_as_text_role", roles.WhereHavingRole
@@ -2116,7 +2124,9 @@ class ClauseList(
@property
def _select_iterable(self):
- return iter(self)
+ return itertools.chain.from_iterable(
+ [elem._select_iterable for elem in self.clauses]
+ )
def append(self, clause):
if self.group_contents:
@@ -2224,6 +2234,32 @@ class BooleanClauseList(ClauseList, ColumnElement):
return cls._construct_raw(operator)
@classmethod
+ def _construct_for_whereclause(cls, clauses):
+ operator, continue_on, skip_on = (
+ operators.and_,
+ True_._singleton,
+ False_._singleton,
+ )
+
+ lcc, convert_clauses = cls._process_clauses_for_boolean(
+ operator,
+ continue_on,
+ skip_on,
+ clauses, # these are assumed to be coerced already
+ )
+
+ if lcc > 1:
+ # multiple elements. Return regular BooleanClauseList
+ # which will link elements against the operator.
+ return cls._construct_raw(operator, convert_clauses)
+ elif lcc == 1:
+ # just one element. return it as a single boolean element,
+ # not a list and discard the operator.
+ return convert_clauses[0]
+ else:
+ return None
+
+ @classmethod
def _construct_raw(cls, operator, clauses=None):
self = cls.__new__(cls)
self.clauses = clauses if clauses else []
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 72a0bdc95..b861f721b 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -19,7 +19,7 @@ class SQLRole(object):
class UsesInspection(object):
- pass
+ _post_inspect = None
class ColumnArgumentRole(SQLRole):
@@ -54,6 +54,14 @@ class ByOfRole(ColumnListRole):
_role_name = "GROUP BY / OF / etc. expression"
+class GroupByRole(UsesInspection, ByOfRole):
+ # note there's a special case right now where you can pass a whole
+ # ORM entity to group_by() and it splits out. we may not want to keep
+ # this around
+
+ _role_name = "GROUP BY expression"
+
+
class OrderByRole(ByOfRole):
_role_name = "ORDER BY expression"
@@ -92,7 +100,14 @@ class InElementRole(SQLRole):
)
-class FromClauseRole(ColumnsClauseRole):
+class JoinTargetRole(UsesInspection, StructuralRole):
+ _role_name = (
+ "Join target, typically a FROM expression, or ORM "
+ "relationship attribute"
+ )
+
+
+class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
_role_name = "FROM expression, such as a Table or alias() object"
_is_subquery = False
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 689eda11d..65f8bd81c 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -477,7 +477,10 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
]
def _gen_cache_key(self, anon_map, bindparams):
- return (self,) + self._annotations_cache_key
+ if self._annotations:
+ return (self,) + self._annotations_cache_key
+ else:
+ return (self,)
def __new__(cls, *args, **kw):
if not args:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 85abbb5e0..6a552c18c 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -28,6 +28,7 @@ from .base import _expand_cloned
from .base import _from_objects
from .base import _generative
from .base import _select_iterables
+from .base import CacheableOptions
from .base import ColumnCollection
from .base import ColumnSet
from .base import CompileState
@@ -42,6 +43,7 @@ from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import and_
from .elements import BindParameter
+from .elements import BooleanClauseList
from .elements import ClauseElement
from .elements import ClauseList
from .elements import ColumnClause
@@ -339,6 +341,90 @@ class HasSuffixes(object):
)
+class HasHints(object):
+ _hints = util.immutabledict()
+ _statement_hints = ()
+
+ _has_hints_traverse_internals = [
+ ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+
+ def with_statement_hint(self, text, dialect_name="*"):
+ """add a statement hint to this :class:`_expression.Select` or
+ other selectable object.
+
+ This method is similar to :meth:`_expression.Select.with_hint`
+ except that
+ it does not require an individual table, and instead applies to the
+ statement as a whole.
+
+ Hints here are specific to the backend database and may include
+ directives such as isolation levels, file directives, fetch directives,
+ etc.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_hint`
+
+ :meth:.`.Select.prefix_with` - generic SELECT prefixing which also
+ can suit some database-specific HINT syntaxes such as MySQL
+ optimizer hints
+
+ """
+ return self.with_hint(None, text, dialect_name)
+
+ @_generative
+ def with_hint(self, selectable, text, dialect_name="*"):
+ r"""Add an indexing or other executional context hint for the given
+ selectable to this :class:`_expression.Select` or other selectable
+ object.
+
+ The text of the hint is rendered in the appropriate
+ location for the database backend in use, relative
+ to the given :class:`_schema.Table` or :class:`_expression.Alias`
+ passed as the
+ ``selectable`` argument. The dialect implementation
+ typically uses Python string substitution syntax
+ with the token ``%(name)s`` to render the name of
+ the table or alias. E.g. when using Oracle, the
+ following::
+
+ select([mytable]).\
+ with_hint(mytable, "index(%(name)s ix_mytable)")
+
+ Would render SQL as::
+
+ select /*+ index(mytable ix_mytable) */ ... from mytable
+
+ The ``dialect_name`` option will limit the rendering of a particular
+ hint to a particular backend. Such as, to add hints for both Oracle
+ and Sybase simultaneously::
+
+ select([mytable]).\
+ with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\
+ with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_statement_hint`
+
+ """
+ if selectable is None:
+ self._statement_hints += ((dialect_name, text),)
+ else:
+ self._hints = self._hints.union(
+ {
+ (
+ coercions.expect(roles.FromClauseRole, selectable),
+ dialect_name,
+ ): text
+ }
+ )
+
+
class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""Represent an element that can be used within the ``FROM``
clause of a ``SELECT`` statement.
@@ -597,6 +683,22 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
self._populate_column_collection()
return self._columns.as_immutable()
+ @property
+ def entity_namespace(self):
+ """Return a namespace used for name-based access in SQL expressions.
+
+ This is the namespace that is used to resolve "filter_by()" type
+ expressions, such as::
+
+ stmt.filter_by(address='some address')
+
+ It defaults to the .c collection, however internally it can
+ be overridden using the "entity_namespace" annotation to deliver
+ alternative results.
+
+ """
+ return self.columns
+
@util.memoized_property
def primary_key(self):
"""Return the collection of Column objects which comprise the
@@ -727,13 +829,21 @@ class Join(FromClause):
:class:`_expression.FromClause` object.
"""
- self.left = coercions.expect(roles.FromClauseRole, left)
- self.right = coercions.expect(roles.FromClauseRole, right).self_group()
+ self.left = coercions.expect(
+ roles.FromClauseRole, left, deannotate=True
+ )
+ self.right = coercions.expect(
+ roles.FromClauseRole, right, deannotate=True
+ ).self_group()
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
else:
- self.onclause = onclause.self_group(against=operators._asbool)
+ # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba
+ # not merged yet
+ self.onclause = coercions.expect(
+ roles.WhereHavingRole, onclause
+ ).self_group(against=operators._asbool)
self.isouter = isouter
self.full = full
@@ -1963,6 +2073,12 @@ class TableClause(Immutable, FromClause):
if kw:
raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw))
+ def __str__(self):
+ if self.schema is not None:
+ return self.schema + "." + self.name
+ else:
+ return self.name
+
def _refresh_for_new_column(self, column):
pass
@@ -2905,7 +3021,8 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
self._group_by_clauses = ()
else:
self._group_by_clauses += tuple(
- coercions.expect(roles.ByOfRole, clause) for clause in clauses
+ coercions.expect(roles.GroupByRole, clause)
+ for clause in clauses
)
@@ -3309,8 +3426,16 @@ class DeprecatedSelectGenerations(object):
class SelectState(CompileState):
+ class default_select_compile_options(CacheableOptions):
+ _cache_key_traversal = []
+
def __init__(self, statement, compiler, **kw):
self.statement = statement
+ self.from_clauses = statement._from_obj
+
+ if statement._setup_joins:
+ self._setup_joins(statement._setup_joins)
+
self.froms = self._get_froms(statement)
self.columns_plus_names = statement._generate_columns_plus_names(True)
@@ -3319,7 +3444,18 @@ class SelectState(CompileState):
froms = []
seen = set()
- for item in statement._iterate_from_elements():
+ for item in itertools.chain(
+ itertools.chain.from_iterable(
+ [element._from_objects for element in statement._raw_columns]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ self.from_clauses,
+ ):
if item._is_subquery and item.element is statement:
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
@@ -3341,6 +3477,7 @@ class SelectState(CompileState):
correlating.
"""
+
froms = self.froms
toremove = set(
@@ -3425,10 +3562,162 @@ class SelectState(CompileState):
return with_cols, only_froms, only_cols
+ @classmethod
+ def determine_last_joined_entity(cls, stmt):
+ if stmt._setup_joins:
+ return stmt._setup_joins[-1][0]
+ else:
+ return None
+
+ def _setup_joins(self, args):
+ for (right, onclause, left, flags) in args:
+ isouter = flags["isouter"]
+ full = flags["full"]
+
+ if left is None:
+ (
+ left,
+ replace_from_obj_index,
+ ) = self._join_determine_implicit_left_side(
+ left, right, onclause
+ )
+ else:
+ (replace_from_obj_index) = self._join_place_explicit_left_side(
+ left
+ )
+
+ if replace_from_obj_index is not None:
+ # splice into an existing element in the
+ # self._from_obj list
+ left_clause = self.from_clauses[replace_from_obj_index]
+
+ self.from_clauses = (
+ self.from_clauses[:replace_from_obj_index]
+ + (
+ Join(
+ left_clause,
+ right,
+ onclause,
+ isouter=isouter,
+ full=full,
+ ),
+ )
+ + self.from_clauses[replace_from_obj_index + 1 :]
+ )
+ else:
+
+ self.from_clauses = self.from_clauses + (
+ Join(left, right, onclause, isouter=isouter, full=full,),
+ )
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_determine_implicit_left_side(self, left, right, onclause):
+ """When join conditions don't express the left side explicitly,
+ determine if an existing FROM or entity in this query
+ can serve as the left hand side.
+
+ """
+
+ sql_util = util.preloaded.sql_util
+
+ replace_from_obj_index = None
+
+ from_clauses = self.statement._from_obj
+
+ if from_clauses:
+
+ indexes = sql_util.find_left_clause_to_join_from(
+ from_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ replace_from_obj_index = indexes[0]
+ left = from_clauses[replace_from_obj_index]
+ else:
+ potential = {}
+ statement = self.statement
+
+ for from_clause in itertools.chain(
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._raw_columns
+ ]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ ):
+
+ potential[from_clause] = ()
+
+ all_clauses = list(potential.keys())
+ indexes = sql_util.find_left_clause_to_join_from(
+ all_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ left = all_clauses[indexes[0]]
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explcit ON clause if not present already to "
+ "help resolve the ambiguity."
+ )
+ elif not indexes:
+ raise exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explcit ON clause if not present already to "
+ "help resolve the ambiguity." % (right,)
+ )
+ return left, replace_from_obj_index
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_place_explicit_left_side(self, left):
+ replace_from_obj_index = None
+
+ sql_util = util.preloaded.sql_util
+
+ from_clauses = list(self.statement._iterate_from_elements())
+
+ if from_clauses:
+ indexes = sql_util.find_left_clause_that_matches_given(
+ self.from_clauses, left
+ )
+ else:
+ indexes = []
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't identify which entity in which to assign the "
+ "left side of this join. Please use a more specific "
+ "ON clause."
+ )
+
+ # have an index, means the left side is already present in
+ # an existing FROM in the self._from_obj tuple
+ if indexes:
+ replace_from_obj_index = indexes[0]
+
+ # no index, means we need to add a new element to the
+ # self._from_obj tuple
+
+ return replace_from_obj_index
+
class Select(
HasPrefixes,
HasSuffixes,
+ HasHints,
HasCompileState,
DeprecatedSelectGenerations,
GenerativeSelect,
@@ -3440,9 +3729,10 @@ class Select(
__visit_name__ = "select"
_compile_state_factory = SelectState._create
+ _is_future = False
+ _setup_joins = ()
+ _legacy_setup_joins = ()
- _hints = util.immutabledict()
- _statement_hints = ()
_distinct = False
_distinct_on = ()
_correlate = ()
@@ -3452,6 +3742,8 @@ class Select(
_from_obj = ()
_auto_correlate = True
+ compile_options = SelectState.default_select_compile_options
+
_traverse_internals = (
[
("_raw_columns", InternalTraversal.dp_clauseelement_list),
@@ -3460,58 +3752,33 @@ class Select(
("_having_criteria", InternalTraversal.dp_clauseelement_list),
("_order_by_clauses", InternalTraversal.dp_clauseelement_list,),
("_group_by_clauses", InternalTraversal.dp_clauseelement_list,),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple,),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,),
("_correlate", InternalTraversal.dp_clauseelement_unordered_set),
(
"_correlate_except",
InternalTraversal.dp_clauseelement_unordered_set,
),
("_for_update_arg", InternalTraversal.dp_clauseelement),
- ("_statement_hints", InternalTraversal.dp_statement_hint_list),
- ("_hints", InternalTraversal.dp_table_hint_list),
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_list),
("_label_style", InternalTraversal.dp_plain_obj),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ HasSuffixes._has_suffixes_traverse_internals
+ + HasHints._has_hints_traverse_internals
+ SupportsCloneAnnotations._clone_annotations_traverse_internals
+ + Executable._executable_traverse_internals
)
+ _cache_key_traversal = _traverse_internals + [
+ ("compile_options", InternalTraversal.dp_has_cache_key)
+ ]
+
@classmethod
def _create_select(cls, *entities):
- r"""Construct a new :class:`_expression.Select` using the 2.
- x style API.
-
- .. versionadded:: 2.0 - the :func:`_future.select` construct is
- the same construct as the one returned by
- :func:`_expression.select`, except that the function only
- accepts the "columns clause" entities up front; the rest of the
- state of the SELECT should be built up using generative methods.
-
- Similar functionality is also available via the
- :meth:`_expression.FromClause.select` method on any
- :class:`_expression.FromClause`.
-
- .. seealso::
-
- :ref:`coretutorial_selecting` - Core Tutorial description of
- :func:`_expression.select`.
-
- :param \*entities:
- Entities to SELECT from. For Core usage, this is typically a series
- of :class:`_expression.ColumnElement` and / or
- :class:`_expression.FromClause`
- objects which will form the columns clause of the resulting
- statement. For those objects that are instances of
- :class:`_expression.FromClause` (typically :class:`_schema.Table`
- or :class:`_expression.Alias`
- objects), the :attr:`_expression.FromClause.c`
- collection is extracted
- to form a collection of :class:`_expression.ColumnElement` objects.
-
- This parameter will also accept :class:`_expression.TextClause`
- constructs as
- given, as well as ORM-mapped classes.
+ r"""Construct an old style :class:`_expression.Select` using the
+ the 2.x style constructor.
"""
@@ -3779,7 +4046,10 @@ class Select(
if cols_present:
self._raw_columns = [
- coercions.expect(roles.ColumnsClauseRole, c,) for c in columns
+ coercions.expect(
+ roles.ColumnsClauseRole, c, apply_plugins=self
+ )
+ for c in columns
]
else:
self._raw_columns = []
@@ -3820,71 +4090,6 @@ class Select(
return self._compile_state_factory(self, None)._get_display_froms()
- def with_statement_hint(self, text, dialect_name="*"):
- """add a statement hint to this :class:`_expression.Select`.
-
- This method is similar to :meth:`_expression.Select.with_hint`
- except that
- it does not require an individual table, and instead applies to the
- statement as a whole.
-
- Hints here are specific to the backend database and may include
- directives such as isolation levels, file directives, fetch directives,
- etc.
-
- .. versionadded:: 1.0.0
-
- .. seealso::
-
- :meth:`_expression.Select.with_hint`
-
- :meth:`.Select.prefix_with` - generic SELECT prefixing which also
- can suit some database-specific HINT syntaxes such as MySQL
- optimizer hints
-
- """
- return self.with_hint(None, text, dialect_name)
-
- @_generative
- def with_hint(self, selectable, text, dialect_name="*"):
- r"""Add an indexing or other executional context hint for the given
- selectable to this :class:`_expression.Select`.
-
- The text of the hint is rendered in the appropriate
- location for the database backend in use, relative
- to the given :class:`_schema.Table` or :class:`_expression.Alias`
- passed as the
- ``selectable`` argument. The dialect implementation
- typically uses Python string substitution syntax
- with the token ``%(name)s`` to render the name of
- the table or alias. E.g. when using Oracle, the
- following::
-
- select([mytable]).\
- with_hint(mytable, "index(%(name)s ix_mytable)")
-
- Would render SQL as::
-
- select /*+ index(mytable ix_mytable) */ ... from mytable
-
- The ``dialect_name`` option will limit the rendering of a particular
- hint to a particular backend. Such as, to add hints for both Oracle
- and Sybase simultaneously::
-
- select([mytable]).\
- with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\
- with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
-
- .. seealso::
-
- :meth:`_expression.Select.with_statement_hint`
-
- """
- if selectable is None:
- self._statement_hints += ((dialect_name, text),)
- else:
- self._hints = self._hints.union({(selectable, dialect_name): text})
-
@property
def inner_columns(self):
"""an iterator of all ColumnElement expressions which would
@@ -3921,9 +4126,16 @@ class Select(
_from_objects(*self._where_criteria),
)
)
+
+ # do a clone for the froms we've gathered. what is important here
+ # is if any of the things we are selecting from, like tables,
+ # were converted into Join objects. if so, these need to be
+ # added to _from_obj explicitly, because otherwise they won't be
+ # part of the new state, as they don't associate themselves with
+ # their columns.
new_froms = {f: clone(f, **kw) for f in all_the_froms}
- # 2. copy FROM collections.
+ # 2. copy FROM collections, adding in joins that we've created.
self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple(
f for f in new_froms.values() if isinstance(f, Join)
)
@@ -3937,6 +4149,10 @@ class Select(
kw["replace"] = replace
+ # copy everything else. for table-ish things like correlate,
+ # correlate_except, setup_joins, these clone normally. For
+ # column-expression oriented things like raw_columns, where_criteria,
+ # order by, we get this from the new froms.
super(Select, self)._copy_internals(
clone=clone, omit_attrs=("_from_obj",), **kw
)
@@ -3975,10 +4191,18 @@ class Select(
self._assert_no_memoizations()
self._raw_columns = self._raw_columns + [
- coercions.expect(roles.ColumnsClauseRole, column,)
+ coercions.expect(
+ roles.ColumnsClauseRole, column, apply_plugins=self
+ )
for column in columns
]
+ def _set_entities(self, entities):
+ self._raw_columns = [
+ coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self)
+ for ent in util.to_list(entities)
+ ]
+
@util.deprecated(
"1.4",
"The :meth:`_expression.Select.column` method is deprecated and will "
@@ -4111,6 +4335,7 @@ class Select(
rc = []
for c in columns:
c = coercions.expect(roles.ColumnsClauseRole, c,)
+ # TODO: why are we doing this here?
if isinstance(c, ScalarSelect):
c = c.self_group(against=operators.comma_op)
rc.append(c)
@@ -4121,7 +4346,9 @@ class Select(
"""Legacy, return the WHERE clause as a """
""":class:`_expression.BooleanClauseList`"""
- return and_(*self._where_criteria)
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
@_generative
def where(self, whereclause):
@@ -4202,7 +4429,9 @@ class Select(
"""
self._from_obj += tuple(
- coercions.expect(roles.FromClauseRole, fromclause)
+ coercions.expect(
+ roles.FromClauseRole, fromclause, apply_plugins=self
+ )
for fromclause in froms
)
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index 8c63fcba1..a308feb7c 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -29,9 +29,8 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
-class HasCacheKey(HasMemoized):
+class HasCacheKey(object):
_cache_key_traversal = NO_CACHE
-
__slots__ = ()
def _gen_cache_key(self, anon_map, bindparams):
@@ -141,7 +140,6 @@ class HasCacheKey(HasMemoized):
return result
- @HasMemoized.memoized_instancemethod
def _generate_cache_key(self):
"""return a cache key.
@@ -183,6 +181,23 @@ class HasCacheKey(HasMemoized):
else:
return CacheKey(key, bindparams)
+ @classmethod
+ def _generate_cache_key_for_object(cls, obj):
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = obj._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+ @HasMemoized.memoized_instancemethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key(self)
+
class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
def __hash__(self):
@@ -191,6 +206,40 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
def __eq__(self, other):
return self.key == other.key
+ def _whats_different(self, other):
+
+ k1 = self.key
+ k2 = other.key
+
+ stack = []
+ pickup_index = 0
+ while True:
+ s1, s2 = k1, k2
+ for idx in stack:
+ s1 = s1[idx]
+ s2 = s2[idx]
+
+ for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)):
+ if idx < pickup_index:
+ continue
+ if e1 != e2:
+ if isinstance(e1, tuple) and isinstance(e2, tuple):
+ stack.append(idx)
+ break
+ else:
+ yield "key%s[%d]: %s != %s" % (
+ "".join("[%d]" % id_ for id_ in stack),
+ idx,
+ e1,
+ e2,
+ )
+ else:
+ pickup_index = stack.pop(-1)
+ break
+
+ def _diff(self, other):
+ return ", ".join(self._whats_different(other))
+
def __str__(self):
stack = [self.key]
@@ -241,9 +290,7 @@ class _CacheKey(ExtendedInternalTraversal):
visit_type = STATIC_CACHE_KEY
def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
- return self.visit_has_cache_key(
- attrname, inspect(obj), parent, anon_map, bindparams
- )
+ return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
return tuple(obj)
@@ -361,6 +408,24 @@ class _CacheKey(ExtendedInternalTraversal):
),
)
+ def visit_setup_join_tuple(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ return tuple(
+ (
+ target._gen_cache_key(anon_map, bindparams),
+ onclause._gen_cache_key(anon_map, bindparams)
+ if onclause is not None
+ else None,
+ from_._gen_cache_key(anon_map, bindparams)
+ if from_ is not None
+ else None,
+ tuple([(key, flags[key]) for key in sorted(flags)]),
+ )
+ for (target, onclause, from_, flags) in obj
+ )
+
def visit_table_hint_list(
self, attrname, obj, parent, anon_map, bindparams
):
@@ -498,31 +563,53 @@ class _CopyInternals(InternalTraversal):
"""Generate a _copy_internals internal traversal dispatch for classes
with a _traverse_internals collection."""
- def visit_clauseelement(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return clone(element, **kw)
- def visit_clauseelement_list(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement_list(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return [clone(clause, **kw) for clause in element]
def visit_clauseelement_unordered_set(
- self, parent, element, clone=_clone, **kw
+ self, attrname, parent, element, clone=_clone, **kw
):
return {clone(clause, **kw) for clause in element}
- def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw):
+ def visit_clauseelement_tuples(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
return [
tuple(clone(tup_elem, **kw) for tup_elem in elem)
for elem in element
]
def visit_string_clauseelement_dict(
- self, parent, element, clone=_clone, **kw
+ self, attrname, parent, element, clone=_clone, **kw
):
return dict(
(key, clone(value, **kw)) for key, value in element.items()
)
- def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw):
+ def visit_setup_join_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ return tuple(
+ (
+ clone(target, **kw) if target is not None else None,
+ clone(onclause, **kw) if onclause is not None else None,
+ clone(from_, **kw) if from_ is not None else None,
+ flags,
+ )
+ for (target, onclause, from_, flags) in element
+ )
+
+ def visit_dml_ordered_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
# sequence of 2-tuples
return [
(
@@ -534,7 +621,7 @@ class _CopyInternals(InternalTraversal):
for key, value in element
]
- def visit_dml_values(self, parent, element, clone=_clone, **kw):
+ def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
return {
(
clone(key, **kw) if hasattr(key, "__clause_element__") else key
@@ -542,7 +629,9 @@ class _CopyInternals(InternalTraversal):
for key, value in element.items()
}
- def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
+ def visit_dml_multi_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
# sequence of sequences, each sequence contains a list/dict/tuple
def copy(elem):
@@ -741,7 +830,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
continue
comparison = dispatch(
- left, left_child, right, right_child, **kw
+ left_attrname, left, left_child, right, right_child, **kw
)
if comparison is COMPARE_FAILED:
return False
@@ -753,31 +842,40 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return comparator.compare(obj1, obj2, **kw)
def visit_has_cache_key(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
self.anon_map[1], []
):
return COMPARE_FAILED
+ def visit_has_cache_key_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
def visit_clauseelement(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
self.stack.append((left, right))
def visit_fromclause_canonical_column_collection(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
self.stack.append((lcol, rcol))
def visit_fromclause_derived_column_collection(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
pass
def visit_string_clauseelement_dict(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lstr, rstr in util.zip_longest(
sorted(left), sorted(right), fillvalue=None
@@ -787,7 +885,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((left[lstr], right[rstr]))
def visit_clauseelement_tuples(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
if ltup is None or rtup is None:
@@ -797,7 +895,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((l, r))
def visit_clauseelement_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
@@ -815,48 +913,62 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return len(completed) == len(seq1) == len(seq2)
def visit_clauseelement_unordered_set(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return self._compare_unordered_sequences(left, right, **kw)
def visit_fromclause_ordered_set(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in util.zip_longest(left, right, fillvalue=None):
self.stack.append((l, r))
- def visit_string(self, left_parent, left, right_parent, right, **kw):
+ def visit_string(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_string_list(self, left_parent, left, right_parent, right, **kw):
+ def visit_string_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_anon_name(self, left_parent, left, right_parent, right, **kw):
+ def visit_anon_name(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return _resolve_name_for_compare(
left_parent, left, self.anon_map[0], **kw
) == _resolve_name_for_compare(
right_parent, right, self.anon_map[1], **kw
)
- def visit_boolean(self, left_parent, left, right_parent, right, **kw):
+ def visit_boolean(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
- def visit_operator(self, left_parent, left, right_parent, right, **kw):
+ def visit_operator(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left is right
- def visit_type(self, left_parent, left, right_parent, right, **kw):
+ def visit_type(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left._compare_type_affinity(right)
- def visit_plain_dict(self, left_parent, left, right_parent, right, **kw):
+ def visit_plain_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
def visit_dialect_options(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return left == right
def visit_annotations_key(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left and right:
return (
@@ -866,11 +978,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
else:
return left == right
- def visit_plain_obj(self, left_parent, left, right_parent, right, **kw):
+ def visit_plain_obj(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
return left == right
def visit_named_ddl_element(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
if left is None:
if right is not None:
@@ -879,7 +993,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return left.name == right.name
def visit_prefix_sequence(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
left, right, fillvalue=(None, None)
@@ -889,8 +1003,22 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
else:
self.stack.append((l_clause, r_clause))
+ def visit_setup_join_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ for (
+ (l_target, l_onclause, l_from, l_flags),
+ (r_target, r_onclause, r_from, r_flags),
+ ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)):
+ if l_flags != r_flags:
+ return COMPARE_FAILED
+ self.stack.append((l_target, r_target))
+ self.stack.append((l_onclause, r_onclause))
+ self.stack.append((l_from, r_from))
+
def visit_table_hint_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
right_keys = sorted(
@@ -907,17 +1035,17 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
self.stack.append((ltable, rtable))
def visit_statement_hint_list(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
return left == right
def visit_unknown_structure(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
raise NotImplementedError()
def visit_dml_ordered_values(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
# sequence of tuple pairs
@@ -941,7 +1069,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return True
- def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
+ def visit_dml_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
if left is None or right is None or len(left) != len(right):
return COMPARE_FAILED
@@ -961,7 +1091,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
return COMPARE_FAILED
def visit_dml_multi_values(
- self, left_parent, left, right_parent, right, **kw
+ self, attrname, left_parent, left, right_parent, right, **kw
):
for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
if lseq is None or rseq is None:
@@ -970,7 +1100,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
if (
self.visit_dml_values(
- left_parent, ld, right_parent, rd, **kw
+ attrname, left_parent, ld, right_parent, rd, **kw
)
is COMPARE_FAILED
):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 0a67ff9bf..377aa4fe0 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -37,6 +37,7 @@ from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
+from .traversals import HasCacheKey # noqa
from .. import exc
from .. import util
@@ -921,6 +922,14 @@ class ColumnAdapter(ClauseAdapter):
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
+ def adapt_check_present(self, col):
+ newcol = self.columns[col]
+
+ if newcol is col and self._corresponding_column(col, True) is None:
+ return None
+
+ return newcol
+
def _locate_col(self, col):
c = ClauseAdapter.traverse(self, col)
@@ -945,3 +954,25 @@ class ColumnAdapter(ClauseAdapter):
def __setstate__(self, state):
self.__dict__.update(state)
self.columns = util.WeakPopulateDict(self._locate_col)
+
+
+def _entity_namespace_key(entity, key):
+ """Return an entry from an entity_namespace.
+
+
+ Raises :class:`_exc.InvalidRequestError` rather than attribute error
+ on not found.
+
+ """
+
+ ns = entity.entity_namespace
+ try:
+ return getattr(ns, key)
+ except AttributeError as err:
+ util.raise_(
+ exc.InvalidRequestError(
+ 'Entity namespace for "%s" has no property "%s"'
+ % (entity, key)
+ ),
+ replace_context=err,
+ )
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 574896cc7..030fd2fde 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -225,6 +225,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
dp_has_cache_key = symbol("HC")
"""Visit a :class:`.HasCacheKey` object."""
+ dp_has_cache_key_list = symbol("HL")
+ """Visit a list of :class:`.HasCacheKey` objects."""
+
dp_clauseelement = symbol("CE")
"""Visit a :class:`_expression.ClauseElement` object."""
@@ -372,6 +375,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
"""
+ dp_setup_join_tuple = symbol("SJ")
+
dp_statement_hint_list = symbol("SH")
"""Visit the ``_statement_hints`` collection of a
:class:`_expression.Select`
@@ -437,9 +442,6 @@ class ExtendedInternalTraversal(InternalTraversal):
"""
- dp_has_cache_key_list = symbol("HL")
- """Visit a list of :class:`.HasCacheKey` objects."""
-
dp_inspectable_list = symbol("IL")
"""Visit a list of inspectable objects which upon inspection are
HasCacheKey objects."""