summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py17
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py1
-rw-r--r--lib/sqlalchemy/engine/default.py9
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/sql/crud.py277
-rw-r--r--lib/sqlalchemy/testing/requirements.py4
-rw-r--r--test/dialect/mysql/test_on_duplicate.py34
-rw-r--r--test/engine/test_execute.py15
-rw-r--r--test/requirements.py6
-rw-r--r--test/sql/test_insert_exec.py72
-rw-r--r--test/sql/test_returning.py24
-rw-r--r--test/sql/test_sequences.py31
12 files changed, 313 insertions, 181 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 1d032b600..46529636d 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1782,15 +1782,10 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
if not column.nullable:
colspec.append("NOT NULL")
- # see: http://docs.sqlalchemy.org/en/latest/dialects/
- # mysql.html#mysql_timestamp_null
+ # see: http://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa
elif column.nullable and is_timestamp:
colspec.append("NULL")
- default = self.get_column_default_string(column)
- if default is not None:
- colspec.append("DEFAULT " + default)
-
comment = column.comment
if comment is not None:
literal = self.sql_compiler.render_literal_value(
@@ -1802,9 +1797,17 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
column.table is not None
and column is column.table._autoincrement_column
and column.server_default is None
+ and not (
+ self.dialect.supports_sequences
+ and isinstance(column.default, sa_schema.Sequence)
+ and not column.default.optional
+ )
):
colspec.append("AUTO_INCREMENT")
-
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec.append("DEFAULT " + default)
return " ".join(colspec)
def post_create_table(self, table):
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 53c54ab65..07405e6d1 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -2526,7 +2526,6 @@ class PGExecutionContext(default.DefaultExecutionContext):
elif column.default is None or (
column.default.is_sequence and column.default.optional
):
-
# execute the sequence associated with a SERIAL primary
# key column. for non-primary-key SERIAL, the ID just
# generates server side.
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index ec0f2ed9f..564258a28 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1315,15 +1315,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
def _setup_dml_or_text_result(self):
if self.isinsert:
- if (
- not self._is_implicit_returning
- and not self.compiled.inline
- and self.dialect.postfetch_lastrowid
- and not self.executemany
- ):
-
+ if self.compiled.postfetch_lastrowid:
self._setup_ins_pk_from_lastrowid()
-
elif not self._is_implicit_returning:
self._setup_ins_pk_from_empty()
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 17cacc981..7b917e661 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -669,6 +669,10 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
+ postfetch_lastrowid = False
+ """if True, and this in insert, use cursor.lastrowid to populate
+ result.inserted_primary_key. """
+
_cache_key_bind_match = None
"""a mapping that will relate the BindParameter object we compile
to those that are part of the extracted collection of parameters
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 3bf8a7c62..986f63aad 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -335,7 +335,6 @@ def _scan_cols(
values,
kw,
):
-
(
need_pks,
implicit_returning,
@@ -358,9 +357,12 @@ def _scan_cols(
cols = stmt.table.columns
for c in cols:
+ # scan through every column in the target table
+
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
+ # parameter is present for the column. use that.
_append_param_parameter(
compiler,
@@ -377,30 +379,43 @@ def _scan_cols(
)
elif compile_state.isinsert:
- if (
- c.primary_key
- and need_pks
- and (
- implicit_returning
- or not postfetch_lastrowid
- or c is not stmt.table._autoincrement_column
- )
- ):
+ # no parameter is present and it's an insert.
+
+ if c.primary_key and need_pks:
+ # it's a primary key column, it will need to be generated by a
+ # default generator of some kind, and the statement expects
+ # inserted_primary_key to be available.
if implicit_returning:
+ # we can use RETURNING, find out how to invoke this
+ # column and get the value where RETURNING is an option.
+ # we can inline server-side functions in this case.
+
_append_param_insert_pk_returning(
compiler, stmt, c, values, kw
)
else:
- _append_param_insert_pk(compiler, stmt, c, values, kw)
+ # otherwise, find out how to invoke this column
+ # and get its value where RETURNING is not an option.
+ # if we have to invoke a server-side function, we need
+ # to pre-execute it. or if this is a straight
+ # autoincrement column and the dialect supports it
+ # we can use curosr.lastrowid.
+
+ _append_param_insert_pk_no_returning(
+ compiler, stmt, c, values, kw
+ )
elif c.default is not None:
-
+ # column has a default, but it's not a pk column, or it is but
+ # we don't need to get the pk back.
_append_param_insert_hasdefault(
compiler, stmt, c, implicit_return_defaults, values, kw
)
elif c.server_default is not None:
+ # column has a DDL-level default, and is either not a pk
+ # column or we don't need the pk.
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
@@ -415,6 +430,8 @@ def _scan_cols(
_warn_pk_with_no_anticipated_value(c)
elif compile_state.isupdate:
+ # no parameter is present and it's an insert.
+
_append_param_update(
compiler,
compile_state,
@@ -468,38 +485,42 @@ def _append_param_parameter(
**kw
)
else:
- if c.primary_key and implicit_returning:
- compiler.returning.append(c)
- value = compiler.process(value.self_group(), **kw)
- elif implicit_return_defaults and c in implicit_return_defaults:
- compiler.returning.append(c)
- value = compiler.process(value.self_group(), **kw)
+ # value is a SQL expression
+ value = compiler.process(value.self_group(), **kw)
+
+ if compile_state.isupdate:
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ else:
+ compiler.postfetch.append(c)
else:
- # postfetch specifically means, "we can SELECT the row we just
- # inserted by primary key to get back the server generated
- # defaults". so by definition this can't be used to get the primary
- # key value back, because we need to have it ahead of time.
- if not c.primary_key:
+ if c.primary_key:
+
+ if implicit_returning:
+ compiler.returning.append(c)
+ elif compiler.dialect.postfetch_lastrowid:
+ compiler.postfetch_lastrowid = True
+
+ elif implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ else:
+ # postfetch specifically means, "we can SELECT the row we just
+ # inserted by primary key to get back the server generated
+ # defaults". so by definition this can't be used to get the
+ # primary key value back, because we need to have it ahead of
+ # time.
+
compiler.postfetch.append(c)
- value = compiler.process(value.self_group(), **kw)
+
values.append((c, col_value, value))
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
- """Create a primary key expression in the INSERT statement and
- possibly a RETURNING clause for it.
-
- If the column has a Python-side default, we will create a bound
- parameter for it and "pre-execute" the Python function. If
- the column has a SQL expression default, or is a sequence,
- we will add it directly into the INSERT statement and add a
- RETURNING element to get the new value. If the column has a
- server side default or is marked as the "autoincrement" column,
- we will add a RETRUNING element to get at the value.
-
- If all the above tests fail, that indicates a primary key column with no
- noted default generation capabilities that has no parameter passed;
- raise an exception.
+ """Create a primary key expression in the INSERT statement where
+ we want to populate result.inserted_primary_key and RETURNING
+ is available.
"""
if c.default is not None:
@@ -526,6 +547,9 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
)
compiler.returning.append(c)
else:
+ # client side default. OK we can't use RETURNING, need to
+ # do a "prefetch", which in fact fetches the default value
+ # on the Python side
values.append(
(
c,
@@ -541,78 +565,15 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
_warn_pk_with_no_anticipated_value(c)
-def _create_insert_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
- param = _create_bind_param(
- compiler, c, None, process=process, name=name, **kw
- )
- compiler.insert_prefetch.append(c)
- return param
+def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw):
+ """Create a primary key expression in the INSERT statement where
+ we want to populate result.inserted_primary_key and we cannot use
+ RETURNING.
+ Depending on the kind of default here we may create a bound parameter
+ in the INSERT statement and pre-execute a default generation function,
+ or we may use cursor.lastrowid if supported by the dialect.
-def _create_update_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
- param = _create_bind_param(
- compiler, c, None, process=process, name=name, **kw
- )
- compiler.update_prefetch.append(c)
- return param
-
-
-class _multiparam_column(elements.ColumnElement):
- _is_multiparam_column = True
-
- def __init__(self, original, index):
- self.index = index
- self.key = "%s_m%d" % (original.key, index + 1)
- self.original = original
- self.default = original.default
- self.type = original.type
-
- def compare(self, other, **kw):
- raise NotImplementedError()
-
- def _copy_internals(self, other, **kw):
- raise NotImplementedError()
-
- def __eq__(self, other):
- return (
- isinstance(other, _multiparam_column)
- and other.key == self.key
- and other.original == self.original
- )
-
-
-def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
-
- if not c.default:
- raise exc.CompileError(
- "INSERT value for column %s is explicitly rendered as a bound"
- "parameter in the VALUES clause; "
- "a Python-side value or SQL expression is required" % c
- )
- elif c.default.is_clause_element:
- return compiler.process(c.default.arg.self_group(), **kw)
- else:
- col = _multiparam_column(c, index)
- if isinstance(stmt, dml.Insert):
- return _create_insert_prefetch_bind_param(compiler, col, **kw)
- else:
- return _create_update_prefetch_bind_param(compiler, col, **kw)
-
-
-def _append_param_insert_pk(compiler, stmt, c, values, kw):
- """Create a bound parameter in the INSERT statement to receive a
- 'prefetched' default value.
-
- The 'prefetched' value indicates that we are to invoke a Python-side
- default function or expliclt SQL expression before the INSERT statement
- proceeds, so that we have a primary key value available.
-
- if the column has no noted default generation capabilities, it has
- no value passed in either; raise an exception.
"""
@@ -635,12 +596,27 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
# column is the "autoincrement column"
c is stmt.table._autoincrement_column
and (
- # and it's either a "sequence" or a
- # pre-executable "autoincrement" sequence
- compiler.dialect.supports_sequences
- or compiler.dialect.preexecute_autoincrement_sequences
+ # dialect can't use cursor.lastrowid
+ not compiler.dialect.postfetch_lastrowid
+ and (
+ # column has a Sequence and we support those
+ (
+ c.default is not None
+ and c.default.is_sequence
+ and compiler.dialect.supports_sequences
+ )
+ or
+ # column has no default on it, but dialect can run the
+ # "autoincrement" mechanism explictly, e.g. PostrgreSQL
+ # SERIAL we know the sequence name
+ (
+ c.default is None
+ and compiler.dialect.preexecute_autoincrement_sequences
+ )
+ )
)
):
+ # do a pre-execute of the default
values.append(
(
c,
@@ -648,16 +624,26 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
_create_insert_prefetch_bind_param(compiler, c, **kw),
)
)
- elif c.default is None and c.server_default is None and not c.nullable:
+ elif (
+ c.default is None
+ and c.server_default is None
+ and not c.nullable
+ and c is not stmt.table._autoincrement_column
+ ):
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
_warn_pk_with_no_anticipated_value(c)
+ elif compiler.dialect.postfetch_lastrowid:
+ # finally, where it seems like there will be a generated primary key
+ # value and we haven't set up any other way to fetch it, and the
+ # dialect supports cursor.lastrowid, switch on the lastrowid flag so
+ # that the DefaultExecutionContext calls upon cursor.lastrowid
+ compiler.postfetch_lastrowid = True
def _append_param_insert_hasdefault(
compiler, stmt, c, implicit_return_defaults, values, kw
):
-
if c.default.is_sequence:
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
@@ -765,6 +751,69 @@ def _append_param_update(
compiler.returning.append(c)
+def _create_insert_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
+ compiler.insert_prefetch.append(c)
+ return param
+
+
+def _create_update_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
+ compiler.update_prefetch.append(c)
+ return param
+
+
+class _multiparam_column(elements.ColumnElement):
+ _is_multiparam_column = True
+
+ def __init__(self, original, index):
+ self.index = index
+ self.key = "%s_m%d" % (original.key, index + 1)
+ self.original = original
+ self.default = original.default
+ self.type = original.type
+
+ def compare(self, other, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, other, **kw):
+ raise NotImplementedError()
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, _multiparam_column)
+ and other.key == self.key
+ and other.original == self.original
+ )
+
+
+def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
+
+ if not c.default:
+ raise exc.CompileError(
+ "INSERT value for column %s is explicitly rendered as a bound"
+ "parameter in the VALUES clause; "
+ "a Python-side value or SQL expression is required" % c
+ )
+ elif c.default.is_clause_element:
+ return compiler.process(c.default.arg.self_group(), **kw)
+ else:
+ col = _multiparam_column(c, index)
+ if isinstance(stmt, dml.Insert):
+ return _create_insert_prefetch_bind_param(compiler, col, **kw)
+ else:
+ return _create_update_prefetch_bind_param(compiler, col, **kw)
+
+
def _get_multitable_params(
compiler,
stmt,
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index cc40022e9..7b0ddafe3 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -507,7 +507,9 @@ class SuiteRequirements(Requirements):
@property
def no_lastrowid_support(self):
"""the opposite of supports_lastrowid"""
- return exclusions.NotPredicate(self.supports_lastrowid)
+ return exclusions.only_if(
+ [lambda config: not config.db.dialect.postfetch_lastrowid]
+ )
@property
def reflects_pk_names(self):
diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py
index 95aabc776..ed88121a5 100644
--- a/test/dialect/mysql/test_on_duplicate.py
+++ b/test/dialect/mysql/test_on_duplicate.py
@@ -47,7 +47,7 @@ class OnDuplicateTest(fixtures.TablesTest):
{"id": 2, "bar": "baz"},
)
- def test_on_duplicate_key_update(self):
+ def test_on_duplicate_key_update_multirow(self):
foos = self.tables.foos
with testing.db.connect() as conn:
conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
@@ -55,14 +55,34 @@ class OnDuplicateTest(fixtures.TablesTest):
[dict(id=1, bar="ab"), dict(id=2, bar="b")]
)
stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
result = conn.execute(stmt)
- eq_(result.inserted_primary_key, (2,))
+
+ # multirow, so its ambiguous. this is a behavioral change
+ # in 1.4
+ eq_(result.inserted_primary_key, (None,))
eq_(
conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
[(1, "ab", "bz", False)],
)
- def test_on_duplicate_key_update_null(self):
+ def test_on_duplicate_key_update_singlerow(self):
+ foos = self.tables.foos
+ with testing.db.connect() as conn:
+ conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+ stmt = insert(foos).values(dict(id=2, bar="b"))
+ stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
+ result = conn.execute(stmt)
+
+ # only one row in the INSERT so we do inserted_primary_key
+ eq_(result.inserted_primary_key, (2,))
+ eq_(
+ conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+ [(1, "b", "bz", False)],
+ )
+
+ def test_on_duplicate_key_update_null_multirow(self):
foos = self.tables.foos
with testing.db.connect() as conn:
conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
@@ -71,13 +91,15 @@ class OnDuplicateTest(fixtures.TablesTest):
)
stmt = stmt.on_duplicate_key_update(updated_once=None)
result = conn.execute(stmt)
- eq_(result.inserted_primary_key, (2,))
+
+ # ambiguous
+ eq_(result.inserted_primary_key, (None,))
eq_(
conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
[(1, "b", "bz", None)],
)
- def test_on_duplicate_key_update_expression(self):
+ def test_on_duplicate_key_update_expression_multirow(self):
foos = self.tables.foos
with testing.db.connect() as conn:
conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
@@ -88,7 +110,7 @@ class OnDuplicateTest(fixtures.TablesTest):
bar=func.concat(stmt.inserted.bar, "_foo")
)
result = conn.execute(stmt)
- eq_(result.inserted_primary_key, (2,))
+ eq_(result.inserted_primary_key, (None,))
eq_(
conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
[(1, "ab_foo", "bz", False)],
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 5b922a97d..4b527f54c 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -1947,16 +1947,11 @@ class EngineEventsTest(fixtures.TestBase):
)
conn.execute(t.insert())
- if testing.requires.supports_lastrowid.enabled:
- # new MariaDB 10.3 supports sequences + lastrowid; only
- # one statement
- assert "INSERT" in canary[0][0]
- else:
- # we see the sequence pre-executed in the first call
- assert "t_id_seq" in canary[0][0]
- assert "INSERT" in canary[1][0]
- # same context
- is_(canary[0][1], canary[1][1])
+ # we see the sequence pre-executed in the first call
+ assert "t_id_seq" in canary[0][0]
+ assert "INSERT" in canary[1][0]
+ # same context
+ is_(canary[0][1], canary[1][1])
def test_transactional(self):
canary = []
diff --git a/test/requirements.py b/test/requirements.py
index 35d76bde9..c2fd65b05 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -884,7 +884,7 @@ class DefaultRequirements(SuiteRequirements):
def emulated_lastrowid_even_with_sequences(self):
""""target dialect retrieves cursor.lastrowid or an equivalent
after an insert() construct executes, even if the table has a
- Sequence on it..
+ Sequence on it.
"""
return fails_on_everything_except(
"mysql",
@@ -1666,8 +1666,8 @@ class DefaultRequirements(SuiteRequirements):
@property
def supports_lastrowid_for_expressions(self):
- """sequences allowed in WHERE, GROUP BY, HAVING, etc."""
- return skip_if("mssql")
+ """cursor.lastrowid works if an explicit SQL expression was used."""
+ return only_on(["sqlite", "mysql", "mariadb"])
@property
def supports_sequence_for_autoincrement_column(self):
diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py
index e27decd6f..16b27aeaa 100644
--- a/test/sql/test_insert_exec.py
+++ b/test/sql/test_insert_exec.py
@@ -4,6 +4,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import INT
from sqlalchemy import Integer
+from sqlalchemy import literal
from sqlalchemy import MetaData
from sqlalchemy import Sequence
from sqlalchemy import sql
@@ -207,6 +208,22 @@ class InsertExecTest(fixtures.TablesTest):
{"id": 1, "foo": "hi", "bar": "hi"},
)
+ @testing.requires.sequences
+ def test_lastrow_accessor_four_a(self):
+ metadata = MetaData()
+ self._test_lastrow_accessor(
+ Table(
+ "t4",
+ metadata,
+ Column(
+ "id", Integer, Sequence("t4_id_seq"), primary_key=True,
+ ),
+ Column("foo", String(30)),
+ ),
+ {"foo": "hi"},
+ {"id": 1, "foo": "hi"},
+ )
+
def test_lastrow_accessor_five(self):
metadata = MetaData()
self._test_lastrow_accessor(
@@ -362,6 +379,16 @@ class TableInsertTest(fixtures.TablesTest):
Column("x", Integer),
)
+ Table(
+ "foo_no_seq",
+ metadata,
+ # note this will have full AUTO INCREMENT on MariaDB
+ # whereas "foo" will not due to sequence support
+ Column("id", Integer, primary_key=True,),
+ Column("data", String(50)),
+ Column("x", Integer),
+ )
+
def _fixture(self, types=True):
if types:
t = sql.table(
@@ -376,16 +403,22 @@ class TableInsertTest(fixtures.TablesTest):
)
return t
- def _test(self, stmt, row, returning=None, inserted_primary_key=False):
- r = testing.db.execute(stmt)
+ def _test(
+ self, stmt, row, returning=None, inserted_primary_key=False, table=None
+ ):
+ with testing.db.connect() as conn:
+ r = conn.execute(stmt)
+
+ if returning:
+ returned = r.first()
+ eq_(returned, returning)
+ elif inserted_primary_key is not False:
+ eq_(r.inserted_primary_key, inserted_primary_key)
- if returning:
- returned = r.first()
- eq_(returned, returning)
- elif inserted_primary_key is not False:
- eq_(r.inserted_primary_key, inserted_primary_key)
+ if table is None:
+ table = self.tables.foo
- eq_(testing.db.execute(self.tables.foo.select()).first(), row)
+ eq_(conn.execute(table.select()).first(), row)
def _test_multi(self, stmt, rows, data):
testing.db.execute(stmt, rows)
@@ -459,6 +492,19 @@ class TableInsertTest(fixtures.TablesTest):
returning=(1, 5),
)
+ @testing.requires.sql_expressions_inserted_as_primary_key
+ def test_sql_expr_lastrowid(self):
+
+ # see also test.orm.test_unitofwork.py
+ # ClauseAttributesTest.test_insert_pk_expression
+ t = self.tables.foo_no_seq
+ self._test(
+ t.insert().values(id=literal(5) + 10, data="data", x=5),
+ (15, "data", 5),
+ inserted_primary_key=(15,),
+ table=self.tables.foo_no_seq,
+ )
+
def test_direct_params(self):
t = self._fixture()
self._test(
@@ -476,7 +522,11 @@ class TableInsertTest(fixtures.TablesTest):
returning=(testing.db.dialect.default_sequence_base, 5),
)
- @testing.requires.emulated_lastrowid_even_with_sequences
+ # there's a non optional Sequence in the metadata, which if the dialect
+ # supports sequences, it means the CREATE TABLE should *not* have
+ # autoincrement, so the INSERT below would fail because the "t" fixture
+ # does not indicate the Sequence
+ @testing.fails_if(testing.requires.sequences)
@testing.requires.emulated_lastrowid
def test_implicit_pk(self):
t = self._fixture()
@@ -486,7 +536,7 @@ class TableInsertTest(fixtures.TablesTest):
inserted_primary_key=(),
)
- @testing.requires.emulated_lastrowid_even_with_sequences
+ @testing.fails_if(testing.requires.sequences)
@testing.requires.emulated_lastrowid
def test_implicit_pk_multi_rows(self):
t = self._fixture()
@@ -500,7 +550,7 @@ class TableInsertTest(fixtures.TablesTest):
[(1, "d1", 5), (2, "d2", 6), (3, "d3", 7)],
)
- @testing.requires.emulated_lastrowid_even_with_sequences
+ @testing.fails_if(testing.requires.sequences)
@testing.requires.emulated_lastrowid
def test_implicit_pk_inline(self):
t = self._fixture()
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index 7d60dd475..20aa1fb3a 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -430,6 +430,30 @@ class ReturnDefaultsTest(fixtures.TablesTest):
[None],
)
+ def test_insert_sql_expr(self, connection):
+ from sqlalchemy import literal
+
+ t1 = self.tables.t1
+ result = connection.execute(
+ t1.insert().return_defaults().values(insdef=literal(10) + 5)
+ )
+
+ eq_(
+ result.returned_defaults._mapping,
+ {"id": 1, "data": None, "insdef": 15, "upddef": None},
+ )
+
+ def test_update_sql_expr(self, connection):
+ from sqlalchemy import literal
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.update().values(upddef=literal(10) + 5).return_defaults()
+ )
+
+ eq_(result.returned_defaults._mapping, {"upddef": 15})
+
def test_insert_non_default_plus_default(self, connection):
t1 = self.tables.t1
result = connection.execute(
diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py
index 8d894f9f3..ee7c77a93 100644
--- a/test/sql/test_sequences.py
+++ b/test/sql/test_sequences.py
@@ -231,38 +231,29 @@ class SequenceExecTest(fixtures.TestBase):
connection.execute(t1.insert().values(x=s.next_value()))
self._assert_seq_result(connection.scalar(t1.select()))
- @testing.requires.no_lastrowid_support
@testing.provide_metadata
- def test_inserted_pk_no_returning_no_lastrowid(self):
+ def test_inserted_pk_no_returning(self):
"""test inserted_primary_key contains [None] when
pk_col=next_value(), implicit returning is not used."""
+ # I'm not really sure what this test wants to accomlish.
+
metadata = self.metadata
t1 = Table("t", metadata, Column("x", Integer, primary_key=True))
- t1.create(testing.db)
+ s = Sequence("my_sequence_here", metadata=metadata)
e = engines.testing_engine(options={"implicit_returning": False})
- s = Sequence("my_sequence")
with e.connect() as conn:
- r = conn.execute(t1.insert().values(x=s.next_value()))
- eq_(r.inserted_primary_key, [None])
- @testing.requires.supports_lastrowid
- @testing.requires.supports_lastrowid_for_expressions
- @testing.provide_metadata
- def test_inserted_pk_no_returning_w_lastrowid(self):
- """test inserted_primary_key contains the pk when
- pk_col=next_value(), lastrowid is supported."""
-
- metadata = self.metadata
- t1 = Table("t", metadata, Column("x", Integer, primary_key=True,),)
- t1.create(testing.db)
- e = engines.testing_engine(options={"implicit_returning": False})
- s = Sequence("my_sequence")
+ t1.create(conn)
+ s.create(conn)
- with e.connect() as conn:
r = conn.execute(t1.insert().values(x=s.next_value()))
- self._assert_seq_result(r.inserted_primary_key[0])
+
+ if testing.requires.emulated_lastrowid_even_with_sequences.enabled:
+ eq_(r.inserted_primary_key, (1,))
+ else:
+ eq_(r.inserted_primary_key, (None,))
@testing.requires.returning
@testing.provide_metadata