diff options
author | j-a-n <oss@janschneider.net> | 2023-01-26 10:13:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-26 09:13:20 +0000 |
commit | 8af1cbb17f5f12c9999fc56adf7b0187a1714174 (patch) | |
tree | 1087efd7fd1c1779d2b17d8bdc04b67f3f05849a | |
parent | f0d2d0e8bae6cee1f57c00e9ccb295c85643731d (diff) | |
download | websocket-client-8af1cbb17f5f12c9999fc56adf7b0187a1714174.tar.gz |
Refactor and fix reconnect and ping (#862)
* Refactor and fix reconnect and ping
- Dispatcher: do not reconnect in loop
- Avoid reconnect in recursion and hit maximum recursion depth
- Close websocket before reconnect to avoid unclosed ssl.SSLSocket warnings
- Start ping thread on socket connect
- Stop ping thread on socket disconnect
- Add reconnect test
- Add callback exception test
* Remove duplicate function
* Also check passed app in testCallbackException
* Test exception in callback method
Co-authored-by: Jan Schneider <j.schneider@uib.de>
Co-authored-by: engn33r <engn33r@users.noreply.github.com>
-rw-r--r-- | websocket/_app.py | 108 | ||||
-rw-r--r-- | websocket/tests/test_app.py | 85 |
2 files changed, 156 insertions, 37 deletions
diff --git a/websocket/_app.py b/websocket/_app.py index e5a4484..0a16ddb 100644 --- a/websocket/_app.py +++ b/websocket/_app.py @@ -4,11 +4,12 @@ import sys import threading import time import traceback + +from . import _logging from ._abnf import ABNF from ._url import parse_url from ._core import WebSocket, getdefaulttimeout from ._exceptions import * -from . import _logging """ _app.py @@ -53,10 +54,9 @@ class DispatcherBase: def reconnect(self, seconds, reconnector): try: - while True: - _logging.info("reconnect() - retrying in %s seconds [%s frames in stack]" % (seconds, len(inspect.stack()))) - time.sleep(seconds) - reconnector(reconnecting=True) + _logging.info("reconnect() - retrying in %s seconds [%s frames in stack]" % (seconds, len(inspect.stack()))) + time.sleep(seconds) + reconnector(reconnecting=True) except KeyboardInterrupt as e: _logging.info("User exited %s" % (e,)) @@ -214,6 +214,11 @@ class WebSocketApp: self.sock = None self.last_ping_tm = 0 self.last_pong_tm = 0 + self.ping_thread = None + self.stop_ping = None + self.ping_interval = 0 + self.ping_timeout = None + self.ping_payload = "" self.subprotocols = subprotocols self.prepared_socket = socket self.has_errored = False @@ -244,15 +249,31 @@ class WebSocketApp: self.sock.close(**kwargs) self.sock = None - def _send_ping(self, interval, event, payload): - while not event.wait(interval): - self.last_ping_tm = time.time() + def _start_ping_thread(self): + self.last_ping_tm = self.last_pong_tm = 0 + self.stop_ping = threading.Event() + self.ping_thread = threading.Thread(target=self._send_ping) + self.ping_thread.daemon = True + self.ping_thread.start() + + def _stop_ping_thread(self): + if self.stop_ping: + self.stop_ping.set() + if self.ping_thread and self.ping_thread.is_alive(): + self.ping_thread.join(3) + self.last_ping_tm = self.last_pong_tm = 0 + + def _send_ping(self): + if self.stop_ping.wait(self.ping_interval): + return + while not self.stop_ping.wait(self.ping_interval): if self.sock: + self.last_ping_tm = time.time() try: - self.sock.ping(payload) + _logging.debug("Sending ping") + self.sock.ping(self.ping_payload) except Exception as ex: - _logging.warning("send_ping routine terminated: {}".format(ex)) - break + _logging.debug("Failed to send ping: %s", ex) def run_forever(self, sockopt=None, sslopt=None, ping_interval=0, ping_timeout=None, @@ -331,10 +352,11 @@ class WebSocketApp: sslopt = {} if self.sock: raise WebSocketException("socket is already opened") - thread = None + + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.ping_payload = ping_payload self.keep_running = True - self.last_ping_tm = 0 - self.last_pong_tm = 0 def teardown(close_frame=None): """ @@ -347,9 +369,7 @@ class WebSocketApp: with the statusCode and reason from the provided frame. """ - if thread and thread.is_alive(): - event.set() - thread.join() + self._stop_ping_thread() self.keep_running = False if self.sock: self.sock.close() @@ -361,11 +381,15 @@ class WebSocketApp: self._callback(self.on_close, close_status_code, close_reason) def setSock(reconnecting=False): + if reconnecting and self.sock: + self.sock.shutdown() + self.sock = WebSocket( self.get_mask_key, sockopt=sockopt, sslopt=sslopt, fire_cont_frame=self.on_cont_message is not None, skip_utf8_validation=skip_utf8_validation, enable_multithread=True) + self.sock.settimeout(getdefaulttimeout()) try: self.sock.connect( @@ -377,13 +401,16 @@ class WebSocketApp: host=host, origin=origin, suppress_origin=suppress_origin, proxy_type=proxy_type, socket=self.prepared_socket) + _logging.info("Websocket connected") + + if self.ping_interval: + self._start_ping_thread() + self._callback(self.on_open) - _logging.info("websocket connected") dispatcher.read(self.sock.sock, read, check) except (WebSocketConnectionClosedException, ConnectionRefusedError, KeyboardInterrupt, SystemExit, Exception) as e: - _logging.error("%s - %s" % (e, reconnect and "reconnecting" or "goodbye")) - reconnecting or handleDisconnect(e) + handleDisconnect(e, reconnecting) def read(): if not self.keep_running: @@ -396,6 +423,7 @@ class WebSocketApp: return handleDisconnect(e) else: raise e + if op_code == ABNF.OPCODE_CLOSE: return teardown(frame) elif op_code == ABNF.OPCODE_PING: @@ -418,10 +446,10 @@ class WebSocketApp: return True def check(): - if (ping_timeout): - has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout + if (self.ping_timeout): + has_timeout_expired = time.time() - self.last_ping_tm > self.ping_timeout has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 - has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout + has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > self.ping_timeout if (self.last_ping_tm and has_timeout_expired and @@ -429,29 +457,35 @@ class WebSocketApp: raise WebSocketTimeoutException("ping/pong timed out") return True - def handleDisconnect(e): + def handleDisconnect(e, reconnecting=False): self.has_errored = True - self._callback(self.on_error, e) - if isinstance(e, SystemExit): - # propagate SystemExit further + self._stop_ping_thread() + if not reconnecting: + self._callback(self.on_error, e) + + if isinstance(e, (KeyboardInterrupt, SystemExit)): + teardown() + # Propagate further raise - if reconnect and not isinstance(e, KeyboardInterrupt): - _logging.info("websocket disconnected (retrying in %s seconds) [%s frames in stack]" % (reconnect, len(inspect.stack()))) - dispatcher.reconnect(reconnect, setSock) + + if reconnect: + _logging.info("%s - reconnect" % e) + if custom_dispatcher: + _logging.debug("Calling custom dispatcher reconnect [%s frames in stack]" % len(inspect.stack())) + dispatcher.reconnect(reconnect, setSock) else: + _logging.error("%s - goodbye" % e) teardown() custom_dispatcher = bool(dispatcher) dispatcher = self.create_dispatcher(ping_timeout, dispatcher, parse_url(self.url)[3]) - if ping_interval: - event = threading.Event() - thread = threading.Thread( - target=self._send_ping, args=(ping_interval, event, ping_payload)) - thread.daemon = True - thread.start() - setSock() + if not custom_dispatcher and reconnect: + while self.keep_running: + _logging.debug("Calling dispatcher reconnect [%s frames in stack]" % len(inspect.stack())) + dispatcher.reconnect(reconnect, setSock) + return self.has_errored def create_dispatcher(self, ping_timeout, dispatcher=None, is_ssl=False): diff --git a/websocket/tests/test_app.py b/websocket/tests/test_app.py index 063023f..09ba348 100644 --- a/websocket/tests/test_app.py +++ b/websocket/tests/test_app.py @@ -229,6 +229,91 @@ class WebSocketAppTest(unittest.TestCase): self.assertRaises(ws.WebSocketConnectionClosedException, app.send, data="test if connection is closed") + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") + def testCallbackFunctionException(self): + """ Test callback function exception handling """ + + exc = None + passed_app = None + + def on_open(app): + raise RuntimeError("Callback failed") + + def on_error(app, err): + nonlocal passed_app + passed_app = app + nonlocal exc + exc = err + + def on_pong(app, msg): + app.close() + + app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_open=on_open, on_error=on_error, on_pong=on_pong) + app.run_forever(ping_interval=2, ping_timeout=1) + + self.assertEqual(passed_app, app) + self.assertIsInstance(exc, RuntimeError) + self.assertEqual(str(exc), "Callback failed") + + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") + def testCallbackMethodException(self): + """ Test callback method exception handling """ + + class Callbacks: + def __init__(self): + self.exc = None + self.passed_app = None + self.app = ws.WebSocketApp( + 'ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, + on_open=self.on_open, + on_error=self.on_error, + on_pong=self.on_pong + ) + self.app.run_forever(ping_interval=2, ping_timeout=1) + + def on_open(self, app): + raise RuntimeError("Callback failed") + + def on_error(self, app, err): + self.passed_app = app + self.exc = err + + def on_pong(self, app, msg): + app.close() + + callbacks = Callbacks() + + self.assertEqual(callbacks.passed_app, callbacks.app) + self.assertIsInstance(callbacks.exc, RuntimeError) + self.assertEqual(str(callbacks.exc), "Callback failed") + + @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") + def testReconnect(self): + """ Test reconnect """ + pong_count = 0 + exc = None + + def on_error(app, err): + nonlocal exc + exc = err + + def on_pong(app, msg): + nonlocal pong_count + pong_count += 1 + if pong_count == 1: + # First pong, shutdown socket, enforce read error + app.sock.shutdown() + if pong_count >= 2: + # Got second pong after reconnect + app.close() + + app = ws.WebSocketApp('ws://127.0.0.1:' + LOCAL_WS_SERVER_PORT, on_pong=on_pong, on_error=on_error) + app.run_forever(ping_interval=2, ping_timeout=1, reconnect=3) + + self.assertEqual(pong_count, 2) + self.assertIsInstance(exc, ValueError) + self.assertEqual(str(exc), "Invalid file object: None") + if __name__ == "__main__": unittest.main() |