summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorj-a-n <oss@janschneider.net>2023-01-26 10:13:20 +0100
committerGitHub <noreply@github.com>2023-01-26 09:13:20 +0000
commit8af1cbb17f5f12c9999fc56adf7b0187a1714174 (patch)
tree1087efd7fd1c1779d2b17d8bdc04b67f3f05849a
parentf0d2d0e8bae6cee1f57c00e9ccb295c85643731d (diff)
downloadwebsocket-client-8af1cbb17f5f12c9999fc56adf7b0187a1714174.tar.gz
Refactor and fix reconnect and ping (#862)
* Refactor and fix reconnect and ping - Dispatcher: do not reconnect in loop - Avoid reconnect in recursion and hit maximum recursion depth - Close websocket before reconnect to avoid unclosed ssl.SSLSocket warnings - Start ping thread on socket connect - Stop ping thread on socket disconnect - Add reconnect test - Add callback exception test * Remove duplicate function * Also check passed app in testCallbackException * Test exception in callback method Co-authored-by: Jan Schneider <j.schneider@uib.de> Co-authored-by: engn33r <engn33r@users.noreply.github.com>
-rw-r--r--websocket/_app.py108
-rw-r--r--websocket/tests/test_app.py85
2 files changed, 156 insertions, 37 deletions
diff --git a/websocket/_app.py b/websocket/_app.py
index e5a4484..0a16ddb 100644
--- a/websocket/_app.py
+++ b/websocket/_app.py
@@ -4,11 +4,12 @@ import sys
import threading
import time
import traceback
+
+from . import _logging
from ._abnf import ABNF
from ._url import parse_url
from ._core import WebSocket, getdefaulttimeout
from ._exceptions import *
-from . import _logging
"""
_app.py
@@ -53,10 +54,9 @@ class DispatcherBase:
def reconnect(self, seconds, reconnector):
try:
- while True:
- _logging.info("reconnect() - retrying in %s seconds [%s frames in stack]" % (seconds, len(inspect.stack())))
- time.sleep(seconds)
- reconnector(reconnecting=True)
+ _logging.info("reconnect() - retrying in %s seconds [%s frames in stack]" % (seconds, len(inspect.stack())))
+ time.sleep(seconds)
+ reconnector(reconnecting=True)
except KeyboardInterrupt as e:
_logging.info("User exited %s" % (e,))
@@ -214,6 +214,11 @@ class WebSocketApp:
self.sock = None
self.last_ping_tm = 0
self.last_pong_tm = 0
+ self.ping_thread = None
+ self.stop_ping = None
+ self.ping_interval = 0
+ self.ping_timeout = None
+ self.ping_payload = ""
self.subprotocols = subprotocols
self.prepared_socket = socket
self.has_errored = False
@@ -244,15 +249,31 @@ class WebSocketApp:
self.sock.close(**kwargs)
self.sock = None
- def _send_ping(self, interval, event, payload):
- while not event.wait(interval):
- self.last_ping_tm = time.time()
+ def _start_ping_thread(self):
+ self.last_ping_tm = self.last_pong_tm = 0
+ self.stop_ping = threading.Event()
+ self.ping_thread = threading.Thread(target=self._send_ping)
+ self.ping_thread.daemon = True
+ self.ping_thread.start()
+
+ def _stop_ping_thread(self):
+ if self.stop_ping:
+ self.stop_ping.set()
+ if self.ping_thread and self.ping_thread.is_alive():
+ self.ping_thread.join(3)
+ self.last_ping_tm = self.last_pong_tm = 0
+
+ def _send_ping(self):
+ if self.stop_ping.wait(self.ping_interval):
+ return
+ while not self.stop_ping.wait(self.ping_interval):
if self.sock:
+ self.last_ping_tm = time.time()
try:
- self.sock.ping(payload)
+ _logging.debug("Sending ping")
+ self.sock.ping(self.ping_payload)
except Exception as ex:
- _logging.warning("send_ping routine terminated: {}".format(ex))
- break
+ _logging.debug("Failed to send ping: %s", ex)
def run_forever(self, sockopt=None, sslopt=None,
ping_interval=0, ping_timeout=None,
@@ -331,10 +352,11 @@ class WebSocketApp:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
- thread = None
+
+ self.ping_interval = ping_interval
+ self.ping_timeout = ping_timeout
+ self.ping_payload = ping_payload
self.keep_running = True
- self.last_ping_tm = 0
- self.last_pong_tm = 0
def teardown(close_frame=None):
"""
@@ -347,9 +369,7 @@ class WebSocketApp:
with the statusCode and reason from the provided frame.
"""
- if thread and thread.is_alive():
- event.set()
- thread.join()
+ self._stop_ping_thread()
self.keep_running = False
if self.sock:
self.sock.close()
@@ -361,11 +381,15 @@ class WebSocketApp:
self._callback(self.on_close, close_status_code, close_reason)
def setSock(reconnecting=False):
+ if reconnecting and self.sock:
+ self.sock.shutdown()
+
self.sock = WebSocket(
self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation,
enable_multithread=True)
+
self.sock.settimeout(getdefaulttimeout())
try:
self.sock.connect(
@@ -377,13 +401,16 @@ class WebSocketApp:
host=host, origin=origin, suppress_origin=suppress_origin,
proxy_type=proxy_type, socket=self.prepared_socket)
+ _logging.info("Websocket connected")
+
+ if self.ping_interval:
+ self._start_ping_thread()
+
self._callback(self.on_open)
- _logging.info("websocket connected")
dispatcher.read(self.sock.sock, read, check)
except (WebSocketConnectionClosedException, ConnectionRefusedError, KeyboardInterrupt, SystemExit, Exception) as e:
- _logging.error("%s - %s" % (e, reconnect and "reconnecting" or "goodbye"))
- reconnecting or handleDisconnect(e)
+ handleDisconnect(e, reconnecting)
def read():
if not self.keep_running:
@@ -396,6 +423,7 @@ class WebSocketApp:
return handleDisconnect(e)
else:
raise e
+
if op_code == ABNF.OPCODE_CLOSE:
return teardown(frame)
elif op_code == ABNF.OPCODE_PING:
@@ -418,10 +446,10 @@ class WebSocketApp:
return True
def check():
- if (ping_timeout):
- has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout
+ if (self.ping_timeout):
+ has_timeout_expired = time.time() - self.last_ping_tm > self.ping_timeout
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
- has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
+ has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > self.ping_timeout
if (self.last_ping_tm and
has_timeout_expired and
@@ -429,29 +457,35 @@ class WebSocketApp:
raise WebSocketTimeoutException("ping/pong timed out")
return True
- def handleDisconnect(e):
+ def handleDisconnect(e, reconnecting=False):
self.has_errored = True
- self._callback(self.on_error, e)
- if isinstance(e, SystemExit):
- # propagate SystemExit further
+ self._stop_ping_thread()
+ if not reconnecting:
+ self._callback(self.on_error, e)
+
+ if isinstance(e, (KeyboardInterrupt, SystemExit)):
+ teardown()
+ # Propagate further
raise
- if reconnect and not isinstance(e, KeyboardInterrupt):
- _logging.info("websocket disconnected (retrying in %s seconds) [%s frames in stack]" % (reconnect, len(inspect.stack())))
- dispatcher.reconnect(reconnect, setSock)
+
+ if reconnect:
+ _logging.info("%s - reconnect" % e)
+ if custom_dispatcher:
+ _logging.debug("Calling custom dispatcher reconnect [%s frames in stack]" % len(inspect.stack()))
+ dispatcher.reconnect(reconnect, setSock)
else:
+ _logging.error("%s - goodbye" % e)
teardown()
custom_dispatcher = bool(dispatcher)
dispatcher = self.create_dispatcher(ping_timeout, dispatcher, parse_url(self.url)[3])
- if ping_interval:
- event = threading.Event()
- thread = threading.Thread(
- target=self._send_ping, args=(ping_interval, event, ping_payload))
- thread.daemon = True
- thread.start()
-
setSock()
+ if not custom_dispatcher and reconnect:
+ while self.keep_running:
+ _logging.debug("Calling dispatcher reconnect [%s frames in stack]" % len(inspect.stack()))
+ dispatcher.reconnect(reconnect, setSock)
+
return self.has_errored
def create_dispatcher(self, ping_timeout, dispatcher=None, is_ssl=False):
diff --git a/websocket/tests/test_app.py b/websocket/tests/test_app.py
index 063023f..09ba348 100644
--- a/websocket/tests/test_app.py
+++ b/websocket/tests/test_app.py
@@ -229,6 +229,91 @@ class WebSocketAppTest(unittest.TestCase):
self.assertRaises(ws.WebSocketConnectionClosedException, app.send, data="test if connection is closed")
+ @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
+ def testCallbackFunctionException(self):
+ """ Test callback function exception handling """
+
+ exc = None
+ passed_app = None
+
+ def on_open(app):
+ raise RuntimeError("Callback failed")
+
+ def on_error(app, err):
+ nonlocal passed_app
+ passed_app = app
+ nonlocal exc
+ exc = err
+
+ def on_pong(app, msg):
+ app.close()
+
+ app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_open=on_open, on_error=on_error, on_pong=on_pong)
+ app.run_forever(ping_interval=2, ping_timeout=1)
+
+ self.assertEqual(passed_app, app)
+ self.assertIsInstance(exc, RuntimeError)
+ self.assertEqual(str(exc), "Callback failed")
+
+ @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
+ def testCallbackMethodException(self):
+ """ Test callback method exception handling """
+
+ class Callbacks:
+ def __init__(self):
+ self.exc = None
+ self.passed_app = None
+ self.app = ws.WebSocketApp(
+ 'ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT,
+ on_open=self.on_open,
+ on_error=self.on_error,
+ on_pong=self.on_pong
+ )
+ self.app.run_forever(ping_interval=2, ping_timeout=1)
+
+ def on_open(self, app):
+ raise RuntimeError("Callback failed")
+
+ def on_error(self, app, err):
+ self.passed_app = app
+ self.exc = err
+
+ def on_pong(self, app, msg):
+ app.close()
+
+ callbacks = Callbacks()
+
+ self.assertEqual(callbacks.passed_app, callbacks.app)
+ self.assertIsInstance(callbacks.exc, RuntimeError)
+ self.assertEqual(str(callbacks.exc), "Callback failed")
+
+ @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
+ def testReconnect(self):
+ """ Test reconnect """
+ pong_count = 0
+ exc = None
+
+ def on_error(app, err):
+ nonlocal exc
+ exc = err
+
+ def on_pong(app, msg):
+ nonlocal pong_count
+ pong_count += 1
+ if pong_count == 1:
+ # First pong, shutdown socket, enforce read error
+ app.sock.shutdown()
+ if pong_count >= 2:
+ # Got second pong after reconnect
+ app.close()
+
+ app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_pong=on_pong, on_error=on_error)
+ app.run_forever(ping_interval=2, ping_timeout=1, reconnect=3)
+
+ self.assertEqual(pong_count, 2)
+ self.assertIsInstance(exc, ValueError)
+ self.assertEqual(str(exc), "Invalid file object: None")
+
if __name__ == "__main__":
unittest.main()