diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2023-05-09 14:56:55 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-05-09 14:56:55 +0000 |
commit | 946e71efdfc93777027f4fd7360a524051be393d (patch) | |
tree | d64871a5acca93a629c1f62ef40ab11dbbbc38f4 | |
parent | ddd25a03743543ed9a7f0a9516d3bfa2528b9fce (diff) | |
parent | 4a62625d99470c8928422c4822df5234b93b6bb8 (diff) | |
download | sqlalchemy-946e71efdfc93777027f4fd7360a524051be393d.tar.gz |
Merge "implement FromLinter for UPDATE, DELETE statements" into main
-rw-r--r-- | doc/build/changelog/unreleased_20/9721.rst | 16 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 72 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 2 | ||||
-rw-r--r-- | test/sql/test_from_linter.py | 86 | ||||
-rw-r--r-- | test/sql/test_update.py | 53 |
7 files changed, 218 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_20/9721.rst b/doc/build/changelog/unreleased_20/9721.rst new file mode 100644 index 000000000..2a2b29f84 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9721.rst @@ -0,0 +1,16 @@ +.. change:: + :tags: usecase, sql + :tickets: 9721 + + Implemented the "cartesian product warning" for UPDATE and DELETE + statements, those which include multiple tables that are not correlated + together in some way. + +.. change:: + :tags: bug, sql + + Fixed issue where :func:`_dml.update` construct that included multiple + tables and no VALUES clause would raise with an internal error. Current + behavior for :class:`_dml.Update` with no values is to generate a SQL + UPDATE statement with an empty "set" clause, so this has been made + consistent for this specific sub-case. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index b33ce4aec..aa319e239 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2425,13 +2425,13 @@ class MSSQLCompiler(compiler.SQLCompiler): for t in [from_table] + extra_froms ) - def delete_table_clause(self, delete_stmt, from_table, extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: ashint = True return from_table._compiler_dispatch( - self, asfrom=True, iscrud=True, ashint=ashint + self, asfrom=True, iscrud=True, ashint=ashint, **kw ) def delete_extra_from_clause( diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 2ed2bbc7a..ae40fea99 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1657,13 +1657,13 @@ class MySQLCompiler(compiler.SQLCompiler): ): return None - def delete_table_clause(self, delete_stmt, from_table, extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: ashint = True return from_table._compiler_dispatch( - self, asfrom=True, iscrud=True, ashint=ashint + self, asfrom=True, iscrud=True, ashint=ashint, **kw ) def delete_extra_from_clause( diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 554a84112..619ff0848 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): else: return None, None - def warn(self): + def warn(self, stmt_type="SELECT"): the_rest, start_with = self.lint() # FROMS left over? boom @@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): froms = the_rest if froms: template = ( - "SELECT statement has a cartesian product between " + "{stmt_type} statement has a cartesian product between " "FROM element(s) {froms} and " 'FROM element "{start}". Apply join condition(s) ' "between each element to resolve." @@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): f'"{self.froms[from_]}"' for from_ in froms ) message = template.format( - froms=froms_str, start=self.froms[start_with] + stmt_type=stmt_type, + froms=froms_str, + start=self.froms[start_with], ) util.warn(message) @@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled): ) def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_factory( update_stmt, self, **kw ) @@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms is_multitable = bool(extra_froms) @@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled): ) table_text = self.update_tables_clause( - update_stmt, update_stmt.table, render_extra_froms, **kw + update_stmt, + update_stmt.table, + render_extra_froms, + from_linter=from_linter, + **kw, ) crud_params_struct = crud._get_crud_params( self, update_stmt, compile_state, toplevel, **kw @@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled): update_stmt.table, render_extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled): if update_stmt._where_criteria: t = self._generate_delimited_and_list( - update_stmt._where_criteria, **kw + update_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="UPDATE") + self.stack.pop(-1) return text @@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled): "criteria within DELETE" ) - def delete_table_clause(self, delete_stmt, from_table, extra_froms): - return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, **kw + ) def visit_delete(self, delete_stmt, **kw): compile_state = delete_stmt._compile_state_factory( @@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled): if not self.compile_state: self.compile_state = compile_state + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + extra_froms = compile_state._extra_froms correlate_froms = {delete_stmt.table}.union(extra_froms) @@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled): ) text += "FROM " - table_text = self.delete_table_clause( - delete_stmt, delete_stmt.table, extra_froms - ) + + try: + table_text = self.delete_table_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + from_linter=from_linter, + ) + except TypeError: + # anticipate 3rd party dialects that don't include **kw + # TODO: remove in 2.1 + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) + if from_linter: + _ = self.process(delete_stmt.table, from_linter=from_linter) crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) @@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled): delete_stmt.table, extra_froms, dialect_hints, + from_linter=from_linter, **kw, ) if extra_from_text: @@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled): if delete_stmt._where_criteria: t = self._generate_delimited_and_list( - delete_stmt._where_criteria, **kw + delete_stmt._where_criteria, from_linter=from_linter, **kw ) if t: text += " WHERE " + t @@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = self._render_cte_clause(nesting_level=nesting_level) + text + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="DELETE") + self.stack.pop(-1) return text diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 563f61c04..16d5ce494 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -1344,7 +1344,7 @@ def _get_update_multitable_params( ): normalized_params = { coercions.expect(roles.DMLColumnRole, c): param - for c, param in stmt_parameter_tuples + for c, param in stmt_parameter_tuples or () } include_table = compile_state.include_table_with_column_exprs diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py index 49370b1e6..9a471d571 100644 --- a/test/sql/test_from_linter.py +++ b/test/sql/test_from_linter.py @@ -1,4 +1,5 @@ from sqlalchemy import column +from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import JSON @@ -7,6 +8,7 @@ from sqlalchemy import sql from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import true +from sqlalchemy import update from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import expect_warnings @@ -382,18 +384,54 @@ class TestFindUnmatchingFroms(fixtures.TablesTest): froms, start = find_unmatching_froms(query) assert not froms + @testing.variation("dml", ["update", "delete"]) + @testing.combinations( + (False, False), (True, False), (True, True), argnames="twotable,error" + ) + def test_dml(self, dml, twotable, error): + if dml.update: + stmt = update(self.a) + elif dml.delete: + stmt = delete(self.a) + else: + dml.fail() + + stmt = stmt.where(self.a.c.col_a == "a1") + if twotable: + stmt = stmt.where(self.b.c.col_b == "a1") + + if not error: + stmt = stmt.where(self.b.c.col_b == self.a.c.col_a) + + froms, _ = find_unmatching_froms(stmt) + if error: + assert froms + else: + assert not froms + + +class TestLinterRoundTrip(fixtures.TablesTest): + __backend__ = True -class TestLinter(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) - Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + Table( + "table_a", + metadata, + Column("col_a", Integer, primary_key=True, autoincrement=False), + ) + Table( + "table_b", + metadata, + Column("col_b", Integer, primary_key=True, autoincrement=False), + ) @classmethod def setup_bind(cls): # from linting is enabled by default return config.db + @testing.only_on("sqlite") def test_noop_for_unhandled_objects(self): with self.bind.connect() as conn: conn.exec_driver_sql("SELECT 1;").fetchone() @@ -429,6 +467,7 @@ class TestLinter(fixtures.TablesTest): with self.bind.connect() as conn: conn.execute(query) + @testing.requires.ctes def test_warn_anon_cte(self): a, b = self.tables("table_a", "table_b") @@ -444,6 +483,47 @@ class TestLinter(fixtures.TablesTest): with self.bind.connect() as conn: conn.execute(query) + @testing.variation( + "dml", + [ + ("update", testing.requires.update_from), + ("delete", testing.requires.delete_using), + ], + ) + @testing.combinations( + (False, False), (True, False), (True, True), argnames="twotable,error" + ) + def test_warn_dml(self, dml, twotable, error): + a, b = self.tables("table_a", "table_b") + + if dml.update: + stmt = update(a).values(col_a=5) + elif dml.delete: + stmt = delete(a) + else: + dml.fail() + + stmt = stmt.where(a.c.col_a == 1) + if twotable: + stmt = stmt.where(b.c.col_b == 1) + + if not error: + stmt = stmt.where(b.c.col_b == a.c.col_a) + + stmt_type = "UPDATE" if dml.update else "DELETE" + + with self.bind.connect() as conn: + if error: + with expect_warnings( + rf"{stmt_type} statement has a cartesian product between " + rf'FROM element\(s\) "table_[ab]" and FROM ' + rf'element "table_[ab]"' + ): + with self.bind.connect() as conn: + conn.execute(stmt) + else: + conn.execute(stmt) + def test_no_linting(self, metadata, connection): eng = engines.testing_engine( options={"enable_from_linting": False, "use_reaper": False} diff --git a/test/sql/test_update.py b/test/sql/test_update.py index ef8f117bc..d8de5c277 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -113,6 +113,59 @@ class _UpdateFromTestBase: class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = "default_enhanced" + @testing.variation("twotable", [True, False]) + @testing.variation("values", ["none", "blank"]) + def test_update_no_params(self, values, twotable): + """test issue identified while doing #9721 + + UPDATE with empty VALUES but multiple tables would raise a + NoneType error; fixed this to emit an empty "SET" the way a single + table UPDATE currently does. + + both cases should probably raise CompileError, however this could + be backwards incompatible with current use cases (such as other test + suites) + + """ + + table1 = self.tables.mytable + table2 = self.tables.myothertable + + stmt = table1.update().where(table1.c.name == "jill") + if twotable: + stmt = stmt.where(table2.c.otherid == table1.c.myid) + + if values.blank: + stmt = stmt.values() + + if twotable: + if values.blank: + self.assert_compile( + stmt, + "UPDATE mytable SET FROM myothertable " + "WHERE mytable.name = :name_1 " + "AND myothertable.otherid = mytable.myid", + ) + elif values.none: + self.assert_compile( + stmt, + "UPDATE mytable SET myid=:myid, name=:name, " + "description=:description FROM myothertable " + "WHERE mytable.name = :name_1 " + "AND myothertable.otherid = mytable.myid", + ) + elif values.blank: + self.assert_compile( + stmt, + "UPDATE mytable SET WHERE mytable.name = :name_1", + ) + elif values.none: + self.assert_compile( + stmt, + "UPDATE mytable SET myid=:myid, name=:name, " + "description=:description WHERE mytable.name = :name_1", + ) + def test_update_literal_binds(self): table1 = self.tables.mytable |