diff options
Diffstat (limited to 'lib/sqlalchemy/sql/elements.py')
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 85 |
1 files changed, 53 insertions, 32 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index c735085f8..aec29d1b2 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -26,6 +26,7 @@ from typing import Dict from typing import FrozenSet from typing import Generic from typing import Iterable +from typing import Iterator from typing import List from typing import Mapping from typing import Optional @@ -77,8 +78,8 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _PropagateAttrsType - from ._typing import _SelectIterable from ._typing import _TypeEngineArgument + from .cache_key import _CacheKeyTraversalType from .cache_key import CacheKey from .compiler import Compiled from .compiler import SQLCompiler @@ -88,6 +89,7 @@ if typing.TYPE_CHECKING: from .schema import DefaultGenerator from .schema import FetchedValue from .schema import ForeignKey + from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause from .selectable import ReturnsRows @@ -96,6 +98,7 @@ if typing.TYPE_CHECKING: from .sqltypes import Boolean from .sqltypes import TupleType from .type_api import TypeEngine + from .visitors import _CloneCallableType from .visitors import _TraverseInternalsType from ..engine import Connection from ..engine import Dialect @@ -310,6 +313,7 @@ class ClauseElement( _is_text_clause = False _is_from_container = False _is_select_container = False + _is_select_base = False _is_select_statement = False _is_bind_parameter = False _is_clause_list = False @@ -321,7 +325,7 @@ class ClauseElement( def _order_by_label_element(self) -> Optional[Label[Any]]: return None - _cache_key_traversal = None + _cache_key_traversal: _CacheKeyTraversalType = None negation_clause: ColumnElement[bool] @@ -528,7 +532,7 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Any: """Apply a 'grouping' to this :class:`_expression.ClauseElement`. This method is overridden by subclasses to return a "grouping" @@ -637,9 +641,9 @@ class ClauseElement( return self._negate() def _negate(self) -> ClauseElement: - return UnaryExpression( - self.self_group(against=operators.inv), operator=operators.inv - ) + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression(grouped, operator=operators.inv) def __bool__(self): raise TypeError("Boolean value of this clause is not defined") @@ -1290,12 +1294,6 @@ class ColumnElement( @overload def self_group( - self: ColumnElement[bool], against: Optional[OperatorType] = None - ) -> ColumnElement[bool]: - ... - - @overload - def self_group( self: ColumnElement[Any], against: Optional[OperatorType] = None ) -> ColumnElement[Any]: ... @@ -1764,6 +1762,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): key: str type: TypeEngine[_T] + value: Optional[_T] _is_crud = False _is_bind_parameter = True @@ -1883,7 +1882,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): return cloned @property - def effective_value(self): + def effective_value(self) -> Optional[_T]: """Return the value of this bound parameter, taking into account if the ``callable`` parameter was set. @@ -1893,11 +1892,12 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): """ if self.callable: - return self.callable() + # TODO: set up protocol for bind parameter callable + return self.callable() # type: ignore else: return self.value - def render_literal_execute(self): + def render_literal_execute(self) -> BindParameter[_T]: """Produce a copy of this bound parameter that will enable the :paramref:`_sql.BindParameter.literal_execute` flag. @@ -2513,8 +2513,10 @@ class ClauseList( self.operator = operator self.group = group self.group_contents = group_contents + clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses if _flatten_sub_clauses: - clauses = util.flatten_iterator(clauses) + clauses_iterator = util.flatten_iterator(clauses_iterator) + self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role text_converter_role: Type[roles.SQLRole] = _literal_as_text_role @@ -2523,31 +2525,35 @@ class ClauseList( coercions.expect( text_converter_role, clause, apply_propagate_attrs=self ).self_group(against=self.operator) - for clause in clauses + for clause in clauses_iterator ] else: self.clauses = [ coercions.expect( text_converter_role, clause, apply_propagate_attrs=self ) - for clause in clauses + for clause in clauses_iterator ] self._is_implicitly_boolean = operators.is_boolean(self.operator) @classmethod - def _construct_raw(cls, operator, clauses=None): + def _construct_raw( + cls, + operator: OperatorType, + clauses: Optional[Sequence[ColumnElement[Any]]] = None, + ) -> ClauseList: self = cls.__new__(cls) - self.clauses = clauses if clauses else [] + self.clauses = list(clauses) if clauses else [] self.group = True self.operator = operator self.group_contents = True self._is_implicitly_boolean = False return self - def __iter__(self): + def __iter__(self) -> Iterator[ColumnElement[Any]]: return iter(self.clauses) - def __len__(self): + def __len__(self) -> int: return len(self.clauses) @property @@ -2708,10 +2714,10 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): def _construct_raw( cls, operator: OperatorType, - clauses: Optional[List[ColumnElement[Any]]] = None, + clauses: Optional[Sequence[ColumnElement[Any]]] = None, ) -> BooleanClauseList: self = cls.__new__(cls) - self.clauses = clauses if clauses else [] + self.clauses = list(clauses) if clauses else [] self.group = True self.operator = operator self.group_contents = True @@ -2781,7 +2787,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): sqltypes = util.preloaded.sql_sqltypes if types is None: - init_clauses = [ + init_clauses: List[ColumnElement[Any]] = [ coercions.expect(roles.ExpressionElementRole, c) for c in clauses ] @@ -2908,7 +2914,7 @@ class Case(ColumnElement[_T]): ] if whenlist: - type_ = list(whenlist[-1])[-1].type + type_ = whenlist[-1][-1].type else: type_ = None @@ -3098,6 +3104,8 @@ class _label_reference(ColumnElement[_T]): ("element", InternalTraversal.dp_clauseelement) ] + element: ColumnElement[_T] + def __init__(self, element: ColumnElement[_T]): self.element = element @@ -3212,7 +3220,9 @@ class UnaryExpression(ColumnElement[_T]): cls, expr: _ColumnExpressionArgument[_T], ) -> UnaryExpression[_T]: - col_expr = coercions.expect(roles.ExpressionElementRole, expr) + col_expr: ColumnElement[_T] = coercions.expect( + roles.ExpressionElementRole, expr + ) return UnaryExpression( col_expr, operator=operators.distinct_op, @@ -3265,7 +3275,7 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_any( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: - col_expr = coercions.expect( + col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, ) @@ -3281,7 +3291,7 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_all( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: - col_expr = coercions.expect( + col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, ) @@ -3374,6 +3384,9 @@ class BinaryExpression(ColumnElement[_T]): modifiers: Optional[Mapping[str, Any]] + left: ColumnElement[Any] + right: Union[ColumnElement[Any], ClauseList] + def __init__( self, left: ColumnElement[Any], @@ -4147,7 +4160,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): def foreign_keys(self): return self.element.foreign_keys - def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): + def _copy_internals( + self, + *, + clone: _CloneCallableType = _clone, + anonymize_labels: bool = False, + **kw: Any, + ) -> None: self._reset_memoizations() self._element = clone(self._element, **kw) if anonymize_labels: @@ -4447,7 +4466,9 @@ class TableValuedColumn(NamedColumn[_T]): self.key = self.name = scalar_alias.name self.type = type_ - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: self.scalar_alias = clone(self.scalar_alias, **kw) self.key = self.name = self.scalar_alias.name @@ -4467,7 +4488,7 @@ class CollationClause(ColumnElement[str]): def _create_collation_expression( cls, expression: _ColumnExpressionArgument[str], collation: str ) -> BinaryExpression[str]: - expr = coercions.expect(roles.ExpressionElementRole, expression) + expr = coercions.expect(roles.ExpressionElementRole[str], expression) return BinaryExpression( expr, CollationClause(collation), |
