summaryrefslogtreecommitdiff
path: root/funtests/transport.py
blob: f6ae1e17a2e49a96f55014389abef4d10a73f850 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import random
import socket
import string
import sys
import time
import unittest2 as unittest
import warnings
import weakref

from nose import SkipTest

from kombu import Connection
from kombu import Exchange, Queue
from kombu.tests.utils import skip_if_quick

if sys.version_info >= (2, 5):
    from hashlib import sha256 as _digest
else:
    from sha import new as _digest  # noqa


def say(msg):
    sys.stderr.write(unicode(msg) + "\n")


def consumeN(conn, consumer, n=1, timeout=30):
    messages = []

    def callback(message_data, message):
        messages.append(message_data)
        message.ack()

    prev, consumer.callbacks = consumer.callbacks, [callback]
    consumer.consume()

    seconds = 0
    while True:
        try:
            conn.drain_events(timeout=1)
        except socket.timeout:
            seconds += 1
            msg = "Received %s/%s messages. %s seconds passed." % (
                len(messages), n, seconds)
            if seconds >= timeout:
                raise socket.timeout(msg)
            if seconds > 1:
                say(msg)
        if len(messages) >= n:
            break

    consumer.cancel()
    consumer.callback = prev
    return messages


class TransportCase(unittest.TestCase):
    transport = None
    prefix = None
    sep = '.'
    userid = None
    password = None
    event_loop_max = 100
    connection_options = {}
    suppress_disorder_warning = False
    reliable_purge = True

    connected = False
    skip_test_reason = None

    message_size_limit = None

    def before_connect(self):
        pass

    def after_connect(self, connection):
        pass

    def setUp(self):
        if self.transport:
            try:
                self.before_connect()
            except SkipTest, exc:
                self.skip_test_reason = str(exc)
            else:
                self.do_connect()
            self.exchange = Exchange(self.prefix, "direct")
            self.queue = Queue(self.prefix, self.exchange, self.prefix)

    def purge(self, names):
        chan = self.connection.channel()
        total = 0
        for queue in names:
            while 1:
                # ensure the queue is completly empty
                purged = chan.queue_purge(queue=queue)
                if not purged:
                    break
                total += purged
        chan.close()
        return total

    def get_connection(self, **options):
        if self.userid:
            options.setdefault("userid", self.userid)
        if self.password:
            options.setdefault("password", self.password)
        return Connection(transport=self.transport, **options)

    def do_connect(self):
        self.connection = self.get_connection(**self.connection_options)
        try:
            self.connection.connect()
            self.after_connect(self.connection)
        except self.connection.connection_errors:
            self.skip_test_reason = "%s transport can't connect" % (
                self.transport, )
        else:
            self.connected = True

    def verify_alive(self):
        if self.transport:
            if not self.connected:
                raise SkipTest(self.skip_test_reason)
            return True

    def purge_consumer(self, consumer):
        return self.purge([queue.name for queue in consumer.queues])

    def test_produce__consume(self):
        if not self.verify_alive():
            return
        chan1 = self.connection.channel()
        consumer = chan1.Consumer(self.queue)
        self.purge_consumer(consumer)
        producer = chan1.Producer(self.exchange)
        producer.publish({"foo": "bar"}, routing_key=self.prefix)
        message = consumeN(self.connection, consumer)
        self.assertDictEqual(message[0], {"foo": "bar"})
        chan1.close()
        self.purge([self.queue.name])

    def test_purge(self):
        if not self.verify_alive():
            return
        chan1 = self.connection.channel()
        consumer = chan1.Consumer(self.queue)
        self.purge_consumer(consumer)

        producer = chan1.Producer(self.exchange)
        for i in xrange(10):
            producer.publish({"foo": "bar"}, routing_key=self.prefix)
        if self.reliable_purge:
            self.assertEqual(consumer.purge(), 10)
            self.assertEqual(consumer.purge(), 0)
        else:
            purged = 0
            while purged < 9:
                purged += self.purge_consumer(consumer)

    def _digest(self, data):
        return _digest(data).hexdigest()

    @skip_if_quick
    def test_produce__consume_large_messages(
            self, bytes=1048576, n=10,
            charset=string.punctuation + string.letters + string.digits):
        if not self.verify_alive():
            return
        bytes = min(filter(None, [bytes, self.message_size_limit]))
        messages = ["".join(random.choice(charset)
                    for j in xrange(bytes)) + "--%s" % n
                    for i in xrange(n)]
        digests = []
        chan1 = self.connection.channel()
        consumer = chan1.Consumer(self.queue)
        self.purge_consumer(consumer)
        producer = chan1.Producer(self.exchange)
        for i, message in enumerate(messages):
            producer.publish({"text": message,
                              "i": i}, routing_key=self.prefix)
            digests.append(self._digest(message))

        received = [(msg["i"], msg["text"])
                    for msg in consumeN(self.connection, consumer, n)]
        self.assertEqual(len(received), n)
        ordering = [i for i, _ in received]
        if ordering != range(n) and not self.suppress_disorder_warning:
            warnings.warn(
                "%s did not deliver messages in FIFO order: %r" % (
                    self.transport, ordering))

        for i, text in received:
            if text != messages[i]:
                raise AssertionError("%i: %r is not %r" % (
                    i, text[-100:], messages[i][-100:]))
            self.assertEqual(self._digest(text), digests[i])

        chan1.close()
        self.purge([self.queue.name])

    def P(self, rest):
        return "%s%s%s" % (self.prefix, self.sep, rest)

    def test_produce__consume_multiple(self):
        if not self.verify_alive():
            return
        chan1 = self.connection.channel()
        producer = chan1.Producer(self.exchange)
        b1 = Queue(self.P("b1"), self.exchange, "b1")(chan1)
        b2 = Queue(self.P("b2"), self.exchange, "b2")(chan1)
        b3 = Queue(self.P("b3"), self.exchange, "b3")(chan1)
        [q.declare() for q in (b1, b2, b3)]
        self.purge([b1.name, b2.name, b3.name])

        producer.publish("b1", routing_key="b1")
        producer.publish("b2", routing_key="b2")
        producer.publish("b3", routing_key="b3")
        chan1.close()

        chan2 = self.connection.channel()
        consumer = chan2.Consumer([b1, b2, b3])
        messages = consumeN(self.connection, consumer, 3)
        self.assertItemsEqual(messages, ["b1", "b2", "b3"])
        chan2.close()
        self.purge([self.P("b1"), self.P("b2"), self.P("b3")])

    def test_timeout(self):
        if not self.verify_alive():
            return
        chan = self.connection.channel()
        self.purge([self.queue.name])
        consumer = chan.Consumer(self.queue)
        self.assertRaises(
            socket.timeout, self.connection.drain_events, timeout=0.3,
        )
        consumer.cancel()
        chan.close()

    def test_basic_get(self):
        if not self.verify_alive():
            return
        chan1 = self.connection.channel()
        producer = chan1.Producer(self.exchange)
        chan2 = self.connection.channel()
        queue = Queue(self.P("basic_get"), self.exchange, "basic_get")
        queue = queue(chan2)
        queue.declare()
        producer.publish({"basic.get": "this"}, routing_key="basic_get")
        chan1.close()

        for i in range(self.event_loop_max):
            m = queue.get()
            if m:
                break
            time.sleep(0.1)
        self.assertEqual(m.payload, {"basic.get": "this"})
        self.purge([queue.name])
        chan2.close()

    def test_cyclic_reference_transport(self):
        if not self.verify_alive():
            return

        def _createref():
            conn = self.get_connection()
            conn.transport
            conn.close()
            return weakref.ref(conn)

        self.assertIsNone(_createref()())

    def test_cyclic_reference_connection(self):
        if not self.verify_alive():
            return

        def _createref():
            conn = self.get_connection()
            conn.connect()
            conn.close()
            return weakref.ref(conn)

        self.assertIsNone(_createref()())

    def test_cyclic_reference_channel(self):
        if not self.verify_alive():
            return

        def _createref():
            conn = self.get_connection()
            conn.connect()
            chanrefs = []
            try:
                for i in xrange(100):
                    channel = conn.channel()
                    chanrefs.append(weakref.ref(channel))
                    channel.close()
            finally:
                conn.close()
            return chanrefs

        for chanref in _createref():
            self.assertIsNone(chanref())

    def tearDown(self):
        if self.transport and self.connected:
            self.connection.close()