summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-05 23:12:34 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-06 01:39:07 +0300
commit6fed43f29bfa7929fcaf5e549ce34819ba7e3702 (patch)
treef3c8e7675a03354cb30205bd6a21e7ae231c83a6
parent7147cd422bfc6280a56bb4335d65935e942ab9e3 (diff)
downloadapscheduler-6fed43f29bfa7929fcaf5e549ce34819ba7e3702.tar.gz
Implemented task accounting
The maximum number of concurrent jobs for a given task is now enforced if set.
-rw-r--r--src/apscheduler/abc.py86
-rw-r--r--src/apscheduler/datastores/async_/sqlalchemy.py132
-rw-r--r--src/apscheduler/datastores/async_/sync_adapter.py18
-rw-r--r--src/apscheduler/datastores/sync/memory.py81
-rw-r--r--src/apscheduler/datastores/sync/mongodb.py125
-rw-r--r--src/apscheduler/datastores/sync/sqlalchemy.py143
-rw-r--r--src/apscheduler/enums.py1
-rw-r--r--src/apscheduler/events.py10
-rw-r--r--src/apscheduler/exceptions.py5
-rw-r--r--src/apscheduler/schedulers/async_.py65
-rw-r--r--src/apscheduler/schedulers/sync.py41
-rw-r--r--src/apscheduler/structures.py46
-rw-r--r--src/apscheduler/workers/async_.py40
-rw-r--r--src/apscheduler/workers/sync.py17
-rw-r--r--tests/test_datastores.py139
-rw-r--r--tests/test_schedulers.py16
-rw-r--r--tests/test_workers.py47
17 files changed, 791 insertions, 221 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py
index ec77de5..19c5a02 100644
--- a/src/apscheduler/abc.py
+++ b/src/apscheduler/abc.py
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, List, Optio
from uuid import UUID
from .policies import ConflictPolicy
-from .structures import Job, JobResult, Schedule
+from .structures import Job, JobResult, Schedule, Task
if TYPE_CHECKING:
from . import events
@@ -94,6 +94,44 @@ class DataStore(EventSource):
pass
@abstractmethod
+ def add_task(self, task: Task) -> None:
+ """
+ Add the given task to the store.
+
+ If a task with the same ID already exists, it replaces the old one but does NOT affect
+ task accounting (# of running jobs).
+
+ :param task: the task to be added
+ """
+
+ @abstractmethod
+ def remove_task(self, task_id: str) -> None:
+ """
+ Remove the task with the given ID.
+
+ :param task_id: ID of the task to be removed
+ :raises TaskLookupError: if no matching task was found
+ """
+
+ @abstractmethod
+ def get_task(self, task_id: str) -> Task:
+ """
+ Get an existing task definition.
+
+ :param task_id: ID of the task to be returned
+ :return: the matching task
+ :raises TaskLookupError: if no matching task was found
+ """
+
+ @abstractmethod
+ def get_tasks(self) -> List[Task]:
+ """
+ Get all the tasks in this store.
+
+ :return: a list of tasks, sorted by ID
+ """
+
+ @abstractmethod
def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
"""
Get schedules from the data store.
@@ -180,12 +218,12 @@ class DataStore(EventSource):
"""
@abstractmethod
- def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
"""
Release the claim on the given job and record the result.
:param worker_id: unique identifier of the worker
- :param job_id: identifier of the job
+ :param job: the job to be released
:param result: the result of the job (or ``None`` to discard the job)
"""
@@ -209,6 +247,44 @@ class AsyncDataStore(EventSource):
pass
@abstractmethod
+ async def add_task(self, task: Task) -> None:
+ """
+ Add the given task to the store.
+
+ If a task with the same ID already exists, it replaces the old one but does NOT affect
+ task accounting (# of running jobs).
+
+ :param task: the task to be added
+ """
+
+ @abstractmethod
+ async def remove_task(self, task_id: str) -> None:
+ """
+ Remove the task with the given ID.
+
+ :param task_id: ID of the task to be removed
+ :raises TaskLookupError: if no matching task was found
+ """
+
+ @abstractmethod
+ async def get_task(self, task_id: str) -> Task:
+ """
+ Get an existing task definition.
+
+ :param task_id: ID of the task to be returned
+ :return: the matching task
+ :raises TaskLookupError: if no matching task was found
+ """
+
+ @abstractmethod
+ async def get_tasks(self) -> List[Task]:
+ """
+ Get all the tasks in this store.
+
+ :return: a list of tasks, sorted by ID
+ """
+
+ @abstractmethod
async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
"""
Get schedules from the data store.
@@ -295,12 +371,12 @@ class AsyncDataStore(EventSource):
"""
@abstractmethod
- async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
"""
Release the claim on the given job and record the result.
:param worker_id: unique identifier of the worker
- :param job_id: identifier of the job
+ :param job: the job to be released
:param result: the result of the job (or ``None`` to discard the job)
"""
diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py
index caf832c..00e9efc 100644
--- a/src/apscheduler/datastores/async_/sqlalchemy.py
+++ b/src/apscheduler/datastores/async_/sqlalchemy.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import json
import logging
+from collections import defaultdict
from contextlib import AsyncExitStack, closing
from datetime import datetime, timedelta, timezone
from json import JSONDecodeError
@@ -18,17 +19,19 @@ from sqlalchemy.exc import CompileError, IntegrityError
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.ddl import DropTable
-from sqlalchemy.sql.elements import BindParameter
+from sqlalchemy.sql.elements import BindParameter, literal
from ... import events as events_module
from ...abc import AsyncDataStore, Job, Schedule, Serializer
from ...events import (
AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded,
- ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken)
-from ...exceptions import ConflictingIdError, SerializationError
+ ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded,
+ TaskRemoved)
+from ...exceptions import ConflictingIdError, SerializationError, TaskLookupError
+from ...marshalling import callable_to_ref
from ...policies import ConflictPolicy
from ...serializers.pickle import PickleSerializer
-from ...structures import JobResult
+from ...structures import JobResult, Task
from ...util import reentrant
logger = logging.getLogger(__name__)
@@ -78,6 +81,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
# Generate the table definitions
self._metadata = self.get_table_definitions()
self.t_metadata = self._metadata.tables['metadata']
+ self.t_tasks = self._metadata.tables['tasks']
self.t_schedules = self._metadata.tables['schedules']
self.t_jobs = self._metadata.tables['jobs']
self.t_job_results = self._metadata.tables['job_results']
@@ -153,7 +157,11 @@ class SQLAlchemyDataStore(AsyncDataStore):
'tasks',
metadata,
Column('id', Unicode(500), primary_key=True),
- Column('serialized_data', LargeBinary, nullable=False)
+ Column('func', Unicode(500), nullable=False),
+ Column('state', LargeBinary),
+ Column('max_running_jobs', Integer),
+ Column('misfire_grace_time', Unicode(16)),
+ Column('running_jobs', Integer, nullable=False, server_default=literal(0))
)
Table(
'schedules',
@@ -259,6 +267,55 @@ class SQLAlchemyDataStore(AsyncDataStore):
def unsubscribe(self, token: SubscriptionToken) -> None:
self._events.unsubscribe(token)
+ async def add_task(self, task: Task) -> None:
+ insert = self.t_tasks.insert().\
+ values(id=task.id, func=callable_to_ref(task.func),
+ max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time)
+ try:
+ async with self.engine.begin() as conn:
+ await conn.execute(insert)
+ except IntegrityError:
+ update = self.t_tasks.update().\
+ values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time).\
+ where(self.t_tasks.c.id == task.id)
+ async with self.engine.begin() as conn:
+ await conn.execute(update)
+
+ self._events.publish(TaskAdded(task_id=task.id))
+
+ async def remove_task(self, task_id: str) -> None:
+ delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
+ async with self.engine.begin() as conn:
+ result = await conn.execute(delete)
+ if result.rowcount == 0:
+ raise TaskLookupError(task_id)
+ else:
+ self._events.publish(TaskRemoved(task_id=task_id))
+
+ async def get_task(self, task_id: str) -> Task:
+ query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
+ where(self.t_tasks.c.id == task_id)
+ async with self.engine.begin() as conn:
+ result = await conn.execute(query)
+ row = result.fetch_one()
+
+ if row:
+ return Task.unmarshal(self.serializer, row._asdict())
+ else:
+ raise TaskLookupError
+
+ async def get_tasks(self) -> List[Task]:
+ query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
+ order_by(self.t_tasks.c.id)
+ async with self.engine.begin() as conn:
+ result = await conn.execute(query)
+ tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result]
+ return tasks
+
async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
event: Event
serialized_data = self.serializer.serialize(schedule)
@@ -436,30 +493,79 @@ class SQLAlchemyDataStore(AsyncDataStore):
now = datetime.now(timezone.utc)
acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\
+ join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\
where(or_(self.t_jobs.c.acquired_until.is_(None),
self.t_jobs.c.acquired_until < now)).\
order_by(self.t_jobs.c.created_at).\
limit(limit)
- serialized_jobs: Dict[str, bytes] = {row[0]: row[1]
- for row in await conn.execute(query)}
- if serialized_jobs:
+ result = await conn.execute(query)
+ if not result:
+ return []
+
+ # Mark the jobs as acquired by this worker
+ jobs = self._deserialize_jobs(result)
+ task_ids: Set[str] = {job.task_id for job in jobs}
+
+ # Retrieve the limits
+ query = select([self.t_tasks.c.id,
+ self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\
+ where(self.t_tasks.c.max_running_jobs.isnot(None),
+ self.t_tasks.c.id.in_(task_ids))
+ result = await conn.execute(query)
+ job_slots_left = dict(result.fetchall())
+
+ # Filter out jobs that don't have free slots
+ acquired_jobs: List[Job] = []
+ increments: Dict[str, int] = defaultdict(lambda: 0)
+ for job in jobs:
+ # Don't acquire the job if there are no free slots left
+ slots_left = job_slots_left.get(job.task_id)
+ if slots_left == 0:
+ continue
+ elif slots_left is not None:
+ job_slots_left[job.task_id] -= 1
+
+ acquired_jobs.append(job)
+ increments[job.task_id] += 1
+
+ if acquired_jobs:
+ # Mark the acquired jobs as acquired by this worker
+ acquired_job_ids = [job.id.hex for job in acquired_jobs]
update = self.t_jobs.update().\
values(acquired_by=worker_id, acquired_until=acquired_until).\
- where(self.t_jobs.c.id.in_(serialized_jobs))
+ where(self.t_jobs.c.id.in_(acquired_job_ids))
await conn.execute(update)
- return self._deserialize_jobs(serialized_jobs.items())
+ # Increment the running job counters on each task
+ p_id = bindparam('p_id')
+ p_increment = bindparam('p_increment')
+ params = [{'p_id': task_id, 'p_increment': increment}
+ for task_id, increment in increments.items()]
+ update = self.t_tasks.update().\
+ values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\
+ where(self.t_tasks.c.id == p_id)
+ await conn.execute(update, params)
+
+ return acquired_jobs
- async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
async with self.engine.begin() as conn:
+ # Insert the job result
now = datetime.now(timezone.utc)
serialized_data = self.serializer.serialize(result)
insert = self.t_job_results.insert().\
- values(job_id=job_id.hex, finished_at=now, serialized_data=serialized_data)
+ values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_data)
await conn.execute(insert)
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == job_id.hex)
+ # Decrement the running jobs counter
+ update = self.t_tasks.update().\
+ values(running_jobs=self.t_tasks.c.running_jobs - 1).\
+ where(self.t_tasks.c.id == job.task_id)
+ await conn.execute(update)
+
+ # Delete the job
+ delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex)
await conn.execute(delete)
async def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
diff --git a/src/apscheduler/datastores/async_/sync_adapter.py b/src/apscheduler/datastores/async_/sync_adapter.py
index 2bb691e..4f81196 100644
--- a/src/apscheduler/datastores/async_/sync_adapter.py
+++ b/src/apscheduler/datastores/async_/sync_adapter.py
@@ -13,7 +13,7 @@ from ... import events
from ...abc import AsyncDataStore, DataStore
from ...events import Event, SubscriptionToken
from ...policies import ConflictPolicy
-from ...structures import Job, JobResult, Schedule
+from ...structures import Job, JobResult, Schedule, Task
from ...util import reentrant
@@ -33,6 +33,18 @@ class AsyncDataStoreAdapter(AsyncDataStore):
await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb)
await self._portal.__aexit__(exc_type, exc_val, exc_tb)
+ async def add_task(self, task: Task) -> None:
+ await to_thread.run_sync(self.original.add_task, task)
+
+ async def remove_task(self, task_id: str) -> None:
+ await to_thread.run_sync(self.original.remove_task, task_id)
+
+ async def get_task(self, task_id: str) -> Task:
+ return await to_thread.run_sync(self.original.get_task, task_id)
+
+ async def get_tasks(self) -> List[Task]:
+ return await to_thread.run_sync(self.original.get_tasks)
+
async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
return await to_thread.run_sync(self.original.get_schedules, ids)
@@ -60,8 +72,8 @@ class AsyncDataStoreAdapter(AsyncDataStore):
async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit)
- async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
- await to_thread.run_sync(self.original.release_job, worker_id, job_id, result)
+ async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
+ await to_thread.run_sync(self.original.release_job, worker_id, job, result)
async def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
return await to_thread.run_sync(self.original.get_job_result, job_id)
diff --git a/src/apscheduler/datastores/sync/memory.py b/src/apscheduler/datastores/sync/memory.py
index a108142..4239dc7 100644
--- a/src/apscheduler/datastores/sync/memory.py
+++ b/src/apscheduler/datastores/sync/memory.py
@@ -12,16 +12,27 @@ import attr
from ... import events
from ...abc import DataStore, Job, Schedule
from ...events import (
- EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken)
-from ...exceptions import ConflictingIdError
+ EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken,
+ TaskAdded, TaskRemoved)
+from ...exceptions import ConflictingIdError, TaskLookupError
from ...policies import ConflictPolicy
-from ...structures import JobResult
+from ...structures import JobResult, Task
from ...util import reentrant
max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)
@attr.define
+class TaskState:
+ task: Task
+ running_jobs: int = 0
+ saved_state: Any = None
+
+ def __eq__(self, other):
+ return self.task.id == other.task.id
+
+
+@attr.define
class ScheduleState:
schedule: Schedule
next_fire_time: Optional[datetime] = attr.field(init=False, eq=False)
@@ -65,6 +76,7 @@ class JobState:
class MemoryDataStore(DataStore):
lock_expiration_delay: float = 30
_events: EventHub = attr.Factory(EventHub)
+ _tasks: Dict[str, TaskState] = attr.Factory(dict)
_schedules: List[ScheduleState] = attr.Factory(list)
_schedules_by_id: Dict[str, ScheduleState] = attr.Factory(dict)
_schedules_by_task_id: Dict[str, Set[ScheduleState]] = attr.Factory(partial(defaultdict, set))
@@ -101,6 +113,27 @@ class MemoryDataStore(DataStore):
return [state.schedule for state in self._schedules
if ids is None or state.schedule.id in ids]
+ def add_task(self, task: Task) -> None:
+ self._tasks[task.id] = TaskState(task)
+ self._events.publish(TaskAdded(task_id=task.id))
+
+ def remove_task(self, task_id: str) -> None:
+ try:
+ del self._tasks[task_id]
+ except KeyError:
+ raise TaskLookupError(task_id) from None
+
+ self._events.publish(TaskRemoved(task_id=task_id))
+
+ def get_task(self, task_id: str) -> Task:
+ try:
+ return self._tasks[task_id].task
+ except KeyError:
+ raise TaskLookupError(task_id) from None
+
+ def get_tasks(self) -> List[Task]:
+ return sorted((state.task for state in self._tasks.values()), key=lambda task: task.id)
+
def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
old_state = self._schedules_by_id.get(schedule.id)
if old_state is not None:
@@ -201,27 +234,49 @@ class MemoryDataStore(DataStore):
def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
now = datetime.now(timezone.utc)
jobs: List[Job] = []
- for state in self._jobs:
- if state.acquired_by is not None and state.acquired_until >= now:
+ for index, job_state in enumerate(self._jobs):
+ task_state = self._tasks[job_state.job.task_id]
+
+ # Skip already acquired jobs (unless the acquisition lock has expired)
+ if job_state.acquired_by is not None:
+ if job_state.acquired_until >= now:
+ continue
+ else:
+ task_state.running_jobs -= 1
+
+ # Check if the task allows one more job to be started
+ if (task_state.task.max_running_jobs is not None
+ and task_state.running_jobs >= task_state.task.max_running_jobs):
continue
- jobs.append(state.job)
- state.acquired_by = worker_id
- state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+ # Mark the job as acquired by this worker
+ jobs.append(job_state.job)
+ job_state.acquired_by = worker_id
+ job_state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+
+ # Increment the number of running jobs for this task
+ task_state.running_jobs += 1
+
+ # Exit the loop if enough jobs have been acquired
if len(jobs) == limit:
break
return jobs
- def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
# Delete the job
- state = self._jobs_by_id.pop(job_id)
- self._jobs_by_task_id[state.job.task_id].remove(state)
- index = self._find_job_index(state)
+ job_state = self._jobs_by_id.pop(job.id)
+ self._jobs_by_task_id[job.task_id].remove(job_state)
+ index = self._find_job_index(job_state)
del self._jobs[index]
+ # Decrement the number of running jobs for this task
+ task_state = self._tasks.get(job.task_id)
+ if task_state is not None:
+ task_state.running_jobs -= 1
+
# Record the result
- self._job_results[job_id] = result
+ self._job_results[job.id] = result
def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
return self._job_results.pop(job_id, None)
diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py
index 4d5c0f2..fcf7d4d 100644
--- a/src/apscheduler/datastores/sync/mongodb.py
+++ b/src/apscheduler/datastores/sync/mongodb.py
@@ -1,11 +1,14 @@
from __future__ import annotations
import logging
+from collections import defaultdict
from contextlib import ExitStack
from datetime import datetime, timezone
-from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Type
+from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Set, Tuple, Type
from uuid import UUID
+import attr
+import pymongo
from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne
from pymongo.collection import Collection
from pymongo.errors import DuplicateKeyError
@@ -14,19 +17,25 @@ from ... import events
from ...abc import DataStore, Job, Schedule, Serializer
from ...events import (
DataStoreEvent, EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated,
- SubscriptionToken)
-from ...exceptions import ConflictingIdError, DeserializationError, SerializationError
+ SubscriptionToken, TaskAdded, TaskRemoved)
+from ...exceptions import (
+ ConflictingIdError, DeserializationError, SerializationError, TaskLookupError)
from ...policies import ConflictPolicy
from ...serializers.pickle import PickleSerializer
-from ...structures import JobResult
+from ...structures import JobResult, Task
from ...util import reentrant
@reentrant
class MongoDBDataStore(DataStore):
+ _task_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Task)]
+ _schedule_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Schedule)]
+ _job_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Job)]
+
def __init__(self, client: MongoClient, *, serializer: Optional[Serializer] = None,
- database: str = 'apscheduler', schedules_collection: str = 'schedules',
- jobs_collection: str = 'jobs', job_results_collection: str = 'job_results',
+ database: str = 'apscheduler', tasks_collection: str = 'tasks',
+ schedules_collection: str = 'schedules', jobs_collection: str = 'jobs',
+ job_results_collection: str = 'job_results',
lock_expiration_delay: float = 30, start_from_scratch: bool = False):
super().__init__()
if not client.delegate.codec_options.tz_aware:
@@ -36,7 +45,9 @@ class MongoDBDataStore(DataStore):
self.serializer = serializer or PickleSerializer()
self.lock_expiration_delay = lock_expiration_delay
self.start_from_scratch = start_from_scratch
+ self._local_tasks: Dict[str, Task] = {}
self._database = client[database]
+ self._tasks: Collection = self._database[tasks_collection]
self._schedules: Collection = self._database[schedules_collection]
self._jobs: Collection = self._database[jobs_collection]
self._jobs_results: Collection = self._database[job_results_collection]
@@ -59,6 +70,7 @@ class MongoDBDataStore(DataStore):
self._exit_stack.enter_context(self._events)
if self.start_from_scratch:
+ self._tasks.delete_many({})
self._schedules.delete_many({})
self._jobs.delete_many({})
self._jobs_results.delete_many({})
@@ -80,6 +92,44 @@ class MongoDBDataStore(DataStore):
def unsubscribe(self, token: events.SubscriptionToken) -> None:
self._events.unsubscribe(token)
+ def add_task(self, task: Task) -> None:
+ self._tasks.find_one_and_update(
+ {'_id': task.id},
+ {'$set': task.marshal(self.serializer),
+ '$setOnInsert': {'running_jobs': 0}},
+ upsert=True
+ )
+ self._local_tasks[task.id] = task
+ self._events.publish(TaskAdded(task_id=task.id))
+
+ def remove_task(self, task_id: str) -> None:
+ if not self._tasks.find_one_and_delete({'_id': task_id}):
+ raise TaskLookupError(task_id)
+
+ del self._local_tasks[task_id]
+ self._events.publish(TaskRemoved(task_id=task_id))
+
+ def get_task(self, task_id: str) -> Task:
+ try:
+ return self._local_tasks[task_id]
+ except KeyError:
+ document = self._tasks.find_one({'_id': task_id}, projection=self._task_attrs)
+ if not document:
+ raise TaskLookupError(task_id)
+
+ document['id'] = document.pop('id')
+ task = self._local_tasks[task_id] = Task.unmarshal(self.serializer, document)
+ return task
+
+ def get_tasks(self) -> List[Task]:
+ tasks: List[Task] = []
+ for document in self._tasks.find(projection=self._task_attrs,
+ sort=[('_id', pymongo.ASCENDING)]):
+ document['id'] = document.pop('_id')
+ tasks.append(Task.unmarshal(self.serializer, document))
+
+ return tasks
+
def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
schedules: List[Schedule] = []
filters = {'_id': {'$in': list(ids)}} if ids is not None else {}
@@ -238,45 +288,82 @@ class MongoDBDataStore(DataStore):
return jobs
def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
- jobs: List[Job] = []
with self.client.start_session() as session:
cursor = self._jobs.find(
{'$or': [{'acquired_until': {'$exists': False}},
{'acquired_until': {'$lt': datetime.now(timezone.utc)}}]
},
- projection=['serialized_data'],
+ projection=['task_id', 'serialized_data'],
sort=[('created_at', ASCENDING)],
limit=limit,
session=session
)
- for document in cursor:
+ documents = list(cursor)
+
+ # Retrieve the limits
+ task_ids: Set[str] = {document['task_id'] for document in documents}
+ task_limits = self._tasks.find(
+ {'_id': {'$in': list(task_ids)}, 'max_running_jobs': {'$ne': None}},
+ projection=['max_running_jobs', 'running_jobs'],
+ session=session
+ )
+ job_slots_left = {doc['_id']: doc['max_running_jobs'] - doc['running_jobs']
+ for doc in task_limits}
+
+ # Filter out jobs that don't have free slots
+ acquired_jobs: List[Job] = []
+ increments: Dict[str, int] = defaultdict(lambda: 0)
+ for document in documents:
job = self.serializer.deserialize(document['serialized_data'])
- jobs.append(job)
- if jobs:
+ # Don't acquire the job if there are no free slots left
+ slots_left = job_slots_left.get(job.task_id)
+ if slots_left == 0:
+ continue
+ elif slots_left is not None:
+ job_slots_left[job.task_id] -= 1
+
+ acquired_jobs.append(job)
+ increments[job.task_id] += 1
+
+ if acquired_jobs:
now = datetime.now(timezone.utc)
acquired_until = datetime.fromtimestamp(
now.timestamp() + self.lock_expiration_delay, timezone.utc)
- filters = {'_id': {'$in': [job.id for job in jobs]}}
+ filters = {'_id': {'$in': [job.id for job in acquired_jobs]}}
update = {'$set': {'acquired_by': worker_id,
'acquired_until': acquired_until}}
self._jobs.update_many(filters, update, session=session)
- return jobs
+ # Increment the running job counters on each task
+ for task_id, increment in increments.items():
+ self._tasks.find_one_and_update(
+ {'_id': task_id},
+ {'$inc': {'running_jobs': increment}},
+ session=session
+ )
+
+ return acquired_jobs
- def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
with self.client.start_session() as session:
+ # Insert the job result
now = datetime.now(timezone.utc)
- serialized_result = self.serializer.serialize(result)
document = {
- '_id': job_id,
+ '_id': job.id,
'finished_at': now,
- 'serialized_data': serialized_result
+ 'serialized_data': self.serializer.serialize(result)
}
self._jobs_results.insert_one(document, session=session)
- filters = {'_id': job_id, 'acquired_by': worker_id}
- self._jobs.delete_one(filters, session=session)
+ # Decrement the running jobs counter
+ self._tasks.find_one_and_update(
+ {'_id': job.task_id},
+ {'$inc': {'running_jobs': -1}}
+ )
+
+ # Delete the job
+ self._jobs.delete_one({'_id': job.id}, session=session)
def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
document = self._jobs_results.find_one_and_delete(
diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py
index f1f541a..3b1d2bf 100644
--- a/src/apscheduler/datastores/sync/sqlalchemy.py
+++ b/src/apscheduler/datastores/sync/sqlalchemy.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
+from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from uuid import UUID
@@ -11,16 +12,18 @@ from sqlalchemy.engine import URL
from sqlalchemy.exc import CompileError, IntegrityError
from sqlalchemy.future import Engine, create_engine
from sqlalchemy.sql.ddl import DropTable
-from sqlalchemy.sql.elements import BindParameter
+from sqlalchemy.sql.elements import BindParameter, literal
from ...abc import DataStore, Job, Schedule, Serializer
from ...events import (
Event, EventHub, JobAdded, JobDeserializationFailed, ScheduleAdded,
- ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken)
-from ...exceptions import ConflictingIdError, SerializationError
+ ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded,
+ TaskRemoved)
+from ...exceptions import ConflictingIdError, SerializationError, TaskLookupError
+from ...marshalling import callable_to_ref
from ...policies import ConflictPolicy
from ...serializers.pickle import PickleSerializer
-from ...structures import JobResult
+from ...structures import JobResult, Task
from ...util import reentrant
logger = logging.getLogger(__name__)
@@ -45,6 +48,7 @@ class SQLAlchemyDataStore(DataStore):
# Generate the table definitions
self._metadata = self.get_table_definitions()
self.t_metadata = self._metadata.tables['metadata']
+ self.t_tasks = self._metadata.tables['tasks']
self.t_schedules = self._metadata.tables['schedules']
self.t_jobs = self._metadata.tables['jobs']
self.t_job_results = self._metadata.tables['job_results']
@@ -58,6 +62,15 @@ class SQLAlchemyDataStore(DataStore):
else:
self._supports_update_returning = True
+ # Find out if the dialect supports INSERT...ON DUPLICATE KEY UPDATE
+ insert = self.t_jobs.update().returning(self.t_schedules.c.id)
+ try:
+ insert.compile(bind=self.engine)
+ except CompileError:
+ self._supports_update_returning = False
+ else:
+ self._supports_update_returning = True
+
@classmethod
def from_url(cls, url: Union[str, URL], **options) -> 'SQLAlchemyDataStore':
engine = create_engine(url)
@@ -103,7 +116,11 @@ class SQLAlchemyDataStore(DataStore):
'tasks',
metadata,
Column('id', Unicode(500), primary_key=True),
- Column('serialized_data', LargeBinary, nullable=False)
+ Column('func', Unicode(500), nullable=False),
+ Column('state', LargeBinary),
+ Column('max_running_jobs', Integer),
+ Column('misfire_grace_time', Unicode(16)),
+ Column('running_jobs', Integer, nullable=False, server_default=literal(0))
)
Table(
'schedules',
@@ -163,6 +180,55 @@ class SQLAlchemyDataStore(DataStore):
def unsubscribe(self, token: SubscriptionToken) -> None:
self._events.unsubscribe(token)
+ def add_task(self, task: Task) -> None:
+ insert = self.t_tasks.insert().\
+ values(id=task.id, func=callable_to_ref(task.func),
+ max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time)
+ try:
+ with self.engine.begin() as conn:
+ conn.execute(insert)
+ except IntegrityError:
+ update = self.t_tasks.update().\
+ values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time).\
+ where(self.t_tasks.c.id == task.id)
+ with self.engine.begin() as conn:
+ conn.execute(update)
+
+ self._events.publish(TaskAdded(task_id=task.id))
+
+ def remove_task(self, task_id: str) -> None:
+ delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
+ with self.engine.begin() as conn:
+ result = conn.execute(delete)
+ if result.rowcount == 0:
+ raise TaskLookupError(task_id)
+ else:
+ self._events.publish(TaskRemoved(task_id=task_id))
+
+ def get_task(self, task_id: str) -> Task:
+ query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
+ where(self.t_tasks.c.id == task_id)
+ with self.engine.begin() as conn:
+ result = conn.execute(query)
+ row = result.fetch_one()
+
+ if row:
+ return Task.unmarshal(self.serializer, row._asdict())
+ else:
+ raise TaskLookupError
+
+ def get_tasks(self) -> List[Task]:
+ query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
+ order_by(self.t_tasks.c.id)
+ with self.engine.begin() as conn:
+ result = conn.execute(query)
+ tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result]
+ return tasks
+
def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
event: Event
serialized_data = self.serializer.serialize(schedule)
@@ -339,38 +405,89 @@ class SQLAlchemyDataStore(DataStore):
now = datetime.now(timezone.utc)
acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\
+ join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\
where(or_(self.t_jobs.c.acquired_until.is_(None),
self.t_jobs.c.acquired_until < now)).\
order_by(self.t_jobs.c.created_at).\
limit(limit)
- serialized_jobs: Dict[str, bytes] = {row[0]: row[1]
- for row in conn.execute(query)}
- if serialized_jobs:
+ result = conn.execute(query)
+ if not result:
+ return []
+
+ # Mark the jobs as acquired by this worker
+ jobs = self._deserialize_jobs(result)
+ task_ids: Set[str] = {job.task_id for job in jobs}
+
+ # Retrieve the limits
+ query = select([self.t_tasks.c.id,
+ self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\
+ where(self.t_tasks.c.max_running_jobs.isnot(None),
+ self.t_tasks.c.id.in_(task_ids))
+ result = conn.execute(query)
+ job_slots_left = dict(result.fetchall())
+
+ # Filter out jobs that don't have free slots
+ acquired_jobs: List[Job] = []
+ increments: Dict[str, int] = defaultdict(lambda: 0)
+ for job in jobs:
+ # Don't acquire the job if there are no free slots left
+ slots_left = job_slots_left.get(job.task_id)
+ if slots_left == 0:
+ continue
+ elif slots_left is not None:
+ job_slots_left[job.task_id] -= 1
+
+ acquired_jobs.append(job)
+ increments[job.task_id] += 1
+
+ if acquired_jobs:
+ # Mark the acquired jobs as acquired by this worker
+ acquired_job_ids = [job.id.hex for job in acquired_jobs]
update = self.t_jobs.update().\
values(acquired_by=worker_id, acquired_until=acquired_until).\
- where(self.t_jobs.c.id.in_(serialized_jobs))
+ where(self.t_jobs.c.id.in_(acquired_job_ids))
conn.execute(update)
- return self._deserialize_jobs(serialized_jobs.items())
+ # Increment the running job counters on each task
+ p_id = bindparam('p_id')
+ p_increment = bindparam('p_increment')
+ params = [{'p_id': task_id, 'p_increment': increment}
+ for task_id, increment in increments.items()]
+ update = self.t_tasks.update().\
+ values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\
+ where(self.t_tasks.c.id == p_id)
+ conn.execute(update, params)
+
+ return acquired_jobs
- def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
with self.engine.begin() as conn:
+ # Insert the job result
now = datetime.now(timezone.utc)
serialized_result = self.serializer.serialize(result)
insert = self.t_job_results.insert().\
- values(job_id=job_id.hex, finished_at=now, serialized_data=serialized_result)
+ values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_result)
conn.execute(insert)
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == job_id.hex)
+ # Decrement the running jobs counter
+ update = self.t_tasks.update().\
+ values(running_jobs=self.t_tasks.c.running_jobs - 1).\
+ where(self.t_tasks.c.id == job.task_id)
+ conn.execute(update)
+
+ # Delete the job
+ delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex)
conn.execute(delete)
def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
with self.engine.begin() as conn:
+ # Retrieve the result
query = select(self.t_job_results.c.serialized_data).\
where(self.t_job_results.c.job_id == job_id.hex)
result = conn.execute(query)
+ # Delete the result
delete = self.t_job_results.delete().\
where(self.t_job_results.c.job_id == job_id.hex)
conn.execute(delete)
diff --git a/src/apscheduler/enums.py b/src/apscheduler/enums.py
index a7c0d69..941de3a 100644
--- a/src/apscheduler/enums.py
+++ b/src/apscheduler/enums.py
@@ -13,3 +13,4 @@ class JobOutcome(Enum):
failure = auto()
missed_start_deadline = auto()
cancelled = auto()
+ expired = auto()
diff --git a/src/apscheduler/events.py b/src/apscheduler/events.py
index 552d30b..6c0d270 100644
--- a/src/apscheduler/events.py
+++ b/src/apscheduler/events.py
@@ -45,6 +45,16 @@ class DataStoreEvent(Event):
@attr.define(kw_only=True, frozen=True)
+class TaskAdded(DataStoreEvent):
+ task_id: str
+
+
+@attr.define(kw_only=True, frozen=True)
+class TaskRemoved(DataStoreEvent):
+ task_id: str
+
+
+@attr.define(kw_only=True, frozen=True)
class ScheduleAdded(DataStoreEvent):
schedule_id: str
next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime)
diff --git a/src/apscheduler/exceptions.py b/src/apscheduler/exceptions.py
index ec04f2c..1f90ba0 100644
--- a/src/apscheduler/exceptions.py
+++ b/src/apscheduler/exceptions.py
@@ -1,3 +1,8 @@
+class TaskLookupError(LookupError):
+ """Raised by a data store when it cannot find the requested task."""
+
+ def __init__(self, task_id: str):
+ super().__init__(f'No task by the id of {task_id!r} was found')
class JobLookupError(KeyError):
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index bb1760d..708d2ec 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -5,7 +5,7 @@ import platform
from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
-from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Type, Union
+from typing import Any, Callable, Iterable, Mapping, Optional, Type, Union
from uuid import uuid4
import anyio
@@ -41,7 +41,6 @@ class AsyncScheduler(EventSource):
self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}'
self.logger = logger or getLogger(__name__)
self.start_worker = start_worker
- self._tasks: Dict[str, Task] = {}
self._exit_stack = AsyncExitStack()
self._events = AsyncEventHub()
@@ -96,27 +95,27 @@ class AsyncScheduler(EventSource):
def unsubscribe(self, token: SubscriptionToken) -> None:
self._events.unsubscribe(token)
- def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task:
- task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id)
- taskdef = self._tasks.get(task_id)
- if not taskdef:
- if isinstance(func_or_id, str):
- raise LookupError('no task found with ID {!r}'.format(func_or_id))
- else:
- taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id)
-
- return taskdef
-
- def define_task(self, func: Callable, task_id: Optional[str] = None, **kwargs):
- if task_id is None:
- task_id = callable_to_ref(func)
-
- task = Task(id=task_id, **kwargs)
- if self._tasks.setdefault(task_id, task) is not task:
- pass
+ # def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task:
+ # task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id)
+ # taskdef = self._tasks.get(task_id)
+ # if not taskdef:
+ # if isinstance(func_or_id, str):
+ # raise LookupError('no task found with ID {!r}'.format(func_or_id))
+ # else:
+ # taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id)
+ #
+ # return taskdef
+ #
+ # def define_task(self, func: Callable, task_id: Optional[str] = None, **kwargs):
+ # if task_id is None:
+ # task_id = callable_to_ref(func)
+ #
+ # task = Task(id=task_id, **kwargs)
+ # if self._tasks.setdefault(task_id, task) is not task:
+ # pass
async def add_schedule(
- self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None,
+ self, func_or_task_id: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None,
args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None,
coalesce: CoalescePolicy = CoalescePolicy.latest,
misfire_grace_time: Union[float, timedelta, None] = None,
@@ -130,12 +129,17 @@ class AsyncScheduler(EventSource):
if isinstance(misfire_grace_time, (int, float)):
misfire_grace_time = timedelta(seconds=misfire_grace_time)
- taskdef = self._get_taskdef(task)
- schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs,
+ if callable(func_or_task_id):
+ task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
+ await self.data_store.add_task(task)
+ else:
+ task = await self.data_store.get_task(func_or_task_id)
+
+ schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs,
coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags)
schedule.next_fire_time = trigger.next()
await self.data_store.add_schedule(schedule, conflict_policy)
- self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef,
+ self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task,
trigger, schedule.next_fire_time)
return schedule.id
@@ -157,15 +161,6 @@ class AsyncScheduler(EventSource):
schedules = await self.data_store.acquire_schedules(self.identity, 100)
now = datetime.now(timezone.utc)
for schedule in schedules:
- # Look up the task definition
- try:
- taskdef = self._get_taskdef(schedule.task_id)
- except LookupError:
- self.logger.error('Cannot locate task definition %r for schedule %r – '
- 'removing schedule', schedule.task_id, schedule.id)
- schedule.next_fire_time = None
- continue
-
# Calculate a next fire time for the schedule, if possible
fire_times = [schedule.next_fire_time]
calculate_next = schedule.trigger.next
@@ -175,7 +170,7 @@ class AsyncScheduler(EventSource):
except Exception:
self.logger.exception(
'Error computing next fire time for schedule %r of task %r – '
- 'removing schedule', schedule.id, taskdef.id)
+ 'removing schedule', schedule.id, schedule.task_id)
break
# Stop if the calculated fire time is in the future
@@ -192,7 +187,7 @@ class AsyncScheduler(EventSource):
# Add one or more jobs to the job queue
for fire_time in fire_times:
schedule.last_fire_time = fire_time
- job = Job(task_id=taskdef.id, func=taskdef.func, args=schedule.args,
+ job = Job(task_id=schedule.task_id, args=schedule.args,
kwargs=schedule.kwargs, schedule_id=schedule.id,
scheduled_fire_time=fire_time,
start_deadline=schedule.next_deadline, tags=schedule.tags)
diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py
index 3005d27..3d86b25 100644
--- a/src/apscheduler/schedulers/sync.py
+++ b/src/apscheduler/schedulers/sync.py
@@ -7,7 +7,7 @@ from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from contextlib import ExitStack
from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
-from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Type, Union
+from typing import Any, Callable, Iterable, Mapping, Optional, Type, Union
from uuid import uuid4
from ..abc import DataStore, EventSource, Trigger
@@ -35,7 +35,6 @@ class Scheduler(EventSource):
self.logger = logger or getLogger(__name__)
self.start_worker = start_worker
self.data_store = data_store or MemoryDataStore()
- self._tasks: Dict[str, Task] = {}
self._exit_stack = ExitStack()
self._executor = ThreadPoolExecutor(max_workers=1)
self._events = EventHub()
@@ -97,19 +96,8 @@ class Scheduler(EventSource):
def unsubscribe(self, token: SubscriptionToken) -> None:
self._events.unsubscribe(token)
- def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task:
- task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id)
- taskdef = self._tasks.get(task_id)
- if not taskdef:
- if isinstance(func_or_id, str):
- raise LookupError('no task found with ID {!r}'.format(func_or_id))
- else:
- taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id)
-
- return taskdef
-
def add_schedule(
- self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None,
+ self, func_or_task_id: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None,
args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None,
coalesce: CoalescePolicy = CoalescePolicy.latest,
misfire_grace_time: Union[float, timedelta, None] = None,
@@ -123,12 +111,17 @@ class Scheduler(EventSource):
if isinstance(misfire_grace_time, (int, float)):
misfire_grace_time = timedelta(seconds=misfire_grace_time)
- taskdef = self._get_taskdef(task)
- schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs,
+ if callable(func_or_task_id):
+ task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id)
+ self.data_store.add_task(task)
+ else:
+ task = self.data_store.get_task(func_or_task_id)
+
+ schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs,
coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags)
schedule.next_fire_time = trigger.next()
self.data_store.add_schedule(schedule, conflict_policy)
- self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef,
+ self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task,
trigger, schedule.next_fire_time)
return schedule.id
@@ -149,16 +142,6 @@ class Scheduler(EventSource):
schedules = self.data_store.acquire_schedules(self.identity, 100)
now = datetime.now(timezone.utc)
for schedule in schedules:
- # Look up the task definition
- try:
- taskdef = self._get_taskdef(schedule.task_id)
- except LookupError:
- self.logger.error('Cannot locate task definition %r for schedule %r – '
- 'putting schedule on hold', schedule.task_id,
- schedule.id)
- schedule.next_fire_time = None
- continue
-
# Calculate a next fire time for the schedule, if possible
fire_times = [schedule.next_fire_time]
calculate_next = schedule.trigger.next
@@ -168,7 +151,7 @@ class Scheduler(EventSource):
except Exception:
self.logger.exception(
'Error computing next fire time for schedule %r of task %r – '
- 'removing schedule', schedule.id, taskdef.id)
+ 'removing schedule', schedule.id, schedule.task_id)
break
# Stop if the calculated fire time is in the future
@@ -185,7 +168,7 @@ class Scheduler(EventSource):
# Add one or more jobs to the job queue
for fire_time in fire_times:
schedule.last_fire_time = fire_time
- job = Job(task_id=taskdef.id, func=taskdef.func, args=schedule.args,
+ job = Job(task_id=schedule.task_id, args=schedule.args,
kwargs=schedule.kwargs, schedule_id=schedule.id,
scheduled_fire_time=fire_time,
start_deadline=schedule.next_deadline, tags=schedule.tags)
diff --git a/src/apscheduler/structures.py b/src/apscheduler/structures.py
index 4950818..2f97ca6 100644
--- a/src/apscheduler/structures.py
+++ b/src/apscheduler/structures.py
@@ -8,6 +8,7 @@ import attr
from . import abc
from .enums import JobOutcome
+from .marshalling import callable_from_ref, callable_to_ref
from .policies import CoalescePolicy
@@ -15,11 +16,24 @@ from .policies import CoalescePolicy
class Task:
id: str
func: Callable = attr.field(eq=False, order=False)
- max_instances: Optional[int] = attr.field(eq=False, order=False, default=None)
- metadata_arg: Optional[str] = attr.field(eq=False, order=False, default=None)
- stateful: bool = attr.field(eq=False, order=False, default=False)
+ max_running_jobs: Optional[int] = attr.field(eq=False, order=False, default=None)
+ state: Any = None
misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None)
+ def marshal(self, serializer: abc.Serializer) -> Dict[str, Any]:
+ marshalled = attr.asdict(self)
+ marshalled['func'] = callable_to_ref(self.func)
+ marshalled['state'] = serializer.serialize(self.state) if self.state else None
+ return marshalled
+
+ @classmethod
+ def unmarshal(cls, serializer: abc.Serializer, marshalled: Dict[str, Any]) -> Task:
+ marshalled['func'] = callable_from_ref(marshalled['func'])
+ if marshalled['state'] is not None:
+ marshalled['state'] = serializer.deserialize(marshalled['state'])
+
+ return cls(**marshalled)
+
@attr.define(kw_only=True)
class Schedule:
@@ -36,6 +50,15 @@ class Schedule:
last_fire_time: Optional[datetime] = attr.field(eq=False, order=False, init=False,
default=None)
+ def marshal(self, serializer: abc.Serializer) -> Dict[str, Any]:
+ marshalled = attr.asdict(self)
+ marshalled['trigger_type'] = serializer.serialize(self.args)
+ marshalled['trigger_data'] = serializer.serialize(self.trigger)
+ marshalled['args'] = serializer.serialize(self.args) if self.args else None
+ marshalled['kwargs'] = serializer.serialize(self.kwargs) if self.kwargs else None
+ marshalled['tags'] = list(self.tags)
+ return marshalled
+
@property
def next_deadline(self) -> Optional[datetime]:
if self.next_fire_time and self.misfire_grace_time:
@@ -48,7 +71,6 @@ class Schedule:
class Job:
id: UUID = attr.field(factory=uuid4)
task_id: str = attr.field(eq=False, order=False)
- func: Callable = attr.field(eq=False, order=False)
args: tuple = attr.field(eq=False, order=False, default=())
kwargs: Dict[str, Any] = attr.field(eq=False, order=False, factory=dict)
schedule_id: Optional[str] = attr.field(eq=False, order=False, default=None)
@@ -57,6 +79,22 @@ class Job:
tags: FrozenSet[str] = attr.field(eq=False, order=False, factory=frozenset)
started_at: Optional[datetime] = attr.field(eq=False, order=False, init=False, default=None)
+ def marshal(self, serializer: abc.Serializer) -> Dict[str, Any]:
+ marshalled = attr.asdict(self)
+ marshalled['args'] = serializer.serialize(self.args) if self.args else None
+ marshalled['kwargs'] = serializer.serialize(self.kwargs) if self.kwargs else None
+ marshalled['tags'] = list(self.tags)
+ return marshalled
+
+ @classmethod
+ def unmarshal(cls, serializer: abc.Serializer, marshalled: Dict[str, Any]) -> Task:
+ for key in ('args', 'kwargs'):
+ if marshalled[key] is not None:
+ marshalled[key] = serializer.deserialize(marshalled[key])
+
+ marshalled['tags'] = frozenset(marshalled['tags'])
+ return cls(**marshalled)
+
@attr.define(eq=False, order=False, frozen=True)
class JobResult:
diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py
index 79b0143..dbc72b3 100644
--- a/src/apscheduler/workers/async_.py
+++ b/src/apscheduler/workers/async_.py
@@ -11,7 +11,7 @@ from uuid import UUID
import anyio
from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class
-from anyio.abc import CancelScope, TaskGroup
+from anyio.abc import CancelScope
from ..abc import AsyncDataStore, DataStore, EventSource, Job
from ..datastores.async_.sync_adapter import AsyncDataStoreAdapter
@@ -25,7 +25,6 @@ from ..structures import JobResult
class AsyncWorker(EventSource):
"""Runs jobs locally in a task group."""
- _task_group: Optional[TaskGroup] = None
_stop_event: Optional[anyio.Event] = None
_state: RunState = RunState.stopped
_acquire_cancel_scope: Optional[CancelScope] = None
@@ -72,16 +71,15 @@ class AsyncWorker(EventSource):
self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token)
# Start the actual worker
- self._task_group = create_task_group()
- await self._exit_stack.enter_async_context(self._task_group)
- await self._task_group.start(self.run)
+ task_group = create_task_group()
+ await self._exit_stack.enter_async_context(task_group)
+ await task_group.start(self.run)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self._state = RunState.stopping
self._wakeup_event.set()
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
- del self._task_group
del self._wakeup_event
def subscribe(self, callback: Callable[[Event], Any],
@@ -102,15 +100,17 @@ class AsyncWorker(EventSource):
self._events.publish(WorkerStarted())
try:
- while self._state is RunState.started:
- limit = self.max_concurrent_jobs - len(self._running_jobs)
- jobs = await self.data_store.acquire_jobs(self.identity, limit)
- for job in jobs:
- self._running_jobs.add(job.id)
- self._task_group.start_soon(self._run_job, job)
-
- await self._wakeup_event.wait()
- self._wakeup_event = anyio.Event()
+ async with create_task_group() as tg:
+ while self._state is RunState.started:
+ limit = self.max_concurrent_jobs - len(self._running_jobs)
+ jobs = await self.data_store.acquire_jobs(self.identity, limit)
+ for job in jobs:
+ task = await self.data_store.get_task(job.task_id)
+ self._running_jobs.add(job.id)
+ tg.start_soon(self._run_job, job, task.func)
+
+ await self._wakeup_event.wait()
+ self._wakeup_event = anyio.Event()
except get_cancelled_exc_class():
pass
except BaseException as exc:
@@ -121,7 +121,7 @@ class AsyncWorker(EventSource):
self._state = RunState.stopped
self._events.publish(WorkerStopped())
- async def _run_job(self, job: Job) -> None:
+ async def _run_job(self, job: Job, func: Callable) -> None:
try:
# Check if the job started before the deadline
start_time = datetime.now(timezone.utc)
@@ -131,24 +131,24 @@ class AsyncWorker(EventSource):
self._events.publish(JobStarted.from_job(job, start_time))
try:
- retval = job.func(*job.args, **job.kwargs)
+ retval = func(*job.args, **job.kwargs)
if isawaitable(retval):
retval = await retval
except get_cancelled_exc_class():
with CancelScope(shield=True):
result = JobResult(outcome=JobOutcome.cancelled)
- await self.data_store.release_job(self.identity, job.id, result)
+ await self.data_store.release_job(self.identity, job, result)
self._events.publish(JobCancelled.from_job(job, start_time))
except BaseException as exc:
result = JobResult(outcome=JobOutcome.failure, exception=exc)
- await self.data_store.release_job(self.identity, job.id, result)
+ await self.data_store.release_job(self.identity, job, result)
self._events.publish(JobFailed.from_exception(job, start_time, exc))
if not isinstance(exc, Exception):
raise
else:
result = JobResult(outcome=JobOutcome.success, return_value=retval)
- await self.data_store.release_job(self.identity, job.id, result)
+ await self.data_store.release_job(self.identity, job, result)
self._events.publish(JobCompleted.from_retval(job, start_time, retval))
finally:
self._running_jobs.remove(job.id)
diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py
index 9412b1c..1fb1cba 100644
--- a/src/apscheduler/workers/sync.py
+++ b/src/apscheduler/workers/sync.py
@@ -22,6 +22,7 @@ from ..structures import Job, JobResult
class Worker(EventSource):
"""Runs jobs locally in a thread pool."""
+ _executor: ThreadPoolExecutor
_state: RunState = RunState.stopped
_wakeup_event: threading.Event
@@ -32,7 +33,6 @@ class Worker(EventSource):
self.logger = logger or getLogger(__name__)
self._acquired_jobs: Set[Job] = set()
self._exit_stack = ExitStack()
- self._executor = ThreadPoolExecutor(max_workers=max_concurrent_jobs + 1)
self._events = EventHub()
self._running_jobs: Set[UUID] = set()
@@ -64,6 +64,7 @@ class Worker(EventSource):
# Start the worker and return when it has signalled readiness or raised an exception
start_future: Future[None] = Future()
token = self._events.subscribe(start_future.set_result)
+ self._executor = ThreadPoolExecutor(1)
run_future = self._executor.submit(self.run)
try:
wait([start_future, run_future], return_when=FIRST_COMPLETED)
@@ -98,26 +99,30 @@ class Worker(EventSource):
self._state = RunState.started
self._events.publish(WorkerStarted())
+ executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs)
try:
while self._state is RunState.started:
available_slots = self.max_concurrent_jobs - len(self._running_jobs)
if available_slots:
jobs = self.data_store.acquire_jobs(self.identity, available_slots)
for job in jobs:
+ task = self.data_store.get_task(job.task_id)
self._running_jobs.add(job.id)
- self._executor.submit(self._run_job, job)
+ executor.submit(self._run_job, job, task.func)
self._wakeup_event.wait()
self._wakeup_event = threading.Event()
except BaseException as exc:
+ executor.shutdown(wait=False)
self._state = RunState.stopped
self._events.publish(WorkerStopped(exception=exc))
raise
+ executor.shutdown()
self._state = RunState.stopped
self._events.publish(WorkerStopped())
- def _run_job(self, job: Job) -> None:
+ def _run_job(self, job: Job, func: Callable) -> None:
try:
# Check if the job started before the deadline
start_time = datetime.now(timezone.utc)
@@ -127,16 +132,16 @@ class Worker(EventSource):
self._events.publish(JobStarted.from_job(job, start_time))
try:
- retval = job.func(*job.args, **job.kwargs)
+ retval = func(*job.args, **job.kwargs)
except BaseException as exc:
result = JobResult(outcome=JobOutcome.failure, exception=exc)
- self.data_store.release_job(self.identity, job.id, result)
+ self.data_store.release_job(self.identity, job, result)
self._events.publish(JobFailed.from_exception(job, start_time, exc))
if not isinstance(exc, Exception):
raise
else:
result = JobResult(outcome=JobOutcome.success, return_value=retval)
- self.data_store.release_job(self.identity, job.id, result)
+ self.data_store.release_job(self.identity, job, result)
self._events.publish(JobCompleted.from_retval(job, start_time, retval))
finally:
self._running_jobs.remove(job.id)
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
index 1a18c03..2751c83 100644
--- a/tests/test_datastores.py
+++ b/tests/test_datastores.py
@@ -10,37 +10,27 @@ from freezegun.api import FrozenDateTimeFactory
from apscheduler.abc import AsyncDataStore, Job, Schedule
from apscheduler.enums import JobOutcome
-from apscheduler.events import Event, ScheduleAdded, ScheduleRemoved, ScheduleUpdated
+from apscheduler.events import Event, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, TaskAdded
from apscheduler.policies import CoalescePolicy, ConflictPolicy
-from apscheduler.structures import JobResult
+from apscheduler.structures import JobResult, Task
from apscheduler.triggers.date import DateTrigger
@pytest.fixture
def schedules() -> List[Schedule]:
trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc))
- schedule1 = Schedule(id='s1', task_id='bogus', trigger=trigger)
+ schedule1 = Schedule(id='s1', task_id='task1', trigger=trigger)
schedule1.next_fire_time = trigger.next()
trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc))
- schedule2 = Schedule(id='s2', task_id='bogus', trigger=trigger)
+ schedule2 = Schedule(id='s2', task_id='task2', trigger=trigger)
schedule2.next_fire_time = trigger.next()
trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc))
- schedule3 = Schedule(id='s3', task_id='bogus', trigger=trigger)
+ schedule3 = Schedule(id='s3', task_id='task1', trigger=trigger)
return [schedule1, schedule2, schedule3]
-@pytest.fixture
-def jobs() -> List[Job]:
- job1 = Job(task_id='task1', func=print, args=('hello',), kwargs={'arg2': 'world'},
- schedule_id='schedule1',
- scheduled_fire_time=datetime(2020, 10, 10, tzinfo=timezone.utc),
- start_deadline=datetime(2020, 10, 10, 1, tzinfo=timezone.utc))
- job2 = Job(task_id='task2', func=print, args=('hello',), kwargs={'arg2': 'world'})
- return [job1, job2]
-
-
@asynccontextmanager
async def capture_events(
store: AsyncDataStore, limit: int,
@@ -63,6 +53,33 @@ async def capture_events(
@pytest.mark.anyio
class TestAsyncStores:
+ async def test_add_replace_task(
+ self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None:
+ import math
+
+ async with datastore_cm as store, capture_events(store, 3, {TaskAdded}) as events:
+ await store.add_task(Task(id='test_task', func=print))
+ await store.add_task(Task(id='test_task2', func=math.ceil))
+ await store.add_task(Task(id='test_task', func=repr))
+
+ tasks = await store.get_tasks()
+ assert len(tasks) == 2
+ assert tasks[0].id == 'test_task'
+ assert tasks[0].func is repr
+ assert tasks[1].id == 'test_task2'
+ assert tasks[1].func is math.ceil
+
+ received_event = events.pop(0)
+ assert received_event.task_id == 'test_task'
+
+ received_event = events.pop(0)
+ assert received_event.task_id == 'test_task2'
+
+ received_event = events.pop(0)
+ assert received_event.task_id == 'test_task'
+
+ assert not events
+
async def test_add_schedules(self, datastore_cm: AsyncContextManager[AsyncDataStore],
schedules: List[Schedule]) -> None:
async with datastore_cm as store, capture_events(store, 3, {ScheduleAdded}) as events:
@@ -200,35 +217,39 @@ class TestAsyncStores:
assert acquired3[0].id == 's1'
async def test_acquire_multiple_workers(
- self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job]) -> None:
+ self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None:
async with datastore_cm as store:
+ await store.add_task(Task(id='task1', func=asynccontextmanager))
+ jobs = [Job(task_id='task1') for _ in range(2)]
for job in jobs:
await store.add_job(job)
# The first worker gets the first job in the queue
- jobs1 = await store.acquire_jobs('dummy-id1', 1)
+ jobs1 = await store.acquire_jobs('worker1', 1)
assert len(jobs1) == 1
assert jobs1[0].id == jobs[0].id
# The second worker gets the second job
- jobs2 = await store.acquire_jobs('dummy-id2', 1)
+ jobs2 = await store.acquire_jobs('worker2', 1)
assert len(jobs2) == 1
assert jobs2[0].id == jobs[1].id
# The third worker gets nothing
- jobs3 = await store.acquire_jobs('dummy-id3', 1)
+ jobs3 = await store.acquire_jobs('worker3', 1)
assert not jobs3
- async def test_job_release_success(self, datastore_cm: AsyncContextManager[AsyncDataStore],
- jobs: List[Job]):
+ async def test_job_release_success(
+ self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None:
async with datastore_cm as store:
- await store.add_job(jobs[0])
+ await store.add_task(Task(id='task1', func=asynccontextmanager))
+ job = Job(task_id='task1')
+ await store.add_job(job)
acquired = await store.acquire_jobs('worker_id', 2)
assert len(acquired) == 1
- assert acquired[0].id == jobs[0].id
+ assert acquired[0].id == job.id
- await store.release_job('worker_id', acquired[0].id,
+ await store.release_job('worker_id', acquired[0],
JobResult(JobOutcome.success, return_value='foo'))
result = await store.get_job_result(acquired[0].id)
assert result.outcome is JobOutcome.success
@@ -239,16 +260,18 @@ class TestAsyncStores:
assert not await store.get_jobs({acquired[0].id})
assert not await store.get_job_result(acquired[0].id)
- async def test_job_release_failure(self, datastore_cm: AsyncContextManager[AsyncDataStore],
- jobs: List[Job]):
+ async def test_job_release_failure(
+ self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None:
async with datastore_cm as store:
- await store.add_job(jobs[0])
+ await store.add_task(Task(id='task1', func=asynccontextmanager))
+ job = Job(task_id='task1')
+ await store.add_job(job)
acquired = await store.acquire_jobs('worker_id', 2)
assert len(acquired) == 1
- assert acquired[0].id == jobs[0].id
+ assert acquired[0].id == job.id
- await store.release_job('worker_id', acquired[0].id,
+ await store.release_job('worker_id', acquired[0],
JobResult(JobOutcome.failure, exception=ValueError('foo')))
result = await store.get_job_result(acquired[0].id)
assert result.outcome is JobOutcome.failure
@@ -261,15 +284,17 @@ class TestAsyncStores:
assert not await store.get_job_result(acquired[0].id)
async def test_job_release_missed_deadline(
- self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job]):
+ self, datastore_cm: AsyncContextManager[AsyncDataStore]):
async with datastore_cm as store:
- await store.add_job(jobs[0])
+ await store.add_task(Task(id='task1', func=asynccontextmanager))
+ job = Job(task_id='task1')
+ await store.add_job(job)
acquired = await store.acquire_jobs('worker_id', 2)
assert len(acquired) == 1
- assert acquired[0].id == jobs[0].id
+ assert acquired[0].id == job.id
- await store.release_job('worker_id', acquired[0].id,
+ await store.release_job('worker_id', acquired[0],
JobResult(JobOutcome.missed_start_deadline))
result = await store.get_job_result(acquired[0].id)
assert result.outcome is JobOutcome.missed_start_deadline
@@ -281,16 +306,17 @@ class TestAsyncStores:
assert not await store.get_job_result(acquired[0].id)
async def test_job_release_cancelled(
- self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job]):
+ self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None:
async with datastore_cm as store:
- await store.add_job(jobs[0])
+ await store.add_task(Task(id='task1', func=asynccontextmanager))
+ job = Job(task_id='task1')
+ await store.add_job(job)
- acquired = await store.acquire_jobs('worker_id', 2)
+ acquired = await store.acquire_jobs('worker1', 2)
assert len(acquired) == 1
- assert acquired[0].id == jobs[0].id
+ assert acquired[0].id == job.id
- await store.release_job('worker_id', acquired[0].id,
- JobResult(JobOutcome.cancelled))
+ await store.release_job('worker1', acquired[0], JobResult(JobOutcome.cancelled))
result = await store.get_job_result(acquired[0].id)
assert result.outcome is JobOutcome.cancelled
assert result.exception is None
@@ -301,7 +327,7 @@ class TestAsyncStores:
assert not await store.get_job_result(acquired[0].id)
async def test_acquire_jobs_lock_timeout(
- self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job],
+ self, datastore_cm: AsyncContextManager[AsyncDataStore],
freezer: FrozenDateTimeFactory) -> None:
"""
Test that a worker can acquire jobs that were acquired by another scheduler but not
@@ -309,19 +335,40 @@ class TestAsyncStores:
"""
async with datastore_cm as store:
+ await store.add_task(Task(id='task1', func=asynccontextmanager))
+ job = Job(task_id='task1')
+ await store.add_job(job)
+
# First, one worker acquires the first available job
- await store.add_job(jobs[0])
- acquired = await store.acquire_jobs('dummy-id1', 1)
+ acquired = await store.acquire_jobs('worker1', 1)
assert len(acquired) == 1
- assert acquired[0].id == jobs[0].id
+ assert acquired[0].id == job.id
# Try to acquire the job just at the threshold (now == acquired_until).
# This should not yield any jobs.
freezer.tick(30)
- assert not await store.acquire_jobs('dummy-id2', 1)
+ assert not await store.acquire_jobs('worker2', 1)
# Right after that, the job should be available
freezer.tick(1)
- acquired = await store.acquire_jobs('dummy-id2', 1)
+ acquired = await store.acquire_jobs('worker2', 1)
assert len(acquired) == 1
- assert acquired[0].id == jobs[0].id
+ assert acquired[0].id == job.id
+
+ async def test_acquire_jobs_max_number_exceeded(
+ self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None:
+ async with datastore_cm as store:
+ await store.add_task(Task(id='task1', func=asynccontextmanager, max_running_jobs=2))
+ jobs = [Job(task_id='task1'), Job(task_id='task1'), Job(task_id='task1')]
+ for job in jobs:
+ await store.add_job(job)
+
+ # Check that only 2 jobs are returned from acquire_jobs() even though the limit wqas 3
+ acquired_jobs = await store.acquire_jobs('worker1', 3)
+ assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]]
+
+ # Release one job, and the worker should be able to acquire the third job
+ await store.release_job('worker1', acquired_jobs[0],
+ JobResult(outcome=JobOutcome.success, return_value=None))
+ acquired_jobs = await store.acquire_jobs('worker1', 3)
+ assert [job.id for job in acquired_jobs] == [jobs[2].id]
diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py
index 3c0911d..990cfec 100644
--- a/tests/test_schedulers.py
+++ b/tests/test_schedulers.py
@@ -7,7 +7,7 @@ import pytest
from anyio import fail_after
from apscheduler.events import (
- Event, JobAdded, ScheduleAdded, ScheduleRemoved, SchedulerStarted, SchedulerStopped)
+ Event, JobAdded, ScheduleAdded, ScheduleRemoved, SchedulerStarted, SchedulerStopped, TaskAdded)
from apscheduler.schedulers.async_ import AsyncScheduler
from apscheduler.schedulers.sync import Scheduler
from apscheduler.triggers.date import DateTrigger
@@ -27,7 +27,7 @@ class TestAsyncScheduler:
async def test_schedule_job(self) -> None:
def listener(received_event: Event) -> None:
received_events.append(received_event)
- if len(received_events) == 4:
+ if len(received_events) == 5:
event.set()
received_events: List[Event] = []
@@ -44,6 +44,11 @@ class TestAsyncScheduler:
received_event = received_events.pop(0)
assert isinstance(received_event, SchedulerStarted)
+ # Then the task was added
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == 'test_schedulers:dummy_async_job'
+
# Then a schedule was added
received_event = received_events.pop(0)
assert isinstance(received_event, ScheduleAdded)
@@ -73,7 +78,7 @@ class TestSyncScheduler:
def test_schedule_job(self):
def listener(received_event: Event) -> None:
received_events.append(received_event)
- if len(received_events) == 4:
+ if len(received_events) == 5:
event.set()
received_events: List[Event] = []
@@ -89,6 +94,11 @@ class TestSyncScheduler:
received_event = received_events.pop(0)
assert isinstance(received_event, SchedulerStarted)
+ # Then the task was added
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == 'test_schedulers:dummy_sync_job'
+
# Then a schedule was added
received_event = received_events.pop(0)
assert isinstance(received_event, ScheduleAdded)
diff --git a/tests/test_workers.py b/tests/test_workers.py
index a45398d..f949edf 100644
--- a/tests/test_workers.py
+++ b/tests/test_workers.py
@@ -9,8 +9,9 @@ from anyio import fail_after
from apscheduler.abc import Job
from apscheduler.datastores.sync.memory import MemoryDataStore
from apscheduler.events import (
- Event, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, WorkerStarted,
- WorkerStopped)
+ Event, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, TaskAdded,
+ WorkerStarted, WorkerStopped)
+from apscheduler.structures import Task
from apscheduler.workers.async_ import AsyncWorker
from apscheduler.workers.sync import Worker
@@ -41,7 +42,7 @@ class TestAsyncWorker:
async def test_run_job_nonscheduled_success(self, target_func: Callable, fail: bool) -> None:
def listener(received_event: Event):
received_events.append(received_event)
- if len(received_events) == 4:
+ if len(received_events) == 5:
event.set()
received_events: List[Event] = []
@@ -50,8 +51,8 @@ class TestAsyncWorker:
worker = AsyncWorker(data_store)
worker.subscribe(listener)
async with worker:
- job = Job(task_id='task_id', func=target_func, args=(1, 2),
- kwargs={'x': 'foo', 'fail': fail})
+ await worker.data_store.add_task(Task(id='task_id', func=target_func))
+ job = Job(task_id='task_id', args=(1, 2), kwargs={'x': 'foo', 'fail': fail})
await worker.data_store.add_job(job)
with fail_after(3):
await event.wait()
@@ -60,6 +61,11 @@ class TestAsyncWorker:
received_event = received_events.pop(0)
assert isinstance(received_event, WorkerStarted)
+ # Then the task was added
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == 'task_id'
+
# Then a job was added
received_event = received_events.pop(0)
assert isinstance(received_event, JobAdded)
@@ -96,7 +102,7 @@ class TestAsyncWorker:
async def test_run_deadline_missed(self) -> None:
def listener(received_event: Event):
received_events.append(received_event)
- if len(received_events) == 3:
+ if len(received_events) == 4:
event.set()
scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc)
@@ -106,7 +112,8 @@ class TestAsyncWorker:
worker = AsyncWorker(data_store)
worker.subscribe(listener)
async with worker:
- job = Job(task_id='task_id', func=fail_func, schedule_id='foo',
+ await worker.data_store.add_task(Task(id='task_id', func=fail_func))
+ job = Job(task_id='task_id', schedule_id='foo',
scheduled_fire_time=scheduled_start_time,
start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc))
await worker.data_store.add_job(job)
@@ -117,6 +124,11 @@ class TestAsyncWorker:
received_event = received_events.pop(0)
assert isinstance(received_event, WorkerStarted)
+ # Then the task was added
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == 'task_id'
+
# Then a job was added
received_event = received_events.pop(0)
assert isinstance(received_event, JobAdded)
@@ -144,7 +156,7 @@ class TestSyncWorker:
def test_run_job_nonscheduled(self, fail: bool) -> None:
def listener(received_event: Event):
received_events.append(received_event)
- if len(received_events) == 4:
+ if len(received_events) == 5:
event.set()
received_events: List[Event] = []
@@ -153,8 +165,8 @@ class TestSyncWorker:
worker = Worker(data_store)
worker.subscribe(listener)
with worker:
- job = Job(task_id='task_id', func=sync_func, args=(1, 2),
- kwargs={'x': 'foo', 'fail': fail})
+ worker.data_store.add_task(Task(id='task_id', func=sync_func))
+ job = Job(task_id='task_id', args=(1, 2), kwargs={'x': 'foo', 'fail': fail})
worker.data_store.add_job(job)
event.wait(5)
@@ -162,6 +174,11 @@ class TestSyncWorker:
received_event = received_events.pop(0)
assert isinstance(received_event, WorkerStarted)
+ # Then the task was added
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == 'task_id'
+
# Then a job was added
received_event = received_events.pop(0)
assert isinstance(received_event, JobAdded)
@@ -198,7 +215,7 @@ class TestSyncWorker:
def test_run_deadline_missed(self) -> None:
def listener(worker_event: Event):
received_events.append(worker_event)
- if len(received_events) == 3:
+ if len(received_events) == 4:
event.set()
scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc)
@@ -208,7 +225,8 @@ class TestSyncWorker:
worker = Worker(data_store)
worker.subscribe(listener)
with worker:
- job = Job(task_id='task_id', func=fail_func, schedule_id='foo',
+ worker.data_store.add_task(Task(id='task_id', func=fail_func))
+ job = Job(task_id='task_id', schedule_id='foo',
scheduled_fire_time=scheduled_start_time,
start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc))
worker.data_store.add_job(job)
@@ -218,6 +236,11 @@ class TestSyncWorker:
received_event = received_events.pop(0)
assert isinstance(received_event, WorkerStarted)
+ # Then the task was added
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, TaskAdded)
+ assert received_event.task_id == 'task_id'
+
# Then a job was added
received_event = received_events.pop(0)
assert isinstance(received_event, JobAdded)