diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-08 00:06:14 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-08 00:06:14 +0300 |
commit | 48a5b0eea05f21b4cd21e9305e5c4ab755c88a94 (patch) | |
tree | 60c93cad87086881b76b1c6dc0b3d25f251d6fc1 | |
parent | 4e2585a6f613905135164d3f6a5c6adf752ba441 (diff) | |
download | apscheduler-48a5b0eea05f21b4cd21e9305e5c4ab755c88a94.tar.gz |
Serialize top level attributes into individual fields/columns
-rw-r--r-- | src/apscheduler/abc.py | 12 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_/sqlalchemy.py | 143 | ||||
-rw-r--r-- | src/apscheduler/datastores/async_/sync_adapter.py | 4 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/memory.py | 10 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/mongodb.py | 74 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/sqlalchemy.py | 136 | ||||
-rw-r--r-- | src/apscheduler/structures.py | 96 | ||||
-rw-r--r-- | src/apscheduler/workers/async_.py | 12 | ||||
-rw-r--r-- | src/apscheduler/workers/sync.py | 8 | ||||
-rw-r--r-- | tests/test_datastores.py | 27 |
10 files changed, 301 insertions, 221 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py index 1038904..8cf8dc1 100644 --- a/src/apscheduler/abc.py +++ b/src/apscheduler/abc.py @@ -218,13 +218,13 @@ class DataStore(EventSource): """ @abstractmethod - def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: """ Release the claim on the given job and record the result. :param worker_id: unique identifier of the worker - :param job: the job to be released - :param result: the result of the job (or ``None`` to discard the job) + :param task_id: the job's task ID + :param result: the result of the job """ @abstractmethod @@ -371,13 +371,13 @@ class AsyncDataStore(EventSource): """ @abstractmethod - async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: + async def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: """ Release the claim on the given job and record the result. :param worker_id: unique identifier of the worker - :param job: the job to be released - :param result: the result of the job (or ``None`` to discard the job) + :param task_id: the job's task ID + :param result: the result of the job """ @abstractmethod diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py index beaac4f..bb07c0b 100644 --- a/src/apscheduler/datastores/async_/sqlalchemy.py +++ b/src/apscheduler/datastores/async_/sqlalchemy.py @@ -6,7 +6,7 @@ from contextlib import AsyncExitStack, closing from datetime import datetime, timedelta, timezone from json import JSONDecodeError from logging import Logger, getLogger -from typing import Any, Callable, Iterable, Optional, Tuple, Type +from typing import Any, Callable, Iterable, Optional, Type from uuid import UUID import attr @@ -14,9 +14,9 @@ import sniffio from anyio import TASK_STATUS_IGNORED, create_task_group, sleep from attr import asdict from sqlalchemy import ( - TIMESTAMP, Column, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, and_, - bindparam, func, or_, select) -from sqlalchemy.engine import URL, Dialect + JSON, TIMESTAMP, Column, Enum, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, + and_, bindparam, func, or_, select) +from sqlalchemy.engine import URL, Dialect, Result from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine @@ -25,7 +25,7 @@ from sqlalchemy.sql.elements import BindParameter, literal from ... import events as events_module from ...abc import AsyncDataStore, Job, Schedule, Serializer -from ...enums import ConflictPolicy +from ...enums import CoalescePolicy, ConflictPolicy, JobOutcome from ...events import ( AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded, ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, @@ -71,6 +71,17 @@ class EmulatedUUID(TypeDecorator): return UUID(value) if value else None +class EmulatedTimestampTZ(TypeDecorator): + impl = Unicode(32) + cache_ok = True + + def process_bind_param(self, value, dialect: Dialect) -> Any: + return value.isoformat() if value is not None else None + + def process_result_value(self, value: Any, dialect: Dialect): + return datetime.fromisoformat(value) if value is not None else None + + @reentrant @attr.define(eq=False) class SQLAlchemyDataStore(AsyncDataStore): @@ -96,7 +107,7 @@ class SQLAlchemyDataStore(AsyncDataStore): self.t_jobs = self._metadata.tables['jobs'] self.t_job_results = self._metadata.tables['job_results'] - # Find out if the dialect supports RETURNING + # Find out if the dialect supports UPDATE...RETURNING update = self.t_jobs.update().returning(self.t_jobs.c.id) try: update.compile(bind=self.engine) @@ -149,17 +160,13 @@ class SQLAlchemyDataStore(AsyncDataStore): await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) def get_table_definitions(self) -> MetaData: - if self.engine.dialect.name in ('mysql', 'mariadb'): - from sqlalchemy.dialects import mysql - timestamp_type = mysql.TIMESTAMP(fsp=6) - else: - timestamp_type = TIMESTAMP(timezone=True) - if self.engine.dialect.name == 'postgresql': from sqlalchemy.dialects import postgresql + timestamp_type = TIMESTAMP(timezone=True) job_id_type = postgresql.UUID(as_uuid=True) else: + timestamp_type = EmulatedTimestampTZ job_id_type = EmulatedUUID metadata = MetaData() @@ -183,8 +190,15 @@ class SQLAlchemyDataStore(AsyncDataStore): metadata, Column('id', Unicode(500), primary_key=True), Column('task_id', Unicode(500), nullable=False, index=True), - Column('serialized_data', LargeBinary, nullable=False), + Column('trigger', LargeBinary), + Column('args', LargeBinary), + Column('kwargs', LargeBinary), + Column('coalesce', Enum(CoalescePolicy), nullable=False), + Column('misfire_grace_time', Unicode(16)), + # Column('max_jitter', Unicode(16)), + Column('tags', JSON, nullable=False), Column('next_fire_time', timestamp_type, index=True), + Column('last_fire_time', timestamp_type), Column('acquired_by', Unicode(500)), Column('acquired_until', timestamp_type) ) @@ -193,8 +207,14 @@ class SQLAlchemyDataStore(AsyncDataStore): metadata, Column('id', job_id_type, primary_key=True), Column('task_id', Unicode(500), nullable=False, index=True), - Column('serialized_data', LargeBinary, nullable=False), + Column('args', LargeBinary, nullable=False), + Column('kwargs', LargeBinary, nullable=False), + Column('schedule_id', Unicode(500)), + Column('scheduled_fire_time', timestamp_type), + Column('start_deadline', timestamp_type), + Column('tags', JSON, nullable=False), Column('created_at', timestamp_type, nullable=False), + Column('started_at', timestamp_type), Column('acquired_by', Unicode(500)), Column('acquired_until', timestamp_type) ) @@ -202,8 +222,10 @@ class SQLAlchemyDataStore(AsyncDataStore): 'job_results', metadata, Column('job_id', job_id_type, primary_key=True), + Column('outcome', Enum(JobOutcome), nullable=False), Column('finished_at', timestamp_type, index=True), - Column('serialized_data', LargeBinary, nullable=False) + Column('exception', LargeBinary), + Column('return_value', LargeBinary) ) return metadata @@ -253,25 +275,25 @@ class SQLAlchemyDataStore(AsyncDataStore): finally: await asyncpg_conn.remove_listener(self.notify_channel, callback) - def _deserialize_jobs(self, serialized_jobs: Iterable[Tuple[UUID, bytes]]) -> list[Job]: - jobs: list[Job] = [] - for job_id, serialized_data in serialized_jobs: + def _deserialize_schedules(self, result: Result) -> list[Schedule]: + schedules: list[Schedule] = [] + for row in result: try: - jobs.append(self.serializer.deserialize(serialized_data)) + schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish(JobDeserializationFailed(job_id=job_id, exception=exc)) + self._events.publish( + ScheduleDeserializationFailed(schedule_id=row['id'], exception=exc)) - return jobs + return schedules - def _deserialize_schedules( - self, serialized_schedules: Iterable[Tuple[str, bytes]]) -> list[Schedule]: - jobs: list[Schedule] = [] - for schedule_id, serialized_data in serialized_schedules: + def _deserialize_jobs(self, result: Result) -> list[Job]: + jobs: list[Job] = [] + for row in result: try: - jobs.append(self.serializer.deserialize(serialized_data)) + jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: self._events.publish( - ScheduleDeserializationFailed(schedule_id=schedule_id, exception=exc)) + JobDeserializationFailed(job_id=row['id'], exception=exc)) return jobs @@ -333,11 +355,9 @@ class SQLAlchemyDataStore(AsyncDataStore): return tasks async def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: - event: Event - serialized_data = self.serializer.serialize(schedule) - insert = self.t_schedules.insert().\ - values(id=schedule.id, task_id=schedule.task_id, serialized_data=serialized_data, - next_fire_time=schedule.next_fire_time) + event: DataStoreEvent + values = schedule.marshal(self.serializer) + insert = self.t_schedules.insert().values(**values) try: async with self.engine.begin() as conn: await conn.execute(insert) @@ -348,10 +368,10 @@ class SQLAlchemyDataStore(AsyncDataStore): if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None elif conflict_policy is ConflictPolicy.replace: + del values['id'] update = self.t_schedules.update().\ where(self.t_schedules.c.id == schedule.id).\ - values(serialized_data=serialized_data, - next_fire_time=schedule.next_fire_time) + values(**values) async with self.engine.begin() as conn: await conn.execute(update) @@ -374,8 +394,7 @@ class SQLAlchemyDataStore(AsyncDataStore): await self._publish(conn, ScheduleRemoved(schedule_id=schedule_id)) async def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]: - query = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ - order_by(self.t_schedules.c.id) + query = self.t_schedules.select().order_by(self.t_schedules.c.id) if ids: query = query.where(self.t_schedules.c.id.in_(ids)) @@ -399,14 +418,13 @@ class SQLAlchemyDataStore(AsyncDataStore): where(self.t_schedules.c.id.in_(subselect)).\ values(acquired_by=scheduler_id, acquired_until=acquired_until) if self._supports_update_returning: - update = update.returning(self.t_schedules.c.id, - self.t_schedules.c.serialized_data) + update = update.returning(*self.t_schedules.columns) result = await conn.execute(update) else: await conn.execute(update) - query = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ + query = self.t_schedules.select().\ where(and_(self.t_schedules.c.acquired_by == scheduler_id)) - result = await conn.execute(query) + result = conn.execute(query) schedules = self._deserialize_schedules(result) @@ -420,16 +438,16 @@ class SQLAlchemyDataStore(AsyncDataStore): for schedule in schedules: if schedule.next_fire_time is not None: try: - serialized_data = self.serializer.serialize(schedule) + serialized_trigger = self.serializer.serialize(schedule.trigger) except SerializationError: - self._logger.exception('Error serializing schedule %r – ' + self._logger.exception('Error serializing trigger for schedule %r – ' 'removing from data store', schedule.id) finished_schedule_ids.append(schedule.id) continue update_args.append({ 'p_id': schedule.id, - 'p_serialized_data': serialized_data, + 'p_trigger': serialized_trigger, 'p_next_fire_time': schedule.next_fire_time }) else: @@ -438,12 +456,12 @@ class SQLAlchemyDataStore(AsyncDataStore): # Update schedules that have a next fire time if update_args: p_id: BindParameter = bindparam('p_id') - p_serialized: BindParameter = bindparam('p_serialized_data') + p_trigger: BindParameter = bindparam('p_trigger') p_next_fire_time: BindParameter = bindparam('p_next_fire_time') update = self.t_schedules.update().\ where(and_(self.t_schedules.c.id == p_id, self.t_schedules.c.acquired_by == scheduler_id)).\ - values(serialized_data=p_serialized, next_fire_time=p_next_fire_time, + values(trigger=p_trigger, next_fire_time=p_next_fire_time, acquired_by=None, acquired_until=None) next_fire_times = {arg['p_id']: arg['p_next_fire_time'] for arg in update_args} if self._supports_update_returning: @@ -481,10 +499,8 @@ class SQLAlchemyDataStore(AsyncDataStore): return result.scalar() async def add_job(self, job: Job) -> None: - now = datetime.now(timezone.utc) - serialized_data = self.serializer.serialize(job) - insert = self.t_jobs.insert().values(id=job.id, task_id=job.task_id, - created_at=now, serialized_data=serialized_data) + marshalled = job.marshal(self.serializer) + insert = self.t_jobs.insert().values(**marshalled) async with self.engine.begin() as conn: await conn.execute(insert) @@ -493,22 +509,20 @@ class SQLAlchemyDataStore(AsyncDataStore): await self._publish(conn, event) async def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> list[Job]: - query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ - order_by(self.t_jobs.c.id) + query = self.t_jobs.select().order_by(self.t_jobs.c.id) if ids: job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) async with self.engine.begin() as conn: result = await conn.execute(query) - - return self._deserialize_jobs(result) + return self._deserialize_jobs(result) async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> list[Job]: async with self.engine.begin() as conn: now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + query = self.t_jobs.select().\ join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ where(or_(self.t_jobs.c.acquired_until.is_(None), self.t_jobs.c.acquired_until < now)).\ @@ -565,34 +579,33 @@ class SQLAlchemyDataStore(AsyncDataStore): return acquired_jobs - async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: + async def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: async with self.engine.begin() as conn: # Insert the job result - now = datetime.now(timezone.utc) - serialized_data = self.serializer.serialize(result) - insert = self.t_job_results.insert().\ - values(job_id=job.id, finished_at=now, serialized_data=serialized_data) + marshalled = result.marshal(self.serializer) + insert = self.t_job_results.insert().values(**marshalled) await conn.execute(insert) # Decrement the running jobs counter update = self.t_tasks.update().\ values(running_jobs=self.t_tasks.c.running_jobs - 1).\ - where(self.t_tasks.c.id == job.task_id) + where(self.t_tasks.c.id == task_id) await conn.execute(update) # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id) + delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) await conn.execute(delete) async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: async with self.engine.begin() as conn: - query = select(self.t_job_results.c.serialized_data).\ + # Retrieve the result + query = self.t_job_results.select().\ where(self.t_job_results.c.job_id == job_id) - result = await conn.execute(query) + row = (await conn.execute(query)).fetchone() + # Delete the result delete = self.t_job_results.delete().\ where(self.t_job_results.c.job_id == job_id) await conn.execute(delete) - serialized_data = result.scalar() - return self.serializer.deserialize(serialized_data) if serialized_data else None + return JobResult.unmarshal(self.serializer, row._asdict()) if row else None diff --git a/src/apscheduler/datastores/async_/sync_adapter.py b/src/apscheduler/datastores/async_/sync_adapter.py index 51b15d5..96fbd5c 100644 --- a/src/apscheduler/datastores/async_/sync_adapter.py +++ b/src/apscheduler/datastores/async_/sync_adapter.py @@ -72,8 +72,8 @@ class AsyncDataStoreAdapter(AsyncDataStore): async def acquire_jobs(self, worker_id: str, limit: Optional[int] = None) -> list[Job]: return await to_thread.run_sync(self.original.acquire_jobs, worker_id, limit) - async def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: - await to_thread.run_sync(self.original.release_job, worker_id, job, result) + async def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: + await to_thread.run_sync(self.original.release_job, worker_id, task_id, result) async def get_job_result(self, job_id: UUID) -> Optional[JobResult]: return await to_thread.run_sync(self.original.get_job_result, job_id) diff --git a/src/apscheduler/datastores/sync/memory.py b/src/apscheduler/datastores/sync/memory.py index c2a033e..7bb42d3 100644 --- a/src/apscheduler/datastores/sync/memory.py +++ b/src/apscheduler/datastores/sync/memory.py @@ -267,20 +267,20 @@ class MemoryDataStore(DataStore): return jobs - def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: # Delete the job - job_state = self._jobs_by_id.pop(job.id) - self._jobs_by_task_id[job.task_id].remove(job_state) + job_state = self._jobs_by_id.pop(result.job_id) + self._jobs_by_task_id[task_id].remove(job_state) index = self._find_job_index(job_state) del self._jobs[index] # Decrement the number of running jobs for this task - task_state = self._tasks.get(job.task_id) + task_state = self._tasks.get(task_id) if task_state is not None: task_state.running_jobs -= 1 # Record the result - self._job_results[job.id] = result + self._job_results[result.job_id] = result def get_job_result(self, job_id: UUID) -> Optional[JobResult]: return self._job_results.pop(job_id, None) diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py index 1f1f72c..7252c0f 100644 --- a/src/apscheduler/datastores/sync/mongodb.py +++ b/src/apscheduler/datastores/sync/mongodb.py @@ -138,10 +138,11 @@ class MongoDBDataStore(DataStore): def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]: schedules: list[Schedule] = [] filters = {'_id': {'$in': list(ids)}} if ids is not None else {} - cursor = self._schedules.find(filters, projection=['_id', 'serialized_data']).sort('_id') + cursor = self._schedules.find(filters).sort('_id') for document in cursor: + document['id'] = document.pop('_id') try: - schedule = self.serializer.deserialize(document['serialized_data']) + schedule = Schedule.unmarshal(self.serializer, document) except DeserializationError: self._logger.warning('Failed to deserialize schedule %r', document['_id']) continue @@ -152,15 +153,11 @@ class MongoDBDataStore(DataStore): def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: event: DataStoreEvent - serialized_data = self.serializer.serialize(schedule) - document = { - '_id': schedule.id, - 'task_id': schedule.task_id, - 'serialized_data': serialized_data, - 'next_fire_time': schedule.next_fire_time - } + document = schedule.marshal(self.serializer) + document['_id'] = document.pop('id') try: self._schedules.insert_one(document) + print(document) except DuplicateKeyError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None @@ -193,11 +190,11 @@ class MongoDBDataStore(DataStore): {'next_fire_time': {'$ne': None}, '$or': [{'acquired_until': {'$exists': False}}, {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] - }, - projection=['serialized_data'] + } ).sort('next_fire_time').limit(limit) for document in cursor: - schedule = self.serializer.deserialize(document['serialized_data']) + document['id'] = document.pop('_id') + schedule = Schedule.unmarshal(self.serializer, document) schedules.append(schedule) if schedules: @@ -221,7 +218,7 @@ class MongoDBDataStore(DataStore): filters = {'_id': schedule.id, 'acquired_by': scheduler_id} if schedule.next_fire_time is not None: try: - serialized_data = self.serializer.serialize(schedule) + serialized_trigger = self.serializer.serialize(schedule.trigger) except SerializationError: self._logger.exception('Error serializing schedule %r – ' 'removing from data store', schedule.id) @@ -235,8 +232,8 @@ class MongoDBDataStore(DataStore): 'acquired_until': True, }, '$set': { - 'next_fire_time': schedule.next_fire_time, - 'serialized_data': serialized_data + 'trigger': serialized_trigger, + 'next_fire_time': schedule.next_fire_time } } requests.append(UpdateOne(filters, update)) @@ -264,14 +261,8 @@ class MongoDBDataStore(DataStore): return None def add_job(self, job: Job) -> None: - serialized_data = self.serializer.serialize(job) - document = { - '_id': job.id, - 'serialized_data': serialized_data, - 'task_id': job.task_id, - 'created_at': datetime.now(timezone.utc), - 'tags': list(job.tags) - } + document = job.marshal(self.serializer) + document['_id'] = document.pop('id') self._jobs.insert_one(document) event = JobAdded(job_id=job.id, task_id=job.task_id, schedule_id=job.schedule_id, tags=job.tags) @@ -280,12 +271,13 @@ class MongoDBDataStore(DataStore): def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> list[Job]: jobs: list[Job] = [] filters = {'_id': {'$in': list(ids)}} if ids is not None else {} - cursor = self._jobs.find(filters, projection=['_id', 'serialized_data']).sort('_id') + cursor = self._jobs.find(filters).sort('_id') for document in cursor: + document['id'] = document.pop('_id') try: - job = self.serializer.deserialize(document['serialized_data']) + job = Job.unmarshal(self.serializer, document) except DeserializationError: - self._logger.warning('Failed to deserialize job %r', document['_id']) + self._logger.warning('Failed to deserialize job %r', document['id']) continue jobs.append(job) @@ -298,7 +290,6 @@ class MongoDBDataStore(DataStore): {'$or': [{'acquired_until': {'$exists': False}}, {'acquired_until': {'$lt': datetime.now(timezone.utc)}}] }, - projection=['task_id', 'serialized_data'], sort=[('created_at', ASCENDING)], limit=limit, session=session @@ -319,7 +310,8 @@ class MongoDBDataStore(DataStore): acquired_jobs: list[Job] = [] increments: dict[str, int] = defaultdict(lambda: 0) for document in documents: - job = self.serializer.deserialize(document['serialized_data']) + document['id'] = document.pop('_id') + job = Job.unmarshal(self.serializer, document) # Don't acquire the job if there are no free slots left slots_left = job_slots_left.get(job.task_id) @@ -350,27 +342,27 @@ class MongoDBDataStore(DataStore): return acquired_jobs - def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: with self.client.start_session() as session: # Insert the job result - now = datetime.now(timezone.utc) - document = { - '_id': job.id, - 'finished_at': now, - 'serialized_data': self.serializer.serialize(result) - } + document = result.marshal(self.serializer) + document['_id'] = document.pop('job_id') self._jobs_results.insert_one(document, session=session) # Decrement the running jobs counter self._tasks.find_one_and_update( - {'_id': job.task_id}, - {'$inc': {'running_jobs': -1}} + {'_id': task_id}, + {'$inc': {'running_jobs': -1}}, + session=session ) # Delete the job - self._jobs.delete_one({'_id': job.id}, session=session) + self._jobs.delete_one({'_id': result.job_id}, session=session) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: - document = self._jobs_results.find_one_and_delete( - filter={'_id': job_id}, projection=['serialized_data']) - return self.serializer.deserialize(document['serialized_data']) if document else None + document = self._jobs_results.find_one_and_delete({'_id': job_id}) + if document: + document['job_id'] = document.pop('_id') + return JobResult.unmarshal(self.serializer, document) + else: + return None diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py index 4914c1b..3b2c970 100644 --- a/src/apscheduler/datastores/sync/sqlalchemy.py +++ b/src/apscheduler/datastores/sync/sqlalchemy.py @@ -3,21 +3,21 @@ from __future__ import annotations from collections import defaultdict from datetime import datetime, timedelta, timezone from logging import Logger, getLogger -from typing import Any, Callable, Iterable, Optional, Tuple, Type +from typing import Any, Callable, Iterable, Optional, Type from uuid import UUID import attr from sqlalchemy import ( - TIMESTAMP, Column, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, and_, - bindparam, or_, select) -from sqlalchemy.engine import URL, Dialect + JSON, TIMESTAMP, Column, Enum, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, + and_, bindparam, or_, select) +from sqlalchemy.engine import URL, Dialect, Result from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.future import Engine, create_engine from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.elements import BindParameter, literal from ...abc import DataStore, Job, Schedule, Serializer -from ...enums import ConflictPolicy +from ...enums import CoalescePolicy, ConflictPolicy, JobOutcome from ...events import ( Event, EventHub, JobAdded, JobDeserializationFailed, ScheduleAdded, ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, @@ -34,12 +34,23 @@ class EmulatedUUID(TypeDecorator): cache_ok = True def process_bind_param(self, value, dialect: Dialect) -> Any: - return value.hex + return value.hex if value is not None else None def process_result_value(self, value: Any, dialect: Dialect): return UUID(value) if value else None +class EmulatedTimestampTZ(TypeDecorator): + impl = Unicode(32) + cache_ok = True + + def process_bind_param(self, value, dialect: Dialect) -> Any: + return value.isoformat() if value is not None else None + + def process_result_value(self, value: Any, dialect: Dialect): + return datetime.fromisoformat(value) if value is not None else None + + @reentrant @attr.define(eq=False) class SQLAlchemyDataStore(DataStore): @@ -64,7 +75,7 @@ class SQLAlchemyDataStore(DataStore): self.t_jobs = self._metadata.tables['jobs'] self.t_job_results = self._metadata.tables['job_results'] - # Find out if the dialect supports RETURNING + # Find out if the dialect supports UPDATE...RETURNING update = self.t_jobs.update().returning(self.t_jobs.c.id) try: update.compile(bind=self.engine) @@ -101,17 +112,13 @@ class SQLAlchemyDataStore(DataStore): self._events.__exit__(exc_type, exc_val, exc_tb) def get_table_definitions(self) -> MetaData: - if self.engine.dialect.name in ('mysql', 'mariadb'): - from sqlalchemy.dialects import mysql - timestamp_type = mysql.TIMESTAMP(fsp=6) - else: - timestamp_type = TIMESTAMP(timezone=True) - if self.engine.dialect.name == 'postgresql': from sqlalchemy.dialects import postgresql + timestamp_type = TIMESTAMP(timezone=True) job_id_type = postgresql.UUID(as_uuid=True) else: + timestamp_type = EmulatedTimestampTZ job_id_type = EmulatedUUID metadata = MetaData() @@ -135,8 +142,15 @@ class SQLAlchemyDataStore(DataStore): metadata, Column('id', Unicode(500), primary_key=True), Column('task_id', Unicode(500), nullable=False, index=True), - Column('serialized_data', LargeBinary, nullable=False), + Column('trigger', LargeBinary), + Column('args', LargeBinary), + Column('kwargs', LargeBinary), + Column('coalesce', Enum(CoalescePolicy), nullable=False), + Column('misfire_grace_time', Unicode(16)), + # Column('max_jitter', Unicode(16)), + Column('tags', JSON, nullable=False), Column('next_fire_time', timestamp_type, index=True), + Column('last_fire_time', timestamp_type), Column('acquired_by', Unicode(500)), Column('acquired_until', timestamp_type) ) @@ -145,8 +159,14 @@ class SQLAlchemyDataStore(DataStore): metadata, Column('id', job_id_type, primary_key=True), Column('task_id', Unicode(500), nullable=False, index=True), - Column('serialized_data', LargeBinary, nullable=False), + Column('args', LargeBinary, nullable=False), + Column('kwargs', LargeBinary, nullable=False), + Column('schedule_id', Unicode(500)), + Column('scheduled_fire_time', timestamp_type), + Column('start_deadline', timestamp_type), + Column('tags', JSON, nullable=False), Column('created_at', timestamp_type, nullable=False), + Column('started_at', timestamp_type), Column('acquired_by', Unicode(500)), Column('acquired_until', timestamp_type) ) @@ -154,30 +174,32 @@ class SQLAlchemyDataStore(DataStore): 'job_results', metadata, Column('job_id', job_id_type, primary_key=True), + Column('outcome', Enum(JobOutcome), nullable=False), Column('finished_at', timestamp_type, index=True), - Column('serialized_data', LargeBinary, nullable=False) + Column('exception', LargeBinary), + Column('return_value', LargeBinary) ) return metadata - def _deserialize_jobs(self, serialized_jobs: Iterable[Tuple[UUID, bytes]]) -> list[Job]: - jobs: list[Job] = [] - for job_id, serialized_data in serialized_jobs: + def _deserialize_schedules(self, result: Result) -> list[Schedule]: + schedules: list[Schedule] = [] + for row in result: try: - jobs.append(self.serializer.deserialize(serialized_data)) + schedules.append(Schedule.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: - self._events.publish(JobDeserializationFailed(job_id=job_id, exception=exc)) + self._events.publish( + ScheduleDeserializationFailed(schedule_id=row['id'], exception=exc)) - return jobs + return schedules - def _deserialize_schedules( - self, serialized_schedules: Iterable[Tuple[str, bytes]]) -> list[Schedule]: - jobs: list[Schedule] = [] - for schedule_id, serialized_data in serialized_schedules: + def _deserialize_jobs(self, result: Result) -> list[Job]: + jobs: list[Job] = [] + for row in result: try: - jobs.append(self.serializer.deserialize(serialized_data)) + jobs.append(Job.unmarshal(self.serializer, row._asdict())) except SerializationError as exc: self._events.publish( - ScheduleDeserializationFailed(schedule_id=schedule_id, exception=exc)) + JobDeserializationFailed(job_id=row['id'], exception=exc)) return jobs @@ -240,10 +262,8 @@ class SQLAlchemyDataStore(DataStore): def add_schedule(self, schedule: Schedule, conflict_policy: ConflictPolicy) -> None: event: Event - serialized_data = self.serializer.serialize(schedule) - insert = self.t_schedules.insert().\ - values(id=schedule.id, task_id=schedule.task_id, serialized_data=serialized_data, - next_fire_time=schedule.next_fire_time) + values = schedule.marshal(self.serializer) + insert = self.t_schedules.insert().values(**values) try: with self.engine.begin() as conn: conn.execute(insert) @@ -254,10 +274,10 @@ class SQLAlchemyDataStore(DataStore): if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None elif conflict_policy is ConflictPolicy.replace: + del values['id'] update = self.t_schedules.update().\ where(self.t_schedules.c.id == schedule.id).\ - values(serialized_data=serialized_data, - next_fire_time=schedule.next_fire_time) + values(**values) with self.engine.begin() as conn: conn.execute(update) @@ -280,8 +300,7 @@ class SQLAlchemyDataStore(DataStore): self._events.publish(ScheduleRemoved(schedule_id=schedule_id)) def get_schedules(self, ids: Optional[set[str]] = None) -> list[Schedule]: - query = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ - order_by(self.t_schedules.c.id) + query = self.t_schedules.select().order_by(self.t_schedules.c.id) if ids: query = query.where(self.t_schedules.c.id.in_(ids)) @@ -305,12 +324,11 @@ class SQLAlchemyDataStore(DataStore): where(self.t_schedules.c.id.in_(subselect)).\ values(acquired_by=scheduler_id, acquired_until=acquired_until) if self._supports_update_returning: - update = update.returning(self.t_schedules.c.id, - self.t_schedules.c.serialized_data) + update = update.returning(*self.t_schedules.columns) result = conn.execute(update) else: conn.execute(update) - query = select([self.t_schedules.c.id, self.t_schedules.c.serialized_data]).\ + query = self.t_schedules.select().\ where(and_(self.t_schedules.c.acquired_by == scheduler_id)) result = conn.execute(query) @@ -326,16 +344,16 @@ class SQLAlchemyDataStore(DataStore): for schedule in schedules: if schedule.next_fire_time is not None: try: - serialized_data = self.serializer.serialize(schedule) + serialized_trigger = self.serializer.serialize(schedule.trigger) except SerializationError: - self._logger.exception('Error serializing schedule %r – ' + self._logger.exception('Error serializing trigger for schedule %r – ' 'removing from data store', schedule.id) finished_schedule_ids.append(schedule.id) continue update_args.append({ 'p_id': schedule.id, - 'p_serialized_data': serialized_data, + 'p_trigger': serialized_trigger, 'p_next_fire_time': schedule.next_fire_time }) else: @@ -344,12 +362,12 @@ class SQLAlchemyDataStore(DataStore): # Update schedules that have a next fire time if update_args: p_id: BindParameter = bindparam('p_id') - p_serialized: BindParameter = bindparam('p_serialized_data') + p_trigger: BindParameter = bindparam('p_trigger') p_next_fire_time: BindParameter = bindparam('p_next_fire_time') update = self.t_schedules.update().\ where(and_(self.t_schedules.c.id == p_id, self.t_schedules.c.acquired_by == scheduler_id)).\ - values(serialized_data=p_serialized, next_fire_time=p_next_fire_time, + values(trigger=p_trigger, next_fire_time=p_next_fire_time, acquired_by=None, acquired_until=None) next_fire_times = {arg['p_id']: arg['p_next_fire_time'] for arg in update_args} if self._supports_update_returning: @@ -387,10 +405,8 @@ class SQLAlchemyDataStore(DataStore): return result.scalar() def add_job(self, job: Job) -> None: - now = datetime.now(timezone.utc) - serialized_data = self.serializer.serialize(job) - insert = self.t_jobs.insert().values(id=job.id, task_id=job.task_id, - created_at=now, serialized_data=serialized_data) + marshalled = job.marshal(self.serializer) + insert = self.t_jobs.insert().values(**marshalled) with self.engine.begin() as conn: conn.execute(insert) @@ -399,8 +415,7 @@ class SQLAlchemyDataStore(DataStore): self._events.publish(event) def get_jobs(self, ids: Optional[Iterable[UUID]] = None) -> list[Job]: - query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ - order_by(self.t_jobs.c.id) + query = self.t_jobs.select().order_by(self.t_jobs.c.id) if ids: job_ids = [job_id for job_id in ids] query = query.where(self.t_jobs.c.id.in_(job_ids)) @@ -413,7 +428,7 @@ class SQLAlchemyDataStore(DataStore): with self.engine.begin() as conn: now = datetime.now(timezone.utc) acquired_until = now + timedelta(seconds=self.lock_expiration_delay) - query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\ + query = self.t_jobs.select().\ join(self.t_tasks, self.t_tasks.c.id == self.t_jobs.c.task_id).\ where(or_(self.t_jobs.c.acquired_until.is_(None), self.t_jobs.c.acquired_until < now)).\ @@ -470,36 +485,33 @@ class SQLAlchemyDataStore(DataStore): return acquired_jobs - def release_job(self, worker_id: str, job: Job, result: Optional[JobResult]) -> None: + def release_job(self, worker_id: str, task_id: str, result: JobResult) -> None: with self.engine.begin() as conn: # Insert the job result - now = datetime.now(timezone.utc) - serialized_result = self.serializer.serialize(result) - insert = self.t_job_results.insert().\ - values(job_id=job.id, finished_at=now, serialized_data=serialized_result) + marshalled = result.marshal(self.serializer) + insert = self.t_job_results.insert().values(**marshalled) conn.execute(insert) # Decrement the running jobs counter update = self.t_tasks.update().\ values(running_jobs=self.t_tasks.c.running_jobs - 1).\ - where(self.t_tasks.c.id == job.task_id) + where(self.t_tasks.c.id == task_id) conn.execute(update) # Delete the job - delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id) + delete = self.t_jobs.delete().where(self.t_jobs.c.id == result.job_id) conn.execute(delete) def get_job_result(self, job_id: UUID) -> Optional[JobResult]: with self.engine.begin() as conn: # Retrieve the result - query = select(self.t_job_results.c.serialized_data).\ + query = self.t_job_results.select().\ where(self.t_job_results.c.job_id == job_id) - result = conn.execute(query) + row = conn.execute(query).fetchone() # Delete the result delete = self.t_job_results.delete().\ where(self.t_job_results.c.job_id == job_id) conn.execute(delete) - serialized_result = result.scalar() - return self.serializer.deserialize(serialized_result) if serialized_result else None + return JobResult.unmarshal(self.serializer, row._asdict()) if row else None diff --git a/src/apscheduler/structures.py b/src/apscheduler/structures.py index a25d411..bafb61c 100644 --- a/src/apscheduler/structures.py +++ b/src/apscheduler/structures.py @@ -1,6 +1,7 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone +from functools import partial from typing import Any, Callable, Optional from uuid import UUID, uuid4 @@ -16,8 +17,8 @@ class Task: id: str func: Callable = attr.field(eq=False, order=False) max_running_jobs: Optional[int] = attr.field(eq=False, order=False, default=None) - state: Any = None misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None) + state: Any = None def marshal(self, serializer: abc.Serializer) -> dict[str, Any]: marshalled = attr.asdict(self) @@ -43,21 +44,39 @@ class Schedule: kwargs: dict[str, Any] = attr.field(eq=False, order=False, factory=dict) coalesce: CoalescePolicy = attr.field(eq=False, order=False, default=CoalescePolicy.latest) misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None) + # max_jitter: Optional[timedelta] = attr.field(eq=False, order=False, default=None) tags: frozenset[str] = attr.field(eq=False, order=False, factory=frozenset) - next_fire_time: Optional[datetime] = attr.field(eq=False, order=False, init=False, - default=None) - last_fire_time: Optional[datetime] = attr.field(eq=False, order=False, init=False, - default=None) + next_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None) + last_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None) + acquired_by: Optional[str] = attr.field(eq=False, order=False, default=None) + acquired_until: Optional[datetime] = attr.field(eq=False, order=False, default=None) def marshal(self, serializer: abc.Serializer) -> dict[str, Any]: marshalled = attr.asdict(self) - marshalled['trigger_type'] = serializer.serialize(self.args) - marshalled['trigger_data'] = serializer.serialize(self.trigger) - marshalled['args'] = serializer.serialize(self.args) if self.args else None - marshalled['kwargs'] = serializer.serialize(self.kwargs) if self.kwargs else None + marshalled['trigger'] = serializer.serialize(self.trigger) + marshalled['args'] = serializer.serialize(self.args) + marshalled['kwargs'] = serializer.serialize(self.kwargs) + marshalled['coalesce'] = self.coalesce.name marshalled['tags'] = list(self.tags) + marshalled['misfire_grace_time'] = (self.misfire_grace_time.total_seconds() + if self.misfire_grace_time is not None else None) + if not self.acquired_by: + del marshalled['acquired_by'] + del marshalled['acquired_until'] + return marshalled + @classmethod + def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> Schedule: + marshalled['trigger'] = serializer.deserialize(marshalled['trigger']) + marshalled['args'] = serializer.deserialize(marshalled['args']) + marshalled['kwargs'] = serializer.deserialize(marshalled['kwargs']) + marshalled['tags'] = frozenset(marshalled['tags']) + if isinstance(marshalled['coalesce'], str): + marshalled['coalesce'] = CoalescePolicy.__members__[marshalled['coalesce']] + + return cls(**marshalled) + @property def next_deadline(self) -> Optional[datetime]: if self.next_fire_time and self.misfire_grace_time: @@ -76,27 +95,64 @@ class Job: scheduled_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None) start_deadline: Optional[datetime] = attr.field(eq=False, order=False, default=None) tags: frozenset[str] = attr.field(eq=False, order=False, factory=frozenset) - started_at: Optional[datetime] = attr.field(eq=False, order=False, init=False, default=None) + created_at: datetime = attr.field(eq=False, order=False, + factory=partial(datetime.now, timezone.utc)) + started_at: Optional[datetime] = attr.field(eq=False, order=False, default=None) + acquired_by: Optional[str] = attr.field(eq=False, order=False, default=None) + acquired_until: Optional[datetime] = attr.field(eq=False, order=False, default=None) def marshal(self, serializer: abc.Serializer) -> dict[str, Any]: marshalled = attr.asdict(self) - marshalled['args'] = serializer.serialize(self.args) if self.args else None - marshalled['kwargs'] = serializer.serialize(self.kwargs) if self.kwargs else None + marshalled['args'] = serializer.serialize(self.args) + marshalled['kwargs'] = serializer.serialize(self.kwargs) marshalled['tags'] = list(self.tags) + if not self.acquired_by: + del marshalled['acquired_by'] + del marshalled['acquired_until'] + return marshalled @classmethod - def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> Task: + def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> Job: for key in ('args', 'kwargs'): - if marshalled[key] is not None: - marshalled[key] = serializer.deserialize(marshalled[key]) + marshalled[key] = serializer.deserialize(marshalled[key]) marshalled['tags'] = frozenset(marshalled['tags']) return cls(**marshalled) -@attr.define(eq=False, order=False, frozen=True) +@attr.define(kw_only=True, frozen=True) class JobResult: - outcome: JobOutcome - exception: Optional[BaseException] = None - return_value: Any = None + job_id: UUID + outcome: JobOutcome = attr.field(eq=False, order=False) + finished_at: datetime = attr.field(eq=False, order=False, + factory=partial(datetime.now, timezone.utc)) + exception: Optional[BaseException] = attr.field(eq=False, order=False, default=None) + return_value: Any = attr.field(eq=False, order=False, default=None) + + def marshal(self, serializer: abc.Serializer) -> dict[str, Any]: + marshalled = attr.asdict(self) + marshalled['outcome'] = self.outcome.name + if self.outcome is JobOutcome.failure: + marshalled['exception'] = serializer.serialize(self.exception) + else: + del marshalled['exception'] + + if self.outcome is JobOutcome.success: + marshalled['return_value'] = serializer.serialize(self.return_value) + else: + del marshalled['return_value'] + + return marshalled + + @classmethod + def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> JobResult: + if isinstance(marshalled['outcome'], str): + marshalled['outcome'] = JobOutcome.__members__[marshalled['outcome']] + + if marshalled.get('exception'): + marshalled['exception'] = serializer.deserialize(marshalled['exception']) + elif marshalled.get('return_value'): + marshalled['return_value'] = serializer.deserialize(marshalled['return_value']) + + return cls(**marshalled) diff --git a/src/apscheduler/workers/async_.py b/src/apscheduler/workers/async_.py index c64d14c..2cea795 100644 --- a/src/apscheduler/workers/async_.py +++ b/src/apscheduler/workers/async_.py @@ -136,19 +136,19 @@ class AsyncWorker(EventSource): retval = await retval except get_cancelled_exc_class(): with CancelScope(shield=True): - result = JobResult(outcome=JobOutcome.cancelled) - await self.data_store.release_job(self.identity, job, result) + result = JobResult(job_id=job.id, outcome=JobOutcome.cancelled) + await self.data_store.release_job(self.identity, job.task_id, result) self._events.publish(JobCancelled.from_job(job, start_time)) except BaseException as exc: - result = JobResult(outcome=JobOutcome.failure, exception=exc) - await self.data_store.release_job(self.identity, job, result) + result = JobResult(job_id=job.id, outcome=JobOutcome.failure, exception=exc) + await self.data_store.release_job(self.identity, job.task_id, result) self._events.publish(JobFailed.from_exception(job, start_time, exc)) if not isinstance(exc, Exception): raise else: - result = JobResult(outcome=JobOutcome.success, return_value=retval) - await self.data_store.release_job(self.identity, job, result) + result = JobResult(job_id=job.id, outcome=JobOutcome.success, return_value=retval) + await self.data_store.release_job(self.identity, job.task_id, result) self._events.publish(JobCompleted.from_retval(job, start_time, retval)) finally: self._running_jobs.remove(job.id) diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py index 8e023f6..efb52af 100644 --- a/src/apscheduler/workers/sync.py +++ b/src/apscheduler/workers/sync.py @@ -134,14 +134,14 @@ class Worker(EventSource): try: retval = func(*job.args, **job.kwargs) except BaseException as exc: - result = JobResult(outcome=JobOutcome.failure, exception=exc) - self.data_store.release_job(self.identity, job, result) + result = JobResult(job_id=job.id, outcome=JobOutcome.failure, exception=exc) + self.data_store.release_job(self.identity, job.task_id, result) self._events.publish(JobFailed.from_exception(job, start_time, exc)) if not isinstance(exc, Exception): raise else: - result = JobResult(outcome=JobOutcome.success, return_value=retval) - self.data_store.release_job(self.identity, job, result) + result = JobResult(job_id=job.id, outcome=JobOutcome.success, return_value=retval) + self.data_store.release_job(self.identity, job.task_id, result) self._events.publish(JobCompleted.from_retval(job, start_time, retval)) finally: self._running_jobs.remove(job.id) diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 5ce7bcd..52f8349 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -35,7 +35,7 @@ def schedules() -> List[Schedule]: async def capture_events( store: AsyncDataStore, limit: int, event_types: Optional[Set[Type[Event]]] = None -) -> AsyncGenerator[List[Event], None, None]: +) -> AsyncGenerator[List[Event], None]: def listener(event: Event) -> None: events.append(event) if len(events) == limit: @@ -253,8 +253,9 @@ class TestAsyncStores: assert len(acquired) == 1 assert acquired[0].id == job.id - await store.release_job('worker_id', acquired[0], - JobResult(JobOutcome.success, return_value='foo')) + await store.release_job( + 'worker_id', acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.success, return_value='foo')) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.success assert result.exception is None @@ -275,8 +276,10 @@ class TestAsyncStores: assert len(acquired) == 1 assert acquired[0].id == job.id - await store.release_job('worker_id', acquired[0], - JobResult(JobOutcome.failure, exception=ValueError('foo'))) + await store.release_job( + 'worker_id', acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.failure, + exception=ValueError('foo'))) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.failure assert isinstance(result.exception, ValueError) @@ -298,8 +301,9 @@ class TestAsyncStores: assert len(acquired) == 1 assert acquired[0].id == job.id - await store.release_job('worker_id', acquired[0], - JobResult(JobOutcome.missed_start_deadline)) + await store.release_job( + 'worker_id', acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.missed_start_deadline)) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.missed_start_deadline assert result.exception is None @@ -320,7 +324,8 @@ class TestAsyncStores: assert len(acquired) == 1 assert acquired[0].id == job.id - await store.release_job('worker1', acquired[0], JobResult(JobOutcome.cancelled)) + await store.release_job('worker1', acquired[0].task_id, + JobResult(job_id=acquired[0].id, outcome=JobOutcome.cancelled)) result = await store.get_job_result(acquired[0].id) assert result.outcome is JobOutcome.cancelled assert result.exception is None @@ -372,7 +377,9 @@ class TestAsyncStores: assert [job.id for job in acquired_jobs] == [job.id for job in jobs[:2]] # Release one job, and the worker should be able to acquire the third job - await store.release_job('worker1', acquired_jobs[0], - JobResult(outcome=JobOutcome.success, return_value=None)) + await store.release_job( + 'worker1', acquired_jobs[0].task_id, + JobResult(job_id=acquired_jobs[0].id, outcome=JobOutcome.success, + return_value=None)) acquired_jobs = await store.acquire_jobs('worker1', 3) assert [job.id for job in acquired_jobs] == [jobs[2].id] |