summaryrefslogtreecommitdiff
path: root/test/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-28 19:50:25 -0500
committerFederico Caselli <cfederico87@gmail.com>2023-01-30 22:28:53 +0100
commitd23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6 (patch)
tree6b89a07b8bda5a469bf6c8dde165101315f571ed /test/sql
parentb99b0c522ddb94468da27867ddfa1f7e2633c920 (diff)
downloadsqlalchemy-d23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6.tar.gz
don't count / gather INSERT bind names inside of a CTE
Fixed regression related to the implementation for the new "insertmanyvalues" feature where an internal ``TypeError`` would occur in arrangements where a :func:`_sql.insert` would be referred towards inside of another :func:`_sql.insert` via a CTE; made additional repairs for this use case for positional dialects such as asyncpg when using "insertmanyvalues". at the core here is a change to positional insertmanyvalues where we now get exactly the positions for the "manyvalues" within the larger list, allowing non-"manyvalues" on the left and right sides at the same time, not assuming anything about how RETURNING renders etc., since CTEs are in the mix also. Fixes: #9173 Change-Id: I5ff071fbef0d92a2d6046b9c4e609bb008438afd
Diffstat (limited to 'test/sql')
-rw-r--r--test/sql/test_cte.py66
-rw-r--r--test/sql/test_insert_exec.py120
2 files changed, 185 insertions, 1 deletions
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py
index 502104dae..4ba4eddfe 100644
--- a/test/sql/test_cte.py
+++ b/test/sql/test_cte.py
@@ -1320,6 +1320,72 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
@testing.combinations(
("default_enhanced",),
("postgresql",),
+ ("postgresql+asyncpg",),
+ )
+ def test_insert_w_cte_in_scalar_subquery(self, dialect):
+ """test #9173"""
+
+ customer = table(
+ "customer",
+ column("id"),
+ column("name"),
+ )
+ order = table(
+ "order",
+ column("id"),
+ column("price"),
+ column("customer_id"),
+ )
+
+ inst = (
+ customer.insert()
+ .values(name="John")
+ .returning(customer.c.id)
+ .cte("inst")
+ )
+
+ stmt = (
+ order.insert()
+ .values(
+ price=1,
+ customer_id=select(inst.c.id).scalar_subquery(),
+ )
+ .add_cte(inst)
+ )
+
+ if dialect == "default_enhanced":
+ self.assert_compile(
+ stmt,
+ "WITH inst AS (INSERT INTO customer (name) VALUES (:param_1) "
+ 'RETURNING customer.id) INSERT INTO "order" '
+ "(price, customer_id) VALUES "
+ "(:price, (SELECT inst.id FROM inst))",
+ dialect=dialect,
+ )
+ elif dialect == "postgresql":
+ self.assert_compile(
+ stmt,
+ "WITH inst AS (INSERT INTO customer (name) "
+ "VALUES (%(param_1)s) "
+ 'RETURNING customer.id) INSERT INTO "order" '
+ "(price, customer_id) "
+ "VALUES (%(price)s, (SELECT inst.id FROM inst))",
+ dialect=dialect,
+ )
+ elif dialect == "postgresql+asyncpg":
+ self.assert_compile(
+ stmt,
+ "WITH inst AS (INSERT INTO customer (name) VALUES ($2) "
+ 'RETURNING customer.id) INSERT INTO "order" '
+ "(price, customer_id) VALUES ($1, (SELECT inst.id FROM inst))",
+ dialect=dialect,
+ )
+ else:
+ assert False
+
+ @testing.combinations(
+ ("default_enhanced",),
+ ("postgresql",),
)
def test_select_from_delete_cte(self, dialect):
t1 = table("table_1", column("id"), column("val"))
diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py
index d9dac75b3..3b5a1856c 100644
--- a/test/sql/test_insert_exec.py
+++ b/test/sql/test_insert_exec.py
@@ -23,6 +23,7 @@ from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
+from sqlalchemy.testing import provision
from sqlalchemy.testing.provision import normalize_sequence
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
@@ -825,6 +826,119 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)])
+ @testing.requires.ctes_on_dml
+ @testing.variation("add_expr_returning", [True, False])
+ def test_insert_w_bindparam_in_nested_insert(
+ self, connection, add_expr_returning
+ ):
+ """test related to #9173"""
+
+ data, extra_table = self.tables("data", "extra_table")
+
+ inst = (
+ extra_table.insert()
+ .values(x_value="x", y_value="y")
+ .returning(extra_table.c.id)
+ .cte("inst")
+ )
+
+ stmt = (
+ data.insert()
+ .values(x="the x", z=select(inst.c.id).scalar_subquery())
+ .add_cte(inst)
+ )
+
+ if add_expr_returning:
+ stmt = stmt.returning(data.c.id, data.c.y + " returned y")
+ else:
+ stmt = stmt.returning(data.c.id)
+
+ result = connection.execute(
+ stmt,
+ [
+ {"y": "y1"},
+ {"y": "y2"},
+ {"y": "y3"},
+ ],
+ )
+
+ result_rows = result.all()
+
+ ids = [row[0] for row in result_rows]
+
+ extra_row = connection.execute(
+ select(extra_table).order_by(extra_table.c.id)
+ ).one()
+ extra_row_id = extra_row[0]
+ eq_(extra_row, (extra_row_id, "x", "y"))
+ eq_(
+ connection.execute(select(data).order_by(data.c.id)).all(),
+ [
+ (ids[0], "the x", "y1", extra_row_id),
+ (ids[1], "the x", "y2", extra_row_id),
+ (ids[2], "the x", "y3", extra_row_id),
+ ],
+ )
+
+ @testing.requires.provisioned_upsert
+ def test_upsert_w_returning(self, connection):
+ """test cases that will execise SQL similar to that of
+ test/orm/dml/test_bulk_statements.py
+
+ """
+
+ data = self.tables.data
+
+ initial_data = [
+ {"x": "x1", "y": "y1", "z": 4},
+ {"x": "x2", "y": "y2", "z": 8},
+ ]
+ ids = connection.scalars(
+ data.insert().returning(data.c.id), initial_data
+ ).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "x": "x1",
+ "y": "y1",
+ },
+ {
+ "id": 32,
+ "x": "x19",
+ "y": "y7",
+ },
+ {
+ "id": ids[1],
+ "x": "x5",
+ "y": "y6",
+ },
+ {
+ "id": 28,
+ "x": "x9",
+ "y": "y15",
+ },
+ ]
+
+ stmt = provision.upsert(
+ config,
+ data,
+ (data,),
+ lambda inserted: {"x": inserted.x + " upserted"},
+ )
+
+ result = connection.execute(stmt, upsert_data)
+
+ eq_(
+ result.all(),
+ [
+ (ids[0], "x1 upserted", "y1", 4),
+ (32, "x19", "y7", 5),
+ (ids[1], "x5 upserted", "y2", 8),
+ (28, "x9", "y15", 5),
+ ],
+ )
+
@testing.combinations(True, False, argnames="use_returning")
@testing.combinations(1, 2, argnames="num_embedded_params")
@testing.combinations(True, False, argnames="use_whereclause")
@@ -835,7 +949,11 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
def test_insert_w_bindparam_in_subq(
self, connection, use_returning, num_embedded_params, use_whereclause
):
- """test #8639"""
+ """test #8639
+
+ see also test_insert_w_bindparam_in_nested_insert
+
+ """
t = self.tables.data
extra = self.tables.extra_table