diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-26 19:51:02 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-26 19:51:02 +0300 |
commit | 410a6667beeea383ae174583193bc8c2a9c21f84 (patch) | |
tree | 3f7fd1fbafc837e5754718ddcf57d2afecd6dbc8 /src/apscheduler/datastores | |
parent | e207204d97c1cb89eb76f21759fa53fa0ec94b5d (diff) | |
download | apscheduler-410a6667beeea383ae174583193bc8c2a9c21f84.tar.gz |
Added the ability to retry operations to all persistent data stores
Diffstat (limited to 'src/apscheduler/datastores')
-rw-r--r-- | src/apscheduler/datastores/async_sqlalchemy.py | 502 | ||||
-rw-r--r-- | src/apscheduler/datastores/mongodb.py | 427 | ||||
-rw-r--r-- | src/apscheduler/datastores/sqlalchemy.py | 467 |
3 files changed, 769 insertions, 627 deletions
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py index 8f1632e..3c6d963 100644 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ b/src/apscheduler/datastores/async_sqlalchemy.py @@ -5,11 +5,13 @@ from datetime import datetime, timedelta, timezone from typing import Any, Iterable, Optional from uuid import UUID +import anyio import attr import sniffio +import tenacity from sqlalchemy import and_, bindparam, or_, select from sqlalchemy.engine import URL, Result -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, InterfaceError from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable @@ -33,39 +35,54 @@ from .sqlalchemy import _BaseSQLAlchemyDataStore @attr.define(eq=False) class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): engine: AsyncEngine + _events: AsyncEventBroker = attr.field(factory=LocalAsyncEventBroker) + _retrying: tenacity.AsyncRetrying = attr.field(init=False) @classmethod def from_url(cls, url: str | URL, **options) -> AsyncSQLAlchemyDataStore: engine = create_async_engine(url, future=True) return cls(engine, **options) + def __attrs_post_init__(self) -> None: + super().__attrs_post_init__() + + # Construct the Tenacity retry controller + # OSError is raised by asyncpg if it can't connect + self._retrying = tenacity.AsyncRetrying( + stop=self.retry_settings.stop, wait=self.retry_settings.wait, + retry=tenacity.retry_if_exception_type((InterfaceError, OSError)), + after=self._after_attempt, sleep=anyio.sleep, reraise=True) + async def __aenter__(self): asynclib = sniffio.current_async_library() or '(unknown)' if asynclib != 'asyncio': raise RuntimeError(f'This data store requires asyncio; currently running: {asynclib}') # Verify that the schema is in place - async with self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - await conn.execute(DropTable(table, if_exists=True)) - - await conn.run_sync(self._metadata.create_all) - query = select(self.t_metadata.c.schema_version) - result = await conn.execute(query) - version = result.scalar() - if version is None: - await conn.execute(self.t_metadata.insert(values={'schema_version': 1})) - elif version > 1: - raise RuntimeError(f'Unexpected schema version ({version}); ' - f'only version 1 is supported by this version of APScheduler') - - await self.events.__aenter__() + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + await conn.execute(DropTable(table, if_exists=True)) + + await conn.run_sync(self._metadata.create_all) + query = select(self.t_metadata.c.schema_version) + result = await conn.execute(query) + version = result.scalar() + if version is None: + await conn.execute(self.t_metadata.insert(values={'schema_version': 1})) + elif version > 1: + raise RuntimeError( + f'Unexpected schema version ({version}); ' + f'only version 1 is supported by this version of APScheduler') + + await self._events.__aenter__() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.events.__aexit__(exc_type, exc_val, exc_tb) + await self._events.__aexit__(exc_type, exc_val, exc_tb) @property def events(self) -> EventSource: @@ -99,15 +116,19 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time) try: - async with self.engine.begin() as conn: - await conn.execute(insert) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + await conn.execute(insert) except IntegrityError: update = self.t_tasks.update().\ values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time).\ where(self.t_tasks.c.id == task.id) - async with self.engine.begin() as conn: - await conn.execute(update) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + await conn.execute(update) await self._events.publish(TaskUpdated(task_id=task.id)) else: @@ -115,20 +136,24 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): async def remove_task(self, task_id: str) -> None: delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - async with self.engine.begin() as conn: - result = await conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - await self._events.publish(TaskRemoved(task_id=task_id)) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + result = await conn.execute(delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + await self._events.publish(TaskRemoved(task_id=task_id)) async def get_task(self, task_id: str) -> Task: query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ where(self.t_tasks.c.id == task_id) - async with self.engine.begin() as conn: - result = await conn.execute(query) - row = result.fetch_one() + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + result = await conn.execute(query) + row = result.one() if row: return Task.unmarshal(self.serializer, row._asdict()) @@ -139,18 +164,22 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ order_by(self.t_tasks.c.id) - async with self.engine.begin() as conn: - result = await conn.execute(query) - tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] - return tasks + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + result = await conn.execute(query) + tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] + return tasks async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: event: DataStoreEvent values = schedule.marshal(self.serializer) insert = self.t_schedules.insert().values(**values) try: - async with self.engine.begin() as conn: - await conn.execute(insert) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + await conn.execute(insert) except IntegrityError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None @@ -159,8 +188,9 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): update = self.t_schedules.update().\ where(self.t_schedules.c.id == schedule.id).\ values(**values) - async with self.engine.begin() as conn: - await conn.execute(update) + async for attempt in self._retrying: + async with attempt, self.engine.begin() as conn: + await conn.execute(update) event = ScheduleUpdated(schedule_id=schedule.id, next_fire_time=schedule.next_fire_time) @@ -171,15 +201,17 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): await self._events.publish(event) async def remove_schedules(self, ids: Iterable[str]) -> None: - async with self.engine.begin() as conn: - delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [row[0] for row in await conn.execute(delete)] - else: - # TODO: actually check which rows were deleted? - await conn.execute(delete) - removed_ids = ids + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) + if self._supports_update_returning: + delete = delete.returning(self.t_schedules.c.id) + removed_ids: Iterable[str] = [row[0] for row in await conn.execute(delete)] + else: + # TODO: actually check which rows were deleted? + await conn.execute(delete) + removed_ids = ids for schedule_id in removed_ids: await self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) @@ -189,90 +221,99 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): if ids: query = query.where(self.t_schedules.c.id.in_(ids)) - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_schedules(result) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + result = await conn.execute(query) + return await self._deserialize_schedules(result) async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = select(self.t_schedules.c.id).\ - where(and_(self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_(self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now))).\ - order_by(self.t_schedules.c.next_fire_time).\ - limit(limit).with_for_update(skip_locked=True).cte() - subselect = select([schedules_cte.c.id]) - update = self.t_schedules.update().\ - where(self.t_schedules.c.id.in_(subselect)).\ - values(acquired_by=scheduler_id, acquired_until=acquired_until) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = await conn.execute(update) - else: - await conn.execute(update) - query = self.t_schedules.select().\ - where(and_(self.t_schedules.c.acquired_by == scheduler_id)) - result = conn.execute(query) - - schedules = await self._deserialize_schedules(result) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + schedules_cte = select(self.t_schedules.c.id).\ + where(and_(self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_(self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now))).\ + order_by(self.t_schedules.c.next_fire_time).\ + limit(limit).with_for_update(skip_locked=True).cte() + subselect = select([schedules_cte.c.id]) + update = self.t_schedules.update().\ + where(self.t_schedules.c.id.in_(subselect)).\ + values(acquired_by=scheduler_id, acquired_until=acquired_until) + if self._supports_update_returning: + update = update.returning(*self.t_schedules.columns) + result = await conn.execute(update) + else: + await conn.execute(update) + query = self.t_schedules.select().\ + where(and_(self.t_schedules.c.acquired_by == scheduler_id)) + result = conn.execute(query) + + schedules = await self._deserialize_schedules(result) return schedules async def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: - async with self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize(schedule.trigger) - except SerializationError: - self._logger.exception('Error serializing trigger for schedule %r – ' - 'removing from data store', schedule.id) - finished_schedule_ids.append(schedule.id) - continue - - update_args.append({ - 'p_id': schedule.id, - 'p_trigger': serialized_trigger, - 'p_next_fire_time': schedule.next_fire_time - }) - else: - finished_schedule_ids.append(schedule.id) - - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam('p_id') - p_trigger: BindParameter = bindparam('p_trigger') - p_next_fire_time: BindParameter = bindparam('p_next_fire_time') - update = self.t_schedules.update().\ - where(and_(self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id)).\ - values(trigger=p_trigger, next_fire_time=p_next_fire_time, - acquired_by=None, acquired_until=None) - next_fire_times = {arg['p_id']: arg['p_next_fire_time'] for arg in update_args} - if self._supports_update_returning: - update = update.returning(self.t_schedules.c.id) - updated_ids = [row[0] for row in await conn.execute(update, update_args)] - else: - # TODO: actually check which rows were updated? - await conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated(schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id]) - update_events.append(event) - - # Remove schedules that have no next fire time or failed to serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().\ - where(self.t_schedules.c.id.in_(finished_schedule_ids)) - await conn.execute(delete) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + update_events: list[ScheduleUpdated] = [] + finished_schedule_ids: list[str] = [] + update_args: list[dict[str, Any]] = [] + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_trigger = self.serializer.serialize(schedule.trigger) + except SerializationError: + self._logger.exception( + 'Error serializing trigger for schedule %r – ' + 'removing from data store', schedule.id) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append({ + 'p_id': schedule.id, + 'p_trigger': serialized_trigger, + 'p_next_fire_time': schedule.next_fire_time + }) + else: + finished_schedule_ids.append(schedule.id) + + # Update schedules that have a next fire time + if update_args: + p_id: BindParameter = bindparam('p_id') + p_trigger: BindParameter = bindparam('p_trigger') + p_next_fire_time: BindParameter = bindparam('p_next_fire_time') + update = self.t_schedules.update().\ + where(and_(self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id)).\ + values(trigger=p_trigger, next_fire_time=p_next_fire_time, + acquired_by=None, acquired_until=None) + next_fire_times = {arg['p_id']: arg['p_next_fire_time'] + for arg in update_args} + if self._supports_update_returning: + update = update.returning(self.t_schedules.c.id) + updated_ids = [ + row[0] for row in await conn.execute(update, update_args)] + else: + # TODO: actually check which rows were updated? + await conn.execute(update, update_args) + updated_ids = list(next_fire_times) + + for schedule_id in updated_ids: + event = ScheduleUpdated(schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id]) + update_events.append(event) + + # Remove schedules that have no next fire time or failed to serialize + if finished_schedule_ids: + delete = self.t_schedules.delete().\ + where(self.t_schedules.c.id.in_(finished_schedule_ids)) + await conn.execute(delete) for event in update_events: await self._events.publish(event) @@ -285,15 +326,19 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): where(self.t_schedules.c.next_fire_time.isnot(None)).\ order_by(self.t_schedules.c.next_fire_time).\ limit(1) - async with self.engine.begin() as conn: - result = await conn.execute(statenent) - return result.scalar() + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + result = await conn.execute(statenent) + return result.scalar() async def add_job(self, job: Job) -> None: marshalled = job.marshal(self.serializer) insert = self.t_jobs.insert().values(**marshalled) - async with self.engine.begin() as conn: - await conn.execute(insert) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + await conn.execute(insert) event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, tags=job.tags) @@ -305,103 +350,112 @@ class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) - async with self.engine.begin() as conn: - result = await conn.execute(query) - return await self._deserialize_jobs(result) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + result = await conn.execute(query) + return await self._deserialize_jobs(result) async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> list[Job]: - async with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = self.t_jobs.select().\ - join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ - where(or_(self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now)).\ - order_by(self.t_jobs.c.created_at).\ - with_for_update(skip_locked=True).\ - limit(limit) - - result = await conn.execute(query) - if not result: - return [] - - # Mark the jobs as acquired by this worker - jobs = await self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select([self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ - where(self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids)) - result = await conn.execute(query) - job_slots_left: dict[str, int] = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = self.t_jobs.update().\ - values(acquired_by=worker_id, acquired_until=acquired_until).\ - where(self.t_jobs.c.id.in_(acquired_job_ids)) - await conn.execute(update) - - # Increment the running job counters on each task - p_id: BindParameter = bindparam('p_id') - p_increment: BindParameter = bindparam('p_increment') - params = [{'p_id': task_id, 'p_increment': increment} - for task_id, increment in increments.items()] - update = self.t_tasks.update().\ - values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ - where(self.t_tasks.c.id == p_id) - await conn.execute(update, params) - - # Publish the appropriate events - for job in acquired_jobs: - await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) - - return acquired_jobs + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + query = self.t_jobs.select().\ + join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ + where(or_(self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now)).\ + order_by(self.t_jobs.c.created_at).\ + with_for_update(skip_locked=True).\ + limit(limit) + + result = await conn.execute(query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = await self._deserialize_jobs(result) + task_ids: set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select([ + self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ + where(self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids)) + result = await conn.execute(query) + job_slots_left: dict[str, int] = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: list[Job] = [] + increments: dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id for job in acquired_jobs] + update = self.t_jobs.update().\ + values(acquired_by=worker_id, acquired_until=acquired_until).\ + where(self.t_jobs.c.id.in_(acquired_job_ids)) + await conn.execute(update) + + # Increment the running job counters on each task + p_id: BindParameter = bindparam('p_id') + p_increment: BindParameter = bindparam('p_increment') + params = [{'p_id': task_id, 'p_increment': increment} + for task_id, increment in increments.items()] + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ + where(self.t_tasks.c.id == p_id) + await conn.execute(update, params) + + # Publish the appropriate events + for job in acquired_jobs: + await self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + + return acquired_jobs async def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: - async with self.engine.begin() as conn: - # Insert the job result - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - await conn.execute(insert) - - # Decrement the running jobs counter - update = self.t_tasks.update().\ - values(running_jobs=self.t_tasks.c.running_jobs - 1).\ - where(self.t_tasks.c.id == task_id) - await conn.execute(update) + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + # Insert the job result + marshalled = result.marshal(self.serializer) + insert = self.t_job_results.insert().values(**marshalled) + await conn.execute(insert) + + # Decrement the running jobs counter + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs - 1).\ + where(self.t_tasks.c.id == task_id) + await conn.execute(update) - # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) - await conn.execute(delete) + # Delete the job + delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) + await conn.execute(delete) async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: - async with self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().\ - where(self.t_job_results.c.job_id == job_id) - row = (await conn.execute(query)).fetchone() - - # Delete the result - delete = self.t_job_results.delete().\ - where(self.t_job_results.c.job_id == job_id) - await conn.execute(delete) - - return JobResult.unmarshal(self.serializer, row._asdict()) if row else None + async for attempt in self._retrying: + with attempt: + async with self.engine.begin() as conn: + # Retrieve the result + query = self.t_job_results.select().\ + where(self.t_job_results.c.job_id == job_id) + row = (await conn.execute(query)).fetchone() + + # Delete the result + delete = self.t_job_results.delete().\ + where(self.t_job_results.c.job_id == job_id) + await conn.execute(delete) + + return JobResult.unmarshal(self.serializer, row._asdict()) if row else None diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index cef0f7e..4e514e2 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -10,12 +10,14 @@ from uuid import UUID import attr import pymongo +import tenacity from attr.validators import instance_of from bson import CodecOptions from bson.codec_options import TypeEncoder, TypeRegistry from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection -from pymongo.errors import DuplicateKeyError +from pymongo.errors import ConnectionFailure, DuplicateKeyError +from tenacity import Retrying from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome @@ -26,7 +28,7 @@ from ..events import ( from ..exceptions import ( ConflictingIdError, DeserializationError, SerializationError, TaskLookupError) from ..serializers.pickle import PickleSerializer -from ..structures import JobResult, Task +from ..structures import JobResult, RetrySettings, Task from ..util import reentrant @@ -50,6 +52,7 @@ class MongoDBDataStore(DataStore): serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True) database: str = attr.field(default='apscheduler', kw_only=True) lock_expiration_delay: float = attr.field(default=30, kw_only=True) + retry_settings: RetrySettings = attr.field(default=RetrySettings()) start_from_scratch: bool = attr.field(default=False, kw_only=True) _task_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Task)] @@ -57,11 +60,17 @@ class MongoDBDataStore(DataStore): _job_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Job)] _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _retrying: Retrying = attr.field(init=False) _exit_stack: ExitStack = attr.field(init=False, factory=ExitStack) _events: EventBroker = attr.field(init=False, factory=LocalEventBroker) _local_tasks: dict[str, Task] = attr.field(init=False, factory=dict) def __attrs_post_init__(self) -> None: + # Construct the Tenacity retry controller + self._retrying = Retrying(stop=self.retry_settings.stop, wait=self.retry_settings.wait, + retry=tenacity.retry_if_exception_type(ConnectionFailure), + after=self._after_attempt, reraise=True) + type_registry = TypeRegistry([ CustomEncoder(timedelta, timedelta.total_seconds), CustomEncoder(ConflictPolicy, operator.attrgetter('name')), @@ -84,6 +93,10 @@ class MongoDBDataStore(DataStore): def events(self) -> EventSource: return self._events + def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None: + self._logger.warning('Temporary data store error (attempt %d): %s', + retry_state.attempt_number, retry_state.outcome.exception()) + def __enter__(self): server_info = self.client.server_info() if server_info['versionArray'] < [4, 0]: @@ -93,29 +106,35 @@ class MongoDBDataStore(DataStore): self._exit_stack.__enter__() self._exit_stack.enter_context(self._events) - if self.start_from_scratch: - self._tasks.delete_many({}) - self._schedules.delete_many({}) - self._jobs.delete_many({}) - self._jobs_results.delete_many({}) - - self._schedules.create_index('next_fire_time') - self._jobs.create_index('task_id') - self._jobs.create_index('created_at') - self._jobs.create_index('tags') - self._jobs_results.create_index('finished_at') + for attempt in self._retrying: + with attempt, self.client.start_session() as session: + if self.start_from_scratch: + self._tasks.delete_many({}, session=session) + self._schedules.delete_many({}, session=session) + self._jobs.delete_many({}, session=session) + self._jobs_results.delete_many({}, session=session) + + self._schedules.create_index('next_fire_time', session=session) + self._jobs.create_index('task_id', session=session) + self._jobs.create_index('created_at', session=session) + self._jobs.create_index('tags', session=session) + self._jobs_results.create_index('finished_at', session=session) + return self def __exit__(self, exc_type, exc_val, exc_tb): self._exit_stack.__exit__(exc_type, exc_val, exc_tb) def add_task(self, task: Task) -> None: - previous = self._tasks.find_one_and_update( - {'_id': task.id}, - {'$set': task.marshal(self.serializer), - '$setOnInsert': {'running_jobs': 0}}, - upsert=True - ) + for attempt in self._retrying: + with attempt: + previous = self._tasks.find_one_and_update( + {'_id': task.id}, + {'$set': task.marshal(self.serializer), + '$setOnInsert': {'running_jobs': 0}}, + upsert=True + ) + self._local_tasks[task.id] = task if previous: self._events.publish(TaskUpdated(task_id=task.id)) @@ -123,8 +142,10 @@ class MongoDBDataStore(DataStore): self._events.publish(TaskAdded(task_id=task.id)) def remove_task(self, task_id: str) -> None: - if not self._tasks.find_one_and_delete({'_id': task_id}): - raise TaskLookupError(task_id) + for attempt in self._retrying: + with attempt: + if not self._tasks.find_one_and_delete({'_id': task_id}): + raise TaskLookupError(task_id) del self._local_tasks[task_id] self._events.publish(TaskRemoved(task_id=task_id)) @@ -133,7 +154,10 @@ class MongoDBDataStore(DataStore): try: return self._local_tasks[task_id] except KeyError: - document = self._tasks.find_one({'_id': task_id}, projection=self._task_attrs) + for attempt in self._retrying: + with attempt: + document = self._tasks.find_one({'_id': task_id}, projection=self._task_attrs) + if not document: raise TaskLookupError(task_id) @@ -142,27 +166,31 @@ class MongoDBDataStore(DataStore): return task def get_tasks(self) -> list[Task]: - tasks: list[Task] = [] - for document in self._tasks.find(projection=self._task_attrs, - sort=[('_id', pymongo.ASCENDING)]): - document['id'] = document.pop('_id') - tasks.append(Task.unmarshal(self.serializer, document)) + for attempt in self._retrying: + with attempt: + tasks: list[Task] = [] + for document in self._tasks.find(projection=self._task_attrs, + sort=[('_id', pymongo.ASCENDING)]): + document['id'] = document.pop('_id') + tasks.append(Task.unmarshal(self.serializer, document)) return tasks def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]: - schedules: list[Schedule] = [] filters = {'_id': {'$in': list(ids)}} if ids is not None else {} - cursor = self._schedules.find(filters).sort('_id') - for document in cursor: - document['id'] = document.pop('_id') - try: - schedule = Schedule.unmarshal(self.serializer, document) - except DeserializationError: - self._logger.warning('Failed to deserialize schedule %r', document['_id']) - continue + for attempt in self._retrying: + with attempt: + schedules: list[Schedule] = [] + cursor = self._schedules.find(filters).sort('_id') + for document in cursor: + document['id'] = document.pop('_id') + try: + schedule = Schedule.unmarshal(self.serializer, document) + except DeserializationError: + self._logger.warning('Failed to deserialize schedule %r', document['_id']) + continue - schedules.append(schedule) + schedules.append(schedule) return schedules @@ -171,12 +199,17 @@ class MongoDBDataStore(DataStore): document = schedule.marshal(self.serializer) document['_id'] = document.pop('id') try: - self._schedules.insert_one(document) + for attempt in self._retrying: + with attempt: + self._schedules.insert_one(document) except DuplicateKeyError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None elif conflict_policy is ConflictPolicy.replace: - self._schedules.replace_one({'_id': schedule.id}, document, True) + for attempt in self._retrying: + with attempt: + self._schedules.replace_one({'_id': schedule.id}, document, True) + event = ScheduleUpdated( schedule_id=schedule.id, next_fire_time=schedule.next_fire_time) @@ -187,88 +220,98 @@ class MongoDBDataStore(DataStore): self._events.publish(event) def remove_schedules(self, ids: Iterable[str]) -> None: - with self.client.start_session() as s, s.start_transaction(): - filters = {'_id': {'$in': list(ids)}} if ids is not None else {} - cursor = self._schedules.find(filters, projection=['_id']) - ids = [doc['_id'] for doc in cursor] - if ids: - self._schedules.delete_many(filters) + filters = {'_id': {'$in': list(ids)}} if ids is not None else {} + for attempt in self._retrying: + with attempt, self.client.start_session() as session: + cursor = self._schedules.find(filters, projection=['_id'], session=session) + ids = [doc['_id'] for doc in cursor] + if ids: + self._schedules.delete_many(filters, session=session) for schedule_id in ids: self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - schedules: list[Schedule] = [] - with self.client.start_session() as s, s.start_transaction(): - cursor = self._schedules.find( - {'next_fire_time': {'$ne': None}, - '$or': [{'acquired_until': {'$exists': False}}, - {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] - } - ).sort('next_fire_time').limit(limit) - for document in cursor: - document['id'] = document.pop('_id') - schedule = Schedule.unmarshal(self.serializer, document) - schedules.append(schedule) - - if schedules: - now = datetime.now(timezone.utc) - acquired_until = datetime.fromtimestamp( - now.timestamp() + self.lock_expiration_delay, now.tzinfo) - filters = {'_id': {'$in': [schedule.id for schedule in schedules]}} - update = {'$set': {'acquired_by': scheduler_id, - 'acquired_until': acquired_until}} - self._schedules.update_many(filters, update) + for attempt in self._retrying: + with attempt, self.client.start_session() as session: + schedules: list[Schedule] = [] + cursor = self._schedules.find( + {'next_fire_time': {'$ne': None}, + '$or': [{'acquired_until': {'$exists': False}}, + {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] + }, + session=session + ).sort('next_fire_time').limit(limit) + for document in cursor: + document['id'] = document.pop('_id') + schedule = Schedule.unmarshal(self.serializer, document) + schedules.append(schedule) + + if schedules: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, now.tzinfo) + filters = {'_id': {'$in': [schedule.id for schedule in schedules]}} + update = {'$set': {'acquired_by': scheduler_id, + 'acquired_until': acquired_until}} + self._schedules.update_many(filters, update, session=session) return schedules def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: updated_schedules: list[tuple[str, datetime]] = [] finished_schedule_ids: list[str] = [] - with self.client.start_session() as s, s.start_transaction(): - # Update schedules that have a next fire time - requests = [] - for schedule in schedules: - filters = {'_id': schedule.id, 'acquired_by': scheduler_id} - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize(schedule.trigger) - except SerializationError: - self._logger.exception('Error serializing schedule %r – ' - 'removing from data store', schedule.id) - requests.append(DeleteOne(filters)) - finished_schedule_ids.append(schedule.id) - continue - update = { - '$unset': { - 'acquired_by': True, - 'acquired_until': True, - }, - '$set': { - 'trigger': serialized_trigger, - 'next_fire_time': schedule.next_fire_time - } - } - requests.append(UpdateOne(filters, update)) - updated_schedules.append((schedule.id, schedule.next_fire_time)) - else: + # Update schedules that have a next fire time + requests = [] + for schedule in schedules: + filters = {'_id': schedule.id, 'acquired_by': scheduler_id} + if schedule.next_fire_time is not None: + try: + serialized_trigger = self.serializer.serialize(schedule.trigger) + except SerializationError: + self._logger.exception('Error serializing schedule %r – ' + 'removing from data store', schedule.id) requests.append(DeleteOne(filters)) finished_schedule_ids.append(schedule.id) + continue + + update = { + '$unset': { + 'acquired_by': True, + 'acquired_until': True, + }, + '$set': { + 'trigger': serialized_trigger, + 'next_fire_time': schedule.next_fire_time + } + } + requests.append(UpdateOne(filters, update)) + updated_schedules.append((schedule.id, schedule.next_fire_time)) + else: + requests.append(DeleteOne(filters)) + finished_schedule_ids.append(schedule.id) if requests: - self._schedules.bulk_write(requests, ordered=False) - for schedule_id, next_fire_time in updated_schedules: - event = ScheduleUpdated(schedule_id=schedule_id, next_fire_time=next_fire_time) - self._events.publish(event) + for attempt in self._retrying: + with attempt, self.client.start_session() as session: + self._schedules.bulk_write(requests, ordered=False, session=session) + + for schedule_id, next_fire_time in updated_schedules: + event = ScheduleUpdated(schedule_id=schedule_id, + next_fire_time=next_fire_time) + self._events.publish(event) for schedule_id in finished_schedule_ids: self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) def get_next_schedule_run_time(self) -> Optional[datetime]: - document = self._schedules.find_one({'next_run_time': {'$ne': None}}, - projection=['next_run_time'], - sort=[('next_run_time', ASCENDING)]) + for attempt in self._retrying: + with attempt: + document = self._schedules.find_one({'next_run_time': {'$ne': None}}, + projection=['next_run_time'], + sort=[('next_run_time', ASCENDING)]) + if document: return document['next_run_time'] else: @@ -277,105 +320,112 @@ class MongoDBDataStore(DataStore): def add_job(self, job: Job) -> None: document = job.marshal(self.serializer) document['_id'] = document.pop('id') - self._jobs.insert_one(document) + for attempt in self._retrying: + with attempt: + self._jobs.insert_one(document) + event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, tags=job.tags) self._events.publish(event) def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> list[Job]: - jobs: list[Job] = [] filters = {'_id': {'$in': list(ids)}} if ids is not None else {} - cursor = self._jobs.find(filters).sort('_id') - for document in cursor: - document['id'] = document.pop('_id') - try: - job = Job.unmarshal(self.serializer, document) - except DeserializationError: - self._logger.warning('Failed to deserialize job %r', document['id']) - continue + for attempt in self._retrying: + with attempt: + jobs: list[Job] = [] + cursor = self._jobs.find(filters).sort('_id') + for document in cursor: + document['id'] = document.pop('_id') + try: + job = Job.unmarshal(self.serializer, document) + except DeserializationError: + self._logger.warning('Failed to deserialize job %r', document['id']) + continue - jobs.append(job) + jobs.append(job) return jobs def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> list[Job]: - with self.client.start_session() as session: - cursor = self._jobs.find( - {'$or': [{'acquired_until': {'$exists': False}}, - {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] - }, - sort=[('created_at', ASCENDING)], - limit=limit, - session=session - ) - documents = list(cursor) - - # Retrieve the limits - task_ids: set[str] = {document['task_id'] for document in documents} - task_limits = self._tasks.find( - {'_id': {'$in': list(task_ids)}, 'max_running_jobs': {'$ne': None}}, - projection=['max_running_jobs', 'running_jobs'], - session=session - ) - job_slots_left = {doc['_id']: doc['max_running_jobs'] - doc['running_jobs'] - for doc in task_limits} - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for document in documents: - document['id'] = document.pop('_id') - job = Job.unmarshal(self.serializer, document) - - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - now = datetime.now(timezone.utc) - acquired_until = datetime.fromtimestamp( - now.timestamp() + self.lock_expiration_delay, timezone.utc) - filters = {'_id': {'$in': [job.id for job in acquired_jobs]}} - update = {'$set': {'acquired_by': worker_id, - 'acquired_until': acquired_until}} - self._jobs.update_many(filters, update, session=session) - - # Increment the running job counters on each task - for task_id, increment in increments.items(): - self._tasks.find_one_and_update( - {'_id': task_id}, - {'$inc': {'running_jobs': increment}}, - session=session - ) - - # Publish the appropriate events - for job in acquired_jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) - - return acquired_jobs + for attempt in self._retrying: + with attempt, self.client.start_session() as session: + cursor = self._jobs.find( + {'$or': [{'acquired_until': {'$exists': False}}, + {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] + }, + sort=[('created_at', ASCENDING)], + limit=limit, + session=session + ) + documents = list(cursor) + + # Retrieve the limits + task_ids: set[str] = {document['task_id'] for document in documents} + task_limits = self._tasks.find( + {'_id': {'$in': list(task_ids)}, 'max_running_jobs': {'$ne': None}}, + projection=['max_running_jobs', 'running_jobs'], + session=session + ) + job_slots_left = {doc['_id']: doc['max_running_jobs'] - doc['running_jobs'] + for doc in task_limits} + + # Filter out jobs that don't have free slots + acquired_jobs: list[Job] = [] + increments: dict[str, int] = defaultdict(lambda: 0) + for document in documents: + document['id'] = document.pop('_id') + job = Job.unmarshal(self.serializer, document) + + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + now = datetime.now(timezone.utc) + acquired_until = datetime.fromtimestamp( + now.timestamp() + self.lock_expiration_delay, timezone.utc) + filters = {'_id': {'$in': [job.id for job in acquired_jobs]}} + update = {'$set': {'acquired_by': worker_id, + 'acquired_until': acquired_until}} + self._jobs.update_many(filters, update, session=session) + + # Increment the running job counters on each task + for task_id, increment in increments.items(): + self._tasks.find_one_and_update( + {'_id': task_id}, + {'$inc': {'running_jobs': increment}}, + session=session + ) + + # Publish the appropriate events + for job in acquired_jobs: + self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + + return acquired_jobs def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: - with self.client.start_session() as session: - # Insert the job result - document = result.marshal(self.serializer) - document['_id'] = document.pop('job_id') - self._jobs_results.insert_one(document, session=session) - - # Decrement the running jobs counter - self._tasks.find_one_and_update( - {'_id': task_id}, - {'$inc': {'running_jobs': -1}}, - session=session - ) - - # Delete the job - self._jobs.delete_one({'_id': result.job_id}, session=session) + for attempt in self._retrying: + with attempt, self.client.start_session() as session: + # Insert the job result + document = result.marshal(self.serializer) + document['_id'] = document.pop('job_id') + self._jobs_results.insert_one(document, session=session) + + # Decrement the running jobs counter + self._tasks.find_one_and_update( + {'_id': task_id}, + {'$inc': {'running_jobs': -1}}, + session=session + ) + + # Delete the job + self._jobs.delete_one({'_id': result.job_id}, session=session) # Publish the event self._events.publish( @@ -383,7 +433,10 @@ class MongoDBDataStore(DataStore): ) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: - document = self._jobs_results.find_one_and_delete({'_id': job_id}) + for attempt in self._retrying: + with attempt: + document = self._jobs_results.find_one_and_delete({'_id': job_id}) + if document: document['job_id'] = document.pop('_id') return JobResult.unmarshal(self.serializer, document) diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index 06dfe9e..a52d6c1 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -7,11 +7,12 @@ from typing import Any, Iterable, Optional from uuid import UUID import attr +import tenacity from sqlalchemy import ( JSON, TIMESTAMP, BigInteger, Column, Enum, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, and_, bindparam, or_, select) from sqlalchemy.engine import URL, Dialect, Result -from sqlalchemy.exc import CompileError, IntegrityError +from sqlalchemy.exc import CompileError, IntegrityError, OperationalError from sqlalchemy.future import Engine, create_engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal @@ -26,7 +27,7 @@ from ..events import ( from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref from ..serializers.pickle import PickleSerializer -from ..structures import JobResult, Task +from ..structures import JobResult, RetrySettings, Task from ..util import reentrant @@ -70,6 +71,7 @@ class _BaseSQLAlchemyDataStore: lock_expiration_delay: float = attr.field(default=30) max_poll_time: Optional[float] = attr.field(default=1) max_idle_time: float = attr.field(default=60) + retry_settings: RetrySettings = attr.field(default=RetrySettings()) start_from_scratch: bool = attr.field(default=False) _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) @@ -92,6 +94,10 @@ class _BaseSQLAlchemyDataStore: else: self._supports_update_returning = True + def _after_attempt(self, retry_state: tenacity.RetryCallState) -> None: + self._logger.warning('Temporary data store error (attempt %d): %s', + retry_state.attempt_number, retry_state.outcome.exception()) + def get_table_definitions(self) -> MetaData: if self.engine.dialect.name == 'postgresql': from sqlalchemy.dialects import postgresql @@ -194,27 +200,39 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): engine: Engine _events: EventBroker = attr.field(init=False, factory=LocalEventBroker) + _retrying: tenacity.Retrying = attr.field(init=False) @classmethod def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: engine = create_engine(url) return cls(engine, **options) + def __attrs_post_init__(self) -> None: + super().__attrs_post_init__() + + # Construct the Tenacity retry controller + self._retrying = tenacity.Retrying( + stop=self.retry_settings.stop, wait=self.retry_settings.wait, + retry=tenacity.retry_if_exception_type(OperationalError), after=self._after_attempt, + reraise=True) + def __enter__(self): - with self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - conn.execute(DropTable(table, if_exists=True)) - - self._metadata.create_all(conn) - query = select(self.t_metadata.c.schema_version) - result = conn.execute(query) - version = result.scalar() - if version is None: - conn.execute(self.t_metadata.insert(values={'schema_version': 1})) - elif version > 1: - raise RuntimeError(f'Unexpected schema version ({version}); ' - f'only version 1 is supported by this version of APScheduler') + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + conn.execute(DropTable(table, if_exists=True)) + + self._metadata.create_all(conn) + query = select(self.t_metadata.c.schema_version) + result = conn.execute(query) + version = result.scalar() + if version is None: + conn.execute(self.t_metadata.insert(values={'schema_version': 1})) + elif version > 1: + raise RuntimeError( + f'Unexpected schema version ({version}); ' + f'only version 1 is supported by this version of APScheduler') self._events.__enter__() return self @@ -232,35 +250,39 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time) try: - with self.engine.begin() as conn: - conn.execute(insert) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + conn.execute(insert) except IntegrityError: update = self.t_tasks.update().\ values(func=callable_to_ref(task.func), max_running_jobs=task.max_running_jobs, misfire_grace_time=task.misfire_grace_time).\ where(self.t_tasks.c.id == task.id) - with self.engine.begin() as conn: - conn.execute(update) - self._events.publish(TaskUpdated(task_id=task.id)) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + conn.execute(update) + self._events.publish(TaskUpdated(task_id=task.id)) else: self._events.publish(TaskAdded(task_id=task.id)) def remove_task(self, task_id: str) -> None: delete = self.t_tasks.delete().where(self.t_tasks.c.id == task_id) - with self.engine.begin() as conn: - result = conn.execute(delete) - if result.rowcount == 0: - raise TaskLookupError(task_id) - else: - self._events.publish(TaskRemoved(task_id=task_id)) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + result = conn.execute(delete) + if result.rowcount == 0: + raise TaskLookupError(task_id) + else: + self._events.publish(TaskRemoved(task_id=task_id)) def get_task(self, task_id: str) -> Task: query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ where(self.t_tasks.c.id == task_id) - with self.engine.begin() as conn: - result = conn.execute(query) - row = result.fetch_one() + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + result = conn.execute(query) + row = result.fetch_one() if row: return Task.unmarshal(self.serializer, row._asdict()) @@ -271,18 +293,20 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): query = select([self.t_tasks.c.id, self.t_tasks.c.func, self.t_tasks.c.max_running_jobs, self.t_tasks.c.state, self.t_tasks.c.misfire_grace_time]).\ order_by(self.t_tasks.c.id) - with self.engine.begin() as conn: - result = conn.execute(query) - tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] - return tasks + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + result = conn.execute(query) + tasks = [Task.unmarshal(self.serializer, row._asdict()) for row in result] + return tasks def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: event: Event values = schedule.marshal(self.serializer) insert = self.t_schedules.insert().values(**values) try: - with self.engine.begin() as conn: - conn.execute(insert) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + conn.execute(insert) except IntegrityError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None @@ -291,8 +315,9 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): update = self.t_schedules.update().\ where(self.t_schedules.c.id == schedule.id).\ values(**values) - with self.engine.begin() as conn: - conn.execute(update) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + conn.execute(update) event = ScheduleUpdated(schedule_id=schedule.id, next_fire_time=schedule.next_fire_time) @@ -303,15 +328,16 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): self._events.publish(event) def remove_schedules(self, ids: Iterable[str]) -> None: - with self.engine.begin() as conn: - delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) - if self._supports_update_returning: - delete = delete.returning(self.t_schedules.c.id) - removed_ids: Iterable[str] = [row[0] for row in conn.execute(delete)] - else: - # TODO: actually check which rows were deleted? - conn.execute(delete) - removed_ids = ids + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + delete = self.t_schedules.delete().where(self.t_schedules.c.id.in_(ids)) + if self._supports_update_returning: + delete = delete.returning(self.t_schedules.c.id) + removed_ids: Iterable[str] = [row[0] for row in conn.execute(delete)] + else: + # TODO: actually check which rows were deleted? + conn.execute(delete) + removed_ids = ids for schedule_id in removed_ids: self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) @@ -321,90 +347,93 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): if ids: query = query.where(self.t_schedules.c.id.in_(ids)) - with self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_schedules(result) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + result = conn.execute(query) + return self._deserialize_schedules(result) def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedule]: - with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - schedules_cte = select(self.t_schedules.c.id).\ - where(and_(self.t_schedules.c.next_fire_time.isnot(None), - self.t_schedules.c.next_fire_time <= now, - or_(self.t_schedules.c.acquired_until.is_(None), - self.t_schedules.c.acquired_until < now))).\ - order_by(self.t_schedules.c.next_fire_time).\ - limit(limit).with_for_update(skip_locked=True).cte() - subselect = select([schedules_cte.c.id]) - update = self.t_schedules.update().\ - where(self.t_schedules.c.id.in_(subselect)).\ - values(acquired_by=scheduler_id, acquired_until=acquired_until) - if self._supports_update_returning: - update = update.returning(*self.t_schedules.columns) - result = conn.execute(update) - else: - conn.execute(update) - query = self.t_schedules.select().\ - where(and_(self.t_schedules.c.acquired_by == scheduler_id)) - result = conn.execute(query) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + schedules_cte = select(self.t_schedules.c.id).\ + where(and_(self.t_schedules.c.next_fire_time.isnot(None), + self.t_schedules.c.next_fire_time <= now, + or_(self.t_schedules.c.acquired_until.is_(None), + self.t_schedules.c.acquired_until < now))).\ + order_by(self.t_schedules.c.next_fire_time).\ + limit(limit).with_for_update(skip_locked=True).cte() + subselect = select([schedules_cte.c.id]) + update = self.t_schedules.update().\ + where(self.t_schedules.c.id.in_(subselect)).\ + values(acquired_by=scheduler_id, acquired_until=acquired_until) + if self._supports_update_returning: + update = update.returning(*self.t_schedules.columns) + result = conn.execute(update) + else: + conn.execute(update) + query = self.t_schedules.select().\ + where(and_(self.t_schedules.c.acquired_by == scheduler_id)) + result = conn.execute(query) - schedules = self._deserialize_schedules(result) + schedules = self._deserialize_schedules(result) return schedules def release_schedules(self, scheduler_id: str, schedules: list[Schedule]) -> None: - with self.engine.begin() as conn: - update_events: list[ScheduleUpdated] = [] - finished_schedule_ids: list[str] = [] - update_args: list[dict[str, Any]] = [] - for schedule in schedules: - if schedule.next_fire_time is not None: - try: - serialized_trigger = self.serializer.serialize(schedule.trigger) - except SerializationError: - self._logger.exception('Error serializing trigger for schedule %r – ' - 'removing from data store', schedule.id) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + update_events: list[ScheduleUpdated] = [] + finished_schedule_ids: list[str] = [] + update_args: list[dict[str, Any]] = [] + for schedule in schedules: + if schedule.next_fire_time is not None: + try: + serialized_trigger = self.serializer.serialize(schedule.trigger) + except SerializationError: + self._logger.exception('Error serializing trigger for schedule %r – ' + 'removing from data store', schedule.id) + finished_schedule_ids.append(schedule.id) + continue + + update_args.append({ + 'p_id': schedule.id, + 'p_trigger': serialized_trigger, + 'p_next_fire_time': schedule.next_fire_time + }) + else: finished_schedule_ids.append(schedule.id) - continue - update_args.append({ - 'p_id': schedule.id, - 'p_trigger': serialized_trigger, - 'p_next_fire_time': schedule.next_fire_time - }) - else: - finished_schedule_ids.append(schedule.id) - - # Update schedules that have a next fire time - if update_args: - p_id: BindParameter = bindparam('p_id') - p_trigger: BindParameter = bindparam('p_trigger') - p_next_fire_time: BindParameter = bindparam('p_next_fire_time') - update = self.t_schedules.update().\ - where(and_(self.t_schedules.c.id == p_id, - self.t_schedules.c.acquired_by == scheduler_id)).\ - values(trigger=p_trigger, next_fire_time=p_next_fire_time, - acquired_by=None, acquired_until=None) - next_fire_times = {arg['p_id']: arg['p_next_fire_time'] for arg in update_args} - if self._supports_update_returning: - update = update.returning(self.t_schedules.c.id) - updated_ids = [row[0] for row in conn.execute(update, update_args)] - else: - # TODO: actually check which rows were updated? - conn.execute(update, update_args) - updated_ids = list(next_fire_times) - - for schedule_id in updated_ids: - event = ScheduleUpdated(schedule_id=schedule_id, - next_fire_time=next_fire_times[schedule_id]) - update_events.append(event) - - # Remove schedules that have no next fire time or failed to serialize - if finished_schedule_ids: - delete = self.t_schedules.delete().\ - where(self.t_schedules.c.id.in_(finished_schedule_ids)) - conn.execute(delete) + # Update schedules that have a next fire time + if update_args: + p_id: BindParameter = bindparam('p_id') + p_trigger: BindParameter = bindparam('p_trigger') + p_next_fire_time: BindParameter = bindparam('p_next_fire_time') + update = self.t_schedules.update().\ + where(and_(self.t_schedules.c.id == p_id, + self.t_schedules.c.acquired_by == scheduler_id)).\ + values(trigger=p_trigger, next_fire_time=p_next_fire_time, + acquired_by=None, acquired_until=None) + next_fire_times = {arg['p_id']: arg['p_next_fire_time'] for arg in update_args} + if self._supports_update_returning: + update = update.returning(self.t_schedules.c.id) + updated_ids = [row[0] for row in conn.execute(update, update_args)] + else: + # TODO: actually check which rows were updated? + conn.execute(update, update_args) + updated_ids = list(next_fire_times) + + for schedule_id in updated_ids: + event = ScheduleUpdated(schedule_id=schedule_id, + next_fire_time=next_fire_times[schedule_id]) + update_events.append(event) + + # Remove schedules that have no next fire time or failed to serialize + if finished_schedule_ids: + delete = self.t_schedules.delete().\ + where(self.t_schedules.c.id.in_(finished_schedule_ids)) + conn.execute(delete) for event in update_events: self._events.publish(event) @@ -417,15 +446,17 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): where(self.t_schedules.c.next_fire_time.isnot(None)).\ order_by(self.t_schedules.c.next_fire_time).\ limit(1) - with self.engine.begin() as conn: - result = conn.execute(query) - return result.scalar() + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + result = conn.execute(query) + return result.scalar() def add_job(self, job: Job) -> None: marshalled = job.marshal(self.serializer) insert = self.t_jobs.insert().values(**marshalled) - with self.engine.begin() as conn: - conn.execute(insert) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + conn.execute(insert) event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, tags=job.tags) @@ -437,92 +468,95 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) - with self.engine.begin() as conn: - result = conn.execute(query) - return self._deserialize_jobs(result) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + result = conn.execute(query) + return self._deserialize_jobs(result) def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> list[Job]: - with self.engine.begin() as conn: - now = datetime.now(timezone.utc) - acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = self.t_jobs.select().\ - join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ - where(or_(self.t_jobs.c.acquired_until.is_(None), - self.t_jobs.c.acquired_until < now)).\ - order_by(self.t_jobs.c.created_at).\ - with_for_update(skip_locked=True).\ - limit(limit) - - result = conn.execute(query) - if not result: - return [] - - # Mark the jobs as acquired by this worker - jobs = self._deserialize_jobs(result) - task_ids: set[str] = {job.task_id for job in jobs} - - # Retrieve the limits - query = select([self.t_tasks.c.id, - self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ - where(self.t_tasks.c.max_running_jobs.isnot(None), - self.t_tasks.c.id.in_(task_ids)) - result = conn.execute(query) - job_slots_left = dict(result.fetchall()) - - # Filter out jobs that don't have free slots - acquired_jobs: list[Job] = [] - increments: dict[str, int] = defaultdict(lambda: 0) - for job in jobs: - # Don't acquire the job if there are no free slots left - slots_left = job_slots_left.get(job.task_id) - if slots_left == 0: - continue - elif slots_left is not None: - job_slots_left[job.task_id] -= 1 - - acquired_jobs.append(job) - increments[job.task_id] += 1 - - if acquired_jobs: - # Mark the acquired jobs as acquired by this worker - acquired_job_ids = [job.id for job in acquired_jobs] - update = self.t_jobs.update().\ - values(acquired_by=worker_id, acquired_until=acquired_until).\ - where(self.t_jobs.c.id.in_(acquired_job_ids)) - conn.execute(update) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + now = datetime.now(timezone.utc) + acquired_until = now + timedelta(seconds=self.lock_expiration_delay) + query = self.t_jobs.select().\ + join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ + where(or_(self.t_jobs.c.acquired_until.is_(None), + self.t_jobs.c.acquired_until < now)).\ + order_by(self.t_jobs.c.created_at).\ + with_for_update(skip_locked=True).\ + limit(limit) - # Increment the running job counters on each task - p_id: BindParameter = bindparam('p_id') - p_increment: BindParameter = bindparam('p_increment') - params = [{'p_id': task_id, 'p_increment': increment} - for task_id, increment in increments.items()] - update = self.t_tasks.update().\ - values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ - where(self.t_tasks.c.id == p_id) - conn.execute(update, params) + result = conn.execute(query) + if not result: + return [] + + # Mark the jobs as acquired by this worker + jobs = self._deserialize_jobs(result) + task_ids: set[str] = {job.task_id for job in jobs} + + # Retrieve the limits + query = select([self.t_tasks.c.id, + self.t_tasks.c.max_running_jobs - self.t_tasks.c.running_jobs]).\ + where(self.t_tasks.c.max_running_jobs.isnot(None), + self.t_tasks.c.id.in_(task_ids)) + result = conn.execute(query) + job_slots_left = dict(result.fetchall()) + + # Filter out jobs that don't have free slots + acquired_jobs: list[Job] = [] + increments: dict[str, int] = defaultdict(lambda: 0) + for job in jobs: + # Don't acquire the job if there are no free slots left + slots_left = job_slots_left.get(job.task_id) + if slots_left == 0: + continue + elif slots_left is not None: + job_slots_left[job.task_id] -= 1 + + acquired_jobs.append(job) + increments[job.task_id] += 1 + + if acquired_jobs: + # Mark the acquired jobs as acquired by this worker + acquired_job_ids = [job.id for job in acquired_jobs] + update = self.t_jobs.update().\ + values(acquired_by=worker_id, acquired_until=acquired_until).\ + where(self.t_jobs.c.id.in_(acquired_job_ids)) + conn.execute(update) - # Publish the appropriate events - for job in acquired_jobs: - self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + # Increment the running job counters on each task + p_id: BindParameter = bindparam('p_id') + p_increment: BindParameter = bindparam('p_increment') + params = [{'p_id': task_id, 'p_increment': increment} + for task_id, increment in increments.items()] + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs + p_increment).\ + where(self.t_tasks.c.id == p_id) + conn.execute(update, params) - return acquired_jobs + # Publish the appropriate events + for job in acquired_jobs: + self._events.publish(JobAcquired(job_id=job.id, worker_id=worker_id)) + + return acquired_jobs def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: - with self.engine.begin() as conn: - # Insert the job result - marshalled = result.marshal(self.serializer) - insert = self.t_job_results.insert().values(**marshalled) - conn.execute(insert) + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + # Insert the job result + marshalled = result.marshal(self.serializer) + insert = self.t_job_results.insert().values(**marshalled) + conn.execute(insert) - # Decrement the running jobs counter - update = self.t_tasks.update().\ - values(running_jobs=self.t_tasks.c.running_jobs - 1).\ - where(self.t_tasks.c.id == task_id) - conn.execute(update) + # Decrement the running jobs counter + update = self.t_tasks.update().\ + values(running_jobs=self.t_tasks.c.running_jobs - 1).\ + where(self.t_tasks.c.id == task_id) + conn.execute(update) - # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) - conn.execute(delete) + # Delete the job + delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) + conn.execute(delete) # Publish the event self._events.publish( @@ -530,15 +564,16 @@ class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): ) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: - with self.engine.begin() as conn: - # Retrieve the result - query = self.t_job_results.select().\ - where(self.t_job_results.c.job_id == job_id) - row = conn.execute(query).fetchone() - - # Delete the result - delete = self.t_job_results.delete().\ - where(self.t_job_results.c.job_id == job_id) - conn.execute(delete) - - return JobResult.unmarshal(self.serializer, row._asdict()) if row else None + for attempt in self._retrying: + with attempt, self.engine.begin() as conn: + # Retrieve the result + query = self.t_job_results.select().\ + where(self.t_job_results.c.job_id == job_id) + row = conn.execute(query).fetchone() + + # Delete the result + delete = self.t_job_results.delete().\ + where(self.t_job_results.c.job_id == job_id) + conn.execute(delete) + + return JobResult.unmarshal(self.serializer, row._asdict()) if row else None |