summaryrefslogtreecommitdiff
path: root/src/apscheduler/datastores/sync/mongodb.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/datastores/sync/mongodb.py')
-rw-r--r--src/apscheduler/datastores/sync/mongodb.py125
1 files changed, 106 insertions, 19 deletions
diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py
index 4d5c0f2..fcf7d4d 100644
--- a/src/apscheduler/datastores/sync/mongodb.py
+++ b/src/apscheduler/datastores/sync/mongodb.py
@@ -1,11 +1,14 @@
from __future__ import annotations
import logging
+from collections import defaultdict
from contextlib import ExitStack
from datetime import datetime, timezone
-from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Type
+from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Set, Tuple, Type
from uuid import UUID
+import attr
+import pymongo
from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne
from pymongo.collection import Collection
from pymongo.errors import DuplicateKeyError
@@ -14,19 +17,25 @@ from ... import events
from ...abc import DataStore, Job, Schedule, Serializer
from ...events import (
DataStoreEvent, EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated,
- SubscriptionToken)
-from ...exceptions import ConflictingIdError, DeserializationError, SerializationError
+ SubscriptionToken, TaskAdded, TaskRemoved)
+from ...exceptions import (
+ ConflictingIdError, DeserializationError, SerializationError, TaskLookupError)
from ...policies import ConflictPolicy
from ...serializers.pickle import PickleSerializer
-from ...structures import JobResult
+from ...structures import JobResult, Task
from ...util import reentrant
@reentrant
class MongoDBDataStore(DataStore):
+ _task_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Task)]
+ _schedule_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Schedule)]
+ _job_attrs: ClassVar[List[str]] = [field.name for field in attr.fields(Job)]
+
def __init__(self, client: MongoClient, *, serializer: Optional[Serializer] = None,
- database: str = 'apscheduler', schedules_collection: str = 'schedules',
- jobs_collection: str = 'jobs', job_results_collection: str = 'job_results',
+ database: str = 'apscheduler', tasks_collection: str = 'tasks',
+ schedules_collection: str = 'schedules', jobs_collection: str = 'jobs',
+ job_results_collection: str = 'job_results',
lock_expiration_delay: float = 30, start_from_scratch: bool = False):
super().__init__()
if not client.delegate.codec_options.tz_aware:
@@ -36,7 +45,9 @@ class MongoDBDataStore(DataStore):
self.serializer = serializer or PickleSerializer()
self.lock_expiration_delay = lock_expiration_delay
self.start_from_scratch = start_from_scratch
+ self._local_tasks: Dict[str, Task] = {}
self._database = client[database]
+ self._tasks: Collection = self._database[tasks_collection]
self._schedules: Collection = self._database[schedules_collection]
self._jobs: Collection = self._database[jobs_collection]
self._jobs_results: Collection = self._database[job_results_collection]
@@ -59,6 +70,7 @@ class MongoDBDataStore(DataStore):
self._exit_stack.enter_context(self._events)
if self.start_from_scratch:
+ self._tasks.delete_many({})
self._schedules.delete_many({})
self._jobs.delete_many({})
self._jobs_results.delete_many({})
@@ -80,6 +92,44 @@ class MongoDBDataStore(DataStore):
def unsubscribe(self, token: events.SubscriptionToken) -> None:
self._events.unsubscribe(token)
+ def add_task(self, task: Task) -> None:
+ self._tasks.find_one_and_update(
+ {'_id': task.id},
+ {'$set': task.marshal(self.serializer),
+ '$setOnInsert': {'running_jobs': 0}},
+ upsert=True
+ )
+ self._local_tasks[task.id] = task
+ self._events.publish(TaskAdded(task_id=task.id))
+
+ def remove_task(self, task_id: str) -> None:
+ if not self._tasks.find_one_and_delete({'_id': task_id}):
+ raise TaskLookupError(task_id)
+
+ del self._local_tasks[task_id]
+ self._events.publish(TaskRemoved(task_id=task_id))
+
+ def get_task(self, task_id: str) -> Task:
+ try:
+ return self._local_tasks[task_id]
+ except KeyError:
+ document = self._tasks.find_one({'_id': task_id}, projection=self._task_attrs)
+ if not document:
+ raise TaskLookupError(task_id)
+
+ document['id'] = document.pop('id')
+ task = self._local_tasks[task_id] = Task.unmarshal(self.serializer, document)
+ return task
+
+ def get_tasks(self) -> List[Task]:
+ tasks: List[Task] = []
+ for document in self._tasks.find(projection=self._task_attrs,
+ sort=[('_id', pymongo.ASCENDING)]):
+ document['id'] = document.pop('_id')
+ tasks.append(Task.unmarshal(self.serializer, document))
+
+ return tasks
+
def get_schedules(self, ids: Optional[Set[str]] = None) -> List[Schedule]:
schedules: List[Schedule] = []
filters = {'_id': {'$in': list(ids)}} if ids is not None else {}
@@ -238,45 +288,82 @@ class MongoDBDataStore(DataStore):
return jobs
def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
- jobs: List[Job] = []
with self.client.start_session() as session:
cursor = self._jobs.find(
{'$or': [{'acquired_until': {'$exists': False}},
{'acquired_until': {'$lt': datetime.now(timezone.utc)}}]
},
- projection=['serialized_data'],
+ projection=['task_id', 'serialized_data'],
sort=[('created_at', ASCENDING)],
limit=limit,
session=session
)
- for document in cursor:
+ documents = list(cursor)
+
+ # Retrieve the limits
+ task_ids: Set[str] = {document['task_id'] for document in documents}
+ task_limits = self._tasks.find(
+ {'_id': {'$in': list(task_ids)}, 'max_running_jobs': {'$ne': None}},
+ projection=['max_running_jobs', 'running_jobs'],
+ session=session
+ )
+ job_slots_left = {doc['_id']: doc['max_running_jobs'] - doc['running_jobs']
+ for doc in task_limits}
+
+ # Filter out jobs that don't have free slots
+ acquired_jobs: List[Job] = []
+ increments: Dict[str, int] = defaultdict(lambda: 0)
+ for document in documents:
job = self.serializer.deserialize(document['serialized_data'])
- jobs.append(job)
- if jobs:
+ # Don't acquire the job if there are no free slots left
+ slots_left = job_slots_left.get(job.task_id)
+ if slots_left == 0:
+ continue
+ elif slots_left is not None:
+ job_slots_left[job.task_id] -= 1
+
+ acquired_jobs.append(job)
+ increments[job.task_id] += 1
+
+ if acquired_jobs:
now = datetime.now(timezone.utc)
acquired_until = datetime.fromtimestamp(
now.timestamp() + self.lock_expiration_delay, timezone.utc)
- filters = {'_id': {'$in': [job.id for job in jobs]}}
+ filters = {'_id': {'$in': [job.id for job in acquired_jobs]}}
update = {'$set': {'acquired_by': worker_id,
'acquired_until': acquired_until}}
self._jobs.update_many(filters, update, session=session)
- return jobs
+ # Increment the running job counters on each task
+ for task_id, increment in increments.items():
+ self._tasks.find_one_and_update(
+ {'_id': task_id},
+ {'$inc': {'running_jobs': increment}},
+ session=session
+ )
+
+ return acquired_jobs
- def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
with self.client.start_session() as session:
+ # Insert the job result
now = datetime.now(timezone.utc)
- serialized_result = self.serializer.serialize(result)
document = {
- '_id': job_id,
+ '_id': job.id,
'finished_at': now,
- 'serialized_data': serialized_result
+ 'serialized_data': self.serializer.serialize(result)
}
self._jobs_results.insert_one(document, session=session)
- filters = {'_id': job_id, 'acquired_by': worker_id}
- self._jobs.delete_one(filters, session=session)
+ # Decrement the running jobs counter
+ self._tasks.find_one_and_update(
+ {'_id': job.task_id},
+ {'$inc': {'running_jobs': -1}}
+ )
+
+ # Delete the job
+ self._jobs.delete_one({'_id': job.id}, session=session)
def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
document = self._jobs_results.find_one_and_delete(