diff options
author | Gao <mishanyo1001@gmail.com> | 2022-09-07 23:54:28 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-07 21:54:28 +0600 |
commit | 8699920e050727d385a6d5a19c939e55a86688d6 (patch) | |
tree | 3447468c8590798aba3861121e47ac94fba81eac | |
parent | ec533af9c1c6e156a1fe754fddc2095ebdba8554 (diff) | |
download | kombu-8699920e050727d385a6d5a19c939e55a86688d6.tar.gz |
Solve Kombu filesystem transport not thread safe (#1593)
* Solve Kombu filesystem transport not thread safe
fix: #398
Currently only write lock used in msg/exchange file written. Cause
reading in other thread got some incomplete result.
1. Add timeout for the lock acquire.
2. Add Share locks when reading message from filesystem.
3. Add a unit test for the `lock` and `unlock`
4. Add a unit test to test the lock during message processing.
* Replace deprecated function.
-rw-r--r-- | kombu/transport/filesystem.py | 87 | ||||
-rw-r--r-- | t/unit/transport/test_filesystem.py | 112 |
2 files changed, 170 insertions, 29 deletions
diff --git a/kombu/transport/filesystem.py b/kombu/transport/filesystem.py index 92d8c4e4..eb07bba6 100644 --- a/kombu/transport/filesystem.py +++ b/kombu/transport/filesystem.py @@ -93,6 +93,7 @@ from __future__ import annotations import os import shutil +import signal import tempfile import uuid from collections import namedtuple @@ -111,6 +112,26 @@ from . import virtual VERSION = (1, 0, 0) __version__ = '.'.join(map(str, VERSION)) + +@contextmanager +def timeout_manager(seconds: int): + def timeout_handler(signum, frame): + # Now that flock retries automatically when interrupted, we need + # an exception to stop it + # This exception will propagate on the main thread, + # make sure you're calling flock there + raise InterruptedError + + original_handler = signal.signal(signal.SIGALRM, timeout_handler) + + try: + signal.alarm(seconds) + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + + # needs win32all to work on Windows if os.name == 'nt': @@ -138,7 +159,7 @@ if os.name == 'nt': elif os.name == 'posix': import fcntl - from fcntl import LOCK_EX, LOCK_NB, LOCK_SH # noqa + from fcntl import LOCK_EX, LOCK_SH def lock(file, flags): """Create file lock.""" @@ -154,6 +175,21 @@ else: 'Filesystem plugin only defined for NT and POSIX platforms') +@contextmanager +def lock_with_timeout(file, flags, timeout: int = 1): + with timeout_manager(timeout): + try: + lock(file, flags) + yield + except InterruptedError: + # Catch the exception raised by the handler + # If we weren't raising an exception, + # flock would automatically retry on signals + raise BlockingIOError("Lock timed out") + finally: + unlock(file) + + exchange_queue_t = namedtuple("exchange_queue_t", ["routing_key", "pattern", "queue"]) @@ -168,18 +204,14 @@ class Channel(virtual.Channel): file = self.control_folder / f"{exchange}.exchange" if "w" in mode: self.control_folder.mkdir(exist_ok=True) - f_obj = file.open(mode) + lock_mode = LOCK_EX if "w" in mode else LOCK_SH - try: - if "w" in mode: - lock(f_obj, LOCK_EX) - yield f_obj - except OSError: - raise ChannelError(f"Cannot open {file}") - finally: - if "w" in mode: - unlock(f_obj) - f_obj.close() + with file.open(mode) as f_obj: + try: + with lock_with_timeout(f_obj, lock_mode): + yield f_obj + except OSError as err: + raise ChannelError(f"Cannot open {file}") from err def get_table(self, exchange): try: @@ -209,15 +241,12 @@ class Channel(virtual.Channel): filename = os.path.join(self.data_folder_out, filename) try: - f = open(filename, 'wb') - lock(f, LOCK_EX) - f.write(str_to_bytes(dumps(payload))) - except OSError: + with open(filename, 'wb') as f: + with lock_with_timeout(f, LOCK_EX): + f.write(str_to_bytes(dumps(payload))) + except OSError as err: raise ChannelError( - f'Cannot add file {filename!r} to directory') - finally: - unlock(f) - f.close() + f'Cannot add file {filename!r} to directory') from err def _get(self, queue): """Get next message from `queue`.""" @@ -245,14 +274,14 @@ class Channel(virtual.Channel): filename = os.path.join(processed_folder, filename) try: - f = open(filename, 'rb') - payload = f.read() - f.close() - if not self.store_processed: - os.remove(filename) - except OSError: + with open(filename, 'rb') as f: + with lock_with_timeout(f, LOCK_SH): + payload = f.read() + if not self.store_processed: + os.remove(filename) + except OSError as err: raise ChannelError( - f'Cannot read file {filename!r} from queue.') + f'Cannot read file {filename!r} from queue.') from err return loads(bytes_to_str(payload)) @@ -272,7 +301,9 @@ class Channel(virtual.Channel): continue filename = os.path.join(self.data_folder_in, filename) - os.remove(filename) + with open(filename, 'wb') as f: + with lock_with_timeout(f, LOCK_EX): + os.remove(filename) count += 1 diff --git a/t/unit/transport/test_filesystem.py b/t/unit/transport/test_filesystem.py index b22e3b8d..c452e7ca 100644 --- a/t/unit/transport/test_filesystem.py +++ b/t/unit/transport/test_filesystem.py @@ -1,17 +1,19 @@ from __future__ import annotations import tempfile +from fcntl import LOCK_EX, LOCK_NB, LOCK_SH from queue import Empty +from unittest.mock import call, patch import pytest import t.skip from kombu import Connection, Consumer, Exchange, Producer, Queue +from kombu.transport.filesystem import lock, unlock @t.skip.if_win32 class test_FilesystemTransport: - def setup(self): self.channels = set() try: @@ -145,6 +147,7 @@ class test_FilesystemTransport: @t.skip.if_win32 class test_FilesystemFanout: + def setup(self): try: data_folder_in = tempfile.mkdtemp() @@ -234,3 +237,110 @@ class test_FilesystemFanout: assert self.q2(self.consume_channel).get() self.q2(self.consume_channel).purge() assert self.q2(self.consume_channel).get() is None + + +@t.skip.if_win32 +class test_FilesystemLock: + def test_lock(self): + file_obj1 = tempfile.NamedTemporaryFile() + with open(file_obj1.name) as file_obj2: + lock(file_obj1, LOCK_SH) + with pytest.raises(BlockingIOError): + lock(file_obj2, LOCK_EX | LOCK_NB) + + lock(file_obj2, LOCK_SH) + unlock(file_obj2) + + unlock(file_obj1) + lock(file_obj2, LOCK_EX) + unlock(file_obj2) + file_obj1.close() + + +@t.skip.if_win32 +class test_FilesystemLockDuringProcess: + def setup(self): + try: + data_folder_in = tempfile.mkdtemp() + data_folder_out = tempfile.mkdtemp() + control_folder = tempfile.mkdtemp() + except Exception: + pytest.skip("filesystem transport: cannot create tempfiles") + + self.consumer_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_in, + "data_folder_out": data_folder_out, + "control_folder": control_folder, + }, + ) + self.consume_channel = self.consumer_connection.channel() + self.produce_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_out, + "data_folder_out": data_folder_in, + "control_folder": control_folder, + }, + ) + self.producer_channel = self.produce_connection.channel() + self.exchange = Exchange("filesystem_exchange_lock", type="fanout") + self.q = Queue("queue1", exchange=self.exchange) + + def teardown(self): + # make sure we don't attempt to restore messages at shutdown. + for channel in [self.producer_channel, self.consumer_connection]: + try: + channel._qos._dirty.clear() + except AttributeError: + pass + try: + channel._qos._delivered.clear() + except AttributeError: + pass + + def test_lock_during_process(self): + producer = Producer(self.producer_channel, self.exchange) + + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + consumer = Consumer(self.consume_channel, self.q) + assert unlock_m.call_count == 1 + lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_EX) + + self.q(self.consume_channel).declare() + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + producer.publish({"foo": 1}) + assert unlock_m.call_count == 2 + assert lock_m.call_count == 2 + exchange_file_obj = unlock_m.call_args_list[0][0][0] + msg_file_obj = unlock_m.call_args_list[1][0][0] + assert lock_m.call_args_list == [call(exchange_file_obj, LOCK_SH), + call(msg_file_obj, LOCK_EX)] + + def callback(_, message): + message.ack() + + consumer.register_callback(callback) + consumer.consume() + + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + self.consume_channel.drain_events() + assert lock_m.call_count == 1 + assert unlock_m.call_count == 1 + lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_SH) + + producer.publish({"foo": 0}) + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + self.q(self.consume_channel).purge() + assert lock_m.call_count == 1 + assert unlock_m.call_count == 1 + lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_EX) |