summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Packman <gzlist@googlemail.com>2017-05-25 20:37:14 +0100
committerMartin Packman <gzlist@googlemail.com>2017-05-25 20:57:21 +0100
commit2f74a9f80787e47cf5b8275f83196694a37cc69d (patch)
tree384dbb4779b77860ac5d45bdee6631ccc436bbfb
parentd4c3426d7b08f51f15239dbcb8a3b8f4c5685f35 (diff)
downloadparamiko-2f74a9f80787e47cf5b8275f83196694a37cc69d.tar.gz
Allow any buffer type to written to BufferedFile
Fixes #967 Also adds test coverage for writing various types to BufferedFile which required some small changes to the test LoopbackFile subclass. Change against the 1.17 branch.
-rw-r--r--paramiko/file.py7
-rwxr-xr-xtests/test_file.py57
2 files changed, 52 insertions, 12 deletions
diff --git a/paramiko/file.py b/paramiko/file.py
index 5b57dfd6..bdaa7ed8 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -17,7 +17,7 @@
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from paramiko.common import linefeed_byte_value, crlf, cr_byte, linefeed_byte, \
cr_byte_value
-from paramiko.py3compat import BytesIO, PY2, u, b, bytes_types
+from paramiko.py3compat import BytesIO, PY2, u, bytes_types, text_type
from paramiko.util import ClosingContextManager
@@ -372,7 +372,10 @@ class BufferedFile (ClosingContextManager):
:param str/bytes data: data to write
"""
- data = b(data)
+ if isinstance(data, text_type):
+ # GZ 2017-05-25: Accepting text on a binary stream unconditionally
+ # cooercing to utf-8 seems questionable, but compatibility reasons?
+ data = data.encode('utf-8')
if self._closed:
raise IOError('File is closed')
if not (self._flags & self.FLAG_WRITE):
diff --git a/tests/test_file.py b/tests/test_file.py
index 7fab6985..b33ecd51 100755
--- a/tests/test_file.py
+++ b/tests/test_file.py
@@ -21,10 +21,14 @@ Some unit tests for the BufferedFile abstraction.
"""
import unittest
-from paramiko.file import BufferedFile
-from paramiko.common import linefeed_byte, crlf, cr_byte
import sys
+from paramiko.common import linefeed_byte, crlf, cr_byte
+from paramiko.file import BufferedFile
+from paramiko.py3compat import BytesIO
+
+from tests import skipUnlessBuiltin
+
class LoopbackFile (BufferedFile):
"""
@@ -33,19 +37,16 @@ class LoopbackFile (BufferedFile):
def __init__(self, mode='r', bufsize=-1):
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
- self.buffer = bytes()
+ self.buffer = BytesIO()
+ self.offset = 0
def _read(self, size):
- if len(self.buffer) == 0:
- return None
- if size > len(self.buffer):
- size = len(self.buffer)
- data = self.buffer[:size]
- self.buffer = self.buffer[size:]
+ data = self.buffer.getvalue()[self.offset:self.offset+size]
+ self.offset += len(data)
return data
def _write(self, data):
- self.buffer += data
+ self.buffer.write(data)
return len(data)
@@ -187,6 +188,42 @@ class BufferedFileTest (unittest.TestCase):
self.assertEqual(data, b'hello')
f.close()
+ def test_write_bad_type(self):
+ with LoopbackFile('wb') as f:
+ self.assertRaises(TypeError, f.write, object())
+
+ def test_write_unicode_as_binary(self):
+ text = u"\xa7 why is writing text to a binary file allowed?\n"
+ with LoopbackFile('rb+') as f:
+ f.write(text)
+ self.assertEqual(f.read(), text.encode("utf-8"))
+
+ @skipUnlessBuiltin('memoryview')
+ def test_write_bytearray(self):
+ with LoopbackFile('rb+') as f:
+ f.write(bytearray(12))
+ self.assertEqual(f.read(), 12 * b"\0")
+
+ @skipUnlessBuiltin('buffer')
+ def test_write_buffer(self):
+ data = 3 * b"pretend giant block of data\n"
+ offsets = range(0, len(data), 8)
+ with LoopbackFile('rb+') as f:
+ for offset in offsets:
+ f.write(buffer(data, offset, 8))
+ self.assertEqual(f.read(), data)
+
+ @skipUnlessBuiltin('memoryview')
+ def test_write_memoryview(self):
+ data = 3 * b"pretend giant block of data\n"
+ offsets = range(0, len(data), 8)
+ with LoopbackFile('rb+') as f:
+ view = memoryview(data)
+ for offset in offsets:
+ f.write(view[offset:offset+8])
+ self.assertEqual(f.read(), data)
+
+
if __name__ == '__main__':
from unittest import main
main()