summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-11 21:14:14 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-11 21:14:14 +0300
commit56afe91d5dc338db3440b2e9ecdea3e522dba30f (patch)
tree311380b0d953f09919d7e8c4c0a340507e5d0dc5
parent7248a78e7e787b728b083aaa8199eeba3a3f3023 (diff)
downloadapscheduler-56afe91d5dc338db3440b2e9ecdea3e522dba30f.tar.gz
Implemented a pluggable event broker system
-rw-r--r--docker-compose.yml12
-rw-r--r--mosquitto.conf3
-rw-r--r--setup.cfg9
-rw-r--r--src/apscheduler/abc.py52
-rw-r--r--src/apscheduler/converters.py25
-rw-r--r--src/apscheduler/datastores/async_adapter.py25
-rw-r--r--src/apscheduler/datastores/async_sqlalchemy.py170
-rw-r--r--src/apscheduler/datastores/memory.py20
-rw-r--r--src/apscheduler/datastores/mongodb.py14
-rw-r--r--src/apscheduler/datastores/sqlalchemy.py14
-rw-r--r--src/apscheduler/eventbrokers/__init__.py0
-rw-r--r--src/apscheduler/eventbrokers/async_adapter.py44
-rw-r--r--src/apscheduler/eventbrokers/async_local.py44
-rw-r--r--src/apscheduler/eventbrokers/asyncpg.py89
-rw-r--r--src/apscheduler/eventbrokers/base.py84
-rw-r--r--src/apscheduler/eventbrokers/local.py47
-rw-r--r--src/apscheduler/eventbrokers/mqtt.py67
-rw-r--r--src/apscheduler/eventbrokers/redis.py81
-rw-r--r--src/apscheduler/events.py126
-rw-r--r--src/apscheduler/schedulers/async_.py20
-rw-r--r--src/apscheduler/schedulers/sync.py14
-rw-r--r--src/apscheduler/serializers/json.py5
-rw-r--r--src/apscheduler/validators.py6
-rw-r--r--src/apscheduler/workers/async_.py36
-rw-r--r--src/apscheduler/workers/sync.py13
-rw-r--r--tests/conftest.py1
-rw-r--r--tests/test_datastores.py4
-rw-r--r--tests/test_eventbrokers.py279
-rw-r--r--tests/test_events.py135
29 files changed, 983 insertions, 456 deletions
diff --git a/docker-compose.yml b/docker-compose.yml
index 89fd92b..066a411 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -20,3 +20,15 @@ services:
image: mongo
ports:
- 127.0.0.1:27017:27017
+
+ mosquitto:
+ image: eclipse-mosquitto:2
+ volumes:
+ - ./mosquitto.conf:/mosquitto/config/mosquitto.conf:ro
+ ports:
+ - 127.0.0.1:1883:1883
+
+ redis:
+ image: redis:6
+ ports:
+ - 127.0.0.1:6379:6379
diff --git a/mosquitto.conf b/mosquitto.conf
new file mode 100644
index 0000000..7c6f5be
--- /dev/null
+++ b/mosquitto.conf
@@ -0,0 +1,3 @@
+# Required configuration for the mosquitto server in docker-compose.yml
+listener 1883
+allow_anonymous true
diff --git a/setup.cfg b/setup.cfg
index 335ed42..9341115 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -26,30 +26,33 @@ python_requires = >= 3.7
install_requires =
anyio ~= 3.0
attrs >= 20.1
- backports.zoneinfo; python_version < '3.9'
- tzdata; platform_system == "Windows"
tzlocal >= 3.0
[options.packages.find]
where = src
[options.extras_require]
+asyncpg = asyncpg >= 0.20
cbor = cbor2 >= 5.0
mongodb = pymongo >= 3.12
-postgresql = asyncpg >= 0.20
+mqtt = paho-mqtt >= 1.5
+redis = redis >= 3.5
sqlalchemy = sqlalchemy >= 1.4.22
test =
asyncpg >= 0.20
cbor2 >= 5.0
coverage
freezegun
+ paho-mqtt >= 1.5
psycopg2
pymongo >= 3.12
pymysql[rsa]
pytest >= 5.0
pytest-cov
pytest-freezegun
+ pytest-lazy-fixture
pytest-mock
+ redis >= 3.5
sqlalchemy >= 1.4.22
trio
doc =
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py
index 97fec87..9751293 100644
--- a/src/apscheduler/abc.py
+++ b/src/apscheduler/abc.py
@@ -65,6 +65,8 @@ class Serializer(metaclass=ABCMeta):
class EventSource(metaclass=ABCMeta):
+ """Interface for objects that can deliver notifications to interested subscribers."""
+
@abstractmethod
def subscribe(
self, callback: Callable[[events.Event], Any],
@@ -86,7 +88,13 @@ class EventSource(metaclass=ABCMeta):
"""
-class DataStore(EventSource):
+class EventBroker(EventSource):
+ """
+ Interface for objects that can be used to publish notifications to interested subscribers.
+
+ Can be used as a context manager.
+ """
+
def __enter__(self):
return self
@@ -94,6 +102,41 @@ class DataStore(EventSource):
pass
@abstractmethod
+ def publish(self, event: events.Event) -> None:
+ """Publish an event."""
+
+
+class AsyncEventBroker(EventSource):
+ """
+ Asynchronous version of :class:`EventBroker`.
+
+ Can be used as an asynchronous context manager.
+ """
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ @abstractmethod
+ async def publish(self, event: events.Event) -> None:
+ """Publish an event."""
+
+
+class DataStore:
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ @property
+ @abstractmethod
+ def events(self) -> EventSource:
+ pass
+
+ @abstractmethod
def add_task(self, task: Task) -> None:
"""
Add the given task to the store.
@@ -239,13 +282,18 @@ class DataStore(EventSource):
"""
-class AsyncDataStore(EventSource):
+class AsyncDataStore:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
+ @property
+ @abstractmethod
+ def events(self) -> EventSource:
+ pass
+
@abstractmethod
async def add_task(self, task: Task) -> None:
"""
diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py
new file mode 100644
index 0000000..7e8e590
--- /dev/null
+++ b/src/apscheduler/converters.py
@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+
+def as_aware_datetime(value: datetime | str) -> Optional[datetime]:
+ """Convert the value from a string to a timezone aware datetime."""
+ if isinstance(value, str):
+ # fromisoformat() does not handle the "Z" suffix
+ if value.upper().endswith('Z'):
+ value = value[:-1] + '+00:00'
+
+ value = datetime.fromisoformat(value)
+
+ return value
+
+
+def as_uuid(value: UUID | str) -> UUID:
+ """Converts a string-formatted UUID to a UUID instance."""
+ if isinstance(value, str):
+ return UUID(value)
+
+ return value
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py
index 89f268a..a5cf6a2 100644
--- a/src/apscheduler/datastores/async_adapter.py
+++ b/src/apscheduler/datastores/async_adapter.py
@@ -1,18 +1,16 @@
from __future__ import annotations
from datetime import datetime
-from functools import partial
-from typing import Any, Callable, Iterable, Optional
+from typing import Iterable, Optional
from uuid import UUID
import attr
from anyio import to_thread
from anyio.from_thread import BlockingPortal
-from .. import events
-from ..abc import AsyncDataStore, DataStore
+from ..abc import AsyncDataStore, AsyncEventBroker, DataStore, EventSource
from ..enums import ConflictPolicy
-from ..events import Event, SubscriptionToken
+from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter
from ..structures import Job, JobResult, Schedule, Task
from ..util import reentrant
@@ -21,16 +19,24 @@ from ..util import reentrant
@attr.define(eq=False)
class AsyncDataStoreAdapter(AsyncDataStore):
original: DataStore
- _portal: BlockingPortal = attr.field(init=False, eq=False)
+ _portal: BlockingPortal = attr.field(init=False)
+ _events: AsyncEventBroker = attr.field(init=False)
+
+ @property
+ def events(self) -> EventSource:
+ return self._events
async def __aenter__(self) -> AsyncDataStoreAdapter:
self._portal = BlockingPortal()
await self._portal.__aenter__()
+ self._events = AsyncEventBrokerAdapter(self.original.events, self._portal)
+ await self._events.__aenter__()
await to_thread.run_sync(self.original.__enter__)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb)
+ await self._events.__aexit__(exc_type, exc_val, exc_tb)
await self._portal.__aexit__(exc_type, exc_val, exc_tb)
async def add_task(self, task: Task) -> None:
@@ -77,10 +83,3 @@ class AsyncDataStoreAdapter(AsyncDataStore):
async def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
return await to_thread.run_sync(self.original.get_job_result, job_id)
-
- def subscribe(self, callback: Callable[[Event], Any],
- event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
- return self.original.subscribe(partial(self._portal.call, callback), event_types)
-
- def unsubscribe(self, token: events.SubscriptionToken) -> None:
- self.original.unsubscribe(token)
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py
index c0d21cb..a215d68 100644
--- a/src/apscheduler/datastores/async_sqlalchemy.py
+++ b/src/apscheduler/datastores/async_sqlalchemy.py
@@ -1,30 +1,25 @@
from __future__ import annotations
-import json
from collections import defaultdict
-from contextlib import AsyncExitStack, closing
from datetime import datetime, timedelta, timezone
-from json import JSONDecodeError
from typing import Any, Callable, Iterable, Optional
from uuid import UUID
import attr
import sniffio
-from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
-from attr import asdict
-from sqlalchemy import and_, bindparam, func, or_, select
+from sqlalchemy import and_, bindparam, or_, select
from sqlalchemy.engine import URL, Result
from sqlalchemy.exc import IntegrityError
-from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
+from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.ddl import DropTable
from sqlalchemy.sql.elements import BindParameter
-from .. import events as events_module
-from ..abc import AsyncDataStore, Job, Schedule
+from ..abc import AsyncDataStore, AsyncEventBroker, EventSource, Job, Schedule
from ..enums import ConflictPolicy
+from ..eventbrokers.async_local import LocalAsyncEventBroker
from ..events import (
- AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded,
+ DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded,
ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded,
TaskRemoved, TaskUpdated)
from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError
@@ -34,43 +29,11 @@ from ..util import reentrant
from .sqlalchemy import _BaseSQLAlchemyDataStore
-def default_json_handler(obj: Any) -> Any:
- if isinstance(obj, datetime):
- return obj.timestamp()
- elif isinstance(obj, UUID):
- return obj.hex
- elif isinstance(obj, frozenset):
- return list(obj)
-
- raise TypeError(f'Cannot JSON encode type {type(obj)}')
-
-
-def json_object_hook(obj: dict[str, Any]) -> Any:
- for key, value in obj.items():
- if key == 'timestamp':
- obj[key] = datetime.fromtimestamp(value, timezone.utc)
- elif key == 'job_id':
- obj[key] = UUID(value)
- elif key == 'tags':
- obj[key] = frozenset(value)
-
- return obj
-
-
@reentrant
@attr.define(eq=False)
class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
engine: AsyncEngine
-
- _exit_stack: AsyncExitStack = attr.field(init=False, factory=AsyncExitStack)
- _events: AsyncEventHub = attr.field(init=False, factory=AsyncEventHub)
-
- def __attrs_post_init__(self) -> None:
- super().__attrs_post_init__()
-
- if self.notify_channel:
- if self.engine.dialect.name != 'postgresql' or self.engine.dialect.driver != 'asyncpg':
- self.notify_channel = None
+ _events: AsyncEventBroker = attr.field(factory=LocalAsyncEventBroker)
@classmethod
def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore:
@@ -98,93 +61,44 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
raise RuntimeError(f'Unexpected schema version ({version}); '
f'only version 1 is supported by this version of APScheduler')
- await self._exit_stack.enter_async_context(self._events)
-
- if self.notify_channel:
- task_group = create_task_group()
- await self._exit_stack.enter_async_context(task_group)
- await task_group.start(self._listen_notifications)
- self._exit_stack.callback(task_group.cancel_scope.cancel)
-
+ await self.events.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
-
- async def _publish(self, conn: AsyncConnection, event: DataStoreEvent) -> None:
- if self.notify_channel:
- event_type = event.__class__.__name__
- event_data = json.dumps(asdict(event), ensure_ascii=False,
- default=default_json_handler)
- notification = event_type + ' ' + event_data
- if len(notification) < 8000:
- await conn.execute(func.pg_notify(self.notify_channel, notification))
- return
-
- self._logger.warning(
- 'Could not send %s notification because it is too long (%d >= 8000)',
- event_type, len(notification))
-
- self._events.publish(event)
-
- async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
- def callback(connection, pid, channel: str, payload: str) -> None:
- self._logger.debug('Received notification on channel %s: %s', channel, payload)
- event_type, _, json_data = payload.partition(' ')
- try:
- event_data = json.loads(json_data, object_hook=json_object_hook)
- except JSONDecodeError:
- self._logger.exception('Failed decoding JSON payload of notification: %s', payload)
- return
-
- event_class = getattr(events_module, event_type)
- event = event_class(**event_data)
- self._events.publish(event)
-
- task_started_sent = False
- while True:
- with closing(await self.engine.raw_connection()) as conn:
- asyncpg_conn = conn.connection._connection
- await asyncpg_conn.add_listener(self.notify_channel, callback)
- if not task_started_sent:
- task_status.started()
- task_started_sent = True
-
- try:
- while True:
- await sleep(self.max_idle_time)
- await asyncpg_conn.execute('SELECT 1')
- finally:
- await asyncpg_conn.remove_listener(self.notify_channel, callback)
-
- def _deserialize_schedules(self, result: Result) -> list[Schedule]:
+ await self.events.__aexit__(exc_type, exc_val, exc_tb)
+
+ @property
+ def events(self) -> EventSource:
+ return self._events
+
+ async def _deserialize_schedules(self, result: Result) -> list[Schedule]:
schedules: list[Schedule] = []
for row in result:
try:
schedules.append(Schedule.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
- self._events.publish(
+ await self._events.publish(
ScheduleDeserializationFailed(schedule_id=row['id'], exception=exc))
return schedules
- def _deserialize_jobs(self, result: Result) -> list[Job]:
+ async def _deserialize_jobs(self, result: Result) -> list[Job]:
jobs: list[Job] = []
for row in result:
try:
jobs.append(Job.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
- self._events.publish(
+ await self._events.publish(
JobDeserializationFailed(job_id=row['id'], exception=exc))
return jobs
def subscribe(self, callback: Callable[[Event], Any],
event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
- return self._events.subscribe(callback, event_types)
+ return self.events.subscribe(callback, event_types)
def unsubscribe(self, token: SubscriptionToken) -> None:
- self._events.unsubscribe(token)
+ self.events.unsubscribe(token)
async def add_task(self, task: Task) -> None:
insert = self.t_tasks.insert().\
@@ -201,9 +115,10 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
where(self.t_tasks.c.id == task.id)
async with self.engine.begin() as conn:
await conn.execute(update)
- self._events.publish(TaskUpdated(task_id=task.id))
+
+ await self._events.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._events.publish(TaskAdded(task_id=task.id))
async def remove_task(self, task_id: str) -> None:
delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
@@ -212,7 +127,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
if result.rowcount == 0:
raise TaskLookupError(task_id)
else:
- self._events.publish(TaskRemoved(task_id=task_id))
+ await self._events.publish(TaskRemoved(task_id=task_id))
async def get_task(self, task_id: str) -> Task:
query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
@@ -243,9 +158,6 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
try:
async with self.engine.begin() as conn:
await conn.execute(insert)
- event = ScheduleAdded(schedule_id=schedule.id,
- next_fire_time=schedule.next_fire_time)
- await self._publish(conn, event)
except IntegrityError:
if conflict_policy is ConflictPolicy.exception:
raise ConflictingIdError(schedule.id) from None
@@ -257,9 +169,13 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
async with self.engine.begin() as conn:
await conn.execute(update)
- event = ScheduleUpdated(schedule_id=schedule.id,
- next_fire_time=schedule.next_fire_time)
- await self._publish(conn, event)
+ event = ScheduleUpdated(schedule_id=schedule.id,
+ next_fire_time=schedule.next_fire_time)
+ await self._events.publish(event)
+ else:
+ event = ScheduleAdded(schedule_id=schedule.id,
+ next_fire_time=schedule.next_fire_time)
+ await self._events.publish(event)
async def remove_schedules(self, ids: Iterable[str]) -> None:
async with self.engine.begin() as conn:
@@ -272,8 +188,8 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
await conn.execute(delete)
removed_ids = ids
- for schedule_id in removed_ids:
- await self._publish(conn, ScheduleRemoved(schedule_id=schedule_id))
+ for schedule_id in removed_ids:
+ await self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
async def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]:
query = self.t_schedules.select().order_by(self.t_schedules.c.id)
@@ -282,7 +198,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
async with self.engine.begin() as conn:
result = await conn.execute(query)
- return self._deserialize_schedules(result)
+ return await self._deserialize_schedules(result)
async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
async with self.engine.begin() as conn:
@@ -308,7 +224,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
where(and_(self.t_schedules.c.acquired_by == scheduler_id))
result = conn.execute(query)
- schedules = self._deserialize_schedules(result)
+ schedules = await self._deserialize_schedules(result)
return schedules
@@ -365,11 +281,11 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
where(self.t_schedules.c.id.in_(finished_schedule_ids))
await conn.execute(delete)
- for event in update_events:
- await self._publish(conn, event)
+ for event in update_events:
+ await self._events.publish(event)
- for schedule_id in finished_schedule_ids:
- await self._publish(conn, ScheduleRemoved(schedule_id=schedule_id))
+ for schedule_id in finished_schedule_ids:
+ await self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
async def get_next_schedule_run_time(self) -> Optional[datetime]:
statenent = select(self.t_schedules.c.id).\
@@ -386,9 +302,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
async with self.engine.begin() as conn:
await conn.execute(insert)
- event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id,
- tags=job.tags)
- await self._publish(conn, event)
+ event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id,
+ tags=job.tags)
+ await self._events.publish(event)
async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> list[Job]:
query = self.t_jobs.select().order_by(self.t_jobs.c.id)
@@ -398,7 +314,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
async with self.engine.begin() as conn:
result = await conn.execute(query)
- return self._deserialize_jobs(result)
+ return await self._deserialize_jobs(result)
async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> list[Job]:
async with self.engine.begin() as conn:
@@ -416,7 +332,7 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
return []
# Mark the jobs as acquired by this worker
- jobs = self._deserialize_jobs(result)
+ jobs = await self._deserialize_jobs(result)
task_ids: set[str] = {job.task_id for job in jobs}
# Retrieve the limits
diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py
index df4e0c2..98306a2 100644
--- a/src/apscheduler/datastores/memory.py
+++ b/src/apscheduler/datastores/memory.py
@@ -4,17 +4,16 @@ from bisect import bisect_left, insort_right
from collections import defaultdict
from datetime import MAXYEAR, datetime, timedelta, timezone
from functools import partial
-from typing import Any, Callable, Iterable, Optional
+from typing import Any, Iterable, Optional
from uuid import UUID
import attr
-from .. import events
-from ..abc import DataStore, Job, Schedule
+from ..abc import DataStore, EventBroker, EventSource, Job, Schedule
from ..enums import ConflictPolicy
+from ..eventbrokers.local import LocalEventBroker
from ..events import (
- EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken,
- TaskAdded, TaskRemoved, TaskUpdated)
+ JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, TaskAdded, TaskRemoved, TaskUpdated)
from ..exceptions import ConflictingIdError, TaskLookupError
from ..structures import JobResult, Task
from ..util import reentrant
@@ -75,7 +74,7 @@ class JobState:
@attr.define(eq=False)
class MemoryDataStore(DataStore):
lock_expiration_delay: float = 30
- _events: EventHub = attr.Factory(EventHub)
+ _events: EventBroker = attr.Factory(LocalEventBroker)
_tasks: dict[str, TaskState] = attr.Factory(dict)
_schedules: list[ScheduleState] = attr.Factory(list)
_schedules_by_id: dict[str, ScheduleState] = attr.Factory(dict)
@@ -102,12 +101,9 @@ class MemoryDataStore(DataStore):
def __exit__(self, exc_type, exc_val, exc_tb):
self._events.__exit__(exc_type, exc_val, exc_tb)
- def subscribe(self, callback: Callable[[events.Event], Any],
- event_types: Optional[Iterable[type[events.Event]]] = None) -> SubscriptionToken:
- return self._events.subscribe(callback, event_types)
-
- def unsubscribe(self, token: events.SubscriptionToken) -> None:
- self._events.unsubscribe(token)
+ @property
+ def events(self) -> EventSource:
+ return self._events
def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]:
return [state.schedule for state in self._schedules
diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py
index 749381e..6e7d0aa 100644
--- a/src/apscheduler/datastores/mongodb.py
+++ b/src/apscheduler/datastores/mongodb.py
@@ -14,12 +14,12 @@ from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne
from pymongo.collection import Collection
from pymongo.errors import DuplicateKeyError
-from .. import events
-from ..abc import DataStore, Job, Schedule, Serializer
+from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer
from ..enums import ConflictPolicy
+from ..eventbrokers.local import LocalEventBroker
from ..events import (
- DataStoreEvent, EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated,
- SubscriptionToken, TaskAdded, TaskRemoved, TaskUpdated)
+ DataStoreEvent, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken,
+ TaskAdded, TaskRemoved, TaskUpdated)
from ..exceptions import (
ConflictingIdError, DeserializationError, SerializationError, TaskLookupError)
from ..serializers.pickle import PickleSerializer
@@ -42,7 +42,7 @@ class MongoDBDataStore(DataStore):
_logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
_exit_stack: ExitStack = attr.field(init=False, factory=ExitStack)
- _events: EventHub = attr.field(init=False, factory=EventHub)
+ _events: EventBroker = attr.field(init=False, factory=LocalEventBroker)
_local_tasks: dict[str, Task] = attr.field(init=False, factory=dict)
@client.validator
@@ -62,6 +62,10 @@ class MongoDBDataStore(DataStore):
client = MongoClient(uri)
return cls(client, **options)
+ @property
+ def events(self) -> EventSource:
+ return self._events
+
def __enter__(self):
server_info = self.client.server_info()
if server_info['versionArray'] < [4, 0]:
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py
index d49ca61..3040ae4 100644
--- a/src/apscheduler/datastores/sqlalchemy.py
+++ b/src/apscheduler/datastores/sqlalchemy.py
@@ -16,12 +16,12 @@ from sqlalchemy.future import Engine, create_engine
from sqlalchemy.sql.ddl import DropTable
from sqlalchemy.sql.elements import BindParameter, literal
-from ..abc import DataStore, Job, Schedule, Serializer
+from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer
from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome
+from ..eventbrokers.local import LocalEventBroker
from ..events import (
- Event, EventHub, JobAdded, JobDeserializationFailed, ScheduleAdded,
- ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded,
- TaskRemoved, TaskUpdated)
+ Event, JobAdded, JobDeserializationFailed, ScheduleAdded, ScheduleDeserializationFailed,
+ ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, TaskRemoved, TaskUpdated)
from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError
from ..marshalling import callable_to_ref
from ..serializers.pickle import PickleSerializer
@@ -179,7 +179,7 @@ class _BaseSQLAlchemyDataStore:
class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore):
engine: Engine
- _events: EventHub = attr.field(init=False, factory=EventHub)
+ _events: EventBroker = attr.field(init=False, factory=LocalEventBroker)
@classmethod
def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore:
@@ -208,6 +208,10 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore):
def __exit__(self, exc_type, exc_val, exc_tb):
self._events.__exit__(exc_type, exc_val, exc_tb)
+ @property
+ def events(self) -> EventSource:
+ return self._events
+
def subscribe(self, callback: Callable[[Event], Any],
event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
return self._events.subscribe(callback, event_types)
diff --git a/src/apscheduler/eventbrokers/__init__.py b/src/apscheduler/eventbrokers/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/apscheduler/eventbrokers/__init__.py
diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py
new file mode 100644
index 0000000..cb18386
--- /dev/null
+++ b/src/apscheduler/eventbrokers/async_adapter.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+from contextlib import AsyncExitStack
+from functools import partial
+from typing import Any, Callable, Iterable, Optional
+
+import attr
+from anyio import to_thread
+from anyio.from_thread import BlockingPortal
+
+from apscheduler.abc import EventBroker
+from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker
+from apscheduler.events import Event, SubscriptionToken
+from apscheduler.util import reentrant
+
+
+@reentrant
+@attr.define(eq=False)
+class AsyncEventBrokerAdapter(LocalAsyncEventBroker):
+ original: EventBroker
+ portal: BlockingPortal
+ _exit_stack: AsyncExitStack = attr.field(init=False)
+
+ async def __aenter__(self):
+ self._exit_stack = AsyncExitStack()
+ if not self.portal:
+ self.portal = BlockingPortal()
+ self._exit_stack.enter_async_context(self.portal)
+
+ await to_thread.run_sync(self.original.__enter__)
+ return await super().__aenter__()
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb)
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
+ return await super().__aexit__(exc_type, exc_val, exc_tb)
+
+ async def publish(self, event: Event) -> None:
+ await to_thread.run_sync(self.original.publish, event)
+
+ def subscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
+ token = self.original.subscribe(partial(self.portal.call, callback), event_types)
+ return token
diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py
new file mode 100644
index 0000000..e73b4b9
--- /dev/null
+++ b/src/apscheduler/eventbrokers/async_local.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+from inspect import isawaitable
+from logging import Logger, getLogger
+from typing import Any, Callable
+
+import attr
+from anyio import create_task_group
+from anyio.abc import TaskGroup
+
+from ..abc import AsyncEventBroker
+from ..events import Event
+from ..util import reentrant
+from .base import BaseEventBroker
+
+
+@reentrant
+@attr.define(eq=False)
+class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker):
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+ _task_group: TaskGroup = attr.field(init=False)
+
+ async def __aenter__(self) -> LocalAsyncEventBroker:
+ self._task_group = create_task_group()
+ await self._task_group.__aenter__()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
+ del self._task_group
+
+ async def publish(self, event: Event) -> None:
+ async def deliver_event(func: Callable[[Event], Any]) -> None:
+ try:
+ retval = func(event)
+ if isawaitable(retval):
+ await retval
+ except BaseException:
+ self._logger.exception('Error delivering %s event', event.__class__.__name__)
+
+ event_type = type(event)
+ for subscription in self._subscriptions.values():
+ if subscription.event_types is None or event_type in subscription.event_types:
+ self._task_group.start_soon(deliver_event, subscription.callback)
diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py
new file mode 100644
index 0000000..447bee8
--- /dev/null
+++ b/src/apscheduler/eventbrokers/asyncpg.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+from contextlib import asynccontextmanager
+from logging import Logger, getLogger
+from typing import TYPE_CHECKING, AsyncContextManager, AsyncGenerator, Callable
+
+import attr
+from anyio import TASK_STATUS_IGNORED, sleep
+from asyncpg import Connection
+from asyncpg.pool import Pool
+
+from ..events import Event
+from ..exceptions import SerializationError
+from ..util import reentrant
+from .async_local import LocalAsyncEventBroker
+from .base import DistributedEventBrokerMixin
+
+if TYPE_CHECKING:
+ from sqlalchemy.ext.asyncio import AsyncEngine
+
+
+@reentrant
+@attr.define(eq=False)
+class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
+ connection_factory: Callable[[], AsyncContextManager[Connection]]
+ channel: str = attr.field(kw_only=True, default='apscheduler')
+ max_idle_time: float = attr.field(kw_only=True, default=30)
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+
+ @classmethod
+ def from_asyncpg_pool(cls, pool: Pool) -> AsyncpgEventBroker:
+ return cls(pool.acquire)
+
+ @classmethod
+ def from_async_sqla_engine(cls, engine: AsyncEngine) -> AsyncpgEventBroker:
+ if engine.dialect.driver != 'asyncpg':
+ raise ValueError(f'The driver in the engine must be "asyncpg" (current: '
+ f'{engine.dialect.driver})')
+
+ @asynccontextmanager
+ async def connection_factory() -> AsyncGenerator[Connection, None]:
+ conn = await engine.raw_connection()
+ try:
+ yield conn.connection._connection
+ finally:
+ conn.close()
+
+ return cls(connection_factory)
+
+ async def __aenter__(self) -> LocalAsyncEventBroker:
+ await super().__aenter__()
+ await self._task_group.start(self._listen_notifications)
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ self._task_group.cancel_scope.cancel()
+ await super().__aexit__(exc_type, exc_val, exc_tb)
+
+ async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
+ local_publish = super(AsyncpgEventBroker, self).publish
+
+ def callback(connection, pid, channel: str, payload: str) -> None:
+ event = self.reconstitute_event_str(payload)
+ if event is not None:
+ self._task_group.start_soon(local_publish, event)
+
+ task_started_sent = False
+ while True:
+ async with self.connection_factory() as conn:
+ await conn.add_listener(self.channel, callback)
+ if not task_started_sent:
+ task_status.started()
+ task_started_sent = True
+
+ try:
+ while True:
+ await sleep(self.max_idle_time)
+ await conn.execute('SELECT 1')
+ finally:
+ await conn.remove_listener(self.channel, callback)
+
+ async def publish(self, event: Event) -> None:
+ notification = self.generate_notification_str(event)
+ if len(notification) > 7999:
+ raise SerializationError('Serialized event object exceeds 7999 bytes in size')
+
+ async with self.connection_factory() as conn:
+ await conn.execute("SELECT pg_notify($1, $2)", self.channel, notification)
+ return
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py
new file mode 100644
index 0000000..da89dc5
--- /dev/null
+++ b/src/apscheduler/eventbrokers/base.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+from base64 import b64decode, b64encode
+from logging import Logger
+from typing import Any, Callable, Iterable, Optional
+
+import attr
+
+from .. import abc, events
+from ..abc import EventBroker, Serializer
+from ..events import Event, Subscription, SubscriptionToken
+from ..exceptions import DeserializationError
+
+
+@attr.define(eq=False)
+class BaseEventBroker(EventBroker):
+ _subscriptions: dict[SubscriptionToken, Subscription] = attr.field(init=False, factory=dict)
+
+ def subscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
+ types = set(event_types) if event_types else None
+ token = SubscriptionToken(object())
+ subscription = Subscription(callback, types)
+ self._subscriptions[token] = subscription
+ return token
+
+ def unsubscribe(self, token: SubscriptionToken) -> None:
+ self._subscriptions.pop(token, None)
+
+ def relay_events_from(self, source: abc.EventSource) -> SubscriptionToken:
+ return source.subscribe(self.publish)
+
+
+class DistributedEventBrokerMixin:
+ _logger: Logger
+ serializer: Serializer
+
+ def generate_notification(self, event: Event, use_base64: bool = False) -> bytes:
+ serialized = self.serializer.serialize(attr.asdict(event))
+ return event.__class__.__name__.encode('ascii') + b' ' + serialized
+
+ def generate_notification_str(self, event: Event) -> str:
+ serialized = self.serializer.serialize(attr.asdict(event))
+ return event.__class__.__name__ + ' ' + b64encode(serialized).decode('ascii')
+
+ def _reconstitute_event(self, event_type: str, serialized: bytes) -> Optional[Event]:
+ try:
+ kwargs = self.serializer.deserialize(serialized)
+ except DeserializationError:
+ self._logger.exception('Failed to deserialize an event of type %s', event_type,
+ serialized=serialized)
+ return None
+
+ try:
+ event_class = getattr(events, event_type)
+ except AttributeError:
+ self._logger.error('Receive notification for a nonexistent event type: %s',
+ event_type, serialized=serialized)
+ return None
+
+ try:
+ return event_class(**kwargs)
+ except Exception:
+ self._logger.exception('Error reconstituting event of type %s', event_type)
+ return None
+
+ def reconstitute_event(self, payload: bytes) -> Optional[Event]:
+ try:
+ event_type_bytes, serialized = payload.split(b' ', 1)
+ except ValueError:
+ self._logger.error('Received malformatted notification', payload=payload)
+ return None
+
+ event_type = event_type_bytes.decode('ascii', errors='replace')
+ return self._reconstitute_event(event_type, serialized)
+
+ def reconstitute_event_str(self, payload: str) -> Optional[Event]:
+ try:
+ event_type, b64_serialized = payload.split(' ', 1)
+ except ValueError:
+ self._logger.error('Received malformatted notification', payload=payload)
+ return None
+
+ return self._reconstitute_event(event_type, b64decode(b64_serialized))
diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py
new file mode 100644
index 0000000..ab75575
--- /dev/null
+++ b/src/apscheduler/eventbrokers/local.py
@@ -0,0 +1,47 @@
+from __future__ import annotations
+
+from asyncio import iscoroutinefunction
+from concurrent.futures import ThreadPoolExecutor
+from logging import Logger, getLogger
+from typing import Any, Callable, Iterable, Optional
+
+import attr
+
+from ..events import Event, SubscriptionToken
+from ..util import reentrant
+from .base import BaseEventBroker
+
+
+@reentrant
+@attr.define(eq=False)
+class LocalEventBroker(BaseEventBroker):
+ _executor: ThreadPoolExecutor = attr.field(init=False)
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+
+ def __enter__(self) -> LocalEventBroker:
+ self._executor = ThreadPoolExecutor(1)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._executor.shutdown(wait=exc_type is None)
+ del self._executor
+
+ def subscribe(self, callback: Callable[[Event], Any],
+ event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
+ if iscoroutinefunction(callback):
+ raise ValueError('Coroutine functions are not supported as callbacks on a synchronous '
+ 'event source')
+
+ return super().subscribe(callback, event_types)
+
+ def publish(self, event: Event) -> None:
+ event_type = type(event)
+ for subscription in list(self._subscriptions.values()):
+ if subscription.event_types is None or event_type in subscription.event_types:
+ self._executor.submit(self._deliver_event, subscription.callback, event)
+
+ def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None:
+ try:
+ func(event)
+ except BaseException:
+ self._logger.exception('Error delivering %s event', event.__class__.__name__)
diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py
new file mode 100644
index 0000000..cfedb88
--- /dev/null
+++ b/src/apscheduler/eventbrokers/mqtt.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from concurrent.futures import Future
+from logging import Logger, getLogger
+from typing import Any, Optional
+
+import attr
+from paho.mqtt.client import Client, MQTTMessage
+from paho.mqtt.properties import Properties
+from paho.mqtt.reasoncodes import ReasonCodes
+
+from ..abc import Serializer
+from ..events import Event
+from ..serializers.json import JSONSerializer
+from ..util import reentrant
+from .base import DistributedEventBrokerMixin
+from .local import LocalEventBroker
+
+
+@reentrant
+@attr.define(eq=False)
+class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
+ client: Client
+ serializer: Serializer = attr.field(factory=JSONSerializer)
+ host: str = attr.field(kw_only=True, default='localhost')
+ port: int = attr.field(kw_only=True, default=1883)
+ topic: str = attr.field(kw_only=True, default='apscheduler')
+ subscribe_qos: int = attr.field(kw_only=True, default=0)
+ publish_qos: int = attr.field(kw_only=True, default=0)
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+ _ready_future: Future[None] = attr.field(init=False)
+
+ def __enter__(self):
+ self._ready_future = Future()
+ self.client.enable_logger(self._logger)
+ self.client.on_connect = self._on_connect
+ self.client.on_message = self._on_message
+ self.client.on_subscribe = self._on_subscribe
+ self.client.connect(self.host, self.port)
+ self.client.loop_start()
+ self._ready_future.result(10)
+ return super().__enter__()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.client.disconnect()
+ self.client.loop_stop(force=exc_type is not None)
+ return super().__exit__(exc_type, exc_val, exc_tb)
+
+ def _on_connect(self, client: Client, userdata: Any, flags: dict[str, Any],
+ rc: ReasonCodes | int, properties: Optional[Properties] = None) -> None:
+ try:
+ client.subscribe(self.topic, qos=self.subscribe_qos)
+ except Exception as exc:
+ self._ready_future.set_exception(exc)
+ raise
+
+ def _on_subscribe(self, client: Client, userdata: Any, mid, granted_qos: list[int]) -> None:
+ self._ready_future.set_result(None)
+
+ def _on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None:
+ event = self.reconstitute_event(msg.payload)
+ if event is not None:
+ super().publish(event)
+
+ def publish(self, event: Event) -> None:
+ notification = self.generate_notification(event)
+ self.client.publish(self.topic, notification, qos=self.publish_qos)
diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py
new file mode 100644
index 0000000..92c7f83
--- /dev/null
+++ b/src/apscheduler/eventbrokers/redis.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+from concurrent.futures import Future
+from logging import Logger, getLogger
+from threading import Thread
+from typing import Optional
+
+import attr
+from redis import ConnectionPool, Redis
+
+from ..abc import Serializer
+from ..events import Event
+from ..serializers.json import JSONSerializer
+from ..util import reentrant
+from .base import DistributedEventBrokerMixin
+from .local import LocalEventBroker
+
+
+@reentrant
+@attr.define(eq=False)
+class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
+ client: Redis
+ serializer: Serializer = attr.field(factory=JSONSerializer)
+ channel: str = attr.field(kw_only=True, default='apscheduler')
+ message_poll_interval: float = attr.field(kw_only=True, default=0.05)
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+ _stopped: bool = attr.field(init=False, default=True)
+ _ready_future: Future[None] = attr.field(init=False)
+
+ @classmethod
+ def from_url(cls, url: str, db: Optional[str] = None, decode_components: bool = False,
+ **kwargs) -> RedisEventBroker:
+ pool = ConnectionPool.from_url(url, db, decode_components, **kwargs)
+ client = Redis(connection_pool=pool)
+ return cls(client)
+
+ def __enter__(self):
+ self._stopped = False
+ self._ready_future = Future()
+ self._thread = Thread(target=self._listen_messages, daemon=True, name='Redis subscriber')
+ self._thread.start()
+ self._ready_future.result(10)
+ return super().__enter__()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._stopped = True
+ if not exc_type:
+ self._thread.join(5)
+
+ super().__exit__(exc_type, exc_val, exc_tb)
+
+ def _listen_messages(self) -> None:
+ while not self._stopped:
+ try:
+ pubsub = self.client.pubsub()
+ pubsub.subscribe(self.channel)
+ except BaseException as exc:
+ if not self._ready_future.done():
+ self._ready_future.set_exception(exc)
+
+ raise
+ else:
+ if not self._ready_future.done():
+ self._ready_future.set_result(None)
+
+ try:
+ while not self._stopped:
+ msg = pubsub.get_message(timeout=self.message_poll_interval)
+ if msg and isinstance(msg['data'], bytes):
+ event = self.reconstitute_event(msg['data'])
+ if event is not None:
+ super().publish(event)
+ except BaseException:
+ self._logger.exception('Subscriber crashed')
+ raise
+ finally:
+ pubsub.close()
+
+ def publish(self, event: Event) -> None:
+ notification = self.generate_notification(event)
+ self.client.publish(self.channel, notification)
diff --git a/src/apscheduler/events.py b/src/apscheduler/events.py
index 571aead..49703d9 100644
--- a/src/apscheduler/events.py
+++ b/src/apscheduler/events.py
@@ -1,38 +1,24 @@
from __future__ import annotations
-import logging
-from abc import abstractmethod
-from asyncio import iscoroutinefunction
-from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime, timezone
from functools import partial
-from inspect import isawaitable
-from logging import Logger
from traceback import format_tb
-from typing import Any, Callable, Iterable, NewType, Optional
+from typing import Any, Callable, NewType, Optional
from uuid import UUID
import attr
-from anyio import create_task_group
-from anyio.abc import TaskGroup
+from attr.converters import optional
-from . import abc
+from .converters import as_aware_datetime, as_uuid
from .structures import Job
SubscriptionToken = NewType('SubscriptionToken', object)
-def timestamp_to_datetime(value: datetime | float | None) -> Optional[datetime]:
- if isinstance(value, float):
- return datetime.fromtimestamp(value, timezone.utc)
-
- return value
-
-
@attr.define(kw_only=True, frozen=True)
class Event:
timestamp: datetime = attr.field(factory=partial(datetime.now, timezone.utc),
- converter=timestamp_to_datetime)
+ converter=as_aware_datetime)
#
@@ -62,13 +48,13 @@ class TaskRemoved(DataStoreEvent):
@attr.define(kw_only=True, frozen=True)
class ScheduleAdded(DataStoreEvent):
schedule_id: str
- next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime)
+ next_fire_time: Optional[datetime] = attr.field(converter=optional(as_aware_datetime))
@attr.define(kw_only=True, frozen=True)
class ScheduleUpdated(DataStoreEvent):
schedule_id: str
- next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime)
+ next_fire_time: Optional[datetime] = attr.field(converter=optional(as_aware_datetime))
@attr.define(kw_only=True, frozen=True)
@@ -78,15 +64,15 @@ class ScheduleRemoved(DataStoreEvent):
@attr.define(kw_only=True, frozen=True)
class JobAdded(DataStoreEvent):
- job_id: UUID
+ job_id: UUID = attr.field(converter=as_uuid)
task_id: str
schedule_id: Optional[str]
- tags: frozenset[str]
+ tags: frozenset[str] = attr.field(converter=frozenset)
@attr.define(kw_only=True, frozen=True)
class JobRemoved(DataStoreEvent):
- job_id: UUID
+ job_id: UUID = attr.field(converter=as_uuid)
@attr.define(kw_only=True, frozen=True)
@@ -97,7 +83,7 @@ class ScheduleDeserializationFailed(DataStoreEvent):
@attr.define(kw_only=True, frozen=True)
class JobDeserializationFailed(DataStoreEvent):
- job_id: UUID
+ job_id: UUID = attr.field(converter=as_uuid)
exception: BaseException
@@ -141,11 +127,11 @@ class WorkerStopped(WorkerEvent):
@attr.define(kw_only=True, frozen=True)
class JobExecutionEvent(WorkerEvent):
- job_id: UUID
+ job_id: UUID = attr.field(converter=as_uuid)
task_id: str
schedule_id: Optional[str]
- scheduled_fire_time: Optional[datetime]
- start_deadline: Optional[datetime]
+ scheduled_fire_time: Optional[datetime] = attr.field(converter=optional(as_aware_datetime))
+ start_deadline: Optional[datetime] = attr.field(converter=optional(as_aware_datetime))
@attr.define(kw_only=True, frozen=True)
@@ -175,7 +161,7 @@ class JobDeadlineMissed(JobExecutionEvent):
@attr.define(kw_only=True, frozen=True)
class JobCompleted(JobExecutionEvent):
"""Signals that a worker has successfully run a job."""
- start_time: datetime
+ start_time: datetime = attr.field(converter=optional(as_aware_datetime))
return_value: str
@classmethod
@@ -188,7 +174,7 @@ class JobCompleted(JobExecutionEvent):
@attr.define(kw_only=True, frozen=True)
class JobCancelled(JobExecutionEvent):
"""Signals that a job was cancelled."""
- start_time: datetime
+ start_time: datetime = attr.field(converter=optional(as_aware_datetime))
@classmethod
def from_job(cls, job: Job, start_time: datetime) -> JobCancelled:
@@ -227,85 +213,3 @@ class JobFailed(JobExecutionEvent):
class Subscription:
callback: Callable[[Event], Any]
event_types: Optional[set[type[Event]]]
-
-
-@attr.define
-class _BaseEventHub(abc.EventSource):
- _logger: Logger = attr.field(init=False, factory=lambda: logging.getLogger(__name__))
- _subscriptions: dict[SubscriptionToken, Subscription] = attr.field(init=False, factory=dict)
-
- def subscribe(self, callback: Callable[[Event], Any],
- event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
- types = set(event_types) if event_types else None
- token = SubscriptionToken(object())
- subscription = Subscription(callback, types)
- self._subscriptions[token] = subscription
- return token
-
- def unsubscribe(self, token: SubscriptionToken) -> None:
- self._subscriptions.pop(token, None)
-
- @abstractmethod
- def publish(self, event: Event) -> None:
- """Publish an event to all subscribers."""
-
- def relay_events_from(self, source: abc.EventSource) -> SubscriptionToken:
- return source.subscribe(self.publish)
-
-
-class EventHub(_BaseEventHub):
- _executor: ThreadPoolExecutor
-
- def __enter__(self) -> EventHub:
- self._executor = ThreadPoolExecutor(1)
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._executor.shutdown(wait=exc_type is None)
-
- def subscribe(self, callback: Callable[[Event], Any],
- event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
- if iscoroutinefunction(callback):
- raise ValueError('Coroutine functions are not supported as callbacks on a synchronous '
- 'event source')
-
- return super().subscribe(callback, event_types)
-
- def publish(self, event: Event) -> None:
- def deliver_event(func: Callable[[Event], Any]) -> None:
- try:
- func(event)
- except BaseException:
- self._logger.exception('Error delivering %s event', event.__class__.__name__)
-
- event_type = type(event)
- for subscription in list(self._subscriptions.values()):
- if subscription.event_types is None or event_type in subscription.event_types:
- self._executor.submit(deliver_event, subscription.callback)
-
-
-class AsyncEventHub(_BaseEventHub):
- _task_group: TaskGroup
-
- async def __aenter__(self) -> AsyncEventHub:
- self._task_group = create_task_group()
- await self._task_group.__aenter__()
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
- del self._task_group
-
- def publish(self, event: Event) -> None:
- async def deliver_event(func: Callable[[Event], Any]) -> None:
- try:
- retval = func(event)
- if isawaitable(retval):
- await retval
- except BaseException:
- self._logger.exception('Error delivering %s event', event.__class__.__name__)
-
- event_type = type(event)
- for subscription in self._subscriptions.values():
- if subscription.event_types is None or event_type in subscription.event_types:
- self._task_group.start_soon(deliver_event, subscription.callback)
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index b60aed6..790baf4 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -16,9 +16,9 @@ from ..abc import AsyncDataStore, DataStore, EventSource, Job, Schedule, Trigger
from ..datastores.async_adapter import AsyncDataStoreAdapter
from ..datastores.memory import MemoryDataStore
from ..enums import CoalescePolicy, ConflictPolicy, RunState
+from ..eventbrokers.async_local import LocalAsyncEventBroker
from ..events import (
- AsyncEventHub, Event, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated,
- SubscriptionToken)
+ Event, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, SubscriptionToken)
from ..marshalling import callable_to_ref
from ..structures import Task
from ..workers.async_ import AsyncWorker
@@ -40,7 +40,7 @@ class AsyncScheduler(EventSource):
self.logger = logger or getLogger(__name__)
self.start_worker = start_worker
self._exit_stack = AsyncExitStack()
- self._events = AsyncEventHub()
+ self._events = LocalAsyncEventBroker()
data_store = data_store or MemoryDataStore()
if isinstance(data_store, DataStore):
@@ -60,13 +60,13 @@ class AsyncScheduler(EventSource):
# Initialize the data store
await self._exit_stack.enter_async_context(self.data_store)
- relay_token = self._events.relay_events_from(self.data_store)
- self._exit_stack.callback(self.data_store.unsubscribe, relay_token)
+ relay_token = self._events.relay_events_from(self.data_store.events)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token)
# Wake up the scheduler if the data store emits a significant schedule event
- wakeup_token = self.data_store.subscribe(
+ wakeup_token = self.data_store.events.subscribe(
lambda event: self._wakeup_event.set(), {ScheduleAdded, ScheduleUpdated})
- self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token)
# Start the built-in worker, if configured to do so
if self.start_worker:
@@ -132,7 +132,7 @@ class AsyncScheduler(EventSource):
# Signal that the scheduler has started
self._state = RunState.started
task_status.started()
- self._events.publish(SchedulerStarted())
+ await self._events.publish(SchedulerStarted())
try:
while self._state is RunState.started:
@@ -190,11 +190,11 @@ class AsyncScheduler(EventSource):
pass
except BaseException as exc:
self._state = RunState.stopped
- self._events.publish(SchedulerStopped(exception=exc))
+ await self._events.publish(SchedulerStopped(exception=exc))
raise
self._state = RunState.stopped
- self._events.publish(SchedulerStopped())
+ await self._events.publish(SchedulerStopped())
# async def stop(self, force: bool = False) -> None:
# self._running = False
diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py
index 1e5f133..5875ddc 100644
--- a/src/apscheduler/schedulers/sync.py
+++ b/src/apscheduler/schedulers/sync.py
@@ -13,9 +13,9 @@ from uuid import uuid4
from ..abc import DataStore, EventSource, Trigger
from ..datastores.memory import MemoryDataStore
from ..enums import CoalescePolicy, ConflictPolicy, RunState
+from ..eventbrokers.local import LocalEventBroker
from ..events import (
- Event, EventHub, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated,
- SubscriptionToken)
+ Event, ScheduleAdded, SchedulerStarted, SchedulerStopped, ScheduleUpdated, SubscriptionToken)
from ..marshalling import callable_to_ref
from ..structures import Job, Schedule, Task
from ..workers.sync import Worker
@@ -36,7 +36,7 @@ class Scheduler(EventSource):
self.data_store = data_store or MemoryDataStore()
self._exit_stack = ExitStack()
self._executor = ThreadPoolExecutor(max_workers=1)
- self._events = EventHub()
+ self._events = LocalEventBroker()
@property
def state(self) -> RunState:
@@ -54,13 +54,13 @@ class Scheduler(EventSource):
# Initialize the data store
self._exit_stack.enter_context(self.data_store)
- relay_token = self._events.relay_events_from(self.data_store)
- self._exit_stack.callback(self.data_store.unsubscribe, relay_token)
+ relay_token = self._events.relay_events_from(self.data_store.events)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token)
# Wake up the scheduler if the data store emits a significant schedule event
- wakeup_token = self.data_store.subscribe(
+ wakeup_token = self.data_store.events.subscribe(
lambda event: self._wakeup_event.set(), {ScheduleAdded, ScheduleUpdated})
- self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token)
# Start the built-in worker, if configured to do so
if self.start_worker:
diff --git a/src/apscheduler/serializers/json.py b/src/apscheduler/serializers/json.py
index f7ef307..8bfe6d7 100644
--- a/src/apscheduler/serializers/json.py
+++ b/src/apscheduler/serializers/json.py
@@ -1,12 +1,13 @@
from __future__ import annotations
+from datetime import datetime
from json import dumps, loads
from typing import Any
import attr
from ..abc import Serializer
-from ..marshalling import marshal_object, unmarshal_object
+from ..marshalling import marshal_date, marshal_object, unmarshal_object
@attr.define(kw_only=True, eq=False)
@@ -23,6 +24,8 @@ class JSONSerializer(Serializer):
if hasattr(obj, '__getstate__'):
cls_ref, state = marshal_object(obj)
return {self.magic_key: [cls_ref, state]}
+ elif isinstance(obj, datetime):
+ return marshal_date(obj)
raise TypeError(f'Object of type {obj.__class__.__name__!r} is not JSON serializable')
diff --git a/src/apscheduler/validators.py b/src/apscheduler/validators.py
index ca12d84..c71f73a 100644
--- a/src/apscheduler/validators.py
+++ b/src/apscheduler/validators.py
@@ -4,6 +4,7 @@ import sys
from datetime import date, datetime, timedelta, timezone, tzinfo
from typing import Any, Optional
+from attr import Attribute
from tzlocal import get_localzone
from .abc import Trigger
@@ -146,6 +147,11 @@ def as_list(value, element_type: type, name: str) -> list:
return value
+def aware_datetime(instance: Any, attribute: Attribute, value: datetime) -> None:
+ if not value.tzinfo:
+ raise ValueError(f'{attribute.name} must be a timezone aware datetime')
+
+
def require_state_version(trigger: Trigger, state: dict[str, Any], max_version: int) -> None:
try:
if state['version'] > max_version:
diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py
index e9bcc5f..0893a98 100644
--- a/src/apscheduler/workers/async_.py
+++ b/src/apscheduler/workers/async_.py
@@ -10,15 +10,16 @@ from typing import Any, Callable, Iterable, Optional
from uuid import UUID
import anyio
-from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class
+from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class, move_on_after
from anyio.abc import CancelScope
from ..abc import AsyncDataStore, DataStore, EventSource, Job
from ..datastores.async_adapter import AsyncDataStoreAdapter
from ..enums import JobOutcome, RunState
+from ..eventbrokers.async_local import LocalAsyncEventBroker
from ..events import (
- AsyncEventHub, Event, JobAdded, JobCancelled, JobCompleted, JobDeadlineMissed, JobFailed,
- JobStarted, SubscriptionToken, WorkerStarted, WorkerStopped)
+ Event, JobAdded, JobCancelled, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted,
+ SubscriptionToken, WorkerStarted, WorkerStopped)
from ..structures import JobResult
@@ -39,7 +40,7 @@ class AsyncWorker(EventSource):
self.logger = logger or getLogger(__name__)
self._acquired_jobs: set[Job] = set()
self._exit_stack = AsyncExitStack()
- self._events = AsyncEventHub()
+ self._events = LocalAsyncEventBroker()
self._running_jobs: set[UUID] = set()
if self.max_concurrent_jobs < 1:
@@ -62,13 +63,13 @@ class AsyncWorker(EventSource):
# Initialize the data store
await self._exit_stack.enter_async_context(self.data_store)
- relay_token = self._events.relay_events_from(self.data_store)
- self._exit_stack.callback(self.data_store.unsubscribe, relay_token)
+ relay_token = self._events.relay_events_from(self.data_store.events)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token)
# Wake up the worker if the data store emits a significant job event
- wakeup_token = self.data_store.subscribe(
+ wakeup_token = self.data_store.events.subscribe(
lambda event: self._wakeup_event.set(), {JobAdded})
- self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token)
# Start the actual worker
task_group = create_task_group()
@@ -97,7 +98,7 @@ class AsyncWorker(EventSource):
# Signal that the worker has started
self._state = RunState.started
task_status.started()
- self._events.publish(WorkerStarted())
+ await self._events.publish(WorkerStarted())
try:
async with create_task_group() as tg:
@@ -115,21 +116,23 @@ class AsyncWorker(EventSource):
pass
except BaseException as exc:
self._state = RunState.stopped
- self._events.publish(WorkerStopped(exception=exc))
+ with move_on_after(1, shield=True):
+ await self._events.publish(WorkerStopped(exception=exc))
+
raise
self._state = RunState.stopped
- self._events.publish(WorkerStopped())
+ await self._events.publish(WorkerStopped())
async def _run_job(self, job: Job, func: Callable) -> None:
try:
# Check if the job started before the deadline
start_time = datetime.now(timezone.utc)
if job.start_deadline is not None and start_time > job.start_deadline:
- self._events.publish(JobDeadlineMissed.from_job(job, start_time))
+ await self._events.publish(JobDeadlineMissed.from_job(job, start_time))
return
- self._events.publish(JobStarted.from_job(job, start_time))
+ await self._events.publish(JobStarted.from_job(job, start_time))
try:
retval = func(*job.args, **job.kwargs)
if isawaitable(retval):
@@ -139,17 +142,18 @@ class AsyncWorker(EventSource):
result = JobResult(job_id=job.id, outcome=JobOutcome.cancelled)
await self.data_store.release_job(self.identity, job.task_id, result)
- self._events.publish(JobCancelled.from_job(job, start_time))
+ with move_on_after(1, shield=True):
+ await self._events.publish(JobCancelled.from_job(job, start_time))
except BaseException as exc:
result = JobResult(job_id=job.id, outcome=JobOutcome.failure, exception=exc)
await self.data_store.release_job(self.identity, job.task_id, result)
- self._events.publish(JobFailed.from_exception(job, start_time, exc))
+ await self._events.publish(JobFailed.from_exception(job, start_time, exc))
if not isinstance(exc, Exception):
raise
else:
result = JobResult(job_id=job.id, outcome=JobOutcome.success, return_value=retval)
await self.data_store.release_job(self.identity, job.task_id, result)
- self._events.publish(JobCompleted.from_retval(job, start_time, retval))
+ await self._events.publish(JobCompleted.from_retval(job, start_time, retval))
finally:
self._running_jobs.remove(job.id)
diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py
index b7803c7..30db6bf 100644
--- a/src/apscheduler/workers/sync.py
+++ b/src/apscheduler/workers/sync.py
@@ -13,8 +13,9 @@ from uuid import UUID
from .. import events
from ..abc import DataStore, EventSource
from ..enums import JobOutcome, RunState
+from ..eventbrokers.local import LocalEventBroker
from ..events import (
- EventHub, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, SubscriptionToken,
+ JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, SubscriptionToken,
WorkerStarted, WorkerStopped)
from ..structures import Job, JobResult
@@ -33,7 +34,7 @@ class Worker(EventSource):
self.logger = logger or getLogger(__name__)
self._acquired_jobs: set[Job] = set()
self._exit_stack = ExitStack()
- self._events = EventHub()
+ self._events = LocalEventBroker()
self._running_jobs: set[UUID] = set()
if self.max_concurrent_jobs < 1:
@@ -53,13 +54,13 @@ class Worker(EventSource):
# Initialize the data store
self._exit_stack.enter_context(self.data_store)
- relay_token = self._events.relay_events_from(self.data_store)
- self._exit_stack.callback(self.data_store.unsubscribe, relay_token)
+ relay_token = self._events.relay_events_from(self.data_store.events)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, relay_token)
# Wake up the worker if the data store emits a significant job event
- wakeup_token = self.data_store.subscribe(
+ wakeup_token = self.data_store.events.subscribe(
lambda event: self._wakeup_event.set(), {JobAdded})
- self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token)
+ self._exit_stack.callback(self.data_store.events.unsubscribe, wakeup_token)
# Start the worker and return when it has signalled readiness or raised an exception
start_future: Future[None] = Future()
diff --git a/tests/conftest.py b/tests/conftest.py
index bc5bab2..135c18d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -24,7 +24,6 @@ def timezone() -> ZoneInfo:
@pytest.fixture(params=[
- pytest.param(None, id='none'),
pytest.param(PickleSerializer, id='pickle'),
pytest.param(CBORSerializer, id='cbor'),
pytest.param(JSONSerializer, id='json')
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 52f8349..74db6e7 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -40,11 +40,11 @@ async def capture_events(
events.append(event)
if len(events) == limit:
limit_event.set()
- store.unsubscribe(token)
+ store.events.unsubscribe(token)
events: List[Event] = []
limit_event = anyio.Event()
- token = store.subscribe(listener, event_types)
+ token = store.events.subscribe(listener, event_types)
yield events
if limit:
with anyio.fail_after(3):
diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py
new file mode 100644
index 0000000..6a0f45d
--- /dev/null
+++ b/tests/test_eventbrokers.py
@@ -0,0 +1,279 @@
+from concurrent.futures import Future
+from datetime import datetime, timezone
+from queue import Empty, Queue
+from typing import Callable
+
+import pytest
+from _pytest.fixtures import FixtureRequest
+from _pytest.logging import LogCaptureFixture
+from anyio import create_memory_object_stream, fail_after
+
+from apscheduler.abc import AsyncEventBroker, EventBroker, Serializer
+from apscheduler.events import Event, ScheduleAdded
+
+
+@pytest.fixture
+def local_broker() -> EventBroker:
+ from apscheduler.eventbrokers.local import LocalEventBroker
+
+ return LocalEventBroker()
+
+
+@pytest.fixture
+def local_async_broker() -> AsyncEventBroker:
+ from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker
+
+ return LocalAsyncEventBroker()
+
+
+@pytest.fixture
+def redis_broker(serializer: Serializer) -> EventBroker:
+ from apscheduler.eventbrokers.redis import RedisEventBroker
+
+ broker = RedisEventBroker.from_url('redis://localhost:6379')
+ broker.serializer = serializer
+ return broker
+
+
+@pytest.fixture
+def mqtt_broker(serializer: Serializer) -> EventBroker:
+ from paho.mqtt.client import Client
+
+ from apscheduler.eventbrokers.mqtt import MQTTEventBroker
+
+ return MQTTEventBroker(Client(), serializer=serializer)
+
+
+@pytest.fixture
+async def asyncpg_broker(serializer: Serializer) -> AsyncEventBroker:
+ from asyncpg import create_pool
+
+ from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker
+
+ pool = await create_pool('postgres://postgres:secret@localhost:5432/testdb')
+ broker = AsyncpgEventBroker.from_asyncpg_pool(pool)
+ broker.serializer = serializer
+ yield broker
+ await pool.close()
+
+
+@pytest.fixture(params=[
+ pytest.param(pytest.lazy_fixture('local_broker'), id='local'),
+ pytest.param(pytest.lazy_fixture('redis_broker'), id='redis'),
+ pytest.param(pytest.lazy_fixture('mqtt_broker'), id='mqtt')
+])
+def broker(request: FixtureRequest) -> Callable[[], EventBroker]:
+ return request.param
+
+
+@pytest.fixture(params=[
+ pytest.param(pytest.lazy_fixture('local_async_broker'), id='local'),
+ pytest.param(pytest.lazy_fixture('asyncpg_broker'), id='asyncpg')
+])
+def async_broker(request: FixtureRequest) -> Callable[[], AsyncEventBroker]:
+ return request.param
+
+
+class TestEventBroker:
+ def test_publish_subscribe(self, broker: EventBroker) -> None:
+ def subscriber1(event) -> None:
+ queue.put_nowait(event)
+
+ def subscriber2(event) -> None:
+ queue.put_nowait(event)
+
+ queue = Queue()
+ with broker:
+ broker.subscribe(subscriber1)
+ broker.subscribe(subscriber2)
+ event = ScheduleAdded(
+ schedule_id='schedule1',
+ next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc))
+ broker.publish(event)
+ event1 = queue.get(timeout=3)
+ event2 = queue.get(timeout=1)
+
+ assert event1 == event2
+ assert isinstance(event1, ScheduleAdded)
+ assert isinstance(event1.timestamp, datetime)
+ assert event1.schedule_id == 'schedule1'
+ assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)
+
+ def test_unsubscribe(self, broker: EventBroker, caplog) -> None:
+ queue = Queue()
+ with broker:
+ token = broker.subscribe(queue.put_nowait)
+ broker.publish(Event())
+ queue.get(timeout=3)
+
+ broker.unsubscribe(token)
+ broker.publish(Event())
+ with pytest.raises(Empty):
+ queue.get(timeout=0.1)
+
+ def test_publish_no_subscribers(self, broker: EventBroker, caplog: LogCaptureFixture) -> None:
+ with broker:
+ broker.publish(Event())
+
+ assert not caplog.text
+
+ def test_publish_exception(self, broker: EventBroker, caplog: LogCaptureFixture) -> None:
+ def bad_subscriber(event: Event) -> None:
+ raise Exception('foo')
+
+ timestamp = datetime.now(timezone.utc)
+ event_future: Future[Event] = Future()
+ with broker:
+ broker.subscribe(bad_subscriber)
+ broker.subscribe(event_future.set_result)
+ broker.publish(Event(timestamp=timestamp))
+
+ event = event_future.result(3)
+ assert isinstance(event, Event)
+ assert event.timestamp == timestamp
+ assert 'Error delivering Event' in caplog.text
+
+
+@pytest.mark.anyio
+class TestAsyncEventBroker:
+ async def test_publish_subscribe(self, async_broker: AsyncEventBroker) -> None:
+ def subscriber1(event) -> None:
+ send.send_nowait(event)
+
+ async def subscriber2(event) -> None:
+ await send.send(event)
+
+ send, receive = create_memory_object_stream(2)
+ async with async_broker:
+ async_broker.subscribe(subscriber1)
+ async_broker.subscribe(subscriber2)
+ event = ScheduleAdded(
+ schedule_id='schedule1',
+ next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc))
+ await async_broker.publish(event)
+
+ with fail_after(3):
+ event1 = await receive.receive()
+ event2 = await receive.receive()
+
+ assert event1 == event2
+ assert isinstance(event1, ScheduleAdded)
+ assert isinstance(event1.timestamp, datetime)
+ assert event1.schedule_id == 'schedule1'
+ assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)
+
+ async def test_unsubscribe(self, async_broker: AsyncEventBroker) -> None:
+ send, receive = create_memory_object_stream()
+ async with async_broker:
+ token = async_broker.subscribe(send.send)
+ await async_broker.publish(Event())
+ with fail_after(3):
+ await receive.receive()
+
+ async_broker.unsubscribe(token)
+ await async_broker.publish(Event())
+ with pytest.raises(TimeoutError), fail_after(0.1):
+ await receive.receive()
+
+ async def test_publish_no_subscribers(self, async_broker: AsyncEventBroker,
+ caplog: LogCaptureFixture) -> None:
+ async with async_broker:
+ await async_broker.publish(Event())
+
+ assert not caplog.text
+
+ async def test_publish_exception(self, async_broker: AsyncEventBroker,
+ caplog: LogCaptureFixture) -> None:
+ def bad_subscriber(event: Event) -> None:
+ raise Exception('foo')
+
+ timestamp = datetime.now(timezone.utc)
+ events = []
+ async with async_broker:
+ async_broker.subscribe(bad_subscriber)
+ async_broker.subscribe(events.append)
+ await async_broker.publish(Event(timestamp=timestamp))
+
+ assert isinstance(events[0], Event)
+ assert events[0].timestamp == timestamp
+ assert 'Error delivering Event' in caplog.text
+#
+# def test_subscribe_coroutine_callback(self) -> None:
+# async def callback(event: Event) -> None:
+# pass
+#
+# with EventBroker() as eventhub:
+# with pytest.raises(ValueError, match='Coroutine functions are not supported'):
+# eventhub.subscribe(callback)
+#
+# def test_relay_events(self) -> None:
+# timestamp = datetime.now(timezone.utc)
+# events = []
+# with EventBroker() as eventhub1, EventBroker() as eventhub2:
+# eventhub2.relay_events_from(eventhub1)
+# eventhub2.subscribe(events.append)
+# eventhub1.publish(Event(timestamp=timestamp))
+#
+# assert isinstance(events[0], Event)
+# assert events[0].timestamp == timestamp
+#
+#
+# @pytest.mark.anyio
+# class TestAsyncEventHub:
+# async def test_publish(self) -> None:
+# async def async_setitem(event: Event) -> None:
+# events[1] = event
+#
+# timestamp = datetime.now(timezone.utc)
+# events: List[Optional[Event]] = [None, None]
+# async with AsyncEventBroker() as eventhub:
+# eventhub.subscribe(partial(setitem, events, 0))
+# eventhub.subscribe(async_setitem)
+# eventhub.publish(Event(timestamp=timestamp))
+#
+# assert events[0] is events[1]
+# assert isinstance(events[0], Event)
+# assert events[0].timestamp == timestamp
+#
+# async def test_unsubscribe(self) -> None:
+# timestamp = datetime.now(timezone.utc)
+# events = []
+# async with AsyncEventBroker() as eventhub:
+# token = eventhub.subscribe(events.append)
+# eventhub.publish(Event(timestamp=timestamp))
+# eventhub.unsubscribe(token)
+# eventhub.publish(Event(timestamp=timestamp))
+#
+# assert len(events) == 1
+#
+# async def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None:
+# async with AsyncEventBroker() as eventhub:
+# eventhub.publish(Event(timestamp=datetime.now(timezone.utc)))
+#
+# assert not caplog.text
+#
+# async def test_publish_exception(self, caplog: LogCaptureFixture) -> None:
+# def bad_subscriber(event: Event) -> None:
+# raise Exception('foo')
+#
+# timestamp = datetime.now(timezone.utc)
+# events = []
+# async with AsyncEventBroker() as eventhub:
+# eventhub.subscribe(bad_subscriber)
+# eventhub.subscribe(events.append)
+# eventhub.publish(Event(timestamp=timestamp))
+#
+# assert isinstance(events[0], Event)
+# assert events[0].timestamp == timestamp
+# assert 'Error delivering Event' in caplog.text
+#
+# async def test_relay_events(self) -> None:
+# timestamp = datetime.now(timezone.utc)
+# events = []
+# async with AsyncEventBroker() as eventhub1, AsyncEventBroker() as eventhub2:
+# eventhub1.relay_events_from(eventhub2)
+# eventhub1.subscribe(events.append)
+# eventhub2.publish(Event(timestamp=timestamp))
+#
+# assert isinstance(events[0], Event)
+# assert events[0].timestamp == timestamp
diff --git a/tests/test_events.py b/tests/test_events.py
deleted file mode 100644
index bbe344f..0000000
--- a/tests/test_events.py
+++ /dev/null
@@ -1,135 +0,0 @@
-from datetime import datetime, timezone
-from functools import partial
-from operator import setitem
-from typing import List, Optional
-
-import pytest
-from _pytest.logging import LogCaptureFixture
-
-from apscheduler.events import AsyncEventHub, Event, EventHub
-
-
-class TestEventHub:
- def test_publish(self) -> None:
- timestamp = datetime.now(timezone.utc)
- events: List[Optional[Event]] = [None, None]
- with EventHub() as eventhub:
- eventhub.subscribe(partial(setitem, events, 0))
- eventhub.subscribe(partial(setitem, events, 1))
- eventhub.publish(Event(timestamp=timestamp))
-
- assert events[0] is events[1]
- assert isinstance(events[0], Event)
- assert events[0].timestamp == timestamp
-
- def test_unsubscribe(self) -> None:
- timestamp = datetime.now(timezone.utc)
- events = []
- with EventHub() as eventhub:
- token = eventhub.subscribe(events.append)
- eventhub.publish(Event(timestamp=timestamp))
- eventhub.unsubscribe(token)
- eventhub.publish(Event(timestamp=timestamp))
-
- assert len(events) == 1
-
- def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None:
- with EventHub() as eventhub:
- eventhub.publish(Event(timestamp=datetime.now(timezone.utc)))
-
- assert not caplog.text
-
- def test_publish_exception(self, caplog: LogCaptureFixture) -> None:
- def bad_subscriber(event: Event) -> None:
- raise Exception('foo')
-
- timestamp = datetime.now(timezone.utc)
- events = []
- with EventHub() as eventhub:
- eventhub.subscribe(bad_subscriber)
- eventhub.subscribe(events.append)
- eventhub.publish(Event(timestamp=timestamp))
-
- assert isinstance(events[0], Event)
- assert events[0].timestamp == timestamp
- assert 'Error delivering Event' in caplog.text
-
- def test_subscribe_coroutine_callback(self) -> None:
- async def callback(event: Event) -> None:
- pass
-
- with EventHub() as eventhub:
- with pytest.raises(ValueError, match='Coroutine functions are not supported'):
- eventhub.subscribe(callback)
-
- def test_relay_events(self) -> None:
- timestamp = datetime.now(timezone.utc)
- events = []
- with EventHub() as eventhub1, EventHub() as eventhub2:
- eventhub2.relay_events_from(eventhub1)
- eventhub2.subscribe(events.append)
- eventhub1.publish(Event(timestamp=timestamp))
-
- assert isinstance(events[0], Event)
- assert events[0].timestamp == timestamp
-
-
-@pytest.mark.anyio
-class TestAsyncEventHub:
- async def test_publish(self) -> None:
- async def async_setitem(event: Event) -> None:
- events[1] = event
-
- timestamp = datetime.now(timezone.utc)
- events: List[Optional[Event]] = [None, None]
- async with AsyncEventHub() as eventhub:
- eventhub.subscribe(partial(setitem, events, 0))
- eventhub.subscribe(async_setitem)
- eventhub.publish(Event(timestamp=timestamp))
-
- assert events[0] is events[1]
- assert isinstance(events[0], Event)
- assert events[0].timestamp == timestamp
-
- async def test_unsubscribe(self) -> None:
- timestamp = datetime.now(timezone.utc)
- events = []
- async with AsyncEventHub() as eventhub:
- token = eventhub.subscribe(events.append)
- eventhub.publish(Event(timestamp=timestamp))
- eventhub.unsubscribe(token)
- eventhub.publish(Event(timestamp=timestamp))
-
- assert len(events) == 1
-
- async def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None:
- async with AsyncEventHub() as eventhub:
- eventhub.publish(Event(timestamp=datetime.now(timezone.utc)))
-
- assert not caplog.text
-
- async def test_publish_exception(self, caplog: LogCaptureFixture) -> None:
- def bad_subscriber(event: Event) -> None:
- raise Exception('foo')
-
- timestamp = datetime.now(timezone.utc)
- events = []
- async with AsyncEventHub() as eventhub:
- eventhub.subscribe(bad_subscriber)
- eventhub.subscribe(events.append)
- eventhub.publish(Event(timestamp=timestamp))
-
- assert isinstance(events[0], Event)
- assert events[0].timestamp == timestamp
- assert 'Error delivering Event' in caplog.text
-
- async def test_relay_events(self) -> None:
- timestamp = datetime.now(timezone.utc)
- events = []
- async with AsyncEventHub() as eventhub1, AsyncEventHub() as eventhub2:
- eventhub1.relay_events_from(eventhub2)
- eventhub1.subscribe(events.append)
- eventhub2.publish(Event(timestamp=timestamp))
-
- assert isinstance(events[0], Event)
- assert events[0].timestamp == timestamp