summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-09-08 13:19:08 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-09-08 14:15:50 -0400
commitfcd298e1afe9b309de34d28b35e4debc3940d6b9 (patch)
treef65417b854556cf0ff1b199f38e0be942dc1e933
parent36803bc5674e30741021462f77cd91d42b717066 (diff)
downloadsqlalchemy-fcd298e1afe9b309de34d28b35e4debc3940d6b9.tar.gz
additional de-stringify pass for unions
the change in c3cfee5b00a40790c18d took out a pass for de-stringify that broke some un-tested cases for Optional with future annotations mode. Adding tests for this revealed that this was a subset of a more general case where Union is presented with ForwardRefs inside of it matching up within the type map, which wasn't working before either, fixed that as well with an additional de-stringify for elements within the Union. Fixes: #8478 Change-Id: I8804cf6c67f14d10804584e1cddd2cfaa2376654
-rw-r--r--lib/sqlalchemy/orm/properties.py6
-rw-r--r--lib/sqlalchemy/util/typing.py17
-rw-r--r--test/orm/declarative/test_tm_future_annotations.py82
3 files changed, 102 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 7d7175678..3d9fe578d 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -51,9 +51,11 @@ from ..sql.schema import Column
from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
+from ..util.typing import de_stringify_union_elements
from ..util.typing import is_fwd_ref
from ..util.typing import is_optional_union
from ..util.typing import is_pep593
+from ..util.typing import is_union
from ..util.typing import Self
from ..util.typing import typing_get_args
@@ -655,6 +657,9 @@ class MappedColumn(
if is_fwd_ref(argument):
argument = de_stringify_annotation(cls, argument)
+ if is_union(argument):
+ argument = de_stringify_union_elements(cls, argument)
+
nullable = is_optional_union(argument)
if not self._has_nullable:
@@ -690,6 +695,7 @@ class MappedColumn(
checks = (our_type,)
for check_type in checks:
+
if registry.type_annotation_map:
new_sqltype = registry.type_annotation_map.get(check_type)
if new_sqltype is None:
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 85c1bae72..a0d59a630 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -120,6 +120,19 @@ def de_stringify_annotation(
return annotation # type: ignore
+def de_stringify_union_elements(
+ cls: Type[Any],
+ annotation: _AnnotationScanType,
+ str_cleanup_fn: Optional[Callable[[str], str]] = None,
+) -> Type[Any]:
+ return make_union_type(
+ *[
+ de_stringify_annotation(cls, anno, str_cleanup_fn)
+ for anno in annotation.__args__ # type: ignore
+ ]
+ )
+
+
def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
@@ -186,7 +199,7 @@ def expand_unions(
return (type_,)
-def is_optional(type_):
+def is_optional(type_: Any) -> bool:
return is_origin_of(
type_,
"Optional",
@@ -199,7 +212,7 @@ def is_optional_union(type_: Any) -> bool:
return is_optional(type_) and NoneType in typing_get_args(type_)
-def is_union(type_):
+def is_union(type_: Any) -> bool:
return is_origin_of(type_, "Union")
diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py
index 74cbebb4d..76ee464fa 100644
--- a/test/orm/declarative/test_tm_future_annotations.py
+++ b/test/orm/declarative/test_tm_future_annotations.py
@@ -1,13 +1,19 @@
from __future__ import annotations
+from decimal import Decimal
from typing import List
+from typing import Optional
from typing import Set
from typing import TypeVar
+from typing import Union
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
+from sqlalchemy import Numeric
+from sqlalchemy import Table
from sqlalchemy.orm import attribute_mapped_collection
+from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import MappedCollection
@@ -16,7 +22,8 @@ from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
-from .test_typed_mapping import MappedColumnTest # noqa
+from sqlalchemy.util import compat
+from .test_typed_mapping import MappedColumnTest as _MappedColumnTest
from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
"""runs the annotation-sensitive tests from test_typed_mappings while
@@ -28,6 +35,79 @@ having ``from __future__ import annotations`` in effect.
_R = TypeVar("_R")
+class MappedColumnTest(_MappedColumnTest):
+ def test_unions(self):
+ our_type = Numeric(10, 2)
+
+ class Base(DeclarativeBase):
+ type_annotation_map = {Union[float, Decimal]: our_type}
+
+ class User(Base):
+ __tablename__ = "users"
+ __table__: Table
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ data: Mapped[Union[float, Decimal]] = mapped_column()
+ reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+
+ optional_data: Mapped[
+ Optional[Union[float, Decimal]]
+ ] = mapped_column()
+
+ # use Optional directly
+ reverse_optional_data: Mapped[
+ Optional[Union[Decimal, float]]
+ ] = mapped_column()
+
+ # use Union with None, same as Optional but presents differently
+ # (Optional object with __origin__ Union vs. Union)
+ reverse_u_optional_data: Mapped[
+ Union[Decimal, float, None]
+ ] = mapped_column()
+
+ float_data: Mapped[float] = mapped_column()
+ decimal_data: Mapped[Decimal] = mapped_column()
+
+ if compat.py310:
+ pep604_data: Mapped[float | Decimal] = mapped_column()
+ pep604_reverse: Mapped[Decimal | float] = mapped_column()
+ pep604_optional: Mapped[
+ Decimal | float | None
+ ] = mapped_column()
+ pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
+ pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
+ pep604_optional_fwd: Mapped[
+ "Decimal | float | None"
+ ] = mapped_column()
+
+ is_(User.__table__.c.data.type, our_type)
+ is_false(User.__table__.c.data.nullable)
+ is_(User.__table__.c.reverse_data.type, our_type)
+ is_(User.__table__.c.optional_data.type, our_type)
+ is_true(User.__table__.c.optional_data.nullable)
+
+ is_(User.__table__.c.reverse_optional_data.type, our_type)
+ is_(User.__table__.c.reverse_u_optional_data.type, our_type)
+ is_true(User.__table__.c.reverse_optional_data.nullable)
+ is_true(User.__table__.c.reverse_u_optional_data.nullable)
+
+ is_(User.__table__.c.float_data.type, our_type)
+ is_(User.__table__.c.decimal_data.type, our_type)
+
+ if compat.py310:
+ for suffix in ("", "_fwd"):
+ data_col = User.__table__.c[f"pep604_data{suffix}"]
+ reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
+ optional_col = User.__table__.c[f"pep604_optional{suffix}"]
+ is_(data_col.type, our_type)
+ is_false(data_col.nullable)
+ is_(reverse_col.type, our_type)
+ is_false(reverse_col.nullable)
+ is_(optional_col.type, our_type)
+ is_true(optional_col.nullable)
+
+
class MappedOneArg(MappedCollection[str, _R]):
pass