diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 10 |
4 files changed, 84 insertions, 15 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index c36089fde..e3e2611da 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -77,6 +77,7 @@ from ..util import typing as compat_typing from ..util.typing import CallableReference from ..util.typing import flatten_newtype from ..util.typing import is_generic +from ..util.typing import is_literal from ..util.typing import is_newtype from ..util.typing import Literal @@ -1218,10 +1219,19 @@ class registry: ) -> Optional[sqltypes.TypeEngine[Any]]: search: Iterable[Tuple[_MatchedOnType, Type[Any]]] + python_type_type: Type[Any] if is_generic(python_type): - python_type_type: Type[Any] = python_type.__origin__ - search = ((python_type, python_type_type),) + if is_literal(python_type): + python_type_type = cast("Type[Any]", python_type) + + search = ( # type: ignore[assignment] + (python_type, python_type_type), + (Literal, python_type_type), + ) + else: + python_type_type = python_type.__origin__ + search = ((python_type, python_type_type),) elif is_newtype(python_type): python_type_type = flatten_newtype(python_type) search = ((python_type, python_type_type),) diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 0462a8945..a858f12cb 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -66,6 +66,7 @@ from ..util import topological from ..util.typing import _AnnotationScanType from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref +from ..util.typing import is_literal from ..util.typing import Protocol from ..util.typing import TypedDict from ..util.typing import typing_get_args @@ -1165,7 +1166,7 @@ class _ClassScanMapperConfig(_MapperConfig): extracted_mapped_annotation, mapped_container = extracted - if attr_value is None: + if attr_value is None and not is_literal(extracted_mapped_annotation): for elem in typing_get_args(extracted_mapped_annotation): if isinstance(elem, str) or is_fwd_ref( elem, check_generic=True diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 717e6c0b2..b2dcc9b8a 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -59,7 +59,9 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict +from ..util.typing import is_literal from ..util.typing import Literal +from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -1263,6 +1265,11 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): .. seealso:: + :ref:`orm_declarative_mapped_column_enums` - background on using + the :class:`_sqltypes.Enum` datatype with the ORM's + :ref:`ORM Annotated Declarative <orm_declarative_mapped_column>` + feature. + :class:`_postgresql.ENUM` - PostgreSQL-specific type, which has additional functionality. @@ -1504,16 +1511,54 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): matched_on: _MatchedOnType, matched_on_flattened: Type[Any], ) -> Optional[Enum]: - if not issubclass(python_type, enum.Enum): - return None + + # "generic form" indicates we were placed in a type map + # as ``sqlalchemy.Enum(enum.Enum)`` which indicates we need to + # get enumerated values from the datatype + we_are_generic_form = self._enums_argument == [enum.Enum] + + native_enum = None + + if not we_are_generic_form and python_type is matched_on: + # if we have enumerated values, and the incoming python + # type is exactly the one that matched in the type map, + # then we use these enumerated values and dont try to parse + # what's incoming + enum_args = self._enums_argument + + elif is_literal(python_type): + # for a literal, where we need to get its contents, parse it out. + enum_args = typing_get_args(python_type) + bad_args = [arg for arg in enum_args if not isinstance(arg, str)] + if bad_args: + raise exc.ArgumentError( + f"Can't create string-based Enum datatype from non-string " + f"values: {', '.join(repr(x) for x in bad_args)}. Please " + f"provide an explicit Enum datatype for this Python type" + ) + native_enum = False + elif isinstance(python_type, type) and issubclass( + python_type, enum.Enum + ): + # same for an enum.Enum + enum_args = [python_type] + + else: + enum_args = self._enums_argument + + # make a new Enum that looks like this one. + # pop the "name" so that it gets generated based on the enum + # arguments or other rules + kw = self._make_enum_kw({}) + + kw.pop("name", None) + if native_enum is False: + kw["native_enum"] = False + + kw["length"] = NO_ARG if self.length == 0 else self.length return cast( Enum, - util.constructor_copy( - self, - self._generic_type_affinity, - python_type, - length=NO_ARG if self.length == 0 else self.length, - ), + self._generic_type_affinity(_enums=enum_args, **kw), # type: ignore # noqa: E501 ) def _setup_for_values(self, values, objects, kw): @@ -1622,19 +1667,23 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self, self._generic_type_affinity, *args, _disable_warnings=True ) - def adapt_to_emulated(self, impltype, **kw): + def _make_enum_kw(self, kw): kw.setdefault("validate_strings", self.validate_strings) kw.setdefault("name", self.name) - kw["_disable_warnings"] = True kw.setdefault("schema", self.schema) kw.setdefault("inherit_schema", self.inherit_schema) kw.setdefault("metadata", self.metadata) - kw.setdefault("_create_events", False) kw.setdefault("native_enum", self.native_enum) kw.setdefault("values_callable", self.values_callable) kw.setdefault("create_constraint", self.create_constraint) kw.setdefault("length", self.length) kw.setdefault("omit_aliases", self._omit_aliases) + return kw + + def adapt_to_emulated(self, impltype, **kw): + self._make_enum_kw(kw) + kw["_disable_warnings"] = True + kw.setdefault("_create_events", False) assert "_enums" in kw return impltype(**kw) @@ -3702,6 +3751,7 @@ _type_map: Dict[Type[Any], TypeEngine[Any]] = { bytes: LargeBinary(), str: _STRING, enum.Enum: Enum(enum.Enum), + Literal: Enum(enum.Enum), # type: ignore[dict-item] } diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 51e95ecfa..755185c9b 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -152,7 +152,11 @@ def de_stringify_annotation( annotation = eval_expression(annotation, originating_module) - if include_generic and is_generic(annotation): + if ( + include_generic + and is_generic(annotation) + and not is_literal(annotation) + ): elements = tuple( de_stringify_annotation( cls, @@ -249,6 +253,10 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated +def is_literal(type_: _AnnotationScanType) -> bool: + return get_origin(type_) is Literal + + def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: return hasattr(type_, "__supertype__") |
