summaryrefslogtreecommitdiff
path: root/test/test_consumer_group.py
blob: 03656fa6b7d63275c2b025a0d3f5cca67fb99ed2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import collections
import logging
import threading
import os
import time

import pytest
import six

from kafka import SimpleClient, SimpleProducer
from kafka.common import TopicPartition
from kafka.conn import BrokerConnection, ConnectionStates
from kafka.consumer.group import KafkaConsumer

from test.conftest import version
from test.testutil import random_string


@pytest.fixture
def simple_client(kafka_broker):
    connect_str = 'localhost:' + str(kafka_broker.port)
    return SimpleClient(connect_str)


@pytest.fixture
def topic(simple_client):
    topic = random_string(5)
    simple_client.ensure_topic_exists(topic)
    return topic


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_consumer(kafka_broker, version):

    # 0.8.2 brokers need a topic to function well
    if version >= (0, 8, 2) and version < (0, 9):
        topic(simple_client(kafka_broker))

    connect_str = 'localhost:' + str(kafka_broker.port)
    consumer = KafkaConsumer(bootstrap_servers=connect_str)
    consumer.poll(500)
    assert len(consumer._client._conns) > 0
    node_id = list(consumer._client._conns.keys())[0]
    assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED


@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_group(kafka_broker, topic):
    num_partitions = 4
    connect_str = 'localhost:' + str(kafka_broker.port)
    consumers = {}
    stop = {}
    messages = collections.defaultdict(list)
    def consumer_thread(i):
        assert i not in consumers
        assert i not in stop
        stop[i] = threading.Event()
        consumers[i] = KafkaConsumer(topic,
                                     bootstrap_servers=connect_str,
                                     heartbeat_interval_ms=500)
        while not stop[i].is_set():
            for tp, records in six.itervalues(consumers[i].poll()):
                messages[i][tp].extend(records)
        consumers[i].close()
        del consumers[i]
        del stop[i]

    num_consumers = 4
    for i in range(num_consumers):
        t = threading.Thread(target=consumer_thread, args=(i,))
        t.daemon = True
        t.start()

    try:
        timeout = time.time() + 35
        while True:
            for c in range(num_consumers):
                if c not in consumers:
                    break
                elif not consumers[c].assignment():
                    break
            else:
                for c in range(num_consumers):
                    logging.info("[%s] %s %s: %s", c,
                                 consumers[c]._coordinator.generation,
                                 consumers[c]._coordinator.member_id,
                                 consumers[c].assignment())
                break
            assert time.time() < timeout, "timeout waiting for assignments"

        group_assignment = set()
        for c in range(num_consumers):
            assert len(consumers[c].assignment()) != 0
            assert set.isdisjoint(consumers[c].assignment(), group_assignment)
            group_assignment.update(consumers[c].assignment())

        assert group_assignment == set([
            TopicPartition(topic, partition)
            for partition in range(num_partitions)])

    finally:
        for c in range(num_consumers):
            stop[c].set()