summaryrefslogtreecommitdiff
path: root/tests/signals_regress/tests.py
blob: 8fb3ad5a48e8cedd5fdbdd4d0760b488760c4ecc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import absolute_import

from django.db import models
from django.test import TestCase

from .models import Author, Book


class SignalsRegressTests(TestCase):
    """
    Testing signals before/after saving and deleting.
    """

    def get_signal_output(self, fn, *args, **kwargs):
        # Flush any existing signal output
        self.signal_output = []
        fn(*args, **kwargs)
        return self.signal_output

    def pre_save_test(self, signal, sender, instance, **kwargs):
        self.signal_output.append('pre_save signal, %s' % instance)
        if kwargs.get('raw'):
            self.signal_output.append('Is raw')

    def post_save_test(self, signal, sender, instance, **kwargs):
        self.signal_output.append('post_save signal, %s' % instance)
        if 'created' in kwargs:
            if kwargs['created']:
                self.signal_output.append('Is created')
            else:
                self.signal_output.append('Is updated')
        if kwargs.get('raw'):
            self.signal_output.append('Is raw')

    def pre_delete_test(self, signal, sender, instance, **kwargs):
        self.signal_output.append('pre_save signal, %s' % instance)
        self.signal_output.append('instance.id is not None: %s' % (instance.id != None))

    def post_delete_test(self, signal, sender, instance, **kwargs):
        self.signal_output.append('post_delete signal, %s' % instance)
        self.signal_output.append('instance.id is not None: %s' % (instance.id != None))

    def setUp(self):
        self.signal_output = []
        # Save up the number of connected signals so that we can check at the end
        # that all the signals we register get properly unregistered (#9989)
        self.pre_signals = (len(models.signals.pre_save.receivers),
                       len(models.signals.post_save.receivers),
                       len(models.signals.pre_delete.receivers),
                       len(models.signals.post_delete.receivers))

        models.signals.pre_save.connect(self.pre_save_test)
        models.signals.post_save.connect(self.post_save_test)
        models.signals.pre_delete.connect(self.pre_delete_test)
        models.signals.post_delete.connect(self.post_delete_test)

    def tearDown(self):
        models.signals.post_delete.disconnect(self.post_delete_test)
        models.signals.pre_delete.disconnect(self.pre_delete_test)
        models.signals.post_save.disconnect(self.post_save_test)
        models.signals.pre_save.disconnect(self.pre_save_test)

        # Check that all our signals got disconnected properly.
        post_signals = (len(models.signals.pre_save.receivers),
                        len(models.signals.post_save.receivers),
                        len(models.signals.pre_delete.receivers),
                        len(models.signals.post_delete.receivers))

        self.assertEqual(self.pre_signals, post_signals)

    def test_model_signals(self):
        """ Model saves should throw some signals. """
        a1 = Author(name='Neal Stephenson')
        self.assertEqual(self.get_signal_output(a1.save), [
            "pre_save signal, Neal Stephenson",
            "post_save signal, Neal Stephenson",
            "Is created"
        ])

        b1 = Book(name='Snow Crash')
        self.assertEqual(self.get_signal_output(b1.save), [
            "pre_save signal, Snow Crash",
            "post_save signal, Snow Crash",
            "Is created"
        ])

    def test_m2m_signals(self):
        """ Assigning and removing to/from m2m shouldn't generate an m2m signal """

        b1 = Book(name='Snow Crash')
        self.get_signal_output(b1.save)
        a1 = Author(name='Neal Stephenson')
        self.get_signal_output(a1.save)
        self.assertEqual(self.get_signal_output(setattr, b1, 'authors', [a1]), [])
        self.assertEqual(self.get_signal_output(setattr, b1, 'authors', []), [])