summaryrefslogtreecommitdiff
path: root/src/apscheduler/datastores/async_/sqlalchemy.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/apscheduler/datastores/async_/sqlalchemy.py')
-rw-r--r--src/apscheduler/datastores/async_/sqlalchemy.py132
1 files changed, 119 insertions, 13 deletions
diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py
index caf832c..00e9efc 100644
--- a/src/apscheduler/datastores/async_/sqlalchemy.py
+++ b/src/apscheduler/datastores/async_/sqlalchemy.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import json
import logging
+from collections import defaultdict
from contextlib import AsyncExitStack, closing
from datetime import datetime, timedelta, timezone
from json import JSONDecodeError
@@ -18,17 +19,19 @@ from sqlalchemy.exc import CompileError, IntegrityError
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.ddl import DropTable
-from sqlalchemy.sql.elements import BindParameter
+from sqlalchemy.sql.elements import BindParameter, literal
from ... import events as events_module
from ...abc import AsyncDataStore, Job, Schedule, Serializer
from ...events import (
AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded,
- ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken)
-from ...exceptions import ConflictingIdError, SerializationError
+ ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded,
+ TaskRemoved)
+from ...exceptions import ConflictingIdError, SerializationError, TaskLookupError
+from ...marshalling import callable_to_ref
from ...policies import ConflictPolicy
from ...serializers.pickle import PickleSerializer
-from ...structures import JobResult
+from ...structures import JobResult, Task
from ...util import reentrant
logger = logging.getLogger(__name__)
@@ -78,6 +81,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
# Generate the table definitions
self._metadata = self.get_table_definitions()
self.t_metadata = self._metadata.tables['metadata']
+ self.t_tasks = self._metadata.tables['tasks']
self.t_schedules = self._metadata.tables['schedules']
self.t_jobs = self._metadata.tables['jobs']
self.t_job_results = self._metadata.tables['job_results']
@@ -153,7 +157,11 @@ class SQLAlchemyDataStore(AsyncDataStore):
'tasks',
metadata,
Column('id', Unicode(500), primary_key=True),
- Column('serialized_data', LargeBinary, nullable=False)
+ Column('func', Unicode(500), nullable=False),
+ Column('state', LargeBinary),
+ Column('max_running_jobs', Integer),
+ Column('misfire_grace_time', Unicode(16)),
+ Column('running_jobs', Integer, nullable=False, server_default=literal(0))
)
Table(
'schedules',
@@ -259,6 +267,55 @@ class SQLAlchemyDataStore(AsyncDataStore):
def unsubscribe(self, token: SubscriptionToken) -> None:
self._events.unsubscribe(token)
+ async def add_task(self, task: Task) -> None:
+ insert = self.t_tasks.insert().\
+ values(id=task.id, func=callable_to_ref(task.func),
+ max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time)
+ try:
+ async with self.engine.begin() as conn:
+ await conn.execute(insert)
+ except IntegrityError:
+ update = self.t_tasks.update().\
+ values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs,
+ misfire_grace_time=task.misfire_grace_time).\
+ where(self.t_tasks.c.id == task.id)
+ async with self.engine.begin() as conn:
+ await conn.execute(update)
+
+ self._events.publish(TaskAdded(task_id=task.id))
+
+ async def remove_task(self, task_id: str) -> None:
+ delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id)
+ async with self.engine.begin() as conn:
+ result = await conn.execute(delete)
+ if result.rowcount == 0:
+ raise TaskLookupError(task_id)
+ else:
+ self._events.publish(TaskRemoved(task_id=task_id))
+
+ async def get_task(self, task_id: str) -> Task:
+ query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
+ where(self.t_tasks.c.id == task_id)
+ async with self.engine.begin() as conn:
+ result = await conn.execute(query)
+ row = result.fetch_one()
+
+ if row:
+ return Task.unmarshal(self.serializer, row._asdict())
+ else:
+ raise TaskLookupError
+
+ async def get_tasks(self) -> List[Task]:
+ query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs,
+ self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\
+ order_by(self.t_tasks.c.id)
+ async with self.engine.begin() as conn:
+ result = await conn.execute(query)
+ tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result]
+ return tasks
+
async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None:
event: Event
serialized_data = self.serializer.serialize(schedule)
@@ -436,30 +493,79 @@ class SQLAlchemyDataStore(AsyncDataStore):
now = datetime.now(timezone.utc)
acquired_until = now + timedelta(seconds=self.lock_expiration_delay)
query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\
+ join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\
where(or_(self.t_jobs.c.acquired_until.is_(None),
self.t_jobs.c.acquired_until < now)).\
order_by(self.t_jobs.c.created_at).\
limit(limit)
- serialized_jobs: Dict[str, bytes] = {row[0]: row[1]
- for row in await conn.execute(query)}
- if serialized_jobs:
+ result = await conn.execute(query)
+ if not result:
+ return []
+
+ # Mark the jobs as acquired by this worker
+ jobs = self._deserialize_jobs(result)
+ task_ids: Set[str] = {job.task_id for job in jobs}
+
+ # Retrieve the limits
+ query = select([self.t_tasks.c.id,
+ self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\
+ where(self.t_tasks.c.max_running_jobs.isnot(None),
+ self.t_tasks.c.id.in_(task_ids))
+ result = await conn.execute(query)
+ job_slots_left = dict(result.fetchall())
+
+ # Filter out jobs that don't have free slots
+ acquired_jobs: List[Job] = []
+ increments: Dict[str, int] = defaultdict(lambda: 0)
+ for job in jobs:
+ # Don't acquire the job if there are no free slots left
+ slots_left = job_slots_left.get(job.task_id)
+ if slots_left == 0:
+ continue
+ elif slots_left is not None:
+ job_slots_left[job.task_id] -= 1
+
+ acquired_jobs.append(job)
+ increments[job.task_id] += 1
+
+ if acquired_jobs:
+ # Mark the acquired jobs as acquired by this worker
+ acquired_job_ids = [job.id.hex for job in acquired_jobs]
update = self.t_jobs.update().\
values(acquired_by=worker_id, acquired_until=acquired_until).\
- where(self.t_jobs.c.id.in_(serialized_jobs))
+ where(self.t_jobs.c.id.in_(acquired_job_ids))
await conn.execute(update)
- return self._deserialize_jobs(serialized_jobs.items())
+ # Increment the running job counters on each task
+ p_id = bindparam('p_id')
+ p_increment = bindparam('p_increment')
+ params = [{'p_id': task_id, 'p_increment': increment}
+ for task_id, increment in increments.items()]
+ update = self.t_tasks.update().\
+ values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\
+ where(self.t_tasks.c.id == p_id)
+ await conn.execute(update, params)
+
+ return acquired_jobs
- async def release_job(self, worker_id: str, job_id: UUID, result: Optional[JobResult]) -> None:
+ async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None:
async with self.engine.begin() as conn:
+ # Insert the job result
now = datetime.now(timezone.utc)
serialized_data = self.serializer.serialize(result)
insert = self.t_job_results.insert().\
- values(job_id=job_id.hex, finished_at=now, serialized_data=serialized_data)
+ values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_data)
await conn.execute(insert)
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == job_id.hex)
+ # Decrement the running jobs counter
+ update = self.t_tasks.update().\
+ values(running_jobs=self.t_tasks.c.running_jobs - 1).\
+ where(self.t_tasks.c.id == job.task_id)
+ await conn.execute(update)
+
+ # Delete the job
+ delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex)
await conn.execute(delete)
async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: