summaryrefslogtreecommitdiff
path: root/tests/defer
diff options
context:
space:
mode:
authorAnssi Kääriäinen <akaariai@gmail.com>2014-07-05 09:03:52 +0300
committerTim Graham <timograham@gmail.com>2014-11-28 06:54:00 -0500
commitc7175fcdfe94be60c04f3b1ceb6d0b2def2b6f09 (patch)
tree409248caf9fe722d53eb1d7654176bb8a5f5c631 /tests/defer
parent912ad03226687dae91971ebd7e5cf87521f6b0de (diff)
downloaddjango-c7175fcdfe94be60c04f3b1ceb6d0b2def2b6f09.tar.gz
Fixed #901 -- Added Model.refresh_from_db() method
Thanks to github aliases dbrgn, carljm, slurms, dfunckt, and timgraham for reviews.
Diffstat (limited to 'tests/defer')
-rw-r--r--tests/defer/models.py14
-rw-r--r--tests/defer/tests.py28
2 files changed, 41 insertions, 1 deletions
diff --git a/tests/defer/models.py b/tests/defer/models.py
index ffc8a0c2c7..ecf69c0d7f 100644
--- a/tests/defer/models.py
+++ b/tests/defer/models.py
@@ -32,3 +32,17 @@ class BigChild(Primary):
class ChildProxy(Child):
class Meta:
proxy = True
+
+
+class RefreshPrimaryProxy(Primary):
+ class Meta:
+ proxy = True
+
+ def refresh_from_db(self, using=None, fields=None, **kwargs):
+ # Reloads all deferred fields if any of the fields is deferred.
+ if fields is not None:
+ fields = set(fields)
+ deferred_fields = self.get_deferred_fields()
+ if fields.intersection(deferred_fields):
+ fields = fields.union(deferred_fields)
+ super(RefreshPrimaryProxy, self).refresh_from_db(using, fields, **kwargs)
diff --git a/tests/defer/tests.py b/tests/defer/tests.py
index 43a088f3e2..597f871cc8 100644
--- a/tests/defer/tests.py
+++ b/tests/defer/tests.py
@@ -3,7 +3,7 @@ from __future__ import unicode_literals
from django.db.models.query_utils import DeferredAttribute, InvalidQuery
from django.test import TestCase
-from .models import Secondary, Primary, Child, BigChild, ChildProxy
+from .models import Secondary, Primary, Child, BigChild, ChildProxy, RefreshPrimaryProxy
class DeferTests(TestCase):
@@ -189,3 +189,29 @@ class DeferTests(TestCase):
s1_defer = Secondary.objects.only('pk').get(pk=s1.pk)
self.assertEqual(s1, s1_defer)
self.assertEqual(s1_defer, s1)
+
+ def test_refresh_not_loading_deferred_fields(self):
+ s = Secondary.objects.create()
+ rf = Primary.objects.create(name='foo', value='bar', related=s)
+ rf2 = Primary.objects.only('related', 'value').get()
+ rf.name = 'new foo'
+ rf.value = 'new bar'
+ rf.save()
+ with self.assertNumQueries(1):
+ rf2.refresh_from_db()
+ self.assertEqual(rf2.value, 'new bar')
+ with self.assertNumQueries(1):
+ self.assertEqual(rf2.name, 'new foo')
+
+ def test_custom_refresh_on_deferred_loading(self):
+ s = Secondary.objects.create()
+ rf = RefreshPrimaryProxy.objects.create(name='foo', value='bar', related=s)
+ rf2 = RefreshPrimaryProxy.objects.only('related').get()
+ rf.name = 'new foo'
+ rf.value = 'new bar'
+ rf.save()
+ with self.assertNumQueries(1):
+ # Customized refresh_from_db() reloads all deferred fields on
+ # access of any of them.
+ self.assertEqual(rf2.name, 'new foo')
+ self.assertEqual(rf2.value, 'new bar')