diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-21 00:21:25 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-21 00:21:25 +0300 |
commit | 8326ac378e5b5f8e5cb2c45f20e0e1bdfa5075c0 (patch) | |
tree | ac96ee8a1cbc792cf56cf3534f5e3c65f0b5a9e7 | |
parent | 8b68b6c5d1c63faae1ba3769b6475b396328e3a3 (diff) | |
download | apscheduler-8326ac378e5b5f8e5cb2c45f20e0e1bdfa5075c0.tar.gz |
Implemented schedule-level jitter
Structures now keep enums, timedeltas and frozensets as-is. The MongoDB store was modified to use a custom type registry to handle this.
-rw-r--r-- | src/apscheduler/converters.py | 22 | ||||
-rw-r--r-- | src/apscheduler/datastores/mongodb.py | 36 | ||||
-rw-r--r-- | src/apscheduler/datastores/sqlalchemy.py | 24 | ||||
-rw-r--r-- | src/apscheduler/schedulers/async_.py | 40 | ||||
-rw-r--r-- | src/apscheduler/schedulers/sync.py | 39 | ||||
-rw-r--r-- | src/apscheduler/structures.py | 56 | ||||
-rw-r--r-- | tests/test_schedulers.py | 86 | ||||
-rw-r--r-- | tests/test_workers.py | 34 |
8 files changed, 267 insertions, 70 deletions
diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py index 7e8e590..c664bc0 100644 --- a/src/apscheduler/converters.py +++ b/src/apscheduler/converters.py @@ -1,6 +1,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta +from enum import Enum from typing import Optional from uuid import UUID @@ -18,8 +19,25 @@ def as_aware_datetime(value: datetime | str) -> Optional[datetime]: def as_uuid(value: UUID | str) -> UUID: - """Converts a string-formatted UUID to a UUID instance.""" + """Convert a string-formatted UUID to a UUID instance.""" if isinstance(value, str): return UUID(value) return value + + +def as_timedelta(value: timedelta | float | None) -> timedelta: + if isinstance(value, (float, int)): + return timedelta(seconds=value) + + return value + + +def as_enum(enum_class: type[Enum]): + def converter(value: enum_class | str): + if isinstance(value, str): + return enum_class.__members__[value] + + return value + + return converter diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index 5066c56..cef0f7e 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -1,21 +1,24 @@ from __future__ import annotations +import operator from collections import defaultdict from contextlib import ExitStack -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from logging import Logger, getLogger -from typing import ClassVar, Iterable, Optional +from typing import Any, Callable, ClassVar, Iterable, Optional from uuid import UUID import attr import pymongo from attr.validators import instance_of +from bson import CodecOptions +from bson.codec_options import TypeEncoder, TypeRegistry from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.errors import DuplicateKeyError from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer -from ..enums import ConflictPolicy +from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome from ..eventbrokers.local import LocalEventBroker from ..events import ( DataStoreEvent, JobAcquired, JobAdded, JobReleased, ScheduleAdded, ScheduleRemoved, @@ -27,6 +30,19 @@ from ..structures import JobResult, Task from ..util import reentrant +class CustomEncoder(TypeEncoder): + def __init__(self, python_type: type, encoder: Callable): + self._python_type = python_type + self._encoder = encoder + + @property + def python_type(self) -> type: + return self._python_type + + def transform_python(self, value: Any) -> Any: + return self._encoder(value) + + @reentrant @attr.define(eq=False) class MongoDBDataStore(DataStore): @@ -45,13 +61,15 @@ class MongoDBDataStore(DataStore): _events: EventBroker = attr.field(init=False, factory=LocalEventBroker) _local_tasks: dict[str, Task] = attr.field(init=False, factory=dict) - @client.validator - def validate_client(self, attribute: attr.Attribute, value: MongoClient) -> None: - if not value.delegate.codec_options.tz_aware: - raise ValueError('MongoDB client must have tz_aware set to True') - def __attrs_post_init__(self) -> None: - database = self.client[self.database] + type_registry = TypeRegistry([ + CustomEncoder(timedelta, timedelta.total_seconds), + CustomEncoder(ConflictPolicy, operator.attrgetter('name')), + CustomEncoder(CoalescePolicy, operator.attrgetter('name')), + CustomEncoder(JobOutcome, operator.attrgetter('name')) + ]) + codec_options = CodecOptions(tz_aware=True, type_registry=type_registry) + database = self.client.get_database(self.database, codec_options=codec_options) self._tasks: Collection = database['tasks'] self._schedules: Collection = database['schedules'] self._jobs: Collection = database['jobs'] diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index 31e60cc..0bfc211 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -8,8 +8,8 @@ from uuid import UUID import attr from sqlalchemy import ( - JSON, TIMESTAMP, Column, Enum, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, - and_, bindparam, or_, select) + JSON, TIMESTAMP, BigInteger, Column, Enum, Integer, LargeBinary, MetaData, Table, + TypeDecorator, Unicode, and_, bindparam, or_, select) from sqlalchemy.engine import URL, Dialect, Result from sqlalchemy.exc import CompileError, IntegrityError from sqlalchemy.future import Engine, create_engine @@ -52,6 +52,17 @@ class EmulatedTimestampTZ(TypeDecorator): return datetime.fromisoformat(value) if value is not None else None +class EmulatedInterval(TypeDecorator): + impl = BigInteger() + cache_ok = True + + def process_bind_param(self, value, dialect: Dialect) -> Any: + return value.total_seconds() if value is not None else None + + def process_result_value(self, value: Any, dialect: Dialect): + return timedelta(seconds=value) if value is not None else None + + @attr.define(kw_only=True, eq=False) class _BaseSQLAlchemyDataStore: schema: Optional[str] = attr.field(default=None) @@ -88,9 +99,11 @@ class _BaseSQLAlchemyDataStore: timestamp_type = TIMESTAMP(timezone=True) job_id_type = postgresql.UUID(as_uuid=True) + interval_type = postgresql.INTERVAL(precision=6) else: timestamp_type = EmulatedTimestampTZ job_id_type = EmulatedUUID + interval_type = EmulatedInterval metadata = MetaData() Table( @@ -105,7 +118,7 @@ class _BaseSQLAlchemyDataStore: Column('func', Unicode(500), nullable=False), Column('state', LargeBinary), Column('max_running_jobs', Integer), - Column('misfire_grace_time', Unicode(16)), + Column('misfire_grace_time', interval_type), Column('running_jobs', Integer, nullable=False, server_default=literal(0)) ) Table( @@ -117,8 +130,8 @@ class _BaseSQLAlchemyDataStore: Column('args', LargeBinary), Column('kwargs', LargeBinary), Column('coalesce', Enum(CoalescePolicy), nullable=False), - Column('misfire_grace_time', Unicode(16)), - # Column('max_jitter', Unicode(16)), + Column('misfire_grace_time', interval_type), + Column('max_jitter', interval_type), Column('tags', JSON, nullable=False), Column('next_fire_time', timestamp_type, index=True), Column('last_fire_time', timestamp_type), @@ -134,6 +147,7 @@ class _BaseSQLAlchemyDataStore: Column('kwargs', LargeBinary, nullable=False), Column('schedule_id', Unicode(500)), Column('scheduled_fire_time', timestamp_type), + Column('jitter', interval_type), Column('start_deadline', timestamp_type), Column('tags', JSON, nullable=False), Column('created_at', timestamp_type, nullable=False), diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py index 6900460..a08c090 100644 --- a/src/apscheduler/schedulers/async_.py +++ b/src/apscheduler/schedulers/async_.py @@ -2,6 +2,7 @@ from __future__ import annotations import os import platform +import random from contextlib import AsyncExitStack from datetime import datetime, timedelta, timezone from logging import Logger, getLogger @@ -24,6 +25,9 @@ from ..marshalling import callable_to_ref from ..structures import JobResult, Task from ..workers.async_ import AsyncWorker +_microsecond_delta = timedelta(microseconds=1) +_zero_timedelta = timedelta() + class AsyncScheduler: """An asynchronous (AnyIO based) scheduler implementation.""" @@ -96,7 +100,8 @@ class AsyncScheduler: self, func_or_task_id: str | Callable, trigger: Trigger, *, id: Optional[str] = None, args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None, coalesce: CoalescePolicy = CoalescePolicy.latest, - misfire_grace_time: float | timedelta | None = None, tags: Optional[Iterable[str]] = None, + misfire_grace_time: float | timedelta | None = None, + max_jitter: float | timedelta | None = None, tags: Optional[Iterable[str]] = None, conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing ) -> str: id = id or str(uuid4()) @@ -113,13 +118,18 @@ class AsyncScheduler: task = await self.data_store.get_task(func_or_task_id) schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs, - coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags) + coalesce=coalesce, misfire_grace_time=misfire_grace_time, + max_jitter=max_jitter, tags=tags) schedule.next_fire_time = trigger.next() await self.data_store.add_schedule(schedule, conflict_policy) self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task, trigger, schedule.next_fire_time) return schedule.id + async def get_schedule(self, id: str) -> Schedule: + schedules = await self.data_store.get_schedules({id}) + return schedules[0] + async def remove_schedule(self, schedule_id: str) -> None: await self.data_store.remove_schedules({schedule_id}) @@ -143,7 +153,7 @@ class AsyncScheduler: else: task = await self.data_store.get_task(func_or_task_id) - job = Job(task_id=task.id, args=args, kwargs=kwargs, tags=tags) + job = Job(task_id=task.id, args=args or (), kwargs=kwargs or {}, tags=tags or frozenset()) await self.data_store.add_job(job) return job.id @@ -248,11 +258,31 @@ class AsyncScheduler: fire_times[0] = fire_time # Add one or more jobs to the job queue - for fire_time in fire_times: + max_jitter = schedule.max_jitter.total_seconds() if schedule.max_jitter else 0 + for i, fire_time in enumerate(fire_times): + # Calculate a jitter if max_jitter > 0 + jitter = _zero_timedelta + if max_jitter: + if i + 1 < len(fire_times): + next_fire_time = fire_times[i + 1] + else: + next_fire_time = schedule.next_fire_time + + if next_fire_time is not None: + # Jitter must never be so high that it would cause a fire time to + # equal or exceed the next fire time + jitter_s = min([ + max_jitter, + (next_fire_time - fire_time + - _microsecond_delta).total_seconds() + ]) + jitter = timedelta(seconds=random.uniform(0, jitter_s)) + fire_time += jitter + schedule.last_fire_time = fire_time job = Job(task_id=schedule.task_id, args=schedule.args, kwargs=schedule.kwargs, schedule_id=schedule.id, - scheduled_fire_time=fire_time, + scheduled_fire_time=fire_time, jitter=jitter, start_deadline=schedule.next_deadline, tags=schedule.tags) await self.data_store.add_job(job) diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py index 1525bea..905efac 100644 --- a/src/apscheduler/schedulers/sync.py +++ b/src/apscheduler/schedulers/sync.py @@ -2,6 +2,7 @@ from __future__ import annotations import os import platform +import random import threading from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import ExitStack @@ -21,6 +22,9 @@ from ..marshalling import callable_to_ref from ..structures import Job, JobResult, Schedule, Task from ..workers.sync import Worker +_microsecond_delta = timedelta(microseconds=1) +_zero_timedelta = timedelta() + class Scheduler: """A synchronous scheduler implementation.""" @@ -96,7 +100,7 @@ class Scheduler: args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None, coalesce: CoalescePolicy = CoalescePolicy.latest, misfire_grace_time: float | timedelta | None = None, - tags: Optional[Iterable[str]] = None, + max_jitter: float | timedelta | None = None, tags: Optional[Iterable[str]] = None, conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing ) -> str: id = id or str(uuid4()) @@ -113,13 +117,18 @@ class Scheduler: task = self.data_store.get_task(func_or_task_id) schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs, - coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags) + coalesce=coalesce, misfire_grace_time=misfire_grace_time, + max_jitter=max_jitter, tags=tags) schedule.next_fire_time = trigger.next() self.data_store.add_schedule(schedule, conflict_policy) self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task, trigger, schedule.next_fire_time) return schedule.id + def get_schedule(self, id: str) -> Schedule: + schedules = self.data_store.get_schedules({id}) + return schedules[0] + def remove_schedule(self, schedule_id: str) -> None: self.data_store.remove_schedules({schedule_id}) @@ -143,7 +152,7 @@ class Scheduler: else: task = self.data_store.get_task(func_or_task_id) - job = Job(task_id=task.id, args=args, kwargs=kwargs, tags=tags) + job = Job(task_id=task.id, args=args or (), kwargs=kwargs or {}, tags=tags or frozenset()) self.data_store.add_job(job) return job.id @@ -247,11 +256,31 @@ class Scheduler: fire_times[0] = fire_time # Add one or more jobs to the job queue - for fire_time in fire_times: + max_jitter = schedule.max_jitter.total_seconds() if schedule.max_jitter else 0 + for i, fire_time in enumerate(fire_times): + # Calculate a jitter if max_jitter > 0 + jitter = _zero_timedelta + if max_jitter: + if i + 1 < len(fire_times): + next_fire_time = fire_times[i + 1] + else: + next_fire_time = schedule.next_fire_time + + if next_fire_time is not None: + # Jitter must never be so high that it would cause a fire time to + # equal or exceed the next fire time + jitter_s = min([ + max_jitter, + (next_fire_time - fire_time + - _microsecond_delta).total_seconds() + ]) + jitter = timedelta(seconds=random.uniform(0, jitter_s)) + fire_time += jitter + schedule.last_fire_time = fire_time job = Job(task_id=schedule.task_id, args=schedule.args, kwargs=schedule.kwargs, schedule_id=schedule.id, - scheduled_fire_time=fire_time, + scheduled_fire_time=fire_time, jitter=jitter, start_deadline=schedule.next_deadline, tags=schedule.tags) self.data_store.add_job(job) diff --git a/src/apscheduler/structures.py b/src/apscheduler/structures.py index c805ef0..a1e653b 100644 --- a/src/apscheduler/structures.py +++ b/src/apscheduler/structures.py @@ -6,9 +6,9 @@ from typing import Any, Callable, Optional from uuid import UUID, uuid4 import attr -from attr.converters import default_if_none from . import abc +from .converters import as_enum, as_timedelta from .enums import CoalescePolicy, JobOutcome from .marshalling import callable_from_ref, callable_to_ref @@ -41,12 +41,15 @@ class Schedule: id: str task_id: str = attr.field(eq=False, order=False) trigger: abc.Trigger = attr.field(eq=False, order=False) - args: tuple = attr.field(eq=False, order=False, default=()) - kwargs: dict[str, Any] = attr.field(eq=False, order=False, factory=dict) - coalesce: CoalescePolicy = attr.field(eq=False, order=False, default=CoalescePolicy.latest) - misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None) - # max_jitter: Optional[timedelta] = attr.field(eq=False, order=False, default=None) - tags: frozenset[str] = attr.field(eq=False, order=False, factory=frozenset) + args: tuple = attr.field(eq=False, order=False, converter=tuple, default=()) + kwargs: dict[str, Any] = attr.field(eq=False, order=False, converter=dict, default=()) + coalesce: CoalescePolicy = attr.field(eq=False, order=False, default=CoalescePolicy.latest, + converter=as_enum(CoalescePolicy)) + misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None, + converter=as_timedelta) + max_jitter: Optional[timedelta] = attr.field(eq=False, order=False, converter=as_timedelta, + default=None) + tags: frozenset[str] = attr.field(eq=False, order=False, converter=frozenset, default=()) next_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None) last_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None) acquired_by: Optional[str] = attr.field(eq=False, order=False, default=None) @@ -57,10 +60,6 @@ class Schedule: marshalled['trigger'] = serializer.serialize(self.trigger) marshalled['args'] = serializer.serialize(self.args) marshalled['kwargs'] = serializer.serialize(self.kwargs) - marshalled['coalesce'] = self.coalesce.name - marshalled['tags'] = list(self.tags) - marshalled['misfire_grace_time'] = (self.misfire_grace_time.total_seconds() - if self.misfire_grace_time is not None else None) if not self.acquired_by: del marshalled['acquired_by'] del marshalled['acquired_until'] @@ -72,10 +71,6 @@ class Schedule: marshalled['trigger'] = serializer.deserialize(marshalled['trigger']) marshalled['args'] = serializer.deserialize(marshalled['args']) marshalled['kwargs'] = serializer.deserialize(marshalled['kwargs']) - marshalled['tags'] = frozenset(marshalled['tags']) - if isinstance(marshalled['coalesce'], str): - marshalled['coalesce'] = CoalescePolicy.__members__[marshalled['coalesce']] - return cls(**marshalled) @property @@ -90,25 +85,32 @@ class Schedule: class Job: id: UUID = attr.field(factory=uuid4) task_id: str = attr.field(eq=False, order=False) - args: tuple = attr.field(eq=False, order=False, converter=default_if_none(())) - kwargs: dict[str, Any] = attr.field( - eq=False, order=False, converter=default_if_none(factory=dict)) + args: tuple = attr.field(eq=False, order=False, converter=tuple, default=()) + kwargs: dict[str, Any] = attr.field(eq=False, order=False, converter=dict, default=()) schedule_id: Optional[str] = attr.field(eq=False, order=False, default=None) scheduled_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None) + jitter: timedelta = attr.field(eq=False, order=False, converter=as_timedelta, + factory=timedelta) start_deadline: Optional[datetime] = attr.field(eq=False, order=False, default=None) - tags: frozenset[str] = attr.field( - eq=False, order=False, converter=default_if_none(factory=frozenset)) + tags: frozenset[str] = attr.field(eq=False, order=False, converter=frozenset, default=()) created_at: datetime = attr.field(eq=False, order=False, factory=partial(datetime.now, timezone.utc)) started_at: Optional[datetime] = attr.field(eq=False, order=False, default=None) acquired_by: Optional[str] = attr.field(eq=False, order=False, default=None) acquired_until: Optional[datetime] = attr.field(eq=False, order=False, default=None) + @property + def original_scheduled_time(self) -> Optional[datetime]: + """The scheduled time without any jitter included.""" + if self.scheduled_fire_time is None: + return None + + return self.scheduled_fire_time - self.jitter + def marshal(self, serializer: abc.Serializer) -> dict[str, Any]: marshalled = attr.asdict(self) marshalled['args'] = serializer.serialize(self.args) marshalled['kwargs'] = serializer.serialize(self.kwargs) - marshalled['tags'] = list(self.tags) if not self.acquired_by: del marshalled['acquired_by'] del marshalled['acquired_until'] @@ -117,17 +119,15 @@ class Job: @classmethod def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> Job: - for key in ('args', 'kwargs'): - marshalled[key] = serializer.deserialize(marshalled[key]) - - marshalled['tags'] = frozenset(marshalled['tags']) + marshalled['args'] = serializer.deserialize(marshalled['args']) + marshalled['kwargs'] = serializer.deserialize(marshalled['kwargs']) return cls(**marshalled) @attr.define(kw_only=True, frozen=True) class JobResult: job_id: UUID - outcome: JobOutcome = attr.field(eq=False, order=False) + outcome: JobOutcome = attr.field(eq=False, order=False, converter=as_enum(JobOutcome)) finished_at: datetime = attr.field(eq=False, order=False, factory=partial(datetime.now, timezone.utc)) exception: Optional[BaseException] = attr.field(eq=False, order=False, default=None) @@ -135,7 +135,6 @@ class JobResult: def marshal(self, serializer: abc.Serializer) -> dict[str, Any]: marshalled = attr.asdict(self) - marshalled['outcome'] = self.outcome.name if self.outcome is JobOutcome.error: marshalled['exception'] = serializer.serialize(self.exception) else: @@ -150,9 +149,6 @@ class JobResult: @classmethod def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> JobResult: - if isinstance(marshalled['outcome'], str): - marshalled['outcome'] = JobOutcome.__members__[marshalled['outcome']] - if marshalled.get('exception'): marshalled['exception'] = serializer.deserialize(marshalled['exception']) elif marshalled.get('return_value'): diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 5f1a7df..240b7bf 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -1,10 +1,14 @@ +import sys import threading import time -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone +from typing import Optional +from uuid import UUID import anyio import pytest from anyio import fail_after +from pytest_mock import MockerFixture from apscheduler.enums import JobOutcome from apscheduler.events import ( @@ -13,6 +17,12 @@ from apscheduler.exceptions import JobLookupError from apscheduler.schedulers.async_ import AsyncScheduler from apscheduler.schedulers.sync import Scheduler from apscheduler.triggers.date import DateTrigger +from apscheduler.triggers.interval import IntervalTrigger + +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo +else: + from backports.zoneinfo import ZoneInfo pytestmark = pytest.mark.anyio @@ -83,6 +93,44 @@ class TestAsyncScheduler: # There should be no more events on the list assert not received_events + @pytest.mark.parametrize('max_jitter, expected_upper_bound', [ + pytest.param(2, 2, id='within'), + pytest.param(4, 2.999999, id='exceed') + ]) + async def test_jitter(self, mocker: MockerFixture, timezone: ZoneInfo, max_jitter: float, + expected_upper_bound: float) -> None: + def job_added_listener(event: Event) -> None: + nonlocal job_id + assert isinstance(event, JobAdded) + job_id = event.job_id + job_added_event.set() + + jitter = 1.569374 + orig_start_time = datetime.now(timezone) - timedelta(seconds=1) + fake_uniform = mocker.patch('random.uniform') + fake_uniform.configure_mock(side_effect=lambda a, b: jitter) + async with AsyncScheduler(start_worker=False) as scheduler: + trigger = IntervalTrigger(seconds=3, start_time=orig_start_time) + job_added_event = anyio.Event() + job_id: Optional[UUID] = None + scheduler.events.subscribe(job_added_listener, {JobAdded}) + schedule_id = await scheduler.add_schedule(dummy_async_job, trigger, + max_jitter=max_jitter) + schedule = await scheduler.get_schedule(schedule_id) + assert schedule.max_jitter == timedelta(seconds=max_jitter) + + # Wait for the job to be added + with fail_after(3): + await job_added_event.wait() + + fake_uniform.assert_called_once_with(0, expected_upper_bound) + + # Check that the job was created with the proper amount of jitter in its scheduled time + jobs = await scheduler.data_store.get_jobs({job_id}) + assert jobs[0].jitter == timedelta(seconds=jitter) + assert jobs[0].scheduled_fire_time == orig_start_time + timedelta(seconds=jitter) + assert jobs[0].original_scheduled_time == orig_start_time + async def test_get_job_result_success(self) -> None: async with AsyncScheduler() as scheduler: job_id = await scheduler.add_job(dummy_async_job, kwargs={'delay': 0.2}) @@ -165,6 +213,42 @@ class TestSyncScheduler: # There should be no more events on the list assert not received_events + @pytest.mark.parametrize('max_jitter, expected_upper_bound', [ + pytest.param(2, 2, id='within'), + pytest.param(4, 2.999999, id='exceed') + ]) + def test_jitter(self, mocker: MockerFixture, timezone: ZoneInfo, max_jitter: float, + expected_upper_bound: float) -> None: + def job_added_listener(event: Event) -> None: + nonlocal job_id + assert isinstance(event, JobAdded) + job_id = event.job_id + job_added_event.set() + + jitter = 1.569374 + orig_start_time = datetime.now(timezone) - timedelta(seconds=1) + fake_uniform = mocker.patch('random.uniform') + fake_uniform.configure_mock(side_effect=lambda a, b: jitter) + with Scheduler(start_worker=False) as scheduler: + trigger = IntervalTrigger(seconds=3, start_time=orig_start_time) + job_added_event = threading.Event() + job_id: Optional[UUID] = None + scheduler.events.subscribe(job_added_listener, {JobAdded}) + schedule_id = scheduler.add_schedule(dummy_async_job, trigger, max_jitter=max_jitter) + schedule = scheduler.get_schedule(schedule_id) + assert schedule.max_jitter == timedelta(seconds=max_jitter) + + # Wait for the job to be added + job_added_event.wait(3) + + fake_uniform.assert_called_once_with(0, expected_upper_bound) + + # Check that the job was created with the proper amount of jitter in its scheduled time + jobs = scheduler.data_store.get_jobs({job_id}) + assert jobs[0].jitter == timedelta(seconds=jitter) + assert jobs[0].scheduled_fire_time == orig_start_time + timedelta(seconds=jitter) + assert jobs[0].original_scheduled_time == orig_start_time + def test_get_job_result(self) -> None: with Scheduler() as scheduler: job_id = scheduler.add_job(dummy_sync_job) diff --git a/tests/test_workers.py b/tests/test_workers.py index 872cf34..f1f020e 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -77,8 +77,7 @@ class TestAsyncWorker: received_event = received_events.pop(0) assert isinstance(received_event, JobAcquired) assert received_event.job_id == job.id - assert received_event.task_id == 'task_id' - assert received_event.schedule_id is None + assert received_event.worker_id == worker.identity received_event = received_events.pop(0) if fail: @@ -100,7 +99,7 @@ class TestAsyncWorker: async def test_run_deadline_missed(self) -> None: def listener(received_event: Event): received_events.append(received_event) - if len(received_events) == 4: + if len(received_events) == 5: event.set() scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) @@ -134,13 +133,18 @@ class TestAsyncWorker: assert received_event.task_id == 'task_id' assert received_event.schedule_id == 'foo' - # Then the deadline was missed + # The worker acquired the job + received_event = received_events.pop(0) + assert isinstance(received_event, JobAcquired) + assert received_event.job_id == job.id + assert received_event.worker_id == worker.identity + + # The worker determined that the deadline has been missed received_event = received_events.pop(0) assert isinstance(received_event, JobReleased) assert received_event.outcome is JobOutcome.missed_start_deadline assert received_event.job_id == job.id - assert received_event.task_id == 'task_id' - assert received_event.schedule_id == 'foo' + assert received_event.worker_id == worker.identity # Finally, the worker was stopped received_event = received_events.pop(0) @@ -189,8 +193,7 @@ class TestSyncWorker: received_event = received_events.pop(0) assert isinstance(received_event, JobAcquired) assert received_event.job_id == job.id - assert received_event.task_id == 'task_id' - assert received_event.schedule_id is None + assert received_event.worker_id == worker.identity received_event = received_events.pop(0) if fail: @@ -212,7 +215,7 @@ class TestSyncWorker: def test_run_deadline_missed(self) -> None: def listener(worker_event: Event): received_events.append(worker_event) - if len(received_events) == 4: + if len(received_events) == 5: event.set() scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc) @@ -227,7 +230,7 @@ class TestSyncWorker: scheduled_fire_time=scheduled_start_time, start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc)) worker.data_store.add_job(job) - event.wait(5) + event.wait(3) # The worker was first started received_event = received_events.pop(0) @@ -245,13 +248,18 @@ class TestSyncWorker: assert received_event.task_id == 'task_id' assert received_event.schedule_id == 'foo' - # Then the deadline was missed + # The worker acquired the job + received_event = received_events.pop(0) + assert isinstance(received_event, JobAcquired) + assert received_event.job_id == job.id + assert received_event.worker_id == worker.identity + + # The worker determined that the deadline has been missed received_event = received_events.pop(0) assert isinstance(received_event, JobReleased) assert received_event.outcome is JobOutcome.missed_start_deadline assert received_event.job_id == job.id - assert received_event.task_id == 'task_id' - assert received_event.schedule_id == 'foo' + assert received_event.worker_id == worker.identity # Finally, the worker was stopped received_event = received_events.pop(0) |