summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-08-02 16:18:18 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-08-05 10:07:15 -0400
commit82a1d4096fbfe94e2fa626d65d5c3beb2c6afa37 (patch)
treebaca62a1a0784f192e65402f824319b0403c6847
parent0027b3a4bc54599ac8102a4a3d81d8007738903e (diff)
downloadsqlalchemy-82a1d4096fbfe94e2fa626d65d5c3beb2c6afa37.tar.gz
include column.default, column.onupdate in eager_defaults
Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults` parameter such that client-side SQL default or onupdate expressions in the table definition alone will trigger a fetch operation using RETURNING or SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only server side defaults established as part of table DDL and/or server-side onupdate expressions would trigger this fetch, even though client-side SQL expressions would be included when the fetch was rendered. Fixes: #7438 Change-Id: Iba719298ba4a26d185edec97ba77d2d54585e5a4
-rw-r--r--doc/build/changelog/unreleased_20/7438.rst11
-rw-r--r--lib/sqlalchemy/orm/mapper.py72
-rw-r--r--lib/sqlalchemy/orm/persistence.py26
-rw-r--r--lib/sqlalchemy/sql/dml.py22
-rw-r--r--test/orm/test_unitofworkv2.py213
-rw-r--r--test/sql/test_insert.py27
6 files changed, 331 insertions, 40 deletions
diff --git a/doc/build/changelog/unreleased_20/7438.rst b/doc/build/changelog/unreleased_20/7438.rst
new file mode 100644
index 000000000..9aca39171
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/7438.rst
@@ -0,0 +1,11 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 7438
+
+ Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults`
+ parameter such that client-side SQL default or onupdate expressions in the
+ table definition alone will trigger a fetch operation using RETURNING or
+ SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only
+ server side defaults established as part of table DDL and/or server-side
+ onupdate expressions would trigger this fetch, even though client-side SQL
+ expressions would be included when the fetch was rendered.
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 769b1b623..6a95030b5 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -28,6 +28,7 @@ from typing import cast
from typing import Collection
from typing import Deque
from typing import Dict
+from typing import FrozenSet
from typing import Generic
from typing import Iterable
from typing import Iterator
@@ -2397,15 +2398,21 @@ class Mapper(
)
@HasMemoized.memoized_attribute
- def _server_default_cols(self):
+ def _server_default_cols(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
return dict(
(
table,
frozenset(
[
- col.key
- for col in columns
+ col
+ for col in cast("Iterable[Column[Any]]", columns)
if col.server_default is not None
+ or (
+ col.default is not None
+ and col.default.is_clause_element
+ )
]
),
)
@@ -2413,35 +2420,60 @@ class Mapper(
)
@HasMemoized.memoized_attribute
- def _server_default_plus_onupdate_propkeys(self):
- result = set()
-
- for table, columns in self._cols_by_table.items():
- for col in columns:
- if (
- col.server_default is not None
- or col.server_onupdate is not None
- ) and col in self._columntoproperty:
- result.add(self._columntoproperty[col].key)
-
- return result
-
- @HasMemoized.memoized_attribute
- def _server_onupdate_default_cols(self):
+ def _server_onupdate_default_cols(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
return dict(
(
table,
frozenset(
[
- col.key
- for col in columns
+ col
+ for col in cast("Iterable[Column[Any]]", columns)
if col.server_onupdate is not None
+ or (
+ col.onupdate is not None
+ and col.onupdate.is_clause_element
+ )
]
),
)
for table, columns in self._cols_by_table.items()
)
+ @HasMemoized.memoized_attribute
+ def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]:
+ return {
+ table: frozenset(col.key for col in cols if col.key is not None)
+ for table, cols in self._server_default_cols.items()
+ }
+
+ @HasMemoized.memoized_attribute
+ def _server_onupdate_default_col_keys(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[str]]:
+ return {
+ table: frozenset(col.key for col in cols if col.key is not None)
+ for table, cols in self._server_onupdate_default_cols.items()
+ }
+
+ @HasMemoized.memoized_attribute
+ def _server_default_plus_onupdate_propkeys(self) -> Set[str]:
+ result: Set[str] = set()
+
+ col_to_property = self._columntoproperty
+ for table, columns in self._server_default_cols.items():
+ result.update(
+ col_to_property[col].key
+ for col in columns.intersection(col_to_property)
+ )
+ for table, columns in self._server_onupdate_default_cols.items():
+ result.update(
+ col_to_property[col].key
+ for col in columns.intersection(col_to_property)
+ )
+ return result
+
@HasMemoized.memoized_instancemethod
def __clause_element__(self):
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index c10f4701e..7cd66513b 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -561,9 +561,9 @@ def _collect_insert_commands(
has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].issubset(
- params
- )
+ has_all_defaults = mapper._server_default_col_keys[
+ table
+ ].issubset(params)
else:
has_all_defaults = True
else:
@@ -659,7 +659,7 @@ def _collect_update_commands(
if mapper.base_mapper.eager_defaults:
has_all_defaults = (
- mapper._server_onupdate_default_cols[table]
+ mapper._server_onupdate_default_col_keys[table]
).issubset(params)
else:
has_all_defaults = True
@@ -930,16 +930,20 @@ def _emit_update_statements(
return_defaults = False
if not has_all_pks:
- statement = statement.return_defaults()
+ statement = statement.return_defaults(*mapper._pks_by_table[table])
return_defaults = True
- elif (
+
+ if (
bookkeeping
and not has_all_defaults
and mapper.base_mapper.eager_defaults
):
- statement = statement.return_defaults()
+ statement = statement.return_defaults(
+ *mapper._server_onupdate_default_cols[table]
+ )
return_defaults = True
- elif mapper.version_id_col is not None:
+
+ if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
return_defaults = True
@@ -1171,8 +1175,10 @@ def _emit_insert_statements(
do_executemany = False
if not has_all_defaults and base_mapper.eager_defaults:
- statement = statement.return_defaults()
- elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(
+ *mapper._server_default_cols[table]
+ )
+ if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
elif do_executemany:
statement = statement.return_defaults(*table.primary_key)
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 76a16eb1c..9d489ed98 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -989,10 +989,26 @@ class ValuesBase(UpdateBase):
:attr:`_engine.CursorResult.inserted_primary_key_rows`
"""
+
+ if self._return_defaults:
+ # note _return_defaults_columns = () means return all columns,
+ # so if we have been here before, only update collection if there
+ # are columns in the collection
+ if self._return_defaults_columns and cols:
+ self._return_defaults_columns = tuple(
+ set(self._return_defaults_columns).union(
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in cols
+ )
+ )
+ else:
+ # set for all columns
+ self._return_defaults_columns = ()
+ else:
+ self._return_defaults_columns = tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
self._return_defaults = True
- self._return_defaults_columns = tuple(
- coercions.expect(roles.ColumnsClauseRole, c) for c in cols
- )
return self
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 68099a7a0..dd3b88915 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -2353,6 +2353,21 @@ class EagerDefaultsTest(fixtures.MappedTest):
Column("bar", Integer, server_onupdate=FetchedValue()),
)
+ Table(
+ "test3",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("foo", String(50), default=func.lower("HI")),
+ )
+
+ Table(
+ "test4",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("foo", Integer),
+ Column("bar", Integer, onupdate=text("5 + 3")),
+ )
+
@classmethod
def setup_classes(cls):
class Thing(cls.Basic):
@@ -2361,6 +2376,12 @@ class EagerDefaultsTest(fixtures.MappedTest):
class Thing2(cls.Basic):
pass
+ class Thing3(cls.Basic):
+ pass
+
+ class Thing4(cls.Basic):
+ pass
+
@classmethod
def setup_mappers(cls):
Thing = cls.classes.Thing
@@ -2375,7 +2396,19 @@ class EagerDefaultsTest(fixtures.MappedTest):
Thing2, cls.tables.test2, eager_defaults=True
)
- def test_insert_defaults_present(self):
+ Thing3 = cls.classes.Thing3
+
+ cls.mapper_registry.map_imperatively(
+ Thing3, cls.tables.test3, eager_defaults=True
+ )
+
+ Thing4 = cls.classes.Thing4
+
+ cls.mapper_registry.map_imperatively(
+ Thing4, cls.tables.test4, eager_defaults=True
+ )
+
+ def test_server_insert_defaults_present(self):
Thing = self.classes.Thing
s = fixture_session()
@@ -2388,7 +2421,10 @@ class EagerDefaultsTest(fixtures.MappedTest):
s.flush,
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, :foo)",
- [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}],
+ [
+ {"foo": 5, "id": 1},
+ {"foo": 10, "id": 2},
+ ],
),
)
@@ -2398,7 +2434,7 @@ class EagerDefaultsTest(fixtures.MappedTest):
self.assert_sql_count(testing.db, go, 0)
- def test_insert_defaults_present_as_expr(self):
+ def test_server_insert_defaults_present_as_expr(self):
Thing = self.classes.Thing
s = fixture_session()
@@ -2414,13 +2450,15 @@ class EagerDefaultsTest(fixtures.MappedTest):
testing.db,
s.flush,
CompiledSQL(
- "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) "
+ "INSERT INTO test (id, foo) "
+ "VALUES (%(id)s, 2 + 5) "
"RETURNING test.foo",
[{"id": 1}],
dialect="postgresql",
),
CompiledSQL(
- "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) "
+ "INSERT INTO test (id, foo) "
+ "VALUES (%(id)s, 5 + 5) "
"RETURNING test.foo",
[{"id": 2}],
dialect="postgresql",
@@ -2457,7 +2495,7 @@ class EagerDefaultsTest(fixtures.MappedTest):
self.assert_sql_count(testing.db, go, 0)
- def test_insert_defaults_nonpresent(self):
+ def test_server_insert_defaults_nonpresent(self):
Thing = self.classes.Thing
s = fixture_session()
@@ -2516,7 +2554,73 @@ class EagerDefaultsTest(fixtures.MappedTest):
),
)
- def test_update_defaults_nonpresent(self):
+ def test_clientsql_insert_defaults_nonpresent(self):
+ Thing3 = self.classes.Thing3
+ s = fixture_session()
+
+ t1, t2 = (Thing3(id=1), Thing3(id=2))
+
+ s.add_all([t1, t2])
+
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ Conditional(
+ testing.db.dialect.insert_returning,
+ [
+ Conditional(
+ testing.db.dialect.insert_executemany_returning,
+ [
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (%(id)s, lower(%(lower_1)s)) "
+ "RETURNING test3.foo",
+ [{"id": 1}, {"id": 2}],
+ dialect="postgresql",
+ ),
+ ],
+ [
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (%(id)s, lower(%(lower_1)s)) "
+ "RETURNING test3.foo",
+ [{"id": 1}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (%(id)s, lower(%(lower_1)s)) "
+ "RETURNING test3.foo",
+ [{"id": 2}],
+ dialect="postgresql",
+ ),
+ ],
+ ),
+ ],
+ [
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (:id, lower(:lower_1))",
+ [
+ {"id": 1, "lower_1": "HI"},
+ {"id": 2, "lower_1": "HI"},
+ ],
+ ),
+ CompiledSQL(
+ "SELECT test3.foo AS test3_foo "
+ "FROM test3 WHERE test3.id = :pk_1",
+ [{"pk_1": 1}],
+ ),
+ CompiledSQL(
+ "SELECT test3.foo AS test3_foo "
+ "FROM test3 WHERE test3.id = :pk_1",
+ [{"pk_1": 2}],
+ ),
+ ],
+ ),
+ )
+
+ def test_server_update_defaults_nonpresent(self):
Thing2 = self.classes.Thing2
s = fixture_session()
@@ -2611,6 +2715,101 @@ class EagerDefaultsTest(fixtures.MappedTest):
self.assert_sql_count(testing.db, go, 0)
+ def test_clientsql_update_defaults_nonpresent(self):
+ Thing4 = self.classes.Thing4
+ s = fixture_session()
+
+ t1, t2, t3, t4 = (
+ Thing4(id=1, foo=1),
+ Thing4(id=2, foo=2),
+ Thing4(id=3, foo=3),
+ Thing4(id=4, foo=4),
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ t1.foo = 5
+ t2.foo = 6
+ t2.bar = 10
+ t3.foo = 7
+ t4.foo = 8
+ t4.bar = 12
+
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ Conditional(
+ testing.db.dialect.update_returning,
+ [
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 "
+ "WHERE test4.id = %(test4_id)s RETURNING test4.bar",
+ [{"foo": 5, "test4_id": 1}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test4.id = %(test4_id)s",
+ [{"foo": 6, "bar": 10, "test4_id": 2}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 WHERE "
+ "test4.id = %(test4_id)s RETURNING test4.bar",
+ [{"foo": 7, "test4_id": 3}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s WHERE "
+ "test4.id = %(test4_id)s",
+ [{"foo": 8, "bar": 12, "test4_id": 4}],
+ dialect="postgresql",
+ ),
+ ],
+ [
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=5 + 3 "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 5, "test4_id": 1}],
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=:bar "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 6, "bar": 10, "test4_id": 2}],
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=5 + 3 "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 7, "test4_id": 3}],
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=:bar "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 8, "bar": 12, "test4_id": 4}],
+ ),
+ CompiledSQL(
+ "SELECT test4.bar AS test4_bar FROM test4 "
+ "WHERE test4.id = :pk_1",
+ [{"pk_1": 1}],
+ ),
+ CompiledSQL(
+ "SELECT test4.bar AS test4_bar FROM test4 "
+ "WHERE test4.id = :pk_1",
+ [{"pk_1": 3}],
+ ),
+ ],
+ ),
+ )
+
+ def go():
+ eq_(t1.bar, 8)
+ eq_(t2.bar, 10)
+ eq_(t3.bar, 8)
+ eq_(t4.bar, 12)
+
+ self.assert_sql_count(testing.db, go, 0)
+
def test_update_defaults_present_as_expr(self):
Thing2 = self.classes.Thing2
s = fixture_session()
diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py
index 071f595f3..61e0783e4 100644
--- a/test/sql/test_insert.py
+++ b/test/sql/test_insert.py
@@ -1,4 +1,7 @@
#! coding:utf-8
+from __future__ import annotations
+
+from typing import Tuple
from sqlalchemy import bindparam
from sqlalchemy import Column
@@ -66,6 +69,30 @@ class _InsertTestBase:
class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
__dialect__ = "default"
+ @testing.combinations(
+ ((), ("z",), ()),
+ (("x",), (), ()),
+ (("x",), ("y",), ("x", "y")),
+ (("x", "y"), ("y",), ("x", "y")),
+ )
+ def test_return_defaults_generative(
+ self,
+ initial_keys: Tuple[str, ...],
+ second_keys: Tuple[str, ...],
+ expected_keys: Tuple[str, ...],
+ ):
+ t = table("foo", column("x"), column("y"), column("z"))
+
+ initial_cols = tuple(t.c[initial_keys])
+ second_cols = tuple(t.c[second_keys])
+ expected = set(t.c[expected_keys])
+
+ stmt = t.insert().return_defaults(*initial_cols)
+ eq_(stmt._return_defaults_columns, initial_cols)
+ stmt = stmt.return_defaults(*second_cols)
+ assert isinstance(stmt._return_defaults_columns, tuple)
+ eq_(set(stmt._return_defaults_columns), expected)
+
def test_binds_that_match_columns(self):
"""test bind params named after column names
replace the normal SET/VALUES generation."""