summaryrefslogtreecommitdiff
path: root/kafka/partitioner/roundrobin.py
blob: e68c372425c87677d12e1691d06eff6fae02c74c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import absolute_import

from kafka.partitioner.base import Partitioner


class RoundRobinPartitioner(Partitioner):
    def __init__(self, partitions=None):
        self.partitions_iterable = CachedPartitionCycler(partitions)
        if partitions:
            self._set_partitions(partitions)
        else:
            self.partitions = None

    def __call__(self, key, all_partitions=None, available_partitions=None):
        if available_partitions:
            cur_partitions = available_partitions
        else:
            cur_partitions = all_partitions
        if not self.partitions:
            self._set_partitions(cur_partitions)
        elif cur_partitions != self.partitions_iterable.partitions and cur_partitions is not None:
            self._set_partitions(cur_partitions)
        return next(self.partitions_iterable)

    def _set_partitions(self, available_partitions):
        self.partitions = available_partitions
        self.partitions_iterable.set_partitions(available_partitions)

    def partition(self, key, all_partitions=None, available_partitions=None):
        return self.__call__(key, all_partitions, available_partitions)


class CachedPartitionCycler(object):
    def __init__(self, partitions=None):
        self.partitions = partitions
        if partitions:
            assert type(partitions) is list
        self.cur_pos = None

    def __next__(self):
        return self.next()

    @staticmethod
    def _index_available(cur_pos, partitions):
        return cur_pos < len(partitions)

    def set_partitions(self, partitions):
        if self.cur_pos:
            if not self._index_available(self.cur_pos, partitions):
                self.cur_pos = 0
                self.partitions = partitions
                return None

            self.partitions = partitions
            next_item = self.partitions[self.cur_pos]
            if next_item in partitions:
                self.cur_pos = partitions.index(next_item)
            else:
                self.cur_pos = 0
            return None
        self.partitions = partitions

    def next(self):
        assert self.partitions is not None
        if self.cur_pos is None or not self._index_available(self.cur_pos, self.partitions):
            self.cur_pos = 1
            return self.partitions[0]
        cur_item = self.partitions[self.cur_pos]
        self.cur_pos += 1
        return cur_item