summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCorentin Henry <corentinhenry@gmail.com>2018-11-27 17:01:06 -0800
committerCorentin Henry <corentinhenry@gmail.com>2018-11-28 13:37:10 -0800
commit6540900dae21571a60fd92037e926cd6599a52eb (patch)
tree6508711d2174ac9b26c445ca3064257be3d781da
parent5f157bbaca5ae62a5bb71547106beb6ef02bc485 (diff)
downloaddocker-py-6540900dae21571a60fd92037e926cd6599a52eb.tar.gz
add tests for _read_from_socket
Check that the return value against the various combination of parameters this function can take (tty, stream, and demux). This commit also fixes a bug that the tests uncovered a bug in consume_socket_output. Signed-off-by: Corentin Henry <corentinhenry@gmail.com>
-rw-r--r--docker/utils/socket.py16
-rw-r--r--tests/unit/api_test.py98
2 files changed, 91 insertions, 23 deletions
diff --git a/docker/utils/socket.py b/docker/utils/socket.py
index fe4a332..4b32853 100644
--- a/docker/utils/socket.py
+++ b/docker/utils/socket.py
@@ -136,15 +136,17 @@ def consume_socket_output(frames, demux=False):
# we just need to concatenate.
return six.binary_type().join(frames)
- # If the streams are demultiplexed, the generator returns tuples
- # (stdin, stdout, stderr)
+ # If the streams are demultiplexed, the generator yields tuples
+ # (stdout, stderr)
out = [six.binary_type(), six.binary_type()]
for frame in frames:
- for stream_id in [STDOUT, STDERR]:
- # It is guaranteed that for each frame, one and only one stream
- # is not None.
- if frame[stream_id] is not None:
- out[stream_id] += frame[stream_id]
+ # It is guaranteed that for each frame, one and only one stream
+ # is not None.
+ assert frame != (None, None)
+ if frame[0] is not None:
+ out[0] += frame[0]
+ else:
+ out[1] += frame[1]
return tuple(out)
diff --git a/tests/unit/api_test.py b/tests/unit/api_test.py
index ccddbb1..0f5ad7c 100644
--- a/tests/unit/api_test.py
+++ b/tests/unit/api_test.py
@@ -15,6 +15,7 @@ from docker.api import APIClient
import requests
from requests.packages import urllib3
import six
+import struct
from . import fake_api
@@ -467,24 +468,25 @@ class UnixSocketStreamTest(unittest.TestCase):
class TCPSocketStreamTest(unittest.TestCase):
- text_data = b'''
+ stdout_data = b'''
Now, those children out there, they're jumping through the
flames in the hope that the god of the fire will make them fruitful.
Really, you can't blame them. After all, what girl would not prefer the
child of a god to that of some acne-scarred artisan?
'''
+ stderr_data = b'''
+ And what of the true God? To whose glory churches and monasteries have been
+ built on these islands for generations past? Now shall what of Him?
+ '''
def setUp(self):
-
self.server = six.moves.socketserver.ThreadingTCPServer(
- ('', 0), self.get_handler_class()
- )
+ ('', 0), self.get_handler_class())
self.thread = threading.Thread(target=self.server.serve_forever)
self.thread.setDaemon(True)
self.thread.start()
self.address = 'http://{}:{}'.format(
- socket.gethostname(), self.server.server_address[1]
- )
+ socket.gethostname(), self.server.server_address[1])
def tearDown(self):
self.server.shutdown()
@@ -492,31 +494,95 @@ class TCPSocketStreamTest(unittest.TestCase):
self.thread.join()
def get_handler_class(self):
- text_data = self.text_data
+ stdout_data = self.stdout_data
+ stderr_data = self.stderr_data
class Handler(six.moves.BaseHTTPServer.BaseHTTPRequestHandler, object):
def do_POST(self):
+ resp_data = self.get_resp_data()
self.send_response(101)
self.send_header(
- 'Content-Type', 'application/vnd.docker.raw-stream'
- )
+ 'Content-Type', 'application/vnd.docker.raw-stream')
self.send_header('Connection', 'Upgrade')
self.send_header('Upgrade', 'tcp')
self.end_headers()
self.wfile.flush()
time.sleep(0.2)
- self.wfile.write(text_data)
+ self.wfile.write(resp_data)
self.wfile.flush()
+ def get_resp_data(self):
+ path = self.path.split('/')[-1]
+ if path == 'tty':
+ return stdout_data + stderr_data
+ elif path == 'no-tty':
+ data = b''
+ data += self.frame_header(1, stdout_data)
+ data += stdout_data
+ data += self.frame_header(2, stderr_data)
+ data += stderr_data
+ return data
+ else:
+ raise Exception('Unknown path {0}'.format(path))
+
+ @staticmethod
+ def frame_header(stream, data):
+ return struct.pack('>BxxxL', stream, len(data))
+
return Handler
- def test_read_from_socket(self):
+ def request(self, stream=None, tty=None, demux=None):
+ assert stream is not None and tty is not None and demux is not None
with APIClient(base_url=self.address) as client:
- resp = client._post(client._url('/dummy'), stream=True)
- data = client._read_from_socket(resp, stream=True, tty=True)
- results = b''.join(data)
-
- assert results == self.text_data
+ if tty:
+ url = client._url('/tty')
+ else:
+ url = client._url('/no-tty')
+ resp = client._post(url, stream=True)
+ return client._read_from_socket(
+ resp, stream=stream, tty=tty, demux=demux)
+
+ def test_read_from_socket_1(self):
+ res = self.request(stream=True, tty=True, demux=False)
+ assert next(res) == self.stdout_data + self.stderr_data
+ with self.assertRaises(StopIteration):
+ next(res)
+
+ def test_read_from_socket_2(self):
+ res = self.request(stream=True, tty=True, demux=True)
+ assert next(res) == (self.stdout_data + self.stderr_data, None)
+ with self.assertRaises(StopIteration):
+ next(res)
+
+ def test_read_from_socket_3(self):
+ res = self.request(stream=True, tty=False, demux=False)
+ assert next(res) == self.stdout_data
+ assert next(res) == self.stderr_data
+ with self.assertRaises(StopIteration):
+ next(res)
+
+ def test_read_from_socket_4(self):
+ res = self.request(stream=True, tty=False, demux=True)
+ assert (self.stdout_data, None) == next(res)
+ assert (None, self.stderr_data) == next(res)
+ with self.assertRaises(StopIteration):
+ next(res)
+
+ def test_read_from_socket_5(self):
+ res = self.request(stream=False, tty=True, demux=False)
+ assert res == self.stdout_data + self.stderr_data
+
+ def test_read_from_socket_6(self):
+ res = self.request(stream=False, tty=True, demux=True)
+ assert res == (self.stdout_data + self.stderr_data, b'')
+
+ def test_read_from_socket_7(self):
+ res = self.request(stream=False, tty=False, demux=False)
+ res == self.stdout_data + self.stderr_data
+
+ def test_read_from_socket_8(self):
+ res = self.request(stream=False, tty=False, demux=True)
+ assert res == (self.stdout_data, self.stderr_data)
class UserAgentTest(unittest.TestCase):