diff options
Diffstat (limited to 'src/apscheduler/datastores/sync/mongodb.py')
-rw-r--r-- | src/apscheduler/datastores/sync/mongodb.py | 125 |
1 files changed, 106 insertions, 19 deletions
diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py index 4d5c0f2..fcf7d4d 100644 --- a/src/apscheduler/datastores/sync/mongodb.py +++ b/src/apscheduler/datastores/sync/mongodb.py @@ -1,11 +1,14 @@ from __future__ import annotations import logging +from collections import defaultdict from contextlib import ExitStack from datetime import datetime, timezone -from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Type +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Set, Tuple, Type from uuid import UUID +import attr +import pymongo from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.errors import DuplicateKeyError @@ -14,19 +17,25 @@ from ... import events from ...abc import DataStore, Job, Schedule, Serializer from ...events import ( DataStoreEvent, EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, - SubscriptionToken) -from ...exceptions import ConflictingIdError, DeserializationError, SerializationError + SubscriptionToken, TaskAdded, TaskRemoved) +from ...exceptions import ( + ConflictingIdError, DeserializationError, SerializationError, TaskLookupError) from ...policies import ConflictPolicy from ...serializers.pickle import PickleSerializer -from ...structures import JobResult +from ...structures import JobResult, Task from ...util import reentrant @reentrant class MongoDBDataStore(DataStore): + _task_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Task)] + _schedule_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Schedule)] + _job_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Job)] + def __init__(self, client: MongoClient, *, serializer: Optional[Serializer] = None, - database: str = 'apscheduler', schedules_collection: str = 'schedules', - jobs_collection: str = 'jobs', job_results_collection: str = 'job_results', + database: str = 'apscheduler', tasks_collection: str = 'tasks', + schedules_collection: str = 'schedules', jobs_collection: str = 'jobs', + job_results_collection: str = 'job_results', lock_expiration_delay: float = 30, start_from_scratch: bool = False): super().__init__() if not client.delegate.codec_options.tz_aware: @@ -36,7 +45,9 @@ class MongoDBDataStore(DataStore): self.serializer = serializer or PickleSerializer() self.lock_expiration_delay = lock_expiration_delay self.start_from_scratch = start_from_scratch + self._local_tasks: Dict[str, Task] = {} self._database = client[database] + self._tasks: Collection = self._database[tasks_collection] self._schedules: Collection = self._database[schedules_collection] self._jobs: Collection = self._database[jobs_collection] self._jobs_results: Collection = self._database[job_results_collection] @@ -59,6 +70,7 @@ class MongoDBDataStore(DataStore): self._exit_stack.enter_context(self._events) if self.start_from_scratch: + self._tasks.delete_many({}) self._schedules.delete_many({}) self._jobs.delete_many({}) self._jobs_results.delete_many({}) @@ -80,6 +92,44 @@ class MongoDBDataStore(DataStore): def unsubscribe(self, token: events.SubscriptionToken) -> None: self._events.unsubscribe(token) + def add_task(self, task: Task) -> None: + self._tasks.find_one_and_update( + {'_id': task.id}, + {'$set': task.marshal(self.serializer), + '$setOnInsert': {'running_jobs': 0}}, + upsert=True + ) + self._local_tasks[task.id] = task + self._events.publish(TaskAdded(task_id=task.id)) + + def remove_task(self, task_id: str) -> None: + if not self._tasks.find_one_and_delete({'_id': task_id}): + raise TaskLookupError(task_id) + + del self._local_tasks[task_id] + self._events.publish(TaskRemoved(task_id=task_id)) + + def get_task(self, task_id: str) -> Task: + try: + return self._local_tasks[task_id] + except KeyError: + document = self._tasks.find_one({'_id': task_id}, projection=self._task_attrs) + if not document: + raise TaskLookupError(task_id) + + document['id'] = document.pop('id') + task = self._local_tasks[task_id] = Task.unmarshal(self.serializer, document) + return task + + def get_tasks(self) -> List[Task]: + tasks: List[Task] = [] + for document in self._tasks.find(projection=self._task_attrs, + sort=[('_id', pymongo.ASCENDING)]): + document['id'] = document.pop('_id') + tasks.append(Task.unmarshal(self.serializer, document)) + + return tasks + def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]: schedules: List[Schedule] = [] filters = {'_id': {'$in': list(ids)}} if ids is not None else {} @@ -238,45 +288,82 @@ class MongoDBDataStore(DataStore): return jobs def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: - jobs: List[Job] = [] with self.client.start_session() as session: cursor = self._jobs.find( {'$or': [{'acquired_until': {'$exists': False}}, {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] }, - projection=['serialized_data'], + projection=['task_id', 'serialized_data'], sort=[('created_at', ASCENDING)], limit=limit, session=session ) - for document in cursor: + documents = list(cursor) + + # Retrieve the limits + task_ids: Set[str] = {document['task_id'] for document in documents} + task_limits = self._tasks.find( + {'_id': {'$in': list(task_ids)}, 'max_running_jobs': {'$ne': None}}, + projection=['max_running_jobs', 'running_jobs'], + session=session + ) + job_slots_left = {doc['_id']: doc['max_running_jobs'] - doc['running_jobs'] + for doc in task_limits} + + # Filter out jobs that don't have free slots + acquired_jobs: List[Job] = [] + increments: Dict[str, int] = defaultdict(lambda: 0) + for document in documents: job = self.serializer.deserialize(document['serialized_data']) - jobs.append(job) - if jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: now = datetime.now(timezone.utc) acquired_until = datetime.fromtimestamp( now.timestamp() + self.lock_expiration_delay, timezone.utc) - filters = {'_id': {'$in': [job.id for job in jobs]}} + filters = {'_id': {'$in': [job.id for job in acquired_jobs]}} update = {'$set': {'acquired_by': worker_id, 'acquired_until': acquired_until}} self._jobs.update_many(filters, update, session=session) - return jobs + # Increment the running job counters on each task + for task_id, increment in increments.items(): + self._tasks.find_one_and_update( + {'_id': task_id}, + {'$inc': {'running_jobs': increment}}, + session=session + ) + + return acquired_jobs - def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: with self.client.start_session() as session: + # Insert the job result now = datetime.now(timezone.utc) - serialized_result = self.serializer.serialize(result) document = { - '_id': job_id, + '_id': job.id, 'finished_at': now, - 'serialized_data': serialized_result + 'serialized_data': self.serializer.serialize(result) } self._jobs_results.insert_one(document, session=session) - filters = {'_id': job_id, 'acquired_by': worker_id} - self._jobs.delete_one(filters, session=session) + # Decrement the running jobs counter + self._tasks.find_one_and_update( + {'_id': job.task_id}, + {'$inc': {'running_jobs': -1}} + ) + + # Delete the job + self._jobs.delete_one({'_id': job.id}, session=session) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: document = self._jobs_results.find_one_and_delete( |