From 59400bdad759fddc1d58bdae64de953968316b61 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 11 Mar 2019 21:48:31 -0700 Subject: Synchronize puts to KafkaConsumer protocol buffer during async sends --- kafka/conn.py | 57 +++++++++++++++++++++++++++++++++++-------------------- test/test_conn.py | 28 +++++++++++++++++++++++---- 2 files changed, 60 insertions(+), 25 deletions(-) diff --git a/kafka/conn.py b/kafka/conn.py index 6b5aff9..c273765 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -17,6 +17,7 @@ except ImportError: import socket import struct import sys +import threading import time from kafka.vendor import six @@ -220,7 +221,6 @@ class BrokerConnection(object): self.afi = afi self._sock_afi = afi self._sock_addr = None - self.in_flight_requests = collections.deque() self._api_versions = None self.config = copy.copy(self.DEFAULT_CONFIG) @@ -255,6 +255,20 @@ class BrokerConnection(object): assert gssapi is not None, 'GSSAPI lib not available' assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' + # This is not a general lock / this class is not generally thread-safe yet + # However, to avoid pushing responsibility for maintaining + # per-connection locks to the upstream client, we will use this lock to + # make sure that access to the protocol buffer is synchronized + # when sends happen on multiple threads + self._lock = threading.Lock() + + # the protocol parser instance manages actual tracking of the + # sequence of in-flight requests to responses, which should + # function like a FIFO queue. For additional request data, + # including tracking request futures and timestamps, we + # can use a simple dictionary of correlation_id => request data + self.in_flight_requests = dict() + self._protocol = KafkaProtocol( client_id=self.config['client_id'], api_version=self.config['api_version']) @@ -729,7 +743,7 @@ class BrokerConnection(object): if error is None: error = Errors.Cancelled(str(self)) while self.in_flight_requests: - (_, future, _) = self.in_flight_requests.popleft() + (_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem() future.failure(error) self.config['state_change_callback'](self) @@ -747,23 +761,22 @@ class BrokerConnection(object): def _send(self, request, blocking=True): assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED) future = Future() - correlation_id = self._protocol.send_request(request) - - # Attempt to replicate behavior from prior to introduction of - # send_pending_requests() / async sends - if blocking: - error = self.send_pending_requests() - if isinstance(error, Exception): - future.failure(error) - return future + with self._lock: + correlation_id = self._protocol.send_request(request) log.debug('%s Request %d: %s', self, correlation_id, request) if request.expect_response(): sent_time = time.time() - ifr = (correlation_id, future, sent_time) - self.in_flight_requests.append(ifr) + assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' + self.in_flight_requests[correlation_id] = (future, sent_time) else: future.success(None) + + # Attempt to replicate behavior from prior to introduction of + # send_pending_requests() / async sends + if blocking: + self.send_pending_requests() + return future def send_pending_requests(self): @@ -818,8 +831,12 @@ class BrokerConnection(object): return () # augment respones w/ correlation_id, future, and timestamp - for i, response in enumerate(responses): - (correlation_id, future, timestamp) = self.in_flight_requests.popleft() + for i, (correlation_id, response) in enumerate(responses): + try: + (future, timestamp) = self.in_flight_requests.pop(correlation_id) + except KeyError: + self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) + return () latency_ms = (time.time() - timestamp) * 1000 if self._sensors: self._sensors.request_time.record(latency_ms) @@ -870,20 +887,18 @@ class BrokerConnection(object): self.close(e) return [] else: - return [resp for (_, resp) in responses] # drop correlation id + return responses def requests_timed_out(self): if self.in_flight_requests: - (_, _, oldest_at) = self.in_flight_requests[0] + get_timestamp = lambda v: v[1] + oldest_at = min(map(get_timestamp, + self.in_flight_requests.values())) timeout = self.config['request_timeout_ms'] / 1000.0 if time.time() >= oldest_at + timeout: return True return False - def _next_correlation_id(self): - self._correlation_id = (self._correlation_id + 1) % 2**31 - return self._correlation_id - def _handle_api_version_response(self, response): error_type = Errors.for_code(response.error_code) assert error_type is Errors.NoError, "API version check failed" diff --git a/test/test_conn.py b/test/test_conn.py index 27d77be..953c112 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -112,8 +112,8 @@ def test_send_connecting(conn): def test_send_max_ifr(conn): conn.state = ConnectionStates.CONNECTED max_ifrs = conn.config['max_in_flight_requests_per_connection'] - for _ in range(max_ifrs): - conn.in_flight_requests.append('foo') + for i in range(max_ifrs): + conn.in_flight_requests[i] = 'foo' f = conn.send('foobar') assert f.failed() is True assert isinstance(f.exception, Errors.TooManyInFlightRequests) @@ -170,9 +170,9 @@ def test_send_error(_socket, conn): def test_can_send_more(conn): assert conn.can_send_more() is True max_ifrs = conn.config['max_in_flight_requests_per_connection'] - for _ in range(max_ifrs): + for i in range(max_ifrs): assert conn.can_send_more() is True - conn.in_flight_requests.append('foo') + conn.in_flight_requests[i] = 'foo' assert conn.can_send_more() is False @@ -311,3 +311,23 @@ def test_relookup_on_failure(): assert conn._sock_afi == afi2 assert conn._sock_addr == sockaddr2 conn.close() + + +def test_requests_timed_out(conn): + with mock.patch("time.time", return_value=0): + # No in-flight requests, not timed out + assert not conn.requests_timed_out() + + # Single request, timestamp = now (0) + conn.in_flight_requests[0] = ('foo', 0) + assert not conn.requests_timed_out() + + # Add another request w/ timestamp > request_timeout ago + request_timeout = conn.config['request_timeout_ms'] + expired_timestamp = 0 - request_timeout - 1 + conn.in_flight_requests[1] = ('bar', expired_timestamp) + assert conn.requests_timed_out() + + # Drop the expired request and we should be good to go again + conn.in_flight_requests.pop(1) + assert not conn.requests_timed_out() -- cgit v1.2.1