summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2017-12-21 14:46:10 -0800
committerGitHub <noreply@github.com>2017-12-21 14:46:10 -0800
commitad024d1e897dbf16bd629fa63895bd7af4a8d959 (patch)
treef1993351b2c6487e8e623cefabf42ddf7477f666
parent995664c7d407009a0a1030c7541848eb5ad51c97 (diff)
downloadkafka-python-ad024d1e897dbf16bd629fa63895bd7af4a8d959.tar.gz
KAFKA-3888 Use background thread to process consumer heartbeats (#1266)
-rw-r--r--kafka/client_async.py465
-rw-r--r--kafka/conn.py2
-rw-r--r--kafka/consumer/fetcher.py3
-rw-r--r--kafka/consumer/group.py102
-rw-r--r--kafka/coordinator/base.py674
-rw-r--r--kafka/coordinator/consumer.py242
-rw-r--r--kafka/coordinator/heartbeat.py49
-rw-r--r--kafka/errors.py13
-rw-r--r--kafka/protocol/group.py2
-rw-r--r--test/test_client_async.py11
-rw-r--r--test/test_consumer.py4
-rw-r--r--test/test_consumer_group.py45
-rw-r--r--test/test_consumer_integration.py3
-rw-r--r--test/test_coordinator.py100
14 files changed, 977 insertions, 738 deletions
diff --git a/kafka/client_async.py b/kafka/client_async.py
index 1350503..24162ad 100644
--- a/kafka/client_async.py
+++ b/kafka/client_async.py
@@ -3,8 +3,6 @@ from __future__ import absolute_import, division
import collections
import copy
import functools
-import heapq
-import itertools
import logging
import random
import threading
@@ -202,15 +200,17 @@ class KafkaClient(object):
self._conns = {}
self._connecting = set()
self._refresh_on_disconnects = True
- self._delayed_tasks = DelayedTaskQueue()
self._last_bootstrap = 0
self._bootstrap_fails = 0
self._wake_r, self._wake_w = socket.socketpair()
self._wake_r.setblocking(False)
self._wake_lock = threading.Lock()
+ self._lock = threading.RLock()
+
# when requests complete, they are transferred to this queue prior to
- # invocation.
+ # invocation. The purpose is to avoid invoking them while holding the
+ # lock above.
self._pending_completion = collections.deque()
self._selector.register(self._wake_r, selectors.EVENT_READ)
@@ -296,90 +296,92 @@ class KafkaClient(object):
return conn.disconnected() and not conn.blacked_out()
def _conn_state_change(self, node_id, conn):
- if conn.connecting():
- # SSL connections can enter this state 2x (second during Handshake)
- if node_id not in self._connecting:
- self._connecting.add(node_id)
- self._selector.register(conn._sock, selectors.EVENT_WRITE)
-
- elif conn.connected():
- log.debug("Node %s connected", node_id)
- if node_id in self._connecting:
- self._connecting.remove(node_id)
-
- try:
- self._selector.unregister(conn._sock)
- except KeyError:
- pass
- self._selector.register(conn._sock, selectors.EVENT_READ, conn)
- if self._sensors:
- self._sensors.connection_created.record()
-
- self._idle_expiry_manager.update(node_id)
-
- if 'bootstrap' in self._conns and node_id != 'bootstrap':
- bootstrap = self._conns.pop('bootstrap')
- # XXX: make conn.close() require error to cause refresh
- self._refresh_on_disconnects = False
- bootstrap.close()
- self._refresh_on_disconnects = True
+ with self._lock:
+ if conn.connecting():
+ # SSL connections can enter this state 2x (second during Handshake)
+ if node_id not in self._connecting:
+ self._connecting.add(node_id)
+ self._selector.register(conn._sock, selectors.EVENT_WRITE)
+
+ elif conn.connected():
+ log.debug("Node %s connected", node_id)
+ if node_id in self._connecting:
+ self._connecting.remove(node_id)
- # Connection failures imply that our metadata is stale, so let's refresh
- elif conn.state is ConnectionStates.DISCONNECTING:
- if node_id in self._connecting:
- self._connecting.remove(node_id)
- try:
- self._selector.unregister(conn._sock)
- except KeyError:
- pass
- if self._sensors:
- self._sensors.connection_closed.record()
+ try:
+ self._selector.unregister(conn._sock)
+ except KeyError:
+ pass
+ self._selector.register(conn._sock, selectors.EVENT_READ, conn)
+ if self._sensors:
+ self._sensors.connection_created.record()
+
+ self._idle_expiry_manager.update(node_id)
+
+ if 'bootstrap' in self._conns and node_id != 'bootstrap':
+ bootstrap = self._conns.pop('bootstrap')
+ # XXX: make conn.close() require error to cause refresh
+ self._refresh_on_disconnects = False
+ bootstrap.close()
+ self._refresh_on_disconnects = True
+
+ # Connection failures imply that our metadata is stale, so let's refresh
+ elif conn.state is ConnectionStates.DISCONNECTING:
+ if node_id in self._connecting:
+ self._connecting.remove(node_id)
+ try:
+ self._selector.unregister(conn._sock)
+ except KeyError:
+ pass
+ if self._sensors:
+ self._sensors.connection_closed.record()
- idle_disconnect = False
- if self._idle_expiry_manager.is_expired(node_id):
- idle_disconnect = True
- self._idle_expiry_manager.remove(node_id)
+ idle_disconnect = False
+ if self._idle_expiry_manager.is_expired(node_id):
+ idle_disconnect = True
+ self._idle_expiry_manager.remove(node_id)
- if self._refresh_on_disconnects and not self._closed and not idle_disconnect:
- log.warning("Node %s connection failed -- refreshing metadata", node_id)
- self.cluster.request_update()
+ if self._refresh_on_disconnects and not self._closed and not idle_disconnect:
+ log.warning("Node %s connection failed -- refreshing metadata", node_id)
+ self.cluster.request_update()
def _maybe_connect(self, node_id):
"""Idempotent non-blocking connection attempt to the given node id."""
- broker = self.cluster.broker_metadata(node_id)
- conn = self._conns.get(node_id)
-
- if conn is None:
- assert broker, 'Broker id %s not in current metadata' % node_id
-
- log.debug("Initiating connection to node %s at %s:%s",
- node_id, broker.host, broker.port)
- host, port, afi = get_ip_port_afi(broker.host)
- cb = functools.partial(self._conn_state_change, node_id)
- conn = BrokerConnection(host, broker.port, afi,
- state_change_callback=cb,
- node_id=node_id,
- **self.config)
- self._conns[node_id] = conn
-
- # Check if existing connection should be recreated because host/port changed
- elif conn.disconnected() and broker is not None:
- host, _, __ = get_ip_port_afi(broker.host)
- if conn.host != host or conn.port != broker.port:
- log.info("Broker metadata change detected for node %s"
- " from %s:%s to %s:%s", node_id, conn.host, conn.port,
- broker.host, broker.port)
-
- # Drop old connection object.
- # It will be recreated on next _maybe_connect
- self._conns.pop(node_id)
- return False
+ with self._lock:
+ broker = self.cluster.broker_metadata(node_id)
+ conn = self._conns.get(node_id)
- elif conn.connected():
- return True
+ if conn is None:
+ assert broker, 'Broker id %s not in current metadata' % node_id
+
+ log.debug("Initiating connection to node %s at %s:%s",
+ node_id, broker.host, broker.port)
+ host, port, afi = get_ip_port_afi(broker.host)
+ cb = functools.partial(self._conn_state_change, node_id)
+ conn = BrokerConnection(host, broker.port, afi,
+ state_change_callback=cb,
+ node_id=node_id,
+ **self.config)
+ self._conns[node_id] = conn
+
+ # Check if existing connection should be recreated because host/port changed
+ elif conn.disconnected() and broker is not None:
+ host, _, __ = get_ip_port_afi(broker.host)
+ if conn.host != host or conn.port != broker.port:
+ log.info("Broker metadata change detected for node %s"
+ " from %s:%s to %s:%s", node_id, conn.host, conn.port,
+ broker.host, broker.port)
+
+ # Drop old connection object.
+ # It will be recreated on next _maybe_connect
+ self._conns.pop(node_id)
+ return False
+
+ elif conn.connected():
+ return True
- conn.connect()
- return conn.connected()
+ conn.connect()
+ return conn.connected()
def ready(self, node_id, metadata_priority=True):
"""Check whether a node is connected and ok to send more requests.
@@ -397,9 +399,10 @@ class KafkaClient(object):
def connected(self, node_id):
"""Return True iff the node_id is connected."""
- if node_id not in self._conns:
- return False
- return self._conns[node_id].connected()
+ with self._lock:
+ if node_id not in self._conns:
+ return False
+ return self._conns[node_id].connected()
def close(self, node_id=None):
"""Close one or all broker connections.
@@ -407,18 +410,19 @@ class KafkaClient(object):
Arguments:
node_id (int, optional): the id of the node to close
"""
- if node_id is None:
- self._closed = True
- for conn in self._conns.values():
- conn.close()
- self._wake_r.close()
- self._wake_w.close()
- self._selector.close()
- elif node_id in self._conns:
- self._conns[node_id].close()
- else:
- log.warning("Node %s not found in current connection list; skipping", node_id)
- return
+ with self._lock:
+ if node_id is None:
+ self._closed = True
+ for conn in self._conns.values():
+ conn.close()
+ self._wake_r.close()
+ self._wake_w.close()
+ self._selector.close()
+ elif node_id in self._conns:
+ self._conns[node_id].close()
+ else:
+ log.warning("Node %s not found in current connection list; skipping", node_id)
+ return
def is_disconnected(self, node_id):
"""Check whether the node connection has been disconnected or failed.
@@ -434,9 +438,10 @@ class KafkaClient(object):
Returns:
bool: True iff the node exists and is disconnected
"""
- if node_id not in self._conns:
- return False
- return self._conns[node_id].disconnected()
+ with self._lock:
+ if node_id not in self._conns:
+ return False
+ return self._conns[node_id].disconnected()
def connection_delay(self, node_id):
"""
@@ -452,9 +457,10 @@ class KafkaClient(object):
Returns:
int: The number of milliseconds to wait.
"""
- if node_id not in self._conns:
- return 0
- return self._conns[node_id].connection_delay()
+ with self._lock:
+ if node_id not in self._conns:
+ return 0
+ return self._conns[node_id].connection_delay()
def is_ready(self, node_id, metadata_priority=True):
"""Check whether a node is ready to send more requests.
@@ -483,10 +489,11 @@ class KafkaClient(object):
return True
def _can_send_request(self, node_id):
- if node_id not in self._conns:
- return False
- conn = self._conns[node_id]
- return conn.connected() and conn.can_send_more()
+ with self._lock:
+ if node_id not in self._conns:
+ return False
+ conn = self._conns[node_id]
+ return conn.connected() and conn.can_send_more()
def send(self, node_id, request):
"""Send a request to a specific node.
@@ -501,12 +508,13 @@ class KafkaClient(object):
Returns:
Future: resolves to Response struct or Error
"""
- if not self._maybe_connect(node_id):
- return Future().failure(Errors.NodeNotReadyError(node_id))
+ with self._lock:
+ if not self._maybe_connect(node_id):
+ return Future().failure(Errors.NodeNotReadyError(node_id))
- return self._conns[node_id].send(request)
+ return self._conns[node_id].send(request)
- def poll(self, timeout_ms=None, future=None, delayed_tasks=True):
+ def poll(self, timeout_ms=None, future=None):
"""Try to read and write to sockets.
This method will also attempt to complete node connections, refresh
@@ -527,44 +535,34 @@ class KafkaClient(object):
elif timeout_ms is None:
timeout_ms = self.config['request_timeout_ms']
- responses = []
-
# Loop for futures, break after first loop if None
+ responses = []
while True:
-
- # Attempt to complete pending connections
- for node_id in list(self._connecting):
- self._maybe_connect(node_id)
-
- # Send a metadata request if needed
- metadata_timeout_ms = self._maybe_refresh_metadata()
-
- # Send scheduled tasks
- if delayed_tasks:
- for task, task_future in self._delayed_tasks.pop_ready():
- try:
- result = task()
- except Exception as e:
- log.error("Task %s failed: %s", task, e)
- task_future.failure(e)
- else:
- task_future.success(result)
-
- # If we got a future that is already done, don't block in _poll
- if future is not None and future.is_done:
- timeout = 0
- else:
- idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms()
- timeout = min(
- timeout_ms,
- metadata_timeout_ms,
- self._delayed_tasks.next_at() * 1000,
- idle_connection_timeout_ms,
- self.config['request_timeout_ms'])
- timeout = max(0, timeout / 1000.0) # avoid negative timeouts
-
- self._poll(timeout)
-
+ with self._lock:
+
+ # Attempt to complete pending connections
+ for node_id in list(self._connecting):
+ self._maybe_connect(node_id)
+
+ # Send a metadata request if needed
+ metadata_timeout_ms = self._maybe_refresh_metadata()
+
+ # If we got a future that is already done, don't block in _poll
+ if future is not None and future.is_done:
+ timeout = 0
+ else:
+ idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms()
+ timeout = min(
+ timeout_ms,
+ metadata_timeout_ms,
+ idle_connection_timeout_ms,
+ self.config['request_timeout_ms'])
+ timeout = max(0, timeout / 1000) # avoid negative timeouts
+
+ self._poll(timeout)
+
+ # called without the lock to avoid deadlock potential
+ # if handlers need to acquire locks
responses.extend(self._fire_pending_completed_requests())
# If all we had was a timeout (future is None) - only do one poll
@@ -646,12 +644,13 @@ class KafkaClient(object):
Returns:
int: pending in-flight requests for the node, or all nodes if None
"""
- if node_id is not None:
- if node_id not in self._conns:
- return 0
- return len(self._conns[node_id].in_flight_requests)
- else:
- return sum([len(conn.in_flight_requests) for conn in self._conns.values()])
+ with self._lock:
+ if node_id is not None:
+ if node_id not in self._conns:
+ return 0
+ return len(self._conns[node_id].in_flight_requests)
+ else:
+ return sum([len(conn.in_flight_requests) for conn in self._conns.values()])
def _fire_pending_completed_requests(self):
responses = []
@@ -672,37 +671,38 @@ class KafkaClient(object):
Returns:
node_id or None if no suitable node was found
"""
- nodes = [broker.nodeId for broker in self.cluster.brokers()]
- random.shuffle(nodes)
+ with self._lock:
+ nodes = [broker.nodeId for broker in self.cluster.brokers()]
+ random.shuffle(nodes)
+
+ inflight = float('inf')
+ found = None
+ for node_id in nodes:
+ conn = self._conns.get(node_id)
+ connected = conn is not None and conn.connected()
+ blacked_out = conn is not None and conn.blacked_out()
+ curr_inflight = len(conn.in_flight_requests) if conn is not None else 0
+ if connected and curr_inflight == 0:
+ # if we find an established connection
+ # with no in-flight requests, we can stop right away
+ return node_id
+ elif not blacked_out and curr_inflight < inflight:
+ # otherwise if this is the best we have found so far, record that
+ inflight = curr_inflight
+ found = node_id
+
+ if found is not None:
+ return found
+
+ # some broker versions return an empty list of broker metadata
+ # if there are no topics created yet. the bootstrap process
+ # should detect this and keep a 'bootstrap' node alive until
+ # a non-bootstrap node is connected and non-empty broker
+ # metadata is available
+ elif 'bootstrap' in self._conns:
+ return 'bootstrap'
- inflight = float('inf')
- found = None
- for node_id in nodes:
- conn = self._conns.get(node_id)
- connected = conn is not None and conn.connected()
- blacked_out = conn is not None and conn.blacked_out()
- curr_inflight = len(conn.in_flight_requests) if conn is not None else 0
- if connected and curr_inflight == 0:
- # if we find an established connection
- # with no in-flight requests, we can stop right away
- return node_id
- elif not blacked_out and curr_inflight < inflight:
- # otherwise if this is the best we have found so far, record that
- inflight = curr_inflight
- found = node_id
-
- if found is not None:
- return found
-
- # some broker versions return an empty list of broker metadata
- # if there are no topics created yet. the bootstrap process
- # should detect this and keep a 'bootstrap' node alive until
- # a non-bootstrap node is connected and non-empty broker
- # metadata is available
- elif 'bootstrap' in self._conns:
- return 'bootstrap'
-
- return None
+ return None
def set_topics(self, topics):
"""Set specific topics to track for metadata.
@@ -735,7 +735,7 @@ class KafkaClient(object):
self._topics.add(topic)
return self.cluster.request_update()
- # request metadata update on disconnect and timedout
+ # This method should be locked when running multi-threaded
def _maybe_refresh_metadata(self):
"""Send a metadata request if needed.
@@ -793,34 +793,6 @@ class KafkaClient(object):
# to let us know the selected connection might be usable again.
return float('inf')
- def schedule(self, task, at):
- """Schedule a new task to be executed at the given time.
-
- This is "best-effort" scheduling and should only be used for coarse
- synchronization. A task cannot be scheduled for multiple times
- simultaneously; any previously scheduled instance of the same task
- will be cancelled.
-
- Arguments:
- task (callable): task to be scheduled
- at (float or int): epoch seconds when task should run
-
- Returns:
- Future: resolves to result of task call, or exception if raised
- """
- return self._delayed_tasks.add(task, at)
-
- def unschedule(self, task):
- """Unschedule a task.
-
- This will remove all instances of the task from the task queue.
- This is a no-op if the task is not scheduled.
-
- Arguments:
- task (callable): task to be unscheduled
- """
- self._delayed_tasks.remove(task)
-
def check_version(self, node_id=None, timeout=2, strict=False):
"""Attempt to guess the version of a Kafka broker.
@@ -890,79 +862,6 @@ class KafkaClient(object):
self.close(node_id=conn_id)
-class DelayedTaskQueue(object):
- # see https://docs.python.org/2/library/heapq.html
- def __init__(self):
- self._tasks = [] # list of entries arranged in a heap
- self._task_map = {} # mapping of tasks to entries
- self._counter = itertools.count() # unique sequence count
-
- def add(self, task, at):
- """Add a task to run at a later time.
-
- Arguments:
- task: can be anything, but generally a callable
- at (float or int): epoch seconds to schedule task
-
- Returns:
- Future: a future that will be returned with the task when ready
- """
- if task in self._task_map:
- self.remove(task)
- count = next(self._counter)
- future = Future()
- entry = [at, count, (task, future)]
- self._task_map[task] = entry
- heapq.heappush(self._tasks, entry)
- return future
-
- def remove(self, task):
- """Remove a previously scheduled task.
-
- Raises:
- KeyError: if task is not found
- """
- entry = self._task_map.pop(task)
- task, future = entry[-1]
- future.failure(Errors.Cancelled)
- entry[-1] = 'REMOVED'
-
- def _drop_removed(self):
- while self._tasks and self._tasks[0][-1] is 'REMOVED':
- at, count, task = heapq.heappop(self._tasks)
-
- def _pop_next(self):
- self._drop_removed()
- if not self._tasks:
- raise KeyError('pop from an empty DelayedTaskQueue')
- _, _, maybe_task = heapq.heappop(self._tasks)
- if maybe_task is 'REMOVED':
- raise ValueError('popped a removed tasks from queue - bug')
- else:
- task, future = maybe_task
- del self._task_map[task]
- return (task, future)
-
- def next_at(self):
- """Number of seconds until next task is ready."""
- self._drop_removed()
- if not self._tasks:
- return float('inf')
- else:
- return max(self._tasks[0][0] - time.time(), 0)
-
- def pop_ready(self):
- """Pop and return a list of all ready (task, future) tuples"""
- ready_tasks = []
- while self._tasks and self._tasks[0][0] < time.time():
- try:
- task = self._pop_next()
- except KeyError:
- break
- ready_tasks.append(task)
- return ready_tasks
-
-
# OrderedDict requires python2.7+
try:
from collections import OrderedDict
diff --git a/kafka/conn.py b/kafka/conn.py
index 68f2659..2b1008b 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -685,7 +685,7 @@ class BrokerConnection(object):
def recv(self):
"""Non-blocking network receive.
- Return list of (response, future)
+ Return list of (response, future) tuples
"""
if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING:
log.warning('%s cannot recv: socket not connected', self)
diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py
index f9251fd..debe86b 100644
--- a/kafka/consumer/fetcher.py
+++ b/kafka/consumer/fetcher.py
@@ -674,6 +674,9 @@ class Fetcher(six.Iterator):
fetchable[node_id][partition.topic].append(partition_info)
log.debug("Adding fetch request for partition %s at offset %d",
partition, position)
+ else:
+ log.log(0, "Skipping fetch for partition %s because there is an inflight request to node %s",
+ partition, node_id)
if self.config['api_version'] >= (0, 11, 0):
version = 4
diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py
index 78686a4..7c345e7 100644
--- a/kafka/consumer/group.py
+++ b/kafka/consumer/group.py
@@ -1,4 +1,4 @@
-from __future__ import absolute_import
+from __future__ import absolute_import, division
import copy
import logging
@@ -125,19 +125,34 @@ class KafkaConsumer(six.Iterator):
distribute partition ownership amongst consumer instances when
group management is used.
Default: [RangePartitionAssignor, RoundRobinPartitionAssignor]
+ max_poll_records (int): The maximum number of records returned in a
+ single call to :meth:`~kafka.KafkaConsumer.poll`. Default: 500
+ max_poll_interval_ms (int): The maximum delay between invocations of
+ :meth:`~kafka.KafkaConsumer.poll` when using consumer group
+ management. This places an upper bound on the amount of time that
+ the consumer can be idle before fetching more records. If
+ :meth:`~kafka.KafkaConsumer.poll` is not called before expiration
+ of this timeout, then the consumer is considered failed and the
+ group will rebalance in order to reassign the partitions to another
+ member. Default 300000
+ session_timeout_ms (int): The timeout used to detect failures when
+ using Kafka's group management facilities. The consumer sends
+ periodic heartbeats to indicate its liveness to the broker. If
+ no heartbeats are received by the broker before the expiration of
+ this session timeout, then the broker will remove this consumer
+ from the group and initiate a rebalance. Note that the value must
+ be in the allowable range as configured in the broker configuration
+ by group.min.session.timeout.ms and group.max.session.timeout.ms.
+ Default: 10000
heartbeat_interval_ms (int): The expected time in milliseconds
between heartbeats to the consumer coordinator when using
- Kafka's group management feature. Heartbeats are used to ensure
+ Kafka's group management facilities. Heartbeats are used to ensure
that the consumer's session stays active and to facilitate
rebalancing when new consumers join or leave the group. The
value must be set lower than session_timeout_ms, but typically
should be set no higher than 1/3 of that value. It can be
adjusted even lower to control the expected time for normal
rebalances. Default: 3000
- session_timeout_ms (int): The timeout used to detect failures when
- using Kafka's group management facilities. Default: 30000
- max_poll_records (int): The maximum number of records returned in a
- single call to :meth:`~kafka.KafkaConsumer.poll`. Default: 500
receive_buffer_bytes (int): The size of the TCP receive buffer
(SO_RCVBUF) to use when reading data. Default: None (relies on
system defaults). The java client defaults to 32768.
@@ -236,7 +251,7 @@ class KafkaConsumer(six.Iterator):
'fetch_min_bytes': 1,
'fetch_max_bytes': 52428800,
'max_partition_fetch_bytes': 1 * 1024 * 1024,
- 'request_timeout_ms': 40 * 1000,
+ 'request_timeout_ms': 305000, # chosen to be higher than the default of max_poll_interval_ms
'retry_backoff_ms': 100,
'reconnect_backoff_ms': 50,
'reconnect_backoff_max_ms': 1000,
@@ -248,9 +263,10 @@ class KafkaConsumer(six.Iterator):
'check_crcs': True,
'metadata_max_age_ms': 5 * 60 * 1000,
'partition_assignment_strategy': (RangePartitionAssignor, RoundRobinPartitionAssignor),
- 'heartbeat_interval_ms': 3000,
- 'session_timeout_ms': 30000,
'max_poll_records': 500,
+ 'max_poll_interval_ms': 300000,
+ 'session_timeout_ms': 10000,
+ 'heartbeat_interval_ms': 3000,
'receive_buffer_bytes': None,
'send_buffer_bytes': None,
'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)],
@@ -278,15 +294,16 @@ class KafkaConsumer(six.Iterator):
'sasl_plain_password': None,
'sasl_kerberos_service_name': 'kafka'
}
+ DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000
def __init__(self, *topics, **configs):
- self.config = copy.copy(self.DEFAULT_CONFIG)
- for key in self.config:
- if key in configs:
- self.config[key] = configs.pop(key)
-
# Only check for extra config keys in top-level class
- assert not configs, 'Unrecognized configs: %s' % configs
+ extra_configs = set(configs).difference(self.DEFAULT_CONFIG)
+ if extra_configs:
+ raise KafkaConfigurationError("Unrecognized configs: %s" % extra_configs)
+
+ self.config = copy.copy(self.DEFAULT_CONFIG)
+ self.config.update(configs)
deprecated = {'smallest': 'earliest', 'largest': 'latest'}
if self.config['auto_offset_reset'] in deprecated:
@@ -296,12 +313,7 @@ class KafkaConsumer(six.Iterator):
self.config['auto_offset_reset'] = new_config
request_timeout_ms = self.config['request_timeout_ms']
- session_timeout_ms = self.config['session_timeout_ms']
fetch_max_wait_ms = self.config['fetch_max_wait_ms']
- if request_timeout_ms <= session_timeout_ms:
- raise KafkaConfigurationError(
- "Request timeout (%s) must be larger than session timeout (%s)" %
- (request_timeout_ms, session_timeout_ms))
if request_timeout_ms <= fetch_max_wait_ms:
raise KafkaConfigurationError("Request timeout (%s) must be larger than fetch-max-wait-ms (%s)" %
(request_timeout_ms, fetch_max_wait_ms))
@@ -330,6 +342,25 @@ class KafkaConsumer(six.Iterator):
if self.config['api_version'] is None:
self.config['api_version'] = self._client.config['api_version']
+ # Coordinator configurations are different for older brokers
+ # max_poll_interval_ms is not supported directly -- it must the be
+ # the same as session_timeout_ms. If the user provides one of them,
+ # use it for both. Otherwise use the old default of 30secs
+ if self.config['api_version'] < (0, 10, 1):
+ if 'session_timeout_ms' not in configs:
+ if 'max_poll_interval_ms' in configs:
+ self.config['session_timeout_ms'] = configs['max_poll_interval_ms']
+ else:
+ self.config['session_timeout_ms'] = self.DEFAULT_SESSION_TIMEOUT_MS_0_9
+ if 'max_poll_interval_ms' not in configs:
+ self.config['max_poll_interval_ms'] = self.config['session_timeout_ms']
+
+ if self.config['group_id'] is not None:
+ if self.config['request_timeout_ms'] <= self.config['session_timeout_ms']:
+ raise KafkaConfigurationError(
+ "Request timeout (%s) must be larger than session timeout (%s)" %
+ (self.config['request_timeout_ms'], self.config['session_timeout_ms']))
+
self._subscription = SubscriptionState(self.config['auto_offset_reset'])
self._fetcher = Fetcher(
self._client, self._subscription, self._metrics, **self.config)
@@ -587,12 +618,7 @@ class KafkaConsumer(six.Iterator):
Returns:
dict: Map of topic to list of records (may be empty).
"""
- if self._use_consumer_group():
- self._coordinator.ensure_active_group()
-
- # 0.8.2 brokers support kafka-backed offset storage via group coordinator
- elif self.config['group_id'] is not None and self.config['api_version'] >= (0, 8, 2):
- self._coordinator.ensure_coordinator_ready()
+ self._coordinator.poll()
# Fetch positions if we have partitions we're subscribed to that we
# don't know the offset for
@@ -614,6 +640,7 @@ class KafkaConsumer(six.Iterator):
# Send any new fetches (won't resend pending fetches)
self._fetcher.send_fetches()
+ timeout_ms = min(timeout_ms, self._coordinator.time_to_next_poll())
self._client.poll(timeout_ms=timeout_ms)
records, _ = self._fetcher.fetched_records(max_records)
return records
@@ -1014,13 +1041,7 @@ class KafkaConsumer(six.Iterator):
assert self.assignment() or self.subscription() is not None, 'No topic subscription or manual partition assignment'
while time.time() < self._consumer_timeout:
- if self._use_consumer_group():
- self._coordinator.ensure_coordinator_ready()
- self._coordinator.ensure_active_group()
-
- # 0.8.2 brokers support kafka-backed offset storage via group coordinator
- elif self.config['group_id'] is not None and self.config['api_version'] >= (0, 8, 2):
- self._coordinator.ensure_coordinator_ready()
+ self._coordinator.poll()
# Fetch offsets for any subscribed partitions that we arent tracking yet
if not self._subscription.has_all_fetch_positions():
@@ -1068,19 +1089,8 @@ class KafkaConsumer(six.Iterator):
def _next_timeout(self):
timeout = min(self._consumer_timeout,
- self._client._delayed_tasks.next_at() + time.time(),
- self._client.cluster.ttl() / 1000.0 + time.time())
-
- # Although the delayed_tasks timeout above should cover processing
- # HeartbeatRequests, it is still possible that HeartbeatResponses
- # are left unprocessed during a long _fetcher iteration without
- # an intermediate poll(). And because tasks are responsible for
- # rescheduling themselves, an unprocessed response will prevent
- # the next heartbeat from being sent. This check should help
- # avoid that.
- if self._use_consumer_group():
- heartbeat = time.time() + self._coordinator.heartbeat.ttl()
- timeout = min(timeout, heartbeat)
+ self._client.cluster.ttl() / 1000.0 + time.time(),
+ self._coordinator.time_to_next_poll() + time.time())
return timeout
def __iter__(self): # pylint: disable=non-iterator-returned
diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py
index a3055da..b16c1e1 100644
--- a/kafka/coordinator/base.py
+++ b/kafka/coordinator/base.py
@@ -3,6 +3,8 @@ from __future__ import absolute_import, division
import abc
import copy
import logging
+import sys
+import threading
import time
import weakref
@@ -20,6 +22,28 @@ from ..protocol.group import (HeartbeatRequest, JoinGroupRequest,
log = logging.getLogger('kafka.coordinator')
+class MemberState(object):
+ UNJOINED = '<unjoined>' # the client is not part of a group
+ REBALANCING = '<rebalancing>' # the client has begun rebalancing
+ STABLE = '<stable>' # the client has joined and is sending heartbeats
+
+
+class Generation(object):
+ def __init__(self, generation_id, member_id, protocol):
+ self.generation_id = generation_id
+ self.member_id = member_id
+ self.protocol = protocol
+
+Generation.NO_GENERATION = Generation(
+ OffsetCommitRequest[2].DEFAULT_GENERATION_ID,
+ JoinGroupRequest[0].UNKNOWN_MEMBER_ID,
+ None)
+
+
+class UnjoinedGroupException(Errors.KafkaError):
+ retriable = True
+
+
class BaseCoordinator(object):
"""
BaseCoordinator implements group management for a single group member
@@ -47,14 +71,23 @@ class BaseCoordinator(object):
:meth:`.group_protocols` and the format of the state assignment provided by
the leader in :meth:`._perform_assignment` and which becomes available to
members in :meth:`._on_join_complete`.
+
+ Note on locking: this class shares state between the caller and a background
+ thread which is used for sending heartbeats after the client has joined the
+ group. All mutable state as well as state transitions are protected with the
+ class's monitor. Generally this means acquiring the lock before reading or
+ writing the state of the group (e.g. generation, member_id) and holding the
+ lock when sending a request that affects the state of the group
+ (e.g. JoinGroup, LeaveGroup).
"""
DEFAULT_CONFIG = {
'group_id': 'kafka-python-default-group',
- 'session_timeout_ms': 30000,
+ 'session_timeout_ms': 10000,
'heartbeat_interval_ms': 3000,
+ 'max_poll_interval_ms': 300000,
'retry_backoff_ms': 100,
- 'api_version': (0, 9),
+ 'api_version': (0, 10, 1),
'metric_group_prefix': '',
}
@@ -83,27 +116,31 @@ class BaseCoordinator(object):
if key in configs:
self.config[key] = configs[key]
+ if self.config['api_version'] < (0, 10, 1):
+ if self.config['max_poll_interval_ms'] != self.config['session_timeout_ms']:
+ raise Errors.KafkaConfigurationError("Broker version %s does not support "
+ "different values for max_poll_interval_ms "
+ "and session_timeout_ms")
+
self._client = client
- self.generation = OffsetCommitRequest[2].DEFAULT_GENERATION_ID
- self.member_id = JoinGroupRequest[0].UNKNOWN_MEMBER_ID
self.group_id = self.config['group_id']
+ self.heartbeat = Heartbeat(**self.config)
+ self._heartbeat_thread = None
+ self._lock = threading.Condition()
+ self.rejoin_needed = True
+ self.rejoining = False # renamed / complement of java needsJoinPrepare
+ self.state = MemberState.UNJOINED
+ self.join_future = None
self.coordinator_id = None
self._find_coordinator_future = None
- self.rejoin_needed = True
- self.rejoining = False
- self.heartbeat = Heartbeat(**self.config)
- self.heartbeat_task = HeartbeatTask(weakref.proxy(self))
+ self._generation = Generation.NO_GENERATION
self.sensors = GroupCoordinatorMetrics(self.heartbeat, metrics,
self.config['metric_group_prefix'])
- def __del__(self):
- if hasattr(self, 'heartbeat_task') and self.heartbeat_task:
- self.heartbeat_task.disable()
-
@abc.abstractmethod
def protocol_type(self):
"""
- Unique identifier for the class of protocols implements
+ Unique identifier for the class of supported protocols
(e.g. "consumer" or "connect").
Returns:
@@ -187,42 +224,51 @@ class BaseCoordinator(object):
Returns:
bool: True if the coordinator is unknown
"""
- if self.coordinator_id is None:
- return True
+ return self.coordinator() is None
- if self._client.is_disconnected(self.coordinator_id):
- self.coordinator_dead('Node Disconnected')
- return True
+ def coordinator(self):
+ """Get the current coordinator
- return False
+ Returns: the current coordinator id or None if it is unknown
+ """
+ with self._lock:
+ if self.coordinator_id is None:
+ return None
+ elif self._client.is_disconnected(self.coordinator_id):
+ self.coordinator_dead('Node Disconnected')
+ return None
+ else:
+ return self.coordinator_id
def ensure_coordinator_ready(self):
"""Block until the coordinator for this group is known
(and we have an active connection -- java client uses unsent queue).
"""
- while self.coordinator_unknown():
- # Prior to 0.8.2 there was no group coordinator
- # so we will just pick a node at random and treat
- # it as the "coordinator"
- if self.config['api_version'] < (0, 8, 2):
- self.coordinator_id = self._client.least_loaded_node()
- if self.coordinator_id is not None:
- self._client.ready(self.coordinator_id)
- continue
-
- future = self.lookup_coordinator()
- self._client.poll(future=future)
-
- if future.failed():
- if future.retriable():
- if getattr(future.exception, 'invalid_metadata', False):
- log.debug('Requesting metadata for group coordinator request: %s', future.exception)
- metadata_update = self._client.cluster.request_update()
- self._client.poll(future=metadata_update)
+ with self._lock:
+ while self.coordinator_unknown():
+
+ # Prior to 0.8.2 there was no group coordinator
+ # so we will just pick a node at random and treat
+ # it as the "coordinator"
+ if self.config['api_version'] < (0, 8, 2):
+ self.coordinator_id = self._client.least_loaded_node()
+ if self.coordinator_id is not None:
+ self._client.ready(self.coordinator_id)
+ continue
+
+ future = self.lookup_coordinator()
+ self._client.poll(future=future)
+
+ if future.failed():
+ if future.retriable():
+ if getattr(future.exception, 'invalid_metadata', False):
+ log.debug('Requesting metadata for group coordinator request: %s', future.exception)
+ metadata_update = self._client.cluster.request_update()
+ self._client.poll(future=metadata_update)
+ else:
+ time.sleep(self.config['retry_backoff_ms'] / 1000)
else:
- time.sleep(self.config['retry_backoff_ms'] / 1000)
- else:
- raise future.exception # pylint: disable-msg=raising-bad-type
+ raise future.exception # pylint: disable-msg=raising-bad-type
def _reset_find_coordinator_future(self, result):
self._find_coordinator_future = None
@@ -248,52 +294,116 @@ class BaseCoordinator(object):
"""
return self.rejoin_needed
+ def poll_heartbeat(self):
+ """
+ Check the status of the heartbeat thread (if it is active) and indicate
+ the liveness of the client. This must be called periodically after
+ joining with :meth:`.ensure_active_group` to ensure that the member stays
+ in the group. If an interval of time longer than the provided rebalance
+ timeout (max_poll_interval_ms) expires without calling this method, then
+ the client will proactively leave the group.
+
+ Raises: RuntimeError for unexpected errors raised from the heartbeat thread
+ """
+ with self._lock:
+ if self._heartbeat_thread is not None:
+ if self._heartbeat_thread.failed:
+ # set the heartbeat thread to None and raise an exception.
+ # If the user catches it, the next call to ensure_active_group()
+ # will spawn a new heartbeat thread.
+ cause = self._heartbeat_thread.failed
+ self._heartbeat_thread = None
+ raise cause # pylint: disable-msg=raising-bad-type
+ self.heartbeat.poll()
+
+ def time_to_next_heartbeat(self):
+ with self._lock:
+ # if we have not joined the group, we don't need to send heartbeats
+ if self.state is MemberState.UNJOINED:
+ return sys.maxsize
+ return self.heartbeat.time_to_next_heartbeat()
+
+ def _handle_join_success(self, member_assignment_bytes):
+ with self._lock:
+ log.info("Successfully joined group %s with generation %s",
+ self.group_id, self._generation.generation_id)
+ self.join_future = None
+ self.state = MemberState.STABLE
+ self.rejoining = False
+ self._heartbeat_thread.enable()
+ self._on_join_complete(self._generation.generation_id,
+ self._generation.member_id,
+ self._generation.protocol,
+ member_assignment_bytes)
+
+ def _handle_join_failure(self, _):
+ with self._lock:
+ self.join_future = None
+ self.state = MemberState.UNJOINED
+
def ensure_active_group(self):
"""Ensure that the group is active (i.e. joined and synced)"""
- # always ensure that the coordinator is ready because we may have been
- # disconnected when sending heartbeats and does not necessarily require
- # us to rejoin the group.
- self.ensure_coordinator_ready()
-
- if not self.need_rejoin():
- return
-
- if not self.rejoining:
- self._on_join_prepare(self.generation, self.member_id)
- self.rejoining = True
-
- while self.need_rejoin():
- self.ensure_coordinator_ready()
-
- # ensure that there are no pending requests to the coordinator.
- # This is important in particular to avoid resending a pending
- # JoinGroup request.
- while not self.coordinator_unknown():
- if not self._client.in_flight_request_count(self.coordinator_id):
- break
- self._client.poll(delayed_tasks=False)
- else:
- continue
-
- future = self._send_join_group_request()
- self._client.poll(future=future)
+ with self._lock:
+ if not self.need_rejoin():
+ return
- if future.succeeded():
- member_assignment_bytes = future.value
- self._on_join_complete(self.generation, self.member_id,
- self.protocol, member_assignment_bytes)
- self.rejoining = False
- self.heartbeat_task.reset()
- else:
- assert future.failed()
- exception = future.exception
- if isinstance(exception, (Errors.UnknownMemberIdError,
- Errors.RebalanceInProgressError,
- Errors.IllegalGenerationError)):
+ # call on_join_prepare if needed. We set a flag to make sure that
+ # we do not call it a second time if the client is woken up before
+ # a pending rebalance completes.
+ if not self.rejoining:
+ self._on_join_prepare(self._generation.generation_id,
+ self._generation.member_id)
+ self.rejoining = True
+
+ if self._heartbeat_thread is None:
+ log.debug('Starting new heartbeat thread')
+ self._heartbeat_thread = HeartbeatThread(weakref.proxy(self))
+ self._heartbeat_thread.daemon = True
+ self._heartbeat_thread.start()
+
+ while self.need_rejoin():
+ self.ensure_coordinator_ready()
+
+ # ensure that there are no pending requests to the coordinator.
+ # This is important in particular to avoid resending a pending
+ # JoinGroup request.
+ while not self.coordinator_unknown():
+ if not self._client.in_flight_request_count(self.coordinator_id):
+ break
+ self._client.poll()
+ else:
continue
- elif not future.retriable():
- raise exception # pylint: disable-msg=raising-bad-type
- time.sleep(self.config['retry_backoff_ms'] / 1000)
+
+ # we store the join future in case we are woken up by the user
+ # after beginning the rebalance in the call to poll below.
+ # This ensures that we do not mistakenly attempt to rejoin
+ # before the pending rebalance has completed.
+ if self.join_future is None:
+ self.state = MemberState.REBALANCING
+ self.join_future = self._send_join_group_request()
+
+ # handle join completion in the callback so that the
+ # callback will be invoked even if the consumer is woken up
+ # before finishing the rebalance
+ self.join_future.add_callback(self._handle_join_success)
+
+ # we handle failures below after the request finishes.
+ # If the join completes after having been woken up, the
+ # exception is ignored and we will rejoin
+ self.join_future.add_errback(self._handle_join_failure)
+
+ future = self.join_future
+ self._client.poll(future=future)
+
+ if future.failed():
+ exception = future.exception
+ if isinstance(exception, (Errors.UnknownMemberIdError,
+ Errors.RebalanceInProgressError,
+ Errors.IllegalGenerationError)):
+ continue
+ elif not future.retriable():
+ raise exception # pylint: disable-msg=raising-bad-type
+ time.sleep(self.config['retry_backoff_ms'] / 1000)
def _send_join_group_request(self):
"""Join the group and return the assignment for the next generation.
@@ -315,14 +425,35 @@ class BaseCoordinator(object):
# send a join group request to the coordinator
log.info("(Re-)joining group %s", self.group_id)
- request = JoinGroupRequest[0](
- self.group_id,
- self.config['session_timeout_ms'],
- self.member_id,
- self.protocol_type(),
- [(protocol,
- metadata if isinstance(metadata, bytes) else metadata.encode())
- for protocol, metadata in self.group_protocols()])
+ member_metadata = [
+ (protocol, metadata if isinstance(metadata, bytes) else metadata.encode())
+ for protocol, metadata in self.group_protocols()
+ ]
+ if self.config['api_version'] < (0, 9):
+ raise Errors.KafkaError('JoinGroupRequest api requires 0.9+ brokers')
+ elif (0, 9) <= self.config['api_version'] < (0, 10, 1):
+ request = JoinGroupRequest[0](
+ self.group_id,
+ self.config['session_timeout_ms'],
+ self._generation.member_id,
+ self.protocol_type(),
+ member_metadata)
+ elif (0, 10, 1) <= self.config['api_version'] < (0, 11, 0):
+ request = JoinGroupRequest[1](
+ self.group_id,
+ self.config['session_timeout_ms'],
+ self.config['max_poll_interval_ms'],
+ self._generation.member_id,
+ self.protocol_type(),
+ member_metadata)
+ else:
+ request = JoinGroupRequest[2](
+ self.group_id,
+ self.config['session_timeout_ms'],
+ self.config['max_poll_interval_ms'],
+ self._generation.member_id,
+ self.protocol_type(),
+ member_metadata)
# create the request for the coordinator
log.debug("Sending JoinGroup (%s) to coordinator %s", request, self.coordinator_id)
@@ -348,19 +479,25 @@ class BaseCoordinator(object):
if error_type is Errors.NoError:
log.debug("Received successful JoinGroup response for group %s: %s",
self.group_id, response)
- self.member_id = response.member_id
- self.generation = response.generation_id
- self.rejoin_needed = False
- self.protocol = response.group_protocol
- log.info("Joined group '%s' (generation %s) with member_id %s",
- self.group_id, self.generation, self.member_id)
self.sensors.join_latency.record((time.time() - send_time) * 1000)
- if response.leader_id == response.member_id:
- log.info("Elected group leader -- performing partition"
- " assignments using %s", self.protocol)
- self._on_join_leader(response).chain(future)
- else:
- self._on_join_follower().chain(future)
+ with self._lock:
+ if self.state is not MemberState.REBALANCING:
+ # if the consumer was woken up before a rebalance completes,
+ # we may have already left the group. In this case, we do
+ # not want to continue with the sync group.
+ future.failure(UnjoinedGroupException())
+ else:
+ self._generation = Generation(response.generation_id,
+ response.member_id,
+ response.group_protocol)
+ self.rejoin_needed = False
+
+ if response.leader_id == response.member_id:
+ log.info("Elected group leader -- performing partition"
+ " assignments using %s", self._generation.protocol)
+ self._on_join_leader(response).chain(future)
+ else:
+ self._on_join_follower().chain(future)
elif error_type is Errors.GroupLoadInProgressError:
log.debug("Attempt to join group %s rejected since coordinator %s"
@@ -369,8 +506,8 @@ class BaseCoordinator(object):
future.failure(error_type(response))
elif error_type is Errors.UnknownMemberIdError:
# reset the member id and retry immediately
- error = error_type(self.member_id)
- self.member_id = JoinGroupRequest[0].UNKNOWN_MEMBER_ID
+ error = error_type(self._generation.member_id)
+ self.reset_generation()
log.debug("Attempt to join group %s failed due to unknown member id",
self.group_id)
future.failure(error)
@@ -400,10 +537,11 @@ class BaseCoordinator(object):
def _on_join_follower(self):
# send follower's sync group with an empty assignment
- request = SyncGroupRequest[0](
+ version = 0 if self.config['api_version'] < (0, 11, 0) else 1
+ request = SyncGroupRequest[version](
self.group_id,
- self.generation,
- self.member_id,
+ self._generation.generation_id,
+ self._generation.member_id,
{})
log.debug("Sending follower SyncGroup for group %s to coordinator %s: %s",
self.group_id, self.coordinator_id, request)
@@ -427,10 +565,11 @@ class BaseCoordinator(object):
except Exception as e:
return Future().failure(e)
- request = SyncGroupRequest[0](
+ version = 0 if self.config['api_version'] < (0, 11, 0) else 1
+ request = SyncGroupRequest[version](
self.group_id,
- self.generation,
- self.member_id,
+ self._generation.generation_id,
+ self._generation.member_id,
[(member_id,
assignment if isinstance(assignment, bytes) else assignment.encode())
for member_id, assignment in six.iteritems(group_assignment)])
@@ -460,14 +599,12 @@ class BaseCoordinator(object):
def _handle_sync_group_response(self, future, send_time, response):
error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
- log.info("Successfully joined group %s with generation %s",
- self.group_id, self.generation)
self.sensors.sync_latency.record((time.time() - send_time) * 1000)
future.success(response.member_assignment)
return
# Always rejoin on error
- self.rejoin_needed = True
+ self.request_rejoin()
if error_type is Errors.GroupAuthorizationFailedError:
future.failure(error_type(self.group_id))
elif error_type is Errors.RebalanceInProgressError:
@@ -478,7 +615,7 @@ class BaseCoordinator(object):
Errors.IllegalGenerationError):
error = error_type()
log.debug("SyncGroup for group %s failed due to %s", self.group_id, error)
- self.member_id = JoinGroupRequest[0].UNKNOWN_MEMBER_ID
+ self.reset_generation()
future.failure(error)
elif error_type in (Errors.GroupCoordinatorNotAvailableError,
Errors.NotCoordinatorForGroupError):
@@ -516,30 +653,24 @@ class BaseCoordinator(object):
def _handle_group_coordinator_response(self, future, response):
log.debug("Received group coordinator response %s", response)
- if not self.coordinator_unknown():
- # We already found the coordinator, so ignore the request
- log.debug("Coordinator already known -- ignoring metadata response")
- future.success(self.coordinator_id)
- return
error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
- ok = self._client.cluster.add_group_coordinator(self.group_id, response)
- if not ok:
- # This could happen if coordinator metadata is different
- # than broker metadata
- future.failure(Errors.IllegalStateError())
- return
-
- self.coordinator_id = response.coordinator_id
- log.info("Discovered coordinator %s for group %s",
- self.coordinator_id, self.group_id)
- self._client.ready(self.coordinator_id)
-
- # start sending heartbeats only if we have a valid generation
- if self.generation > 0:
- self.heartbeat_task.reset()
+ with self._lock:
+ ok = self._client.cluster.add_group_coordinator(self.group_id, response)
+ if not ok:
+ # This could happen if coordinator metadata is different
+ # than broker metadata
+ future.failure(Errors.IllegalStateError())
+ return
+
+ self.coordinator_id = response.coordinator_id
+ log.info("Discovered coordinator %s for group %s",
+ self.coordinator_id, self.group_id)
+ self._client.ready(self.coordinator_id)
+ self.heartbeat.reset_timeouts()
future.success(self.coordinator_id)
+
elif error_type is Errors.GroupCoordinatorNotAvailableError:
log.debug("Group Coordinator Not Available; retry")
future.failure(error_type())
@@ -549,45 +680,74 @@ class BaseCoordinator(object):
future.failure(error)
else:
error = error_type()
- log.error("Unrecognized failure in Group Coordinator Request: %s",
- error)
+ log.error("Group coordinator lookup for group %s failed: %s",
+ self.group_id, error)
future.failure(error)
def coordinator_dead(self, error):
"""Mark the current coordinator as dead."""
- if self.coordinator_id is not None:
- log.warning("Marking the coordinator dead (node %s) for group %s: %s.",
- self.coordinator_id, self.group_id, error)
- self.coordinator_id = None
+ with self._lock:
+ if self.coordinator_id is not None:
+ log.warning("Marking the coordinator dead (node %s) for group %s: %s.",
+ self.coordinator_id, self.group_id, error)
+ self.coordinator_id = None
+
+ def generation(self):
+ """Get the current generation state if the group is stable.
+
+ Returns: the current generation or None if the group is unjoined/rebalancing
+ """
+ with self._lock:
+ if self.state is not MemberState.STABLE:
+ return None
+ return self._generation
+
+ def reset_generation(self):
+ """Reset the generation and memberId because we have fallen out of the group."""
+ with self._lock:
+ self._generation = Generation.NO_GENERATION
+ self.rejoin_needed = True
+ self.state = MemberState.UNJOINED
+
+ def request_rejoin(self):
+ self.rejoin_needed = True
def close(self):
"""Close the coordinator, leave the current group,
and reset local generation / member_id"""
- try:
- self._client.unschedule(self.heartbeat_task)
- except KeyError:
- pass
-
- if not self.coordinator_unknown() and self.generation > 0:
- # this is a minimal effort attempt to leave the group. we do not
- # attempt any resending if the request fails or times out.
- log.info('Leaving consumer group (%s).', self.group_id)
- request = LeaveGroupRequest[0](self.group_id, self.member_id)
- future = self._client.send(self.coordinator_id, request)
- future.add_callback(self._handle_leave_group_response)
- future.add_errback(log.error, "LeaveGroup request failed: %s")
- self._client.poll(future=future)
-
- self.generation = OffsetCommitRequest[2].DEFAULT_GENERATION_ID
- self.member_id = JoinGroupRequest[0].UNKNOWN_MEMBER_ID
- self.rejoin_needed = True
+ with self._lock:
+ if self._heartbeat_thread is not None:
+ self._heartbeat_thread.close()
+ self._heartbeat_thread = None
+ self.maybe_leave_group()
+
+ def maybe_leave_group(self):
+ """Leave the current group and reset local generation/memberId."""
+ with self._lock:
+ if (not self.coordinator_unknown()
+ and self.state is not MemberState.UNJOINED
+ and self._generation is not Generation.NO_GENERATION):
+
+ # this is a minimal effort attempt to leave the group. we do not
+ # attempt any resending if the request fails or times out.
+ log.info('Leaving consumer group (%s).', self.group_id)
+ version = 0 if self.config['api_version'] < (0, 11, 0) else 1
+ request = LeaveGroupRequest[version](self.group_id, self._generation.member_id)
+ future = self._client.send(self.coordinator_id, request)
+ future.add_callback(self._handle_leave_group_response)
+ future.add_errback(log.error, "LeaveGroup request failed: %s")
+ self._client.poll(future=future)
+
+ self.reset_generation()
def _handle_leave_group_response(self, response):
error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
- log.info("LeaveGroup request succeeded")
+ log.debug("LeaveGroup request for group %s returned successfully",
+ self.group_id)
else:
- log.error("LeaveGroup request failed: %s", error_type())
+ log.error("LeaveGroup request for group %s failed with error: %s",
+ self.group_id, error_type())
def _send_heartbeat_request(self):
"""Send a heartbeat request"""
@@ -599,7 +759,10 @@ class BaseCoordinator(object):
e = Errors.NodeNotReadyError(self.coordinator_id)
return Future().failure(e)
- request = HeartbeatRequest[0](self.group_id, self.generation, self.member_id)
+ version = 0 if self.config['api_version'] < (0, 11, 0) else 1
+ request = HeartbeatRequest[version](self.group_id,
+ self._generation.generation_id,
+ self._generation.member_id)
log.debug("Heartbeat: %s[%s] %s", request.group, request.generation_id, request.member_id) # pylint: disable-msg=no-member
future = Future()
_f = self._client.send(self.coordinator_id, request)
@@ -619,24 +782,23 @@ class BaseCoordinator(object):
Errors.NotCoordinatorForGroupError):
log.warning("Heartbeat failed for group %s: coordinator (node %s)"
" is either not started or not valid", self.group_id,
- self.coordinator_id)
+ self.coordinator())
self.coordinator_dead(error_type())
future.failure(error_type())
elif error_type is Errors.RebalanceInProgressError:
log.warning("Heartbeat failed for group %s because it is"
" rebalancing", self.group_id)
- self.rejoin_needed = True
+ self.request_rejoin()
future.failure(error_type())
elif error_type is Errors.IllegalGenerationError:
log.warning("Heartbeat failed for group %s: generation id is not "
" current.", self.group_id)
- self.rejoin_needed = True
+ self.reset_generation()
future.failure(error_type())
elif error_type is Errors.UnknownMemberIdError:
log.warning("Heartbeat: local member_id was not recognized;"
" this consumer needs to re-join")
- self.member_id = JoinGroupRequest[0].UNKNOWN_MEMBER_ID
- self.rejoin_needed = True
+ self.reset_generation()
future.failure(error_type)
elif error_type is Errors.GroupAuthorizationFailedError:
error = error_type(self.group_id)
@@ -648,76 +810,6 @@ class BaseCoordinator(object):
future.failure(error)
-class HeartbeatTask(object):
- def __init__(self, coordinator):
- self._coordinator = coordinator
- self._heartbeat = coordinator.heartbeat
- self._client = coordinator._client
- self._request_in_flight = False
-
- def disable(self):
- try:
- self._client.unschedule(self)
- except KeyError:
- pass
-
- def reset(self):
- # start or restart the heartbeat task to be executed at the next chance
- self._heartbeat.reset_session_timeout()
- try:
- self._client.unschedule(self)
- except KeyError:
- pass
- if not self._request_in_flight:
- self._client.schedule(self, time.time())
-
- def __call__(self):
- if (self._coordinator.generation < 0 or
- self._coordinator.need_rejoin()):
- # no need to send the heartbeat we're not using auto-assignment
- # or if we are awaiting a rebalance
- log.info("Skipping heartbeat: no auto-assignment"
- " or waiting on rebalance")
- return
-
- if self._coordinator.coordinator_unknown():
- log.warning("Coordinator unknown during heartbeat -- will retry")
- self._handle_heartbeat_failure(Errors.GroupCoordinatorNotAvailableError())
- return
-
- if self._heartbeat.session_expired():
- # we haven't received a successful heartbeat in one session interval
- # so mark the coordinator dead
- log.error("Heartbeat session expired - marking coordinator dead")
- self._coordinator.coordinator_dead('Heartbeat session expired')
- return
-
- if not self._heartbeat.should_heartbeat():
- # we don't need to heartbeat now, so reschedule for when we do
- ttl = self._heartbeat.ttl()
- log.debug("Heartbeat task unneeded now, retrying in %s", ttl)
- self._client.schedule(self, time.time() + ttl)
- else:
- self._heartbeat.sent_heartbeat()
- self._request_in_flight = True
- future = self._coordinator._send_heartbeat_request()
- future.add_callback(self._handle_heartbeat_success)
- future.add_errback(self._handle_heartbeat_failure)
-
- def _handle_heartbeat_success(self, v):
- log.debug("Received successful heartbeat")
- self._request_in_flight = False
- self._heartbeat.received_heartbeat()
- ttl = self._heartbeat.ttl()
- self._client.schedule(self, time.time() + ttl)
-
- def _handle_heartbeat_failure(self, e):
- log.warning("Heartbeat failed (%s); retrying", e)
- self._request_in_flight = False
- etd = time.time() + self._coordinator.config['retry_backoff_ms'] / 1000
- self._client.schedule(self, etd)
-
-
class GroupCoordinatorMetrics(object):
def __init__(self, heartbeat, metrics, prefix, tags=None):
self.heartbeat = heartbeat
@@ -764,6 +856,112 @@ class GroupCoordinatorMetrics(object):
metrics.add_metric(metrics.metric_name(
'last-heartbeat-seconds-ago', self.metric_group_name,
- 'The number of seconds since the last controller heartbeat',
+ 'The number of seconds since the last controller heartbeat was sent',
tags), AnonMeasurable(
lambda _, now: (now / 1000) - self.heartbeat.last_send))
+
+
+class HeartbeatThread(threading.Thread):
+ def __init__(self, coordinator):
+ super(HeartbeatThread, self).__init__()
+ self.name = threading.current_thread().name + '-heartbeat'
+ self.coordinator = coordinator
+ self.enabled = False
+ self.closed = False
+ self.failed = None
+
+ def enable(self):
+ with self.coordinator._lock:
+ self.enabled = True
+ self.coordinator.heartbeat.reset_timeouts()
+ self.coordinator._lock.notify()
+
+ def disable(self):
+ with self.coordinator._lock:
+ self.enabled = False
+
+ def close(self):
+ with self.coordinator._lock:
+ self.closed = True
+ self.coordinator._lock.notify()
+
+ def run(self):
+ try:
+ while not self.closed:
+ self._run_once()
+
+ log.debug('Heartbeat closed!')
+
+ except RuntimeError as e:
+ log.error("Heartbeat thread for group %s failed due to unexpected error: %s",
+ self.coordinator.group_id, e)
+ self.failed = e
+
+ def _run_once(self):
+ with self.coordinator._lock:
+ if not self.enabled:
+ log.debug('Heartbeat disabled. Waiting')
+ self.coordinator._lock.wait()
+ log.debug('Heartbeat re-enabled.')
+ return
+
+ if self.coordinator.state is not MemberState.STABLE:
+ # the group is not stable (perhaps because we left the
+ # group or because the coordinator kicked us out), so
+ # disable heartbeats and wait for the main thread to rejoin.
+ log.debug('Group state is not stable, disabling heartbeats')
+ self.disable()
+ return
+
+ # TODO: When consumer.wakeup() is implemented, we need to
+ # disable here to prevent propagating an exception to this
+ # heartbeat thread
+ self.coordinator._client.poll(timeout_ms=0)
+
+ if self.coordinator.coordinator_unknown():
+ if not self.coordinator.lookup_coordinator().is_done:
+ self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000)
+
+ elif self.coordinator.heartbeat.session_timeout_expired():
+ # the session timeout has expired without seeing a
+ # successful heartbeat, so we should probably make sure
+ # the coordinator is still healthy.
+ log.debug('Heartbeat session expired, marking coordinator dead')
+ self.coordinator.coordinator_dead('Heartbeat session expired')
+
+ elif self.coordinator.heartbeat.poll_timeout_expired():
+ # the poll timeout has expired, which means that the
+ # foreground thread has stalled in between calls to
+ # poll(), so we explicitly leave the group.
+ log.debug('Heartbeat poll expired, leaving group')
+ self.coordinator.maybe_leave_group()
+
+ elif not self.coordinator.heartbeat.should_heartbeat():
+ # poll again after waiting for the retry backoff in case
+ # the heartbeat failed or the coordinator disconnected
+ log.debug('Not ready to heartbeat, waiting')
+ self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000)
+
+ else:
+ self.coordinator.heartbeat.sent_heartbeat()
+ future = self.coordinator._send_heartbeat_request()
+ future.add_callback(self._handle_heartbeat_success)
+ future.add_errback(self._handle_heartbeat_failure)
+
+ def _handle_heartbeat_success(self, result):
+ with self.coordinator._lock:
+ self.coordinator.heartbeat.received_heartbeat()
+
+ def _handle_heartbeat_failure(self, exception):
+ with self.coordinator._lock:
+ if isinstance(exception, Errors.RebalanceInProgressError):
+ # it is valid to continue heartbeating while the group is
+ # rebalancing. This ensures that the coordinator keeps the
+ # member in the group for as long as the duration of the
+ # rebalance timeout. If we stop sending heartbeats, however,
+ # then the session timeout may expire before we can rejoin.
+ self.coordinator.heartbeat.received_heartbeat()
+ else:
+ self.coordinator.heartbeat.fail_heartbeat()
+ # wake up the thread if it's sleeping to reschedule the heartbeat
+ self.coordinator._lock.notify()
diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py
index dee70f0..48dcad4 100644
--- a/kafka/coordinator/consumer.py
+++ b/kafka/coordinator/consumer.py
@@ -1,14 +1,13 @@
-from __future__ import absolute_import
+from __future__ import absolute_import, division
-import copy
import collections
+import copy
import logging
import time
-import weakref
from kafka.vendor import six
-from .base import BaseCoordinator
+from .base import BaseCoordinator, Generation
from .assignors.range import RangePartitionAssignor
from .assignors.roundrobin import RoundRobinPartitionAssignor
from .protocol import ConsumerProtocol
@@ -30,12 +29,13 @@ class ConsumerCoordinator(BaseCoordinator):
'group_id': 'kafka-python-default-group',
'enable_auto_commit': True,
'auto_commit_interval_ms': 5000,
- 'default_offset_commit_callback': lambda offsets, response: True,
+ 'default_offset_commit_callback': None,
'assignors': (RangePartitionAssignor, RoundRobinPartitionAssignor),
- 'session_timeout_ms': 30000,
+ 'session_timeout_ms': 10000,
'heartbeat_interval_ms': 3000,
+ 'max_poll_interval_ms': 300000,
'retry_backoff_ms': 100,
- 'api_version': (0, 9),
+ 'api_version': (0, 10, 1),
'exclude_internal_topics': True,
'metric_group_prefix': 'consumer'
}
@@ -52,9 +52,9 @@ class ConsumerCoordinator(BaseCoordinator):
auto_commit_interval_ms (int): milliseconds between automatic
offset commits, if enable_auto_commit is True. Default: 5000.
default_offset_commit_callback (callable): called as
- callback(offsets, response) response will be either an Exception
- or a OffsetCommitResponse struct. This callback can be used to
- trigger custom actions when a commit request completes.
+ callback(offsets, exception) response will be either an Exception
+ or None. This callback can be used to trigger custom actions when
+ a commit request completes.
assignors (list): List of objects to use to distribute partition
ownership amongst consumer instances when group management is
used. Default: [RangePartitionAssignor, RoundRobinPartitionAssignor]
@@ -83,17 +83,27 @@ class ConsumerCoordinator(BaseCoordinator):
if key in configs:
self.config[key] = configs[key]
- if self.config['api_version'] >= (0, 9) and self.config['group_id'] is not None:
- assert self.config['assignors'], 'Coordinator requires assignors'
-
self._subscription = subscription
self._metadata_snapshot = self._build_metadata_snapshot(subscription, client.cluster)
self._assignment_snapshot = None
self._cluster = client.cluster
- self._cluster.request_update()
- self._cluster.add_listener(WeakMethod(self._handle_metadata_update))
+ self.auto_commit_interval = self.config['auto_commit_interval_ms'] / 1000
+ self.next_auto_commit_deadline = None
+ self.completed_offset_commits = collections.deque()
+
+ if self.config['default_offset_commit_callback'] is None:
+ self.config['default_offset_commit_callback'] = self._default_offset_commit_callback
+
+ if self.config['group_id'] is not None:
+ if self.config['api_version'] >= (0, 9):
+ if not self.config['assignors']:
+ raise Errors.KafkaConfigurationError('Coordinator requires assignors')
+ if self.config['api_version'] < (0, 10, 1):
+ if self.config['max_poll_interval_ms'] != self.config['session_timeout_ms']:
+ raise Errors.KafkaConfigurationError("Broker version %s does not support "
+ "different values for max_poll_interval_ms "
+ "and session_timeout_ms")
- self._auto_commit_task = None
if self.config['enable_auto_commit']:
if self.config['api_version'] < (0, 8, 1):
log.warning('Broker version (%s) does not support offset'
@@ -104,13 +114,14 @@ class ConsumerCoordinator(BaseCoordinator):
log.warning('group_id is None: disabling auto-commit.')
self.config['enable_auto_commit'] = False
else:
- interval = self.config['auto_commit_interval_ms'] / 1000.0
- self._auto_commit_task = AutoCommitTask(weakref.proxy(self), interval)
- self._auto_commit_task.reschedule()
+ self.next_auto_commit_deadline = time.time() + self.auto_commit_interval
self.consumer_sensors = ConsumerCoordinatorMetrics(
metrics, self.config['metric_group_prefix'], self._subscription)
+ self._cluster.request_update()
+ self._cluster.add_listener(WeakMethod(self._handle_metadata_update))
+
def __del__(self):
if hasattr(self, '_cluster') and self._cluster:
self._cluster.remove_listener(WeakMethod(self._handle_metadata_update))
@@ -210,8 +221,7 @@ class ConsumerCoordinator(BaseCoordinator):
assignor.on_assignment(assignment)
# reschedule the auto commit starting from now
- if self._auto_commit_task:
- self._auto_commit_task.reschedule()
+ self.next_auto_commit_deadline = time.time() + self.auto_commit_interval
assigned = set(self._subscription.assigned_partitions())
log.info("Setting newly assigned partitions %s for group %s",
@@ -227,6 +237,54 @@ class ConsumerCoordinator(BaseCoordinator):
self._subscription.listener, self.group_id,
assigned)
+ def poll(self):
+ """
+ Poll for coordinator events. Only applicable if group_id is set, and
+ broker version supports GroupCoordinators. This ensures that the
+ coordinator is known, and if using automatic partition assignment,
+ ensures that the consumer has joined the group. This also handles
+ periodic offset commits if they are enabled.
+ """
+ if self.group_id is None or self.config['api_version'] < (0, 8, 2):
+ return
+
+ self._invoke_completed_offset_commit_callbacks()
+ self.ensure_coordinator_ready()
+
+ if self.config['api_version'] >= (0, 9) and self._subscription.partitions_auto_assigned():
+ if self.need_rejoin():
+ # due to a race condition between the initial metadata fetch and the
+ # initial rebalance, we need to ensure that the metadata is fresh
+ # before joining initially, and then request the metadata update. If
+ # metadata update arrives while the rebalance is still pending (for
+ # example, when the join group is still inflight), then we will lose
+ # track of the fact that we need to rebalance again to reflect the
+ # change to the topic subscription. Without ensuring that the
+ # metadata is fresh, any metadata update that changes the topic
+ # subscriptions and arrives while a rebalance is in progress will
+ # essentially be ignored. See KAFKA-3949 for the complete
+ # description of the problem.
+ if self._subscription.subscribed_pattern:
+ metadata_update = self._client.cluster.request_update()
+ self._client.poll(future=metadata_update)
+
+ self.ensure_active_group()
+
+ self.poll_heartbeat()
+
+ self._maybe_auto_commit_offsets_async()
+
+ def time_to_next_poll(self):
+ """Return seconds (float) remaining until :meth:`.poll` should be called again"""
+ if not self.config['enable_auto_commit']:
+ return self.time_to_next_heartbeat()
+
+ if time.time() > self.next_auto_commit_deadline:
+ return 0
+
+ return min(self.next_auto_commit_deadline - time.time(),
+ self.time_to_next_heartbeat())
+
def _perform_assignment(self, leader_id, assignment_strategy, members):
assignor = self._lookup_assignor(assignment_strategy)
assert assignor, 'Invalid assignment protocol: %s' % assignment_strategy
@@ -327,7 +385,7 @@ class ConsumerCoordinator(BaseCoordinator):
if not future.retriable():
raise future.exception # pylint: disable-msg=raising-bad-type
- time.sleep(self.config['retry_backoff_ms'] / 1000.0)
+ time.sleep(self.config['retry_backoff_ms'] / 1000)
def close(self, autocommit=True):
"""Close the coordinator, leave the current group,
@@ -344,6 +402,11 @@ class ConsumerCoordinator(BaseCoordinator):
finally:
super(ConsumerCoordinator, self).close()
+ def _invoke_completed_offset_commit_callbacks(self):
+ while self.completed_offset_commits:
+ callback, offsets, exception = self.completed_offset_commits.popleft()
+ callback(offsets, exception)
+
def commit_offsets_async(self, offsets, callback=None):
"""Commit specific offsets asynchronously.
@@ -354,6 +417,7 @@ class ConsumerCoordinator(BaseCoordinator):
struct. This callback can be used to trigger custom actions when
a commit request completes.
"""
+ self._invoke_completed_offset_commit_callbacks()
if not self.coordinator_unknown():
self._do_commit_offsets_async(offsets, callback)
else:
@@ -367,7 +431,7 @@ class ConsumerCoordinator(BaseCoordinator):
future = self.lookup_coordinator()
future.add_callback(self._do_commit_offsets_async, offsets, callback)
if callback:
- future.add_errback(callback)
+ future.add_errback(lambda e: self.completed_offset_commits.appendleft((callback, offsets, e)))
# ensure the commit has a chance to be transmitted (without blocking on
# its completion). Note that commits are treated as heartbeats by the
@@ -384,7 +448,7 @@ class ConsumerCoordinator(BaseCoordinator):
callback = self.config['default_offset_commit_callback']
self._subscription.needs_fetch_committed_offsets = True
future = self._send_offset_commit_request(offsets)
- future.add_both(callback, offsets)
+ future.add_both(lambda res: self.completed_offset_commits.appendleft((callback, offsets, res)))
return future
def commit_offsets_sync(self, offsets):
@@ -402,6 +466,7 @@ class ConsumerCoordinator(BaseCoordinator):
assert all(map(lambda k: isinstance(k, TopicPartition), offsets))
assert all(map(lambda v: isinstance(v, OffsetAndMetadata),
offsets.values()))
+ self._invoke_completed_offset_commit_callbacks()
if not offsets:
return
@@ -417,26 +482,24 @@ class ConsumerCoordinator(BaseCoordinator):
if not future.retriable():
raise future.exception # pylint: disable-msg=raising-bad-type
- time.sleep(self.config['retry_backoff_ms'] / 1000.0)
+ time.sleep(self.config['retry_backoff_ms'] / 1000)
def _maybe_auto_commit_offsets_sync(self):
- if self._auto_commit_task is None:
- return
-
- try:
- self.commit_offsets_sync(self._subscription.all_consumed_offsets())
-
- # The three main group membership errors are known and should not
- # require a stacktrace -- just a warning
- except (Errors.UnknownMemberIdError,
- Errors.IllegalGenerationError,
- Errors.RebalanceInProgressError):
- log.warning("Offset commit failed: group membership out of date"
- " This is likely to cause duplicate message"
- " delivery.")
- except Exception:
- log.exception("Offset commit failed: This is likely to cause"
- " duplicate message delivery")
+ if self.config['enable_auto_commit']:
+ try:
+ self.commit_offsets_sync(self._subscription.all_consumed_offsets())
+
+ # The three main group membership errors are known and should not
+ # require a stacktrace -- just a warning
+ except (Errors.UnknownMemberIdError,
+ Errors.IllegalGenerationError,
+ Errors.RebalanceInProgressError):
+ log.warning("Offset commit failed: group membership out of date"
+ " This is likely to cause duplicate message"
+ " delivery.")
+ except Exception:
+ log.exception("Offset commit failed: This is likely to cause"
+ " duplicate message delivery")
def _send_offset_commit_request(self, offsets):
"""Commit offsets for the specified list of topics and partitions.
@@ -458,23 +521,34 @@ class ConsumerCoordinator(BaseCoordinator):
offsets.values()))
if not offsets:
log.debug('No offsets to commit')
- return Future().success(True)
+ return Future().success(None)
- elif self.coordinator_unknown():
+ node_id = self.coordinator()
+ if node_id is None:
return Future().failure(Errors.GroupCoordinatorNotAvailableError)
- node_id = self.coordinator_id
# create the offset commit request
offset_data = collections.defaultdict(dict)
for tp, offset in six.iteritems(offsets):
offset_data[tp.topic][tp.partition] = offset
+ if self._subscription.partitions_auto_assigned():
+ generation = self.generation()
+ else:
+ generation = Generation.NO_GENERATION
+
+ # if the generation is None, we are not part of an active group
+ # (and we expect to be). The only thing we can do is fail the commit
+ # and let the user rejoin the group in poll()
+ if self.config['api_version'] >= (0, 9) and generation is None:
+ return Future().failure(Errors.CommitFailedError())
+
if self.config['api_version'] >= (0, 9):
request = OffsetCommitRequest[2](
self.group_id,
- self.generation,
- self.member_id,
+ generation.generation_id,
+ generation.member_id,
OffsetCommitRequest[2].DEFAULT_RETENTION_TIME,
[(
topic, [(
@@ -568,7 +642,7 @@ class ConsumerCoordinator(BaseCoordinator):
error = error_type(self.group_id)
log.debug("OffsetCommit for group %s failed: %s",
self.group_id, error)
- self._subscription.mark_for_reassignment()
+ self.reset_generation()
future.failure(Errors.CommitFailedError(
"Commit cannot be completed since the group has"
" already rebalanced and assigned the partitions to"
@@ -593,7 +667,7 @@ class ConsumerCoordinator(BaseCoordinator):
unauthorized_topics, self.group_id)
future.failure(Errors.TopicAuthorizationFailedError(unauthorized_topics))
else:
- future.success(True)
+ future.success(None)
def _send_offset_fetch_request(self, partitions):
"""Fetch the committed offsets for a set of partitions.
@@ -612,11 +686,10 @@ class ConsumerCoordinator(BaseCoordinator):
if not partitions:
return Future().success({})
- elif self.coordinator_unknown():
+ node_id = self.coordinator()
+ if node_id is None:
return Future().failure(Errors.GroupCoordinatorNotAvailableError)
- node_id = self.coordinator_id
-
# Verify node is ready
if not self._client.ready(node_id):
log.debug("Node %s not ready -- failing offset fetch request",
@@ -665,11 +738,6 @@ class ConsumerCoordinator(BaseCoordinator):
# re-discover the coordinator and retry
self.coordinator_dead(error_type())
future.failure(error)
- elif error_type in (Errors.UnknownMemberIdError,
- Errors.IllegalGenerationError):
- # need to re-join group
- self._subscription.mark_for_reassignment()
- future.failure(error)
elif error_type is Errors.UnknownTopicOrPartitionError:
log.warning("OffsetFetchRequest -- unknown topic %s"
" (have you committed any offsets yet?)",
@@ -689,50 +757,28 @@ class ConsumerCoordinator(BaseCoordinator):
" %s", self.group_id, tp)
future.success(offsets)
+ def _default_offset_commit_callback(self, offsets, exception):
+ if exception is not None:
+ log.error("Offset commit failed: %s", exception)
-class AutoCommitTask(object):
- def __init__(self, coordinator, interval):
- self._coordinator = coordinator
- self._client = coordinator._client
- self._interval = interval
-
- def reschedule(self, at=None):
- if at is None:
- at = time.time() + self._interval
- self._client.schedule(self, at)
-
- def __call__(self):
- if self._coordinator.coordinator_unknown():
- log.debug("Cannot auto-commit offsets for group %s because the"
- " coordinator is unknown", self._coordinator.group_id)
- backoff = self._coordinator.config['retry_backoff_ms'] / 1000.0
- self.reschedule(time.time() + backoff)
- return
-
- self._coordinator.commit_offsets_async(
- self._coordinator._subscription.all_consumed_offsets(),
- self._handle_commit_response)
-
- def _handle_commit_response(self, offsets, result):
- if result is True:
- log.debug("Successfully auto-committed offsets for group %s",
- self._coordinator.group_id)
- next_at = time.time() + self._interval
- elif not isinstance(result, BaseException):
- raise Errors.IllegalStateError(
- 'Unrecognized result in _handle_commit_response: %s'
- % result)
- elif hasattr(result, 'retriable') and result.retriable:
- log.debug("Failed to auto-commit offsets for group %s: %s,"
- " will retry immediately", self._coordinator.group_id,
- result)
- next_at = time.time()
- else:
+ def _commit_offsets_async_on_complete(self, offsets, exception):
+ if exception is not None:
log.warning("Auto offset commit failed for group %s: %s",
- self._coordinator.group_id, result)
- next_at = time.time() + self._interval
+ self.group_id, exception)
+ if getattr(exception, 'retriable', False):
+ self.next_auto_commit_deadline = min(time.time() + self.config['retry_backoff_ms'] / 1000, self.next_auto_commit_deadline)
+ else:
+ log.debug("Completed autocommit of offsets %s for group %s",
+ offsets, self.group_id)
- self.reschedule(next_at)
+ def _maybe_auto_commit_offsets_async(self):
+ if self.config['enable_auto_commit']:
+ if self.coordinator_unknown():
+ self.next_auto_commit_deadline = time.time() + self.config['retry_backoff_ms'] / 1000
+ elif time.time() > self.next_auto_commit_deadline:
+ self.next_auto_commit_deadline = time.time() + self.auto_commit_interval
+ self.commit_offsets_async(self._subscription.all_consumed_offsets(),
+ self._commit_offsets_async_on_complete)
class ConsumerCoordinatorMetrics(object):
diff --git a/kafka/coordinator/heartbeat.py b/kafka/coordinator/heartbeat.py
index fddf298..2f5930b 100644
--- a/kafka/coordinator/heartbeat.py
+++ b/kafka/coordinator/heartbeat.py
@@ -1,4 +1,4 @@
-from __future__ import absolute_import
+from __future__ import absolute_import, division
import copy
import time
@@ -6,8 +6,11 @@ import time
class Heartbeat(object):
DEFAULT_CONFIG = {
+ 'group_id': None,
'heartbeat_interval_ms': 3000,
- 'session_timeout_ms': 30000,
+ 'session_timeout_ms': 10000,
+ 'max_poll_interval_ms': 300000,
+ 'retry_backoff_ms': 100,
}
def __init__(self, **configs):
@@ -16,32 +19,50 @@ class Heartbeat(object):
if key in configs:
self.config[key] = configs[key]
- assert (self.config['heartbeat_interval_ms']
- <= self.config['session_timeout_ms']), (
- 'Heartbeat interval must be lower than the session timeout')
+ if self.config['group_id'] is not None:
+ assert (self.config['heartbeat_interval_ms']
+ <= self.config['session_timeout_ms']), (
+ 'Heartbeat interval must be lower than the session timeout')
- self.interval = self.config['heartbeat_interval_ms'] / 1000.0
- self.timeout = self.config['session_timeout_ms'] / 1000.0
self.last_send = -1 * float('inf')
self.last_receive = -1 * float('inf')
+ self.last_poll = -1 * float('inf')
self.last_reset = time.time()
+ self.heartbeat_failed = None
+
+ def poll(self):
+ self.last_poll = time.time()
def sent_heartbeat(self):
self.last_send = time.time()
+ self.heartbeat_failed = False
+
+ def fail_heartbeat(self):
+ self.heartbeat_failed = True
def received_heartbeat(self):
self.last_receive = time.time()
- def ttl(self):
- last_beat = max(self.last_send, self.last_reset)
- return max(0, last_beat + self.interval - time.time())
+ def time_to_next_heartbeat(self):
+ """Returns seconds (float) remaining before next heartbeat should be sent"""
+ time_since_last_heartbeat = time.time() - max(self.last_send, self.last_reset)
+ if self.heartbeat_failed:
+ delay_to_next_heartbeat = self.config['retry_backoff_ms'] / 1000
+ else:
+ delay_to_next_heartbeat = self.config['heartbeat_interval_ms'] / 1000
+ return max(0, delay_to_next_heartbeat - time_since_last_heartbeat)
def should_heartbeat(self):
- return self.ttl() == 0
+ return self.time_to_next_heartbeat() == 0
- def session_expired(self):
+ def session_timeout_expired(self):
last_recv = max(self.last_receive, self.last_reset)
- return (time.time() - last_recv) > self.timeout
+ return (time.time() - last_recv) > (self.config['session_timeout_ms'] / 1000)
- def reset_session_timeout(self):
+ def reset_timeouts(self):
self.last_reset = time.time()
+ self.last_poll = time.time()
+ self.heartbeat_failed = False
+
+ def poll_timeout_expired(self):
+ return (time.time() - self.last_poll) > (self.config['max_poll_interval_ms'] / 1000)
diff --git a/kafka/errors.py b/kafka/errors.py
index 4a409db..c70853c 100644
--- a/kafka/errors.py
+++ b/kafka/errors.py
@@ -59,7 +59,18 @@ class UnrecognizedBrokerVersion(KafkaError):
class CommitFailedError(KafkaError):
- pass
+ def __init__(self, *args, **kwargs):
+ super(CommitFailedError, self).__init__(
+ """Commit cannot be completed since the group has already
+ rebalanced and assigned the partitions to another member.
+ This means that the time between subsequent calls to poll()
+ was longer than the configured max_poll_interval_ms, which
+ typically implies that the poll loop is spending too much
+ time message processing. You can address this either by
+ increasing the rebalance timeout with max_poll_interval_ms,
+ or by reducing the maximum size of batches returned in poll()
+ with max_poll_records.
+ """, *args, **kwargs)
class AuthenticationMethodNotSupported(KafkaError):
diff --git a/kafka/protocol/group.py b/kafka/protocol/group.py
index ce75a5f..c6acca8 100644
--- a/kafka/protocol/group.py
+++ b/kafka/protocol/group.py
@@ -185,7 +185,7 @@ class HeartbeatRequest_v1(Request):
API_KEY = 12
API_VERSION = 1
RESPONSE_TYPE = HeartbeatResponse_v1
- SCHEMA = HeartbeatRequest_v0
+ SCHEMA = HeartbeatRequest_v0.SCHEMA
HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1]
diff --git a/test/test_client_async.py b/test/test_client_async.py
index ec45543..eece139 100644
--- a/test/test_client_async.py
+++ b/test/test_client_async.py
@@ -253,11 +253,9 @@ def test_poll(mocker):
metadata = mocker.patch.object(KafkaClient, '_maybe_refresh_metadata')
_poll = mocker.patch.object(KafkaClient, '_poll')
cli = KafkaClient(api_version=(0, 9))
- tasks = mocker.patch.object(cli._delayed_tasks, 'next_at')
# metadata timeout wins
metadata.return_value = 1000
- tasks.return_value = 2
cli.poll()
_poll.assert_called_with(1.0)
@@ -265,14 +263,8 @@ def test_poll(mocker):
cli.poll(250)
_poll.assert_called_with(0.25)
- # tasks timeout wins
- tasks.return_value = 0
- cli.poll(250)
- _poll.assert_called_with(0)
-
# default is request_timeout_ms
metadata.return_value = 1000000
- tasks.return_value = 10000
cli.poll()
_poll.assert_called_with(cli.config['request_timeout_ms'] / 1000.0)
@@ -325,9 +317,6 @@ def client(mocker):
connections_max_idle_ms=float('inf'),
api_version=(0, 9))
- tasks = mocker.patch.object(cli._delayed_tasks, 'next_at')
- tasks.return_value = 9999999
-
ttl = mocker.patch.object(cli.cluster, 'ttl')
ttl.return_value = 0
return cli
diff --git a/test/test_consumer.py b/test/test_consumer.py
index e5dd946..013529f 100644
--- a/test/test_consumer.py
+++ b/test/test_consumer.py
@@ -14,11 +14,11 @@ from kafka.structs import (
class TestKafkaConsumer(unittest.TestCase):
def test_non_integer_partitions(self):
with self.assertRaises(AssertionError):
- SimpleConsumer(MagicMock(), 'group', 'topic', partitions = [ '0' ])
+ SimpleConsumer(MagicMock(), 'group', 'topic', partitions=['0'])
def test_session_timeout_larger_than_request_timeout_raises(self):
with self.assertRaises(KafkaConfigurationError):
- KafkaConsumer(bootstrap_servers='localhost:9092', session_timeout_ms=60000, request_timeout_ms=40000)
+ KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(0,9), group_id='foo', session_timeout_ms=60000, request_timeout_ms=40000)
def test_fetch_max_wait_larger_than_request_timeout_raises(self):
with self.assertRaises(KafkaConfigurationError):
diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py
index 8f25e9f..690d45a 100644
--- a/test/test_consumer_group.py
+++ b/test/test_consumer_group.py
@@ -9,6 +9,7 @@ import six
from kafka import SimpleClient
from kafka.conn import ConnectionStates
from kafka.consumer.group import KafkaConsumer
+from kafka.coordinator.base import MemberState, Generation
from kafka.structs import TopicPartition
from test.conftest import version
@@ -92,9 +93,10 @@ def test_group(kafka_broker, topic):
# If all consumers exist and have an assignment
else:
+ logging.info('All consumers have assignment... checking for stable group')
# Verify all consumers are in the same generation
# then log state and break while loop
- generations = set([consumer._coordinator.generation
+ generations = set([consumer._coordinator._generation.generation_id
for consumer in list(consumers.values())])
# New generation assignment is not complete until
@@ -105,12 +107,16 @@ def test_group(kafka_broker, topic):
if not rejoining and len(generations) == 1:
for c, consumer in list(consumers.items()):
logging.info("[%s] %s %s: %s", c,
- consumer._coordinator.generation,
- consumer._coordinator.member_id,
+ consumer._coordinator._generation.generation_id,
+ consumer._coordinator._generation.member_id,
consumer.assignment())
break
+ else:
+ logging.info('Rejoining: %s, generations: %s', rejoining, generations)
+ time.sleep(1)
assert time.time() < timeout, "timeout waiting for assignments"
+ logging.info('Group stabilized; verifying assignment')
group_assignment = set()
for c in range(num_consumers):
assert len(consumers[c].assignment()) != 0
@@ -120,9 +126,12 @@ def test_group(kafka_broker, topic):
assert group_assignment == set([
TopicPartition(topic, partition)
for partition in range(num_partitions)])
+ logging.info('Assignment looks good!')
finally:
+ logging.info('Shutting down %s consumers', num_consumers)
for c in range(num_consumers):
+ logging.info('Stopping consumer %s', c)
stop[c].set()
threads[c].join()
@@ -143,3 +152,33 @@ def test_paused(kafka_broker, topic):
consumer.unsubscribe()
assert set() == consumer.paused()
+
+
+@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
+@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
+def test_heartbeat_thread(kafka_broker, topic):
+ group_id = 'test-group-' + random_string(6)
+ consumer = KafkaConsumer(topic,
+ bootstrap_servers=get_connect_str(kafka_broker),
+ group_id=group_id,
+ heartbeat_interval_ms=500)
+
+ # poll until we have joined group / have assignment
+ while not consumer.assignment():
+ consumer.poll(timeout_ms=100)
+
+ assert consumer._coordinator.state is MemberState.STABLE
+ last_poll = consumer._coordinator.heartbeat.last_poll
+ last_beat = consumer._coordinator.heartbeat.last_send
+
+ timeout = time.time() + 30
+ while True:
+ if time.time() > timeout:
+ raise RuntimeError('timeout waiting for heartbeat')
+ if consumer._coordinator.heartbeat.last_send > last_beat:
+ break
+ time.sleep(0.5)
+
+ assert consumer._coordinator.heartbeat.last_poll == last_poll
+ consumer.poll(timeout_ms=100)
+ assert consumer._coordinator.heartbeat.last_poll > last_poll
diff --git a/test/test_consumer_integration.py b/test/test_consumer_integration.py
index d1843b3..ded2314 100644
--- a/test/test_consumer_integration.py
+++ b/test/test_consumer_integration.py
@@ -739,7 +739,8 @@ class TestConsumerIntegration(KafkaIntegrationTestCase):
@kafka_versions('>=0.10.1')
def test_kafka_consumer_offsets_for_times_errors(self):
- consumer = self.kafka_consumer()
+ consumer = self.kafka_consumer(fetch_max_wait_ms=200,
+ request_timeout_ms=500)
tp = TopicPartition(self.topic, 0)
bad_tp = TopicPartition(self.topic, 100)
diff --git a/test/test_coordinator.py b/test/test_coordinator.py
index 0e96110..7dc0e04 100644
--- a/test/test_coordinator.py
+++ b/test/test_coordinator.py
@@ -10,6 +10,7 @@ from kafka.consumer.subscription_state import (
SubscriptionState, ConsumerRebalanceListener)
from kafka.coordinator.assignors.range import RangePartitionAssignor
from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor
+from kafka.coordinator.base import Generation, MemberState, HeartbeatThread
from kafka.coordinator.consumer import ConsumerCoordinator
from kafka.coordinator.protocol import (
ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment)
@@ -43,13 +44,13 @@ def test_autocommit_enable_api_version(client, api_version):
coordinator = ConsumerCoordinator(client, SubscriptionState(),
Metrics(),
enable_auto_commit=True,
+ session_timeout_ms=30000, # session_timeout_ms and max_poll_interval_ms
+ max_poll_interval_ms=30000, # should be the same to avoid KafkaConfigurationError
group_id='foobar',
api_version=api_version)
if api_version < (0, 8, 1):
- assert coordinator._auto_commit_task is None
assert coordinator.config['enable_auto_commit'] is False
else:
- assert coordinator._auto_commit_task is not None
assert coordinator.config['enable_auto_commit'] is True
@@ -269,19 +270,19 @@ def test_close(mocker, coordinator):
mocker.patch.object(coordinator, '_handle_leave_group_response')
mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False)
coordinator.coordinator_id = 0
- coordinator.generation = 1
+ coordinator._generation = Generation(1, 'foobar', b'')
+ coordinator.state = MemberState.STABLE
cli = coordinator._client
- mocker.patch.object(cli, 'unschedule')
mocker.patch.object(cli, 'send', return_value=Future().success('foobar'))
mocker.patch.object(cli, 'poll')
coordinator.close()
assert coordinator._maybe_auto_commit_offsets_sync.call_count == 1
- cli.unschedule.assert_called_with(coordinator.heartbeat_task)
coordinator._handle_leave_group_response.assert_called_with('foobar')
- assert coordinator.generation == -1
- assert coordinator.member_id == ''
+ assert coordinator.generation() is None
+ assert coordinator._generation is Generation.NO_GENERATION
+ assert coordinator.state is MemberState.UNJOINED
assert coordinator.rejoin_needed is True
@@ -296,6 +297,7 @@ def offsets():
def test_commit_offsets_async(mocker, coordinator, offsets):
mocker.patch.object(coordinator._client, 'poll')
mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False)
+ mocker.patch.object(coordinator, 'ensure_coordinator_ready')
mocker.patch.object(coordinator, '_send_offset_commit_request',
return_value=Future().success('fizzbuzz'))
coordinator.commit_offsets_async(offsets)
@@ -362,19 +364,21 @@ def test_maybe_auto_commit_offsets_sync(mocker, api_version, group_id, enable,
coordinator = ConsumerCoordinator(client, SubscriptionState(),
Metrics(),
api_version=api_version,
+ session_timeout_ms=30000,
+ max_poll_interval_ms=30000,
enable_auto_commit=enable,
group_id=group_id)
commit_sync = mocker.patch.object(coordinator, 'commit_offsets_sync',
side_effect=error)
if has_auto_commit:
- assert coordinator._auto_commit_task is not None
+ assert coordinator.next_auto_commit_deadline is not None
else:
- assert coordinator._auto_commit_task is None
+ assert coordinator.next_auto_commit_deadline is None
assert coordinator._maybe_auto_commit_offsets_sync() is None
if has_auto_commit:
- assert coordinator._auto_commit_task is not None
+ assert coordinator.next_auto_commit_deadline is not None
assert commit_sync.call_count == (1 if commit_offsets else 0)
assert mock_warn.call_count == (1 if warn else 0)
@@ -387,24 +391,25 @@ def patched_coord(mocker, coordinator):
coordinator._subscription.needs_partition_assignment = False
mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False)
coordinator.coordinator_id = 0
- coordinator.generation = 0
+ mocker.patch.object(coordinator, 'coordinator', return_value=0)
+ coordinator._generation = Generation(0, 'foobar', b'')
+ coordinator.state = MemberState.STABLE
+ coordinator.rejoin_needed = False
mocker.patch.object(coordinator, 'need_rejoin', return_value=False)
mocker.patch.object(coordinator._client, 'least_loaded_node',
return_value=1)
mocker.patch.object(coordinator._client, 'ready', return_value=True)
mocker.patch.object(coordinator._client, 'send')
- mocker.patch.object(coordinator._client, 'schedule')
mocker.spy(coordinator, '_failed_request')
mocker.spy(coordinator, '_handle_offset_commit_response')
mocker.spy(coordinator, '_handle_offset_fetch_response')
- mocker.spy(coordinator.heartbeat_task, '_handle_heartbeat_success')
- mocker.spy(coordinator.heartbeat_task, '_handle_heartbeat_failure')
return coordinator
-def test_send_offset_commit_request_fail(patched_coord, offsets):
+def test_send_offset_commit_request_fail(mocker, patched_coord, offsets):
patched_coord.coordinator_unknown.return_value = True
patched_coord.coordinator_id = None
+ patched_coord.coordinator.return_value = None
# No offsets
ret = patched_coord._send_offset_commit_request({})
@@ -488,7 +493,14 @@ def test_handle_offset_commit_response(mocker, patched_coord, offsets,
response)
assert isinstance(future.exception, error)
assert patched_coord.coordinator_id is (None if dead else 0)
- assert patched_coord._subscription.needs_partition_assignment is reassign
+ if reassign:
+ assert patched_coord._generation is Generation.NO_GENERATION
+ assert patched_coord.rejoin_needed is True
+ assert patched_coord.state is MemberState.UNJOINED
+ else:
+ assert patched_coord._generation is not Generation.NO_GENERATION
+ assert patched_coord.rejoin_needed is False
+ assert patched_coord.state is MemberState.STABLE
@pytest.fixture
@@ -496,9 +508,10 @@ def partitions():
return [TopicPartition('foobar', 0), TopicPartition('foobar', 1)]
-def test_send_offset_fetch_request_fail(patched_coord, partitions):
+def test_send_offset_fetch_request_fail(mocker, patched_coord, partitions):
patched_coord.coordinator_unknown.return_value = True
patched_coord.coordinator_id = None
+ patched_coord.coordinator.return_value = None
# No partitions
ret = patched_coord._send_offset_fetch_request([])
@@ -551,28 +564,18 @@ def test_send_offset_fetch_request_success(patched_coord, partitions):
future, response)
-@pytest.mark.parametrize('response,error,dead,reassign', [
- #(OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 30), (1, 234, b'', 30)])]),
- # Errors.GroupAuthorizationFailedError, False, False),
- #(OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 7), (1, 234, b'', 7)])]),
- # Errors.RequestTimedOutError, True, False),
- #(OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 27), (1, 234, b'', 27)])]),
- # Errors.RebalanceInProgressError, False, True),
+@pytest.mark.parametrize('response,error,dead', [
(OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 14), (1, 234, b'', 14)])]),
- Errors.GroupLoadInProgressError, False, False),
+ Errors.GroupLoadInProgressError, False),
(OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 16), (1, 234, b'', 16)])]),
- Errors.NotCoordinatorForGroupError, True, False),
- (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 25), (1, 234, b'', 25)])]),
- Errors.UnknownMemberIdError, False, True),
- (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 22), (1, 234, b'', 22)])]),
- Errors.IllegalGenerationError, False, True),
+ Errors.NotCoordinatorForGroupError, True),
(OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 29), (1, 234, b'', 29)])]),
- Errors.TopicAuthorizationFailedError, False, False),
+ Errors.TopicAuthorizationFailedError, False),
(OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 0), (1, 234, b'', 0)])]),
- None, False, False),
+ None, False),
])
def test_handle_offset_fetch_response(patched_coord, offsets,
- response, error, dead, reassign):
+ response, error, dead):
future = Future()
patched_coord._handle_offset_fetch_response(future, response)
if error is not None:
@@ -581,15 +584,34 @@ def test_handle_offset_fetch_response(patched_coord, offsets,
assert future.succeeded()
assert future.value == offsets
assert patched_coord.coordinator_id is (None if dead else 0)
- assert patched_coord._subscription.needs_partition_assignment is reassign
-def test_heartbeat(patched_coord):
- patched_coord.coordinator_unknown.return_value = True
+def test_heartbeat(mocker, patched_coord):
+ heartbeat = HeartbeatThread(patched_coord)
+
+ assert not heartbeat.enabled and not heartbeat.closed
+
+ heartbeat.enable()
+ assert heartbeat.enabled
+
+ heartbeat.disable()
+ assert not heartbeat.enabled
+
+ # heartbeat disables when un-joined
+ heartbeat.enable()
+ patched_coord.state = MemberState.UNJOINED
+ heartbeat._run_once()
+ assert not heartbeat.enabled
+
+ heartbeat.enable()
+ patched_coord.state = MemberState.STABLE
+ mocker.spy(patched_coord, '_send_heartbeat_request')
+ mocker.patch.object(patched_coord.heartbeat, 'should_heartbeat', return_value=True)
+ heartbeat._run_once()
+ assert patched_coord._send_heartbeat_request.call_count == 1
- patched_coord.heartbeat_task()
- assert patched_coord._client.schedule.call_count == 1
- assert patched_coord.heartbeat_task._handle_heartbeat_failure.call_count == 1
+ heartbeat.close()
+ assert heartbeat.closed
def test_lookup_coordinator_failure(mocker, coordinator):