diff options
author | Dana Powers <dana.powers@gmail.com> | 2017-03-13 16:41:38 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-13 16:41:38 -0700 |
commit | 47004bbd026fc9267f5cf15b96bb4b2d2bb1dc78 (patch) | |
tree | fbe0b05bab468e1f3868c739685ae7f7e0474bf3 | |
parent | 92a66e3009147a9909f32df2adedce831b7fc7fb (diff) | |
download | kafka-python-47004bbd026fc9267f5cf15b96bb4b2d2bb1dc78.tar.gz |
Avoid re-encoding for message crc check (#1027)
-rw-r--r-- | kafka/protocol/message.py | 18 | ||||
-rw-r--r-- | test/test_protocol.py | 24 |
2 files changed, 36 insertions, 6 deletions
diff --git a/kafka/protocol/message.py b/kafka/protocol/message.py index ec5ee6c..efdf4fc 100644 --- a/kafka/protocol/message.py +++ b/kafka/protocol/message.py @@ -48,6 +48,7 @@ class Message(Struct): timestamp = int(time.time() * 1000) self.timestamp = timestamp self.crc = crc + self._validated_crc = None self.magic = magic self.attributes = attributes self.key = key @@ -85,7 +86,9 @@ class Message(Struct): @classmethod def decode(cls, data): + _validated_crc = None if isinstance(data, bytes): + _validated_crc = crc32(data[4:]) data = io.BytesIO(data) # Partial decode required to determine message version base_fields = cls.SCHEMAS[0].fields[0:3] @@ -96,14 +99,17 @@ class Message(Struct): timestamp = fields[0] else: timestamp = None - return cls(fields[-1], key=fields[-2], - magic=magic, attributes=attributes, crc=crc, - timestamp=timestamp) + msg = cls(fields[-1], key=fields[-2], + magic=magic, attributes=attributes, crc=crc, + timestamp=timestamp) + msg._validated_crc = _validated_crc + return msg def validate_crc(self): - raw_msg = self._encode_self(recalc_crc=False) - crc = crc32(raw_msg[4:]) - if crc == self.crc: + if self._validated_crc is None: + raw_msg = self._encode_self(recalc_crc=False) + self._validated_crc = crc32(raw_msg[4:]) + if self.crc == self._validated_crc: return True return False diff --git a/test/test_protocol.py b/test/test_protocol.py index aa3dd17..0203614 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -67,6 +67,30 @@ def test_decode_message(): assert decoded_message == msg +def test_decode_message_validate_crc(): + encoded = b''.join([ + struct.pack('>i', -1427009701), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 3), # Length of key + b'key', # key + struct.pack('>i', 4), # Length of value + b'test', # value + ]) + decoded_message = Message.decode(encoded) + assert decoded_message.validate_crc() is True + + encoded = b''.join([ + struct.pack('>i', 1234), # Incorrect CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 3), # Length of key + b'key', # key + struct.pack('>i', 4), # Length of value + b'test', # value + ]) + decoded_message = Message.decode(encoded) + assert decoded_message.validate_crc() is False + + def test_encode_message_set(): messages = [ Message(b'v1', key=b'k1'), |