summaryrefslogtreecommitdiff
path: root/test/orm
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-12-16 16:56:56 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-12-16 16:56:56 +0000
commitbd5a4611c34d25cf21607544c01ce7fcb886e0a9 (patch)
tree76ef80b434565b6a9ba3e5ddc113f6581a2d3ede /test/orm
parent5bb48511a126b66ed06abf76d706ab707afafbf1 (diff)
parent8e4e325319eaadb81cc1b6e8c8db7cc1a6b920bd (diff)
downloadsqlalchemy-bd5a4611c34d25cf21607544c01ce7fcb886e0a9.tar.gz
Merge "add eager_defaults="auto" for inserts" into main
Diffstat (limited to 'test/orm')
-rw-r--r--test/orm/inheritance/test_basic.py64
-rw-r--r--test/orm/test_defaults.py9
-rw-r--r--test/orm/test_expire.py5
-rw-r--r--test/orm/test_unitofwork.py21
-rw-r--r--test/orm/test_unitofworkv2.py269
5 files changed, 321 insertions, 47 deletions
diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py
index 0ba900798..5803d51bc 100644
--- a/test/orm/inheritance/test_basic.py
+++ b/test/orm/inheritance/test_basic.py
@@ -3074,9 +3074,15 @@ class OptimizedLoadTest(fixtures.MappedTest):
eq_(s1test.comp, Comp("ham", "cheese"))
eq_(s2test.comp, Comp("bacon", "eggs"))
- def test_load_expired_on_pending(self):
+ @testing.variation("eager_defaults", [True, False])
+ def test_load_expired_on_pending(self, eager_defaults):
base, sub = self.tables.base, self.tables.sub
+ expected_eager_defaults = bool(eager_defaults)
+ expect_returning = (
+ expected_eager_defaults and testing.db.dialect.insert_returning
+ )
+
class Base(fixtures.BasicEntity):
pass
@@ -3084,7 +3090,11 @@ class OptimizedLoadTest(fixtures.MappedTest):
pass
self.mapper_registry.map_imperatively(
- Base, base, polymorphic_on=base.c.type, polymorphic_identity="base"
+ Base,
+ base,
+ polymorphic_on=base.c.type,
+ polymorphic_identity="base",
+ eager_defaults=bool(eager_defaults),
)
self.mapper_registry.map_imperatively(
Sub, sub, inherits=Base, polymorphic_identity="sub"
@@ -3095,13 +3105,30 @@ class OptimizedLoadTest(fixtures.MappedTest):
self.assert_sql_execution(
testing.db,
sess.flush,
- CompiledSQL(
- "INSERT INTO base (data, type) VALUES (:data, :type)",
- [{"data": "s1", "type": "sub"}],
- ),
- CompiledSQL(
- "INSERT INTO sub (id, sub) VALUES (:id, :sub)",
- lambda ctx: {"id": s1.id, "sub": None},
+ Conditional(
+ expect_returning,
+ [
+ CompiledSQL(
+ "INSERT INTO base (data, type) VALUES (:data, :type) "
+ "RETURNING base.id, base.counter",
+ [{"data": "s1", "type": "sub"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO sub (id, sub) VALUES (:id, :sub) "
+ "RETURNING sub.subcounter, sub.subcounter2",
+ lambda ctx: {"id": s1.id, "sub": None},
+ ),
+ ],
+ [
+ CompiledSQL(
+ "INSERT INTO base (data, type) VALUES (:data, :type)",
+ [{"data": "s1", "type": "sub"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO sub (id, sub) VALUES (:id, :sub)",
+ lambda ctx: {"id": s1.id, "sub": None},
+ ),
+ ],
),
)
@@ -3111,12 +3138,19 @@ class OptimizedLoadTest(fixtures.MappedTest):
self.assert_sql_execution(
testing.db,
go,
- CompiledSQL(
- "SELECT base.counter AS base_counter, "
- "sub.subcounter AS sub_subcounter, "
- "sub.subcounter2 AS sub_subcounter2 FROM base JOIN sub "
- "ON base.id = sub.id WHERE base.id = :pk_1",
- lambda ctx: {"pk_1": s1.id},
+ Conditional(
+ expect_returning,
+ [],
+ [
+ CompiledSQL(
+ "SELECT base.counter AS base_counter, "
+ "sub.subcounter AS sub_subcounter, "
+ "sub.subcounter2 AS sub_subcounter2 "
+ "FROM base JOIN sub "
+ "ON base.id = sub.id WHERE base.id = :pk_1",
+ lambda ctx: {"pk_1": s1.id},
+ ),
+ ],
),
)
diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py
index e738689b8..fb6fba704 100644
--- a/test/orm/test_defaults.py
+++ b/test/orm/test_defaults.py
@@ -17,6 +17,7 @@ from sqlalchemy.testing.schema import Table
class TriggerDefaultsTest(fixtures.MappedTest):
__requires__ = ("row_triggers",)
+ __backend__ = True
@classmethod
def define_tables(cls, metadata):
@@ -39,6 +40,7 @@ class TriggerDefaultsTest(fixtures.MappedTest):
sa.schema.FetchedValue(),
sa.schema.FetchedValue(for_update=True),
),
+ implicit_returning=False,
)
dialect_name = testing.db.dialect.name
@@ -382,12 +384,7 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest):
asserter.assert_(
CompiledSQL(
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
- [{"foo": 5, "test_id": 1}],
- enable_returning=False,
- ),
- CompiledSQL(
- "UPDATE test SET foo=:foo WHERE test.id = :test_id",
- [{"foo": 6, "test_id": 2}],
+ [{"foo": 5, "test_id": 1}, {"foo": 6, "test_id": 2}],
enable_returning=False,
),
CompiledSQL(
diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py
index a5fd7533e..f851f3698 100644
--- a/test/orm/test_expire.py
+++ b/test/orm/test_expire.py
@@ -1826,7 +1826,9 @@ class LifecycleTest(fixtures.MappedTest):
def setup_mappers(cls):
cls.mapper_registry.map_imperatively(cls.classes.Data, cls.tables.data)
cls.mapper_registry.map_imperatively(
- cls.classes.DataFetched, cls.tables.data_fetched
+ cls.classes.DataFetched,
+ cls.tables.data_fetched,
+ eager_defaults=False,
)
cls.mapper_registry.map_imperatively(
cls.classes.DataDefer,
@@ -1886,7 +1888,6 @@ class LifecycleTest(fixtures.MappedTest):
def go():
eq_(d1.data, None)
- # this one is marked as "fetch" so we emit SQL
self.assert_sql_count(testing.db, go, 1)
def test_cols_missing_in_load(self):
diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py
index 79d4adacf..5835ef65a 100644
--- a/test/orm/test_unitofwork.py
+++ b/test/orm/test_unitofwork.py
@@ -1135,7 +1135,8 @@ class DefaultTest(fixtures.MappedTest):
class Secondary(cls.Comparable):
pass
- def test_insert(self):
+ @testing.variation("eager_defaults", ["auto", True, False])
+ def test_insert(self, eager_defaults):
althohoval, hohoval, default_t, Hoho = (
self.other.althohoval,
self.other.hohoval,
@@ -1143,7 +1144,13 @@ class DefaultTest(fixtures.MappedTest):
self.classes.Hoho,
)
- self.mapper_registry.map_imperatively(Hoho, default_t)
+ mp = self.mapper_registry.map_imperatively(
+ Hoho,
+ default_t,
+ eager_defaults="auto"
+ if eager_defaults.auto
+ else bool(eager_defaults),
+ )
h1 = Hoho(hoho=althohoval)
h2 = Hoho(counter=12)
@@ -1162,12 +1169,18 @@ class DefaultTest(fixtures.MappedTest):
# test deferred load of attributes, one select per instance
self.assert_(h2.hoho == h4.hoho == h5.hoho == hohoval)
- self.sql_count_(3, go)
+ if mp._prefer_eager_defaults(testing.db.dialect, default_t):
+ self.sql_count_(0, go)
+ else:
+ self.sql_count_(3, go)
def go():
self.assert_(h1.counter == h4.counter == h5.counter == 7)
- self.sql_count_(1, go)
+ if mp._prefer_eager_defaults(testing.db.dialect, default_t):
+ self.sql_count_(0, go)
+ else:
+ self.sql_count_(1, go)
def go():
self.assert_(h3.counter == h2.counter == 12)
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 468d43063..ae47dfa4f 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -34,6 +34,7 @@ from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import variation_fixture
from sqlalchemy.testing.assertsql import AllOf
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.assertsql import Conditional
@@ -2077,7 +2078,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
class T(fixtures.ComparableEntity):
pass
- self.mapper_registry.map_imperatively(T, t)
+ mp = self.mapper_registry.map_imperatively(T, t)
sess = fixture_session()
sess.add_all(
[
@@ -2095,6 +2096,17 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
]
)
+ eager_defaults = mp._prefer_eager_defaults(
+ testing.db.dialect, mp.local_table
+ )
+
+ if eager_defaults:
+ tdef_col = ", t.def_"
+ tdef_returning = " RETURNING t.def_"
+ else:
+ tdef_col = ""
+ tdef_returning = ""
+
self.assert_sql_execution(
testing.db,
sess.flush,
@@ -2102,7 +2114,8 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
- "INSERT INTO t (data) VALUES (:data) RETURNING t.id",
+ f"INSERT INTO t (data) VALUES (:data) "
+ f"RETURNING t.id{tdef_col}",
[{"data": "t1"}, {"data": "t2"}],
),
],
@@ -2116,7 +2129,8 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
],
),
CompiledSQL(
- "INSERT INTO t (id, data) VALUES (:id, :data)",
+ f"INSERT INTO t (id, data) "
+ f"VALUES (:id, :data){tdef_returning}",
[
{"data": "t3", "id": 3},
{"data": "t4", "id": 4},
@@ -2124,11 +2138,13 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
],
),
CompiledSQL(
- "INSERT INTO t (id, data) VALUES (:id, lower(:lower_1))",
+ f"INSERT INTO t (id, data) "
+ f"VALUES (:id, lower(:lower_1)){tdef_returning}",
{"lower_1": "t6", "id": 6},
),
CompiledSQL(
- "INSERT INTO t (id, data) VALUES (:id, :data)",
+ f"INSERT INTO t (id, data) "
+ f"VALUES (:id, :data){tdef_returning}",
[{"data": "t7", "id": 7}, {"data": "t8", "id": 8}],
),
CompiledSQL(
@@ -2139,7 +2155,8 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
],
),
CompiledSQL(
- "INSERT INTO t (id, data) VALUES (:id, :data)",
+ f"INSERT INTO t (id, data) "
+ f"VALUES (:id, :data){tdef_returning}",
{"data": "t11", "id": 11},
),
)
@@ -2385,30 +2402,30 @@ class EagerDefaultsTest(fixtures.MappedTest):
class Thing4(cls.Basic):
pass
- @classmethod
- def setup_mappers(cls):
- Thing = cls.classes.Thing
+ def setup_mappers(self):
+ eager_defaults = True
+ Thing = self.classes.Thing
- cls.mapper_registry.map_imperatively(
- Thing, cls.tables.test, eager_defaults=True
+ self.mapper_registry.map_imperatively(
+ Thing, self.tables.test, eager_defaults=eager_defaults
)
- Thing2 = cls.classes.Thing2
+ Thing2 = self.classes.Thing2
- cls.mapper_registry.map_imperatively(
- Thing2, cls.tables.test2, eager_defaults=True
+ self.mapper_registry.map_imperatively(
+ Thing2, self.tables.test2, eager_defaults=eager_defaults
)
- Thing3 = cls.classes.Thing3
+ Thing3 = self.classes.Thing3
- cls.mapper_registry.map_imperatively(
- Thing3, cls.tables.test3, eager_defaults=True
+ self.mapper_registry.map_imperatively(
+ Thing3, self.tables.test3, eager_defaults=eager_defaults
)
- Thing4 = cls.classes.Thing4
+ Thing4 = self.classes.Thing4
- cls.mapper_registry.map_imperatively(
- Thing4, cls.tables.test4, eager_defaults=True
+ self.mapper_registry.map_imperatively(
+ Thing4, self.tables.test4, eager_defaults=eager_defaults
)
def test_server_insert_defaults_present(self):
@@ -3111,6 +3128,218 @@ class EagerDefaultsTest(fixtures.MappedTest):
)
+class EagerDefaultsSettingTest(
+ testing.AssertsExecutionResults, fixtures.TestBase
+):
+ __backend__ = True
+
+ @variation_fixture("eager_defaults", ["unspecified", "auto", True, False])
+ def eager_defaults_variations(self, request):
+ yield request.param
+
+ @variation_fixture("implicit_returning", [True, False])
+ def implicit_returning_variations(self, request):
+ yield request.param
+
+ @testing.fixture
+ def define_tables(
+ self, metadata, connection, implicit_returning_variations
+ ):
+ implicit_returning = bool(implicit_returning_variations)
+
+ t = Table(
+ "test",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column(
+ "foo",
+ Integer,
+ server_default="3",
+ ),
+ Column("bar", Integer, server_onupdate=FetchedValue()),
+ implicit_returning=implicit_returning,
+ )
+ metadata.create_all(connection)
+ return t
+
+ @testing.fixture
+ def setup_mappers(
+ self, define_tables, eager_defaults_variations, registry
+ ):
+ class Thing:
+ pass
+
+ if eager_defaults_variations.unspecified:
+ registry.map_imperatively(Thing, define_tables)
+ else:
+ eager_defaults = (
+ "auto"
+ if eager_defaults_variations.auto
+ else bool(eager_defaults_variations)
+ )
+ registry.map_imperatively(
+ Thing, define_tables, eager_defaults=eager_defaults
+ )
+ return Thing
+
+ def test_eager_default_setting_inserts(
+ self,
+ setup_mappers,
+ eager_defaults_variations,
+ implicit_returning_variations,
+ connection,
+ ):
+ Thing = setup_mappers
+ s = Session(connection)
+
+ t1, t2 = (Thing(id=1, bar=6), Thing(id=2, bar=6))
+
+ s.add_all([t1, t2])
+
+ expected_eager_defaults = eager_defaults_variations.eager_defaults or (
+ (
+ eager_defaults_variations.auto
+ or eager_defaults_variations.unspecified
+ )
+ and connection.dialect.insert_executemany_returning
+ and bool(implicit_returning_variations)
+ )
+ expect_returning = (
+ expected_eager_defaults
+ and connection.dialect.insert_returning
+ and bool(implicit_returning_variations)
+ )
+
+ with self.sql_execution_asserter(connection) as asserter:
+ s.flush()
+
+ asserter.assert_(
+ Conditional(
+ expect_returning,
+ [
+ CompiledSQL(
+ "INSERT INTO test (id, bar) VALUES (:id, :bar) "
+ "RETURNING test.foo",
+ [
+ {"id": 1, "bar": 6},
+ {"id": 2, "bar": 6},
+ ],
+ )
+ ],
+ [
+ CompiledSQL(
+ "INSERT INTO test (id, bar) VALUES (:id, :bar)",
+ [
+ {"id": 1, "bar": 6},
+ {"id": 2, "bar": 6},
+ ],
+ ),
+ Conditional(
+ expected_eager_defaults and not expect_returning,
+ [
+ CompiledSQL(
+ "SELECT test.foo AS test_foo "
+ "FROM test WHERE test.id = :pk_1",
+ [{"pk_1": 1}],
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo "
+ "FROM test WHERE test.id = :pk_1",
+ [{"pk_1": 2}],
+ ),
+ ],
+ [],
+ ),
+ ],
+ )
+ )
+
+ def test_eager_default_setting_updates(
+ self,
+ setup_mappers,
+ eager_defaults_variations,
+ implicit_returning_variations,
+ connection,
+ ):
+ Thing = setup_mappers
+ s = Session(connection)
+
+ t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=5))
+
+ s.add_all([t1, t2])
+ s.flush()
+
+ expected_eager_defaults = eager_defaults_variations.eager_defaults
+ expect_returning = (
+ expected_eager_defaults
+ and connection.dialect.update_returning
+ and bool(implicit_returning_variations)
+ )
+
+ t1.foo = 7
+ t2.foo = 12
+
+ with self.sql_execution_asserter(connection) as asserter:
+ s.flush()
+
+ asserter.assert_(
+ Conditional(
+ expect_returning,
+ [
+ CompiledSQL(
+ "UPDATE test SET foo=:foo WHERE test.id = :test_id "
+ "RETURNING test.bar",
+ [
+ {"test_id": 1, "foo": 7},
+ ],
+ ),
+ CompiledSQL(
+ "UPDATE test SET foo=:foo WHERE test.id = :test_id "
+ "RETURNING test.bar",
+ [
+ {"test_id": 2, "foo": 12},
+ ],
+ ),
+ ],
+ [
+ Conditional(
+ expected_eager_defaults and not expect_returning,
+ [
+ CompiledSQL(
+ "UPDATE test SET foo=:foo "
+ "WHERE test.id = :test_id",
+ [
+ {"test_id": 1, "foo": 7},
+ {"test_id": 2, "foo": 12},
+ ],
+ ),
+ CompiledSQL(
+ "SELECT test.bar AS test_bar "
+ "FROM test WHERE test.id = :pk_1",
+ [{"pk_1": 1}],
+ ),
+ CompiledSQL(
+ "SELECT test.bar AS test_bar "
+ "FROM test WHERE test.id = :pk_1",
+ [{"pk_1": 2}],
+ ),
+ ],
+ [
+ CompiledSQL(
+ "UPDATE test SET foo=:foo "
+ "WHERE test.id = :test_id",
+ [
+ {"test_id": 1, "foo": 7},
+ {"test_id": 2, "foo": 12},
+ ],
+ ),
+ ],
+ ),
+ ],
+ )
+ )
+
+
class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults):
"""test support for custom datatypes that return a non-__bool__ value
when compared via __eq__(), eg. ticket 3469"""