diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-21 18:30:31 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-22 23:55:40 +0300 |
commit | 310652119957194d3c8cac91bf6bf171b647a103 (patch) | |
tree | 045878b164e986f2deae3bda3982473d9d03ce3c | |
parent | 191a9663c6fd2c65f7b524c59285dec5ac747ee7 (diff) | |
download | apscheduler-310652119957194d3c8cac91bf6bf171b647a103.tar.gz |
Refactored scheduler and worker classes to use attrs
-rw-r--r-- | src/apscheduler/converters.py | 10 | ||||
-rw-r--r-- | src/apscheduler/schedulers/async_.py | 48 | ||||
-rw-r--r-- | src/apscheduler/schedulers/sync.py | 35 | ||||
-rw-r--r-- | src/apscheduler/validators.py | 6 | ||||
-rw-r--r-- | src/apscheduler/workers/async_.py | 47 | ||||
-rw-r--r-- | src/apscheduler/workers/sync.py | 40 |
6 files changed, 101 insertions, 85 deletions
diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py index c664bc0..103ab35 100644 --- a/src/apscheduler/converters.py +++ b/src/apscheduler/converters.py @@ -5,6 +5,8 @@ from enum import Enum from typing import Optional from uuid import UUID +from . import abc + def as_aware_datetime(value: datetime | str) -> Optional[datetime]: """Convert the value from a string to a timezone aware datetime.""" @@ -41,3 +43,11 @@ def as_enum(enum_class: type[Enum]): return value return converter + + +def as_async_datastore(value: abc.DataStore | abc.AsyncDataStore) -> abc.AsyncDataStore: + if isinstance(value, abc.DataStore): + from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter + return AsyncDataStoreAdapter(value) + + return value diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index a10e772..751eb2c 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -10,12 +10,12 @@ from typing import Any, Callable, Iterable, Mapping, Optional from uuid import UUID, uuid4 import anyio +import attr from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class, move_on_after -from anyio.abc import TaskGroup -from ..abc import AsyncDataStore, DataStore, EventSource, Job, Schedule, Trigger +from ..abc import AsyncDataStore, EventSource, Job, Schedule, Trigger from ..context import current_scheduler -from ..datastores.async_adapter import AsyncDataStoreAdapter +from ..converters import as_async_datastore from ..datastores.memory import MemoryDataStore from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker @@ -30,29 +30,24 @@ _microsecond_delta = timedelta(microseconds=1) _zero_timedelta = timedelta() +@attr.define(eq=False) class AsyncScheduler: """An asynchronous (AnyIO based) scheduler implementation.""" - data_store: AsyncDataStore - _state: RunState = RunState.stopped - _wakeup_event: anyio.Event - _worker: Optional[AsyncWorker] = None - _task_group: Optional[TaskGroup] = None - - def __init__(self, data_store: DataStore | AsyncDataStore | None = None, *, - identity: Optional[str] = None, logger: Optional[Logger] = None, - start_worker: bool = True): - self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' - self.logger = logger or getLogger(__name__) - self.start_worker = start_worker - self._exit_stack = AsyncExitStack() - self._events = LocalAsyncEventBroker() + data_store: AsyncDataStore = attr.field(converter=as_async_datastore, factory=MemoryDataStore) + identity: str = attr.field(kw_only=True, default=None) + start_worker: bool = attr.field(kw_only=True, default=True) + logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__)) - data_store = data_store or MemoryDataStore() - if isinstance(data_store, DataStore): - self.data_store = AsyncDataStoreAdapter(data_store) - else: - self.data_store = data_store + _state: RunState = attr.field(init=False, default=RunState.stopped) + _wakeup_event: anyio.Event = attr.field(init=False) + _worker: Optional[AsyncWorker] = attr.field(init=False, default=None) + _events: LocalAsyncEventBroker = attr.field(init=False, factory=LocalAsyncEventBroker) + _exit_stack: AsyncExitStack = attr.field(init=False) + + def __attrs_post_init__(self) -> None: + if not self.identity: + self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}' @property def events(self) -> EventSource: @@ -65,6 +60,7 @@ class AsyncScheduler: async def __aenter__(self): self._state = RunState.starting self._wakeup_event = anyio.Event() + self._exit_stack = AsyncExitStack() await self._exit_stack.__aenter__() await self._exit_stack.enter_async_context(self._events) @@ -89,16 +85,16 @@ class AsyncScheduler: current_scheduler.reset(token) # Start the worker and return when it has signalled readiness or raised an exception - self._task_group = create_task_group() - await self._exit_stack.enter_async_context(self._task_group) - await self._task_group.start(self.run) + task_group = create_task_group() + await self._exit_stack.enter_async_context(task_group) + await task_group.start(self.run) return self async def __aexit__(self, exc_type, exc_val, exc_tb): self._state = RunState.stopping self._wakeup_event.set() await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - del self._task_group + self._state = RunState.stopped del self._wakeup_event async def add_schedule( diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index 2af4f8e..c49359c 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -11,6 +11,8 @@ from logging import Logger, getLogger from typing import Any, Callable, Iterable, Mapping, Optional from uuid import UUID, uuid4 +import attr + from ..abc import DataStore, EventSource, Trigger from ..context import current_scheduler from ..datastores.memory import MemoryDataStore @@ -27,22 +29,24 @@ _microsecond_delta = timedelta(microseconds=1) _zero_timedelta = timedelta() +@attr.define(eq=False) class Scheduler: """A synchronous scheduler implementation.""" - _state: RunState = RunState.stopped - _wakeup_event: threading.Event - _worker: Optional[Worker] = None + data_store: DataStore = attr.field(factory=MemoryDataStore) + identity: str = attr.field(kw_only=True, default=None) + start_worker: bool = attr.field(kw_only=True, default=True) + logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__)) - def __init__(self, data_store: Optional[DataStore] = None, *, identity: Optional[str] = None, - logger: Optional[Logger] = None, start_worker: bool = True): - self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' - self.logger = logger or getLogger(__name__) - self.start_worker = start_worker - self.data_store = data_store or MemoryDataStore() - self._exit_stack = ExitStack() - self._executor = ThreadPoolExecutor(max_workers=1) - self._events = LocalEventBroker() + _state: RunState = attr.field(init=False, default=RunState.stopped) + _wakeup_event: threading.Event = attr.field(init=False) + _worker: Optional[Worker] = attr.field(init=False, default=None) + _events: LocalEventBroker = attr.field(init=False, factory=LocalEventBroker) + _exit_stack: ExitStack = attr.field(init=False) + + def __attrs_post_init__(self) -> None: + if not self.identity: + self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}' @property def events(self) -> EventSource: @@ -59,6 +63,7 @@ class Scheduler: def __enter__(self) -> Scheduler: self._state = RunState.starting self._wakeup_event = threading.Event() + self._exit_stack = ExitStack() self._exit_stack.__enter__() self._exit_stack.enter_context(self._events) @@ -85,7 +90,9 @@ class Scheduler: # Start the scheduler and return when it has signalled readiness or raised an exception start_future: Future[Event] = Future() with self._events.subscribe(start_future.set_result, one_shot=True): - run_future = self._executor.submit(self.run) + executor = ThreadPoolExecutor(1) + self._exit_stack.push(lambda exc_type, *args: executor.shutdown(wait=exc_type is None)) + run_future = executor.submit(self.run) wait([start_future, run_future], return_when=FIRST_COMPLETED) if run_future.done(): @@ -96,8 +103,8 @@ class Scheduler: def __exit__(self, exc_type, exc_val, exc_tb): self._state = RunState.stopping self._wakeup_event.set() - self._executor.shutdown(wait=exc_type is None) self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + self._state = RunState.stopped del self._wakeup_event def add_schedule( diff --git a/src/apscheduler/validators.py b/src/apscheduler/validators.py index c71f73a..95c3747 100644 --- a/src/apscheduler/validators.py +++ b/src/apscheduler/validators.py @@ -4,6 +4,7 @@ import sys from datetime import date, datetime, timedelta, timezone, tzinfo from typing import Any, Optional +import attr from attr import Attribute from tzlocal import get_localzone @@ -163,3 +164,8 @@ def require_state_version(trigger: Trigger, state: dict[str, Any], max_version: ) except KeyError as exc: raise DeserializationError('Missing "version" key in the serialized state') from exc + + +def positive_integer(inst, field: attr.Attribute, value) -> None: + if value <= 0: + raise ValueError(f'{field} must be a positive integer') diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py index e037a05..723b004 100644 --- a/src/apscheduler/workers/async_.py +++ b/src/apscheduler/workers/async_.py @@ -10,44 +10,38 @@ from typing import Callable, Optional from uuid import UUID import anyio +import attr from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class, move_on_after from anyio.abc import CancelScope -from ..abc import AsyncDataStore, DataStore, EventSource, Job +from ..abc import AsyncDataStore, EventSource, Job from ..context import current_worker, job_info -from ..datastores.async_adapter import AsyncDataStoreAdapter +from ..converters import as_async_datastore from ..enums import JobOutcome, RunState from ..eventbrokers.async_local import LocalAsyncEventBroker from ..events import JobAdded, WorkerStarted, WorkerStopped from ..structures import JobInfo, JobResult +from ..validators import positive_integer +@attr.define(eq=False) class AsyncWorker: """Runs jobs locally in a task group.""" - - _stop_event: Optional[anyio.Event] = None - _state: RunState = RunState.stopped - _acquire_cancel_scope: Optional[CancelScope] = None - _wakeup_event: anyio.Event - - def __init__(self, data_store: DataStore | AsyncDataStore, *, - max_concurrent_jobs: int = 100, identity: Optional[str] = None, - logger: Optional[Logger] = None): - self.max_concurrent_jobs = max_concurrent_jobs - self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' - self.logger = logger or getLogger(__name__) - self._acquired_jobs: set[Job] = set() - self._exit_stack = AsyncExitStack() - self._events = LocalAsyncEventBroker() - self._running_jobs: set[UUID] = set() - - if self.max_concurrent_jobs < 1: - raise ValueError('max_concurrent_jobs must be at least 1') - - if isinstance(data_store, DataStore): - self.data_store = AsyncDataStoreAdapter(data_store) - else: - self.data_store = data_store + data_store: AsyncDataStore = attr.field(converter=as_async_datastore) + max_concurrent_jobs: int = attr.field(kw_only=True, validator=positive_integer, default=100) + identity: str = attr.field(kw_only=True, default=None) + logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__)) + + _state: RunState = attr.field(init=False, default=RunState.stopped) + _wakeup_event: anyio.Event = attr.field(init=False, factory=anyio.Event) + _acquired_jobs: set[Job] = attr.field(init=False, factory=set) + _events: LocalAsyncEventBroker = attr.field(init=False, factory=LocalAsyncEventBroker) + _running_jobs: set[UUID] = attr.field(init=False, factory=set) + _exit_stack: AsyncExitStack = attr.field(init=False) + + def __attrs_post_init__(self) -> None: + if not self.identity: + self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}' @property def events(self) -> EventSource: @@ -60,6 +54,7 @@ class AsyncWorker: async def __aenter__(self): self._state = RunState.starting self._wakeup_event = anyio.Event() + self._exit_stack = AsyncExitStack() await self._exit_stack.__aenter__() await self._exit_stack.enter_async_context(self._events) diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 0afa8c7..4e3eef4 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -11,35 +11,36 @@ from logging import Logger, getLogger from typing import Callable, Optional from uuid import UUID +import attr + from ..abc import DataStore, EventSource from ..context import current_worker, job_info from ..enums import JobOutcome, RunState from ..eventbrokers.local import LocalEventBroker from ..events import JobAdded, WorkerStarted, WorkerStopped from ..structures import Job, JobInfo, JobResult +from ..validators import positive_integer +@attr.define(eq=False) class Worker: """Runs jobs locally in a thread pool.""" - - _executor: ThreadPoolExecutor - _state: RunState = RunState.stopped - _wakeup_event: threading.Event - - def __init__(self, data_store: DataStore, *, max_concurrent_jobs: int = 20, - identity: Optional[str] = None, logger: Optional[Logger] = None): - self.max_concurrent_jobs = max_concurrent_jobs - self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' - self.logger = logger or getLogger(__name__) - self._acquired_jobs: set[Job] = set() - self._exit_stack = ExitStack() - self._events = LocalEventBroker() - self._running_jobs: set[UUID] = set() - - if self.max_concurrent_jobs < 1: - raise ValueError('max_concurrent_jobs must be at least 1') - - self.data_store = data_store + data_store: DataStore + max_concurrent_jobs: int = attr.field(kw_only=True, validator=positive_integer, default=20) + identity: str = attr.field(kw_only=True, default=None) + logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__)) + + _state: RunState = attr.field(init=False, default=RunState.stopped) + _wakeup_event: threading.Event = attr.field(init=False) + _acquired_jobs: set[Job] = attr.field(init=False, factory=set) + _events: LocalEventBroker = attr.field(init=False, factory=LocalEventBroker) + _running_jobs: set[UUID] = attr.field(init=False, factory=set) + _exit_stack: ExitStack = attr.field(init=False) + _executor: ThreadPoolExecutor = attr.field(init=False) + + def __attrs_post_init__(self) -> None: + if not self.identity: + self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}' @property def events(self) -> EventSource: @@ -52,6 +53,7 @@ class Worker: def __enter__(self) -> Worker: self._state = RunState.starting self._wakeup_event = threading.Event() + self._exit_stack = ExitStack() self._exit_stack.__enter__() self._exit_stack.enter_context(self._events) |