summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/py/Makefile.am1
-rw-r--r--lib/py/src/transport/TSocket.py26
-rw-r--r--lib/py/test/test_socket.py57
3 files changed, 83 insertions, 1 deletions
diff --git a/lib/py/Makefile.am b/lib/py/Makefile.am
index c6c5ff38c..b16305790 100644
--- a/lib/py/Makefile.am
+++ b/lib/py/Makefile.am
@@ -50,6 +50,7 @@ check-local: all py3-test
$(PYTHON) test/thrift_json.py
$(PYTHON) test/thrift_transport.py
$(PYTHON) test/test_sslsocket.py
+ $(PYTHON) test/test_socket.py
$(PYTHON) test/thrift_TBinaryProtocol.py
$(PYTHON) test/thrift_TZlibTransport.py
$(PYTHON) test/thrift_TCompactProtocol.py
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index 9886fa2f8..3c7a3ca7d 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -74,7 +74,31 @@ class TSocket(TSocketBase):
self.handle = h
def isOpen(self):
- return self.handle is not None
+ if self.handle is None:
+ return False
+
+ # this lets us cheaply see if the other end of the socket is still
+ # connected. if disconnected, we'll get EOF back (expressed as zero
+ # bytes of data) otherwise we'll get one byte or an error indicating
+ # we'd have to block for data.
+ #
+ # note that we're not doing this with socket.MSG_DONTWAIT because 1)
+ # it's linux-specific and 2) gevent-patched sockets hide EAGAIN from us
+ # when timeout is non-zero.
+ original_timeout = self.handle.gettimeout()
+ try:
+ self.handle.settimeout(0)
+ try:
+ peeked_bytes = self.handle.recv(1, socket.MSG_PEEK)
+ except (socket.error, OSError) as exc: # on modern python this is just BlockingIOError
+ if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
+ return True
+ return False
+ finally:
+ self.handle.settimeout(original_timeout)
+
+ # the length will be zero if we got EOF (indicating connection closed)
+ return len(peeked_bytes) == 1
def setTimeout(self, ms):
if ms is None:
diff --git a/lib/py/test/test_socket.py b/lib/py/test/test_socket.py
new file mode 100644
index 000000000..95124dcbe
--- /dev/null
+++ b/lib/py/test/test_socket.py
@@ -0,0 +1,57 @@
+import errno
+import unittest
+
+from test_sslsocket import ServerAcceptor
+
+import _import_local_thrift # noqa
+
+from thrift.transport.TSocket import TServerSocket
+from thrift.transport.TSocket import TSocket
+from thrift.transport.TTransport import TTransportException
+
+
+class TSocketTest(unittest.TestCase):
+ def test_isOpen_checks_for_readability(self):
+ # https://docs.python.org/3/library/socket.html#notes-on-socket-timeouts
+ # https://docs.python.org/3/library/socket.html#socket.socket.settimeout
+ timeouts = [
+ None, # blocking mode
+ 0, # non-blocking mode
+ 1.0, # timeout mode
+ ]
+
+ for timeout in timeouts:
+ acc = ServerAcceptor(TServerSocket(port=0))
+ acc.start()
+
+ sock = TSocket(host="localhost", port=acc.port)
+ sock.open()
+ sock.setTimeout(timeout)
+
+ # the socket shows as open immediately after connecting
+ self.assertTrue(sock.isOpen())
+
+ # and remains open during usage
+ sock.write(b"hello")
+ self.assertTrue(sock.isOpen())
+ while True:
+ try:
+ sock.read(5)
+ except TTransportException as exc:
+ if exc.inner.errno == errno.EAGAIN:
+ # try again when we're in non-blocking mode
+ continue
+ raise
+ break
+ self.assertTrue(sock.isOpen())
+
+ # once the server side closes, it no longer shows open
+ acc.client.close() # this also blocks until the other thread is done
+ acc.close()
+ self.assertFalse(sock.isOpen())
+
+ sock.close()
+
+
+if __name__ == "__main__":
+ unittest.main()