diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-04-29 23:26:36 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-05-18 17:46:10 -0400 |
| commit | f07e050c9ce4afdeb9c0c136dbcc547f7e5ac7b8 (patch) | |
| tree | 1b3cd7409ae2eddef635960126551d74f469acc1 /lib/sqlalchemy/sql | |
| parent | 614dfb5f5b5a2427d5d6ce0bc5f34bf0581bf698 (diff) | |
| download | sqlalchemy-f07e050c9ce4afdeb9c0c136dbcc547f7e5ac7b8.tar.gz | |
Implement new ClauseElement role and coercion system
A major refactoring of all the functions handle all detection of
Core argument types as well as perform coercions into a new class hierarchy
based on "roles", each of which identify a syntactical location within a
SQL statement. In contrast to the ClauseElement hierarchy that identifies
"what" each object is syntactically, the SQLRole hierarchy identifies
the "where does it go" of each object syntactically. From this we define
a consistent type checking and coercion system that establishes well
defined behviors.
This is a breakout of the patch that is reorganizing select()
constructs to no longer be in the FromClause hierarchy.
Also includes a rename of as_scalar() into scalar_subquery(); deprecates
automatic coercion to scalar_subquery().
Partially-fixes: #4617
Change-Id: I26f1e78898693c6b99ef7ea2f4e7dfd0e8e1a1bd
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 580 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 146 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 52 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 678 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 24 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/roles.py | 157 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 45 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 304 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 3 |
17 files changed, 1330 insertions, 798 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index fb5639ef3..00cafd8ff 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -94,6 +94,22 @@ def __go(lcls): from .elements import ClauseList # noqa from .selectable import AnnotatedFromClause # noqa + from . import base + from . import coercions + from . import elements + from . import selectable + from . import schema + from . import sqltypes + from . import type_api + + base.coercions = elements.coercions = coercions + base.elements = elements + base.type_api = type_api + coercions.elements = elements + coercions.schema = schema + coercions.selectable = selectable + coercions.sqltypes = sqltypes + _prepare_annotations(ColumnElement, AnnotatedColumnElement) _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index c5e5fd8a1..9df0c932f 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -17,6 +17,9 @@ from .visitors import ClauseVisitor from .. import exc from .. import util +coercions = None # type: types.ModuleType +elements = None # type: types.ModuleType +type_api = None # type: types.ModuleType PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") NO_ARG = util.symbol("NO_ARG") @@ -589,8 +592,7 @@ class ColumnCollection(util.OrderedProperties): __hash__ = None - @util.dependencies("sqlalchemy.sql.elements") - def __eq__(self, elements, other): + def __eq__(self, other): l = [] for c in getattr(other, "_all_columns", other): for local in self._all_columns: @@ -636,8 +638,7 @@ class ColumnSet(util.ordered_column_set): def __add__(self, other): return list(self) + list(other) - @util.dependencies("sqlalchemy.sql.elements") - def __eq__(self, elements, other): + def __eq__(self, other): l = [] for c in other: for local in self: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py new file mode 100644 index 000000000..7c7222f9f --- /dev/null +++ b/lib/sqlalchemy/sql/coercions.py @@ -0,0 +1,580 @@ +# sql/coercions.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import numbers +import re + +from . import operators +from . import roles +from . import visitors +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util import collections_abc + +elements = None # type: types.ModuleType +schema = None # type: types.ModuleType +selectable = None # type: types.ModuleType +sqltypes = None # type: types.ModuleType + + +def _is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + """ + return not isinstance( + element, (Visitable, schema.SchemaEventTarget) + ) and not hasattr(element, "__clause_element__") + + +def _document_text_coercion(paramname, meth_rst, param_rst): + return util.add_parameter_text( + paramname, + ( + ".. warning:: " + "The %s argument to %s can be passed as a Python string argument, " + "which will be treated " + "as **trusted SQL text** and rendered as given. **DO NOT PASS " + "UNTRUSTED INPUT TO THIS PARAMETER**." + ) + % (param_rst, meth_rst), + ) + + +def expect(role, element, **kw): + # major case is that we are given a ClauseElement already, skip more + # elaborate logic up front if possible + impl = _impl_lookup[role] + + if not isinstance(element, (elements.ClauseElement, schema.SchemaItem)): + resolved = impl._resolve_for_clause_element(element, **kw) + else: + resolved = element + + if issubclass(resolved.__class__, impl._role_class): + if impl._post_coercion: + resolved = impl._post_coercion(resolved, **kw) + return resolved + else: + return impl._implicit_coercions(element, resolved, **kw) + + +def expect_as_key(role, element, **kw): + kw["as_key"] = True + return expect(role, element, **kw) + + +def expect_col_expression_collection(role, expressions): + for expr in expressions: + strname = None + column = None + + resolved = expect(role, expr) + if isinstance(resolved, util.string_types): + strname = resolved = expr + else: + cols = [] + visitors.traverse(resolved, {}, {"column": cols.append}) + if cols: + column = cols[0] + add_element = column if column is not None else strname + yield resolved, column, strname, add_element + + +class RoleImpl(object): + __slots__ = ("_role_class", "name", "_use_inspection") + + def _literal_coercion(self, element, **kw): + raise NotImplementedError() + + _post_coercion = None + + def __init__(self, role_class): + self._role_class = role_class + self.name = role_class._role_name + self._use_inspection = issubclass(role_class, roles.UsesInspection) + + def _resolve_for_clause_element(self, element, argname=None, **kw): + literal_coercion = self._literal_coercion + original_element = element + is_clause_element = False + + while hasattr(element, "__clause_element__") and not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + element = element.__clause_element__() + is_clause_element = True + + if not is_clause_element: + if self._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + try: + return insp.__clause_element__() + except AttributeError: + self._raise_for_expected(original_element, argname) + + return self._literal_coercion(element, argname=argname, **kw) + else: + return element + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + self._raise_for_expected(element, argname) + + def _raise_for_expected(self, element, argname=None): + if argname: + raise exc.ArgumentError( + "%s expected for argument %r; got %r." + % (self.name, argname, element) + ) + else: + raise exc.ArgumentError( + "%s expected, got %r." % (self.name, element) + ) + + +class _StringOnly(object): + def _resolve_for_clause_element(self, element, argname=None, **kw): + return self._literal_coercion(element, **kw) + + +class _ReturnsStringKey(object): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return original_element + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, **kw): + return element + + +class _ColumnCoercions(object): + def _warn_for_scalar_subquery_coercion(self): + util.warn_deprecated( + "coercing SELECT object to scalar subquery in a " + "column-expression context is deprecated in version 1.4; " + "please use the .scalar_subquery() method to produce a scalar " + "subquery. This automatic coercion will be removed in a " + "future release." + ) + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_select_statement: + self._warn_for_scalar_subquery_coercion() + return resolved.scalar_subquery() + elif ( + resolved._is_from_clause + and isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + self._warn_for_scalar_subquery_coercion() + return resolved.original.scalar_subquery() + else: + self._raise_for_expected(original_element, argname) + + +def _no_text_coercion( + element, argname=None, exc_cls=exc.ArgumentError, extra=None +): + raise exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ) + + +class _NoTextCoercion(object): + def _literal_coercion(self, element, argname=None): + if isinstance(element, util.string_types) and issubclass( + elements.TextClause, self._role_class + ): + _no_text_coercion(element, argname) + else: + self._raise_for_expected(element, argname) + + +class _CoerceLiterals(object): + _coerce_consts = False + _coerce_star = False + _coerce_numerics = False + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + def _literal_coercion(self, element, argname=None): + if isinstance(element, util.string_types): + if self._coerce_star and element == "*": + return elements.ColumnClause("*", is_literal=True) + else: + return self._text_coercion(element, argname) + + if self._coerce_consts: + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + + if self._coerce_numerics and isinstance(element, (numbers.Number)): + return elements.ColumnClause(str(element), is_literal=True) + + self._raise_for_expected(element, argname) + + +class ExpressionElementImpl( + _ColumnCoercions, RoleImpl, roles.ExpressionElementRole +): + def _literal_coercion(self, element, name=None, type_=None, argname=None): + if element is None: + return elements.Null() + else: + try: + return elements.BindParameter( + name, element, type_, unique=True + ) + except exc.ArgumentError: + self._raise_for_expected(element) + + +class BinaryElementImpl( + ExpressionElementImpl, RoleImpl, roles.BinaryElementRole +): + def _literal_coercion( + self, element, expr, operator, bindparam_type=None, argname=None + ): + try: + return expr._bind_param(operator, element, type_=bindparam_type) + except exc.ArgumentError: + self._raise_for_expected(element) + + def _post_coercion(self, resolved, expr, **kw): + if ( + isinstance(resolved, elements.BindParameter) + and resolved.type._isnull + ): + resolved = resolved._clone() + resolved.type = expr.type + return resolved + + +class InElementImpl(RoleImpl, roles.InElementRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + return resolved.original + else: + return resolved.select() + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, expr, operator, **kw): + if isinstance(element, collections_abc.Iterable) and not isinstance( + element, util.string_types + ): + args = [] + for o in element: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + self._raise_for_expected(element, **kw) + elif o is None: + o = elements.Null() + else: + o = expr._bind_param(operator, o) + args.append(o) + + return elements.ClauseList(*args) + + else: + self._raise_for_expected(element, **kw) + + def _post_coercion(self, element, expr, operator, **kw): + if element._is_select_statement: + return element.scalar_subquery() + elif isinstance(element, elements.ClauseList): + if len(element.clauses) == 0: + op, negate_op = ( + (operators.empty_in_op, operators.empty_notin_op) + if operator is operators.in_op + else (operators.empty_notin_op, operators.empty_in_op) + ) + return element.self_group(against=op)._annotate( + dict(in_ops=(op, negate_op)) + ) + else: + return element.self_group(against=operator) + + elif isinstance(element, elements.BindParameter) and element.expanding: + + if isinstance(expr, elements.Tuple): + element = element._with_expanding_in_types( + [elem.type for elem in expr] + ) + return element + else: + return element + + +class WhereHavingImpl( + _CoerceLiterals, _ColumnCoercions, RoleImpl, roles.WhereHavingRole +): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + +class StatementOptionImpl( + _CoerceLiterals, RoleImpl, roles.StatementOptionRole +): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class ColumnArgumentImpl(_NoTextCoercion, RoleImpl, roles.ColumnArgumentRole): + pass + + +class ColumnArgumentOrKeyImpl( + _ReturnsStringKey, RoleImpl, roles.ColumnArgumentOrKeyRole +): + pass + + +class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements._textual_label_reference(element) + + +class OrderByImpl(ByOfImpl, RoleImpl, roles.OrderByRole): + def _post_coercion(self, resolved): + if ( + isinstance(resolved, self._role_class) + and resolved._order_by_label_element is not None + ): + return elements._label_reference(resolved) + else: + return resolved + + +class DMLColumnImpl(_ReturnsStringKey, RoleImpl, roles.DMLColumnRole): + def _post_coercion(self, element, as_key=False): + if as_key: + return element.key + else: + return element + + +class ConstExprImpl(RoleImpl, roles.ConstExprRole): + def _literal_coercion(self, element, argname=None): + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + else: + self._raise_for_expected(element, argname) + + +class TruncatedLabelImpl(_StringOnly, RoleImpl, roles.TruncatedLabelRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return resolved + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, argname=None): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(element, elements._truncated_label): + return element + else: + return elements._truncated_label(element) + + +class DDLExpressionImpl(_CoerceLiterals, RoleImpl, roles.DDLExpressionRole): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class DDLConstraintColumnImpl( + _ReturnsStringKey, RoleImpl, roles.DDLConstraintColumnRole +): + pass + + +class LimitOffsetImpl(RoleImpl, roles.LimitOffsetRole): + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if resolved is None: + return None + else: + self._raise_for_expected(element, argname) + + def _literal_coercion(self, element, name, type_, **kw): + if element is None: + return None + else: + value = util.asint(element) + return selectable._OffsetLimitParam( + name, value, type_=type_, unique=True + ) + + +class LabeledColumnExprImpl( + ExpressionElementImpl, roles.LabeledColumnExprRole +): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(resolved, roles.ExpressionElementRole): + return resolved.label(None) + else: + new = super(LabeledColumnExprImpl, self)._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + if isinstance(new, roles.ExpressionElementRole): + return new.label(None) + else: + self._raise_for_expected(original_element, argname) + + +class ColumnsClauseImpl(_CoerceLiterals, RoleImpl, roles.ColumnsClauseRole): + + _coerce_consts = True + _coerce_numerics = True + _coerce_star = True + + _guess_straight_column = re.compile(r"^\w\S*$", re.I) + + def _text_coercion(self, element, argname=None): + element = str(element) + + guess_is_literal = not self._guess_straight_column.match(element) + raise exc.ArgumentError( + "Textual column expression %(column)r %(argname)sshould be " + "explicitly declared with text(%(column)r), " + "or use %(literal_column)s(%(column)r) " + "for more specificity" + % { + "column": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "literal_column": "literal_column" + if guess_is_literal + else "column", + } + ) + + +class ReturnsRowsImpl(RoleImpl, roles.ReturnsRowsRole): + pass + + +class StatementImpl(_NoTextCoercion, RoleImpl, roles.StatementRole): + pass + + +class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl, roles.StatementRole): + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class SelectStatementImpl( + _NoTextCoercion, RoleImpl, roles.SelectStatementRole +): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_text_clause: + return resolved.columns() + else: + self._raise_for_expected(original_element, argname) + + +class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole): + pass + + +class FromClauseImpl(_NoTextCoercion, RoleImpl, roles.FromClauseRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_text_clause: + return resolved + else: + self._raise_for_expected(original_element, argname) + + +class DMLSelectImpl(_NoTextCoercion, RoleImpl, roles.DMLSelectRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + return resolved.original + else: + return resolved.select() + else: + self._raise_for_expected(original_element, argname) + + +class CompoundElementImpl( + _NoTextCoercion, RoleImpl, roles.CompoundElementRole +): + def _implicit_coercions(self, original_element, resolved, argname=None): + if resolved._is_from_clause: + return resolved + else: + self._raise_for_expected(original_element, argname) + + +_impl_lookup = {} + + +for name in dir(roles): + cls = getattr(roles, name) + if name.endswith("Role"): + name = name.replace("Role", "Impl") + if name in globals(): + impl = globals()[name](cls) + _impl_lookup[cls] = impl diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c7fe3dc50..8080d2cc6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -27,14 +27,15 @@ import contextlib import itertools import re +from . import coercions from . import crud from . import elements from . import functions from . import operators +from . import roles from . import schema from . import selectable from . import sqltypes -from . import visitors from .. import exc from .. import util @@ -400,7 +401,9 @@ class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): return type_._compiler_dispatch(self, **kw) -class _CompileLabel(visitors.Visitable): +# this was a Visitable, but to allow accurate detection of +# column elements this is actually a column element +class _CompileLabel(elements.ColumnElement): """lightweight label object which acts as an expression.Label.""" @@ -766,10 +769,10 @@ class SQLCompiler(Compiled): else: col = with_cols[element.element] except KeyError: - elements._no_text_coercion( + coercions._no_text_coercion( element.element, - exc.CompileError, - "Can't resolve label reference for ORDER BY / GROUP BY.", + extra="Can't resolve label reference for ORDER BY / GROUP BY.", + exc_cls=exc.CompileError, ) else: kwargs["render_label_as_label"] = col @@ -1635,7 +1638,6 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_name[cte_name] = cte - # look for embedded DML ctes and propagate autocommit if ( "autocommit" in cte.element._execution_options and "autocommit" not in self.execution_options @@ -1656,10 +1658,10 @@ class SQLCompiler(Compiled): self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) if cte.recursive: - if isinstance(cte.original, selectable.Select): - col_source = cte.original - elif isinstance(cte.original, selectable.CompoundSelect): - col_source = cte.original.selects[0] + if isinstance(cte.element, selectable.Select): + col_source = cte.element + elif isinstance(cte.element, selectable.CompoundSelect): + col_source = cte.element.selects[0] else: assert False recur_cols = [ @@ -1810,7 +1812,7 @@ class SQLCompiler(Compiled): ): result_expr = _CompileLabel( col_expr, - elements._as_truncated(column.name), + coercions.expect(roles.TruncatedLabelRole, column.name), alt_names=(column.key,), ) elif ( @@ -1830,7 +1832,7 @@ class SQLCompiler(Compiled): # assert isinstance(column, elements.ColumnClause) result_expr = _CompileLabel( col_expr, - elements._as_truncated(column.name), + coercions.expect(roles.TruncatedLabelRole, column.name), alt_names=(column.key,), ) else: @@ -1880,7 +1882,7 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() if ( - newelem.is_selectable + newelem._is_from_clause and newelem._is_join and isinstance(newelem.right, selectable.FromGrouping) ): @@ -1933,7 +1935,7 @@ class SQLCompiler(Compiled): # marker in the stack. kw["transform_clue"] = "select_container" newelem._copy_internals(clone=visit, **kw) - elif newelem.is_selectable and newelem._is_select: + elif newelem._is_returns_rows and newelem._is_select_statement: barrier_select = ( kw.get("transform_clue", None) == "select_container" ) @@ -2349,6 +2351,7 @@ class SQLCompiler(Compiled): + join_type + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + " ON " + # TODO: likely need asfrom=True here? + join.onclause._compiler_dispatch(self, **kwargs) ) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 552f61b4a..881ea9fcd 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -9,14 +9,16 @@ within INSERT and UPDATE statements. """ +import functools import operator +from . import coercions from . import dml from . import elements +from . import roles from .. import exc from .. import util - REQUIRED = util.symbol( "REQUIRED", """ @@ -174,7 +176,7 @@ def _get_crud_params(compiler, stmt, **kw): if check: raise exc.CompileError( "Unconsumed column names: %s" - % (", ".join("%s" % c for c in check)) + % (", ".join("%s" % (c,) for c in check)) ) if stmt._has_multi_parameters: @@ -207,8 +209,12 @@ def _key_getters_for_crud_column(compiler, stmt): # statement. _et = set(stmt._extra_froms) + c_key_role = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + def _column_as_key(key): - str_key = elements._column_as_key(key) + str_key = c_key_role(key) if hasattr(key, "table") and key.table in _et: return (key.table.name, str_key) else: @@ -227,7 +233,9 @@ def _key_getters_for_crud_column(compiler, stmt): return col.key else: - _column_as_key = elements._column_as_key + _column_as_key = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) _getattr_col_key = _col_bind_name = operator.attrgetter("key") return _column_as_key, _getattr_col_key, _col_bind_name @@ -386,7 +394,7 @@ def _append_param_parameter( kw, ): value = parameters.pop(col_key) - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -633,9 +641,8 @@ def _get_multitable_params( values, kw, ): - normalized_params = dict( - (elements._clause_element_as_expr(c), param) + (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) affected_tables = set() @@ -645,7 +652,7 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -697,7 +704,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): if col in row or col.key in row: key = col if col in row else col.key - if elements._is_literal(row[key]): + if coercions._is_literal(row[key]): new_param = _create_bind_param( compiler, col, @@ -730,7 +737,7 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param - if elements._is_literal(v): + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw ) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index d87a6a1b0..ff36a68e4 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -10,6 +10,7 @@ to invoke them for a create/drop call. """ +from . import roles from .base import _bind_or_error from .base import _generative from .base import Executable @@ -29,7 +30,7 @@ class _DDLCompiles(ClauseElement): return dialect.ddl_compiler(dialect, self, **kw) -class DDLElement(Executable, _DDLCompiles): +class DDLElement(roles.DDLRole, Executable, _DDLCompiles): """Base class for DDL expression constructs. This class is the base for the general purpose :class:`.DDL` class, diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 9a12b84cd..918f7524e 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -8,32 +8,21 @@ """Default implementation of SQL comparison operations. """ + +from . import coercions from . import operators +from . import roles from . import type_api -from .elements import _clause_element_as_expr -from .elements import _const_expr -from .elements import _is_literal -from .elements import _literal_as_text from .elements import and_ from .elements import BinaryExpression -from .elements import BindParameter -from .elements import ClauseElement from .elements import ClauseList from .elements import collate from .elements import CollectionAggregate -from .elements import ColumnElement from .elements import False_ from .elements import Null from .elements import or_ -from .elements import TextClause from .elements import True_ -from .elements import Tuple from .elements import UnaryExpression -from .elements import Visitable -from .selectable import Alias -from .selectable import ScalarSelect -from .selectable import Selectable -from .selectable import SelectBase from .. import exc from .. import util @@ -62,7 +51,7 @@ def _boolean_compare( ): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -71,7 +60,7 @@ def _boolean_compare( elif op in (operators.is_distinct_from, operators.isnot_distinct_from): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -82,7 +71,7 @@ def _boolean_compare( if op in (operators.eq, operators.is_): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.is_, negate=operators.isnot, type_=result_type, @@ -90,7 +79,7 @@ def _boolean_compare( elif op in (operators.ne, operators.isnot): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.isnot, negate=operators.is_, type_=result_type, @@ -102,7 +91,9 @@ def _boolean_compare( "operators can be used with None/True/False" ) else: - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, element=obj, operator=op, expr=expr + ) if reverse: return BinaryExpression( @@ -127,7 +118,9 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw): - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, obj, expr=expr, operator=op + ) if reverse: left, right = obj, expr @@ -156,77 +149,22 @@ def _scalar(expr, op, fn, **kw): def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): - seq_or_selectable = _clause_element_as_expr(seq_or_selectable) - - if isinstance(seq_or_selectable, ScalarSelect): - return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op) - elif isinstance(seq_or_selectable, SelectBase): - - # TODO: if we ever want to support (x, y, z) IN (select x, - # y, z from table), we would need a multi-column version of - # as_scalar() to produce a multi- column selectable that - # does not export itself as a FROM clause - - return _boolean_compare( - expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, ClauseElement): - if ( - isinstance(seq_or_selectable, BindParameter) - and seq_or_selectable.expanding - ): - - if isinstance(expr, Tuple): - seq_or_selectable = seq_or_selectable._with_expanding_in_types( - [elem.type for elem in expr] - ) - - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op - ) - else: - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' - % seq_or_selectable - ) - - # Handle non selectable arguments as sequences - args = [] - for o in seq_or_selectable: - if not _is_literal(o): - if not isinstance(o, operators.ColumnOperators): - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' % o - ) - elif o is None: - o = Null() - else: - o = expr._bind_param(op, o) - args.append(o) - - if len(args) == 0: - op, negate_op = ( - (operators.empty_in_op, operators.empty_notin_op) - if op is operators.in_op - else (operators.empty_notin_op, operators.empty_in_op) - ) + seq_or_selectable = coercions.expect( + roles.InElementRole, seq_or_selectable, expr=expr, operator=op + ) + if "in_ops" in seq_or_selectable._annotations: + op, negate_op = seq_or_selectable._annotations["in_ops"] return _boolean_compare( - expr, op, ClauseList(*args).self_group(against=op), negate=negate_op + expr, op, seq_or_selectable, negate=negate_op, **kw ) def _getitem_impl(expr, op, other, **kw): if isinstance(expr.type, type_api.INDEXABLE): - other = _check_literal(expr, op, other) + other = coercions.expect( + roles.BinaryElementRole, other, expr=expr, operator=op + ) return _binary_operate(expr, op, other, **kw) else: _unsupported_impl(expr, op, other, **kw) @@ -257,7 +195,12 @@ def _match_impl(expr, op, other, **kw): return _boolean_compare( expr, operators.match_op, - _check_literal(expr, operators.match_op, other), + coercions.expect( + roles.BinaryElementRole, + other, + expr=expr, + operator=operators.match_op, + ), result_type=type_api.MATCHTYPE, negate=operators.notmatch_op if op is operators.match_op @@ -278,8 +221,18 @@ def _between_impl(expr, op, cleft, cright, **kw): return BinaryExpression( expr, ClauseList( - _check_literal(expr, operators.and_, cleft), - _check_literal(expr, operators.and_, cright), + coercions.expect( + roles.BinaryElementRole, + cleft, + expr=expr, + operator=operators.and_, + ), + coercions.expect( + roles.BinaryElementRole, + cright, + expr=expr, + operator=operators.and_, + ), operator=operators.and_, group=False, group_contents=False, @@ -349,22 +302,3 @@ operator_lookup = { "rshift": (_unsupported_impl,), "contains": (_unsupported_impl,), } - - -def _check_literal(expr, operator, other, bindparam_type=None): - if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and other.type._isnull: - other = other._clone() - other.type = expr.type - return other - elif hasattr(other, "__clause_element__"): - other = other.__clause_element__() - elif isinstance(other, type_api.TypeEngine.Comparator): - other = other.expr - - if isinstance(other, (SelectBase, Alias)): - return other.as_scalar() - elif not isinstance(other, Visitable): - return expr._bind_param(operator, other, type_=bindparam_type) - else: - return other diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 3c40e7914..c7d83fc12 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,18 +9,16 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ +from . import coercions +from . import roles from .base import _from_objects from .base import _generative from .base import DialectKWArgs from .base import Executable from .elements import _clone -from .elements import _column_as_key -from .elements import _literal_as_text from .elements import and_ from .elements import ClauseElement from .elements import Null -from .selectable import _interpret_as_from -from .selectable import _interpret_as_select from .selectable import HasCTE from .selectable import HasPrefixes from .. import exc @@ -28,7 +26,12 @@ from .. import util class UpdateBase( - HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement + roles.DMLRole, + HasCTE, + DialectKWArgs, + HasPrefixes, + Executable, + ClauseElement, ): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. @@ -210,7 +213,7 @@ class ValuesBase(UpdateBase): _post_values_clause = None def __init__(self, table, values, prefixes): - self.table = _interpret_as_from(table) + self.table = coercions.expect(roles.FromClauseRole, table) self.parameters, self._has_multi_parameters = self._process_colparams( values ) @@ -604,13 +607,16 @@ class Insert(ValuesBase): ) self.parameters, self._has_multi_parameters = self._process_colparams( - {_column_as_key(n): Null() for n in names} + { + coercions.expect(roles.DMLColumnRole, n, as_key=True): Null() + for n in names + } ) self.select_names = names self.inline = True self.include_insert_from_select_defaults = include_defaults - self.select = _interpret_as_select(select) + self.select = coercions.expect(roles.DMLSelectRole, select) def _copy_internals(self, clone=_clone, **kw): # TODO: coverage @@ -678,7 +684,7 @@ class Update(ValuesBase): users.update().values(name='ed').where( users.c.name==select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) :param values: @@ -744,7 +750,7 @@ class Update(ValuesBase): users.update().values( name=select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) .. seealso:: @@ -759,7 +765,9 @@ class Update(ValuesBase): self._bind = bind self._returning = returning if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) else: self._whereclause = None self.inline = inline @@ -785,10 +793,13 @@ class Update(ValuesBase): """ if self._whereclause is not None: self._whereclause = and_( - self._whereclause, _literal_as_text(whereclause) + self._whereclause, + coercions.expect(roles.WhereHavingRole, whereclause), ) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) @property def _extra_froms(self): @@ -846,7 +857,7 @@ class Delete(UpdateBase): users.delete().where( users.c.name==select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) .. versionchanged:: 1.2.0 @@ -858,14 +869,16 @@ class Delete(UpdateBase): """ self._bind = bind - self.table = _interpret_as_from(table) + self.table = coercions.expect(roles.FromClauseRole, table) self._returning = returning if prefixes: self._setup_prefixes(prefixes) if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) else: self._whereclause = None @@ -883,10 +896,13 @@ class Delete(UpdateBase): if self._whereclause is not None: self._whereclause = and_( - self._whereclause, _literal_as_text(whereclause) + self._whereclause, + coercions.expect(roles.WhereHavingRole, whereclause), ) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) @property def _extra_froms(self): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e634e5a36..a333303ec 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -13,12 +13,13 @@ from __future__ import unicode_literals import itertools -import numbers import operator import re from . import clause_compare +from . import coercions from . import operators +from . import roles from . import type_api from .annotation import Annotated from .base import _generative @@ -26,6 +27,7 @@ from .base import Executable from .base import Immutable from .base import NO_ARG from .base import PARSE_AUTOCOMMIT +from .coercions import _document_text_coercion from .visitors import cloned_traverse from .visitors import traverse from .visitors import Visitable @@ -38,20 +40,6 @@ def _clone(element, **kw): return element._clone() -def _document_text_coercion(paramname, meth_rst, param_rst): - return util.add_parameter_text( - paramname, - ( - ".. warning:: " - "The %s argument to %s can be passed as a Python string argument, " - "which will be treated " - "as **trusted SQL text** and rendered as given. **DO NOT PASS " - "UNTRUSTED INPUT TO THIS PARAMETER**." - ) - % (param_rst, meth_rst), - ) - - def collate(expression, collation): """Return the clause ``expression COLLATE collation``. @@ -71,7 +59,7 @@ def collate(expression, collation): """ - expr = _literal_as_binds(expression) + expr = coercions.expect(roles.ExpressionElementRole, expression) return BinaryExpression( expr, CollationClause(collation), operators.collate, type_=expr.type ) @@ -127,7 +115,7 @@ def between(expr, lower_bound, upper_bound, symmetric=False): :meth:`.ColumnElement.between` """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) return expr.between(lower_bound, upper_bound, symmetric=symmetric) @@ -172,11 +160,11 @@ def not_(clause): same result. """ - return operators.inv(_literal_as_binds(clause)) + return operators.inv(coercions.expect(roles.ExpressionElementRole, clause)) @inspection._self_inspects -class ClauseElement(Visitable): +class ClauseElement(roles.SQLRole, Visitable): """Base class for elements of a programmatically constructed SQL expression. @@ -188,13 +176,20 @@ class ClauseElement(Visitable): supports_execution = False _from_objects = [] bind = None + description = None _is_clone_of = None - is_selectable = False + is_clause_element = True + is_selectable = False - description = None - _order_by_label_element = None + _is_textual = False + _is_from_clause = False + _is_returns_rows = False + _is_text_clause = False _is_from_container = False + _is_select_statement = False + + _order_by_label_element = None def _clone(self): """Create a shallow copy of this ClauseElement. @@ -238,7 +233,7 @@ class ClauseElement(Visitable): """ - raise NotImplementedError(self.__class__) + raise NotImplementedError() @property def _constructor(self): @@ -394,6 +389,7 @@ class ClauseElement(Visitable): return [] def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement """Apply a 'grouping' to this :class:`.ClauseElement`. This method is overridden by subclasses to return a @@ -553,7 +549,20 @@ class ClauseElement(Visitable): ) -class ColumnElement(operators.ColumnOperators, ClauseElement): +class ColumnElement( + roles.ColumnArgumentOrKeyRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.BinaryElementRole, + roles.OrderByRole, + roles.ColumnsClauseRole, + roles.LimitOffsetRole, + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + operators.ColumnOperators, + ClauseElement, +): """Represent a column-oriented SQL expression suitable for usage in the "columns" clause, WHERE clause etc. of a statement. @@ -586,17 +595,13 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): :class:`.TypeEngine` objects) are applied to the value. * any special object value, typically ORM-level constructs, which - feature a method called ``__clause_element__()``. The Core + feature an accessor called ``__clause_element__()``. The Core expression system looks for this method when an object of otherwise unknown type is passed to a function that is looking to coerce the - argument into a :class:`.ColumnElement` expression. The - ``__clause_element__()`` method, if present, should return a - :class:`.ColumnElement` instance. The primary use of - ``__clause_element__()`` within SQLAlchemy is that of class-bound - attributes on ORM-mapped classes; a ``User`` class which contains a - mapped attribute named ``.name`` will have a method - ``User.name.__clause_element__()`` which when invoked returns the - :class:`.Column` called ``name`` associated with the mapped table. + argument into a :class:`.ColumnElement` and sometimes a + :class:`.SelectBase` expression. It is used within the ORM to + convert from ORM-specific objects like mapped classes and + mapped attributes into Core expression objects. * The Python ``None`` value is typically interpreted as ``NULL``, which in SQLAlchemy Core produces an instance of :func:`.null`. @@ -702,6 +707,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): _alt_names = () def self_group(self, against=None): + # type: (Module, Module, Optional[Any]) -> ClauseEleent if ( against in (operators.and_, operators.or_, operators._asbool) and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity @@ -826,7 +832,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): else: key = name co = ColumnClause( - _as_truncated(name) if name_is_truncatable else name, + coercions.expect(roles.TruncatedLabelRole, name) + if name_is_truncatable + else name, type_=getattr(self, "type", None), _selectable=selectable, ) @@ -878,7 +886,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): ) -class BindParameter(ColumnElement): +class BindParameter(roles.InElementRole, ColumnElement): r"""Represent a "bound expression". :class:`.BindParameter` is invoked explicitly using the @@ -1235,7 +1243,8 @@ class BindParameter(ColumnElement): "bindparams collection argument required for _cache_key " "implementation. Bound parameter cache keys are not safe " "to use without accommodating for the value or callable " - "within the parameter itself.") + "within the parameter itself." + ) else: bindparams.append(self) return (BindParameter, self.type._cache_key, self._orig_key) @@ -1282,7 +1291,20 @@ class TypeClause(ClauseElement): return (TypeClause, self.type._cache_key) -class TextClause(Executable, ClauseElement): +class TextClause( + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.OrderByRole, + roles.FromClauseRole, + roles.SelectStatementRole, + roles.CoerceTextStatementRole, + roles.BinaryElementRole, + roles.InElementRole, + Executable, + ClauseElement, +): """Represent a literal SQL text fragment. E.g.:: @@ -1304,6 +1326,10 @@ class TextClause(Executable, ClauseElement): __visit_name__ = "textclause" + _is_text_clause = True + + _is_textual = True + _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE) _execution_options = Executable._execution_options.union( {"autocommit": PARSE_AUTOCOMMIT} @@ -1318,20 +1344,16 @@ class TextClause(Executable, ClauseElement): def _select_iterable(self): return (self,) - @property - def selectable(self): - # allows text() to be considered by - # _interpret_as_from - return self - - _hide_froms = [] - # help in those cases where text() is # interpreted in a column expression situation key = _label = _resolve_label = None _allow_label_resolve = False + @property + def _hide_froms(self): + return [] + def __init__(self, text, bind=None): self._bind = bind self._bindparams = {} @@ -1670,7 +1692,6 @@ class TextClause(Executable, ClauseElement): """ - positional_input_cols = [ ColumnClause(col.key, types.pop(col.key)) if col.key in types @@ -1696,6 +1717,7 @@ class TextClause(Executable, ClauseElement): return self.type.comparator_factory(self) def self_group(self, against=None): + # type: (Optional[Any]) -> Union[Grouping, TextClause] if against is operators.in_op: return Grouping(self) else: @@ -1715,7 +1737,7 @@ class TextClause(Executable, ClauseElement): ) -class Null(ColumnElement): +class Null(roles.ConstExprRole, ColumnElement): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -1739,7 +1761,7 @@ class Null(ColumnElement): return (Null,) -class False_(ColumnElement): +class False_(roles.ConstExprRole, ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. :class:`.False_` is accessed as a constant via the @@ -1798,7 +1820,7 @@ class False_(ColumnElement): return (False_,) -class True_(ColumnElement): +class True_(roles.ConstExprRole, ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -1864,7 +1886,12 @@ class True_(ColumnElement): return (True_,) -class ClauseList(ClauseElement): +class ClauseList( + roles.InElementRole, + roles.OrderByRole, + roles.ColumnsClauseRole, + ClauseElement, +): """Describe a list of clauses, separated by an operator. By default, is comma-separated, such as a column listing. @@ -1877,16 +1904,22 @@ class ClauseList(ClauseElement): self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) self.group_contents = kwargs.pop("group_contents", True) - text_converter = kwargs.pop( - "_literal_as_text", _expression_literal_as_text + + self._text_converter_role = text_converter_role = kwargs.pop( + "_literal_as_text_role", roles.WhereHavingRole ) if self.group_contents: self.clauses = [ - text_converter(clause).self_group(against=self.operator) + coercions.expect(text_converter_role, clause).self_group( + against=self.operator + ) for clause in clauses ] else: - self.clauses = [text_converter(clause) for clause in clauses] + self.clauses = [ + coercions.expect(text_converter_role, clause) + for clause in clauses + ] self._is_implicitly_boolean = operators.is_boolean(self.operator) def __iter__(self): @@ -1902,10 +1935,14 @@ class ClauseList(ClauseElement): def append(self, clause): if self.group_contents: self.clauses.append( - _literal_as_text(clause).self_group(against=self.operator) + coercions.expect(self._text_converter_role, clause).self_group( + against=self.operator + ) ) else: - self.clauses.append(_literal_as_text(clause)) + self.clauses.append( + coercions.expect(self._text_converter_role, clause) + ) def _copy_internals(self, clone=_clone, **kw): self.clauses = [clone(clause, **kw) for clause in self.clauses] @@ -1923,6 +1960,7 @@ class ClauseList(ClauseElement): return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -1947,7 +1985,7 @@ class BooleanClauseList(ClauseList, ColumnElement): convert_clauses = [] clauses = [ - _expression_literal_as_text(clause) + coercions.expect(roles.WhereHavingRole, clause) for clause in util.coerce_generator_arg(clauses) ] for clause in clauses: @@ -2055,6 +2093,7 @@ class BooleanClauseList(ClauseList, ColumnElement): return (self,) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if not self.clauses: return self else: @@ -2092,7 +2131,9 @@ class Tuple(ClauseList, ColumnElement): """ - clauses = [_literal_as_binds(c) for c in clauses] + clauses = [ + coercions.expect(roles.ExpressionElementRole, c) for c in clauses + ] self._type_tuple = [arg.type for arg in clauses] self.type = kw.pop( "type_", @@ -2283,12 +2324,20 @@ class Case(ColumnElement): if value is not None: whenlist = [ - (_literal_as_binds(c).self_group(), _literal_as_binds(r)) + ( + coercions.expect( + roles.ExpressionElementRole, c + ).self_group(), + coercions.expect(roles.ExpressionElementRole, r), + ) for (c, r) in whens ] else: whenlist = [ - (_no_literals(c).self_group(), _literal_as_binds(r)) + ( + coercions.expect(roles.ColumnArgumentRole, c).self_group(), + coercions.expect(roles.ExpressionElementRole, r), + ) for (c, r) in whens ] @@ -2300,12 +2349,12 @@ class Case(ColumnElement): if value is None: self.value = None else: - self.value = _literal_as_binds(value) + self.value = coercions.expect(roles.ExpressionElementRole, value) self.type = type_ self.whens = whenlist if else_ is not None: - self.else_ = _literal_as_binds(else_) + self.else_ = coercions.expect(roles.ExpressionElementRole, else_) else: self.else_ = None @@ -2455,7 +2504,9 @@ class Cast(ColumnElement): """ self.type = type_api.to_instance(type_) - self.clause = _literal_as_binds(expression, type_=self.type) + self.clause = coercions.expect( + roles.ExpressionElementRole, expression, type_=self.type + ) self.typeclause = TypeClause(self.type) def _copy_internals(self, clone=_clone, **kw): @@ -2557,7 +2608,9 @@ class TypeCoerce(ColumnElement): """ self.type = type_api.to_instance(type_) - self.clause = _literal_as_binds(expression, type_=self.type) + self.clause = coercions.expect( + roles.ExpressionElementRole, expression, type_=self.type + ) def _copy_internals(self, clone=_clone, **kw): self.clause = clone(self.clause, **kw) @@ -2598,7 +2651,7 @@ class Extract(ColumnElement): """ self.type = type_api.INTEGERTYPE self.field = field - self.expr = _literal_as_binds(expr, None) + self.expr = coercions.expect(roles.ExpressionElementRole, expr) def _copy_internals(self, clone=_clone, **kw): self.expr = clone(self.expr, **kw) @@ -2733,7 +2786,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.nullsfirst_op, wraps_column_expression=False, ) @@ -2776,7 +2829,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.nullslast_op, wraps_column_expression=False, ) @@ -2817,7 +2870,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.desc_op, wraps_column_expression=False, ) @@ -2857,7 +2910,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.asc_op, wraps_column_expression=False, ) @@ -2898,7 +2951,7 @@ class UnaryExpression(ColumnElement): :data:`.func` """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) return UnaryExpression( expr, operator=operators.distinct_op, @@ -2953,6 +3006,7 @@ class UnaryExpression(ColumnElement): return ClauseElement._negate(self) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2990,10 +3044,8 @@ class CollectionAggregate(UnaryExpression): """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) - if expr.is_selectable and hasattr(expr, "as_scalar"): - expr = expr.as_scalar() expr = expr.self_group() return CollectionAggregate( expr, @@ -3023,9 +3075,7 @@ class CollectionAggregate(UnaryExpression): """ - expr = _literal_as_binds(expr) - if expr.is_selectable and hasattr(expr, "as_scalar"): - expr = expr.as_scalar() + expr = coercions.expect(roles.ExpressionElementRole, expr) expr = expr.self_group() return CollectionAggregate( expr, @@ -3064,6 +3114,7 @@ class AsBoolean(UnaryExpression): self._is_implicitly_boolean = element._is_implicitly_boolean def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self def _cache_key(self, **kw): @@ -3155,6 +3206,8 @@ class BinaryExpression(ColumnElement): ) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement + if operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3191,6 +3244,7 @@ class Slice(ColumnElement): self.type = type_api.NULLTYPE def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement assert against is operator.getitem return self @@ -3215,6 +3269,7 @@ class Grouping(ColumnElement): self.type = getattr(element, "type", type_api.NULLTYPE) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self @util.memoized_property @@ -3363,13 +3418,12 @@ class Over(ColumnElement): self.element = element if order_by is not None: self.order_by = ClauseList( - *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) if partition_by is not None: self.partition_by = ClauseList( *util.to_list(partition_by), - _literal_as_text=_literal_as_label_reference + _literal_as_text_role=roles.ByOfRole ) if range_: @@ -3534,8 +3588,7 @@ class WithinGroup(ColumnElement): self.element = element if order_by is not None: self.order_by = ClauseList( - *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) def over(self, partition_by=None, order_by=None, range_=None, rows=None): @@ -3658,7 +3711,7 @@ class FunctionFilter(ColumnElement): """ for criterion in list(criterion): - criterion = _expression_literal_as_text(criterion) + criterion = coercions.expect(roles.WhereHavingRole, criterion) if self.criterion is not None: self.criterion = self.criterion & criterion @@ -3727,7 +3780,7 @@ class FunctionFilter(ColumnElement): ) -class Label(ColumnElement): +class Label(roles.LabeledColumnExprRole, ColumnElement): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3801,6 +3854,7 @@ class Label(ColumnElement): return self._element.self_group(against=operators.as_) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): @@ -3849,7 +3903,7 @@ class Label(ColumnElement): return e -class ColumnClause(Immutable, ColumnElement): +class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): """Represents a column expression from any textual string. The :class:`.ColumnClause`, a lightweight analogue to the @@ -3985,14 +4039,14 @@ class ColumnClause(Immutable, ColumnElement): if ( self.is_literal or self.table is None - or self.table._textual + or self.table._is_textual or not hasattr(other, "proxy_set") or ( isinstance(other, ColumnClause) and ( other.is_literal or other.table is None - or other.table._textual + or other.table._is_textual ) ) ): @@ -4083,7 +4137,7 @@ class ColumnClause(Immutable, ColumnElement): counter += 1 label = _label - return _as_truncated(label) + return coercions.expect(roles.TruncatedLabelRole, label) else: return name @@ -4110,7 +4164,7 @@ class ColumnClause(Immutable, ColumnElement): # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = self._constructor( - _as_truncated(name or self.name) + coercions.expect(roles.TruncatedLabelRole, name or self.name) if name_is_truncatable else (name or self.name), type_=self.type, @@ -4250,6 +4304,108 @@ class quoted_name(util.MemoizedSlots, util.text_type): return "'%s'" % backslashed +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ) + + +def _cloned_difference(a, b): + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + ) + + +def _find_columns(clause): + """locate Column objects within the given expression.""" + + cols = util.column_set() + traverse(clause, {}, {"column": cols.add}) + return cols + + +def _type_from_args(args): + for a in args: + if not a.type._isnull: + return a.type + else: + return type_api.NULLTYPE + + +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column( + column, require_embedded=require_embedded + ) + if c is None: + raise exc.InvalidRequestError( + "Given column '%s', attached to table '%s', " + "failed to locate a corresponding column from table '%s'" + % (column, getattr(column, "table", None), fromclause.description) + ) + return c + + +class AnnotatedColumnElement(Annotated): + def __init__(self, element, values): + Annotated.__init__(self, element, values) + ColumnElement.comparator._reset(self) + for attr in ("name", "key", "table"): + if self.__dict__.get(attr, False) is None: + self.__dict__.pop(attr) + + def _with_annotations(self, values): + clone = super(AnnotatedColumnElement, self)._with_annotations(values) + ColumnElement.comparator._reset(clone) + return clone + + @util.memoized_property + def name(self): + """pull 'name' from parent, if not present""" + return self._Annotated__element.name + + @util.memoized_property + def table(self): + """pull 'table' from parent, if not present""" + return self._Annotated__element.table + + @util.memoized_property + def key(self): + """pull 'key' from parent, if not present""" + return self._Annotated__element.key + + @util.memoized_property + def info(self): + return self._Annotated__element.info + + @util.memoized_property + def anon_label(self): + return self._Annotated__element.anon_label + + class _truncated_label(quoted_name): """A unicode subclass used to identify symbolic " "names that may require truncation.""" @@ -4378,349 +4534,3 @@ class _anonymous_label(_truncated_label): else: # else skip the constructor call return self % map_ - - -def _as_truncated(value): - """coerce the given value to :class:`._truncated_label`. - - Existing :class:`._truncated_label` and - :class:`._anonymous_label` objects are passed - unchanged. - """ - - if isinstance(value, _truncated_label): - return value - else: - return _truncated_label(value) - - -def _string_or_unprintable(element): - if isinstance(element, util.string_types): - return element - else: - try: - return str(element) - except Exception: - return "unprintable element %r" % element - - -def _expand_cloned(elements): - """expand the given set of ClauseElements to be the set of all 'cloned' - predecessors. - - """ - return itertools.chain(*[x._cloned_set for x in elements]) - - -def _select_iterables(elements): - """expand tables into individual columns in the - given list of column expressions. - - """ - return itertools.chain(*[c._select_iterable for c in elements]) - - -def _cloned_intersection(a, b): - """return the intersection of sets a and b, counting - any overlap between 'cloned' predecessors. - - The returned set is in terms of the entities present within 'a'. - - """ - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if all_overlap.intersection(elem._cloned_set) - ) - - -def _cloned_difference(a, b): - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if not all_overlap.intersection(elem._cloned_set) - ) - - -@util.dependencies("sqlalchemy.sql.functions") -def _labeled(functions, element): - if not hasattr(element, "name") or isinstance( - element, functions.FunctionElement - ): - return element.label(None) - else: - return element - - -def _is_column(col): - """True if ``col`` is an instance of :class:`.ColumnElement`.""" - - return isinstance(col, ColumnElement) - - -def _find_columns(clause): - """locate Column objects within the given expression.""" - - cols = util.column_set() - traverse(clause, {}, {"column": cols.add}) - return cols - - -# there is some inconsistency here between the usage of -# inspect() vs. checking for Visitable and __clause_element__. -# Ideally all functions here would derive from inspect(), -# however the inspect() versions add significant callcount -# overhead for critical functions like _interpret_as_column_or_from(). -# Generally, the column-based functions are more performance critical -# and are fine just checking for __clause_element__(). It is only -# _interpret_as_from() where we'd like to be able to receive ORM entities -# that have no defined namespace, hence inspect() is needed there. - - -def _column_as_key(element): - if isinstance(element, util.string_types): - return element - if hasattr(element, "__clause_element__"): - element = element.__clause_element__() - try: - return element.key - except AttributeError: - return None - - -def _clause_element_as_expr(element): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - else: - return element - - -def _literal_as_label_reference(element): - if isinstance(element, util.string_types): - return _textual_label_reference(element) - - elif hasattr(element, "__clause_element__"): - element = element.__clause_element__() - - return _literal_as_text(element) - - -def _literal_and_labels_as_label_reference(element): - if isinstance(element, util.string_types): - return _textual_label_reference(element) - - elif hasattr(element, "__clause_element__"): - element = element.__clause_element__() - - if ( - isinstance(element, ColumnElement) - and element._order_by_label_element is not None - ): - return _label_reference(element) - else: - return _literal_as_text(element) - - -def _expression_literal_as_text(element): - return _literal_as_text(element) - - -def _literal_as(element, text_fallback): - if isinstance(element, Visitable): - return element - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif isinstance(element, util.string_types): - return text_fallback(element) - elif isinstance(element, (util.NoneType, bool)): - return _const_expr(element) - else: - raise exc.ArgumentError( - "SQL expression object expected, got object of type %r " - "instead" % type(element) - ) - - -def _literal_as_text(element, allow_coercion_to_text=False): - if allow_coercion_to_text: - return _literal_as(element, TextClause) - else: - return _literal_as(element, _no_text_coercion) - - -def _literal_as_column(element): - return _literal_as(element, ColumnClause) - - -def _no_column_coercion(element): - element = str(element) - guess_is_literal = not _guess_straight_column.match(element) - raise exc.ArgumentError( - "Textual column expression %(column)r should be " - "explicitly declared with text(%(column)r), " - "or use %(literal_column)s(%(column)r) " - "for more specificity" - % { - "column": util.ellipses_string(element), - "literal_column": "literal_column" - if guess_is_literal - else "column", - } - ) - - -def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None): - raise exc_cls( - "%(extra)sTextual SQL expression %(expr)r should be " - "explicitly declared as text(%(expr)r)" - % { - "expr": util.ellipses_string(element), - "extra": "%s " % extra if extra else "", - } - ) - - -def _no_literals(element): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif not isinstance(element, Visitable): - raise exc.ArgumentError( - "Ambiguous literal: %r. Use the 'text()' " - "function to indicate a SQL expression " - "literal, or 'literal()' to indicate a " - "bound value." % (element,) - ) - else: - return element - - -def _is_literal(element): - return not isinstance(element, Visitable) and not hasattr( - element, "__clause_element__" - ) - - -def _only_column_elements_or_none(element, name): - if element is None: - return None - else: - return _only_column_elements(element, name) - - -def _only_column_elements(element, name): - if hasattr(element, "__clause_element__"): - element = element.__clause_element__() - if not isinstance(element, ColumnElement): - raise exc.ArgumentError( - "Column-based expression object expected for argument " - "'%s'; got: '%s', type %s" % (name, element, type(element)) - ) - return element - - -def _literal_as_binds(element, name=None, type_=None): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif not isinstance(element, Visitable): - if element is None: - return Null() - else: - return BindParameter(name, element, type_=type_, unique=True) - else: - return element - - -_guess_straight_column = re.compile(r"^\w\S*$", re.I) - - -def _interpret_as_column_or_from(element): - if isinstance(element, Visitable): - return element - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - - insp = inspection.inspect(element, raiseerr=False) - if insp is None: - if isinstance(element, (util.NoneType, bool)): - return _const_expr(element) - elif hasattr(insp, "selectable"): - return insp.selectable - - # be forgiving as this is an extremely common - # and known expression - if element == "*": - guess_is_literal = True - elif isinstance(element, (numbers.Number)): - return ColumnClause(str(element), is_literal=True) - else: - _no_column_coercion(element) - return ColumnClause(element, is_literal=guess_is_literal) - - -def _const_expr(element): - if isinstance(element, (Null, False_, True_)): - return element - elif element is None: - return Null() - elif element is False: - return False_() - elif element is True: - return True_() - else: - raise exc.ArgumentError("Expected None, False, or True") - - -def _type_from_args(args): - for a in args: - if not a.type._isnull: - return a.type - else: - return type_api.NULLTYPE - - -def _corresponding_column_or_error(fromclause, column, require_embedded=False): - c = fromclause.corresponding_column( - column, require_embedded=require_embedded - ) - if c is None: - raise exc.InvalidRequestError( - "Given column '%s', attached to table '%s', " - "failed to locate a corresponding column from table '%s'" - % (column, getattr(column, "table", None), fromclause.description) - ) - return c - - -class AnnotatedColumnElement(Annotated): - def __init__(self, element, values): - Annotated.__init__(self, element, values) - ColumnElement.comparator._reset(self) - for attr in ("name", "key", "table"): - if self.__dict__.get(attr, False) is None: - self.__dict__.pop(attr) - - def _with_annotations(self, values): - clone = super(AnnotatedColumnElement, self)._with_annotations(values) - ColumnElement.comparator._reset(clone) - return clone - - @util.memoized_property - def name(self): - """pull 'name' from parent, if not present""" - return self._Annotated__element.name - - @util.memoized_property - def table(self): - """pull 'table' from parent, if not present""" - return self._Annotated__element.table - - @util.memoized_property - def key(self): - """pull 'key' from parent, if not present""" - return self._Annotated__element.key - - @util.memoized_property - def info(self): - return self._Annotated__element.info - - @util.memoized_property - def anon_label(self): - return self._Annotated__element.anon_label diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f381879ce..b04355cf5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -67,7 +67,6 @@ __all__ = [ "outerjoin", "over", "select", - "subquery", "table", "text", "tuple_", @@ -92,22 +91,7 @@ from .dml import Insert # noqa from .dml import Update # noqa from .dml import UpdateBase # noqa from .dml import ValuesBase # noqa -from .elements import _clause_element_as_expr # noqa -from .elements import _clone # noqa -from .elements import _cloned_difference # noqa -from .elements import _cloned_intersection # noqa -from .elements import _column_as_key # noqa -from .elements import _corresponding_column_or_error # noqa -from .elements import _expression_literal_as_text # noqa -from .elements import _is_column # noqa -from .elements import _labeled # noqa -from .elements import _literal_as_binds # noqa -from .elements import _literal_as_column # noqa -from .elements import _literal_as_label_reference # noqa -from .elements import _literal_as_text # noqa -from .elements import _only_column_elements # noqa from .elements import _select_iterables # noqa -from .elements import _string_or_unprintable # noqa from .elements import _truncated_label # noqa from .elements import between # noqa from .elements import BinaryExpression # noqa @@ -147,7 +131,6 @@ from .functions import func # noqa from .functions import Function # noqa from .functions import FunctionElement # noqa from .functions import modifier # noqa -from .selectable import _interpret_as_from # noqa from .selectable import Alias # noqa from .selectable import CompoundSelect # noqa from .selectable import CTE # noqa @@ -160,6 +143,7 @@ from .selectable import HasPrefixes # noqa from .selectable import HasSuffixes # noqa from .selectable import Join # noqa from .selectable import Lateral # noqa +from .selectable import ReturnsRows # noqa from .selectable import ScalarSelect # noqa from .selectable import Select # noqa from .selectable import Selectable # noqa @@ -171,7 +155,6 @@ from .selectable import TextAsFrom # noqa from .visitors import Visitable # noqa from ..util.langhelpers import public_factory # noqa - # factory functions - these pull class-bound constructors and classmethods # from SQL elements and selectables into public functions. This allows # the functions to be available in the sqlalchemy.sql.* namespace and diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index d0aa23988..173789998 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -9,14 +9,15 @@ """ from . import annotation +from . import coercions from . import operators +from . import roles from . import schema from . import sqltypes from . import util as sqlutil from .base import ColumnCollection from .base import Executable from .elements import _clone -from .elements import _literal_as_binds from .elements import _type_from_args from .elements import BinaryExpression from .elements import BindParameter @@ -83,7 +84,12 @@ class FunctionElement(Executable, ColumnElement, FromClause): """Construct a :class:`.FunctionElement`. """ args = [ - _literal_as_binds(c, getattr(self, "name", None)) for c in clauses + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + ) + for c in clauses ] self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( @@ -686,7 +692,12 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: - parsed_args = [_literal_as_binds(c, self.name) for c in args] + parsed_args = [ + coercions.expect( + roles.ExpressionElementRole, c, name=self.name + ) + for c in args + ] self._has_args = self._has_args or bool(parsed_args) self.packagenames = [] self._bind = kwargs.get("bind", None) @@ -751,7 +762,10 @@ class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" def __init__(self, *args, **kwargs): - args = [_literal_as_binds(c, self.name) for c in args] + args = [ + coercions.expect(roles.ExpressionElementRole, c, name=self.name) + for c in args + ] kwargs.setdefault("type_", _type_from_args(args)) kwargs["_parsed_args"] = args super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) @@ -880,7 +894,7 @@ class array_agg(GenericFunction): type = sqltypes.ARRAY def __init__(self, *args, **kwargs): - args = [_literal_as_binds(c) for c in args] + args = [coercions.expect(roles.ExpressionElementRole, c) for c in args] default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 8479c1d59..b8bbb4525 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1053,7 +1053,7 @@ class ColumnOperators(Operators): expr = 5 == mytable.c.somearray.any_() # mysql '5 = ANY (SELECT value FROM table)' - expr = 5 == select([table.c.value]).as_scalar().any_() + expr = 5 == select([table.c.value]).scalar_subquery().any_() .. seealso:: @@ -1078,7 +1078,7 @@ class ColumnOperators(Operators): expr = 5 == mytable.c.somearray.all_() # mysql '5 = ALL (SELECT value FROM table)' - expr = 5 == select([table.c.value]).as_scalar().all_() + expr = 5 == select([table.c.value]).scalar_subquery().all_() .. seealso:: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py new file mode 100644 index 000000000..2d3aaf903 --- /dev/null +++ b/lib/sqlalchemy/sql/roles.py @@ -0,0 +1,157 @@ +# sql/roles.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +class SQLRole(object): + """Define a "role" within a SQL statement structure. + + Classes within SQL Core participate within SQLRole hierarchies in order + to more accurately indicate where they may be used within SQL statements + of all types. + + .. versionadded:: 1.4 + + """ + + +class UsesInspection(object): + pass + + +class ColumnArgumentRole(SQLRole): + _role_name = "Column expression" + + +class ColumnArgumentOrKeyRole(ColumnArgumentRole): + _role_name = "Column expression or string key" + + +class ColumnListRole(SQLRole): + """Elements suitable for forming comma separated lists of expressions.""" + + +class TruncatedLabelRole(SQLRole): + _role_name = "String SQL identifier" + + +class ColumnsClauseRole(UsesInspection, ColumnListRole): + _role_name = "Column expression or FROM clause" + + @property + def _select_iterable(self): + raise NotImplementedError() + + +class LimitOffsetRole(SQLRole): + _role_name = "LIMIT / OFFSET expression" + + +class ByOfRole(ColumnListRole): + _role_name = "GROUP BY / OF / etc. expression" + + +class OrderByRole(ByOfRole): + _role_name = "ORDER BY expression" + + +class StructuralRole(SQLRole): + pass + + +class StatementOptionRole(StructuralRole): + _role_name = "statement sub-expression element" + + +class WhereHavingRole(StructuralRole): + _role_name = "SQL expression for WHERE/HAVING role" + + +class ExpressionElementRole(SQLRole): + _role_name = "SQL expression element" + + +class ConstExprRole(ExpressionElementRole): + _role_name = "Constant True/False/None expression" + + +class LabeledColumnExprRole(ExpressionElementRole): + pass + + +class BinaryElementRole(ExpressionElementRole): + _role_name = "SQL expression element or literal value" + + +class InElementRole(SQLRole): + _role_name = ( + "IN expression list, SELECT construct, or bound parameter object" + ) + + +class FromClauseRole(ColumnsClauseRole): + _role_name = "FROM expression, such as a Table or alias() object" + + @property + def _hide_froms(self): + raise NotImplementedError() + + +class CoerceTextStatementRole(SQLRole): + _role_name = "Executable SQL, text() construct, or string statement" + + +class StatementRole(CoerceTextStatementRole): + _role_name = "Executable SQL or text() construct" + + +class ReturnsRowsRole(StatementRole): + _role_name = ( + "Row returning expression such as a SELECT, or an " + "INSERT/UPDATE/DELETE with RETURNING" + ) + + +class SelectStatementRole(ReturnsRowsRole): + _role_name = "SELECT construct or equivalent text() construct" + + +class HasCTERole(ReturnsRowsRole): + pass + + +class CompoundElementRole(SQLRole): + """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc.""" + + _role_name = ( + "SELECT construct for inclusion in a UNION or other set construct" + ) + + +class DMLRole(StatementRole): + pass + + +class DMLColumnRole(SQLRole): + _role_name = "SET/VALUES column expression or string key" + + +class DMLSelectRole(SQLRole): + """A SELECT statement embedded in DML, typically INSERT from SELECT """ + + _role_name = "SELECT statement or equivalent textual object" + + +class DDLRole(StatementRole): + pass + + +class DDLExpressionRole(StructuralRole): + _role_name = "SQL expression element for DDL constraint" + + +class DDLConstraintColumnRole(SQLRole): + _role_name = "String column name or column object for DDL constraint" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index b045e006e..62ff25a64 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -34,16 +34,16 @@ import collections import operator import sqlalchemy +from . import coercions from . import ddl +from . import roles from . import type_api from . import visitors from .base import _bind_or_error from .base import ColumnCollection from .base import DialectKWArgs from .base import SchemaEventTarget -from .elements import _as_truncated -from .elements import _document_text_coercion -from .elements import _literal_as_text +from .coercions import _document_text_coercion from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -1583,7 +1583,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) try: c = self._constructor( - _as_truncated(name or self.name) + coercions.expect( + roles.TruncatedLabelRole, name if name else self.name + ) if name_is_truncatable else (name or self.name), self.type, @@ -2109,13 +2111,19 @@ class ForeignKey(DialectKWArgs, SchemaItem): class _NotAColumnExpr(object): + # the coercions system is not used in crud.py for the values passed in + # the insert().values() and update().values() methods, so the usual + # pathways to rejecting a coercion in the unlikely case of adding defaut + # generator objects to insert() or update() constructs aren't available; + # create a quick coercion rejection here that is specific to what crud.py + # calls on value objects. def _not_a_column_expr(self): raise exc.InvalidRequestError( "This %s cannot be used directly " "as a column expression." % self.__class__.__name__ ) - __clause_element__ = self_group = lambda self: self._not_a_column_expr() + self_group = lambda self: self._not_a_column_expr() # noqa _from_objects = property(lambda self: self._not_a_column_expr()) @@ -2274,7 +2282,7 @@ class ColumnDefault(DefaultGenerator): return "ColumnDefault(%r)" % (self.arg,) -class Sequence(DefaultGenerator): +class Sequence(roles.StatementRole, DefaultGenerator): """Represents a named database sequence. The :class:`.Sequence` object represents the name and configurational @@ -2759,25 +2767,6 @@ class ColumnCollectionMixin(object): if _autoattach and self._pending_colargs: self._check_attach() - @classmethod - def _extract_col_expression_collection(cls, expressions): - for expr in expressions: - strname = None - column = None - if hasattr(expr, "__clause_element__"): - expr = expr.__clause_element__() - - if not isinstance(expr, (ColumnElement, TextClause)): - # this assumes a string - strname = expr - else: - cols = [] - visitors.traverse(expr, {}, {"column": cols.append}) - if cols: - column = cols[0] - add_element = column if column is not None else strname - yield expr, column, strname, add_element - def _check_attach(self, evt=False): col_objs = [c for c in self._pending_colargs if isinstance(c, Column)] @@ -2960,7 +2949,7 @@ class CheckConstraint(ColumnCollectionConstraint): """ - self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True) + self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext) columns = [] visitors.traverse(self.sqltext, {}, {"column": columns.append}) @@ -3630,7 +3619,9 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): column, strname, add_element, - ) in self._extract_col_expression_collection(expressions): + ) in coercions.expect_col_expression_collection( + roles.DDLConstraintColumnRole, expressions + ): if add_element is not None: columns.append(add_element) processed_expressions.append(expr) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5167182fe..41be9fc5a 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -15,8 +15,9 @@ import itertools import operator from operator import attrgetter -from sqlalchemy.sql.visitors import Visitable +from . import coercions from . import operators +from . import roles from . import type_api from .annotation import Annotated from .base import _from_objects @@ -26,18 +27,12 @@ from .base import ColumnSet from .base import Executable from .base import Generative from .base import Immutable +from .coercions import _document_text_coercion from .elements import _anonymous_label -from .elements import _clause_element_as_expr from .elements import _clone from .elements import _cloned_difference from .elements import _cloned_intersection -from .elements import _document_text_coercion from .elements import _expand_cloned -from .elements import _interpret_as_column_or_from -from .elements import _literal_and_labels_as_label_reference -from .elements import _literal_as_label_reference -from .elements import _literal_as_text -from .elements import _no_text_coercion from .elements import _select_iterables from .elements import and_ from .elements import BindParameter @@ -48,75 +43,15 @@ from .elements import literal_column from .elements import True_ from .elements import UnaryExpression from .. import exc -from .. import inspection from .. import util -def _interpret_as_from(element): - insp = inspection.inspect(element, raiseerr=False) - if insp is None: - if isinstance(element, util.string_types): - _no_text_coercion(element) - try: - return insp.selectable - except AttributeError: - raise exc.ArgumentError("FROM expression expected") - - -def _interpret_as_select(element): - element = _interpret_as_from(element) - if isinstance(element, Alias): - element = element.original - if not isinstance(element, SelectBase): - element = element.select() - return element - - class _OffsetLimitParam(BindParameter): @property def _limit_offset_value(self): return self.effective_value -def _offset_or_limit_clause(element, name=None, type_=None): - """Convert the given value to an "offset or limit" clause. - - This handles incoming integers and converts to an expression; if - an expression is already given, it is passed through. - - """ - if element is None: - return None - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif isinstance(element, Visitable): - return element - else: - value = util.asint(element) - return _OffsetLimitParam(name, value, type_=type_, unique=True) - - -def _offset_or_limit_clause_asint(clause, attrname): - """Convert the "offset or limit" clause of a select construct to an - integer. - - This is only possible if the value is stored as a simple bound parameter. - Otherwise, a compilation error is raised. - - """ - if clause is None: - return None - try: - value = clause._limit_offset_value - except AttributeError: - raise exc.CompileError( - "This SELECT structure does not use a simple " - "integer value for %s" % attrname - ) - else: - return util.asint(value) - - def subquery(alias, *args, **kwargs): r"""Return an :class:`.Alias` object derived from a :class:`.Select`. @@ -133,8 +68,42 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) -class Selectable(ClauseElement): - """mark a class as being selectable""" +class ReturnsRows(roles.ReturnsRowsRole, ClauseElement): + """The basemost class for Core contructs that have some concept of + columns that can represent rows. + + While the SELECT statement and TABLE are the primary things we think + of in this category, DML like INSERT, UPDATE and DELETE can also specify + RETURNING which means they can be used in CTEs and other forms, and + PostgreSQL has functions that return rows also. + + .. versionadded:: 1.4 + + """ + + _is_returns_rows = True + + # sub-elements of returns_rows + _is_from_clause = False + _is_select_statement = False + _is_lateral = False + + @property + def selectable(self): + raise NotImplementedError( + "This object is a base ReturnsRows object, but is not a " + "FromClause so has no .c. collection." + ) + + +class Selectable(ReturnsRows): + """mark a class as being selectable. + + This class is legacy as of 1.4 as the concept of a SQL construct which + "returns rows" is more generalized than one which can be the subject + of a SELECT. + + """ __visit_name__ = "selectable" @@ -190,7 +159,7 @@ class HasPrefixes(object): def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( [ - (_literal_as_text(p, allow_coercion_to_text=True), dialect) + (coercions.expect(roles.StatementOptionRole, p), dialect) for p in prefixes ] ) @@ -236,13 +205,13 @@ class HasSuffixes(object): def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( [ - (_literal_as_text(p, allow_coercion_to_text=True), dialect) + (coercions.expect(roles.StatementOptionRole, p), dialect) for p in suffixes ] ) -class FromClause(Selectable): +class FromClause(roles.FromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -265,16 +234,6 @@ class FromClause(Selectable): named_with_column = False _hide_froms = [] - _is_join = False - _is_select = False - _is_from_container = False - - _is_lateral = False - - _textual = False - """a marker that allows us to easily distinguish a :class:`.TextAsFrom` - or similar object from other kinds of :class:`.FromClause` objects.""" - schema = None """Define the 'schema' attribute for this :class:`.FromClause`. @@ -284,6 +243,11 @@ class FromClause(Selectable): """ + is_selectable = has_selectable = True + _is_from_clause = True + _is_text_as_from = False + _is_join = False + def _translate_schema(self, effective_schema, map_): return effective_schema @@ -726,8 +690,8 @@ class Join(FromClause): :class:`.FromClause` object. """ - self.left = _interpret_as_from(left) - self.right = _interpret_as_from(right).self_group() + self.left = coercions.expect(roles.FromClauseRole, left) + self.right = coercions.expect(roles.FromClauseRole, right).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) @@ -1292,7 +1256,9 @@ class Alias(FromClause): .. versionadded:: 0.9.0 """ - return _interpret_as_from(selectable).alias(name=name, flat=flat) + return coercions.expect(roles.FromClauseRole, selectable).alias( + name=name, flat=flat + ) def _init(self, selectable, name=None): baseselectable = selectable @@ -1327,14 +1293,6 @@ class Alias(FromClause): else: return self.name.encode("ascii", "backslashreplace") - def as_scalar(self): - try: - return self.element.as_scalar() - except AttributeError: - raise AttributeError( - "Element %s does not support " "'as_scalar()'" % self.element - ) - def is_derived_from(self, fromclause): if fromclause in self._cloned_set: return True @@ -1426,7 +1384,9 @@ class Lateral(Alias): :ref:`lateral_selects` - overview of usage. """ - return _interpret_as_from(selectable).lateral(name=name) + return coercions.expect(roles.FromClauseRole, selectable).lateral( + name=name + ) class TableSample(Alias): @@ -1488,7 +1448,7 @@ class TableSample(Alias): REPEATABLE sub-clause is also rendered. """ - return _interpret_as_from(selectable).tablesample( + return coercions.expect(roles.FromClauseRole, selectable).tablesample( sampling, name=name, seed=seed ) @@ -1523,7 +1483,7 @@ class CTE(Generative, HasSuffixes, Alias): Please see :meth:`.HasCte.cte` for detail on CTE usage. """ - return _interpret_as_from(selectable).cte( + return coercions.expect(roles.HasCTERole, selectable).cte( name=name, recursive=recursive ) @@ -1588,7 +1548,7 @@ class CTE(Generative, HasSuffixes, Alias): ) -class HasCTE(object): +class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. .. versionadded:: 1.1 @@ -2059,13 +2019,22 @@ class ForUpdateArg(ClauseElement): self.key_share = key_share if of is not None: self.of = [ - _interpret_as_column_or_from(elem) for elem in util.to_list(of) + coercions.expect(roles.ColumnsClauseRole, elem) + for elem in util.to_list(of) ] else: self.of = None -class SelectBase(HasCTE, Executable, FromClause): +class SelectBase( + roles.SelectStatementRole, + roles.DMLSelectRole, + roles.CompoundElementRole, + roles.InElementRole, + HasCTE, + Executable, + FromClause, +): """Base class for SELECT statements. @@ -2075,15 +2044,32 @@ class SelectBase(HasCTE, Executable, FromClause): """ + _is_select_statement = True + + @util.deprecated( + "1.4", + "The :meth:`.SelectBase.as_scalar` method is deprecated and will be " + "removed in a future release. Please refer to " + ":meth:`.SelectBase.scalar_subquery`.", + ) def as_scalar(self): + return self.scalar_subquery() + + def scalar_subquery(self): """return a 'scalar' representation of this selectable, which can be used as a column expression. Typically, a select statement which has only one column in its columns - clause is eligible to be used as a scalar expression. + clause is eligible to be used as a scalar expression. The scalar + subquery can then be used in the WHERE clause or columns clause of + an enclosing SELECT. - The returned object is an instance of - :class:`ScalarSelect`. + Note that the scalar subquery differentiates from the FROM-level + subquery that can be produced using the :meth:`.SelectBase.subquery` + method. + + .. versionchanged: 1.4 - the ``.as_scalar()`` method was renamed to + :meth:`.SelectBase.scalar_subquery`. """ return ScalarSelect(self) @@ -2097,7 +2083,7 @@ class SelectBase(HasCTE, Executable, FromClause): :meth:`~.SelectBase.as_scalar`. """ - return self.as_scalar().label(name) + return self.scalar_subquery().label(name) @_generative @util.deprecated( @@ -2181,20 +2167,19 @@ class GenerativeSelect(SelectBase): {"autocommit": autocommit} ) if limit is not None: - self._limit_clause = _offset_or_limit_clause(limit) + self._limit_clause = self._offset_or_limit_clause(limit) if offset is not None: - self._offset_clause = _offset_or_limit_clause(offset) + self._offset_clause = self._offset_or_limit_clause(offset) self._bind = bind if order_by is not None: self._order_by_clause = ClauseList( *util.to_list(order_by), - _literal_as_text=_literal_and_labels_as_label_reference + _literal_as_text_role=roles.OrderByRole ) if group_by is not None: self._group_by_clause = ClauseList( - *util.to_list(group_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(group_by), _literal_as_text_role=roles.ByOfRole ) @property @@ -2287,6 +2272,37 @@ class GenerativeSelect(SelectBase): """ self.use_labels = True + def _offset_or_limit_clause(self, element, name=None, type_=None): + """Convert the given value to an "offset or limit" clause. + + This handles incoming integers and converts to an expression; if + an expression is already given, it is passed through. + + """ + return coercions.expect( + roles.LimitOffsetRole, element, name=name, type_=type_ + ) + + def _offset_or_limit_clause_asint(self, clause, attrname): + """Convert the "offset or limit" clause of a select construct to an + integer. + + This is only possible if the value is stored as a simple bound + parameter. Otherwise, a compilation error is raised. + + """ + if clause is None: + return None + try: + value = clause._limit_offset_value + except AttributeError: + raise exc.CompileError( + "This SELECT structure does not use a simple " + "integer value for %s" % attrname + ) + else: + return util.asint(value) + @property def _limit(self): """Get an integer value for the limit. This should only be used @@ -2295,7 +2311,7 @@ class GenerativeSelect(SelectBase): isn't currently set to an integer. """ - return _offset_or_limit_clause_asint(self._limit_clause, "limit") + return self._offset_or_limit_clause_asint(self._limit_clause, "limit") @property def _simple_int_limit(self): @@ -2319,7 +2335,9 @@ class GenerativeSelect(SelectBase): offset isn't currently set to an integer. """ - return _offset_or_limit_clause_asint(self._offset_clause, "offset") + return self._offset_or_limit_clause_asint( + self._offset_clause, "offset" + ) @_generative def limit(self, limit): @@ -2339,7 +2357,7 @@ class GenerativeSelect(SelectBase): """ - self._limit_clause = _offset_or_limit_clause(limit) + self._limit_clause = self._offset_or_limit_clause(limit) @_generative def offset(self, offset): @@ -2361,7 +2379,7 @@ class GenerativeSelect(SelectBase): """ - self._offset_clause = _offset_or_limit_clause(offset) + self._offset_clause = self._offset_or_limit_clause(offset) @_generative def order_by(self, *clauses): @@ -2403,8 +2421,7 @@ class GenerativeSelect(SelectBase): if getattr(self, "_order_by_clause", None) is not None: clauses = list(self._order_by_clause) + list(clauses) self._order_by_clause = ClauseList( - *clauses, - _literal_as_text=_literal_and_labels_as_label_reference + *clauses, _literal_as_text_role=roles.OrderByRole ) def append_group_by(self, *clauses): @@ -2423,7 +2440,7 @@ class GenerativeSelect(SelectBase): if getattr(self, "_group_by_clause", None) is not None: clauses = list(self._group_by_clause) + list(clauses) self._group_by_clause = ClauseList( - *clauses, _literal_as_text=_literal_as_label_reference + *clauses, _literal_as_text_role=roles.ByOfRole ) @property @@ -2478,7 +2495,7 @@ class CompoundSelect(GenerativeSelect): # some DBs do not like ORDER BY in the inner queries of a UNION, etc. for n, s in enumerate(selects): - s = _clause_element_as_expr(s) + s = coercions.expect(roles.CompoundElementRole, s) if not numcols: numcols = len(s.c._all_columns) @@ -2741,7 +2758,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): _correlate = () _correlate_except = None _memoized_property = SelectBase._memoized_property - _is_select = True @util.deprecated_params( autocommit=( @@ -2965,12 +2981,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._distinct = True else: self._distinct = [ - _literal_as_text(e) for e in util.to_list(distinct) + coercions.expect(roles.WhereHavingRole, e) + for e in util.to_list(distinct) ] if from_obj is not None: self._from_obj = util.OrderedSet( - _interpret_as_from(f) for f in util.to_list(from_obj) + coercions.expect(roles.FromClauseRole, f) + for f in util.to_list(from_obj) ) else: self._from_obj = util.OrderedSet() @@ -2986,7 +3004,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): if cols_present: self._raw_columns = [] for c in columns: - c = _interpret_as_column_or_from(c) + c = coercions.expect(roles.ColumnsClauseRole, c) if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) self._raw_columns.append(c) @@ -2994,16 +3012,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._raw_columns = [] if whereclause is not None: - self._whereclause = _literal_as_text(whereclause).self_group( - against=operators._asbool - ) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ).self_group(against=operators._asbool) else: self._whereclause = None if having is not None: - self._having = _literal_as_text(having).self_group( - against=operators._asbool - ) + self._having = coercions.expect( + roles.WhereHavingRole, having + ).self_group(against=operators._asbool) else: self._having = None @@ -3202,15 +3220,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): else: self._hints = self._hints.union({(selectable, dialect_name): text}) - @property - def type(self): - raise exc.InvalidRequestError( - "Select objects don't have a type. " - "Call as_scalar() on this Select " - "object to return a 'scalar' version " - "of this Select." - ) - @_memoized_property.method def locate_all_froms(self): """return a Set of all FromClause elements referenced by this Select. @@ -3496,7 +3505,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._reset_exported() rc = [] for c in columns: - c = _interpret_as_column_or_from(c) + c = coercions.expect(roles.ColumnsClauseRole, c) if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) rc.append(c) @@ -3530,7 +3539,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ if expr: - expr = [_literal_as_label_reference(e) for e in expr] + expr = [coercions.expect(roles.ByOfRole, e) for e in expr] if isinstance(self._distinct, list): self._distinct = self._distinct + expr else: @@ -3618,7 +3627,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate = () else: self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclauses + coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) @_generative @@ -3653,7 +3662,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate_except = () else: self._correlate_except = set(self._correlate_except or ()).union( - _interpret_as_from(f) for f in fromclauses + coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) def append_correlation(self, fromclause): @@ -3668,7 +3677,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._auto_correlate = False self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclause + coercions.expect(roles.FromClauseRole, f) for f in fromclause ) def append_column(self, column): @@ -3689,7 +3698,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - column = _interpret_as_column_or_from(column) + column = coercions.expect(roles.ColumnsClauseRole, column) if isinstance(column, ScalarSelect): column = column.self_group(against=operators.comma_op) @@ -3705,7 +3714,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): standard :term:`method chaining`. """ - clause = _literal_as_text(clause) + clause = coercions.expect(roles.WhereHavingRole, clause) self._prefixes = self._prefixes + (clause,) def append_whereclause(self, whereclause): @@ -3747,7 +3756,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - fromclause = _interpret_as_from(fromclause) + fromclause = coercions.expect(roles.FromClauseRole, fromclause) self._from_obj = self._from_obj.union([fromclause]) @_memoized_property @@ -3894,7 +3903,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): bind = property(bind, _set_bind) -class ScalarSelect(Generative, Grouping): +class ScalarSelect(roles.InElementRole, Generative, Grouping): _from_objects = [] _is_from_container = True _is_implicitly_boolean = False @@ -3956,7 +3965,7 @@ class Exists(UnaryExpression): else: if not args: args = ([literal_column("*")],) - s = Select(*args, **kwargs).as_scalar().self_group() + s = Select(*args, **kwargs).scalar_subquery().self_group() UnaryExpression.__init__( self, @@ -3999,6 +4008,7 @@ class Exists(UnaryExpression): return e +# TODO: rename to TextualSelect, this is not a FROM clause class TextAsFrom(SelectBase): """Wrap a :class:`.TextClause` construct within a :class:`.SelectBase` interface. @@ -4022,7 +4032,7 @@ class TextAsFrom(SelectBase): __visit_name__ = "text_as_from" - _textual = True + _is_textual = True def __init__(self, text, columns, positional=False): self.element = text diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 0d3944552..6a520a2d5 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -14,13 +14,14 @@ import datetime as dt import decimal import json +from . import coercions from . import elements from . import operators +from . import roles from . import type_api from .base import _bind_or_error from .base import SchemaEventTarget from .elements import _defer_name -from .elements import _literal_as_binds from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa @@ -2187,19 +2188,21 @@ class JSON(Indexable, TypeEngine): if not isinstance(index, util.string_types) and isinstance( index, compat.collections_abc.Sequence ): - index = default_comparator._check_literal( - self.expr, - operators.json_path_getitem_op, + index = coercions.expect( + roles.BinaryElementRole, index, + expr=self.expr, + operator=operators.json_path_getitem_op, bindparam_type=JSON.JSONPathType, ) operator = operators.json_path_getitem_op else: - index = default_comparator._check_literal( - self.expr, - operators.json_getitem_op, + index = coercions.expect( + roles.BinaryElementRole, index, + expr=self.expr, + operator=operators.json_getitem_op, bindparam_type=JSON.JSONIndexType, ) operator = operators.json_getitem_op @@ -2372,17 +2375,20 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): if self.type.zero_indexes: index = slice(index.start + 1, index.stop + 1, index.step) index = Slice( - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.start, name=self.expr.key, type_=type_api.INTEGERTYPE, ), - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.stop, name=self.expr.key, type_=type_api.INTEGERTYPE, ), - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.step, name=self.expr.key, type_=type_api.INTEGERTYPE, @@ -2438,7 +2444,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ operator = operator if operator else operators.eq return operator( - elements._literal_as_binds(other), + coercions.expect(roles.ExpressionElementRole, other), elements.CollectionAggregate._create_any(self.expr), ) @@ -2473,7 +2479,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ operator = operator if operator else operators.eq return operator( - elements._literal_as_binds(other), + coercions.expect(roles.ExpressionElementRole, other), elements.CollectionAggregate._create_all(self.expr), ) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index bdeae9613..5eea27e08 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -57,6 +57,9 @@ class TypeEngine(Visitable): default_comparator = None + def __clause_element__(self): + return self.expr + def __init__(self, expr): self.expr = expr self.type = expr.type |
