summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2016-06-01 16:49:17 -0700
committerDana Powers <dana.powers@gmail.com>2016-06-01 16:49:17 -0700
commit8805d30b781b95786e8f6fc2fa0a24e6e2bd270d (patch)
tree00e11588d3c66dace75e19b9d2788f98c4e92de3
parent644a1141b0dd22e618277afe7b171b2f3fb8ca2d (diff)
downloadkafka-python-8805d30b781b95786e8f6fc2fa0a24e6e2bd270d.tar.gz
Fix regression in MessageSet decoding wrt PartialMessages (#716)
-rw-r--r--kafka/protocol/message.py9
-rw-r--r--test/test_protocol.py102
2 files changed, 107 insertions, 4 deletions
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py
index 78840fc..656c131 100644
--- a/kafka/protocol/message.py
+++ b/kafka/protocol/message.py
@@ -169,14 +169,17 @@ class MessageSet(AbstractType):
data = io.BytesIO(data)
if bytes_to_read is None:
bytes_to_read = Int32.decode(data)
- items = []
# if FetchRequest max_bytes is smaller than the available message set
# the server returns partial data for the final message
+ # So create an internal buffer to avoid over-reading
+ raw = io.BytesIO(data.read(bytes_to_read))
+
+ items = []
while bytes_to_read:
try:
- offset = Int64.decode(data)
- msg_bytes = Bytes.decode(data)
+ offset = Int64.decode(raw)
+ msg_bytes = Bytes.decode(raw)
bytes_to_read -= 8 + 4 + len(msg_bytes)
items.append((offset, len(msg_bytes), Message.decode(msg_bytes)))
except ValueError:
diff --git a/test/test_protocol.py b/test/test_protocol.py
index 247fcc3..2b52f48 100644
--- a/test/test_protocol.py
+++ b/test/test_protocol.py
@@ -1,4 +1,5 @@
#pylint: skip-file
+import io
import struct
import pytest
@@ -6,7 +7,9 @@ import six
from kafka.protocol.api import RequestHeader
from kafka.protocol.commit import GroupCoordinatorRequest
-from kafka.protocol.message import Message, MessageSet
+from kafka.protocol.fetch import FetchResponse
+from kafka.protocol.message import Message, MessageSet, PartialMessage
+from kafka.protocol.types import Int16, Int32, Int64, String
def test_create_message():
@@ -144,3 +147,100 @@ def test_encode_message_header():
req = GroupCoordinatorRequest[0]('foo')
header = RequestHeader(req, correlation_id=4, client_id='client3')
assert header.encode() == expect
+
+
+def test_decode_message_set_partial():
+ encoded = b''.join([
+ struct.pack('>q', 0), # Msg Offset
+ struct.pack('>i', 18), # Msg Size
+ struct.pack('>i', 1474775406), # CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 2), # Length of key
+ b'k1', # Key
+ struct.pack('>i', 2), # Length of value
+ b'v1', # Value
+
+ struct.pack('>q', 1), # Msg Offset
+ struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size)
+ struct.pack('>i', -16383415), # CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 2), # Length of key
+ b'k2', # Key
+ struct.pack('>i', 8), # Length of value
+ b'ar', # Value (truncated)
+ ])
+
+ msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded))
+ assert len(msgs) == 2
+ msg1, msg2 = msgs
+
+ returned_offset1, message1_size, decoded_message1 = msg1
+ returned_offset2, message2_size, decoded_message2 = msg2
+
+ assert returned_offset1 == 0
+ message1 = Message(b'v1', key=b'k1')
+ message1.encode()
+ assert decoded_message1 == message1
+
+ assert returned_offset2 is None
+ assert message2_size is None
+ assert decoded_message2 == PartialMessage()
+
+
+def test_decode_fetch_response_partial():
+ encoded = b''.join([
+ Int32.encode(1), # Num Topics (Array)
+ String('utf-8').encode('foobar'),
+ Int32.encode(2), # Num Partitions (Array)
+ Int32.encode(0), # Partition id
+ Int16.encode(0), # Error Code
+ Int64.encode(1234), # Highwater offset
+ Int32.encode(52), # MessageSet size
+ Int64.encode(0), # Msg Offset
+ Int32.encode(18), # Msg Size
+ struct.pack('>i', 1474775406), # CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 2), # Length of key
+ b'k1', # Key
+ struct.pack('>i', 2), # Length of value
+ b'v1', # Value
+
+ Int64.encode(1), # Msg Offset
+ struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size)
+ struct.pack('>i', -16383415), # CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 2), # Length of key
+ b'k2', # Key
+ struct.pack('>i', 8), # Length of value
+ b'ar', # Value (truncated)
+ Int32.encode(1),
+ Int16.encode(0),
+ Int64.encode(2345),
+ Int32.encode(52), # MessageSet size
+ Int64.encode(0), # Msg Offset
+ Int32.encode(18), # Msg Size
+ struct.pack('>i', 1474775406), # CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 2), # Length of key
+ b'k1', # Key
+ struct.pack('>i', 2), # Length of value
+ b'v1', # Value
+
+ Int64.encode(1), # Msg Offset
+ struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size)
+ struct.pack('>i', -16383415), # CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 2), # Length of key
+ b'k2', # Key
+ struct.pack('>i', 8), # Length of value
+ b'ar', # Value (truncated)
+ ])
+
+ resp = FetchResponse[0].decode(io.BytesIO(encoded))
+ assert len(resp.topics) == 1
+ topic, partitions = resp.topics[0]
+ assert topic == 'foobar'
+ assert len(partitions) == 2
+ m1 = partitions[0][3]
+ assert len(m1) == 2
+ assert m1[1] == (None, None, PartialMessage())