summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Holth <dholth@fastmail.fm>2019-11-17 22:45:52 -0500
committerPaul Kehrer <paul.l.kehrer@gmail.com>2019-11-18 11:45:52 +0800
commit079c963ddd4ebfd13a905829bc341dce85d94fbd (patch)
tree5ca14d29d4e46fd7328d3372e1bd4b0c8db76ae5
parent8543286168ed3bb234395d66fb401714495ff198 (diff)
downloadpyopenssl-git-079c963ddd4ebfd13a905829bc341dce85d94fbd.tar.gz
use _ffi.from_buffer() to support bytearray (#852)
* use _ffi.from_buffer(buf) in send, to support bytearray * add bytearray test * update CHANGELOG.rst * move from_buffer before 'buffer too long' check * context-managed from_buffer + black * don't shadow buf in send() * test return count for sendall * test sending an array * fix test * also use from_buffer in bio_write * de-format _util.py * formatting * add simple bio_write tests * wrap line
-rw-r--r--.gitignore1
-rw-r--r--CHANGELOG.rst3
-rw-r--r--src/OpenSSL/SSL.py72
-rw-r--r--src/OpenSSL/_util.py14
-rw-r--r--tests/test_ssl.py42
5 files changed, 92 insertions, 40 deletions
diff --git a/.gitignore b/.gitignore
index 2040b80..e3057ad 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,4 @@ doc/_build/
examples/simple/*.cert
examples/simple/*.pkey
.cache
+.mypy_cache \ No newline at end of file
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 2bf74f5..e0c034d 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -28,7 +28,8 @@ Deprecations:
Changes:
^^^^^^^^
-*none*
+- Support ``bytearray`` in ``SSL.Connection.send()`` by using cffi's from_buffer.
+ `#852 <https://github.com/pyca/pyopenssl/pull/852>`_
----
diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py
index 5521151..adcfd8f 100644
--- a/src/OpenSSL/SSL.py
+++ b/src/OpenSSL/SSL.py
@@ -15,6 +15,7 @@ from OpenSSL._util import (
UNSPECIFIED as _UNSPECIFIED,
exception_from_error_queue as _exception_from_error_queue,
ffi as _ffi,
+ from_buffer as _from_buffer,
lib as _lib,
make_assert as _make_assert,
native as _native,
@@ -1730,18 +1731,18 @@ class Connection(object):
# Backward compatibility
buf = _text_to_bytes_and_warn("buf", buf)
- if isinstance(buf, memoryview):
- buf = buf.tobytes()
- if isinstance(buf, _buffer):
- buf = str(buf)
- if not isinstance(buf, bytes):
- raise TypeError("data must be a memoryview, buffer or byte string")
- if len(buf) > 2147483647:
- raise ValueError("Cannot send more than 2**31-1 bytes at once.")
+ with _from_buffer(buf) as data:
+ # check len(buf) instead of len(data) for testability
+ if len(buf) > 2147483647:
+ raise ValueError(
+ "Cannot send more than 2**31-1 bytes at once."
+ )
+
+ result = _lib.SSL_write(self._ssl, data, len(data))
+ self._raise_ssl_error(self._ssl, result)
+
+ return result
- result = _lib.SSL_write(self._ssl, buf, len(buf))
- self._raise_ssl_error(self._ssl, result)
- return result
write = send
def sendall(self, buf, flags=0):
@@ -1757,28 +1758,24 @@ class Connection(object):
"""
buf = _text_to_bytes_and_warn("buf", buf)
- if isinstance(buf, memoryview):
- buf = buf.tobytes()
- if isinstance(buf, _buffer):
- buf = str(buf)
- if not isinstance(buf, bytes):
- raise TypeError("buf must be a memoryview, buffer or byte string")
-
- left_to_send = len(buf)
- total_sent = 0
- data = _ffi.new("char[]", buf)
-
- while left_to_send:
- # SSL_write's num arg is an int,
- # so we cannot send more than 2**31-1 bytes at once.
- result = _lib.SSL_write(
- self._ssl,
- data + total_sent,
- min(left_to_send, 2147483647)
- )
- self._raise_ssl_error(self._ssl, result)
- total_sent += result
- left_to_send -= result
+ with _from_buffer(buf) as data:
+
+ left_to_send = len(buf)
+ total_sent = 0
+
+ while left_to_send:
+ # SSL_write's num arg is an int,
+ # so we cannot send more than 2**31-1 bytes at once.
+ result = _lib.SSL_write(
+ self._ssl,
+ data + total_sent,
+ min(left_to_send, 2147483647)
+ )
+ self._raise_ssl_error(self._ssl, result)
+ total_sent += result
+ left_to_send -= result
+
+ return total_sent
def recv(self, bufsiz, flags=None):
"""
@@ -1892,10 +1889,11 @@ class Connection(object):
if self._into_ssl is None:
raise TypeError("Connection sock was not None")
- result = _lib.BIO_write(self._into_ssl, buf, len(buf))
- if result <= 0:
- self._handle_bio_errors(self._into_ssl, result)
- return result
+ with _from_buffer(buf) as data:
+ result = _lib.BIO_write(self._into_ssl, data, len(data))
+ if result <= 0:
+ self._handle_bio_errors(self._into_ssl, result)
+ return result
def renegotiate(self):
"""
diff --git a/src/OpenSSL/_util.py b/src/OpenSSL/_util.py
index cdcacc8..d8e3f66 100644
--- a/src/OpenSSL/_util.py
+++ b/src/OpenSSL/_util.py
@@ -145,3 +145,17 @@ def text_to_bytes_and_warn(label, obj):
)
return obj.encode('utf-8')
return obj
+
+
+try:
+ # newer versions of cffi free the buffer deterministically
+ with ffi.from_buffer(b""):
+ pass
+ from_buffer = ffi.from_buffer
+except AttributeError:
+ # cffi < 0.12 frees the buffer with refcounting gc
+ from contextlib import contextmanager
+
+ @contextmanager
+ def from_buffer(*args):
+ yield ffi.from_buffer(*args)
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index 6b9422c..16767e9 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -2087,6 +2087,29 @@ class TestConnection(object):
with pytest.raises(TypeError):
Connection(bad_context)
+ @pytest.mark.parametrize('bad_bio', [object(), None, 1, [1, 2, 3]])
+ def test_bio_write_wrong_args(self, bad_bio):
+ """
+ `Connection.bio_write` raises `TypeError` if called with a non-bytes
+ (or text) argument.
+ """
+ context = Context(TLSv1_METHOD)
+ connection = Connection(context, None)
+ with pytest.raises(TypeError):
+ connection.bio_write(bad_bio)
+
+ def test_bio_write(self):
+ """
+ `Connection.bio_write` does not raise if called with bytes or
+ bytearray, warns if called with text.
+ """
+ context = Context(TLSv1_METHOD)
+ connection = Connection(context, None)
+ connection.bio_write(b'xy')
+ connection.bio_write(bytearray(b'za'))
+ with pytest.warns(DeprecationWarning):
+ connection.bio_write(u'deprecated')
+
def test_get_context(self):
"""
`Connection.get_context` returns the `Context` instance used to
@@ -2807,6 +2830,8 @@ class TestConnectionSend(object):
connection = Connection(Context(TLSv1_METHOD), None)
with pytest.raises(TypeError):
connection.send(object())
+ with pytest.raises(TypeError):
+ connection.send([1, 2, 3])
def test_short_bytes(self):
"""
@@ -2845,6 +2870,16 @@ class TestConnectionSend(object):
assert count == 2
assert client.recv(2) == b'xy'
+ def test_short_bytearray(self):
+ """
+ When passed a short bytearray, `Connection.send` transmits all of
+ it and returns the number of bytes sent.
+ """
+ server, client = loopback()
+ count = server.send(bytearray(b'xy'))
+ assert count == 2
+ assert client.recv(2) == b'xy'
+
@skip_if_py3
def test_short_buffer(self):
"""
@@ -3015,6 +3050,8 @@ class TestConnectionSendall(object):
connection = Connection(Context(TLSv1_METHOD), None)
with pytest.raises(TypeError):
connection.sendall(object())
+ with pytest.raises(TypeError):
+ connection.sendall([1, 2, 3])
def test_short(self):
"""
@@ -3056,8 +3093,9 @@ class TestConnectionSendall(object):
`Connection.sendall` transmits all of them.
"""
server, client = loopback()
- server.sendall(buffer(b'x'))
- assert client.recv(1) == b'x'
+ count = server.sendall(buffer(b'xy'))
+ assert count == 2
+ assert client.recv(2) == b'xy'
def test_long(self):
"""