summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-04-28 12:07:09 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2023-05-09 10:08:52 -0400
commit4a62625d99470c8928422c4822df5234b93b6bb8 (patch)
tree280182818aea6846f1294705357b6a0754d51df4
parent39c8e95b1f50190ff30a836b2bcf13ba2cacc052 (diff)
downloadsqlalchemy-4a62625d99470c8928422c4822df5234b93b6bb8.tar.gz
implement FromLinter for UPDATE, DELETE statements
Implemented the "cartesian product warning" for UPDATE and DELETE statements, those which include multiple tables that are not correlated together in some way. Fixed issue where :func:`_dml.update` construct that included multiple tables and no VALUES clause would raise with an internal error. Current behavior for :class:`_dml.Update` with no values is to generate a SQL UPDATE statement with an empty "set" clause, so this has been made consistent for this specific sub-case. Fixes: #9721 Change-Id: I556639811cc930d2e37532965d2ae751882af921
-rw-r--r--doc/build/changelog/unreleased_20/9721.rst16
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py4
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py72
-rw-r--r--lib/sqlalchemy/sql/crud.py2
-rw-r--r--test/sql/test_from_linter.py86
-rw-r--r--test/sql/test_update.py53
7 files changed, 218 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_20/9721.rst b/doc/build/changelog/unreleased_20/9721.rst
new file mode 100644
index 000000000..2a2b29f84
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9721.rst
@@ -0,0 +1,16 @@
+.. change::
+ :tags: usecase, sql
+ :tickets: 9721
+
+ Implemented the "cartesian product warning" for UPDATE and DELETE
+ statements, those which include multiple tables that are not correlated
+ together in some way.
+
+.. change::
+ :tags: bug, sql
+
+ Fixed issue where :func:`_dml.update` construct that included multiple
+ tables and no VALUES clause would raise with an internal error. Current
+ behavior for :class:`_dml.Update` with no values is to generate a SQL
+ UPDATE statement with an empty "set" clause, so this has been made
+ consistent for this specific sub-case.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index b33ce4aec..aa319e239 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2425,13 +2425,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
for t in [from_table] + extra_froms
)
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
ashint = True
return from_table._compiler_dispatch(
- self, asfrom=True, iscrud=True, ashint=ashint
+ self, asfrom=True, iscrud=True, ashint=ashint, **kw
)
def delete_extra_from_clause(
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 2ed2bbc7a..ae40fea99 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1657,13 +1657,13 @@ class MySQLCompiler(compiler.SQLCompiler):
):
return None
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
ashint = True
return from_table._compiler_dispatch(
- self, asfrom=True, iscrud=True, ashint=ashint
+ self, asfrom=True, iscrud=True, ashint=ashint, **kw
)
def delete_extra_from_clause(
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 554a84112..619ff0848 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
else:
return None, None
- def warn(self):
+ def warn(self, stmt_type="SELECT"):
the_rest, start_with = self.lint()
# FROMS left over? boom
@@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
froms = the_rest
if froms:
template = (
- "SELECT statement has a cartesian product between "
+ "{stmt_type} statement has a cartesian product between "
"FROM element(s) {froms} and "
'FROM element "{start}". Apply join condition(s) '
"between each element to resolve."
@@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
f'"{self.froms[from_]}"' for from_ in froms
)
message = template.format(
- froms=froms_str, start=self.froms[start_with]
+ stmt_type=stmt_type,
+ froms=froms_str,
+ start=self.froms[start_with],
)
util.warn(message)
@@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
+
compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
@@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
@@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled):
)
table_text = self.update_tables_clause(
- update_stmt, update_stmt.table, render_extra_froms, **kw
+ update_stmt,
+ update_stmt.table,
+ render_extra_froms,
+ from_linter=from_linter,
+ **kw,
)
crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, toplevel, **kw
@@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled):
update_stmt.table,
render_extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled):
if update_stmt._where_criteria:
t = self._generate_delimited_and_list(
- update_stmt._where_criteria, **kw
+ update_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="UPDATE")
+
self.stack.pop(-1)
return text
@@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled):
"criteria within DELETE"
)
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
- return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, **kw
+ )
def visit_delete(self, delete_stmt, **kw):
compile_state = delete_stmt._compile_state_factory(
@@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
@@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled):
)
text += "FROM "
- table_text = self.delete_table_clause(
- delete_stmt, delete_stmt.table, extra_froms
- )
+
+ try:
+ table_text = self.delete_table_clause(
+ delete_stmt,
+ delete_stmt.table,
+ extra_froms,
+ from_linter=from_linter,
+ )
+ except TypeError:
+ # anticipate 3rd party dialects that don't include **kw
+ # TODO: remove in 2.1
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
+ if from_linter:
+ _ = self.process(delete_stmt.table, from_linter=from_linter)
crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
@@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled):
delete_stmt.table,
extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled):
if delete_stmt._where_criteria:
t = self._generate_delimited_and_list(
- delete_stmt._where_criteria, **kw
+ delete_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="DELETE")
+
self.stack.pop(-1)
return text
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 563f61c04..16d5ce494 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -1344,7 +1344,7 @@ def _get_update_multitable_params(
):
normalized_params = {
coercions.expect(roles.DMLColumnRole, c): param
- for c, param in stmt_parameter_tuples
+ for c, param in stmt_parameter_tuples or ()
}
include_table = compile_state.include_table_with_column_exprs
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
index 49370b1e6..9a471d571 100644
--- a/test/sql/test_from_linter.py
+++ b/test/sql/test_from_linter.py
@@ -1,4 +1,5 @@
from sqlalchemy import column
+from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import JSON
@@ -7,6 +8,7 @@ from sqlalchemy import sql
from sqlalchemy import table
from sqlalchemy import testing
from sqlalchemy import true
+from sqlalchemy import update
from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import expect_warnings
@@ -382,18 +384,54 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
froms, start = find_unmatching_froms(query)
assert not froms
+ @testing.variation("dml", ["update", "delete"])
+ @testing.combinations(
+ (False, False), (True, False), (True, True), argnames="twotable,error"
+ )
+ def test_dml(self, dml, twotable, error):
+ if dml.update:
+ stmt = update(self.a)
+ elif dml.delete:
+ stmt = delete(self.a)
+ else:
+ dml.fail()
+
+ stmt = stmt.where(self.a.c.col_a == "a1")
+ if twotable:
+ stmt = stmt.where(self.b.c.col_b == "a1")
+
+ if not error:
+ stmt = stmt.where(self.b.c.col_b == self.a.c.col_a)
+
+ froms, _ = find_unmatching_froms(stmt)
+ if error:
+ assert froms
+ else:
+ assert not froms
+
+
+class TestLinterRoundTrip(fixtures.TablesTest):
+ __backend__ = True
-class TestLinter(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
- Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+ Table(
+ "table_a",
+ metadata,
+ Column("col_a", Integer, primary_key=True, autoincrement=False),
+ )
+ Table(
+ "table_b",
+ metadata,
+ Column("col_b", Integer, primary_key=True, autoincrement=False),
+ )
@classmethod
def setup_bind(cls):
# from linting is enabled by default
return config.db
+ @testing.only_on("sqlite")
def test_noop_for_unhandled_objects(self):
with self.bind.connect() as conn:
conn.exec_driver_sql("SELECT 1;").fetchone()
@@ -429,6 +467,7 @@ class TestLinter(fixtures.TablesTest):
with self.bind.connect() as conn:
conn.execute(query)
+ @testing.requires.ctes
def test_warn_anon_cte(self):
a, b = self.tables("table_a", "table_b")
@@ -444,6 +483,47 @@ class TestLinter(fixtures.TablesTest):
with self.bind.connect() as conn:
conn.execute(query)
+ @testing.variation(
+ "dml",
+ [
+ ("update", testing.requires.update_from),
+ ("delete", testing.requires.delete_using),
+ ],
+ )
+ @testing.combinations(
+ (False, False), (True, False), (True, True), argnames="twotable,error"
+ )
+ def test_warn_dml(self, dml, twotable, error):
+ a, b = self.tables("table_a", "table_b")
+
+ if dml.update:
+ stmt = update(a).values(col_a=5)
+ elif dml.delete:
+ stmt = delete(a)
+ else:
+ dml.fail()
+
+ stmt = stmt.where(a.c.col_a == 1)
+ if twotable:
+ stmt = stmt.where(b.c.col_b == 1)
+
+ if not error:
+ stmt = stmt.where(b.c.col_b == a.c.col_a)
+
+ stmt_type = "UPDATE" if dml.update else "DELETE"
+
+ with self.bind.connect() as conn:
+ if error:
+ with expect_warnings(
+ rf"{stmt_type} statement has a cartesian product between "
+ rf'FROM element\(s\) "table_[ab]" and FROM '
+ rf'element "table_[ab]"'
+ ):
+ with self.bind.connect() as conn:
+ conn.execute(stmt)
+ else:
+ conn.execute(stmt)
+
def test_no_linting(self, metadata, connection):
eng = engines.testing_engine(
options={"enable_from_linting": False, "use_reaper": False}
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index ef8f117bc..d8de5c277 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -113,6 +113,59 @@ class _UpdateFromTestBase:
class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
__dialect__ = "default_enhanced"
+ @testing.variation("twotable", [True, False])
+ @testing.variation("values", ["none", "blank"])
+ def test_update_no_params(self, values, twotable):
+ """test issue identified while doing #9721
+
+ UPDATE with empty VALUES but multiple tables would raise a
+ NoneType error; fixed this to emit an empty "SET" the way a single
+ table UPDATE currently does.
+
+ both cases should probably raise CompileError, however this could
+ be backwards incompatible with current use cases (such as other test
+ suites)
+
+ """
+
+ table1 = self.tables.mytable
+ table2 = self.tables.myothertable
+
+ stmt = table1.update().where(table1.c.name == "jill")
+ if twotable:
+ stmt = stmt.where(table2.c.otherid == table1.c.myid)
+
+ if values.blank:
+ stmt = stmt.values()
+
+ if twotable:
+ if values.blank:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET FROM myothertable "
+ "WHERE mytable.name = :name_1 "
+ "AND myothertable.otherid = mytable.myid",
+ )
+ elif values.none:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET myid=:myid, name=:name, "
+ "description=:description FROM myothertable "
+ "WHERE mytable.name = :name_1 "
+ "AND myothertable.otherid = mytable.myid",
+ )
+ elif values.blank:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET WHERE mytable.name = :name_1",
+ )
+ elif values.none:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET myid=:myid, name=:name, "
+ "description=:description WHERE mytable.name = :name_1",
+ )
+
def test_update_literal_binds(self):
table1 = self.tables.mytable