diff options
author | Daniel Holth <dholth@fastmail.fm> | 2019-11-17 22:45:52 -0500 |
---|---|---|
committer | Paul Kehrer <paul.l.kehrer@gmail.com> | 2019-11-18 11:45:52 +0800 |
commit | 079c963ddd4ebfd13a905829bc341dce85d94fbd (patch) | |
tree | 5ca14d29d4e46fd7328d3372e1bd4b0c8db76ae5 /src | |
parent | 8543286168ed3bb234395d66fb401714495ff198 (diff) | |
download | pyopenssl-git-079c963ddd4ebfd13a905829bc341dce85d94fbd.tar.gz |
use _ffi.from_buffer() to support bytearray (#852)
* use _ffi.from_buffer(buf) in send, to support bytearray
* add bytearray test
* update CHANGELOG.rst
* move from_buffer before 'buffer too long' check
* context-managed from_buffer + black
* don't shadow buf in send()
* test return count for sendall
* test sending an array
* fix test
* also use from_buffer in bio_write
* de-format _util.py
* formatting
* add simple bio_write tests
* wrap line
Diffstat (limited to 'src')
-rw-r--r-- | src/OpenSSL/SSL.py | 72 | ||||
-rw-r--r-- | src/OpenSSL/_util.py | 14 |
2 files changed, 49 insertions, 37 deletions
diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py index 5521151..adcfd8f 100644 --- a/src/OpenSSL/SSL.py +++ b/src/OpenSSL/SSL.py @@ -15,6 +15,7 @@ from OpenSSL._util import ( UNSPECIFIED as _UNSPECIFIED, exception_from_error_queue as _exception_from_error_queue, ffi as _ffi, + from_buffer as _from_buffer, lib as _lib, make_assert as _make_assert, native as _native, @@ -1730,18 +1731,18 @@ class Connection(object): # Backward compatibility buf = _text_to_bytes_and_warn("buf", buf) - if isinstance(buf, memoryview): - buf = buf.tobytes() - if isinstance(buf, _buffer): - buf = str(buf) - if not isinstance(buf, bytes): - raise TypeError("data must be a memoryview, buffer or byte string") - if len(buf) > 2147483647: - raise ValueError("Cannot send more than 2**31-1 bytes at once.") + with _from_buffer(buf) as data: + # check len(buf) instead of len(data) for testability + if len(buf) > 2147483647: + raise ValueError( + "Cannot send more than 2**31-1 bytes at once." + ) + + result = _lib.SSL_write(self._ssl, data, len(data)) + self._raise_ssl_error(self._ssl, result) + + return result - result = _lib.SSL_write(self._ssl, buf, len(buf)) - self._raise_ssl_error(self._ssl, result) - return result write = send def sendall(self, buf, flags=0): @@ -1757,28 +1758,24 @@ class Connection(object): """ buf = _text_to_bytes_and_warn("buf", buf) - if isinstance(buf, memoryview): - buf = buf.tobytes() - if isinstance(buf, _buffer): - buf = str(buf) - if not isinstance(buf, bytes): - raise TypeError("buf must be a memoryview, buffer or byte string") - - left_to_send = len(buf) - total_sent = 0 - data = _ffi.new("char[]", buf) - - while left_to_send: - # SSL_write's num arg is an int, - # so we cannot send more than 2**31-1 bytes at once. - result = _lib.SSL_write( - self._ssl, - data + total_sent, - min(left_to_send, 2147483647) - ) - self._raise_ssl_error(self._ssl, result) - total_sent += result - left_to_send -= result + with _from_buffer(buf) as data: + + left_to_send = len(buf) + total_sent = 0 + + while left_to_send: + # SSL_write's num arg is an int, + # so we cannot send more than 2**31-1 bytes at once. + result = _lib.SSL_write( + self._ssl, + data + total_sent, + min(left_to_send, 2147483647) + ) + self._raise_ssl_error(self._ssl, result) + total_sent += result + left_to_send -= result + + return total_sent def recv(self, bufsiz, flags=None): """ @@ -1892,10 +1889,11 @@ class Connection(object): if self._into_ssl is None: raise TypeError("Connection sock was not None") - result = _lib.BIO_write(self._into_ssl, buf, len(buf)) - if result <= 0: - self._handle_bio_errors(self._into_ssl, result) - return result + with _from_buffer(buf) as data: + result = _lib.BIO_write(self._into_ssl, data, len(data)) + if result <= 0: + self._handle_bio_errors(self._into_ssl, result) + return result def renegotiate(self): """ diff --git a/src/OpenSSL/_util.py b/src/OpenSSL/_util.py index cdcacc8..d8e3f66 100644 --- a/src/OpenSSL/_util.py +++ b/src/OpenSSL/_util.py @@ -145,3 +145,17 @@ def text_to_bytes_and_warn(label, obj): ) return obj.encode('utf-8') return obj + + +try: + # newer versions of cffi free the buffer deterministically + with ffi.from_buffer(b""): + pass + from_buffer = ffi.from_buffer +except AttributeError: + # cffi < 0.12 frees the buffer with refcounting gc + from contextlib import contextmanager + + @contextmanager + def from_buffer(*args): + yield ffi.from_buffer(*args) |