summaryrefslogtreecommitdiff
path: root/src/apscheduler/datastores/sync/memory.py
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-05 23:12:34 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-06 01:39:07 +0300
commit6fed43f29bfa7929fcaf5e549ce34819ba7e3702 (patch)
treef3c8e7675a03354cb30205bd6a21e7ae231c83a6 /src/apscheduler/datastores/sync/memory.py
parent7147cd422bfc6280a56bb4335d65935e942ab9e3 (diff)
downloadapscheduler-6fed43f29bfa7929fcaf5e549ce34819ba7e3702.tar.gz
Implemented task accounting
The maximum number of concurrent jobs for a given task is now enforced if set.
Diffstat (limited to 'src/apscheduler/datastores/sync/memory.py')
-rw-r--r--src/apscheduler/datastores/sync/memory.py81
1 files changed, 68 insertions, 13 deletions
diff --git a/src/apscheduler/datastores/sync/memory.py b/src/apscheduler/datastores/sync/memory.py
index a108142..4239dc7 100644
--- a/src/apscheduler/datastores/sync/memory.py
+++ b/src/apscheduler/datastores/sync/memory.py
@@ -12,16 +12,27 @@ import attr
from ... import events
from ...abc import DataStore, Job, Schedule
from ...events import (
- EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken)
-from ...exceptions import ConflictingIdError
+ EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken,
+ TaskAdded, TaskRemoved)
+from ...exceptions import ConflictingIdError, TaskLookupError
from ...policies import ConflictPolicy
-from ...structures import JobResult
+from ...structures import JobResult, Task
from ...util import reentrant
max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)
@attr.define
+class TaskState:
+ task: Task
+ running_jobs: int = 0
+ saved_state: Any = None
+
+ def __eq__(self, other):
+ return self.task.id == other.task.id
+
+
+@attr.define
class ScheduleState:
schedule: Schedule
next_fire_time: Optional[datetime] = attr.field(init=False, eq=False)
@@ -65,6 +76,7 @@ class JobState:
class MemoryDataStore(DataStore):
lock_expiration_delay: float = 30
_events: EventHub = attr.Factory(EventHub)
+ _tasks: Dict[str, TaskState] = attr.Factory(dict)
_schedules: List[ScheduleState] = attr.Factory(list)
_schedules_by_id: Dict[str, ScheduleState] = attr.Factory(dict)
_schedules_by_task_id: Dict[str, Set[ScheduleState]] = attr.Factory(partial(defaultdict, set))
@@ -101,6 +113,27 @@ class MemoryDataStore(DataStore):
return [state.schedule for state in self._schedules
if ids is None or state.schedule.id in ids]
+ def add_task(self, task: Task) -> None:
+ self._tasks[task.id] = TaskState(task)
+ self._events.publish(TaskAdded(task_id=task.id))
+
+ def remove_task(self, task_id: str) -> None:
+ try:
+ del self._tasks[task_id]
+ except KeyError:
+ raise TaskLookupError(task_id) from None
+
+ self._events.publish(TaskRemoved(task_id=task_id))
+
+ def get_task(self, task_id: str) -> Task:
+ try:
+ return self._tasks[task_id].task
+ except KeyError:
+ raise TaskLookupError(task_id) from None
+
+ def get_tasks(self) -> List[Task]:
+ return sorted((state.task for state in self._tasks.values()), key=lambda task: task.id)
+
def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
old_state = self._schedules_by_id.get(schedule.id)
if old_state is not None:
@@ -201,27 +234,49 @@ class MemoryDataStore(DataStore):
def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]:
now = datetime.now(timezone.utc)
jobs: List[Job] = []
- for state in self._jobs:
- if state.acquired_by is not None and state.acquired_until >= now:
+ for index, job_state in enumerate(self._jobs):
+ task_state = self._tasks[job_state.job.task_id]
+
+ # Skip already acquired jobs (unless the acquisition lock has expired)
+ if job_state.acquired_by is not None:
+ if job_state.acquired_until >= now:
+ continue
+ else:
+ task_state.running_jobs -= 1
+
+ # Check if the task allows one more job to be started
+ if (task_state.task.max_running_jobs is not None
+ and task_state.running_jobs >= task_state.task.max_running_jobs):
continue
- jobs.append(state.job)
- state.acquired_by = worker_id
- state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+ # Mark the job as acquired by this worker
+ jobs.append(job_state.job)
+ job_state.acquired_by = worker_id
+ job_state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
+
+ # Increment the number of running jobs for this task
+ task_state.running_jobs += 1
+
+ # Exit the loop if enough jobs have been acquired
if len(jobs) == limit:
break
return jobs
- def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
# Delete the job
- state = self._jobs_by_id.pop(job_id)
- self._jobs_by_task_id[state.job.task_id].remove(state)
- index = self._find_job_index(state)
+ job_state = self._jobs_by_id.pop(job.id)
+ self._jobs_by_task_id[job.task_id].remove(job_state)
+ index = self._find_job_index(job_state)
del self._jobs[index]
+ # Decrement the number of running jobs for this task
+ task_state = self._tasks.get(job.task_id)
+ if task_state is not None:
+ task_state.running_jobs -= 1
+
# Record the result
- self._job_results[job_id] = result
+ self._job_results[job.id] = result
def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
return self._job_results.pop(job_id, None)