diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-12-14 17:24:47 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-12-14 17:30:21 -0500 |
| commit | 0e4c4d7efc08d04c3c0ae960428b08ada37e4a91 (patch) | |
| tree | 4421c6681b9bc6025c5baccffbe5d61b901c48da /test/orm | |
| parent | 7d96ad4d535dc02a8ab1384df1db94dea2a045b5 (diff) | |
| download | sqlalchemy-0e4c4d7efc08d04c3c0ae960428b08ada37e4a91.tar.gz | |
- Fixed bug in :meth:`.Update.return_defaults` which would cause all
insert-default holding columns not otherwise included in the SET
clause (such as primary key cols) to get rendered into the RETURNING
even though this is an UPDATE.
- Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this
flag would not be honored correctly in the case that multiple
UPDATE statements were to be emitted, either as part of a flush
or a bulk update operation. Additionally, RETURNING
would be emitted unnecessarily within update statements.
fixes #3609
Diffstat (limited to 'test/orm')
| -rw-r--r-- | test/orm/test_unitofworkv2.py | 447 | ||||
| -rw-r--r-- | test/orm/test_versioning.py | 19 |
2 files changed, 462 insertions, 4 deletions
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 09240dfdb..c8ce13c91 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -5,7 +5,8 @@ from sqlalchemy.testing.schema import Table, Column from test.orm import _fixtures from sqlalchemy import exc, util from sqlalchemy.testing import fixtures, config -from sqlalchemy import Integer, String, ForeignKey, func, literal +from sqlalchemy import Integer, String, ForeignKey, func, \ + literal, FetchedValue, text from sqlalchemy.orm import mapper, relationship, backref, \ create_session, unitofwork, attributes,\ Session, exc as orm_exc @@ -1848,6 +1849,450 @@ class NoAttrEventInFlushTest(fixtures.MappedTest): eq_(t1.returning_val, 5) +class EagerDefaultsTest(fixtures.MappedTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + 'test', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer, server_default="3") + ) + + Table( + 'test2', metadata, + Column('id', Integer, primary_key=True), + Column('foo', Integer), + Column('bar', Integer, server_onupdate=FetchedValue()) + ) + + @classmethod + def setup_classes(cls): + class Thing(cls.Basic): + pass + + class Thing2(cls.Basic): + pass + + @classmethod + def setup_mappers(cls): + Thing = cls.classes.Thing + + mapper(Thing, cls.tables.test, eager_defaults=True) + + Thing2 = cls.classes.Thing2 + + mapper(Thing2, cls.tables.test2, eager_defaults=True) + + def test_insert_defaults_present(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1, foo=5), + Thing(id=2, foo=10) + ) + + s.add_all([t1, t2]) + + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, :foo)", + [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}] + ), + ) + + def go(): + eq_(t1.foo, 5) + eq_(t2.foo, 10) + + self.assert_sql_count(testing.db, go, 0) + + def test_insert_defaults_present_as_expr(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1, foo=text("2 + 5")), + Thing(id=2, foo=text("5 + 5")) + ) + + s.add_all([t1, t2]) + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) " + "RETURNING test.foo", + [{'id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) " + "RETURNING test.foo", + [{'id': 2}], + dialect='postgresql' + ) + ) + + else: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)", + [{'id': 1}] + ), + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)", + [{'id': 2}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 2}] + ), + ) + + def go(): + eq_(t1.foo, 7) + eq_(t2.foo, 10) + + self.assert_sql_count(testing.db, go, 0) + + def test_insert_defaults_nonpresent(self): + Thing = self.classes.Thing + s = Session() + + t1, t2 = ( + Thing(id=1), + Thing(id=2) + ) + + s.add_all([t1, t2]) + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", + [{'id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo", + [{'id': 2}], + dialect='postgresql' + ), + ) + else: + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "INSERT INTO test (id) VALUES (:id)", + [{'id': 1}, {'id': 2}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test.foo AS test_foo FROM test " + "WHERE test.id = :param_1", + [{'param_1': 2}] + ) + ) + + def test_update_defaults_nonpresent(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2, t3, t4 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3), + Thing2(id=3, foo=3, bar=4), + Thing2(id=4, foo=4, bar=5) + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + t1.foo = 5 + t2.foo = 6 + t2.bar = 10 + t3.foo = 7 + t4.foo = 8 + t4.bar = 12 + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 5, 'test2_id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 7, 'test2_id': 3}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 8, 'bar': 12, 'test2_id': 4}], + dialect='postgresql' + ), + ) + else: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 5, 'test2_id': 1}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 7, 'test2_id': 3}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 8, 'bar': 12, 'test2_id': 4}], + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 3}] + ) + ) + + def go(): + eq_(t1.bar, 2) + eq_(t2.bar, 10) + eq_(t3.bar, 4) + eq_(t4.bar, 12) + + self.assert_sql_count(testing.db, go, 0) + + def test_update_defaults_present_as_expr(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2, t3, t4 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3), + Thing2(id=3, foo=3, bar=4), + Thing2(id=4, foo=4, bar=5) + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + t1.foo = 5 + t1.bar = text("1 + 1") + t2.foo = 6 + t2.bar = 10 + t3.foo = 7 + t4.foo = 8 + t4.bar = text("5 + 7") + + if testing.db.dialect.implicit_returning: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=1 + 1 " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 5, 'test2_id': 1}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test2.id = %(test2_id)s", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s " + "WHERE test2.id = %(test2_id)s " + "RETURNING test2.bar", + [{'foo': 7, 'test2_id': 3}], + dialect='postgresql' + ), + CompiledSQL( + "UPDATE test2 SET foo=%(foo)s, bar=5 + 7 " + "WHERE test2.id = %(test2_id)s RETURNING test2.bar", + [{'foo': 8, 'test2_id': 4}], + dialect='postgresql' + ), + ) + else: + self.assert_sql_execution( + testing.db, + s.flush, + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=1 + 1 " + "WHERE test2.id = :test2_id", + [{'foo': 5, 'test2_id': 1}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 6, 'bar': 10, 'test2_id': 2}], + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 7, 'test2_id': 3}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=5 + 7 " + "WHERE test2.id = :test2_id", + [{'foo': 8, 'test2_id': 4}], + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 1}] + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 3}] + ), + CompiledSQL( + "SELECT test2.bar AS test2_bar FROM test2 " + "WHERE test2.id = :param_1", + [{'param_1': 4}] + ) + ) + + def go(): + eq_(t1.bar, 2) + eq_(t2.bar, 10) + eq_(t3.bar, 4) + eq_(t4.bar, 12) + + self.assert_sql_count(testing.db, go, 0) + + def test_insert_defaults_bulk_insert(self): + Thing = self.classes.Thing + s = Session() + + mappings = [ + {"id": 1}, + {"id": 2} + ] + + self.assert_sql_execution( + testing.db, + lambda: s.bulk_insert_mappings(Thing, mappings), + CompiledSQL( + "INSERT INTO test (id) VALUES (:id)", + [{'id': 1}, {'id': 2}] + ) + ) + + def test_update_defaults_bulk_update(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2, t3, t4 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3), + Thing2(id=3, foo=3, bar=4), + Thing2(id=4, foo=4, bar=5) + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + mappings = [ + {"id": 1, "foo": 5}, + {"id": 2, "foo": 6, "bar": 10}, + {"id": 3, "foo": 7}, + {"id": 4, "foo": 8} + ] + + self.assert_sql_execution( + testing.db, + lambda: s.bulk_update_mappings(Thing2, mappings), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 5, 'test2_id': 1}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo, bar=:bar " + "WHERE test2.id = :test2_id", + [{'foo': 6, 'bar': 10, 'test2_id': 2}] + ), + CompiledSQL( + "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id", + [{'foo': 7, 'test2_id': 3}, {'foo': 8, 'test2_id': 4}] + ) + ) + + def test_update_defaults_present(self): + Thing2 = self.classes.Thing2 + s = Session() + + t1, t2 = ( + Thing2(id=1, foo=1, bar=2), + Thing2(id=2, foo=2, bar=3) + ) + + s.add_all([t1, t2]) + s.flush() + + t1.bar = 5 + t2.bar = 10 + + self.assert_sql_execution( + testing.db, + s.commit, + CompiledSQL( + "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s", + [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}], + dialect='postgresql' + ) + ) + class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults): """test support for custom datatypes that return a non-__bool__ value when compared via __eq__(), eg. ticket 3469""" diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index f42069230..124053d47 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -894,19 +894,26 @@ class ServerVersioningTest(fixtures.MappedTest): class Bar(cls.Basic): pass - def _fixture(self, expire_on_commit=True): + def _fixture(self, expire_on_commit=True, eager_defaults=False): Foo, version_table = self.classes.Foo, self.tables.version_table mapper( Foo, version_table, version_id_col=version_table.c.version_id, version_id_generator=False, + eager_defaults=eager_defaults ) s1 = Session(expire_on_commit=expire_on_commit) return s1 def test_insert_col(self): - sess = self._fixture() + self._test_insert_col() + + def test_insert_col_eager_defaults(self): + self._test_insert_col(eager_defaults=True) + + def _test_insert_col(self, **kw): + sess = self._fixture(**kw) f1 = self.classes.Foo(value='f1') sess.add(f1) @@ -935,7 +942,13 @@ class ServerVersioningTest(fixtures.MappedTest): self.assert_sql_execution(testing.db, sess.flush, *statements) def test_update_col(self): - sess = self._fixture() + self._test_update_col() + + def test_update_col_eager_defaults(self): + self._test_update_col(eager_defaults=True) + + def _test_update_col(self, **kw): + sess = self._fixture(**kw) f1 = self.classes.Foo(value='f1') sess.add(f1) |
