diff options
| author | Federico Caselli <cfederico87@gmail.com> | 2022-11-27 18:11:34 +0100 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-11-29 17:49:27 -0500 |
| commit | 9c9fd31bcea3beaed6d14fde639e65f6b43bea09 (patch) | |
| tree | 2eef4b31c1f89f364c9bf15fdf153a4aad0f98c6 /test | |
| parent | 78833af4e650d37e6257cfbb541e4db56e2a285f (diff) | |
| download | sqlalchemy-9c9fd31bcea3beaed6d14fde639e65f6b43bea09.tar.gz | |
Improve support for enum in mapped classes
Add a new system by which TypeEngine objects have some
say in how the declarative type registry interprets them.
The Enum datatype is the primary target for this but it is
hoped the system may be useful for other types as well.
Fixes: #8859
Change-Id: I15ac3daee770408b5795746f47c1bbd931b7d26d
Diffstat (limited to 'test')
| -rw-r--r-- | test/orm/declarative/test_tm_future_annotations_sync.py | 203 | ||||
| -rw-r--r-- | test/orm/declarative/test_typed_mapping.py | 203 |
2 files changed, 406 insertions, 0 deletions
diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 7358f385d..5d1b6b199 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -10,6 +10,7 @@ from __future__ import annotations import dataclasses import datetime from decimal import Decimal +import enum from typing import Any from typing import ClassVar from typing import Dict @@ -53,11 +54,13 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped from sqlalchemy.orm.collections import attribute_keyed_dict from sqlalchemy.orm.collections import KeyFuncDict from sqlalchemy.schema import CreateTable +from sqlalchemy.sql.sqltypes import Enum from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -1134,6 +1137,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped["fake"] # noqa + @testing.variation("use_callable", [True, False]) + @testing.variation("include_generic", [True, False]) + def test_enum_explicit(self, use_callable, include_generic): + global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + if use_callable: + tam = {FooEnum: Enum(FooEnum, length=500)} + else: + tam = {FooEnum: Enum(FooEnum, length=500)} + if include_generic: + tam[enum.Enum] = Enum(enum.Enum) + Base = declarative_base(type_annotation_map=tam) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 500) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_generic(self): + """test for #8859""" + global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + Base = declarative_base( + type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)} + ) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 42) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_default(self, decl_base): + """test #8859. + + We now have Enum in the default SQL lookup map, in conjunction with + a mechanism that will adapt it for a given enum type. + + This relies on a search through __mro__ for the given type, + which in other tests we ensure does not actually function if + we aren't dealing with Enum (or some other type that allows for + __mro__ lookup) + + """ + global FooEnum + + class FooEnum(enum.Enum): + foo = "foo" + bar_value = "bar" + + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 9) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_type_dont_mis_resolve_on_superclass(self): + """test for #8859. + + For subclasses of a type that's in the map, don't resolve this + by default, even though we do a search through __mro__. + + """ + global int_sub + + class int_sub(int): + pass + + Base = declarative_base( + type_annotation_map={ + int: Integer, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[int_sub] + + @testing.variation( + "dict_key", ["typing", ("plain", testing.requires.python310)] + ) + def test_type_dont_mis_resolve_on_non_generic(self, dict_key): + """test for #8859. + + For a specific generic type with arguments, don't do any MRO + lookup. + + """ + + Base = declarative_base( + type_annotation_map={ + dict: String, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if dict_key.plain: + data: Mapped[dict[str, str]] + elif dict_key.typing: + data: Mapped[Dict[str, str]] + + def test_type_secondary_resolution(self): + class MyString(String): + def _resolve_for_python_type(self, python_type, matched_type): + return String(length=42) + + Base = declarative_base(type_annotation_map={str: MyString}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + is_true(isinstance(MyClass.__table__.c.data.type, String)) + eq_(MyClass.__table__.c.data.type.length, 42) + class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -2200,3 +2355,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase): select(typ).where(typ.key == "x"), "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1", ) + + +class BackendTests(fixtures.TestBase): + __backend__ = True + + @testing.variation("native_enum", [True, False]) + @testing.variation("include_column", [True, False]) + def test_schema_type_actually_works( + self, connection, decl_base, include_column, native_enum + ): + """test that schema type bindings are set up correctly""" + + global Status + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + if not include_column and not native_enum: + decl_base.registry.update_type_annotation_map( + {enum.Enum: Enum(enum.Enum, native_enum=False)} + ) + + class SomeClass(decl_base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if include_column: + status: Mapped[Status] = mapped_column( + Enum(Status, native_enum=bool(native_enum)) + ) + else: + status: Mapped[Status] + + decl_base.metadata.create_all(connection) + + with Session(connection) as sess: + sess.add(SomeClass(id=1, status=Status.RECEIVED)) + sess.commit() + + eq_( + sess.scalars( + select(SomeClass.status).where(SomeClass.id == 1) + ).first(), + Status.RECEIVED, + ) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index ba099412f..ffa640d4c 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1,6 +1,7 @@ import dataclasses import datetime from decimal import Decimal +import enum from typing import Any from typing import ClassVar from typing import Dict @@ -44,11 +45,13 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped from sqlalchemy.orm.collections import attribute_keyed_dict from sqlalchemy.orm.collections import KeyFuncDict from sqlalchemy.schema import CreateTable +from sqlalchemy.sql.sqltypes import Enum from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -1125,6 +1128,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped["fake"] # noqa + @testing.variation("use_callable", [True, False]) + @testing.variation("include_generic", [True, False]) + def test_enum_explicit(self, use_callable, include_generic): + # anno only: global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + if use_callable: + tam = {FooEnum: Enum(FooEnum, length=500)} + else: + tam = {FooEnum: Enum(FooEnum, length=500)} + if include_generic: + tam[enum.Enum] = Enum(enum.Enum) + Base = declarative_base(type_annotation_map=tam) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 500) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_generic(self): + """test for #8859""" + # anno only: global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + Base = declarative_base( + type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)} + ) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 42) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_default(self, decl_base): + """test #8859. + + We now have Enum in the default SQL lookup map, in conjunction with + a mechanism that will adapt it for a given enum type. + + This relies on a search through __mro__ for the given type, + which in other tests we ensure does not actually function if + we aren't dealing with Enum (or some other type that allows for + __mro__ lookup) + + """ + # anno only: global FooEnum + + class FooEnum(enum.Enum): + foo = "foo" + bar_value = "bar" + + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 9) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_type_dont_mis_resolve_on_superclass(self): + """test for #8859. + + For subclasses of a type that's in the map, don't resolve this + by default, even though we do a search through __mro__. + + """ + # anno only: global int_sub + + class int_sub(int): + pass + + Base = declarative_base( + type_annotation_map={ + int: Integer, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[int_sub] + + @testing.variation( + "dict_key", ["typing", ("plain", testing.requires.python310)] + ) + def test_type_dont_mis_resolve_on_non_generic(self, dict_key): + """test for #8859. + + For a specific generic type with arguments, don't do any MRO + lookup. + + """ + + Base = declarative_base( + type_annotation_map={ + dict: String, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if dict_key.plain: + data: Mapped[dict[str, str]] + elif dict_key.typing: + data: Mapped[Dict[str, str]] + + def test_type_secondary_resolution(self): + class MyString(String): + def _resolve_for_python_type(self, python_type, matched_type): + return String(length=42) + + Base = declarative_base(type_annotation_map={str: MyString}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + is_true(isinstance(MyClass.__table__.c.data.type, String)) + eq_(MyClass.__table__.c.data.type.length, 42) + class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -2191,3 +2346,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase): select(typ).where(typ.key == "x"), "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1", ) + + +class BackendTests(fixtures.TestBase): + __backend__ = True + + @testing.variation("native_enum", [True, False]) + @testing.variation("include_column", [True, False]) + def test_schema_type_actually_works( + self, connection, decl_base, include_column, native_enum + ): + """test that schema type bindings are set up correctly""" + + # anno only: global Status + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + if not include_column and not native_enum: + decl_base.registry.update_type_annotation_map( + {enum.Enum: Enum(enum.Enum, native_enum=False)} + ) + + class SomeClass(decl_base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if include_column: + status: Mapped[Status] = mapped_column( + Enum(Status, native_enum=bool(native_enum)) + ) + else: + status: Mapped[Status] + + decl_base.metadata.create_all(connection) + + with Session(connection) as sess: + sess.add(SomeClass(id=1, status=Status.RECEIVED)) + sess.commit() + + eq_( + sess.scalars( + select(SomeClass.status).where(SomeClass.id == 1) + ).first(), + Status.RECEIVED, + ) |
