summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/properties.py10
-rw-r--r--lib/sqlalchemy/util/typing.py9
2 files changed, 11 insertions, 8 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 6213cfef8..7d7175678 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -52,8 +52,8 @@ from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
+from ..util.typing import is_optional_union
from ..util.typing import is_pep593
-from ..util.typing import NoneType
from ..util.typing import Self
from ..util.typing import typing_get_args
@@ -652,17 +652,15 @@ class MappedColumn(
) -> None:
sqltype = self.column.type
- nullable = False
+ if is_fwd_ref(argument):
+ argument = de_stringify_annotation(cls, argument)
- if hasattr(argument, "__origin__"):
- nullable = NoneType in argument.__args__ # type: ignore
+ nullable = is_optional_union(argument)
if not self._has_nullable:
self.column.nullable = nullable
our_type = de_optionalize_union_types(argument)
- if is_fwd_ref(our_type):
- our_type = de_stringify_annotation(cls, our_type)
use_args_from = None
if is_pep593(our_type):
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 45fe63765..85c1bae72 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -169,7 +169,7 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
def expand_unions(
type_: Type[Any], include_union: bool = False, discard_none: bool = False
) -> Tuple[Type[Any], ...]:
- """Return a type as as a tuple of individual types, expanding for
+ """Return a type as a tuple of individual types, expanding for
``Union`` types."""
if is_union(type_):
@@ -191,9 +191,14 @@ def is_optional(type_):
type_,
"Optional",
"Union",
+ "UnionType",
)
+def is_optional_union(type_: Any) -> bool:
+ return is_optional(type_) and NoneType in typing_get_args(type_)
+
+
def is_union(type_):
return is_origin_of(type_, "Union")
@@ -204,7 +209,7 @@ def is_origin_of(
"""return True if the given type has an __origin__ with the given name
and optional module."""
- origin = getattr(type_, "__origin__", None)
+ origin = typing_get_origin(type_)
if origin is None:
return False