diff options
author | Dana Powers <dana.powers@gmail.com> | 2016-06-01 16:49:17 -0700 |
---|---|---|
committer | Dana Powers <dana.powers@gmail.com> | 2016-06-01 16:49:17 -0700 |
commit | 8805d30b781b95786e8f6fc2fa0a24e6e2bd270d (patch) | |
tree | 00e11588d3c66dace75e19b9d2788f98c4e92de3 | |
parent | 644a1141b0dd22e618277afe7b171b2f3fb8ca2d (diff) | |
download | kafka-python-8805d30b781b95786e8f6fc2fa0a24e6e2bd270d.tar.gz |
Fix regression in MessageSet decoding wrt PartialMessages (#716)
-rw-r--r-- | kafka/protocol/message.py | 9 | ||||
-rw-r--r-- | test/test_protocol.py | 102 |
2 files changed, 107 insertions, 4 deletions
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py index 78840fc..656c131 100644 --- a/kafka/protocol/message.py +++ b/kafka/protocol/message.py @@ -169,14 +169,17 @@ class MessageSet(AbstractType): data = io.BytesIO(data) if bytes_to_read is None: bytes_to_read = Int32.decode(data) - items = [] # if FetchRequest max_bytes is smaller than the available message set # the server returns partial data for the final message + # So create an internal buffer to avoid over-reading + raw = io.BytesIO(data.read(bytes_to_read)) + + items = [] while bytes_to_read: try: - offset = Int64.decode(data) - msg_bytes = Bytes.decode(data) + offset = Int64.decode(raw) + msg_bytes = Bytes.decode(raw) bytes_to_read -= 8 + 4 + len(msg_bytes) items.append((offset, len(msg_bytes), Message.decode(msg_bytes))) except ValueError: diff --git a/test/test_protocol.py b/test/test_protocol.py index 247fcc3..2b52f48 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -1,4 +1,5 @@ #pylint: skip-file +import io import struct import pytest @@ -6,7 +7,9 @@ import six from kafka.protocol.api import RequestHeader from kafka.protocol.commit import GroupCoordinatorRequest -from kafka.protocol.message import Message, MessageSet +from kafka.protocol.fetch import FetchResponse +from kafka.protocol.message import Message, MessageSet, PartialMessage +from kafka.protocol.types import Int16, Int32, Int64, String def test_create_message(): @@ -144,3 +147,100 @@ def test_encode_message_header(): req = GroupCoordinatorRequest[0]('foo') header = RequestHeader(req, correlation_id=4, client_id='client3') assert header.encode() == expect + + +def test_decode_message_set_partial(): + encoded = b''.join([ + struct.pack('>q', 0), # Msg Offset + struct.pack('>i', 18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + struct.pack('>q', 1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + ]) + + msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) + assert len(msgs) == 2 + msg1, msg2 = msgs + + returned_offset1, message1_size, decoded_message1 = msg1 + returned_offset2, message2_size, decoded_message2 = msg2 + + assert returned_offset1 == 0 + message1 = Message(b'v1', key=b'k1') + message1.encode() + assert decoded_message1 == message1 + + assert returned_offset2 is None + assert message2_size is None + assert decoded_message2 == PartialMessage() + + +def test_decode_fetch_response_partial(): + encoded = b''.join([ + Int32.encode(1), # Num Topics (Array) + String('utf-8').encode('foobar'), + Int32.encode(2), # Num Partitions (Array) + Int32.encode(0), # Partition id + Int16.encode(0), # Error Code + Int64.encode(1234), # Highwater offset + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + Int64.encode(1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + Int32.encode(1), + Int16.encode(0), + Int64.encode(2345), + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + Int64.encode(1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + ]) + + resp = FetchResponse[0].decode(io.BytesIO(encoded)) + assert len(resp.topics) == 1 + topic, partitions = resp.topics[0] + assert topic == 'foobar' + assert len(partitions) == 2 + m1 = partitions[0][3] + assert len(m1) == 2 + assert m1[1] == (None, None, PartialMessage()) |