summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-21 18:30:31 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-22 23:55:40 +0300
commit310652119957194d3c8cac91bf6bf171b647a103 (patch)
tree045878b164e986f2deae3bda3982473d9d03ce3c
parent191a9663c6fd2c65f7b524c59285dec5ac747ee7 (diff)
downloadapscheduler-310652119957194d3c8cac91bf6bf171b647a103.tar.gz
Refactored scheduler and worker classes to use attrs
-rw-r--r--src/apscheduler/converters.py10
-rw-r--r--src/apscheduler/schedulers/async_.py48
-rw-r--r--src/apscheduler/schedulers/sync.py35
-rw-r--r--src/apscheduler/validators.py6
-rw-r--r--src/apscheduler/workers/async_.py47
-rw-r--r--src/apscheduler/workers/sync.py40
6 files changed, 101 insertions, 85 deletions
diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py
index c664bc0..103ab35 100644
--- a/src/apscheduler/converters.py
+++ b/src/apscheduler/converters.py
@@ -5,6 +5,8 @@ from enum import Enum
from typing import Optional
from uuid import UUID
+from . import abc
+
def as_aware_datetime(value: datetime | str) -> Optional[datetime]:
"""Convert the value from a string to a timezone aware datetime."""
@@ -41,3 +43,11 @@ def as_enum(enum_class: type[Enum]):
return value
return converter
+
+
+def as_async_datastore(value: abc.DataStore | abc.AsyncDataStore) -> abc.AsyncDataStore:
+ if isinstance(value, abc.DataStore):
+ from apscheduler.datastores.async_adapter import AsyncDataStoreAdapter
+ return AsyncDataStoreAdapter(value)
+
+ return value
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index a10e772..751eb2c 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -10,12 +10,12 @@ from typing import Any, Callable, Iterable, Mapping, Optional
from uuid import UUID, uuid4
import anyio
+import attr
from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class, move_on_after
-from anyio.abc import TaskGroup
-from ..abc import AsyncDataStore, DataStore, EventSource, Job, Schedule, Trigger
+from ..abc import AsyncDataStore, EventSource, Job, Schedule, Trigger
from ..context import current_scheduler
-from ..datastores.async_adapter import AsyncDataStoreAdapter
+from ..converters import as_async_datastore
from ..datastores.memory import MemoryDataStore
from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome, RunState
from ..eventbrokers.async_local import LocalAsyncEventBroker
@@ -30,29 +30,24 @@ _microsecond_delta = timedelta(microseconds=1)
_zero_timedelta = timedelta()
+@attr.define(eq=False)
class AsyncScheduler:
"""An asynchronous (AnyIO based) scheduler implementation."""
- data_store: AsyncDataStore
- _state: RunState = RunState.stopped
- _wakeup_event: anyio.Event
- _worker: Optional[AsyncWorker] = None
- _task_group: Optional[TaskGroup] = None
-
- def __init__(self, data_store: DataStore | AsyncDataStore | None = None, *,
- identity: Optional[str] = None, logger: Optional[Logger] = None,
- start_worker: bool = True):
- self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}'
- self.logger = logger or getLogger(__name__)
- self.start_worker = start_worker
- self._exit_stack = AsyncExitStack()
- self._events = LocalAsyncEventBroker()
+ data_store: AsyncDataStore = attr.field(converter=as_async_datastore, factory=MemoryDataStore)
+ identity: str = attr.field(kw_only=True, default=None)
+ start_worker: bool = attr.field(kw_only=True, default=True)
+ logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__))
- data_store = data_store or MemoryDataStore()
- if isinstance(data_store, DataStore):
- self.data_store = AsyncDataStoreAdapter(data_store)
- else:
- self.data_store = data_store
+ _state: RunState = attr.field(init=False, default=RunState.stopped)
+ _wakeup_event: anyio.Event = attr.field(init=False)
+ _worker: Optional[AsyncWorker] = attr.field(init=False, default=None)
+ _events: LocalAsyncEventBroker = attr.field(init=False, factory=LocalAsyncEventBroker)
+ _exit_stack: AsyncExitStack = attr.field(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ if not self.identity:
+ self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}'
@property
def events(self) -> EventSource:
@@ -65,6 +60,7 @@ class AsyncScheduler:
async def __aenter__(self):
self._state = RunState.starting
self._wakeup_event = anyio.Event()
+ self._exit_stack = AsyncExitStack()
await self._exit_stack.__aenter__()
await self._exit_stack.enter_async_context(self._events)
@@ -89,16 +85,16 @@ class AsyncScheduler:
current_scheduler.reset(token)
# Start the worker and return when it has signalled readiness or raised an exception
- self._task_group = create_task_group()
- await self._exit_stack.enter_async_context(self._task_group)
- await self._task_group.start(self.run)
+ task_group = create_task_group()
+ await self._exit_stack.enter_async_context(task_group)
+ await task_group.start(self.run)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self._state = RunState.stopping
self._wakeup_event.set()
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
- del self._task_group
+ self._state = RunState.stopped
del self._wakeup_event
async def add_schedule(
diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py
index 2af4f8e..c49359c 100644
--- a/src/apscheduler/schedulers/sync.py
+++ b/src/apscheduler/schedulers/sync.py
@@ -11,6 +11,8 @@ from logging import Logger, getLogger
from typing import Any, Callable, Iterable, Mapping, Optional
from uuid import UUID, uuid4
+import attr
+
from ..abc import DataStore, EventSource, Trigger
from ..context import current_scheduler
from ..datastores.memory import MemoryDataStore
@@ -27,22 +29,24 @@ _microsecond_delta = timedelta(microseconds=1)
_zero_timedelta = timedelta()
+@attr.define(eq=False)
class Scheduler:
"""A synchronous scheduler implementation."""
- _state: RunState = RunState.stopped
- _wakeup_event: threading.Event
- _worker: Optional[Worker] = None
+ data_store: DataStore = attr.field(factory=MemoryDataStore)
+ identity: str = attr.field(kw_only=True, default=None)
+ start_worker: bool = attr.field(kw_only=True, default=True)
+ logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__))
- def __init__(self, data_store: Optional[DataStore] = None, *, identity: Optional[str] = None,
- logger: Optional[Logger] = None, start_worker: bool = True):
- self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}'
- self.logger = logger or getLogger(__name__)
- self.start_worker = start_worker
- self.data_store = data_store or MemoryDataStore()
- self._exit_stack = ExitStack()
- self._executor = ThreadPoolExecutor(max_workers=1)
- self._events = LocalEventBroker()
+ _state: RunState = attr.field(init=False, default=RunState.stopped)
+ _wakeup_event: threading.Event = attr.field(init=False)
+ _worker: Optional[Worker] = attr.field(init=False, default=None)
+ _events: LocalEventBroker = attr.field(init=False, factory=LocalEventBroker)
+ _exit_stack: ExitStack = attr.field(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ if not self.identity:
+ self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}'
@property
def events(self) -> EventSource:
@@ -59,6 +63,7 @@ class Scheduler:
def __enter__(self) -> Scheduler:
self._state = RunState.starting
self._wakeup_event = threading.Event()
+ self._exit_stack = ExitStack()
self._exit_stack.__enter__()
self._exit_stack.enter_context(self._events)
@@ -85,7 +90,9 @@ class Scheduler:
# Start the scheduler and return when it has signalled readiness or raised an exception
start_future: Future[Event] = Future()
with self._events.subscribe(start_future.set_result, one_shot=True):
- run_future = self._executor.submit(self.run)
+ executor = ThreadPoolExecutor(1)
+ self._exit_stack.push(lambda exc_type, *args: executor.shutdown(wait=exc_type is None))
+ run_future = executor.submit(self.run)
wait([start_future, run_future], return_when=FIRST_COMPLETED)
if run_future.done():
@@ -96,8 +103,8 @@ class Scheduler:
def __exit__(self, exc_type, exc_val, exc_tb):
self._state = RunState.stopping
self._wakeup_event.set()
- self._executor.shutdown(wait=exc_type is None)
self._exit_stack.__exit__(exc_type, exc_val, exc_tb)
+ self._state = RunState.stopped
del self._wakeup_event
def add_schedule(
diff --git a/src/apscheduler/validators.py b/src/apscheduler/validators.py
index c71f73a..95c3747 100644
--- a/src/apscheduler/validators.py
+++ b/src/apscheduler/validators.py
@@ -4,6 +4,7 @@ import sys
from datetime import date, datetime, timedelta, timezone, tzinfo
from typing import Any, Optional
+import attr
from attr import Attribute
from tzlocal import get_localzone
@@ -163,3 +164,8 @@ def require_state_version(trigger: Trigger, state: dict[str, Any], max_version:
)
except KeyError as exc:
raise DeserializationError('Missing "version" key in the serialized state') from exc
+
+
+def positive_integer(inst, field: attr.Attribute, value) -> None:
+ if value <= 0:
+ raise ValueError(f'{field} must be a positive integer')
diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py
index e037a05..723b004 100644
--- a/src/apscheduler/workers/async_.py
+++ b/src/apscheduler/workers/async_.py
@@ -10,44 +10,38 @@ from typing import Callable, Optional
from uuid import UUID
import anyio
+import attr
from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class, move_on_after
from anyio.abc import CancelScope
-from ..abc import AsyncDataStore, DataStore, EventSource, Job
+from ..abc import AsyncDataStore, EventSource, Job
from ..context import current_worker, job_info
-from ..datastores.async_adapter import AsyncDataStoreAdapter
+from ..converters import as_async_datastore
from ..enums import JobOutcome, RunState
from ..eventbrokers.async_local import LocalAsyncEventBroker
from ..events import JobAdded, WorkerStarted, WorkerStopped
from ..structures import JobInfo, JobResult
+from ..validators import positive_integer
+@attr.define(eq=False)
class AsyncWorker:
"""Runs jobs locally in a task group."""
-
- _stop_event: Optional[anyio.Event] = None
- _state: RunState = RunState.stopped
- _acquire_cancel_scope: Optional[CancelScope] = None
- _wakeup_event: anyio.Event
-
- def __init__(self, data_store: DataStore | AsyncDataStore, *,
- max_concurrent_jobs: int = 100, identity: Optional[str] = None,
- logger: Optional[Logger] = None):
- self.max_concurrent_jobs = max_concurrent_jobs
- self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}'
- self.logger = logger or getLogger(__name__)
- self._acquired_jobs: set[Job] = set()
- self._exit_stack = AsyncExitStack()
- self._events = LocalAsyncEventBroker()
- self._running_jobs: set[UUID] = set()
-
- if self.max_concurrent_jobs < 1:
- raise ValueError('max_concurrent_jobs must be at least 1')
-
- if isinstance(data_store, DataStore):
- self.data_store = AsyncDataStoreAdapter(data_store)
- else:
- self.data_store = data_store
+ data_store: AsyncDataStore = attr.field(converter=as_async_datastore)
+ max_concurrent_jobs: int = attr.field(kw_only=True, validator=positive_integer, default=100)
+ identity: str = attr.field(kw_only=True, default=None)
+ logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__))
+
+ _state: RunState = attr.field(init=False, default=RunState.stopped)
+ _wakeup_event: anyio.Event = attr.field(init=False, factory=anyio.Event)
+ _acquired_jobs: set[Job] = attr.field(init=False, factory=set)
+ _events: LocalAsyncEventBroker = attr.field(init=False, factory=LocalAsyncEventBroker)
+ _running_jobs: set[UUID] = attr.field(init=False, factory=set)
+ _exit_stack: AsyncExitStack = attr.field(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ if not self.identity:
+ self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}'
@property
def events(self) -> EventSource:
@@ -60,6 +54,7 @@ class AsyncWorker:
async def __aenter__(self):
self._state = RunState.starting
self._wakeup_event = anyio.Event()
+ self._exit_stack = AsyncExitStack()
await self._exit_stack.__aenter__()
await self._exit_stack.enter_async_context(self._events)
diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py
index 0afa8c7..4e3eef4 100644
--- a/src/apscheduler/workers/sync.py
+++ b/src/apscheduler/workers/sync.py
@@ -11,35 +11,36 @@ from logging import Logger, getLogger
from typing import Callable, Optional
from uuid import UUID
+import attr
+
from ..abc import DataStore, EventSource
from ..context import current_worker, job_info
from ..enums import JobOutcome, RunState
from ..eventbrokers.local import LocalEventBroker
from ..events import JobAdded, WorkerStarted, WorkerStopped
from ..structures import Job, JobInfo, JobResult
+from ..validators import positive_integer
+@attr.define(eq=False)
class Worker:
"""Runs jobs locally in a thread pool."""
-
- _executor: ThreadPoolExecutor
- _state: RunState = RunState.stopped
- _wakeup_event: threading.Event
-
- def __init__(self, data_store: DataStore, *, max_concurrent_jobs: int = 20,
- identity: Optional[str] = None, logger: Optional[Logger] = None):
- self.max_concurrent_jobs = max_concurrent_jobs
- self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}'
- self.logger = logger or getLogger(__name__)
- self._acquired_jobs: set[Job] = set()
- self._exit_stack = ExitStack()
- self._events = LocalEventBroker()
- self._running_jobs: set[UUID] = set()
-
- if self.max_concurrent_jobs < 1:
- raise ValueError('max_concurrent_jobs must be at least 1')
-
- self.data_store = data_store
+ data_store: DataStore
+ max_concurrent_jobs: int = attr.field(kw_only=True, validator=positive_integer, default=20)
+ identity: str = attr.field(kw_only=True, default=None)
+ logger: Optional[Logger] = attr.field(kw_only=True, default=getLogger(__name__))
+
+ _state: RunState = attr.field(init=False, default=RunState.stopped)
+ _wakeup_event: threading.Event = attr.field(init=False)
+ _acquired_jobs: set[Job] = attr.field(init=False, factory=set)
+ _events: LocalEventBroker = attr.field(init=False, factory=LocalEventBroker)
+ _running_jobs: set[UUID] = attr.field(init=False, factory=set)
+ _exit_stack: ExitStack = attr.field(init=False)
+ _executor: ThreadPoolExecutor = attr.field(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ if not self.identity:
+ self.identity = f'{platform.node()}-{os.getpid()}-{id(self)}'
@property
def events(self) -> EventSource:
@@ -52,6 +53,7 @@ class Worker:
def __enter__(self) -> Worker:
self._state = RunState.starting
self._wakeup_event = threading.Event()
+ self._exit_stack = ExitStack()
self._exit_stack.__enter__()
self._exit_stack.enter_context(self._events)