summaryrefslogtreecommitdiff
path: root/test/sql
diff options
context:
space:
mode:
Diffstat (limited to 'test/sql')
-rw-r--r--test/sql/test_from_linter.py22
-rw-r--r--test/sql/test_functions.py26
2 files changed, 43 insertions, 5 deletions
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
index 4a4d907f9..1fa3aff36 100644
--- a/test/sql/test_from_linter.py
+++ b/test/sql/test_from_linter.py
@@ -165,8 +165,15 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
assert start is p3
assert froms == {p1}
+ @testing.combinations(
+ "render_derived", "alias", None, argnames="additional_transformation"
+ )
@testing.combinations(True, False, argnames="joins_implicitly")
- def test_table_valued(self, joins_implicitly):
+ def test_table_valued(
+ self,
+ joins_implicitly,
+ additional_transformation,
+ ):
"""test #7845"""
my_table = table(
"tbl",
@@ -175,9 +182,16 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
)
sub_dict = my_table.c.data["d"]
- tv = func.json_each(sub_dict).table_valued(
- "key", joins_implicitly=joins_implicitly
- )
+
+ tv = func.json_each(sub_dict)
+
+ tv = tv.table_valued("key", joins_implicitly=joins_implicitly)
+
+ if additional_transformation == "render_derived":
+ tv = tv.render_derived(name="tv", with_types=True)
+ elif additional_transformation == "alias":
+ tv = tv.alias()
+
has_key = tv.c.key == "f"
stmt = select(my_table.c.id).where(has_key)
froms, start = find_unmatching_froms(stmt, my_table)
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index e08526419..c055bc150 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -26,7 +26,6 @@ from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import Text
from sqlalchemy import true
-from sqlalchemy import types as sqltypes
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects import postgresql
@@ -37,6 +36,7 @@ from sqlalchemy.sql import functions
from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlalchemy.sql import operators
from sqlalchemy.sql import quoted_name
+from sqlalchemy.sql import sqltypes
from sqlalchemy.sql import table
from sqlalchemy.sql.compiler import BIND_TEMPLATES
from sqlalchemy.sql.functions import FunctionElement
@@ -1425,6 +1425,30 @@ class TableValuedCompileTest(fixtures.TestBase, AssertsCompiledSQL):
"LEFT OUTER JOIN b ON unnested.unnested = b.ref",
)
+ def test_render_derived_maintains_tableval_type(self):
+ fn = func.json_something()
+
+ tv = fn.table_valued(column("x", String))
+
+ eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType))
+ eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String))
+
+ tv = tv.render_derived()
+ eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType))
+ eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String))
+
+ def test_alias_maintains_tableval_type(self):
+ fn = func.json_something()
+
+ tv = fn.table_valued(column("x", String))
+
+ eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType))
+ eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String))
+
+ tv = tv.alias()
+ eq_(tv.column.type, testing.eq_type_affinity(sqltypes.TableValueType))
+ eq_(tv.column.type._elements[0].type, testing.eq_type_affinity(String))
+
def test_star_with_ordinality(self):
"""
SELECT * FROM generate_series(4,1,-1) WITH ORDINALITY;