diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-11 21:14:14 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-11 21:14:14 +0300 |
commit | 56afe91d5dc338db3440b2e9ecdea3e522dba30f (patch) | |
tree | 311380b0d953f09919d7e8c4c0a340507e5d0dc5 | |
parent | 7248a78e7e787b728b083aaa8199eeba3a3f3023 (diff) | |
download | apscheduler-56afe91d5dc338db3440b2e9ecdea3e522dba30f.tar.gz |
Implemented a pluggable event broker system
29 files changed, 983 insertions, 456 deletions
diff --git a/docker-compose.yml b/docker-compose.yml index 89fd92b..066a411 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,3 +20,15 @@ services: image: mongo ports: - 127.0.0.1:27017:27017 + + mosquitto: + image: eclipse-mosquitto:2 + volumes: + - ./mosquitto.conf:/mosquitto/config/mosquitto.conf:ro + ports: + - 127.0.0.1:1883:1883 + + redis: + image: redis:6 + ports: + - 127.0.0.1:6379:6379 diff --git a/mosquitto.conf b/mosquitto.conf new file mode 100644 index 0000000..7c6f5be --- /dev/null +++ b/mosquitto.conf @@ -0,0 +1,3 @@ +# Required configuration for the mosquitto server in docker-compose.yml +listener 1883 +allow_anonymous true @@ -26,30 +26,33 @@ python_requires = >= 3.7 install_requires = anyio ~= 3.0 attrs >= 20.1 - backports.zoneinfo; python_version < '3.9' - tzdata; platform_system == "Windows" tzlocal >= 3.0 [options.packages.find] where = src [options.extras_require] +asyncpg = asyncpg >= 0.20 cbor = cbor2 >= 5.0 mongodb = pymongo >= 3.12 -postgresql = asyncpg >= 0.20 +mqtt = paho-mqtt >= 1.5 +redis = redis >= 3.5 sqlalchemy = sqlalchemy >= 1.4.22 test = asyncpg >= 0.20 cbor2 >= 5.0 coverage freezegun + paho-mqtt >= 1.5 psycopg2 pymongo >= 3.12 pymysql[rsa] pytest >= 5.0 pytest-cov pytest-freezegun + pytest-lazy-fixture pytest-mock + redis >= 3.5 sqlalchemy >= 1.4.22 trio doc = diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index 97fec87..9751293 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -65,6 +65,8 @@ class Serializer(metaclass=ABCMeta): class EventSource(metaclass=ABCMeta): + """Interface for objects that can deliver notifications to interested subscribers.""" + @abstractmethod def subscribe( self, callback: Callable[[events.Event], Any], @@ -86,7 +88,13 @@ class EventSource(metaclass=ABCMeta): """ -class DataStore(EventSource): +class EventBroker(EventSource): + """ + Interface for objects that can be used to publish notifications to interested subscribers. + + Can be used as a context manager. + """ + def __enter__(self): return self @@ -94,6 +102,41 @@ class DataStore(EventSource): pass @abstractmethod + def publish(self, event: events.Event) -> None: + """Publish an event.""" + + +class AsyncEventBroker(EventSource): + """ + Asynchronous version of :class:`EventBroker`. + + Can be used as an asynchronous context manager. + """ + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + @abstractmethod + async def publish(self, event: events.Event) -> None: + """Publish an event.""" + + +class DataStore: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + @property + @abstractmethod + def events(self) -> EventSource: + pass + + @abstractmethod def add_task(self, task: Task) -> None: """ Add the given task to the store. @@ -239,13 +282,18 @@ class DataStore(EventSource): """ -class AsyncDataStore(EventSource): +class AsyncDataStore: async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): pass + @property + @abstractmethod + def events(self) -> EventSource: + pass + @abstractmethod async def add_task(self, task: Task) -> None: """ diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py new file mode 100644 index 0000000..7e8e590 --- /dev/null +++ b/src/apscheduler/converters.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Optional +from uuid import UUID + + +def as_aware_datetime(value: datetime | str) -> Optional[datetime]: + """Convert the value from a string to a timezone aware datetime.""" + if isinstance(value, str): + # fromisoformat() does not handle the "Z" suffix + if value.upper().endswith('Z'): + value = value[:-1] + '+00:00' + + value = datetime.fromisoformat(value) + + return value + + +def as_uuid(value: UUID | str) -> UUID: + """Converts a string-formatted UUID to a UUID instance.""" + if isinstance(value, str): + return UUID(value) + + return value diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py index 89f268a..a5cf6a2 100644 --- a/src/apscheduler/datastores/async_adapter.py +++ b/src/apscheduler/datastores/async_adapter.py @@ -1,18 +1,16 @@ from __future__ import annotations from datetime import datetime -from functools import partial -from typing import Any, Callable, Iterable, Optional +from typing import Iterable, Optional from uuid import UUID import attr from anyio import to_thread from anyio.from_thread import BlockingPortal -from .. import events -from ..abc import AsyncDataStore, DataStore +from ..abc import AsyncDataStore, AsyncEventBroker, DataStore, EventSource from ..enums import ConflictPolicy -from ..events import Event, SubscriptionToken +from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter from ..structures import Job, JobResult, Schedule, Task from ..util import reentrant @@ -21,16 +19,24 @@ from ..util import reentrant @attr.define(eq=False) class AsyncDataStoreAdapter(AsyncDataStore): original: DataStore - _portal: BlockingPortal = attr.field(init=False, eq=False) + _portal: BlockingPortal = attr.field(init=False) + _events: AsyncEventBroker = attr.field(init=False) + + @property + def events(self) -> EventSource: + return self._events async def __aenter__(self) -> AsyncDataStoreAdapter: self._portal = BlockingPortal() await self._portal.__aenter__() + self._events = AsyncEventBrokerAdapter(self.original.events, self._portal) + await self._events.__aenter__() await to_thread.run_sync(self.original.__enter__) return self async def __aexit__(self, exc_type, exc_val, exc_tb): await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb) + await self._events.__aexit__(exc_type, exc_val, exc_tb) await self._portal.__aexit__(exc_type, exc_val, exc_tb) async def add_task(self, task: Task) -> None: @@ -77,10 +83,3 @@ class AsyncDataStoreAdapter(AsyncDataStore): async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: return await to_thread.run_sync(self.original.get_job_result, job_id) - - def subscribe(self, callback: Callable[[Event], Any], - event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: - return self.original.subscribe(partial(self._portal.call, callback), event_types) - - def unsubscribe(self, token: events.SubscriptionToken) -> None: - self.original.unsubscribe(token) diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py index c0d21cb..a215d68 100644 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ b/src/apscheduler/datastores/async_sqlalchemy.py @@ -1,30 +1,25 @@ from __future__ import annotations -import json from collections import defaultdict -from contextlib import AsyncExitStack, closing from datetime import datetime, timedelta, timezone -from json import JSONDecodeError from typing import Any, Callable, Iterable, Optional from uuid import UUID import attr import sniffio -from anyio import TASK_STATUS_IGNORED, create_task_group, sleep -from attr import asdict -from sqlalchemy import and_, bindparam, func, or_, select +from sqlalchemy import and_, bindparam, or_, select from sqlalchemy.engine import URL, Result from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter -from .. import events as events_module -from ..abc import AsyncDataStore, Job, Schedule +from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule from ..enums import ConflictPolicy +from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import ( - AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded, + DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded, ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, TaskRemoved, TaskUpdated) from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError @@ -34,43 +29,11 @@ from ..util import reentrant from .sqlalchemy import _BaseSQLAlchemyDataStore -def default_json_handler(obj: Any) -> Any: - if isinstance(obj, datetime): - return obj.timestamp() - elif isinstance(obj, UUID): - return obj.hex - elif isinstance(obj, frozenset): - return list(obj) - - raise TypeError(f'Cannot JSON encode type {type(obj)}') - - -def json_object_hook(obj: dict[str, Any]) -> Any: - for key, value in obj.items(): - if key == 'timestamp': - obj[key] = datetime.fromtimestamp(value, timezone.utc) - elif key == 'job_id': - obj[key] = UUID(value) - elif key == 'tags': - obj[key] = frozenset(value) - - return obj - - @reentrant @attr.define(eq=False) class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): engine: AsyncEngine - - _exit_stack: AsyncExitStack = attr.field(init=False, factory=AsyncExitStack) - _events: AsyncEventHub = attr.field(init=False, factory=AsyncEventHub) - - def __attrs_post_init__(self) -> None: - super().__attrs_post_init__() - - if self.notify_channel: - if self.engine.dialect.name != 'postgresql' or self.engine.dialect.driver != 'asyncpg': - self.notify_channel = None + _events: AsyncEventBroker = attr.field(factory=LocalAsyncEventBroker) @classmethod def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: @@ -98,93 +61,44 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): raise RuntimeError(f'Unexpected schema version ({version}); ' f'only version 1 is supported by this version of APScheduler') - await self._exit_stack.enter_async_context(self._events) - - if self.notify_channel: - task_group = create_task_group() - await self._exit_stack.enter_async_context(task_group) - await task_group.start(self._listen_notifications) - self._exit_stack.callback(task_group.cancel_scope.cancel) - + await self.events.__aenter__() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - - async def _publish(self, conn: AsyncConnection, event: DataStoreEvent) -> None: - if self.notify_channel: - event_type = event.__class__.__name__ - event_data = json.dumps(asdict(event), ensure_ascii=False, - default=default_json_handler) - notification = event_type + ' ' + event_data - if len(notification) < 8000: - await conn.execute(func.pg_notify(self.notify_channel, notification)) - return - - self._logger.warning( - 'Could not send %s notification because it is too long (%d >= 8000)', - event_type, len(notification)) - - self._events.publish(event) - - async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: - def callback(connection, pid, channel: str, payload: str) -> None: - self._logger.debug('Received notification on channel %s: %s', channel, payload) - event_type, _, json_data = payload.partition(' ') - try: - event_data = json.loads(json_data, object_hook=json_object_hook) - except JSONDecodeError: - self._logger.exception('Failed decoding JSON payload of notification: %s', payload) - return - - event_class = getattr(events_module, event_type) - event = event_class(**event_data) - self._events.publish(event) - - task_started_sent = False - while True: - with closing(await self.engine.raw_connection()) as conn: - asyncpg_conn = conn.connection._connection - await asyncpg_conn.add_listener(self.notify_channel, callback) - if not task_started_sent: - task_status.started() - task_started_sent = True - - try: - while True: - await sleep(self.max_idle_time) - await asyncpg_conn.execute('SELECT 1') - finally: - await asyncpg_conn.remove_listener(self.notify_channel, callback) - - def _deserialize_schedules(self, result: Result) -> list[Schedule]: + await self.events.__aexit__(exc_type, exc_val, exc_tb) + + @property + def events(self) -> EventSource: + return self._events + + async def _deserialize_schedules(self, result: Result) -> list[Schedule]: schedules: list[Schedule] = [] for row in result: try: schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._events.publish( ScheduleDeserializationFailed(schedule_id=row['id'], exception=exc)) return schedules - def _deserialize_jobs(self, result: Result) -> list[Job]: + async def _deserialize_jobs(self, result: Result) -> list[Job]: jobs: list[Job] = [] for row in result: try: jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish( + await self._events.publish( JobDeserializationFailed(job_id=row['id'], exception=exc)) return jobs def subscribe(self, callback: Callable[[Event], Any], event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: - return self._events.subscribe(callback, event_types) + return self.events.subscribe(callback, event_types) def unsubscribe(self, token: SubscriptionToken) -> None: - self._events.unsubscribe(token) + self.events.unsubscribe(token) async def add_task(self, task: Task) -> None: insert = self.t_tasks.insert().\ @@ -201,9 +115,10 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): where(self.t_tasks.c.id == task.id) async with self.engine.begin() as conn: await conn.execute(update) - self._events.publish(TaskUpdated(task_id=task.id)) + + await self._events.publish(TaskUpdated(task_id=task.id)) else: - self._events.publish(TaskAdded(task_id=task.id)) + await self._events.publish(TaskAdded(task_id=task.id)) async def remove_task(self, task_id: str) -> None: delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) @@ -212,7 +127,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): if result.rowcount == 0: raise TaskLookupError(task_id) else: - self._events.publish(TaskRemoved(task_id=task_id)) + await self._events.publish(TaskRemoved(task_id=task_id)) async def get_task(self, task_id: str) -> Task: query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, @@ -243,9 +158,6 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): try: async with self.engine.begin() as conn: await conn.execute(insert) - event = ScheduleAdded(schedule_id=schedule.id, - next_fire_time=schedule.next_fire_time) - await self._publish(conn, event) except IntegrityError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None @@ -257,9 +169,13 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async with self.engine.begin() as conn: await conn.execute(update) - event = ScheduleUpdated(schedule_id=schedule.id, - next_fire_time=schedule.next_fire_time) - await self._publish(conn, event) + event = ScheduleUpdated(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + await self._events.publish(event) + else: + event = ScheduleAdded(schedule_id=schedule.id, + next_fire_time=schedule.next_fire_time) + await self._events.publish(event) async def remove_schedules(self, ids: Iterable[str]) -> None: async with self.engine.begin() as conn: @@ -272,8 +188,8 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): await conn.execute(delete) removed_ids = ids - for schedule_id in removed_ids: - await self._publish(conn, ScheduleRemoved(schedule_id=schedule_id)) + for schedule_id in removed_ids: + await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) async def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]: query = self.t_schedules.select().order_by(self.t_schedules.c.id) @@ -282,7 +198,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async with self.engine.begin() as conn: result = await conn.execute(query) - return self._deserialize_schedules(result) + return await self._deserialize_schedules(result) async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: async with self.engine.begin() as conn: @@ -308,7 +224,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): where(and_(self.t_schedules.c.acquired_by == scheduler_id)) result = conn.execute(query) - schedules = self._deserialize_schedules(result) + schedules = await self._deserialize_schedules(result) return schedules @@ -365,11 +281,11 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): where(self.t_schedules.c.id.in_(finished_schedule_ids)) await conn.execute(delete) - for event in update_events: - await self._publish(conn, event) + for event in update_events: + await self._events.publish(event) - for schedule_id in finished_schedule_ids: - await self._publish(conn, ScheduleRemoved(schedule_id=schedule_id)) + for schedule_id in finished_schedule_ids: + await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) async def get_next_schedule_run_time(self) -> Optional[datetime]: statenent = select(self.t_schedules.c.id).\ @@ -386,9 +302,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async with self.engine.begin() as conn: await conn.execute(insert) - event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, - tags=job.tags) - await self._publish(conn, event) + event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, + tags=job.tags) + await self._events.publish(event) async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> list[Job]: query = self.t_jobs.select().order_by(self.t_jobs.c.id) @@ -398,7 +314,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async with self.engine.begin() as conn: result = await conn.execute(query) - return self._deserialize_jobs(result) + return await self._deserialize_jobs(result) async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> list[Job]: async with self.engine.begin() as conn: @@ -416,7 +332,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): return [] # Mark the jobs as acquired by this worker - jobs = self._deserialize_jobs(result) + jobs = await self._deserialize_jobs(result) task_ids: set[str] = {job.task_id for job in jobs} # Retrieve the limits diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py index df4e0c2..98306a2 100644 --- a/src/apscheduler/datastores/memory.py +++ b/src/apscheduler/datastores/memory.py @@ -4,17 +4,16 @@ from bisect import bisect_left, insort_right from collections import defaultdict from datetime import MAXYEAR, datetime, timedelta, timezone from functools import partial -from typing import Any, Callable, Iterable, Optional +from typing import Any, Iterable, Optional from uuid import UUID import attr -from .. import events -from ..abc import DataStore, Job, Schedule +from ..abc import DataStore, EventBroker, EventSource, Job, Schedule from ..enums import ConflictPolicy +from ..eventbrokers.local import LocalEventBroker from ..events import ( - EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, - TaskAdded, TaskRemoved, TaskUpdated) + JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, TaskAdded, TaskRemoved, TaskUpdated) from ..exceptions import ConflictingIdError, TaskLookupError from ..structures import JobResult, Task from ..util import reentrant @@ -75,7 +74,7 @@ class JobState: @attr.define(eq=False) class MemoryDataStore(DataStore): lock_expiration_delay: float = 30 - _events: EventHub = attr.Factory(EventHub) + _events: EventBroker = attr.Factory(LocalEventBroker) _tasks: dict[str, TaskState] = attr.Factory(dict) _schedules: list[ScheduleState] = attr.Factory(list) _schedules_by_id: dict[str, ScheduleState] = attr.Factory(dict) @@ -102,12 +101,9 @@ class MemoryDataStore(DataStore): def __exit__(self, exc_type, exc_val, exc_tb): self._events.__exit__(exc_type, exc_val, exc_tb) - def subscribe(self, callback: Callable[[events.Event], Any], - event_types: Optional[Iterable[type[events.Event]]] = None) -> SubscriptionToken: - return self._events.subscribe(callback, event_types) - - def unsubscribe(self, token: events.SubscriptionToken) -> None: - self._events.unsubscribe(token) + @property + def events(self) -> EventSource: + return self._events def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]: return [state.schedule for state in self._schedules diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index 749381e..6e7d0aa 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -14,12 +14,12 @@ from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.errors import DuplicateKeyError -from .. import events -from ..abc import DataStore, Job, Schedule, Serializer +from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer from ..enums import ConflictPolicy +from ..eventbrokers.local import LocalEventBroker from ..events import ( - DataStoreEvent, EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, - SubscriptionToken, TaskAdded, TaskRemoved, TaskUpdated) + DataStoreEvent, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, + TaskAdded, TaskRemoved, TaskUpdated) from ..exceptions import ( ConflictingIdError, DeserializationError, SerializationError, TaskLookupError) from ..serializers.pickle import PickleSerializer @@ -42,7 +42,7 @@ class MongoDBDataStore(DataStore): _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) _exit_stack: ExitStack = attr.field(init=False, factory=ExitStack) - _events: EventHub = attr.field(init=False, factory=EventHub) + _events: EventBroker = attr.field(init=False, factory=LocalEventBroker) _local_tasks: dict[str, Task] = attr.field(init=False, factory=dict) @client.validator @@ -62,6 +62,10 @@ class MongoDBDataStore(DataStore): client = MongoClient(uri) return cls(client, **options) + @property + def events(self) -> EventSource: + return self._events + def __enter__(self): server_info = self.client.server_info() if server_info['versionArray'] < [4, 0]: diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index d49ca61..3040ae4 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -16,12 +16,12 @@ from sqlalchemy.future import Engine, create_engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal -from ..abc import DataStore, Job, Schedule, Serializer +from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome +from ..eventbrokers.local import LocalEventBroker from ..events import ( - Event, EventHub, JobAdded, JobDeserializationFailed, ScheduleAdded, - ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, - TaskRemoved, TaskUpdated) + Event, JobAdded, JobDeserializationFailed, ScheduleAdded, ScheduleDeserializationFailed, + ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, TaskRemoved, TaskUpdated) from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..serializers.pickle import PickleSerializer @@ -179,7 +179,7 @@ class _BaseSQLAlchemyDataStore: class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): engine: Engine - _events: EventHub = attr.field(init=False, factory=EventHub) + _events: EventBroker = attr.field(init=False, factory=LocalEventBroker) @classmethod def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: @@ -208,6 +208,10 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): def __exit__(self, exc_type, exc_val, exc_tb): self._events.__exit__(exc_type, exc_val, exc_tb) + @property + def events(self) -> EventSource: + return self._events + def subscribe(self, callback: Callable[[Event], Any], event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: return self._events.subscribe(callback, event_types) diff --git a/src/apscheduler/eventbrokers/__init__.py b/src/apscheduler/eventbrokers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/apscheduler/eventbrokers/__init__.py diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py new file mode 100644 index 0000000..cb18386 --- /dev/null +++ b/src/apscheduler/eventbrokers/async_adapter.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from functools import partial +from typing import Any, Callable, Iterable, Optional + +import attr +from anyio import to_thread +from anyio.from_thread import BlockingPortal + +from apscheduler.abc import EventBroker +from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker +from apscheduler.events import Event, SubscriptionToken +from apscheduler.util import reentrant + + +@reentrant +@attr.define(eq=False) +class AsyncEventBrokerAdapter(LocalAsyncEventBroker): + original: EventBroker + portal: BlockingPortal + _exit_stack: AsyncExitStack = attr.field(init=False) + + async def __aenter__(self): + self._exit_stack = AsyncExitStack() + if not self.portal: + self.portal = BlockingPortal() + self._exit_stack.enter_async_context(self.portal) + + await to_thread.run_sync(self.original.__enter__) + return await super().__aenter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb) + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + return await super().__aexit__(exc_type, exc_val, exc_tb) + + async def publish(self, event: Event) -> None: + await to_thread.run_sync(self.original.publish, event) + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: + token = self.original.subscribe(partial(self.portal.call, callback), event_types) + return token diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py new file mode 100644 index 0000000..e73b4b9 --- /dev/null +++ b/src/apscheduler/eventbrokers/async_local.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from inspect import isawaitable +from logging import Logger, getLogger +from typing import Any, Callable + +import attr +from anyio import create_task_group +from anyio.abc import TaskGroup + +from ..abc import AsyncEventBroker +from ..events import Event +from ..util import reentrant +from .base import BaseEventBroker + + +@reentrant +@attr.define(eq=False) +class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker): + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _task_group: TaskGroup = attr.field(init=False) + + async def __aenter__(self) -> LocalAsyncEventBroker: + self._task_group = create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + del self._task_group + + async def publish(self, event: Event) -> None: + async def deliver_event(func: Callable[[Event], Any]) -> None: + try: + retval = func(event) + if isawaitable(retval): + await retval + except BaseException: + self._logger.exception('Error delivering %s event', event.__class__.__name__) + + event_type = type(event) + for subscription in self._subscriptions.values(): + if subscription.event_types is None or event_type in subscription.event_types: + self._task_group.start_soon(deliver_event, subscription.callback) diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py new file mode 100644 index 0000000..447bee8 --- /dev/null +++ b/src/apscheduler/eventbrokers/asyncpg.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from logging import Logger, getLogger +from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable + +import attr +from anyio import TASK_STATUS_IGNORED, sleep +from asyncpg import Connection +from asyncpg.pool import Pool + +from ..events import Event +from ..exceptions import SerializationError +from ..util import reentrant +from .async_local import LocalAsyncEventBroker +from .base import DistributedEventBrokerMixin + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncEngine + + +@reentrant +@attr.define(eq=False) +class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin): + connection_factory: Callable[[], AsyncContextManager[Connection]] + channel: str = attr.field(kw_only=True, default='apscheduler') + max_idle_time: float = attr.field(kw_only=True, default=30) + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + + @classmethod + def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker: + return cls(pool.acquire) + + @classmethod + def from_async_sqla_engine(cls, engine: AsyncEngine) -> AsyncpgEventBroker: + if engine.dialect.driver != 'asyncpg': + raise ValueError(f'The driver in the engine must be "asyncpg" (current: ' + f'{engine.dialect.driver})') + + @asynccontextmanager + async def connection_factory() -> AsyncGenerator[Connection, None]: + conn = await engine.raw_connection() + try: + yield conn.connection._connection + finally: + conn.close() + + return cls(connection_factory) + + async def __aenter__(self) -> LocalAsyncEventBroker: + await super().__aenter__() + await self._task_group.start(self._listen_notifications) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._task_group.cancel_scope.cancel() + await super().__aexit__(exc_type, exc_val, exc_tb) + + async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None: + local_publish = super(AsyncpgEventBroker, self).publish + + def callback(connection, pid, channel: str, payload: str) -> None: + event = self.reconstitute_event_str(payload) + if event is not None: + self._task_group.start_soon(local_publish, event) + + task_started_sent = False + while True: + async with self.connection_factory() as conn: + await conn.add_listener(self.channel, callback) + if not task_started_sent: + task_status.started() + task_started_sent = True + + try: + while True: + await sleep(self.max_idle_time) + await conn.execute('SELECT 1') + finally: + await conn.remove_listener(self.channel, callback) + + async def publish(self, event: Event) -> None: + notification = self.generate_notification_str(event) + if len(notification) > 7999: + raise SerializationError('Serialized event object exceeds 7999 bytes in size') + + async with self.connection_factory() as conn: + await conn.execute("SELECT pg_notify($1, $2)", self.channel, notification) + return diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py new file mode 100644 index 0000000..da89dc5 --- /dev/null +++ b/src/apscheduler/eventbrokers/base.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from base64 import b64decode, b64encode +from logging import Logger +from typing import Any, Callable, Iterable, Optional + +import attr + +from .. import abc, events +from ..abc import EventBroker, Serializer +from ..events import Event, Subscription, SubscriptionToken +from ..exceptions import DeserializationError + + +@attr.define(eq=False) +class BaseEventBroker(EventBroker): + _subscriptions: dict[SubscriptionToken, Subscription] = attr.field(init=False, factory=dict) + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: + types = set(event_types) if event_types else None + token = SubscriptionToken(object()) + subscription = Subscription(callback, types) + self._subscriptions[token] = subscription + return token + + def unsubscribe(self, token: SubscriptionToken) -> None: + self._subscriptions.pop(token, None) + + def relay_events_from(self, source: abc.EventSource) -> SubscriptionToken: + return source.subscribe(self.publish) + + +class DistributedEventBrokerMixin: + _logger: Logger + serializer: Serializer + + def generate_notification(self, event: Event, use_base64: bool = False) -> bytes: + serialized = self.serializer.serialize(attr.asdict(event)) + return event.__class__.__name__.encode('ascii') + b' ' + serialized + + def generate_notification_str(self, event: Event) -> str: + serialized = self.serializer.serialize(attr.asdict(event)) + return event.__class__.__name__ + ' ' + b64encode(serialized).decode('ascii') + + def _reconstitute_event(self, event_type: str, serialized: bytes) -> Optional[Event]: + try: + kwargs = self.serializer.deserialize(serialized) + except DeserializationError: + self._logger.exception('Failed to deserialize an event of type %s', event_type, + serialized=serialized) + return None + + try: + event_class = getattr(events, event_type) + except AttributeError: + self._logger.error('Receive notification for a nonexistent event type: %s', + event_type, serialized=serialized) + return None + + try: + return event_class(**kwargs) + except Exception: + self._logger.exception('Error reconstituting event of type %s', event_type) + return None + + def reconstitute_event(self, payload: bytes) -> Optional[Event]: + try: + event_type_bytes, serialized = payload.split(b' ', 1) + except ValueError: + self._logger.error('Received malformatted notification', payload=payload) + return None + + event_type = event_type_bytes.decode('ascii', errors='replace') + return self._reconstitute_event(event_type, serialized) + + def reconstitute_event_str(self, payload: str) -> Optional[Event]: + try: + event_type, b64_serialized = payload.split(' ', 1) + except ValueError: + self._logger.error('Received malformatted notification', payload=payload) + return None + + return self._reconstitute_event(event_type, b64decode(b64_serialized)) diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py new file mode 100644 index 0000000..ab75575 --- /dev/null +++ b/src/apscheduler/eventbrokers/local.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from asyncio import iscoroutinefunction +from concurrent.futures import ThreadPoolExecutor +from logging import Logger, getLogger +from typing import Any, Callable, Iterable, Optional + +import attr + +from ..events import Event, SubscriptionToken +from ..util import reentrant +from .base import BaseEventBroker + + +@reentrant +@attr.define(eq=False) +class LocalEventBroker(BaseEventBroker): + _executor: ThreadPoolExecutor = attr.field(init=False) + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + + def __enter__(self) -> LocalEventBroker: + self._executor = ThreadPoolExecutor(1) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._executor.shutdown(wait=exc_type is None) + del self._executor + + def subscribe(self, callback: Callable[[Event], Any], + event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: + if iscoroutinefunction(callback): + raise ValueError('Coroutine functions are not supported as callbacks on a synchronous ' + 'event source') + + return super().subscribe(callback, event_types) + + def publish(self, event: Event) -> None: + event_type = type(event) + for subscription in list(self._subscriptions.values()): + if subscription.event_types is None or event_type in subscription.event_types: + self._executor.submit(self._deliver_event, subscription.callback, event) + + def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None: + try: + func(event) + except BaseException: + self._logger.exception('Error delivering %s event', event.__class__.__name__) diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py new file mode 100644 index 0000000..cfedb88 --- /dev/null +++ b/src/apscheduler/eventbrokers/mqtt.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from concurrent.futures import Future +from logging import Logger, getLogger +from typing import Any, Optional + +import attr +from paho.mqtt.client import Client, MQTTMessage +from paho.mqtt.properties import Properties +from paho.mqtt.reasoncodes import ReasonCodes + +from ..abc import Serializer +from ..events import Event +from ..serializers.json import JSONSerializer +from ..util import reentrant +from .base import DistributedEventBrokerMixin +from .local import LocalEventBroker + + +@reentrant +@attr.define(eq=False) +class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin): + client: Client + serializer: Serializer = attr.field(factory=JSONSerializer) + host: str = attr.field(kw_only=True, default='localhost') + port: int = attr.field(kw_only=True, default=1883) + topic: str = attr.field(kw_only=True, default='apscheduler') + subscribe_qos: int = attr.field(kw_only=True, default=0) + publish_qos: int = attr.field(kw_only=True, default=0) + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _ready_future: Future[None] = attr.field(init=False) + + def __enter__(self): + self._ready_future = Future() + self.client.enable_logger(self._logger) + self.client.on_connect = self._on_connect + self.client.on_message = self._on_message + self.client.on_subscribe = self._on_subscribe + self.client.connect(self.host, self.port) + self.client.loop_start() + self._ready_future.result(10) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.disconnect() + self.client.loop_stop(force=exc_type is not None) + return super().__exit__(exc_type, exc_val, exc_tb) + + def _on_connect(self, client: Client, userdata: Any, flags: dict[str, Any], + rc: ReasonCodes | int, properties: Optional[Properties] = None) -> None: + try: + client.subscribe(self.topic, qos=self.subscribe_qos) + except Exception as exc: + self._ready_future.set_exception(exc) + raise + + def _on_subscribe(self, client: Client, userdata: Any, mid, granted_qos: list[int]) -> None: + self._ready_future.set_result(None) + + def _on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None: + event = self.reconstitute_event(msg.payload) + if event is not None: + super().publish(event) + + def publish(self, event: Event) -> None: + notification = self.generate_notification(event) + self.client.publish(self.topic, notification, qos=self.publish_qos) diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py new file mode 100644 index 0000000..92c7f83 --- /dev/null +++ b/src/apscheduler/eventbrokers/redis.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from concurrent.futures import Future +from logging import Logger, getLogger +from threading import Thread +from typing import Optional + +import attr +from redis import ConnectionPool, Redis + +from ..abc import Serializer +from ..events import Event +from ..serializers.json import JSONSerializer +from ..util import reentrant +from .base import DistributedEventBrokerMixin +from .local import LocalEventBroker + + +@reentrant +@attr.define(eq=False) +class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin): + client: Redis + serializer: Serializer = attr.field(factory=JSONSerializer) + channel: str = attr.field(kw_only=True, default='apscheduler') + message_poll_interval: float = attr.field(kw_only=True, default=0.05) + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _stopped: bool = attr.field(init=False, default=True) + _ready_future: Future[None] = attr.field(init=False) + + @classmethod + def from_url(cls, url: str, db: Optional[str] = None, decode_components: bool = False, + **kwargs) -> RedisEventBroker: + pool = ConnectionPool.from_url(url, db, decode_components, **kwargs) + client = Redis(connection_pool=pool) + return cls(client) + + def __enter__(self): + self._stopped = False + self._ready_future = Future() + self._thread = Thread(target=self._listen_messages, daemon=True, name='Redis subscriber') + self._thread.start() + self._ready_future.result(10) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stopped = True + if not exc_type: + self._thread.join(5) + + super().__exit__(exc_type, exc_val, exc_tb) + + def _listen_messages(self) -> None: + while not self._stopped: + try: + pubsub = self.client.pubsub() + pubsub.subscribe(self.channel) + except BaseException as exc: + if not self._ready_future.done(): + self._ready_future.set_exception(exc) + + raise + else: + if not self._ready_future.done(): + self._ready_future.set_result(None) + + try: + while not self._stopped: + msg = pubsub.get_message(timeout=self.message_poll_interval) + if msg and isinstance(msg['data'], bytes): + event = self.reconstitute_event(msg['data']) + if event is not None: + super().publish(event) + except BaseException: + self._logger.exception('Subscriber crashed') + raise + finally: + pubsub.close() + + def publish(self, event: Event) -> None: + notification = self.generate_notification(event) + self.client.publish(self.channel, notification) diff --git a/src/apscheduler/events.py b/src/apscheduler/events.py index 571aead..49703d9 100644 --- a/src/apscheduler/events.py +++ b/src/apscheduler/events.py @@ -1,38 +1,24 @@ from __future__ import annotations -import logging -from abc import abstractmethod -from asyncio import iscoroutinefunction -from concurrent.futures.thread import ThreadPoolExecutor from datetime import datetime, timezone from functools import partial -from inspect import isawaitable -from logging import Logger from traceback import format_tb -from typing import Any, Callable, Iterable, NewType, Optional +from typing import Any, Callable, NewType, Optional from uuid import UUID import attr -from anyio import create_task_group -from anyio.abc import TaskGroup +from attr.converters import optional -from . import abc +from .converters import as_aware_datetime, as_uuid from .structures import Job SubscriptionToken = NewType('SubscriptionToken', object) -def timestamp_to_datetime(value: datetime | float | None) -> Optional[datetime]: - if isinstance(value, float): - return datetime.fromtimestamp(value, timezone.utc) - - return value - - @attr.define(kw_only=True, frozen=True) class Event: timestamp: datetime = attr.field(factory=partial(datetime.now, timezone.utc), - converter=timestamp_to_datetime) + converter=as_aware_datetime) # @@ -62,13 +48,13 @@ class TaskRemoved(DataStoreEvent): @attr.define(kw_only=True, frozen=True) class ScheduleAdded(DataStoreEvent): schedule_id: str - next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime) + next_fire_time: Optional[datetime] = attr.field(converter=optional(as_aware_datetime)) @attr.define(kw_only=True, frozen=True) class ScheduleUpdated(DataStoreEvent): schedule_id: str - next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime) + next_fire_time: Optional[datetime] = attr.field(converter=optional(as_aware_datetime)) @attr.define(kw_only=True, frozen=True) @@ -78,15 +64,15 @@ class ScheduleRemoved(DataStoreEvent): @attr.define(kw_only=True, frozen=True) class JobAdded(DataStoreEvent): - job_id: UUID + job_id: UUID = attr.field(converter=as_uuid) task_id: str schedule_id: Optional[str] - tags: frozenset[str] + tags: frozenset[str] = attr.field(converter=frozenset) @attr.define(kw_only=True, frozen=True) class JobRemoved(DataStoreEvent): - job_id: UUID + job_id: UUID = attr.field(converter=as_uuid) @attr.define(kw_only=True, frozen=True) @@ -97,7 +83,7 @@ class ScheduleDeserializationFailed(DataStoreEvent): @attr.define(kw_only=True, frozen=True) class JobDeserializationFailed(DataStoreEvent): - job_id: UUID + job_id: UUID = attr.field(converter=as_uuid) exception: BaseException @@ -141,11 +127,11 @@ class WorkerStopped(WorkerEvent): @attr.define(kw_only=True, frozen=True) class JobExecutionEvent(WorkerEvent): - job_id: UUID + job_id: UUID = attr.field(converter=as_uuid) task_id: str schedule_id: Optional[str] - scheduled_fire_time: Optional[datetime] - start_deadline: Optional[datetime] + scheduled_fire_time: Optional[datetime] = attr.field(converter=optional(as_aware_datetime)) + start_deadline: Optional[datetime] = attr.field(converter=optional(as_aware_datetime)) @attr.define(kw_only=True, frozen=True) @@ -175,7 +161,7 @@ class JobDeadlineMissed(JobExecutionEvent): @attr.define(kw_only=True, frozen=True) class JobCompleted(JobExecutionEvent): """Signals that a worker has successfully run a job.""" - start_time: datetime + start_time: datetime = attr.field(converter=optional(as_aware_datetime)) return_value: str @classmethod @@ -188,7 +174,7 @@ class JobCompleted(JobExecutionEvent): @attr.define(kw_only=True, frozen=True) class JobCancelled(JobExecutionEvent): """Signals that a job was cancelled.""" - start_time: datetime + start_time: datetime = attr.field(converter=optional(as_aware_datetime)) @classmethod def from_job(cls, job: Job, start_time: datetime) -> JobCancelled: @@ -227,85 +213,3 @@ class JobFailed(JobExecutionEvent): class Subscription: callback: Callable[[Event], Any] event_types: Optional[set[type[Event]]] - - -@attr.define -class _BaseEventHub(abc.EventSource): - _logger: Logger = attr.field(init=False, factory=lambda: logging.getLogger(__name__)) - _subscriptions: dict[SubscriptionToken, Subscription] = attr.field(init=False, factory=dict) - - def subscribe(self, callback: Callable[[Event], Any], - event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: - types = set(event_types) if event_types else None - token = SubscriptionToken(object()) - subscription = Subscription(callback, types) - self._subscriptions[token] = subscription - return token - - def unsubscribe(self, token: SubscriptionToken) -> None: - self._subscriptions.pop(token, None) - - @abstractmethod - def publish(self, event: Event) -> None: - """Publish an event to all subscribers.""" - - def relay_events_from(self, source: abc.EventSource) -> SubscriptionToken: - return source.subscribe(self.publish) - - -class EventHub(_BaseEventHub): - _executor: ThreadPoolExecutor - - def __enter__(self) -> EventHub: - self._executor = ThreadPoolExecutor(1) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._executor.shutdown(wait=exc_type is None) - - def subscribe(self, callback: Callable[[Event], Any], - event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: - if iscoroutinefunction(callback): - raise ValueError('Coroutine functions are not supported as callbacks on a synchronous ' - 'event source') - - return super().subscribe(callback, event_types) - - def publish(self, event: Event) -> None: - def deliver_event(func: Callable[[Event], Any]) -> None: - try: - func(event) - except BaseException: - self._logger.exception('Error delivering %s event', event.__class__.__name__) - - event_type = type(event) - for subscription in list(self._subscriptions.values()): - if subscription.event_types is None or event_type in subscription.event_types: - self._executor.submit(deliver_event, subscription.callback) - - -class AsyncEventHub(_BaseEventHub): - _task_group: TaskGroup - - async def __aenter__(self) -> AsyncEventHub: - self._task_group = create_task_group() - await self._task_group.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - del self._task_group - - def publish(self, event: Event) -> None: - async def deliver_event(func: Callable[[Event], Any]) -> None: - try: - retval = func(event) - if isawaitable(retval): - await retval - except BaseException: - self._logger.exception('Error delivering %s event', event.__class__.__name__) - - event_type = type(event) - for subscription in self._subscriptions.values(): - if subscription.event_types is None or event_type in subscription.event_types: - self._task_group.start_soon(deliver_event, subscription.callback) diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index b60aed6..790baf4 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -16,9 +16,9 @@ from ..abc import AsyncDataStore, DataStore, EventSource, Job, Schedule, Trigger from ..datastores.async_adapter import AsyncDataStoreAdapter from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, RunState +from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import ( - AsyncEventHub, Event, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, - SubscriptionToken) + Event, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, SubscriptionToken) from ..marshalling import callable_to_ref from ..structures import Task from ..workers.async_ import AsyncWorker @@ -40,7 +40,7 @@ class AsyncScheduler(EventSource): self.logger = logger or getLogger(__name__) self.start_worker = start_worker self._exit_stack = AsyncExitStack() - self._events = AsyncEventHub() + self._events = LocalAsyncEventBroker() data_store = data_store or MemoryDataStore() if isinstance(data_store, DataStore): @@ -60,13 +60,13 @@ class AsyncScheduler(EventSource): # Initialize the data store await self._exit_stack.enter_async_context(self.data_store) - relay_token = self._events.relay_events_from(self.data_store) - self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + relay_token = self._events.relay_events_from(self.data_store.events) + self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token) # Wake up the scheduler if the data store emits a significant schedule event - wakeup_token = self.data_store.subscribe( + wakeup_token = self.data_store.events.subscribe( lambda event: self._wakeup_event.set(), {ScheduleAdded, ScheduleUpdated}) - self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token) # Start the built-in worker, if configured to do so if self.start_worker: @@ -132,7 +132,7 @@ class AsyncScheduler(EventSource): # Signal that the scheduler has started self._state = RunState.started task_status.started() - self._events.publish(SchedulerStarted()) + await self._events.publish(SchedulerStarted()) try: while self._state is RunState.started: @@ -190,11 +190,11 @@ class AsyncScheduler(EventSource): pass except BaseException as exc: self._state = RunState.stopped - self._events.publish(SchedulerStopped(exception=exc)) + await self._events.publish(SchedulerStopped(exception=exc)) raise self._state = RunState.stopped - self._events.publish(SchedulerStopped()) + await self._events.publish(SchedulerStopped()) # async def stop(self, force: bool = False) -> None: # self._running = False diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index 1e5f133..5875ddc 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -13,9 +13,9 @@ from uuid import uuid4 from ..abc import DataStore, EventSource, Trigger from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, RunState +from ..eventbrokers.local import LocalEventBroker from ..events import ( - Event, EventHub, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, - SubscriptionToken) + Event, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, SubscriptionToken) from ..marshalling import callable_to_ref from ..structures import Job, Schedule, Task from ..workers.sync import Worker @@ -36,7 +36,7 @@ class Scheduler(EventSource): self.data_store = data_store or MemoryDataStore() self._exit_stack = ExitStack() self._executor = ThreadPoolExecutor(max_workers=1) - self._events = EventHub() + self._events = LocalEventBroker() @property def state(self) -> RunState: @@ -54,13 +54,13 @@ class Scheduler(EventSource): # Initialize the data store self._exit_stack.enter_context(self.data_store) - relay_token = self._events.relay_events_from(self.data_store) - self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + relay_token = self._events.relay_events_from(self.data_store.events) + self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token) # Wake up the scheduler if the data store emits a significant schedule event - wakeup_token = self.data_store.subscribe( + wakeup_token = self.data_store.events.subscribe( lambda event: self._wakeup_event.set(), {ScheduleAdded, ScheduleUpdated}) - self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token) # Start the built-in worker, if configured to do so if self.start_worker: diff --git a/src/apscheduler/serializers/json.py b/src/apscheduler/serializers/json.py index f7ef307..8bfe6d7 100644 --- a/src/apscheduler/serializers/json.py +++ b/src/apscheduler/serializers/json.py @@ -1,12 +1,13 @@ from __future__ import annotations +from datetime import datetime from json import dumps, loads from typing import Any import attr from ..abc import Serializer -from ..marshalling import marshal_object, unmarshal_object +from ..marshalling import marshal_date, marshal_object, unmarshal_object @attr.define(kw_only=True, eq=False) @@ -23,6 +24,8 @@ class JSONSerializer(Serializer): if hasattr(obj, '__getstate__'): cls_ref, state = marshal_object(obj) return {self.magic_key: [cls_ref, state]} + elif isinstance(obj, datetime): + return marshal_date(obj) raise TypeError(f'Object of type {obj.__class__.__name__!r} is not JSON serializable') diff --git a/src/apscheduler/validators.py b/src/apscheduler/validators.py index ca12d84..c71f73a 100644 --- a/src/apscheduler/validators.py +++ b/src/apscheduler/validators.py @@ -4,6 +4,7 @@ import sys from datetime import date, datetime, timedelta, timezone, tzinfo from typing import Any, Optional +from attr import Attribute from tzlocal import get_localzone from .abc import Trigger @@ -146,6 +147,11 @@ def as_list(value, element_type: type, name: str) -> list: return value +def aware_datetime(instance: Any, attribute: Attribute, value: datetime) -> None: + if not value.tzinfo: + raise ValueError(f'{attribute.name} must be a timezone aware datetime') + + def require_state_version(trigger: Trigger, state: dict[str, Any], max_version: int) -> None: try: if state['version'] > max_version: diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py index e9bcc5f..0893a98 100644 --- a/src/apscheduler/workers/async_.py +++ b/src/apscheduler/workers/async_.py @@ -10,15 +10,16 @@ from typing import Any, Callable, Iterable, Optional from uuid import UUID import anyio -from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class +from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class, move_on_after from anyio.abc import CancelScope from ..abc import AsyncDataStore, DataStore, EventSource, Job from ..datastores.async_adapter import AsyncDataStoreAdapter from ..enums import JobOutcome, RunState +from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import ( - AsyncEventHub, Event, JobAdded, JobCancelled, JobCompleted, JobDeadlineMissed, JobFailed, - JobStarted, SubscriptionToken, WorkerStarted, WorkerStopped) + Event, JobAdded, JobCancelled, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, + SubscriptionToken, WorkerStarted, WorkerStopped) from ..structures import JobResult @@ -39,7 +40,7 @@ class AsyncWorker(EventSource): self.logger = logger or getLogger(__name__) self._acquired_jobs: set[Job] = set() self._exit_stack = AsyncExitStack() - self._events = AsyncEventHub() + self._events = LocalAsyncEventBroker() self._running_jobs: set[UUID] = set() if self.max_concurrent_jobs < 1: @@ -62,13 +63,13 @@ class AsyncWorker(EventSource): # Initialize the data store await self._exit_stack.enter_async_context(self.data_store) - relay_token = self._events.relay_events_from(self.data_store) - self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + relay_token = self._events.relay_events_from(self.data_store.events) + self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token) # Wake up the worker if the data store emits a significant job event - wakeup_token = self.data_store.subscribe( + wakeup_token = self.data_store.events.subscribe( lambda event: self._wakeup_event.set(), {JobAdded}) - self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token) # Start the actual worker task_group = create_task_group() @@ -97,7 +98,7 @@ class AsyncWorker(EventSource): # Signal that the worker has started self._state = RunState.started task_status.started() - self._events.publish(WorkerStarted()) + await self._events.publish(WorkerStarted()) try: async with create_task_group() as tg: @@ -115,21 +116,23 @@ class AsyncWorker(EventSource): pass except BaseException as exc: self._state = RunState.stopped - self._events.publish(WorkerStopped(exception=exc)) + with move_on_after(1, shield=True): + await self._events.publish(WorkerStopped(exception=exc)) + raise self._state = RunState.stopped - self._events.publish(WorkerStopped()) + await self._events.publish(WorkerStopped()) async def _run_job(self, job: Job, func: Callable) -> None: try: # Check if the job started before the deadline start_time = datetime.now(timezone.utc) if job.start_deadline is not None and start_time > job.start_deadline: - self._events.publish(JobDeadlineMissed.from_job(job, start_time)) + await self._events.publish(JobDeadlineMissed.from_job(job, start_time)) return - self._events.publish(JobStarted.from_job(job, start_time)) + await self._events.publish(JobStarted.from_job(job, start_time)) try: retval = func(*job.args, **job.kwargs) if isawaitable(retval): @@ -139,17 +142,18 @@ class AsyncWorker(EventSource): result = JobResult(job_id=job.id, outcome=JobOutcome.cancelled) await self.data_store.release_job(self.identity, job.task_id, result) - self._events.publish(JobCancelled.from_job(job, start_time)) + with move_on_after(1, shield=True): + await self._events.publish(JobCancelled.from_job(job, start_time)) except BaseException as exc: result = JobResult(job_id=job.id, outcome=JobOutcome.failure, exception=exc) await self.data_store.release_job(self.identity, job.task_id, result) - self._events.publish(JobFailed.from_exception(job, start_time, exc)) + await self._events.publish(JobFailed.from_exception(job, start_time, exc)) if not isinstance(exc, Exception): raise else: result = JobResult(job_id=job.id, outcome=JobOutcome.success, return_value=retval) await self.data_store.release_job(self.identity, job.task_id, result) - self._events.publish(JobCompleted.from_retval(job, start_time, retval)) + await self._events.publish(JobCompleted.from_retval(job, start_time, retval)) finally: self._running_jobs.remove(job.id) diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index b7803c7..30db6bf 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -13,8 +13,9 @@ from uuid import UUID from .. import events from ..abc import DataStore, EventSource from ..enums import JobOutcome, RunState +from ..eventbrokers.local import LocalEventBroker from ..events import ( - EventHub, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, SubscriptionToken, + JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, SubscriptionToken, WorkerStarted, WorkerStopped) from ..structures import Job, JobResult @@ -33,7 +34,7 @@ class Worker(EventSource): self.logger = logger or getLogger(__name__) self._acquired_jobs: set[Job] = set() self._exit_stack = ExitStack() - self._events = EventHub() + self._events = LocalEventBroker() self._running_jobs: set[UUID] = set() if self.max_concurrent_jobs < 1: @@ -53,13 +54,13 @@ class Worker(EventSource): # Initialize the data store self._exit_stack.enter_context(self.data_store) - relay_token = self._events.relay_events_from(self.data_store) - self._exit_stack.callback(self.data_store.unsubscribe, relay_token) + relay_token = self._events.relay_events_from(self.data_store.events) + self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token) # Wake up the worker if the data store emits a significant job event - wakeup_token = self.data_store.subscribe( + wakeup_token = self.data_store.events.subscribe( lambda event: self._wakeup_event.set(), {JobAdded}) - self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) + self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token) # Start the worker and return when it has signalled readiness or raised an exception start_future: Future[None] = Future() diff --git a/tests/conftest.py b/tests/conftest.py index bc5bab2..135c18d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,6 @@ def timezone() -> ZoneInfo: @pytest.fixture(params=[ - pytest.param(None, id='none'), pytest.param(PickleSerializer, id='pickle'), pytest.param(CBORSerializer, id='cbor'), pytest.param(JSONSerializer, id='json') diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 52f8349..74db6e7 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -40,11 +40,11 @@ async def capture_events( events.append(event) if len(events) == limit: limit_event.set() - store.unsubscribe(token) + store.events.unsubscribe(token) events: List[Event] = [] limit_event = anyio.Event() - token = store.subscribe(listener, event_types) + token = store.events.subscribe(listener, event_types) yield events if limit: with anyio.fail_after(3): diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py new file mode 100644 index 0000000..6a0f45d --- /dev/null +++ b/tests/test_eventbrokers.py @@ -0,0 +1,279 @@ +from concurrent.futures import Future +from datetime import datetime, timezone +from queue import Empty, Queue +from typing import Callable + +import pytest +from _pytest.fixtures import FixtureRequest +from _pytest.logging import LogCaptureFixture +from anyio import create_memory_object_stream, fail_after + +from apscheduler.abc import AsyncEventBroker, EventBroker, Serializer +from apscheduler.events import Event, ScheduleAdded + + +@pytest.fixture +def local_broker() -> EventBroker: + from apscheduler.eventbrokers.local import LocalEventBroker + + return LocalEventBroker() + + +@pytest.fixture +def local_async_broker() -> AsyncEventBroker: + from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker + + return LocalAsyncEventBroker() + + +@pytest.fixture +def redis_broker(serializer: Serializer) -> EventBroker: + from apscheduler.eventbrokers.redis import RedisEventBroker + + broker = RedisEventBroker.from_url('redis://localhost:6379') + broker.serializer = serializer + return broker + + +@pytest.fixture +def mqtt_broker(serializer: Serializer) -> EventBroker: + from paho.mqtt.client import Client + + from apscheduler.eventbrokers.mqtt import MQTTEventBroker + + return MQTTEventBroker(Client(), serializer=serializer) + + +@pytest.fixture +async def asyncpg_broker(serializer: Serializer) -> AsyncEventBroker: + from asyncpg import create_pool + + from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker + + pool = await create_pool('postgres://postgres:secret@localhost:5432/testdb') + broker = AsyncpgEventBroker.from_asyncpg_pool(pool) + broker.serializer = serializer + yield broker + await pool.close() + + +@pytest.fixture(params=[ + pytest.param(pytest.lazy_fixture('local_broker'), id='local'), + pytest.param(pytest.lazy_fixture('redis_broker'), id='redis'), + pytest.param(pytest.lazy_fixture('mqtt_broker'), id='mqtt') +]) +def broker(request: FixtureRequest) -> Callable[[], EventBroker]: + return request.param + + +@pytest.fixture(params=[ + pytest.param(pytest.lazy_fixture('local_async_broker'), id='local'), + pytest.param(pytest.lazy_fixture('asyncpg_broker'), id='asyncpg') +]) +def async_broker(request: FixtureRequest) -> Callable[[], AsyncEventBroker]: + return request.param + + +class TestEventBroker: + def test_publish_subscribe(self, broker: EventBroker) -> None: + def subscriber1(event) -> None: + queue.put_nowait(event) + + def subscriber2(event) -> None: + queue.put_nowait(event) + + queue = Queue() + with broker: + broker.subscribe(subscriber1) + broker.subscribe(subscriber2) + event = ScheduleAdded( + schedule_id='schedule1', + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)) + broker.publish(event) + event1 = queue.get(timeout=3) + event2 = queue.get(timeout=1) + + assert event1 == event2 + assert isinstance(event1, ScheduleAdded) + assert isinstance(event1.timestamp, datetime) + assert event1.schedule_id == 'schedule1' + assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc) + + def test_unsubscribe(self, broker: EventBroker, caplog) -> None: + queue = Queue() + with broker: + token = broker.subscribe(queue.put_nowait) + broker.publish(Event()) + queue.get(timeout=3) + + broker.unsubscribe(token) + broker.publish(Event()) + with pytest.raises(Empty): + queue.get(timeout=0.1) + + def test_publish_no_subscribers(self, broker: EventBroker, caplog: LogCaptureFixture) -> None: + with broker: + broker.publish(Event()) + + assert not caplog.text + + def test_publish_exception(self, broker: EventBroker, caplog: LogCaptureFixture) -> None: + def bad_subscriber(event: Event) -> None: + raise Exception('foo') + + timestamp = datetime.now(timezone.utc) + event_future: Future[Event] = Future() + with broker: + broker.subscribe(bad_subscriber) + broker.subscribe(event_future.set_result) + broker.publish(Event(timestamp=timestamp)) + + event = event_future.result(3) + assert isinstance(event, Event) + assert event.timestamp == timestamp + assert 'Error delivering Event' in caplog.text + + +@pytest.mark.anyio +class TestAsyncEventBroker: + async def test_publish_subscribe(self, async_broker: AsyncEventBroker) -> None: + def subscriber1(event) -> None: + send.send_nowait(event) + + async def subscriber2(event) -> None: + await send.send(event) + + send, receive = create_memory_object_stream(2) + async with async_broker: + async_broker.subscribe(subscriber1) + async_broker.subscribe(subscriber2) + event = ScheduleAdded( + schedule_id='schedule1', + next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)) + await async_broker.publish(event) + + with fail_after(3): + event1 = await receive.receive() + event2 = await receive.receive() + + assert event1 == event2 + assert isinstance(event1, ScheduleAdded) + assert isinstance(event1.timestamp, datetime) + assert event1.schedule_id == 'schedule1' + assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc) + + async def test_unsubscribe(self, async_broker: AsyncEventBroker) -> None: + send, receive = create_memory_object_stream() + async with async_broker: + token = async_broker.subscribe(send.send) + await async_broker.publish(Event()) + with fail_after(3): + await receive.receive() + + async_broker.unsubscribe(token) + await async_broker.publish(Event()) + with pytest.raises(TimeoutError), fail_after(0.1): + await receive.receive() + + async def test_publish_no_subscribers(self, async_broker: AsyncEventBroker, + caplog: LogCaptureFixture) -> None: + async with async_broker: + await async_broker.publish(Event()) + + assert not caplog.text + + async def test_publish_exception(self, async_broker: AsyncEventBroker, + caplog: LogCaptureFixture) -> None: + def bad_subscriber(event: Event) -> None: + raise Exception('foo') + + timestamp = datetime.now(timezone.utc) + events = [] + async with async_broker: + async_broker.subscribe(bad_subscriber) + async_broker.subscribe(events.append) + await async_broker.publish(Event(timestamp=timestamp)) + + assert isinstance(events[0], Event) + assert events[0].timestamp == timestamp + assert 'Error delivering Event' in caplog.text +# +# def test_subscribe_coroutine_callback(self) -> None: +# async def callback(event: Event) -> None: +# pass +# +# with EventBroker() as eventhub: +# with pytest.raises(ValueError, match='Coroutine functions are not supported'): +# eventhub.subscribe(callback) +# +# def test_relay_events(self) -> None: +# timestamp = datetime.now(timezone.utc) +# events = [] +# with EventBroker() as eventhub1, EventBroker() as eventhub2: +# eventhub2.relay_events_from(eventhub1) +# eventhub2.subscribe(events.append) +# eventhub1.publish(Event(timestamp=timestamp)) +# +# assert isinstance(events[0], Event) +# assert events[0].timestamp == timestamp +# +# +# @pytest.mark.anyio +# class TestAsyncEventHub: +# async def test_publish(self) -> None: +# async def async_setitem(event: Event) -> None: +# events[1] = event +# +# timestamp = datetime.now(timezone.utc) +# events: List[Optional[Event]] = [None, None] +# async with AsyncEventBroker() as eventhub: +# eventhub.subscribe(partial(setitem, events, 0)) +# eventhub.subscribe(async_setitem) +# eventhub.publish(Event(timestamp=timestamp)) +# +# assert events[0] is events[1] +# assert isinstance(events[0], Event) +# assert events[0].timestamp == timestamp +# +# async def test_unsubscribe(self) -> None: +# timestamp = datetime.now(timezone.utc) +# events = [] +# async with AsyncEventBroker() as eventhub: +# token = eventhub.subscribe(events.append) +# eventhub.publish(Event(timestamp=timestamp)) +# eventhub.unsubscribe(token) +# eventhub.publish(Event(timestamp=timestamp)) +# +# assert len(events) == 1 +# +# async def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None: +# async with AsyncEventBroker() as eventhub: +# eventhub.publish(Event(timestamp=datetime.now(timezone.utc))) +# +# assert not caplog.text +# +# async def test_publish_exception(self, caplog: LogCaptureFixture) -> None: +# def bad_subscriber(event: Event) -> None: +# raise Exception('foo') +# +# timestamp = datetime.now(timezone.utc) +# events = [] +# async with AsyncEventBroker() as eventhub: +# eventhub.subscribe(bad_subscriber) +# eventhub.subscribe(events.append) +# eventhub.publish(Event(timestamp=timestamp)) +# +# assert isinstance(events[0], Event) +# assert events[0].timestamp == timestamp +# assert 'Error delivering Event' in caplog.text +# +# async def test_relay_events(self) -> None: +# timestamp = datetime.now(timezone.utc) +# events = [] +# async with AsyncEventBroker() as eventhub1, AsyncEventBroker() as eventhub2: +# eventhub1.relay_events_from(eventhub2) +# eventhub1.subscribe(events.append) +# eventhub2.publish(Event(timestamp=timestamp)) +# +# assert isinstance(events[0], Event) +# assert events[0].timestamp == timestamp diff --git a/tests/test_events.py b/tests/test_events.py deleted file mode 100644 index bbe344f..0000000 --- a/tests/test_events.py +++ /dev/null @@ -1,135 +0,0 @@ -from datetime import datetime, timezone -from functools import partial -from operator import setitem -from typing import List, Optional - -import pytest -from _pytest.logging import LogCaptureFixture - -from apscheduler.events import AsyncEventHub, Event, EventHub - - -class TestEventHub: - def test_publish(self) -> None: - timestamp = datetime.now(timezone.utc) - events: List[Optional[Event]] = [None, None] - with EventHub() as eventhub: - eventhub.subscribe(partial(setitem, events, 0)) - eventhub.subscribe(partial(setitem, events, 1)) - eventhub.publish(Event(timestamp=timestamp)) - - assert events[0] is events[1] - assert isinstance(events[0], Event) - assert events[0].timestamp == timestamp - - def test_unsubscribe(self) -> None: - timestamp = datetime.now(timezone.utc) - events = [] - with EventHub() as eventhub: - token = eventhub.subscribe(events.append) - eventhub.publish(Event(timestamp=timestamp)) - eventhub.unsubscribe(token) - eventhub.publish(Event(timestamp=timestamp)) - - assert len(events) == 1 - - def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None: - with EventHub() as eventhub: - eventhub.publish(Event(timestamp=datetime.now(timezone.utc))) - - assert not caplog.text - - def test_publish_exception(self, caplog: LogCaptureFixture) -> None: - def bad_subscriber(event: Event) -> None: - raise Exception('foo') - - timestamp = datetime.now(timezone.utc) - events = [] - with EventHub() as eventhub: - eventhub.subscribe(bad_subscriber) - eventhub.subscribe(events.append) - eventhub.publish(Event(timestamp=timestamp)) - - assert isinstance(events[0], Event) - assert events[0].timestamp == timestamp - assert 'Error delivering Event' in caplog.text - - def test_subscribe_coroutine_callback(self) -> None: - async def callback(event: Event) -> None: - pass - - with EventHub() as eventhub: - with pytest.raises(ValueError, match='Coroutine functions are not supported'): - eventhub.subscribe(callback) - - def test_relay_events(self) -> None: - timestamp = datetime.now(timezone.utc) - events = [] - with EventHub() as eventhub1, EventHub() as eventhub2: - eventhub2.relay_events_from(eventhub1) - eventhub2.subscribe(events.append) - eventhub1.publish(Event(timestamp=timestamp)) - - assert isinstance(events[0], Event) - assert events[0].timestamp == timestamp - - -@pytest.mark.anyio -class TestAsyncEventHub: - async def test_publish(self) -> None: - async def async_setitem(event: Event) -> None: - events[1] = event - - timestamp = datetime.now(timezone.utc) - events: List[Optional[Event]] = [None, None] - async with AsyncEventHub() as eventhub: - eventhub.subscribe(partial(setitem, events, 0)) - eventhub.subscribe(async_setitem) - eventhub.publish(Event(timestamp=timestamp)) - - assert events[0] is events[1] - assert isinstance(events[0], Event) - assert events[0].timestamp == timestamp - - async def test_unsubscribe(self) -> None: - timestamp = datetime.now(timezone.utc) - events = [] - async with AsyncEventHub() as eventhub: - token = eventhub.subscribe(events.append) - eventhub.publish(Event(timestamp=timestamp)) - eventhub.unsubscribe(token) - eventhub.publish(Event(timestamp=timestamp)) - - assert len(events) == 1 - - async def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None: - async with AsyncEventHub() as eventhub: - eventhub.publish(Event(timestamp=datetime.now(timezone.utc))) - - assert not caplog.text - - async def test_publish_exception(self, caplog: LogCaptureFixture) -> None: - def bad_subscriber(event: Event) -> None: - raise Exception('foo') - - timestamp = datetime.now(timezone.utc) - events = [] - async with AsyncEventHub() as eventhub: - eventhub.subscribe(bad_subscriber) - eventhub.subscribe(events.append) - eventhub.publish(Event(timestamp=timestamp)) - - assert isinstance(events[0], Event) - assert events[0].timestamp == timestamp - assert 'Error delivering Event' in caplog.text - - async def test_relay_events(self) -> None: - timestamp = datetime.now(timezone.utc) - events = [] - async with AsyncEventHub() as eventhub1, AsyncEventHub() as eventhub2: - eventhub1.relay_events_from(eventhub2) - eventhub1.subscribe(events.append) - eventhub2.publish(Event(timestamp=timestamp)) - - assert isinstance(events[0], Event) - assert events[0].timestamp == timestamp |