summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2022-09-12 22:09:05 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2022-09-21 02:40:02 +0300
commitc5727432736b55b7d76753307f14efdb962c2edf (patch)
tree005bd129694b56bd601d65c4cdf43828cfcd4381
parent26c4db062145fcb4f623ecfda96c42ce2414e8e1 (diff)
downloadapscheduler-c5727432736b55b7d76753307f14efdb962c2edf.tar.gz
Major refactoring
- Made SyncScheduler a synchronous wrapper for AsyncScheduler - Removed workers as a user interface - Removed synchronous interfaces for data stores and event brokers and refactored existing implementations to use the async interface - Added the current_async_scheduler contextvar - Added job executors
-rw-r--r--docs/versionhistory.rst24
-rw-r--r--examples/separate_worker/async_scheduler.py10
-rw-r--r--examples/separate_worker/async_worker.py14
-rw-r--r--examples/separate_worker/sync_worker.py12
-rw-r--r--src/apscheduler/__init__.py7
-rw-r--r--src/apscheduler/_context.py9
-rw-r--r--src/apscheduler/_converters.py22
-rw-r--r--src/apscheduler/_retry.py68
-rw-r--r--src/apscheduler/_structures.py23
-rw-r--r--src/apscheduler/_worker.py189
-rw-r--r--src/apscheduler/abc.py272
-rw-r--r--src/apscheduler/datastores/async_adapter.py101
-rw-r--r--src/apscheduler/datastores/async_sqlalchemy.py602
-rw-r--r--src/apscheduler/datastores/base.py58
-rw-r--r--src/apscheduler/datastores/memory.py62
-rw-r--r--src/apscheduler/datastores/mongodb.py109
-rw-r--r--src/apscheduler/datastores/sqlalchemy.py865
-rw-r--r--src/apscheduler/eventbrokers/async_adapter.py65
-rw-r--r--src/apscheduler/eventbrokers/async_local.py64
-rw-r--r--src/apscheduler/eventbrokers/async_redis.py124
-rw-r--r--src/apscheduler/eventbrokers/asyncpg.py52
-rw-r--r--src/apscheduler/eventbrokers/base.py64
-rw-r--r--src/apscheduler/eventbrokers/local.py72
-rw-r--r--src/apscheduler/eventbrokers/mqtt.py35
-rw-r--r--src/apscheduler/eventbrokers/redis.py79
-rw-r--r--src/apscheduler/executors/__init__.py (renamed from src/apscheduler/workers/__init__.py)0
-rw-r--r--src/apscheduler/executors/async_.py24
-rw-r--r--src/apscheduler/executors/subprocess.py33
-rw-r--r--src/apscheduler/executors/thread.py31
-rw-r--r--src/apscheduler/schedulers/async_.py165
-rw-r--r--src/apscheduler/schedulers/sync.py644
-rw-r--r--src/apscheduler/workers/async_.py251
-rw-r--r--src/apscheduler/workers/sync.py257
-rw-r--r--tests/conftest.py194
-rw-r--r--tests/test_datastores.py1435
-rw-r--r--tests/test_eventbrokers.py325
-rw-r--r--tests/test_schedulers.py36
-rw-r--r--tests/test_workers.py281
38 files changed, 2144 insertions, 4534 deletions
diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index efe69fd..559ee36 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -6,6 +6,30 @@ APScheduler, see the :doc:`migration section <migration>`.
**UNRELEASED**
+- **BREAKING** Workers can no longer be run independently. Instead, you can run a
+ scheduler that only starts a worker but does not process schedules by passing
+ ``process_schedules=False`` to the scheduler
+- **BREAKING** The synchronous interfaces for event brokers and data stores have been
+ removed. Synchronous libraries can still be used to implement these services through
+ the use of ``anyio.to_thread.run_sync()``.
+- **BREAKING** The ``current_worker`` context variable has been removed
+- **BREAKING** The ``current_scheduler`` context variable is now specified to only
+ contain the currently running instance of a **synchronous** scheduler
+ (``apscheduler.schedulers.sync.Scheduler``). The asynchronous scheduler instance can
+ be fetched from the new ``current_async_scheduler`` context variable, and will always
+ be available when a scheduler is running in the current context, while
+ ``current_scheduler`` is only available when the synchronous wrapper is being run.
+- **BREAKING** Changed the initialization of data stores and event brokers to use a
+ single ``start()`` method that accepts an ``AsyncExitStack`` (and, depending on the
+ interface, other arguments too)
+- **BREAKING** Added a concept of "job executors". This determines how the task function
+ is executed once picked up by a worker. Several data structures and scheduler methods
+ have a new field/parameter for this, ``job_executor``. This addition requires database
+ schema changes too.
+- Added the ability to run jobs in worker processes, courtesy of the ``processpool``
+ executor
+- The synchronous scheduler now runs an asyncio event loop in a thread, acting as a
+ façade for ``AsyncScheduler`
- Fixed the ``schema`` parameter in ``SQLAlchemyDataStore`` not being applied
**4.0.0a2**
diff --git a/examples/separate_worker/async_scheduler.py b/examples/separate_worker/async_scheduler.py
index 6ffdbcd..2ac53c5 100644
--- a/examples/separate_worker/async_scheduler.py
+++ b/examples/separate_worker/async_scheduler.py
@@ -19,7 +19,7 @@ import logging
from example_tasks import tick
from sqlalchemy.ext.asyncio import create_async_engine
-from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore
+from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker
from apscheduler.schedulers.async_ import AsyncScheduler
from apscheduler.triggers.interval import IntervalTrigger
@@ -29,15 +29,15 @@ async def main():
engine = create_async_engine(
"postgresql+asyncpg://postgres:secret@localhost/testdb"
)
- data_store = AsyncSQLAlchemyDataStore(engine)
+ data_store = SQLAlchemyDataStore(engine)
event_broker = AsyncpgEventBroker.from_async_sqla_engine(engine)
# Uncomment the next two lines to use the Redis event broker instead
- # from apscheduler.eventbrokers.async_redis import AsyncRedisEventBroker
- # event_broker = AsyncRedisEventBroker.from_url("redis://localhost")
+ # from apscheduler.eventbrokers.redis import RedisEventBroker
+ # event_broker = RedisEventBroker.from_url("redis://localhost")
async with AsyncScheduler(
- data_store, event_broker, start_worker=False
+ data_store, event_broker, process_jobs=False
) as scheduler:
await scheduler.add_schedule(tick, IntervalTrigger(seconds=1), id="tick")
await scheduler.run_until_stopped()
diff --git a/examples/separate_worker/async_worker.py b/examples/separate_worker/async_worker.py
index 700720e..51c51e9 100644
--- a/examples/separate_worker/async_worker.py
+++ b/examples/separate_worker/async_worker.py
@@ -18,24 +18,24 @@ import logging
from sqlalchemy.ext.asyncio import create_async_engine
-from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore
+from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker
-from apscheduler.workers.async_ import AsyncWorker
+from apscheduler.schedulers.async_ import AsyncScheduler
async def main():
engine = create_async_engine(
"postgresql+asyncpg://postgres:secret@localhost/testdb"
)
- data_store = AsyncSQLAlchemyDataStore(engine)
+ data_store = SQLAlchemyDataStore(engine)
event_broker = AsyncpgEventBroker.from_async_sqla_engine(engine)
# Uncomment the next two lines to use the Redis event broker instead
- # from apscheduler.eventbrokers.async_redis import AsyncRedisEventBroker
- # event_broker = AsyncRedisEventBroker.from_url("redis://localhost")
+ # from apscheduler.eventbrokers.redis import RedisEventBroker
+ # event_broker = RedisEventBroker.from_url("redis://localhost")
- worker = AsyncWorker(data_store, event_broker)
- await worker.run_until_stopped()
+ scheduler = AsyncScheduler(data_store, event_broker, process_schedules=False)
+ await scheduler.run_until_stopped()
logging.basicConfig(level=logging.INFO)
diff --git a/examples/separate_worker/sync_worker.py b/examples/separate_worker/sync_worker.py
index e57be64..4329d02 100644
--- a/examples/separate_worker/sync_worker.py
+++ b/examples/separate_worker/sync_worker.py
@@ -1,7 +1,7 @@
"""
-Example demonstrating the separation of scheduler and worker.
-This script runs the worker part. You need to be running both this and the scheduler
-script simultaneously in order for the scheduled task to be run.
+Example demonstrating a scheduler that only runs jobs but does not process schedules.
+You need to be running both this and the scheduler script simultaneously in order for
+the scheduled task to be run.
Requires the "postgresql" and "redis" services to be running.
To install prerequisites: pip install sqlalchemy psycopg2 redis
@@ -19,7 +19,7 @@ from sqlalchemy.future import create_engine
from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
from apscheduler.eventbrokers.redis import RedisEventBroker
-from apscheduler.workers.sync import Worker
+from apscheduler.schedulers.sync import Scheduler
logging.basicConfig(level=logging.INFO)
engine = create_engine("postgresql+psycopg2://postgres:secret@localhost/testdb")
@@ -30,5 +30,5 @@ event_broker = RedisEventBroker.from_url("redis://localhost")
# from apscheduler.eventbrokers.mqtt import MQTTEventBroker
# event_broker = MQTTEventBroker()
-worker = Worker(data_store, event_broker)
-worker.run_until_stopped()
+with Scheduler(data_store, event_broker, process_schedules=False) as scheduler:
+ scheduler.run_until_stopped()
diff --git a/src/apscheduler/__init__.py b/src/apscheduler/__init__.py
index 6b5828d..844a056 100644
--- a/src/apscheduler/__init__.py
+++ b/src/apscheduler/__init__.py
@@ -41,14 +41,14 @@ __all__ = [
"WorkerEvent",
"WorkerStarted",
"WorkerStopped",
+ "current_async_scheduler",
"current_scheduler",
- "current_worker",
"current_job",
]
from typing import Any
-from ._context import current_job, current_scheduler, current_worker
+from ._context import current_async_scheduler, current_job, current_scheduler
from ._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
from ._events import (
DataStoreEvent,
@@ -84,7 +84,8 @@ from ._exceptions import (
SerializationError,
TaskLookupError,
)
-from ._structures import Job, JobInfo, JobResult, RetrySettings, Schedule, Task
+from ._retry import RetrySettings
+from ._structures import Job, JobInfo, JobResult, Schedule, Task
# Re-export imports, so they look like they live directly in this package
value: Any
diff --git a/src/apscheduler/_context.py b/src/apscheduler/_context.py
index cc5aff2..5edc310 100644
--- a/src/apscheduler/_context.py
+++ b/src/apscheduler/_context.py
@@ -7,16 +7,13 @@ if TYPE_CHECKING:
from ._structures import JobInfo
from .schedulers.async_ import AsyncScheduler
from .schedulers.sync import Scheduler
- from .workers.async_ import AsyncWorker
- from .workers.sync import Worker
#: The currently running (local) scheduler
-current_scheduler: ContextVar[Scheduler | AsyncScheduler | None] = ContextVar(
+current_scheduler: ContextVar[Scheduler | None] = ContextVar(
"current_scheduler", default=None
)
-#: The worker running the current job
-current_worker: ContextVar[Worker | AsyncWorker | None] = ContextVar(
- "current_worker", default=None
+current_async_scheduler: ContextVar[AsyncScheduler | None] = ContextVar(
+ "current_async_scheduler", default=None
)
#: Metadata about the current job
current_job: ContextVar[JobInfo] = ContextVar("job_info")
diff --git a/src/apscheduler/_converters.py b/src/apscheduler/_converters.py
index 3518d44..9e299dd 100644
--- a/src/apscheduler/_converters.py
+++ b/src/apscheduler/_converters.py
@@ -6,8 +6,6 @@ from enum import Enum
from typing import TypeVar
from uuid import UUID
-from . import abc
-
TEnum = TypeVar("TEnum", bound=Enum)
@@ -46,23 +44,3 @@ def as_enum(enum_class: type[TEnum]) -> Callable[[TEnum | str], TEnum]:
return value
return converter
-
-
-def as_async_eventbroker(
- value: abc.EventBroker | abc.AsyncEventBroker,
-) -> abc.AsyncEventBroker:
- if isinstance(value, abc.EventBroker):
- from apscheduler.eventbrokers.async_adapter import AsyncEventBrokerAdapter
-
- return AsyncEventBrokerAdapter(value)
-
- return value
-
-
-def as_async_datastore(value: abc.DataStore | abc.AsyncDataStore) -> abc.AsyncDataStore:
- if isinstance(value, abc.DataStore):
- from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter
-
- return AsyncDataStoreAdapter(value)
-
- return value
diff --git a/src/apscheduler/_retry.py b/src/apscheduler/_retry.py
new file mode 100644
index 0000000..19cf51b
--- /dev/null
+++ b/src/apscheduler/_retry.py
@@ -0,0 +1,68 @@
+from __future__ import annotations
+
+import attrs
+from attr.validators import instance_of
+from tenacity import (
+ AsyncRetrying,
+ RetryCallState,
+ retry_if_exception_type,
+ stop_after_delay,
+ wait_exponential,
+)
+from tenacity.stop import stop_base
+from tenacity.wait import wait_base
+
+
+@attrs.define(kw_only=True, frozen=True)
+class RetrySettings:
+ """
+ Settings for retrying an operation with Tenacity.
+
+ :param stop: defines when to stop trying
+ :param wait: defines how long to wait between attempts
+ """
+
+ stop: stop_base = attrs.field(
+ validator=instance_of(stop_base),
+ default=stop_after_delay(60),
+ )
+ wait: wait_base = attrs.field(
+ validator=instance_of(wait_base),
+ default=wait_exponential(min=0.5, max=20),
+ )
+
+
+@attrs.define(kw_only=True, slots=False)
+class RetryMixin:
+ """
+ Mixin that provides support for retrying operations.
+
+ :param retry_settings: Tenacity settings for retrying operations in case of a
+ database connecitivty problem
+ """
+
+ retry_settings: RetrySettings = attrs.field(default=RetrySettings())
+
+ @property
+ def _temporary_failure_exceptions(self) -> tuple[type[Exception]]:
+ """
+ Tuple of exception classes which indicate that the operation should be retried.
+
+ """
+ return ()
+
+ def _retry(self) -> AsyncRetrying:
+ def after_attempt(self, retry_state: RetryCallState) -> None:
+ self._logger.warning(
+ "Temporary data store error (attempt %d): %s",
+ retry_state.attempt_number,
+ retry_state.outcome.exception(),
+ )
+
+ return AsyncRetrying(
+ stop=self.retry_settings.stop,
+ wait=self.retry_settings.wait,
+ retry=retry_if_exception_type(self._temporary_failure_exceptions),
+ after=after_attempt,
+ reraise=True,
+ )
diff --git a/src/apscheduler/_structures.py b/src/apscheduler/_structures.py
index 0a959e4..eb26a3a 100644
--- a/src/apscheduler/_structures.py
+++ b/src/apscheduler/_structures.py
@@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, Any
from uuid import UUID, uuid4
import attrs
-import tenacity.stop
-import tenacity.wait
-from attrs.validators import instance_of
from ._converters import as_enum, as_timedelta
from ._enums import CoalescePolicy, JobOutcome
@@ -43,6 +40,7 @@ class Task:
id: str
func: Callable = attrs.field(eq=False, order=False)
+ executor: str = attrs.field(eq=False)
max_running_jobs: int | None = attrs.field(eq=False, order=False, default=None)
misfire_grace_time: timedelta | None = attrs.field(
eq=False, order=False, default=None
@@ -339,22 +337,3 @@ class JobResult:
)
return cls(**marshalled)
-
-
-@attrs.define(kw_only=True, frozen=True)
-class RetrySettings:
- """
- Settings for retrying an operation with Tenacity.
-
- :param stop: defines when to stop trying
- :param wait: defines how long to wait between attempts
- """
-
- stop: tenacity.stop.stop_base = attrs.field(
- validator=instance_of(tenacity.stop.stop_base),
- default=tenacity.stop_after_delay(60),
- )
- wait: tenacity.wait.wait_base = attrs.field(
- validator=instance_of(tenacity.wait.wait_base),
- default=tenacity.wait_exponential(min=0.5, max=20),
- )
diff --git a/src/apscheduler/_worker.py b/src/apscheduler/_worker.py
new file mode 100644
index 0000000..8c95d16
--- /dev/null
+++ b/src/apscheduler/_worker.py
@@ -0,0 +1,189 @@
+from __future__ import annotations
+
+from collections.abc import Mapping
+from contextlib import AsyncExitStack
+from datetime import datetime, timezone
+from logging import Logger, getLogger
+from typing import Callable
+from uuid import UUID
+
+import anyio
+import attrs
+from anyio import create_task_group, get_cancelled_exc_class, move_on_after
+from anyio.abc import CancelScope
+
+from ._context import current_job
+from ._enums import JobOutcome, RunState
+from ._events import JobAdded, JobReleased, WorkerStarted, WorkerStopped
+from ._structures import Job, JobInfo, JobResult
+from ._validators import positive_integer
+from .abc import DataStore, EventBroker, JobExecutor
+
+
+@attrs.define(eq=False, kw_only=True)
+class Worker:
+ """
+ Runs jobs locally in a task group.
+
+ :param max_concurrent_jobs: Maximum number of jobs the worker will run at once
+ """
+
+ job_executors: Mapping[str, JobExecutor] = attrs.field(kw_only=True)
+ max_concurrent_jobs: int = attrs.field(
+ kw_only=True, validator=positive_integer, default=100
+ )
+ logger: Logger = attrs.field(kw_only=True, default=getLogger(__name__))
+
+ _data_store: DataStore = attrs.field(init=False)
+ _event_broker: EventBroker = attrs.field(init=False)
+ _identity: str = attrs.field(init=False)
+ _state: RunState = attrs.field(init=False, default=RunState.stopped)
+ _wakeup_event: anyio.Event = attrs.field(init=False)
+ _acquired_jobs: set[Job] = attrs.field(init=False, factory=set)
+ _running_jobs: set[UUID] = attrs.field(init=False, factory=set)
+
+ async def start(
+ self,
+ exit_stack: AsyncExitStack,
+ data_store: DataStore,
+ event_broker: EventBroker,
+ identity: str,
+ ) -> None:
+ self._data_store = data_store
+ self._event_broker = event_broker
+ self._identity = identity
+ self._state = RunState.started
+ self._wakeup_event = anyio.Event()
+
+ # Start the job executors
+ for job_executor in self.job_executors.values():
+ await job_executor.start(exit_stack)
+
+ # Start the worker in a background task
+ task_group = await exit_stack.enter_async_context(create_task_group())
+ task_group.start_soon(self._run)
+
+ # Stop the worker when the exit stack unwinds
+ exit_stack.callback(lambda: self._wakeup_event.set())
+ exit_stack.callback(setattr, self, "_state", RunState.stopped)
+
+ # Wake up the worker if the data store emits a significant job event
+ exit_stack.enter_context(
+ self._event_broker.subscribe(
+ lambda event: self._wakeup_event.set(), {JobAdded}
+ )
+ )
+
+ # Signal that the worker has started
+ await self._event_broker.publish_local(WorkerStarted())
+
+ async def _run(self) -> None:
+ """Run the worker until it is explicitly stopped."""
+ exception: BaseException | None = None
+ try:
+ async with create_task_group() as tg:
+ while self._state is RunState.started:
+ limit = self.max_concurrent_jobs - len(self._running_jobs)
+ jobs = await self._data_store.acquire_jobs(self._identity, limit)
+ for job in jobs:
+ task = await self._data_store.get_task(job.task_id)
+ self._running_jobs.add(job.id)
+ tg.start_soon(self._run_job, job, task.func, task.executor)
+
+ await self._wakeup_event.wait()
+ self._wakeup_event = anyio.Event()
+ except get_cancelled_exc_class():
+ pass
+ except BaseException as exc:
+ exception = exc
+ raise
+ finally:
+ if not exception:
+ self.logger.info("Worker stopped")
+ elif isinstance(exception, Exception):
+ self.logger.exception("Worker crashed")
+ elif exception:
+ self.logger.info(
+ f"Worker stopped due to {exception.__class__.__name__}"
+ )
+
+ with move_on_after(3, shield=True):
+ await self._event_broker.publish_local(
+ WorkerStopped(exception=exception)
+ )
+
+ async def _run_job(self, job: Job, func: Callable, executor: str) -> None:
+ try:
+ # Check if the job started before the deadline
+ start_time = datetime.now(timezone.utc)
+ if job.start_deadline is not None and start_time > job.start_deadline:
+ result = JobResult.from_job(
+ job,
+ outcome=JobOutcome.missed_start_deadline,
+ finished_at=start_time,
+ )
+ await self._data_store.release_job(self._identity, job.task_id, result)
+ await self._event_broker.publish(
+ JobReleased.from_result(result, self._identity)
+ )
+ return
+
+ try:
+ job_executor = self.job_executors[executor]
+ except KeyError:
+ return
+
+ token = current_job.set(JobInfo.from_job(job))
+ try:
+ retval = await job_executor.run_job(func, job)
+ except get_cancelled_exc_class():
+ self.logger.info("Job %s was cancelled", job.id)
+ with CancelScope(shield=True):
+ result = JobResult.from_job(
+ job,
+ outcome=JobOutcome.cancelled,
+ )
+ await self._data_store.release_job(
+ self._identity, job.task_id, result
+ )
+ await self._event_broker.publish(
+ JobReleased.from_result(result, self._identity)
+ )
+ except BaseException as exc:
+ if isinstance(exc, Exception):
+ self.logger.exception("Job %s raised an exception", job.id)
+ else:
+ self.logger.error(
+ "Job %s was aborted due to %s", job.id, exc.__class__.__name__
+ )
+
+ result = JobResult.from_job(
+ job,
+ JobOutcome.error,
+ exception=exc,
+ )
+ await self._data_store.release_job(
+ self._identity,
+ job.task_id,
+ result,
+ )
+ await self._event_broker.publish(
+ JobReleased.from_result(result, self._identity)
+ )
+ if not isinstance(exc, Exception):
+ raise
+ else:
+ self.logger.info("Job %s completed successfully", job.id)
+ result = JobResult.from_job(
+ job,
+ JobOutcome.success,
+ return_value=retval,
+ )
+ await self._data_store.release_job(self._identity, job.task_id, result)
+ await self._event_broker.publish(
+ JobReleased.from_result(result, self._identity)
+ )
+ finally:
+ current_job.reset(token)
+ finally:
+ self._running_jobs.remove(job.id)
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py
index 74920e5..bc30ddc 100644
--- a/src/apscheduler/abc.py
+++ b/src/apscheduler/abc.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABCMeta, abstractmethod
+from contextlib import AsyncExitStack
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator
from uuid import UUID
@@ -91,63 +92,20 @@ class Subscription(metaclass=ABCMeta):
"""
-class EventSource(metaclass=ABCMeta):
- """
- Interface for objects that can deliver notifications to interested subscribers.
- """
-
- @abstractmethod
- def subscribe(
- self,
- callback: Callable[[Event], Any],
- event_types: Iterable[type[Event]] | None = None,
- *,
- one_shot: bool = False,
- ) -> Subscription:
- """
- Subscribe to events from this event source.
-
- :param callback: callable to be called with the event object when an event is
- published
- :param event_types: an iterable of concrete Event classes to subscribe to
- :param one_shot: if ``True``, automatically unsubscribe after the first matching
- event
- """
-
-
-class EventBroker(EventSource):
+class EventBroker(metaclass=ABCMeta):
"""
Interface for objects that can be used to publish notifications to interested
subscribers.
"""
@abstractmethod
- def start(self) -> None:
- pass
-
- @abstractmethod
- def stop(self, *, force: bool = False) -> None:
- pass
-
- @abstractmethod
- def publish(self, event: Event) -> None:
- """Publish an event."""
-
- @abstractmethod
- def publish_local(self, event: Event) -> None:
- """Publish an event, but only to local subscribers."""
-
-
-class AsyncEventBroker(EventSource):
- """Asynchronous version of :class:`EventBroker`. Expected to work on asyncio."""
-
- @abstractmethod
- async def start(self) -> None:
- pass
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ """
+ Start the event broker.
- @abstractmethod
- async def stop(self, *, force: bool = False) -> None:
- pass
+ :param exit_stack: an asynchronous exit stack which will be processed when the
+ scheduler is shut down
+ """
@abstractmethod
async def publish(self, event: Event) -> None:
@@ -157,185 +115,45 @@ class AsyncEventBroker(EventSource):
async def publish_local(self, event: Event) -> None:
"""Publish an event, but only to local subscribers."""
-
-class DataStore(metaclass=ABCMeta):
- @abstractmethod
- def start(self, event_broker: EventBroker) -> None:
- pass
-
- @abstractmethod
- def stop(self, *, force: bool = False) -> None:
- pass
-
- @property
- @abstractmethod
- def events(self) -> EventSource:
- pass
-
- @abstractmethod
- def add_task(self, task: Task) -> None:
- """
- Add the given task to the store.
-
- If a task with the same ID already exists, it replaces the old one but does NOT
- affect task accounting (# of running jobs).
-
- :param task: the task to be added
- """
-
- @abstractmethod
- def remove_task(self, task_id: str) -> None:
- """
- Remove the task with the given ID.
-
- :param task_id: ID of the task to be removed
- :raises TaskLookupError: if no matching task was found
- """
-
- @abstractmethod
- def get_task(self, task_id: str) -> Task:
- """
- Get an existing task definition.
-
- :param task_id: ID of the task to be returned
- :return: the matching task
- :raises TaskLookupError: if no matching task was found
- """
-
- @abstractmethod
- def get_tasks(self) -> list[Task]:
- """
- Get all the tasks in this store.
-
- :return: a list of tasks, sorted by ID
- """
-
- @abstractmethod
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
- """
- Get schedules from the data store.
-
- :param ids: a specific set of schedule IDs to return, or ``None`` to return all
- schedules
- :return: the list of matching schedules, in unspecified order
- """
-
- @abstractmethod
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
- """
- Add or update the given schedule in the data store.
-
- :param schedule: schedule to be added
- :param conflict_policy: policy that determines what to do if there is an
- existing schedule with the same ID
- """
-
- @abstractmethod
- def remove_schedules(self, ids: Iterable[str]) -> None:
- """
- Remove schedules from the data store.
-
- :param ids: a specific set of schedule IDs to remove
- """
-
- @abstractmethod
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
- """
- Acquire unclaimed due schedules for processing.
-
- This method claims up to the requested number of schedules for the given
- scheduler and returns them.
-
- :param scheduler_id: unique identifier of the scheduler
- :param limit: maximum number of schedules to claim
- :return: the list of claimed schedules
- """
-
- @abstractmethod
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
- """
- Release the claims on the given schedules and update them on the store.
-
- :param scheduler_id: unique identifier of the scheduler
- :param schedules: the previously claimed schedules
- """
-
- @abstractmethod
- def get_next_schedule_run_time(self) -> datetime | None:
- """
- Return the earliest upcoming run time of all the schedules in the store, or
- ``None`` if there are no active schedules.
- """
-
@abstractmethod
- def add_job(self, job: Job) -> None:
- """
- Add a job to be executed by an eligible worker.
-
- :param job: the job object
- """
-
- @abstractmethod
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
- """
- Get the list of pending jobs.
-
- :param ids: a specific set of job IDs to return, or ``None`` to return all jobs
- :return: the list of matching pending jobs, in the order they will be given to
- workers
- """
-
- @abstractmethod
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ is_async: bool = True,
+ one_shot: bool = False,
+ ) -> Subscription:
"""
- Acquire unclaimed jobs for execution.
-
- This method claims up to the requested number of jobs for the given worker and
- returns them.
+ Subscribe to events from this event broker.
- :param worker_id: unique identifier of the worker
- :param limit: maximum number of jobs to claim and return
- :return: the list of claimed jobs
+ :param callback: callable to be called with the event object when an event is
+ published
+ :param event_types: an iterable of concrete Event classes to subscribe to
+ :param is_async: ``True`` if the (synchronous) callback should be called on the
+ event loop thread, ``False`` if it should be called in a worker thread.
+ If the callback is a coroutine function, this flag is ignored.
+ :param one_shot: if ``True``, automatically unsubscribe after the first matching
+ event
"""
- @abstractmethod
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
- """
- Release the claim on the given job and record the result.
- :param worker_id: unique identifier of the worker
- :param task_id: the job's task ID
- :param result: the result of the job
- """
+class DataStore(metaclass=ABCMeta):
+ """Asynchronous version of :class:`DataStore`. Expected to work on asyncio."""
@abstractmethod
- def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
"""
- Retrieve the result of a job.
-
- The result is removed from the store after retrieval.
+ Start the event broker.
- :param job_id: the identifier of the job
- :return: the result, or ``None`` if the result was not found
+ :param exit_stack: an asynchronous exit stack which will be processed when the
+ scheduler is shut down
+ :param event_broker: the event broker shared between the scheduler, worker (if
+ any) and this data store
"""
-
-class AsyncDataStore(metaclass=ABCMeta):
- """Asynchronous version of :class:`DataStore`. Expected to work on asyncio."""
-
- @abstractmethod
- async def start(self, event_broker: AsyncEventBroker) -> None:
- pass
-
- @abstractmethod
- async def stop(self, *, force: bool = False) -> None:
- pass
-
- @property
- @abstractmethod
- def events(self) -> EventSource:
- pass
-
@abstractmethod
async def add_task(self, task: Task) -> None:
"""
@@ -488,3 +306,23 @@ class AsyncDataStore(metaclass=ABCMeta):
:param job_id: the identifier of the job
:return: the result, or ``None`` if the result was not found
"""
+
+
+class JobExecutor(metaclass=ABCMeta):
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ """
+ Start the job executor.
+
+ :param exit_stack: an asynchronous exit stack which will be processed when the
+ scheduler is shut down
+ """
+
+ @abstractmethod
+ async def run_job(self, func: Callable[..., Any], job: Job) -> Any:
+ """
+
+ :param func:
+ :param job:
+ :return: the return value of ``func`` (potentially awaiting on the returned
+ aawaitable, if any)
+ """
diff --git a/src/apscheduler/datastores/async_adapter.py b/src/apscheduler/datastores/async_adapter.py
deleted file mode 100644
index d16ae56..0000000
--- a/src/apscheduler/datastores/async_adapter.py
+++ /dev/null
@@ -1,101 +0,0 @@
-from __future__ import annotations
-
-import sys
-from datetime import datetime
-from typing import Iterable
-from uuid import UUID
-
-import attrs
-from anyio import to_thread
-from anyio.from_thread import BlockingPortal
-
-from .._enums import ConflictPolicy
-from .._structures import Job, JobResult, Schedule, Task
-from ..abc import AsyncEventBroker, DataStore
-from ..eventbrokers.async_adapter import AsyncEventBrokerAdapter, SyncEventBrokerAdapter
-from .base import BaseAsyncDataStore
-
-
-@attrs.define(eq=False)
-class AsyncDataStoreAdapter(BaseAsyncDataStore):
- original: DataStore
- _portal: BlockingPortal = attrs.field(init=False)
-
- async def start(self, event_broker: AsyncEventBroker) -> None:
- await super().start(event_broker)
-
- self._portal = BlockingPortal()
- await self._portal.__aenter__()
-
- if isinstance(event_broker, AsyncEventBrokerAdapter):
- sync_event_broker = event_broker.original
- else:
- sync_event_broker = SyncEventBrokerAdapter(event_broker, self._portal)
-
- try:
- await to_thread.run_sync(lambda: self.original.start(sync_event_broker))
- except BaseException:
- await self._portal.__aexit__(*sys.exc_info())
- raise
-
- async def stop(self, *, force: bool = False) -> None:
- try:
- await to_thread.run_sync(lambda: self.original.stop(force=force))
- finally:
- await self._portal.__aexit__(None, None, None)
- await super().stop(force=force)
-
- async def add_task(self, task: Task) -> None:
- await to_thread.run_sync(self.original.add_task, task)
-
- async def remove_task(self, task_id: str) -> None:
- await to_thread.run_sync(self.original.remove_task, task_id)
-
- async def get_task(self, task_id: str) -> Task:
- return await to_thread.run_sync(self.original.get_task, task_id)
-
- async def get_tasks(self) -> list[Task]:
- return await to_thread.run_sync(self.original.get_tasks)
-
- async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
- return await to_thread.run_sync(self.original.get_schedules, ids)
-
- async def add_schedule(
- self, schedule: Schedule, conflict_policy: ConflictPolicy
- ) -> None:
- await to_thread.run_sync(self.original.add_schedule, schedule, conflict_policy)
-
- async def remove_schedules(self, ids: Iterable[str]) -> None:
- await to_thread.run_sync(self.original.remove_schedules, ids)
-
- async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
- return await to_thread.run_sync(
- self.original.acquire_schedules, scheduler_id, limit
- )
-
- async def release_schedules(
- self, scheduler_id: str, schedules: list[Schedule]
- ) -> None:
- await to_thread.run_sync(
- self.original.release_schedules, scheduler_id, schedules
- )
-
- async def get_next_schedule_run_time(self) -> datetime | None:
- return await to_thread.run_sync(self.original.get_next_schedule_run_time)
-
- async def add_job(self, job: Job) -> None:
- await to_thread.run_sync(self.original.add_job, job)
-
- async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
- return await to_thread.run_sync(self.original.get_jobs, ids)
-
- async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
- return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit)
-
- async def release_job(
- self, worker_id: str, task_id: str, result: JobResult
- ) -> None:
- await to_thread.run_sync(self.original.release_job, worker_id, task_id, result)
-
- async def get_job_result(self, job_id: UUID) -> JobResult | None:
- return await to_thread.run_sync(self.original.get_job_result, job_id)
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py
deleted file mode 100644
index 0c165b8..0000000
--- a/src/apscheduler/datastores/async_sqlalchemy.py
+++ /dev/null
@@ -1,602 +0,0 @@
-from __future__ import annotations
-
-from collections import defaultdict
-from datetime import datetime, timedelta, timezone
-from typing import Any, Iterable
-from uuid import UUID
-
-import anyio
-import attrs
-import sniffio
-import tenacity
-from sqlalchemy import and_, bindparam, or_, select
-from sqlalchemy.engine import URL, Result
-from sqlalchemy.exc import IntegrityError, InterfaceError
-from sqlalchemy.ext.asyncio import create_async_engine
-from sqlalchemy.ext.asyncio.engine import AsyncEngine
-from sqlalchemy.sql.ddl import DropTable
-from sqlalchemy.sql.elements import BindParameter
-
-from .._enums import ConflictPolicy
-from .._events import (
- DataStoreEvent,
- JobAcquired,
- JobAdded,
- JobDeserializationFailed,
- ScheduleAdded,
- ScheduleDeserializationFailed,
- ScheduleRemoved,
- ScheduleUpdated,
- TaskAdded,
- TaskRemoved,
- TaskUpdated,
-)
-from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError
-from .._structures import Job, JobResult, Schedule, Task
-from ..abc import AsyncEventBroker
-from ..marshalling import callable_to_ref
-from .base import BaseAsyncDataStore
-from .sqlalchemy import _BaseSQLAlchemyDataStore
-
-
-@attrs.define(eq=False)
-class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseAsyncDataStore):
- """
- Uses a relational database to store data.
-
- When started, this data store creates the appropriate tables on the given database
- if they're not already present.
-
- Operations are retried (in accordance to ``retry_settings``) when an operation
- raises :exc:`sqlalchemy.OperationalError`.
-
- This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL
- (asyncmy driver).
-
- :param engine: an asynchronous SQLAlchemy engine
- :param schema: a database schema name to use, if not the default
- :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
- for storage
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
- :param retry_settings: Tenacity settings for retrying operations in case of a
- database connecitivty problem
- :param start_from_scratch: erase all existing data during startup (useful for test
- suites)
- """
-
- engine: AsyncEngine
-
- @classmethod
- def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore:
- """
- Create a new asynchronous SQLAlchemy data store.
-
- :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine`
- (must use an async dialect like ``asyncpg`` or ``asyncmy``)
- :param kwargs: keyword arguments to pass to the initializer of this class
- :return: the newly created data store
-
- """
- engine = create_async_engine(url, future=True)
- return cls(engine, **options)
-
- def _retry(self) -> tenacity.AsyncRetrying:
- # OSError is raised by asyncpg if it can't connect
- return tenacity.AsyncRetrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type((InterfaceError, OSError)),
- after=self._after_attempt,
- sleep=anyio.sleep,
- reraise=True,
- )
-
- async def start(self, event_broker: AsyncEventBroker) -> None:
- await super().start(event_broker)
-
- asynclib = sniffio.current_async_library() or "(unknown)"
- if asynclib != "asyncio":
- raise RuntimeError(
- f"This data store requires asyncio; currently running: {asynclib}"
- )
-
- # Verify that the schema is in place
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- if self.start_from_scratch:
- for table in self._metadata.sorted_tables:
- await conn.execute(DropTable(table, if_exists=True))
-
- await conn.run_sync(self._metadata.create_all)
- query = select(self.t_metadata.c.schema_version)
- result = await conn.execute(query)
- version = result.scalar()
- if version is None:
- await conn.execute(
- self.t_metadata.insert(values={"schema_version": 1})
- )
- elif version > 1:
- raise RuntimeError(
- f"Unexpected schema version ({version}); "
- f"only version 1 is supported by this version of "
- f"APScheduler"
- )
-
- async def _deserialize_schedules(self, result: Result) -> list[Schedule]:
- schedules: list[Schedule] = []
- for row in result:
- try:
- schedules.append(Schedule.unmarshal(self.serializer, row._asdict()))
- except SerializationError as exc:
- await self._events.publish(
- ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc)
- )
-
- return schedules
-
- async def _deserialize_jobs(self, result: Result) -> list[Job]:
- jobs: list[Job] = []
- for row in result:
- try:
- jobs.append(Job.unmarshal(self.serializer, row._asdict()))
- except SerializationError as exc:
- await self._events.publish(
- JobDeserializationFailed(job_id=row["id"], exception=exc)
- )
-
- return jobs
-
- async def add_task(self, task: Task) -> None:
- insert = self.t_tasks.insert().values(
- id=task.id,
- func=callable_to_ref(task.func),
- max_running_jobs=task.max_running_jobs,
- misfire_grace_time=task.misfire_grace_time,
- )
- try:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(insert)
- except IntegrityError:
- update = (
- self.t_tasks.update()
- .values(
- func=callable_to_ref(task.func),
- max_running_jobs=task.max_running_jobs,
- misfire_grace_time=task.misfire_grace_time,
- )
- .where(self.t_tasks.c.id == task.id)
- )
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(update)
-
- await self._events.publish(TaskUpdated(task_id=task.id))
- else:
- await self._events.publish(TaskAdded(task_id=task.id))
-
- async def remove_task(self, task_id: str) -> None:
- delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(delete)
- if result.rowcount == 0:
- raise TaskLookupError(task_id)
- else:
- await self._events.publish(TaskRemoved(task_id=task_id))
-
- async def get_task(self, task_id: str) -> Task:
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.func,
- self.t_tasks.c.max_running_jobs,
- self.t_tasks.c.state,
- self.t_tasks.c.misfire_grace_time,
- ]
- ).where(self.t_tasks.c.id == task_id)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- row = result.first()
-
- if row:
- return Task.unmarshal(self.serializer, row._asdict())
- else:
- raise TaskLookupError(task_id)
-
- async def get_tasks(self) -> list[Task]:
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.func,
- self.t_tasks.c.max_running_jobs,
- self.t_tasks.c.state,
- self.t_tasks.c.misfire_grace_time,
- ]
- ).order_by(self.t_tasks.c.id)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- tasks = [
- Task.unmarshal(self.serializer, row._asdict()) for row in result
- ]
- return tasks
-
- async def add_schedule(
- self, schedule: Schedule, conflict_policy: ConflictPolicy
- ) -> None:
- event: DataStoreEvent
- values = schedule.marshal(self.serializer)
- insert = self.t_schedules.insert().values(**values)
- try:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(insert)
- except IntegrityError:
- if conflict_policy is ConflictPolicy.exception:
- raise ConflictingIdError(schedule.id) from None
- elif conflict_policy is ConflictPolicy.replace:
- del values["id"]
- update = (
- self.t_schedules.update()
- .where(self.t_schedules.c.id == schedule.id)
- .values(**values)
- )
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(update)
-
- event = ScheduleUpdated(
- schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
- )
- await self._events.publish(event)
- else:
- event = ScheduleAdded(
- schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
- )
- await self._events.publish(event)
-
- async def remove_schedules(self, ids: Iterable[str]) -> None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- delete = self.t_schedules.delete().where(
- self.t_schedules.c.id.in_(ids)
- )
- if self._supports_update_returning:
- delete = delete.returning(self.t_schedules.c.id)
- removed_ids: Iterable[str] = [
- row[0] for row in await conn.execute(delete)
- ]
- else:
- # TODO: actually check which rows were deleted?
- await conn.execute(delete)
- removed_ids = ids
-
- for schedule_id in removed_ids:
- await self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
-
- async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
- query = self.t_schedules.select().order_by(self.t_schedules.c.id)
- if ids:
- query = query.where(self.t_schedules.c.id.in_(ids))
-
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- return await self._deserialize_schedules(result)
-
- async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- schedules_cte = (
- select(self.t_schedules.c.id)
- .where(
- and_(
- self.t_schedules.c.next_fire_time.isnot(None),
- self.t_schedules.c.next_fire_time <= now,
- or_(
- self.t_schedules.c.acquired_until.is_(None),
- self.t_schedules.c.acquired_until < now,
- ),
- )
- )
- .order_by(self.t_schedules.c.next_fire_time)
- .limit(limit)
- .with_for_update(skip_locked=True)
- .cte()
- )
- subselect = select([schedules_cte.c.id])
- update = (
- self.t_schedules.update()
- .where(self.t_schedules.c.id.in_(subselect))
- .values(acquired_by=scheduler_id, acquired_until=acquired_until)
- )
- if self._supports_update_returning:
- update = update.returning(*self.t_schedules.columns)
- result = await conn.execute(update)
- else:
- await conn.execute(update)
- query = self.t_schedules.select().where(
- and_(self.t_schedules.c.acquired_by == scheduler_id)
- )
- result = await conn.execute(query)
-
- schedules = await self._deserialize_schedules(result)
-
- return schedules
-
- async def release_schedules(
- self, scheduler_id: str, schedules: list[Schedule]
- ) -> None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- update_events: list[ScheduleUpdated] = []
- finished_schedule_ids: list[str] = []
- update_args: list[dict[str, Any]] = []
- for schedule in schedules:
- if schedule.next_fire_time is not None:
- try:
- serialized_trigger = self.serializer.serialize(
- schedule.trigger
- )
- except SerializationError:
- self._logger.exception(
- "Error serializing trigger for schedule %r – "
- "removing from data store",
- schedule.id,
- )
- finished_schedule_ids.append(schedule.id)
- continue
-
- update_args.append(
- {
- "p_id": schedule.id,
- "p_trigger": serialized_trigger,
- "p_next_fire_time": schedule.next_fire_time,
- }
- )
- else:
- finished_schedule_ids.append(schedule.id)
-
- # Update schedules that have a next fire time
- if update_args:
- p_id: BindParameter = bindparam("p_id")
- p_trigger: BindParameter = bindparam("p_trigger")
- p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
- update = (
- self.t_schedules.update()
- .where(
- and_(
- self.t_schedules.c.id == p_id,
- self.t_schedules.c.acquired_by == scheduler_id,
- )
- )
- .values(
- trigger=p_trigger,
- next_fire_time=p_next_fire_time,
- acquired_by=None,
- acquired_until=None,
- )
- )
- next_fire_times = {
- arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
- }
- # TODO: actually check which rows were updated?
- await conn.execute(update, update_args)
- updated_ids = list(next_fire_times)
-
- for schedule_id in updated_ids:
- event = ScheduleUpdated(
- schedule_id=schedule_id,
- next_fire_time=next_fire_times[schedule_id],
- )
- update_events.append(event)
-
- # Remove schedules that have no next fire time or failed to
- # serialize
- if finished_schedule_ids:
- delete = self.t_schedules.delete().where(
- self.t_schedules.c.id.in_(finished_schedule_ids)
- )
- await conn.execute(delete)
-
- for event in update_events:
- await self._events.publish(event)
-
- for schedule_id in finished_schedule_ids:
- await self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
-
- async def get_next_schedule_run_time(self) -> datetime | None:
- statenent = (
- select(self.t_schedules.c.next_fire_time)
- .where(self.t_schedules.c.next_fire_time.isnot(None))
- .order_by(self.t_schedules.c.next_fire_time)
- .limit(1)
- )
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(statenent)
- return result.scalar()
-
- async def add_job(self, job: Job) -> None:
- marshalled = job.marshal(self.serializer)
- insert = self.t_jobs.insert().values(**marshalled)
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- await conn.execute(insert)
-
- event = JobAdded(
- job_id=job.id,
- task_id=job.task_id,
- schedule_id=job.schedule_id,
- tags=job.tags,
- )
- await self._events.publish(event)
-
- async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
- query = self.t_jobs.select().order_by(self.t_jobs.c.id)
- if ids:
- job_ids = [job_id for job_id in ids]
- query = query.where(self.t_jobs.c.id.in_(job_ids))
-
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- result = await conn.execute(query)
- return await self._deserialize_jobs(result)
-
- async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- query = (
- self.t_jobs.select()
- .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id)
- .where(
- or_(
- self.t_jobs.c.acquired_until.is_(None),
- self.t_jobs.c.acquired_until < now,
- )
- )
- .order_by(self.t_jobs.c.created_at)
- .with_for_update(skip_locked=True)
- .limit(limit)
- )
-
- result = await conn.execute(query)
- if not result:
- return []
-
- # Mark the jobs as acquired by this worker
- jobs = await self._deserialize_jobs(result)
- task_ids: set[str] = {job.task_id for job in jobs}
-
- # Retrieve the limits
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.max_running_jobs
- - self.t_tasks.c.running_jobs,
- ]
- ).where(
- self.t_tasks.c.max_running_jobs.isnot(None),
- self.t_tasks.c.id.in_(task_ids),
- )
- result = await conn.execute(query)
- job_slots_left: dict[str, int] = dict(result.fetchall())
-
- # Filter out jobs that don't have free slots
- acquired_jobs: list[Job] = []
- increments: dict[str, int] = defaultdict(lambda: 0)
- for job in jobs:
- # Don't acquire the job if there are no free slots left
- slots_left = job_slots_left.get(job.task_id)
- if slots_left == 0:
- continue
- elif slots_left is not None:
- job_slots_left[job.task_id] -= 1
-
- acquired_jobs.append(job)
- increments[job.task_id] += 1
-
- if acquired_jobs:
- # Mark the acquired jobs as acquired by this worker
- acquired_job_ids = [job.id for job in acquired_jobs]
- update = (
- self.t_jobs.update()
- .values(
- acquired_by=worker_id, acquired_until=acquired_until
- )
- .where(self.t_jobs.c.id.in_(acquired_job_ids))
- )
- await conn.execute(update)
-
- # Increment the running job counters on each task
- p_id: BindParameter = bindparam("p_id")
- p_increment: BindParameter = bindparam("p_increment")
- params = [
- {"p_id": task_id, "p_increment": increment}
- for task_id, increment in increments.items()
- ]
- update = (
- self.t_tasks.update()
- .values(
- running_jobs=self.t_tasks.c.running_jobs + p_increment
- )
- .where(self.t_tasks.c.id == p_id)
- )
- await conn.execute(update, params)
-
- # Publish the appropriate events
- for job in acquired_jobs:
- await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
-
- return acquired_jobs
-
- async def release_job(
- self, worker_id: str, task_id: str, result: JobResult
- ) -> None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- # Record the job result
- if result.expires_at > result.finished_at:
- marshalled = result.marshal(self.serializer)
- insert = self.t_job_results.insert().values(**marshalled)
- await conn.execute(insert)
-
- # Decrement the number of running jobs for this task
- update = (
- self.t_tasks.update()
- .values(running_jobs=self.t_tasks.c.running_jobs - 1)
- .where(self.t_tasks.c.id == task_id)
- )
- await conn.execute(update)
-
- # Delete the job
- delete = self.t_jobs.delete().where(
- self.t_jobs.c.id == result.job_id
- )
- await conn.execute(delete)
-
- async def get_job_result(self, job_id: UUID) -> JobResult | None:
- async for attempt in self._retry():
- with attempt:
- async with self.engine.begin() as conn:
- # Retrieve the result
- query = self.t_job_results.select().where(
- self.t_job_results.c.job_id == job_id
- )
- row = (await conn.execute(query)).first()
-
- # Delete the result
- delete = self.t_job_results.delete().where(
- self.t_job_results.c.job_id == job_id
- )
- await conn.execute(delete)
-
- return (
- JobResult.unmarshal(self.serializer, row._asdict())
- if row
- else None
- )
diff --git a/src/apscheduler/datastores/base.py b/src/apscheduler/datastores/base.py
index c05d28c..5c7ef7d 100644
--- a/src/apscheduler/datastores/base.py
+++ b/src/apscheduler/datastores/base.py
@@ -1,37 +1,47 @@
from __future__ import annotations
-from apscheduler.abc import (
- AsyncDataStore,
- AsyncEventBroker,
- DataStore,
- EventBroker,
- EventSource,
-)
+from contextlib import AsyncExitStack
+from logging import Logger, getLogger
+import attrs
+from .._retry import RetryMixin
+from ..abc import DataStore, EventBroker, Serializer
+from ..serializers.pickle import PickleSerializer
+
+
+@attrs.define(kw_only=True)
class BaseDataStore(DataStore):
- _events: EventBroker
+ """
+ Base class for data stores.
- def start(self, event_broker: EventBroker) -> None:
- self._events = event_broker
+ :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
+ or worker can keep a lock on a schedule or task
+ """
- def stop(self, *, force: bool = False) -> None:
- del self._events
+ lock_expiration_delay: float = 30
+ _event_broker: EventBroker = attrs.field(init=False)
+ _logger: Logger = attrs.field(init=False)
- @property
- def events(self) -> EventSource:
- return self._events
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
+ self._event_broker = event_broker
+ def __attrs_post_init__(self):
+ self._logger = getLogger(self.__class__.__name__)
-class BaseAsyncDataStore(AsyncDataStore):
- _events: AsyncEventBroker
- async def start(self, event_broker: AsyncEventBroker) -> None:
- self._events = event_broker
+@attrs.define(kw_only=True)
+class BaseExternalDataStore(BaseDataStore, RetryMixin):
+ """
+ Base class for data stores using an external service such as a database.
- async def stop(self, *, force: bool = False) -> None:
- del self._events
+ :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
+ for storage
+ :param start_from_scratch: erase all existing data during startup (useful for test
+ suites)
+ """
- @property
- def events(self) -> EventSource:
- return self._events
+ serializer: Serializer = attrs.field(factory=PickleSerializer)
+ start_from_scratch: bool = attrs.field(default=False)
diff --git a/src/apscheduler/datastores/memory.py b/src/apscheduler/datastores/memory.py
index a9ff3cb..fd7e90e 100644
--- a/src/apscheduler/datastores/memory.py
+++ b/src/apscheduler/datastores/memory.py
@@ -85,13 +85,9 @@ class MemoryDataStore(BaseDataStore):
"""
Stores scheduler data in memory, without serializing it.
- Can be shared between multiple schedulers and workers within the same event loop.
-
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
+ Can be shared between multiple schedulers within the same event loop.
"""
- lock_expiration_delay: float = 30
_tasks: dict[str, TaskState] = attrs.Factory(dict)
_schedules: list[ScheduleState] = attrs.Factory(list)
_schedules_by_id: dict[str, ScheduleState] = attrs.Factory(dict)
@@ -115,41 +111,43 @@ class MemoryDataStore(BaseDataStore):
right_index = bisect_left(self._jobs, state)
return self._jobs.index(state, left_index, right_index + 1)
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
+ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
return [
state.schedule
for state in self._schedules
if ids is None or state.schedule.id in ids
]
- def add_task(self, task: Task) -> None:
+ async def add_task(self, task: Task) -> None:
task_exists = task.id in self._tasks
self._tasks[task.id] = TaskState(task)
if task_exists:
- self._events.publish(TaskUpdated(task_id=task.id))
+ await self._event_broker.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._event_broker.publish(TaskAdded(task_id=task.id))
- def remove_task(self, task_id: str) -> None:
+ async def remove_task(self, task_id: str) -> None:
try:
del self._tasks[task_id]
except KeyError:
raise TaskLookupError(task_id) from None
- self._events.publish(TaskRemoved(task_id=task_id))
+ await self._event_broker.publish(TaskRemoved(task_id=task_id))
- def get_task(self, task_id: str) -> Task:
+ async def get_task(self, task_id: str) -> Task:
try:
return self._tasks[task_id].task
except KeyError:
raise TaskLookupError(task_id) from None
- def get_tasks(self) -> list[Task]:
+ async def get_tasks(self) -> list[Task]:
return sorted(
(state.task for state in self._tasks.values()), key=lambda task: task.id
)
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
old_state = self._schedules_by_id.get(schedule.id)
if old_state is not None:
if conflict_policy is ConflictPolicy.do_nothing:
@@ -175,17 +173,17 @@ class MemoryDataStore(BaseDataStore):
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def remove_schedules(self, ids: Iterable[str]) -> None:
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
for schedule_id in ids:
state = self._schedules_by_id.pop(schedule_id, None)
if state:
self._schedules.remove(state)
event = ScheduleRemoved(schedule_id=state.schedule.id)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
now = datetime.now(timezone.utc)
schedules: list[Schedule] = []
for state in self._schedules:
@@ -206,7 +204,9 @@ class MemoryDataStore(BaseDataStore):
return schedules
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
# Send update events for schedules that have a next time
finished_schedule_ids: list[str] = []
for s in schedules:
@@ -224,17 +224,17 @@ class MemoryDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=s.id, next_fire_time=s.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
else:
finished_schedule_ids.append(s.id)
# Remove schedules that didn't get a new next fire time
- self.remove_schedules(finished_schedule_ids)
+ await self.remove_schedules(finished_schedule_ids)
- def get_next_schedule_run_time(self) -> datetime | None:
+ async def get_next_schedule_run_time(self) -> datetime | None:
return self._schedules[0].next_fire_time if self._schedules else None
- def add_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
state = JobState(job)
self._jobs.append(state)
self._jobs_by_id[job.id] = state
@@ -246,15 +246,15 @@ class MemoryDataStore(BaseDataStore):
schedule_id=job.schedule_id,
tags=job.tags,
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
+ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
if ids is not None:
ids = frozenset(ids)
return [state.job for state in self._jobs if ids is None or state.job.id in ids]
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
now = datetime.now(timezone.utc)
jobs: list[Job] = []
for _index, job_state in enumerate(self._jobs):
@@ -290,11 +290,15 @@ class MemoryDataStore(BaseDataStore):
# Publish the appropriate events
for job in jobs:
- self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
+ await self._event_broker.publish(
+ JobAcquired(job_id=job.id, worker_id=worker_id)
+ )
return jobs
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
# Record the job result
if result.expires_at > result.finished_at:
self._job_results[result.job_id] = result
@@ -310,5 +314,5 @@ class MemoryDataStore(BaseDataStore):
index = self._find_job_index(job_state)
del self._jobs[index]
- def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async def get_job_result(self, job_id: UUID) -> JobResult | None:
return self._job_results.pop(job_id, None)
diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py
index 13299bb..4899f54 100644
--- a/src/apscheduler/datastores/mongodb.py
+++ b/src/apscheduler/datastores/mongodb.py
@@ -2,14 +2,13 @@ from __future__ import annotations
import operator
from collections import defaultdict
+from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
-from logging import Logger, getLogger
from typing import Any, Callable, ClassVar, Iterable
from uuid import UUID
import attrs
import pymongo
-import tenacity
from attrs.validators import instance_of
from bson import CodecOptions, UuidRepresentation
from bson.codec_options import TypeEncoder, TypeRegistry
@@ -35,10 +34,9 @@ from .._exceptions import (
SerializationError,
TaskLookupError,
)
-from .._structures import Job, JobResult, RetrySettings, Schedule, Task
-from ..abc import EventBroker, Serializer
-from ..serializers.pickle import PickleSerializer
-from .base import BaseDataStore
+from .._structures import Job, JobResult, Schedule, Task
+from ..abc import EventBroker
+from .base import BaseExternalDataStore
class CustomEncoder(TypeEncoder):
@@ -55,7 +53,7 @@ class CustomEncoder(TypeEncoder):
@attrs.define(eq=False)
-class MongoDBDataStore(BaseDataStore):
+class MongoDBDataStore(BaseExternalDataStore):
"""
Uses a MongoDB server to store data.
@@ -66,23 +64,11 @@ class MongoDBDataStore(BaseDataStore):
raises :exc:`pymongo.errors.ConnectionFailure`.
:param client: a PyMongo client
- :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
- for storage
:param database: name of the database to use
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
- :param retry_settings: Tenacity settings for retrying operations in case of a
- database connecitivty problem
- :param start_from_scratch: erase all existing data during startup (useful for test
- suites)
"""
client: MongoClient = attrs.field(validator=instance_of(MongoClient))
- serializer: Serializer = attrs.field(factory=PickleSerializer, kw_only=True)
database: str = attrs.field(default="apscheduler", kw_only=True)
- lock_expiration_delay: float = attrs.field(default=30, kw_only=True)
- retry_settings: RetrySettings = attrs.field(default=RetrySettings(), kw_only=True)
- start_from_scratch: bool = attrs.field(default=False, kw_only=True)
_task_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Task)]
_schedule_attrs: ClassVar[list[str]] = [
@@ -90,8 +76,8 @@ class MongoDBDataStore(BaseDataStore):
]
_job_attrs: ClassVar[list[str]] = [field.name for field in attrs.fields(Job)]
- _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__))
_local_tasks: dict[str, Task] = attrs.field(init=False, factory=dict)
+ _temporary_failure_exceptions = (ConnectionFailure,)
def __attrs_post_init__(self) -> None:
type_registry = TypeRegistry(
@@ -118,25 +104,10 @@ class MongoDBDataStore(BaseDataStore):
client = MongoClient(uri)
return cls(client, **options)
- def _retry(self) -> tenacity.Retrying:
- return tenacity.Retrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type(ConnectionFailure),
- after=self._after_attempt,
- reraise=True,
- )
-
- def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None:
- self._logger.warning(
- "Temporary data store error (attempt %d): %s",
- retry_state.attempt_number,
- retry_state.outcome.exception(),
- )
-
- def start(self, event_broker: EventBroker) -> None:
- super().start(event_broker)
-
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
+ await super().start(exit_stack, event_broker)
server_info = self.client.server_info()
if server_info["versionArray"] < [4, 0]:
raise RuntimeError(
@@ -144,7 +115,7 @@ class MongoDBDataStore(BaseDataStore):
f"{server_info['version']}"
)
- for attempt in self._retry():
+ async for attempt in self._retry():
with attempt, self.client.start_session() as session:
if self.start_from_scratch:
self._tasks.delete_many({}, session=session)
@@ -159,7 +130,7 @@ class MongoDBDataStore(BaseDataStore):
self._jobs_results.create_index("finished_at", session=session)
self._jobs_results.create_index("expires_at", session=session)
- def add_task(self, task: Task) -> None:
+ async def add_task(self, task: Task) -> None:
for attempt in self._retry():
with attempt:
previous = self._tasks.find_one_and_update(
@@ -173,20 +144,20 @@ class MongoDBDataStore(BaseDataStore):
self._local_tasks[task.id] = task
if previous:
- self._events.publish(TaskUpdated(task_id=task.id))
+ await self._event_broker.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._event_broker.publish(TaskAdded(task_id=task.id))
- def remove_task(self, task_id: str) -> None:
+ async def remove_task(self, task_id: str) -> None:
for attempt in self._retry():
with attempt:
if not self._tasks.find_one_and_delete({"_id": task_id}):
raise TaskLookupError(task_id)
del self._local_tasks[task_id]
- self._events.publish(TaskRemoved(task_id=task_id))
+ await self._event_broker.publish(TaskRemoved(task_id=task_id))
- def get_task(self, task_id: str) -> Task:
+ async def get_task(self, task_id: str) -> Task:
try:
return self._local_tasks[task_id]
except KeyError:
@@ -205,7 +176,7 @@ class MongoDBDataStore(BaseDataStore):
)
return task
- def get_tasks(self) -> list[Task]:
+ async def get_tasks(self) -> list[Task]:
for attempt in self._retry():
with attempt:
tasks: list[Task] = []
@@ -217,7 +188,7 @@ class MongoDBDataStore(BaseDataStore):
return tasks
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
+ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt:
@@ -237,7 +208,9 @@ class MongoDBDataStore(BaseDataStore):
return schedules
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
event: DataStoreEvent
document = schedule.marshal(self.serializer)
document["_id"] = document.pop("id")
@@ -258,14 +231,14 @@ class MongoDBDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
else:
event = ScheduleAdded(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def remove_schedules(self, ids: Iterable[str]) -> None:
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt, self.client.start_session() as session:
@@ -277,9 +250,9 @@ class MongoDBDataStore(BaseDataStore):
self._schedules.delete_many(filters, session=session)
for schedule_id in ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
schedules: list[Schedule] = []
@@ -318,7 +291,9 @@ class MongoDBDataStore(BaseDataStore):
return schedules
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
updated_schedules: list[tuple[str, datetime]] = []
finished_schedule_ids: list[str] = []
@@ -365,12 +340,12 @@ class MongoDBDataStore(BaseDataStore):
event = ScheduleUpdated(
schedule_id=schedule_id, next_fire_time=next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
for schedule_id in finished_schedule_ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def get_next_schedule_run_time(self) -> datetime | None:
+ async def get_next_schedule_run_time(self) -> datetime | None:
for attempt in self._retry():
with attempt:
document = self._schedules.find_one(
@@ -384,7 +359,7 @@ class MongoDBDataStore(BaseDataStore):
else:
return None
- def add_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
document = job.marshal(self.serializer)
document["_id"] = document.pop("id")
for attempt in self._retry():
@@ -397,9 +372,9 @@ class MongoDBDataStore(BaseDataStore):
schedule_id=job.schedule_id,
tags=job.tags,
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
+ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
filters = {"_id": {"$in": list(ids)}} if ids is not None else {}
for attempt in self._retry():
with attempt:
@@ -419,7 +394,7 @@ class MongoDBDataStore(BaseDataStore):
return jobs
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
cursor = self._jobs.find(
@@ -488,11 +463,15 @@ class MongoDBDataStore(BaseDataStore):
# Publish the appropriate events
for job in acquired_jobs:
- self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
+ await self._event_broker.publish(
+ JobAcquired(job_id=job.id, worker_id=worker_id)
+ )
return acquired_jobs
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
for attempt in self._retry():
with attempt, self.client.start_session() as session:
# Record the job result
@@ -509,7 +488,7 @@ class MongoDBDataStore(BaseDataStore):
# Delete the job
self._jobs.delete_one({"_id": result.job_id}, session=session)
- def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async def get_job_result(self, job_id: UUID) -> JobResult | None:
for attempt in self._retry():
with attempt:
document = self._jobs_results.find_one_and_delete({"_id": job_id})
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py
index ebb076f..00d06aa 100644
--- a/src/apscheduler/datastores/sqlalchemy.py
+++ b/src/apscheduler/datastores/sqlalchemy.py
@@ -1,16 +1,20 @@
from __future__ import annotations
+import sys
from collections import defaultdict
+from collections.abc import AsyncGenerator, Sequence
+from contextlib import AsyncExitStack, asynccontextmanager
from datetime import datetime, timedelta, timezone
-from logging import Logger, getLogger
from typing import Any, Iterable
from uuid import UUID
+import anyio
import attrs
+import sniffio
import tenacity
+from anyio import to_thread
from sqlalchemy import (
JSON,
- TIMESTAMP,
BigInteger,
Column,
Enum,
@@ -26,14 +30,15 @@ from sqlalchemy import (
select,
)
from sqlalchemy.engine import URL, Dialect, Result
-from sqlalchemy.exc import CompileError, IntegrityError, OperationalError
-from sqlalchemy.future import Engine, create_engine
+from sqlalchemy.exc import CompileError, IntegrityError, InterfaceError
+from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
+from sqlalchemy.future import Connection, Engine
from sqlalchemy.sql.ddl import DropTable
from sqlalchemy.sql.elements import BindParameter, literal
from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome
from .._events import (
- Event,
+ DataStoreEvent,
JobAcquired,
JobAdded,
JobDeserializationFailed,
@@ -46,11 +51,15 @@ from .._events import (
TaskUpdated,
)
from .._exceptions import ConflictingIdError, SerializationError, TaskLookupError
-from .._structures import Job, JobResult, RetrySettings, Schedule, Task
-from ..abc import EventBroker, Serializer
+from .._structures import Job, JobResult, Schedule, Task
+from ..abc import EventBroker
from ..marshalling import callable_to_ref
-from ..serializers.pickle import PickleSerializer
-from .base import BaseDataStore
+from .base import BaseExternalDataStore
+
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
class EmulatedUUID(TypeDecorator):
@@ -86,17 +95,30 @@ class EmulatedInterval(TypeDecorator):
return timedelta(seconds=value) if value is not None else None
-@attrs.define(kw_only=True, eq=False)
-class _BaseSQLAlchemyDataStore:
+@attrs.define(eq=False)
+class SQLAlchemyDataStore(BaseExternalDataStore):
+ """
+ Uses a relational database to store data.
+
+ When started, this data store creates the appropriate tables on the given database
+ if they're not already present.
+
+ Operations are retried (in accordance to ``retry_settings``) when an operation
+ raises :exc:`sqlalchemy.OperationalError`.
+
+ This store has been tested to work with PostgreSQL (asyncpg driver) and MySQL
+ (asyncmy driver).
+
+ :param engine: an asynchronous SQLAlchemy engine
+ :param schema: a database schema name to use, if not the default
+ """
+
+ engine: Engine | AsyncEngine
schema: str | None = attrs.field(default=None)
- serializer: Serializer = attrs.field(factory=PickleSerializer)
- lock_expiration_delay: float = attrs.field(default=30)
max_poll_time: float | None = attrs.field(default=1)
max_idle_time: float = attrs.field(default=60)
- retry_settings: RetrySettings = attrs.field(default=RetrySettings())
- start_from_scratch: bool = attrs.field(default=False)
- _logger: Logger = attrs.field(init=False, factory=lambda: getLogger(__name__))
+ _is_async: bool = attrs.field(init=False)
def __attrs_post_init__(self) -> None:
# Generate the table definitions
@@ -107,6 +129,7 @@ class _BaseSQLAlchemyDataStore:
self.t_schedules = self._metadata.tables[prefix + "schedules"]
self.t_jobs = self._metadata.tables[prefix + "jobs"]
self.t_job_results = self._metadata.tables[prefix + "job_results"]
+ self._is_async = isinstance(self.engine, AsyncEngine)
# Find out if the dialect supports UPDATE...RETURNING
update = self.t_jobs.update().returning(self.t_jobs.c.id)
@@ -117,18 +140,76 @@ class _BaseSQLAlchemyDataStore:
else:
self._supports_update_returning = True
- def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None:
- self._logger.warning(
- "Temporary data store error (attempt %d): %s",
- retry_state.attempt_number,
- retry_state.outcome.exception(),
+ @classmethod
+ def from_url(cls: type[Self], url: str | URL, **options) -> Self:
+ """
+ Create a new asynchronous SQLAlchemy data store.
+
+ :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine`
+ (must use an async dialect like ``asyncpg`` or ``asyncmy``)
+ :param kwargs: keyword arguments to pass to the initializer of this class
+ :return: the newly created data store
+
+ """
+ engine = create_async_engine(url, future=True)
+ return cls(engine, **options)
+
+ def _retry(self) -> tenacity.AsyncRetrying:
+ def after_attempt(self, retry_state: tenacity.RetryCallState) -> None:
+ self._logger.warning(
+ "Temporary data store error (attempt %d): %s",
+ retry_state.attempt_number,
+ retry_state.outcome.exception(),
+ )
+
+ # OSError is raised by asyncpg if it can't connect
+ return tenacity.AsyncRetrying(
+ stop=self.retry_settings.stop,
+ wait=self.retry_settings.wait,
+ retry=tenacity.retry_if_exception_type((InterfaceError, OSError)),
+ after=after_attempt,
+ sleep=anyio.sleep,
+ reraise=True,
)
+ @asynccontextmanager
+ async def _begin_transaction(self) -> AsyncGenerator[Connection | AsyncConnection]:
+ if isinstance(self.engine, AsyncEngine):
+ async with self.engine.begin() as conn:
+ yield conn
+ else:
+ cm = self.engine.begin()
+ conn = await to_thread.run_sync(cm.__enter__)
+ try:
+ yield conn
+ except BaseException as exc:
+ await to_thread.run_sync(cm.__exit__, type(exc), exc, exc.__traceback__)
+ raise
+ else:
+ await to_thread.run_sync(cm.__exit__, None, None, None)
+
+ async def _create_metadata(self, conn: Connection | AsyncConnection) -> None:
+ if isinstance(conn, AsyncConnection):
+ await conn.run_sync(self._metadata.create_all)
+ else:
+ await to_thread.run_sync(self._metadata.create_all, conn)
+
+ async def _execute(
+ self,
+ conn: Connection | AsyncConnection,
+ statement,
+ parameters: Sequence | None = None,
+ ):
+ if isinstance(conn, AsyncConnection):
+ return await conn.execute(statement, parameters)
+ else:
+ return await to_thread.run_sync(conn.execute, statement, parameters)
+
def get_table_definitions(self) -> MetaData:
if self.engine.dialect.name == "postgresql":
from sqlalchemy.dialects import postgresql
- timestamp_type = TIMESTAMP(timezone=True)
+ timestamp_type = postgresql.TIMESTAMP(timezone=True)
job_id_type = postgresql.UUID(as_uuid=True)
interval_type = postgresql.INTERVAL(precision=6)
tags_type = postgresql.ARRAY(Unicode)
@@ -145,6 +226,7 @@ class _BaseSQLAlchemyDataStore:
metadata,
Column("id", Unicode(500), primary_key=True),
Column("func", Unicode(500), nullable=False),
+ Column("executor", Unicode(500), nullable=False),
Column("state", LargeBinary),
Column("max_running_jobs", Integer),
Column("misfire_grace_time", interval_type),
@@ -197,186 +279,160 @@ class _BaseSQLAlchemyDataStore:
)
return metadata
- def _deserialize_schedules(self, result: Result) -> list[Schedule]:
+ async def start(
+ self, exit_stack: AsyncExitStack, event_broker: EventBroker
+ ) -> None:
+ await super().start(exit_stack, event_broker)
+ asynclib = sniffio.current_async_library() or "(unknown)"
+ if asynclib != "asyncio":
+ raise RuntimeError(
+ f"This data store requires asyncio; currently running: {asynclib}"
+ )
+
+ # Verify that the schema is in place
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ if self.start_from_scratch:
+ for table in self._metadata.sorted_tables:
+ await self._execute(conn, DropTable(table, if_exists=True))
+
+ await self._create_metadata(conn)
+ query = select(self.t_metadata.c.schema_version)
+ result = await self._execute(conn, query)
+ version = result.scalar()
+ if version is None:
+ await self._execute(
+ conn, self.t_metadata.insert(values={"schema_version": 1})
+ )
+ elif version > 1:
+ raise RuntimeError(
+ f"Unexpected schema version ({version}); "
+ f"only version 1 is supported by this version of "
+ f"APScheduler"
+ )
+
+ async def _deserialize_schedules(self, result: Result) -> list[Schedule]:
schedules: list[Schedule] = []
for row in result:
try:
schedules.append(Schedule.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
- self._events.publish(
+ await self._event_broker.publish(
ScheduleDeserializationFailed(schedule_id=row["id"], exception=exc)
)
return schedules
- def _deserialize_jobs(self, result: Result) -> list[Job]:
+ async def _deserialize_jobs(self, result: Result) -> list[Job]:
jobs: list[Job] = []
for row in result:
try:
jobs.append(Job.unmarshal(self.serializer, row._asdict()))
except SerializationError as exc:
- self._events.publish(
+ await self._event_broker.publish(
JobDeserializationFailed(job_id=row["id"], exception=exc)
)
return jobs
-
-@attrs.define(eq=False)
-class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore):
- """
- Uses a relational database to store data.
-
- When started, this data store creates the appropriate tables on the given database
- if they're not already present.
-
- Operations are retried (in accordance to ``retry_settings``) when an operation
- raises :exc:`sqlalchemy.OperationalError`.
-
- This store has been tested to work with PostgreSQL (psycopg2 driver), MySQL
- (pymysql driver) and SQLite.
-
- :param engine: a (synchronous) SQLAlchemy engine
- :param schema: a database schema name to use, if not the default
- :param serializer: the serializer used to (de)serialize tasks, schedules and jobs
- for storage
- :param lock_expiration_delay: maximum amount of time (in seconds) that a scheduler
- or worker can keep a lock on a schedule or task
- :param retry_settings: Tenacity settings for retrying operations in case of a
- database connecitivty problem
- :param start_from_scratch: erase all existing data during startup (useful for test
- suites)
- """
-
- engine: Engine
-
- @classmethod
- def from_url(cls, url: str | URL, **kwargs) -> SQLAlchemyDataStore:
- """
- Create a new SQLAlchemy data store.
-
- :param url: an SQLAlchemy URL to pass to :func:`~sqlalchemy.create_engine`
- :param kwargs: keyword arguments to pass to the initializer of this class
- :return: the newly created data store
-
- """
- engine = create_engine(url)
- return cls(engine, **kwargs)
-
- def _retry(self) -> tenacity.Retrying:
- return tenacity.Retrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type(OperationalError),
- after=self._after_attempt,
- reraise=True,
- )
-
- def start(self, event_broker: EventBroker) -> None:
- super().start(event_broker)
-
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- if self.start_from_scratch:
- for table in self._metadata.sorted_tables:
- conn.execute(DropTable(table, if_exists=True))
-
- self._metadata.create_all(conn)
- query = select(self.t_metadata.c.schema_version)
- result = conn.execute(query)
- version = result.scalar()
- if version is None:
- conn.execute(self.t_metadata.insert(values={"schema_version": 1}))
- elif version > 1:
- raise RuntimeError(
- f"Unexpected schema version ({version}); "
- f"only version 1 is supported by this version of APScheduler"
- )
-
- def add_task(self, task: Task) -> None:
+ async def add_task(self, task: Task) -> None:
insert = self.t_tasks.insert().values(
id=task.id,
func=callable_to_ref(task.func),
+ executor=task.executor,
max_running_jobs=task.max_running_jobs,
misfire_grace_time=task.misfire_grace_time,
)
try:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(insert)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, insert)
except IntegrityError:
update = (
self.t_tasks.update()
.values(
func=callable_to_ref(task.func),
+ executor=task.executor,
max_running_jobs=task.max_running_jobs,
misfire_grace_time=task.misfire_grace_time,
)
.where(self.t_tasks.c.id == task.id)
)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(update)
- self._events.publish(TaskUpdated(task_id=task.id))
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, update)
+
+ await self._event_broker.publish(TaskUpdated(task_id=task.id))
else:
- self._events.publish(TaskAdded(task_id=task.id))
+ await self._event_broker.publish(TaskAdded(task_id=task.id))
- def remove_task(self, task_id: str) -> None:
+ async def remove_task(self, task_id: str) -> None:
delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(delete)
- if result.rowcount == 0:
- raise TaskLookupError(task_id)
- else:
- self._events.publish(TaskRemoved(task_id=task_id))
-
- def get_task(self, task_id: str) -> Task:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, delete)
+ if result.rowcount == 0:
+ raise TaskLookupError(task_id)
+ else:
+ await self._event_broker.publish(TaskRemoved(task_id=task_id))
+
+ async def get_task(self, task_id: str) -> Task:
query = select(
[
self.t_tasks.c.id,
self.t_tasks.c.func,
+ self.t_tasks.c.executor,
self.t_tasks.c.max_running_jobs,
self.t_tasks.c.state,
self.t_tasks.c.misfire_grace_time,
]
).where(self.t_tasks.c.id == task_id)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- row = result.first()
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ row = result.first()
if row:
return Task.unmarshal(self.serializer, row._asdict())
else:
raise TaskLookupError(task_id)
- def get_tasks(self) -> list[Task]:
+ async def get_tasks(self) -> list[Task]:
query = select(
[
self.t_tasks.c.id,
self.t_tasks.c.func,
+ self.t_tasks.c.executor,
self.t_tasks.c.max_running_jobs,
self.t_tasks.c.state,
self.t_tasks.c.misfire_grace_time,
]
).order_by(self.t_tasks.c.id)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- tasks = [
- Task.unmarshal(self.serializer, row._asdict()) for row in result
- ]
- return tasks
-
- def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
- event: Event
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ tasks = [
+ Task.unmarshal(self.serializer, row._asdict()) for row in result
+ ]
+ return tasks
+
+ async def add_schedule(
+ self, schedule: Schedule, conflict_policy: ConflictPolicy
+ ) -> None:
+ event: DataStoreEvent
values = schedule.marshal(self.serializer)
insert = self.t_schedules.insert().values(**values)
try:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(insert)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, insert)
except IntegrityError:
if conflict_policy is ConflictPolicy.exception:
raise ConflictingIdError(schedule.id) from None
@@ -387,185 +443,197 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore):
.where(self.t_schedules.c.id == schedule.id)
.values(**values)
)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(update)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, update)
event = ScheduleUpdated(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
else:
event = ScheduleAdded(
schedule_id=schedule.id, next_fire_time=schedule.next_fire_time
)
- self._events.publish(event)
-
- def remove_schedules(self, ids: Iterable[str]) -> None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids))
- if self._supports_update_returning:
- delete = delete.returning(self.t_schedules.c.id)
- removed_ids: Iterable[str] = [
- row[0] for row in conn.execute(delete)
- ]
- else:
- # TODO: actually check which rows were deleted?
- conn.execute(delete)
- removed_ids = ids
+ await self._event_broker.publish(event)
+
+ async def remove_schedules(self, ids: Iterable[str]) -> None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ delete = self.t_schedules.delete().where(
+ self.t_schedules.c.id.in_(ids)
+ )
+ if self._supports_update_returning:
+ delete = delete.returning(self.t_schedules.c.id)
+ removed_ids: Iterable[str] = [
+ row[0] for row in await self._execute(conn, delete)
+ ]
+ else:
+ # TODO: actually check which rows were deleted?
+ await self._execute(conn, delete)
+ removed_ids = ids
for schedule_id in removed_ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
+ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]:
query = self.t_schedules.select().order_by(self.t_schedules.c.id)
if ids:
query = query.where(self.t_schedules.c.id.in_(ids))
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- return self._deserialize_schedules(result)
-
- def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- schedules_cte = (
- select(self.t_schedules.c.id)
- .where(
- and_(
- self.t_schedules.c.next_fire_time.isnot(None),
- self.t_schedules.c.next_fire_time <= now,
- or_(
- self.t_schedules.c.acquired_until.is_(None),
- self.t_schedules.c.acquired_until < now,
- ),
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ return await self._deserialize_schedules(result)
+
+ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ now = datetime.now(timezone.utc)
+ acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+ schedules_cte = (
+ select(self.t_schedules.c.id)
+ .where(
+ and_(
+ self.t_schedules.c.next_fire_time.isnot(None),
+ self.t_schedules.c.next_fire_time <= now,
+ or_(
+ self.t_schedules.c.acquired_until.is_(None),
+ self.t_schedules.c.acquired_until < now,
+ ),
+ )
)
+ .order_by(self.t_schedules.c.next_fire_time)
+ .limit(limit)
+ .with_for_update(skip_locked=True)
+ .cte()
)
- .order_by(self.t_schedules.c.next_fire_time)
- .limit(limit)
- .with_for_update(skip_locked=True)
- .cte()
- )
- subselect = select([schedules_cte.c.id])
- update = (
- self.t_schedules.update()
- .where(self.t_schedules.c.id.in_(subselect))
- .values(acquired_by=scheduler_id, acquired_until=acquired_until)
- )
- if self._supports_update_returning:
- update = update.returning(*self.t_schedules.columns)
- result = conn.execute(update)
- else:
- conn.execute(update)
- query = self.t_schedules.select().where(
- and_(self.t_schedules.c.acquired_by == scheduler_id)
+ subselect = select([schedules_cte.c.id])
+ update = (
+ self.t_schedules.update()
+ .where(self.t_schedules.c.id.in_(subselect))
+ .values(acquired_by=scheduler_id, acquired_until=acquired_until)
)
- result = conn.execute(query)
+ if self._supports_update_returning:
+ update = update.returning(*self.t_schedules.columns)
+ result = await self._execute(conn, update)
+ else:
+ await self._execute(conn, update)
+ query = self.t_schedules.select().where(
+ and_(self.t_schedules.c.acquired_by == scheduler_id)
+ )
+ result = await self._execute(conn, query)
- schedules = self._deserialize_schedules(result)
+ schedules = await self._deserialize_schedules(result)
return schedules
- def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- update_events: list[ScheduleUpdated] = []
- finished_schedule_ids: list[str] = []
- update_args: list[dict[str, Any]] = []
- for schedule in schedules:
- if schedule.next_fire_time is not None:
- try:
- serialized_trigger = self.serializer.serialize(
- schedule.trigger
- )
- except SerializationError:
- self._logger.exception(
- "Error serializing trigger for schedule %r – "
- "removing from data store",
- schedule.id,
+ async def release_schedules(
+ self, scheduler_id: str, schedules: list[Schedule]
+ ) -> None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ update_events: list[ScheduleUpdated] = []
+ finished_schedule_ids: list[str] = []
+ update_args: list[dict[str, Any]] = []
+ for schedule in schedules:
+ if schedule.next_fire_time is not None:
+ try:
+ serialized_trigger = self.serializer.serialize(
+ schedule.trigger
+ )
+ except SerializationError:
+ self._logger.exception(
+ "Error serializing trigger for schedule %r – "
+ "removing from data store",
+ schedule.id,
+ )
+ finished_schedule_ids.append(schedule.id)
+ continue
+
+ update_args.append(
+ {
+ "p_id": schedule.id,
+ "p_trigger": serialized_trigger,
+ "p_next_fire_time": schedule.next_fire_time,
+ }
)
+ else:
finished_schedule_ids.append(schedule.id)
- continue
-
- update_args.append(
- {
- "p_id": schedule.id,
- "p_trigger": serialized_trigger,
- "p_next_fire_time": schedule.next_fire_time,
- }
- )
- else:
- finished_schedule_ids.append(schedule.id)
- # Update schedules that have a next fire time
- if update_args:
- p_id: BindParameter = bindparam("p_id")
- p_trigger: BindParameter = bindparam("p_trigger")
- p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
- update = (
- self.t_schedules.update()
- .where(
- and_(
- self.t_schedules.c.id == p_id,
- self.t_schedules.c.acquired_by == scheduler_id,
+ # Update schedules that have a next fire time
+ if update_args:
+ p_id: BindParameter = bindparam("p_id")
+ p_trigger: BindParameter = bindparam("p_trigger")
+ p_next_fire_time: BindParameter = bindparam("p_next_fire_time")
+ update = (
+ self.t_schedules.update()
+ .where(
+ and_(
+ self.t_schedules.c.id == p_id,
+ self.t_schedules.c.acquired_by == scheduler_id,
+ )
+ )
+ .values(
+ trigger=p_trigger,
+ next_fire_time=p_next_fire_time,
+ acquired_by=None,
+ acquired_until=None,
)
)
- .values(
- trigger=p_trigger,
- next_fire_time=p_next_fire_time,
- acquired_by=None,
- acquired_until=None,
- )
- )
- next_fire_times = {
- arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
- }
- # TODO: actually check which rows were updated?
- conn.execute(update, update_args)
- updated_ids = list(next_fire_times)
-
- for schedule_id in updated_ids:
- event = ScheduleUpdated(
- schedule_id=schedule_id,
- next_fire_time=next_fire_times[schedule_id],
- )
- update_events.append(event)
+ next_fire_times = {
+ arg["p_id"]: arg["p_next_fire_time"] for arg in update_args
+ }
+ # TODO: actually check which rows were updated?
+ await self._execute(conn, update, update_args)
+ updated_ids = list(next_fire_times)
+
+ for schedule_id in updated_ids:
+ event = ScheduleUpdated(
+ schedule_id=schedule_id,
+ next_fire_time=next_fire_times[schedule_id],
+ )
+ update_events.append(event)
- # Remove schedules that have no next fire time or failed to serialize
- if finished_schedule_ids:
- delete = self.t_schedules.delete().where(
- self.t_schedules.c.id.in_(finished_schedule_ids)
- )
- conn.execute(delete)
+ # Remove schedules that have no next fire time or failed to
+ # serialize
+ if finished_schedule_ids:
+ delete = self.t_schedules.delete().where(
+ self.t_schedules.c.id.in_(finished_schedule_ids)
+ )
+ await self._execute(conn, delete)
for event in update_events:
- self._events.publish(event)
+ await self._event_broker.publish(event)
for schedule_id in finished_schedule_ids:
- self._events.publish(ScheduleRemoved(schedule_id=schedule_id))
+ await self._event_broker.publish(ScheduleRemoved(schedule_id=schedule_id))
- def get_next_schedule_run_time(self) -> datetime | None:
- query = (
+ async def get_next_schedule_run_time(self) -> datetime | None:
+ statenent = (
select(self.t_schedules.c.next_fire_time)
.where(self.t_schedules.c.next_fire_time.isnot(None))
.order_by(self.t_schedules.c.next_fire_time)
.limit(1)
)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- return result.scalar()
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, statenent)
+ return result.scalar()
- def add_job(self, job: Job) -> None:
+ async def add_job(self, job: Job) -> None:
marshalled = job.marshal(self.serializer)
insert = self.t_jobs.insert().values(**marshalled)
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- conn.execute(insert)
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ await self._execute(conn, insert)
event = JobAdded(
job_id=job.id,
@@ -573,139 +641,156 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, BaseDataStore):
schedule_id=job.schedule_id,
tags=job.tags,
)
- self._events.publish(event)
+ await self._event_broker.publish(event)
- def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
+ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]:
query = self.t_jobs.select().order_by(self.t_jobs.c.id)
if ids:
job_ids = [job_id for job_id in ids]
query = query.where(self.t_jobs.c.id.in_(job_ids))
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- result = conn.execute(query)
- return self._deserialize_jobs(result)
-
- def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- now = datetime.now(timezone.utc)
- acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
- query = (
- self.t_jobs.select()
- .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id)
- .where(
- or_(
- self.t_jobs.c.acquired_until.is_(None),
- self.t_jobs.c.acquired_until < now,
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ result = await self._execute(conn, query)
+ return await self._deserialize_jobs(result)
+
+ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[Job]:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ now = datetime.now(timezone.utc)
+ acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+ query = (
+ self.t_jobs.select()
+ .join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id)
+ .where(
+ or_(
+ self.t_jobs.c.acquired_until.is_(None),
+ self.t_jobs.c.acquired_until < now,
+ )
)
+ .order_by(self.t_jobs.c.created_at)
+ .with_for_update(skip_locked=True)
+ .limit(limit)
)
- .order_by(self.t_jobs.c.created_at)
- .with_for_update(skip_locked=True)
- .limit(limit)
- )
-
- result = conn.execute(query)
- if not result:
- return []
- # Mark the jobs as acquired by this worker
- jobs = self._deserialize_jobs(result)
- task_ids: set[str] = {job.task_id for job in jobs}
-
- # Retrieve the limits
- query = select(
- [
- self.t_tasks.c.id,
- self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs,
- ]
- ).where(
- self.t_tasks.c.max_running_jobs.isnot(None),
- self.t_tasks.c.id.in_(task_ids),
- )
- result = conn.execute(query)
- job_slots_left = dict(result.fetchall())
-
- # Filter out jobs that don't have free slots
- acquired_jobs: list[Job] = []
- increments: dict[str, int] = defaultdict(lambda: 0)
- for job in jobs:
- # Don't acquire the job if there are no free slots left
- slots_left = job_slots_left.get(job.task_id)
- if slots_left == 0:
- continue
- elif slots_left is not None:
- job_slots_left[job.task_id] -= 1
-
- acquired_jobs.append(job)
- increments[job.task_id] += 1
-
- if acquired_jobs:
- # Mark the acquired jobs as acquired by this worker
- acquired_job_ids = [job.id for job in acquired_jobs]
- update = (
- self.t_jobs.update()
- .values(acquired_by=worker_id, acquired_until=acquired_until)
- .where(self.t_jobs.c.id.in_(acquired_job_ids))
+ result = await self._execute(conn, query)
+ if not result:
+ return []
+
+ # Mark the jobs as acquired by this worker
+ jobs = await self._deserialize_jobs(result)
+ task_ids: set[str] = {job.task_id for job in jobs}
+
+ # Retrieve the limits
+ query = select(
+ [
+ self.t_tasks.c.id,
+ self.t_tasks.c.max_running_jobs
+ - self.t_tasks.c.running_jobs,
+ ]
+ ).where(
+ self.t_tasks.c.max_running_jobs.isnot(None),
+ self.t_tasks.c.id.in_(task_ids),
)
- conn.execute(update)
-
- # Increment the running job counters on each task
- p_id: BindParameter = bindparam("p_id")
- p_increment: BindParameter = bindparam("p_increment")
- params = [
- {"p_id": task_id, "p_increment": increment}
- for task_id, increment in increments.items()
- ]
- update = (
- self.t_tasks.update()
- .values(running_jobs=self.t_tasks.c.running_jobs + p_increment)
- .where(self.t_tasks.c.id == p_id)
- )
- conn.execute(update, params)
+ result = await self._execute(conn, query)
+ job_slots_left: dict[str, int] = dict(result.fetchall())
+
+ # Filter out jobs that don't have free slots
+ acquired_jobs: list[Job] = []
+ increments: dict[str, int] = defaultdict(lambda: 0)
+ for job in jobs:
+ # Don't acquire the job if there are no free slots left
+ slots_left = job_slots_left.get(job.task_id)
+ if slots_left == 0:
+ continue
+ elif slots_left is not None:
+ job_slots_left[job.task_id] -= 1
+
+ acquired_jobs.append(job)
+ increments[job.task_id] += 1
+
+ if acquired_jobs:
+ # Mark the acquired jobs as acquired by this worker
+ acquired_job_ids = [job.id for job in acquired_jobs]
+ update = (
+ self.t_jobs.update()
+ .values(
+ acquired_by=worker_id, acquired_until=acquired_until
+ )
+ .where(self.t_jobs.c.id.in_(acquired_job_ids))
+ )
+ await self._execute(conn, update)
+
+ # Increment the running job counters on each task
+ p_id: BindParameter = bindparam("p_id")
+ p_increment: BindParameter = bindparam("p_increment")
+ params = [
+ {"p_id": task_id, "p_increment": increment}
+ for task_id, increment in increments.items()
+ ]
+ update = (
+ self.t_tasks.update()
+ .values(
+ running_jobs=self.t_tasks.c.running_jobs + p_increment
+ )
+ .where(self.t_tasks.c.id == p_id)
+ )
+ await self._execute(conn, update, params)
# Publish the appropriate events
for job in acquired_jobs:
- self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id))
+ await self._event_broker.publish(
+ JobAcquired(job_id=job.id, worker_id=worker_id)
+ )
return acquired_jobs
- def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- # Insert the job result
- if result.expires_at > result.finished_at:
- marshalled = result.marshal(self.serializer)
- insert = self.t_job_results.insert().values(**marshalled)
- conn.execute(insert)
+ async def release_job(
+ self, worker_id: str, task_id: str, result: JobResult
+ ) -> None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ # Record the job result
+ if result.expires_at > result.finished_at:
+ marshalled = result.marshal(self.serializer)
+ insert = self.t_job_results.insert().values(**marshalled)
+ await self._execute(conn, insert)
+
+ # Decrement the number of running jobs for this task
+ update = (
+ self.t_tasks.update()
+ .values(running_jobs=self.t_tasks.c.running_jobs - 1)
+ .where(self.t_tasks.c.id == task_id)
+ )
+ await self._execute(conn, update)
- # Decrement the running jobs counter
- update = (
- self.t_tasks.update()
- .values(running_jobs=self.t_tasks.c.running_jobs - 1)
- .where(self.t_tasks.c.id == task_id)
- )
- conn.execute(update)
-
- # Delete the job
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id)
- conn.execute(delete)
-
- def get_job_result(self, job_id: UUID) -> JobResult | None:
- for attempt in self._retry():
- with attempt, self.engine.begin() as conn:
- # Retrieve the result
- query = self.t_job_results.select().where(
- self.t_job_results.c.job_id == job_id
- )
- row = conn.execute(query).first()
+ # Delete the job
+ delete = self.t_jobs.delete().where(
+ self.t_jobs.c.id == result.job_id
+ )
+ await self._execute(conn, delete)
+
+ async def get_job_result(self, job_id: UUID) -> JobResult | None:
+ async for attempt in self._retry():
+ with attempt:
+ async with self._begin_transaction() as conn:
+ # Retrieve the result
+ query = self.t_job_results.select().where(
+ self.t_job_results.c.job_id == job_id
+ )
+ row = (await self._execute(conn, query)).first()
- # Delete the result
- delete = self.t_job_results.delete().where(
- self.t_job_results.c.job_id == job_id
- )
- conn.execute(delete)
+ # Delete the result
+ delete = self.t_job_results.delete().where(
+ self.t_job_results.c.job_id == job_id
+ )
+ await self._execute(conn, delete)
- return (
- JobResult.unmarshal(self.serializer, row._asdict()) if row else None
- )
+ return (
+ JobResult.unmarshal(self.serializer, row._asdict())
+ if row
+ else None
+ )
diff --git a/src/apscheduler/eventbrokers/async_adapter.py b/src/apscheduler/eventbrokers/async_adapter.py
deleted file mode 100644
index 4ff08a5..0000000
--- a/src/apscheduler/eventbrokers/async_adapter.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from __future__ import annotations
-
-from typing import Any, Callable, Iterable
-
-import attrs
-from anyio import to_thread
-from anyio.from_thread import BlockingPortal
-
-from .._events import Event
-from ..abc import AsyncEventBroker, EventBroker, Subscription
-
-
-@attrs.define(eq=False)
-class AsyncEventBrokerAdapter(AsyncEventBroker):
- original: EventBroker
-
- async def start(self) -> None:
- await to_thread.run_sync(self.original.start)
-
- async def stop(self, *, force: bool = False) -> None:
- await to_thread.run_sync(lambda: self.original.stop(force=force))
-
- async def publish_local(self, event: Event) -> None:
- await to_thread.run_sync(self.original.publish_local, event)
-
- async def publish(self, event: Event) -> None:
- await to_thread.run_sync(self.original.publish, event)
-
- def subscribe(
- self,
- callback: Callable[[Event], Any],
- event_types: Iterable[type[Event]] | None = None,
- *,
- one_shot: bool = False,
- ) -> Subscription:
- return self.original.subscribe(callback, event_types, one_shot=one_shot)
-
-
-@attrs.define(eq=False)
-class SyncEventBrokerAdapter(EventBroker):
- original: AsyncEventBroker
- portal: BlockingPortal
-
- def start(self) -> None:
- pass
-
- def stop(self, *, force: bool = False) -> None:
- pass
-
- def publish_local(self, event: Event) -> None:
- self.portal.call(self.original.publish_local, event)
-
- def publish(self, event: Event) -> None:
- self.portal.call(self.original.publish, event)
-
- def subscribe(
- self,
- callback: Callable[[Event], Any],
- event_types: Iterable[type[Event]] | None = None,
- *,
- one_shot: bool = False,
- ) -> Subscription:
- return self.portal.call(
- lambda: self.original.subscribe(callback, event_types, one_shot=one_shot)
- )
diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py
deleted file mode 100644
index f69bfc7..0000000
--- a/src/apscheduler/eventbrokers/async_local.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from __future__ import annotations
-
-from asyncio import iscoroutine
-from typing import Any, Callable
-
-import attrs
-from anyio import create_task_group
-from anyio.abc import TaskGroup
-
-from .._events import Event
-from ..abc import AsyncEventBroker
-from .base import BaseEventBroker
-
-
-@attrs.define(eq=False)
-class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker):
- """
- Asynchronous, local event broker.
-
- This event broker only broadcasts within the process it runs in, and is therefore
- not suitable for multi-node or multiprocess use cases.
-
- Does not serialize events.
- """
-
- _task_group: TaskGroup = attrs.field(init=False)
-
- async def start(self) -> None:
- self._task_group = create_task_group()
- await self._task_group.__aenter__()
-
- async def stop(self, *, force: bool = False) -> None:
- await self._task_group.__aexit__(None, None, None)
- del self._task_group
-
- async def publish(self, event: Event) -> None:
- await self.publish_local(event)
-
- async def publish_local(self, event: Event) -> None:
- event_type = type(event)
- one_shot_tokens: list[object] = []
- for _token, subscription in self._subscriptions.items():
- if (
- subscription.event_types is None
- or event_type in subscription.event_types
- ):
- self._task_group.start_soon(
- self._deliver_event, subscription.callback, event
- )
- if subscription.one_shot:
- one_shot_tokens.append(subscription.token)
-
- for token in one_shot_tokens:
- super().unsubscribe(token)
-
- async def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None:
- try:
- retval = func(event)
- if iscoroutine(retval):
- await retval
- except BaseException:
- self._logger.exception(
- "Error delivering %s event", event.__class__.__name__
- )
diff --git a/src/apscheduler/eventbrokers/async_redis.py b/src/apscheduler/eventbrokers/async_redis.py
deleted file mode 100644
index 5e71621..0000000
--- a/src/apscheduler/eventbrokers/async_redis.py
+++ /dev/null
@@ -1,124 +0,0 @@
-from __future__ import annotations
-
-from asyncio import CancelledError
-
-import anyio
-import attrs
-import tenacity
-from redis import ConnectionError
-from redis.asyncio import Redis, RedisCluster
-from redis.asyncio.client import PubSub
-from redis.asyncio.connection import ConnectionPool
-
-from .. import RetrySettings
-from .._events import Event
-from ..abc import Serializer
-from ..serializers.json import JSONSerializer
-from .async_local import LocalAsyncEventBroker
-from .base import DistributedEventBrokerMixin
-
-
-@attrs.define(eq=False)
-class AsyncRedisEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
- """
- An event broker that uses a Redis server to broadcast events.
-
- Requires the redis_ library to be installed.
-
- .. _redis: https://pypi.org/project/redis/
-
- :param client: an asynchronous Redis client
- :param serializer: the serializer used to (de)serialize events for transport
- :param channel: channel on which to send the messages
- :param retry_settings: Tenacity settings for retrying operations in case of a
- broker connectivity problem
- :param stop_check_interval: interval (in seconds) on which the channel listener
- should check if it should stop (higher values mean slower reaction time but less
- CPU use)
- """
-
- client: Redis | RedisCluster
- serializer: Serializer = attrs.field(factory=JSONSerializer)
- channel: str = attrs.field(kw_only=True, default="apscheduler")
- retry_settings: RetrySettings = attrs.field(default=RetrySettings())
- stop_check_interval: float = attrs.field(kw_only=True, default=1)
- _stopped: bool = attrs.field(init=False, default=True)
-
- @classmethod
- def from_url(cls, url: str, **kwargs) -> AsyncRedisEventBroker:
- """
- Create a new event broker from a URL.
-
- :param url: a Redis URL (```redis://...```)
- :param kwargs: keyword arguments to pass to the initializer of this class
- :return: the newly created event broker
-
- """
- pool = ConnectionPool.from_url(url)
- client = Redis(connection_pool=pool)
- return cls(client, **kwargs)
-
- def _retry(self) -> tenacity.AsyncRetrying:
- def after_attempt(retry_state: tenacity.RetryCallState) -> None:
- self._logger.warning(
- f"{self.__class__.__name__}: connection failure "
- f"(attempt {retry_state.attempt_number}): "
- f"{retry_state.outcome.exception()}",
- )
-
- return tenacity.AsyncRetrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type(ConnectionError),
- after=after_attempt,
- sleep=anyio.sleep,
- reraise=True,
- )
-
- async def start(self) -> None:
- await super().start()
- pubsub = self.client.pubsub()
- try:
- await pubsub.subscribe(self.channel)
- except BaseException:
- await self.stop(force=True)
- raise
-
- self._stopped = False
- self._task_group.start_soon(
- self._listen_messages, pubsub, name="Redis subscriber"
- )
-
- async def stop(self, *, force: bool = False) -> None:
- self._stopped = True
- await super().stop(force=force)
-
- async def _listen_messages(self, pubsub: PubSub) -> None:
- while not self._stopped:
- try:
- async for attempt in self._retry():
- with attempt:
- msg = await pubsub.get_message(
- ignore_subscribe_messages=True,
- timeout=self.stop_check_interval,
- )
-
- if msg and isinstance(msg["data"], bytes):
- event = self.reconstitute_event(msg["data"])
- if event is not None:
- await self.publish_local(event)
- except Exception as exc:
- # CancelledError is a subclass of Exception in Python 3.7
- if not isinstance(exc, CancelledError):
- self._logger.exception(
- f"{self.__class__.__name__} listener crashed"
- )
-
- await pubsub.close()
- raise
-
- async def publish(self, event: Event) -> None:
- notification = self.generate_notification(event)
- async for attempt in self._retry():
- with attempt:
- await self.client.publish(self.channel, notification)
diff --git a/src/apscheduler/eventbrokers/asyncpg.py b/src/apscheduler/eventbrokers/asyncpg.py
index 7e3045e..59f7292 100644
--- a/src/apscheduler/eventbrokers/asyncpg.py
+++ b/src/apscheduler/eventbrokers/asyncpg.py
@@ -5,10 +5,8 @@ from contextlib import AsyncExitStack
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, cast
-import anyio
import asyncpg
import attrs
-import tenacity
from anyio import (
TASK_STATUS_IGNORED,
EndOfStream,
@@ -18,20 +16,16 @@ from anyio import (
from anyio.streams.memory import MemoryObjectSendStream
from asyncpg import Connection, InterfaceError
-from .. import RetrySettings
from .._events import Event
from .._exceptions import SerializationError
-from ..abc import Serializer
-from ..serializers.json import JSONSerializer
-from .async_local import LocalAsyncEventBroker
-from .base import DistributedEventBrokerMixin
+from .base import BaseExternalEventBroker
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine
@attrs.define(eq=False)
-class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
+class AsyncpgEventBroker(BaseExternalEventBroker):
"""
An asynchronous, asyncpg_ based event broker that uses a PostgreSQL server to
broadcast events using its ``NOTIFY`` mechanism.
@@ -39,16 +33,13 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
.. _asyncpg: https://pypi.org/project/asyncpg/
:param connection_factory: a callable that creates an asyncpg connection
- :param serializer: the serializer used to (de)serialize events for transport
:param channel: the ``NOTIFY`` channel to use
:param max_idle_time: maximum time to let the connection go idle, before sending a
``SELECT 1`` query to prevent a connection timeout
"""
connection_factory: Callable[[], Awaitable[Connection]]
- serializer: Serializer = attrs.field(kw_only=True, factory=JSONSerializer)
channel: str = attrs.field(kw_only=True, default="apscheduler")
- retry_settings: RetrySettings = attrs.field(default=RetrySettings())
max_idle_time: float = attrs.field(kw_only=True, default=10)
_send: MemoryObjectSendStream[str] = attrs.field(init=False)
@@ -111,38 +102,17 @@ class AsyncpgEventBroker(LocalAsyncEventBroker, DistributedEventBrokerMixin):
factory = partial(asyncpg.connect, **connect_args)
return cls(factory, **kwargs)
- def _retry(self) -> tenacity.AsyncRetrying:
- def after_attempt(retry_state: tenacity.RetryCallState) -> None:
- self._logger.warning(
- f"{self.__class__.__name__}: connection failure "
- f"(attempt {retry_state.attempt_number}): "
- f"{retry_state.outcome.exception()}",
- )
+ @property
+ def _temporary_failure_exceptions(self) -> tuple[type[Exception]]:
+ return OSError, InterfaceError
- return tenacity.AsyncRetrying(
- stop=self.retry_settings.stop,
- wait=self.retry_settings.wait,
- retry=tenacity.retry_if_exception_type((OSError, InterfaceError)),
- after=after_attempt,
- sleep=anyio.sleep,
- reraise=True,
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ await super().start(exit_stack)
+ self._send = cast(
+ MemoryObjectSendStream[str],
+ await self._task_group.start(self._listen_notifications),
)
-
- async def start(self) -> None:
- await super().start()
- try:
- self._send = cast(
- MemoryObjectSendStream[str],
- await self._task_group.start(self._listen_notifications),
- )
- except BaseException:
- await super().stop(force=True)
- raise
-
- async def stop(self, *, force: bool = False) -> None:
- self._send.close()
- await super().stop(force=force)
- self._logger.info("Stopped event broker")
+ await exit_stack.enter_async_context(self._send)
async def _listen_notifications(self, *, task_status=TASK_STATUS_IGNORED) -> None:
conn: Connection
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py
index 1373d54..5a3e718 100644
--- a/src/apscheduler/eventbrokers/base.py
+++ b/src/apscheduler/eventbrokers/base.py
@@ -1,15 +1,21 @@
from __future__ import annotations
from base64 import b64decode, b64encode
+from contextlib import AsyncExitStack
+from inspect import iscoroutine
from logging import Logger, getLogger
from typing import Any, Callable, Iterable
import attrs
+from anyio import create_task_group, to_thread
+from anyio.abc import TaskGroup
from .. import _events
from .._events import Event
from .._exceptions import DeserializationError
-from ..abc import EventSource, Serializer, Subscription
+from .._retry import RetryMixin
+from ..abc import EventBroker, Serializer, Subscription
+from ..serializers.json import JSONSerializer
@attrs.define(eq=False, frozen=True)
@@ -17,6 +23,7 @@ class LocalSubscription(Subscription):
callback: Callable[[Event], Any]
event_types: set[type[Event]] | None
one_shot: bool
+ is_async: bool
token: object
_source: BaseEventBroker
@@ -24,36 +31,79 @@ class LocalSubscription(Subscription):
self._source.unsubscribe(self.token)
-@attrs.define(eq=False)
-class BaseEventBroker(EventSource):
+@attrs.define(kw_only=True)
+class BaseEventBroker(EventBroker):
_logger: Logger = attrs.field(init=False)
_subscriptions: dict[object, LocalSubscription] = attrs.field(
init=False, factory=dict
)
+ _task_group: TaskGroup = attrs.field(init=False)
def __attrs_post_init__(self) -> None:
self._logger = getLogger(self.__class__.__module__)
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ self._task_group = await exit_stack.enter_async_context(create_task_group())
+
def subscribe(
self,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
*,
+ is_async: bool = True,
one_shot: bool = False,
) -> Subscription:
types = set(event_types) if event_types else None
token = object()
- subscription = LocalSubscription(callback, types, one_shot, token, self)
+ subscription = LocalSubscription(
+ callback, types, one_shot, is_async, token, self
+ )
self._subscriptions[token] = subscription
return subscription
def unsubscribe(self, token: object) -> None:
self._subscriptions.pop(token, None)
+ async def publish_local(self, event: Event) -> None:
+ event_type = type(event)
+ one_shot_tokens: list[object] = []
+ for _token, subscription in self._subscriptions.items():
+ if (
+ subscription.event_types is None
+ or event_type in subscription.event_types
+ ):
+ self._task_group.start_soon(self._deliver_event, subscription, event)
+ if subscription.one_shot:
+ one_shot_tokens.append(subscription.token)
+
+ for token in one_shot_tokens:
+ self.unsubscribe(token)
+
+ async def _deliver_event(
+ self, subscription: LocalSubscription, event: Event
+ ) -> None:
+ try:
+ if subscription.is_async:
+ retval = subscription.callback(event)
+ if iscoroutine(retval):
+ await retval
+ else:
+ await to_thread.run_sync(subscription.callback, event)
+ except Exception:
+ self._logger.exception(
+ "Error delivering %s event", event.__class__.__name__
+ )
+
+
+@attrs.define(kw_only=True)
+class BaseExternalEventBroker(BaseEventBroker, RetryMixin):
+ """
+ Base class for event brokers that use an external service.
+
+ :param serializer: the serializer used to (de)serialize events for transport
+ """
-class DistributedEventBrokerMixin:
- serializer: Serializer
- _logger: Logger
+ serializer: Serializer = attrs.field(factory=JSONSerializer)
def generate_notification(self, event: Event) -> bytes:
serialized = self.serializer.serialize(event.marshal(self.serializer))
diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py
index 25ff2dd..27a3cfd 100644
--- a/src/apscheduler/eventbrokers/local.py
+++ b/src/apscheduler/eventbrokers/local.py
@@ -1,22 +1,15 @@
from __future__ import annotations
-from asyncio import iscoroutinefunction
-from concurrent.futures import ThreadPoolExecutor
-from contextlib import ExitStack
-from threading import Lock
-from typing import Any, Callable, Iterable
-
import attrs
from .._events import Event
-from ..abc import EventBroker, Subscription
from .base import BaseEventBroker
@attrs.define(eq=False)
-class LocalEventBroker(EventBroker, BaseEventBroker):
+class LocalEventBroker(BaseEventBroker):
"""
- Synchronous, local event broker.
+ Asynchronous, local event broker.
This event broker only broadcasts within the process it runs in, and is therefore
not suitable for multi-node or multiprocess use cases.
@@ -24,62 +17,5 @@ class LocalEventBroker(EventBroker, BaseEventBroker):
Does not serialize events.
"""
- _executor: ThreadPoolExecutor = attrs.field(init=False)
- _exit_stack: ExitStack = attrs.field(init=False)
- _subscriptions_lock: Lock = attrs.field(init=False, factory=Lock)
-
- def start(self) -> None:
- self._executor = ThreadPoolExecutor(1)
-
- def stop(self, *, force: bool = False) -> None:
- self._executor.shutdown(wait=not force)
- del self._executor
-
- def subscribe(
- self,
- callback: Callable[[Event], Any],
- event_types: Iterable[type[Event]] | None = None,
- *,
- one_shot: bool = False,
- ) -> Subscription:
- if iscoroutinefunction(callback):
- raise ValueError(
- "Coroutine functions are not supported as callbacks on a synchronous "
- "event source"
- )
-
- with self._subscriptions_lock:
- return super().subscribe(callback, event_types, one_shot=one_shot)
-
- def unsubscribe(self, token: object) -> None:
- with self._subscriptions_lock:
- super().unsubscribe(token)
-
- def publish(self, event: Event) -> None:
- self.publish_local(event)
-
- def publish_local(self, event: Event) -> None:
- event_type = type(event)
- with self._subscriptions_lock:
- one_shot_tokens: list[object] = []
- for _token, subscription in self._subscriptions.items():
- if (
- subscription.event_types is None
- or event_type in subscription.event_types
- ):
- self._executor.submit(
- self._deliver_event, subscription.callback, event
- )
- if subscription.one_shot:
- one_shot_tokens.append(subscription.token)
-
- for token in one_shot_tokens:
- super().unsubscribe(token)
-
- def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None:
- try:
- func(event)
- except BaseException:
- self._logger.exception(
- "Error delivering %s event", event.__class__.__name__
- )
+ async def publish(self, event: Event) -> None:
+ await self.publish_local(event)
diff --git a/src/apscheduler/eventbrokers/mqtt.py b/src/apscheduler/eventbrokers/mqtt.py
index 10ac605..567418d 100644
--- a/src/apscheduler/eventbrokers/mqtt.py
+++ b/src/apscheduler/eventbrokers/mqtt.py
@@ -2,22 +2,22 @@ from __future__ import annotations
import sys
from concurrent.futures import Future
+from contextlib import AsyncExitStack
from typing import Any
import attrs
+from anyio import to_thread
+from anyio.from_thread import BlockingPortal
from paho.mqtt.client import Client, MQTTMessage
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCodes
from .._events import Event
-from ..abc import Serializer
-from ..serializers.json import JSONSerializer
-from .base import DistributedEventBrokerMixin
-from .local import LocalEventBroker
+from .base import BaseExternalEventBroker
@attrs.define(eq=False)
-class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
+class MQTTEventBroker(BaseExternalEventBroker):
"""
An event broker that uses an MQTT (v3.1 or v5) broker to broadcast events.
@@ -26,7 +26,6 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
.. _paho-mqtt: https://pypi.org/project/paho-mqtt/
:param client: a paho-mqtt client
- :param serializer: the serializer used to (de)serialize events for transport
:param host: host name or IP address to connect to
:param port: TCP port number to connect to
:param topic: topic on which to send the messages
@@ -35,16 +34,17 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
"""
client: Client = attrs.field(factory=Client)
- serializer: Serializer = attrs.field(factory=JSONSerializer)
host: str = attrs.field(kw_only=True, default="localhost")
port: int = attrs.field(kw_only=True, default=1883)
topic: str = attrs.field(kw_only=True, default="apscheduler")
subscribe_qos: int = attrs.field(kw_only=True, default=0)
publish_qos: int = attrs.field(kw_only=True, default=0)
+ _portal: BlockingPortal = attrs.field(init=False)
_ready_future: Future[None] = attrs.field(init=False)
- def start(self) -> None:
- super().start()
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ await super().start(exit_stack)
+ self._portal = await exit_stack.enter_async_context(BlockingPortal())
self._ready_future = Future()
self.client.on_connect = self._on_connect
self.client.on_connect_fail = self._on_connect_fail
@@ -53,12 +53,9 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
self.client.on_subscribe = self._on_subscribe
self.client.connect(self.host, self.port)
self.client.loop_start()
- self._ready_future.result(10)
-
- def stop(self, *, force: bool = False) -> None:
- self.client.disconnect()
- self.client.loop_stop(force=force)
- super().stop()
+ exit_stack.push_async_callback(to_thread.run_sync, self.client.loop_stop)
+ await to_thread.run_sync(self._ready_future.result, 10)
+ exit_stack.push_async_callback(to_thread.run_sync, self.client.disconnect)
def _on_connect(
self,
@@ -97,8 +94,10 @@ class MQTTEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
def _on_message(self, client: Client, userdata: Any, msg: MQTTMessage) -> None:
event = self.reconstitute_event(msg.payload)
if event is not None:
- self.publish_local(event)
+ self._portal.call(self.publish_local, event)
- def publish(self, event: Event) -> None:
+ async def publish(self, event: Event) -> None:
notification = self.generate_notification(event)
- self.client.publish(self.topic, notification, qos=self.publish_qos)
+ await to_thread.run_sync(
+ lambda: self.client.publish(self.topic, notification, qos=self.publish_qos)
+ )
diff --git a/src/apscheduler/eventbrokers/redis.py b/src/apscheduler/eventbrokers/redis.py
index 6683276..10d2343 100644
--- a/src/apscheduler/eventbrokers/redis.py
+++ b/src/apscheduler/eventbrokers/redis.py
@@ -1,22 +1,22 @@
from __future__ import annotations
-from threading import Thread
+from asyncio import CancelledError
+from contextlib import AsyncExitStack
+import anyio
import attrs
import tenacity
-from redis import ConnectionError, ConnectionPool, Redis, RedisCluster
-from redis.client import PubSub
+from redis import ConnectionError
+from redis.asyncio import Redis, RedisCluster
+from redis.asyncio.client import PubSub
+from redis.asyncio.connection import ConnectionPool
-from .. import RetrySettings
from .._events import Event
-from ..abc import Serializer
-from ..serializers.json import JSONSerializer
-from .base import DistributedEventBrokerMixin
-from .local import LocalEventBroker
+from .base import BaseExternalEventBroker
@attrs.define(eq=False)
-class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
+class RedisEventBroker(BaseExternalEventBroker):
"""
An event broker that uses a Redis server to broadcast events.
@@ -24,8 +24,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
.. _redis: https://pypi.org/project/redis/
- :param client: a (synchronous) Redis client
- :param serializer: the serializer used to (de)serialize events for transport
+ :param client: an asynchronous Redis client
:param channel: channel on which to send the messages
:param stop_check_interval: interval (in seconds) on which the channel listener
should check if it should stop (higher values mean slower reaction time but less
@@ -33,12 +32,9 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
"""
client: Redis | RedisCluster
- serializer: Serializer = attrs.field(factory=JSONSerializer)
channel: str = attrs.field(kw_only=True, default="apscheduler")
stop_check_interval: float = attrs.field(kw_only=True, default=1)
- retry_settings: RetrySettings = attrs.field(default=RetrySettings())
_stopped: bool = attrs.field(init=False, default=True)
- _thread: Thread = attrs.field(init=False)
@classmethod
def from_url(cls, url: str, **kwargs) -> RedisEventBroker:
@@ -54,7 +50,7 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
client = Redis(connection_pool=pool)
return cls(client, **kwargs)
- def _retry(self) -> tenacity.Retrying:
+ def _retry(self) -> tenacity.AsyncRetrying:
def after_attempt(retry_state: tenacity.RetryCallState) -> None:
self._logger.warning(
f"{self.__class__.__name__}: connection failure "
@@ -62,40 +58,32 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
f"{retry_state.outcome.exception()}",
)
- return tenacity.Retrying(
+ return tenacity.AsyncRetrying(
stop=self.retry_settings.stop,
wait=self.retry_settings.wait,
retry=tenacity.retry_if_exception_type(ConnectionError),
after=after_attempt,
+ sleep=anyio.sleep,
reraise=True,
)
- def start(self) -> None:
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ await super().start(exit_stack)
pubsub = self.client.pubsub()
- pubsub.subscribe(self.channel)
+ await pubsub.subscribe(self.channel)
+
self._stopped = False
- super().start()
- self._thread = Thread(
- target=self._listen_messages,
- args=[pubsub],
- daemon=True,
- name="Redis subscriber",
+ exit_stack.callback(setattr, self, "_stopped", True)
+ self._task_group.start_soon(
+ self._listen_messages, pubsub, name="Redis subscriber"
)
- self._thread.start()
-
- def stop(self, *, force: bool = False) -> None:
- self._stopped = True
- if not force:
- self._thread.join(5)
- super().stop(force=force)
-
- def _listen_messages(self, pubsub: PubSub) -> None:
+ async def _listen_messages(self, pubsub: PubSub) -> None:
while not self._stopped:
try:
- for attempt in self._retry():
+ async for attempt in self._retry():
with attempt:
- msg = pubsub.get_message(
+ msg = await pubsub.get_message(
ignore_subscribe_messages=True,
timeout=self.stop_check_interval,
)
@@ -103,16 +91,19 @@ class RedisEventBroker(LocalEventBroker, DistributedEventBrokerMixin):
if msg and isinstance(msg["data"], bytes):
event = self.reconstitute_event(msg["data"])
if event is not None:
- self.publish_local(event)
- except Exception:
- self._logger.exception(f"{self.__class__.__name__} listener crashed")
- pubsub.close()
+ await self.publish_local(event)
+ except Exception as exc:
+ # CancelledError is a subclass of Exception in Python 3.7
+ if not isinstance(exc, CancelledError):
+ self._logger.exception(
+ f"{self.__class__.__name__} listener crashed"
+ )
+
+ await pubsub.close()
raise
- pubsub.close()
-
- def publish(self, event: Event) -> None:
+ async def publish(self, event: Event) -> None:
notification = self.generate_notification(event)
- for attempt in self._retry():
+ async for attempt in self._retry():
with attempt:
- self.client.publish(self.channel, notification)
+ await self.client.publish(self.channel, notification)
diff --git a/src/apscheduler/workers/__init__.py b/src/apscheduler/executors/__init__.py
index e69de29..e69de29 100644
--- a/src/apscheduler/workers/__init__.py
+++ b/src/apscheduler/executors/__init__.py
diff --git a/src/apscheduler/executors/async_.py b/src/apscheduler/executors/async_.py
new file mode 100644
index 0000000..c7924ad
--- /dev/null
+++ b/src/apscheduler/executors/async_.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+from collections.abc import Callable
+from inspect import isawaitable
+from typing import Any
+
+from .._structures import Job
+from ..abc import JobExecutor
+
+
+class AsyncJobExecutor(JobExecutor):
+ """
+ Executes functions directly on the event loop thread.
+
+ If the function returns a coroutine object (or another kind of awaitable), that is
+ awaited on and its return value is used as the job's return value.
+ """
+
+ async def run_job(self, func: Callable[..., Any], job: Job) -> Any:
+ retval = func(*job.args, **job.kwargs)
+ if isawaitable(retval):
+ retval = await retval
+
+ return retval
diff --git a/src/apscheduler/executors/subprocess.py b/src/apscheduler/executors/subprocess.py
new file mode 100644
index 0000000..e766e71
--- /dev/null
+++ b/src/apscheduler/executors/subprocess.py
@@ -0,0 +1,33 @@
+from __future__ import annotations
+
+from collections.abc import Callable
+from contextlib import AsyncExitStack
+from functools import partial
+from typing import Any
+
+import attrs
+from anyio import CapacityLimiter, to_process
+
+from .._structures import Job
+from ..abc import JobExecutor
+
+
+@attrs.define(eq=False, kw_only=True)
+class ProcessPoolJobExecutor(JobExecutor):
+ """
+ Executes functions in a process pool.
+
+ :param max_workers: the maximum number of worker processes to keep
+ """
+
+ max_workers: int = 40
+ _limiter: CapacityLimiter = attrs.field(init=False)
+
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ self._limiter = CapacityLimiter(self.max_workers)
+
+ async def run_job(self, func: Callable[..., Any], job: Job) -> Any:
+ wrapped = partial(func, *job.args, **job.kwargs)
+ return await to_process.run_sync(
+ wrapped, cancellable=True, limiter=self._limiter
+ )
diff --git a/src/apscheduler/executors/thread.py b/src/apscheduler/executors/thread.py
new file mode 100644
index 0000000..9774093
--- /dev/null
+++ b/src/apscheduler/executors/thread.py
@@ -0,0 +1,31 @@
+from __future__ import annotations
+
+from collections.abc import Callable
+from contextlib import AsyncExitStack
+from functools import partial
+from typing import Any
+
+import attrs
+from anyio import CapacityLimiter, to_thread
+
+from .._structures import Job
+from ..abc import JobExecutor
+
+
+@attrs.define(eq=False, kw_only=True)
+class ThreadPoolJobExecutor(JobExecutor):
+ """
+ Executes functions in a thread pool.
+
+ :param max_workers: the maximum number of worker threads to keep
+ """
+
+ max_workers: int = 40
+ _limiter: CapacityLimiter = attrs.field(init=False)
+
+ async def start(self, exit_stack: AsyncExitStack) -> None:
+ self._limiter = CapacityLimiter(self.max_workers)
+
+ async def run_job(self, func: Callable[..., Any], job: Job) -> Any:
+ wrapped = partial(func, *job.args, **job.kwargs)
+ return await to_thread.run_sync(wrapped, limiter=self._limiter)
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index 75972b5..44fee27 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -5,6 +5,7 @@ import platform
import random
import sys
from asyncio import CancelledError
+from collections.abc import MutableMapping
from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
@@ -16,9 +17,9 @@ import anyio
import attrs
from anyio import TASK_STATUS_IGNORED, create_task_group, move_on_after
from anyio.abc import TaskGroup, TaskStatus
+from attr.validators import instance_of
-from .._context import current_scheduler
-from .._converters import as_async_datastore, as_async_eventbroker
+from .._context import current_async_scheduler
from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
from .._events import (
Event,
@@ -35,11 +36,14 @@ from .._exceptions import (
ScheduleLookupError,
)
from .._structures import Job, JobResult, Schedule, Task
-from ..abc import AsyncDataStore, AsyncEventBroker, Subscription, Trigger
+from .._worker import Worker
+from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger
from ..datastores.memory import MemoryDataStore
-from ..eventbrokers.async_local import LocalAsyncEventBroker
+from ..eventbrokers.local import LocalEventBroker
+from ..executors.async_ import AsyncJobExecutor
+from ..executors.subprocess import ProcessPoolJobExecutor
+from ..executors.thread import ThreadPoolJobExecutor
from ..marshalling import callable_to_ref
-from ..workers.async_ import AsyncWorker
if sys.version_info >= (3, 11):
from typing import Self
@@ -54,19 +58,24 @@ _zero_timedelta = timedelta()
class AsyncScheduler:
"""An asynchronous (AnyIO based) scheduler implementation."""
- data_store: AsyncDataStore = attrs.field(
- converter=as_async_datastore, factory=MemoryDataStore
+ data_store: DataStore = attrs.field(
+ validator=instance_of(DataStore), factory=MemoryDataStore
)
- event_broker: AsyncEventBroker = attrs.field(
- converter=as_async_eventbroker, factory=LocalAsyncEventBroker
+ event_broker: EventBroker = attrs.field(
+ validator=instance_of(EventBroker), factory=LocalEventBroker
)
identity: str = attrs.field(kw_only=True, default=None)
- start_worker: bool = attrs.field(kw_only=True, default=True)
+ process_jobs: bool = attrs.field(kw_only=True, default=True)
+ job_executors: MutableMapping[str, JobExecutor] | None = attrs.field(
+ kw_only=True, default=None
+ )
+ default_job_executor: str | None = attrs.field(kw_only=True, default=None)
+ process_schedules: bool = attrs.field(kw_only=True, default=True)
logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
_state: RunState = attrs.field(init=False, default=RunState.stopped)
_task_group: TaskGroup | None = attrs.field(init=False, default=None)
- _exit_stack: AsyncExitStack | None = attrs.field(init=False, default=None)
+ _exit_stack: AsyncExitStack = attrs.field(init=False, factory=AsyncExitStack)
_services_initialized: bool = attrs.field(init=False, default=False)
_wakeup_event: anyio.Event = attrs.field(init=False)
_wakeup_deadline: datetime | None = attrs.field(init=False, default=None)
@@ -76,13 +85,32 @@ class AsyncScheduler:
if not self.identity:
self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
+ if not self.job_executors:
+ self.job_executors = {
+ "async": AsyncJobExecutor(),
+ "threadpool": ThreadPoolJobExecutor(),
+ "processpool": ProcessPoolJobExecutor(),
+ }
+
+ if not self.default_job_executor:
+ self.default_job_executor = next(iter(self.job_executors))
+ elif self.default_job_executor not in self.job_executors:
+ raise ValueError(
+ "default_job_executor must be one of the given job executors"
+ )
+
async def __aenter__(self: Self) -> Self:
- self._exit_stack = AsyncExitStack()
await self._exit_stack.__aenter__()
- await self._ensure_services_ready(self._exit_stack)
- self._task_group = await self._exit_stack.enter_async_context(
- create_task_group()
- )
+ try:
+ await self._ensure_services_initialized(self._exit_stack)
+ self._task_group = await self._exit_stack.enter_async_context(
+ create_task_group()
+ )
+ self._exit_stack.callback(setattr, self, "_task_group", None)
+ except BaseException as exc:
+ await self._exit_stack.__aexit__(type(exc), exc, exc.__traceback__)
+ raise
+
return self
async def __aexit__(
@@ -93,27 +121,25 @@ class AsyncScheduler:
) -> None:
await self.stop()
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
- self._task_group = None
- async def _ensure_services_ready(self, exit_stack: AsyncExitStack) -> None:
+ async def _ensure_services_initialized(self, exit_stack: AsyncExitStack) -> None:
"""Ensure that the data store and event broker have been initialized."""
if not self._services_initialized:
self._services_initialized = True
exit_stack.callback(setattr, self, "_services_initialized", False)
- # Initialize the event broker
- await self.event_broker.start()
- exit_stack.push_async_exit(
- lambda *exc_info: self.event_broker.stop(force=exc_info[0] is not None)
- )
+ await self.event_broker.start(exit_stack)
+ await self.data_store.start(exit_stack, self.event_broker)
- # Initialize the data store
- await self.data_store.start(self.event_broker)
- exit_stack.push_async_exit(
- lambda *exc_info: self.data_store.stop(force=exc_info[0] is not None)
+ def _check_initialized(self) -> None:
+ if not self._services_initialized:
+ raise RuntimeError(
+ "The scheduler has not been initialized yet. Use the scheduler as an "
+ "async context manager (async with ...) in order to call methods other "
+ "than run_until_complete()."
)
- def _schedule_added_or_modified(self, event: Event) -> None:
+ async def _schedule_added_or_modified(self, event: Event) -> None:
event_ = cast("ScheduleAdded | ScheduleUpdated", event)
if not self._wakeup_deadline or (
event_.next_fire_time and event_.next_fire_time < self._wakeup_deadline
@@ -128,6 +154,35 @@ class AsyncScheduler:
"""The current running state of the scheduler."""
return self._state
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ one_shot: bool = False,
+ is_async: bool = True,
+ ) -> Subscription:
+ """
+ Subscribe to events.
+
+ To unsubscribe, call the :meth:`Subscription.unsubscribe` method on the returned
+ object.
+
+ :param callback: callable to be called with the event object when an event is
+ published
+ :param event_types: an iterable of concrete Event classes to subscribe to
+ :param one_shot: if ``True``, automatically unsubscribe after the first matching
+ event
+ :param is_async: ``True`` if the (synchronous) callback should be called on the
+ event loop thread, ``False`` if it should be called in a worker thread.
+ If the callback is a coroutine function, this flag is ignored.
+
+ """
+ self._check_initialized()
+ return self.event_broker.subscribe(
+ callback, event_types, is_async=is_async, one_shot=one_shot
+ )
+
async def add_schedule(
self,
func_or_task_id: str | Callable,
@@ -136,6 +191,7 @@ class AsyncScheduler:
id: str | None = None,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
coalesce: CoalescePolicy = CoalescePolicy.latest,
misfire_grace_time: float | timedelta | None = None,
max_jitter: float | timedelta | None = None,
@@ -152,6 +208,7 @@ class AsyncScheduler:
based ID will be assigned)
:param args: positional arguments to be passed to the task function
:param kwargs: keyword arguments to be passed to the task function
+ :param job_executor: name of the job executor to run the task with
:param coalesce: determines what to do when processing the schedule if multiple
fire times have become due for this schedule since the last processing
:param misfire_grace_time: maximum number of seconds the scheduled job's actual
@@ -165,6 +222,7 @@ class AsyncScheduler:
:return: the ID of the newly added schedule
"""
+ self._check_initialized()
id = id or str(uuid4())
args = tuple(args or ())
kwargs = dict(kwargs or {})
@@ -173,7 +231,11 @@ class AsyncScheduler:
misfire_grace_time = timedelta(seconds=misfire_grace_time)
if callable(func_or_task_id):
- task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
+ task = Task(
+ id=callable_to_ref(func_or_task_id),
+ func=func_or_task_id,
+ executor=job_executor or self.default_job_executor,
+ )
await self.data_store.add_task(task)
else:
task = await self.data_store.get_task(func_or_task_id)
@@ -207,6 +269,7 @@ class AsyncScheduler:
:raises ScheduleLookupError: if the schedule could not be found
"""
+ self._check_initialized()
schedules = await self.data_store.get_schedules({id})
if schedules:
return schedules[0]
@@ -220,6 +283,7 @@ class AsyncScheduler:
:return: a list of schedules, in an unspecified order
"""
+ self._check_initialized()
return await self.data_store.get_schedules()
async def remove_schedule(self, id: str) -> None:
@@ -229,6 +293,7 @@ class AsyncScheduler:
:param id: the unique identifier of the schedule
"""
+ self._check_initialized()
await self.data_store.remove_schedules({id})
async def add_job(
@@ -237,6 +302,7 @@ class AsyncScheduler:
*,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
tags: Iterable[str] | None = None,
result_expiration_time: timedelta | float = 0,
) -> UUID:
@@ -244,8 +310,10 @@ class AsyncScheduler:
Add a job to the data store.
:param func_or_task_id:
+ :param job_executor: name of the job executor to run the task with
:param args: positional arguments to call the target callable with
:param kwargs: keyword arguments to call the target callable with
+ :param job_executor: name of the job executor to run the task with
:param tags: strings that can be used to categorize and filter the job
:param result_expiration_time: the minimum time (as seconds, or timedelta) to
keep the result of the job available for fetching (the result won't be
@@ -253,8 +321,13 @@ class AsyncScheduler:
:return: the ID of the newly created job
"""
+ self._check_initialized()
if callable(func_or_task_id):
- task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
+ task = Task(
+ id=callable_to_ref(func_or_task_id),
+ func=func_or_task_id,
+ executor=job_executor or self.default_job_executor,
+ )
await self.data_store.add_task(task)
else:
task = await self.data_store.get_task(func_or_task_id)
@@ -280,13 +353,14 @@ class AsyncScheduler:
the data store
"""
+ self._check_initialized()
wait_event = anyio.Event()
def listener(event: JobReleased) -> None:
if event.job_id == job_id:
wait_event.set()
- with self.data_store.events.subscribe(listener, {JobReleased}):
+ with self.event_broker.subscribe(listener, {JobReleased}):
result = await self.data_store.get_job_result(job_id)
if result:
return result
@@ -303,6 +377,7 @@ class AsyncScheduler:
*,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
tags: Iterable[str] | None = (),
) -> Any:
"""
@@ -314,10 +389,12 @@ class AsyncScheduler:
definition
:param args: positional arguments to be passed to the task function
:param kwargs: keyword arguments to be passed to the task function
+ :param job_executor: name of the job executor to run the task with
:param tags: strings that can be used to categorize and filter the job
:returns: the return value of the task function
"""
+ self._check_initialized()
job_complete_event = anyio.Event()
def listener(event: JobReleased) -> None:
@@ -325,11 +402,12 @@ class AsyncScheduler:
job_complete_event.set()
job_id: UUID | None = None
- with self.data_store.events.subscribe(listener, {JobReleased}):
+ with self.event_broker.subscribe(listener, {JobReleased}):
job_id = await self.add_job(
func_or_task_id,
args=args,
kwargs=kwargs,
+ job_executor=job_executor,
tags=tags,
result_expiration_time=timedelta(minutes=15),
)
@@ -378,12 +456,7 @@ class AsyncScheduler:
await event.wait()
async def start_in_background(self) -> None:
- if self._task_group is None:
- raise RuntimeError(
- "The scheduler must be used as an async context manager (async with "
- "...) in order to be startable in the background"
- )
-
+ self._check_initialized()
await self._task_group.start(self.run_until_stopped)
async def run_until_stopped(
@@ -398,7 +471,7 @@ class AsyncScheduler:
self._state = RunState.starting
async with AsyncExitStack() as exit_stack:
self._wakeup_event = anyio.Event()
- await self._ensure_services_ready(exit_stack)
+ await self._ensure_services_initialized(exit_stack)
# Wake up the scheduler if the data store emits a significant schedule event
exit_stack.enter_context(
@@ -407,14 +480,16 @@ class AsyncScheduler:
)
)
+ # Set this scheduler as the current scheduler
+ token = current_async_scheduler.set(self)
+ exit_stack.callback(current_async_scheduler.reset, token)
+
# Start the built-in worker, if configured to do so
- if self.start_worker:
- token = current_scheduler.set(self)
- exit_stack.callback(current_scheduler.reset, token)
- worker = AsyncWorker(
- self.data_store, self.event_broker, is_internal=True
+ if self.process_jobs:
+ worker = Worker(job_executors=self.job_executors)
+ await worker.start(
+ exit_stack, self.data_store, self.event_broker, self.identity
)
- await exit_stack.enter_async_context(worker)
# Signal that the scheduler has started
self._state = RunState.started
diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py
index d98161a..3a812a4 100644
--- a/src/apscheduler/schedulers/sync.py
+++ b/src/apscheduler/schedulers/sync.py
@@ -1,77 +1,100 @@
from __future__ import annotations
import atexit
-import os
-import platform
-import random
+import logging
import sys
import threading
-from concurrent.futures import Future
+from collections.abc import MutableMapping
from contextlib import ExitStack
-from datetime import datetime, timedelta, timezone
-from logging import Logger, getLogger
+from datetime import timedelta
+from functools import partial
+from logging import Logger
from types import TracebackType
-from typing import Any, Callable, Iterable, Mapping, cast
-from uuid import UUID, uuid4
-
-import attrs
-
-from .._context import current_scheduler
-from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
-from .._events import (
- Event,
- JobReleased,
- ScheduleAdded,
- SchedulerStarted,
- SchedulerStopped,
- ScheduleUpdated,
-)
-from .._exceptions import (
- JobCancelled,
- JobDeadlineMissed,
- JobLookupError,
- ScheduleLookupError,
-)
-from .._structures import Job, JobResult, Schedule, Task
-from ..abc import DataStore, EventBroker, Trigger
-from ..datastores.memory import MemoryDataStore
-from ..eventbrokers.local import LocalEventBroker
-from ..marshalling import callable_to_ref
-from ..workers.sync import Worker
+from typing import Any, Callable, Iterable, Mapping
+from uuid import UUID
+
+from anyio import start_blocking_portal
+from anyio.from_thread import BlockingPortal
+
+from .. import Event, current_scheduler
+from .._enums import CoalescePolicy, ConflictPolicy, RunState
+from .._structures import JobResult, Schedule
+from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger
+from .async_ import AsyncScheduler
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
-_microsecond_delta = timedelta(microseconds=1)
-_zero_timedelta = timedelta()
-
-@attrs.define(eq=False)
class Scheduler:
"""A synchronous scheduler implementation."""
- data_store: DataStore = attrs.field(factory=MemoryDataStore)
- event_broker: EventBroker = attrs.field(factory=LocalEventBroker)
- identity: str = attrs.field(kw_only=True, default=None)
- start_worker: bool = attrs.field(kw_only=True, default=True)
- logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
+ def __init__(
+ self,
+ data_store: DataStore | None = None,
+ event_broker: EventBroker | None = None,
+ *,
+ identity: str | None = None,
+ process_schedules: bool = True,
+ start_worker: bool = True,
+ job_executors: Mapping[str, JobExecutor] | None = None,
+ default_job_executor: str | None = None,
+ logger: Logger | None = None,
+ ):
+ kwargs: dict[str, Any] = {}
+ if data_store is not None:
+ kwargs["data_store"] = data_store
+ if event_broker is not None:
+ kwargs["event_broker"] = event_broker
+
+ if not default_job_executor and not job_executors:
+ default_job_executor = "threadpool"
+
+ self._async_scheduler = AsyncScheduler(
+ identity=identity,
+ process_schedules=process_schedules,
+ process_jobs=start_worker,
+ job_executors=job_executors,
+ default_job_executor=default_job_executor,
+ logger=logger or logging.getLogger(__name__),
+ **kwargs,
+ )
+ self._exit_stack = ExitStack()
+ self._portal: BlockingPortal | None = None
+ self._lock = threading.RLock()
- _state: RunState = attrs.field(init=False, default=RunState.stopped)
- _thread: threading.Thread | None = attrs.field(init=False, default=None)
- _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event)
- _wakeup_deadline: datetime | None = attrs.field(init=False, default=None)
- _services_initialized: bool = attrs.field(init=False, default=False)
- _exit_stack: ExitStack | None = attrs.field(init=False, default=None)
- _lock: threading.RLock = attrs.field(init=False, factory=threading.RLock)
+ @property
+ def data_store(self) -> DataStore:
+ return self._async_scheduler.data_store
- def __attrs_post_init__(self) -> None:
- if not self.identity:
- self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
+ @property
+ def event_broker(self) -> EventBroker:
+ return self._async_scheduler.event_broker
+
+ @property
+ def identity(self) -> str:
+ return self._async_scheduler.identity
+
+ @property
+ def process_schedules(self) -> bool:
+ return self._async_scheduler.process_schedules
+
+ @property
+ def start_worker(self) -> bool:
+ return self._async_scheduler.process_jobs
+
+ @property
+ def job_executors(self) -> MutableMapping[str, JobExecutor]:
+ return self._async_scheduler.job_executors
+
+ @property
+ def state(self) -> RunState:
+ """The current running state of the scheduler."""
+ return self._async_scheduler.state
def __enter__(self: Self) -> Self:
- self._exit_stack = ExitStack()
self._ensure_services_ready(self._exit_stack)
return self
@@ -81,13 +104,12 @@ class Scheduler:
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
- self.stop()
self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
def _ensure_services_ready(self, exit_stack: ExitStack | None = None) -> None:
- """Ensure that the data store and event broker have been initialized."""
+ """Ensure that the underlying asynchronous scheduler has been initialized."""
with self._lock:
- if not self._services_initialized:
+ if self._portal is None:
if exit_stack is None:
if self._exit_stack is None:
self._exit_stack = exit_stack = ExitStack()
@@ -95,43 +117,39 @@ class Scheduler:
else:
exit_stack = self._exit_stack
- self._services_initialized = True
- exit_stack.callback(setattr, self, "_services_initialized", False)
+ # Set this scheduler as the current synchronous scheduler
+ token = current_scheduler.set(self)
+ exit_stack.callback(current_scheduler.reset, token)
- self.event_broker.start()
- exit_stack.push(
- lambda *exc_info: self.event_broker.stop(
- force=exc_info[0] is not None
- )
+ self._portal = exit_stack.enter_context(start_blocking_portal())
+ exit_stack.callback(setattr, self, "_portal", None)
+ exit_stack.enter_context(
+ self._portal.wrap_async_context_manager(self._async_scheduler)
)
- # Initialize the data store
- self.data_store.start(self.event_broker)
- exit_stack.push(
- lambda *exc_info: self.data_store.stop(
- force=exc_info[0] is not None
- )
- )
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ one_shot: bool = False,
+ ) -> Subscription:
+ """
+ Subscribe to events.
- def _schedule_added_or_modified(self, event: Event) -> None:
- event_ = cast("ScheduleAdded | ScheduleUpdated", event)
- if not self._wakeup_deadline or (
- event_.next_fire_time and event_.next_fire_time < self._wakeup_deadline
- ):
- self.logger.debug(
- "Detected a %s event – waking up the scheduler", type(event).__name__
- )
- self._wakeup_event.set()
+ To unsubscribe, call the :meth:`Subscription.unsubscribe` method on the returned
+ object.
- def _join_thread(self) -> None:
- if self._thread:
- self._thread.join()
- self._thread = None
+ :param callback: callable to be called with the event object when an event is
+ published
+ :param event_types: an iterable of concrete Event classes to subscribe to
+ :param one_shot: if ``True``, automatically unsubscribe after the first matching
+ event
- @property
- def state(self) -> RunState:
- """The current running state of the scheduler."""
- return self._state
+ """
+ return self.data_store.event_broker.subscribe(
+ callback, event_types, is_async=False, one_shot=one_shot
+ )
def add_schedule(
self,
@@ -141,104 +159,42 @@ class Scheduler:
id: str | None = None,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
coalesce: CoalescePolicy = CoalescePolicy.latest,
misfire_grace_time: float | timedelta | None = None,
max_jitter: float | timedelta | None = None,
tags: Iterable[str] | None = None,
conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing,
) -> str:
- """
- Schedule a task to be run one or more times in the future.
-
- :param func_or_task_id: either a callable or an ID of an existing task
- definition
- :param trigger: determines the times when the task should be run
- :param id: an explicit identifier for the schedule (if omitted, a random, UUID
- based ID will be assigned)
- :param args: positional arguments to be passed to the task function
- :param kwargs: keyword arguments to be passed to the task function
- :param coalesce: determines what to do when processing the schedule if multiple
- fire times have become due for this schedule since the last processing
- :param misfire_grace_time: maximum number of seconds the scheduled job's actual
- run time is allowed to be late, compared to the scheduled run time
- :param max_jitter: maximum number of seconds to randomly add to the scheduled
- time for each job created from this schedule
- :param tags: strings that can be used to categorize and filter the schedule and
- its derivative jobs
- :param conflict_policy: determines what to do if a schedule with the same ID
- already exists in the data store
- :return: the ID of the newly added schedule
-
- """
self._ensure_services_ready()
- id = id or str(uuid4())
- args = tuple(args or ())
- kwargs = dict(kwargs or {})
- tags = frozenset(tags or ())
- if isinstance(misfire_grace_time, (int, float)):
- misfire_grace_time = timedelta(seconds=misfire_grace_time)
-
- if callable(func_or_task_id):
- task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
- self.data_store.add_task(task)
- else:
- task = self.data_store.get_task(func_or_task_id)
-
- schedule = Schedule(
- id=id,
- task_id=task.id,
- trigger=trigger,
- args=args,
- kwargs=kwargs,
- coalesce=coalesce,
- misfire_grace_time=misfire_grace_time,
- max_jitter=max_jitter,
- tags=tags,
- )
- schedule.next_fire_time = trigger.next()
- self.data_store.add_schedule(schedule, conflict_policy)
- self.logger.info(
- "Added new schedule (task=%r, trigger=%r); next run time at %s",
- task,
- trigger,
- schedule.next_fire_time,
+ return self._portal.call(
+ partial(
+ self._async_scheduler.add_schedule,
+ func_or_task_id,
+ trigger,
+ id=id,
+ args=args,
+ kwargs=kwargs,
+ job_executor=job_executor,
+ coalesce=coalesce,
+ misfire_grace_time=misfire_grace_time,
+ max_jitter=max_jitter,
+ tags=tags,
+ conflict_policy=conflict_policy,
+ )
)
- return schedule.id
def get_schedule(self, id: str) -> Schedule:
- """
- Retrieve a schedule from the data store.
-
- :param id: the unique identifier of the schedule
- :raises ScheduleLookupError: if the schedule could not be found
-
- """
self._ensure_services_ready()
- schedules = self.data_store.get_schedules({id})
- if schedules:
- return schedules[0]
- else:
- raise ScheduleLookupError(id)
+ return self._portal.call(self._async_scheduler.get_schedule, id)
def get_schedules(self) -> list[Schedule]:
- """
- Retrieve all schedules from the data store.
-
- :return: a list of schedules, in an unspecified order
-
- """
self._ensure_services_ready()
- return self.data_store.get_schedules()
+ return self._portal.call(self._async_scheduler.get_schedules)
def remove_schedule(self, id: str) -> None:
- """
- Remove the given schedule from the data store.
-
- :param id: the unique identifier of the schedule
-
- """
self._ensure_services_ready()
- self.data_store.remove_schedules({id})
+ self._portal.call(self._async_scheduler.remove_schedule, id)
def add_job(
self,
@@ -246,68 +202,28 @@ class Scheduler:
*,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
tags: Iterable[str] | None = None,
result_expiration_time: timedelta | float = 0,
) -> UUID:
- """
- Add a job to the data store.
-
- :param func_or_task_id: either a callable or an ID of an existing task
- definition
- :param args: positional arguments to be passed to the task function
- :param kwargs: keyword arguments to be passed to the task function
- :param tags: strings that can be used to categorize and filter the job
- :param result_expiration_time: the minimum time (as seconds, or timedelta) to
- keep the result of the job available for fetching (the result won't be
- saved at all if that time is 0)
- :return: the ID of the newly created job
-
- """
self._ensure_services_ready()
- if callable(func_or_task_id):
- task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
- self.data_store.add_task(task)
- else:
- task = self.data_store.get_task(func_or_task_id)
-
- job = Job(
- task_id=task.id,
- args=args or (),
- kwargs=kwargs or {},
- tags=tags or frozenset(),
- result_expiration_time=result_expiration_time,
+ return self._portal.call(
+ partial(
+ self._async_scheduler.add_job,
+ func_or_task_id,
+ args=args,
+ kwargs=kwargs,
+ job_executor=job_executor,
+ tags=tags,
+ result_expiration_time=result_expiration_time,
+ )
)
- self.data_store.add_job(job)
- return job.id
def get_job_result(self, job_id: UUID, *, wait: bool = True) -> JobResult:
- """
- Retrieve the result of a job.
-
- :param job_id: the ID of the job
- :param wait: if ``True``, wait until the job has ended (one way or another),
- ``False`` to raise an exception if the result is not yet available
- :raises JobLookupError: if ``wait=False`` and the job result does not exist in
- the data store
-
- """
self._ensure_services_ready()
- wait_event = threading.Event()
-
- def listener(event: JobReleased) -> None:
- if event.job_id == job_id:
- wait_event.set()
-
- with self.data_store.events.subscribe(listener, {JobReleased}, one_shot=True):
- result = self.data_store.get_job_result(job_id)
- if result:
- return result
- elif not wait:
- raise JobLookupError(job_id)
-
- wait_event.wait()
-
- return self.data_store.get_job_result(job_id)
+ return self._portal.call(
+ partial(self._async_scheduler.get_job_result, job_id, wait=wait)
+ )
def run_job(
self,
@@ -315,50 +231,20 @@ class Scheduler:
*,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
tags: Iterable[str] | None = (),
) -> Any:
- """
- Convenience method to add a job and then return its result.
-
- If the job raised an exception, that exception will be reraised here.
-
- :param func_or_task_id: either a callable or an ID of an existing task
- definition
- :param args: positional arguments to be passed to the task function
- :param kwargs: keyword arguments to be passed to the task function
- :param tags: strings that can be used to categorize and filter the job
- :returns: the return value of the task function
-
- """
self._ensure_services_ready()
- job_complete_event = threading.Event()
-
- def listener(event: JobReleased) -> None:
- if event.job_id == job_id:
- job_complete_event.set()
-
- job_id: UUID | None = None
- with self.data_store.events.subscribe(listener, {JobReleased}):
- job_id = self.add_job(
+ return self._portal.call(
+ partial(
+ self._async_scheduler.run_job,
func_or_task_id,
args=args,
kwargs=kwargs,
+ job_executor=job_executor,
tags=tags,
- result_expiration_time=timedelta(minutes=15),
)
- job_complete_event.wait()
-
- result = self.get_job_result(job_id)
- if result.outcome is JobOutcome.success:
- return result.return_value
- elif result.outcome is JobOutcome.error:
- raise result.exception
- elif result.outcome is JobOutcome.missed_start_deadline:
- raise JobDeadlineMissed
- elif result.outcome is JobOutcome.cancelled:
- raise JobCancelled
- else:
- raise RuntimeError(f"Unknown job outcome: {result.outcome}")
+ )
def start_in_background(self) -> None:
"""
@@ -370,241 +256,19 @@ class Scheduler:
:raises RuntimeError: if the scheduler is not in the ``stopped`` state
"""
- with self._lock:
- if self._state is not RunState.stopped:
- raise RuntimeError(
- f'Cannot start the scheduler when it is in the "{self._state}" '
- f"state"
- )
-
- self._state = RunState.starting
-
- start_future: Future[None] = Future()
- self._thread = threading.Thread(
- target=self._run, args=[start_future], daemon=True
- )
- self._thread.start()
- try:
- start_future.result()
- except BaseException:
- self._thread = None
- raise
-
- self._exit_stack.callback(self._join_thread)
- self._exit_stack.callback(self.stop)
+ self._ensure_services_ready()
+ self._portal.call(self._async_scheduler.start_in_background)
def stop(self) -> None:
- """
- Signal the scheduler that it should stop processing schedules.
-
- This method does not wait for the scheduler to actually stop.
- For that, see :meth:`wait_until_stopped`.
-
- """
- with self._lock:
- if self._state is RunState.started:
- self._state = RunState.stopping
- self._wakeup_event.set()
+ if self._portal is not None:
+ self._portal.call(self._async_scheduler.stop)
def wait_until_stopped(self) -> None:
- """
- Wait until the scheduler is in the "stopped" or "stopping" state.
-
- If the scheduler is already stopped or in the process of stopping, this method
- returns immediately. Otherwise, it waits until the scheduler posts the
- ``SchedulerStopped`` event.
-
- """
- with self._lock:
- if self._state in (RunState.stopped, RunState.stopping):
- return
-
- event = threading.Event()
- sub = self.event_broker.subscribe(
- lambda ev: event.set(), {SchedulerStopped}, one_shot=True
- )
-
- with sub:
- event.wait()
+ if self._portal is not None:
+ self._portal.call(self._async_scheduler.wait_until_stopped)
def run_until_stopped(self) -> None:
- """
- Run the scheduler (and its internal worker) until it is explicitly stopped.
-
- This method will only return if :meth:`stop` is called.
-
- """
- with self._lock:
- if self._state is not RunState.stopped:
- raise RuntimeError(
- f'Cannot start the scheduler when it is in the "{self._state}" '
- f"state"
- )
-
- self._state = RunState.starting
-
- self._run(None)
-
- def _run(self, start_future: Future[None] | None) -> None:
- assert self._state is RunState.starting
- with self._exit_stack.pop_all() as exit_stack:
- try:
- self._ensure_services_ready(exit_stack)
-
- # Wake up the scheduler if the data store emits a significant schedule
- # event
- exit_stack.enter_context(
- self.data_store.events.subscribe(
- self._schedule_added_or_modified,
- {ScheduleAdded, ScheduleUpdated},
- )
- )
-
- # Start the built-in worker, if configured to do so
- if self.start_worker:
- token = current_scheduler.set(self)
- exit_stack.callback(current_scheduler.reset, token)
- worker = Worker(
- self.data_store, self.event_broker, is_internal=True
- )
- exit_stack.enter_context(worker)
-
- # Signal that the scheduler has started
- self._state = RunState.started
- self.event_broker.publish_local(SchedulerStarted())
- except BaseException as exc:
- if start_future:
- start_future.set_exception(exc)
- return
- else:
- raise
- else:
- if start_future:
- start_future.set_result(None)
-
- exception: BaseException | None = None
- try:
- while self._state is RunState.started:
- schedules = self.data_store.acquire_schedules(self.identity, 100)
- self.logger.debug(
- "Processing %d schedules retrieved from the data store",
- len(schedules),
- )
- now = datetime.now(timezone.utc)
- for schedule in schedules:
- # Calculate a next fire time for the schedule, if possible
- fire_times = [schedule.next_fire_time]
- calculate_next = schedule.trigger.next
- while True:
- try:
- fire_time = calculate_next()
- except Exception:
- self.logger.exception(
- "Error computing next fire time for schedule %r of "
- "task %r – removing schedule",
- schedule.id,
- schedule.task_id,
- )
- break
-
- # Stop if the calculated fire time is in the future
- if fire_time is None or fire_time > now:
- schedule.next_fire_time = fire_time
- break
-
- # Only keep all the fire times if coalesce policy = "all"
- if schedule.coalesce is CoalescePolicy.all:
- fire_times.append(fire_time)
- elif schedule.coalesce is CoalescePolicy.latest:
- fire_times[0] = fire_time
-
- # Add one or more jobs to the job queue
- max_jitter = (
- schedule.max_jitter.total_seconds()
- if schedule.max_jitter
- else 0
- )
- for i, fire_time in enumerate(fire_times):
- # Calculate a jitter if max_jitter > 0
- jitter = _zero_timedelta
- if max_jitter:
- if i + 1 < len(fire_times):
- next_fire_time = fire_times[i + 1]
- else:
- next_fire_time = schedule.next_fire_time
-
- if next_fire_time is not None:
- # Jitter must never be so high that it would cause
- # a fire time to equal or exceed the next fire time
- jitter_s = min(
- [
- max_jitter,
- (
- next_fire_time
- - fire_time
- - _microsecond_delta
- ).total_seconds(),
- ]
- )
- jitter = timedelta(
- seconds=random.uniform(0, jitter_s)
- )
- fire_time += jitter
-
- schedule.last_fire_time = fire_time
- job = Job(
- task_id=schedule.task_id,
- args=schedule.args,
- kwargs=schedule.kwargs,
- schedule_id=schedule.id,
- scheduled_fire_time=fire_time,
- jitter=jitter,
- start_deadline=schedule.next_deadline,
- tags=schedule.tags,
- )
- self.data_store.add_job(job)
-
- # Update the schedules (and release the scheduler's claim on them)
- self.data_store.release_schedules(self.identity, schedules)
-
- # If we received fewer schedules than the maximum amount, sleep
- # until the next schedule is due or the scheduler is explicitly
- # woken up
- wait_time = None
- if len(schedules) < 100:
- self._wakeup_deadline = (
- self.data_store.get_next_schedule_run_time()
- )
- if self._wakeup_deadline:
- wait_time = (
- self._wakeup_deadline - datetime.now(timezone.utc)
- ).total_seconds()
- self.logger.debug(
- "Sleeping %.3f seconds until the next fire time (%s)",
- wait_time,
- self._wakeup_deadline,
- )
- else:
- self.logger.debug("Waiting for any due schedules to appear")
-
- if self._wakeup_event.wait(wait_time):
- self._wakeup_event = threading.Event()
- else:
- self.logger.debug(
- "Processing more schedules on the next iteration"
- )
- except BaseException as exc:
- exception = exc
- raise
- finally:
- self._state = RunState.stopped
- if isinstance(exception, Exception):
- self.logger.exception("Scheduler crashed")
- elif exception:
- self.logger.info(
- f"Scheduler stopped due to {exception.__class__.__name__}"
- )
- else:
- self.logger.info("Scheduler stopped")
-
- self.event_broker.publish_local(SchedulerStopped(exception=exception))
+ with ExitStack() as exit_stack:
+ # Run the async scheduler
+ self._ensure_services_ready(exit_stack)
+ self._portal.call(self._async_scheduler.run_until_stopped)
diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py
deleted file mode 100644
index 4ef1bff..0000000
--- a/src/apscheduler/workers/async_.py
+++ /dev/null
@@ -1,251 +0,0 @@
-from __future__ import annotations
-
-import os
-import platform
-from asyncio import CancelledError
-from contextlib import AsyncExitStack
-from datetime import datetime, timezone
-from inspect import isawaitable
-from logging import Logger, getLogger
-from types import TracebackType
-from typing import Callable
-from uuid import UUID
-
-import anyio
-import attrs
-from anyio import (
- TASK_STATUS_IGNORED,
- create_task_group,
- get_cancelled_exc_class,
- move_on_after,
-)
-from anyio.abc import CancelScope, TaskGroup
-
-from .._context import current_job, current_worker
-from .._converters import as_async_datastore, as_async_eventbroker
-from .._enums import JobOutcome, RunState
-from .._events import JobAdded, JobReleased, WorkerStarted, WorkerStopped
-from .._structures import Job, JobInfo, JobResult
-from .._validators import positive_integer
-from ..abc import AsyncDataStore, AsyncEventBroker
-from ..eventbrokers.async_local import LocalAsyncEventBroker
-
-
-@attrs.define(eq=False)
-class AsyncWorker:
- """Runs jobs locally in a task group."""
-
- data_store: AsyncDataStore = attrs.field(converter=as_async_datastore)
- event_broker: AsyncEventBroker = attrs.field(
- converter=as_async_eventbroker, factory=LocalAsyncEventBroker
- )
- max_concurrent_jobs: int = attrs.field(
- kw_only=True, validator=positive_integer, default=100
- )
- identity: str = attrs.field(kw_only=True, default=None)
- logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
- # True if a scheduler owns this worker
- _is_internal: bool = attrs.field(kw_only=True, default=False)
-
- _state: RunState = attrs.field(init=False, default=RunState.stopped)
- _wakeup_event: anyio.Event = attrs.field(init=False)
- _task_group: TaskGroup = attrs.field(init=False)
- _acquired_jobs: set[Job] = attrs.field(init=False, factory=set)
- _running_jobs: set[UUID] = attrs.field(init=False, factory=set)
-
- def __attrs_post_init__(self) -> None:
- if not self.identity:
- self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
-
- async def __aenter__(self) -> AsyncWorker:
- self._task_group = create_task_group()
- await self._task_group.__aenter__()
- await self._task_group.start(self.run_until_stopped)
- return self
-
- async def __aexit__(
- self,
- exc_type: type[BaseException],
- exc_val: BaseException,
- exc_tb: TracebackType,
- ) -> None:
- self._state = RunState.stopping
- self._wakeup_event.set()
- await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
- del self._task_group
- del self._wakeup_event
-
- @property
- def state(self) -> RunState:
- """The current running state of the worker."""
- return self._state
-
- async def run_until_stopped(self, *, task_status=TASK_STATUS_IGNORED) -> None:
- """Run the worker until it is explicitly stopped."""
- if self._state is not RunState.stopped:
- raise RuntimeError(
- f'Cannot start the worker when it is in the "{self._state}" ' f"state"
- )
-
- self._state = RunState.starting
- self._wakeup_event = anyio.Event()
- async with AsyncExitStack() as exit_stack:
- if not self._is_internal:
- # Initialize the event broker
- await self.event_broker.start()
- exit_stack.push_async_exit(
- lambda *exc_info: self.event_broker.stop(
- force=exc_info[0] is not None
- )
- )
-
- # Initialize the data store
- await self.data_store.start(self.event_broker)
- exit_stack.push_async_exit(
- lambda *exc_info: self.data_store.stop(
- force=exc_info[0] is not None
- )
- )
-
- # Set the current worker
- token = current_worker.set(self)
- exit_stack.callback(current_worker.reset, token)
-
- # Wake up the worker if the data store emits a significant job event
- self.event_broker.subscribe(
- lambda event: self._wakeup_event.set(), {JobAdded}
- )
-
- # Signal that the worker has started
- self._state = RunState.started
- task_status.started()
- exception: BaseException | None = None
- try:
- await self.event_broker.publish_local(WorkerStarted())
-
- async with create_task_group() as tg:
- while self._state is RunState.started:
- limit = self.max_concurrent_jobs - len(self._running_jobs)
- jobs = await self.data_store.acquire_jobs(self.identity, limit)
- for job in jobs:
- task = await self.data_store.get_task(job.task_id)
- self._running_jobs.add(job.id)
- tg.start_soon(self._run_job, job, task.func)
-
- await self._wakeup_event.wait()
- self._wakeup_event = anyio.Event()
- except get_cancelled_exc_class():
- pass
- except BaseException as exc:
- exception = exc
- raise
- finally:
- self._state = RunState.stopped
-
- # CancelledError is a subclass of Exception in Python 3.7
- if not exception or isinstance(exception, CancelledError):
- self.logger.info("Worker stopped")
- elif isinstance(exception, Exception):
- self.logger.exception("Worker crashed")
- elif exception:
- self.logger.info(
- f"Worker stopped due to {exception.__class__.__name__}"
- )
-
- with move_on_after(3, shield=True):
- await self.event_broker.publish_local(
- WorkerStopped(exception=exception)
- )
-
- async def stop(self, *, force: bool = False) -> None:
- """
- Signal the worker that it should stop running jobs.
-
- This method does not wait for the worker to actually stop.
-
- """
- if self._state in (RunState.starting, RunState.started):
- self._state = RunState.stopping
- event = anyio.Event()
- self.event_broker.subscribe(
- lambda ev: event.set(), {WorkerStopped}, one_shot=True
- )
- if force:
- self._task_group.cancel_scope.cancel()
- else:
- self._wakeup_event.set()
-
- await event.wait()
-
- async def _run_job(self, job: Job, func: Callable) -> None:
- try:
- # Check if the job started before the deadline
- start_time = datetime.now(timezone.utc)
- if job.start_deadline is not None and start_time > job.start_deadline:
- result = JobResult.from_job(
- job,
- outcome=JobOutcome.missed_start_deadline,
- finished_at=start_time,
- )
- await self.data_store.release_job(self.identity, job.task_id, result)
- await self.event_broker.publish(
- JobReleased.from_result(result, self.identity)
- )
- return
-
- token = current_job.set(JobInfo.from_job(job))
- try:
- retval = func(*job.args, **job.kwargs)
- if isawaitable(retval):
- retval = await retval
- except get_cancelled_exc_class():
- self.logger.info("Job %s was cancelled", job.id)
- with CancelScope(shield=True):
- result = JobResult.from_job(
- job,
- outcome=JobOutcome.cancelled,
- )
- await self.data_store.release_job(
- self.identity, job.task_id, result
- )
- await self.event_broker.publish(
- JobReleased.from_result(result, self.identity)
- )
- except BaseException as exc:
- if isinstance(exc, Exception):
- self.logger.exception("Job %s raised an exception", job.id)
- else:
- self.logger.error(
- "Job %s was aborted due to %s", job.id, exc.__class__.__name__
- )
-
- result = JobResult.from_job(
- job,
- JobOutcome.error,
- exception=exc,
- )
- await self.data_store.release_job(
- self.identity,
- job.task_id,
- result,
- )
- await self.event_broker.publish(
- JobReleased.from_result(result, self.identity)
- )
- if not isinstance(exc, Exception):
- raise
- else:
- self.logger.info("Job %s completed successfully", job.id)
- result = JobResult.from_job(
- job,
- JobOutcome.success,
- return_value=retval,
- )
- await self.data_store.release_job(self.identity, job.task_id, result)
- await self.event_broker.publish(
- JobReleased.from_result(result, self.identity)
- )
- finally:
- current_job.reset(token)
- finally:
- self._running_jobs.remove(job.id)
diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py
deleted file mode 100644
index 41279d9..0000000
--- a/src/apscheduler/workers/sync.py
+++ /dev/null
@@ -1,257 +0,0 @@
-from __future__ import annotations
-
-import atexit
-import os
-import platform
-import threading
-from concurrent.futures import Future, ThreadPoolExecutor
-from contextlib import ExitStack
-from contextvars import copy_context
-from datetime import datetime, timezone
-from logging import Logger, getLogger
-from types import TracebackType
-from typing import Callable
-from uuid import UUID
-
-import attrs
-
-from .. import JobReleased
-from .._context import current_job, current_worker
-from .._enums import JobOutcome, RunState
-from .._events import JobAdded, WorkerStarted, WorkerStopped
-from .._structures import Job, JobInfo, JobResult
-from .._validators import positive_integer
-from ..abc import DataStore, EventBroker
-from ..eventbrokers.local import LocalEventBroker
-
-
-@attrs.define(eq=False)
-class Worker:
- """Runs jobs locally in a thread pool."""
-
- data_store: DataStore
- event_broker: EventBroker = attrs.field(factory=LocalEventBroker)
- max_concurrent_jobs: int = attrs.field(
- kw_only=True, validator=positive_integer, default=20
- )
- identity: str = attrs.field(kw_only=True, default=None)
- logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
- # True if a scheduler owns this worker
- _is_internal: bool = attrs.field(kw_only=True, default=False)
-
- _state: RunState = attrs.field(init=False, default=RunState.stopped)
- _thread: threading.Thread | None = attrs.field(init=False, default=None)
- _wakeup_event: threading.Event = attrs.field(init=False, factory=threading.Event)
- _executor: ThreadPoolExecutor = attrs.field(init=False)
- _acquired_jobs: set[Job] = attrs.field(init=False, factory=set)
- _running_jobs: set[UUID] = attrs.field(init=False, factory=set)
-
- def __attrs_post_init__(self) -> None:
- if not self.identity:
- self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
-
- def __enter__(self) -> Worker:
- self.start_in_background()
- return self
-
- def __exit__(
- self,
- exc_type: type[BaseException],
- exc_val: BaseException,
- exc_tb: TracebackType,
- ) -> None:
- self.stop()
-
- @property
- def state(self) -> RunState:
- """The current running state of the worker."""
- return self._state
-
- def start_in_background(self) -> None:
- """
- Launch the worker in a new thread.
-
- This method registers an :mod:`atexit` hook to shut down the worker and wait
- for the thread to finish.
-
- """
- start_future: Future[None] = Future()
- self._thread = threading.Thread(
- target=copy_context().run, args=[self._run, start_future], daemon=True
- )
- self._thread.start()
- try:
- start_future.result()
- except BaseException:
- self._thread = None
- raise
-
- atexit.register(self.stop)
-
- def stop(self) -> None:
- """
- Signal the worker that it should stop running jobs.
-
- This method does not wait for the worker to actually stop.
-
- """
- atexit.unregister(self.stop)
- if self._state is RunState.started:
- self._state = RunState.stopping
- self._wakeup_event.set()
-
- if threading.current_thread() != self._thread:
- self._thread.join()
- self._thread = None
-
- def run_until_stopped(self) -> None:
- """
- Run the worker until it is explicitly stopped.
-
- This method will only return if :meth:`stop` is called.
-
- """
- self._run(None)
-
- def _run(self, start_future: Future[None] | None) -> None:
- with ExitStack() as exit_stack:
- try:
- if self._state is not RunState.stopped:
- raise RuntimeError(
- f'Cannot start the worker when it is in the "{self._state}" '
- f"state"
- )
-
- if not self._is_internal:
- # Initialize the event broker
- self.event_broker.start()
- exit_stack.push(
- lambda *exc_info: self.event_broker.stop(
- force=exc_info[0] is not None
- )
- )
-
- # Initialize the data store
- self.data_store.start(self.event_broker)
- exit_stack.push(
- lambda *exc_info: self.data_store.stop(
- force=exc_info[0] is not None
- )
- )
-
- # Set the current worker
- token = current_worker.set(self)
- exit_stack.callback(current_worker.reset, token)
-
- # Wake up the worker if the data store emits a significant job event
- exit_stack.enter_context(
- self.event_broker.subscribe(
- lambda event: self._wakeup_event.set(), {JobAdded}
- )
- )
-
- # Initialize the thread pool
- executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs)
- exit_stack.enter_context(executor)
-
- # Signal that the worker has started
- self._state = RunState.started
- self.event_broker.publish_local(WorkerStarted())
- except BaseException as exc:
- if start_future:
- start_future.set_exception(exc)
- return
- else:
- raise
- else:
- if start_future:
- start_future.set_result(None)
-
- exception: BaseException | None = None
- try:
- while self._state is RunState.started:
- available_slots = self.max_concurrent_jobs - len(self._running_jobs)
- if available_slots:
- jobs = self.data_store.acquire_jobs(
- self.identity, available_slots
- )
- for job in jobs:
- task = self.data_store.get_task(job.task_id)
- self._running_jobs.add(job.id)
- executor.submit(
- copy_context().run, self._run_job, job, task.func
- )
-
- self._wakeup_event.wait()
- self._wakeup_event = threading.Event()
- except BaseException as exc:
- exception = exc
- raise
- finally:
- self._state = RunState.stopped
- if not exception:
- self.logger.info("Worker stopped")
- elif isinstance(exception, Exception):
- self.logger.exception("Worker crashed")
- elif exception:
- self.logger.info(
- f"Worker stopped due to {exception.__class__.__name__}"
- )
-
- self.event_broker.publish_local(WorkerStopped(exception=exception))
-
- def _run_job(self, job: Job, func: Callable) -> None:
- try:
- # Check if the job started before the deadline
- start_time = datetime.now(timezone.utc)
- if job.start_deadline is not None and start_time > job.start_deadline:
- result = JobResult.from_job(
- job, JobOutcome.missed_start_deadline, finished_at=start_time
- )
- self.event_broker.publish(
- JobReleased.from_result(result, self.identity)
- )
- self.data_store.release_job(self.identity, job.task_id, result)
- return
-
- token = current_job.set(JobInfo.from_job(job))
- try:
- retval = func(*job.args, **job.kwargs)
- except BaseException as exc:
- if isinstance(exc, Exception):
- self.logger.exception("Job %s raised an exception", job.id)
- else:
- self.logger.error(
- "Job %s was aborted due to %s", job.id, exc.__class__.__name__
- )
-
- result = JobResult.from_job(
- job,
- JobOutcome.error,
- exception=exc,
- )
- self.data_store.release_job(
- self.identity,
- job.task_id,
- result,
- )
- self.event_broker.publish(
- JobReleased.from_result(result, self.identity)
- )
- if not isinstance(exc, Exception):
- raise
- else:
- self.logger.info("Job %s completed successfully", job.id)
- result = JobResult.from_job(
- job,
- JobOutcome.success,
- return_value=retval,
- )
- self.data_store.release_job(self.identity, job.task_id, result)
- self.event_broker.publish(
- JobReleased.from_result(result, self.identity)
- )
- finally:
- current_job.reset(token)
- finally:
- self._running_jobs.remove(job.id)
diff --git a/tests/conftest.py b/tests/conftest.py
index 31ea9b0..7497b52 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,10 +1,16 @@
from __future__ import annotations
import sys
+from contextlib import AsyncExitStack
+from tempfile import TemporaryDirectory
+from typing import Any, AsyncGenerator, cast
import pytest
+from _pytest.fixtures import SubRequest
+from pytest_lazyfixture import lazy_fixture
-from apscheduler.abc import Serializer
+from apscheduler.abc import DataStore, EventBroker, Serializer
+from apscheduler.datastores.memory import MemoryDataStore
from apscheduler.serializers.cbor import CBORSerializer
from apscheduler.serializers.json import JSONSerializer
from apscheduler.serializers.pickle import PickleSerializer
@@ -34,3 +40,189 @@ def serializer(request) -> Serializer | None:
@pytest.fixture
def anyio_backend() -> str:
return "asyncio"
+
+
+@pytest.fixture
+def local_broker() -> EventBroker:
+ from apscheduler.eventbrokers.local import LocalEventBroker
+
+ return LocalEventBroker()
+
+
+@pytest.fixture
+async def redis_broker(serializer: Serializer) -> EventBroker:
+ from apscheduler.eventbrokers.redis import RedisEventBroker
+
+ broker = RedisEventBroker.from_url(
+ "redis://localhost:6379", serializer=serializer, stop_check_interval=0.05
+ )
+ await broker.client.flushdb()
+ return broker
+
+
+@pytest.fixture
+def mqtt_broker(serializer: Serializer) -> EventBroker:
+ from paho.mqtt.client import Client
+
+ from apscheduler.eventbrokers.mqtt import MQTTEventBroker
+
+ return MQTTEventBroker(Client(), serializer=serializer)
+
+
+@pytest.fixture
+async def asyncpg_broker(serializer: Serializer) -> EventBroker:
+ from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker
+
+ broker = AsyncpgEventBroker.from_dsn(
+ "postgres://postgres:secret@localhost:5432/testdb", serializer=serializer
+ )
+ return broker
+
+
+@pytest.fixture(
+ params=[
+ pytest.param(lazy_fixture("local_broker"), id="local"),
+ pytest.param(
+ lazy_fixture("asyncpg_broker"),
+ id="asyncpg",
+ marks=[pytest.mark.external_service],
+ ),
+ pytest.param(
+ lazy_fixture("redis_broker"),
+ id="redis",
+ marks=[pytest.mark.external_service],
+ ),
+ pytest.param(
+ lazy_fixture("mqtt_broker"), id="mqtt", marks=[pytest.mark.external_service]
+ ),
+ ]
+)
+async def raw_event_broker(request: SubRequest) -> EventBroker:
+ return cast(EventBroker, request.param)
+
+
+@pytest.fixture
+async def event_broker(
+ raw_event_broker: EventBroker,
+) -> AsyncGenerator[EventBroker, Any]:
+ async with AsyncExitStack() as exit_stack:
+ await raw_event_broker.start(exit_stack)
+ yield raw_event_broker
+
+
+@pytest.fixture
+def memory_store() -> DataStore:
+ yield MemoryDataStore()
+
+
+@pytest.fixture
+def mongodb_store() -> DataStore:
+ from pymongo import MongoClient
+
+ from apscheduler.datastores.mongodb import MongoDBDataStore
+
+ with MongoClient(tz_aware=True, serverSelectionTimeoutMS=1000) as client:
+ yield MongoDBDataStore(client, start_from_scratch=True)
+
+
+@pytest.fixture
+def sqlite_store() -> DataStore:
+ from sqlalchemy.future import create_engine
+
+ from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+
+ with TemporaryDirectory("sqlite_") as tempdir:
+ engine = create_engine(f"sqlite:///{tempdir}/test.db")
+ try:
+ yield SQLAlchemyDataStore(engine)
+ finally:
+ engine.dispose()
+
+
+@pytest.fixture
+def psycopg2_store() -> DataStore:
+ from sqlalchemy.future import create_engine
+
+ from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+
+ engine = create_engine("postgresql+psycopg2://postgres:secret@localhost/testdb")
+ try:
+ yield SQLAlchemyDataStore(engine, schema="alter", start_from_scratch=True)
+ finally:
+ engine.dispose()
+
+
+@pytest.fixture
+def mysql_store() -> DataStore:
+ from sqlalchemy.future import create_engine
+
+ from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+
+ engine = create_engine("mysql+pymysql://root:secret@localhost/testdb")
+ try:
+ yield SQLAlchemyDataStore(engine, start_from_scratch=True)
+ finally:
+ engine.dispose()
+
+
+@pytest.fixture
+async def asyncpg_store() -> DataStore:
+ from sqlalchemy.ext.asyncio import create_async_engine
+
+ from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+
+ engine = create_async_engine(
+ "postgresql+asyncpg://postgres:secret@localhost/testdb", future=True
+ )
+ try:
+ yield SQLAlchemyDataStore(engine, start_from_scratch=True)
+ finally:
+ await engine.dispose()
+
+
+@pytest.fixture
+async def asyncmy_store() -> DataStore:
+ from sqlalchemy.ext.asyncio import create_async_engine
+
+ from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+
+ engine = create_async_engine(
+ "mysql+asyncmy://root:secret@localhost/testdb?charset=utf8mb4", future=True
+ )
+ try:
+ yield SQLAlchemyDataStore(engine, start_from_scratch=True)
+ finally:
+ await engine.dispose()
+
+
+@pytest.fixture(
+ params=[
+ pytest.param(
+ lazy_fixture("asyncpg_store"),
+ id="asyncpg",
+ marks=[pytest.mark.external_service],
+ ),
+ pytest.param(
+ lazy_fixture("asyncmy_store"),
+ id="asyncmy",
+ marks=[pytest.mark.external_service],
+ ),
+ pytest.param(
+ lazy_fixture("mongodb_store"),
+ id="mongodb",
+ marks=[pytest.mark.external_service],
+ ),
+ ]
+)
+async def raw_datastore(request: SubRequest) -> DataStore:
+ return cast(DataStore, request.param)
+
+
+@pytest.fixture
+async def datastore(
+ raw_datastore: DataStore, local_broker: EventBroker
+) -> AsyncGenerator[DataStore, Any]:
+ async with AsyncExitStack() as exit_stack:
+ await local_broker.start(exit_stack)
+ await raw_datastore.start(exit_stack, local_broker)
+ yield raw_datastore
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 9c9a0cf..618369a 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -1,18 +1,13 @@
from __future__ import annotations
-import threading
-from collections.abc import Generator
-from contextlib import asynccontextmanager, contextmanager
+from contextlib import AsyncExitStack, asynccontextmanager
from datetime import datetime, timedelta, timezone
-from tempfile import TemporaryDirectory
-from typing import Any, AsyncGenerator, cast
+from typing import AsyncGenerator
import anyio
import pytest
-from _pytest.fixtures import SubRequest
from anyio import CancelScope
from freezegun.api import FrozenDateTimeFactory
-from pytest_lazyfixture import lazy_fixture
from apscheduler import (
CoalescePolicy,
@@ -30,1055 +25,483 @@ from apscheduler import (
TaskLookupError,
TaskUpdated,
)
-from apscheduler.abc import AsyncDataStore, AsyncEventBroker, DataStore, EventBroker
-from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter
-from apscheduler.datastores.memory import MemoryDataStore
-from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker
-from apscheduler.eventbrokers.local import LocalEventBroker
+from apscheduler.abc import DataStore, EventBroker
from apscheduler.triggers.date import DateTrigger
+pytestmark = pytest.mark.anyio
+
@pytest.fixture
-def memory_store() -> DataStore:
- yield MemoryDataStore()
+def schedules() -> list[Schedule]:
+ trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
+ schedule1 = Schedule(id="s1", task_id="task1", trigger=trigger)
+ schedule1.next_fire_time = trigger.next()
+ trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc))
+ schedule2 = Schedule(id="s2", task_id="task2", trigger=trigger)
+ schedule2.next_fire_time = trigger.next()
-@pytest.fixture
-def adapted_memory_store() -> AsyncDataStore:
- store = MemoryDataStore()
- return AsyncDataStoreAdapter(store)
+ trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc))
+ schedule3 = Schedule(id="s3", task_id="task1", trigger=trigger)
+ return [schedule1, schedule2, schedule3]
-@pytest.fixture
-def mongodb_store() -> DataStore:
- from pymongo import MongoClient
+@asynccontextmanager
+async def capture_events(
+ datastore: DataStore,
+ limit: int,
+ event_types: set[type[Event]] | None = None,
+) -> AsyncGenerator[list[Event], None]:
+ def listener(event: Event) -> None:
+ events.append(event)
+ if len(events) == limit:
+ limit_event.set()
+ subscription.unsubscribe()
+
+ events: list[Event] = []
+ limit_event = anyio.Event()
+ subscription = datastore._event_broker.subscribe(listener, event_types)
+ yield events
+ if limit:
+ with anyio.fail_after(3):
+ await limit_event.wait()
+
+
+async def test_add_replace_task(datastore: DataStore) -> None:
+ import math
+
+ event_types = {TaskAdded, TaskUpdated}
+ async with capture_events(datastore, 3, event_types) as events:
+ await datastore.add_task(Task(id="test_task", func=print, executor="async"))
+ await datastore.add_task(
+ Task(id="test_task2", func=math.ceil, executor="async")
+ )
+ await datastore.add_task(Task(id="test_task", func=repr, executor="async"))
- from apscheduler.datastores.mongodb import MongoDBDataStore
+ tasks = await datastore.get_tasks()
+ assert len(tasks) == 2
+ assert tasks[0].id == "test_task"
+ assert tasks[0].func is repr
+ assert tasks[1].id == "test_task2"
+ assert tasks[1].func is math.ceil
- with MongoClient(tz_aware=True, serverSelectionTimeoutMS=1000) as client:
- yield MongoDBDataStore(client, start_from_scratch=True)
+ received_event = events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == "test_task"
+ received_event = events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == "test_task2"
-@pytest.fixture
-def sqlite_store() -> DataStore:
- from sqlalchemy.future import create_engine
+ received_event = events.pop(0)
+ assert isinstance(received_event, TaskUpdated)
+ assert received_event.task_id == "test_task"
- from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+ assert not events
- with TemporaryDirectory("sqlite_") as tempdir:
- engine = create_engine(f"sqlite:///{tempdir}/test.db")
- try:
- yield SQLAlchemyDataStore(engine)
- finally:
- engine.dispose()
+async def test_add_schedules(datastore: DataStore, schedules: list[Schedule]) -> None:
+ async with capture_events(datastore, 3, {ScheduleAdded}) as events:
+ for schedule in schedules:
+ await datastore.add_schedule(schedule, ConflictPolicy.exception)
-@pytest.fixture
-def psycopg2_store() -> DataStore:
- from sqlalchemy.future import create_engine
+ assert await datastore.get_schedules() == schedules
+ assert await datastore.get_schedules({"s1", "s2", "s3"}) == schedules
+ assert await datastore.get_schedules({"s1"}) == [schedules[0]]
+ assert await datastore.get_schedules({"s2"}) == [schedules[1]]
+ assert await datastore.get_schedules({"s3"}) == [schedules[2]]
- from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+ for event, schedule in zip(events, schedules):
+ assert event.schedule_id == schedule.id
+ assert event.next_fire_time == schedule.next_fire_time
- engine = create_engine("postgresql+psycopg2://postgres:secret@localhost/testdb")
- try:
- yield SQLAlchemyDataStore(engine, schema="alter", start_from_scratch=True)
- finally:
- engine.dispose()
+async def test_replace_schedules(
+ datastore: DataStore, schedules: list[Schedule]
+) -> None:
+ async with capture_events(datastore, 1, {ScheduleUpdated}) as events:
+ for schedule in schedules:
+ await datastore.add_schedule(schedule, ConflictPolicy.exception)
-@pytest.fixture
-def mysql_store() -> DataStore:
- from sqlalchemy.future import create_engine
+ next_fire_time = schedules[2].trigger.next()
+ schedule = Schedule(
+ id="s3",
+ task_id="foo",
+ trigger=schedules[2].trigger,
+ args=(),
+ kwargs={},
+ coalesce=CoalescePolicy.earliest,
+ misfire_grace_time=None,
+ tags=frozenset(),
+ )
+ schedule.next_fire_time = next_fire_time
+ await datastore.add_schedule(schedule, ConflictPolicy.replace)
+
+ schedules = await datastore.get_schedules({schedule.id})
+ assert schedules[0].task_id == "foo"
+ assert schedules[0].next_fire_time == next_fire_time
+ assert schedules[0].args == ()
+ assert schedules[0].kwargs == {}
+ assert schedules[0].coalesce is CoalescePolicy.earliest
+ assert schedules[0].misfire_grace_time is None
+ assert schedules[0].tags == frozenset()
+
+ received_event = events.pop(0)
+ assert received_event.schedule_id == "s3"
+ assert received_event.next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc)
+ assert not events
+
+
+async def test_remove_schedules(
+ datastore: DataStore, schedules: list[Schedule]
+) -> None:
+ async with capture_events(datastore, 2, {ScheduleRemoved}) as events:
+ for schedule in schedules:
+ await datastore.add_schedule(schedule, ConflictPolicy.exception)
- from apscheduler.datastores.sqlalchemy import SQLAlchemyDataStore
+ await datastore.remove_schedules(["s1", "s2"])
+ assert await datastore.get_schedules() == [schedules[2]]
- engine = create_engine("mysql+pymysql://root:secret@localhost/testdb")
- try:
- yield SQLAlchemyDataStore(engine, start_from_scratch=True)
- finally:
- engine.dispose()
+ received_event = events.pop(0)
+ assert received_event.schedule_id == "s1"
+ received_event = events.pop(0)
+ assert received_event.schedule_id == "s2"
-@pytest.fixture
-async def asyncpg_store() -> AsyncDataStore:
- from sqlalchemy.ext.asyncio import create_async_engine
+ assert not events
- from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore
- engine = create_async_engine(
- "postgresql+asyncpg://postgres:secret@localhost/testdb", future=True
+@pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc))
+async def test_acquire_release_schedules(
+ datastore: DataStore, schedules: list[Schedule]
+) -> None:
+ event_types = {ScheduleRemoved, ScheduleUpdated}
+ async with capture_events(datastore, 2, event_types) as events:
+ for schedule in schedules:
+ await datastore.add_schedule(schedule, ConflictPolicy.exception)
+
+ # The first scheduler gets the first due schedule
+ schedules1 = await datastore.acquire_schedules("dummy-id1", 1)
+ assert len(schedules1) == 1
+ assert schedules1[0].id == "s1"
+
+ # The second scheduler gets the second due schedule
+ schedules2 = await datastore.acquire_schedules("dummy-id2", 1)
+ assert len(schedules2) == 1
+ assert schedules2[0].id == "s2"
+
+ # The third scheduler gets nothing
+ schedules3 = await datastore.acquire_schedules("dummy-id3", 1)
+ assert not schedules3
+
+ # Update the schedules and check that the job store actually deletes the
+ # first one and updates the second one
+ schedules1[0].next_fire_time = None
+ schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc)
+
+ # Release all the schedules
+ await datastore.release_schedules("dummy-id1", schedules1)
+ await datastore.release_schedules("dummy-id2", schedules2)
+
+ # Check that the first schedule is gone
+ schedules = await datastore.get_schedules()
+ assert len(schedules) == 2
+ assert schedules[0].id == "s2"
+ assert schedules[1].id == "s3"
+
+ # Check for the appropriate update and delete events
+ received_event = events.pop(0)
+ assert isinstance(received_event, ScheduleRemoved)
+ assert received_event.schedule_id == "s1"
+
+ received_event = events.pop(0)
+ assert isinstance(received_event, ScheduleUpdated)
+ assert received_event.schedule_id == "s2"
+ assert received_event.next_fire_time == datetime(2020, 9, 15, tzinfo=timezone.utc)
+
+ assert not events
+
+
+async def test_release_schedule_two_identical_fire_times(datastore: DataStore) -> None:
+ """Regression test for #616."""
+ for i in range(1, 3):
+ trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
+ schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger)
+ schedule.next_fire_time = trigger.next()
+ await datastore.add_schedule(schedule, ConflictPolicy.exception)
+
+ schedules = await datastore.acquire_schedules("foo", 3)
+ schedules[0].next_fire_time = None
+ await datastore.release_schedules("foo", schedules)
+
+ remaining = await datastore.get_schedules({s.id for s in schedules})
+ assert len(remaining) == 1
+ assert remaining[0].id == schedules[1].id
+
+
+async def test_release_two_schedules_at_once(datastore: DataStore) -> None:
+ """Regression test for #621."""
+ for i in range(2):
+ trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
+ schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger)
+ schedule.next_fire_time = trigger.next()
+ await datastore.add_schedule(schedule, ConflictPolicy.exception)
+
+ schedules = await datastore.acquire_schedules("foo", 3)
+ await datastore.release_schedules("foo", schedules)
+
+ remaining = await datastore.get_schedules({s.id for s in schedules})
+ assert len(remaining) == 2
+
+
+async def test_acquire_schedules_lock_timeout(
+ datastore: DataStore, schedules: list[Schedule], freezer
+) -> None:
+ """
+ Test that a scheduler can acquire schedules that were acquired by another
+ scheduler but not released within the lock timeout period.
+
+ """
+ await datastore.add_schedule(schedules[0], ConflictPolicy.exception)
+
+ # First, one scheduler acquires the first available schedule
+ acquired1 = await datastore.acquire_schedules("dummy-id1", 1)
+ assert len(acquired1) == 1
+ assert acquired1[0].id == "s1"
+
+ # Try to acquire the schedule just at the threshold (now == acquired_until).
+ # This should not yield any schedules.
+ freezer.tick(30)
+ acquired2 = await datastore.acquire_schedules("dummy-id2", 1)
+ assert not acquired2
+
+ # Right after that, the schedule should be available
+ freezer.tick(1)
+ acquired3 = await datastore.acquire_schedules("dummy-id2", 1)
+ assert len(acquired3) == 1
+ assert acquired3[0].id == "s1"
+
+
+async def test_acquire_multiple_workers(datastore: DataStore) -> None:
+ await datastore.add_task(
+ Task(id="task1", func=asynccontextmanager, executor="async")
)
- try:
- yield AsyncSQLAlchemyDataStore(engine, start_from_scratch=True)
- finally:
- await engine.dispose()
+ jobs = [Job(task_id="task1") for _ in range(2)]
+ for job in jobs:
+ await datastore.add_job(job)
+ # The first worker gets the first job in the queue
+ jobs1 = await datastore.acquire_jobs("worker1", 1)
+ assert len(jobs1) == 1
+ assert jobs1[0].id == jobs[0].id
-@pytest.fixture
-async def asyncmy_store() -> AsyncDataStore:
- from sqlalchemy.ext.asyncio import create_async_engine
+ # The second worker gets the second job
+ jobs2 = await datastore.acquire_jobs("worker2", 1)
+ assert len(jobs2) == 1
+ assert jobs2[0].id == jobs[1].id
+
+ # The third worker gets nothing
+ jobs3 = await datastore.acquire_jobs("worker3", 1)
+ assert not jobs3
- from apscheduler.datastores.async_sqlalchemy import AsyncSQLAlchemyDataStore
- engine = create_async_engine(
- "mysql+asyncmy://root:secret@localhost/testdb?charset=utf8mb4", future=True
+async def test_job_release_success(datastore: DataStore) -> None:
+ await datastore.add_task(
+ Task(id="task1", func=asynccontextmanager, executor="async")
)
- try:
- yield AsyncSQLAlchemyDataStore(engine, start_from_scratch=True)
- finally:
- await engine.dispose()
+ job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
+ await datastore.add_job(job)
+
+ acquired = await datastore.acquire_jobs("worker_id", 2)
+ assert len(acquired) == 1
+ assert acquired[0].id == job.id
+
+ await datastore.release_job(
+ "worker_id",
+ acquired[0].task_id,
+ JobResult.from_job(
+ acquired[0],
+ JobOutcome.success,
+ return_value="foo",
+ ),
+ )
+ result = await datastore.get_job_result(acquired[0].id)
+ assert result.outcome is JobOutcome.success
+ assert result.exception is None
+ assert result.return_value == "foo"
+ # Check that the job and its result are gone
+ assert not await datastore.get_jobs({acquired[0].id})
+ assert not await datastore.get_job_result(acquired[0].id)
-@pytest.fixture
-def schedules() -> list[Schedule]:
- trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
- schedule1 = Schedule(id="s1", task_id="task1", trigger=trigger)
- schedule1.next_fire_time = trigger.next()
- trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc))
- schedule2 = Schedule(id="s2", task_id="task2", trigger=trigger)
- schedule2.next_fire_time = trigger.next()
+async def test_job_release_failure(datastore: DataStore) -> None:
+ await datastore.add_task(
+ Task(id="task1", executor="async", func=asynccontextmanager)
+ )
+ job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
+ await datastore.add_job(job)
+
+ acquired = await datastore.acquire_jobs("worker_id", 2)
+ assert len(acquired) == 1
+ assert acquired[0].id == job.id
+
+ await datastore.release_job(
+ "worker_id",
+ acquired[0].task_id,
+ JobResult.from_job(
+ acquired[0],
+ JobOutcome.error,
+ exception=ValueError("foo"),
+ ),
+ )
+ result = await datastore.get_job_result(acquired[0].id)
+ assert result.outcome is JobOutcome.error
+ assert isinstance(result.exception, ValueError)
+ assert result.exception.args == ("foo",)
+ assert result.return_value is None
- trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc))
- schedule3 = Schedule(id="s3", task_id="task1", trigger=trigger)
- return [schedule1, schedule2, schedule3]
+ # Check that the job and its result are gone
+ assert not await datastore.get_jobs({acquired[0].id})
+ assert not await datastore.get_job_result(acquired[0].id)
-class TestDataStores:
- @contextmanager
- def capture_events(
- self,
- datastore: DataStore,
- limit: int,
- event_types: set[type[Event]] | None = None,
- ) -> Generator[list[Event], None, None]:
- def listener(event: Event) -> None:
- events.append(event)
- if len(events) == limit:
- limit_event.set()
- subscription.unsubscribe()
-
- events: list[Event] = []
- limit_event = threading.Event()
- subscription = datastore.events.subscribe(listener, event_types)
- yield events
- if limit:
- limit_event.wait(2)
-
- @pytest.fixture
- def event_broker(self) -> Generator[EventBroker, Any, None]:
- broker = LocalEventBroker()
- broker.start()
- yield broker
- broker.stop()
-
- @pytest.fixture(
- params=[
- pytest.param(lazy_fixture("memory_store"), id="memory"),
- pytest.param(lazy_fixture("sqlite_store"), id="sqlite"),
- pytest.param(
- lazy_fixture("mongodb_store"),
- id="mongodb",
- marks=[pytest.mark.external_service],
- ),
- pytest.param(
- lazy_fixture("psycopg2_store"),
- id="psycopg2",
- marks=[pytest.mark.external_service],
- ),
- pytest.param(
- lazy_fixture("mysql_store"),
- id="mysql",
- marks=[pytest.mark.external_service],
- ),
- ]
+async def test_job_release_missed_deadline(datastore: DataStore):
+ await datastore.add_task(
+ Task(id="task1", func=asynccontextmanager, executor="async")
)
- def datastore(
- self, request: SubRequest, event_broker: EventBroker
- ) -> Generator[DataStore, Any, None]:
- datastore = cast(DataStore, request.param)
- datastore.start(event_broker)
- yield datastore
- datastore.stop()
-
- def test_add_replace_task(self, datastore: DataStore) -> None:
- import math
-
- event_types = {TaskAdded, TaskUpdated}
- with self.capture_events(datastore, 3, event_types) as events:
- datastore.add_task(Task(id="test_task", func=print))
- datastore.add_task(Task(id="test_task2", func=math.ceil))
- datastore.add_task(Task(id="test_task", func=repr))
-
- tasks = datastore.get_tasks()
- assert len(tasks) == 2
- assert tasks[0].id == "test_task"
- assert tasks[0].func is repr
- assert tasks[1].id == "test_task2"
- assert tasks[1].func is math.ceil
-
- received_event = events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "test_task"
-
- received_event = events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "test_task2"
-
- received_event = events.pop(0)
- assert isinstance(received_event, TaskUpdated)
- assert received_event.task_id == "test_task"
-
- assert not events
-
- def test_add_schedules(
- self, datastore: DataStore, schedules: list[Schedule]
- ) -> None:
- with self.capture_events(datastore, 3, {ScheduleAdded}) as events:
- for schedule in schedules:
- datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- assert datastore.get_schedules() == schedules
- assert datastore.get_schedules({"s1", "s2", "s3"}) == schedules
- assert datastore.get_schedules({"s1"}) == [schedules[0]]
- assert datastore.get_schedules({"s2"}) == [schedules[1]]
- assert datastore.get_schedules({"s3"}) == [schedules[2]]
-
- for event, schedule in zip(events, schedules):
- assert event.schedule_id == schedule.id
- assert event.next_fire_time == schedule.next_fire_time
-
- def test_replace_schedules(
- self, datastore: DataStore, schedules: list[Schedule]
- ) -> None:
- with self.capture_events(datastore, 1, {ScheduleUpdated}) as events:
- for schedule in schedules:
- datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- next_fire_time = schedules[2].trigger.next()
- schedule = Schedule(
- id="s3",
- task_id="foo",
- trigger=schedules[2].trigger,
- args=(),
- kwargs={},
- coalesce=CoalescePolicy.earliest,
- misfire_grace_time=None,
- tags=frozenset(),
- )
- schedule.next_fire_time = next_fire_time
- datastore.add_schedule(schedule, ConflictPolicy.replace)
-
- schedules = datastore.get_schedules({schedule.id})
- assert schedules[0].task_id == "foo"
- assert schedules[0].next_fire_time == next_fire_time
- assert schedules[0].args == ()
- assert schedules[0].kwargs == {}
- assert schedules[0].coalesce is CoalescePolicy.earliest
- assert schedules[0].misfire_grace_time is None
- assert schedules[0].tags == frozenset()
-
- received_event = events.pop(0)
- assert received_event.schedule_id == "s3"
- assert received_event.next_fire_time == datetime(
- 2020, 9, 15, tzinfo=timezone.utc
- )
- assert not events
-
- def test_remove_schedules(
- self, datastore: DataStore, schedules: list[Schedule]
- ) -> None:
- with self.capture_events(datastore, 2, {ScheduleRemoved}) as events:
- for schedule in schedules:
- datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- datastore.remove_schedules(["s1", "s2"])
- assert datastore.get_schedules() == [schedules[2]]
-
- received_event = events.pop(0)
- assert received_event.schedule_id == "s1"
-
- received_event = events.pop(0)
- assert received_event.schedule_id == "s2"
-
- assert not events
-
- @pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc))
- def test_acquire_release_schedules(
- self, datastore: DataStore, schedules: list[Schedule]
- ) -> None:
- event_types = {ScheduleRemoved, ScheduleUpdated}
- with self.capture_events(datastore, 2, event_types) as events:
- for schedule in schedules:
- datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- # The first scheduler gets the first due schedule
- schedules1 = datastore.acquire_schedules("dummy-id1", 1)
- assert len(schedules1) == 1
- assert schedules1[0].id == "s1"
-
- # The second scheduler gets the second due schedule
- schedules2 = datastore.acquire_schedules("dummy-id2", 1)
- assert len(schedules2) == 1
- assert schedules2[0].id == "s2"
-
- # The third scheduler gets nothing
- schedules3 = datastore.acquire_schedules("dummy-id3", 1)
- assert not schedules3
-
- # Update the schedules and check that the job store actually deletes the
- # first one and updates the second one
- schedules1[0].next_fire_time = None
- schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc)
-
- # Release all the schedules
- datastore.release_schedules("dummy-id1", schedules1)
- datastore.release_schedules("dummy-id2", schedules2)
-
- # Check that the first schedule is gone
- schedules = datastore.get_schedules()
- assert len(schedules) == 2
- assert schedules[0].id == "s2"
- assert schedules[1].id == "s3"
-
- # Check for the appropriate update and delete events
- received_event = events.pop(0)
- assert isinstance(received_event, ScheduleRemoved)
- assert received_event.schedule_id == "s1"
-
- received_event = events.pop(0)
- assert isinstance(received_event, ScheduleUpdated)
- assert received_event.schedule_id == "s2"
- assert received_event.next_fire_time == datetime(
- 2020, 9, 15, tzinfo=timezone.utc
- )
+ job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
+ await datastore.add_job(job)
+
+ acquired = await datastore.acquire_jobs("worker_id", 2)
+ assert len(acquired) == 1
+ assert acquired[0].id == job.id
+
+ await datastore.release_job(
+ "worker_id",
+ acquired[0].task_id,
+ JobResult.from_job(
+ acquired[0],
+ JobOutcome.missed_start_deadline,
+ ),
+ )
+ result = await datastore.get_job_result(acquired[0].id)
+ assert result.outcome is JobOutcome.missed_start_deadline
+ assert result.exception is None
+ assert result.return_value is None
- assert not events
-
- def test_release_schedule_two_identical_fire_times(
- self, datastore: DataStore
- ) -> None:
- """Regression test for #616."""
- for i in range(1, 3):
- trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
- schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger)
- schedule.next_fire_time = trigger.next()
- datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- schedules = datastore.acquire_schedules("foo", 3)
- schedules[0].next_fire_time = None
- datastore.release_schedules("foo", schedules)
-
- remaining = datastore.get_schedules({s.id for s in schedules})
- assert len(remaining) == 1
- assert remaining[0].id == schedules[1].id
-
- def test_release_two_schedules_at_once(self, datastore: DataStore) -> None:
- """Regression test for #621."""
- for i in range(2):
- trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
- schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger)
- schedule.next_fire_time = trigger.next()
- datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- schedules = datastore.acquire_schedules("foo", 3)
- datastore.release_schedules("foo", schedules)
-
- remaining = datastore.get_schedules({s.id for s in schedules})
- assert len(remaining) == 2
-
- def test_acquire_schedules_lock_timeout(
- self, datastore: DataStore, schedules: list[Schedule], freezer
- ) -> None:
- """
- Test that a scheduler can acquire schedules that were acquired by another
- scheduler but not released within the lock timeout period.
-
- """
- datastore.add_schedule(schedules[0], ConflictPolicy.exception)
-
- # First, one scheduler acquires the first available schedule
- acquired1 = datastore.acquire_schedules("dummy-id1", 1)
- assert len(acquired1) == 1
- assert acquired1[0].id == "s1"
-
- # Try to acquire the schedule just at the threshold (now == acquired_until).
- # This should not yield any schedules.
- freezer.tick(30)
- acquired2 = datastore.acquire_schedules("dummy-id2", 1)
- assert not acquired2
-
- # Right after that, the schedule should be available
- freezer.tick(1)
- acquired3 = datastore.acquire_schedules("dummy-id2", 1)
- assert len(acquired3) == 1
- assert acquired3[0].id == "s1"
-
- def test_acquire_multiple_workers(self, datastore: DataStore) -> None:
- datastore.add_task(Task(id="task1", func=asynccontextmanager))
- jobs = [Job(task_id="task1") for _ in range(2)]
- for job in jobs:
- datastore.add_job(job)
-
- # The first worker gets the first job in the queue
- jobs1 = datastore.acquire_jobs("worker1", 1)
- assert len(jobs1) == 1
- assert jobs1[0].id == jobs[0].id
-
- # The second worker gets the second job
- jobs2 = datastore.acquire_jobs("worker2", 1)
- assert len(jobs2) == 1
- assert jobs2[0].id == jobs[1].id
-
- # The third worker gets nothing
- jobs3 = datastore.acquire_jobs("worker3", 1)
- assert not jobs3
-
- def test_job_release_success(self, datastore: DataStore) -> None:
- datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- datastore.add_job(job)
-
- acquired = datastore.acquire_jobs("worker_id", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- datastore.release_job(
- "worker_id",
- acquired[0].task_id,
- JobResult.from_job(
- acquired[0],
- JobOutcome.success,
- return_value="foo",
- ),
- )
- result = datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.success
- assert result.exception is None
- assert result.return_value == "foo"
-
- # Check that the job and its result are gone
- assert not datastore.get_jobs({acquired[0].id})
- assert not datastore.get_job_result(acquired[0].id)
-
- def test_job_release_failure(self, datastore: DataStore) -> None:
- datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- datastore.add_job(job)
-
- acquired = datastore.acquire_jobs("worker_id", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- datastore.release_job(
- "worker_id",
- acquired[0].task_id,
- JobResult.from_job(
- acquired[0],
- JobOutcome.error,
- exception=ValueError("foo"),
- ),
- )
- result = datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.error
- assert isinstance(result.exception, ValueError)
- assert result.exception.args == ("foo",)
- assert result.return_value is None
-
- # Check that the job and its result are gone
- assert not datastore.get_jobs({acquired[0].id})
- assert not datastore.get_job_result(acquired[0].id)
-
- def test_job_release_missed_deadline(self, datastore: DataStore):
- datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- datastore.add_job(job)
-
- acquired = datastore.acquire_jobs("worker_id", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- datastore.release_job(
- "worker_id",
- acquired[0].task_id,
- JobResult.from_job(
- acquired[0],
- JobOutcome.missed_start_deadline,
- ),
- )
- result = datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.missed_start_deadline
- assert result.exception is None
- assert result.return_value is None
-
- # Check that the job and its result are gone
- assert not datastore.get_jobs({acquired[0].id})
- assert not datastore.get_job_result(acquired[0].id)
-
- def test_job_release_cancelled(self, datastore: DataStore) -> None:
- datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- datastore.add_job(job)
-
- acquired = datastore.acquire_jobs("worker1", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- datastore.release_job(
- "worker1",
- acquired[0].task_id,
- JobResult.from_job(
- acquired[0],
- JobOutcome.cancelled,
- ),
- )
- result = datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.cancelled
- assert result.exception is None
- assert result.return_value is None
-
- # Check that the job and its result are gone
- assert not datastore.get_jobs({acquired[0].id})
- assert not datastore.get_job_result(acquired[0].id)
-
- def test_acquire_jobs_lock_timeout(
- self, datastore: DataStore, freezer: FrozenDateTimeFactory
- ) -> None:
- """
- Test that a worker can acquire jobs that were acquired by another scheduler but
- not released within the lock timeout period.
-
- """
- datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1")
- datastore.add_job(job)
-
- # First, one worker acquires the first available job
- acquired = datastore.acquire_jobs("worker1", 1)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- # Try to acquire the job just at the threshold (now == acquired_until).
- # This should not yield any jobs.
- freezer.tick(30)
- assert not datastore.acquire_jobs("worker2", 1)
-
- # Right after that, the job should be available
- freezer.tick(1)
- acquired = datastore.acquire_jobs("worker2", 1)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- def test_acquire_jobs_max_number_exceeded(self, datastore: DataStore) -> None:
- datastore.add_task(
- Task(id="task1", func=asynccontextmanager, max_running_jobs=2)
- )
- jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")]
- for job in jobs:
- datastore.add_job(job)
-
- # Check that only 2 jobs are returned from acquire_jobs() even though the limit
- # wqas 3
- acquired_jobs = datastore.acquire_jobs("worker1", 3)
- assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]]
-
- # Release one job, and the worker should be able to acquire the third job
- datastore.release_job(
- "worker1",
- acquired_jobs[0].task_id,
- JobResult.from_job(
- acquired_jobs[0],
- JobOutcome.success,
- return_value=None,
- ),
- )
- acquired_jobs = datastore.acquire_jobs("worker1", 3)
- assert [job.id for job in acquired_jobs] == [jobs[2].id]
-
- def test_add_get_task(self, datastore: DataStore) -> None:
- with pytest.raises(TaskLookupError):
- datastore.get_task("dummyid")
-
- datastore.add_task(Task(id="dummyid", func=asynccontextmanager))
- task = datastore.get_task("dummyid")
- assert task.id == "dummyid"
- assert task.func is asynccontextmanager
-
-
-@pytest.mark.anyio
-class TestAsyncDataStores:
- @asynccontextmanager
- async def capture_events(
- self,
- datastore: AsyncDataStore,
- limit: int,
- event_types: set[type[Event]] | None = None,
- ) -> AsyncGenerator[list[Event], None]:
- def listener(event: Event) -> None:
- events.append(event)
- if len(events) == limit:
- limit_event.set()
- subscription.unsubscribe()
-
- events: list[Event] = []
- limit_event = anyio.Event()
- subscription = datastore.events.subscribe(listener, event_types)
- yield events
- if limit:
- with anyio.fail_after(3):
- await limit_event.wait()
-
- @pytest.fixture
- async def event_broker(self) -> AsyncGenerator[AsyncEventBroker, Any]:
- broker = LocalAsyncEventBroker()
- await broker.start()
- yield broker
- await broker.stop()
-
- @pytest.fixture(
- params=[
- pytest.param(lazy_fixture("adapted_memory_store"), id="memory"),
- pytest.param(
- lazy_fixture("asyncpg_store"),
- id="asyncpg",
- marks=[pytest.mark.external_service],
- ),
- pytest.param(
- lazy_fixture("asyncmy_store"),
- id="asyncmy",
- marks=[pytest.mark.external_service],
- ),
- ]
+ # Check that the job and its result are gone
+ assert not await datastore.get_jobs({acquired[0].id})
+ assert not await datastore.get_job_result(acquired[0].id)
+
+
+async def test_job_release_cancelled(datastore: DataStore) -> None:
+ await datastore.add_task(
+ Task(id="task1", func=asynccontextmanager, executor="async")
)
- async def raw_datastore(
- self, request: SubRequest, event_broker: AsyncEventBroker
- ) -> AsyncDataStore:
- return cast(AsyncDataStore, request.param)
-
- @pytest.fixture
- async def datastore(
- self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker
- ) -> AsyncGenerator[AsyncDataStore, Any]:
- await raw_datastore.start(event_broker)
- yield raw_datastore
- await raw_datastore.stop()
-
- async def test_add_replace_task(self, datastore: AsyncDataStore) -> None:
- import math
-
- event_types = {TaskAdded, TaskUpdated}
- async with self.capture_events(datastore, 3, event_types) as events:
- await datastore.add_task(Task(id="test_task", func=print))
- await datastore.add_task(Task(id="test_task2", func=math.ceil))
- await datastore.add_task(Task(id="test_task", func=repr))
-
- tasks = await datastore.get_tasks()
- assert len(tasks) == 2
- assert tasks[0].id == "test_task"
- assert tasks[0].func is repr
- assert tasks[1].id == "test_task2"
- assert tasks[1].func is math.ceil
-
- received_event = events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "test_task"
-
- received_event = events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "test_task2"
-
- received_event = events.pop(0)
- assert isinstance(received_event, TaskUpdated)
- assert received_event.task_id == "test_task"
-
- assert not events
-
- async def test_add_schedules(
- self, datastore: AsyncDataStore, schedules: list[Schedule]
- ) -> None:
- async with self.capture_events(datastore, 3, {ScheduleAdded}) as events:
- for schedule in schedules:
- await datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- assert await datastore.get_schedules() == schedules
- assert await datastore.get_schedules({"s1", "s2", "s3"}) == schedules
- assert await datastore.get_schedules({"s1"}) == [schedules[0]]
- assert await datastore.get_schedules({"s2"}) == [schedules[1]]
- assert await datastore.get_schedules({"s3"}) == [schedules[2]]
-
- for event, schedule in zip(events, schedules):
- assert event.schedule_id == schedule.id
- assert event.next_fire_time == schedule.next_fire_time
-
- async def test_replace_schedules(
- self, datastore: AsyncDataStore, schedules: list[Schedule]
- ) -> None:
- async with self.capture_events(datastore, 1, {ScheduleUpdated}) as events:
- for schedule in schedules:
- await datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- next_fire_time = schedules[2].trigger.next()
- schedule = Schedule(
- id="s3",
- task_id="foo",
- trigger=schedules[2].trigger,
- args=(),
- kwargs={},
- coalesce=CoalescePolicy.earliest,
- misfire_grace_time=None,
- tags=frozenset(),
- )
- schedule.next_fire_time = next_fire_time
- await datastore.add_schedule(schedule, ConflictPolicy.replace)
-
- schedules = await datastore.get_schedules({schedule.id})
- assert schedules[0].task_id == "foo"
- assert schedules[0].next_fire_time == next_fire_time
- assert schedules[0].args == ()
- assert schedules[0].kwargs == {}
- assert schedules[0].coalesce is CoalescePolicy.earliest
- assert schedules[0].misfire_grace_time is None
- assert schedules[0].tags == frozenset()
-
- received_event = events.pop(0)
- assert received_event.schedule_id == "s3"
- assert received_event.next_fire_time == datetime(
- 2020, 9, 15, tzinfo=timezone.utc
- )
- assert not events
-
- async def test_remove_schedules(
- self, datastore: AsyncDataStore, schedules: list[Schedule]
- ) -> None:
- async with self.capture_events(datastore, 2, {ScheduleRemoved}) as events:
- for schedule in schedules:
- await datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- await datastore.remove_schedules(["s1", "s2"])
- assert await datastore.get_schedules() == [schedules[2]]
-
- received_event = events.pop(0)
- assert received_event.schedule_id == "s1"
-
- received_event = events.pop(0)
- assert received_event.schedule_id == "s2"
-
- assert not events
-
- @pytest.mark.freeze_time(datetime(2020, 9, 14, tzinfo=timezone.utc))
- async def test_acquire_release_schedules(
- self, datastore: AsyncDataStore, schedules: list[Schedule]
- ) -> None:
- event_types = {ScheduleRemoved, ScheduleUpdated}
- async with self.capture_events(datastore, 2, event_types) as events:
- for schedule in schedules:
- await datastore.add_schedule(schedule, ConflictPolicy.exception)
-
- # The first scheduler gets the first due schedule
- schedules1 = await datastore.acquire_schedules("dummy-id1", 1)
- assert len(schedules1) == 1
- assert schedules1[0].id == "s1"
-
- # The second scheduler gets the second due schedule
- schedules2 = await datastore.acquire_schedules("dummy-id2", 1)
- assert len(schedules2) == 1
- assert schedules2[0].id == "s2"
-
- # The third scheduler gets nothing
- schedules3 = await datastore.acquire_schedules("dummy-id3", 1)
- assert not schedules3
-
- # Update the schedules and check that the job store actually deletes the
- # first one and updates the second one
- schedules1[0].next_fire_time = None
- schedules2[0].next_fire_time = datetime(2020, 9, 15, tzinfo=timezone.utc)
-
- # Release all the schedules
- await datastore.release_schedules("dummy-id1", schedules1)
- await datastore.release_schedules("dummy-id2", schedules2)
-
- # Check that the first schedule is gone
- schedules = await datastore.get_schedules()
- assert len(schedules) == 2
- assert schedules[0].id == "s2"
- assert schedules[1].id == "s3"
-
- # Check for the appropriate update and delete events
- received_event = events.pop(0)
- assert isinstance(received_event, ScheduleRemoved)
- assert received_event.schedule_id == "s1"
-
- received_event = events.pop(0)
- assert isinstance(received_event, ScheduleUpdated)
- assert received_event.schedule_id == "s2"
- assert received_event.next_fire_time == datetime(
- 2020, 9, 15, tzinfo=timezone.utc
- )
+ job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
+ await datastore.add_job(job)
- assert not events
+ acquired = await datastore.acquire_jobs("worker1", 2)
+ assert len(acquired) == 1
+ assert acquired[0].id == job.id
- async def test_release_schedule_two_identical_fire_times(
- self, datastore: AsyncDataStore
- ) -> None:
- """Regression test for #616."""
- for i in range(1, 3):
- trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
- schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger)
- schedule.next_fire_time = trigger.next()
- await datastore.add_schedule(schedule, ConflictPolicy.exception)
+ await datastore.release_job(
+ "worker1",
+ acquired[0].task_id,
+ JobResult.from_job(acquired[0], JobOutcome.cancelled),
+ )
+ result = await datastore.get_job_result(acquired[0].id)
+ assert result.outcome is JobOutcome.cancelled
+ assert result.exception is None
+ assert result.return_value is None
+
+ # Check that the job and its result are gone
+ assert not await datastore.get_jobs({acquired[0].id})
+ assert not await datastore.get_job_result(acquired[0].id)
+
+
+async def test_acquire_jobs_lock_timeout(
+ datastore: DataStore, freezer: FrozenDateTimeFactory
+) -> None:
+ """
+ Test that a worker can acquire jobs that were acquired by another scheduler but
+ not released within the lock timeout period.
+
+ """
+ await datastore.add_task(
+ Task(id="task1", func=asynccontextmanager, executor="async")
+ )
+ job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
+ await datastore.add_job(job)
- schedules = await datastore.acquire_schedules("foo", 3)
- schedules[0].next_fire_time = None
- await datastore.release_schedules("foo", schedules)
-
- remaining = await datastore.get_schedules({s.id for s in schedules})
- assert len(remaining) == 1
- assert remaining[0].id == schedules[1].id
-
- async def test_release_two_schedules_at_once(
- self, datastore: AsyncDataStore
- ) -> None:
- """Regression test for #621."""
- for i in range(2):
- trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
- schedule = Schedule(id=f"s{i}", task_id="task1", trigger=trigger)
- schedule.next_fire_time = trigger.next()
- await datastore.add_schedule(schedule, ConflictPolicy.exception)
+ # First, one worker acquires the first available job
+ acquired = await datastore.acquire_jobs("worker1", 1)
+ assert len(acquired) == 1
+ assert acquired[0].id == job.id
- schedules = await datastore.acquire_schedules("foo", 3)
- await datastore.release_schedules("foo", schedules)
-
- remaining = await datastore.get_schedules({s.id for s in schedules})
- assert len(remaining) == 2
-
- async def test_acquire_schedules_lock_timeout(
- self, datastore: AsyncDataStore, schedules: list[Schedule], freezer
- ) -> None:
- """
- Test that a scheduler can acquire schedules that were acquired by another
- scheduler but not released within the lock timeout period.
-
- """
- await datastore.add_schedule(schedules[0], ConflictPolicy.exception)
-
- # First, one scheduler acquires the first available schedule
- acquired1 = await datastore.acquire_schedules("dummy-id1", 1)
- assert len(acquired1) == 1
- assert acquired1[0].id == "s1"
-
- # Try to acquire the schedule just at the threshold (now == acquired_until).
- # This should not yield any schedules.
- freezer.tick(30)
- acquired2 = await datastore.acquire_schedules("dummy-id2", 1)
- assert not acquired2
-
- # Right after that, the schedule should be available
- freezer.tick(1)
- acquired3 = await datastore.acquire_schedules("dummy-id2", 1)
- assert len(acquired3) == 1
- assert acquired3[0].id == "s1"
-
- async def test_acquire_multiple_workers(self, datastore: AsyncDataStore) -> None:
- await datastore.add_task(Task(id="task1", func=asynccontextmanager))
- jobs = [Job(task_id="task1") for _ in range(2)]
- for job in jobs:
- await datastore.add_job(job)
-
- # The first worker gets the first job in the queue
- jobs1 = await datastore.acquire_jobs("worker1", 1)
- assert len(jobs1) == 1
- assert jobs1[0].id == jobs[0].id
-
- # The second worker gets the second job
- jobs2 = await datastore.acquire_jobs("worker2", 1)
- assert len(jobs2) == 1
- assert jobs2[0].id == jobs[1].id
-
- # The third worker gets nothing
- jobs3 = await datastore.acquire_jobs("worker3", 1)
- assert not jobs3
-
- async def test_job_release_success(self, datastore: AsyncDataStore) -> None:
- await datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- await datastore.add_job(job)
+ # Try to acquire the job just at the threshold (now == acquired_until).
+ # This should not yield any jobs.
+ freezer.tick(30)
+ assert not await datastore.acquire_jobs("worker2", 1)
- acquired = await datastore.acquire_jobs("worker_id", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- await datastore.release_job(
- "worker_id",
- acquired[0].task_id,
- JobResult.from_job(
- acquired[0],
- JobOutcome.success,
- return_value="foo",
- ),
- )
- result = await datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.success
- assert result.exception is None
- assert result.return_value == "foo"
-
- # Check that the job and its result are gone
- assert not await datastore.get_jobs({acquired[0].id})
- assert not await datastore.get_job_result(acquired[0].id)
-
- async def test_job_release_failure(self, datastore: AsyncDataStore) -> None:
- await datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- await datastore.add_job(job)
+ # Right after that, the job should be available
+ freezer.tick(1)
+ acquired = await datastore.acquire_jobs("worker2", 1)
+ assert len(acquired) == 1
+ assert acquired[0].id == job.id
- acquired = await datastore.acquire_jobs("worker_id", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- await datastore.release_job(
- "worker_id",
- acquired[0].task_id,
- JobResult.from_job(
- acquired[0],
- JobOutcome.error,
- exception=ValueError("foo"),
- ),
- )
- result = await datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.error
- assert isinstance(result.exception, ValueError)
- assert result.exception.args == ("foo",)
- assert result.return_value is None
-
- # Check that the job and its result are gone
- assert not await datastore.get_jobs({acquired[0].id})
- assert not await datastore.get_job_result(acquired[0].id)
-
- async def test_job_release_missed_deadline(self, datastore: AsyncDataStore):
- await datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- await datastore.add_job(job)
- acquired = await datastore.acquire_jobs("worker_id", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- await datastore.release_job(
- "worker_id",
- acquired[0].task_id,
- JobResult.from_job(
- acquired[0],
- JobOutcome.missed_start_deadline,
- ),
- )
- result = await datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.missed_start_deadline
- assert result.exception is None
- assert result.return_value is None
-
- # Check that the job and its result are gone
- assert not await datastore.get_jobs({acquired[0].id})
- assert not await datastore.get_job_result(acquired[0].id)
-
- async def test_job_release_cancelled(self, datastore: AsyncDataStore) -> None:
- await datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
+async def test_acquire_jobs_max_number_exceeded(datastore: DataStore) -> None:
+ await datastore.add_task(
+ Task(id="task1", func=asynccontextmanager, executor="async", max_running_jobs=2)
+ )
+ jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")]
+ for job in jobs:
await datastore.add_job(job)
- acquired = await datastore.acquire_jobs("worker1", 2)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
+ # Check that only 2 jobs are returned from acquire_jobs() even though the limit
+ # wqas 3
+ acquired_jobs = await datastore.acquire_jobs("worker1", 3)
+ assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]]
+
+ # Release one job, and the worker should be able to acquire the third job
+ await datastore.release_job(
+ "worker1",
+ acquired_jobs[0].task_id,
+ JobResult.from_job(
+ acquired_jobs[0],
+ JobOutcome.success,
+ return_value=None,
+ ),
+ )
+ acquired_jobs = await datastore.acquire_jobs("worker1", 3)
+ assert [job.id for job in acquired_jobs] == [jobs[2].id]
- await datastore.release_job(
- "worker1",
- acquired[0].task_id,
- JobResult.from_job(acquired[0], JobOutcome.cancelled),
- )
- result = await datastore.get_job_result(acquired[0].id)
- assert result.outcome is JobOutcome.cancelled
- assert result.exception is None
- assert result.return_value is None
-
- # Check that the job and its result are gone
- assert not await datastore.get_jobs({acquired[0].id})
- assert not await datastore.get_job_result(acquired[0].id)
-
- async def test_acquire_jobs_lock_timeout(
- self, datastore: AsyncDataStore, freezer: FrozenDateTimeFactory
- ) -> None:
- """
- Test that a worker can acquire jobs that were acquired by another scheduler but
- not released within the lock timeout period.
-
- """
- await datastore.add_task(Task(id="task1", func=asynccontextmanager))
- job = Job(task_id="task1", result_expiration_time=timedelta(minutes=1))
- await datastore.add_job(job)
- # First, one worker acquires the first available job
- acquired = await datastore.acquire_jobs("worker1", 1)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- # Try to acquire the job just at the threshold (now == acquired_until).
- # This should not yield any jobs.
- freezer.tick(30)
- assert not await datastore.acquire_jobs("worker2", 1)
-
- # Right after that, the job should be available
- freezer.tick(1)
- acquired = await datastore.acquire_jobs("worker2", 1)
- assert len(acquired) == 1
- assert acquired[0].id == job.id
-
- async def test_acquire_jobs_max_number_exceeded(
- self, datastore: AsyncDataStore
- ) -> None:
- await datastore.add_task(
- Task(id="task1", func=asynccontextmanager, max_running_jobs=2)
- )
- jobs = [Job(task_id="task1"), Job(task_id="task1"), Job(task_id="task1")]
- for job in jobs:
- await datastore.add_job(job)
-
- # Check that only 2 jobs are returned from acquire_jobs() even though the limit
- # wqas 3
- acquired_jobs = await datastore.acquire_jobs("worker1", 3)
- assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]]
-
- # Release one job, and the worker should be able to acquire the third job
- await datastore.release_job(
- "worker1",
- acquired_jobs[0].task_id,
- JobResult.from_job(
- acquired_jobs[0],
- JobOutcome.success,
- return_value=None,
- ),
- )
- acquired_jobs = await datastore.acquire_jobs("worker1", 3)
- assert [job.id for job in acquired_jobs] == [jobs[2].id]
-
- async def test_add_get_task(self, datastore: DataStore) -> None:
- with pytest.raises(TaskLookupError):
- await datastore.get_task("dummyid")
-
- await datastore.add_task(Task(id="dummyid", func=asynccontextmanager))
- task = await datastore.get_task("dummyid")
- assert task.id == "dummyid"
- assert task.func is asynccontextmanager
-
- async def test_cancel_start(
- self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker
- ) -> None:
- with CancelScope() as scope:
- scope.cancel()
- await raw_datastore.start(event_broker)
- await raw_datastore.stop()
-
- async def test_cancel_stop(
- self, raw_datastore: AsyncDataStore, event_broker: AsyncEventBroker
- ) -> None:
- with CancelScope() as scope:
- await raw_datastore.start(event_broker)
+async def test_add_get_task(datastore: DataStore) -> None:
+ with pytest.raises(TaskLookupError):
+ await datastore.get_task("dummyid")
+
+ await datastore.add_task(
+ Task(id="dummyid", func=asynccontextmanager, executor="async")
+ )
+ task = await datastore.get_task("dummyid")
+ assert task.id == "dummyid"
+ assert task.func is asynccontextmanager
+
+
+async def test_cancel_start(
+ raw_datastore: DataStore, local_broker: EventBroker
+) -> None:
+ with CancelScope() as scope:
+ scope.cancel()
+ async with AsyncExitStack() as exit_stack:
+ await raw_datastore.start(exit_stack, local_broker)
+
+
+async def test_cancel_stop(raw_datastore: DataStore, local_broker: EventBroker) -> None:
+ with CancelScope() as scope:
+ async with AsyncExitStack() as exit_stack:
+ await raw_datastore.start(exit_stack, local_broker)
scope.cancel()
- await raw_datastore.stop()
diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py
index 09942f5..9b98f60 100644
--- a/tests/test_eventbrokers.py
+++ b/tests/test_eventbrokers.py
@@ -1,286 +1,107 @@
from __future__ import annotations
-from collections.abc import AsyncGenerator, Generator
-from concurrent.futures import Future
+from contextlib import AsyncExitStack
from datetime import datetime, timezone
-from queue import Empty, Queue
-from typing import Any, cast
import pytest
-from _pytest.fixtures import SubRequest
from _pytest.logging import LogCaptureFixture
from anyio import CancelScope, create_memory_object_stream, fail_after
-from pytest_lazyfixture import lazy_fixture
from apscheduler import Event, ScheduleAdded
-from apscheduler.abc import AsyncEventBroker, EventBroker, Serializer
+from apscheduler.abc import EventBroker
+pytestmark = pytest.mark.anyio
-@pytest.fixture
-def local_broker() -> EventBroker:
- from apscheduler.eventbrokers.local import LocalEventBroker
- return LocalEventBroker()
-
-
-@pytest.fixture
-def local_async_broker() -> AsyncEventBroker:
- from apscheduler.eventbrokers.async_local import LocalAsyncEventBroker
-
- return LocalAsyncEventBroker()
-
-
-@pytest.fixture
-def redis_broker(serializer: Serializer) -> EventBroker:
- from apscheduler.eventbrokers.redis import RedisEventBroker
-
- broker = RedisEventBroker.from_url(
- "redis://localhost:6379", serializer=serializer, stop_check_interval=0.05
+async def test_publish_subscribe(event_broker: EventBroker) -> None:
+ send, receive = create_memory_object_stream(2)
+ event_broker.subscribe(send.send)
+ event_broker.subscribe(send.send_nowait)
+ event = ScheduleAdded(
+ schedule_id="schedule1",
+ next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc),
)
- return broker
-
-
-@pytest.fixture
-async def async_redis_broker(serializer: Serializer) -> AsyncEventBroker:
- from apscheduler.eventbrokers.async_redis import AsyncRedisEventBroker
-
- broker = AsyncRedisEventBroker.from_url(
- "redis://localhost:6379", serializer=serializer, stop_check_interval=0.05
+ await event_broker.publish(event)
+
+ with fail_after(3):
+ event1 = await receive.receive()
+ event2 = await receive.receive()
+
+ assert event1 == event2
+ assert isinstance(event1, ScheduleAdded)
+ assert isinstance(event1.timestamp, datetime)
+ assert event1.schedule_id == "schedule1"
+ assert event1.next_fire_time == datetime(
+ 2021, 9, 11, 12, 31, 56, 254867, timezone.utc
)
- return broker
-@pytest.fixture
-def mqtt_broker(serializer: Serializer) -> EventBroker:
- from paho.mqtt.client import Client
-
- from apscheduler.eventbrokers.mqtt import MQTTEventBroker
-
- return MQTTEventBroker(Client(), serializer=serializer)
-
-
-@pytest.fixture
-async def asyncpg_broker(serializer: Serializer) -> AsyncEventBroker:
- from apscheduler.eventbrokers.asyncpg import AsyncpgEventBroker
-
- broker = AsyncpgEventBroker.from_dsn(
- "postgres://postgres:secret@localhost:5432/testdb", serializer=serializer
+async def test_subscribe_one_shot(event_broker: EventBroker) -> None:
+ send, receive = create_memory_object_stream(2)
+ event_broker.subscribe(send.send, one_shot=True)
+ event = ScheduleAdded(
+ schedule_id="schedule1",
+ next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc),
)
- yield broker
-
-
-@pytest.fixture(
- params=[
- pytest.param(lazy_fixture("local_broker"), id="local"),
- pytest.param(
- lazy_fixture("redis_broker"),
- id="redis",
- marks=[pytest.mark.external_service],
- ),
- pytest.param(
- lazy_fixture("mqtt_broker"), id="mqtt", marks=[pytest.mark.external_service]
- ),
- ]
-)
-def broker(request: SubRequest) -> Generator[EventBroker, Any, None]:
- request.param.start()
- yield request.param
- request.param.stop()
-
-
-@pytest.fixture(
- params=[
- pytest.param(lazy_fixture("local_async_broker"), id="local"),
- pytest.param(
- lazy_fixture("asyncpg_broker"),
- id="asyncpg",
- marks=[pytest.mark.external_service],
- ),
- pytest.param(
- lazy_fixture("async_redis_broker"),
- id="async_redis",
- marks=[pytest.mark.external_service],
- ),
- ]
-)
-async def raw_async_broker(request: SubRequest) -> AsyncEventBroker:
- return cast(AsyncEventBroker, request.param)
-
-
-@pytest.fixture
-async def async_broker(
- raw_async_broker: AsyncEventBroker,
-) -> AsyncGenerator[AsyncEventBroker, Any]:
- await raw_async_broker.start()
- yield raw_async_broker
- await raw_async_broker.stop()
-
-
-class TestEventBroker:
- def test_publish_subscribe(self, broker: EventBroker) -> None:
- queue: Queue[Event] = Queue()
- broker.subscribe(queue.put_nowait)
- broker.subscribe(queue.put_nowait)
- event = ScheduleAdded(
- schedule_id="schedule1",
- next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc),
- )
- broker.publish(event)
- event1 = queue.get(timeout=3)
- event2 = queue.get(timeout=1)
-
- assert event1 == event2
- assert isinstance(event1, ScheduleAdded)
- assert isinstance(event1.timestamp, datetime)
- assert event1.schedule_id == "schedule1"
- assert event1.next_fire_time == datetime(
- 2021, 9, 11, 12, 31, 56, 254867, timezone.utc
- )
-
- def test_subscribe_one_shot(self, broker: EventBroker) -> None:
- queue: Queue[Event] = Queue()
- broker.subscribe(queue.put_nowait, one_shot=True)
- event = ScheduleAdded(
- schedule_id="schedule1",
- next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc),
- )
- broker.publish(event)
- event = ScheduleAdded(
- schedule_id="schedule2",
- next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc),
- )
- broker.publish(event)
- received_event = queue.get(timeout=3)
- with pytest.raises(Empty):
- queue.get(timeout=0.1)
-
- assert isinstance(received_event, ScheduleAdded)
- assert received_event.schedule_id == "schedule1"
-
- def test_unsubscribe(self, broker: EventBroker, caplog) -> None:
- queue: Queue[Event] = Queue()
- subscription = broker.subscribe(queue.put_nowait)
- broker.publish(Event())
- queue.get(timeout=3)
-
- subscription.unsubscribe()
- broker.publish(Event())
- with pytest.raises(Empty):
- queue.get(timeout=0.1)
-
- def test_publish_no_subscribers(
- self, broker: EventBroker, caplog: LogCaptureFixture
- ) -> None:
- broker.publish(Event())
- assert not caplog.text
-
- def test_publish_exception(
- self, broker: EventBroker, caplog: LogCaptureFixture
- ) -> None:
- def bad_subscriber(event: Event) -> None:
- raise Exception("foo")
-
- timestamp = datetime.now(timezone.utc)
- event_future: Future[Event] = Future()
- broker.subscribe(bad_subscriber)
- broker.subscribe(event_future.set_result)
- broker.publish(Event(timestamp=timestamp))
-
- event = event_future.result(3)
- assert isinstance(event, Event)
- assert event.timestamp == timestamp
- assert "Error delivering Event" in caplog.text
+ await event_broker.publish(event)
+ event = ScheduleAdded(
+ schedule_id="schedule2",
+ next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc),
+ )
+ await event_broker.publish(event)
+ with fail_after(3):
+ received_event = await receive.receive()
-@pytest.mark.anyio
-class TestAsyncEventBroker:
- async def test_publish_subscribe(self, async_broker: AsyncEventBroker) -> None:
- send, receive = create_memory_object_stream(2)
- async_broker.subscribe(send.send)
- async_broker.subscribe(send.send_nowait)
- event = ScheduleAdded(
- schedule_id="schedule1",
- next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc),
- )
- await async_broker.publish(event)
+ with pytest.raises(TimeoutError), fail_after(0.1):
+ await receive.receive()
- with fail_after(3):
- event1 = await receive.receive()
- event2 = await receive.receive()
+ assert isinstance(received_event, ScheduleAdded)
+ assert received_event.schedule_id == "schedule1"
- assert event1 == event2
- assert isinstance(event1, ScheduleAdded)
- assert isinstance(event1.timestamp, datetime)
- assert event1.schedule_id == "schedule1"
- assert event1.next_fire_time == datetime(
- 2021, 9, 11, 12, 31, 56, 254867, timezone.utc
- )
- async def test_subscribe_one_shot(self, async_broker: AsyncEventBroker) -> None:
- send, receive = create_memory_object_stream(2)
- async_broker.subscribe(send.send, one_shot=True)
- event = ScheduleAdded(
- schedule_id="schedule1",
- next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc),
- )
- await async_broker.publish(event)
- event = ScheduleAdded(
- schedule_id="schedule2",
- next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc),
- )
- await async_broker.publish(event)
+async def test_unsubscribe(event_broker: EventBroker) -> None:
+ send, receive = create_memory_object_stream()
+ subscription = event_broker.subscribe(send.send)
+ await event_broker.publish(Event())
+ with fail_after(3):
+ await receive.receive()
- with fail_after(3):
- received_event = await receive.receive()
+ subscription.unsubscribe()
+ await event_broker.publish(Event())
+ with pytest.raises(TimeoutError), fail_after(0.1):
+ await receive.receive()
- with pytest.raises(TimeoutError), fail_after(0.1):
- await receive.receive()
- assert isinstance(received_event, ScheduleAdded)
- assert received_event.schedule_id == "schedule1"
+async def test_publish_no_subscribers(event_broker, caplog: LogCaptureFixture) -> None:
+ await event_broker.publish(Event())
+ assert not caplog.text
- async def test_unsubscribe(self, async_broker: AsyncEventBroker) -> None:
- send, receive = create_memory_object_stream()
- subscription = async_broker.subscribe(send.send)
- await async_broker.publish(Event())
- with fail_after(3):
- await receive.receive()
- subscription.unsubscribe()
- await async_broker.publish(Event())
- with pytest.raises(TimeoutError), fail_after(0.1):
- await receive.receive()
+async def test_publish_exception(event_broker, caplog: LogCaptureFixture) -> None:
+ def bad_subscriber(event: Event) -> None:
+ raise Exception("foo")
- async def test_publish_no_subscribers(
- self, async_broker: AsyncEventBroker, caplog: LogCaptureFixture
- ) -> None:
- await async_broker.publish(Event())
- assert not caplog.text
+ timestamp = datetime.now(timezone.utc)
+ send, receive = create_memory_object_stream()
+ event_broker.subscribe(bad_subscriber)
+ event_broker.subscribe(send.send)
+ await event_broker.publish(Event(timestamp=timestamp))
- async def test_publish_exception(
- self, async_broker: AsyncEventBroker, caplog: LogCaptureFixture
- ) -> None:
- def bad_subscriber(event: Event) -> None:
- raise Exception("foo")
+ received_event = await receive.receive()
+ assert received_event.timestamp == timestamp
+ assert "Error delivering Event" in caplog.text
- timestamp = datetime.now(timezone.utc)
- send, receive = create_memory_object_stream()
- async_broker.subscribe(bad_subscriber)
- async_broker.subscribe(send.send)
- await async_broker.publish(Event(timestamp=timestamp))
- received_event = await receive.receive()
- assert received_event.timestamp == timestamp
- assert "Error delivering Event" in caplog.text
+async def test_cancel_start(raw_event_broker: EventBroker) -> None:
+ with CancelScope() as scope:
+ scope.cancel()
+ async with AsyncExitStack() as exit_stack:
+ await raw_event_broker.start(exit_stack)
- async def test_cancel_start(self, raw_async_broker: AsyncEventBroker) -> None:
- with CancelScope() as scope:
- scope.cancel()
- await raw_async_broker.start()
- await raw_async_broker.stop()
- async def test_cancel_stop(self, raw_async_broker: AsyncEventBroker) -> None:
- with CancelScope() as scope:
- await raw_async_broker.start()
+async def test_cancel_stop(raw_event_broker: EventBroker) -> None:
+ with CancelScope() as scope:
+ async with AsyncExitStack() as exit_stack:
+ await raw_event_broker.start(exit_stack)
scope.cancel()
- await raw_async_broker.stop()
diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py
index 32c1c29..30c8dc9 100644
--- a/tests/test_schedulers.py
+++ b/tests/test_schedulers.py
@@ -26,15 +26,14 @@ from apscheduler import (
SchedulerStopped,
Task,
TaskAdded,
+ current_async_scheduler,
current_job,
current_scheduler,
- current_worker,
)
from apscheduler.schedulers.async_ import AsyncScheduler
from apscheduler.schedulers.sync import Scheduler
from apscheduler.triggers.date import DateTrigger
from apscheduler.triggers.interval import IntervalTrigger
-from apscheduler.workers.async_ import AsyncWorker
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
@@ -70,7 +69,7 @@ class TestAsyncScheduler:
received_events: list[Event] = []
event = anyio.Event()
trigger = DateTrigger(datetime.now(timezone.utc))
- async with AsyncScheduler(start_worker=False) as scheduler:
+ async with AsyncScheduler(process_jobs=False) as scheduler:
scheduler.event_broker.subscribe(listener)
await scheduler.add_schedule(dummy_async_job, trigger, id="foo")
await scheduler.start_in_background()
@@ -111,7 +110,7 @@ class TestAsyncScheduler:
assert not received_events
async def test_add_get_schedule(self) -> None:
- async with AsyncScheduler(start_worker=False) as scheduler:
+ async with AsyncScheduler(process_jobs=False) as scheduler:
with pytest.raises(ScheduleLookupError):
await scheduler.get_schedule("dummyid")
@@ -121,7 +120,7 @@ class TestAsyncScheduler:
assert isinstance(schedule, Schedule)
async def test_add_get_schedules(self) -> None:
- async with AsyncScheduler(start_worker=False) as scheduler:
+ async with AsyncScheduler(process_jobs=False) as scheduler:
assert await scheduler.get_schedules() == []
schedule1_id = await scheduler.add_schedule(
@@ -161,7 +160,7 @@ class TestAsyncScheduler:
orig_start_time = datetime.now(timezone) - timedelta(seconds=1)
fake_uniform = mocker.patch("random.uniform")
fake_uniform.configure_mock(side_effect=lambda a, b: jitter)
- async with AsyncScheduler(start_worker=False) as scheduler:
+ async with AsyncScheduler(process_jobs=False) as scheduler:
trigger = IntervalTrigger(seconds=3, start_time=orig_start_time)
job_added_event = anyio.Event()
scheduler.event_broker.subscribe(job_added_listener, {JobAdded})
@@ -263,8 +262,7 @@ class TestAsyncScheduler:
async def test_contextvars(self) -> None:
def check_contextvars() -> None:
- assert current_scheduler.get() is scheduler
- assert isinstance(current_worker.get(), AsyncWorker)
+ assert current_async_scheduler.get() is scheduler
info = current_job.get()
assert info.task_id == "task_id"
assert info.schedule_id == "foo"
@@ -277,7 +275,7 @@ class TestAsyncScheduler:
start_deadline = datetime.now(timezone.utc) + timedelta(seconds=10)
async with AsyncScheduler() as scheduler:
await scheduler.data_store.add_task(
- Task(id="task_id", func=check_contextvars)
+ Task(id="task_id", func=check_contextvars, executor="async")
)
job = Job(
task_id="task_id",
@@ -300,10 +298,7 @@ class TestAsyncScheduler:
async def test_wait_until_stopped(self) -> None:
async with AsyncScheduler() as scheduler:
- trigger = DateTrigger(
- datetime.now(timezone.utc) + timedelta(milliseconds=100)
- )
- await scheduler.add_schedule(scheduler.stop, trigger)
+ await scheduler.add_job(scheduler.stop)
await scheduler.wait_until_stopped()
# This should be a no-op
@@ -422,7 +417,7 @@ class TestSyncScheduler:
# Check that the job was created with the proper amount of jitter in its
# scheduled time
- jobs = scheduler.data_store.get_jobs({job_id})
+ jobs = scheduler._portal.call(scheduler.data_store.get_jobs, {job_id})
assert jobs[0].jitter == timedelta(seconds=jitter)
assert jobs[0].scheduled_fire_time == orig_start_time + timedelta(
seconds=jitter
@@ -495,7 +490,6 @@ class TestSyncScheduler:
def test_contextvars(self) -> None:
def check_contextvars() -> None:
assert current_scheduler.get() is scheduler
- assert current_worker.get() is not None
info = current_job.get()
assert info.task_id == "task_id"
assert info.schedule_id == "foo"
@@ -507,7 +501,10 @@ class TestSyncScheduler:
scheduled_fire_time = datetime.now(timezone.utc)
start_deadline = datetime.now(timezone.utc) + timedelta(seconds=10)
with Scheduler() as scheduler:
- scheduler.data_store.add_task(Task(id="task_id", func=check_contextvars))
+ scheduler._portal.call(
+ scheduler.data_store.add_task,
+ Task(id="task_id", func=check_contextvars, executor="threadpool"),
+ )
job = Job(
task_id="task_id",
schedule_id="foo",
@@ -517,7 +514,7 @@ class TestSyncScheduler:
tags={"foo", "bar"},
result_expiration_time=timedelta(seconds=10),
)
- scheduler.data_store.add_job(job)
+ scheduler._portal.call(scheduler.data_store.add_job, job)
scheduler.start_in_background()
result = scheduler.get_job_result(job.id)
if result.outcome is JobOutcome.error:
@@ -527,10 +524,7 @@ class TestSyncScheduler:
def test_wait_until_stopped(self) -> None:
with Scheduler() as scheduler:
- trigger = DateTrigger(
- datetime.now(timezone.utc) + timedelta(milliseconds=100)
- )
- scheduler.add_schedule(scheduler.stop, trigger)
+ scheduler.add_job(scheduler.stop)
scheduler.start_in_background()
scheduler.wait_until_stopped()
diff --git a/tests/test_workers.py b/tests/test_workers.py
deleted file mode 100644
index aecc63b..0000000
--- a/tests/test_workers.py
+++ /dev/null
@@ -1,281 +0,0 @@
-from __future__ import annotations
-
-import threading
-from datetime import datetime, timezone
-from typing import Callable
-
-import anyio
-import pytest
-from anyio import fail_after
-
-from apscheduler import (
- Event,
- Job,
- JobAcquired,
- JobAdded,
- JobOutcome,
- JobReleased,
- Task,
- TaskAdded,
- WorkerStopped,
-)
-from apscheduler.datastores.memory import MemoryDataStore
-from apscheduler.workers.async_ import AsyncWorker
-from apscheduler.workers.sync import Worker
-
-pytestmark = pytest.mark.anyio
-
-
-def sync_func(*args, fail: bool, **kwargs):
- if fail:
- raise Exception("failing as requested")
- else:
- return args, kwargs
-
-
-async def async_func(*args, fail: bool, **kwargs):
- if fail:
- raise Exception("failing as requested")
- else:
- return args, kwargs
-
-
-def fail_func():
- pytest.fail("This function should never be run")
-
-
-class TestAsyncWorker:
- @pytest.mark.parametrize(
- "target_func", [sync_func, async_func], ids=["sync", "async"]
- )
- @pytest.mark.parametrize("fail", [False, True], ids=["success", "fail"])
- async def test_run_job_nonscheduled_success(
- self, target_func: Callable, fail: bool
- ) -> None:
- def listener(received_event: Event):
- received_events.append(received_event)
- if isinstance(received_event, JobReleased):
- event.set()
-
- received_events: list[Event] = []
- event = anyio.Event()
- async with AsyncWorker(MemoryDataStore()) as worker:
- worker.event_broker.subscribe(listener)
- await worker.data_store.add_task(Task(id="task_id", func=target_func))
- job = Job(task_id="task_id", args=(1, 2), kwargs={"x": "foo", "fail": fail})
- await worker.data_store.add_job(job)
- with fail_after(3):
- await event.wait()
-
- # First, a task was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "task_id"
-
- # Then a job was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAdded)
- assert received_event.job_id == job.id
- assert received_event.task_id == "task_id"
- assert received_event.schedule_id is None
-
- # Then the job was started
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAcquired)
- assert received_event.job_id == job.id
- assert received_event.worker_id == worker.identity
-
- received_event = received_events.pop(0)
- if fail:
- # Then the job failed
- assert isinstance(received_event, JobReleased)
- assert received_event.outcome is JobOutcome.error
- assert received_event.exception_type == "Exception"
- assert received_event.exception_message == "failing as requested"
- assert isinstance(received_event.exception_traceback, list)
- assert all(
- isinstance(line, str) for line in received_event.exception_traceback
- )
- else:
- # Then the job finished successfully
- assert isinstance(received_event, JobReleased)
- assert received_event.outcome is JobOutcome.success
- assert received_event.exception_type is None
- assert received_event.exception_message is None
- assert received_event.exception_traceback is None
-
- # Finally, the worker was stopped
- received_event = received_events.pop(0)
- assert isinstance(received_event, WorkerStopped)
-
- # There should be no more events on the list
- assert not received_events
-
- async def test_run_deadline_missed(self) -> None:
- def listener(received_event: Event):
- received_events.append(received_event)
- if isinstance(received_event, JobReleased):
- event.set()
-
- scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc)
- received_events: list[Event] = []
- event = anyio.Event()
- async with AsyncWorker(MemoryDataStore()) as worker:
- worker.event_broker.subscribe(listener)
- await worker.data_store.add_task(Task(id="task_id", func=fail_func))
- job = Job(
- task_id="task_id",
- schedule_id="foo",
- scheduled_fire_time=scheduled_start_time,
- start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc),
- )
- await worker.data_store.add_job(job)
- with fail_after(3):
- await event.wait()
-
- # First, a task was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "task_id"
-
- # Then a job was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAdded)
- assert received_event.job_id == job.id
- assert received_event.task_id == "task_id"
- assert received_event.schedule_id == "foo"
-
- # The worker acquired the job
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAcquired)
- assert received_event.job_id == job.id
- assert received_event.worker_id == worker.identity
-
- # The worker determined that the deadline has been missed
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobReleased)
- assert received_event.outcome is JobOutcome.missed_start_deadline
- assert received_event.job_id == job.id
- assert received_event.worker_id == worker.identity
-
- # Finally, the worker was stopped
- received_event = received_events.pop(0)
- assert isinstance(received_event, WorkerStopped)
-
- # There should be no more events on the list
- assert not received_events
-
-
-class TestSyncWorker:
- @pytest.mark.parametrize("fail", [False, True], ids=["success", "fail"])
- def test_run_job_nonscheduled(self, fail: bool) -> None:
- def listener(received_event: Event):
- received_events.append(received_event)
- if isinstance(received_event, JobReleased):
- event.set()
-
- received_events: list[Event] = []
- event = threading.Event()
- with Worker(MemoryDataStore()) as worker:
- worker.event_broker.subscribe(listener)
- worker.data_store.add_task(Task(id="task_id", func=sync_func))
- job = Job(task_id="task_id", args=(1, 2), kwargs={"x": "foo", "fail": fail})
- worker.data_store.add_job(job)
- event.wait(3)
-
- # First, a task was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "task_id"
-
- # Then a job was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAdded)
- assert received_event.job_id == job.id
- assert received_event.task_id == "task_id"
- assert received_event.schedule_id is None
-
- # Then the job was started
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAcquired)
- assert received_event.job_id == job.id
- assert received_event.worker_id == worker.identity
-
- received_event = received_events.pop(0)
- if fail:
- # Then the job failed
- assert isinstance(received_event, JobReleased)
- assert received_event.outcome is JobOutcome.error
- assert received_event.exception_type == "Exception"
- assert received_event.exception_message == "failing as requested"
- assert isinstance(received_event.exception_traceback, list)
- assert all(
- isinstance(line, str) for line in received_event.exception_traceback
- )
- else:
- # Then the job finished successfully
- assert isinstance(received_event, JobReleased)
- assert received_event.outcome is JobOutcome.success
- assert received_event.exception_type is None
- assert received_event.exception_message is None
- assert received_event.exception_traceback is None
-
- # Finally, the worker was stopped
- received_event = received_events.pop(0)
- assert isinstance(received_event, WorkerStopped)
-
- # There should be no more events on the list
- assert not received_events
-
- def test_run_deadline_missed(self) -> None:
- def listener(received_event: Event):
- received_events.append(received_event)
- if isinstance(received_event, JobReleased):
- event.set()
-
- scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc)
- received_events: list[Event] = []
- event = threading.Event()
- with Worker(MemoryDataStore()) as worker:
- worker.event_broker.subscribe(listener)
- worker.data_store.add_task(Task(id="task_id", func=fail_func))
- job = Job(
- task_id="task_id",
- schedule_id="foo",
- scheduled_fire_time=scheduled_start_time,
- start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc),
- )
- worker.data_store.add_job(job)
- event.wait(3)
-
- # First, a task was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, TaskAdded)
- assert received_event.task_id == "task_id"
-
- # Then a job was added
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAdded)
- assert received_event.job_id == job.id
- assert received_event.task_id == "task_id"
- assert received_event.schedule_id == "foo"
-
- # The worker acquired the job
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobAcquired)
- assert received_event.job_id == job.id
- assert received_event.worker_id == worker.identity
-
- # The worker determined that the deadline has been missed
- received_event = received_events.pop(0)
- assert isinstance(received_event, JobReleased)
- assert received_event.outcome is JobOutcome.missed_start_deadline
- assert received_event.job_id == job.id
- assert received_event.worker_id == worker.identity
-
- # Finally, the worker was stopped
- received_event = received_events.pop(0)
- assert isinstance(received_event, WorkerStopped)
-
- # There should be no more events on the list
- assert not received_events