diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/ext/test_extendedattr.py | 2 | ||||
| -rw-r--r-- | test/orm/test_attributes.py | 6 | ||||
| -rw-r--r-- | test/orm/test_cascade.py | 11 | ||||
| -rw-r--r-- | test/orm/test_deferred.py | 45 | ||||
| -rw-r--r-- | test/orm/test_dynamic.py | 27 | ||||
| -rw-r--r-- | test/orm/test_expire.py | 36 | ||||
| -rw-r--r-- | test/orm/test_load_on_fks.py | 10 | ||||
| -rw-r--r-- | test/orm/test_versioning.py | 7 |
8 files changed, 127 insertions, 17 deletions
diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index df9d2f9d5..ad9bf0bc0 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -223,7 +223,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): data = {"a": "this is a", "b": 12} - def loader(state, keys): + def loader(state, keys, passive): for k in keys: state.dict[k] = data[k] return attributes.ATTR_WAS_SET diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 347dd4e46..bb3399a5a 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -442,7 +442,7 @@ class AttributesTest(fixtures.ORMTest): data = {"a": "this is a", "b": 12} - def loader(state, keys): + def loader(state, keys, passive): for k in keys: state.dict[k] = data[k] return attributes.ATTR_WAS_SET @@ -488,7 +488,7 @@ class AttributesTest(fixtures.ORMTest): def test_deferred_pickleable(self): data = {"a": "this is a", "b": 12} - def loader(state, keys): + def loader(state, keys, passive): for k in keys: state.dict[k] = data[k] return attributes.ATTR_WAS_SET @@ -2242,7 +2242,7 @@ class HistoryTest(fixtures.TestBase): state.dict.pop("someattr", None) state.expired_attributes.add("someattr") - def scalar_loader(state, toload): + def scalar_loader(state, toload, passive): state.dict["someattr"] = "one" state.manager.expired_attribute_loader = scalar_loader diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index cfc3ad38f..815df3620 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -842,6 +842,17 @@ class O2OSingleParentNoFlushTest(fixtures.MappedTest): sess.add(u1) sess.commit() + # in this case, u1.address has active history set, because + # this operation necessarily replaces the old object which must be + # loaded. + # the set operation requires that "u1" is unexpired, because the + # replace operation wants to load the + # previous value. The original test case for #2921 only included + # that the lazyload operation passed a no autoflush flag through + # to the operation, however in #5226 this has been enhanced to pass + # the no autoflush flag down through to the unexpire of the attributes + # as well, so that attribute unexpire can otherwise invoke autoflush. + assert "id" not in u1.__dict__ a2 = Address(email_address="asdf") sess.add(a2) u1.address = a2 diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py index a7957ec28..a02ee250c 100644 --- a/test/orm/test_deferred.py +++ b/test/orm/test_deferred.py @@ -1,6 +1,8 @@ import sqlalchemy as sa from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import testing from sqlalchemy import util from sqlalchemy.orm import aliased @@ -2023,3 +2025,46 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): eq_(a1.id, 1) assert "x" in a1.__dict__ + + +class AutoflushTest(fixtures.DeclarativeMappedTest): + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class A(Base): + __tablename__ = "a" + + id = Column(Integer, primary_key=True) + bs = relationship("B") + + class B(Base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey("a.id")) + + A.b_count = deferred( + select([func.count(1)]).where(A.id == B.a_id).scalar_subquery() + ) + + def test_deferred_autoflushes(self): + A, B = self.classes("A", "B") + + s = Session() + + a1 = A(id=1, bs=[B()]) + s.add(a1) + s.commit() + + eq_(a1.b_count, 1) + s.close() + + a1 = s.query(A).first() + assert "b_count" not in a1.__dict__ + + b1 = B(a_id=1) + s.add(b1) + + eq_(a1.b_count, 2) + + assert b1 in s diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index 94ecf4ee2..1ca1bec03 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -2,6 +2,7 @@ from sqlalchemy import cast from sqlalchemy import desc from sqlalchemy import exc from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import testing @@ -969,17 +970,25 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): elif isinstance(obj, self.classes.Order): attrname = "items" - eq_(attributes.get_history(obj, attrname), compare) + sess = inspect(obj).session - if compare_passive is None: - compare_passive = compare + if sess: + sess.autoflush = False + try: + eq_(attributes.get_history(obj, attrname), compare) - eq_( - attributes.get_history( - obj, attrname, attributes.LOAD_AGAINST_COMMITTED - ), - compare_passive, - ) + if compare_passive is None: + compare_passive = compare + + eq_( + attributes.get_history( + obj, attrname, attributes.LOAD_AGAINST_COMMITTED + ), + compare_passive, + ) + finally: + if sess: + sess.autoflush = True def test_append_transient(self): u1, a1 = self._transient_fixture() diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index 083c2f465..127380fad 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -82,6 +82,24 @@ class ExpireTest(_fixtures.FixtureTest): self.assert_sql_count(testing.db, go, 0) + def test_expire_autoflush(self): + User, users = self.classes.User, self.tables.users + Address, addresses = self.classes.Address, self.tables.addresses + + mapper(User, users) + mapper(Address, addresses, properties={"user": relationship(User)}) + + s = Session() + + a1 = s.query(Address).get(2) + u1 = s.query(User).get(7) + a1.user = u1 + + s.expire(a1, ["user_id"]) + + # autoflushes + eq_(a1.user_id, 7) + def test_persistence_check(self): users, User = self.tables.users, self.classes.User @@ -1748,6 +1766,24 @@ class RefreshTest(_fixtures.FixtureTest): lambda: s.refresh(u), ) + def test_refresh_autoflush(self): + User, users = self.classes.User, self.tables.users + Address, addresses = self.classes.Address, self.tables.addresses + + mapper(User, users) + mapper(Address, addresses, properties={"user": relationship(User)}) + + s = Session() + + a1 = s.query(Address).get(2) + u1 = s.query(User).get(7) + a1.user = u1 + + s.refresh(a1, ["user_id"]) + + # autoflushes + eq_(a1.user_id, 7) + def test_refresh_expired(self): User, users = self.classes.User, self.tables.users diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py index 6e9dde16b..0e8ac97e3 100644 --- a/test/orm/test_load_on_fks.py +++ b/test/orm/test_load_on_fks.py @@ -176,6 +176,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): assert c3 in p1.children def test_autoflush_on_pending(self): + # ensure p1.id is not expired + p1.id + c3 = Child() sess.add(c3) c3.parent_id = p1.id @@ -184,6 +187,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): assert c3.parent is None def test_autoflush_load_on_pending_on_pending(self): + # ensure p1.id is not expired + p1.id + Child.parent.property.load_on_pending = True c3 = Child() sess.add(c3) @@ -305,6 +311,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): for manualflush in (False, True): Child.parent.property.load_on_pending = loadonpending sess.autoflush = autoflush + + # ensure p2.id not expired + p2.id + c2 = Child() sess.add(c2) c2.parent_id = p2.id diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index c6418745d..1c540145b 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -139,11 +139,11 @@ class NullVersionIdTest(fixtures.MappedTest): # you should get a FlushError on update. f1.value = "f1rev2" - f1.version_id = None with conditional_sane_rowcount_warnings( update=True, only_returning=True ): + f1.version_id = None assert_raises_message( sa.orm.exc.FlushError, "Instance does not contain a non-NULL version value", @@ -1973,10 +1973,9 @@ class VersioningMappedSelectTest(fixtures.MappedTest): s1.expire_all() - f1.value = "f2" - f1.version_id = 2 - with conditional_sane_rowcount_warnings( update=True, only_returning=True ): + f1.value = "f2" + f1.version_id = 2 s1.flush() |
