summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2022-11-27 18:11:34 +0100
committerMike Bayer <mike_mp@zzzcomputing.com>2022-11-29 17:49:27 -0500
commit9c9fd31bcea3beaed6d14fde639e65f6b43bea09 (patch)
tree2eef4b31c1f89f364c9bf15fdf153a4aad0f98c6 /test
parent78833af4e650d37e6257cfbb541e4db56e2a285f (diff)
downloadsqlalchemy-9c9fd31bcea3beaed6d14fde639e65f6b43bea09.tar.gz
Improve support for enum in mapped classes
Add a new system by which TypeEngine objects have some say in how the declarative type registry interprets them. The Enum datatype is the primary target for this but it is hoped the system may be useful for other types as well. Fixes: #8859 Change-Id: I15ac3daee770408b5795746f47c1bbd931b7d26d
Diffstat (limited to 'test')
-rw-r--r--test/orm/declarative/test_tm_future_annotations_sync.py203
-rw-r--r--test/orm/declarative/test_typed_mapping.py203
2 files changed, 406 insertions, 0 deletions
diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py
index 7358f385d..5d1b6b199 100644
--- a/test/orm/declarative/test_tm_future_annotations_sync.py
+++ b/test/orm/declarative/test_tm_future_annotations_sync.py
@@ -10,6 +10,7 @@ from __future__ import annotations
import dataclasses
import datetime
from decimal import Decimal
+import enum
from typing import Any
from typing import ClassVar
from typing import Dict
@@ -53,11 +54,13 @@ from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import MappedAsDataclass
from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
from sqlalchemy.orm import undefer
from sqlalchemy.orm import WriteOnlyMapped
from sqlalchemy.orm.collections import attribute_keyed_dict
from sqlalchemy.orm.collections import KeyFuncDict
from sqlalchemy.schema import CreateTable
+from sqlalchemy.sql.sqltypes import Enum
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
@@ -1134,6 +1137,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
id: Mapped[int] = mapped_column(primary_key=True)
data: Mapped["fake"] # noqa
+ @testing.variation("use_callable", [True, False])
+ @testing.variation("include_generic", [True, False])
+ def test_enum_explicit(self, use_callable, include_generic):
+ global FooEnum
+
+ class FooEnum(enum.Enum):
+ foo = enum.auto()
+ bar = enum.auto()
+
+ if use_callable:
+ tam = {FooEnum: Enum(FooEnum, length=500)}
+ else:
+ tam = {FooEnum: Enum(FooEnum, length=500)}
+ if include_generic:
+ tam[enum.Enum] = Enum(enum.Enum)
+ Base = declarative_base(type_annotation_map=tam)
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[FooEnum]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+ eq_(MyClass.__table__.c.data.type.length, 500)
+ is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+ def test_enum_generic(self):
+ """test for #8859"""
+ global FooEnum
+
+ class FooEnum(enum.Enum):
+ foo = enum.auto()
+ bar = enum.auto()
+
+ Base = declarative_base(
+ type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)}
+ )
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[FooEnum]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+ eq_(MyClass.__table__.c.data.type.length, 42)
+ is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+ def test_enum_default(self, decl_base):
+ """test #8859.
+
+ We now have Enum in the default SQL lookup map, in conjunction with
+ a mechanism that will adapt it for a given enum type.
+
+ This relies on a search through __mro__ for the given type,
+ which in other tests we ensure does not actually function if
+ we aren't dealing with Enum (or some other type that allows for
+ __mro__ lookup)
+
+ """
+ global FooEnum
+
+ class FooEnum(enum.Enum):
+ foo = "foo"
+ bar_value = "bar"
+
+ class MyClass(decl_base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[FooEnum]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+ eq_(MyClass.__table__.c.data.type.length, 9)
+ is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+ def test_type_dont_mis_resolve_on_superclass(self):
+ """test for #8859.
+
+ For subclasses of a type that's in the map, don't resolve this
+ by default, even though we do a search through __mro__.
+
+ """
+ global int_sub
+
+ class int_sub(int):
+ pass
+
+ Base = declarative_base(
+ type_annotation_map={
+ int: Integer,
+ }
+ )
+
+ with expect_raises_message(
+ sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+ ):
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[int_sub]
+
+ @testing.variation(
+ "dict_key", ["typing", ("plain", testing.requires.python310)]
+ )
+ def test_type_dont_mis_resolve_on_non_generic(self, dict_key):
+ """test for #8859.
+
+ For a specific generic type with arguments, don't do any MRO
+ lookup.
+
+ """
+
+ Base = declarative_base(
+ type_annotation_map={
+ dict: String,
+ }
+ )
+
+ with expect_raises_message(
+ sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+ ):
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ if dict_key.plain:
+ data: Mapped[dict[str, str]]
+ elif dict_key.typing:
+ data: Mapped[Dict[str, str]]
+
+ def test_type_secondary_resolution(self):
+ class MyString(String):
+ def _resolve_for_python_type(self, python_type, matched_type):
+ return String(length=42)
+
+ Base = declarative_base(type_annotation_map={str: MyString})
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, String))
+ eq_(MyClass.__table__.c.data.type.length, 42)
+
class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
@@ -2200,3 +2355,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase):
select(typ).where(typ.key == "x"),
"SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1",
)
+
+
+class BackendTests(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.variation("native_enum", [True, False])
+ @testing.variation("include_column", [True, False])
+ def test_schema_type_actually_works(
+ self, connection, decl_base, include_column, native_enum
+ ):
+ """test that schema type bindings are set up correctly"""
+
+ global Status
+
+ class Status(enum.Enum):
+ PENDING = "pending"
+ RECEIVED = "received"
+ COMPLETED = "completed"
+
+ if not include_column and not native_enum:
+ decl_base.registry.update_type_annotation_map(
+ {enum.Enum: Enum(enum.Enum, native_enum=False)}
+ )
+
+ class SomeClass(decl_base):
+ __tablename__ = "some_table"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ if include_column:
+ status: Mapped[Status] = mapped_column(
+ Enum(Status, native_enum=bool(native_enum))
+ )
+ else:
+ status: Mapped[Status]
+
+ decl_base.metadata.create_all(connection)
+
+ with Session(connection) as sess:
+ sess.add(SomeClass(id=1, status=Status.RECEIVED))
+ sess.commit()
+
+ eq_(
+ sess.scalars(
+ select(SomeClass.status).where(SomeClass.id == 1)
+ ).first(),
+ Status.RECEIVED,
+ )
diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py
index ba099412f..ffa640d4c 100644
--- a/test/orm/declarative/test_typed_mapping.py
+++ b/test/orm/declarative/test_typed_mapping.py
@@ -1,6 +1,7 @@
import dataclasses
import datetime
from decimal import Decimal
+import enum
from typing import Any
from typing import ClassVar
from typing import Dict
@@ -44,11 +45,13 @@ from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import MappedAsDataclass
from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
from sqlalchemy.orm import undefer
from sqlalchemy.orm import WriteOnlyMapped
from sqlalchemy.orm.collections import attribute_keyed_dict
from sqlalchemy.orm.collections import KeyFuncDict
from sqlalchemy.schema import CreateTable
+from sqlalchemy.sql.sqltypes import Enum
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
@@ -1125,6 +1128,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
id: Mapped[int] = mapped_column(primary_key=True)
data: Mapped["fake"] # noqa
+ @testing.variation("use_callable", [True, False])
+ @testing.variation("include_generic", [True, False])
+ def test_enum_explicit(self, use_callable, include_generic):
+ # anno only: global FooEnum
+
+ class FooEnum(enum.Enum):
+ foo = enum.auto()
+ bar = enum.auto()
+
+ if use_callable:
+ tam = {FooEnum: Enum(FooEnum, length=500)}
+ else:
+ tam = {FooEnum: Enum(FooEnum, length=500)}
+ if include_generic:
+ tam[enum.Enum] = Enum(enum.Enum)
+ Base = declarative_base(type_annotation_map=tam)
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[FooEnum]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+ eq_(MyClass.__table__.c.data.type.length, 500)
+ is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+ def test_enum_generic(self):
+ """test for #8859"""
+ # anno only: global FooEnum
+
+ class FooEnum(enum.Enum):
+ foo = enum.auto()
+ bar = enum.auto()
+
+ Base = declarative_base(
+ type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)}
+ )
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[FooEnum]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+ eq_(MyClass.__table__.c.data.type.length, 42)
+ is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+ def test_enum_default(self, decl_base):
+ """test #8859.
+
+ We now have Enum in the default SQL lookup map, in conjunction with
+ a mechanism that will adapt it for a given enum type.
+
+ This relies on a search through __mro__ for the given type,
+ which in other tests we ensure does not actually function if
+ we aren't dealing with Enum (or some other type that allows for
+ __mro__ lookup)
+
+ """
+ # anno only: global FooEnum
+
+ class FooEnum(enum.Enum):
+ foo = "foo"
+ bar_value = "bar"
+
+ class MyClass(decl_base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[FooEnum]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+ eq_(MyClass.__table__.c.data.type.length, 9)
+ is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+ def test_type_dont_mis_resolve_on_superclass(self):
+ """test for #8859.
+
+ For subclasses of a type that's in the map, don't resolve this
+ by default, even though we do a search through __mro__.
+
+ """
+ # anno only: global int_sub
+
+ class int_sub(int):
+ pass
+
+ Base = declarative_base(
+ type_annotation_map={
+ int: Integer,
+ }
+ )
+
+ with expect_raises_message(
+ sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+ ):
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[int_sub]
+
+ @testing.variation(
+ "dict_key", ["typing", ("plain", testing.requires.python310)]
+ )
+ def test_type_dont_mis_resolve_on_non_generic(self, dict_key):
+ """test for #8859.
+
+ For a specific generic type with arguments, don't do any MRO
+ lookup.
+
+ """
+
+ Base = declarative_base(
+ type_annotation_map={
+ dict: String,
+ }
+ )
+
+ with expect_raises_message(
+ sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+ ):
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ if dict_key.plain:
+ data: Mapped[dict[str, str]]
+ elif dict_key.typing:
+ data: Mapped[Dict[str, str]]
+
+ def test_type_secondary_resolution(self):
+ class MyString(String):
+ def _resolve_for_python_type(self, python_type, matched_type):
+ return String(length=42)
+
+ Base = declarative_base(type_annotation_map={str: MyString})
+
+ class MyClass(Base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str]
+
+ is_true(isinstance(MyClass.__table__.c.data.type, String))
+ eq_(MyClass.__table__.c.data.type.length, 42)
+
class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
@@ -2191,3 +2346,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase):
select(typ).where(typ.key == "x"),
"SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1",
)
+
+
+class BackendTests(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.variation("native_enum", [True, False])
+ @testing.variation("include_column", [True, False])
+ def test_schema_type_actually_works(
+ self, connection, decl_base, include_column, native_enum
+ ):
+ """test that schema type bindings are set up correctly"""
+
+ # anno only: global Status
+
+ class Status(enum.Enum):
+ PENDING = "pending"
+ RECEIVED = "received"
+ COMPLETED = "completed"
+
+ if not include_column and not native_enum:
+ decl_base.registry.update_type_annotation_map(
+ {enum.Enum: Enum(enum.Enum, native_enum=False)}
+ )
+
+ class SomeClass(decl_base):
+ __tablename__ = "some_table"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ if include_column:
+ status: Mapped[Status] = mapped_column(
+ Enum(Status, native_enum=bool(native_enum))
+ )
+ else:
+ status: Mapped[Status]
+
+ decl_base.metadata.create_all(connection)
+
+ with Session(connection) as sess:
+ sess.add(SomeClass(id=1, status=Status.RECEIVED))
+ sess.commit()
+
+ eq_(
+ sess.scalars(
+ select(SomeClass.status).where(SomeClass.id == 1)
+ ).first(),
+ Status.RECEIVED,
+ )