diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-05 23:12:34 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-06 01:39:07 +0300 |
commit | 6fed43f29bfa7929fcaf5e549ce34819ba7e3702 (patch) | |
tree | f3c8e7675a03354cb30205bd6a21e7ae231c83a6 /src/apscheduler/datastores/sync/memory.py | |
parent | 7147cd422bfc6280a56bb4335d65935e942ab9e3 (diff) | |
download | apscheduler-6fed43f29bfa7929fcaf5e549ce34819ba7e3702.tar.gz |
Implemented task accounting
The maximum number of concurrent jobs for a given task is now enforced if set.
Diffstat (limited to 'src/apscheduler/datastores/sync/memory.py')
-rw-r--r-- | src/apscheduler/datastores/sync/memory.py | 81 |
1 files changed, 68 insertions, 13 deletions
diff --git a/src/apscheduler/datastores/sync/memory.py b/src/apscheduler/datastores/sync/memory.py index a108142..4239dc7 100644 --- a/src/apscheduler/datastores/sync/memory.py +++ b/src/apscheduler/datastores/sync/memory.py @@ -12,16 +12,27 @@ import attr from ... import events from ...abc import DataStore, Job, Schedule from ...events import ( - EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) -from ...exceptions import ConflictingIdError + EventHub, JobAdded, ScheduleAdded, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, + TaskAdded, TaskRemoved) +from ...exceptions import ConflictingIdError, TaskLookupError from ...policies import ConflictPolicy -from ...structures import JobResult +from ...structures import JobResult, Task from ...util import reentrant max_datetime = datetime(MAXYEAR, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc) @attr.define +class TaskState: + task: Task + running_jobs: int = 0 + saved_state: Any = None + + def __eq__(self, other): + return self.task.id == other.task.id + + +@attr.define class ScheduleState: schedule: Schedule next_fire_time: Optional[datetime] = attr.field(init=False, eq=False) @@ -65,6 +76,7 @@ class JobState: class MemoryDataStore(DataStore): lock_expiration_delay: float = 30 _events: EventHub = attr.Factory(EventHub) + _tasks: Dict[str, TaskState] = attr.Factory(dict) _schedules: List[ScheduleState] = attr.Factory(list) _schedules_by_id: Dict[str, ScheduleState] = attr.Factory(dict) _schedules_by_task_id: Dict[str, Set[ScheduleState]] = attr.Factory(partial(defaultdict, set)) @@ -101,6 +113,27 @@ class MemoryDataStore(DataStore): return [state.schedule for state in self._schedules if ids is None or state.schedule.id in ids] + def add_task(self, task: Task) -> None: + self._tasks[task.id] = TaskState(task) + self._events.publish(TaskAdded(task_id=task.id)) + + def remove_task(self, task_id: str) -> None: + try: + del self._tasks[task_id] + except KeyError: + raise TaskLookupError(task_id) from None + + self._events.publish(TaskRemoved(task_id=task_id)) + + def get_task(self, task_id: str) -> Task: + try: + return self._tasks[task_id].task + except KeyError: + raise TaskLookupError(task_id) from None + + def get_tasks(self) -> List[Task]: + return sorted((state.task for state in self._tasks.values()), key=lambda task: task.id) + def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: old_state = self._schedules_by_id.get(schedule.id) if old_state is not None: @@ -201,27 +234,49 @@ class MemoryDataStore(DataStore): def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> List[Job]: now = datetime.now(timezone.utc) jobs: List[Job] = [] - for state in self._jobs: - if state.acquired_by is not None and state.acquired_until >= now: + for index, job_state in enumerate(self._jobs): + task_state = self._tasks[job_state.job.task_id] + + # Skip already acquired jobs (unless the acquisition lock has expired) + if job_state.acquired_by is not None: + if job_state.acquired_until >= now: + continue + else: + task_state.running_jobs -= 1 + + # Check if the task allows one more job to be started + if (task_state.task.max_running_jobs is not None + and task_state.running_jobs >= task_state.task.max_running_jobs): continue - jobs.append(state.job) - state.acquired_by = worker_id - state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + # Mark the job as acquired by this worker + jobs.append(job_state.job) + job_state.acquired_by = worker_id + job_state.acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + + # Increment the number of running jobs for this task + task_state.running_jobs += 1 + + # Exit the loop if enough jobs have been acquired if len(jobs) == limit: break return jobs - def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: # Delete the job - state = self._jobs_by_id.pop(job_id) - self._jobs_by_task_id[state.job.task_id].remove(state) - index = self._find_job_index(state) + job_state = self._jobs_by_id.pop(job.id) + self._jobs_by_task_id[job.task_id].remove(job_state) + index = self._find_job_index(job_state) del self._jobs[index] + # Decrement the number of running jobs for this task + task_state = self._tasks.get(job.task_id) + if task_state is not None: + task_state.running_jobs -= 1 + # Record the result - self._job_results[job_id] = result + self._job_results[job.id] = result def get_job_result(self, job_id: UUID) -> Optional[JobResult]: return self._job_results.pop(job_id, None) |