summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-16 12:07:25 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-17 09:42:29 -0400
commit3b520e758a715cf817075e4a90ae1b5813ffadd3 (patch)
tree260f9517af499e7fb789d188f1631cd823a59929 /lib/sqlalchemy/ext
parent6acf5d2fca4a988a77481b82662174e8015a6b37 (diff)
downloadsqlalchemy-3b520e758a715cf817075e4a90ae1b5813ffadd3.tar.gz
pep484 for hybrid
Change-Id: I53274b13094d996e11b04acb03f9613edbddf87f References: #6810
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/hybrid.py224
1 files changed, 175 insertions, 49 deletions
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
index dc34a2ef5..92b3ce54f 100644
--- a/lib/sqlalchemy/ext/hybrid.py
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -802,17 +802,41 @@ advanced and/or patient developers, there's probably a whole lot of amazing
things it can be used for.
""" # noqa
+
+from __future__ import annotations
+
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from .. import util
from ..orm import attributes
from ..orm import InspectionAttrExtensionType
from ..orm import interfaces
from ..orm import ORMDescriptor
+from ..sql._typing import is_has_column_element_clause_element
+from ..sql.elements import ColumnElement
+from ..sql.elements import SQLCoreOperations
+from ..util.typing import Literal
+from ..util.typing import Protocol
+if TYPE_CHECKING:
+ from ..orm.util import AliasedInsp
+ from ..sql.operators import OperatorType
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
class HybridExtensionType(InspectionAttrExtensionType):
@@ -844,7 +868,34 @@ class HybridExtensionType(InspectionAttrExtensionType):
"""
-class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
+class _HybridGetterType(Protocol[_T_co]):
+ def __call__(s, self: Any) -> _T_co:
+ ...
+
+
+class _HybridSetterType(Protocol[_T_con]):
+ def __call__(self, instance: Any, value: _T_con) -> None:
+ ...
+
+
+class _HybridUpdaterType(Protocol[_T]):
+ def __call__(
+ self, cls: Type[Any], value: Union[_T, SQLCoreOperations[_T]]
+ ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
+ ...
+
+
+class _HybridDeleterType(Protocol[_T_co]):
+ def __call__(self, instance: Any) -> None:
+ ...
+
+
+class _HybridExprCallableType(Protocol[_T]):
+ def __call__(self, cls: Any) -> SQLCoreOperations[_T]:
+ ...
+
+
+class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]):
"""A decorator which allows definition of a Python object method with both
instance-level and class-level behavior.
@@ -853,7 +904,11 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
is_attribute = True
extension_type = HybridExtensionType.HYBRID_METHOD
- def __init__(self, func, expr=None):
+ def __init__(
+ self,
+ func: Callable[..., _T],
+ expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None,
+ ):
"""Create a new :class:`.hybrid_method`.
Usage is typically via decorator::
@@ -873,13 +928,29 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
self.func = func
self.expression(expr or func)
- def __get__(self, instance, owner):
+ @overload
+ def __get__(
+ self, instance: Literal[None], owner: Type[object]
+ ) -> Callable[[Any], SQLCoreOperations[_T]]:
+ ...
+
+ @overload
+ def __get__(
+ self, instance: object, owner: Type[object]
+ ) -> Callable[[Any], _T]:
+ ...
+
+ def __get__(
+ self, instance: Optional[object], owner: Type[object]
+ ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]:
if instance is None:
- return self.expr.__get__(owner, owner.__class__)
+ return self.expr.__get__(owner, owner) # type: ignore
else:
- return self.func.__get__(instance, owner)
+ return self.func.__get__(instance, owner) # type: ignore
- def expression(self, expr):
+ def expression(
+ self, expr: Callable[..., SQLCoreOperations[_T]]
+ ) -> hybrid_method[_T]:
"""Provide a modifying decorator that defines a
SQL-expression producing method."""
@@ -889,7 +960,12 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
return self
-class hybrid_property(interfaces.InspectionAttrInfo):
+Selfhybrid_property = TypeVar(
+ "Selfhybrid_property", bound="hybrid_property[Any]"
+)
+
+
+class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
"""A decorator which allows definition of a Python descriptor with both
instance-level and class-level behavior.
@@ -898,14 +974,16 @@ class hybrid_property(interfaces.InspectionAttrInfo):
is_attribute = True
extension_type = HybridExtensionType.HYBRID_PROPERTY
+ __name__: str
+
def __init__(
self,
- fget,
- fset=None,
- fdel=None,
- expr=None,
- custom_comparator=None,
- update_expr=None,
+ fget: _HybridGetterType[_T],
+ fset: Optional[_HybridSetterType[_T]] = None,
+ fdel: Optional[_HybridDeleterType[_T]] = None,
+ expr: Optional[_HybridExprCallableType[_T]] = None,
+ custom_comparator: Optional[Comparator[_T]] = None,
+ update_expr: Optional[_HybridUpdaterType[_T]] = None,
):
"""Create a new :class:`.hybrid_property`.
@@ -931,23 +1009,43 @@ class hybrid_property(interfaces.InspectionAttrInfo):
self.update_expr = update_expr
util.update_wrapper(self, fget)
- def __get__(self, instance, owner):
- if instance is None:
+ @overload
+ def __get__(
+ self: Selfhybrid_property, instance: Any, owner: Literal[None]
+ ) -> Selfhybrid_property:
+ ...
+
+ @overload
+ def __get__(
+ self, instance: Literal[None], owner: Type[object]
+ ) -> SQLCoreOperations[_T]:
+ ...
+
+ @overload
+ def __get__(self, instance: object, owner: Type[object]) -> _T:
+ ...
+
+ def __get__(
+ self, instance: Optional[object], owner: Optional[Type[object]]
+ ) -> Union[hybrid_property[_T], SQLCoreOperations[_T], _T]:
+ if owner is None:
+ return self
+ elif instance is None:
return self._expr_comparator(owner)
else:
return self.fget(instance)
- def __set__(self, instance, value):
+ def __set__(self, instance: object, value: Any) -> None:
if self.fset is None:
raise AttributeError("can't set attribute")
self.fset(instance, value)
- def __delete__(self, instance):
+ def __delete__(self, instance: object) -> None:
if self.fdel is None:
raise AttributeError("can't delete attribute")
self.fdel(instance)
- def _copy(self, **kw):
+ def _copy(self, **kw: Any) -> hybrid_property[_T]:
defaults = {
key: value
for key, value in self.__dict__.items()
@@ -957,7 +1055,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return type(self)(**defaults)
@property
- def overrides(self):
+ def overrides(self: Selfhybrid_property) -> Selfhybrid_property:
"""Prefix for a method that is overriding an existing attribute.
The :attr:`.hybrid_property.overrides` accessor just returns
@@ -992,7 +1090,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
"""
return self
- def getter(self, fget):
+ def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a getter method.
.. versionadded:: 1.2
@@ -1001,17 +1099,19 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._copy(fget=fget)
- def setter(self, fset):
+ def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a setter method."""
return self._copy(fset=fset)
- def deleter(self, fdel):
+ def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a deletion method."""
return self._copy(fdel=fdel)
- def expression(self, expr):
+ def expression(
+ self, expr: _HybridExprCallableType[_T]
+ ) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a SQL-expression
producing method.
@@ -1043,7 +1143,7 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._copy(expr=expr)
- def comparator(self, comparator):
+ def comparator(self, comparator: Comparator[_T]) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines a custom
comparator producing method.
@@ -1078,7 +1178,9 @@ class hybrid_property(interfaces.InspectionAttrInfo):
"""
return self._copy(custom_comparator=comparator)
- def update_expression(self, meth):
+ def update_expression(
+ self, meth: _HybridUpdaterType[_T]
+ ) -> hybrid_property[_T]:
"""Provide a modifying decorator that defines an UPDATE tuple
producing method.
@@ -1115,27 +1217,35 @@ class hybrid_property(interfaces.InspectionAttrInfo):
return self._copy(update_expr=meth)
@util.memoized_property
- def _expr_comparator(self):
+ def _expr_comparator(
+ self,
+ ) -> Callable[[Any], interfaces.PropComparator[_T]]:
if self.custom_comparator is not None:
return self._get_comparator(self.custom_comparator)
elif self.expr is not None:
return self._get_expr(self.expr)
else:
- return self._get_expr(self.fget)
+ return self._get_expr(cast(_HybridExprCallableType[_T], self.fget))
- def _get_expr(self, expr):
- def _expr(cls):
+ def _get_expr(
+ self, expr: _HybridExprCallableType[_T]
+ ) -> Callable[[Any], interfaces.PropComparator[_T]]:
+ def _expr(cls: Any) -> ExprComparator[_T]:
return ExprComparator(cls, expr(cls), self)
util.update_wrapper(_expr, expr)
return self._get_comparator(_expr)
- def _get_comparator(self, comparator):
+ def _get_comparator(
+ self, comparator: Any
+ ) -> Callable[[Any], interfaces.PropComparator[_T]]:
proxy_attr = attributes.create_proxied_attribute(self)
- def expr_comparator(owner):
+ def expr_comparator(
+ owner: Type[object],
+ ) -> interfaces.PropComparator[_T]:
# because this is the descriptor protocol, we don't really know
# what our attribute name is. so search for it through the
# MRO.
@@ -1163,36 +1273,48 @@ class Comparator(interfaces.PropComparator[_T]):
:class:`~.orm.interfaces.PropComparator`
classes for usage with hybrids."""
- property = None
-
- def __init__(self, expression):
+ def __init__(self, expression: SQLCoreOperations[_T]):
self.expression = expression
- def __clause_element__(self):
+ def __clause_element__(self) -> ColumnElement[_T]:
expr = self.expression
- if hasattr(expr, "__clause_element__"):
+ if is_has_column_element_clause_element(expr):
expr = expr.__clause_element__()
+
+ elif TYPE_CHECKING:
+ assert isinstance(expr, ColumnElement)
return expr
- def adapt_to_entity(self, adapt_to_entity):
+ @util.non_memoized_property
+ def property(self) -> Any:
+ return None
+
+ def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]:
# interesting....
return self
class ExprComparator(Comparator[_T]):
- def __init__(self, cls, expression, hybrid):
+ def __init__(
+ self,
+ cls: Type[Any],
+ expression: SQLCoreOperations[_T],
+ hybrid: hybrid_property[_T],
+ ):
self.cls = cls
self.expression = expression
self.hybrid = hybrid
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
return getattr(self.expression, key)
- @property
- def info(self):
+ @util.non_memoized_property
+ def info(self) -> Dict[Any, Any]:
return self.hybrid.info
- def _bulk_update_tuples(self, value):
+ def _bulk_update_tuples(
+ self, value: Any
+ ) -> List[Tuple[SQLCoreOperations[_T], Any]]:
if isinstance(self.expression, attributes.QueryableAttribute):
return self.expression._bulk_update_tuples(value)
elif self.hybrid.update_expr is not None:
@@ -1200,12 +1322,16 @@ class ExprComparator(Comparator[_T]):
else:
return [(self.expression, value)]
- @property
- def property(self):
- return self.expression.property
+ @util.non_memoized_property
+ def property(self) -> Any:
+ return self.expression.property # type: ignore
- def operate(self, op, *other, **kwargs):
- return op(self.expression, *other, **kwargs)
+ def operate(
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(self.expression, *other, **kwargs) # type: ignore
- def reverse_operate(self, op, other, **kwargs):
- return op(other, self.expression, **kwargs)
+ def reverse_operate(
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
+ return op(other, self.expression, **kwargs) # type: ignore