summaryrefslogtreecommitdiff
path: root/test/test_consumer_group.py
blob: 795e12739585dd5ae97e98b66e6405268167076a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import collections
import logging
import threading
import os
import time

import pytest
import six

from kafka import KafkaClient, SimpleProducer
from kafka.common import TopicPartition
from kafka.conn import BrokerConnection, ConnectionStates
from kafka.consumer.group import KafkaConsumer

from test.fixtures import KafkaFixture, ZookeeperFixture
from test.testutil import random_string


@pytest.fixture(scope="module")
def version():
    if 'KAFKA_VERSION' not in os.environ:
        return ()
    return tuple(map(int, os.environ['KAFKA_VERSION'].split('.')))


@pytest.fixture(scope="module")
def zookeeper(version, request):
    assert version
    zk = ZookeeperFixture.instance()
    def fin():
        zk.close()
    request.addfinalizer(fin)
    return zk


@pytest.fixture(scope="module")
def kafka_broker(version, zookeeper, request):
    assert version
    k = KafkaFixture.instance(0, zookeeper.host, zookeeper.port,
                              partitions=4)
    def fin():
        k.close()
    request.addfinalizer(fin)
    return k


@pytest.fixture
def simple_client(kafka_broker):
    connect_str = 'localhost:' + str(kafka_broker.port)
    return KafkaClient(connect_str)


@pytest.fixture
def topic(simple_client):
    topic = random_string(5)
    simple_client.ensure_topic_exists(topic)
    return topic


@pytest.fixture
def topic_with_messages(simple_client, topic):
    producer = SimpleProducer(simple_client)
    for i in six.moves.xrange(100):
        producer.send_messages(topic, 'msg_%d' % i)
    return topic


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_consumer(kafka_broker, version):

    # 0.8.2 brokers need a topic to function well
    if version >= (0, 8, 2) and version < (0, 9):
        topic(simple_client(kafka_broker))

    connect_str = 'localhost:' + str(kafka_broker.port)
    consumer = KafkaConsumer(bootstrap_servers=connect_str)
    consumer.poll(500)
    assert len(consumer._client._conns) > 0
    node_id = list(consumer._client._conns.keys())[0]
    assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED


@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_group(kafka_broker, topic):
    num_partitions = 4
    connect_str = 'localhost:' + str(kafka_broker.port)
    consumers = {}
    stop = {}
    messages = collections.defaultdict(list)
    def consumer_thread(i):
        assert i not in consumers
        assert i not in stop
        stop[i] = threading.Event()
        consumers[i] = KafkaConsumer(topic,
                                     bootstrap_servers=connect_str,
                                     heartbeat_interval_ms=500,
                                     request_timeout_ms=1000)
        while not stop[i].is_set():
            for tp, records in six.itervalues(consumers[i].poll()):
                messages[i][tp].extend(records)
        consumers[i].close()
        del consumers[i]
        del stop[i]

    num_consumers = 4
    for i in range(num_consumers):
        threading.Thread(target=consumer_thread, args=(i,)).start()

    try:
        timeout = time.time() + 35
        while True:
            for c in range(num_consumers):
                if c not in consumers:
                    break
                elif not consumers[c].assignment():
                    break
            else:
                for c in range(num_consumers):
                    logging.info("%s: %s", c, consumers[c].assignment())
                break
            assert time.time() < timeout, "timeout waiting for assignments"

        group_assignment = set()
        for c in range(num_consumers):
            assert len(consumers[c].assignment()) != 0
            assert set.isdisjoint(consumers[c].assignment(), group_assignment)
            group_assignment.update(consumers[c].assignment())

        assert group_assignment == set([
            TopicPartition(topic, partition)
            for partition in range(num_partitions)])

    finally:
        for c in range(num_consumers):
            stop[c].set()


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_correlation_id_rollover(kafka_broker):
    logging.getLogger('kafka.conn').setLevel(logging.ERROR)
    from kafka.protocol.metadata import MetadataRequest
    conn = BrokerConnection('localhost', kafka_broker.port,
                            receive_buffer_bytes=131072,
                            max_in_flight_requests_per_connection=100)
    req = MetadataRequest([])
    while not conn.connected():
        conn.connect()
    futures = collections.deque()
    start = time.time()
    done = 0
    for i in six.moves.xrange(2**13):
        if not conn.can_send_more():
            conn.recv(timeout=None)
        futures.append(conn.send(req))
        conn.recv()
        while futures and futures[0].is_done:
            f = futures.popleft()
            if not f.succeeded():
                raise f.exception
            done += 1
        if time.time() > start + 10:
            print ("%d done" % done)
            start = time.time()

    while futures:
        conn.recv()
        if futures[0].is_done:
            f = futures.popleft()
            if not f.succeeded():
                raise f.exception