diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 128 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 82 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 50 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 453 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 216 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 8 |
11 files changed, 837 insertions, 183 deletions
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 891b8ae09..71d05f38f 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -31,7 +31,10 @@ class SupportsAnnotations(object): if isinstance(value, HasCacheKey) else value, ) - for key, value in self._annotations.items() + for key, value in [ + (key, self._annotations[key]) + for key in sorted(self._annotations) + ] ), ) @@ -51,6 +54,7 @@ class SupportsCloneAnnotations(SupportsAnnotations): new = self._clone() new._annotations = new._annotations.union(values) new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) return new def _with_annotations(self, values): @@ -61,6 +65,7 @@ class SupportsCloneAnnotations(SupportsAnnotations): new = self._clone() new._annotations = util.immutabledict(values) new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) return new def _deannotate(self, values=None, clone=False): @@ -76,7 +81,7 @@ class SupportsCloneAnnotations(SupportsAnnotations): # clone is used when we are also copying # the expression for a deep deannotation new = self._clone() - new._annotations = {} + new._annotations = util.immutabledict() new.__dict__.pop("_annotations_cache_key", None) return new else: @@ -156,6 +161,7 @@ class Annotated(object): def __init__(self, element, values): self.__dict__ = element.__dict__.copy() self.__dict__.pop("_annotations_cache_key", None) + self.__dict__.pop("_generate_cache_key", None) self.__element = element self._annotations = values self._hash = hash(element) @@ -169,6 +175,7 @@ class Annotated(object): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() clone.__dict__.pop("_annotations_cache_key", None) + clone.__dict__.pop("_generate_cache_key", None) clone._annotations = values return clone @@ -211,6 +218,13 @@ class Annotated(object): else: return hash(other) == hash(self) + @property + def entity_namespace(self): + if "entity_namespace" in self._annotations: + return self._annotations["entity_namespace"].entity_namespace + else: + return self.__element.entity_namespace + # hard-generate Annotated subclasses. this technique # is used instead of on-the-fly types (i.e. type.__new__()) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 2d023c6a6..04cc34480 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -15,11 +15,14 @@ import operator import re from .traversals import HasCacheKey # noqa +from .traversals import MemoizedHasCacheKey # noqa from .visitors import ClauseVisitor +from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import exc from .. import util from ..util import HasMemoized +from ..util import hybridmethod if util.TYPE_CHECKING: from types import ModuleType @@ -433,22 +436,52 @@ class CompileState(object): __slots__ = ("statement",) + plugins = {} + @classmethod def _create(cls, statement, compiler, **kw): # factory construction. - # specific CompileState classes here will look for - # "plugins" in the given statement. From there they will invoke - # the appropriate plugin constructor if one is found and return - # the alternate CompileState object. + if statement._compile_state_plugin is not None: + constructor = cls.plugins.get( + ( + statement._compile_state_plugin, + statement.__visit_name__, + None, + ), + cls, + ) + else: + constructor = cls - c = cls.__new__(cls) - c.__init__(statement, compiler, **kw) - return c + return constructor(statement, compiler, **kw) def __init__(self, statement, compiler, **kw): self.statement = statement + @classmethod + def get_plugin_classmethod(cls, statement, name): + if statement._compile_state_plugin is not None: + fn = cls.plugins.get( + ( + statement._compile_state_plugin, + statement.__visit_name__, + name, + ), + None, + ) + if fn is not None: + return fn + return getattr(cls, name) + + @classmethod + def plugin_for(cls, plugin_name, visit_name, method_name=None): + def decorate(fn): + cls.plugins[(plugin_name, visit_name, method_name)] = fn + return fn + + return decorate + class Generative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @@ -479,6 +512,57 @@ class HasCompileState(Generative): _compile_state_plugin = None + _attributes = util.immutabledict() + + +class _MetaOptions(type): + """metaclass for the Options class.""" + + def __init__(cls, classname, bases, dict_): + cls._cache_attrs = tuple( + sorted(d for d in dict_ if not d.startswith("__")) + ) + type.__init__(cls, classname, bases, dict_) + + def __add__(self, other): + o1 = self() + o1.__dict__.update(other) + return o1 + + +class Options(util.with_metaclass(_MetaOptions)): + """A cacheable option dictionary with defaults. + + + """ + + def __init__(self, **kw): + self.__dict__.update(kw) + + def __add__(self, other): + o1 = self.__class__.__new__(self.__class__) + o1.__dict__.update(self.__dict__) + o1.__dict__.update(other) + return o1 + + @hybridmethod + def add_to_element(self, name, value): + return self + {name: getattr(self, name) + value} + + +class CacheableOptions(Options, HasCacheKey): + @hybridmethod + def _gen_cache_key(self, anon_map, bindparams): + return HasCacheKey._gen_cache_key(self, anon_map, bindparams) + + @_gen_cache_key.classlevel + def _gen_cache_key(cls, anon_map, bindparams): + return (cls, ()) + + @hybridmethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key_for_object(self) + class Executable(Generative): """Mark a ClauseElement as supporting execution. @@ -492,7 +576,21 @@ class Executable(Generative): supports_execution = True _execution_options = util.immutabledict() _bind = None + _with_options = () + _with_context_options = () + _cache_enable = True + + _executable_traverse_internals = [ + ("_with_options", ExtendedInternalTraversal.dp_has_cache_key_list), + ("_with_context_options", ExtendedInternalTraversal.dp_plain_obj), + ("_cache_enable", ExtendedInternalTraversal.dp_plain_obj), + ] + + @_generative + def _disable_caching(self): + self._cache_enable = HasCacheKey() + @_generative def options(self, *options): """Apply options to this statement. @@ -522,7 +620,21 @@ class Executable(Generative): to the usage of ORM queries """ - self._options += options + self._with_options += options + + @_generative + def _add_context_option(self, callable_, cache_args): + """Add a context option to this statement. + + These are callable functions that will + be given the CompileState object upon compilation. + + A second argument cache_args is required, which will be combined + with the identity of the function itself in order to produce a + cache key. + + """ + self._with_context_options += ((callable_, cache_args),) @_generative def execution_options(self, **kw): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index c7b85c415..2fc63b82f 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -57,7 +57,7 @@ def expect(role, element, **kw): if not isinstance( element, - (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), + (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue,), ): resolved = impl._resolve_for_clause_element(element, **kw) else: @@ -106,7 +106,9 @@ class RoleImpl(object): self.name = role_class._role_name self._use_inspection = issubclass(role_class, roles.UsesInspection) - def _resolve_for_clause_element(self, element, argname=None, **kw): + def _resolve_for_clause_element( + self, element, argname=None, apply_plugins=None, **kw + ): original_element = element is_clause_element = False @@ -115,18 +117,39 @@ class RoleImpl(object): if not getattr(element, "is_clause_element", False): element = element.__clause_element__() else: - return element + break + + should_apply_plugins = ( + apply_plugins is not None + and apply_plugins._compile_state_plugin is None + ) if is_clause_element: + if ( + should_apply_plugins + and "compile_state_plugin" in element._annotations + ): + apply_plugins._compile_state_plugin = element._annotations[ + "compile_state_plugin" + ] return element if self._use_inspection: insp = inspection.inspect(element, raiseerr=False) if insp is not None: + insp._post_inspect try: - return insp.__clause_element__() + element = insp.__clause_element__() except AttributeError: self._raise_for_expected(original_element, argname) + else: + if ( + should_apply_plugins + and "compile_state_plugin" in element._annotations + ): + plugin = element._annotations["compile_state_plugin"] + apply_plugins._compile_state_plugin = plugin + return element return self._literal_coercion(element, argname=argname, **kw) @@ -287,7 +310,7 @@ class _SelectIsNotFrom(object): advice = ( "To create a " "FROM clause from a %s object, use the .subquery() method." - % (element.__class__) + % (element.__class__,) ) code = "89ve" else: @@ -453,6 +476,18 @@ class OrderByImpl(ByOfImpl, RoleImpl): return resolved +class GroupByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(resolved, roles.StrictFromClauseRole): + return elements.ClauseList(*resolved.c) + else: + return resolved + + class DMLColumnImpl(_ReturnsStringKey, RoleImpl): __slots__ = () @@ -618,6 +653,37 @@ class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole): pass +class JoinTargetImpl(RoleImpl): + __slots__ = () + + def _literal_coercion(self, element, legacy=False, **kw): + if isinstance(element, str): + return element + + def _implicit_coercions( + self, original_element, resolved, argname=None, legacy=False, **kw + ): + if isinstance(original_element, roles.JoinTargetRole): + return original_element + elif legacy and isinstance(resolved, (str, roles.WhereHavingRole)): + return resolved + elif legacy and resolved._is_select_statement: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + # TODO: doing _implicit_subquery here causes tests to fail, + # how was this working before? probably that ORM + # join logic treated it as a select and subquery would happen + # in _ORMJoin->Join + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): __slots__ = () @@ -647,6 +713,12 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): else: self._raise_for_expected(original_element, argname, resolved) + def _post_coercion(self, element, deannotate=False, **kw): + if deannotate: + return element._deannotate() + else: + return element + class StrictFromClauseImpl(FromClauseImpl): __slots__ = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ccc1b53fe..9a7646743 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -653,6 +653,12 @@ class SQLCompiler(Compiled): """ + compile_state_factories = util.immutabledict() + """Dictionary of alternate :class:`.CompileState` factories for given + classes, identified by their visit_name. + + """ + def __init__( self, dialect, @@ -661,6 +667,7 @@ class SQLCompiler(Compiled): column_keys=None, inline=False, linting=NO_LINTING, + compile_state_factories=None, **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -727,6 +734,9 @@ class SQLCompiler(Compiled): # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} + if compile_state_factories: + self.compile_state_factories = compile_state_factories + Compiled.__init__(self, dialect, statement, **kwargs) if ( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 7310edd3f..c1bc9edbc 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -26,7 +26,6 @@ from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative from .base import Executable -from .base import HasCacheKey from .base import HasMemoized from .base import Immutable from .base import NO_ARG @@ -35,6 +34,7 @@ from .base import SingletonConstant from .coercions import _document_text_coercion from .traversals import _copy_internals from .traversals import _get_children +from .traversals import MemoizedHasCacheKey from .traversals import NO_CACHE from .visitors import cloned_traverse from .visitors import InternalTraversal @@ -179,7 +179,10 @@ def not_(clause): @inspection._self_inspects class ClauseElement( - roles.SQLRole, SupportsWrappingAnnotations, HasCacheKey, Traversible, + roles.SQLRole, + SupportsWrappingAnnotations, + MemoizedHasCacheKey, + Traversible, ): """Base class for elements of a programmatically constructed SQL expression. @@ -206,6 +209,7 @@ class ClauseElement( _is_select_container = False _is_select_statement = False _is_bind_parameter = False + _is_clause_list = False _order_by_label_element = None @@ -300,7 +304,7 @@ class ClauseElement( used. """ - return self._params(True, optionaldict, kwargs) + return self._replace_params(True, optionaldict, kwargs) def params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elements replaced. @@ -315,9 +319,9 @@ class ClauseElement( {'foo':7} """ - return self._params(False, optionaldict, kwargs) + return self._replace_params(False, optionaldict, kwargs) - def _params(self, unique, optionaldict, kwargs): + def _replace_params(self, unique, optionaldict, kwargs): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: @@ -371,7 +375,7 @@ class ClauseElement( continue if obj is not None: - result = meth(self, obj, **kw) + result = meth(self, attrname, obj, **kw) if result is not None: setattr(self, attrname, result) @@ -2070,6 +2074,8 @@ class ClauseList( __visit_name__ = "clauselist" + _is_clause_list = True + _traverse_internals = [ ("clauses", InternalTraversal.dp_clauseelement_list), ("operator", InternalTraversal.dp_operator), @@ -2079,6 +2085,8 @@ class ClauseList( self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) self.group_contents = kwargs.pop("group_contents", True) + if kwargs.pop("_flatten_sub_clauses", False): + clauses = util.flatten_iterator(clauses) self._tuple_values = kwargs.pop("_tuple_values", False) self._text_converter_role = text_converter_role = kwargs.pop( "_literal_as_text_role", roles.WhereHavingRole @@ -2116,7 +2124,9 @@ class ClauseList( @property def _select_iterable(self): - return iter(self) + return itertools.chain.from_iterable( + [elem._select_iterable for elem in self.clauses] + ) def append(self, clause): if self.group_contents: @@ -2224,6 +2234,32 @@ class BooleanClauseList(ClauseList, ColumnElement): return cls._construct_raw(operator) @classmethod + def _construct_for_whereclause(cls, clauses): + operator, continue_on, skip_on = ( + operators.and_, + True_._singleton, + False_._singleton, + ) + + lcc, convert_clauses = cls._process_clauses_for_boolean( + operator, + continue_on, + skip_on, + clauses, # these are assumed to be coerced already + ) + + if lcc > 1: + # multiple elements. Return regular BooleanClauseList + # which will link elements against the operator. + return cls._construct_raw(operator, convert_clauses) + elif lcc == 1: + # just one element. return it as a single boolean element, + # not a list and discard the operator. + return convert_clauses[0] + else: + return None + + @classmethod def _construct_raw(cls, operator, clauses=None): self = cls.__new__(cls) self.clauses = clauses if clauses else [] diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 72a0bdc95..b861f721b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -19,7 +19,7 @@ class SQLRole(object): class UsesInspection(object): - pass + _post_inspect = None class ColumnArgumentRole(SQLRole): @@ -54,6 +54,14 @@ class ByOfRole(ColumnListRole): _role_name = "GROUP BY / OF / etc. expression" +class GroupByRole(UsesInspection, ByOfRole): + # note there's a special case right now where you can pass a whole + # ORM entity to group_by() and it splits out. we may not want to keep + # this around + + _role_name = "GROUP BY expression" + + class OrderByRole(ByOfRole): _role_name = "ORDER BY expression" @@ -92,7 +100,14 @@ class InElementRole(SQLRole): ) -class FromClauseRole(ColumnsClauseRole): +class JoinTargetRole(UsesInspection, StructuralRole): + _role_name = ( + "Join target, typically a FROM expression, or ORM " + "relationship attribute" + ) + + +class FromClauseRole(ColumnsClauseRole, JoinTargetRole): _role_name = "FROM expression, such as a Table or alias() object" _is_subquery = False diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 689eda11d..65f8bd81c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -477,7 +477,10 @@ class Table(DialectKWArgs, SchemaItem, TableClause): ] def _gen_cache_key(self, anon_map, bindparams): - return (self,) + self._annotations_cache_key + if self._annotations: + return (self,) + self._annotations_cache_key + else: + return (self,) def __new__(cls, *args, **kw): if not args: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 85abbb5e0..6a552c18c 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -28,6 +28,7 @@ from .base import _expand_cloned from .base import _from_objects from .base import _generative from .base import _select_iterables +from .base import CacheableOptions from .base import ColumnCollection from .base import ColumnSet from .base import CompileState @@ -42,6 +43,7 @@ from .coercions import _document_text_coercion from .elements import _anonymous_label from .elements import and_ from .elements import BindParameter +from .elements import BooleanClauseList from .elements import ClauseElement from .elements import ClauseList from .elements import ColumnClause @@ -339,6 +341,90 @@ class HasSuffixes(object): ) +class HasHints(object): + _hints = util.immutabledict() + _statement_hints = () + + _has_hints_traverse_internals = [ + ("_statement_hints", InternalTraversal.dp_statement_hint_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + def with_statement_hint(self, text, dialect_name="*"): + """add a statement hint to this :class:`_expression.Select` or + other selectable object. + + This method is similar to :meth:`_expression.Select.with_hint` + except that + it does not require an individual table, and instead applies to the + statement as a whole. + + Hints here are specific to the backend database and may include + directives such as isolation levels, file directives, fetch directives, + etc. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :meth:`_expression.Select.with_hint` + + :meth:.`.Select.prefix_with` - generic SELECT prefixing which also + can suit some database-specific HINT syntaxes such as MySQL + optimizer hints + + """ + return self.with_hint(None, text, dialect_name) + + @_generative + def with_hint(self, selectable, text, dialect_name="*"): + r"""Add an indexing or other executional context hint for the given + selectable to this :class:`_expression.Select` or other selectable + object. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the given :class:`_schema.Table` or :class:`_expression.Alias` + passed as the + ``selectable`` argument. The dialect implementation + typically uses Python string substitution syntax + with the token ``%(name)s`` to render the name of + the table or alias. E.g. when using Oracle, the + following:: + + select([mytable]).\ + with_hint(mytable, "index(%(name)s ix_mytable)") + + Would render SQL as:: + + select /*+ index(mytable ix_mytable) */ ... from mytable + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add hints for both Oracle + and Sybase simultaneously:: + + select([mytable]).\ + with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\ + with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') + + .. seealso:: + + :meth:`_expression.Select.with_statement_hint` + + """ + if selectable is None: + self._statement_hints += ((dialect_name, text),) + else: + self._hints = self._hints.union( + { + ( + coercions.expect(roles.FromClauseRole, selectable), + dialect_name, + ): text + } + ) + + class FromClause(roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -597,6 +683,22 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._populate_column_collection() return self._columns.as_immutable() + @property + def entity_namespace(self): + """Return a namespace used for name-based access in SQL expressions. + + This is the namespace that is used to resolve "filter_by()" type + expressions, such as:: + + stmt.filter_by(address='some address') + + It defaults to the .c collection, however internally it can + be overridden using the "entity_namespace" annotation to deliver + alternative results. + + """ + return self.columns + @util.memoized_property def primary_key(self): """Return the collection of Column objects which comprise the @@ -727,13 +829,21 @@ class Join(FromClause): :class:`_expression.FromClause` object. """ - self.left = coercions.expect(roles.FromClauseRole, left) - self.right = coercions.expect(roles.FromClauseRole, right).self_group() + self.left = coercions.expect( + roles.FromClauseRole, left, deannotate=True + ) + self.right = coercions.expect( + roles.FromClauseRole, right, deannotate=True + ).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) else: - self.onclause = onclause.self_group(against=operators._asbool) + # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba + # not merged yet + self.onclause = coercions.expect( + roles.WhereHavingRole, onclause + ).self_group(against=operators._asbool) self.isouter = isouter self.full = full @@ -1963,6 +2073,12 @@ class TableClause(Immutable, FromClause): if kw: raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw)) + def __str__(self): + if self.schema is not None: + return self.schema + "." + self.name + else: + return self.name + def _refresh_for_new_column(self, column): pass @@ -2905,7 +3021,8 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase): self._group_by_clauses = () else: self._group_by_clauses += tuple( - coercions.expect(roles.ByOfRole, clause) for clause in clauses + coercions.expect(roles.GroupByRole, clause) + for clause in clauses ) @@ -3309,8 +3426,16 @@ class DeprecatedSelectGenerations(object): class SelectState(CompileState): + class default_select_compile_options(CacheableOptions): + _cache_key_traversal = [] + def __init__(self, statement, compiler, **kw): self.statement = statement + self.from_clauses = statement._from_obj + + if statement._setup_joins: + self._setup_joins(statement._setup_joins) + self.froms = self._get_froms(statement) self.columns_plus_names = statement._generate_columns_plus_names(True) @@ -3319,7 +3444,18 @@ class SelectState(CompileState): froms = [] seen = set() - for item in statement._iterate_from_elements(): + for item in itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in statement._raw_columns] + ), + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._where_criteria + ] + ), + self.from_clauses, + ): if item._is_subquery and item.element is statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" @@ -3341,6 +3477,7 @@ class SelectState(CompileState): correlating. """ + froms = self.froms toremove = set( @@ -3425,10 +3562,162 @@ class SelectState(CompileState): return with_cols, only_froms, only_cols + @classmethod + def determine_last_joined_entity(cls, stmt): + if stmt._setup_joins: + return stmt._setup_joins[-1][0] + else: + return None + + def _setup_joins(self, args): + for (right, onclause, left, flags) in args: + isouter = flags["isouter"] + full = flags["full"] + + if left is None: + ( + left, + replace_from_obj_index, + ) = self._join_determine_implicit_left_side( + left, right, onclause + ) + else: + (replace_from_obj_index) = self._join_place_explicit_left_side( + left + ) + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + ( + Join( + left_clause, + right, + onclause, + isouter=isouter, + full=full, + ), + ) + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + + self.from_clauses = self.from_clauses + ( + Join(left, right, onclause, isouter=isouter, full=full,), + ) + + @util.preload_module("sqlalchemy.sql.util") + def _join_determine_implicit_left_side(self, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + sql_util = util.preloaded.sql_util + + replace_from_obj_index = None + + from_clauses = self.statement._from_obj + + if from_clauses: + + indexes = sql_util.find_left_clause_to_join_from( + from_clauses, right, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = from_clauses[replace_from_obj_index] + else: + potential = {} + statement = self.statement + + for from_clause in itertools.chain( + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._raw_columns + ] + ), + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._where_criteria + ] + ), + ): + + potential[from_clause] = () + + all_clauses = list(potential.keys()) + indexes = sql_util.find_left_clause_to_join_from( + all_clauses, right, onclause + ) + + if len(indexes) == 1: + left = all_clauses[indexes[0]] + + if len(indexes) > 1: + raise exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." + ) + elif not indexes: + raise exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explcit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + return left, replace_from_obj_index + + @util.preload_module("sqlalchemy.sql.util") + def _join_place_explicit_left_side(self, left): + replace_from_obj_index = None + + sql_util = util.preloaded.sql_util + + from_clauses = list(self.statement._iterate_from_elements()) + + if from_clauses: + indexes = sql_util.find_left_clause_that_matches_given( + self.from_clauses, left + ) + else: + indexes = [] + + if len(indexes) > 1: + raise exc.InvalidRequestError( + "Can't identify which entity in which to assign the " + "left side of this join. Please use a more specific " + "ON clause." + ) + + # have an index, means the left side is already present in + # an existing FROM in the self._from_obj tuple + if indexes: + replace_from_obj_index = indexes[0] + + # no index, means we need to add a new element to the + # self._from_obj tuple + + return replace_from_obj_index + class Select( HasPrefixes, HasSuffixes, + HasHints, HasCompileState, DeprecatedSelectGenerations, GenerativeSelect, @@ -3440,9 +3729,10 @@ class Select( __visit_name__ = "select" _compile_state_factory = SelectState._create + _is_future = False + _setup_joins = () + _legacy_setup_joins = () - _hints = util.immutabledict() - _statement_hints = () _distinct = False _distinct_on = () _correlate = () @@ -3452,6 +3742,8 @@ class Select( _from_obj = () _auto_correlate = True + compile_options = SelectState.default_select_compile_options + _traverse_internals = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), @@ -3460,58 +3752,33 @@ class Select( ("_having_criteria", InternalTraversal.dp_clauseelement_list), ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,), ("_group_by_clauses", InternalTraversal.dp_clauseelement_list,), + ("_setup_joins", InternalTraversal.dp_setup_join_tuple,), + ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple,), ("_correlate", InternalTraversal.dp_clauseelement_unordered_set), ( "_correlate_except", InternalTraversal.dp_clauseelement_unordered_set, ), ("_for_update_arg", InternalTraversal.dp_clauseelement), - ("_statement_hints", InternalTraversal.dp_statement_hint_list), - ("_hints", InternalTraversal.dp_table_hint_list), ("_distinct", InternalTraversal.dp_boolean), ("_distinct_on", InternalTraversal.dp_clauseelement_list), ("_label_style", InternalTraversal.dp_plain_obj), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals + + HasHints._has_hints_traverse_internals + SupportsCloneAnnotations._clone_annotations_traverse_internals + + Executable._executable_traverse_internals ) + _cache_key_traversal = _traverse_internals + [ + ("compile_options", InternalTraversal.dp_has_cache_key) + ] + @classmethod def _create_select(cls, *entities): - r"""Construct a new :class:`_expression.Select` using the 2. - x style API. - - .. versionadded:: 2.0 - the :func:`_future.select` construct is - the same construct as the one returned by - :func:`_expression.select`, except that the function only - accepts the "columns clause" entities up front; the rest of the - state of the SELECT should be built up using generative methods. - - Similar functionality is also available via the - :meth:`_expression.FromClause.select` method on any - :class:`_expression.FromClause`. - - .. seealso:: - - :ref:`coretutorial_selecting` - Core Tutorial description of - :func:`_expression.select`. - - :param \*entities: - Entities to SELECT from. For Core usage, this is typically a series - of :class:`_expression.ColumnElement` and / or - :class:`_expression.FromClause` - objects which will form the columns clause of the resulting - statement. For those objects that are instances of - :class:`_expression.FromClause` (typically :class:`_schema.Table` - or :class:`_expression.Alias` - objects), the :attr:`_expression.FromClause.c` - collection is extracted - to form a collection of :class:`_expression.ColumnElement` objects. - - This parameter will also accept :class:`_expression.TextClause` - constructs as - given, as well as ORM-mapped classes. + r"""Construct an old style :class:`_expression.Select` using the + the 2.x style constructor. """ @@ -3779,7 +4046,10 @@ class Select( if cols_present: self._raw_columns = [ - coercions.expect(roles.ColumnsClauseRole, c,) for c in columns + coercions.expect( + roles.ColumnsClauseRole, c, apply_plugins=self + ) + for c in columns ] else: self._raw_columns = [] @@ -3820,71 +4090,6 @@ class Select( return self._compile_state_factory(self, None)._get_display_froms() - def with_statement_hint(self, text, dialect_name="*"): - """add a statement hint to this :class:`_expression.Select`. - - This method is similar to :meth:`_expression.Select.with_hint` - except that - it does not require an individual table, and instead applies to the - statement as a whole. - - Hints here are specific to the backend database and may include - directives such as isolation levels, file directives, fetch directives, - etc. - - .. versionadded:: 1.0.0 - - .. seealso:: - - :meth:`_expression.Select.with_hint` - - :meth:`.Select.prefix_with` - generic SELECT prefixing which also - can suit some database-specific HINT syntaxes such as MySQL - optimizer hints - - """ - return self.with_hint(None, text, dialect_name) - - @_generative - def with_hint(self, selectable, text, dialect_name="*"): - r"""Add an indexing or other executional context hint for the given - selectable to this :class:`_expression.Select`. - - The text of the hint is rendered in the appropriate - location for the database backend in use, relative - to the given :class:`_schema.Table` or :class:`_expression.Alias` - passed as the - ``selectable`` argument. The dialect implementation - typically uses Python string substitution syntax - with the token ``%(name)s`` to render the name of - the table or alias. E.g. when using Oracle, the - following:: - - select([mytable]).\ - with_hint(mytable, "index(%(name)s ix_mytable)") - - Would render SQL as:: - - select /*+ index(mytable ix_mytable) */ ... from mytable - - The ``dialect_name`` option will limit the rendering of a particular - hint to a particular backend. Such as, to add hints for both Oracle - and Sybase simultaneously:: - - select([mytable]).\ - with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\ - with_hint(mytable, "WITH INDEX ix_mytable", 'sybase') - - .. seealso:: - - :meth:`_expression.Select.with_statement_hint` - - """ - if selectable is None: - self._statement_hints += ((dialect_name, text),) - else: - self._hints = self._hints.union({(selectable, dialect_name): text}) - @property def inner_columns(self): """an iterator of all ColumnElement expressions which would @@ -3921,9 +4126,16 @@ class Select( _from_objects(*self._where_criteria), ) ) + + # do a clone for the froms we've gathered. what is important here + # is if any of the things we are selecting from, like tables, + # were converted into Join objects. if so, these need to be + # added to _from_obj explicitly, because otherwise they won't be + # part of the new state, as they don't associate themselves with + # their columns. new_froms = {f: clone(f, **kw) for f in all_the_froms} - # 2. copy FROM collections. + # 2. copy FROM collections, adding in joins that we've created. self._from_obj = tuple(clone(f, **kw) for f in self._from_obj) + tuple( f for f in new_froms.values() if isinstance(f, Join) ) @@ -3937,6 +4149,10 @@ class Select( kw["replace"] = replace + # copy everything else. for table-ish things like correlate, + # correlate_except, setup_joins, these clone normally. For + # column-expression oriented things like raw_columns, where_criteria, + # order by, we get this from the new froms. super(Select, self)._copy_internals( clone=clone, omit_attrs=("_from_obj",), **kw ) @@ -3975,10 +4191,18 @@ class Select( self._assert_no_memoizations() self._raw_columns = self._raw_columns + [ - coercions.expect(roles.ColumnsClauseRole, column,) + coercions.expect( + roles.ColumnsClauseRole, column, apply_plugins=self + ) for column in columns ] + def _set_entities(self, entities): + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, ent, apply_plugins=self) + for ent in util.to_list(entities) + ] + @util.deprecated( "1.4", "The :meth:`_expression.Select.column` method is deprecated and will " @@ -4111,6 +4335,7 @@ class Select( rc = [] for c in columns: c = coercions.expect(roles.ColumnsClauseRole, c,) + # TODO: why are we doing this here? if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) rc.append(c) @@ -4121,7 +4346,9 @@ class Select( """Legacy, return the WHERE clause as a """ """:class:`_expression.BooleanClauseList`""" - return and_(*self._where_criteria) + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) @_generative def where(self, whereclause): @@ -4202,7 +4429,9 @@ class Select( """ self._from_obj += tuple( - coercions.expect(roles.FromClauseRole, fromclause) + coercions.expect( + roles.FromClauseRole, fromclause, apply_plugins=self + ) for fromclause in froms ) diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 8c63fcba1..a308feb7c 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -29,9 +29,8 @@ def compare(obj1, obj2, **kw): return strategy.compare(obj1, obj2, **kw) -class HasCacheKey(HasMemoized): +class HasCacheKey(object): _cache_key_traversal = NO_CACHE - __slots__ = () def _gen_cache_key(self, anon_map, bindparams): @@ -141,7 +140,6 @@ class HasCacheKey(HasMemoized): return result - @HasMemoized.memoized_instancemethod def _generate_cache_key(self): """return a cache key. @@ -183,6 +181,23 @@ class HasCacheKey(HasMemoized): else: return CacheKey(key, bindparams) + @classmethod + def _generate_cache_key_for_object(cls, obj): + bindparams = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + return CacheKey(key, bindparams) + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key(self) + class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __hash__(self): @@ -191,6 +206,40 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): def __eq__(self, other): return self.key == other.key + def _whats_different(self, other): + + k1 = self.key + k2 = other.key + + stack = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other): + return ", ".join(self._whats_different(other)) + def __str__(self): stack = [self.key] @@ -241,9 +290,7 @@ class _CacheKey(ExtendedInternalTraversal): visit_type = STATIC_CACHE_KEY def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): - return self.visit_has_cache_key( - attrname, inspect(obj), parent, anon_map, bindparams - ) + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) def visit_string_list(self, attrname, obj, parent, anon_map, bindparams): return tuple(obj) @@ -361,6 +408,24 @@ class _CacheKey(ExtendedInternalTraversal): ), ) + def visit_setup_join_tuple( + self, attrname, obj, parent, anon_map, bindparams + ): + # TODO: look at attrname for "legacy_join" and use different structure + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + def visit_table_hint_list( self, attrname, obj, parent, anon_map, bindparams ): @@ -498,31 +563,53 @@ class _CopyInternals(InternalTraversal): """Generate a _copy_internals internal traversal dispatch for classes with a _traverse_internals collection.""" - def visit_clauseelement(self, parent, element, clone=_clone, **kw): + def visit_clauseelement( + self, attrname, parent, element, clone=_clone, **kw + ): return clone(element, **kw) - def visit_clauseelement_list(self, parent, element, clone=_clone, **kw): + def visit_clauseelement_list( + self, attrname, parent, element, clone=_clone, **kw + ): return [clone(clause, **kw) for clause in element] def visit_clauseelement_unordered_set( - self, parent, element, clone=_clone, **kw + self, attrname, parent, element, clone=_clone, **kw ): return {clone(clause, **kw) for clause in element} - def visit_clauseelement_tuples(self, parent, element, clone=_clone, **kw): + def visit_clauseelement_tuples( + self, attrname, parent, element, clone=_clone, **kw + ): return [ tuple(clone(tup_elem, **kw) for tup_elem in elem) for elem in element ] def visit_string_clauseelement_dict( - self, parent, element, clone=_clone, **kw + self, attrname, parent, element, clone=_clone, **kw ): return dict( (key, clone(value, **kw)) for key, value in element.items() ) - def visit_dml_ordered_values(self, parent, element, clone=_clone, **kw): + def visit_setup_join_tuple( + self, attrname, parent, element, clone=_clone, **kw + ): + # TODO: look at attrname for "legacy_join" and use different structure + return tuple( + ( + clone(target, **kw) if target is not None else None, + clone(onclause, **kw) if onclause is not None else None, + clone(from_, **kw) if from_ is not None else None, + flags, + ) + for (target, onclause, from_, flags) in element + ) + + def visit_dml_ordered_values( + self, attrname, parent, element, clone=_clone, **kw + ): # sequence of 2-tuples return [ ( @@ -534,7 +621,7 @@ class _CopyInternals(InternalTraversal): for key, value in element ] - def visit_dml_values(self, parent, element, clone=_clone, **kw): + def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw): return { ( clone(key, **kw) if hasattr(key, "__clause_element__") else key @@ -542,7 +629,9 @@ class _CopyInternals(InternalTraversal): for key, value in element.items() } - def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): + def visit_dml_multi_values( + self, attrname, parent, element, clone=_clone, **kw + ): # sequence of sequences, each sequence contains a list/dict/tuple def copy(elem): @@ -741,7 +830,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): continue comparison = dispatch( - left, left_child, right, right_child, **kw + left_attrname, left, left_child, right, right_child, **kw ) if comparison is COMPARE_FAILED: return False @@ -753,31 +842,40 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return comparator.compare(obj1, obj2, **kw) def visit_has_cache_key( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key( self.anon_map[1], [] ): return COMPARE_FAILED + def visit_has_cache_key_list( + self, attrname, left_parent, left, right_parent, right, **kw + ): + for l, r in util.zip_longest(left, right, fillvalue=None): + if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key( + self.anon_map[1], [] + ): + return COMPARE_FAILED + def visit_clauseelement( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): self.stack.append((left, right)) def visit_fromclause_canonical_column_collection( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lcol, rcol in util.zip_longest(left, right, fillvalue=None): self.stack.append((lcol, rcol)) def visit_fromclause_derived_column_collection( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): pass def visit_string_clauseelement_dict( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lstr, rstr in util.zip_longest( sorted(left), sorted(right), fillvalue=None @@ -787,7 +885,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((left[lstr], right[rstr])) def visit_clauseelement_tuples( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for ltup, rtup in util.zip_longest(left, right, fillvalue=None): if ltup is None or rtup is None: @@ -797,7 +895,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((l, r)) def visit_clauseelement_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) @@ -815,48 +913,62 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return len(completed) == len(seq1) == len(seq2) def visit_clauseelement_unordered_set( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return self._compare_unordered_sequences(left, right, **kw) def visit_fromclause_ordered_set( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for l, r in util.zip_longest(left, right, fillvalue=None): self.stack.append((l, r)) - def visit_string(self, left_parent, left, right_parent, right, **kw): + def visit_string( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_string_list(self, left_parent, left, right_parent, right, **kw): + def visit_string_list( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_anon_name(self, left_parent, left, right_parent, right, **kw): + def visit_anon_name( + self, attrname, left_parent, left, right_parent, right, **kw + ): return _resolve_name_for_compare( left_parent, left, self.anon_map[0], **kw ) == _resolve_name_for_compare( right_parent, right, self.anon_map[1], **kw ) - def visit_boolean(self, left_parent, left, right_parent, right, **kw): + def visit_boolean( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right - def visit_operator(self, left_parent, left, right_parent, right, **kw): + def visit_operator( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left is right - def visit_type(self, left_parent, left, right_parent, right, **kw): + def visit_type( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left._compare_type_affinity(right) - def visit_plain_dict(self, left_parent, left, right_parent, right, **kw): + def visit_plain_dict( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right def visit_dialect_options( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return left == right def visit_annotations_key( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left and right: return ( @@ -866,11 +978,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): else: return left == right - def visit_plain_obj(self, left_parent, left, right_parent, right, **kw): + def visit_plain_obj( + self, attrname, left_parent, left, right_parent, right, **kw + ): return left == right def visit_named_ddl_element( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): if left is None: if right is not None: @@ -879,7 +993,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return left.name == right.name def visit_prefix_sequence( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for (l_clause, l_str), (r_clause, r_str) in util.zip_longest( left, right, fillvalue=(None, None) @@ -889,8 +1003,22 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): else: self.stack.append((l_clause, r_clause)) + def visit_setup_join_tuple( + self, attrname, left_parent, left, right_parent, right, **kw + ): + # TODO: look at attrname for "legacy_join" and use different structure + for ( + (l_target, l_onclause, l_from, l_flags), + (r_target, r_onclause, r_from, r_flags), + ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)): + if l_flags != r_flags: + return COMPARE_FAILED + self.stack.append((l_target, r_target)) + self.stack.append((l_onclause, r_onclause)) + self.stack.append((l_from, r_from)) + def visit_table_hint_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1])) right_keys = sorted( @@ -907,17 +1035,17 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((ltable, rtable)) def visit_statement_hint_list( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): return left == right def visit_unknown_structure( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): raise NotImplementedError() def visit_dml_ordered_values( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): # sequence of tuple pairs @@ -941,7 +1069,9 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return True - def visit_dml_values(self, left_parent, left, right_parent, right, **kw): + def visit_dml_values( + self, attrname, left_parent, left, right_parent, right, **kw + ): if left is None or right is None or len(left) != len(right): return COMPARE_FAILED @@ -961,7 +1091,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): return COMPARE_FAILED def visit_dml_multi_values( - self, left_parent, left, right_parent, right, **kw + self, attrname, left_parent, left, right_parent, right, **kw ): for lseq, rseq in util.zip_longest(left, right, fillvalue=None): if lseq is None or rseq is None: @@ -970,7 +1100,7 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None): if ( self.visit_dml_values( - left_parent, ld, right_parent, rd, **kw + attrname, left_parent, ld, right_parent, rd, **kw ) is COMPARE_FAILED ): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 0a67ff9bf..377aa4fe0 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -37,6 +37,7 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause +from .traversals import HasCacheKey # noqa from .. import exc from .. import util @@ -921,6 +922,14 @@ class ColumnAdapter(ClauseAdapter): adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process + def adapt_check_present(self, col): + newcol = self.columns[col] + + if newcol is col and self._corresponding_column(col, True) is None: + return None + + return newcol + def _locate_col(self, col): c = ClauseAdapter.traverse(self, col) @@ -945,3 +954,25 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) self.columns = util.WeakPopulateDict(self._locate_col) + + +def _entity_namespace_key(entity, key): + """Return an entry from an entity_namespace. + + + Raises :class:`_exc.InvalidRequestError` rather than attribute error + on not found. + + """ + + ns = entity.entity_namespace + try: + return getattr(ns, key) + except AttributeError as err: + util.raise_( + exc.InvalidRequestError( + 'Entity namespace for "%s" has no property "%s"' + % (entity, key) + ), + replace_context=err, + ) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 574896cc7..030fd2fde 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -225,6 +225,9 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): dp_has_cache_key = symbol("HC") """Visit a :class:`.HasCacheKey` object.""" + dp_has_cache_key_list = symbol("HL") + """Visit a list of :class:`.HasCacheKey` objects.""" + dp_clauseelement = symbol("CE") """Visit a :class:`_expression.ClauseElement` object.""" @@ -372,6 +375,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): """ + dp_setup_join_tuple = symbol("SJ") + dp_statement_hint_list = symbol("SH") """Visit the ``_statement_hints`` collection of a :class:`_expression.Select` @@ -437,9 +442,6 @@ class ExtendedInternalTraversal(InternalTraversal): """ - dp_has_cache_key_list = symbol("HL") - """Visit a list of :class:`.HasCacheKey` objects.""" - dp_inspectable_list = symbol("IL") """Visit a list of inspectable objects which upon inspection are HasCacheKey objects.""" |
