summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2019-03-11 21:48:31 -0700
committerDana Powers <dana.powers@gmail.com>2019-03-11 21:48:31 -0700
commit59400bdad759fddc1d58bdae64de953968316b61 (patch)
tree0a7df1ada8e4e502c90215d9b54bf16684b3b8e1
parent8c0792581d8a38822c01b40f5d3926c659b0c439 (diff)
downloadkafka-python-conn_lock_async_send.tar.gz
Synchronize puts to KafkaConsumer protocol buffer during async sendsconn_lock_async_send
-rw-r--r--kafka/conn.py57
-rw-r--r--test/test_conn.py28
2 files changed, 60 insertions, 25 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index 6b5aff9..c273765 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -17,6 +17,7 @@ except ImportError:
import socket
import struct
import sys
+import threading
import time
from kafka.vendor import six
@@ -220,7 +221,6 @@ class BrokerConnection(object):
self.afi = afi
self._sock_afi = afi
self._sock_addr = None
- self.in_flight_requests = collections.deque()
self._api_versions = None
self.config = copy.copy(self.DEFAULT_CONFIG)
@@ -255,6 +255,20 @@ class BrokerConnection(object):
assert gssapi is not None, 'GSSAPI lib not available'
assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
+ # This is not a general lock / this class is not generally thread-safe yet
+ # However, to avoid pushing responsibility for maintaining
+ # per-connection locks to the upstream client, we will use this lock to
+ # make sure that access to the protocol buffer is synchronized
+ # when sends happen on multiple threads
+ self._lock = threading.Lock()
+
+ # the protocol parser instance manages actual tracking of the
+ # sequence of in-flight requests to responses, which should
+ # function like a FIFO queue. For additional request data,
+ # including tracking request futures and timestamps, we
+ # can use a simple dictionary of correlation_id => request data
+ self.in_flight_requests = dict()
+
self._protocol = KafkaProtocol(
client_id=self.config['client_id'],
api_version=self.config['api_version'])
@@ -729,7 +743,7 @@ class BrokerConnection(object):
if error is None:
error = Errors.Cancelled(str(self))
while self.in_flight_requests:
- (_, future, _) = self.in_flight_requests.popleft()
+ (_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem()
future.failure(error)
self.config['state_change_callback'](self)
@@ -747,23 +761,22 @@ class BrokerConnection(object):
def _send(self, request, blocking=True):
assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED)
future = Future()
- correlation_id = self._protocol.send_request(request)
-
- # Attempt to replicate behavior from prior to introduction of
- # send_pending_requests() / async sends
- if blocking:
- error = self.send_pending_requests()
- if isinstance(error, Exception):
- future.failure(error)
- return future
+ with self._lock:
+ correlation_id = self._protocol.send_request(request)
log.debug('%s Request %d: %s', self, correlation_id, request)
if request.expect_response():
sent_time = time.time()
- ifr = (correlation_id, future, sent_time)
- self.in_flight_requests.append(ifr)
+ assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
+ self.in_flight_requests[correlation_id] = (future, sent_time)
else:
future.success(None)
+
+ # Attempt to replicate behavior from prior to introduction of
+ # send_pending_requests() / async sends
+ if blocking:
+ self.send_pending_requests()
+
return future
def send_pending_requests(self):
@@ -818,8 +831,12 @@ class BrokerConnection(object):
return ()
# augment respones w/ correlation_id, future, and timestamp
- for i, response in enumerate(responses):
- (correlation_id, future, timestamp) = self.in_flight_requests.popleft()
+ for i, (correlation_id, response) in enumerate(responses):
+ try:
+ (future, timestamp) = self.in_flight_requests.pop(correlation_id)
+ except KeyError:
+ self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
+ return ()
latency_ms = (time.time() - timestamp) * 1000
if self._sensors:
self._sensors.request_time.record(latency_ms)
@@ -870,20 +887,18 @@ class BrokerConnection(object):
self.close(e)
return []
else:
- return [resp for (_, resp) in responses] # drop correlation id
+ return responses
def requests_timed_out(self):
if self.in_flight_requests:
- (_, _, oldest_at) = self.in_flight_requests[0]
+ get_timestamp = lambda v: v[1]
+ oldest_at = min(map(get_timestamp,
+ self.in_flight_requests.values()))
timeout = self.config['request_timeout_ms'] / 1000.0
if time.time() >= oldest_at + timeout:
return True
return False
- def _next_correlation_id(self):
- self._correlation_id = (self._correlation_id + 1) % 2**31
- return self._correlation_id
-
def _handle_api_version_response(self, response):
error_type = Errors.for_code(response.error_code)
assert error_type is Errors.NoError, "API version check failed"
diff --git a/test/test_conn.py b/test/test_conn.py
index 27d77be..953c112 100644
--- a/test/test_conn.py
+++ b/test/test_conn.py
@@ -112,8 +112,8 @@ def test_send_connecting(conn):
def test_send_max_ifr(conn):
conn.state = ConnectionStates.CONNECTED
max_ifrs = conn.config['max_in_flight_requests_per_connection']
- for _ in range(max_ifrs):
- conn.in_flight_requests.append('foo')
+ for i in range(max_ifrs):
+ conn.in_flight_requests[i] = 'foo'
f = conn.send('foobar')
assert f.failed() is True
assert isinstance(f.exception, Errors.TooManyInFlightRequests)
@@ -170,9 +170,9 @@ def test_send_error(_socket, conn):
def test_can_send_more(conn):
assert conn.can_send_more() is True
max_ifrs = conn.config['max_in_flight_requests_per_connection']
- for _ in range(max_ifrs):
+ for i in range(max_ifrs):
assert conn.can_send_more() is True
- conn.in_flight_requests.append('foo')
+ conn.in_flight_requests[i] = 'foo'
assert conn.can_send_more() is False
@@ -311,3 +311,23 @@ def test_relookup_on_failure():
assert conn._sock_afi == afi2
assert conn._sock_addr == sockaddr2
conn.close()
+
+
+def test_requests_timed_out(conn):
+ with mock.patch("time.time", return_value=0):
+ # No in-flight requests, not timed out
+ assert not conn.requests_timed_out()
+
+ # Single request, timestamp = now (0)
+ conn.in_flight_requests[0] = ('foo', 0)
+ assert not conn.requests_timed_out()
+
+ # Add another request w/ timestamp > request_timeout ago
+ request_timeout = conn.config['request_timeout_ms']
+ expired_timestamp = 0 - request_timeout - 1
+ conn.in_flight_requests[1] = ('bar', expired_timestamp)
+ assert conn.requests_timed_out()
+
+ # Drop the expired request and we should be good to go again
+ conn.in_flight_requests.pop(1)
+ assert not conn.requests_timed_out()