diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-01-30 15:12:52 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2023-01-31 19:13:16 +0000 |
| commit | a21c715b7a89b0619db0d2d5b31617d17b25a27a (patch) | |
| tree | 01971eece9c687570137dd6e5d8e86b6040b697c /lib/sqlalchemy | |
| parent | 6d6a17240815b9090a2972607657f93d347167d6 (diff) | |
| download | sqlalchemy-a21c715b7a89b0619db0d2d5b31617d17b25a27a.tar.gz | |
support NewType in type_annotation_map
Added support for :pep:`484` ``NewType`` to be used in the
:paramref:`_orm.registry.type_annotation_map` as well as within
:class:`.Mapped` constructs. These types will behave in the same way as
custom subclasses of types right now; they must appear explicitly within
the :paramref:`_orm.registry.type_annotation_map` to be mapped.
Within this change, the lookup between decl_api._resolve_type
and TypeEngine._resolve_for_python_type is streamlined to not
inspect the given type multiple times, instead passing
in from decl_api to TypeEngine the already "flattened" version
of a Generic or NewType type.
Fixes: #9175
Change-Id: I227cf84b4b88e4567fa2d1d7da0c05b54e00c562
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 30 |
4 files changed, 52 insertions, 28 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index a46c1a7fb..4f8443833 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -19,6 +19,7 @@ from typing import ClassVar from typing import Dict from typing import FrozenSet from typing import Generic +from typing import Iterable from typing import Iterator from typing import Mapping from typing import Optional @@ -74,7 +75,9 @@ from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing from ..util.typing import CallableReference +from ..util.typing import flatten_newtype from ..util.typing import is_generic +from ..util.typing import is_newtype from ..util.typing import Literal if TYPE_CHECKING: @@ -85,7 +88,7 @@ if TYPE_CHECKING: from .interfaces import MapperProperty from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument - from ..util.typing import GenericProtocol + from ..sql.type_api import _MatchedOnType _T = TypeVar("_T", bound=Any) @@ -1211,21 +1214,24 @@ class registry: ) def _resolve_type( - self, python_type: Union[GenericProtocol[Any], Type[Any]] + self, python_type: _MatchedOnType ) -> Optional[sqltypes.TypeEngine[Any]]: - search: Tuple[Union[GenericProtocol[Any], Type[Any]], ...] + search: Iterable[Tuple[_MatchedOnType, Type[Any]]] if is_generic(python_type): python_type_type: Type[Any] = python_type.__origin__ - search = (python_type,) + search = ((python_type, python_type_type),) + elif is_newtype(python_type): + python_type_type = flatten_newtype(python_type) + search = ((python_type, python_type_type),) else: - # don't know why is_generic() TypeGuard[GenericProtocol[Any]] - # check above is not sufficient here python_type_type = cast("Type[Any]", python_type) - search = python_type_type.__mro__ + flattened = None + search = ((pt, pt) for pt in python_type_type.__mro__) - for pt in search: + for pt, flattened in search: + # we search through full __mro__ for types. however... sql_type = self.type_annotation_map.get(pt) if sql_type is None: sql_type = sqltypes._type_map_get(pt) # type: ignore # noqa: E501 @@ -1233,8 +1239,15 @@ class registry: if sql_type is not None: sql_type_inst = sqltypes.to_instance(sql_type) # type: ignore + # ... this additional step will reject most + # type -> supertype matches, such as if we had + # a MyInt(int) subclass. note also we pass NewType() + # here directly; these always have to be in the + # type_annotation_map to be useful resolved_sql_type = sql_type_inst._resolve_for_python_type( - python_type_type, pt + python_type_type, + pt, + flattened, ) if resolved_sql_type is not None: return resolved_sql_type diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b5c79b4b9..717e6c0b2 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -59,7 +59,6 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict -from ..util.typing import GenericProtocol from ..util.typing import Literal if TYPE_CHECKING: @@ -69,6 +68,7 @@ if TYPE_CHECKING: from .schema import MetaData from .type_api import _BindProcessorType from .type_api import _ComparatorFactory + from .type_api import _MatchedOnType from .type_api import _ResultProcessorType from ..engine.interfaces import Dialect @@ -1493,14 +1493,16 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return enums, enums def _resolve_for_literal(self, value: Any) -> Enum: - typ = self._resolve_for_python_type(type(value), type(value)) + tv = type(value) + typ = self._resolve_for_python_type(tv, tv, tv) assert typ is not None return typ def _resolve_for_python_type( self, python_type: Type[Any], - matched_on: Union[GenericProtocol[Any], Type[Any]], + matched_on: _MatchedOnType, + matched_on_flattened: Type[Any], ) -> Optional[Enum]: if not issubclass(python_type, enum.Enum): return None diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 79c889763..fefbf4997 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -19,6 +19,7 @@ from typing import cast from typing import Dict from typing import Generic from typing import Mapping +from typing import NewType from typing import Optional from typing import overload from typing import Sequence @@ -35,7 +36,6 @@ from .operators import ColumnOperators from .visitors import Visitable from .. import exc from .. import util -from ..util.typing import flatten_generic from ..util.typing import Protocol from ..util.typing import TypedDict from ..util.typing import TypeGuard @@ -65,6 +65,8 @@ _O = TypeVar("_O", bound=object) _TE = TypeVar("_TE", bound="TypeEngine[Any]") _CT = TypeVar("_CT", bound=Any) +_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]] + # replace with pep-673 when applicable SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine[Any]") @@ -731,7 +733,8 @@ class TypeEngine(Visitable, Generic[_T]): def _resolve_for_python_type( self: SelfTypeEngine, python_type: Type[Any], - matched_on: Union[GenericProtocol[Any], Type[Any]], + matched_on: _MatchedOnType, + matched_on_flattened: Type[Any], ) -> Optional[SelfTypeEngine]: """given a Python type (e.g. ``int``, ``str``, etc. ) return an instance of this :class:`.TypeEngine` that's appropriate for this type. @@ -772,9 +775,7 @@ class TypeEngine(Visitable, Generic[_T]): """ - matched_on = flatten_generic(matched_on) - - if python_type is not matched_on: + if python_type is not matched_on_flattened: return None return self diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index e1670ed21..51e95ecfa 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -18,6 +18,7 @@ from typing import Dict from typing import ForwardRef from typing import Generic from typing import Iterable +from typing import NewType from typing import NoReturn from typing import Optional from typing import overload @@ -71,10 +72,9 @@ typing_get_args = get_args typing_get_origin = get_origin -# copied from TypeShed, required in order to implement -# MutableMapping.update() - -_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"] +_AnnotationScanType = Union[ + Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]" +] class ArgsTypeProcotol(Protocol): @@ -105,6 +105,8 @@ class GenericProtocol(Protocol[_T]): # ... +# copied from TypeShed, required in order to implement +# MutableMapping.update() class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): def keys(self) -> Iterable[_KT]: ... @@ -247,17 +249,23 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated +def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: + return hasattr(type_, "__supertype__") + + # doesn't work in 3.8, 3.7 as it passes a closure, not an + # object instance + # return isinstance(type_, NewType) + + def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: return hasattr(type_, "__args__") and hasattr(type_, "__origin__") -def flatten_generic( - type_: Union[GenericProtocol[Any], Type[Any]] -) -> Type[Any]: - if is_generic(type_): - return type_.__origin__ - else: - return cast("Type[Any]", type_) +def flatten_newtype(type_: NewType) -> Type[Any]: + super_type = type_.__supertype__ + while is_newtype(super_type): + super_type = super_type.__supertype__ + return super_type def is_fwd_ref( |
