diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-07-18 15:08:37 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-09-24 11:15:32 -0400 |
| commit | 2bcc97da424eef7db9a5d02f81d02344925415ee (patch) | |
| tree | 13d4f04bc7dd40a0207f86aa2fc3a3b49e065674 /test/sql | |
| parent | 332188e5680574368001ded52eb0a9d259ecdef5 (diff) | |
| download | sqlalchemy-2bcc97da424eef7db9a5d02f81d02344925415ee.tar.gz | |
implement batched INSERT..VALUES () () for executemany
the feature is enabled for all built in backends
when RETURNING is used,
except for Oracle that doesn't need it, and on
psycopg2 and mssql+pyodbc it is used for all INSERT statements,
not just those that use RETURNING.
third party dialects would need to opt in to the new feature
by setting use_insertmanyvalues to True.
Also adds dialect-level guards against using returning
with executemany where we dont have an implementation to
suit it. execute single w/ returning still defers to the
server without us checking.
Fixes: #6047
Fixes: #7907
Change-Id: I3936d3c00003f02e322f2e43fb949d0e6e568304
Diffstat (limited to 'test/sql')
| -rw-r--r-- | test/sql/test_insert.py | 59 | ||||
| -rw-r--r-- | test/sql/test_insert_exec.py | 239 | ||||
| -rw-r--r-- | test/sql/test_returning.py | 473 |
3 files changed, 732 insertions, 39 deletions
diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 61e0783e4..23a850f08 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -469,6 +469,65 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=postgresql.dialect(), ) + def test_heterogeneous_multi_values(self): + """for #6047, originally I thought we'd take any insert().values() + and be able to convert it to a "many" style execution that we can + cache. + + however, this test shows that we cannot, at least not in the + general case, because SQL expressions are not guaranteed to be in + the same position each time, therefore each ``VALUES`` clause is not + of the same structure. + + """ + + m = MetaData() + + t1 = Table( + "t", + m, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + + stmt = t1.insert().values( + [ + {"x": 1, "y": func.sum(1, 2), "z": 2}, + {"x": func.sum(1, 2), "y": 2, "z": 3}, + {"x": func.sum(1, 2), "y": 2, "z": func.foo(10)}, + ] + ) + + # SQL expressions in the params at arbitrary locations means + # we have to scan them at compile time, and the shape of the bound + # parameters is not predictable. so for #6047 where I originally + # thought all of values() could be rewritten, this makes it not + # really worth it. + self.assert_compile( + stmt, + "INSERT INTO t (x, y, z) VALUES " + "(%(x_m0)s, sum(%(sum_1)s, %(sum_2)s), %(z_m0)s), " + "(sum(%(sum_3)s, %(sum_4)s), %(y_m1)s, %(z_m1)s), " + "(sum(%(sum_5)s, %(sum_6)s), %(y_m2)s, foo(%(foo_1)s))", + checkparams={ + "x_m0": 1, + "sum_1": 1, + "sum_2": 2, + "z_m0": 2, + "sum_3": 1, + "sum_4": 2, + "y_m1": 2, + "z_m1": 3, + "sum_5": 1, + "sum_6": 2, + "y_m2": 2, + "foo_1": 10, + }, + dialect=postgresql.dialect(), + ) + def test_insert_seq_pk_multi_values_seq_not_supported(self): m = MetaData() diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index b6945813e..4ce093156 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -1,4 +1,7 @@ +import itertools + from sqlalchemy import and_ +from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func @@ -10,8 +13,10 @@ from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import VARCHAR +from sqlalchemy.engine import cursor as _cursor from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock @@ -712,3 +717,237 @@ class TableInsertTest(fixtures.TablesTest): table=t, parameters=dict(id=None, data="data", x=5), ) + + +class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): + __backend__ = True + __requires__ = ("insertmanyvalues",) + + @classmethod + def define_tables(cls, metadata): + Table( + "data", + metadata, + Column("id", Integer, primary_key=True), + Column("x", String(50)), + Column("y", String(50)), + Column("z", Integer, server_default="5"), + ) + + Table( + "Unitéble2", + metadata, + Column("méil", Integer, primary_key=True), + Column("\u6e2c\u8a66", Integer), + ) + + def test_insert_unicode_keys(self, connection): + table = self.tables["Unitéble2"] + + stmt = table.insert().returning(table.c["méil"]) + + connection.execute( + stmt, + [ + {"méil": 1, "\u6e2c\u8a66": 1}, + {"méil": 2, "\u6e2c\u8a66": 2}, + {"méil": 3, "\u6e2c\u8a66": 3}, + ], + ) + + eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)]) + + def test_insert_returning_values(self, connection): + t = self.tables.data + + conn = connection + page_size = conn.dialect.insertmanyvalues_page_size or 100 + data = [ + {"x": "x%d" % i, "y": "y%d" % i} + for i in range(1, page_size * 2 + 27) + ] + result = conn.execute(t.insert().returning(t.c.x, t.c.y), data) + + eq_([tup[0] for tup in result.cursor.description], ["x", "y"]) + eq_(result.keys(), ["x", "y"]) + assert t.c.x in result.keys() + assert t.c.id not in result.keys() + assert not result._soft_closed + assert isinstance( + result.cursor_strategy, + _cursor.FullyBufferedCursorFetchStrategy, + ) + assert not result.closed + eq_(result.mappings().all(), data) + + assert result._soft_closed + # assert result.closed + assert result.cursor is None + + def test_insert_returning_preexecute_pk(self, metadata, connection): + counter = itertools.count(1) + + t = Table( + "t", + self.metadata, + Column( + "id", + Integer, + primary_key=True, + default=lambda: next(counter), + ), + Column("data", Integer), + ) + metadata.create_all(connection) + + result = connection.execute( + t.insert().return_defaults(), + [{"data": 1}, {"data": 2}, {"data": 3}], + ) + + eq_(result.inserted_primary_key_rows, [(1,), (2,), (3,)]) + + def test_insert_returning_defaults(self, connection): + t = self.tables.data + + conn = connection + + result = conn.execute(t.insert(), {"x": "x0", "y": "y0"}) + first_pk = result.inserted_primary_key[0] + + page_size = conn.dialect.insertmanyvalues_page_size or 100 + total_rows = page_size * 5 + 27 + data = [{"x": "x%d" % i, "y": "y%d" % i} for i in range(1, total_rows)] + result = conn.execute(t.insert().returning(t.c.id, t.c.z), data) + + eq_( + result.all(), + [(pk, 5) for pk in range(1 + first_pk, total_rows + first_pk)], + ) + + def test_insert_return_pks_default_values(self, connection): + """test sending multiple, empty rows into an INSERT and getting primary + key values back. + + This has to use a format that indicates at least one DEFAULT in + multiple parameter sets, i.e. "INSERT INTO table (anycol) VALUES + (DEFAULT) (DEFAULT) (DEFAULT) ... RETURNING col" + + if the database doesnt support this (like SQLite, mssql), it + actually runs the statement that many times on the cursor. + This is much less efficient, but is still more efficient than + how it worked previously where we'd run the statement that many + times anyway. + + There's ways to make it work for those, such as on SQLite + we can use "INSERT INTO table (pk_col) VALUES (NULL) RETURNING pk_col", + but that assumes an autoincrement pk_col, not clear how this + could be produced generically. + + """ + t = self.tables.data + + conn = connection + + result = conn.execute(t.insert(), {"x": "x0", "y": "y0"}) + first_pk = result.inserted_primary_key[0] + + page_size = conn.dialect.insertmanyvalues_page_size or 100 + total_rows = page_size * 2 + 27 + data = [{} for i in range(1, total_rows)] + result = conn.execute(t.insert().returning(t.c.id), data) + + eq_( + result.all(), + [(pk,) for pk in range(1 + first_pk, total_rows + first_pk)], + ) + + @testing.combinations(None, 100, 329, argnames="batchsize") + @testing.combinations( + "engine", + "conn_execution_option", + "exec_execution_option", + "stmt_execution_option", + argnames="paramtype", + ) + def test_page_size_adjustment(self, testing_engine, batchsize, paramtype): + + t = self.tables.data + + if paramtype == "engine" and batchsize is not None: + e = testing_engine( + options={ + "insertmanyvalues_page_size": batchsize, + }, + ) + + # sqlite, since this is a new engine, re-create the table + if not testing.requires.independent_connections.enabled: + t.create(e, checkfirst=True) + else: + e = testing.db + + totalnum = 1275 + data = [{"x": "x%d" % i, "y": "y%d" % i} for i in range(1, totalnum)] + + insert_count = 0 + + with e.begin() as conn: + + @event.listens_for(conn, "before_cursor_execute") + def go(conn, cursor, statement, parameters, context, executemany): + nonlocal insert_count + if statement.startswith("INSERT"): + insert_count += 1 + + stmt = t.insert() + if batchsize is None or paramtype == "engine": + conn.execute(stmt.returning(t.c.id), data) + elif paramtype == "conn_execution_option": + conn = conn.execution_options( + insertmanyvalues_page_size=batchsize + ) + conn.execute(stmt.returning(t.c.id), data) + elif paramtype == "stmt_execution_option": + stmt = stmt.execution_options( + insertmanyvalues_page_size=batchsize + ) + conn.execute(stmt.returning(t.c.id), data) + elif paramtype == "exec_execution_option": + conn.execute( + stmt.returning(t.c.id), + data, + execution_options=dict( + insertmanyvalues_page_size=batchsize + ), + ) + else: + assert False + + assert_batchsize = batchsize or 1000 + eq_( + insert_count, + totalnum // assert_batchsize + + (1 if totalnum % assert_batchsize else 0), + ) + + def test_disabled(self, testing_engine): + + e = testing_engine( + options={"use_insertmanyvalues": False}, + share_pool=True, + transfer_staticpool=True, + ) + totalnum = 1275 + data = [{"x": "x%d" % i, "y": "y%d" % i} for i in range(1, totalnum)] + + t = self.tables.data + + with e.begin() as conn: + stmt = t.insert() + with expect_raises_message( + exc.StatementError, + "with current server capabilities does not support " + "INSERT..RETURNING when executemany", + ): + conn.execute(stmt.returning(t.c.id), data) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index c458e3262..f8cc32517 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -19,8 +19,12 @@ from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults +from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import mock +from sqlalchemy.testing import provision from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.types import TypeDecorator @@ -71,10 +75,12 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): stmt = stmt.returning(t.c.x) + stmt = stmt.return_defaults() assert_raises_message( - sa_exc.InvalidRequestError, - "RETURNING is already configured on this statement", - stmt.return_defaults, + sa_exc.CompileError, + r"Can't compile statement that includes returning\(\) " + r"and return_defaults\(\) simultaneously", + stmt.compile, ) def test_return_defaults_no_returning(self, table_fixture): @@ -224,7 +230,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): cls.GoofyType = GoofyType Table( - "tables", + "returning_tbl", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -236,7 +242,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) def test_column_targeting(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl result = connection.execute( table.insert().returning(table.c.id, table.c.full), {"persons": 1, "full": False}, @@ -260,7 +266,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): eq_(row["goofy"], "FOOsomegoofyBAR") def test_labeling(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl result = connection.execute( table.insert() .values(persons=6) @@ -270,7 +276,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): assert row["lala"] == 6 def test_anon_expressions(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl GoofyType = self.GoofyType result = connection.execute( table.insert() @@ -286,27 +292,75 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): row = result.first() eq_(row[0], 30) - @testing.fails_on( - "mssql", - "driver has unknown issue with string concatenation " - "in INSERT RETURNING", + @testing.combinations( + (lambda table: (table.c.strval + "hi",), ("str1hi",)), + ( + lambda table: ( + table.c.persons, + table.c.full, + table.c.strval + "hi", + ), + ( + 5, + False, + "str1hi", + ), + ), + ( + lambda table: ( + table.c.persons, + table.c.strval + "hi", + table.c.full, + ), + (5, "str1hi", False), + ), + ( + lambda table: ( + table.c.strval + "hi", + table.c.persons, + table.c.full, + ), + ("str1hi", 5, False), + ), + argnames="testcase, expected_row", ) - def test_insert_returning_w_expression_one(self, connection): - table = self.tables.tables + def test_insert_returning_w_expression( + self, connection, testcase, expected_row + ): + table = self.tables.returning_tbl + + exprs = testing.resolve_lambda(testcase, table=table) result = connection.execute( - table.insert().returning(table.c.strval + "hi"), + table.insert().returning(*exprs), {"persons": 5, "full": False, "strval": "str1"}, ) - eq_(result.fetchall(), [("str1hi",)]) + eq_(result.fetchall(), [expected_row]) result2 = connection.execute( select(table.c.id, table.c.strval).order_by(table.c.id) ) eq_(result2.fetchall(), [(1, "str1")]) + def test_insert_explicit_pk_col(self, connection): + table = self.tables.returning_tbl + result = connection.execute( + table.insert().returning(table.c.id, table.c.strval), + {"id": 1, "strval": "str1"}, + ) + + eq_( + result.fetchall(), + [ + ( + 1, + "str1", + ) + ], + ) + def test_insert_returning_w_type_coerce_expression(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl result = connection.execute( table.insert().returning(type_coerce(table.c.goofy, String)), {"persons": 5, "goofy": "somegoofy"}, @@ -320,7 +374,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): eq_(result2.fetchall(), [(1, "FOOsomegoofyBAR")]) def test_no_ipk_on_returning(self, connection, close_result_when_finished): - table = self.tables.tables + table = self.tables.returning_tbl result = connection.execute( table.insert().returning(table.c.id), {"persons": 1, "full": False} ) @@ -334,7 +388,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) def test_insert_returning(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl result = connection.execute( table.insert().returning(table.c.id), {"persons": 1, "full": False} ) @@ -342,8 +396,8 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): eq_(result.fetchall(), [(1,)]) @testing.requires.multivalues_inserts - def test_multirow_returning(self, connection): - table = self.tables.tables + def test_multivalues_insert_returning(self, connection): + table = self.tables.returning_tbl ins = ( table.insert() .returning(table.c.id, table.c.persons) @@ -372,7 +426,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): literal_true = "1" result4 = connection.exec_driver_sql( - "insert into tables (id, persons, %sfull%s) " + "insert into returning_tbl (id, persons, %sfull%s) " "values (5, 10, %s) returning persons" % (quote, quote, literal_true) ) @@ -388,7 +442,7 @@ class UpdateReturningTest(fixtures.TablesTest, AssertsExecutionResults): define_tables = InsertReturningTest.define_tables def test_update_returning(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -408,7 +462,7 @@ class UpdateReturningTest(fixtures.TablesTest, AssertsExecutionResults): eq_(result2.fetchall(), [(1, True), (2, False)]) def test_update_returning_w_expression_one(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl connection.execute( table.insert(), [ @@ -431,7 +485,7 @@ class UpdateReturningTest(fixtures.TablesTest, AssertsExecutionResults): eq_(result2.fetchall(), [(1, "str1"), (2, "str2")]) def test_update_returning_w_type_coerce_expression(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl connection.execute( table.insert(), [ @@ -457,7 +511,7 @@ class UpdateReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) def test_update_full_returning(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -481,7 +535,7 @@ class DeleteReturningTest(fixtures.TablesTest, AssertsExecutionResults): define_tables = InsertReturningTest.define_tables def test_delete_returning(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -536,7 +590,7 @@ class SequenceReturningTest(fixtures.TablesTest): def define_tables(cls, metadata): seq = Sequence("tid_seq") Table( - "tables", + "returning_tbl", metadata, Column( "id", @@ -549,7 +603,7 @@ class SequenceReturningTest(fixtures.TablesTest): cls.sequences.tid_seq = seq def test_insert(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl r = connection.execute( table.insert().values(data="hi").returning(table.c.id) ) @@ -570,7 +624,7 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): @classmethod def define_tables(cls, metadata): Table( - "tables", + "returning_tbl", metadata, Column( "id", @@ -584,7 +638,7 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): @testing.exclude("postgresql", "<", (8, 2), "8.2+ feature") def test_insert(self, connection): - table = self.tables.tables + table = self.tables.returning_tbl result = connection.execute( table.insert().returning(table.c.foo_id), dict(data="somedata") ) @@ -886,18 +940,359 @@ class InsertManyReturnDefaultsTest(fixtures.TablesTest): ], ) + if connection.dialect.insert_null_pk_still_autoincrements: + eq_( + [row._mapping for row in result.returned_defaults_rows], + [ + {"id": 10, "insdef": 0, "upddef": None}, + {"id": 11, "insdef": 0, "upddef": None}, + {"id": 12, "insdef": 0, "upddef": None}, + {"id": 13, "insdef": 0, "upddef": None}, + {"id": 14, "insdef": 0, "upddef": None}, + {"id": 15, "insdef": 0, "upddef": None}, + ], + ) + else: + eq_( + [row._mapping for row in result.returned_defaults_rows], + [ + {"insdef": 0, "upddef": None}, + {"insdef": 0, "upddef": None}, + {"insdef": 0, "upddef": None}, + {"insdef": 0, "upddef": None}, + {"insdef": 0, "upddef": None}, + {"insdef": 0, "upddef": None}, + ], + ) eq_( - [row._mapping for row in result.returned_defaults_rows], + result.inserted_primary_key_rows, + [(10,), (11,), (12,), (13,), (14,), (15,)], + ) + + +class InsertManyReturningTest(fixtures.TablesTest): + __requires__ = ("insert_executemany_returning",) + run_define_tables = "each" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + from sqlalchemy.sql import ColumnElement + from sqlalchemy.ext.compiler import compiles + + counter = itertools.count() + + class IncDefault(ColumnElement): + pass + + @compiles(IncDefault) + def compile_(element, compiler, **kw): + return str(next(counter)) + + Table( + "default_cases", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("insdef", Integer, default=IncDefault()), + Column("upddef", Integer, onupdate=IncDefault()), + ) + + class GoofyType(TypeDecorator): + impl = String + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + return "FOO" + value + + def process_result_value(self, value, dialect): + if value is None: + return None + return value + "BAR" + + cls.GoofyType = GoofyType + + Table( + "type_cases", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("persons", Integer), + Column("full", Boolean), + Column("goofy", GoofyType(50)), + Column("strval", String(50)), + ) + + @testing.combinations( + ( + lambda table: (table.c.strval + "hi",), + [("str1hi",), ("str2hi",), ("str3hi",)], + ), + ( + lambda table: ( + table.c.persons, + table.c.full, + table.c.strval + "hi", + ), + [ + (5, False, "str1hi"), + (6, True, "str2hi"), + (7, False, "str3hi"), + ], + ), + ( + lambda table: ( + table.c.persons, + table.c.strval + "hi", + table.c.full, + ), [ - {"insdef": 0, "upddef": None}, - {"insdef": 0, "upddef": None}, - {"insdef": 0, "upddef": None}, - {"insdef": 0, "upddef": None}, - {"insdef": 0, "upddef": None}, - {"insdef": 0, "upddef": None}, + (5, "str1hi", False), + (6, "str2hi", True), + (7, "str3hi", False), + ], + ), + ( + lambda table: ( + table.c.strval + "hi", + table.c.persons, + table.c.full, + ), + [ + ("str1hi", 5, False), + ("str2hi", 6, True), + ("str3hi", 7, False), + ], + ), + argnames="testcase, expected_rows", + ) + def test_insert_returning_w_expression( + self, connection, testcase, expected_rows + ): + table = self.tables.type_cases + + exprs = testing.resolve_lambda(testcase, table=table) + result = connection.execute( + table.insert().returning(*exprs), + [ + {"persons": 5, "full": False, "strval": "str1"}, + {"persons": 6, "full": True, "strval": "str2"}, + {"persons": 7, "full": False, "strval": "str3"}, + ], + ) + + eq_(result.fetchall(), expected_rows) + + result2 = connection.execute( + select(table.c.id, table.c.strval).order_by(table.c.id) + ) + eq_(result2.fetchall(), [(1, "str1"), (2, "str2"), (3, "str3")]) + + @testing.fails_if( + # Oracle has native executemany() + returning and does not use + # insertmanyvalues to achieve this. so test that for + # that particular dialect, the exception expected is not raised + # in the case that the compiler vetoed insertmanyvalues ( + # since Oracle's compiler will always veto it) + lambda config: not config.db.dialect.use_insertmanyvalues + ) + def test_iie_supported_but_not_this_statement(self, connection): + """test the case where INSERT..RETURNING w/ executemany is used, + the dialect requires use_insertmanyreturning, but + the compiler vetoed the use of insertmanyvalues.""" + + t1 = self.tables.type_cases + + with mock.patch.object( + testing.db.dialect.statement_compiler, + "_insert_stmt_should_use_insertmanyvalues", + lambda *arg: False, + ): + with expect_raises_message( + sa_exc.StatementError, + r'Statement does not have "insertmanyvalues" enabled, ' + r"can\'t use INSERT..RETURNING with executemany in this case.", + ): + connection.execute( + t1.insert().returning(t1.c.id, t1.c.goofy, t1.c.full), + [ + {"persons": 5, "full": True}, + {"persons": 6, "full": True}, + {"persons": 7, "full": False}, + ], + ) + + def test_insert_executemany_type_test(self, connection): + t1 = self.tables.type_cases + result = connection.execute( + t1.insert().returning(t1.c.id, t1.c.goofy, t1.c.full), + [ + {"persons": 5, "full": True, "goofy": "row1", "strval": "s1"}, + {"persons": 6, "full": True, "goofy": "row2", "strval": "s2"}, + {"persons": 7, "full": False, "goofy": "row3", "strval": "s3"}, + {"persons": 8, "full": True, "goofy": "row4", "strval": "s4"}, ], ) eq_( - result.inserted_primary_key_rows, - [(10,), (11,), (12,), (13,), (14,), (15,)], + result.mappings().all(), + [ + {"id": 1, "goofy": "FOOrow1BAR", "full": True}, + {"id": 2, "goofy": "FOOrow2BAR", "full": True}, + {"id": 3, "goofy": "FOOrow3BAR", "full": False}, + {"id": 4, "goofy": "FOOrow4BAR", "full": True}, + ], ) + + def test_insert_executemany_default_generators(self, connection): + t1 = self.tables.default_cases + result = connection.execute( + t1.insert().returning(t1.c.id, t1.c.insdef, t1.c.upddef), + [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + {"data": "d4"}, + {"data": "d5"}, + {"data": "d6"}, + ], + ) + + eq_( + result.mappings().all(), + [ + {"id": 1, "insdef": 0, "upddef": None}, + {"id": 2, "insdef": 0, "upddef": None}, + {"id": 3, "insdef": 0, "upddef": None}, + {"id": 4, "insdef": 0, "upddef": None}, + {"id": 5, "insdef": 0, "upddef": None}, + {"id": 6, "insdef": 0, "upddef": None}, + ], + ) + + @testing.combinations(True, False, argnames="update_cols") + @testing.requires.provisioned_upsert + def test_upsert_data_w_defaults(self, connection, update_cols): + t1 = self.tables.default_cases + + new_rows = connection.execute( + t1.insert().returning(t1.c.id, t1.c.insdef, t1.c.data), + [ + {"data": "d1"}, + {"data": "d2"}, + {"data": "d3"}, + {"data": "d4"}, + {"data": "d5"}, + {"data": "d6"}, + ], + ).all() + + eq_( + new_rows, + [ + (1, 0, "d1"), + (2, 0, "d2"), + (3, 0, "d3"), + (4, 0, "d4"), + (5, 0, "d5"), + (6, 0, "d6"), + ], + ) + + stmt = provision.upsert( + config, + t1, + (t1.c.id, t1.c.insdef, t1.c.data), + (lambda excluded: {"data": excluded.data + " excluded"}) + if update_cols + else None, + ) + + upserted_rows = connection.execute( + stmt, + [ + {"id": 1, "data": "d1 upserted"}, + {"id": 4, "data": "d4 upserted"}, + {"id": 5, "data": "d5 upserted"}, + {"id": 7, "data": "d7 upserted"}, + {"id": 8, "data": "d8 upserted"}, + {"id": 9, "data": "d9 upserted"}, + ], + ).all() + + if update_cols: + eq_( + upserted_rows, + [ + (1, 0, "d1 upserted excluded"), + (4, 0, "d4 upserted excluded"), + (5, 0, "d5 upserted excluded"), + (7, 1, "d7 upserted"), + (8, 1, "d8 upserted"), + (9, 1, "d9 upserted"), + ], + ) + else: + if testing.against("sqlite", "postgresql"): + eq_( + upserted_rows, + [ + (7, 1, "d7 upserted"), + (8, 1, "d8 upserted"), + (9, 1, "d9 upserted"), + ], + ) + elif testing.against("mariadb"): + # mariadb does not seem to have an "empty" upsert, + # so the provision.upsert() sets table.c.id to itself. + # this means we get all the rows back + eq_( + upserted_rows, + [ + (1, 0, "d1"), + (4, 0, "d4"), + (5, 0, "d5"), + (7, 1, "d7 upserted"), + (8, 1, "d8 upserted"), + (9, 1, "d9 upserted"), + ], + ) + + resulting_data = connection.execute( + t1.select().order_by(t1.c.id) + ).all() + + if update_cols: + eq_( + resulting_data, + [ + (1, "d1 upserted excluded", 0, None), + (2, "d2", 0, None), + (3, "d3", 0, None), + (4, "d4 upserted excluded", 0, None), + (5, "d5 upserted excluded", 0, None), + (6, "d6", 0, None), + (7, "d7 upserted", 1, None), + (8, "d8 upserted", 1, None), + (9, "d9 upserted", 1, None), + ], + ) + else: + eq_( + resulting_data, + [ + (1, "d1", 0, None), + (2, "d2", 0, None), + (3, "d3", 0, None), + (4, "d4", 0, None), + (5, "d5", 0, None), + (6, "d6", 0, None), + (7, "d7 upserted", 1, None), + (8, "d8 upserted", 1, None), + (9, "d9 upserted", 1, None), + ], + ) |
