summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-12-17 02:02:33 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-12-17 02:02:33 +0000
commite7e51af5b61c49d5198e31dfd0ef04e8941551eb (patch)
treea47665bbcda2a450aec8c3a8ff30a5d7873f5988
parente84cc158c469f17c90f2e058ed72595bc3be5cdb (diff)
parentf8fd9ce23350c1f8fad13ff78506b100670a5896 (diff)
downloadsqlalchemy-e7e51af5b61c49d5198e31dfd0ef04e8941551eb.tar.gz
Merge "ensure all visit methods accept **kw" into main
-rw-r--r--doc/build/changelog/unreleased_20/8988.rst8
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py24
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py14
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py8
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py20
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py16
-rw-r--r--lib/sqlalchemy/sql/compiler.py14
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py51
8 files changed, 107 insertions, 48 deletions
diff --git a/doc/build/changelog/unreleased_20/8988.rst b/doc/build/changelog/unreleased_20/8988.rst
new file mode 100644
index 000000000..b5300c1b4
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8988.rst
@@ -0,0 +1,8 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 8988
+
+ Added test support to ensure that all compiler ``visit_xyz()`` methods
+ across all :class:`.Compiler` implementations in SQLAlchemy accept a
+ ``**kw`` parameter, so that all compilers accept additional keyword
+ arguments under all circumstances.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index aa640727f..08b76206a 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2237,12 +2237,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
field = self.extract_map.get(extract.field, extract.field)
return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw))
- def visit_savepoint(self, savepoint_stmt):
+ def visit_savepoint(self, savepoint_stmt, **kw):
return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(
savepoint_stmt
)
- def visit_rollback_to_savepoint(self, savepoint_stmt):
+ def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(
savepoint_stmt
)
@@ -2393,7 +2393,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
for t in [from_table] + extra_froms
)
- def visit_empty_set_expr(self, type_):
+ def visit_empty_set_expr(self, type_, **kw):
return "SELECT 1 WHERE 1!=1"
def visit_is_distinct_from_binary(self, binary, operator, **kw):
@@ -2581,7 +2581,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
return colspec
- def visit_create_index(self, create, include_schema=False):
+ def visit_create_index(self, create, include_schema=False, **kw):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
@@ -2633,13 +2633,13 @@ class MSDDLCompiler(compiler.DDLCompiler):
return text
- def visit_drop_index(self, drop):
+ def visit_drop_index(self, drop, **kw):
return "\nDROP INDEX %s ON %s" % (
self._prepared_index_name(drop.element, include_schema=False),
self.preparer.format_table(drop.element.table),
)
- def visit_primary_key_constraint(self, constraint):
+ def visit_primary_key_constraint(self, constraint, **kw):
if len(constraint) == 0:
return ""
text = ""
@@ -2662,7 +2662,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_unique_constraint(self, constraint):
+ def visit_unique_constraint(self, constraint, **kw):
if len(constraint) == 0:
return ""
text = ""
@@ -2685,7 +2685,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_computed_column(self, generated):
+ def visit_computed_column(self, generated, **kw):
text = "AS (%s)" % self.sql_compiler.process(
generated.sqltext, include_table=False, literal_binds=True
)
@@ -2694,7 +2694,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
text += " PERSISTED"
return text
- def visit_set_table_comment(self, create):
+ def visit_set_table_comment(self, create, **kw):
schema = self.preparer.schema_for_object(create.element)
schema_name = schema if schema else self.dialect.default_schema_name
return (
@@ -2708,7 +2708,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
)
)
- def visit_drop_table_comment(self, drop):
+ def visit_drop_table_comment(self, drop, **kw):
schema = self.preparer.schema_for_object(drop.element)
schema_name = schema if schema else self.dialect.default_schema_name
return (
@@ -2719,7 +2719,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
)
)
- def visit_set_column_comment(self, create):
+ def visit_set_column_comment(self, create, **kw):
schema = self.preparer.schema_for_object(create.element.table)
schema_name = schema if schema else self.dialect.default_schema_name
return (
@@ -2736,7 +2736,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
)
)
- def visit_drop_column_comment(self, drop):
+ def visit_drop_column_comment(self, drop, **kw):
schema = self.preparer.schema_for_object(drop.element.table)
schema_name = schema if schema else self.dialect.default_schema_name
return (
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 2525c6c32..f965eac15 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1663,7 +1663,7 @@ class MySQLCompiler(compiler.SQLCompiler):
for t in [from_table] + extra_froms
)
- def visit_empty_set_expr(self, element_types):
+ def visit_empty_set_expr(self, element_types, **kw):
return (
"SELECT %(outer)s FROM (SELECT %(inner)s) "
"as _empty_set WHERE 1!=1"
@@ -1962,14 +1962,14 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
return text
- def visit_primary_key_constraint(self, constraint):
+ def visit_primary_key_constraint(self, constraint, **kw):
text = super().visit_primary_key_constraint(constraint)
using = constraint.dialect_options["mysql"]["using"]
if using:
text += " USING %s" % (self.preparer.quote(using))
return text
- def visit_drop_index(self, drop):
+ def visit_drop_index(self, drop, **kw):
index = drop.element
text = "\nDROP INDEX "
if drop.if_exists:
@@ -1980,7 +1980,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
self.preparer.format_table(index.table),
)
- def visit_drop_constraint(self, drop):
+ def visit_drop_constraint(self, drop, **kw):
constraint = drop.element
if isinstance(constraint, sa_schema.ForeignKeyConstraint):
qual = "FOREIGN KEY "
@@ -2014,7 +2014,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
)
return ""
- def visit_set_table_comment(self, create):
+ def visit_set_table_comment(self, create, **kw):
return "ALTER TABLE %s COMMENT %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
@@ -2022,12 +2022,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
),
)
- def visit_drop_table_comment(self, create):
+ def visit_drop_table_comment(self, create, **kw):
return "ALTER TABLE %s COMMENT ''" % (
self.preparer.format_table(create.element)
)
- def visit_set_column_comment(self, create):
+ def visit_set_column_comment(self, create, **kw):
return "ALTER TABLE %s CHANGE %s %s" % (
self.preparer.format_table(create.element.table),
self.preparer.format_column(create.element),
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index dc2b011af..d6f65e5ed 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -1189,7 +1189,7 @@ class OracleCompiler(compiler.SQLCompiler):
def limit_clause(self, select, **kw):
return ""
- def visit_empty_set_expr(self, type_):
+ def visit_empty_set_expr(self, type_, **kw):
return "SELECT 1 FROM DUAL WHERE 1!=1"
def for_update_clause(self, select, **kw):
@@ -1279,12 +1279,12 @@ class OracleDDLCompiler(compiler.DDLCompiler):
return text
- def visit_drop_table_comment(self, drop):
+ def visit_drop_table_comment(self, drop, **kw):
return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table(
drop.element
)
- def visit_create_index(self, create):
+ def visit_create_index(self, create, **kw):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
@@ -1336,7 +1336,7 @@ class OracleDDLCompiler(compiler.DDLCompiler):
text = text.replace("NO ORDER", "NOORDER")
return text
- def visit_computed_column(self, generated):
+ def visit_computed_column(self, generated, **kw):
text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
generated.sqltext, include_table=False, literal_binds=True
)
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 8287e828a..3fb29812b 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1836,7 +1836,7 @@ class PGCompiler(compiler.SQLCompiler):
self.process(flags, **kw),
)
- def visit_empty_set_expr(self, element_types):
+ def visit_empty_set_expr(self, element_types, **kw):
# cast the empty set to the type we are comparing against. if
# we are comparing against the null type, pick an arbitrary
# datatype for the empty set
@@ -2144,7 +2144,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
not_valid = constraint.dialect_options["postgresql"]["not_valid"]
return " NOT VALID" if not_valid else ""
- def visit_check_constraint(self, constraint):
+ def visit_check_constraint(self, constraint, **kw):
if constraint._type_bound:
typ = list(constraint.columns)[0].type
if (
@@ -2162,12 +2162,12 @@ class PGDDLCompiler(compiler.DDLCompiler):
text += self._define_constraint_validity(constraint)
return text
- def visit_foreign_key_constraint(self, constraint):
+ def visit_foreign_key_constraint(self, constraint, **kw):
text = super().visit_foreign_key_constraint(constraint)
text += self._define_constraint_validity(constraint)
return text
- def visit_create_enum_type(self, create):
+ def visit_create_enum_type(self, create, **kw):
type_ = create.element
return "CREATE TYPE %s AS ENUM (%s)" % (
@@ -2178,12 +2178,12 @@ class PGDDLCompiler(compiler.DDLCompiler):
),
)
- def visit_drop_enum_type(self, drop):
+ def visit_drop_enum_type(self, drop, **kw):
type_ = drop.element
return "DROP TYPE %s" % (self.preparer.format_type(type_))
- def visit_create_domain_type(self, create):
+ def visit_create_domain_type(self, create, **kw):
domain: DOMAIN = create.element
options = []
@@ -2211,11 +2211,11 @@ class PGDDLCompiler(compiler.DDLCompiler):
f"{' '.join(options)}"
)
- def visit_drop_domain_type(self, drop):
+ def visit_drop_domain_type(self, drop, **kw):
domain = drop.element
return f"DROP DOMAIN {self.preparer.format_type(domain)}"
- def visit_create_index(self, create):
+ def visit_create_index(self, create, **kw):
preparer = self.preparer
index = create.element
self._verify_index_table(index)
@@ -2303,7 +2303,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
return text
- def visit_drop_index(self, drop):
+ def visit_drop_index(self, drop, **kw):
index = drop.element
text = "\nDROP INDEX "
@@ -2382,7 +2382,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
return "".join(table_opts)
- def visit_computed_column(self, generated):
+ def visit_computed_column(self, generated, **kw):
if generated.persisted is False:
raise exc.CompileError(
"PostrgreSQL computed columns do not support 'virtual' "
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 5d8b3fbad..5a0761e5f 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -1423,12 +1423,12 @@ class SQLiteCompiler(compiler.SQLCompiler):
self.process(binary.right, **kw),
)
- def visit_empty_set_op_expr(self, type_, expand_op):
+ def visit_empty_set_op_expr(self, type_, expand_op, **kw):
# slightly old SQLite versions don't seem to be able to handle
# the empty set impl
return self.visit_empty_set_expr(type_)
- def visit_empty_set_expr(self, element_types):
+ def visit_empty_set_expr(self, element_types, **kw):
return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % (
", ".join("1" for type_ in element_types or [INTEGER()]),
", ".join("1" for type_ in element_types or [INTEGER()]),
@@ -1595,7 +1595,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return colspec
- def visit_primary_key_constraint(self, constraint):
+ def visit_primary_key_constraint(self, constraint, **kw):
# for columns with sqlite_autoincrement=True,
# the PRIMARY KEY constraint can only be inline
# with the column itself.
@@ -1624,7 +1624,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return text
- def visit_unique_constraint(self, constraint):
+ def visit_unique_constraint(self, constraint, **kw):
text = super().visit_unique_constraint(constraint)
on_conflict_clause = constraint.dialect_options["sqlite"][
@@ -1642,7 +1642,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return text
- def visit_check_constraint(self, constraint):
+ def visit_check_constraint(self, constraint, **kw):
text = super().visit_check_constraint(constraint)
on_conflict_clause = constraint.dialect_options["sqlite"][
@@ -1654,7 +1654,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return text
- def visit_column_check_constraint(self, constraint):
+ def visit_column_check_constraint(self, constraint, **kw):
text = super().visit_column_check_constraint(constraint)
if constraint.dialect_options["sqlite"]["on_conflict"] is not None:
@@ -1665,7 +1665,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return text
- def visit_foreign_key_constraint(self, constraint):
+ def visit_foreign_key_constraint(self, constraint, **kw):
local_table = constraint.elements[0].parent.table
remote_table = constraint.elements[0].column.table
@@ -1681,7 +1681,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
return preparer.format_table(table, use_schema=False)
def visit_create_index(
- self, create, include_schema=False, include_table_schema=True
+ self, create, include_schema=False, include_table_schema=True, **kw
):
index = create.element
self._verify_index_table(index)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 596ca986f..895e9724c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -731,7 +731,7 @@ class Compiled:
else:
raise exc.ObjectNotExecutableError(self.statement)
- def visit_unsupported_compilation(self, element, err):
+ def visit_unsupported_compilation(self, element, err, **kw):
raise exc.UnsupportedCompilationError(self, type(element)) from err
@property
@@ -2909,7 +2909,7 @@ class SQLCompiler(Compiled):
binary, OPERATORS[operator], **kw
)
- def visit_empty_set_op_expr(self, type_, expand_op):
+ def visit_empty_set_op_expr(self, type_, expand_op, **kw):
if expand_op is operators.not_in_op:
if len(type_) > 1:
return "(%s)) OR (1 = 1" % (
@@ -2927,7 +2927,7 @@ class SQLCompiler(Compiled):
else:
return self.visit_empty_set_expr(type_)
- def visit_empty_set_expr(self, element_types):
+ def visit_empty_set_expr(self, element_types, **kw):
raise NotImplementedError(
"Dialect '%s' does not support empty set expression."
% self.dialect.name
@@ -5687,15 +5687,15 @@ class SQLCompiler(Compiled):
return text
- def visit_savepoint(self, savepoint_stmt):
+ def visit_savepoint(self, savepoint_stmt, **kw):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
- def visit_rollback_to_savepoint(self, savepoint_stmt):
+ def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
savepoint_stmt
)
- def visit_release_savepoint(self, savepoint_stmt):
+ def visit_release_savepoint(self, savepoint_stmt, **kw):
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
savepoint_stmt
)
@@ -5783,7 +5783,7 @@ class StrSQLCompiler(SQLCompiler):
for t in extra_froms
)
- def visit_empty_set_expr(self, type_):
+ def visit_empty_set_expr(self, type_, **kw):
return "SELECT 1 WHERE 1!=1"
def get_from_hint_text(self, table, text):
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
index 945edef85..38fe8f9c4 100644
--- a/lib/sqlalchemy/testing/suite/test_dialect.py
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -1,12 +1,15 @@
# mypy: ignore-errors
+import importlib
+
from . import testing
from .. import assert_raises
from .. import config
from .. import engines
from .. import eq_
from .. import fixtures
+from .. import is_not_none
from .. import is_true
from .. import ne_
from .. import provide_metadata
@@ -17,12 +20,15 @@ from ..provision import set_default_schema_on_connection
from ..schema import Column
from ..schema import Table
from ... import bindparam
+from ... import dialects
from ... import event
from ... import exc
from ... import Integer
from ... import literal_column
from ... import select
from ... import String
+from ...sql.compiler import Compiled
+from ...util import inspect_getfullargspec
class PingTest(fixtures.TestBase):
@@ -35,6 +41,51 @@ class PingTest(fixtures.TestBase):
)
+class ArgSignatureTest(fixtures.TestBase):
+ """test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have
+ ``**kw``, for #8988.
+
+ This test uses runtime code inspection. Does not need to be a
+ ``__backend__`` test as it only needs to run once provided all target
+ dialects have been imported.
+
+ For third party dialects, the suite would be run with that third
+ party as a "--dburi", which means its compiler classes will have been
+ imported by the time this test runs.
+
+ """
+
+ def _all_subclasses(): # type: ignore # noqa
+ for d in dialects.__all__:
+ if not d.startswith("_"):
+ importlib.import_module("sqlalchemy.dialects.%s" % d)
+
+ stack = [Compiled]
+
+ while stack:
+ cls = stack.pop(0)
+ stack.extend(cls.__subclasses__())
+ yield cls
+
+ @testing.fixture(params=list(_all_subclasses()))
+ def all_subclasses(self, request):
+ yield request.param
+
+ def test_all_visit_methods_accept_kw(self, all_subclasses):
+ cls = all_subclasses
+
+ for k in cls.__dict__:
+ if k.startswith("visit_"):
+ meth = getattr(cls, k)
+
+ insp = inspect_getfullargspec(meth)
+ is_not_none(
+ insp.varkw,
+ f"Compiler visit method {cls.__name__}.{k}() does "
+ "not accommodate for **kw in its argument signature",
+ )
+
+
class ExceptionTest(fixtures.TablesTest):
"""Test basic exception wrapping.