summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2017-01-07 10:52:01 -0800
committerDana Powers <dana.powers@gmail.com>2017-08-13 19:47:36 -0700
commitdf227a6015992d8ddb79f5faa3f782d0042edd6b (patch)
tree064e31324df666ecc7575d5033948600f5df52fd
parentf13ce1d87919ab763b02e38c17080580e199b4af (diff)
downloadkafka-python-receive_bytes_pipe.tar.gz
BrokerConnection.receive_bytes(data) -> response eventsreceive_bytes_pipe
-rw-r--r--kafka/client_async.py16
-rw-r--r--kafka/conn.py161
-rw-r--r--kafka/protocol/message.py7
3 files changed, 91 insertions, 93 deletions
diff --git a/kafka/client_async.py b/kafka/client_async.py
index ecd2cea..75b169e 100644
--- a/kafka/client_async.py
+++ b/kafka/client_async.py
@@ -605,25 +605,14 @@ class KafkaClient(object):
continue
self._idle_expiry_manager.update(conn.node_id)
-
- # Accumulate as many responses as the connection has pending
- while conn.in_flight_requests:
- response = conn.recv() # Note: conn.recv runs callbacks / errbacks
-
- # Incomplete responses are buffered internally
- # while conn.in_flight_requests retains the request
- if not response:
- break
- responses.append(response)
+ responses.extend(conn.recv()) # Note: conn.recv runs callbacks / errbacks
# Check for additional pending SSL bytes
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
# TODO: optimize
for conn in self._conns.values():
if conn not in processed and conn.connected() and conn._sock.pending():
- response = conn.recv()
- if response:
- responses.append(response)
+ responses.extend(conn.recv())
for conn in six.itervalues(self._conns):
if conn.requests_timed_out():
@@ -635,6 +624,7 @@ class KafkaClient(object):
if self._sensors:
self._sensors.io_time.record((time.time() - end_select) * 1000000000)
+
self._maybe_close_oldest_connection()
return responses
diff --git a/kafka/conn.py b/kafka/conn.py
index 61d63bf..949fca5 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -4,7 +4,6 @@ import collections
import copy
import errno
import logging
-import io
from random import shuffle, uniform
import socket
import time
@@ -18,6 +17,7 @@ from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.protocol.api import RequestHeader
from kafka.protocol.admin import SaslHandShakeRequest
from kafka.protocol.commit import GroupCoordinatorResponse, OffsetFetchRequest
+from kafka.protocol.frame import KafkaBytes
from kafka.protocol.metadata import MetadataRequest
from kafka.protocol.fetch import FetchRequest
from kafka.protocol.types import Int32
@@ -234,9 +234,9 @@ class BrokerConnection(object):
if self.config['ssl_context'] is not None:
self._ssl_context = self.config['ssl_context']
self._sasl_auth_future = None
- self._rbuffer = io.BytesIO()
+ self._header = KafkaBytes(4)
+ self._rbuffer = None
self._receiving = False
- self._next_payload_bytes = 0
self.last_attempt = 0
self._processing = False
self._correlation_id = 0
@@ -629,10 +629,7 @@ class BrokerConnection(object):
self.state = ConnectionStates.DISCONNECTED
self.last_attempt = time.time()
self._sasl_auth_future = None
- self._receiving = False
- self._next_payload_bytes = 0
- self._rbuffer.seek(0)
- self._rbuffer.truncate()
+ self._reset_buffer()
if error is None:
error = Errors.Cancelled(str(self))
while self.in_flight_requests:
@@ -640,6 +637,11 @@ class BrokerConnection(object):
ifr.future.failure(error)
self.config['state_change_callback'](self)
+ def _reset_buffer(self):
+ self._receiving = False
+ self._header.seek(0)
+ self._rbuffer = None
+
def send(self, request):
"""send request, return Future()
@@ -713,11 +715,11 @@ class BrokerConnection(object):
# fail all the pending request futures
if self.in_flight_requests:
self.close(Errors.ConnectionError('Socket not connected during recv with in-flight-requests'))
- return None
+ return ()
elif not self.in_flight_requests:
log.warning('%s: No in-flight-requests to recv', self)
- return None
+ return ()
response = self._recv()
if not response and self.requests_timed_out():
@@ -726,15 +728,15 @@ class BrokerConnection(object):
self.close(error=Errors.RequestTimedOutError(
'Request timed out after %s ms' %
self.config['request_timeout_ms']))
- return None
+ return ()
return response
def _recv(self):
- # Not receiving is the state of reading the payload header
- if not self._receiving:
+ responses = []
+ SOCK_CHUNK_BYTES = 4096
+ while True:
try:
- bytes_to_read = 4 - self._rbuffer.tell()
- data = self._sock.recv(bytes_to_read)
+ data = self._sock.recv(SOCK_CHUNK_BYTES)
# We expect socket.recv to raise an exception if there is not
# enough data to read the full bytes_to_read
# but if the socket is disconnected, we will get empty data
@@ -742,87 +744,92 @@ class BrokerConnection(object):
if not data:
log.error('%s: socket disconnected', self)
self.close(error=Errors.ConnectionError('socket disconnected'))
- return None
- self._rbuffer.write(data)
+ break
+ else:
+ responses.extend(self.receive_bytes(data))
+ if len(data) < SOCK_CHUNK_BYTES:
+ break
except SSLWantReadError:
- return None
+ break
except ConnectionError as e:
if six.PY2 and e.errno == errno.EWOULDBLOCK:
- return None
- log.exception('%s: Error receiving 4-byte payload header -'
+ break
+ log.exception('%s: Error receiving network data'
' closing socket', self)
self.close(error=Errors.ConnectionError(e))
- return None
- except BlockingIOError:
- if six.PY3:
- return None
- raise
-
- if self._rbuffer.tell() == 4:
- self._rbuffer.seek(0)
- self._next_payload_bytes = Int32.decode(self._rbuffer)
- # reset buffer and switch state to receiving payload bytes
- self._rbuffer.seek(0)
- self._rbuffer.truncate()
- self._receiving = True
- elif self._rbuffer.tell() > 4:
- raise Errors.KafkaError('this should not happen - are you threading?')
-
- if self._receiving:
- staged_bytes = self._rbuffer.tell()
- try:
- bytes_to_read = self._next_payload_bytes - staged_bytes
- data = self._sock.recv(bytes_to_read)
- # We expect socket.recv to raise an exception if there is not
- # enough data to read the full bytes_to_read
- # but if the socket is disconnected, we will get empty data
- # without an exception raised
- if bytes_to_read and not data:
- log.error('%s: socket disconnected', self)
- self.close(error=Errors.ConnectionError('socket disconnected'))
- return None
- self._rbuffer.write(data)
- except SSLWantReadError:
- return None
- except ConnectionError as e:
- # Extremely small chance that we have exactly 4 bytes for a
- # header, but nothing to read in the body yet
- if six.PY2 and e.errno == errno.EWOULDBLOCK:
- return None
- log.exception('%s: Error in recv', self)
- self.close(error=Errors.ConnectionError(e))
- return None
+ break
except BlockingIOError:
if six.PY3:
- return None
+ break
raise
+ return responses
- staged_bytes = self._rbuffer.tell()
- if staged_bytes > self._next_payload_bytes:
- self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
-
- if staged_bytes != self._next_payload_bytes:
- return None
+ def receive_bytes(self, data):
+ i = 0
+ n = len(data)
+ responses = []
+ if self._sensors:
+ self._sensors.bytes_received.record(n)
+ while i < n:
+
+ # Not receiving is the state of reading the payload header
+ if not self._receiving:
+ bytes_to_read = min(4 - self._header.tell(), n - i)
+ self._header.write(data[i:i+bytes_to_read])
+ i += bytes_to_read
+
+ if self._header.tell() == 4:
+ self._header.seek(0)
+ nbytes = Int32.decode(self._header)
+ # reset buffer and switch state to receiving payload bytes
+ self._rbuffer = KafkaBytes(nbytes)
+ self._receiving = True
+ elif self._header.tell() > 4:
+ raise Errors.KafkaError('this should not happen - are you threading?')
+
+
+ if self._receiving:
+ total_bytes = len(self._rbuffer)
+ staged_bytes = self._rbuffer.tell()
+ bytes_to_read = min(total_bytes - staged_bytes, n - i)
+ self._rbuffer.write(data[i:i+bytes_to_read])
+ i += bytes_to_read
+
+ staged_bytes = self._rbuffer.tell()
+ if staged_bytes > total_bytes:
+ self.close(error=Errors.KafkaError('Receive buffer has more bytes than expected?'))
+
+ if staged_bytes != total_bytes:
+ break
- self._receiving = False
- self._next_payload_bytes = 0
- if self._sensors:
- self._sensors.bytes_received.record(4 + self._rbuffer.tell())
- self._rbuffer.seek(0)
- response = self._process_response(self._rbuffer)
- self._rbuffer.seek(0)
- self._rbuffer.truncate()
- return response
+ self._receiving = False
+ self._rbuffer.seek(0)
+ resp = self._process_response(self._rbuffer)
+ if resp is not None:
+ responses.append(resp)
+ self._reset_buffer()
+ return responses
def _process_response(self, read_buffer):
assert not self._processing, 'Recursion not supported'
self._processing = True
- ifr = self.in_flight_requests.popleft()
+ recv_correlation_id = Int32.decode(read_buffer)
+
+ if not self.in_flight_requests:
+ error = Errors.CorrelationIdError(
+ '%s: No in-flight-request found for server response'
+ ' with correlation ID %d'
+ % (self, recv_correlation_id))
+ self.close(error)
+ self._processing = False
+ return None
+ else:
+ ifr = self.in_flight_requests.popleft()
+
if self._sensors:
self._sensors.request_time.record((time.time() - ifr.timestamp) * 1000)
# verify send/recv correlation ids match
- recv_correlation_id = Int32.decode(read_buffer)
# 0.8.2 quirk
if (self.config['api_version'] == (0, 8, 2) and
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py
index efdf4fc..70d5b36 100644
--- a/kafka/protocol/message.py
+++ b/kafka/protocol/message.py
@@ -6,6 +6,7 @@ import time
from ..codec import (has_gzip, has_snappy, has_lz4,
gzip_decode, snappy_decode,
lz4_decode, lz4_decode_old_kafka)
+from .frame import KafkaBytes
from .struct import Struct
from .types import (
Int8, Int32, Int64, Bytes, Schema, AbstractType
@@ -155,10 +156,10 @@ class MessageSet(AbstractType):
@classmethod
def encode(cls, items):
# RecordAccumulator encodes messagesets internally
- if isinstance(items, io.BytesIO):
+ if isinstance(items, (io.BytesIO, KafkaBytes)):
size = Int32.decode(items)
# rewind and return all the bytes
- items.seek(-4, 1)
+ items.seek(items.tell() - 4)
return items.read(size + 4)
encoded_values = []
@@ -198,7 +199,7 @@ class MessageSet(AbstractType):
@classmethod
def repr(cls, messages):
- if isinstance(messages, io.BytesIO):
+ if isinstance(messages, (KafkaBytes, io.BytesIO)):
offset = messages.tell()
decoded = cls.decode(messages)
messages.seek(offset)