summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-12 01:47:31 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-12 01:47:31 +0300
commita58fca290e0831d377d496a69101e5e3dc4c604e (patch)
tree8beb7504e7113ff1f01fb610513bb72745fa91ba
parent59ea7376985ef2c8b8b6b6d6df6b1b3be958480c (diff)
downloadapscheduler-a58fca290e0831d377d496a69101e5e3dc4c604e.tar.gz
Refactored event brokers to use exit stacks
-rw-r--r--src/apscheduler/abc.py8
-rw-r--r--src/apscheduler/datastores/async_adapter.py17
-rw-r--r--src/apscheduler/eventbrokers/async_adapter.py21
-rw-r--r--src/apscheduler/eventbrokers/async_local.py18
-rw-r--r--src/apscheduler/eventbrokers/asyncpg.py11
-rw-r--r--src/apscheduler/eventbrokers/base.py10
-rw-r--r--src/apscheduler/eventbrokers/local.py14
-rw-r--r--src/apscheduler/eventbrokers/mqtt.py14
-rw-r--r--src/apscheduler/eventbrokers/redis.py4
9 files changed, 62 insertions, 55 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py
index 9751293..e267836 100644
--- a/src/apscheduler/abc.py
+++ b/src/apscheduler/abc.py
@@ -105,6 +105,10 @@ class EventBroker(EventSource):
def publish(self, event: events.Event) -> None:
"""Publish an event."""
+ @abstractmethod
+ def publish_local(self, event: events.Event) -> None:
+ """Publish an event, but only to local subscribers."""
+
class AsyncEventBroker(EventSource):
"""
@@ -123,6 +127,10 @@ class AsyncEventBroker(EventSource):
async def publish(self, event: events.Event) -> None:
"""Publish an event."""
+ @abstractmethod
+ async def publish_local(self, event: events.Event) -> None:
+ """Publish an event, but only to local subscribers."""
+
class DataStore:
def __enter__(self):
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py
index a5cf6a2..736685c 100644
--- a/src/apscheduler/datastores/async_adapter.py
+++ b/src/apscheduler/datastores/async_adapter.py
@@ -1,6 +1,8 @@
from __future__ import annotations
+from contextlib import AsyncExitStack
from datetime import datetime
+from functools import partial
from typing import Iterable, Optional
from uuid import UUID
@@ -21,23 +23,28 @@ class AsyncDataStoreAdapter(AsyncDataStore):
original: DataStore
_portal: BlockingPortal = attr.field(init=False)
_events: AsyncEventBroker = attr.field(init=False)
+ _exit_stack: AsyncExitStack = attr.field(init=False)
@property
def events(self) -> EventSource:
return self._events
async def __aenter__(self) -> AsyncDataStoreAdapter:
+ self._exit_stack = AsyncExitStack()
+
self._portal = BlockingPortal()
- await self._portal.__aenter__()
+ await self._exit_stack.enter_async_context(self._portal)
+
self._events = AsyncEventBrokerAdapter(self.original.events, self._portal)
- await self._events.__aenter__()
+ await self._exit_stack.enter_async_context(self._events)
+
await to_thread.run_sync(self.original.__enter__)
+ self._exit_stack.push_async_exit(partial(to_thread.run_sync, self.original.__exit__))
+
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
- await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb)
- await self._events.__aexit__(exc_type, exc_val, exc_tb)
- await self._portal.__aexit__(exc_type, exc_val, exc_tb)
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
async def add_task(self, task: Task) -> None:
await to_thread.run_sync(self.original.add_task, task)
diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py
index cb18386..1fb9177 100644
--- a/src/apscheduler/eventbrokers/async_adapter.py
+++ b/src/apscheduler/eventbrokers/async_adapter.py
@@ -1,8 +1,6 @@
from __future__ import annotations
-from contextlib import AsyncExitStack
from functools import partial
-from typing import Any, Callable, Iterable, Optional
import attr
from anyio import to_thread
@@ -10,7 +8,7 @@ from anyio.from_thread import BlockingPortal
from apscheduler.abc import EventBroker
from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker
-from apscheduler.events import Event, SubscriptionToken
+from apscheduler.events import Event
from apscheduler.util import reentrant
@@ -19,26 +17,19 @@ from apscheduler.util import reentrant
class AsyncEventBrokerAdapter(LocalAsyncEventBroker):
original: EventBroker
portal: BlockingPortal
- _exit_stack: AsyncExitStack = attr.field(init=False)
async def __aenter__(self):
- self._exit_stack = AsyncExitStack()
+ await super().__aenter__()
+
if not self.portal:
self.portal = BlockingPortal()
self._exit_stack.enter_async_context(self.portal)
await to_thread.run_sync(self.original.__enter__)
- return await super().__aenter__()
+ self._exit_stack.push_async_exit(partial(to_thread.run_sync, self.original.__exit__))
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb)
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
- return await super().__aexit__(exc_type, exc_val, exc_tb)
+ token = self.original.subscribe(partial(self.portal.call, self.publish_local))
+ self._exit_stack.callback(self.original.unsubscribe, token)
async def publish(self, event: Event) -> None:
await to_thread.run_sync(self.original.publish, event)
-
- def subscribe(self, callback: Callable[[Event], Any],
- event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
- token = self.original.subscribe(partial(self.portal.call, callback), event_types)
- return token
diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py
index e73b4b9..590f0cb 100644
--- a/src/apscheduler/eventbrokers/async_local.py
+++ b/src/apscheduler/eventbrokers/async_local.py
@@ -1,7 +1,7 @@
from __future__ import annotations
-from inspect import isawaitable
-from logging import Logger, getLogger
+from asyncio import iscoroutine
+from contextlib import AsyncExitStack
from typing import Any, Callable
import attr
@@ -17,23 +17,29 @@ from .base import BaseEventBroker
@reentrant
@attr.define(eq=False)
class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker):
- _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
_task_group: TaskGroup = attr.field(init=False)
+ _exit_stack: AsyncExitStack = attr.field(init=False)
async def __aenter__(self) -> LocalAsyncEventBroker:
+ self._exit_stack = AsyncExitStack()
+
self._task_group = create_task_group()
- await self._task_group.__aenter__()
+ await self._exit_stack.enter_async_context(self._task_group)
+
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
del self._task_group
async def publish(self, event: Event) -> None:
+ await self.publish_local(event)
+
+ async def publish_local(self, event: Event) -> None:
async def deliver_event(func: Callable[[Event], Any]) -> None:
try:
retval = func(event)
- if isawaitable(retval):
+ if iscoroutine(retval):
await retval
except BaseException:
self._logger.exception('Error delivering %s event', event.__class__.__name__)
diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py
index 447bee8..93bfd6a 100644
--- a/src/apscheduler/eventbrokers/asyncpg.py
+++ b/src/apscheduler/eventbrokers/asyncpg.py
@@ -1,7 +1,6 @@
from __future__ import annotations
from contextlib import asynccontextmanager
-from logging import Logger, getLogger
from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable
import attr
@@ -25,7 +24,6 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
connection_factory: Callable[[], AsyncContextManager[Connection]]
channel: str = attr.field(kw_only=True, default='apscheduler')
max_idle_time: float = attr.field(kw_only=True, default=30)
- _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
@classmethod
def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker:
@@ -50,19 +48,14 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
async def __aenter__(self) -> LocalAsyncEventBroker:
await super().__aenter__()
await self._task_group.start(self._listen_notifications)
+ self._exit_stack.callback(self._task_group.cancel_scope.cancel)
return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- self._task_group.cancel_scope.cancel()
- await super().__aexit__(exc_type, exc_val, exc_tb)
-
async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
- local_publish = super(AsyncpgEventBroker, self).publish
-
def callback(connection, pid, channel: str, payload: str) -> None:
event = self.reconstitute_event_str(payload)
if event is not None:
- self._task_group.start_soon(local_publish, event)
+ self._task_group.start_soon(self.publish_local, event)
task_started_sent = False
while True:
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py
index da89dc5..ce12055 100644
--- a/src/apscheduler/eventbrokers/base.py
+++ b/src/apscheduler/eventbrokers/base.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from base64 import b64decode, b64encode
-from logging import Logger
+from logging import Logger, getLogger
from typing import Any, Callable, Iterable, Optional
import attr
@@ -14,8 +14,12 @@ from ..exceptions import DeserializationError
@attr.define(eq=False)
class BaseEventBroker(EventBroker):
+ _logger: Logger = attr.field(init=False)
_subscriptions: dict[SubscriptionToken, Subscription] = attr.field(init=False, factory=dict)
+ def __attrs_post_init__(self) -> None:
+ self._logger = getLogger(self.__class__.__module__)
+
def subscribe(self, callback: Callable[[Event], Any],
event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
types = set(event_types) if event_types else None
@@ -32,10 +36,10 @@ class BaseEventBroker(EventBroker):
class DistributedEventBrokerMixin:
- _logger: Logger
serializer: Serializer
+ _logger: Logger
- def generate_notification(self, event: Event, use_base64: bool = False) -> bytes:
+ def generate_notification(self, event: Event) -> bytes:
serialized = self.serializer.serialize(attr.asdict(event))
return event.__class__.__name__.encode('ascii') + b' ' + serialized
diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py
index ab75575..a657f4e 100644
--- a/src/apscheduler/eventbrokers/local.py
+++ b/src/apscheduler/eventbrokers/local.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from asyncio import iscoroutinefunction
from concurrent.futures import ThreadPoolExecutor
-from logging import Logger, getLogger
+from contextlib import ExitStack
from typing import Any, Callable, Iterable, Optional
import attr
@@ -16,14 +16,15 @@ from .base import BaseEventBroker
@attr.define(eq=False)
class LocalEventBroker(BaseEventBroker):
_executor: ThreadPoolExecutor = attr.field(init=False)
- _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+ _exit_stack: ExitStack = attr.field(init=False)
- def __enter__(self) -> LocalEventBroker:
- self._executor = ThreadPoolExecutor(1)
+ def __enter__(self):
+ self._exit_stack = ExitStack()
+ self._executor = self._exit_stack.enter_context(ThreadPoolExecutor(1))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- self._executor.shutdown(wait=exc_type is None)
+ self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
del self._executor
def subscribe(self, callback: Callable[[Event], Any],
@@ -35,6 +36,9 @@ class LocalEventBroker(BaseEventBroker):
return super().subscribe(callback, event_types)
def publish(self, event: Event) -> None:
+ self.publish_local(event)
+
+ def publish_local(self, event: Event) -> None:
event_type = type(event)
for subscription in list(self._subscriptions.values()):
if subscription.event_types is None or event_type in subscription.event_types:
diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py
index cfedb88..dbdffe4 100644
--- a/src/apscheduler/eventbrokers/mqtt.py
+++ b/src/apscheduler/eventbrokers/mqtt.py
@@ -1,7 +1,6 @@
from __future__ import annotations
from concurrent.futures import Future
-from logging import Logger, getLogger
from typing import Any, Optional
import attr
@@ -27,10 +26,10 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
topic: str = attr.field(kw_only=True, default='apscheduler')
subscribe_qos: int = attr.field(kw_only=True, default=0)
publish_qos: int = attr.field(kw_only=True, default=0)
- _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
_ready_future: Future[None] = attr.field(init=False)
def __enter__(self):
+ super().__enter__()
self._ready_future = Future()
self.client.enable_logger(self._logger)
self.client.on_connect = self._on_connect
@@ -39,12 +38,9 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
self.client.connect(self.host, self.port)
self.client.loop_start()
self._ready_future.result(10)
- return super().__enter__()
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.client.disconnect()
- self.client.loop_stop(force=exc_type is not None)
- return super().__exit__(exc_type, exc_val, exc_tb)
+ self._exit_stack.push(lambda exc_type, *_: self.client.loop_stop(force=bool(exc_type)))
+ self._exit_stack.callback(self.client.disconnect)
+ return self
def _on_connect(self, client: Client, userdata: Any, flags: dict[str, Any],
rc: ReasonCodes | int, properties: Optional[Properties] = None) -> None:
@@ -60,7 +56,7 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
def _on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None:
event = self.reconstitute_event(msg.payload)
if event is not None:
- super().publish(event)
+ self.publish_local(event)
def publish(self, event: Event) -> None:
notification = self.generate_notification(event)
diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py
index 92c7f83..68b86e0 100644
--- a/src/apscheduler/eventbrokers/redis.py
+++ b/src/apscheduler/eventbrokers/redis.py
@@ -1,7 +1,6 @@
from __future__ import annotations
from concurrent.futures import Future
-from logging import Logger, getLogger
from threading import Thread
from typing import Optional
@@ -23,7 +22,6 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
serializer: Serializer = attr.field(factory=JSONSerializer)
channel: str = attr.field(kw_only=True, default='apscheduler')
message_poll_interval: float = attr.field(kw_only=True, default=0.05)
- _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
_stopped: bool = attr.field(init=False, default=True)
_ready_future: Future[None] = attr.field(init=False)
@@ -69,7 +67,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
if msg and isinstance(msg['data'], bytes):
event = self.reconstitute_event(msg['data'])
if event is not None:
- super().publish(event)
+ self.publish_local(event)
except BaseException:
self._logger.exception('Subscriber crashed')
raise