summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2019-04-02 09:23:44 -0700
committerGitHub <noreply@github.com>2019-04-02 09:23:44 -0700
commit27cd93be3e7f2e3f3baca04d2126cf3bb6374668 (patch)
treea2ad25e9144b7760a72b5313ff9853957684db7d
parented4cab65704fb5c1c5f0c1071590ca0a7b3fbf4e (diff)
downloadkafka-python-27cd93be3e7f2e3f3baca04d2126cf3bb6374668.tar.gz
Additional BrokerConnection locks to synchronize protocol/IFR state (#1768)
-rw-r--r--kafka/conn.py146
1 files changed, 85 insertions, 61 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index 94cf584..a00206f 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -589,11 +589,14 @@ class BrokerConnection(object):
self.config['sasl_plain_password']]).encode('utf-8'))
size = Int32.encode(len(msg))
try:
- self._send_bytes_blocking(size + msg)
+ with self._lock:
+ if not self._can_send_recv():
+ return future.failure(Errors.NodeNotReadyError(str(self)))
+ self._send_bytes_blocking(size + msg)
- # The server will send a zero sized message (that is Int32(0)) on success.
- # The connection is closed on failure
- data = self._recv_bytes_blocking(4)
+ # The server will send a zero sized message (that is Int32(0)) on success.
+ # The connection is closed on failure
+ data = self._recv_bytes_blocking(4)
except ConnectionError as e:
log.exception("%s: Error receiving reply from server", self)
@@ -617,6 +620,9 @@ class BrokerConnection(object):
).canonicalize(gssapi.MechType.kerberos)
log.debug('%s: GSSAPI name: %s', self, gssapi_name)
+ self._lock.acquire()
+ if not self._can_send_recv():
+ return future.failure(Errors.NodeNotReadyError(str(self)))
# Establish security context and negotiate protection level
# For reference RFC 2222, section 7.2.1
try:
@@ -659,13 +665,16 @@ class BrokerConnection(object):
self._send_bytes_blocking(size + msg)
except ConnectionError as e:
+ self._lock.release()
log.exception("%s: Error receiving reply from server", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return future.failure(error)
except Exception as e:
+ self._lock.release()
return future.failure(e)
+ self._lock.release()
log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
return future.success(True)
@@ -674,6 +683,9 @@ class BrokerConnection(object):
msg = bytes(self._build_oauth_client_request().encode("utf-8"))
size = Int32.encode(len(msg))
+ self._lock.acquire()
+ if not self._can_send_recv():
+ return future.failure(Errors.NodeNotReadyError(str(self)))
try:
# Send SASL OAuthBearer request with OAuth token
self._send_bytes_blocking(size + msg)
@@ -683,11 +695,14 @@ class BrokerConnection(object):
data = self._recv_bytes_blocking(4)
except ConnectionError as e:
+ self._lock.release()
log.exception("%s: Error receiving reply from server", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return future.failure(error)
+ self._lock.release()
+
if data != b'\x00\x00\x00\x00':
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
return future.failure(error)
@@ -787,26 +802,33 @@ class BrokerConnection(object):
will be failed with this exception.
Default: kafka.errors.KafkaConnectionError.
"""
- if self.state is ConnectionStates.DISCONNECTED:
- if error is not None:
- log.warning('%s: Duplicate close() with error: %s', self, error)
- return
- log.info('%s: Closing connection. %s', self, error or '')
- self.state = ConnectionStates.DISCONNECTING
- self.config['state_change_callback'](self)
- self._update_reconnect_backoff()
- self._close_socket()
- self.state = ConnectionStates.DISCONNECTED
- self._sasl_auth_future = None
- self._protocol = KafkaProtocol(
- client_id=self.config['client_id'],
- api_version=self.config['api_version'])
- if error is None:
- error = Errors.Cancelled(str(self))
- while self.in_flight_requests:
- (_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem()
+ with self._lock:
+ if self.state is ConnectionStates.DISCONNECTED:
+ return
+ log.info('%s: Closing connection. %s', self, error or '')
+ self.state = ConnectionStates.DISCONNECTING
+ self.config['state_change_callback'](self)
+ self._update_reconnect_backoff()
+ self._close_socket()
+ self.state = ConnectionStates.DISCONNECTED
+ self._sasl_auth_future = None
+ self._protocol = KafkaProtocol(
+ client_id=self.config['client_id'],
+ api_version=self.config['api_version'])
+ if error is None:
+ error = Errors.Cancelled(str(self))
+ ifrs = list(self.in_flight_requests.items())
+ self.in_flight_requests.clear()
+ self.config['state_change_callback'](self)
+
+ # drop lock before processing futures
+ for (_correlation_id, (future, _timestamp)) in ifrs:
future.failure(error)
- self.config['state_change_callback'](self)
+
+ def _can_send_recv(self):
+ """Return True iff socket is ready for requests / responses"""
+ return self.state in (ConnectionStates.AUTHENTICATING,
+ ConnectionStates.CONNECTED)
def send(self, request, blocking=True):
"""Queue request for async network send, return Future()"""
@@ -820,18 +842,20 @@ class BrokerConnection(object):
return self._send(request, blocking=blocking)
def _send(self, request, blocking=True):
- assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED)
future = Future()
with self._lock:
+ if not self._can_send_recv():
+ return future.failure(Errors.NodeNotReadyError(str(self)))
+
correlation_id = self._protocol.send_request(request)
- log.debug('%s Request %d: %s', self, correlation_id, request)
- if request.expect_response():
- sent_time = time.time()
- assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
- self.in_flight_requests[correlation_id] = (future, sent_time)
- else:
- future.success(None)
+ log.debug('%s Request %d: %s', self, correlation_id, request)
+ if request.expect_response():
+ sent_time = time.time()
+ assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
+ self.in_flight_requests[correlation_id] = (future, sent_time)
+ else:
+ future.success(None)
# Attempt to replicate behavior from prior to introduction of
# send_pending_requests() / async sends
@@ -842,16 +866,15 @@ class BrokerConnection(object):
def send_pending_requests(self):
"""Can block on network if request is larger than send_buffer_bytes"""
- if self.state not in (ConnectionStates.AUTHENTICATING,
- ConnectionStates.CONNECTED):
- return Errors.NodeNotReadyError(str(self))
- with self._lock:
- data = self._protocol.send_bytes()
try:
- # In the future we might manage an internal write buffer
- # and send bytes asynchronously. For now, just block
- # sending each request payload
- total_bytes = self._send_bytes_blocking(data)
+ with self._lock:
+ if not self._can_send_recv():
+ return Errors.NodeNotReadyError(str(self))
+ # In the future we might manage an internal write buffer
+ # and send bytes asynchronously. For now, just block
+ # sending each request payload
+ data = self._protocol.send_bytes()
+ total_bytes = self._send_bytes_blocking(data)
if self._sensors:
self._sensors.bytes_sent.record(total_bytes)
return total_bytes
@@ -871,18 +894,6 @@ class BrokerConnection(object):
Return list of (response, future) tuples
"""
- if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING:
- log.warning('%s cannot recv: socket not connected', self)
- # If requests are pending, we should close the socket and
- # fail all the pending request futures
- if self.in_flight_requests:
- self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests'))
- return ()
-
- elif not self.in_flight_requests:
- log.warning('%s: No in-flight-requests to recv', self)
- return ()
-
responses = self._recv()
if not responses and self.requests_timed_out():
log.warning('%s timed out after %s ms. Closing connection.',
@@ -895,7 +906,8 @@ class BrokerConnection(object):
# augment respones w/ correlation_id, future, and timestamp
for i, (correlation_id, response) in enumerate(responses):
try:
- (future, timestamp) = self.in_flight_requests.pop(correlation_id)
+ with self._lock:
+ (future, timestamp) = self.in_flight_requests.pop(correlation_id)
except KeyError:
self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
return ()
@@ -911,6 +923,12 @@ class BrokerConnection(object):
def _recv(self):
"""Take all available bytes from socket, return list of any responses from parser"""
recvd = []
+ self._lock.acquire()
+ if not self._can_send_recv():
+ log.warning('%s cannot recv: socket not connected', self)
+ self._lock.release()
+ return ()
+
while len(recvd) < self.config['sock_chunk_buffer_count']:
try:
data = self._sock.recv(self.config['sock_chunk_bytes'])
@@ -920,6 +938,7 @@ class BrokerConnection(object):
# without an exception raised
if not data:
log.error('%s: socket disconnected', self)
+ self._lock.release()
self.close(error=Errors.KafkaConnectionError('socket disconnected'))
return []
else:
@@ -932,11 +951,13 @@ class BrokerConnection(object):
break
log.exception('%s: Error receiving network data'
' closing socket', self)
+ self._lock.release()
self.close(error=Errors.KafkaConnectionError(e))
return []
except BlockingIOError:
if six.PY3:
break
+ self._lock.release()
raise
recvd_data = b''.join(recvd)
@@ -946,20 +967,23 @@ class BrokerConnection(object):
try:
responses = self._protocol.receive_bytes(recvd_data)
except Errors.KafkaProtocolError as e:
+ self._lock.release()
self.close(e)
return []
else:
+ self._lock.release()
return responses
def requests_timed_out(self):
- if self.in_flight_requests:
- get_timestamp = lambda v: v[1]
- oldest_at = min(map(get_timestamp,
- self.in_flight_requests.values()))
- timeout = self.config['request_timeout_ms'] / 1000.0
- if time.time() >= oldest_at + timeout:
- return True
- return False
+ with self._lock:
+ if self.in_flight_requests:
+ get_timestamp = lambda v: v[1]
+ oldest_at = min(map(get_timestamp,
+ self.in_flight_requests.values()))
+ timeout = self.config['request_timeout_ms'] / 1000.0
+ if time.time() >= oldest_at + timeout:
+ return True
+ return False
def _handle_api_version_response(self, response):
error_type = Errors.for_code(response.error_code)