summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-14 22:24:36 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2023-01-14 22:36:22 -0500
commit67c1c018f571fbbcf070c4e0637f36d9533c86d7 (patch)
tree61ef98ec5b20447713ec36a740fb0f5f0d7c0351 /lib/sqlalchemy
parente07130c597422d5f9a5d734e1411d8fef0c2deff (diff)
downloadsqlalchemy-67c1c018f571fbbcf070c4e0637f36d9533c86d7.tar.gz
apply pep-612 to hybrid_method; accept SQLCoreOperations
Fixes to the annotations within the ``sqlalchemy.ext.hybrid`` extension for more effective typing of user-defined methods. The typing now uses :pep:`612` features, now supported by recent versions of Mypy, to maintain argument signatures for :class:`.hybrid_method`. Return values for hybrid methods are accepted as SQL expressions in contexts such as :meth:`_sql.Select.where` while still supporting SQL methods. Fixes: #9096 Change-Id: Id4e3a38ec50e415220dfc5f022281b11bb262469
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ext/hybrid.py27
-rw-r--r--lib/sqlalchemy/sql/_typing.py8
2 files changed, 25 insertions, 10 deletions
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index 657bc8c6e..115c1cb85 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -707,7 +707,9 @@ from ..sql import roles
from ..sql._typing import is_has_clause_element
from ..sql.elements import ColumnElement
from ..sql.elements import SQLCoreOperations
+from ..util.typing import Concatenate
from ..util.typing import Literal
+from ..util.typing import ParamSpec
from ..util.typing import Protocol
if TYPE_CHECKING:
@@ -719,6 +721,8 @@ if TYPE_CHECKING:
from ..sql._typing import _InfoType
from ..sql.operators import OperatorType
+_P = ParamSpec("_P")
+_R = TypeVar("_R")
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
@@ -784,7 +788,7 @@ class _HybridExprCallableType(Protocol[_T_co]):
...
-class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
+class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
"""A decorator which allows definition of a Python object method with both
instance-level and class-level behavior.
@@ -795,8 +799,10 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
def __init__(
self,
- func: Callable[..., _T],
- expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None,
+ func: Callable[Concatenate[Any, _P], _R],
+ expr: Optional[
+ Callable[Concatenate[Any, _P], SQLCoreOperations[_R]]
+ ] = None,
):
"""Create a new :class:`.hybrid_method`.
@@ -815,31 +821,34 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
"""
self.func = func
- self.expression(expr or func)
+ if expr is not None:
+ self.expression(expr)
+ else:
+ self.expression(func) # type: ignore
@overload
def __get__(
self, instance: Literal[None], owner: Type[object]
- ) -> Callable[[Any], SQLCoreOperations[_T]]:
+ ) -> Callable[_P, SQLCoreOperations[_R]]:
...
@overload
def __get__(
self, instance: object, owner: Type[object]
- ) -> Callable[[Any], _T]:
+ ) -> Callable[_P, _R]:
...
def __get__(
self, instance: Optional[object], owner: Type[object]
- ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]:
+ ) -> Union[Callable[_P, _R], Callable[_P, SQLCoreOperations[_R]]]:
if instance is None:
return self.expr.__get__(owner, owner) # type: ignore
else:
return self.func.__get__(instance, owner) # type: ignore
def expression(
- self, expr: Callable[..., SQLCoreOperations[_T]]
- ) -> hybrid_method[_T]:
+ self, expr: Callable[Concatenate[Any, _P], SQLCoreOperations[_R]]
+ ) -> hybrid_method[_P, _R]:
"""Provide a modifying decorator that defines a
SQL-expression producing method."""
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index a120629ca..da3a9ad4e 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -44,6 +44,7 @@ if TYPE_CHECKING:
from .elements import ColumnElement
from .elements import KeyedColumnElement
from .elements import quoted_name
+ from .elements import SQLCoreOperations
from .elements import TextClause
from .lambdas import LambdaElement
from .roles import ColumnsClauseRole
@@ -128,6 +129,7 @@ _TextCoercedExpressionArgument = Union[
_ColumnsClauseArgument = Union[
roles.TypedColumnsClauseRole[_T],
roles.ColumnsClauseRole,
+ "SQLCoreOperations[_T]",
Literal["*", 1],
Type[_T],
Inspectable[_HasClauseElement],
@@ -144,7 +146,10 @@ sets; select(...), insert().returning(...), etc.
"""
_TypedColumnClauseArgument = Union[
- roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T]
+ roles.TypedColumnsClauseRole[_T],
+ "SQLCoreOperations[_T]",
+ roles.ExpressionElementRole[_T],
+ Type[_T],
]
_TP = TypeVar("_TP", bound=Tuple[Any, ...])
@@ -164,6 +169,7 @@ _T9 = TypeVar("_T9", bound=Any)
_ColumnExpressionArgument = Union[
"ColumnElement[_T]",
_HasClauseElement,
+ "SQLCoreOperations[_T]",
roles.ExpressionElementRole[_T],
Callable[[], "ColumnElement[_T]"],
"LambdaElement",