summaryrefslogtreecommitdiff
path: root/kafka/conn.py
diff options
context:
space:
mode:
Diffstat (limited to 'kafka/conn.py')
-rw-r--r--kafka/conn.py57
1 files changed, 36 insertions, 21 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index 6b5aff9..c273765 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -17,6 +17,7 @@ except ImportError:
import socket
import struct
import sys
+import threading
import time
from kafka.vendor import six
@@ -220,7 +221,6 @@ class BrokerConnection(object):
self.afi = afi
self._sock_afi = afi
self._sock_addr = None
- self.in_flight_requests = collections.deque()
self._api_versions = None
self.config = copy.copy(self.DEFAULT_CONFIG)
@@ -255,6 +255,20 @@ class BrokerConnection(object):
assert gssapi is not None, 'GSSAPI lib not available'
assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
+ # This is not a general lock / this class is not generally thread-safe yet
+ # However, to avoid pushing responsibility for maintaining
+ # per-connection locks to the upstream client, we will use this lock to
+ # make sure that access to the protocol buffer is synchronized
+ # when sends happen on multiple threads
+ self._lock = threading.Lock()
+
+ # the protocol parser instance manages actual tracking of the
+ # sequence of in-flight requests to responses, which should
+ # function like a FIFO queue. For additional request data,
+ # including tracking request futures and timestamps, we
+ # can use a simple dictionary of correlation_id => request data
+ self.in_flight_requests = dict()
+
self._protocol = KafkaProtocol(
client_id=self.config['client_id'],
api_version=self.config['api_version'])
@@ -729,7 +743,7 @@ class BrokerConnection(object):
if error is None:
error = Errors.Cancelled(str(self))
while self.in_flight_requests:
- (_, future, _) = self.in_flight_requests.popleft()
+ (_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem()
future.failure(error)
self.config['state_change_callback'](self)
@@ -747,23 +761,22 @@ class BrokerConnection(object):
def _send(self, request, blocking=True):
assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED)
future = Future()
- correlation_id = self._protocol.send_request(request)
-
- # Attempt to replicate behavior from prior to introduction of
- # send_pending_requests() / async sends
- if blocking:
- error = self.send_pending_requests()
- if isinstance(error, Exception):
- future.failure(error)
- return future
+ with self._lock:
+ correlation_id = self._protocol.send_request(request)
log.debug('%s Request %d: %s', self, correlation_id, request)
if request.expect_response():
sent_time = time.time()
- ifr = (correlation_id, future, sent_time)
- self.in_flight_requests.append(ifr)
+ assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
+ self.in_flight_requests[correlation_id] = (future, sent_time)
else:
future.success(None)
+
+ # Attempt to replicate behavior from prior to introduction of
+ # send_pending_requests() / async sends
+ if blocking:
+ self.send_pending_requests()
+
return future
def send_pending_requests(self):
@@ -818,8 +831,12 @@ class BrokerConnection(object):
return ()
# augment respones w/ correlation_id, future, and timestamp
- for i, response in enumerate(responses):
- (correlation_id, future, timestamp) = self.in_flight_requests.popleft()
+ for i, (correlation_id, response) in enumerate(responses):
+ try:
+ (future, timestamp) = self.in_flight_requests.pop(correlation_id)
+ except KeyError:
+ self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
+ return ()
latency_ms = (time.time() - timestamp) * 1000
if self._sensors:
self._sensors.request_time.record(latency_ms)
@@ -870,20 +887,18 @@ class BrokerConnection(object):
self.close(e)
return []
else:
- return [resp for (_, resp) in responses] # drop correlation id
+ return responses
def requests_timed_out(self):
if self.in_flight_requests:
- (_, _, oldest_at) = self.in_flight_requests[0]
+ get_timestamp = lambda v: v[1]
+ oldest_at = min(map(get_timestamp,
+ self.in_flight_requests.values()))
timeout = self.config['request_timeout_ms'] / 1000.0
if time.time() >= oldest_at + timeout:
return True
return False
- def _next_correlation_id(self):
- self._correlation_id = (self._correlation_id + 1) % 2**31
- return self._correlation_id
-
def _handle_api_version_response(self, response):
error_type = Errors.for_code(response.error_code)
assert error_type is Errors.NoError, "API version check failed"