diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-04-05 11:58:52 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-04-21 11:30:40 -0400 |
| commit | cf6872d3bdf1a8a9613e853694acc2b1e6f06f51 (patch) | |
| tree | 3a4ee41ab8b48aea7ac1e275c2f553763ec28dad /test | |
| parent | 63f51491c5f0cb22883c800a065d7c4b4c54774e (diff) | |
| download | sqlalchemy-cf6872d3bdf1a8a9613e853694acc2b1e6f06f51.tar.gz | |
add deterministic imv returning ordering using sentinel columns
Repaired a major shortcoming which was identified in the
:ref:`engine_insertmanyvalues` performance optimization feature first
introduced in the 2.0 series. This was a continuation of the change in
2.0.9 which disabled the SQL Server version of the feature due to a
reliance in the ORM on apparent row ordering that is not guaranteed to take
place. The fix applies new logic to all "insertmanyvalues" operations,
which takes effect when a new parameter
:paramref:`_dml.Insert.returning.sort_by_parameter_order` on the
:meth:`_dml.Insert.returning` or :meth:`_dml.UpdateBase.return_defaults`
methods, that through a combination of alternate SQL forms, direct
correspondence of client side parameters, and in some cases downgrading to
running row-at-a-time, will apply sorting to each batch of returned rows
using correspondence to primary key or other unique values in each row
which can be correlated to the input data.
Performance impact is expected to be minimal as nearly all common primary
key scenarios are suitable for parameter-ordered batching to be
achieved for all backends other than SQLite, while "row-at-a-time"
mode operates with a bare minimum of Python overhead compared to the very
heavyweight approaches used in the 1.x series. For SQLite, there is no
difference in performance when "row-at-a-time" mode is used.
It's anticipated that with an efficient "row-at-a-time" INSERT with
RETURNING batching capability, the "insertmanyvalues" feature can be later
be more easily generalized to third party backends that include RETURNING
support but not necessarily easy ways to guarantee a correspondence
with parameter order.
Fixes: #9618
References: #9603
Change-Id: I1d79353f5f19638f752936ba1c35e4dc235a8b7c
Diffstat (limited to 'test')
| -rw-r--r-- | test/dialect/mssql/test_engine.py | 9 | ||||
| -rw-r--r-- | test/engine/test_logging.py | 7 | ||||
| -rw-r--r-- | test/orm/declarative/test_basic.py | 22 | ||||
| -rw-r--r-- | test/orm/declarative/test_inheritance.py | 96 | ||||
| -rw-r--r-- | test/orm/dml/test_bulk_statements.py | 520 | ||||
| -rw-r--r-- | test/orm/test_ac_relationships.py | 2 | ||||
| -rw-r--r-- | test/orm/test_defaults.py | 43 | ||||
| -rw-r--r-- | test/orm/test_expire.py | 4 | ||||
| -rw-r--r-- | test/orm/test_unitofwork.py | 44 | ||||
| -rw-r--r-- | test/orm/test_unitofworkv2.py | 301 | ||||
| -rw-r--r-- | test/requirements.py | 4 | ||||
| -rw-r--r-- | test/sql/test_compiler.py | 241 | ||||
| -rw-r--r-- | test/sql/test_defaults.py | 1 | ||||
| -rw-r--r-- | test/sql/test_insert.py | 47 | ||||
| -rw-r--r-- | test/sql/test_insert_exec.py | 1625 | ||||
| -rw-r--r-- | test/sql/test_metadata.py | 55 | ||||
| -rw-r--r-- | test/sql/test_returning.py | 60 |
17 files changed, 2808 insertions, 273 deletions
diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index 095df2eaf..799452ade 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -3,7 +3,6 @@ import re from unittest.mock import Mock from sqlalchemy import Column -from sqlalchemy import create_engine from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import inspect @@ -629,14 +628,6 @@ class MiscTest(fixtures.TestBase): __only_on__ = "mssql" __backend__ = True - def test_no_insertmanyvalues(self): - with expect_raises_message( - exc.ArgumentError, - "The use_insertmanyvalues feature on SQL Server is " - "currently not safe to use", - ): - create_engine("mssql+pyodbc://", use_insertmanyvalues=True) - @testing.variation("enable_comments", [True, False]) def test_comments_enabled_disabled( self, testing_engine, metadata, enable_comments diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index 19c26f43c..a498ec85c 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -283,7 +283,8 @@ class LogParamsTest(fixtures.TestBase): eq_regex( self.buf.buffer[4].message, - r"\[generated in .* \(insertmanyvalues\)\] \('d0', 'd1', " + r"\[generated in .* \(insertmanyvalues\) 1/3 " + r"\(unordered\)\] \('d0', 'd1', " r"'d2', 'd3', 'd4', 'd5', 'd6', 'd7', " r"'d8', 'd9', 'd10', 'd11', 'd12', 'd13', 'd14', 'd15', " r"'d16', 'd17', 'd18', 'd19', 'd20', 'd21', 'd22', 'd23', " @@ -304,7 +305,7 @@ class LogParamsTest(fixtures.TestBase): eq_(self.buf.buffer[5].message, full_insert) eq_( self.buf.buffer[6].message, - "[insertmanyvalues batch 2 of 3] ('d150', 'd151', 'd152', " + "[insertmanyvalues 2/3 (unordered)] ('d150', 'd151', 'd152', " "'d153', 'd154', 'd155', 'd156', 'd157', 'd158', 'd159', " "'d160', 'd161', 'd162', 'd163', 'd164', 'd165', 'd166', " "'d167', 'd168', 'd169', 'd170', 'd171', 'd172', 'd173', " @@ -330,7 +331,7 @@ class LogParamsTest(fixtures.TestBase): ) eq_( self.buf.buffer[8].message, - "[insertmanyvalues batch 3 of 3] ('d300', 'd301', 'd302', " + "[insertmanyvalues 3/3 (unordered)] ('d300', 'd301', 'd302', " "'d303', 'd304', 'd305', 'd306', 'd307', 'd308', 'd309', " "'d310', 'd311', 'd312', 'd313', 'd314', 'd315', 'd316', " "'d317', 'd318', 'd319', 'd320', 'd321', 'd322', 'd323', " diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 698b66db1..d0e56819c 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -1,4 +1,5 @@ import random +import uuid import sqlalchemy as sa from sqlalchemy import CheckConstraint @@ -13,6 +14,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import UniqueConstraint +from sqlalchemy import Uuid from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import as_declarative from sqlalchemy.orm import backref @@ -209,6 +211,26 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): ): Base.__init__(fs, x=5) + def test_insert_sentinel_param_custom_type_maintained(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id: Mapped[uuid.UUID] = mapped_column( + default=uuid.uuid4, primary_key=True, insert_sentinel=True + ) + data: Mapped[str] + + is_(A.id.expression.type._type_affinity, Uuid) + + def test_insert_sentinel_param_default_type(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column( + primary_key=True, insert_sentinel=True + ) + data: Mapped[str] + + is_(A.id.expression.type._type_affinity, Integer) + @testing.variation("argument", ["version_id_col", "polymorphic_on"]) @testing.variation("column_type", ["anno", "non_anno", "plain_column"]) def test_mapped_column_version_poly_arg( diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index e8658926b..333d24230 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -1,3 +1,5 @@ +import contextlib + import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy import Identity @@ -7,6 +9,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import class_mapper from sqlalchemy.orm import close_all_sessions +from sqlalchemy.orm import column_property from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred @@ -20,6 +23,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false @@ -987,6 +991,98 @@ class DeclarativeInheritanceTest( session.commit() eq_(session.query(Engineer).first().target, o1) + @testing.variation("omit_from_statements", [True, False]) + @testing.variation("combine_on_b", [True, False]) + @testing.variation("c_first", [True, False]) + def test_use_existing_column_other_inh_types( + self, decl_base, omit_from_statements, combine_on_b, c_first + ): + """test additional fixes to use_existing_column, adding + some new use cases with "omit_from_statements" which in this case + is essentially the same as adding it to the mapper exclude_cols + list. + + """ + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + extra: Mapped[int] = mapped_column( + use_existing_column=True, + _omit_from_statements=bool(omit_from_statements), + ) + + if c_first: + + class C(A): + foo: Mapped[str] + extra: Mapped[int] = mapped_column( + use_existing_column=True, + _omit_from_statements=bool(omit_from_statements), + ) + + if not combine_on_b and not omit_from_statements: + ctx = expect_warnings( + "Implicitly combining column a.extra with column b.extra", + raise_on_any_unexpected=True, + ) + else: + ctx = contextlib.nullcontext() + + with ctx: + + class B(A): + __tablename__ = "b" + id: Mapped[int] = mapped_column( + ForeignKey("a.id"), primary_key=True + ) + if combine_on_b: + extra: Mapped[int] = column_property( + mapped_column( + _omit_from_statements=bool(omit_from_statements) + ), + A.extra, + ) + else: + extra: Mapped[int] = mapped_column( + use_existing_column=True, + _omit_from_statements=bool(omit_from_statements), + ) + + if not c_first: + + class C(A): # noqa: F811 + foo: Mapped[str] + extra: Mapped[int] = mapped_column( + use_existing_column=True, + _omit_from_statements=bool(omit_from_statements), + ) + + if bool(omit_from_statements): + self.assert_compile(select(A), "SELECT a.id, a.data FROM a") + else: + self.assert_compile( + select(A), "SELECT a.id, a.data, a.extra FROM a" + ) + + if bool(omit_from_statements) and not combine_on_b: + self.assert_compile( + select(B), + "SELECT b.id, a.id AS id_1, a.data " + "FROM a JOIN b ON a.id = b.id", + ) + else: + # if we combine_on_b we made a column_property, which brought + # out "extra" even if it was omit_from_statements. this should be + # expected + self.assert_compile( + select(B), + "SELECT b.id, a.id AS id_1, a.data, b.extra, " + "a.extra AS extra_1 FROM a JOIN b ON a.id = b.id", + ) + @testing.variation("decl_type", ["legacy", "use_existing_column"]) def test_columns_single_inheritance_conflict_resolution_pk( self, decl_base, decl_type diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 7a9f3324f..84ea7c82c 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -1,10 +1,13 @@ from __future__ import annotations +import contextlib from typing import Any from typing import List from typing import Optional +from typing import Set import uuid +from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func @@ -22,45 +25,101 @@ from sqlalchemy.orm import column_property from sqlalchemy.orm import load_only from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import orm_insert_sentinel +from sqlalchemy.orm import Session from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing import provision from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): - def test_no_returning_error(self, decl_base): + __backend__ = True + + @testing.variation( + "style", + [ + "no_executemany", + ("no_sort_by", testing.requires.insert_returning), + ("all_enabled", testing.requires.insert_returning), + ], + ) + @testing.variation("sort_by_parameter_order", [True, False]) + def test_no_returning_error( + self, + decl_base, + testing_engine, + style: testing.Variation, + sort_by_parameter_order, + ): class A(fixtures.ComparableEntity, decl_base): __tablename__ = "a" id: Mapped[int] = mapped_column(Identity(), primary_key=True) data: Mapped[str] x: Mapped[Optional[int]] = mapped_column("xcol") - decl_base.metadata.create_all(testing.db) - s = fixture_session() + engine = testing_engine() + + if style.no_executemany: + engine.dialect.use_insertmanyvalues = False + engine.dialect.insert_executemany_returning = False + engine.dialect.insert_executemany_returning_sort_by_parameter_order = ( # noqa: E501 + False + ) + elif style.no_sort_by: + engine.dialect.use_insertmanyvalues = True + engine.dialect.insert_executemany_returning = True + engine.dialect.insert_executemany_returning_sort_by_parameter_order = ( # noqa: E501 + False + ) + elif style.all_enabled: + engine.dialect.use_insertmanyvalues = True + engine.dialect.insert_executemany_returning = True + engine.dialect.insert_executemany_returning_sort_by_parameter_order = ( # noqa: E501 + True + ) + else: + style.fail() + + decl_base.metadata.create_all(engine) + s = Session(engine) - if testing.requires.insert_executemany_returning.enabled: + if style.all_enabled or ( + style.no_sort_by and not sort_by_parameter_order + ): result = s.scalars( - insert(A).returning(A), + insert(A).returning( + A, sort_by_parameter_order=bool(sort_by_parameter_order) + ), [ {"data": "d3", "x": 5}, {"data": "d4", "x": 6}, ], ) - eq_(result.all(), [A(data="d3", x=5), A(data="d4", x=6)]) + eq_(set(result.all()), {A(data="d3", x=5), A(data="d4", x=6)}) else: with expect_raises_message( exc.InvalidRequestError, - "Can't use explicit RETURNING for bulk INSERT operation", + r"Can't use explicit RETURNING for bulk INSERT operation.*" + rf"""executemany with RETURNING{ + ' and sort by parameter order' + if sort_by_parameter_order else '' + } is """ + r"not enabled for this dialect", ): s.scalars( - insert(A).returning(A), + insert(A).returning( + A, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), [ {"data": "d3", "x": 5}, {"data": "d4", "x": 6}, @@ -132,6 +191,9 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): ) @testing.requires.insert_returning + @testing.skip_if( + "oracle", "oracle doesn't like the no-FROM SELECT inside of an INSERT" + ) def test_insert_from_select_col_property(self, decl_base): """test #9273""" @@ -166,6 +228,40 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): class BulkDMLReturningInhTest: + use_sentinel = False + randomize_returning = False + + def assert_for_downgrade(self, *, sort_by_parameter_order): + if ( + not sort_by_parameter_order + or not self.randomize_returning + or not testing.against(["postgresql", "mssql", "mariadb"]) + ): + return contextlib.nullcontext() + else: + return expect_warnings("Batches were downgraded") + + @classmethod + def setup_bind(cls): + if cls.randomize_returning: + new_eng = config.db.execution_options() + + @event.listens_for(new_eng, "engine_connect") + def eng_connect(connection): + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=True, + # there should be no sentinel downgrades for any of + # these three dbs. sqlite has downgrades + warn_on_downgraded=testing.against( + ["postgresql", "mssql", "mariadb"] + ), + ) + + return new_eng + else: + return config.db + def test_insert_col_key_also_works_currently(self): """using the column key, not mapped attr key. @@ -178,7 +274,7 @@ class BulkDMLReturningInhTest: """ A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) s.execute(insert(A).values(type="a", data="d", xcol=10)) eq_(s.scalars(select(A.x)).all(), [10]) @@ -186,7 +282,7 @@ class BulkDMLReturningInhTest: def test_autoflush(self, autoflush_option): A = self.classes.A - s = fixture_session() + s = fixture_session(bind=self.bind) a1 = A(data="x1") s.add(a1) @@ -211,8 +307,9 @@ class BulkDMLReturningInhTest: else: assert False - @testing.combinations(True, False, argnames="use_returning") - def test_heterogeneous_keys(self, use_returning): + @testing.variation("use_returning", [True, False]) + @testing.variation("sort_by_parameter_order", [True, False]) + def test_heterogeneous_keys(self, use_returning, sort_by_parameter_order): A, B = self.classes("A", "B") values = [ @@ -224,21 +321,31 @@ class BulkDMLReturningInhTest: {"data": "d8", "x": 7, "type": "a"}, ] - s = fixture_session() + s = fixture_session(bind=self.bind) stmt = insert(A) if use_returning: - stmt = stmt.returning(A) + stmt = stmt.returning( + A, sort_by_parameter_order=bool(sort_by_parameter_order) + ) with self.sql_execution_asserter() as asserter: result = s.execute(stmt, values) if use_returning: + if self.use_sentinel and sort_by_parameter_order: + _sentinel_col = ", _sentinel" + _sentinel_returning = ", a._sentinel" + _sentinel_param = ", :_sentinel" + else: + _sentinel_col = _sentinel_param = _sentinel_returning = "" + # note no sentinel col is used when there is only one row asserter.assert_( CompiledSQL( - "INSERT INTO a (type, data, xcol) VALUES " - "(:type, :data, :xcol) " - "RETURNING a.id, a.type, a.data, a.xcol, a.y", + f"INSERT INTO a (type, data, xcol{_sentinel_col}) VALUES " + f"(:type, :data, :xcol{_sentinel_param}) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y" + f"{_sentinel_returning}", [ {"type": "a", "data": "d3", "xcol": 5}, {"type": "a", "data": "d4", "xcol": 6}, @@ -250,9 +357,10 @@ class BulkDMLReturningInhTest: [{"type": "a", "data": "d5"}], ), CompiledSQL( - "INSERT INTO a (type, data, xcol, y) " - "VALUES (:type, :data, :xcol, :y) " - "RETURNING a.id, a.type, a.data, a.xcol, a.y", + f"INSERT INTO a (type, data, xcol, y{_sentinel_col}) " + f"VALUES (:type, :data, :xcol, :y{_sentinel_param}) " + f"RETURNING a.id, a.type, a.data, a.xcol, a.y" + f"{_sentinel_returning}", [ {"type": "a", "data": "d6", "xcol": 8, "y": 9}, {"type": "a", "data": "d7", "xcol": 12, "y": 12}, @@ -297,15 +405,15 @@ class BulkDMLReturningInhTest: if use_returning: with self.assert_statement_count(testing.db, 0): eq_( - result.scalars().all(), - [ + set(result.scalars().all()), + { A(data="d3", id=mock.ANY, type="a", x=5, y=None), A(data="d4", id=mock.ANY, type="a", x=6, y=None), A(data="d5", id=mock.ANY, type="a", x=None, y=None), A(data="d6", id=mock.ANY, type="a", x=8, y=9), A(data="d7", id=mock.ANY, type="a", x=12, y=12), A(data="d8", id=mock.ANY, type="a", x=7, y=None), - ], + }, ) @testing.combinations( @@ -315,10 +423,8 @@ class BulkDMLReturningInhTest: "cols_w_exprs", argnames="paramstyle", ) - @testing.combinations( - True, - (False, testing.requires.multivalues_inserts), - argnames="single_element", + @testing.variation( + "single_element", [True, (False, testing.requires.multivalues_inserts)] ) def test_single_values_returning_fn(self, paramstyle, single_element): """test using insert().values(). @@ -364,7 +470,7 @@ class BulkDMLReturningInhTest: else: assert False - s = fixture_session() + s = fixture_session(bind=self.bind) if single_element: if paramstyle.startswith("strings"): @@ -405,7 +511,7 @@ class BulkDMLReturningInhTest: }, ] - s = fixture_session() + s = fixture_session(bind=self.bind) stmt = ( insert(A) @@ -415,11 +521,11 @@ class BulkDMLReturningInhTest: for i in range(3): result = s.execute(stmt, data) - expected: List[Any] = [ + expected: Set[Any] = { (A(data="dd", x=5, y=9), "DD"), (A(data="dd", x=10, y=8), "DD"), - ] - eq_(result.all(), expected) + } + eq_(set(result.all()), expected) def test_bulk_w_sql_expressions_subclass(self): A, B = self.classes("A", "B") @@ -429,7 +535,7 @@ class BulkDMLReturningInhTest: {"bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, ] - s = fixture_session() + s = fixture_session(bind=self.bind) stmt = ( insert(B) @@ -439,17 +545,17 @@ class BulkDMLReturningInhTest: for i in range(3): result = s.execute(stmt, data) - expected: List[Any] = [ + expected: Set[Any] = { (B(bd="bd1", data="dd", q=4, type="b", x=1, y=2, z=3), "DD"), (B(bd="bd2", data="dd", q=8, type="b", x=5, y=6, z=7), "DD"), - ] - eq_(result.all(), expected) + } + eq_(set(result), expected) @testing.combinations(True, False, argnames="use_ordered") def test_bulk_upd_w_sql_expressions_no_ordered_values(self, use_ordered): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) stmt = update(B).ordered_values( ("data", func.lower("DD_UPDATE")), @@ -471,13 +577,16 @@ class BulkDMLReturningInhTest: def test_bulk_upd_w_sql_expressions_subclass(self): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) data = [ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, ] - ids = s.scalars(insert(B).returning(B.id), data).all() + ids = { + row.data: row.id + for row in s.execute(insert(B).returning(B.id, B.data), data) + } stmt = update(B).values( data=func.lower("DD_UPDATE"), z=literal_column("3 + 12") @@ -486,8 +595,8 @@ class BulkDMLReturningInhTest: result = s.execute( stmt, [ - {"id": ids[0], "bd": "bd1_updated"}, - {"id": ids[1], "bd": "bd2_updated"}, + {"id": ids["d3"], "bd": "bd1_updated"}, + {"id": ids["d4"], "bd": "bd2_updated"}, ], ) @@ -495,12 +604,12 @@ class BulkDMLReturningInhTest: assert result is not None eq_( - s.scalars(select(B)).all(), - [ + set(s.scalars(select(B))), + { B( bd="bd1_updated", data="dd_update", - id=ids[0], + id=ids["d3"], q=4, type="b", x=1, @@ -510,36 +619,32 @@ class BulkDMLReturningInhTest: B( bd="bd2_updated", data="dd_update", - id=ids[1], + id=ids["d4"], q=8, type="b", x=5, y=6, z=15, ), - ], + }, ) def test_single_returning_fn(self): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) for i in range(3): result = s.execute( insert(A).returning(A, func.upper(A.data, type_=String)), [{"data": "d3"}, {"data": "d4"}], ) - eq_(result.all(), [(A(data="d3"), "D3"), (A(data="d4"), "D4")]) + eq_(set(result), {(A(data="d3"), "D3"), (A(data="d4"), "D4")}) - @testing.combinations( - True, - False, - argnames="single_element", - ) + @testing.variation("single_element", [True, False]) def test_subclass_no_returning(self, single_element): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) if single_element: data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} @@ -552,19 +657,16 @@ class BulkDMLReturningInhTest: result = s.execute(insert(B), data) assert result._soft_closed - @testing.combinations( - True, - False, - argnames="single_element", - ) - def test_subclass_load_only(self, single_element): + @testing.variation("sort_by_parameter_order", [True, False]) + @testing.variation("single_element", [True, False]) + def test_subclass_load_only(self, single_element, sort_by_parameter_order): """test that load_only() prevents additional attributes from being populated. """ A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) if single_element: data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} @@ -578,7 +680,12 @@ class BulkDMLReturningInhTest: # tests both caching and that the data dictionaries aren't # mutated... result = s.execute( - insert(B).returning(B).options(load_only(B.data, B.y, B.q)), + insert(B) + .returning( + B, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + .options(load_only(B.data, B.y, B.q)), data, ) objects = result.scalars().all() @@ -593,13 +700,14 @@ class BulkDMLReturningInhTest: ] if not single_element: expected.append(B(data="d4", bd="bd2", x=5, y=6, z=7, q=8)) - eq_(objects, expected) - @testing.combinations( - True, - False, - argnames="single_element", - ) + if sort_by_parameter_order: + coll = list + else: + coll = set + eq_(coll(objects), coll(expected)) + + @testing.variation("single_element", [True, False]) def test_subclass_load_only_doesnt_fetch_cols(self, single_element): """test that when using load_only(), the actual INSERT statement does not include the deferred columns @@ -607,7 +715,7 @@ class BulkDMLReturningInhTest: """ A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) data = [ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, @@ -699,30 +807,60 @@ class BulkDMLReturningInhTest: # RETURNING only includes PK, discriminator, then the cols # we asked for data, y, q. xcol, z, bd are omitted. plus they # are broken out correctly in the two statements. + asserter.assert_( - CompiledSQL( - "INSERT INTO a (type, data, xcol, y) VALUES " - "(:type, :data, :xcol, :y) " - "RETURNING a.id, a.type, a.data, a.y", - a_data, - ), - CompiledSQL( - "INSERT INTO b (id, bd, zcol, q) " - "VALUES (:id, :bd, :zcol, :q) " - "RETURNING b.id, b.q", - b_data, - ), + Conditional( + self.use_sentinel and not single_element, + [ + CompiledSQL( + "INSERT INTO a (type, data, xcol, y, _sentinel) " + "VALUES " + "(:type, :data, :xcol, :y, :_sentinel) " + "RETURNING a.id, a.type, a.data, a.y, a._sentinel", + a_data, + ), + CompiledSQL( + "INSERT INTO b (id, bd, zcol, q, _sentinel) " + "VALUES (:id, :bd, :zcol, :q, :_sentinel) " + "RETURNING b.id, b.q, b._sentinel", + b_data, + ), + ], + [ + CompiledSQL( + "INSERT INTO a (type, data, xcol, y) VALUES " + "(:type, :data, :xcol, :y) " + "RETURNING a.id, a.type, a.data, a.y", + a_data, + ), + Conditional( + single_element, + [ + CompiledSQL( + "INSERT INTO b (id, bd, zcol, q) " + "VALUES (:id, :bd, :zcol, :q) " + "RETURNING b.id, b.q", + b_data, + ), + ], + [ + CompiledSQL( + "INSERT INTO b (id, bd, zcol, q) " + "VALUES (:id, :bd, :zcol, :q) " + "RETURNING b.id, b.q, b.id AS id__1", + b_data, + ), + ], + ), + ], + ) ) - @testing.combinations( - True, - False, - argnames="single_element", - ) + @testing.variation("single_element", [True, False]) def test_subclass_returning_bind_expr(self, single_element): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) if single_element: data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} @@ -740,24 +878,27 @@ class BulkDMLReturningInhTest: if single_element: eq_(result.all(), [("d3", 2, 9)]) else: - eq_(result.all(), [("d3", 2, 9), ("d4", 6, 13)]) + eq_(set(result), {("d3", 2, 9), ("d4", 6, 13)}) def test_subclass_bulk_update(self): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) data = [ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, ] - ids = s.scalars(insert(B).returning(B.id), data).all() + ids = { + row.data: row.id + for row in s.execute(insert(B).returning(B.id, B.data), data).all() + } result = s.execute( update(B), [ - {"id": ids[0], "data": "d3_updated", "bd": "bd1_updated"}, - {"id": ids[1], "data": "d4_updated", "bd": "bd2_updated"}, + {"id": ids["d3"], "data": "d3_updated", "bd": "bd1_updated"}, + {"id": ids["d4"], "data": "d4_updated", "bd": "bd2_updated"}, ], ) @@ -765,12 +906,12 @@ class BulkDMLReturningInhTest: assert result is not None eq_( - s.scalars(select(B)).all(), - [ + set(s.scalars(select(B))), + { B( bd="bd1_updated", data="d3_updated", - id=ids[0], + id=ids["d3"], q=4, type="b", x=1, @@ -780,21 +921,24 @@ class BulkDMLReturningInhTest: B( bd="bd2_updated", data="d4_updated", - id=ids[1], + id=ids["d4"], q=8, type="b", x=5, y=6, z=7, ), - ], + }, ) - @testing.combinations(True, False, argnames="single_element") - def test_subclass_return_just_subclass_ids(self, single_element): + @testing.variation("single_element", [True, False]) + @testing.variation("sort_by_parameter_order", [True, False]) + def test_subclass_return_just_subclass_ids( + self, single_element, sort_by_parameter_order + ): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) if single_element: data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4} @@ -804,16 +948,24 @@ class BulkDMLReturningInhTest: {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, ] - ids = s.scalars(insert(B).returning(B.id), data).all() - actual_ids = s.scalars(select(B.id).order_by(B.data)).all() + ids = s.execute( + insert(B).returning( + B.id, + B.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + data, + ) + actual_ids = s.execute(select(B.id, B.data).order_by(B.id)) - eq_(ids, actual_ids) + if sort_by_parameter_order: + coll = list + else: + coll = set - @testing.combinations( - "orm", - "bulk", - argnames="insert_strategy", - ) + eq_(coll(ids), coll(actual_ids)) + + @testing.variation("insert_strategy", ["orm", "bulk", "bulk_ordered"]) @testing.requires.provisioned_upsert def test_base_class_upsert(self, insert_strategy): """upsert is really tricky. if you dont have any data updated, @@ -825,17 +977,22 @@ class BulkDMLReturningInhTest: """ A = self.classes.A - s = fixture_session() + s = fixture_session(bind=self.bind) initial_data = [ {"data": "d3", "x": 1, "y": 2, "q": 4}, {"data": "d4", "x": 5, "y": 6, "q": 8}, ] - ids = s.scalars(insert(A).returning(A.id), initial_data).all() + ids = { + row.data: row.id + for row in s.execute( + insert(A).returning(A.id, A.data), initial_data + ) + } upsert_data = [ { - "id": ids[0], + "id": ids["d3"], "type": "a", "data": "d3", "x": 1, @@ -849,7 +1006,7 @@ class BulkDMLReturningInhTest: "y": 5, }, { - "id": ids[1], + "id": ids["d4"], "type": "a", "data": "d4", "x": 5, @@ -868,24 +1025,28 @@ class BulkDMLReturningInhTest: config, A, (A,), - lambda inserted: {"data": inserted.data + " upserted"}, + set_lambda=lambda inserted: {"data": inserted.data + " upserted"}, + sort_by_parameter_order=insert_strategy.bulk_ordered, ) - if insert_strategy == "orm": + if insert_strategy.orm: result = s.scalars(stmt.values(upsert_data)) - elif insert_strategy == "bulk": - result = s.scalars(stmt, upsert_data) + elif insert_strategy.bulk or insert_strategy.bulk_ordered: + with self.assert_for_downgrade( + sort_by_parameter_order=insert_strategy.bulk_ordered + ): + result = s.scalars(stmt, upsert_data) else: - assert False + insert_strategy.fail() eq_( - result.all(), - [ - A(data="d3 upserted", id=ids[0], type="a", x=1, y=2), + set(result.all()), + { + A(data="d3 upserted", id=ids["d3"], type="a", x=1, y=2), A(data="d32", id=32, type="a", x=19, y=5), - A(data="d4 upserted", id=ids[1], type="a", x=5, y=6), + A(data="d4 upserted", id=ids["d4"], type="a", x=5, y=6), A(data="d28", id=28, type="a", x=9, y=15), - ], + }, ) @testing.combinations( @@ -893,13 +1054,14 @@ class BulkDMLReturningInhTest: "bulk", argnames="insert_strategy", ) + @testing.variation("sort_by_parameter_order", [True, False]) @testing.requires.provisioned_upsert - def test_subclass_upsert(self, insert_strategy): + def test_subclass_upsert(self, insert_strategy, sort_by_parameter_order): """note this is overridden in the joined version to expect failure""" A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) idd3 = 1 idd4 = 2 @@ -926,11 +1088,19 @@ class BulkDMLReturningInhTest: "q": 8, }, ] - ids = s.scalars(insert(B).returning(B.id), initial_data).all() + ids = { + row.data: row.id + for row in s.execute( + insert(B).returning( + B.id, B.data, sort_by_parameter_order=True + ), + initial_data, + ) + } upsert_data = [ { - "id": ids[0], + "id": ids["d3"], "type": "b", "data": "d3", "bd": "bd1_upserted", @@ -950,7 +1120,7 @@ class BulkDMLReturningInhTest: "q": 21, }, { - "id": ids[1], + "id": ids["d4"], "type": "b", "bd": "bd2_upserted", "data": "d4", @@ -975,19 +1145,24 @@ class BulkDMLReturningInhTest: config, B, (B,), - lambda inserted: { + set_lambda=lambda inserted: { "data": inserted.data + " upserted", "bd": inserted.bd + " upserted", }, + sort_by_parameter_order=bool(sort_by_parameter_order), ) - result = s.scalars(stmt, upsert_data) + + with self.assert_for_downgrade( + sort_by_parameter_order=bool(sort_by_parameter_order) + ): + result = s.scalars(stmt, upsert_data) eq_( - result.all(), - [ + set(result), + { B( bd="bd1_upserted upserted", data="d3 upserted", - id=ids[0], + id=ids["d3"], q=4, type="b", x=1, @@ -1007,7 +1182,7 @@ class BulkDMLReturningInhTest: B( bd="bd2_upserted upserted", data="d4 upserted", - id=ids[1], + id=ids["d4"], q=8, type="b", x=5, @@ -1024,10 +1199,34 @@ class BulkDMLReturningInhTest: y=15, z=10, ), - ], + }, ) +@testing.combinations( + ( + "no_sentinel", + False, + ), + ( + "w_sentinel", + True, + ), + argnames="use_sentinel", + id_="ia", +) +@testing.combinations( + ( + "nonrandom", + False, + ), + ( + "random", + True, + ), + argnames="randomize_returning", + id_="ia", +) class BulkDMLReturningJoinedInhTest( BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest ): @@ -1035,6 +1234,9 @@ class BulkDMLReturningJoinedInhTest( __requires__ = ("insert_returning", "insert_executemany_returning") __backend__ = True + use_sentinel = False + randomize_returning = False + @classmethod def setup_classes(cls): decl_base = cls.DeclarativeBasic @@ -1047,6 +1249,9 @@ class BulkDMLReturningJoinedInhTest( x: Mapped[Optional[int]] = mapped_column("xcol") y: Mapped[Optional[int]] + if cls.use_sentinel: + _sentinel: Mapped[int] = orm_insert_sentinel() + __mapper_args__ = { "polymorphic_identity": "a", "polymorphic_on": "type", @@ -1061,6 +1266,9 @@ class BulkDMLReturningJoinedInhTest( z: Mapped[Optional[int]] = mapped_column("zcol") q: Mapped[Optional[int]] + if cls.use_sentinel: + _sentinel: Mapped[int] = orm_insert_sentinel() + __mapper_args__ = {"polymorphic_identity": "b"} @testing.combinations( @@ -1073,17 +1281,26 @@ class BulkDMLReturningJoinedInhTest( False, argnames="single_param", ) + @testing.variation("sort_by_parameter_order", [True, False]) @testing.requires.provisioned_upsert - def test_subclass_upsert(self, insert_strategy, single_param): + def test_subclass_upsert( + self, + insert_strategy, + single_param, + sort_by_parameter_order, + ): A, B = self.classes("A", "B") - s = fixture_session() + s = fixture_session(bind=self.bind) initial_data = [ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}, {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8}, ] - ids = s.scalars(insert(B).returning(B.id), initial_data).all() + ids = s.scalars( + insert(B).returning(B.id, sort_by_parameter_order=True), + initial_data, + ).all() upsert_data = [ { @@ -1102,9 +1319,10 @@ class BulkDMLReturningJoinedInhTest( config, B, (B,), - lambda inserted: { + set_lambda=lambda inserted: { "bd": inserted.bd + " upserted", }, + sort_by_parameter_order=bool(sort_by_parameter_order), ) with expect_raises_message( @@ -1115,6 +1333,18 @@ class BulkDMLReturningJoinedInhTest( s.scalars(stmt, upsert_data) +@testing.combinations( + ( + "nonrandom", + False, + ), + ( + "random", + True, + ), + argnames="randomize_returning", + id_="ia", +) class BulkDMLReturningSingleInhTest( BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest ): @@ -1146,6 +1376,18 @@ class BulkDMLReturningSingleInhTest( __mapper_args__ = {"polymorphic_identity": "b"} +@testing.combinations( + ( + "nonrandom", + False, + ), + ( + "random", + True, + ), + argnames="randomize_returning", + id_="ia", +) class BulkDMLReturningConcreteInhTest( BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest ): @@ -1253,7 +1495,7 @@ class CTETest(fixtures.DeclarativeMappedTest): else: assert False - sess = fixture_session() + sess = fixture_session(bind=self.bind) with self.sql_execution_asserter() as asserter: if not expect_entity: diff --git a/test/orm/test_ac_relationships.py b/test/orm/test_ac_relationships.py index a5efd9930..b500c1e1a 100644 --- a/test/orm/test_ac_relationships.py +++ b/test/orm/test_ac_relationships.py @@ -3,6 +3,7 @@ from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import insert_sentinel from sqlalchemy import Integer from sqlalchemy import join from sqlalchemy import select @@ -42,6 +43,7 @@ class PartitionByFixture(fixtures.DeclarativeMappedTest): __tablename__ = "c" id = Column(Integer, primary_key=True) b_id = Column(ForeignKey("b.id")) + _sentinel = insert_sentinel() partition = select( B, diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index fb6fba704..562d9b9dc 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -10,6 +10,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional +from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -285,26 +286,26 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): Conditional( testing.db.dialect.insert_executemany_returning, [ - CompiledSQL( - "INSERT INTO test (id, foo) " - "VALUES (%(id)s, %(foo)s) " - "RETURNING test.bar", + RegexSQL( + r"INSERT INTO test \(id, foo\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test.bar, test.id", [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}], dialect="postgresql", ), ], [ - CompiledSQL( - "INSERT INTO test (id, foo) " - "VALUES (%(id)s, %(foo)s) " - "RETURNING test.bar", + RegexSQL( + r"INSERT INTO test \(id, foo\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test.bar, test.id", [{"foo": 5, "id": 1}], dialect="postgresql", ), - CompiledSQL( - "INSERT INTO test (id, foo) " - "VALUES (%(id)s, %(foo)s) " - "RETURNING test.bar", + RegexSQL( + r"INSERT INTO test \(id, foo\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test.bar, test.id", [{"foo": 10, "id": 2}], dialect="postgresql", ), @@ -468,23 +469,23 @@ class IdentityDefaultsOnUpdateTest(fixtures.MappedTest): Conditional( testing.db.dialect.insert_executemany_returning, [ - CompiledSQL( - "INSERT INTO test (foo) VALUES (%(foo)s) " - "RETURNING test.id", + RegexSQL( + r"INSERT INTO test \(foo\).*VALUES (.*).* " + r"RETURNING test.id, test.id AS id__1", [{"foo": 5}, {"foo": 10}], dialect="postgresql", ), ], [ - CompiledSQL( - "INSERT INTO test (foo) VALUES (%(foo)s) " - "RETURNING test.id", + RegexSQL( + r"INSERT INTO test \(foo\).*VALUES (.*).* " + r"RETURNING test.id, test.id AS id__1", [{"foo": 5}], dialect="postgresql", ), - CompiledSQL( - "INSERT INTO test (foo) VALUES (%(foo)s) " - "RETURNING test.id", + RegexSQL( + r"INSERT INTO test \(foo\).*VALUES (.*).* " + r"RETURNING test.id, test.id AS id__1", [{"foo": 10}], dialect="postgresql", ), diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index f98cae922..906771f16 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -506,7 +506,9 @@ class ExpireTest(_fixtures.FixtureTest): users, properties={ "addresses": relationship( - Address, cascade="all, refresh-expire" + Address, + cascade="all, refresh-expire", + order_by=addresses.c.id, ) }, ) diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 5835ef65a..f9c565c86 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -1,6 +1,7 @@ """Tests unitofwork operations.""" import datetime +import re import sqlalchemy as sa from sqlalchemy import Boolean @@ -3513,14 +3514,15 @@ class PartialNullPKTest(fixtures.MappedTest): class NoRowInsertedTest(fixtures.TestBase): """test #7594. - failure modes when INSERT doesnt actually insert a row. + failure modes when INSERT doesn't actually insert a row. + s """ - __backend__ = True - # the test manipulates INSERTS to become UPDATES to simulate - # "INSERT that returns no row" so both are needed - __requires__ = ("insert_returning", "update_returning") + # "INSERT that returns no row" so both are needed; the manipulations + # are currently postgresql or SQLite specific + __backend__ = True + __only_on__ = ("postgresql", "sqlite") @testing.fixture def null_server_default_fixture(self, registry, connection): @@ -3537,30 +3539,26 @@ class NoRowInsertedTest(fixtures.TestBase): def revert_insert( conn, cursor, statement, parameters, context, executemany ): - if statement.startswith("INSERT"): - if statement.endswith("RETURNING my_table.id"): - if executemany and isinstance(parameters, list): - # remove some rows, so the count is wrong - parameters = parameters[0:1] - else: - # statement should return no rows - statement = ( - "UPDATE my_table SET id=NULL WHERE 1!=1 " - "RETURNING my_table.id" - ) - parameters = {} + if re.match(r"INSERT.* RETURNING (?:my_table.)?id", statement): + if executemany and isinstance(parameters, list): + # remove some rows, so the count is wrong + parameters = parameters[0:1] else: - assert not testing.against( - "postgresql" - ), "this test has to at least run on PostgreSQL" - testing.config.skip_test( - "backend doesn't support the expected form of " - "RETURNING for this test to work" + # statement should return no rows + statement = ( + "UPDATE my_table SET id=NULL WHERE 1!=1 " + "RETURNING my_table.id" ) + parameters = {} return statement, parameters return MyClass + @testing.only_on( + "postgresql", + "only postgresql uses RETURNING for a single-row " + "INSERT among the DBs we are using in this test", + ) def test_insert_single_no_pk_correct_exception( self, null_server_default_fixture, connection ): diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 968285148..4d04ce0a6 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -1,5 +1,6 @@ from unittest.mock import Mock from unittest.mock import patch +import uuid from sqlalchemy import cast from sqlalchemy import DateTime @@ -9,6 +10,8 @@ from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity +from sqlalchemy import insert +from sqlalchemy import insert_sentinel from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import JSON @@ -19,6 +22,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import util +from sqlalchemy import Uuid from sqlalchemy.orm import attributes from sqlalchemy.orm import backref from sqlalchemy.orm import clear_mappers @@ -32,12 +36,14 @@ from sqlalchemy.testing import assert_warns_message from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import variation_fixture from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional +from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column @@ -2536,23 +2542,26 @@ class EagerDefaultsTest(fixtures.MappedTest): Conditional( testing.db.dialect.insert_executemany_returning, [ - CompiledSQL( - "INSERT INTO test (id) VALUES (%(id)s) " - "RETURNING test.foo", + RegexSQL( + r"INSERT INTO test \(id\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test.foo, test.id", [{"id": 1}, {"id": 2}], dialect="postgresql", ), ], [ - CompiledSQL( - "INSERT INTO test (id) VALUES (%(id)s) " - "RETURNING test.foo", + RegexSQL( + r"INSERT INTO test \(id\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test.foo, test.id", [{"id": 1}], dialect="postgresql", ), - CompiledSQL( - "INSERT INTO test (id) VALUES (%(id)s) " - "RETURNING test.foo", + RegexSQL( + r"INSERT INTO test \(id\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test.foo, test.id", [{"id": 2}], dialect="postgresql", ), @@ -2595,26 +2604,26 @@ class EagerDefaultsTest(fixtures.MappedTest): Conditional( testing.db.dialect.insert_executemany_returning, [ - CompiledSQL( - "INSERT INTO test3 (id, foo) " - "VALUES (%(id)s, lower(%(lower_1)s)) " - "RETURNING test3.foo", + RegexSQL( + r"INSERT INTO test3 \(id, foo\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test3.foo, test3.id", [{"id": 1}, {"id": 2}], dialect="postgresql", ), ], [ - CompiledSQL( - "INSERT INTO test3 (id, foo) " - "VALUES (%(id)s, lower(%(lower_1)s)) " - "RETURNING test3.foo", + RegexSQL( + r"INSERT INTO test3 \(id, foo\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test3.foo, test3.id", [{"id": 1}], dialect="postgresql", ), - CompiledSQL( - "INSERT INTO test3 (id, foo) " - "VALUES (%(id)s, lower(%(lower_1)s)) " - "RETURNING test3.foo", + RegexSQL( + r"INSERT INTO test3 \(id, foo\) .*" + r"VALUES \(.*\) .*" + r"RETURNING test3.foo, test3.id", [{"id": 2}], dialect="postgresql", ), @@ -3830,9 +3839,19 @@ class TryToFoolInsertManyValuesTest(fixtures.TestBase): ("identity", testing.requires.identity_columns), ], ) - def test_bulk_insert_maintains_correct_pks( - self, decl_base, connection, pk_type - ): + @testing.variation( + "sentinel", + [ + "none", # passes because we automatically downgrade + # for no sentinel col + "implicit_not_omitted", + "implicit_omitted", + "explicit", + "default_uuid", + "default_string_uuid", + ], + ) + def test_original_use_case(self, decl_base, connection, pk_type, sentinel): """test #9603. this uses the ORM to ensure the ORM is not using any kind of @@ -3840,75 +3859,221 @@ class TryToFoolInsertManyValuesTest(fixtures.TestBase): specific to SQL Server, however if we identify any other similar issues in other DBs we should add tests to this suite. + NOTE: Assuming the code is not doing the correct kind of INSERT + for SQL Server, the SQL Server failure here is still extremely + difficult to trip; any changes to the table structure and it no longer + fails, and it's likely this version of the test might not fail on SQL + Server in any case. The test_this_really_fails_on_mssql_wo_full_fix is + more optimized to producing the SQL Server failure as reliably as + possible, however this can change at any time as SQL Server's decisions + here are completely opaque. + """ class Datum(decl_base): __tablename__ = "datum" - id = Column(Integer, autoincrement=False, primary_key=True) - data = Column(String(10)) + datum_id = Column(Integer, Identity(), primary_key=True) class Result(decl_base): __tablename__ = "result" if pk_type.plain_autoinc: - id = Column(Integer, primary_key=True) # noqa: A001 + result_id = Column(Integer, primary_key=True) elif pk_type.sequence: - id = Column( # noqa: A001 - Integer, Sequence("rid_seq", start=1), primary_key=True + result_id = Column( + Integer, + Sequence("result_id_seq", start=1), + primary_key=True, ) elif pk_type.identity: - id = Column( # noqa: A001 - Integer, Identity(), primary_key=True - ) + result_id = Column(Integer, Identity(), primary_key=True) else: pk_type.fail() - thing = Column(Integer) - lft_datum_id = Column(Integer, ForeignKey(Datum.id)) + lft_datum_id = Column(ForeignKey(Datum.datum_id)) - decl_base.metadata.create_all(connection) - with Session(connection) as sess: + lft_datum = relationship(Datum) - size = 15 - datum_ids = list(range(1, size + 1)) + if sentinel.implicit_not_omitted or sentinel.implicit_omitted: + _sentinel = insert_sentinel( + omit_from_statements=bool(sentinel.implicit_omitted), + ) + elif sentinel.explicit: + some_uuid = Column( + Uuid(), insert_sentinel=True, nullable=False + ) + elif sentinel.default_uuid or sentinel.default_string_uuid: + _sentinel = Column( + Uuid(native_uuid=bool(sentinel.default_uuid)), + insert_sentinel=True, + default=uuid.uuid4, + ) - sess.add_all([Datum(id=id_, data=f"d{id_}") for id_ in datum_ids]) - sess.flush() + class ResultDatum(decl_base): - result_data = [ - Result(thing=num, lft_datum_id=datum_ids[num % size]) - for num in range(size * size) - ] - sess.add_all(result_data) + __tablename__ = "result_datum" + + result_id = Column(ForeignKey(Result.result_id), primary_key=True) + lft_datum_id = Column(ForeignKey(Datum.datum_id)) + + lft_datum = relationship(Datum) + result = relationship(Result) + + if sentinel.implicit_not_omitted or sentinel.implicit_omitted: + _sentinel = insert_sentinel( + omit_from_statements=bool(sentinel.implicit_omitted), + ) + elif sentinel.explicit: + some_uuid = Column( + Uuid(native_uuid=False), + insert_sentinel=True, + nullable=False, + ) + elif sentinel.default_uuid or sentinel.default_string_uuid: + _sentinel = Column( + Uuid(native_uuid=bool(sentinel.default_uuid)), + insert_sentinel=True, + default=uuid.uuid4, + ) + + decl_base.metadata.create_all(connection) + N = 13 + with Session(connection) as sess: + full_range = [num for num in range(N * N)] + + datum_idx = [Datum() for num in range(N)] + sess.add_all(datum_idx) sess.flush() - # this is what we expected we put in - the_data_in_order_should_be = [ - (num + 1, num, datum_ids[num % size]) - for num in range(size * size) - ] + if sentinel.explicit: + result_idx = [ + Result( + lft_datum=datum_idx[n % N], + some_uuid=uuid.uuid4(), + ) + for n in full_range + ] + else: + result_idx = [ + Result( + lft_datum=datum_idx[n % N], + ) + for n in full_range + ] + + sess.add_all(result_idx) + + if sentinel.explicit: + sess.add_all( + ResultDatum( + lft_datum=datum_idx[n % N], + result=result_idx[n], + some_uuid=uuid.uuid4(), + ) + for n in full_range + ) + else: + sess.add_all( + ResultDatum( + lft_datum=datum_idx[n % N], + result=result_idx[n], + ) + for n in full_range + ) - # and yes, that's what went in - eq_( - sess.execute( - select( - Result.id, Result.thing, Result.lft_datum_id - ).order_by(Result.id) - ).all(), - the_data_in_order_should_be, + fixtures.insertmanyvalues_fixture( + sess.connection(), warn_on_downgraded=True ) + if ( + sentinel.none + and testing.db.dialect.insert_returning + and testing.db.dialect.use_insertmanyvalues + and select() + .compile(dialect=testing.db.dialect) + ._get_sentinel_column_for_table(Result.__table__) + is None + ): + with expect_warnings( + "Batches were downgraded for sorted INSERT" + ): + sess.flush() + else: + sess.flush() - # however, if insertmanyvalues is turned on, OUTPUT inserted - # did not give us the rows in the order we sent, so ids were - # mis-applied. even if we sort the original records by the - # ids that were given - eq_( - [ - (r.id, r.thing, r.lft_datum_id) - for r in sorted(result_data, key=lambda r: r.id) - ], - the_data_in_order_should_be, + num_bad = ( + sess.query(ResultDatum) + .join(Result) + .filter( + Result.lft_datum_id != ResultDatum.lft_datum_id, + ) + .count() ) + + eq_(num_bad, 0) + + @testing.only_on("mssql") + def test_this_really_fails_on_mssql_wo_full_fix( + self, decl_base, connection + ): + """this test tries as hard as possible to simulate the SQL server + failure. + + """ + + class Datum(decl_base): + + __tablename__ = "datum" + + datum_id = Column(Integer, primary_key=True) + data = Column(String(10)) + + class Result(decl_base): + + __tablename__ = "result" + + result_id = Column(Integer, primary_key=True) + + lft_datum_id = Column(Integer, ForeignKey(Datum.datum_id)) + + # use this instead to resolve; FK constraint is what affects + # SQL server + # lft_datum_id = Column(Integer) + + decl_base.metadata.create_all(connection) + + size = 13 + + result = connection.execute( + insert(Datum).returning(Datum.datum_id), + [{"data": f"d{i}"} for i in range(size)], + ) + + datum_ids = [row[0] for row in result] + assert datum_ids == list(range(1, size + 1)) + + # the rows are not inserted in the order that the table valued + # expressions are given. SQL Server organizes the rows so that the + # "datum_id" values are grouped + result = connection.execute( + insert(Result).returning( + Result.result_id, + Result.lft_datum_id, + sort_by_parameter_order=True, + ), + [ + {"lft_datum_id": datum_ids[num % size]} + for num in range(size * size) + ], + ) + + we_expect_returning_is = [ + {"result_id": num + 1, "lft_datum_id": datum_ids[num % size]} + for num in range(size * size) + ] + what_we_got_is = [ + {"result_id": row[0], "lft_datum_id": row[1]} for row in result + ] + eq_(we_expect_returning_is, what_we_got_is) diff --git a/test/requirements.py b/test/requirements.py index b76e671e9..9a8500ac3 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1904,6 +1904,10 @@ class DefaultRequirements(SuiteRequirements): return only_if(["postgresql >= 10", "oracle >= 12", "mssql"]) @property + def multiple_identity_columns(self): + return only_if(["postgresql >= 10"]) + + @property def identity_columns_standard(self): return self.identity_columns + skip_if("mssql") diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index e05fafbdf..4d0864af9 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -35,6 +35,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import insert +from sqlalchemy import insert_sentinel from sqlalchemy import Integer from sqlalchemy import intersect from sqlalchemy import join @@ -7705,3 +7706,243 @@ class ResultMapTest(fixtures.TestBase): for orig_obj, proxied_obj in zip(orig, proxied): is_(orig_obj, proxied_obj) + + +class OmitFromStatementsTest(fixtures.TestBase, AssertsCompiledSQL): + """test the _omit_from_statements parameter. + + this somewhat awkward parameter was added to suit the case of + "insert_sentinel" columns that would try very hard not to be noticed + when not needed, by being omitted from any SQL statement that does not + refer to them explicitly. If they are referred to explicitly or + are in a context where their client side default has to be fired off, + then they are present. + + If marked public, the feature could be used as a general "I don't want to + see this column unless I asked it to" use case. + + """ + + __dialect__ = "default_enhanced" + + @testing.fixture + def t1(self): + m1 = MetaData() + + t1 = Table( + "t1", + m1, + Column("id", Integer, primary_key=True), + Column("a", Integer), + Column( + "b", Integer, _omit_from_statements=True, insert_sentinel=True + ), + Column("c", Integer), + Column("d", Integer, _omit_from_statements=True), + Column("e", Integer), + ) + return t1 + + @testing.fixture + def t2(self): + m1 = MetaData() + + t2 = Table( + "t2", + m1, + Column("id", Integer, primary_key=True), + Column("a", Integer), + Column( + "b", + Integer, + _omit_from_statements=True, + insert_sentinel=True, + default="10", + onupdate="20", + ), + Column("c", Integer, default="14", onupdate="19"), + Column( + "d", + Integer, + _omit_from_statements=True, + default="5", + onupdate="15", + ), + Column("e", Integer), + ) + return t2 + + @testing.fixture + def t3(self): + m1 = MetaData() + + t3 = Table( + "t3", + m1, + Column("id", Integer, primary_key=True), + Column("a", Integer), + insert_sentinel("b"), + Column("c", Integer, default="14", onupdate="19"), + ) + return t3 + + def test_select_omitted(self, t1): + self.assert_compile( + select(t1), "SELECT t1.id, t1.a, t1.c, t1.e FROM t1" + ) + + def test_select_from_subquery_includes_hidden(self, t1): + s1 = select(t1.c.a, t1.c.b, t1.c.c, t1.c.d, t1.c.e).subquery() + eq_(s1.c.keys(), ["a", "b", "c", "d", "e"]) + + self.assert_compile( + select(s1), + "SELECT anon_1.a, anon_1.b, anon_1.c, anon_1.d, anon_1.e " + "FROM (SELECT t1.a AS a, t1.b AS b, t1.c AS c, t1.d AS d, " + "t1.e AS e FROM t1) AS anon_1", + ) + + def test_select_from_subquery_omitted(self, t1): + s1 = select(t1).subquery() + + eq_(s1.c.keys(), ["id", "a", "c", "e"]) + self.assert_compile( + select(s1), + "SELECT anon_1.id, anon_1.a, anon_1.c, anon_1.e FROM " + "(SELECT t1.id AS id, t1.a AS a, t1.c AS c, t1.e AS e FROM t1) " + "AS anon_1", + ) + + def test_insert_omitted(self, t1): + self.assert_compile( + insert(t1), "INSERT INTO t1 (id, a, c, e) VALUES (:id, :a, :c, :e)" + ) + + def test_insert_from_select_omitted(self, t1): + self.assert_compile( + insert(t1).from_select(["a", "c", "e"], select(t1)), + "INSERT INTO t1 (a, c, e) SELECT t1.id, t1.a, t1.c, t1.e FROM t1", + ) + + def test_insert_from_select_included(self, t1): + self.assert_compile( + insert(t1).from_select(["a", "b", "c", "d", "e"], select(t1)), + "INSERT INTO t1 (a, b, c, d, e) SELECT t1.id, t1.a, t1.c, t1.e " + "FROM t1", + ) + + def test_insert_from_select_defaults_included(self, t2): + self.assert_compile( + insert(t2).from_select(["a", "c", "e"], select(t2)), + "INSERT INTO t2 (a, c, e, b, d) SELECT t2.id, t2.a, t2.c, t2.e, " + ":b AS anon_1, :d AS anon_2 FROM t2", + # TODO: do we have a test in test_defaults for this, that the + # default values get set up as expected? + ) + + def test_insert_from_select_sentinel_defaults_omitted(self, t3): + self.assert_compile( + # a pure SentinelDefault not included here, so there is no 'b' + insert(t3).from_select(["a", "c"], select(t3)), + "INSERT INTO t3 (a, c) SELECT t3.id, t3.a, t3.c FROM t3", + ) + + def test_insert_omitted_return_col_nonspecified(self, t1): + self.assert_compile( + insert(t1).returning(t1), + "INSERT INTO t1 (id, a, c, e) VALUES (:id, :a, :c, :e) " + "RETURNING t1.id, t1.a, t1.c, t1.e", + ) + + def test_insert_omitted_return_col_specified(self, t1): + self.assert_compile( + insert(t1).returning(t1.c.a, t1.c.b, t1.c.c, t1.c.d, t1.c.e), + "INSERT INTO t1 (id, a, c, e) VALUES (:id, :a, :c, :e) " + "RETURNING t1.a, t1.b, t1.c, t1.d, t1.e", + ) + + def test_insert_omitted_no_params(self, t1): + self.assert_compile( + insert(t1), "INSERT INTO t1 () VALUES ()", params={} + ) + + def test_insert_omitted_no_params_defaults(self, t2): + # omit columns that nonetheless have client-side defaults + # are included + self.assert_compile( + insert(t2), + "INSERT INTO t2 (b, c, d) VALUES (:b, :c, :d)", + params={}, + ) + + def test_insert_omitted_no_params_defaults_no_sentinel(self, t3): + # omit columns that nonetheless have client-side defaults + # are included + self.assert_compile( + insert(t3), + "INSERT INTO t3 (c) VALUES (:c)", + params={}, + ) + + def test_insert_omitted_defaults(self, t2): + self.assert_compile( + insert(t2), "INSERT INTO t2 (id, a, c, e) VALUES (:id, :a, :c, :e)" + ) + + def test_update_omitted(self, t1): + self.assert_compile( + update(t1), "UPDATE t1 SET id=:id, a=:a, c=:c, e=:e" + ) + + def test_update_omitted_defaults(self, t2): + self.assert_compile( + update(t2), "UPDATE t2 SET id=:id, a=:a, c=:c, e=:e" + ) + + def test_update_omitted_no_params_defaults(self, t2): + # omit columns that nonetheless have client-side defaults + # are included + self.assert_compile( + update(t2), "UPDATE t2 SET b=:b, c=:c, d=:d", params={} + ) + + def test_select_include_col(self, t1): + self.assert_compile( + select(t1, t1.c.b, t1.c.d), + "SELECT t1.id, t1.a, t1.c, t1.e, t1.b, t1.d FROM t1", + ) + + def test_update_include_col(self, t1): + self.assert_compile( + update(t1).values(a=5, b=10, c=15, d=20, e=25), + "UPDATE t1 SET a=:a, b=:b, c=:c, d=:d, e=:e", + checkparams={"a": 5, "b": 10, "c": 15, "d": 20, "e": 25}, + ) + + def test_insert_include_col(self, t1): + self.assert_compile( + insert(t1).values(a=5, b=10, c=15, d=20, e=25), + "INSERT INTO t1 (a, b, c, d, e) VALUES (:a, :b, :c, :d, :e)", + checkparams={"a": 5, "b": 10, "c": 15, "d": 20, "e": 25}, + ) + + def test_insert_include_col_via_keys(self, t1): + self.assert_compile( + insert(t1), + "INSERT INTO t1 (a, b, c, d, e) VALUES (:a, :b, :c, :d, :e)", + params={"a": 5, "b": 10, "c": 15, "d": 20, "e": 25}, + checkparams={"a": 5, "b": 10, "c": 15, "d": 20, "e": 25}, + ) + + def test_select_omitted_incl_whereclause(self, t1): + self.assert_compile( + select(t1).where(t1.c.d == 5), + "SELECT t1.id, t1.a, t1.c, t1.e FROM t1 WHERE t1.d = :d_1", + checkparams={"d_1": 5}, + ) + + def test_select_omitted_incl_order_by(self, t1): + self.assert_compile( + select(t1).order_by(t1.c.d), + "SELECT t1.id, t1.a, t1.c, t1.e FROM t1 ORDER BY t1.d", + ) diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 633972b45..01f6b290c 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -1567,6 +1567,7 @@ class CurrentParametersTest(fixtures.TablesTest): some_table = self.tables.some_table some_table.c.x.default.arg = gen_default + some_table.c.x._reset_memoizations() return fn @testing.combinations( diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 308f654f7..904271fcb 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -68,7 +68,7 @@ class _InsertTestBase: class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): - __dialect__ = "default" + __dialect__ = "default_enhanced" @testing.combinations( ((), ("z",), ()), @@ -94,6 +94,51 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): assert isinstance(stmt._return_defaults_columns, tuple) eq_(set(stmt._return_defaults_columns), expected) + @testing.variation("add_values", ["before", "after"]) + @testing.variation("multi_values", [True, False]) + @testing.variation("sort_by_parameter_order", [True, False]) + def test_sort_by_parameter_ordering_parameter_no_multi_values( + self, add_values, multi_values, sort_by_parameter_order + ): + t = table("foo", column("x"), column("y"), column("z")) + stmt = insert(t) + + if add_values.before: + if multi_values: + stmt = stmt.values([{"y": 6}, {"y": 7}]) + else: + stmt = stmt.values(y=6) + + stmt = stmt.returning( + t.c.x, sort_by_parameter_order=bool(sort_by_parameter_order) + ) + + if add_values.after: + if multi_values: + stmt = stmt.values([{"y": 6}, {"y": 7}]) + else: + stmt = stmt.values(y=6) + + if multi_values: + if sort_by_parameter_order: + with expect_raises_message( + exc.CompileError, + "RETURNING cannot be determinstically sorted " + "when using an INSERT", + ): + stmt.compile() + else: + self.assert_compile( + stmt, + "INSERT INTO foo (y) VALUES (:y_m0), (:y_m1) " + "RETURNING foo.x", + ) + else: + self.assert_compile( + stmt, + "INSERT INTO foo (y) VALUES (:y) RETURNING foo.x", + ) + def test_binds_that_match_columns(self): """test bind params named after column names replace the normal SET/VALUES generation. diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 3b5a1856c..f545671e7 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -1,11 +1,19 @@ +import contextlib +import functools import itertools +import uuid from sqlalchemy import and_ +from sqlalchemy import ARRAY from sqlalchemy import bindparam +from sqlalchemy import DateTime from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import Identity +from sqlalchemy import insert +from sqlalchemy import insert_sentinel from sqlalchemy import INT from sqlalchemy import Integer from sqlalchemy import literal @@ -14,16 +22,22 @@ from sqlalchemy import Sequence from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import TypeDecorator +from sqlalchemy import Uuid from sqlalchemy import VARCHAR from sqlalchemy.engine import cursor as _cursor +from sqlalchemy.sql.compiler import InsertmanyvaluesSentinelOpts from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing import provision +from sqlalchemy.testing.fixtures import insertmanyvalues_fixture from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -924,7 +938,7 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): config, data, (data,), - lambda inserted: {"x": inserted.x + " upserted"}, + set_lambda=lambda inserted: {"x": inserted.x + " upserted"}, ) result = connection.execute(stmt, upsert_data) @@ -1169,3 +1183,1612 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): "INSERT..RETURNING when executemany", ): conn.execute(stmt.returning(t.c.id), data) + + +class IMVSentinelTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("insert_returning",) + + def _expect_downgrade_warnings( + self, + *, + warn_for_downgrades, + sort_by_parameter_order, + separate_sentinel=False, + server_autoincrement=False, + client_side_pk=False, + autoincrement_is_sequence=False, + connection=None, + ): + + if connection: + dialect = connection.dialect + else: + dialect = testing.db.dialect + + if ( + sort_by_parameter_order + and warn_for_downgrades + and dialect.use_insertmanyvalues + ): + + if ( + not separate_sentinel + and ( + server_autoincrement + and ( + not ( + dialect.insertmanyvalues_implicit_sentinel # noqa: E501 + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + or ( + autoincrement_is_sequence + and not ( + dialect.insertmanyvalues_implicit_sentinel # noqa: E501 + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ) + ) + ) + or ( + not separate_sentinel + and not server_autoincrement + and not client_side_pk + ) + ): + return expect_warnings( + "Batches were downgraded", + raise_on_any_unexpected=True, + ) + + return contextlib.nullcontext() + + @testing.variation + def sort_by_parameter_order(self): + return [True, False] + + @testing.variation + def warn_for_downgrades(self): + return [True, False] + + @testing.variation + def randomize_returning(self): + return [True, False] + + @testing.requires.insertmanyvalues + def test_fixture_randomizing(self, connection, metadata): + t = Table( + "t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", String(50)), + ) + metadata.create_all(connection) + + insertmanyvalues_fixture(connection, randomize_rows=True) + + results = set() + + for i in range(15): + result = connection.execute( + insert(t).returning(t.c.data, sort_by_parameter_order=False), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + + hashed_result = tuple(result.all()) + results.add(hashed_result) + if len(results) > 1: + return + else: + assert False, "got same order every time for 15 tries" + + @testing.only_on("postgresql>=13") + @testing.variation("downgrade", [True, False]) + def test_fixture_downgraded(self, connection, metadata, downgrade): + t = Table( + "t", + metadata, + Column( + "id", + Uuid(), + server_default=func.gen_random_uuid(), + primary_key=True, + ), + Column("data", String(50)), + ) + metadata.create_all(connection) + + r1 = connection.execute( + insert(t).returning(t.c.data, sort_by_parameter_order=True), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + eq_(r1.all(), [("d1",), ("d2",), ("d3",)]) + + if downgrade: + insertmanyvalues_fixture(connection, warn_on_downgraded=True) + + with self._expect_downgrade_warnings( + warn_for_downgrades=True, + sort_by_parameter_order=True, + ): + connection.execute( + insert(t).returning( + t.c.data, sort_by_parameter_order=True + ), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + else: + # run a plain test to help ensure the fixture doesn't leak to + # other tests + r1 = connection.execute( + insert(t).returning(t.c.data, sort_by_parameter_order=True), + [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], + ) + eq_(r1.all(), [("d1",), ("d2",), ("d3",)]) + + @testing.variation( + "sequence_type", + [ + ("sequence", testing.requires.sequences), + ("identity", testing.requires.identity_columns), + ], + ) + @testing.variation("increment", ["positive", "negative", "implicit"]) + @testing.variation("explicit_sentinel", [True, False]) + def test_invalid_identities( + self, + metadata, + connection, + warn_for_downgrades, + randomize_returning, + sort_by_parameter_order, + sequence_type: testing.Variation, + increment: testing.Variation, + explicit_sentinel, + ): + if sequence_type.sequence: + seq_cls = functools.partial(Sequence, name="t1_id_seq") + elif sequence_type.identity: + seq_cls = Identity + else: + sequence_type.fail() + + if increment.implicit: + sequence = seq_cls(start=1) + elif increment.positive: + sequence = seq_cls(start=1, increment=1) + elif increment.negative: + sequence = seq_cls(start=-1, increment=-1) + else: + increment.fail() + + t1 = Table( + "t1", + metadata, + Column( + "id", + Integer, + sequence, + primary_key=True, + insert_sentinel=bool(explicit_sentinel), + ), + Column("data", String(50)), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.id, + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}"} for i in range(10)] + + use_imv = testing.db.dialect.use_insertmanyvalues + if ( + use_imv + and increment.negative + and explicit_sentinel + and sort_by_parameter_order + ): + with expect_raises_message( + exc.InvalidRequestError, + rf"Can't use " + rf"{'SEQUENCE' if sequence_type.sequence else 'IDENTITY'} " + rf"default with negative increment", + ): + connection.execute(stmt, data) + return + elif ( + use_imv + and explicit_sentinel + and sort_by_parameter_order + and sequence_type.sequence + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + r"Column t1.id can't be explicitly marked as a sentinel " + r"column .* as the particular type of default generation", + ): + connection.execute(stmt, data) + return + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + server_autoincrement=not increment.negative, + autoincrement_is_sequence=sequence_type.sequence, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + if increment.negative: + expected_data = [(-1 - i, f"d{i}") for i in range(10)] + else: + expected_data = [(i + 1, f"d{i}") for i in range(10)] + + eq_( + coll(result), + coll(expected_data), + ) + + @testing.combinations( + Integer(), + String(50), + (ARRAY(Integer()), testing.requires.array_type), + DateTime(), + Uuid(), + argnames="datatype", + ) + def test_inserts_w_all_nulls( + self, connection, metadata, sort_by_parameter_order, datatype + ): + """this test is geared towards the INSERT..SELECT VALUES case, + where if the VALUES have all NULL for some column, PostgreSQL assumes + the datatype must be TEXT and throws for other table datatypes. So an + additional layer of casts is applied to the SELECT p0,p1, p2... part of + the statement for all datatypes unconditionally. Even though the VALUES + clause also has bind casts for selected datatypes, this NULL handling + is needed even for simple datatypes. We'd prefer not to render bind + casts for all possible datatypes as that affects other kinds of + statements as well and also is very verbose for insertmanyvalues. + + + """ + t = Table( + "t", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", datatype), + ) + metadata.create_all(connection) + result = connection.execute( + insert(t).returning( + t.c.id, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + [{"data": None}, {"data": None}, {"data": None}], + ) + eq_(set(result), {(1,), (2,), (3,)}) + + @testing.variation("pk_type", ["autoinc", "clientside"]) + @testing.variation("add_sentinel", ["none", "clientside", "sentinel"]) + def test_imv_w_additional_values( + self, + metadata, + connection, + sort_by_parameter_order, + pk_type: testing.Variation, + randomize_returning, + warn_for_downgrades, + add_sentinel, + ): + if pk_type.autoinc: + pk_col = Column("id", Integer(), Identity(), primary_key=True) + elif pk_type.clientside: + pk_col = Column("id", Uuid(), default=uuid.uuid4, primary_key=True) + else: + pk_type.fail() + + if add_sentinel.clientside: + extra_col = insert_sentinel( + "sentinel", type_=Uuid(), default=uuid.uuid4 + ) + elif add_sentinel.sentinel: + extra_col = insert_sentinel("sentinel") + else: + extra_col = Column("sentinel", Integer()) + + t1 = Table( + "t1", + metadata, + pk_col, + Column("data", String(30)), + Column("moredata", String(30)), + extra_col, + Column( + "has_server_default", + String(50), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = ( + insert(t1) + .values(moredata="more data") + .returning( + t1.c.data, + t1.c.moredata, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + ) + data = [{"data": f"d{i}"} for i in range(10)] + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + separate_sentinel=not add_sentinel.none, + server_autoincrement=pk_type.autoinc, + client_side_pk=pk_type.clientside, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll( + [ + (f"d{i}", "more data", "some_server_default") + for i in range(10) + ] + ), + ) + + def test_sentinel_incorrect_rowcount( + self, metadata, connection, sort_by_parameter_order + ): + """test assertions to ensure sentinel values don't have duplicates""" + + uuids = [uuid.uuid4() for i in range(10)] + + # make some dupes + uuids[3] = uuids[5] + uuids[9] = uuids[5] + + t1 = Table( + "data", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", String(50)), + insert_sentinel( + "uuids", + Uuid(), + default=functools.partial(next, iter(uuids)), + ), + ) + + metadata.create_all(connection) + + stmt = insert(t1).returning( + t1.c.data, + t1.c.uuids, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}"} for i in range(10)] + + if testing.db.dialect.use_insertmanyvalues and sort_by_parameter_order: + with expect_raises_message( + exc.InvalidRequestError, + "Sentinel-keyed result set did not produce correct " + "number of rows 10; produced 8.", + ): + connection.execute(stmt, data) + else: + result = connection.execute(stmt, data) + eq_( + set(result.all()), + {(f"d{i}", uuids[i]) for i in range(10)}, + ) + + @testing.variation("resolve_sentinel_values", [True, False]) + def test_sentinel_cant_match_keys( + self, + metadata, + connection, + sort_by_parameter_order, + resolve_sentinel_values, + ): + """test assertions to ensure sentinel values passed in parameter + structures can be identified when they come back in cursor.fetchall(). + + Values that are further modified by the database driver or by + SQL expressions (as in the case below) before being INSERTed + won't match coming back out, so datatypes need to implement + _sentinel_value_resolver() if this is the case. + + """ + + class UnsymmetricDataType(TypeDecorator): + cache_ok = True + impl = String + + def bind_expression(self, bindparam): + return func.lower(bindparam) + + if resolve_sentinel_values: + + def _sentinel_value_resolver(self, dialect): + def fix_sentinels(value): + return value.lower() + + return fix_sentinels + + t1 = Table( + "data", + metadata, + Column("id", Integer, Identity(), primary_key=True), + Column("data", String(50)), + insert_sentinel("unsym", UnsymmetricDataType(10)), + ) + + metadata.create_all(connection) + + stmt = insert(t1).returning( + t1.c.data, + t1.c.unsym, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}", "unsym": f"UPPER_d{i}"} for i in range(10)] + + if ( + testing.db.dialect.use_insertmanyvalues + and sort_by_parameter_order + and not resolve_sentinel_values + ): + with expect_raises_message( + exc.InvalidRequestError, + r"Can't match sentinel values in result set to parameter " + r"sets; key 'UPPER_d.' was not found.", + ): + connection.execute(stmt, data) + else: + result = connection.execute(stmt, data) + eq_( + set(result.all()), + {(f"d{i}", f"upper_d{i}") for i in range(10)}, + ) + + @testing.variation("add_insert_sentinel", [True, False]) + def test_sentinel_insert_default_pk_only( + self, + metadata, + connection, + sort_by_parameter_order, + add_insert_sentinel, + ): + t1 = Table( + "data", + metadata, + Column( + "id", + Integer, + Identity(), + insert_sentinel=bool(add_insert_sentinel), + primary_key=True, + ), + Column("data", String(50)), + ) + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, randomize_rows=True, warn_on_downgraded=False + ) + + stmt = insert(t1).returning( + t1.c.id, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{} for i in range(3)] + + if ( + testing.db.dialect.use_insertmanyvalues + and add_insert_sentinel + and sort_by_parameter_order + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column data.id can't be explicitly marked as a " + f"sentinel column when using the {testing.db.dialect.name} " + "dialect", + ): + connection.execute(stmt, data) + return + else: + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + # if we used a client side default function, or we had no sentinel + # at all, we're sorted + coll = list + else: + # otherwise we are not, we randomized the order in any case + coll = set + + eq_( + coll(result), + coll( + [ + (1,), + (2,), + (3,), + ] + ), + ) + + @testing.only_on("postgresql>=13") + @testing.variation("default_type", ["server_side", "client_side"]) + @testing.variation("add_insert_sentinel", [True, False]) + def test_no_sentinel_on_non_int_ss_function( + self, + metadata, + connection, + add_insert_sentinel, + default_type, + sort_by_parameter_order, + ): + + t1 = Table( + "data", + metadata, + Column( + "id", + Uuid(), + server_default=func.gen_random_uuid() + if default_type.server_side + else None, + default=uuid.uuid4 if default_type.client_side else None, + primary_key=True, + insert_sentinel=bool(add_insert_sentinel), + ), + Column("data", String(50)), + ) + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, randomize_rows=True, warn_on_downgraded=False + ) + + stmt = insert(t1).returning( + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + ] + + if ( + default_type.server_side + and add_insert_sentinel + and sort_by_parameter_order + ): + with expect_raises_message( + exc.InvalidRequestError, + r"Column data.id can't be a sentinel column because it uses " + r"an explicit server side default that's not the Identity\(\)", + ): + connection.execute(stmt, data) + return + else: + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + # if we used a client side default function, or we had no sentinel + # at all, we're sorted + coll = list + else: + # otherwise we are not, we randomized the order in any case + coll = set + + eq_( + coll(result), + coll( + [ + ("d1",), + ("d2",), + ("d3",), + ] + ), + ) + + @testing.variation( + "pk_type", + [ + ("plain_autoinc", testing.requires.autoincrement_without_sequence), + ("sequence", testing.requires.sequences), + ("identity", testing.requires.identity_columns), + ], + ) + @testing.variation( + "sentinel", + [ + "none", # passes because we automatically downgrade + # for no sentinel col + "implicit_not_omitted", + "implicit_omitted", + "explicit", + "explicit_but_nullable", + "default_uuid", + "default_string_uuid", + ("identity", testing.requires.multiple_identity_columns), + ("sequence", testing.requires.sequences), + ], + ) + def test_sentinel_col_configurations( + self, + pk_type: testing.Variation, + sentinel: testing.Variation, + sort_by_parameter_order, + randomize_returning, + metadata, + connection, + ): + + if pk_type.plain_autoinc: + pk_col = Column("id", Integer, primary_key=True) + elif pk_type.sequence: + pk_col = Column( + "id", + Integer, + Sequence("result_id_seq", start=1), + primary_key=True, + ) + elif pk_type.identity: + pk_col = Column("id", Integer, Identity(), primary_key=True) + else: + pk_type.fail() + + if sentinel.implicit_not_omitted or sentinel.implicit_omitted: + _sentinel = insert_sentinel( + "sentinel", + omit_from_statements=bool(sentinel.implicit_omitted), + ) + elif sentinel.explicit: + _sentinel = Column( + "some_uuid", Uuid(), nullable=False, insert_sentinel=True + ) + elif sentinel.explicit_but_nullable: + _sentinel = Column("some_uuid", Uuid(), insert_sentinel=True) + elif sentinel.default_uuid or sentinel.default_string_uuid: + _sentinel = Column( + "some_uuid", + Uuid(native_uuid=bool(sentinel.default_uuid)), + insert_sentinel=True, + default=uuid.uuid4, + ) + elif sentinel.identity: + _sentinel = Column( + "some_identity", + Integer, + Identity(), + insert_sentinel=True, + ) + elif sentinel.sequence: + _sentinel = Column( + "some_identity", + Integer, + Sequence("some_id_seq", start=1), + insert_sentinel=True, + ) + else: + _sentinel = Column("some_uuid", Uuid()) + + t = Table("t", metadata, pk_col, Column("data", String(50)), _sentinel) + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=True, + ) + + stmt = insert(t).returning( + pk_col, + t.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + if sentinel.explicit: + data = [ + {"data": f"d{i}", "some_uuid": uuid.uuid4()} + for i in range(150) + ] + else: + data = [{"data": f"d{i}"} for i in range(150)] + + expect_sentinel_use = ( + sort_by_parameter_order + and testing.db.dialect.insert_returning + and testing.db.dialect.use_insertmanyvalues + ) + + if sentinel.explicit_but_nullable and expect_sentinel_use: + with expect_raises_message( + exc.InvalidRequestError, + "Column t.some_uuid has been marked as a sentinel column " + "with no default generation function; it at least needs to " + "be marked nullable=False", + ): + connection.execute(stmt, data) + return + + elif ( + expect_sentinel_use + and sentinel.sequence + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column t.some_identity can't be explicitly marked as a " + f"sentinel column when using the {testing.db.dialect.name} " + "dialect", + ): + connection.execute(stmt, data) + return + + elif ( + sentinel.none + and expect_sentinel_use + and stmt.compile( + dialect=testing.db.dialect + )._get_sentinel_column_for_table(t) + is None + ): + with expect_warnings( + "Batches were downgraded for sorted INSERT", + raise_on_any_unexpected=True, + ): + result = connection.execute(stmt, data) + else: + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + eq_(list(result), [(i + 1, f"d{i}") for i in range(150)]) + else: + eq_(set(result), {(i + 1, f"d{i}") for i in range(150)}) + + @testing.variation( + "return_type", ["include_sentinel", "default_only", "return_defaults"] + ) + @testing.variation("add_sentinel_flag_to_col", [True, False]) + def test_sentinel_on_non_autoinc_primary_key( + self, + metadata, + connection, + return_type: testing.Variation, + sort_by_parameter_order, + randomize_returning, + add_sentinel_flag_to_col, + ): + uuids = [uuid.uuid4() for i in range(10)] + _some_uuids = iter(uuids) + + t1 = Table( + "data", + metadata, + Column( + "id", + Uuid(), + default=functools.partial(next, _some_uuids), + primary_key=True, + insert_sentinel=bool(add_sentinel_flag_to_col), + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=True, + ) + + if sort_by_parameter_order: + collection_cls = list + else: + collection_cls = set + + metadata.create_all(connection) + + if sort_by_parameter_order: + kw = {"sort_by_parameter_order": True} + else: + kw = {} + + if return_type.include_sentinel: + stmt = t1.insert().returning( + t1.c.id, t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.default_only: + stmt = t1.insert().returning( + t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.return_defaults: + stmt = t1.insert().return_defaults(**kw) + + else: + return_type.fail() + + r = connection.execute( + stmt, + [{"data": f"d{i}"} for i in range(1, 6)], + ) + + if return_type.include_sentinel: + eq_(r.keys(), ["id", "data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [ + (uuids[i], f"d{i+1}", "some_server_default") + for i in range(5) + ] + ), + ) + elif return_type.default_only: + eq_(r.keys(), ["data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [ + ( + f"d{i+1}", + "some_server_default", + ) + for i in range(5) + ] + ), + ) + elif return_type.return_defaults: + eq_(r.keys(), ["has_server_default"]) + eq_(r.inserted_primary_key_rows, [(uuids[i],) for i in range(5)]) + eq_( + r.returned_defaults_rows, + [ + ("some_server_default",), + ("some_server_default",), + ("some_server_default",), + ("some_server_default",), + ("some_server_default",), + ], + ) + eq_(r.all(), []) + else: + return_type.fail() + + def test_client_composite_pk( + self, + metadata, + connection, + randomize_returning, + sort_by_parameter_order, + warn_for_downgrades, + ): + uuids = [uuid.uuid4() for i in range(10)] + + t1 = Table( + "data", + metadata, + Column( + "id1", + Uuid(), + default=functools.partial(next, iter(uuids)), + primary_key=True, + ), + Column( + "id2", + # note this is testing that plain populated PK cols + # also qualify as sentinels since they have to be there + String(30), + primary_key=True, + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + result = connection.execute( + insert(t1).returning( + t1.c.id1, + t1.c.id2, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + [{"id2": f"id{i}", "data": f"d{i}"} for i in range(10)], + ) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll( + [ + (uuids[i], f"id{i}", f"d{i}", "some_server_default") + for i in range(10) + ] + ), + ) + + @testing.variation("add_sentinel", [True, False]) + @testing.variation( + "set_identity", [(True, testing.requires.identity_columns), False] + ) + def test_no_pk( + self, + metadata, + connection, + randomize_returning, + sort_by_parameter_order, + warn_for_downgrades, + add_sentinel, + set_identity, + ): + if set_identity: + id_col = Column("id", Integer(), Identity()) + else: + id_col = Column("id", Integer()) + + uuids = [uuid.uuid4() for i in range(10)] + + sentinel_col = Column( + "unique_id", + Uuid, + default=functools.partial(next, iter(uuids)), + insert_sentinel=bool(add_sentinel), + ) + t1 = Table( + "nopk", + metadata, + id_col, + Column("data", String(50)), + sentinel_col, + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.id, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + if not set_identity: + data = [{"id": i + 1, "data": f"d{i}"} for i in range(10)] + else: + data = [{"data": f"d{i}"} for i in range(10)] + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + separate_sentinel=add_sentinel, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll([(i + 1, f"d{i}", "some_server_default") for i in range(10)]), + ) + + @testing.variation("add_sentinel_to_col", [True, False]) + @testing.variation( + "set_autoincrement", [True, (False, testing.skip_if("mariadb"))] + ) + def test_hybrid_client_composite_pk( + self, + metadata, + connection, + randomize_returning, + sort_by_parameter_order, + warn_for_downgrades, + add_sentinel_to_col, + set_autoincrement, + ): + """test a pk that is part server generated part client generated. + + The server generated col by itself can be the sentinel. if it's + part of the PK and is autoincrement=True then it is automatically + used as such. if not, there's a graceful downgrade. + + """ + + t1 = Table( + "data", + metadata, + Column( + "idint", + Integer, + Identity(), + autoincrement=True if set_autoincrement else "auto", + primary_key=True, + insert_sentinel=bool(add_sentinel_to_col), + ), + Column( + "idstr", + String(30), + primary_key=True, + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + + no_autoincrement = ( + not testing.requires.supports_autoincrement_w_composite_pk.enabled # noqa: E501 + ) + if set_autoincrement and no_autoincrement: + with expect_raises_message( + exc.CompileError, + r".*SQLite does not support autoincrement for " + "composite primary keys", + ): + metadata.create_all(connection) + return + else: + + metadata.create_all(connection) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.idint, + t1.c.idstr, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + + if no_autoincrement: + data = [ + {"idint": i + 1, "idstr": f"id{i}", "data": f"d{i}"} + for i in range(10) + ] + else: + data = [{"idstr": f"id{i}", "data": f"d{i}"} for i in range(10)] + + if ( + testing.db.dialect.use_insertmanyvalues + and add_sentinel_to_col + and sort_by_parameter_order + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column data.idint can't be explicitly marked as a sentinel " + "column when using the sqlite dialect", + ): + result = connection.execute(stmt, data) + return + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + separate_sentinel=not set_autoincrement and add_sentinel_to_col, + server_autoincrement=set_autoincrement, + ): + result = connection.execute(stmt, data) + + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_( + coll(result), + coll( + [ + (i + 1, f"id{i}", f"d{i}", "some_server_default") + for i in range(10) + ] + ), + ) + + @testing.variation("composite_pk", [True, False]) + @testing.only_on( + [ + "+psycopg", + "+psycopg2", + "+pysqlite", + "+mysqlclient", + "+cx_oracle", + "+oracledb", + ] + ) + def test_failure_mode_if_i_dont_send_value( + self, metadata, connection, sort_by_parameter_order, composite_pk + ): + """test that we get a regular integrity error if a required + PK value was not sent, that is, imv does not get in the way + + """ + t1 = Table( + "data", + metadata, + Column("id", String(30), primary_key=True), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + if composite_pk: + t1.append_column(Column("uid", Uuid(), default=uuid.uuid4)) + + metadata.create_all(connection) + + with expect_warnings( + r".*but has no Python-side or server-side default ", + raise_on_any_unexpected=True, + ): + with expect_raises(exc.IntegrityError): + connection.execute( + insert(t1).returning( + t1.c.id, + t1.c.data, + t1.c.has_server_default, + sort_by_parameter_order=bool(sort_by_parameter_order), + ), + [{"data": f"d{i}"} for i in range(10)], + ) + + @testing.variation("add_sentinel_flag_to_col", [True, False]) + @testing.variation( + "return_type", ["include_sentinel", "default_only", "return_defaults"] + ) + @testing.variation( + "sentinel_type", + [ + ("autoincrement", testing.requires.autoincrement_without_sequence), + "identity", + "sequence", + ], + ) + def test_implicit_autoincrement_sentinel( + self, + metadata, + connection, + return_type: testing.Variation, + sort_by_parameter_order, + randomize_returning, + sentinel_type, + add_sentinel_flag_to_col, + ): + + if sentinel_type.identity: + sentinel_args = [Identity()] + elif sentinel_type.sequence: + sentinel_args = [Sequence("id_seq", start=1)] + else: + sentinel_args = [] + t1 = Table( + "data", + metadata, + Column( + "id", + Integer, + *sentinel_args, + primary_key=True, + insert_sentinel=bool(add_sentinel_flag_to_col), + ), + Column("data", String(50)), + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=False, + ) + + if sort_by_parameter_order: + collection_cls = list + else: + collection_cls = set + + metadata.create_all(connection) + + if sort_by_parameter_order: + kw = {"sort_by_parameter_order": True} + else: + kw = {} + + if return_type.include_sentinel: + stmt = t1.insert().returning( + t1.c.id, t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.default_only: + stmt = t1.insert().returning( + t1.c.data, t1.c.has_server_default, **kw + ) + elif return_type.return_defaults: + stmt = t1.insert().return_defaults(**kw) + + else: + return_type.fail() + + if ( + testing.db.dialect.use_insertmanyvalues + and add_sentinel_flag_to_col + and sort_by_parameter_order + and ( + not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + or ( + # currently a SQL Server case, we dont yet render a + # syntax for SQL Server sequence w/ deterministic + # ordering. The INSERT..SELECT could be restructured + # further to support this at a later time however + # sequences with SQL Server are very unusual. + sentinel_type.sequence + and not ( + testing.db.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.SEQUENCE + ) + ) + ) + ): + with expect_raises_message( + exc.InvalidRequestError, + "Column data.id can't be explicitly marked as a " + f"sentinel column when using the {testing.db.dialect.name} " + "dialect", + ): + connection.execute( + stmt, + [{"data": f"d{i}"} for i in range(1, 6)], + ) + return + else: + r = connection.execute( + stmt, + [{"data": f"d{i}"} for i in range(1, 6)], + ) + + if return_type.include_sentinel: + eq_(r.keys(), ["id", "data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [(i, f"d{i}", "some_server_default") for i in range(1, 6)] + ), + ) + elif return_type.default_only: + eq_(r.keys(), ["data", "has_server_default"]) + eq_( + collection_cls(r), + collection_cls( + [(f"d{i}", "some_server_default") for i in range(1, 6)] + ), + ) + elif return_type.return_defaults: + eq_(r.keys(), ["id", "has_server_default"]) + eq_( + collection_cls(r.inserted_primary_key_rows), + collection_cls([(i + 1,) for i in range(5)]), + ) + eq_( + collection_cls(r.returned_defaults_rows), + collection_cls( + [ + ( + 1, + "some_server_default", + ), + ( + 2, + "some_server_default", + ), + ( + 3, + "some_server_default", + ), + ( + 4, + "some_server_default", + ), + ( + 5, + "some_server_default", + ), + ] + ), + ) + eq_(r.all(), []) + else: + return_type.fail() + + @testing.variation("pk_type", ["serverside", "clientside"]) + @testing.variation( + "sentinel_type", + [ + "use_pk", + ("use_pk_explicit", testing.skip_if("sqlite")), + "separate_uuid", + "separate_sentinel", + ], + ) + @testing.requires.provisioned_upsert + def test_upsert_downgrades( + self, + metadata, + connection, + pk_type: testing.Variation, + sort_by_parameter_order, + randomize_returning, + sentinel_type, + warn_for_downgrades, + ): + if pk_type.serverside: + pk_col = Column( + "id", + Integer(), + primary_key=True, + insert_sentinel=bool(sentinel_type.use_pk_explicit), + ) + elif pk_type.clientside: + pk_col = Column( + "id", + Uuid(), + default=uuid.uuid4, + primary_key=True, + insert_sentinel=bool(sentinel_type.use_pk_explicit), + ) + else: + pk_type.fail() + + if sentinel_type.separate_uuid: + extra_col = Column( + "sent_col", + Uuid(), + default=uuid.uuid4, + insert_sentinel=True, + nullable=False, + ) + elif sentinel_type.separate_sentinel: + extra_col = insert_sentinel("sent_col") + else: + extra_col = Column("sent_col", Integer) + + t1 = Table( + "upsert_table", + metadata, + pk_col, + Column("data", String(50)), + extra_col, + Column( + "has_server_default", + String(30), + server_default="some_server_default", + ), + ) + metadata.create_all(connection) + + result = connection.execute( + insert(t1).returning( + t1.c.id, t1.c.data, sort_by_parameter_order=True + ), + [{"data": "d1"}, {"data": "d2"}], + ) + d1d2 = list(result) + + if pk_type.serverside: + new_ids = [10, 15, 3] + elif pk_type.clientside: + new_ids = [uuid.uuid4() for i in range(3)] + else: + pk_type.fail() + + upsert_data = [ + {"id": d1d2[0][0], "data": "d1 new"}, + {"id": new_ids[0], "data": "d10"}, + {"id": new_ids[1], "data": "d15"}, + {"id": d1d2[1][0], "data": "d2 new"}, + {"id": new_ids[2], "data": "d3"}, + ] + + fixtures.insertmanyvalues_fixture( + connection, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = provision.upsert( + config, + t1, + (t1.c.data, t1.c.has_server_default), + set_lambda=lambda inserted: { + "data": inserted.data + " upserted", + }, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=sort_by_parameter_order, + ): + result = connection.execute(stmt, upsert_data) + + expected_data = [ + ("d1 new upserted", "some_server_default"), + ("d10", "some_server_default"), + ("d15", "some_server_default"), + ("d2 new upserted", "some_server_default"), + ("d3", "some_server_default"), + ] + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_(coll(result), coll(expected_data)) + + def test_auto_downgraded_non_mvi_dialect( + self, + metadata, + testing_engine, + randomize_returning, + warn_for_downgrades, + sort_by_parameter_order, + ): + """Accommodate the case of the dialect that supports RETURNING, but + does not support "multi values INSERT" syntax. + + These dialects should still provide insertmanyvalues/returning + support, using downgraded batching. + + For now, we are still keeping this entire thing "opt in" by requiring + that use_insertmanyvalues=True, which means we can't simplify the + ORM by not worrying about dialects where ordering is available or + not. + + However, dialects that use RETURNING, but don't support INSERT VALUES + (..., ..., ...) can set themselves up like this:: + + class MyDialect(DefaultDialect): + use_insertmanyvalues = True + supports_multivalues_insert = False + + This test runs for everyone **including** Oracle, where we + exercise Oracle using "insertmanyvalues" without "multivalues_insert". + + """ + engine = testing_engine() + engine.connect().close() + + engine.dialect.supports_multivalues_insert = False + engine.dialect.use_insertmanyvalues = True + + uuids = [uuid.uuid4() for i in range(10)] + + t1 = Table( + "t1", + metadata, + Column("id", Uuid(), default=functools.partial(next, iter(uuids))), + Column("data", String(50)), + ) + metadata.create_all(engine) + + with engine.connect() as conn: + + fixtures.insertmanyvalues_fixture( + conn, + randomize_rows=bool(randomize_returning), + warn_on_downgraded=bool(warn_for_downgrades), + ) + + stmt = insert(t1).returning( + t1.c.id, + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + data = [{"data": f"d{i}"} for i in range(10)] + + with self._expect_downgrade_warnings( + warn_for_downgrades=warn_for_downgrades, + sort_by_parameter_order=True, # will warn even if not sorted + connection=conn, + ): + result = conn.execute(stmt, data) + + expected_data = [(uuids[i], f"d{i}") for i in range(10)] + if sort_by_parameter_order: + coll = list + else: + coll = set + + eq_(coll(result), coll(expected_data)) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 8f6c81f15..a8c6bdbbe 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -21,6 +21,7 @@ from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import Index +from sqlalchemy import insert_sentinel from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import Numeric @@ -48,6 +49,7 @@ from sqlalchemy.sql import naming from sqlalchemy.sql import operators from sqlalchemy.sql.base import _NONE_NAME from sqlalchemy.sql.elements import literal_column +from sqlalchemy.sql.schema import _InsertSentinelColumnDefault from sqlalchemy.sql.schema import RETAIN_SCHEMA from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -621,7 +623,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): c.add_is_dependent_on(a) eq_(meta.sorted_tables, [d, b, a, c, e]) - def test_deterministic_order(self): + def test_sort_by_parameter_order(self): meta = MetaData() a = Table("a", meta, Column("foo", Integer)) b = Table("b", meta, Column("foo", Integer)) @@ -633,7 +635,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): a.add_is_dependent_on(b) eq_(meta.sorted_tables, [b, c, d, a, e]) - def test_fks_deterministic_order(self): + def test_fks_sort_by_parameter_order(self): meta = MetaData() a = Table("a", meta, Column("foo", Integer, ForeignKey("b.foo"))) b = Table("b", meta, Column("foo", Integer)) @@ -6079,3 +6081,52 @@ class CopyDialectOptionsTest(fixtures.TestBase): m2 = MetaData() t2 = t1.to_metadata(m2) # make a copy self.check_dialect_options_(t2) + + +class SentinelColTest(fixtures.TestBase): + def make_table_w_sentinel_col(self, *arg, **kw): + return Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True), + Column(*arg, **kw), + ) + + def test_only_one_sentinel(self): + with expect_raises_message( + exc.ArgumentError, + "a Table may have only one explicit sentinel column", + ): + Table( + "t", + MetaData(), + Column("id", Integer, primary_key=True, insert_sentinel=True), + Column("ASdf", String(50)), + insert_sentinel("sentinel"), + ) + + def test_no_sentinel_default_on_notnull(self): + with expect_raises_message( + exc.ArgumentError, + "The _InsertSentinelColumnDefault may only be applied to a " + "Column that is nullable", + ): + self.make_table_w_sentinel_col( + "sentinel", + Integer, + nullable=False, + insert_sentinel=True, + default=_InsertSentinelColumnDefault(), + ) + + def test_no_sentinel_default_on_non_sentinel(self): + with expect_raises_message( + exc.ArgumentError, + "The _InsertSentinelColumnDefault may only be applied to a " + "Column marked as insert_sentinel=True", + ): + self.make_table_w_sentinel_col( + "sentinel", + Integer, + default=_InsertSentinelColumnDefault(), + ) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index e0299e334..7d40fa76f 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -15,6 +15,7 @@ from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import type_coerce from sqlalchemy import update +from sqlalchemy.sql import crud from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -24,6 +25,8 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import provision from sqlalchemy.testing.schema import Column @@ -85,6 +88,31 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): stmt.compile, ) + @testing.combinations("return_defaults", "returning", argnames="methname") + @testing.combinations(insert, update, delete, argnames="construct") + def test_sort_by_parameter_ordering_param( + self, methname, construct, table_fixture + ): + t = table_fixture + + stmt = construct(t) + + if construct is insert: + is_false(stmt._sort_by_parameter_order) + + meth = getattr(stmt, methname) + + if construct in (update, delete): + with expect_raises_message( + sa_exc.ArgumentError, + rf"The 'sort_by_parameter_order' argument to " + rf"{methname}\(\) only applies to INSERT statements", + ): + meth(t.c.id, sort_by_parameter_order=True) + else: + new = meth(t.c.id, sort_by_parameter_order=True) + is_true(new._sort_by_parameter_order) + def test_return_defaults_no_returning(self, table_fixture): t = table_fixture @@ -1347,15 +1375,37 @@ class InsertManyReturningTest(fixtures.TablesTest): t1 = self.tables.type_cases + grm = crud._get_returning_modifiers + + def _grm(*arg, **kw): + ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + _, + _, + ) = grm(*arg, **kw) + + return ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + False, + None, + ) + with mock.patch.object( - testing.db.dialect.statement_compiler, - "_insert_stmt_should_use_insertmanyvalues", - lambda *arg: False, + crud, + "_get_returning_modifiers", + new=_grm, ): with expect_raises_message( sa_exc.StatementError, r'Statement does not have "insertmanyvalues" enabled, ' - r"can\'t use INSERT..RETURNING with executemany in this case.", + r"can\'t use INSERT..RETURNING with executemany in this " + "case.", ): connection.execute( t1.insert().returning(t1.c.id, t1.c.goofy, t1.c.full), @@ -1446,7 +1496,7 @@ class InsertManyReturningTest(fixtures.TablesTest): config, t1, (t1.c.id, t1.c.insdef, t1.c.data), - (lambda excluded: {"data": excluded.data + " excluded"}) + set_lambda=(lambda excluded: {"data": excluded.data + " excluded"}) if update_cols else None, ) |
