diff options
Diffstat (limited to 'src/apscheduler/datastores/async_/sqlalchemy.py')
-rw-r--r-- | src/apscheduler/datastores/async_/sqlalchemy.py | 132 |
1 files changed, 119 insertions, 13 deletions
diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py index caf832c..00e9efc 100644 --- a/src/apscheduler/datastores/async_/sqlalchemy.py +++ b/src/apscheduler/datastores/async_/sqlalchemy.py @@ -2,6 +2,7 @@ from __future__ import annotations import json import logging +from collections import defaultdict from contextlib import AsyncExitStack, closing from datetime import datetime, timedelta, timezone from json import JSONDecodeError @@ -18,17 +19,19 @@ from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable -from sqlalchemy.sql.elements import BindParameter +from sqlalchemy.sql.elements import BindParameter, literal from ... import events as events_module from ...abc import AsyncDataStore, Job, Schedule, Serializer from ...events import ( AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded, - ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken) -from ...exceptions import ConflictingIdError, SerializationError + ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, + TaskRemoved) +from ...exceptions import ConflictingIdError, SerializationError, TaskLookupError +from ...marshalling import callable_to_ref from ...policies import ConflictPolicy from ...serializers.pickle import PickleSerializer -from ...structures import JobResult +from ...structures import JobResult, Task from ...util import reentrant logger = logging.getLogger(__name__) @@ -78,6 +81,7 @@ class SQLAlchemyDataStore(AsyncDataStore): # Generate the table definitions self._metadata = self.get_table_definitions() self.t_metadata = self._metadata.tables['metadata'] + self.t_tasks = self._metadata.tables['tasks'] self.t_schedules = self._metadata.tables['schedules'] self.t_jobs = self._metadata.tables['jobs'] self.t_job_results = self._metadata.tables['job_results'] @@ -153,7 +157,11 @@ class SQLAlchemyDataStore(AsyncDataStore): 'tasks', metadata, Column('id', Unicode(500), primary_key=True), - Column('serialized_data', LargeBinary, nullable=False) + Column('func', Unicode(500), nullable=False), + Column('state', LargeBinary), + Column('max_running_jobs', Integer), + Column('misfire_grace_time', Unicode(16)), + Column('running_jobs', Integer, nullable=False, server_default=literal(0)) ) Table( 'schedules', @@ -259,6 +267,55 @@ class SQLAlchemyDataStore(AsyncDataStore): def unsubscribe(self, token: SubscriptionToken) -> None: self._events.unsubscribe(token) + async def add_task(self, task: Task) -> None: + insert = self.t_tasks.insert().\ + values(id=task.id, func=callable_to_ref(task.func), + max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time) + try: + async with self.engine.begin() as conn: + await conn.execute(insert) + except IntegrityError: + update = self.t_tasks.update().\ + values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, + misfire_grace_time=task.misfire_grace_time).\ + where(self.t_tasks.c.id == task.id) + async with self.engine.begin() as conn: + await conn.execute(update) + + self._events.publish(TaskAdded(task_id=task.id)) + + async def remove_task(self, task_id: str) -> None: + delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) + async with self.engine.begin() as conn: + result = await conn.execute(delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + self._events.publish(TaskRemoved(task_id=task_id)) + + async def get_task(self, task_id: str) -> Task: + query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ + where(self.t_tasks.c.id == task_id) + async with self.engine.begin() as conn: + result = await conn.execute(query) + row = result.fetch_one() + + if row: + return Task.unmarshal(self.serializer, row._asdict()) + else: + raise TaskLookupError + + async def get_tasks(self) -> List[Task]: + query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, + self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ + order_by(self.t_tasks.c.id) + async with self.engine.begin() as conn: + result = await conn.execute(query) + tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] + return tasks + async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: event: Event serialized_data = self.serializer.serialize(schedule) @@ -436,30 +493,79 @@ class SQLAlchemyDataStore(AsyncDataStore): now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ where(or_(self.t_jobs.c.acquired_until.is_(None), self.t_jobs.c.acquired_until < now)).\ order_by(self.t_jobs.c.created_at).\ limit(limit) - serialized_jobs: Dict[str, bytes] = {row[0]: row[1] - for row in await conn.execute(query)} - if serialized_jobs: + result = await conn.execute(query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = self._deserialize_jobs(result) + task_ids: Set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select([self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ + where(self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids)) + result = await conn.execute(query) + job_slots_left = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: List[Job] = [] + increments: Dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id.hex for job in acquired_jobs] update = self.t_jobs.update().\ values(acquired_by=worker_id, acquired_until=acquired_until).\ - where(self.t_jobs.c.id.in_(serialized_jobs)) + where(self.t_jobs.c.id.in_(acquired_job_ids)) await conn.execute(update) - return self._deserialize_jobs(serialized_jobs.items()) + # Increment the running job counters on each task + p_id = bindparam('p_id') + p_increment = bindparam('p_increment') + params = [{'p_id': task_id, 'p_increment': increment} + for task_id, increment in increments.items()] + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ + where(self.t_tasks.c.id == p_id) + await conn.execute(update, params) + + return acquired_jobs - async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None: + async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: async with self.engine.begin() as conn: + # Insert the job result now = datetime.now(timezone.utc) serialized_data = self.serializer.serialize(result) insert = self.t_job_results.insert().\ - values(job_id=job_id.hex, finished_at=now, serialized_data=serialized_data) + values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_data) await conn.execute(insert) - delete = self.t_jobs.delete().where(self.t_jobs.c.id == job_id.hex) + # Decrement the running jobs counter + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs - 1).\ + where(self.t_tasks.c.id == job.task_id) + await conn.execute(update) + + # Delete the job + delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex) await conn.execute(delete) async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: |