diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-16 12:07:25 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-17 09:42:29 -0400 |
| commit | 3b520e758a715cf817075e4a90ae1b5813ffadd3 (patch) | |
| tree | 260f9517af499e7fb789d188f1631cd823a59929 /lib/sqlalchemy/ext | |
| parent | 6acf5d2fca4a988a77481b82662174e8015a6b37 (diff) | |
| download | sqlalchemy-3b520e758a715cf817075e4a90ae1b5813ffadd3.tar.gz | |
pep484 for hybrid
Change-Id: I53274b13094d996e11b04acb03f9613edbddf87f
References: #6810
Diffstat (limited to 'lib/sqlalchemy/ext')
| -rw-r--r-- | lib/sqlalchemy/ext/hybrid.py | 224 |
1 files changed, 175 insertions, 49 deletions
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index dc34a2ef5..92b3ce54f 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -802,17 +802,41 @@ advanced and/or patient developers, there's probably a whole lot of amazing things it can be used for. """ # noqa + +from __future__ import annotations + from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from .. import util from ..orm import attributes from ..orm import InspectionAttrExtensionType from ..orm import interfaces from ..orm import ORMDescriptor +from ..sql._typing import is_has_column_element_clause_element +from ..sql.elements import ColumnElement +from ..sql.elements import SQLCoreOperations +from ..util.typing import Literal +from ..util.typing import Protocol +if TYPE_CHECKING: + from ..orm.util import AliasedInsp + from ..sql.operators import OperatorType _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) class HybridExtensionType(InspectionAttrExtensionType): @@ -844,7 +868,34 @@ class HybridExtensionType(InspectionAttrExtensionType): """ -class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): +class _HybridGetterType(Protocol[_T_co]): + def __call__(s, self: Any) -> _T_co: + ... + + +class _HybridSetterType(Protocol[_T_con]): + def __call__(self, instance: Any, value: _T_con) -> None: + ... + + +class _HybridUpdaterType(Protocol[_T]): + def __call__( + self, cls: Type[Any], value: Union[_T, SQLCoreOperations[_T]] + ) -> List[Tuple[SQLCoreOperations[_T], Any]]: + ... + + +class _HybridDeleterType(Protocol[_T_co]): + def __call__(self, instance: Any) -> None: + ... + + +class _HybridExprCallableType(Protocol[_T]): + def __call__(self, cls: Any) -> SQLCoreOperations[_T]: + ... + + +class hybrid_method(interfaces.InspectionAttrInfo, Generic[_T]): """A decorator which allows definition of a Python object method with both instance-level and class-level behavior. @@ -853,7 +904,11 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): is_attribute = True extension_type = HybridExtensionType.HYBRID_METHOD - def __init__(self, func, expr=None): + def __init__( + self, + func: Callable[..., _T], + expr: Optional[Callable[..., SQLCoreOperations[_T]]] = None, + ): """Create a new :class:`.hybrid_method`. Usage is typically via decorator:: @@ -873,13 +928,29 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): self.func = func self.expression(expr or func) - def __get__(self, instance, owner): + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> Callable[[Any], SQLCoreOperations[_T]]: + ... + + @overload + def __get__( + self, instance: object, owner: Type[object] + ) -> Callable[[Any], _T]: + ... + + def __get__( + self, instance: Optional[object], owner: Type[object] + ) -> Union[Callable[[Any], _T], Callable[[Any], SQLCoreOperations[_T]]]: if instance is None: - return self.expr.__get__(owner, owner.__class__) + return self.expr.__get__(owner, owner) # type: ignore else: - return self.func.__get__(instance, owner) + return self.func.__get__(instance, owner) # type: ignore - def expression(self, expr): + def expression( + self, expr: Callable[..., SQLCoreOperations[_T]] + ) -> hybrid_method[_T]: """Provide a modifying decorator that defines a SQL-expression producing method.""" @@ -889,7 +960,12 @@ class hybrid_method(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): return self -class hybrid_property(interfaces.InspectionAttrInfo): +Selfhybrid_property = TypeVar( + "Selfhybrid_property", bound="hybrid_property[Any]" +) + + +class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): """A decorator which allows definition of a Python descriptor with both instance-level and class-level behavior. @@ -898,14 +974,16 @@ class hybrid_property(interfaces.InspectionAttrInfo): is_attribute = True extension_type = HybridExtensionType.HYBRID_PROPERTY + __name__: str + def __init__( self, - fget, - fset=None, - fdel=None, - expr=None, - custom_comparator=None, - update_expr=None, + fget: _HybridGetterType[_T], + fset: Optional[_HybridSetterType[_T]] = None, + fdel: Optional[_HybridDeleterType[_T]] = None, + expr: Optional[_HybridExprCallableType[_T]] = None, + custom_comparator: Optional[Comparator[_T]] = None, + update_expr: Optional[_HybridUpdaterType[_T]] = None, ): """Create a new :class:`.hybrid_property`. @@ -931,23 +1009,43 @@ class hybrid_property(interfaces.InspectionAttrInfo): self.update_expr = update_expr util.update_wrapper(self, fget) - def __get__(self, instance, owner): - if instance is None: + @overload + def __get__( + self: Selfhybrid_property, instance: Any, owner: Literal[None] + ) -> Selfhybrid_property: + ... + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> SQLCoreOperations[_T]: + ... + + @overload + def __get__(self, instance: object, owner: Type[object]) -> _T: + ... + + def __get__( + self, instance: Optional[object], owner: Optional[Type[object]] + ) -> Union[hybrid_property[_T], SQLCoreOperations[_T], _T]: + if owner is None: + return self + elif instance is None: return self._expr_comparator(owner) else: return self.fget(instance) - def __set__(self, instance, value): + def __set__(self, instance: object, value: Any) -> None: if self.fset is None: raise AttributeError("can't set attribute") self.fset(instance, value) - def __delete__(self, instance): + def __delete__(self, instance: object) -> None: if self.fdel is None: raise AttributeError("can't delete attribute") self.fdel(instance) - def _copy(self, **kw): + def _copy(self, **kw: Any) -> hybrid_property[_T]: defaults = { key: value for key, value in self.__dict__.items() @@ -957,7 +1055,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): return type(self)(**defaults) @property - def overrides(self): + def overrides(self: Selfhybrid_property) -> Selfhybrid_property: """Prefix for a method that is overriding an existing attribute. The :attr:`.hybrid_property.overrides` accessor just returns @@ -992,7 +1090,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): """ return self - def getter(self, fget): + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a getter method. .. versionadded:: 1.2 @@ -1001,17 +1099,19 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._copy(fget=fget) - def setter(self, fset): + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a setter method.""" return self._copy(fset=fset) - def deleter(self, fdel): + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a deletion method.""" return self._copy(fdel=fdel) - def expression(self, expr): + def expression( + self, expr: _HybridExprCallableType[_T] + ) -> hybrid_property[_T]: """Provide a modifying decorator that defines a SQL-expression producing method. @@ -1043,7 +1143,7 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._copy(expr=expr) - def comparator(self, comparator): + def comparator(self, comparator: Comparator[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a custom comparator producing method. @@ -1078,7 +1178,9 @@ class hybrid_property(interfaces.InspectionAttrInfo): """ return self._copy(custom_comparator=comparator) - def update_expression(self, meth): + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: """Provide a modifying decorator that defines an UPDATE tuple producing method. @@ -1115,27 +1217,35 @@ class hybrid_property(interfaces.InspectionAttrInfo): return self._copy(update_expr=meth) @util.memoized_property - def _expr_comparator(self): + def _expr_comparator( + self, + ) -> Callable[[Any], interfaces.PropComparator[_T]]: if self.custom_comparator is not None: return self._get_comparator(self.custom_comparator) elif self.expr is not None: return self._get_expr(self.expr) else: - return self._get_expr(self.fget) + return self._get_expr(cast(_HybridExprCallableType[_T], self.fget)) - def _get_expr(self, expr): - def _expr(cls): + def _get_expr( + self, expr: _HybridExprCallableType[_T] + ) -> Callable[[Any], interfaces.PropComparator[_T]]: + def _expr(cls: Any) -> ExprComparator[_T]: return ExprComparator(cls, expr(cls), self) util.update_wrapper(_expr, expr) return self._get_comparator(_expr) - def _get_comparator(self, comparator): + def _get_comparator( + self, comparator: Any + ) -> Callable[[Any], interfaces.PropComparator[_T]]: proxy_attr = attributes.create_proxied_attribute(self) - def expr_comparator(owner): + def expr_comparator( + owner: Type[object], + ) -> interfaces.PropComparator[_T]: # because this is the descriptor protocol, we don't really know # what our attribute name is. so search for it through the # MRO. @@ -1163,36 +1273,48 @@ class Comparator(interfaces.PropComparator[_T]): :class:`~.orm.interfaces.PropComparator` classes for usage with hybrids.""" - property = None - - def __init__(self, expression): + def __init__(self, expression: SQLCoreOperations[_T]): self.expression = expression - def __clause_element__(self): + def __clause_element__(self) -> ColumnElement[_T]: expr = self.expression - if hasattr(expr, "__clause_element__"): + if is_has_column_element_clause_element(expr): expr = expr.__clause_element__() + + elif TYPE_CHECKING: + assert isinstance(expr, ColumnElement) return expr - def adapt_to_entity(self, adapt_to_entity): + @util.non_memoized_property + def property(self) -> Any: + return None + + def adapt_to_entity(self, adapt_to_entity: AliasedInsp) -> Comparator[_T]: # interesting.... return self class ExprComparator(Comparator[_T]): - def __init__(self, cls, expression, hybrid): + def __init__( + self, + cls: Type[Any], + expression: SQLCoreOperations[_T], + hybrid: hybrid_property[_T], + ): self.cls = cls self.expression = expression self.hybrid = hybrid - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: return getattr(self.expression, key) - @property - def info(self): + @util.non_memoized_property + def info(self) -> Dict[Any, Any]: return self.hybrid.info - def _bulk_update_tuples(self, value): + def _bulk_update_tuples( + self, value: Any + ) -> List[Tuple[SQLCoreOperations[_T], Any]]: if isinstance(self.expression, attributes.QueryableAttribute): return self.expression._bulk_update_tuples(value) elif self.hybrid.update_expr is not None: @@ -1200,12 +1322,16 @@ class ExprComparator(Comparator[_T]): else: return [(self.expression, value)] - @property - def property(self): - return self.expression.property + @util.non_memoized_property + def property(self) -> Any: + return self.expression.property # type: ignore - def operate(self, op, *other, **kwargs): - return op(self.expression, *other, **kwargs) + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.expression, *other, **kwargs) # type: ignore - def reverse_operate(self, op, other, **kwargs): - return op(other, self.expression, **kwargs) + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.expression, **kwargs) # type: ignore |
