summaryrefslogtreecommitdiff
path: root/test/sql
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-02-25 17:48:30 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-02-25 17:48:30 +0000
commit60fca2ac8cf44bdaf68552ab5c69854a6776c73c (patch)
tree2b9bc005223c6c58009762dc120fccf309c1ba92 /test/sql
parent2d97c388eae4345840f745337ec033045651b36d (diff)
parent0fe8f4a3e79c8fc805e7a84849920c7258177f41 (diff)
downloadsqlalchemy-60fca2ac8cf44bdaf68552ab5c69854a6776c73c.tar.gz
Merge "Add more nesting features to add_cte()" into main
Diffstat (limited to 'test/sql')
-rw-r--r--test/sql/test_cte.py293
1 files changed, 293 insertions, 0 deletions
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index b05692504..2ee6fa9f3 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -1,11 +1,13 @@
from sqlalchemy import Column
from sqlalchemy import delete
+from sqlalchemy import exc
from sqlalchemy import Integer
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
+from sqlalchemy import true
from sqlalchemy import update
from sqlalchemy.dialects import mssql
from sqlalchemy.engine import default
@@ -25,6 +27,7 @@ from sqlalchemy.sql.visitors import cloned_traverse
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
@@ -1869,6 +1872,21 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT cte.outer_cte FROM cte",
)
+ def test_select_with_nesting_cte_in_cte_w_add_cte(self):
+ nesting_cte = select(literal(1).label("inner_cte")).cte("nesting")
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte"))
+ .add_cte(nesting_cte, nest_here=True)
+ .cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
+ "SELECT cte.outer_cte FROM cte",
+ )
+
def test_select_with_aliased_nesting_cte_in_cte(self):
nesting_cte = (
select(literal(1).label("inner_cte"))
@@ -1887,6 +1905,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT cte.outer_cte FROM cte",
)
+ def test_select_with_aliased_nesting_cte_in_cte_w_add_cte(self):
+ inner_nesting_cte = select(literal(1).label("inner_cte")).cte(
+ "nesting"
+ )
+ outer_cte = select().add_cte(inner_nesting_cte, nest_here=True)
+ nesting_cte = inner_nesting_cte.alias("aliased_nested")
+ outer_cte = outer_cte.add_columns(
+ nesting_cte.c.inner_cte.label("outer_cte")
+ ).cte("cte")
+ stmt = select(outer_cte)
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
+ "SELECT aliased_nested.inner_cte AS outer_cte "
+ "FROM nesting AS aliased_nested) "
+ "SELECT cte.outer_cte FROM cte",
+ )
+
def test_nesting_cte_in_cte_with_same_name(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"some_cte", nesting=True
@@ -1904,6 +1941,23 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT some_cte.outer_cte FROM some_cte",
)
+ def test_nesting_cte_in_cte_with_same_name_w_add_cte(self):
+ nesting_cte = select(literal(1).label("inner_cte")).cte("some_cte")
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte"))
+ .add_cte(nesting_cte, nest_here=True)
+ .cte("some_cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH some_cte AS (WITH some_cte AS "
+ "(SELECT :param_1 AS inner_cte) "
+ "SELECT some_cte.inner_cte AS outer_cte "
+ "FROM some_cte) "
+ "SELECT some_cte.outer_cte FROM some_cte",
+ )
+
def test_nesting_cte_at_top_level(self):
nesting_cte = select(literal(1).label("val")).cte(
"nesting_cte", nesting=True
@@ -1918,6 +1972,20 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
" SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
)
+ def test_nesting_cte_at_top_level_w_add_cte(self):
+ nesting_cte = select(literal(1).label("val")).cte("nesting_cte")
+ cte = select(literal(2).label("val")).cte("cte")
+ stmt = select(nesting_cte.c.val, cte.c.val).add_cte(
+ nesting_cte, nest_here=True
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH nesting_cte AS (SELECT :param_1 AS val)"
+ ", cte AS (SELECT :param_2 AS val)"
+ " SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
+ )
+
def test_double_nesting_cte_in_cte(self):
"""
Validate that the SELECT in the 2nd nesting CTE does not render
@@ -1950,6 +2018,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
") SELECT cte.outer_1, cte.outer_2 FROM cte",
)
+ def test_double_nesting_cte_in_cte_w_add_cte(self):
+ """
+ Validate that the SELECT in the 2nd nesting CTE does not render
+ the 1st CTE.
+
+ It implies that nesting CTE level is taken in account.
+ """
+ select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
+ select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
+
+ stmt = select(
+ select(
+ select_1_cte.c.inner_cte.label("outer_1"),
+ select_2_cte.c.inner_cte.label("outer_2"),
+ )
+ .add_cte(select_1_cte, select_2_cte, nest_here=True)
+ .cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS ("
+ "WITH nesting_1 AS (SELECT :param_1 AS inner_cte)"
+ ", nesting_2 AS (SELECT :param_2 AS inner_cte)"
+ " SELECT nesting_1.inner_cte AS outer_1"
+ ", nesting_2.inner_cte AS outer_2"
+ " FROM nesting_1, nesting_2"
+ ") SELECT cte.outer_1, cte.outer_2 FROM cte",
+ )
+
def test_double_nesting_cte_with_cross_reference_in_cte(self):
select_1_cte = select(literal(1).label("inner_cte_1")).cte(
"nesting_1", nesting=True
@@ -1993,6 +2091,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
") SELECT cte.inner_cte_2, cte.inner_cte_1 FROM cte",
)
+ def test_double_nesting_cte_with_cross_reference_in_cte_w_add_cte(self):
+ select_1_cte = select(literal(1).label("inner_cte_1")).cte("nesting_1")
+ select_2_cte = select(
+ (select_1_cte.c.inner_cte_1 + 1).label("inner_cte_2")
+ ).cte("nesting_2")
+
+ # 1 next 2
+
+ nesting_cte_1_2 = (
+ select(select_1_cte, select_2_cte)
+ .add_cte(select_1_cte, select_2_cte, nest_here=True)
+ .cte("cte")
+ )
+ stmt_1_2 = select(nesting_cte_1_2)
+ self.assert_compile(
+ stmt_1_2,
+ "WITH cte AS ("
+ "WITH nesting_1 AS (SELECT :param_1 AS inner_cte_1)"
+ ", nesting_2 AS (SELECT nesting_1.inner_cte_1 + :inner_cte_1_1"
+ " AS inner_cte_2 FROM nesting_1)"
+ " SELECT nesting_1.inner_cte_1 AS inner_cte_1"
+ ", nesting_2.inner_cte_2 AS inner_cte_2"
+ " FROM nesting_1, nesting_2"
+ ") SELECT cte.inner_cte_1, cte.inner_cte_2 FROM cte",
+ )
+
def test_nesting_cte_in_nesting_cte_in_cte(self):
select_1_cte = select(literal(1).label("inner_cte")).cte(
"nesting_1", nesting=True
@@ -2069,6 +2193,31 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT rec_cte.outer_cte FROM rec_cte",
)
+ def test_nesting_cte_in_recursive_cte_w_add_cte(self):
+ nesting_cte = select(literal(1).label("inner_cte")).cte(
+ "nesting", nesting=True
+ )
+
+ rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+ "rec_cte", recursive=True
+ )
+ rec_part = select(rec_cte.c.outer_cte).where(
+ rec_cte.c.outer_cte == literal(1)
+ )
+ rec_cte = rec_cte.union(rec_part)
+
+ stmt = select(rec_cte)
+
+ self.assert_compile(
+ stmt,
+ "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+ "(SELECT :param_1 AS inner_cte) "
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+ "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+ "WHERE rec_cte.outer_cte = :param_2) "
+ "SELECT rec_cte.outer_cte FROM rec_cte",
+ )
+
def test_recursive_nesting_cte_in_cte(self):
rec_root = select(literal(1).label("inner_cte")).cte(
"nesting", recursive=True, nesting=True
@@ -2209,6 +2358,80 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"FROM nesting_cte",
)
+ def test_add_cte_dont_nest_in_two_places(self):
+ nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
+ "nesting_cte"
+ )
+ select_add_cte = select(
+ (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
+ ).cte("nesting_2")
+
+ union_cte = (
+ select(
+ (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
+ )
+ .add_cte(nesting_cte_used_twice, nest_here=True)
+ .union(
+ select(select_add_cte).add_cte(select_add_cte, nest_here=True)
+ )
+ .cte("wrapper")
+ )
+
+ stmt = (
+ select(union_cte)
+ .add_cte(nesting_cte_used_twice, nest_here=True)
+ .union(select(nesting_cte_used_twice))
+ )
+ with expect_raises_message(
+ exc.CompileError,
+ "CTE is stated as 'nest_here' in more than one location",
+ ):
+ stmt.compile()
+
+ def test_same_nested_cte_is_not_generated_twice_w_add_cte(self):
+ # Same = name and query
+ nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
+ "nesting_cte"
+ )
+ select_add_cte = select(
+ (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
+ ).cte("nesting_2")
+
+ union_cte = (
+ select(
+ (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
+ )
+ .add_cte(nesting_cte_used_twice)
+ .union(
+ select(select_add_cte).add_cte(select_add_cte, nest_here=True)
+ )
+ .cte("wrapper")
+ )
+
+ stmt = (
+ select(union_cte)
+ .add_cte(nesting_cte_used_twice, nest_here=True)
+ .union(select(nesting_cte_used_twice))
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH nesting_cte AS "
+ "(SELECT :param_1 AS inner_cte_1)"
+ ", wrapper AS "
+ "(WITH nesting_2 AS "
+ "(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 "
+ "AS next_value "
+ "FROM nesting_cte)"
+ " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 "
+ "AS next_value "
+ "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value "
+ "FROM nesting_2)"
+ " SELECT wrapper.next_value "
+ "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
+ "FROM nesting_cte",
+ )
+
def test_recursive_nesting_cte_in_recursive_cte(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting", nesting=True, recursive=True
@@ -2363,6 +2586,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
") SELECT cte.outer_cte FROM cte",
)
+ def test_compound_select_with_nesting_cte_in_custom_order_w_add_cte(self):
+ select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
+ select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
+
+ nesting_cte = (
+ select(select_1_cte)
+ .add_cte(select_1_cte, nest_here=True)
+ .union(select(select_2_cte))
+ # Generate "select_2_cte" first
+ .add_cte(select_2_cte, nest_here=True)
+ .subquery()
+ )
+
+ stmt = select(
+ select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+ )
+
+ self.assert_compile(
+ stmt,
+ "WITH cte AS ("
+ "SELECT anon_1.inner_cte AS outer_cte FROM ("
+ "WITH nesting_2 AS (SELECT :param_1 AS inner_cte)"
+ ", nesting_1 AS (SELECT :param_2 AS inner_cte)"
+ " SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1"
+ " UNION"
+ " SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2"
+ ") AS anon_1"
+ ") SELECT cte.outer_cte FROM cte",
+ )
+
def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
rec_root = select(literal(1).label("the_value")).cte(
"recursive_cte", recursive=True
@@ -2411,3 +2664,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
" WHERE should_continue.val != true))"
" SELECT recursive_cte.the_value FROM recursive_cte",
)
+
+ @testing.combinations(True, False)
+ def test_correlated_cte_in_lateral_w_add_cte(self, reverse_direction):
+ """this is the original use case that led to #7759"""
+ contracts = table("contracts", column("id"))
+
+ invoices = table("invoices", column("id"), column("contract_id"))
+
+ contracts_alias = contracts.alias()
+ cte1 = (
+ select(contracts_alias)
+ .where(contracts_alias.c.id == contracts.c.id)
+ .correlate(contracts)
+ .cte(name="cte1")
+ )
+ cte2 = (
+ select(invoices)
+ .join(cte1, invoices.c.contract_id == cte1.c.id)
+ .cte(name="cte2")
+ )
+
+ if reverse_direction:
+ subq = select(cte1, cte2).add_cte(cte2, cte1, nest_here=True)
+ else:
+ subq = select(cte1, cte2).add_cte(cte1, cte2, nest_here=True)
+ stmt = select(contracts).outerjoin(subq.lateral(), true())
+
+ self.assert_compile(
+ stmt,
+ "SELECT contracts.id FROM contracts LEFT OUTER JOIN LATERAL "
+ "(WITH cte1 AS (SELECT contracts_1.id AS id "
+ "FROM contracts AS contracts_1 "
+ "WHERE contracts_1.id = contracts.id), "
+ "cte2 AS (SELECT invoices.id AS id, "
+ "invoices.contract_id AS contract_id FROM invoices "
+ "JOIN cte1 ON invoices.contract_id = cte1.id) "
+ "SELECT cte1.id AS id, cte2.id AS id_1, "
+ "cte2.contract_id AS contract_id "
+ "FROM cte1, cte2) AS anon_1 ON true",
+ )