From 26c0e8e1846a4e6ac05c15a1ad188a5655b72edb Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 16 Jul 2022 16:19:15 -0400 Subject: implement column._merge() this takes the user-defined args of one Column and merges them into the not-user-defined args of another Column. Implemented within the pep-593 column transfer operation to begin to make this new feature more robust. work may still be needed for constraints etc. but in theory everything from the left side annotated column should take effect for the right side if not otherwise specified on the right. Change-Id: I57eb37ed6ceb4b60979a35cfc4b63731d990911d --- test/sql/test_metadata.py | 128 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) (limited to 'test/sql') diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 33b6e130f..b7913e606 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -3,6 +3,7 @@ import pickle import sqlalchemy as tsa from sqlalchemy import ARRAY +from sqlalchemy import BigInteger from sqlalchemy import bindparam from sqlalchemy import BLANK_SCHEMA from sqlalchemy import Boolean @@ -10,6 +11,7 @@ from sqlalchemy import CheckConstraint from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import ColumnDefault +from sqlalchemy import Computed from sqlalchemy import desc from sqlalchemy import Enum from sqlalchemy import event @@ -17,9 +19,11 @@ from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import Numeric from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema from sqlalchemy import select @@ -4182,6 +4186,130 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): deregister(schema.CreateColumn) + @testing.combinations( + ("default", lambda ctx: 10), + ("default", func.foo()), + ("identity_gen", Identity()), + ("identity_gen", Sequence("some_seq")), + ("identity_gen", Computed("side * side")), + ("onupdate", lambda ctx: 10), + ("onupdate", func.foo()), + ("server_onupdate", func.foo()), + ("server_default", func.foo()), + ("nullable", True), + ("nullable", False), + ("type", BigInteger()), + ("type", Enum("one", "two", "three", create_constraint=True)), + argnames="paramname, value", + ) + def test_merge_column( + self, + paramname, + value, + ): + + args = [] + params = {} + if paramname == "type" or isinstance( + value, (Computed, Sequence, Identity) + ): + args.append(value) + else: + params[paramname] = value + + source = Column(*args, **params) + + target = Column() + + source._merge(target) + + if isinstance(value, (Computed, Identity)): + default = target.server_default + assert isinstance(default, type(value)) + elif isinstance(value, Sequence): + default = target.default + assert isinstance(default, type(value)) + + elif paramname in ( + "default", + "onupdate", + "server_default", + "server_onupdate", + ): + default = getattr(target, paramname) + is_(default.arg, value) + is_(default.column, target) + elif paramname == "type": + assert type(target.type) is type(value) + + if isinstance(target.type, Enum): + target.name = "data" + t = Table("t", MetaData(), target) + assert CheckConstraint in [type(c) for c in t.constraints] + else: + is_(getattr(target, paramname), value) + + @testing.combinations( + ("default", lambda ctx: 10, lambda ctx: 15), + ("default", func.foo(), func.bar()), + ("identity_gen", Identity(), Identity()), + ("identity_gen", Sequence("some_seq"), Sequence("some_other_seq")), + ("identity_gen", Computed("side * side"), Computed("top / top")), + ("onupdate", lambda ctx: 10, lambda ctx: 15), + ("onupdate", func.foo(), func.bar()), + ("server_onupdate", func.foo(), func.bar()), + ("server_default", func.foo(), func.bar()), + ("nullable", True, False), + ("nullable", False, True), + ("type", BigInteger(), Numeric()), + argnames="paramname, value, override_value", + ) + def test_dont_merge_column( + self, + paramname, + value, + override_value, + ): + + args = [] + params = {} + override_args = [] + override_params = {} + if paramname == "type" or isinstance( + value, (Computed, Sequence, Identity) + ): + args.append(value) + override_args.append(override_value) + else: + params[paramname] = value + override_params[paramname] = override_value + + source = Column(*args, **params) + + target = Column(*override_args, **override_params) + + source._merge(target) + + if isinstance(value, Sequence): + default = target.default + assert default is override_value + elif isinstance(value, (Computed, Identity)): + default = target.server_default + assert default is override_value + elif paramname in ( + "default", + "onupdate", + "server_default", + "server_onupdate", + ): + default = getattr(target, paramname) + is_(default.arg, override_value) + is_(default.column, target) + elif paramname == "type": + assert type(target.type) is type(override_value) + else: + is_(getattr(target, paramname), override_value) + class ColumnDefaultsTest(fixtures.TestBase): -- cgit v1.2.1