summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-05-09 14:56:55 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-05-09 14:56:55 +0000
commit946e71efdfc93777027f4fd7360a524051be393d (patch)
treed64871a5acca93a629c1f62ef40ab11dbbbc38f4
parentddd25a03743543ed9a7f0a9516d3bfa2528b9fce (diff)
parent4a62625d99470c8928422c4822df5234b93b6bb8 (diff)
downloadsqlalchemy-946e71efdfc93777027f4fd7360a524051be393d.tar.gz
Merge "implement FromLinter for UPDATE, DELETE statements" into main
-rw-r--r--doc/build/changelog/unreleased_20/9721.rst16
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py4
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py72
-rw-r--r--lib/sqlalchemy/sql/crud.py2
-rw-r--r--test/sql/test_from_linter.py86
-rw-r--r--test/sql/test_update.py53
7 files changed, 218 insertions, 19 deletions
diff --git a/doc/build/changelog/unreleased_20/9721.rst b/doc/build/changelog/unreleased_20/9721.rst
new file mode 100644
index 000000000..2a2b29f84
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9721.rst
@@ -0,0 +1,16 @@
+.. change::
+ :tags: usecase, sql
+ :tickets: 9721
+
+ Implemented the "cartesian product warning" for UPDATE and DELETE
+ statements, those which include multiple tables that are not correlated
+ together in some way.
+
+.. change::
+ :tags: bug, sql
+
+ Fixed issue where :func:`_dml.update` construct that included multiple
+ tables and no VALUES clause would raise with an internal error. Current
+ behavior for :class:`_dml.Update` with no values is to generate a SQL
+ UPDATE statement with an empty "set" clause, so this has been made
+ consistent for this specific sub-case.
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index b33ce4aec..aa319e239 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -2425,13 +2425,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
for t in [from_table] + extra_froms
)
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
ashint = True
return from_table._compiler_dispatch(
- self, asfrom=True, iscrud=True, ashint=ashint
+ self, asfrom=True, iscrud=True, ashint=ashint, **kw
)
def delete_extra_from_clause(
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 2ed2bbc7a..ae40fea99 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1657,13 +1657,13 @@ class MySQLCompiler(compiler.SQLCompiler):
):
return None
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
ashint = True
return from_table._compiler_dispatch(
- self, asfrom=True, iscrud=True, ashint=ashint
+ self, asfrom=True, iscrud=True, ashint=ashint, **kw
)
def delete_extra_from_clause(
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 554a84112..619ff0848 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -710,7 +710,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
else:
return None, None
- def warn(self):
+ def warn(self, stmt_type="SELECT"):
the_rest, start_with = self.lint()
# FROMS left over? boom
@@ -719,7 +719,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
froms = the_rest
if froms:
template = (
- "SELECT statement has a cartesian product between "
+ "{stmt_type} statement has a cartesian product between "
"FROM element(s) {froms} and "
'FROM element "{start}". Apply join condition(s) '
"between each element to resolve."
@@ -728,7 +728,9 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
f'"{self.froms[from_]}"' for from_ in froms
)
message = template.format(
- froms=froms_str, start=self.froms[start_with]
+ stmt_type=stmt_type,
+ froms=froms_str,
+ start=self.froms[start_with],
)
util.warn(message)
@@ -5997,6 +5999,7 @@ class SQLCompiler(Compiled):
)
def visit_update(self, update_stmt, **kw):
+
compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
@@ -6010,6 +6013,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
@@ -6040,7 +6052,11 @@ class SQLCompiler(Compiled):
)
table_text = self.update_tables_clause(
- update_stmt, update_stmt.table, render_extra_froms, **kw
+ update_stmt,
+ update_stmt.table,
+ render_extra_froms,
+ from_linter=from_linter,
+ **kw,
)
crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, toplevel, **kw
@@ -6081,6 +6097,7 @@ class SQLCompiler(Compiled):
update_stmt.table,
render_extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6088,7 +6105,7 @@ class SQLCompiler(Compiled):
if update_stmt._where_criteria:
t = self._generate_delimited_and_list(
- update_stmt._where_criteria, **kw
+ update_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6110,6 +6127,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="UPDATE")
+
self.stack.pop(-1)
return text
@@ -6130,8 +6151,10 @@ class SQLCompiler(Compiled):
"criteria within DELETE"
)
- def delete_table_clause(self, delete_stmt, from_table, extra_froms):
- return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, **kw
+ )
def visit_delete(self, delete_stmt, **kw):
compile_state = delete_stmt._compile_state_factory(
@@ -6147,6 +6170,15 @@ class SQLCompiler(Compiled):
if not self.compile_state:
self.compile_state = compile_state
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
extra_froms = compile_state._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
@@ -6166,9 +6198,22 @@ class SQLCompiler(Compiled):
)
text += "FROM "
- table_text = self.delete_table_clause(
- delete_stmt, delete_stmt.table, extra_froms
- )
+
+ try:
+ table_text = self.delete_table_clause(
+ delete_stmt,
+ delete_stmt.table,
+ extra_froms,
+ from_linter=from_linter,
+ )
+ except TypeError:
+ # anticipate 3rd party dialects that don't include **kw
+ # TODO: remove in 2.1
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
+ if from_linter:
+ _ = self.process(delete_stmt.table, from_linter=from_linter)
crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
@@ -6199,6 +6244,7 @@ class SQLCompiler(Compiled):
delete_stmt.table,
extra_froms,
dialect_hints,
+ from_linter=from_linter,
**kw,
)
if extra_from_text:
@@ -6206,7 +6252,7 @@ class SQLCompiler(Compiled):
if delete_stmt._where_criteria:
t = self._generate_delimited_and_list(
- delete_stmt._where_criteria, **kw
+ delete_stmt._where_criteria, from_linter=from_linter, **kw
)
if t:
text += " WHERE " + t
@@ -6224,6 +6270,10 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = self._render_cte_clause(nesting_level=nesting_level) + text
+ if warn_linting:
+ assert from_linter is not None
+ from_linter.warn(stmt_type="DELETE")
+
self.stack.pop(-1)
return text
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 563f61c04..16d5ce494 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -1344,7 +1344,7 @@ def _get_update_multitable_params(
):
normalized_params = {
coercions.expect(roles.DMLColumnRole, c): param
- for c, param in stmt_parameter_tuples
+ for c, param in stmt_parameter_tuples or ()
}
include_table = compile_state.include_table_with_column_exprs
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
index 49370b1e6..9a471d571 100644
--- a/test/sql/test_from_linter.py
+++ b/test/sql/test_from_linter.py
@@ -1,4 +1,5 @@
from sqlalchemy import column
+from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import JSON
@@ -7,6 +8,7 @@ from sqlalchemy import sql
from sqlalchemy import table
from sqlalchemy import testing
from sqlalchemy import true
+from sqlalchemy import update
from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import expect_warnings
@@ -382,18 +384,54 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
froms, start = find_unmatching_froms(query)
assert not froms
+ @testing.variation("dml", ["update", "delete"])
+ @testing.combinations(
+ (False, False), (True, False), (True, True), argnames="twotable,error"
+ )
+ def test_dml(self, dml, twotable, error):
+ if dml.update:
+ stmt = update(self.a)
+ elif dml.delete:
+ stmt = delete(self.a)
+ else:
+ dml.fail()
+
+ stmt = stmt.where(self.a.c.col_a == "a1")
+ if twotable:
+ stmt = stmt.where(self.b.c.col_b == "a1")
+
+ if not error:
+ stmt = stmt.where(self.b.c.col_b == self.a.c.col_a)
+
+ froms, _ = find_unmatching_froms(stmt)
+ if error:
+ assert froms
+ else:
+ assert not froms
+
+
+class TestLinterRoundTrip(fixtures.TablesTest):
+ __backend__ = True
-class TestLinter(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
- Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+ Table(
+ "table_a",
+ metadata,
+ Column("col_a", Integer, primary_key=True, autoincrement=False),
+ )
+ Table(
+ "table_b",
+ metadata,
+ Column("col_b", Integer, primary_key=True, autoincrement=False),
+ )
@classmethod
def setup_bind(cls):
# from linting is enabled by default
return config.db
+ @testing.only_on("sqlite")
def test_noop_for_unhandled_objects(self):
with self.bind.connect() as conn:
conn.exec_driver_sql("SELECT 1;").fetchone()
@@ -429,6 +467,7 @@ class TestLinter(fixtures.TablesTest):
with self.bind.connect() as conn:
conn.execute(query)
+ @testing.requires.ctes
def test_warn_anon_cte(self):
a, b = self.tables("table_a", "table_b")
@@ -444,6 +483,47 @@ class TestLinter(fixtures.TablesTest):
with self.bind.connect() as conn:
conn.execute(query)
+ @testing.variation(
+ "dml",
+ [
+ ("update", testing.requires.update_from),
+ ("delete", testing.requires.delete_using),
+ ],
+ )
+ @testing.combinations(
+ (False, False), (True, False), (True, True), argnames="twotable,error"
+ )
+ def test_warn_dml(self, dml, twotable, error):
+ a, b = self.tables("table_a", "table_b")
+
+ if dml.update:
+ stmt = update(a).values(col_a=5)
+ elif dml.delete:
+ stmt = delete(a)
+ else:
+ dml.fail()
+
+ stmt = stmt.where(a.c.col_a == 1)
+ if twotable:
+ stmt = stmt.where(b.c.col_b == 1)
+
+ if not error:
+ stmt = stmt.where(b.c.col_b == a.c.col_a)
+
+ stmt_type = "UPDATE" if dml.update else "DELETE"
+
+ with self.bind.connect() as conn:
+ if error:
+ with expect_warnings(
+ rf"{stmt_type} statement has a cartesian product between "
+ rf'FROM element\(s\) "table_[ab]" and FROM '
+ rf'element "table_[ab]"'
+ ):
+ with self.bind.connect() as conn:
+ conn.execute(stmt)
+ else:
+ conn.execute(stmt)
+
def test_no_linting(self, metadata, connection):
eng = engines.testing_engine(
options={"enable_from_linting": False, "use_reaper": False}
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index ef8f117bc..d8de5c277 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -113,6 +113,59 @@ class _UpdateFromTestBase:
class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
__dialect__ = "default_enhanced"
+ @testing.variation("twotable", [True, False])
+ @testing.variation("values", ["none", "blank"])
+ def test_update_no_params(self, values, twotable):
+ """test issue identified while doing #9721
+
+ UPDATE with empty VALUES but multiple tables would raise a
+ NoneType error; fixed this to emit an empty "SET" the way a single
+ table UPDATE currently does.
+
+ both cases should probably raise CompileError, however this could
+ be backwards incompatible with current use cases (such as other test
+ suites)
+
+ """
+
+ table1 = self.tables.mytable
+ table2 = self.tables.myothertable
+
+ stmt = table1.update().where(table1.c.name == "jill")
+ if twotable:
+ stmt = stmt.where(table2.c.otherid == table1.c.myid)
+
+ if values.blank:
+ stmt = stmt.values()
+
+ if twotable:
+ if values.blank:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET FROM myothertable "
+ "WHERE mytable.name = :name_1 "
+ "AND myothertable.otherid = mytable.myid",
+ )
+ elif values.none:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET myid=:myid, name=:name, "
+ "description=:description FROM myothertable "
+ "WHERE mytable.name = :name_1 "
+ "AND myothertable.otherid = mytable.myid",
+ )
+ elif values.blank:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET WHERE mytable.name = :name_1",
+ )
+ elif values.none:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable SET myid=:myid, name=:name, "
+ "description=:description WHERE mytable.name = :name_1",
+ )
+
def test_update_literal_binds(self):
table1 = self.tables.mytable