diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-05 23:12:34 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-06 01:39:07 +0300 |
commit | 6fed43f29bfa7929fcaf5e549ce34819ba7e3702 (patch) | |
tree | f3c8e7675a03354cb30205bd6a21e7ae231c83a6 | |
parent | 7147cd422bfc6280a56bb4335d65935e942ab9e3 (diff) | |
download | apscheduler-6fed43f29bfa7929fcaf5e549ce34819ba7e3702.tar.gz |
Implemented task accounting
The maximum number of concurrent jobs for a given task is now enforced if set.
-rw-r--r-- | src/apscheduler/abc.py | 86 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_/sqlalchemy.py | 132 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_/sync_adapter.py | 18 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/memory.py | 81 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/mongodb.py | 125 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/sqlalchemy.py | 143 | ||||
-rw-r--r-- | src/apscheduler/enums.py | 1 | ||||
-rw-r--r-- | src/apscheduler/events.py | 10 | ||||
-rw-r--r-- | src/apscheduler/exceptions.py | 5 | ||||
-rw-r--r-- | src/apscheduler/schedulers/async_.py | 65 | ||||
-rw-r--r-- | src/apscheduler/schedulers/sync.py | 41 | ||||
-rw-r--r-- | src/apscheduler/structures.py | 46 | ||||
-rw-r--r-- | src/apscheduler/workers/async_.py | 40 | ||||
-rw-r--r-- | src/apscheduler/workers/sync.py | 17 | ||||
-rw-r--r-- | tests/test_datastores.py | 139 | ||||
-rw-r--r-- | tests/test_schedulers.py | 16 | ||||
-rw-r--r-- | tests/test_workers.py | 47 |
17 files changed, 791 insertions, 221 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index ec77de5..19c5a02 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, List, Optio from uuid import UUID from .policies import ConflictPolicy -from .structures import Job, JobResult, Schedule +from .structures import Job, JobResult, Schedule, Task if TYPE_CHECKING: from . import events @@ -94,6 +94,44 @@ class DataStore(EventSource): pass @abstractmethod + def add_task(self, task: Task) -> None: + """ + Add the given task to the store. + + If a task with the same ID already exists, it replaces the old one but does NOT affect + task accounting (# of running jobs). + + :param task: the task to be added + """ + + @abstractmethod + def remove_task(self, task_id: str) -> None: + """ + Remove the task with the given ID. + + :param task_id: ID of the task to be removed + :raises TaskLookupError: if no matching task was found + """ + + @abstractmethod + def get_task(self, task_id: str) -> Task: + """ + Get an existing task definition. + + :param task_id: ID of the task to be returned + :return: the matching task + :raises TaskLookupError: if no matching task was found + """ + + @abstractmethod + def get_tasks(self) -> List[Task]: + """ + Get all the tasks in this store. + + :return: a list of tasks, sorted by ID + """ + + @abstractmethod def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: """ Get schedules from the data store. @@ -180,12 +218,12 @@ class DataStore(EventSource): """ @abstractmethod - def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: """ Release the claim on the given job and record the result. :param worker_id: unique identifier of the worker - :param job_id: identifier of the job + :param job: the job to be released :param result: the result of the job (or ``None`` to discard the job) """ @@ -209,6 +247,44 @@ class AsyncDataStore(EventSource): pass @abstractmethod + async def add_task(self, task: Task) -> None: + """ + Add the given task to the store. + + If a task with the same ID already exists, it replaces the old one but does NOT affect + task accounting (# of running jobs). + + :param task: the task to be added + """ + + @abstractmethod + async def remove_task(self, task_id: str) -> None: + """ + Remove the task with the given ID. + + :param task_id: ID of the task to be removed + :raises TaskLookupError: if no matching task was found + """ + + @abstractmethod + async def get_task(self, task_id: str) -> Task: + """ + Get an existing task definition. + + :param task_id: ID of the task to be returned + :return: the matching task + :raises TaskLookupError: if no matching task was found + """ + + @abstractmethod + async def get_tasks(self) -> List[Task]: + """ + Get all the tasks in this store. + + :return: a list of tasks, sorted by ID + """ + + @abstractmethod async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: """ Get schedules from the data store. @@ -295,12 +371,12 @@ class AsyncDataStore(EventSource): """ @abstractmethod - async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: """ Release the claim on the given job and record the result. :param worker_id: unique identifier of the worker - :param job_id: identifier of the job + :param job: the job to be released :param result: the result of the job (or ``None`` to discard the job) """ diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py index caf832c..00e9efc 100644 --- a/src/apscheduler/datastores/async_/sqlalchemy.py +++ b/src/apscheduler/datastores/async_/sqlalchemy.py @@ -2,6 +2,7 @@ from __future__ import annotations import json import logging +from collections import defaultdict from contextlib import AsyncExitStack, closing from datetime import datetime, timedelta, timezone from json import JSONDecodeError @@ -18,17 +19,19 @@ from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable -from sqlalchemy.sql.elements import BindParameter +from sqlalchemy.sql.elements import BindParameter, literal from ... import events as events_module from ...abc import AsyncDataStore, Job, Schedule, Serializer from ...events import ( AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded, - ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) -from ...exceptions import ConflictingIdError, SerializationError + ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, + TaskRemoved) +from ...exceptions import ConflictingIdError, SerializationError, TaskLookupError +from ...marshalling import callable_to_ref from ...policies import ConflictPolicy from ...serializers.pickle import PickleSerializer -from ...structures import JobResult +from ...structures import JobResult, Task from ...util import reentrant logger = logging.getLogger(__name__) @@ -78,6 +81,7 @@ class SQLAlchemyDataStore(AsyncDataStore): # Generate the table definitions self._metadata = self.get_table_definitions() self.t_metadata = self._metadata.tables['metadata'] + self.t_tasks = self._metadata.tables['tasks'] self.t_schedules = self._metadata.tables['schedules'] self.t_jobs = self._metadata.tables['jobs'] self.t_job_results = self._metadata.tables['job_results'] @@ -153,7 +157,11 @@ class SQLAlchemyDataStore(AsyncDataStore): 'tasks', metadata, Column('id', Unicode(500), primary_key=True), - Column('serialized_data', LargeBinary, nullable=False) + Column('func', Unicode(500), nullable=False), + Column('state', LargeBinary), + Column('max_running_jobs', Integer), + Column('misfire_grace_time', Unicode(16)), + Column('running_jobs', Integer, nullable=False, server_default=literal(0)) ) Table( 'schedules', @@ -259,6 +267,55 @@ class SQLAlchemyDataStore(AsyncDataStore): def unsubscribe(self, token: SubscriptionToken) -> None: self._events.unsubscribe(token) + async def add_task(self, task: Task) -> None: + insert = self.t_tasks.insert().\ + values(id=task.id, func=callable_to_ref(task.func), + max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time) + try: + async with self.engine.begin() as conn: + await conn.execute(insert) + except IntegrityError: + update = self.t_tasks.update().\ + values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time).\ + where(self.t_tasks.c.id == task.id) + async with self.engine.begin() as conn: + await conn.execute(update) + + self._events.publish(TaskAdded(task_id=task.id)) + + async def remove_task(self, task_id: str) -> None: + delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) + async with self.engine.begin() as conn: + result = await conn.execute(delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + self._events.publish(TaskRemoved(task_id=task_id)) + + async def get_task(self, task_id: str) -> Task: + query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ + where(self.t_tasks.c.id == task_id) + async with self.engine.begin() as conn: + result = await conn.execute(query) + row = result.fetch_one() + + if row: + return Task.unmarshal(self.serializer, row._asdict()) + else: + raise TaskLookupError + + async def get_tasks(self) -> List[Task]: + query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ + order_by(self.t_tasks.c.id) + async with self.engine.begin() as conn: + result = await conn.execute(query) + tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] + return tasks + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: event: Event serialized_data = self.serializer.serialize(schedule) @@ -436,30 +493,79 @@ class SQLAlchemyDataStore(AsyncDataStore): now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ where(or_(self.t_jobs.c.acquired_until.is_(None), self.t_jobs.c.acquired_until < now)).\ order_by(self.t_jobs.c.created_at).\ limit(limit) - serialized_jobs: Dict[str, bytes] = {row[0]: row[1] - for row in await conn.execute(query)} - if serialized_jobs: + result = await conn.execute(query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = self._deserialize_jobs(result) + task_ids: Set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select([self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ + where(self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids)) + result = await conn.execute(query) + job_slots_left = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: List[Job] = [] + increments: Dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id.hex for job in acquired_jobs] update = self.t_jobs.update().\ values(acquired_by=worker_id, acquired_until=acquired_until).\ - where(self.t_jobs.c.id.in_(serialized_jobs)) + where(self.t_jobs.c.id.in_(acquired_job_ids)) await conn.execute(update) - return self._deserialize_jobs(serialized_jobs.items()) + # Increment the running job counters on each task + p_id = bindparam('p_id') + p_increment = bindparam('p_increment') + params = [{'p_id': task_id, 'p_increment': increment} + for task_id, increment in increments.items()] + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ + where(self.t_tasks.c.id == p_id) + await conn.execute(update, params) + + return acquired_jobs - async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: async with self.engine.begin() as conn: + # Insert the job result now = datetime.now(timezone.utc) serialized_data = self.serializer.serialize(result) insert = self.t_job_results.insert().\ - values(job_id=job_id.hex, finished_at=now, serialized_data=serialized_data) + values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_data) await conn.execute(insert) - delete = self.t_jobs.delete().where(self.t_jobs.c.id == job_id.hex) + # Decrement the running jobs counter + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs - 1).\ + where(self.t_tasks.c.id == job.task_id) + await conn.execute(update) + + # Delete the job + delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex) await conn.execute(delete) async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: diff --git a/src/apscheduler/datastores/async_/sync_adapter.py b/src/apscheduler/datastores/async_/sync_adapter.py index 2bb691e..4f81196 100644 --- a/src/apscheduler/datastores/async_/sync_adapter.py +++ b/src/apscheduler/datastores/async_/sync_adapter.py @@ -13,7 +13,7 @@ from ... import events from ...abc import AsyncDataStore, DataStore from ...events import Event, SubscriptionToken from ...policies import ConflictPolicy -from ...structures import Job, JobResult, Schedule +from ...structures import Job, JobResult, Schedule, Task from ...util import reentrant @@ -33,6 +33,18 @@ class AsyncDataStoreAdapter(AsyncDataStore): await to_thread.run_sync(self.original.__exit__, exc_type, exc_val, exc_tb) await self._portal.__aexit__(exc_type, exc_val, exc_tb) + async def add_task(self, task: Task) -> None: + await to_thread.run_sync(self.original.add_task, task) + + async def remove_task(self, task_id: str) -> None: + await to_thread.run_sync(self.original.remove_task, task_id) + + async def get_task(self, task_id: str) -> Task: + return await to_thread.run_sync(self.original.get_task, task_id) + + async def get_tasks(self) -> List[Task]: + return await to_thread.run_sync(self.original.get_tasks) + async def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: return await to_thread.run_sync(self.original.get_schedules, ids) @@ -60,8 +72,8 @@ class AsyncDataStoreAdapter(AsyncDataStore): async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit) - async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: - await to_thread.run_sync(self.original.release_job, worker_id, job_id, result) + async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: + await to_thread.run_sync(self.original.release_job, worker_id, job, result) async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: return await to_thread.run_sync(self.original.get_job_result, job_id) diff --git a/src/apscheduler/datastores/sync/memory.py b/src/apscheduler/datastores/sync/memory.py index a108142..4239dc7 100644 --- a/src/apscheduler/datastores/sync/memory.py +++ b/src/apscheduler/datastores/sync/memory.py @@ -12,16 +12,27 @@ import attr from ... import events from ...abc import DataStore, Job, Schedule from ...events import ( - EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) -from ...exceptions import ConflictingIdError + EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, + TaskAdded, TaskRemoved) +from ...exceptions import ConflictingIdError, TaskLookupError from ...policies import ConflictPolicy -from ...structures import JobResult +from ...structures import JobResult, Task from ...util import reentrant max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) @attr.define +class TaskState: + task: Task + running_jobs: int = 0 + saved_state: Any = None + + def __eq__(self, other): + return self.task.id == other.task.id + + +@attr.define class ScheduleState: schedule: Schedule next_fire_time: Optional[datetime] = attr.field(init=False, eq=False) @@ -65,6 +76,7 @@ class JobState: class MemoryDataStore(DataStore): lock_expiration_delay: float = 30 _events: EventHub = attr.Factory(EventHub) + _tasks: Dict[str, TaskState] = attr.Factory(dict) _schedules: List[ScheduleState] = attr.Factory(list) _schedules_by_id: Dict[str, ScheduleState] = attr.Factory(dict) _schedules_by_task_id: Dict[str, Set[ScheduleState]] = attr.Factory(partial(defaultdict, set)) @@ -101,6 +113,27 @@ class MemoryDataStore(DataStore): return [state.schedule for state in self._schedules if ids is None or state.schedule.id in ids] + def add_task(self, task: Task) -> None: + self._tasks[task.id] = TaskState(task) + self._events.publish(TaskAdded(task_id=task.id)) + + def remove_task(self, task_id: str) -> None: + try: + del self._tasks[task_id] + except KeyError: + raise TaskLookupError(task_id) from None + + self._events.publish(TaskRemoved(task_id=task_id)) + + def get_task(self, task_id: str) -> Task: + try: + return self._tasks[task_id].task + except KeyError: + raise TaskLookupError(task_id) from None + + def get_tasks(self) -> List[Task]: + return sorted((state.task for state in self._tasks.values()), key=lambda task: task.id) + def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: old_state = self._schedules_by_id.get(schedule.id) if old_state is not None: @@ -201,27 +234,49 @@ class MemoryDataStore(DataStore): def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: now = datetime.now(timezone.utc) jobs: List[Job] = [] - for state in self._jobs: - if state.acquired_by is not None and state.acquired_until >= now: + for index, job_state in enumerate(self._jobs): + task_state = self._tasks[job_state.job.task_id] + + # Skip already acquired jobs (unless the acquisition lock has expired) + if job_state.acquired_by is not None: + if job_state.acquired_until >= now: + continue + else: + task_state.running_jobs -= 1 + + # Check if the task allows one more job to be started + if (task_state.task.max_running_jobs is not None + and task_state.running_jobs >= task_state.task.max_running_jobs): continue - jobs.append(state.job) - state.acquired_by = worker_id - state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + # Mark the job as acquired by this worker + jobs.append(job_state.job) + job_state.acquired_by = worker_id + job_state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + + # Increment the number of running jobs for this task + task_state.running_jobs += 1 + + # Exit the loop if enough jobs have been acquired if len(jobs) == limit: break return jobs - def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: # Delete the job - state = self._jobs_by_id.pop(job_id) - self._jobs_by_task_id[state.job.task_id].remove(state) - index = self._find_job_index(state) + job_state = self._jobs_by_id.pop(job.id) + self._jobs_by_task_id[job.task_id].remove(job_state) + index = self._find_job_index(job_state) del self._jobs[index] + # Decrement the number of running jobs for this task + task_state = self._tasks.get(job.task_id) + if task_state is not None: + task_state.running_jobs -= 1 + # Record the result - self._job_results[job_id] = result + self._job_results[job.id] = result def get_job_result(self, job_id: UUID) -> Optional[JobResult]: return self._job_results.pop(job_id, None) diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py index 4d5c0f2..fcf7d4d 100644 --- a/src/apscheduler/datastores/sync/mongodb.py +++ b/src/apscheduler/datastores/sync/mongodb.py @@ -1,11 +1,14 @@ from __future__ import annotations import logging +from collections import defaultdict from contextlib import ExitStack from datetime import datetime, timezone -from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Type +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Set, Tuple, Type from uuid import UUID +import attr +import pymongo from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.errors import DuplicateKeyError @@ -14,19 +17,25 @@ from ... import events from ...abc import DataStore, Job, Schedule, Serializer from ...events import ( DataStoreEvent, EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, - SubscriptionToken) -from ...exceptions import ConflictingIdError, DeserializationError, SerializationError + SubscriptionToken, TaskAdded, TaskRemoved) +from ...exceptions import ( + ConflictingIdError, DeserializationError, SerializationError, TaskLookupError) from ...policies import ConflictPolicy from ...serializers.pickle import PickleSerializer -from ...structures import JobResult +from ...structures import JobResult, Task from ...util import reentrant @reentrant class MongoDBDataStore(DataStore): + _task_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Task)] + _schedule_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Schedule)] + _job_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Job)] + def __init__(self, client: MongoClient, *, serializer: Optional[Serializer] = None, - database: str = 'apscheduler', schedules_collection: str = 'schedules', - jobs_collection: str = 'jobs', job_results_collection: str = 'job_results', + database: str = 'apscheduler', tasks_collection: str = 'tasks', + schedules_collection: str = 'schedules', jobs_collection: str = 'jobs', + job_results_collection: str = 'job_results', lock_expiration_delay: float = 30, start_from_scratch: bool = False): super().__init__() if not client.delegate.codec_options.tz_aware: @@ -36,7 +45,9 @@ class MongoDBDataStore(DataStore): self.serializer = serializer or PickleSerializer() self.lock_expiration_delay = lock_expiration_delay self.start_from_scratch = start_from_scratch + self._local_tasks: Dict[str, Task] = {} self._database = client[database] + self._tasks: Collection = self._database[tasks_collection] self._schedules: Collection = self._database[schedules_collection] self._jobs: Collection = self._database[jobs_collection] self._jobs_results: Collection = self._database[job_results_collection] @@ -59,6 +70,7 @@ class MongoDBDataStore(DataStore): self._exit_stack.enter_context(self._events) if self.start_from_scratch: + self._tasks.delete_many({}) self._schedules.delete_many({}) self._jobs.delete_many({}) self._jobs_results.delete_many({}) @@ -80,6 +92,44 @@ class MongoDBDataStore(DataStore): def unsubscribe(self, token: events.SubscriptionToken) -> None: self._events.unsubscribe(token) + def add_task(self, task: Task) -> None: + self._tasks.find_one_and_update( + {'_id': task.id}, + {'$set': task.marshal(self.serializer), + '$setOnInsert': {'running_jobs': 0}}, + upsert=True + ) + self._local_tasks[task.id] = task + self._events.publish(TaskAdded(task_id=task.id)) + + def remove_task(self, task_id: str) -> None: + if not self._tasks.find_one_and_delete({'_id': task_id}): + raise TaskLookupError(task_id) + + del self._local_tasks[task_id] + self._events.publish(TaskRemoved(task_id=task_id)) + + def get_task(self, task_id: str) -> Task: + try: + return self._local_tasks[task_id] + except KeyError: + document = self._tasks.find_one({'_id': task_id}, projection=self._task_attrs) + if not document: + raise TaskLookupError(task_id) + + document['id'] = document.pop('id') + task = self._local_tasks[task_id] = Task.unmarshal(self.serializer, document) + return task + + def get_tasks(self) -> List[Task]: + tasks: List[Task] = [] + for document in self._tasks.find(projection=self._task_attrs, + sort=[('_id', pymongo.ASCENDING)]): + document['id'] = document.pop('_id') + tasks.append(Task.unmarshal(self.serializer, document)) + + return tasks + def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: schedules: List[Schedule] = [] filters = {'_id': {'$in': list(ids)}} if ids is not None else {} @@ -238,45 +288,82 @@ class MongoDBDataStore(DataStore): return jobs def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: - jobs: List[Job] = [] with self.client.start_session() as session: cursor = self._jobs.find( {'$or': [{'acquired_until': {'$exists': False}}, {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] }, - projection=['serialized_data'], + projection=['task_id', 'serialized_data'], sort=[('created_at', ASCENDING)], limit=limit, session=session ) - for document in cursor: + documents = list(cursor) + + # Retrieve the limits + task_ids: Set[str] = {document['task_id'] for document in documents} + task_limits = self._tasks.find( + {'_id': {'$in': list(task_ids)}, 'max_running_jobs': {'$ne': None}}, + projection=['max_running_jobs', 'running_jobs'], + session=session + ) + job_slots_left = {doc['_id']: doc['max_running_jobs'] - doc['running_jobs'] + for doc in task_limits} + + # Filter out jobs that don't have free slots + acquired_jobs: List[Job] = [] + increments: Dict[str, int] = defaultdict(lambda: 0) + for document in documents: job = self.serializer.deserialize(document['serialized_data']) - jobs.append(job) - if jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: now = datetime.now(timezone.utc) acquired_until = datetime.fromtimestamp( now.timestamp() + self.lock_expiration_delay, timezone.utc) - filters = {'_id': {'$in': [job.id for job in jobs]}} + filters = {'_id': {'$in': [job.id for job in acquired_jobs]}} update = {'$set': {'acquired_by': worker_id, 'acquired_until': acquired_until}} self._jobs.update_many(filters, update, session=session) - return jobs + # Increment the running job counters on each task + for task_id, increment in increments.items(): + self._tasks.find_one_and_update( + {'_id': task_id}, + {'$inc': {'running_jobs': increment}}, + session=session + ) + + return acquired_jobs - def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: with self.client.start_session() as session: + # Insert the job result now = datetime.now(timezone.utc) - serialized_result = self.serializer.serialize(result) document = { - '_id': job_id, + '_id': job.id, 'finished_at': now, - 'serialized_data': serialized_result + 'serialized_data': self.serializer.serialize(result) } self._jobs_results.insert_one(document, session=session) - filters = {'_id': job_id, 'acquired_by': worker_id} - self._jobs.delete_one(filters, session=session) + # Decrement the running jobs counter + self._tasks.find_one_and_update( + {'_id': job.task_id}, + {'$inc': {'running_jobs': -1}} + ) + + # Delete the job + self._jobs.delete_one({'_id': job.id}, session=session) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: document = self._jobs_results.find_one_and_delete( diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py index f1f541a..3b1d2bf 100644 --- a/src/apscheduler/datastores/sync/sqlalchemy.py +++ b/src/apscheduler/datastores/sync/sqlalchemy.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from collections import defaultdict from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from uuid import UUID @@ -11,16 +12,18 @@ from sqlalchemy.engine import URL from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.future import Engine, create_engine from sqlalchemy.sql.ddl import DropTable -from sqlalchemy.sql.elements import BindParameter +from sqlalchemy.sql.elements import BindParameter, literal from ...abc import DataStore, Job, Schedule, Serializer from ...events import ( Event, EventHub, JobAdded, JobDeserializationFailed, ScheduleAdded, - ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) -from ...exceptions import ConflictingIdError, SerializationError + ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, + TaskRemoved) +from ...exceptions import ConflictingIdError, SerializationError, TaskLookupError +from ...marshalling import callable_to_ref from ...policies import ConflictPolicy from ...serializers.pickle import PickleSerializer -from ...structures import JobResult +from ...structures import JobResult, Task from ...util import reentrant logger = logging.getLogger(__name__) @@ -45,6 +48,7 @@ class SQLAlchemyDataStore(DataStore): # Generate the table definitions self._metadata = self.get_table_definitions() self.t_metadata = self._metadata.tables['metadata'] + self.t_tasks = self._metadata.tables['tasks'] self.t_schedules = self._metadata.tables['schedules'] self.t_jobs = self._metadata.tables['jobs'] self.t_job_results = self._metadata.tables['job_results'] @@ -58,6 +62,15 @@ class SQLAlchemyDataStore(DataStore): else: self._supports_update_returning = True + # Find out if the dialect supports INSERT...ON DUPLICATE KEY UPDATE + insert = self.t_jobs.update().returning(self.t_schedules.c.id) + try: + insert.compile(bind=self.engine) + except CompileError: + self._supports_update_returning = False + else: + self._supports_update_returning = True + @classmethod def from_url(cls, url: Union[str, URL], **options) -> 'SQLAlchemyDataStore': engine = create_engine(url) @@ -103,7 +116,11 @@ class SQLAlchemyDataStore(DataStore): 'tasks', metadata, Column('id', Unicode(500), primary_key=True), - Column('serialized_data', LargeBinary, nullable=False) + Column('func', Unicode(500), nullable=False), + Column('state', LargeBinary), + Column('max_running_jobs', Integer), + Column('misfire_grace_time', Unicode(16)), + Column('running_jobs', Integer, nullable=False, server_default=literal(0)) ) Table( 'schedules', @@ -163,6 +180,55 @@ class SQLAlchemyDataStore(DataStore): def unsubscribe(self, token: SubscriptionToken) -> None: self._events.unsubscribe(token) + def add_task(self, task: Task) -> None: + insert = self.t_tasks.insert().\ + values(id=task.id, func=callable_to_ref(task.func), + max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time) + try: + with self.engine.begin() as conn: + conn.execute(insert) + except IntegrityError: + update = self.t_tasks.update().\ + values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time).\ + where(self.t_tasks.c.id == task.id) + with self.engine.begin() as conn: + conn.execute(update) + + self._events.publish(TaskAdded(task_id=task.id)) + + def remove_task(self, task_id: str) -> None: + delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) + with self.engine.begin() as conn: + result = conn.execute(delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + self._events.publish(TaskRemoved(task_id=task_id)) + + def get_task(self, task_id: str) -> Task: + query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ + where(self.t_tasks.c.id == task_id) + with self.engine.begin() as conn: + result = conn.execute(query) + row = result.fetch_one() + + if row: + return Task.unmarshal(self.serializer, row._asdict()) + else: + raise TaskLookupError + + def get_tasks(self) -> List[Task]: + query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ + order_by(self.t_tasks.c.id) + with self.engine.begin() as conn: + result = conn.execute(query) + tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] + return tasks + def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: event: Event serialized_data = self.serializer.serialize(schedule) @@ -339,38 +405,89 @@ class SQLAlchemyDataStore(DataStore): now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ where(or_(self.t_jobs.c.acquired_until.is_(None), self.t_jobs.c.acquired_until < now)).\ order_by(self.t_jobs.c.created_at).\ limit(limit) - serialized_jobs: Dict[str, bytes] = {row[0]: row[1] - for row in conn.execute(query)} - if serialized_jobs: + result = conn.execute(query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = self._deserialize_jobs(result) + task_ids: Set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select([self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ + where(self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids)) + result = conn.execute(query) + job_slots_left = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: List[Job] = [] + increments: Dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id.hex for job in acquired_jobs] update = self.t_jobs.update().\ values(acquired_by=worker_id, acquired_until=acquired_until).\ - where(self.t_jobs.c.id.in_(serialized_jobs)) + where(self.t_jobs.c.id.in_(acquired_job_ids)) conn.execute(update) - return self._deserialize_jobs(serialized_jobs.items()) + # Increment the running job counters on each task + p_id = bindparam('p_id') + p_increment = bindparam('p_increment') + params = [{'p_id': task_id, 'p_increment': increment} + for task_id, increment in increments.items()] + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ + where(self.t_tasks.c.id == p_id) + conn.execute(update, params) + + return acquired_jobs - def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: with self.engine.begin() as conn: + # Insert the job result now = datetime.now(timezone.utc) serialized_result = self.serializer.serialize(result) insert = self.t_job_results.insert().\ - values(job_id=job_id.hex, finished_at=now, serialized_data=serialized_result) + values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_result) conn.execute(insert) - delete = self.t_jobs.delete().where(self.t_jobs.c.id == job_id.hex) + # Decrement the running jobs counter + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs - 1).\ + where(self.t_tasks.c.id == job.task_id) + conn.execute(update) + + # Delete the job + delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex) conn.execute(delete) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: with self.engine.begin() as conn: + # Retrieve the result query = select(self.t_job_results.c.serialized_data).\ where(self.t_job_results.c.job_id == job_id.hex) result = conn.execute(query) + # Delete the result delete = self.t_job_results.delete().\ where(self.t_job_results.c.job_id == job_id.hex) conn.execute(delete) diff --git a/src/apscheduler/enums.py b/src/apscheduler/enums.py index a7c0d69..941de3a 100644 --- a/src/apscheduler/enums.py +++ b/src/apscheduler/enums.py @@ -13,3 +13,4 @@ class JobOutcome(Enum): failure = auto() missed_start_deadline = auto() cancelled = auto() + expired = auto() diff --git a/src/apscheduler/events.py b/src/apscheduler/events.py index 552d30b..6c0d270 100644 --- a/src/apscheduler/events.py +++ b/src/apscheduler/events.py @@ -45,6 +45,16 @@ class DataStoreEvent(Event): @attr.define(kw_only=True, frozen=True) +class TaskAdded(DataStoreEvent): + task_id: str + + +@attr.define(kw_only=True, frozen=True) +class TaskRemoved(DataStoreEvent): + task_id: str + + +@attr.define(kw_only=True, frozen=True) class ScheduleAdded(DataStoreEvent): schedule_id: str next_fire_time: Optional[datetime] = attr.field(converter=timestamp_to_datetime) diff --git a/src/apscheduler/exceptions.py b/src/apscheduler/exceptions.py index ec04f2c..1f90ba0 100644 --- a/src/apscheduler/exceptions.py +++ b/src/apscheduler/exceptions.py @@ -1,3 +1,8 @@ +class TaskLookupError(LookupError): + """Raised by a data store when it cannot find the requested task.""" + + def __init__(self, task_id: str): + super().__init__(f'No task by the id of {task_id!r} was found') class JobLookupError(KeyError): diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index bb1760d..708d2ec 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -5,7 +5,7 @@ import platform from contextlib import AsyncExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Type, Union +from typing import Any, Callable, Iterable, Mapping, Optional, Type, Union from uuid import uuid4 import anyio @@ -41,7 +41,6 @@ class AsyncScheduler(EventSource): self.identity = identity or f'{platform.node()}-{os.getpid()}-{id(self)}' self.logger = logger or getLogger(__name__) self.start_worker = start_worker - self._tasks: Dict[str, Task] = {} self._exit_stack = AsyncExitStack() self._events = AsyncEventHub() @@ -96,27 +95,27 @@ class AsyncScheduler(EventSource): def unsubscribe(self, token: SubscriptionToken) -> None: self._events.unsubscribe(token) - def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task: - task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id) - taskdef = self._tasks.get(task_id) - if not taskdef: - if isinstance(func_or_id, str): - raise LookupError('no task found with ID {!r}'.format(func_or_id)) - else: - taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id) - - return taskdef - - def define_task(self, func: Callable, task_id: Optional[str] = None, **kwargs): - if task_id is None: - task_id = callable_to_ref(func) - - task = Task(id=task_id, **kwargs) - if self._tasks.setdefault(task_id, task) is not task: - pass + # def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task: + # task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id) + # taskdef = self._tasks.get(task_id) + # if not taskdef: + # if isinstance(func_or_id, str): + # raise LookupError('no task found with ID {!r}'.format(func_or_id)) + # else: + # taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id) + # + # return taskdef + # + # def define_task(self, func: Callable, task_id: Optional[str] = None, **kwargs): + # if task_id is None: + # task_id = callable_to_ref(func) + # + # task = Task(id=task_id, **kwargs) + # if self._tasks.setdefault(task_id, task) is not task: + # pass async def add_schedule( - self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None, + self, func_or_task_id: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None, args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None, coalesce: CoalescePolicy = CoalescePolicy.latest, misfire_grace_time: Union[float, timedelta, None] = None, @@ -130,12 +129,17 @@ class AsyncScheduler(EventSource): if isinstance(misfire_grace_time, (int, float)): misfire_grace_time = timedelta(seconds=misfire_grace_time) - taskdef = self._get_taskdef(task) - schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs, + if callable(func_or_task_id): + task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) + await self.data_store.add_task(task) + else: + task = await self.data_store.get_task(func_or_task_id) + + schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs, coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags) schedule.next_fire_time = trigger.next() await self.data_store.add_schedule(schedule, conflict_policy) - self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef, + self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task, trigger, schedule.next_fire_time) return schedule.id @@ -157,15 +161,6 @@ class AsyncScheduler(EventSource): schedules = await self.data_store.acquire_schedules(self.identity, 100) now = datetime.now(timezone.utc) for schedule in schedules: - # Look up the task definition - try: - taskdef = self._get_taskdef(schedule.task_id) - except LookupError: - self.logger.error('Cannot locate task definition %r for schedule %r – ' - 'removing schedule', schedule.task_id, schedule.id) - schedule.next_fire_time = None - continue - # Calculate a next fire time for the schedule, if possible fire_times = [schedule.next_fire_time] calculate_next = schedule.trigger.next @@ -175,7 +170,7 @@ class AsyncScheduler(EventSource): except Exception: self.logger.exception( 'Error computing next fire time for schedule %r of task %r – ' - 'removing schedule', schedule.id, taskdef.id) + 'removing schedule', schedule.id, schedule.task_id) break # Stop if the calculated fire time is in the future @@ -192,7 +187,7 @@ class AsyncScheduler(EventSource): # Add one or more jobs to the job queue for fire_time in fire_times: schedule.last_fire_time = fire_time - job = Job(task_id=taskdef.id, func=taskdef.func, args=schedule.args, + job = Job(task_id=schedule.task_id, args=schedule.args, kwargs=schedule.kwargs, schedule_id=schedule.id, scheduled_fire_time=fire_time, start_deadline=schedule.next_deadline, tags=schedule.tags) diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index 3005d27..3d86b25 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -7,7 +7,7 @@ from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import ExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Type, Union +from typing import Any, Callable, Iterable, Mapping, Optional, Type, Union from uuid import uuid4 from ..abc import DataStore, EventSource, Trigger @@ -35,7 +35,6 @@ class Scheduler(EventSource): self.logger = logger or getLogger(__name__) self.start_worker = start_worker self.data_store = data_store or MemoryDataStore() - self._tasks: Dict[str, Task] = {} self._exit_stack = ExitStack() self._executor = ThreadPoolExecutor(max_workers=1) self._events = EventHub() @@ -97,19 +96,8 @@ class Scheduler(EventSource): def unsubscribe(self, token: SubscriptionToken) -> None: self._events.unsubscribe(token) - def _get_taskdef(self, func_or_id: Union[str, Callable]) -> Task: - task_id = func_or_id if isinstance(func_or_id, str) else callable_to_ref(func_or_id) - taskdef = self._tasks.get(task_id) - if not taskdef: - if isinstance(func_or_id, str): - raise LookupError('no task found with ID {!r}'.format(func_or_id)) - else: - taskdef = self._tasks[task_id] = Task(id=task_id, func=func_or_id) - - return taskdef - def add_schedule( - self, task: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None, + self, func_or_task_id: Union[str, Callable], trigger: Trigger, *, id: Optional[str] = None, args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None, coalesce: CoalescePolicy = CoalescePolicy.latest, misfire_grace_time: Union[float, timedelta, None] = None, @@ -123,12 +111,17 @@ class Scheduler(EventSource): if isinstance(misfire_grace_time, (int, float)): misfire_grace_time = timedelta(seconds=misfire_grace_time) - taskdef = self._get_taskdef(task) - schedule = Schedule(id=id, task_id=taskdef.id, trigger=trigger, args=args, kwargs=kwargs, + if callable(func_or_task_id): + task = Task(id=callable_to_ref(func_or_task_id), func=func_or_task_id) + self.data_store.add_task(task) + else: + task = self.data_store.get_task(func_or_task_id) + + schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs, coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags) schedule.next_fire_time = trigger.next() self.data_store.add_schedule(schedule, conflict_policy) - self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', taskdef, + self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task, trigger, schedule.next_fire_time) return schedule.id @@ -149,16 +142,6 @@ class Scheduler(EventSource): schedules = self.data_store.acquire_schedules(self.identity, 100) now = datetime.now(timezone.utc) for schedule in schedules: - # Look up the task definition - try: - taskdef = self._get_taskdef(schedule.task_id) - except LookupError: - self.logger.error('Cannot locate task definition %r for schedule %r – ' - 'putting schedule on hold', schedule.task_id, - schedule.id) - schedule.next_fire_time = None - continue - # Calculate a next fire time for the schedule, if possible fire_times = [schedule.next_fire_time] calculate_next = schedule.trigger.next @@ -168,7 +151,7 @@ class Scheduler(EventSource): except Exception: self.logger.exception( 'Error computing next fire time for schedule %r of task %r – ' - 'removing schedule', schedule.id, taskdef.id) + 'removing schedule', schedule.id, schedule.task_id) break # Stop if the calculated fire time is in the future @@ -185,7 +168,7 @@ class Scheduler(EventSource): # Add one or more jobs to the job queue for fire_time in fire_times: schedule.last_fire_time = fire_time - job = Job(task_id=taskdef.id, func=taskdef.func, args=schedule.args, + job = Job(task_id=schedule.task_id, args=schedule.args, kwargs=schedule.kwargs, schedule_id=schedule.id, scheduled_fire_time=fire_time, start_deadline=schedule.next_deadline, tags=schedule.tags) diff --git a/src/apscheduler/structures.py b/src/apscheduler/structures.py index 4950818..2f97ca6 100644 --- a/src/apscheduler/structures.py +++ b/src/apscheduler/structures.py @@ -8,6 +8,7 @@ import attr from . import abc from .enums import JobOutcome +from .marshalling import callable_from_ref, callable_to_ref from .policies import CoalescePolicy @@ -15,11 +16,24 @@ from .policies import CoalescePolicy class Task: id: str func: Callable = attr.field(eq=False, order=False) - max_instances: Optional[int] = attr.field(eq=False, order=False, default=None) - metadata_arg: Optional[str] = attr.field(eq=False, order=False, default=None) - stateful: bool = attr.field(eq=False, order=False, default=False) + max_running_jobs: Optional[int] = attr.field(eq=False, order=False, default=None) + state: Any = None misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None) + def marshal(self, serializer: abc.Serializer) -> Dict[str, Any]: + marshalled = attr.asdict(self) + marshalled['func'] = callable_to_ref(self.func) + marshalled['state'] = serializer.serialize(self.state) if self.state else None + return marshalled + + @classmethod + def unmarshal(cls, serializer: abc.Serializer, marshalled: Dict[str, Any]) -> Task: + marshalled['func'] = callable_from_ref(marshalled['func']) + if marshalled['state'] is not None: + marshalled['state'] = serializer.deserialize(marshalled['state']) + + return cls(**marshalled) + @attr.define(kw_only=True) class Schedule: @@ -36,6 +50,15 @@ class Schedule: last_fire_time: Optional[datetime] = attr.field(eq=False, order=False, init=False, default=None) + def marshal(self, serializer: abc.Serializer) -> Dict[str, Any]: + marshalled = attr.asdict(self) + marshalled['trigger_type'] = serializer.serialize(self.args) + marshalled['trigger_data'] = serializer.serialize(self.trigger) + marshalled['args'] = serializer.serialize(self.args) if self.args else None + marshalled['kwargs'] = serializer.serialize(self.kwargs) if self.kwargs else None + marshalled['tags'] = list(self.tags) + return marshalled + @property def next_deadline(self) -> Optional[datetime]: if self.next_fire_time and self.misfire_grace_time: @@ -48,7 +71,6 @@ class Schedule: class Job: id: UUID = attr.field(factory=uuid4) task_id: str = attr.field(eq=False, order=False) - func: Callable = attr.field(eq=False, order=False) args: tuple = attr.field(eq=False, order=False, default=()) kwargs: Dict[str, Any] = attr.field(eq=False, order=False, factory=dict) schedule_id: Optional[str] = attr.field(eq=False, order=False, default=None) @@ -57,6 +79,22 @@ class Job: tags: FrozenSet[str] = attr.field(eq=False, order=False, factory=frozenset) started_at: Optional[datetime] = attr.field(eq=False, order=False, init=False, default=None) + def marshal(self, serializer: abc.Serializer) -> Dict[str, Any]: + marshalled = attr.asdict(self) + marshalled['args'] = serializer.serialize(self.args) if self.args else None + marshalled['kwargs'] = serializer.serialize(self.kwargs) if self.kwargs else None + marshalled['tags'] = list(self.tags) + return marshalled + + @classmethod + def unmarshal(cls, serializer: abc.Serializer, marshalled: Dict[str, Any]) -> Task: + for key in ('args', 'kwargs'): + if marshalled[key] is not None: + marshalled[key] = serializer.deserialize(marshalled[key]) + + marshalled['tags'] = frozenset(marshalled['tags']) + return cls(**marshalled) + @attr.define(eq=False, order=False, frozen=True) class JobResult: diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py index 79b0143..dbc72b3 100644 --- a/src/apscheduler/workers/async_.py +++ b/src/apscheduler/workers/async_.py @@ -11,7 +11,7 @@ from uuid import UUID import anyio from anyio import TASK_STATUS_IGNORED, create_task_group, get_cancelled_exc_class -from anyio.abc import CancelScope, TaskGroup +from anyio.abc import CancelScope from ..abc import AsyncDataStore, DataStore, EventSource, Job from ..datastores.async_.sync_adapter import AsyncDataStoreAdapter @@ -25,7 +25,6 @@ from ..structures import JobResult class AsyncWorker(EventSource): """Runs jobs locally in a task group.""" - _task_group: Optional[TaskGroup] = None _stop_event: Optional[anyio.Event] = None _state: RunState = RunState.stopped _acquire_cancel_scope: Optional[CancelScope] = None @@ -72,16 +71,15 @@ class AsyncWorker(EventSource): self._exit_stack.callback(self.data_store.unsubscribe, wakeup_token) # Start the actual worker - self._task_group = create_task_group() - await self._exit_stack.enter_async_context(self._task_group) - await self._task_group.start(self.run) + task_group = create_task_group() + await self._exit_stack.enter_async_context(task_group) + await task_group.start(self.run) return self async def __aexit__(self, exc_type, exc_val, exc_tb): self._state = RunState.stopping self._wakeup_event.set() await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - del self._task_group del self._wakeup_event def subscribe(self, callback: Callable[[Event], Any], @@ -102,15 +100,17 @@ class AsyncWorker(EventSource): self._events.publish(WorkerStarted()) try: - while self._state is RunState.started: - limit = self.max_concurrent_jobs - len(self._running_jobs) - jobs = await self.data_store.acquire_jobs(self.identity, limit) - for job in jobs: - self._running_jobs.add(job.id) - self._task_group.start_soon(self._run_job, job) - - await self._wakeup_event.wait() - self._wakeup_event = anyio.Event() + async with create_task_group() as tg: + while self._state is RunState.started: + limit = self.max_concurrent_jobs - len(self._running_jobs) + jobs = await self.data_store.acquire_jobs(self.identity, limit) + for job in jobs: + task = await self.data_store.get_task(job.task_id) + self._running_jobs.add(job.id) + tg.start_soon(self._run_job, job, task.func) + + await self._wakeup_event.wait() + self._wakeup_event = anyio.Event() except get_cancelled_exc_class(): pass except BaseException as exc: @@ -121,7 +121,7 @@ class AsyncWorker(EventSource): self._state = RunState.stopped self._events.publish(WorkerStopped()) - async def _run_job(self, job: Job) -> None: + async def _run_job(self, job: Job, func: Callable) -> None: try: # Check if the job started before the deadline start_time = datetime.now(timezone.utc) @@ -131,24 +131,24 @@ class AsyncWorker(EventSource): self._events.publish(JobStarted.from_job(job, start_time)) try: - retval = job.func(*job.args, **job.kwargs) + retval = func(*job.args, **job.kwargs) if isawaitable(retval): retval = await retval except get_cancelled_exc_class(): with CancelScope(shield=True): result = JobResult(outcome=JobOutcome.cancelled) - await self.data_store.release_job(self.identity, job.id, result) + await self.data_store.release_job(self.identity, job, result) self._events.publish(JobCancelled.from_job(job, start_time)) except BaseException as exc: result = JobResult(outcome=JobOutcome.failure, exception=exc) - await self.data_store.release_job(self.identity, job.id, result) + await self.data_store.release_job(self.identity, job, result) self._events.publish(JobFailed.from_exception(job, start_time, exc)) if not isinstance(exc, Exception): raise else: result = JobResult(outcome=JobOutcome.success, return_value=retval) - await self.data_store.release_job(self.identity, job.id, result) + await self.data_store.release_job(self.identity, job, result) self._events.publish(JobCompleted.from_retval(job, start_time, retval)) finally: self._running_jobs.remove(job.id) diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 9412b1c..1fb1cba 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -22,6 +22,7 @@ from ..structures import Job, JobResult class Worker(EventSource): """Runs jobs locally in a thread pool.""" + _executor: ThreadPoolExecutor _state: RunState = RunState.stopped _wakeup_event: threading.Event @@ -32,7 +33,6 @@ class Worker(EventSource): self.logger = logger or getLogger(__name__) self._acquired_jobs: Set[Job] = set() self._exit_stack = ExitStack() - self._executor = ThreadPoolExecutor(max_workers=max_concurrent_jobs + 1) self._events = EventHub() self._running_jobs: Set[UUID] = set() @@ -64,6 +64,7 @@ class Worker(EventSource): # Start the worker and return when it has signalled readiness or raised an exception start_future: Future[None] = Future() token = self._events.subscribe(start_future.set_result) + self._executor = ThreadPoolExecutor(1) run_future = self._executor.submit(self.run) try: wait([start_future, run_future], return_when=FIRST_COMPLETED) @@ -98,26 +99,30 @@ class Worker(EventSource): self._state = RunState.started self._events.publish(WorkerStarted()) + executor = ThreadPoolExecutor(max_workers=self.max_concurrent_jobs) try: while self._state is RunState.started: available_slots = self.max_concurrent_jobs - len(self._running_jobs) if available_slots: jobs = self.data_store.acquire_jobs(self.identity, available_slots) for job in jobs: + task = self.data_store.get_task(job.task_id) self._running_jobs.add(job.id) - self._executor.submit(self._run_job, job) + executor.submit(self._run_job, job, task.func) self._wakeup_event.wait() self._wakeup_event = threading.Event() except BaseException as exc: + executor.shutdown(wait=False) self._state = RunState.stopped self._events.publish(WorkerStopped(exception=exc)) raise + executor.shutdown() self._state = RunState.stopped self._events.publish(WorkerStopped()) - def _run_job(self, job: Job) -> None: + def _run_job(self, job: Job, func: Callable) -> None: try: # Check if the job started before the deadline start_time = datetime.now(timezone.utc) @@ -127,16 +132,16 @@ class Worker(EventSource): self._events.publish(JobStarted.from_job(job, start_time)) try: - retval = job.func(*job.args, **job.kwargs) + retval = func(*job.args, **job.kwargs) except BaseException as exc: result = JobResult(outcome=JobOutcome.failure, exception=exc) - self.data_store.release_job(self.identity, job.id, result) + self.data_store.release_job(self.identity, job, result) self._events.publish(JobFailed.from_exception(job, start_time, exc)) if not isinstance(exc, Exception): raise else: result = JobResult(outcome=JobOutcome.success, return_value=retval) - self.data_store.release_job(self.identity, job.id, result) + self.data_store.release_job(self.identity, job, result) self._events.publish(JobCompleted.from_retval(job, start_time, retval)) finally: self._running_jobs.remove(job.id) diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 1a18c03..2751c83 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -10,37 +10,27 @@ from freezegun.api import FrozenDateTimeFactory from apscheduler.abc import AsyncDataStore, Job, Schedule from apscheduler.enums import JobOutcome -from apscheduler.events import Event, ScheduleAdded, ScheduleRemoved, ScheduleUpdated +from apscheduler.events import Event, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, TaskAdded from apscheduler.policies import CoalescePolicy, ConflictPolicy -from apscheduler.structures import JobResult +from apscheduler.structures import JobResult, Task from apscheduler.triggers.date import DateTrigger @pytest.fixture def schedules() -> List[Schedule]: trigger = DateTrigger(datetime(2020, 9, 13, tzinfo=timezone.utc)) - schedule1 = Schedule(id='s1', task_id='bogus', trigger=trigger) + schedule1 = Schedule(id='s1', task_id='task1', trigger=trigger) schedule1.next_fire_time = trigger.next() trigger = DateTrigger(datetime(2020, 9, 14, tzinfo=timezone.utc)) - schedule2 = Schedule(id='s2', task_id='bogus', trigger=trigger) + schedule2 = Schedule(id='s2', task_id='task2', trigger=trigger) schedule2.next_fire_time = trigger.next() trigger = DateTrigger(datetime(2020, 9, 15, tzinfo=timezone.utc)) - schedule3 = Schedule(id='s3', task_id='bogus', trigger=trigger) + schedule3 = Schedule(id='s3', task_id='task1', trigger=trigger) return [schedule1, schedule2, schedule3] -@pytest.fixture -def jobs() -> List[Job]: - job1 = Job(task_id='task1', func=print, args=('hello',), kwargs={'arg2': 'world'}, - schedule_id='schedule1', - scheduled_fire_time=datetime(2020, 10, 10, tzinfo=timezone.utc), - start_deadline=datetime(2020, 10, 10, 1, tzinfo=timezone.utc)) - job2 = Job(task_id='task2', func=print, args=('hello',), kwargs={'arg2': 'world'}) - return [job1, job2] - - @asynccontextmanager async def capture_events( store: AsyncDataStore, limit: int, @@ -63,6 +53,33 @@ async def capture_events( @pytest.mark.anyio class TestAsyncStores: + async def test_add_replace_task( + self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None: + import math + + async with datastore_cm as store, capture_events(store, 3, {TaskAdded}) as events: + await store.add_task(Task(id='test_task', func=print)) + await store.add_task(Task(id='test_task2', func=math.ceil)) + await store.add_task(Task(id='test_task', func=repr)) + + tasks = await store.get_tasks() + assert len(tasks) == 2 + assert tasks[0].id == 'test_task' + assert tasks[0].func is repr + assert tasks[1].id == 'test_task2' + assert tasks[1].func is math.ceil + + received_event = events.pop(0) + assert received_event.task_id == 'test_task' + + received_event = events.pop(0) + assert received_event.task_id == 'test_task2' + + received_event = events.pop(0) + assert received_event.task_id == 'test_task' + + assert not events + async def test_add_schedules(self, datastore_cm: AsyncContextManager[AsyncDataStore], schedules: List[Schedule]) -> None: async with datastore_cm as store, capture_events(store, 3, {ScheduleAdded}) as events: @@ -200,35 +217,39 @@ class TestAsyncStores: assert acquired3[0].id == 's1' async def test_acquire_multiple_workers( - self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job]) -> None: + self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None: async with datastore_cm as store: + await store.add_task(Task(id='task1', func=asynccontextmanager)) + jobs = [Job(task_id='task1') for _ in range(2)] for job in jobs: await store.add_job(job) # The first worker gets the first job in the queue - jobs1 = await store.acquire_jobs('dummy-id1', 1) + jobs1 = await store.acquire_jobs('worker1', 1) assert len(jobs1) == 1 assert jobs1[0].id == jobs[0].id # The second worker gets the second job - jobs2 = await store.acquire_jobs('dummy-id2', 1) + jobs2 = await store.acquire_jobs('worker2', 1) assert len(jobs2) == 1 assert jobs2[0].id == jobs[1].id # The third worker gets nothing - jobs3 = await store.acquire_jobs('dummy-id3', 1) + jobs3 = await store.acquire_jobs('worker3', 1) assert not jobs3 - async def test_job_release_success(self, datastore_cm: AsyncContextManager[AsyncDataStore], - jobs: List[Job]): + async def test_job_release_success( + self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None: async with datastore_cm as store: - await store.add_job(jobs[0]) + await store.add_task(Task(id='task1', func=asynccontextmanager)) + job = Job(task_id='task1') + await store.add_job(job) acquired = await store.acquire_jobs('worker_id', 2) assert len(acquired) == 1 - assert acquired[0].id == jobs[0].id + assert acquired[0].id == job.id - await store.release_job('worker_id', acquired[0].id, + await store.release_job('worker_id', acquired[0], JobResult(JobOutcome.success, return_value='foo')) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.success @@ -239,16 +260,18 @@ class TestAsyncStores: assert not await store.get_jobs({acquired[0].id}) assert not await store.get_job_result(acquired[0].id) - async def test_job_release_failure(self, datastore_cm: AsyncContextManager[AsyncDataStore], - jobs: List[Job]): + async def test_job_release_failure( + self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None: async with datastore_cm as store: - await store.add_job(jobs[0]) + await store.add_task(Task(id='task1', func=asynccontextmanager)) + job = Job(task_id='task1') + await store.add_job(job) acquired = await store.acquire_jobs('worker_id', 2) assert len(acquired) == 1 - assert acquired[0].id == jobs[0].id + assert acquired[0].id == job.id - await store.release_job('worker_id', acquired[0].id, + await store.release_job('worker_id', acquired[0], JobResult(JobOutcome.failure, exception=ValueError('foo'))) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.failure @@ -261,15 +284,17 @@ class TestAsyncStores: assert not await store.get_job_result(acquired[0].id) async def test_job_release_missed_deadline( - self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job]): + self, datastore_cm: AsyncContextManager[AsyncDataStore]): async with datastore_cm as store: - await store.add_job(jobs[0]) + await store.add_task(Task(id='task1', func=asynccontextmanager)) + job = Job(task_id='task1') + await store.add_job(job) acquired = await store.acquire_jobs('worker_id', 2) assert len(acquired) == 1 - assert acquired[0].id == jobs[0].id + assert acquired[0].id == job.id - await store.release_job('worker_id', acquired[0].id, + await store.release_job('worker_id', acquired[0], JobResult(JobOutcome.missed_start_deadline)) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.missed_start_deadline @@ -281,16 +306,17 @@ class TestAsyncStores: assert not await store.get_job_result(acquired[0].id) async def test_job_release_cancelled( - self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job]): + self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None: async with datastore_cm as store: - await store.add_job(jobs[0]) + await store.add_task(Task(id='task1', func=asynccontextmanager)) + job = Job(task_id='task1') + await store.add_job(job) - acquired = await store.acquire_jobs('worker_id', 2) + acquired = await store.acquire_jobs('worker1', 2) assert len(acquired) == 1 - assert acquired[0].id == jobs[0].id + assert acquired[0].id == job.id - await store.release_job('worker_id', acquired[0].id, - JobResult(JobOutcome.cancelled)) + await store.release_job('worker1', acquired[0], JobResult(JobOutcome.cancelled)) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.cancelled assert result.exception is None @@ -301,7 +327,7 @@ class TestAsyncStores: assert not await store.get_job_result(acquired[0].id) async def test_acquire_jobs_lock_timeout( - self, datastore_cm: AsyncContextManager[AsyncDataStore], jobs: List[Job], + self, datastore_cm: AsyncContextManager[AsyncDataStore], freezer: FrozenDateTimeFactory) -> None: """ Test that a worker can acquire jobs that were acquired by another scheduler but not @@ -309,19 +335,40 @@ class TestAsyncStores: """ async with datastore_cm as store: + await store.add_task(Task(id='task1', func=asynccontextmanager)) + job = Job(task_id='task1') + await store.add_job(job) + # First, one worker acquires the first available job - await store.add_job(jobs[0]) - acquired = await store.acquire_jobs('dummy-id1', 1) + acquired = await store.acquire_jobs('worker1', 1) assert len(acquired) == 1 - assert acquired[0].id == jobs[0].id + assert acquired[0].id == job.id # Try to acquire the job just at the threshold (now == acquired_until). # This should not yield any jobs. freezer.tick(30) - assert not await store.acquire_jobs('dummy-id2', 1) + assert not await store.acquire_jobs('worker2', 1) # Right after that, the job should be available freezer.tick(1) - acquired = await store.acquire_jobs('dummy-id2', 1) + acquired = await store.acquire_jobs('worker2', 1) assert len(acquired) == 1 - assert acquired[0].id == jobs[0].id + assert acquired[0].id == job.id + + async def test_acquire_jobs_max_number_exceeded( + self, datastore_cm: AsyncContextManager[AsyncDataStore]) -> None: + async with datastore_cm as store: + await store.add_task(Task(id='task1', func=asynccontextmanager, max_running_jobs=2)) + jobs = [Job(task_id='task1'), Job(task_id='task1'), Job(task_id='task1')] + for job in jobs: + await store.add_job(job) + + # Check that only 2 jobs are returned from acquire_jobs() even though the limit wqas 3 + acquired_jobs = await store.acquire_jobs('worker1', 3) + assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] + + # Release one job, and the worker should be able to acquire the third job + await store.release_job('worker1', acquired_jobs[0], + JobResult(outcome=JobOutcome.success, return_value=None)) + acquired_jobs = await store.acquire_jobs('worker1', 3) + assert [job.id for job in acquired_jobs] == [jobs[2].id] diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 3c0911d..990cfec 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -7,7 +7,7 @@ import pytest from anyio import fail_after from apscheduler.events import ( - Event, JobAdded, ScheduleAdded, ScheduleRemoved, SchedulerStarted, SchedulerStopped) + Event, JobAdded, ScheduleAdded, ScheduleRemoved, SchedulerStarted, SchedulerStopped, TaskAdded) from apscheduler.schedulers.async_ import AsyncScheduler from apscheduler.schedulers.sync import Scheduler from apscheduler.triggers.date import DateTrigger @@ -27,7 +27,7 @@ class TestAsyncScheduler: async def test_schedule_job(self) -> None: def listener(received_event: Event) -> None: received_events.append(received_event) - if len(received_events) == 4: + if len(received_events) == 5: event.set() received_events: List[Event] = [] @@ -44,6 +44,11 @@ class TestAsyncScheduler: received_event = received_events.pop(0) assert isinstance(received_event, SchedulerStarted) + # Then the task was added + received_event = received_events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == 'test_schedulers:dummy_async_job' + # Then a schedule was added received_event = received_events.pop(0) assert isinstance(received_event, ScheduleAdded) @@ -73,7 +78,7 @@ class TestSyncScheduler: def test_schedule_job(self): def listener(received_event: Event) -> None: received_events.append(received_event) - if len(received_events) == 4: + if len(received_events) == 5: event.set() received_events: List[Event] = [] @@ -89,6 +94,11 @@ class TestSyncScheduler: received_event = received_events.pop(0) assert isinstance(received_event, SchedulerStarted) + # Then the task was added + received_event = received_events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == 'test_schedulers:dummy_sync_job' + # Then a schedule was added received_event = received_events.pop(0) assert isinstance(received_event, ScheduleAdded) diff --git a/tests/test_workers.py b/tests/test_workers.py index a45398d..f949edf 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -9,8 +9,9 @@ from anyio import fail_after from apscheduler.abc import Job from apscheduler.datastores.sync.memory import MemoryDataStore from apscheduler.events import ( - Event, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, WorkerStarted, - WorkerStopped) + Event, JobAdded, JobCompleted, JobDeadlineMissed, JobFailed, JobStarted, TaskAdded, + WorkerStarted, WorkerStopped) +from apscheduler.structures import Task from apscheduler.workers.async_ import AsyncWorker from apscheduler.workers.sync import Worker @@ -41,7 +42,7 @@ class TestAsyncWorker: async def test_run_job_nonscheduled_success(self, target_func: Callable, fail: bool) -> None: def listener(received_event: Event): received_events.append(received_event) - if len(received_events) == 4: + if len(received_events) == 5: event.set() received_events: List[Event] = [] @@ -50,8 +51,8 @@ class TestAsyncWorker: worker = AsyncWorker(data_store) worker.subscribe(listener) async with worker: - job = Job(task_id='task_id', func=target_func, args=(1, 2), - kwargs={'x': 'foo', 'fail': fail}) + await worker.data_store.add_task(Task(id='task_id', func=target_func)) + job = Job(task_id='task_id', args=(1, 2), kwargs={'x': 'foo', 'fail': fail}) await worker.data_store.add_job(job) with fail_after(3): await event.wait() @@ -60,6 +61,11 @@ class TestAsyncWorker: received_event = received_events.pop(0) assert isinstance(received_event, WorkerStarted) + # Then the task was added + received_event = received_events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == 'task_id' + # Then a job was added received_event = received_events.pop(0) assert isinstance(received_event, JobAdded) @@ -96,7 +102,7 @@ class TestAsyncWorker: async def test_run_deadline_missed(self) -> None: def listener(received_event: Event): received_events.append(received_event) - if len(received_events) == 3: + if len(received_events) == 4: event.set() scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) @@ -106,7 +112,8 @@ class TestAsyncWorker: worker = AsyncWorker(data_store) worker.subscribe(listener) async with worker: - job = Job(task_id='task_id', func=fail_func, schedule_id='foo', + await worker.data_store.add_task(Task(id='task_id', func=fail_func)) + job = Job(task_id='task_id', schedule_id='foo', scheduled_fire_time=scheduled_start_time, start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc)) await worker.data_store.add_job(job) @@ -117,6 +124,11 @@ class TestAsyncWorker: received_event = received_events.pop(0) assert isinstance(received_event, WorkerStarted) + # Then the task was added + received_event = received_events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == 'task_id' + # Then a job was added received_event = received_events.pop(0) assert isinstance(received_event, JobAdded) @@ -144,7 +156,7 @@ class TestSyncWorker: def test_run_job_nonscheduled(self, fail: bool) -> None: def listener(received_event: Event): received_events.append(received_event) - if len(received_events) == 4: + if len(received_events) == 5: event.set() received_events: List[Event] = [] @@ -153,8 +165,8 @@ class TestSyncWorker: worker = Worker(data_store) worker.subscribe(listener) with worker: - job = Job(task_id='task_id', func=sync_func, args=(1, 2), - kwargs={'x': 'foo', 'fail': fail}) + worker.data_store.add_task(Task(id='task_id', func=sync_func)) + job = Job(task_id='task_id', args=(1, 2), kwargs={'x': 'foo', 'fail': fail}) worker.data_store.add_job(job) event.wait(5) @@ -162,6 +174,11 @@ class TestSyncWorker: received_event = received_events.pop(0) assert isinstance(received_event, WorkerStarted) + # Then the task was added + received_event = received_events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == 'task_id' + # Then a job was added received_event = received_events.pop(0) assert isinstance(received_event, JobAdded) @@ -198,7 +215,7 @@ class TestSyncWorker: def test_run_deadline_missed(self) -> None: def listener(worker_event: Event): received_events.append(worker_event) - if len(received_events) == 3: + if len(received_events) == 4: event.set() scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) @@ -208,7 +225,8 @@ class TestSyncWorker: worker = Worker(data_store) worker.subscribe(listener) with worker: - job = Job(task_id='task_id', func=fail_func, schedule_id='foo', + worker.data_store.add_task(Task(id='task_id', func=fail_func)) + job = Job(task_id='task_id', schedule_id='foo', scheduled_fire_time=scheduled_start_time, start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc)) worker.data_store.add_job(job) @@ -218,6 +236,11 @@ class TestSyncWorker: received_event = received_events.pop(0) assert isinstance(received_event, WorkerStarted) + # Then the task was added + received_event = received_events.pop(0) + assert isinstance(received_event, TaskAdded) + assert received_event.task_id == 'task_id' + # Then a job was added received_event = received_events.pop(0) assert isinstance(received_event, JobAdded) |