summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-30 15:12:52 -0500
committermike bayer <mike_mp@zzzcomputing.com>2023-01-31 19:13:16 +0000
commita21c715b7a89b0619db0d2d5b31617d17b25a27a (patch)
tree01971eece9c687570137dd6e5d8e86b6040b697c /lib/sqlalchemy
parent6d6a17240815b9090a2972607657f93d347167d6 (diff)
downloadsqlalchemy-a21c715b7a89b0619db0d2d5b31617d17b25a27a.tar.gz
support NewType in type_annotation_map
Added support for :pep:`484` ``NewType`` to be used in the :paramref:`_orm.registry.type_annotation_map` as well as within :class:`.Mapped` constructs. These types will behave in the same way as custom subclasses of types right now; they must appear explicitly within the :paramref:`_orm.registry.type_annotation_map` to be mapped. Within this change, the lookup between decl_api._resolve_type and TypeEngine._resolve_for_python_type is streamlined to not inspect the given type multiple times, instead passing in from decl_api to TypeEngine the already "flattened" version of a Generic or NewType type. Fixes: #9175 Change-Id: I227cf84b4b88e4567fa2d1d7da0c05b54e00c562
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/decl_api.py31
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py8
-rw-r--r--lib/sqlalchemy/sql/type_api.py11
-rw-r--r--lib/sqlalchemy/util/typing.py30
4 files changed, 52 insertions, 28 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index a46c1a7fb..4f8443833 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -19,6 +19,7 @@ from typing import ClassVar
from typing import Dict
from typing import FrozenSet
from typing import Generic
+from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
@@ -74,7 +75,9 @@ from ..util import hybridmethod
from ..util import hybridproperty
from ..util import typing as compat_typing
from ..util.typing import CallableReference
+from ..util.typing import flatten_newtype
from ..util.typing import is_generic
+from ..util.typing import is_newtype
from ..util.typing import Literal
if TYPE_CHECKING:
@@ -85,7 +88,7 @@ if TYPE_CHECKING:
from .interfaces import MapperProperty
from .state import InstanceState # noqa
from ..sql._typing import _TypeEngineArgument
- from ..util.typing import GenericProtocol
+ from ..sql.type_api import _MatchedOnType
_T = TypeVar("_T", bound=Any)
@@ -1211,21 +1214,24 @@ class registry:
)
def _resolve_type(
- self, python_type: Union[GenericProtocol[Any], Type[Any]]
+ self, python_type: _MatchedOnType
) -> Optional[sqltypes.TypeEngine[Any]]:
- search: Tuple[Union[GenericProtocol[Any], Type[Any]], ...]
+ search: Iterable[Tuple[_MatchedOnType, Type[Any]]]
if is_generic(python_type):
python_type_type: Type[Any] = python_type.__origin__
- search = (python_type,)
+ search = ((python_type, python_type_type),)
+ elif is_newtype(python_type):
+ python_type_type = flatten_newtype(python_type)
+ search = ((python_type, python_type_type),)
else:
- # don't know why is_generic() TypeGuard[GenericProtocol[Any]]
- # check above is not sufficient here
python_type_type = cast("Type[Any]", python_type)
- search = python_type_type.__mro__
+ flattened = None
+ search = ((pt, pt) for pt in python_type_type.__mro__)
- for pt in search:
+ for pt, flattened in search:
+ # we search through full __mro__ for types. however...
sql_type = self.type_annotation_map.get(pt)
if sql_type is None:
sql_type = sqltypes._type_map_get(pt) # type: ignore # noqa: E501
@@ -1233,8 +1239,15 @@ class registry:
if sql_type is not None:
sql_type_inst = sqltypes.to_instance(sql_type) # type: ignore
+ # ... this additional step will reject most
+ # type -> supertype matches, such as if we had
+ # a MyInt(int) subclass. note also we pass NewType()
+ # here directly; these always have to be in the
+ # type_annotation_map to be useful
resolved_sql_type = sql_type_inst._resolve_for_python_type(
- python_type_type, pt
+ python_type_type,
+ pt,
+ flattened,
)
if resolved_sql_type is not None:
return resolved_sql_type
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index b5c79b4b9..717e6c0b2 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -59,7 +59,6 @@ from .. import util
from ..engine import processors
from ..util import langhelpers
from ..util import OrderedDict
-from ..util.typing import GenericProtocol
from ..util.typing import Literal
if TYPE_CHECKING:
@@ -69,6 +68,7 @@ if TYPE_CHECKING:
from .schema import MetaData
from .type_api import _BindProcessorType
from .type_api import _ComparatorFactory
+ from .type_api import _MatchedOnType
from .type_api import _ResultProcessorType
from ..engine.interfaces import Dialect
@@ -1493,14 +1493,16 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
return enums, enums
def _resolve_for_literal(self, value: Any) -> Enum:
- typ = self._resolve_for_python_type(type(value), type(value))
+ tv = type(value)
+ typ = self._resolve_for_python_type(tv, tv, tv)
assert typ is not None
return typ
def _resolve_for_python_type(
self,
python_type: Type[Any],
- matched_on: Union[GenericProtocol[Any], Type[Any]],
+ matched_on: _MatchedOnType,
+ matched_on_flattened: Type[Any],
) -> Optional[Enum]:
if not issubclass(python_type, enum.Enum):
return None
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 79c889763..fefbf4997 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -19,6 +19,7 @@ from typing import cast
from typing import Dict
from typing import Generic
from typing import Mapping
+from typing import NewType
from typing import Optional
from typing import overload
from typing import Sequence
@@ -35,7 +36,6 @@ from .operators import ColumnOperators
from .visitors import Visitable
from .. import exc
from .. import util
-from ..util.typing import flatten_generic
from ..util.typing import Protocol
from ..util.typing import TypedDict
from ..util.typing import TypeGuard
@@ -65,6 +65,8 @@ _O = TypeVar("_O", bound=object)
_TE = TypeVar("_TE", bound="TypeEngine[Any]")
_CT = TypeVar("_CT", bound=Any)
+_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]]
+
# replace with pep-673 when applicable
SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine[Any]")
@@ -731,7 +733,8 @@ class TypeEngine(Visitable, Generic[_T]):
def _resolve_for_python_type(
self: SelfTypeEngine,
python_type: Type[Any],
- matched_on: Union[GenericProtocol[Any], Type[Any]],
+ matched_on: _MatchedOnType,
+ matched_on_flattened: Type[Any],
) -> Optional[SelfTypeEngine]:
"""given a Python type (e.g. ``int``, ``str``, etc. ) return an
instance of this :class:`.TypeEngine` that's appropriate for this type.
@@ -772,9 +775,7 @@ class TypeEngine(Visitable, Generic[_T]):
"""
- matched_on = flatten_generic(matched_on)
-
- if python_type is not matched_on:
+ if python_type is not matched_on_flattened:
return None
return self
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index e1670ed21..51e95ecfa 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -18,6 +18,7 @@ from typing import Dict
from typing import ForwardRef
from typing import Generic
from typing import Iterable
+from typing import NewType
from typing import NoReturn
from typing import Optional
from typing import overload
@@ -71,10 +72,9 @@ typing_get_args = get_args
typing_get_origin = get_origin
-# copied from TypeShed, required in order to implement
-# MutableMapping.update()
-
-_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"]
+_AnnotationScanType = Union[
+ Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]"
+]
class ArgsTypeProcotol(Protocol):
@@ -105,6 +105,8 @@ class GenericProtocol(Protocol[_T]):
# ...
+# copied from TypeShed, required in order to implement
+# MutableMapping.update()
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]:
...
@@ -247,17 +249,23 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
+def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
+ return hasattr(type_, "__supertype__")
+
+ # doesn't work in 3.8, 3.7 as it passes a closure, not an
+ # object instance
+ # return isinstance(type_, NewType)
+
+
def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
-def flatten_generic(
- type_: Union[GenericProtocol[Any], Type[Any]]
-) -> Type[Any]:
- if is_generic(type_):
- return type_.__origin__
- else:
- return cast("Type[Any]", type_)
+def flatten_newtype(type_: NewType) -> Type[Any]:
+ super_type = type_.__supertype__
+ while is_newtype(super_type):
+ super_type = super_type.__supertype__
+ return super_type
def is_fwd_ref(