diff options
author | Dana Powers <dana.powers@gmail.com> | 2016-07-14 22:22:52 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-07-14 22:22:52 -0700 |
commit | 916c25726f6238c5af92728aa8df8d8fddd809a7 (patch) | |
tree | 9cffe99e43c7df93e421611c539dfc7cdeedac67 | |
parent | 1eb7e05c323322818fb60192f638d6b83f2fd1ef (diff) | |
parent | ca9d2fabc352f5b6f2709295df7382f5dd7bfc97 (diff) | |
download | kafka-python-916c25726f6238c5af92728aa8df8d8fddd809a7.tar.gz |
Merge pull request #757 from dpkp/double_compression
Fix bug causing KafkaProducer to double-compress message batches
-rw-r--r-- | kafka/producer/buffer.py | 39 | ||||
-rw-r--r-- | test/test_buffer.py | 70 |
2 files changed, 93 insertions, 16 deletions
diff --git a/kafka/producer/buffer.py b/kafka/producer/buffer.py index 5dc2e1f..5fcb35f 100644 --- a/kafka/producer/buffer.py +++ b/kafka/producer/buffer.py @@ -89,22 +89,29 @@ class MessageSetBuffer(object): return self._buffer.tell() >= self._batch_size def close(self): - if self._compressor: - # TODO: avoid copies with bytearray / memoryview - self._buffer.seek(4) - msg = Message(self._compressor(self._buffer.read()), - attributes=self._compression_attributes, - magic=self._message_version) - encoded = msg.encode() - self._buffer.seek(4) - self._buffer.write(Int64.encode(0)) # offset 0 for wrapper msg - self._buffer.write(Int32.encode(len(encoded))) - self._buffer.write(encoded) - - # Update the message set size, and return ready for full read() - size = self._buffer.tell() - 4 - self._buffer.seek(0) - self._buffer.write(Int32.encode(size)) + # This method may be called multiple times on the same batch + # i.e., on retries + # we need to make sure we only close it out once + # otherwise compressed messages may be double-compressed + # see Issue 718 + if not self._closed: + if self._compressor: + # TODO: avoid copies with bytearray / memoryview + self._buffer.seek(4) + msg = Message(self._compressor(self._buffer.read()), + attributes=self._compression_attributes, + magic=self._message_version) + encoded = msg.encode() + self._buffer.seek(4) + self._buffer.write(Int64.encode(0)) # offset 0 for wrapper msg + self._buffer.write(Int32.encode(len(encoded))) + self._buffer.write(encoded) + + # Update the message set size, and return ready for full read() + size = self._buffer.tell() - 4 + self._buffer.seek(0) + self._buffer.write(Int32.encode(size)) + self._buffer.seek(0) self._closed = True diff --git a/test/test_buffer.py b/test/test_buffer.py new file mode 100644 index 0000000..c8e283d --- /dev/null +++ b/test/test_buffer.py @@ -0,0 +1,70 @@ +# pylint: skip-file +from __future__ import absolute_import + +import io + +import pytest + +from kafka.producer.buffer import MessageSetBuffer +from kafka.protocol.message import Message, MessageSet + + +def test_buffer_close(): + records = MessageSetBuffer(io.BytesIO(), 100000) + orig_msg = Message(b'foobar') + records.append(1234, orig_msg) + records.close() + + msgset = MessageSet.decode(records.buffer()) + assert len(msgset) == 1 + (offset, size, msg) = msgset[0] + assert offset == 1234 + assert msg == orig_msg + + # Closing again should work fine + records.close() + + msgset = MessageSet.decode(records.buffer()) + assert len(msgset) == 1 + (offset, size, msg) = msgset[0] + assert offset == 1234 + assert msg == orig_msg + + +@pytest.mark.parametrize('compression', [ + 'gzip', + 'snappy', + pytest.mark.skipif("sys.version_info < (2,7)")('lz4'), # lz4tools does not work on py26 +]) +def test_compressed_buffer_close(compression): + records = MessageSetBuffer(io.BytesIO(), 100000, compression_type=compression) + orig_msg = Message(b'foobar') + records.append(1234, orig_msg) + records.close() + + msgset = MessageSet.decode(records.buffer()) + assert len(msgset) == 1 + (offset, size, msg) = msgset[0] + assert offset == 0 + assert msg.is_compressed() + + msgset = msg.decompress() + (offset, size, msg) = msgset[0] + assert not msg.is_compressed() + assert offset == 1234 + assert msg == orig_msg + + # Closing again should work fine + records.close() + + msgset = MessageSet.decode(records.buffer()) + assert len(msgset) == 1 + (offset, size, msg) = msgset[0] + assert offset == 0 + assert msg.is_compressed() + + msgset = msg.decompress() + (offset, size, msg) = msgset[0] + assert not msg.is_compressed() + assert offset == 1234 + assert msg == orig_msg |