diff options
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 17 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 277 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 4 | ||||
-rw-r--r-- | test/dialect/mysql/test_on_duplicate.py | 34 | ||||
-rw-r--r-- | test/engine/test_execute.py | 15 | ||||
-rw-r--r-- | test/requirements.py | 6 | ||||
-rw-r--r-- | test/sql/test_insert_exec.py | 72 | ||||
-rw-r--r-- | test/sql/test_returning.py | 24 | ||||
-rw-r--r-- | test/sql/test_sequences.py | 31 |
12 files changed, 313 insertions, 181 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 1d032b600..46529636d 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1782,15 +1782,10 @@ class MySQLDDLCompiler(compiler.DDLCompiler): if not column.nullable: colspec.append("NOT NULL") - # see: http://docs.sqlalchemy.org/en/latest/dialects/ - # mysql.html#mysql_timestamp_null + # see: http://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa elif column.nullable and is_timestamp: colspec.append("NULL") - default = self.get_column_default_string(column) - if default is not None: - colspec.append("DEFAULT " + default) - comment = column.comment if comment is not None: literal = self.sql_compiler.render_literal_value( @@ -1802,9 +1797,17 @@ class MySQLDDLCompiler(compiler.DDLCompiler): column.table is not None and column is column.table._autoincrement_column and column.server_default is None + and not ( + self.dialect.supports_sequences + and isinstance(column.default, sa_schema.Sequence) + and not column.default.optional + ) ): colspec.append("AUTO_INCREMENT") - + else: + default = self.get_column_default_string(column) + if default is not None: + colspec.append("DEFAULT " + default) return " ".join(colspec) def post_create_table(self, table): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 53c54ab65..07405e6d1 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2526,7 +2526,6 @@ class PGExecutionContext(default.DefaultExecutionContext): elif column.default is None or ( column.default.is_sequence and column.default.optional ): - # execute the sequence associated with a SERIAL primary # key column. for non-primary-key SERIAL, the ID just # generates server side. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ec0f2ed9f..564258a28 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1315,15 +1315,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _setup_dml_or_text_result(self): if self.isinsert: - if ( - not self._is_implicit_returning - and not self.compiled.inline - and self.dialect.postfetch_lastrowid - and not self.executemany - ): - + if self.compiled.postfetch_lastrowid: self._setup_ins_pk_from_lastrowid() - elif not self._is_implicit_returning: self._setup_ins_pk_from_empty() diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 17cacc981..7b917e661 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -669,6 +669,10 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () + postfetch_lastrowid = False + """if True, and this in insert, use cursor.lastrowid to populate + result.inserted_primary_key. """ + _cache_key_bind_match = None """a mapping that will relate the BindParameter object we compile to those that are part of the extracted collection of parameters diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 3bf8a7c62..986f63aad 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -335,7 +335,6 @@ def _scan_cols( values, kw, ): - ( need_pks, implicit_returning, @@ -358,9 +357,12 @@ def _scan_cols( cols = stmt.table.columns for c in cols: + # scan through every column in the target table + col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: + # parameter is present for the column. use that. _append_param_parameter( compiler, @@ -377,30 +379,43 @@ def _scan_cols( ) elif compile_state.isinsert: - if ( - c.primary_key - and need_pks - and ( - implicit_returning - or not postfetch_lastrowid - or c is not stmt.table._autoincrement_column - ) - ): + # no parameter is present and it's an insert. + + if c.primary_key and need_pks: + # it's a primary key column, it will need to be generated by a + # default generator of some kind, and the statement expects + # inserted_primary_key to be available. if implicit_returning: + # we can use RETURNING, find out how to invoke this + # column and get the value where RETURNING is an option. + # we can inline server-side functions in this case. + _append_param_insert_pk_returning( compiler, stmt, c, values, kw ) else: - _append_param_insert_pk(compiler, stmt, c, values, kw) + # otherwise, find out how to invoke this column + # and get its value where RETURNING is not an option. + # if we have to invoke a server-side function, we need + # to pre-execute it. or if this is a straight + # autoincrement column and the dialect supports it + # we can use curosr.lastrowid. + + _append_param_insert_pk_no_returning( + compiler, stmt, c, values, kw + ) elif c.default is not None: - + # column has a default, but it's not a pk column, or it is but + # we don't need to get the pk back. _append_param_insert_hasdefault( compiler, stmt, c, implicit_return_defaults, values, kw ) elif c.server_default is not None: + # column has a DDL-level default, and is either not a pk + # column or we don't need the pk. if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: @@ -415,6 +430,8 @@ def _scan_cols( _warn_pk_with_no_anticipated_value(c) elif compile_state.isupdate: + # no parameter is present and it's an insert. + _append_param_update( compiler, compile_state, @@ -468,38 +485,42 @@ def _append_param_parameter( **kw ) else: - if c.primary_key and implicit_returning: - compiler.returning.append(c) - value = compiler.process(value.self_group(), **kw) - elif implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) - value = compiler.process(value.self_group(), **kw) + # value is a SQL expression + value = compiler.process(value.self_group(), **kw) + + if compile_state.isupdate: + if implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + + else: + compiler.postfetch.append(c) else: - # postfetch specifically means, "we can SELECT the row we just - # inserted by primary key to get back the server generated - # defaults". so by definition this can't be used to get the primary - # key value back, because we need to have it ahead of time. - if not c.primary_key: + if c.primary_key: + + if implicit_returning: + compiler.returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + + elif implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + + else: + # postfetch specifically means, "we can SELECT the row we just + # inserted by primary key to get back the server generated + # defaults". so by definition this can't be used to get the + # primary key value back, because we need to have it ahead of + # time. + compiler.postfetch.append(c) - value = compiler.process(value.self_group(), **kw) + values.append((c, col_value, value)) def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): - """Create a primary key expression in the INSERT statement and - possibly a RETURNING clause for it. - - If the column has a Python-side default, we will create a bound - parameter for it and "pre-execute" the Python function. If - the column has a SQL expression default, or is a sequence, - we will add it directly into the INSERT statement and add a - RETURNING element to get the new value. If the column has a - server side default or is marked as the "autoincrement" column, - we will add a RETRUNING element to get at the value. - - If all the above tests fail, that indicates a primary key column with no - noted default generation capabilities that has no parameter passed; - raise an exception. + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and RETURNING + is available. """ if c.default is not None: @@ -526,6 +547,9 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): ) compiler.returning.append(c) else: + # client side default. OK we can't use RETURNING, need to + # do a "prefetch", which in fact fetches the default value + # on the Python side values.append( ( c, @@ -541,78 +565,15 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): _warn_pk_with_no_anticipated_value(c) -def _create_insert_prefetch_bind_param( - compiler, c, process=True, name=None, **kw -): - param = _create_bind_param( - compiler, c, None, process=process, name=name, **kw - ) - compiler.insert_prefetch.append(c) - return param +def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and we cannot use + RETURNING. + Depending on the kind of default here we may create a bound parameter + in the INSERT statement and pre-execute a default generation function, + or we may use cursor.lastrowid if supported by the dialect. -def _create_update_prefetch_bind_param( - compiler, c, process=True, name=None, **kw -): - param = _create_bind_param( - compiler, c, None, process=process, name=name, **kw - ) - compiler.update_prefetch.append(c) - return param - - -class _multiparam_column(elements.ColumnElement): - _is_multiparam_column = True - - def __init__(self, original, index): - self.index = index - self.key = "%s_m%d" % (original.key, index + 1) - self.original = original - self.default = original.default - self.type = original.type - - def compare(self, other, **kw): - raise NotImplementedError() - - def _copy_internals(self, other, **kw): - raise NotImplementedError() - - def __eq__(self, other): - return ( - isinstance(other, _multiparam_column) - and other.key == self.key - and other.original == self.original - ) - - -def _process_multiparam_default_bind(compiler, stmt, c, index, kw): - - if not c.default: - raise exc.CompileError( - "INSERT value for column %s is explicitly rendered as a bound" - "parameter in the VALUES clause; " - "a Python-side value or SQL expression is required" % c - ) - elif c.default.is_clause_element: - return compiler.process(c.default.arg.self_group(), **kw) - else: - col = _multiparam_column(c, index) - if isinstance(stmt, dml.Insert): - return _create_insert_prefetch_bind_param(compiler, col, **kw) - else: - return _create_update_prefetch_bind_param(compiler, col, **kw) - - -def _append_param_insert_pk(compiler, stmt, c, values, kw): - """Create a bound parameter in the INSERT statement to receive a - 'prefetched' default value. - - The 'prefetched' value indicates that we are to invoke a Python-side - default function or expliclt SQL expression before the INSERT statement - proceeds, so that we have a primary key value available. - - if the column has no noted default generation capabilities, it has - no value passed in either; raise an exception. """ @@ -635,12 +596,27 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): # column is the "autoincrement column" c is stmt.table._autoincrement_column and ( - # and it's either a "sequence" or a - # pre-executable "autoincrement" sequence - compiler.dialect.supports_sequences - or compiler.dialect.preexecute_autoincrement_sequences + # dialect can't use cursor.lastrowid + not compiler.dialect.postfetch_lastrowid + and ( + # column has a Sequence and we support those + ( + c.default is not None + and c.default.is_sequence + and compiler.dialect.supports_sequences + ) + or + # column has no default on it, but dialect can run the + # "autoincrement" mechanism explictly, e.g. PostrgreSQL + # SERIAL we know the sequence name + ( + c.default is None + and compiler.dialect.preexecute_autoincrement_sequences + ) + ) ) ): + # do a pre-execute of the default values.append( ( c, @@ -648,16 +624,26 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): _create_insert_prefetch_bind_param(compiler, c, **kw), ) ) - elif c.default is None and c.server_default is None and not c.nullable: + elif ( + c.default is None + and c.server_default is None + and not c.nullable + and c is not stmt.table._autoincrement_column + ): # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value _warn_pk_with_no_anticipated_value(c) + elif compiler.dialect.postfetch_lastrowid: + # finally, where it seems like there will be a generated primary key + # value and we haven't set up any other way to fetch it, and the + # dialect supports cursor.lastrowid, switch on the lastrowid flag so + # that the DefaultExecutionContext calls upon cursor.lastrowid + compiler.postfetch_lastrowid = True def _append_param_insert_hasdefault( compiler, stmt, c, implicit_return_defaults, values, kw ): - if c.default.is_sequence: if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional @@ -765,6 +751,69 @@ def _append_param_update( compiler.returning.append(c) +def _create_insert_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.insert_prefetch.append(c) + return param + + +def _create_update_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.update_prefetch.append(c) + return param + + +class _multiparam_column(elements.ColumnElement): + _is_multiparam_column = True + + def __init__(self, original, index): + self.index = index + self.key = "%s_m%d" % (original.key, index + 1) + self.original = original + self.default = original.default + self.type = original.type + + def compare(self, other, **kw): + raise NotImplementedError() + + def _copy_internals(self, other, **kw): + raise NotImplementedError() + + def __eq__(self, other): + return ( + isinstance(other, _multiparam_column) + and other.key == self.key + and other.original == self.original + ) + + +def _process_multiparam_default_bind(compiler, stmt, c, index, kw): + + if not c.default: + raise exc.CompileError( + "INSERT value for column %s is explicitly rendered as a bound" + "parameter in the VALUES clause; " + "a Python-side value or SQL expression is required" % c + ) + elif c.default.is_clause_element: + return compiler.process(c.default.arg.self_group(), **kw) + else: + col = _multiparam_column(c, index) + if isinstance(stmt, dml.Insert): + return _create_insert_prefetch_bind_param(compiler, col, **kw) + else: + return _create_update_prefetch_bind_param(compiler, col, **kw) + + def _get_multitable_params( compiler, stmt, diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index cc40022e9..7b0ddafe3 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -507,7 +507,9 @@ class SuiteRequirements(Requirements): @property def no_lastrowid_support(self): """the opposite of supports_lastrowid""" - return exclusions.NotPredicate(self.supports_lastrowid) + return exclusions.only_if( + [lambda config: not config.db.dialect.postfetch_lastrowid] + ) @property def reflects_pk_names(self): diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py index 95aabc776..ed88121a5 100644 --- a/test/dialect/mysql/test_on_duplicate.py +++ b/test/dialect/mysql/test_on_duplicate.py @@ -47,7 +47,7 @@ class OnDuplicateTest(fixtures.TablesTest): {"id": 2, "bar": "baz"}, ) - def test_on_duplicate_key_update(self): + def test_on_duplicate_key_update_multirow(self): foos = self.tables.foos with testing.db.connect() as conn: conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) @@ -55,14 +55,34 @@ class OnDuplicateTest(fixtures.TablesTest): [dict(id=1, bar="ab"), dict(id=2, bar="b")] ) stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) + result = conn.execute(stmt) - eq_(result.inserted_primary_key, (2,)) + + # multirow, so its ambiguous. this is a behavioral change + # in 1.4 + eq_(result.inserted_primary_key, (None,)) eq_( conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), [(1, "ab", "bz", False)], ) - def test_on_duplicate_key_update_null(self): + def test_on_duplicate_key_update_singlerow(self): + foos = self.tables.foos + with testing.db.connect() as conn: + conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) + stmt = insert(foos).values(dict(id=2, bar="b")) + stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) + + result = conn.execute(stmt) + + # only one row in the INSERT so we do inserted_primary_key + eq_(result.inserted_primary_key, (2,)) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "b", "bz", False)], + ) + + def test_on_duplicate_key_update_null_multirow(self): foos = self.tables.foos with testing.db.connect() as conn: conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) @@ -71,13 +91,15 @@ class OnDuplicateTest(fixtures.TablesTest): ) stmt = stmt.on_duplicate_key_update(updated_once=None) result = conn.execute(stmt) - eq_(result.inserted_primary_key, (2,)) + + # ambiguous + eq_(result.inserted_primary_key, (None,)) eq_( conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), [(1, "b", "bz", None)], ) - def test_on_duplicate_key_update_expression(self): + def test_on_duplicate_key_update_expression_multirow(self): foos = self.tables.foos with testing.db.connect() as conn: conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) @@ -88,7 +110,7 @@ class OnDuplicateTest(fixtures.TablesTest): bar=func.concat(stmt.inserted.bar, "_foo") ) result = conn.execute(stmt) - eq_(result.inserted_primary_key, (2,)) + eq_(result.inserted_primary_key, (None,)) eq_( conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), [(1, "ab_foo", "bz", False)], diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 5b922a97d..4b527f54c 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -1947,16 +1947,11 @@ class EngineEventsTest(fixtures.TestBase): ) conn.execute(t.insert()) - if testing.requires.supports_lastrowid.enabled: - # new MariaDB 10.3 supports sequences + lastrowid; only - # one statement - assert "INSERT" in canary[0][0] - else: - # we see the sequence pre-executed in the first call - assert "t_id_seq" in canary[0][0] - assert "INSERT" in canary[1][0] - # same context - is_(canary[0][1], canary[1][1]) + # we see the sequence pre-executed in the first call + assert "t_id_seq" in canary[0][0] + assert "INSERT" in canary[1][0] + # same context + is_(canary[0][1], canary[1][1]) def test_transactional(self): canary = [] diff --git a/test/requirements.py b/test/requirements.py index 35d76bde9..c2fd65b05 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -884,7 +884,7 @@ class DefaultRequirements(SuiteRequirements): def emulated_lastrowid_even_with_sequences(self): """"target dialect retrieves cursor.lastrowid or an equivalent after an insert() construct executes, even if the table has a - Sequence on it.. + Sequence on it. """ return fails_on_everything_except( "mysql", @@ -1666,8 +1666,8 @@ class DefaultRequirements(SuiteRequirements): @property def supports_lastrowid_for_expressions(self): - """sequences allowed in WHERE, GROUP BY, HAVING, etc.""" - return skip_if("mssql") + """cursor.lastrowid works if an explicit SQL expression was used.""" + return only_on(["sqlite", "mysql", "mariadb"]) @property def supports_sequence_for_autoincrement_column(self): diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index e27decd6f..16b27aeaa 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -4,6 +4,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import INT from sqlalchemy import Integer +from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import Sequence from sqlalchemy import sql @@ -207,6 +208,22 @@ class InsertExecTest(fixtures.TablesTest): {"id": 1, "foo": "hi", "bar": "hi"}, ) + @testing.requires.sequences + def test_lastrow_accessor_four_a(self): + metadata = MetaData() + self._test_lastrow_accessor( + Table( + "t4", + metadata, + Column( + "id", Integer, Sequence("t4_id_seq"), primary_key=True, + ), + Column("foo", String(30)), + ), + {"foo": "hi"}, + {"id": 1, "foo": "hi"}, + ) + def test_lastrow_accessor_five(self): metadata = MetaData() self._test_lastrow_accessor( @@ -362,6 +379,16 @@ class TableInsertTest(fixtures.TablesTest): Column("x", Integer), ) + Table( + "foo_no_seq", + metadata, + # note this will have full AUTO INCREMENT on MariaDB + # whereas "foo" will not due to sequence support + Column("id", Integer, primary_key=True,), + Column("data", String(50)), + Column("x", Integer), + ) + def _fixture(self, types=True): if types: t = sql.table( @@ -376,16 +403,22 @@ class TableInsertTest(fixtures.TablesTest): ) return t - def _test(self, stmt, row, returning=None, inserted_primary_key=False): - r = testing.db.execute(stmt) + def _test( + self, stmt, row, returning=None, inserted_primary_key=False, table=None + ): + with testing.db.connect() as conn: + r = conn.execute(stmt) + + if returning: + returned = r.first() + eq_(returned, returning) + elif inserted_primary_key is not False: + eq_(r.inserted_primary_key, inserted_primary_key) - if returning: - returned = r.first() - eq_(returned, returning) - elif inserted_primary_key is not False: - eq_(r.inserted_primary_key, inserted_primary_key) + if table is None: + table = self.tables.foo - eq_(testing.db.execute(self.tables.foo.select()).first(), row) + eq_(conn.execute(table.select()).first(), row) def _test_multi(self, stmt, rows, data): testing.db.execute(stmt, rows) @@ -459,6 +492,19 @@ class TableInsertTest(fixtures.TablesTest): returning=(1, 5), ) + @testing.requires.sql_expressions_inserted_as_primary_key + def test_sql_expr_lastrowid(self): + + # see also test.orm.test_unitofwork.py + # ClauseAttributesTest.test_insert_pk_expression + t = self.tables.foo_no_seq + self._test( + t.insert().values(id=literal(5) + 10, data="data", x=5), + (15, "data", 5), + inserted_primary_key=(15,), + table=self.tables.foo_no_seq, + ) + def test_direct_params(self): t = self._fixture() self._test( @@ -476,7 +522,11 @@ class TableInsertTest(fixtures.TablesTest): returning=(testing.db.dialect.default_sequence_base, 5), ) - @testing.requires.emulated_lastrowid_even_with_sequences + # there's a non optional Sequence in the metadata, which if the dialect + # supports sequences, it means the CREATE TABLE should *not* have + # autoincrement, so the INSERT below would fail because the "t" fixture + # does not indicate the Sequence + @testing.fails_if(testing.requires.sequences) @testing.requires.emulated_lastrowid def test_implicit_pk(self): t = self._fixture() @@ -486,7 +536,7 @@ class TableInsertTest(fixtures.TablesTest): inserted_primary_key=(), ) - @testing.requires.emulated_lastrowid_even_with_sequences + @testing.fails_if(testing.requires.sequences) @testing.requires.emulated_lastrowid def test_implicit_pk_multi_rows(self): t = self._fixture() @@ -500,7 +550,7 @@ class TableInsertTest(fixtures.TablesTest): [(1, "d1", 5), (2, "d2", 6), (3, "d3", 7)], ) - @testing.requires.emulated_lastrowid_even_with_sequences + @testing.fails_if(testing.requires.sequences) @testing.requires.emulated_lastrowid def test_implicit_pk_inline(self): t = self._fixture() diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 7d60dd475..20aa1fb3a 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -430,6 +430,30 @@ class ReturnDefaultsTest(fixtures.TablesTest): [None], ) + def test_insert_sql_expr(self, connection): + from sqlalchemy import literal + + t1 = self.tables.t1 + result = connection.execute( + t1.insert().return_defaults().values(insdef=literal(10) + 5) + ) + + eq_( + result.returned_defaults._mapping, + {"id": 1, "data": None, "insdef": 15, "upddef": None}, + ) + + def test_update_sql_expr(self, connection): + from sqlalchemy import literal + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.update().values(upddef=literal(10) + 5).return_defaults() + ) + + eq_(result.returned_defaults._mapping, {"upddef": 15}) + def test_insert_non_default_plus_default(self, connection): t1 = self.tables.t1 result = connection.execute( diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index 8d894f9f3..ee7c77a93 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -231,38 +231,29 @@ class SequenceExecTest(fixtures.TestBase): connection.execute(t1.insert().values(x=s.next_value())) self._assert_seq_result(connection.scalar(t1.select())) - @testing.requires.no_lastrowid_support @testing.provide_metadata - def test_inserted_pk_no_returning_no_lastrowid(self): + def test_inserted_pk_no_returning(self): """test inserted_primary_key contains [None] when pk_col=next_value(), implicit returning is not used.""" + # I'm not really sure what this test wants to accomlish. + metadata = self.metadata t1 = Table("t", metadata, Column("x", Integer, primary_key=True)) - t1.create(testing.db) + s = Sequence("my_sequence_here", metadata=metadata) e = engines.testing_engine(options={"implicit_returning": False}) - s = Sequence("my_sequence") with e.connect() as conn: - r = conn.execute(t1.insert().values(x=s.next_value())) - eq_(r.inserted_primary_key, [None]) - @testing.requires.supports_lastrowid - @testing.requires.supports_lastrowid_for_expressions - @testing.provide_metadata - def test_inserted_pk_no_returning_w_lastrowid(self): - """test inserted_primary_key contains the pk when - pk_col=next_value(), lastrowid is supported.""" - - metadata = self.metadata - t1 = Table("t", metadata, Column("x", Integer, primary_key=True,),) - t1.create(testing.db) - e = engines.testing_engine(options={"implicit_returning": False}) - s = Sequence("my_sequence") + t1.create(conn) + s.create(conn) - with e.connect() as conn: r = conn.execute(t1.insert().values(x=s.next_value())) - self._assert_seq_result(r.inserted_primary_key[0]) + + if testing.requires.emulated_lastrowid_even_with_sequences.enabled: + eq_(r.inserted_primary_key, (1,)) + else: + eq_(r.inserted_primary_key, (None,)) @testing.requires.returning @testing.provide_metadata |