summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2017-03-13 16:41:38 -0700
committerGitHub <noreply@github.com>2017-03-13 16:41:38 -0700
commit47004bbd026fc9267f5cf15b96bb4b2d2bb1dc78 (patch)
treefbe0b05bab468e1f3868c739685ae7f7e0474bf3
parent92a66e3009147a9909f32df2adedce831b7fc7fb (diff)
downloadkafka-python-47004bbd026fc9267f5cf15b96bb4b2d2bb1dc78.tar.gz
Avoid re-encoding for message crc check (#1027)
-rw-r--r--kafka/protocol/message.py18
-rw-r--r--test/test_protocol.py24
2 files changed, 36 insertions, 6 deletions
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py
index ec5ee6c..efdf4fc 100644
--- a/kafka/protocol/message.py
+++ b/kafka/protocol/message.py
@@ -48,6 +48,7 @@ class Message(Struct):
timestamp = int(time.time() * 1000)
self.timestamp = timestamp
self.crc = crc
+ self._validated_crc = None
self.magic = magic
self.attributes = attributes
self.key = key
@@ -85,7 +86,9 @@ class Message(Struct):
@classmethod
def decode(cls, data):
+ _validated_crc = None
if isinstance(data, bytes):
+ _validated_crc = crc32(data[4:])
data = io.BytesIO(data)
# Partial decode required to determine message version
base_fields = cls.SCHEMAS[0].fields[0:3]
@@ -96,14 +99,17 @@ class Message(Struct):
timestamp = fields[0]
else:
timestamp = None
- return cls(fields[-1], key=fields[-2],
- magic=magic, attributes=attributes, crc=crc,
- timestamp=timestamp)
+ msg = cls(fields[-1], key=fields[-2],
+ magic=magic, attributes=attributes, crc=crc,
+ timestamp=timestamp)
+ msg._validated_crc = _validated_crc
+ return msg
def validate_crc(self):
- raw_msg = self._encode_self(recalc_crc=False)
- crc = crc32(raw_msg[4:])
- if crc == self.crc:
+ if self._validated_crc is None:
+ raw_msg = self._encode_self(recalc_crc=False)
+ self._validated_crc = crc32(raw_msg[4:])
+ if self.crc == self._validated_crc:
return True
return False
diff --git a/test/test_protocol.py b/test/test_protocol.py
index aa3dd17..0203614 100644
--- a/test/test_protocol.py
+++ b/test/test_protocol.py
@@ -67,6 +67,30 @@ def test_decode_message():
assert decoded_message == msg
+def test_decode_message_validate_crc():
+ encoded = b''.join([
+ struct.pack('>i', -1427009701), # CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 3), # Length of key
+ b'key', # key
+ struct.pack('>i', 4), # Length of value
+ b'test', # value
+ ])
+ decoded_message = Message.decode(encoded)
+ assert decoded_message.validate_crc() is True
+
+ encoded = b''.join([
+ struct.pack('>i', 1234), # Incorrect CRC
+ struct.pack('>bb', 0, 0), # Magic, flags
+ struct.pack('>i', 3), # Length of key
+ b'key', # key
+ struct.pack('>i', 4), # Length of value
+ b'test', # value
+ ])
+ decoded_message = Message.decode(encoded)
+ assert decoded_message.validate_crc() is False
+
+
def test_encode_message_set():
messages = [
Message(b'v1', key=b'k1'),