diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-02-25 17:48:30 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-02-25 17:48:30 +0000 |
| commit | 60fca2ac8cf44bdaf68552ab5c69854a6776c73c (patch) | |
| tree | 2b9bc005223c6c58009762dc120fccf309c1ba92 /test/sql | |
| parent | 2d97c388eae4345840f745337ec033045651b36d (diff) | |
| parent | 0fe8f4a3e79c8fc805e7a84849920c7258177f41 (diff) | |
| download | sqlalchemy-60fca2ac8cf44bdaf68552ab5c69854a6776c73c.tar.gz | |
Merge "Add more nesting features to add_cte()" into main
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_cte.py | 293 |
1 files changed, 293 insertions, 0 deletions
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index b05692504..2ee6fa9f3 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,11 +1,13 @@ from sqlalchemy import Column from sqlalchemy import delete +from sqlalchemy import exc from sqlalchemy import Integer from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy import MetaData from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import update from sqlalchemy.dialects import mssql from sqlalchemy.engine import default @@ -25,6 +27,7 @@ from sqlalchemy.sql.visitors import cloned_traverse from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures @@ -1869,6 +1872,21 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) + def test_select_with_nesting_cte_in_cte_w_add_cte(self): + nesting_cte = select(literal(1).label("inner_cte")).cte("nesting") + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")) + .add_cte(nesting_cte, nest_here=True) + .cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting) " + "SELECT cte.outer_cte FROM cte", + ) + def test_select_with_aliased_nesting_cte_in_cte(self): nesting_cte = ( select(literal(1).label("inner_cte")) @@ -1887,6 +1905,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) + def test_select_with_aliased_nesting_cte_in_cte_w_add_cte(self): + inner_nesting_cte = select(literal(1).label("inner_cte")).cte( + "nesting" + ) + outer_cte = select().add_cte(inner_nesting_cte, nest_here=True) + nesting_cte = inner_nesting_cte.alias("aliased_nested") + outer_cte = outer_cte.add_columns( + nesting_cte.c.inner_cte.label("outer_cte") + ).cte("cte") + stmt = select(outer_cte) + + self.assert_compile( + stmt, + "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) " + "SELECT aliased_nested.inner_cte AS outer_cte " + "FROM nesting AS aliased_nested) " + "SELECT cte.outer_cte FROM cte", + ) + def test_nesting_cte_in_cte_with_same_name(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "some_cte", nesting=True @@ -1904,6 +1941,23 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT some_cte.outer_cte FROM some_cte", ) + def test_nesting_cte_in_cte_with_same_name_w_add_cte(self): + nesting_cte = select(literal(1).label("inner_cte")).cte("some_cte") + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")) + .add_cte(nesting_cte, nest_here=True) + .cte("some_cte") + ) + + self.assert_compile( + stmt, + "WITH some_cte AS (WITH some_cte AS " + "(SELECT :param_1 AS inner_cte) " + "SELECT some_cte.inner_cte AS outer_cte " + "FROM some_cte) " + "SELECT some_cte.outer_cte FROM some_cte", + ) + def test_nesting_cte_at_top_level(self): nesting_cte = select(literal(1).label("val")).cte( "nesting_cte", nesting=True @@ -1918,6 +1972,20 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): " SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte", ) + def test_nesting_cte_at_top_level_w_add_cte(self): + nesting_cte = select(literal(1).label("val")).cte("nesting_cte") + cte = select(literal(2).label("val")).cte("cte") + stmt = select(nesting_cte.c.val, cte.c.val).add_cte( + nesting_cte, nest_here=True + ) + + self.assert_compile( + stmt, + "WITH nesting_cte AS (SELECT :param_1 AS val)" + ", cte AS (SELECT :param_2 AS val)" + " SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte", + ) + def test_double_nesting_cte_in_cte(self): """ Validate that the SELECT in the 2nd nesting CTE does not render @@ -1950,6 +2018,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_1, cte.outer_2 FROM cte", ) + def test_double_nesting_cte_in_cte_w_add_cte(self): + """ + Validate that the SELECT in the 2nd nesting CTE does not render + the 1st CTE. + + It implies that nesting CTE level is taken in account. + """ + select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1") + select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2") + + stmt = select( + select( + select_1_cte.c.inner_cte.label("outer_1"), + select_2_cte.c.inner_cte.label("outer_2"), + ) + .add_cte(select_1_cte, select_2_cte, nest_here=True) + .cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + "WITH nesting_1 AS (SELECT :param_1 AS inner_cte)" + ", nesting_2 AS (SELECT :param_2 AS inner_cte)" + " SELECT nesting_1.inner_cte AS outer_1" + ", nesting_2.inner_cte AS outer_2" + " FROM nesting_1, nesting_2" + ") SELECT cte.outer_1, cte.outer_2 FROM cte", + ) + def test_double_nesting_cte_with_cross_reference_in_cte(self): select_1_cte = select(literal(1).label("inner_cte_1")).cte( "nesting_1", nesting=True @@ -1993,6 +2091,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.inner_cte_2, cte.inner_cte_1 FROM cte", ) + def test_double_nesting_cte_with_cross_reference_in_cte_w_add_cte(self): + select_1_cte = select(literal(1).label("inner_cte_1")).cte("nesting_1") + select_2_cte = select( + (select_1_cte.c.inner_cte_1 + 1).label("inner_cte_2") + ).cte("nesting_2") + + # 1 next 2 + + nesting_cte_1_2 = ( + select(select_1_cte, select_2_cte) + .add_cte(select_1_cte, select_2_cte, nest_here=True) + .cte("cte") + ) + stmt_1_2 = select(nesting_cte_1_2) + self.assert_compile( + stmt_1_2, + "WITH cte AS (" + "WITH nesting_1 AS (SELECT :param_1 AS inner_cte_1)" + ", nesting_2 AS (SELECT nesting_1.inner_cte_1 + :inner_cte_1_1" + " AS inner_cte_2 FROM nesting_1)" + " SELECT nesting_1.inner_cte_1 AS inner_cte_1" + ", nesting_2.inner_cte_2 AS inner_cte_2" + " FROM nesting_1, nesting_2" + ") SELECT cte.inner_cte_1, cte.inner_cte_2 FROM cte", + ) + def test_nesting_cte_in_nesting_cte_in_cte(self): select_1_cte = select(literal(1).label("inner_cte")).cte( "nesting_1", nesting=True @@ -2069,6 +2193,31 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT rec_cte.outer_cte FROM rec_cte", ) + def test_nesting_cte_in_recursive_cte_w_add_cte(self): + nesting_cte = select(literal(1).label("inner_cte")).cte( + "nesting", nesting=True + ) + + rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte( + "rec_cte", recursive=True + ) + rec_part = select(rec_cte.c.outer_cte).where( + rec_cte.c.outer_cte == literal(1) + ) + rec_cte = rec_cte.union(rec_part) + + stmt = select(rec_cte) + + self.assert_compile( + stmt, + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " + "(SELECT :param_1 AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = :param_2) " + "SELECT rec_cte.outer_cte FROM rec_cte", + ) + def test_recursive_nesting_cte_in_cte(self): rec_root = select(literal(1).label("inner_cte")).cte( "nesting", recursive=True, nesting=True @@ -2209,6 +2358,80 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "FROM nesting_cte", ) + def test_add_cte_dont_nest_in_two_places(self): + nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( + "nesting_cte" + ) + select_add_cte = select( + (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value") + ).cte("nesting_2") + + union_cte = ( + select( + (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value") + ) + .add_cte(nesting_cte_used_twice, nest_here=True) + .union( + select(select_add_cte).add_cte(select_add_cte, nest_here=True) + ) + .cte("wrapper") + ) + + stmt = ( + select(union_cte) + .add_cte(nesting_cte_used_twice, nest_here=True) + .union(select(nesting_cte_used_twice)) + ) + with expect_raises_message( + exc.CompileError, + "CTE is stated as 'nest_here' in more than one location", + ): + stmt.compile() + + def test_same_nested_cte_is_not_generated_twice_w_add_cte(self): + # Same = name and query + nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( + "nesting_cte" + ) + select_add_cte = select( + (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value") + ).cte("nesting_2") + + union_cte = ( + select( + (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value") + ) + .add_cte(nesting_cte_used_twice) + .union( + select(select_add_cte).add_cte(select_add_cte, nest_here=True) + ) + .cte("wrapper") + ) + + stmt = ( + select(union_cte) + .add_cte(nesting_cte_used_twice, nest_here=True) + .union(select(nesting_cte_used_twice)) + ) + + self.assert_compile( + stmt, + "WITH nesting_cte AS " + "(SELECT :param_1 AS inner_cte_1)" + ", wrapper AS " + "(WITH nesting_2 AS " + "(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 " + "AS next_value " + "FROM nesting_cte)" + " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 " + "AS next_value " + "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value " + "FROM nesting_2)" + " SELECT wrapper.next_value " + "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " + "FROM nesting_cte", + ) + def test_recursive_nesting_cte_in_recursive_cte(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True, recursive=True @@ -2363,6 +2586,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_cte FROM cte", ) + def test_compound_select_with_nesting_cte_in_custom_order_w_add_cte(self): + select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1") + select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2") + + nesting_cte = ( + select(select_1_cte) + .add_cte(select_1_cte, nest_here=True) + .union(select(select_2_cte)) + # Generate "select_2_cte" first + .add_cte(select_2_cte, nest_here=True) + .subquery() + ) + + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + "SELECT anon_1.inner_cte AS outer_cte FROM (" + "WITH nesting_2 AS (SELECT :param_1 AS inner_cte)" + ", nesting_1 AS (SELECT :param_2 AS inner_cte)" + " SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1" + " UNION" + " SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2" + ") AS anon_1" + ") SELECT cte.outer_cte FROM cte", + ) + def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self): rec_root = select(literal(1).label("the_value")).cte( "recursive_cte", recursive=True @@ -2411,3 +2664,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): " WHERE should_continue.val != true))" " SELECT recursive_cte.the_value FROM recursive_cte", ) + + @testing.combinations(True, False) + def test_correlated_cte_in_lateral_w_add_cte(self, reverse_direction): + """this is the original use case that led to #7759""" + contracts = table("contracts", column("id")) + + invoices = table("invoices", column("id"), column("contract_id")) + + contracts_alias = contracts.alias() + cte1 = ( + select(contracts_alias) + .where(contracts_alias.c.id == contracts.c.id) + .correlate(contracts) + .cte(name="cte1") + ) + cte2 = ( + select(invoices) + .join(cte1, invoices.c.contract_id == cte1.c.id) + .cte(name="cte2") + ) + + if reverse_direction: + subq = select(cte1, cte2).add_cte(cte2, cte1, nest_here=True) + else: + subq = select(cte1, cte2).add_cte(cte1, cte2, nest_here=True) + stmt = select(contracts).outerjoin(subq.lateral(), true()) + + self.assert_compile( + stmt, + "SELECT contracts.id FROM contracts LEFT OUTER JOIN LATERAL " + "(WITH cte1 AS (SELECT contracts_1.id AS id " + "FROM contracts AS contracts_1 " + "WHERE contracts_1.id = contracts.id), " + "cte2 AS (SELECT invoices.id AS id, " + "invoices.contract_id AS contract_id FROM invoices " + "JOIN cte1 ON invoices.contract_id = cte1.id) " + "SELECT cte1.id AS id, cte2.id AS id_1, " + "cte2.contract_id AS contract_id " + "FROM cte1, cte2) AS anon_1 ON true", + ) |
