summaryrefslogtreecommitdiff
path: root/src/apscheduler/schedulers/async_.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/schedulers/async_.py')
-rw-r--r--src/apscheduler/schedulers/async_.py165
1 files changed, 120 insertions, 45 deletions
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index 75972b5..44fee27 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -5,6 +5,7 @@ import platform
import random
import sys
from asyncio import CancelledError
+from collections.abc import MutableMapping
from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
@@ -16,9 +17,9 @@ import anyio
import attrs
from anyio import TASK_STATUS_IGNORED, create_task_group, move_on_after
from anyio.abc import TaskGroup, TaskStatus
+from attr.validators import instance_of
-from .._context import current_scheduler
-from .._converters import as_async_datastore, as_async_eventbroker
+from .._context import current_async_scheduler
from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
from .._events import (
Event,
@@ -35,11 +36,14 @@ from .._exceptions import (
ScheduleLookupError,
)
from .._structures import Job, JobResult, Schedule, Task
-from ..abc import AsyncDataStore, AsyncEventBroker, Subscription, Trigger
+from .._worker import Worker
+from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger
from ..datastores.memory import MemoryDataStore
-from ..eventbrokers.async_local import LocalAsyncEventBroker
+from ..eventbrokers.local import LocalEventBroker
+from ..executors.async_ import AsyncJobExecutor
+from ..executors.subprocess import ProcessPoolJobExecutor
+from ..executors.thread import ThreadPoolJobExecutor
from ..marshalling import callable_to_ref
-from ..workers.async_ import AsyncWorker
if sys.version_info >= (3, 11):
from typing import Self
@@ -54,19 +58,24 @@ _zero_timedelta = timedelta()
class AsyncScheduler:
"""An asynchronous (AnyIO based) scheduler implementation."""
- data_store: AsyncDataStore = attrs.field(
- converter=as_async_datastore, factory=MemoryDataStore
+ data_store: DataStore = attrs.field(
+ validator=instance_of(DataStore), factory=MemoryDataStore
)
- event_broker: AsyncEventBroker = attrs.field(
- converter=as_async_eventbroker, factory=LocalAsyncEventBroker
+ event_broker: EventBroker = attrs.field(
+ validator=instance_of(EventBroker), factory=LocalEventBroker
)
identity: str = attrs.field(kw_only=True, default=None)
- start_worker: bool = attrs.field(kw_only=True, default=True)
+ process_jobs: bool = attrs.field(kw_only=True, default=True)
+ job_executors: MutableMapping[str, JobExecutor] | None = attrs.field(
+ kw_only=True, default=None
+ )
+ default_job_executor: str | None = attrs.field(kw_only=True, default=None)
+ process_schedules: bool = attrs.field(kw_only=True, default=True)
logger: Logger | None = attrs.field(kw_only=True, default=getLogger(__name__))
_state: RunState = attrs.field(init=False, default=RunState.stopped)
_task_group: TaskGroup | None = attrs.field(init=False, default=None)
- _exit_stack: AsyncExitStack | None = attrs.field(init=False, default=None)
+ _exit_stack: AsyncExitStack = attrs.field(init=False, factory=AsyncExitStack)
_services_initialized: bool = attrs.field(init=False, default=False)
_wakeup_event: anyio.Event = attrs.field(init=False)
_wakeup_deadline: datetime | None = attrs.field(init=False, default=None)
@@ -76,13 +85,32 @@ class AsyncScheduler:
if not self.identity:
self.identity = f"{platform.node()}-{os.getpid()}-{id(self)}"
+ if not self.job_executors:
+ self.job_executors = {
+ "async": AsyncJobExecutor(),
+ "threadpool": ThreadPoolJobExecutor(),
+ "processpool": ProcessPoolJobExecutor(),
+ }
+
+ if not self.default_job_executor:
+ self.default_job_executor = next(iter(self.job_executors))
+ elif self.default_job_executor not in self.job_executors:
+ raise ValueError(
+ "default_job_executor must be one of the given job executors"
+ )
+
async def __aenter__(self: Self) -> Self:
- self._exit_stack = AsyncExitStack()
await self._exit_stack.__aenter__()
- await self._ensure_services_ready(self._exit_stack)
- self._task_group = await self._exit_stack.enter_async_context(
- create_task_group()
- )
+ try:
+ await self._ensure_services_initialized(self._exit_stack)
+ self._task_group = await self._exit_stack.enter_async_context(
+ create_task_group()
+ )
+ self._exit_stack.callback(setattr, self, "_task_group", None)
+ except BaseException as exc:
+ await self._exit_stack.__aexit__(type(exc), exc, exc.__traceback__)
+ raise
+
return self
async def __aexit__(
@@ -93,27 +121,25 @@ class AsyncScheduler:
) -> None:
await self.stop()
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
- self._task_group = None
- async def _ensure_services_ready(self, exit_stack: AsyncExitStack) -> None:
+ async def _ensure_services_initialized(self, exit_stack: AsyncExitStack) -> None:
"""Ensure that the data store and event broker have been initialized."""
if not self._services_initialized:
self._services_initialized = True
exit_stack.callback(setattr, self, "_services_initialized", False)
- # Initialize the event broker
- await self.event_broker.start()
- exit_stack.push_async_exit(
- lambda *exc_info: self.event_broker.stop(force=exc_info[0] is not None)
- )
+ await self.event_broker.start(exit_stack)
+ await self.data_store.start(exit_stack, self.event_broker)
- # Initialize the data store
- await self.data_store.start(self.event_broker)
- exit_stack.push_async_exit(
- lambda *exc_info: self.data_store.stop(force=exc_info[0] is not None)
+ def _check_initialized(self) -> None:
+ if not self._services_initialized:
+ raise RuntimeError(
+ "The scheduler has not been initialized yet. Use the scheduler as an "
+ "async context manager (async with ...) in order to call methods other "
+ "than run_until_complete()."
)
- def _schedule_added_or_modified(self, event: Event) -> None:
+ async def _schedule_added_or_modified(self, event: Event) -> None:
event_ = cast("ScheduleAdded | ScheduleUpdated", event)
if not self._wakeup_deadline or (
event_.next_fire_time and event_.next_fire_time < self._wakeup_deadline
@@ -128,6 +154,35 @@ class AsyncScheduler:
"""The current running state of the scheduler."""
return self._state
+ def subscribe(
+ self,
+ callback: Callable[[Event], Any],
+ event_types: Iterable[type[Event]] | None = None,
+ *,
+ one_shot: bool = False,
+ is_async: bool = True,
+ ) -> Subscription:
+ """
+ Subscribe to events.
+
+ To unsubscribe, call the :meth:`Subscription.unsubscribe` method on the returned
+ object.
+
+ :param callback: callable to be called with the event object when an event is
+ published
+ :param event_types: an iterable of concrete Event classes to subscribe to
+ :param one_shot: if ``True``, automatically unsubscribe after the first matching
+ event
+ :param is_async: ``True`` if the (synchronous) callback should be called on the
+ event loop thread, ``False`` if it should be called in a worker thread.
+ If the callback is a coroutine function, this flag is ignored.
+
+ """
+ self._check_initialized()
+ return self.event_broker.subscribe(
+ callback, event_types, is_async=is_async, one_shot=one_shot
+ )
+
async def add_schedule(
self,
func_or_task_id: str | Callable,
@@ -136,6 +191,7 @@ class AsyncScheduler:
id: str | None = None,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
coalesce: CoalescePolicy = CoalescePolicy.latest,
misfire_grace_time: float | timedelta | None = None,
max_jitter: float | timedelta | None = None,
@@ -152,6 +208,7 @@ class AsyncScheduler:
based ID will be assigned)
:param args: positional arguments to be passed to the task function
:param kwargs: keyword arguments to be passed to the task function
+ :param job_executor: name of the job executor to run the task with
:param coalesce: determines what to do when processing the schedule if multiple
fire times have become due for this schedule since the last processing
:param misfire_grace_time: maximum number of seconds the scheduled job's actual
@@ -165,6 +222,7 @@ class AsyncScheduler:
:return: the ID of the newly added schedule
"""
+ self._check_initialized()
id = id or str(uuid4())
args = tuple(args or ())
kwargs = dict(kwargs or {})
@@ -173,7 +231,11 @@ class AsyncScheduler:
misfire_grace_time = timedelta(seconds=misfire_grace_time)
if callable(func_or_task_id):
- task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
+ task = Task(
+ id=callable_to_ref(func_or_task_id),
+ func=func_or_task_id,
+ executor=job_executor or self.default_job_executor,
+ )
await self.data_store.add_task(task)
else:
task = await self.data_store.get_task(func_or_task_id)
@@ -207,6 +269,7 @@ class AsyncScheduler:
:raises ScheduleLookupError: if the schedule could not be found
"""
+ self._check_initialized()
schedules = await self.data_store.get_schedules({id})
if schedules:
return schedules[0]
@@ -220,6 +283,7 @@ class AsyncScheduler:
:return: a list of schedules, in an unspecified order
"""
+ self._check_initialized()
return await self.data_store.get_schedules()
async def remove_schedule(self, id: str) -> None:
@@ -229,6 +293,7 @@ class AsyncScheduler:
:param id: the unique identifier of the schedule
"""
+ self._check_initialized()
await self.data_store.remove_schedules({id})
async def add_job(
@@ -237,6 +302,7 @@ class AsyncScheduler:
*,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
tags: Iterable[str] | None = None,
result_expiration_time: timedelta | float = 0,
) -> UUID:
@@ -244,8 +310,10 @@ class AsyncScheduler:
Add a job to the data store.
:param func_or_task_id:
+ :param job_executor: name of the job executor to run the task with
:param args: positional arguments to call the target callable with
:param kwargs: keyword arguments to call the target callable with
+ :param job_executor: name of the job executor to run the task with
:param tags: strings that can be used to categorize and filter the job
:param result_expiration_time: the minimum time (as seconds, or timedelta) to
keep the result of the job available for fetching (the result won't be
@@ -253,8 +321,13 @@ class AsyncScheduler:
:return: the ID of the newly created job
"""
+ self._check_initialized()
if callable(func_or_task_id):
- task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
+ task = Task(
+ id=callable_to_ref(func_or_task_id),
+ func=func_or_task_id,
+ executor=job_executor or self.default_job_executor,
+ )
await self.data_store.add_task(task)
else:
task = await self.data_store.get_task(func_or_task_id)
@@ -280,13 +353,14 @@ class AsyncScheduler:
the data store
"""
+ self._check_initialized()
wait_event = anyio.Event()
def listener(event: JobReleased) -> None:
if event.job_id == job_id:
wait_event.set()
- with self.data_store.events.subscribe(listener, {JobReleased}):
+ with self.event_broker.subscribe(listener, {JobReleased}):
result = await self.data_store.get_job_result(job_id)
if result:
return result
@@ -303,6 +377,7 @@ class AsyncScheduler:
*,
args: Iterable | None = None,
kwargs: Mapping[str, Any] | None = None,
+ job_executor: str | None = None,
tags: Iterable[str] | None = (),
) -> Any:
"""
@@ -314,10 +389,12 @@ class AsyncScheduler:
definition
:param args: positional arguments to be passed to the task function
:param kwargs: keyword arguments to be passed to the task function
+ :param job_executor: name of the job executor to run the task with
:param tags: strings that can be used to categorize and filter the job
:returns: the return value of the task function
"""
+ self._check_initialized()
job_complete_event = anyio.Event()
def listener(event: JobReleased) -> None:
@@ -325,11 +402,12 @@ class AsyncScheduler:
job_complete_event.set()
job_id: UUID | None = None
- with self.data_store.events.subscribe(listener, {JobReleased}):
+ with self.event_broker.subscribe(listener, {JobReleased}):
job_id = await self.add_job(
func_or_task_id,
args=args,
kwargs=kwargs,
+ job_executor=job_executor,
tags=tags,
result_expiration_time=timedelta(minutes=15),
)
@@ -378,12 +456,7 @@ class AsyncScheduler:
await event.wait()
async def start_in_background(self) -> None:
- if self._task_group is None:
- raise RuntimeError(
- "The scheduler must be used as an async context manager (async with "
- "...) in order to be startable in the background"
- )
-
+ self._check_initialized()
await self._task_group.start(self.run_until_stopped)
async def run_until_stopped(
@@ -398,7 +471,7 @@ class AsyncScheduler:
self._state = RunState.starting
async with AsyncExitStack() as exit_stack:
self._wakeup_event = anyio.Event()
- await self._ensure_services_ready(exit_stack)
+ await self._ensure_services_initialized(exit_stack)
# Wake up the scheduler if the data store emits a significant schedule event
exit_stack.enter_context(
@@ -407,14 +480,16 @@ class AsyncScheduler:
)
)
+ # Set this scheduler as the current scheduler
+ token = current_async_scheduler.set(self)
+ exit_stack.callback(current_async_scheduler.reset, token)
+
# Start the built-in worker, if configured to do so
- if self.start_worker:
- token = current_scheduler.set(self)
- exit_stack.callback(current_scheduler.reset, token)
- worker = AsyncWorker(
- self.data_store, self.event_broker, is_internal=True
+ if self.process_jobs:
+ worker = Worker(job_executors=self.job_executors)
+ await worker.start(
+ exit_stack, self.data_store, self.event_broker, self.identity
)
- await exit_stack.enter_async_context(worker)
# Signal that the scheduler has started
self._state = RunState.started