summaryrefslogtreecommitdiff
path: root/test/engine
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-03-25 19:35:32 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-03-25 19:35:32 +0000
commit59308fd60fd8f7750205063134df4b944d12704b (patch)
treee51596b2f6772322acffb8ab266dd6ca0e394f78 /test/engine
parent221aff778e1eb3c3aa8f8a1f72629177442694bc (diff)
parent57877461c1bd3b43a9d833fbca873d59db36b6f7 (diff)
downloadsqlalchemy-59308fd60fd8f7750205063134df4b944d12704b.tar.gz
Merge "generalize conditional DDL throughout schema / DDL" into main
Diffstat (limited to 'test/engine')
-rw-r--r--test/engine/test_ddlevents.py121
1 files changed, 120 insertions, 1 deletions
diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py
index f339ef171..0c72e32c7 100644
--- a/test/engine/test_ddlevents.py
+++ b/test/engine/test_ddlevents.py
@@ -1,7 +1,11 @@
+from unittest import mock
+from unittest.mock import Mock
+
import sqlalchemy as tsa
from sqlalchemy import create_engine
from sqlalchemy import create_mock_engine
from sqlalchemy import event
+from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import String
@@ -373,7 +377,7 @@ class DDLEventTest(fixtures.TestBase):
eq_(metadata_canary.mock_calls, [])
-class DDLExecutionTest(fixtures.TestBase):
+class DDLExecutionTest(AssertsCompiledSQL, fixtures.TestBase):
def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
@@ -485,6 +489,121 @@ class DDLExecutionTest(fixtures.TestBase):
strings = " ".join(str(x) for x in pg_mock.mock)
assert "my_test_constraint" in strings
+ @testing.combinations(("dialect",), ("callable",), ("callable_w_state",))
+ def test_inline_ddl_if_dialect_name(self, ddl_if_type):
+ nonpg_mock = engines.mock_engine(dialect_name="sqlite")
+ pg_mock = engines.mock_engine(dialect_name="postgresql")
+
+ metadata = MetaData()
+
+ capture_mock = Mock()
+ state = object()
+
+ if ddl_if_type == "dialect":
+ ddl_kwargs = dict(dialect="postgresql")
+ elif ddl_if_type == "callable":
+
+ def is_pg(ddl, target, bind, **kw):
+ capture_mock.is_pg(ddl, target, bind, **kw)
+ return kw["dialect"].name == "postgresql"
+
+ ddl_kwargs = dict(callable_=is_pg)
+ elif ddl_if_type == "callable_w_state":
+
+ def is_pg(ddl, target, bind, **kw):
+ capture_mock.is_pg(ddl, target, bind, **kw)
+ return kw["dialect"].name == "postgresql"
+
+ ddl_kwargs = dict(callable_=is_pg, state=state)
+ else:
+ assert False
+
+ data_col = Column("data", String)
+ t = Table(
+ "a",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("num", Integer),
+ data_col,
+ Index("my_pg_index", data_col).ddl_if(**ddl_kwargs),
+ CheckConstraint("num > 5").ddl_if(**ddl_kwargs),
+ )
+
+ metadata.create_all(nonpg_mock)
+ eq_(len(nonpg_mock.mock), 1)
+ self.assert_compile(
+ nonpg_mock.mock[0],
+ "CREATE TABLE a (id INTEGER NOT NULL, num INTEGER, "
+ "data VARCHAR, PRIMARY KEY (id))",
+ dialect=nonpg_mock.dialect,
+ )
+
+ metadata.create_all(pg_mock)
+
+ eq_(len(pg_mock.mock), 2)
+
+ self.assert_compile(
+ pg_mock.mock[0],
+ "CREATE TABLE a (id SERIAL NOT NULL, num INTEGER, "
+ "data VARCHAR, PRIMARY KEY (id), CHECK (num > 5))",
+ dialect=pg_mock.dialect,
+ )
+ self.assert_compile(
+ pg_mock.mock[1],
+ "CREATE INDEX my_pg_index ON a (data)",
+ dialect="postgresql",
+ )
+
+ the_index = list(t.indexes)[0]
+ the_constraint = list(
+ c for c in t.constraints if isinstance(c, CheckConstraint)
+ )[0]
+
+ if ddl_if_type in ("callable", "callable_w_state"):
+
+ if ddl_if_type == "callable":
+ check_state = None
+ else:
+ check_state = state
+
+ eq_(
+ capture_mock.mock_calls,
+ [
+ mock.call.is_pg(
+ mock.ANY,
+ the_index,
+ mock.ANY,
+ state=check_state,
+ dialect=nonpg_mock.dialect,
+ compiler=None,
+ ),
+ mock.call.is_pg(
+ mock.ANY,
+ the_constraint,
+ None,
+ state=check_state,
+ dialect=nonpg_mock.dialect,
+ compiler=mock.ANY,
+ ),
+ mock.call.is_pg(
+ mock.ANY,
+ the_index,
+ mock.ANY,
+ state=check_state,
+ dialect=pg_mock.dialect,
+ compiler=None,
+ ),
+ mock.call.is_pg(
+ mock.ANY,
+ the_constraint,
+ None,
+ state=check_state,
+ dialect=pg_mock.dialect,
+ compiler=mock.ANY,
+ ),
+ ],
+ )
+
@testing.requires.sqlite
def test_ddl_execute(self):
engine = create_engine("sqlite:///")