From be0831fea83247451628bc6643d5b130c63f6011 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 19 Jan 2023 12:09:29 -0500 Subject: implement basic typing for lambda elements These weren't working at all, so fixed things up and added a test suite. Keeping things very basic with Any returns etc. as having more specific return types starts making it too cumbersome to write end-user code. Corrected the type passed for "lambda statements" so that a plain lambda is accepted by mypy, pyright, others without any errors about argument types. Additionally implemented typing for more of the public API for lambda statements and ensured :class:`.StatementLambdaElement` is part of the :class:`.Executable` hierarchy so it's typed as accepted by :meth:`_engine.Connection.execute`. Fixes: #9120 Change-Id: Ia7fa34e5b6e43fba02c8f94ccc256f3a68a1f445 --- lib/sqlalchemy/sql/elements.py | 6 +-- lib/sqlalchemy/sql/lambdas.py | 104 +++++++++++++++++++++++++++++------------ 2 files changed, 76 insertions(+), 34 deletions(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 6d1949425..043fb7a03 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -106,9 +106,9 @@ if typing.TYPE_CHECKING: from ..engine import Dialect from ..engine import Engine from ..engine.interfaces import _CoreMultiExecuteParams - from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import CoreExecuteOptionsParameter from ..engine.interfaces import SchemaTranslateMapType from ..engine.result import Result @@ -481,7 +481,7 @@ class ClauseElement( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: CoreExecuteOptionsParameter, ) -> Result[Any]: if self.supports_execution: if TYPE_CHECKING: @@ -496,7 +496,7 @@ class ClauseElement( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: CoreExecuteOptionsParameter, ) -> Any: """an additional hook for subclasses to provide a different implementation for connection.scalar() vs. connection.execute(). diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index b153ba999..d737b1bcb 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -18,13 +18,13 @@ from types import CodeType from typing import Any from typing import Callable from typing import cast -from typing import Iterable from typing import List from typing import MutableMapping from typing import Optional from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -43,7 +43,6 @@ from .. import exc from .. import inspection from .. import util from ..util.typing import Literal -from ..util.typing import Protocol from ..util.typing import Self if TYPE_CHECKING: @@ -60,12 +59,14 @@ _BoundParameterGetter = Callable[..., Any] _closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000) -class _LambdaType(Protocol): - __code__: CodeType - __closure__: Iterable[Tuple[Any, Any]] +_LambdaType = Callable[[], Any] - def __call__(self, *arg: Any, **kw: Any) -> ClauseElement: - ... +_AnyLambdaType = Callable[..., Any] + +_StmtLambdaType = Callable[[], Any] + +_E = TypeVar("_E", bound=Executable) +_StmtLambdaElementType = Callable[[_E], Any] class LambdaOptions(Options): @@ -78,7 +79,7 @@ class LambdaOptions(Options): def lambda_stmt( - lmb: _LambdaType, + lmb: _StmtLambdaType, enable_tracking: bool = True, track_closure_variables: bool = True, track_on: Optional[object] = None, @@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement): closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] role: Type[SQLRole] _rec: Union[AnalyzedFunction, NonAnalyzedFunction] - fn: _LambdaType + fn: _AnyLambdaType tracker_key: Tuple[CodeType, ...] def __repr__(self): @@ -416,8 +417,8 @@ class LambdaElement(elements.ClauseElement): bindparams.extend(self._resolved_bindparams) return cache_key - def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement: - return fn() + def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement: + return fn() # type: ignore[no-any-return] class DeferredLambdaElement(LambdaElement): @@ -494,7 +495,9 @@ class DeferredLambdaElement(LambdaElement): self._transforms += (deferred_copy_internals,) -class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): +class StatementLambdaElement( + roles.AllowsLambdaRole, LambdaElement, Executable +): """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. The :class:`_sql.StatementLambdaElement` is constructed using the @@ -520,17 +523,30 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ - def __add__(self, other): + if TYPE_CHECKING: + + def __init__( + self, + fn: _StmtLambdaType, + role: Type[SQLRole], + opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, + apply_propagate_attrs: Optional[ClauseElement] = None, + ): + ... + + def __add__( + self, other: _StmtLambdaElementType[Any] + ) -> StatementLambdaElement: return self.add_criteria(other) def add_criteria( self, - other, - enable_tracking=True, - track_on=None, - track_closure_variables=True, - track_bound_values=True, - ): + other: _StmtLambdaElementType[Any], + enable_tracking: bool = True, + track_on: Optional[Any] = None, + track_closure_variables: bool = True, + track_bound_values: bool = True, + ) -> StatementLambdaElement: """Add new criteria to this :class:`_sql.StatementLambdaElement`. E.g.:: @@ -587,25 +603,51 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): else: raise exc.ObjectNotExecutableError(self) + @property + def _proxied(self) -> Any: + return self._rec_expected_expr + @property def _with_options(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._with_options + return self._proxied._with_options @property def _effective_plugin_target(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._effective_plugin_target + return self._proxied._effective_plugin_target @property def _execution_options(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._execution_options + return self._proxied._execution_options + + @property + def _all_selected_columns(self): + return self._proxied._all_selected_columns + + @property + def is_select(self): + return self._proxied.is_select + + @property + def is_update(self): + return self._proxied.is_update + + @property + def is_insert(self): + return self._proxied.is_insert - def spoil(self): + @property + def is_text(self): + return self._proxied.is_text + + @property + def is_delete(self): + return self._proxied.is_delete + + @property + def is_dml(self): + return self._proxied.is_dml + + def spoil(self) -> NullLambdaStatement: """Return a new :class:`.StatementLambdaElement` that will run all lambdas unconditionally each time. @@ -667,12 +709,12 @@ class LinkedLambdaElement(StatementLambdaElement): def __init__( self, - fn: _LambdaType, + fn: _StmtLambdaElementType[Any], parent_lambda: StatementLambdaElement, opts: Union[Type[LambdaOptions], LambdaOptions], ): self.opts = opts - self.fn = fn + self.fn = fn # type: ignore[assignment] self.parent_lambda = parent_lambda self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) -- cgit v1.2.1