diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 9 |
2 files changed, 11 insertions, 8 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 6213cfef8..7d7175678 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -52,8 +52,8 @@ from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref +from ..util.typing import is_optional_union from ..util.typing import is_pep593 -from ..util.typing import NoneType from ..util.typing import Self from ..util.typing import typing_get_args @@ -652,17 +652,15 @@ class MappedColumn( ) -> None: sqltype = self.column.type - nullable = False + if is_fwd_ref(argument): + argument = de_stringify_annotation(cls, argument) - if hasattr(argument, "__origin__"): - nullable = NoneType in argument.__args__ # type: ignore + nullable = is_optional_union(argument) if not self._has_nullable: self.column.nullable = nullable our_type = de_optionalize_union_types(argument) - if is_fwd_ref(our_type): - our_type = de_stringify_annotation(cls, our_type) use_args_from = None if is_pep593(our_type): diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 45fe63765..85c1bae72 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -169,7 +169,7 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]: def expand_unions( type_: Type[Any], include_union: bool = False, discard_none: bool = False ) -> Tuple[Type[Any], ...]: - """Return a type as as a tuple of individual types, expanding for + """Return a type as a tuple of individual types, expanding for ``Union`` types.""" if is_union(type_): @@ -191,9 +191,14 @@ def is_optional(type_): type_, "Optional", "Union", + "UnionType", ) +def is_optional_union(type_: Any) -> bool: + return is_optional(type_) and NoneType in typing_get_args(type_) + + def is_union(type_): return is_origin_of(type_, "Union") @@ -204,7 +209,7 @@ def is_origin_of( """return True if the given type has an __origin__ with the given name and optional module.""" - origin = getattr(type_, "__origin__", None) + origin = typing_get_origin(type_) if origin is None: return False |
