diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2022-12-17 02:02:33 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-12-17 02:02:33 +0000 |
commit | e7e51af5b61c49d5198e31dfd0ef04e8941551eb (patch) | |
tree | a47665bbcda2a450aec8c3a8ff30a5d7873f5988 | |
parent | e84cc158c469f17c90f2e058ed72595bc3be5cdb (diff) | |
parent | f8fd9ce23350c1f8fad13ff78506b100670a5896 (diff) | |
download | sqlalchemy-e7e51af5b61c49d5198e31dfd0ef04e8941551eb.tar.gz |
Merge "ensure all visit methods accept **kw" into main
-rw-r--r-- | doc/build/changelog/unreleased_20/8988.rst | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 24 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_dialect.py | 51 |
8 files changed, 107 insertions, 48 deletions
diff --git a/doc/build/changelog/unreleased_20/8988.rst b/doc/build/changelog/unreleased_20/8988.rst new file mode 100644 index 000000000..b5300c1b4 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8988.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 8988 + + Added test support to ensure that all compiler ``visit_xyz()`` methods + across all :class:`.Compiler` implementations in SQLAlchemy accept a + ``**kw`` parameter, so that all compilers accept additional keyword + arguments under all circumstances. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index aa640727f..08b76206a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2237,12 +2237,12 @@ class MSSQLCompiler(compiler.SQLCompiler): field = self.extract_map.get(extract.field, extract.field) return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) - def visit_savepoint(self, savepoint_stmt): + def visit_savepoint(self, savepoint_stmt, **kw): return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( savepoint_stmt ) - def visit_rollback_to_savepoint(self, savepoint_stmt): + def visit_rollback_to_savepoint(self, savepoint_stmt, **kw): return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( savepoint_stmt ) @@ -2393,7 +2393,7 @@ class MSSQLCompiler(compiler.SQLCompiler): for t in [from_table] + extra_froms ) - def visit_empty_set_expr(self, type_): + def visit_empty_set_expr(self, type_, **kw): return "SELECT 1 WHERE 1!=1" def visit_is_distinct_from_binary(self, binary, operator, **kw): @@ -2581,7 +2581,7 @@ class MSDDLCompiler(compiler.DDLCompiler): return colspec - def visit_create_index(self, create, include_schema=False): + def visit_create_index(self, create, include_schema=False, **kw): index = create.element self._verify_index_table(index) preparer = self.preparer @@ -2633,13 +2633,13 @@ class MSDDLCompiler(compiler.DDLCompiler): return text - def visit_drop_index(self, drop): + def visit_drop_index(self, drop, **kw): return "\nDROP INDEX %s ON %s" % ( self._prepared_index_name(drop.element, include_schema=False), self.preparer.format_table(drop.element.table), ) - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -2662,7 +2662,7 @@ class MSDDLCompiler(compiler.DDLCompiler): text += self.define_constraint_deferrability(constraint) return text - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): if len(constraint) == 0: return "" text = "" @@ -2685,7 +2685,7 @@ class MSDDLCompiler(compiler.DDLCompiler): text += self.define_constraint_deferrability(constraint) return text - def visit_computed_column(self, generated): + def visit_computed_column(self, generated, **kw): text = "AS (%s)" % self.sql_compiler.process( generated.sqltext, include_table=False, literal_binds=True ) @@ -2694,7 +2694,7 @@ class MSDDLCompiler(compiler.DDLCompiler): text += " PERSISTED" return text - def visit_set_table_comment(self, create): + def visit_set_table_comment(self, create, **kw): schema = self.preparer.schema_for_object(create.element) schema_name = schema if schema else self.dialect.default_schema_name return ( @@ -2708,7 +2708,7 @@ class MSDDLCompiler(compiler.DDLCompiler): ) ) - def visit_drop_table_comment(self, drop): + def visit_drop_table_comment(self, drop, **kw): schema = self.preparer.schema_for_object(drop.element) schema_name = schema if schema else self.dialect.default_schema_name return ( @@ -2719,7 +2719,7 @@ class MSDDLCompiler(compiler.DDLCompiler): ) ) - def visit_set_column_comment(self, create): + def visit_set_column_comment(self, create, **kw): schema = self.preparer.schema_for_object(create.element.table) schema_name = schema if schema else self.dialect.default_schema_name return ( @@ -2736,7 +2736,7 @@ class MSDDLCompiler(compiler.DDLCompiler): ) ) - def visit_drop_column_comment(self, drop): + def visit_drop_column_comment(self, drop, **kw): schema = self.preparer.schema_for_object(drop.element.table) schema_name = schema if schema else self.dialect.default_schema_name return ( diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 2525c6c32..f965eac15 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1663,7 +1663,7 @@ class MySQLCompiler(compiler.SQLCompiler): for t in [from_table] + extra_froms ) - def visit_empty_set_expr(self, element_types): + def visit_empty_set_expr(self, element_types, **kw): return ( "SELECT %(outer)s FROM (SELECT %(inner)s) " "as _empty_set WHERE 1!=1" @@ -1962,14 +1962,14 @@ class MySQLDDLCompiler(compiler.DDLCompiler): return text - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): text = super().visit_primary_key_constraint(constraint) using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) return text - def visit_drop_index(self, drop): + def visit_drop_index(self, drop, **kw): index = drop.element text = "\nDROP INDEX " if drop.if_exists: @@ -1980,7 +1980,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): self.preparer.format_table(index.table), ) - def visit_drop_constraint(self, drop): + def visit_drop_constraint(self, drop, **kw): constraint = drop.element if isinstance(constraint, sa_schema.ForeignKeyConstraint): qual = "FOREIGN KEY " @@ -2014,7 +2014,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): ) return "" - def visit_set_table_comment(self, create): + def visit_set_table_comment(self, create, **kw): return "ALTER TABLE %s COMMENT %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( @@ -2022,12 +2022,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler): ), ) - def visit_drop_table_comment(self, create): + def visit_drop_table_comment(self, create, **kw): return "ALTER TABLE %s COMMENT ''" % ( self.preparer.format_table(create.element) ) - def visit_set_column_comment(self, create): + def visit_set_column_comment(self, create, **kw): return "ALTER TABLE %s CHANGE %s %s" % ( self.preparer.format_table(create.element.table), self.preparer.format_column(create.element), diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index dc2b011af..d6f65e5ed 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1189,7 +1189,7 @@ class OracleCompiler(compiler.SQLCompiler): def limit_clause(self, select, **kw): return "" - def visit_empty_set_expr(self, type_): + def visit_empty_set_expr(self, type_, **kw): return "SELECT 1 FROM DUAL WHERE 1!=1" def for_update_clause(self, select, **kw): @@ -1279,12 +1279,12 @@ class OracleDDLCompiler(compiler.DDLCompiler): return text - def visit_drop_table_comment(self, drop): + def visit_drop_table_comment(self, drop, **kw): return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table( drop.element ) - def visit_create_index(self, create): + def visit_create_index(self, create, **kw): index = create.element self._verify_index_table(index) preparer = self.preparer @@ -1336,7 +1336,7 @@ class OracleDDLCompiler(compiler.DDLCompiler): text = text.replace("NO ORDER", "NOORDER") return text - def visit_computed_column(self, generated): + def visit_computed_column(self, generated, **kw): text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( generated.sqltext, include_table=False, literal_binds=True ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 8287e828a..3fb29812b 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1836,7 +1836,7 @@ class PGCompiler(compiler.SQLCompiler): self.process(flags, **kw), ) - def visit_empty_set_expr(self, element_types): + def visit_empty_set_expr(self, element_types, **kw): # cast the empty set to the type we are comparing against. if # we are comparing against the null type, pick an arbitrary # datatype for the empty set @@ -2144,7 +2144,7 @@ class PGDDLCompiler(compiler.DDLCompiler): not_valid = constraint.dialect_options["postgresql"]["not_valid"] return " NOT VALID" if not_valid else "" - def visit_check_constraint(self, constraint): + def visit_check_constraint(self, constraint, **kw): if constraint._type_bound: typ = list(constraint.columns)[0].type if ( @@ -2162,12 +2162,12 @@ class PGDDLCompiler(compiler.DDLCompiler): text += self._define_constraint_validity(constraint) return text - def visit_foreign_key_constraint(self, constraint): + def visit_foreign_key_constraint(self, constraint, **kw): text = super().visit_foreign_key_constraint(constraint) text += self._define_constraint_validity(constraint) return text - def visit_create_enum_type(self, create): + def visit_create_enum_type(self, create, **kw): type_ = create.element return "CREATE TYPE %s AS ENUM (%s)" % ( @@ -2178,12 +2178,12 @@ class PGDDLCompiler(compiler.DDLCompiler): ), ) - def visit_drop_enum_type(self, drop): + def visit_drop_enum_type(self, drop, **kw): type_ = drop.element return "DROP TYPE %s" % (self.preparer.format_type(type_)) - def visit_create_domain_type(self, create): + def visit_create_domain_type(self, create, **kw): domain: DOMAIN = create.element options = [] @@ -2211,11 +2211,11 @@ class PGDDLCompiler(compiler.DDLCompiler): f"{' '.join(options)}" ) - def visit_drop_domain_type(self, drop): + def visit_drop_domain_type(self, drop, **kw): domain = drop.element return f"DROP DOMAIN {self.preparer.format_type(domain)}" - def visit_create_index(self, create): + def visit_create_index(self, create, **kw): preparer = self.preparer index = create.element self._verify_index_table(index) @@ -2303,7 +2303,7 @@ class PGDDLCompiler(compiler.DDLCompiler): return text - def visit_drop_index(self, drop): + def visit_drop_index(self, drop, **kw): index = drop.element text = "\nDROP INDEX " @@ -2382,7 +2382,7 @@ class PGDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) - def visit_computed_column(self, generated): + def visit_computed_column(self, generated, **kw): if generated.persisted is False: raise exc.CompileError( "PostrgreSQL computed columns do not support 'virtual' " diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 5d8b3fbad..5a0761e5f 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1423,12 +1423,12 @@ class SQLiteCompiler(compiler.SQLCompiler): self.process(binary.right, **kw), ) - def visit_empty_set_op_expr(self, type_, expand_op): + def visit_empty_set_op_expr(self, type_, expand_op, **kw): # slightly old SQLite versions don't seem to be able to handle # the empty set impl return self.visit_empty_set_expr(type_) - def visit_empty_set_expr(self, element_types): + def visit_empty_set_expr(self, element_types, **kw): return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % ( ", ".join("1" for type_ in element_types or [INTEGER()]), ", ".join("1" for type_ in element_types or [INTEGER()]), @@ -1595,7 +1595,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return colspec - def visit_primary_key_constraint(self, constraint): + def visit_primary_key_constraint(self, constraint, **kw): # for columns with sqlite_autoincrement=True, # the PRIMARY KEY constraint can only be inline # with the column itself. @@ -1624,7 +1624,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text - def visit_unique_constraint(self, constraint): + def visit_unique_constraint(self, constraint, **kw): text = super().visit_unique_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ @@ -1642,7 +1642,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text - def visit_check_constraint(self, constraint): + def visit_check_constraint(self, constraint, **kw): text = super().visit_check_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ @@ -1654,7 +1654,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text - def visit_column_check_constraint(self, constraint): + def visit_column_check_constraint(self, constraint, **kw): text = super().visit_column_check_constraint(constraint) if constraint.dialect_options["sqlite"]["on_conflict"] is not None: @@ -1665,7 +1665,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text - def visit_foreign_key_constraint(self, constraint): + def visit_foreign_key_constraint(self, constraint, **kw): local_table = constraint.elements[0].parent.table remote_table = constraint.elements[0].column.table @@ -1681,7 +1681,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return preparer.format_table(table, use_schema=False) def visit_create_index( - self, create, include_schema=False, include_table_schema=True + self, create, include_schema=False, include_table_schema=True, **kw ): index = create.element self._verify_index_table(index) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 596ca986f..895e9724c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -731,7 +731,7 @@ class Compiled: else: raise exc.ObjectNotExecutableError(self.statement) - def visit_unsupported_compilation(self, element, err): + def visit_unsupported_compilation(self, element, err, **kw): raise exc.UnsupportedCompilationError(self, type(element)) from err @property @@ -2909,7 +2909,7 @@ class SQLCompiler(Compiled): binary, OPERATORS[operator], **kw ) - def visit_empty_set_op_expr(self, type_, expand_op): + def visit_empty_set_op_expr(self, type_, expand_op, **kw): if expand_op is operators.not_in_op: if len(type_) > 1: return "(%s)) OR (1 = 1" % ( @@ -2927,7 +2927,7 @@ class SQLCompiler(Compiled): else: return self.visit_empty_set_expr(type_) - def visit_empty_set_expr(self, element_types): + def visit_empty_set_expr(self, element_types, **kw): raise NotImplementedError( "Dialect '%s' does not support empty set expression." % self.dialect.name @@ -5687,15 +5687,15 @@ class SQLCompiler(Compiled): return text - def visit_savepoint(self, savepoint_stmt): + def visit_savepoint(self, savepoint_stmt, **kw): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - def visit_rollback_to_savepoint(self, savepoint_stmt): + def visit_rollback_to_savepoint(self, savepoint_stmt, **kw): return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( savepoint_stmt ) - def visit_release_savepoint(self, savepoint_stmt): + def visit_release_savepoint(self, savepoint_stmt, **kw): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( savepoint_stmt ) @@ -5783,7 +5783,7 @@ class StrSQLCompiler(SQLCompiler): for t in extra_froms ) - def visit_empty_set_expr(self, type_): + def visit_empty_set_expr(self, type_, **kw): return "SELECT 1 WHERE 1!=1" def get_from_hint_text(self, table, text): diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 945edef85..38fe8f9c4 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -1,12 +1,15 @@ # mypy: ignore-errors +import importlib + from . import testing from .. import assert_raises from .. import config from .. import engines from .. import eq_ from .. import fixtures +from .. import is_not_none from .. import is_true from .. import ne_ from .. import provide_metadata @@ -17,12 +20,15 @@ from ..provision import set_default_schema_on_connection from ..schema import Column from ..schema import Table from ... import bindparam +from ... import dialects from ... import event from ... import exc from ... import Integer from ... import literal_column from ... import select from ... import String +from ...sql.compiler import Compiled +from ...util import inspect_getfullargspec class PingTest(fixtures.TestBase): @@ -35,6 +41,51 @@ class PingTest(fixtures.TestBase): ) +class ArgSignatureTest(fixtures.TestBase): + """test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have + ``**kw``, for #8988. + + This test uses runtime code inspection. Does not need to be a + ``__backend__`` test as it only needs to run once provided all target + dialects have been imported. + + For third party dialects, the suite would be run with that third + party as a "--dburi", which means its compiler classes will have been + imported by the time this test runs. + + """ + + def _all_subclasses(): # type: ignore # noqa + for d in dialects.__all__: + if not d.startswith("_"): + importlib.import_module("sqlalchemy.dialects.%s" % d) + + stack = [Compiled] + + while stack: + cls = stack.pop(0) + stack.extend(cls.__subclasses__()) + yield cls + + @testing.fixture(params=list(_all_subclasses())) + def all_subclasses(self, request): + yield request.param + + def test_all_visit_methods_accept_kw(self, all_subclasses): + cls = all_subclasses + + for k in cls.__dict__: + if k.startswith("visit_"): + meth = getattr(cls, k) + + insp = inspect_getfullargspec(meth) + is_not_none( + insp.varkw, + f"Compiler visit method {cls.__name__}.{k}() does " + "not accommodate for **kw in its argument signature", + ) + + class ExceptionTest(fixtures.TablesTest): """Test basic exception wrapping. |