diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-12 01:47:31 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-12 01:47:31 +0300 |
commit | a58fca290e0831d377d496a69101e5e3dc4c604e (patch) | |
tree | 8beb7504e7113ff1f01fb610513bb72745fa91ba /src | |
parent | 59ea7376985ef2c8b8b6b6d6df6b1b3be958480c (diff) | |
download | apscheduler-a58fca290e0831d377d496a69101e5e3dc4c604e.tar.gz |
Refactored event brokers to use exit stacks
Diffstat (limited to 'src')
-rw-r--r-- | src/apscheduler/abc.py | 8 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_adapter.py | 17 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/async_adapter.py | 21 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/async_local.py | 18 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/asyncpg.py | 11 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/base.py | 10 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/local.py | 14 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/mqtt.py | 14 | ||||
-rw-r--r-- | src/apscheduler/eventbrokers/redis.py | 4 |
9 files changed, 62 insertions, 55 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index 9751293..e267836 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -105,6 +105,10 @@ class EventBroker(EventSource): def publish(self, event: events.Event) -> None: """Publish an event.""" + @abstractmethod + def publish_local(self, event: events.Event) -> None: + """Publish an event, but only to local subscribers.""" + class AsyncEventBroker(EventSource): """ @@ -123,6 +127,10 @@ class AsyncEventBroker(EventSource): async def publish(self, event: events.Event) -> None: """Publish an event.""" + @abstractmethod + async def publish_local(self, event: events.Event) -> None: + """Publish an event, but only to local subscribers.""" + class DataStore: def __enter__(self): diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py index a5cf6a2..736685c 100644 --- a/src/apscheduler/datastores/async_adapter.py +++ b/src/apscheduler/datastores/async_adapter.py @@ -1,6 +1,8 @@ from __future__ import annotations +from contextlib import AsyncExitStack from datetime import datetime +from functools import partial from typing import Iterable, Optional from uuid import UUID @@ -21,23 +23,28 @@ class AsyncDataStoreAdapter(AsyncDataStore): original: DataStore _portal: BlockingPortal = attr.field(init=False) _events: AsyncEventBroker = attr.field(init=False) + _exit_stack: AsyncExitStack = attr.field(init=False) @property def events(self) -> EventSource: return self._events async def __aenter__(self) -> AsyncDataStoreAdapter: + self._exit_stack = AsyncExitStack() + self._portal = BlockingPortal() - await self._portal.__aenter__() + await self._exit_stack.enter_async_context(self._portal) + self._events = AsyncEventBrokerAdapter(self.original.events, self._portal) - await self._events.__aenter__() + await self._exit_stack.enter_async_context(self._events) + await to_thread.run_sync(self.original.__enter__) + self._exit_stack.push_async_exit(partial(to_thread.run_sync, self.original.__exit__)) + return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb) - await self._events.__aexit__(exc_type, exc_val, exc_tb) - await self._portal.__aexit__(exc_type, exc_val, exc_tb) + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) async def add_task(self, task: Task) -> None: await to_thread.run_sync(self.original.add_task, task) diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py index cb18386..1fb9177 100644 --- a/src/apscheduler/eventbrokers/async_adapter.py +++ b/src/apscheduler/eventbrokers/async_adapter.py @@ -1,8 +1,6 @@ from __future__ import annotations -from contextlib import AsyncExitStack from functools import partial -from typing import Any, Callable, Iterable, Optional import attr from anyio import to_thread @@ -10,7 +8,7 @@ from anyio.from_thread import BlockingPortal from apscheduler.abc import EventBroker from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker -from apscheduler.events import Event, SubscriptionToken +from apscheduler.events import Event from apscheduler.util import reentrant @@ -19,26 +17,19 @@ from apscheduler.util import reentrant class AsyncEventBrokerAdapter(LocalAsyncEventBroker): original: EventBroker portal: BlockingPortal - _exit_stack: AsyncExitStack = attr.field(init=False) async def __aenter__(self): - self._exit_stack = AsyncExitStack() + await super().__aenter__() + if not self.portal: self.portal = BlockingPortal() self._exit_stack.enter_async_context(self.portal) await to_thread.run_sync(self.original.__enter__) - return await super().__aenter__() + self._exit_stack.push_async_exit(partial(to_thread.run_sync, self.original.__exit__)) - async def __aexit__(self, exc_type, exc_val, exc_tb): - await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb) - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - return await super().__aexit__(exc_type, exc_val, exc_tb) + token = self.original.subscribe(partial(self.portal.call, self.publish_local)) + self._exit_stack.callback(self.original.unsubscribe, token) async def publish(self, event: Event) -> None: await to_thread.run_sync(self.original.publish, event) - - def subscribe(self, callback: Callable[[Event], Any], - event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: - token = self.original.subscribe(partial(self.portal.call, callback), event_types) - return token diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py index e73b4b9..590f0cb 100644 --- a/src/apscheduler/eventbrokers/async_local.py +++ b/src/apscheduler/eventbrokers/async_local.py @@ -1,7 +1,7 @@ from __future__ import annotations -from inspect import isawaitable -from logging import Logger, getLogger +from asyncio import iscoroutine +from contextlib import AsyncExitStack from typing import Any, Callable import attr @@ -17,23 +17,29 @@ from .base import BaseEventBroker @reentrant @attr.define(eq=False) class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker): - _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) _task_group: TaskGroup = attr.field(init=False) + _exit_stack: AsyncExitStack = attr.field(init=False) async def __aenter__(self) -> LocalAsyncEventBroker: + self._exit_stack = AsyncExitStack() + self._task_group = create_task_group() - await self._task_group.__aenter__() + await self._exit_stack.enter_async_context(self._task_group) + return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) del self._task_group async def publish(self, event: Event) -> None: + await self.publish_local(event) + + async def publish_local(self, event: Event) -> None: async def deliver_event(func: Callable[[Event], Any]) -> None: try: retval = func(event) - if isawaitable(retval): + if iscoroutine(retval): await retval except BaseException: self._logger.exception('Error delivering %s event', event.__class__.__name__) diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py index 447bee8..93bfd6a 100644 --- a/src/apscheduler/eventbrokers/asyncpg.py +++ b/src/apscheduler/eventbrokers/asyncpg.py @@ -1,7 +1,6 @@ from __future__ import annotations from contextlib import asynccontextmanager -from logging import Logger, getLogger from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable import attr @@ -25,7 +24,6 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): connection_factory: Callable[[], AsyncContextManager[Connection]] channel: str = attr.field(kw_only=True, default='apscheduler') max_idle_time: float = attr.field(kw_only=True, default=30) - _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) @classmethod def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker: @@ -50,19 +48,14 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): async def __aenter__(self) -> LocalAsyncEventBroker: await super().__aenter__() await self._task_group.start(self._listen_notifications) + self._exit_stack.callback(self._task_group.cancel_scope.cancel) return self - async def __aexit__(self, exc_type, exc_val, exc_tb): - self._task_group.cancel_scope.cancel() - await super().__aexit__(exc_type, exc_val, exc_tb) - async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: - local_publish = super(AsyncpgEventBroker, self).publish - def callback(connection, pid, channel: str, payload: str) -> None: event = self.reconstitute_event_str(payload) if event is not None: - self._task_group.start_soon(local_publish, event) + self._task_group.start_soon(self.publish_local, event) task_started_sent = False while True: diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index da89dc5..ce12055 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from base64 import b64decode, b64encode -from logging import Logger +from logging import Logger, getLogger from typing import Any, Callable, Iterable, Optional import attr @@ -14,8 +14,12 @@ from ..exceptions import DeserializationError @attr.define(eq=False) class BaseEventBroker(EventBroker): + _logger: Logger = attr.field(init=False) _subscriptions: dict[SubscriptionToken, Subscription] = attr.field(init=False, factory=dict) + def __attrs_post_init__(self) -> None: + self._logger = getLogger(self.__class__.__module__) + def subscribe(self, callback: Callable[[Event], Any], event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: types = set(event_types) if event_types else None @@ -32,10 +36,10 @@ class BaseEventBroker(EventBroker): class DistributedEventBrokerMixin: - _logger: Logger serializer: Serializer + _logger: Logger - def generate_notification(self, event: Event, use_base64: bool = False) -> bytes: + def generate_notification(self, event: Event) -> bytes: serialized = self.serializer.serialize(attr.asdict(event)) return event.__class__.__name__.encode('ascii') + b' ' + serialized diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py index ab75575..a657f4e 100644 --- a/src/apscheduler/eventbrokers/local.py +++ b/src/apscheduler/eventbrokers/local.py @@ -2,7 +2,7 @@ from __future__ import annotations from asyncio import iscoroutinefunction from concurrent.futures import ThreadPoolExecutor -from logging import Logger, getLogger +from contextlib import ExitStack from typing import Any, Callable, Iterable, Optional import attr @@ -16,14 +16,15 @@ from .base import BaseEventBroker @attr.define(eq=False) class LocalEventBroker(BaseEventBroker): _executor: ThreadPoolExecutor = attr.field(init=False) - _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _exit_stack: ExitStack = attr.field(init=False) - def __enter__(self) -> LocalEventBroker: - self._executor = ThreadPoolExecutor(1) + def __enter__(self): + self._exit_stack = ExitStack() + self._executor = self._exit_stack.enter_context(ThreadPoolExecutor(1)) return self def __exit__(self, exc_type, exc_val, exc_tb): - self._executor.shutdown(wait=exc_type is None) + self._exit_stack.__exit__(exc_type, exc_val, exc_tb) del self._executor def subscribe(self, callback: Callable[[Event], Any], @@ -35,6 +36,9 @@ class LocalEventBroker(BaseEventBroker): return super().subscribe(callback, event_types) def publish(self, event: Event) -> None: + self.publish_local(event) + + def publish_local(self, event: Event) -> None: event_type = type(event) for subscription in list(self._subscriptions.values()): if subscription.event_types is None or event_type in subscription.event_types: diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py index cfedb88..dbdffe4 100644 --- a/src/apscheduler/eventbrokers/mqtt.py +++ b/src/apscheduler/eventbrokers/mqtt.py @@ -1,7 +1,6 @@ from __future__ import annotations from concurrent.futures import Future -from logging import Logger, getLogger from typing import Any, Optional import attr @@ -27,10 +26,10 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): topic: str = attr.field(kw_only=True, default='apscheduler') subscribe_qos: int = attr.field(kw_only=True, default=0) publish_qos: int = attr.field(kw_only=True, default=0) - _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) _ready_future: Future[None] = attr.field(init=False) def __enter__(self): + super().__enter__() self._ready_future = Future() self.client.enable_logger(self._logger) self.client.on_connect = self._on_connect @@ -39,12 +38,9 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): self.client.connect(self.host, self.port) self.client.loop_start() self._ready_future.result(10) - return super().__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.client.disconnect() - self.client.loop_stop(force=exc_type is not None) - return super().__exit__(exc_type, exc_val, exc_tb) + self._exit_stack.push(lambda exc_type, *_: self.client.loop_stop(force=bool(exc_type))) + self._exit_stack.callback(self.client.disconnect) + return self def _on_connect(self, client: Client, userdata: Any, flags: dict[str, Any], rc: ReasonCodes | int, properties: Optional[Properties] = None) -> None: @@ -60,7 +56,7 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): def _on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None: event = self.reconstitute_event(msg.payload) if event is not None: - super().publish(event) + self.publish_local(event) def publish(self, event: Event) -> None: notification = self.generate_notification(event) diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py index 92c7f83..68b86e0 100644 --- a/src/apscheduler/eventbrokers/redis.py +++ b/src/apscheduler/eventbrokers/redis.py @@ -1,7 +1,6 @@ from __future__ import annotations from concurrent.futures import Future -from logging import Logger, getLogger from threading import Thread from typing import Optional @@ -23,7 +22,6 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): serializer: Serializer = attr.field(factory=JSONSerializer) channel: str = attr.field(kw_only=True, default='apscheduler') message_poll_interval: float = attr.field(kw_only=True, default=0.05) - _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) _stopped: bool = attr.field(init=False, default=True) _ready_future: Future[None] = attr.field(init=False) @@ -69,7 +67,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): if msg and isinstance(msg['data'], bytes): event = self.reconstitute_event(msg['data']) if event is not None: - super().publish(event) + self.publish_local(event) except BaseException: self._logger.exception('Subscriber crashed') raise |