summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/ext/test_extendedattr.py2
-rw-r--r--test/orm/test_attributes.py6
-rw-r--r--test/orm/test_cascade.py11
-rw-r--r--test/orm/test_deferred.py45
-rw-r--r--test/orm/test_dynamic.py27
-rw-r--r--test/orm/test_expire.py36
-rw-r--r--test/orm/test_load_on_fks.py10
-rw-r--r--test/orm/test_versioning.py7
8 files changed, 127 insertions, 17 deletions
diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py
index df9d2f9d5..ad9bf0bc0 100644
--- a/test/ext/test_extendedattr.py
+++ b/test/ext/test_extendedattr.py
@@ -223,7 +223,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
data = {"a": "this is a", "b": 12}
- def loader(state, keys):
+ def loader(state, keys, passive):
for k in keys:
state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py
index 347dd4e46..bb3399a5a 100644
--- a/test/orm/test_attributes.py
+++ b/test/orm/test_attributes.py
@@ -442,7 +442,7 @@ class AttributesTest(fixtures.ORMTest):
data = {"a": "this is a", "b": 12}
- def loader(state, keys):
+ def loader(state, keys, passive):
for k in keys:
state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
@@ -488,7 +488,7 @@ class AttributesTest(fixtures.ORMTest):
def test_deferred_pickleable(self):
data = {"a": "this is a", "b": 12}
- def loader(state, keys):
+ def loader(state, keys, passive):
for k in keys:
state.dict[k] = data[k]
return attributes.ATTR_WAS_SET
@@ -2242,7 +2242,7 @@ class HistoryTest(fixtures.TestBase):
state.dict.pop("someattr", None)
state.expired_attributes.add("someattr")
- def scalar_loader(state, toload):
+ def scalar_loader(state, toload, passive):
state.dict["someattr"] = "one"
state.manager.expired_attribute_loader = scalar_loader
diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py
index cfc3ad38f..815df3620 100644
--- a/test/orm/test_cascade.py
+++ b/test/orm/test_cascade.py
@@ -842,6 +842,17 @@ class O2OSingleParentNoFlushTest(fixtures.MappedTest):
sess.add(u1)
sess.commit()
+ # in this case, u1.address has active history set, because
+ # this operation necessarily replaces the old object which must be
+ # loaded.
+ # the set operation requires that "u1" is unexpired, because the
+ # replace operation wants to load the
+ # previous value. The original test case for #2921 only included
+ # that the lazyload operation passed a no autoflush flag through
+ # to the operation, however in #5226 this has been enhanced to pass
+ # the no autoflush flag down through to the unexpire of the attributes
+ # as well, so that attribute unexpire can otherwise invoke autoflush.
+ assert "id" not in u1.__dict__
a2 = Address(email_address="asdf")
sess.add(a2)
u1.address = a2
diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py
index a7957ec28..a02ee250c 100644
--- a/test/orm/test_deferred.py
+++ b/test/orm/test_deferred.py
@@ -1,6 +1,8 @@
import sqlalchemy as sa
from sqlalchemy import ForeignKey
+from sqlalchemy import func
from sqlalchemy import Integer
+from sqlalchemy import select
from sqlalchemy import testing
from sqlalchemy import util
from sqlalchemy.orm import aliased
@@ -2023,3 +2025,46 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest):
eq_(a1.id, 1)
assert "x" in a1.__dict__
+
+
+class AutoflushTest(fixtures.DeclarativeMappedTest):
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class A(Base):
+ __tablename__ = "a"
+
+ id = Column(Integer, primary_key=True)
+ bs = relationship("B")
+
+ class B(Base):
+ __tablename__ = "b"
+ id = Column(Integer, primary_key=True)
+ a_id = Column(ForeignKey("a.id"))
+
+ A.b_count = deferred(
+ select([func.count(1)]).where(A.id == B.a_id).scalar_subquery()
+ )
+
+ def test_deferred_autoflushes(self):
+ A, B = self.classes("A", "B")
+
+ s = Session()
+
+ a1 = A(id=1, bs=[B()])
+ s.add(a1)
+ s.commit()
+
+ eq_(a1.b_count, 1)
+ s.close()
+
+ a1 = s.query(A).first()
+ assert "b_count" not in a1.__dict__
+
+ b1 = B(a_id=1)
+ s.add(b1)
+
+ eq_(a1.b_count, 2)
+
+ assert b1 in s
diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py
index 94ecf4ee2..1ca1bec03 100644
--- a/test/orm/test_dynamic.py
+++ b/test/orm/test_dynamic.py
@@ -2,6 +2,7 @@ from sqlalchemy import cast
from sqlalchemy import desc
from sqlalchemy import exc
from sqlalchemy import func
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import testing
@@ -969,17 +970,25 @@ class HistoryTest(_DynamicFixture, _fixtures.FixtureTest):
elif isinstance(obj, self.classes.Order):
attrname = "items"
- eq_(attributes.get_history(obj, attrname), compare)
+ sess = inspect(obj).session
- if compare_passive is None:
- compare_passive = compare
+ if sess:
+ sess.autoflush = False
+ try:
+ eq_(attributes.get_history(obj, attrname), compare)
- eq_(
- attributes.get_history(
- obj, attrname, attributes.LOAD_AGAINST_COMMITTED
- ),
- compare_passive,
- )
+ if compare_passive is None:
+ compare_passive = compare
+
+ eq_(
+ attributes.get_history(
+ obj, attrname, attributes.LOAD_AGAINST_COMMITTED
+ ),
+ compare_passive,
+ )
+ finally:
+ if sess:
+ sess.autoflush = True
def test_append_transient(self):
u1, a1 = self._transient_fixture()
diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py
index 083c2f465..127380fad 100644
--- a/test/orm/test_expire.py
+++ b/test/orm/test_expire.py
@@ -82,6 +82,24 @@ class ExpireTest(_fixtures.FixtureTest):
self.assert_sql_count(testing.db, go, 0)
+ def test_expire_autoflush(self):
+ User, users = self.classes.User, self.tables.users
+ Address, addresses = self.classes.Address, self.tables.addresses
+
+ mapper(User, users)
+ mapper(Address, addresses, properties={"user": relationship(User)})
+
+ s = Session()
+
+ a1 = s.query(Address).get(2)
+ u1 = s.query(User).get(7)
+ a1.user = u1
+
+ s.expire(a1, ["user_id"])
+
+ # autoflushes
+ eq_(a1.user_id, 7)
+
def test_persistence_check(self):
users, User = self.tables.users, self.classes.User
@@ -1748,6 +1766,24 @@ class RefreshTest(_fixtures.FixtureTest):
lambda: s.refresh(u),
)
+ def test_refresh_autoflush(self):
+ User, users = self.classes.User, self.tables.users
+ Address, addresses = self.classes.Address, self.tables.addresses
+
+ mapper(User, users)
+ mapper(Address, addresses, properties={"user": relationship(User)})
+
+ s = Session()
+
+ a1 = s.query(Address).get(2)
+ u1 = s.query(User).get(7)
+ a1.user = u1
+
+ s.refresh(a1, ["user_id"])
+
+ # autoflushes
+ eq_(a1.user_id, 7)
+
def test_refresh_expired(self):
User, users = self.classes.User, self.tables.users
diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py
index 6e9dde16b..0e8ac97e3 100644
--- a/test/orm/test_load_on_fks.py
+++ b/test/orm/test_load_on_fks.py
@@ -176,6 +176,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
assert c3 in p1.children
def test_autoflush_on_pending(self):
+ # ensure p1.id is not expired
+ p1.id
+
c3 = Child()
sess.add(c3)
c3.parent_id = p1.id
@@ -184,6 +187,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
assert c3.parent is None
def test_autoflush_load_on_pending_on_pending(self):
+ # ensure p1.id is not expired
+ p1.id
+
Child.parent.property.load_on_pending = True
c3 = Child()
sess.add(c3)
@@ -305,6 +311,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
for manualflush in (False, True):
Child.parent.property.load_on_pending = loadonpending
sess.autoflush = autoflush
+
+ # ensure p2.id not expired
+ p2.id
+
c2 = Child()
sess.add(c2)
c2.parent_id = p2.id
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index c6418745d..1c540145b 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -139,11 +139,11 @@ class NullVersionIdTest(fixtures.MappedTest):
# you should get a FlushError on update.
f1.value = "f1rev2"
- f1.version_id = None
with conditional_sane_rowcount_warnings(
update=True, only_returning=True
):
+ f1.version_id = None
assert_raises_message(
sa.orm.exc.FlushError,
"Instance does not contain a non-NULL version value",
@@ -1973,10 +1973,9 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
s1.expire_all()
- f1.value = "f2"
- f1.version_id = 2
-
with conditional_sane_rowcount_warnings(
update=True, only_returning=True
):
+ f1.value = "f2"
+ f1.version_id = 2
s1.flush()