diff options
-rw-r--r-- | msgpack/_packer.pyx | 14 | ||||
-rw-r--r-- | msgpack/fallback.py | 6 | ||||
-rw-r--r-- | test/test_pack.py | 7 |
3 files changed, 23 insertions, 4 deletions
diff --git a/msgpack/_packer.pyx b/msgpack/_packer.pyx index f24aa70..5a81709 100644 --- a/msgpack/_packer.pyx +++ b/msgpack/_packer.pyx @@ -10,6 +10,8 @@ from msgpack import ExtType cdef extern from "Python.h": int PyMemoryView_Check(object obj) + int PyByteArray_Check(object obj) + int PyByteArray_CheckExact(object obj) cdef extern from "pack.h": @@ -39,6 +41,14 @@ cdef int DEFAULT_RECURSE_LIMIT=511 cdef size_t ITEM_LIMIT = (2**32)-1 +cdef inline int PyBytesLike_Check(object o): + return PyBytes_Check(o) or PyByteArray_Check(o) + + +cdef inline int PyBytesLike_CheckExact(object o): + return PyBytes_CheckExact(o) or PyByteArray_CheckExact(o) + + cdef class Packer(object): """ MessagePack Packer @@ -174,10 +184,10 @@ cdef class Packer(object): else: dval = o ret = msgpack_pack_double(&self.pk, dval) - elif PyBytes_CheckExact(o) if strict_types else PyBytes_Check(o): + elif PyBytesLike_CheckExact(o) if strict_types else PyBytesLike_Check(o): L = len(o) if L > ITEM_LIMIT: - raise PackValueError("bytes is too large") + raise PackValueError("%s is too large" % type(o).__name__) rawval = o ret = msgpack_pack_bin(&self.pk, L) if ret == 0: diff --git a/msgpack/fallback.py b/msgpack/fallback.py index 508fd06..a02cbe1 100644 --- a/msgpack/fallback.py +++ b/msgpack/fallback.py @@ -38,6 +38,8 @@ if hasattr(sys, 'pypy_version_info'): def write(self, s): if isinstance(s, memoryview): s = s.tobytes() + elif isinstance(s, bytearray): + s = bytes(s) self.builder.append(s) def getvalue(self): return self.builder.build() @@ -728,10 +730,10 @@ class Packer(object): default_used = True continue raise PackOverflowError("Integer value out of range") - if check(obj, bytes): + if check(obj, (bytes, bytearray)): n = len(obj) if n >= 2**32: - raise PackValueError("Bytes is too large") + raise PackValueError("%s is too large" % type(obj).__name__) self._pack_bin_header(n) return self._buffer.write(obj) if check(obj, Unicode): diff --git a/test/test_pack.py b/test/test_pack.py index e945902..a704fdb 100644 --- a/test/test_pack.py +++ b/test/test_pack.py @@ -58,6 +58,13 @@ def testPackBytes(): for td in test_data: check(td) +def testPackByteArrays(): + test_data = [ + bytearray(b""), bytearray(b"abcd"), (bytearray(b"defgh"),), + ] + for td in test_data: + check(td) + def testIgnoreUnicodeErrors(): re = unpackb(packb(b'abc\xeddef'), encoding='utf-8', unicode_errors='ignore', use_list=1) assert re == "abcdef" |