summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/elements.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/elements.py')
-rw-r--r--lib/sqlalchemy/sql/elements.py85
1 files changed, 53 insertions, 32 deletions
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index c735085f8..aec29d1b2 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -26,6 +26,7 @@ from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
@@ -77,8 +78,8 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _PropagateAttrsType
- from ._typing import _SelectIterable
from ._typing import _TypeEngineArgument
+ from .cache_key import _CacheKeyTraversalType
from .cache_key import CacheKey
from .compiler import Compiled
from .compiler import SQLCompiler
@@ -88,6 +89,7 @@ if typing.TYPE_CHECKING:
from .schema import DefaultGenerator
from .schema import FetchedValue
from .schema import ForeignKey
+ from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import NamedFromClause
from .selectable import ReturnsRows
@@ -96,6 +98,7 @@ if typing.TYPE_CHECKING:
from .sqltypes import Boolean
from .sqltypes import TupleType
from .type_api import TypeEngine
+ from .visitors import _CloneCallableType
from .visitors import _TraverseInternalsType
from ..engine import Connection
from ..engine import Dialect
@@ -310,6 +313,7 @@ class ClauseElement(
_is_text_clause = False
_is_from_container = False
_is_select_container = False
+ _is_select_base = False
_is_select_statement = False
_is_bind_parameter = False
_is_clause_list = False
@@ -321,7 +325,7 @@ class ClauseElement(
def _order_by_label_element(self) -> Optional[Label[Any]]:
return None
- _cache_key_traversal = None
+ _cache_key_traversal: _CacheKeyTraversalType = None
negation_clause: ColumnElement[bool]
@@ -528,7 +532,7 @@ class ClauseElement(
"""
return traversals.compare(self, other, **kw)
- def self_group(self, against=None):
+ def self_group(self, against: Optional[OperatorType] = None) -> Any:
"""Apply a 'grouping' to this :class:`_expression.ClauseElement`.
This method is overridden by subclasses to return a "grouping"
@@ -637,9 +641,9 @@ class ClauseElement(
return self._negate()
def _negate(self) -> ClauseElement:
- return UnaryExpression(
- self.self_group(against=operators.inv), operator=operators.inv
- )
+ grouped = self.self_group(against=operators.inv)
+ assert isinstance(grouped, ColumnElement)
+ return UnaryExpression(grouped, operator=operators.inv)
def __bool__(self):
raise TypeError("Boolean value of this clause is not defined")
@@ -1290,12 +1294,6 @@ class ColumnElement(
@overload
def self_group(
- self: ColumnElement[bool], against: Optional[OperatorType] = None
- ) -> ColumnElement[bool]:
- ...
-
- @overload
- def self_group(
self: ColumnElement[Any], against: Optional[OperatorType] = None
) -> ColumnElement[Any]:
...
@@ -1764,6 +1762,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
key: str
type: TypeEngine[_T]
+ value: Optional[_T]
_is_crud = False
_is_bind_parameter = True
@@ -1883,7 +1882,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
return cloned
@property
- def effective_value(self):
+ def effective_value(self) -> Optional[_T]:
"""Return the value of this bound parameter,
taking into account if the ``callable`` parameter
was set.
@@ -1893,11 +1892,12 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
"""
if self.callable:
- return self.callable()
+ # TODO: set up protocol for bind parameter callable
+ return self.callable() # type: ignore
else:
return self.value
- def render_literal_execute(self):
+ def render_literal_execute(self) -> BindParameter[_T]:
"""Produce a copy of this bound parameter that will enable the
:paramref:`_sql.BindParameter.literal_execute` flag.
@@ -2513,8 +2513,10 @@ class ClauseList(
self.operator = operator
self.group = group
self.group_contents = group_contents
+ clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses
if _flatten_sub_clauses:
- clauses = util.flatten_iterator(clauses)
+ clauses_iterator = util.flatten_iterator(clauses_iterator)
+
self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
@@ -2523,31 +2525,35 @@ class ClauseList(
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
).self_group(against=self.operator)
- for clause in clauses
+ for clause in clauses_iterator
]
else:
self.clauses = [
coercions.expect(
text_converter_role, clause, apply_propagate_attrs=self
)
- for clause in clauses
+ for clause in clauses_iterator
]
self._is_implicitly_boolean = operators.is_boolean(self.operator)
@classmethod
- def _construct_raw(cls, operator, clauses=None):
+ def _construct_raw(
+ cls,
+ operator: OperatorType,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
+ ) -> ClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
self._is_implicitly_boolean = False
return self
- def __iter__(self):
+ def __iter__(self) -> Iterator[ColumnElement[Any]]:
return iter(self.clauses)
- def __len__(self):
+ def __len__(self) -> int:
return len(self.clauses)
@property
@@ -2708,10 +2714,10 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
def _construct_raw(
cls,
operator: OperatorType,
- clauses: Optional[List[ColumnElement[Any]]] = None,
+ clauses: Optional[Sequence[ColumnElement[Any]]] = None,
) -> BooleanClauseList:
self = cls.__new__(cls)
- self.clauses = clauses if clauses else []
+ self.clauses = list(clauses) if clauses else []
self.group = True
self.operator = operator
self.group_contents = True
@@ -2781,7 +2787,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
sqltypes = util.preloaded.sql_sqltypes
if types is None:
- init_clauses = [
+ init_clauses: List[ColumnElement[Any]] = [
coercions.expect(roles.ExpressionElementRole, c)
for c in clauses
]
@@ -2908,7 +2914,7 @@ class Case(ColumnElement[_T]):
]
if whenlist:
- type_ = list(whenlist[-1])[-1].type
+ type_ = whenlist[-1][-1].type
else:
type_ = None
@@ -3098,6 +3104,8 @@ class _label_reference(ColumnElement[_T]):
("element", InternalTraversal.dp_clauseelement)
]
+ element: ColumnElement[_T]
+
def __init__(self, element: ColumnElement[_T]):
self.element = element
@@ -3212,7 +3220,9 @@ class UnaryExpression(ColumnElement[_T]):
cls,
expr: _ColumnExpressionArgument[_T],
) -> UnaryExpression[_T]:
- col_expr = coercions.expect(roles.ExpressionElementRole, expr)
+ col_expr: ColumnElement[_T] = coercions.expect(
+ roles.ExpressionElementRole, expr
+ )
return UnaryExpression(
col_expr,
operator=operators.distinct_op,
@@ -3265,7 +3275,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_any(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3281,7 +3291,7 @@ class CollectionAggregate(UnaryExpression[_T]):
def _create_all(
cls, expr: _ColumnExpressionArgument[_T]
) -> CollectionAggregate[bool]:
- col_expr = coercions.expect(
+ col_expr: ColumnElement[_T] = coercions.expect(
roles.ExpressionElementRole,
expr,
)
@@ -3374,6 +3384,9 @@ class BinaryExpression(ColumnElement[_T]):
modifiers: Optional[Mapping[str, Any]]
+ left: ColumnElement[Any]
+ right: Union[ColumnElement[Any], ClauseList]
+
def __init__(
self,
left: ColumnElement[Any],
@@ -4147,7 +4160,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
def foreign_keys(self):
return self.element.foreign_keys
- def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ def _copy_internals(
+ self,
+ *,
+ clone: _CloneCallableType = _clone,
+ anonymize_labels: bool = False,
+ **kw: Any,
+ ) -> None:
self._reset_memoizations()
self._element = clone(self._element, **kw)
if anonymize_labels:
@@ -4447,7 +4466,9 @@ class TableValuedColumn(NamedColumn[_T]):
self.key = self.name = scalar_alias.name
self.type = type_
- def _copy_internals(self, clone=_clone, **kw):
+ def _copy_internals(
+ self, clone: _CloneCallableType = _clone, **kw: Any
+ ) -> None:
self.scalar_alias = clone(self.scalar_alias, **kw)
self.key = self.name = self.scalar_alias.name
@@ -4467,7 +4488,7 @@ class CollationClause(ColumnElement[str]):
def _create_collation_expression(
cls, expression: _ColumnExpressionArgument[str], collation: str
) -> BinaryExpression[str]:
- expr = coercions.expect(roles.ExpressionElementRole, expression)
+ expr = coercions.expect(roles.ExpressionElementRole[str], expression)
return BinaryExpression(
expr,
CollationClause(collation),