diff options
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_from_linter.py | 22 | ||||
| -rw-r--r-- | test/sql/test_functions.py | 26 |
2 files changed, 43 insertions, 5 deletions
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py index 4a4d907f9..1fa3aff36 100644 --- a/test/sql/test_from_linter.py +++ b/test/sql/test_from_linter.py @@ -165,8 +165,15 @@ class TestFindUnmatchingFroms(fixtures.TablesTest): assert start is p3 assert froms == {p1} + @testing.combinations( + "render_derived", "alias", None, argnames="additional_transformation" + ) @testing.combinations(True, False, argnames="joins_implicitly") - def test_table_valued(self, joins_implicitly): + def test_table_valued( + self, + joins_implicitly, + additional_transformation, + ): """test #7845""" my_table = table( "tbl", @@ -175,9 +182,16 @@ class TestFindUnmatchingFroms(fixtures.TablesTest): ) sub_dict = my_table.c.data["d"] - tv = func.json_each(sub_dict).table_valued( - "key", joins_implicitly=joins_implicitly - ) + + tv = func.json_each(sub_dict) + + tv = tv.table_valued("key", joins_implicitly=joins_implicitly) + + if additional_transformation == "render_derived": + tv = tv.render_derived(name="tv", with_types=True) + elif additional_transformation == "alias": + tv = tv.alias() + has_key = tv.c.key == "f" stmt = select(my_table.c.id).where(has_key) froms, start = find_unmatching_froms(stmt, my_table) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index e08526419..c055bc150 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -26,7 +26,6 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import Text from sqlalchemy import true -from sqlalchemy import types as sqltypes from sqlalchemy.dialects import mysql from sqlalchemy.dialects import oracle from sqlalchemy.dialects import postgresql @@ -37,6 +36,7 @@ from sqlalchemy.sql import functions from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.sql import operators from sqlalchemy.sql import quoted_name +from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql.compiler import BIND_TEMPLATES from sqlalchemy.sql.functions import FunctionElement @@ -1425,6 +1425,30 @@ class TableValuedCompileTest(fixtures.TestBase, AssertsCompiledSQL): "LEFT OUTER JOIN b ON unnested.unnested = b.ref", ) + def test_render_derived_maintains_tableval_type(self): + fn = func.json_something() + + tv = fn.table_valued(column("x", String)) + + eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType)) + eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String)) + + tv = tv.render_derived() + eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType)) + eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String)) + + def test_alias_maintains_tableval_type(self): + fn = func.json_something() + + tv = fn.table_valued(column("x", String)) + + eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType)) + eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String)) + + tv = tv.alias() + eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType)) + eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String)) + def test_star_with_ordinality(self): """ SELECT * FROM generate_series(4,1,-1) WITH ORDINALITY; |
