summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2016-04-25 22:33:19 -0700
committerDana Powers <dana.powers@gmail.com>2016-04-25 22:33:19 -0700
commit5b393ac2b51b9100e43299a16d11f70fe117da5c (patch)
tree7e198a0f5be06d60db76d12d44f1cb81d4a8115f
parent161fa6d76b8220954eb52554e4bebc470308172d (diff)
parentfa59d4da590e851a137cb0cf4c93f0089cae6890 (diff)
downloadkafka-python-5b393ac2b51b9100e43299a16d11f70fe117da5c.tar.gz
Merge pull request #671 from dpkp/disconnects
Improve socket disconnect handling
-rw-r--r--kafka/client_async.py4
-rw-r--r--kafka/conn.py26
-rw-r--r--test/test_conn.py81
3 files changed, 87 insertions, 24 deletions
diff --git a/kafka/client_async.py b/kafka/client_async.py
index 6f5d1fe..7719426 100644
--- a/kafka/client_async.py
+++ b/kafka/client_async.py
@@ -142,6 +142,7 @@ class KafkaClient(object):
# Exponential backoff if bootstrap fails
backoff_ms = self.config['reconnect_backoff_ms'] * 2 ** self._bootstrap_fails
next_at = self._last_bootstrap + backoff_ms / 1000.0
+ self._refresh_on_disconnects = False
now = time.time()
if next_at > now:
log.debug("Sleeping %0.4f before bootstrapping again", next_at - now)
@@ -180,6 +181,7 @@ class KafkaClient(object):
log.error('Unable to bootstrap from %s', hosts)
# Max exponential backoff is 2^12, x4000 (50ms -> 200s)
self._bootstrap_fails = min(self._bootstrap_fails + 1, 12)
+ self._refresh_on_disconnects = True
def _can_connect(self, node_id):
if node_id not in self._conns:
@@ -223,7 +225,7 @@ class KafkaClient(object):
except KeyError:
pass
if self._refresh_on_disconnects:
- log.warning("Node %s connect failed -- refreshing metadata", node_id)
+ log.warning("Node %s connection failed -- refreshing metadata", node_id)
self.cluster.request_update()
def _maybe_connect(self, node_id):
diff --git a/kafka/conn.py b/kafka/conn.py
index 3571e90..b5c7ba0 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -381,9 +381,17 @@ class BrokerConnection(object):
# Not receiving is the state of reading the payload header
if not self._receiving:
try:
- # An extremely small, but non-zero, probability that there are
- # more than 0 but not yet 4 bytes available to read
- self._rbuffer.write(self._sock.recv(4 - self._rbuffer.tell()))
+ bytes_to_read = 4 - self._rbuffer.tell()
+ data = self._sock.recv(bytes_to_read)
+ # We expect socket.recv to raise an exception if there is not
+ # enough data to read the full bytes_to_read
+ # but if the socket is disconnected, we will get empty data
+ # without an exception raised
+ if not data:
+ log.error('%s: socket disconnected', self)
+ self.close(error=Errors.ConnectionError('socket disconnected'))
+ return None
+ self._rbuffer.write(data)
except ssl.SSLWantReadError:
return None
except ConnectionError as e:
@@ -411,7 +419,17 @@ class BrokerConnection(object):
if self._receiving:
staged_bytes = self._rbuffer.tell()
try:
- self._rbuffer.write(self._sock.recv(self._next_payload_bytes - staged_bytes))
+ bytes_to_read = self._next_payload_bytes - staged_bytes
+ data = self._sock.recv(bytes_to_read)
+ # We expect socket.recv to raise an exception if there is not
+ # enough data to read the full bytes_to_read
+ # but if the socket is disconnected, we will get empty data
+ # without an exception raised
+ if not data:
+ log.error('%s: socket disconnected', self)
+ self.close(error=Errors.ConnectionError('socket disconnected'))
+ return None
+ self._rbuffer.write(data)
except ssl.SSLWantReadError:
return None
except ConnectionError as e:
diff --git a/test/test_conn.py b/test/test_conn.py
index f0ca2cf..6a3b154 100644
--- a/test/test_conn.py
+++ b/test/test_conn.py
@@ -2,6 +2,7 @@
from __future__ import absolute_import
from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET
+import socket
import time
import pytest
@@ -14,7 +15,7 @@ import kafka.common as Errors
@pytest.fixture
-def socket(mocker):
+def _socket(mocker):
socket = mocker.MagicMock()
socket.connect_ex.return_value = 0
mocker.patch('socket.socket', return_value=socket)
@@ -22,9 +23,8 @@ def socket(mocker):
@pytest.fixture
-def conn(socket):
- from socket import AF_INET
- conn = BrokerConnection('localhost', 9092, AF_INET)
+def conn(_socket):
+ conn = BrokerConnection('localhost', 9092, socket.AF_INET)
return conn
@@ -38,23 +38,23 @@ def conn(socket):
([EALREADY], ConnectionStates.CONNECTING),
([EISCONN], ConnectionStates.CONNECTED)),
])
-def test_connect(socket, conn, states):
+def test_connect(_socket, conn, states):
assert conn.state is ConnectionStates.DISCONNECTED
for errno, state in states:
- socket.connect_ex.side_effect = errno
+ _socket.connect_ex.side_effect = errno
conn.connect()
assert conn.state is state
-def test_connect_timeout(socket, conn):
+def test_connect_timeout(_socket, conn):
assert conn.state is ConnectionStates.DISCONNECTED
# Initial connect returns EINPROGRESS
# immediate inline connect returns EALREADY
# second explicit connect returns EALREADY
# third explicit connect returns EALREADY and times out via last_attempt
- socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY]
+ _socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY]
conn.connect()
assert conn.state is ConnectionStates.CONNECTING
conn.connect()
@@ -108,7 +108,7 @@ def test_send_max_ifr(conn):
assert isinstance(f.exception, Errors.TooManyInFlightRequests)
-def test_send_no_response(socket, conn):
+def test_send_no_response(_socket, conn):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
req = MetadataRequest[0]([])
@@ -116,7 +116,7 @@ def test_send_no_response(socket, conn):
payload_bytes = len(header.encode()) + len(req.encode())
third = payload_bytes // 3
remainder = payload_bytes % 3
- socket.send.side_effect = [4, third, third, third, remainder]
+ _socket.send.side_effect = [4, third, third, third, remainder]
assert len(conn.in_flight_requests) == 0
f = conn.send(req, expect_response=False)
@@ -125,7 +125,7 @@ def test_send_no_response(socket, conn):
assert len(conn.in_flight_requests) == 0
-def test_send_response(socket, conn):
+def test_send_response(_socket, conn):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
req = MetadataRequest[0]([])
@@ -133,7 +133,7 @@ def test_send_response(socket, conn):
payload_bytes = len(header.encode()) + len(req.encode())
third = payload_bytes // 3
remainder = payload_bytes % 3
- socket.send.side_effect = [4, third, third, third, remainder]
+ _socket.send.side_effect = [4, third, third, third, remainder]
assert len(conn.in_flight_requests) == 0
f = conn.send(req)
@@ -141,20 +141,18 @@ def test_send_response(socket, conn):
assert len(conn.in_flight_requests) == 1
-def test_send_error(socket, conn):
+def test_send_error(_socket, conn):
conn.connect()
assert conn.state is ConnectionStates.CONNECTED
req = MetadataRequest[0]([])
- header = RequestHeader(req, client_id=conn.config['client_id'])
try:
- error = ConnectionError
+ _socket.send.side_effect = ConnectionError
except NameError:
- from socket import error
- socket.send.side_effect = error
+ _socket.send.side_effect = socket.error
f = conn.send(req)
assert f.failed() is True
assert isinstance(f.exception, Errors.ConnectionError)
- assert socket.close.call_count == 1
+ assert _socket.close.call_count == 1
assert conn.state is ConnectionStates.DISCONNECTED
@@ -167,7 +165,52 @@ def test_can_send_more(conn):
assert conn.can_send_more() is False
-def test_recv(socket, conn):
+def test_recv_disconnected():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(('127.0.0.1', 0))
+ port = sock.getsockname()[1]
+ sock.listen(5)
+
+ conn = BrokerConnection('127.0.0.1', port, socket.AF_INET)
+ timeout = time.time() + 1
+ while time.time() < timeout:
+ conn.connect()
+ if conn.connected():
+ break
+ else:
+ assert False, 'Connection attempt to local socket timed-out ?'
+
+ conn.send(MetadataRequest[0]([]))
+
+ # Disconnect server socket
+ sock.close()
+
+ # Attempt to receive should mark connection as disconnected
+ assert conn.connected()
+ conn.recv()
+ assert conn.disconnected()
+
+
+def test_recv_disconnected_too(_socket, conn):
+ conn.connect()
+ assert conn.connected()
+
+ req = MetadataRequest[0]([])
+ header = RequestHeader(req, client_id=conn.config['client_id'])
+ payload_bytes = len(header.encode()) + len(req.encode())
+ _socket.send.side_effect = [4, payload_bytes]
+ conn.send(req)
+
+ # Empty data on recv means the socket is disconnected
+ _socket.recv.return_value = b''
+
+ # Attempt to receive should mark connection as disconnected
+ assert conn.connected()
+ conn.recv()
+ assert conn.disconnected()
+
+
+def test_recv(_socket, conn):
pass # TODO