summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2016-07-14 22:22:52 -0700
committerGitHub <noreply@github.com>2016-07-14 22:22:52 -0700
commit916c25726f6238c5af92728aa8df8d8fddd809a7 (patch)
tree9cffe99e43c7df93e421611c539dfc7cdeedac67
parent1eb7e05c323322818fb60192f638d6b83f2fd1ef (diff)
parentca9d2fabc352f5b6f2709295df7382f5dd7bfc97 (diff)
downloadkafka-python-916c25726f6238c5af92728aa8df8d8fddd809a7.tar.gz
Merge pull request #757 from dpkp/double_compression
Fix bug causing KafkaProducer to double-compress message batches
-rw-r--r--kafka/producer/buffer.py39
-rw-r--r--test/test_buffer.py70
2 files changed, 93 insertions, 16 deletions
diff --git a/kafka/producer/buffer.py b/kafka/producer/buffer.py
index 5dc2e1f..5fcb35f 100644
--- a/kafka/producer/buffer.py
+++ b/kafka/producer/buffer.py
@@ -89,22 +89,29 @@ class MessageSetBuffer(object):
return self._buffer.tell() >= self._batch_size
def close(self):
- if self._compressor:
- # TODO: avoid copies with bytearray / memoryview
- self._buffer.seek(4)
- msg = Message(self._compressor(self._buffer.read()),
- attributes=self._compression_attributes,
- magic=self._message_version)
- encoded = msg.encode()
- self._buffer.seek(4)
- self._buffer.write(Int64.encode(0)) # offset 0 for wrapper msg
- self._buffer.write(Int32.encode(len(encoded)))
- self._buffer.write(encoded)
-
- # Update the message set size, and return ready for full read()
- size = self._buffer.tell() - 4
- self._buffer.seek(0)
- self._buffer.write(Int32.encode(size))
+ # This method may be called multiple times on the same batch
+ # i.e., on retries
+ # we need to make sure we only close it out once
+ # otherwise compressed messages may be double-compressed
+ # see Issue 718
+ if not self._closed:
+ if self._compressor:
+ # TODO: avoid copies with bytearray / memoryview
+ self._buffer.seek(4)
+ msg = Message(self._compressor(self._buffer.read()),
+ attributes=self._compression_attributes,
+ magic=self._message_version)
+ encoded = msg.encode()
+ self._buffer.seek(4)
+ self._buffer.write(Int64.encode(0)) # offset 0 for wrapper msg
+ self._buffer.write(Int32.encode(len(encoded)))
+ self._buffer.write(encoded)
+
+ # Update the message set size, and return ready for full read()
+ size = self._buffer.tell() - 4
+ self._buffer.seek(0)
+ self._buffer.write(Int32.encode(size))
+
self._buffer.seek(0)
self._closed = True
diff --git a/test/test_buffer.py b/test/test_buffer.py
new file mode 100644
index 0000000..c8e283d
--- /dev/null
+++ b/test/test_buffer.py
@@ -0,0 +1,70 @@
+# pylint: skip-file
+from __future__ import absolute_import
+
+import io
+
+import pytest
+
+from kafka.producer.buffer import MessageSetBuffer
+from kafka.protocol.message import Message, MessageSet
+
+
+def test_buffer_close():
+ records = MessageSetBuffer(io.BytesIO(), 100000)
+ orig_msg = Message(b'foobar')
+ records.append(1234, orig_msg)
+ records.close()
+
+ msgset = MessageSet.decode(records.buffer())
+ assert len(msgset) == 1
+ (offset, size, msg) = msgset[0]
+ assert offset == 1234
+ assert msg == orig_msg
+
+ # Closing again should work fine
+ records.close()
+
+ msgset = MessageSet.decode(records.buffer())
+ assert len(msgset) == 1
+ (offset, size, msg) = msgset[0]
+ assert offset == 1234
+ assert msg == orig_msg
+
+
+@pytest.mark.parametrize('compression', [
+ 'gzip',
+ 'snappy',
+ pytest.mark.skipif("sys.version_info < (2,7)")('lz4'), # lz4tools does not work on py26
+])
+def test_compressed_buffer_close(compression):
+ records = MessageSetBuffer(io.BytesIO(), 100000, compression_type=compression)
+ orig_msg = Message(b'foobar')
+ records.append(1234, orig_msg)
+ records.close()
+
+ msgset = MessageSet.decode(records.buffer())
+ assert len(msgset) == 1
+ (offset, size, msg) = msgset[0]
+ assert offset == 0
+ assert msg.is_compressed()
+
+ msgset = msg.decompress()
+ (offset, size, msg) = msgset[0]
+ assert not msg.is_compressed()
+ assert offset == 1234
+ assert msg == orig_msg
+
+ # Closing again should work fine
+ records.close()
+
+ msgset = MessageSet.decode(records.buffer())
+ assert len(msgset) == 1
+ (offset, size, msg) = msgset[0]
+ assert offset == 0
+ assert msg.is_compressed()
+
+ msgset = msg.decompress()
+ (offset, size, msg) = msgset[0]
+ assert not msg.is_compressed()
+ assert offset == 1234
+ assert msg == orig_msg