summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-21 00:21:25 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-21 00:21:25 +0300
commit8326ac378e5b5f8e5cb2c45f20e0e1bdfa5075c0 (patch)
treeac96ee8a1cbc792cf56cf3534f5e3c65f0b5a9e7
parent8b68b6c5d1c63faae1ba3769b6475b396328e3a3 (diff)
downloadapscheduler-8326ac378e5b5f8e5cb2c45f20e0e1bdfa5075c0.tar.gz
Implemented schedule-level jitter
Structures now keep enums, timedeltas and frozensets as-is. The MongoDB store was modified to use a custom type registry to handle this.
-rw-r--r--src/apscheduler/converters.py22
-rw-r--r--src/apscheduler/datastores/mongodb.py36
-rw-r--r--src/apscheduler/datastores/sqlalchemy.py24
-rw-r--r--src/apscheduler/schedulers/async_.py40
-rw-r--r--src/apscheduler/schedulers/sync.py39
-rw-r--r--src/apscheduler/structures.py56
-rw-r--r--tests/test_schedulers.py86
-rw-r--r--tests/test_workers.py34
8 files changed, 267 insertions, 70 deletions
diff --git a/src/apscheduler/converters.py b/src/apscheduler/converters.py
index 7e8e590..c664bc0 100644
--- a/src/apscheduler/converters.py
+++ b/src/apscheduler/converters.py
@@ -1,6 +1,7 @@
from __future__ import annotations
-from datetime import datetime
+from datetime import datetime, timedelta
+from enum import Enum
from typing import Optional
from uuid import UUID
@@ -18,8 +19,25 @@ def as_aware_datetime(value: datetime | str) -> Optional[datetime]:
def as_uuid(value: UUID | str) -> UUID:
- """Converts a string-formatted UUID to a UUID instance."""
+ """Convert a string-formatted UUID to a UUID instance."""
if isinstance(value, str):
return UUID(value)
return value
+
+
+def as_timedelta(value: timedelta | float | None) -> timedelta:
+ if isinstance(value, (float, int)):
+ return timedelta(seconds=value)
+
+ return value
+
+
+def as_enum(enum_class: type[Enum]):
+ def converter(value: enum_class | str):
+ if isinstance(value, str):
+ return enum_class.__members__[value]
+
+ return value
+
+ return converter
diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py
index 5066c56..cef0f7e 100644
--- a/src/apscheduler/datastores/mongodb.py
+++ b/src/apscheduler/datastores/mongodb.py
@@ -1,21 +1,24 @@
from __future__ import annotations
+import operator
from collections import defaultdict
from contextlib import ExitStack
-from datetime import datetime, timezone
+from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
-from typing import ClassVar, Iterable, Optional
+from typing import Any, Callable, ClassVar, Iterable, Optional
from uuid import UUID
import attr
import pymongo
from attr.validators import instance_of
+from bson import CodecOptions
+from bson.codec_options import TypeEncoder, TypeRegistry
from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne
from pymongo.collection import Collection
from pymongo.errors import DuplicateKeyError
from ..abc import DataStore, EventBroker, EventSource, Job, Schedule, Serializer
-from ..enums import ConflictPolicy
+from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome
from ..eventbrokers.local import LocalEventBroker
from ..events import (
DataStoreEvent, JobAcquired, JobAdded, JobReleased, ScheduleAdded, ScheduleRemoved,
@@ -27,6 +30,19 @@ from ..structures import JobResult, Task
from ..util import reentrant
+class CustomEncoder(TypeEncoder):
+ def __init__(self, python_type: type, encoder: Callable):
+ self._python_type = python_type
+ self._encoder = encoder
+
+ @property
+ def python_type(self) -> type:
+ return self._python_type
+
+ def transform_python(self, value: Any) -> Any:
+ return self._encoder(value)
+
+
@reentrant
@attr.define(eq=False)
class MongoDBDataStore(DataStore):
@@ -45,13 +61,15 @@ class MongoDBDataStore(DataStore):
_events: EventBroker = attr.field(init=False, factory=LocalEventBroker)
_local_tasks: dict[str, Task] = attr.field(init=False, factory=dict)
- @client.validator
- def validate_client(self, attribute: attr.Attribute, value: MongoClient) -> None:
- if not value.delegate.codec_options.tz_aware:
- raise ValueError('MongoDB client must have tz_aware set to True')
-
def __attrs_post_init__(self) -> None:
- database = self.client[self.database]
+ type_registry = TypeRegistry([
+ CustomEncoder(timedelta, timedelta.total_seconds),
+ CustomEncoder(ConflictPolicy, operator.attrgetter('name')),
+ CustomEncoder(CoalescePolicy, operator.attrgetter('name')),
+ CustomEncoder(JobOutcome, operator.attrgetter('name'))
+ ])
+ codec_options = CodecOptions(tz_aware=True, type_registry=type_registry)
+ database = self.client.get_database(self.database, codec_options=codec_options)
self._tasks: Collection = database['tasks']
self._schedules: Collection = database['schedules']
self._jobs: Collection = database['jobs']
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py
index 31e60cc..0bfc211 100644
--- a/src/apscheduler/datastores/sqlalchemy.py
+++ b/src/apscheduler/datastores/sqlalchemy.py
@@ -8,8 +8,8 @@ from uuid import UUID
import attr
from sqlalchemy import (
- JSON, TIMESTAMP, Column, Enum, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode,
- and_, bindparam, or_, select)
+ JSON, TIMESTAMP, BigInteger, Column, Enum, Integer, LargeBinary, MetaData, Table,
+ TypeDecorator, Unicode, and_, bindparam, or_, select)
from sqlalchemy.engine import URL, Dialect, Result
from sqlalchemy.exc import CompileError, IntegrityError
from sqlalchemy.future import Engine, create_engine
@@ -52,6 +52,17 @@ class EmulatedTimestampTZ(TypeDecorator):
return datetime.fromisoformat(value) if value is not None else None
+class EmulatedInterval(TypeDecorator):
+ impl = BigInteger()
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect: Dialect) -> Any:
+ return value.total_seconds() if value is not None else None
+
+ def process_result_value(self, value: Any, dialect: Dialect):
+ return timedelta(seconds=value) if value is not None else None
+
+
@attr.define(kw_only=True, eq=False)
class _BaseSQLAlchemyDataStore:
schema: Optional[str] = attr.field(default=None)
@@ -88,9 +99,11 @@ class _BaseSQLAlchemyDataStore:
timestamp_type = TIMESTAMP(timezone=True)
job_id_type = postgresql.UUID(as_uuid=True)
+ interval_type = postgresql.INTERVAL(precision=6)
else:
timestamp_type = EmulatedTimestampTZ
job_id_type = EmulatedUUID
+ interval_type = EmulatedInterval
metadata = MetaData()
Table(
@@ -105,7 +118,7 @@ class _BaseSQLAlchemyDataStore:
Column('func', Unicode(500), nullable=False),
Column('state', LargeBinary),
Column('max_running_jobs', Integer),
- Column('misfire_grace_time', Unicode(16)),
+ Column('misfire_grace_time', interval_type),
Column('running_jobs', Integer, nullable=False, server_default=literal(0))
)
Table(
@@ -117,8 +130,8 @@ class _BaseSQLAlchemyDataStore:
Column('args', LargeBinary),
Column('kwargs', LargeBinary),
Column('coalesce', Enum(CoalescePolicy), nullable=False),
- Column('misfire_grace_time', Unicode(16)),
- # Column('max_jitter', Unicode(16)),
+ Column('misfire_grace_time', interval_type),
+ Column('max_jitter', interval_type),
Column('tags', JSON, nullable=False),
Column('next_fire_time', timestamp_type, index=True),
Column('last_fire_time', timestamp_type),
@@ -134,6 +147,7 @@ class _BaseSQLAlchemyDataStore:
Column('kwargs', LargeBinary, nullable=False),
Column('schedule_id', Unicode(500)),
Column('scheduled_fire_time', timestamp_type),
+ Column('jitter', interval_type),
Column('start_deadline', timestamp_type),
Column('tags', JSON, nullable=False),
Column('created_at', timestamp_type, nullable=False),
diff --git a/src/apscheduler/schedulers/async_.py b/src/apscheduler/schedulers/async_.py
index 6900460..a08c090 100644
--- a/src/apscheduler/schedulers/async_.py
+++ b/src/apscheduler/schedulers/async_.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import os
import platform
+import random
from contextlib import AsyncExitStack
from datetime import datetime, timedelta, timezone
from logging import Logger, getLogger
@@ -24,6 +25,9 @@ from ..marshalling import callable_to_ref
from ..structures import JobResult, Task
from ..workers.async_ import AsyncWorker
+_microsecond_delta = timedelta(microseconds=1)
+_zero_timedelta = timedelta()
+
class AsyncScheduler:
"""An asynchronous (AnyIO based) scheduler implementation."""
@@ -96,7 +100,8 @@ class AsyncScheduler:
self, func_or_task_id: str | Callable, trigger: Trigger, *, id: Optional[str] = None,
args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None,
coalesce: CoalescePolicy = CoalescePolicy.latest,
- misfire_grace_time: float | timedelta | None = None, tags: Optional[Iterable[str]] = None,
+ misfire_grace_time: float | timedelta | None = None,
+ max_jitter: float | timedelta | None = None, tags: Optional[Iterable[str]] = None,
conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing
) -> str:
id = id or str(uuid4())
@@ -113,13 +118,18 @@ class AsyncScheduler:
task = await self.data_store.get_task(func_or_task_id)
schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs,
- coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags)
+ coalesce=coalesce, misfire_grace_time=misfire_grace_time,
+ max_jitter=max_jitter, tags=tags)
schedule.next_fire_time = trigger.next()
await self.data_store.add_schedule(schedule, conflict_policy)
self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task,
trigger, schedule.next_fire_time)
return schedule.id
+ async def get_schedule(self, id: str) -> Schedule:
+ schedules = await self.data_store.get_schedules({id})
+ return schedules[0]
+
async def remove_schedule(self, schedule_id: str) -> None:
await self.data_store.remove_schedules({schedule_id})
@@ -143,7 +153,7 @@ class AsyncScheduler:
else:
task = await self.data_store.get_task(func_or_task_id)
- job = Job(task_id=task.id, args=args, kwargs=kwargs, tags=tags)
+ job = Job(task_id=task.id, args=args or (), kwargs=kwargs or {}, tags=tags or frozenset())
await self.data_store.add_job(job)
return job.id
@@ -248,11 +258,31 @@ class AsyncScheduler:
fire_times[0] = fire_time
# Add one or more jobs to the job queue
- for fire_time in fire_times:
+ max_jitter = schedule.max_jitter.total_seconds() if schedule.max_jitter else 0
+ for i, fire_time in enumerate(fire_times):
+ # Calculate a jitter if max_jitter > 0
+ jitter = _zero_timedelta
+ if max_jitter:
+ if i + 1 < len(fire_times):
+ next_fire_time = fire_times[i + 1]
+ else:
+ next_fire_time = schedule.next_fire_time
+
+ if next_fire_time is not None:
+ # Jitter must never be so high that it would cause a fire time to
+ # equal or exceed the next fire time
+ jitter_s = min([
+ max_jitter,
+ (next_fire_time - fire_time
+ - _microsecond_delta).total_seconds()
+ ])
+ jitter = timedelta(seconds=random.uniform(0, jitter_s))
+ fire_time += jitter
+
schedule.last_fire_time = fire_time
job = Job(task_id=schedule.task_id, args=schedule.args,
kwargs=schedule.kwargs, schedule_id=schedule.id,
- scheduled_fire_time=fire_time,
+ scheduled_fire_time=fire_time, jitter=jitter,
start_deadline=schedule.next_deadline, tags=schedule.tags)
await self.data_store.add_job(job)
diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py
index 1525bea..905efac 100644
--- a/src/apscheduler/schedulers/sync.py
+++ b/src/apscheduler/schedulers/sync.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import os
import platform
+import random
import threading
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from contextlib import ExitStack
@@ -21,6 +22,9 @@ from ..marshalling import callable_to_ref
from ..structures import Job, JobResult, Schedule, Task
from ..workers.sync import Worker
+_microsecond_delta = timedelta(microseconds=1)
+_zero_timedelta = timedelta()
+
class Scheduler:
"""A synchronous scheduler implementation."""
@@ -96,7 +100,7 @@ class Scheduler:
args: Optional[Iterable] = None, kwargs: Optional[Mapping[str, Any]] = None,
coalesce: CoalescePolicy = CoalescePolicy.latest,
misfire_grace_time: float | timedelta | None = None,
- tags: Optional[Iterable[str]] = None,
+ max_jitter: float | timedelta | None = None, tags: Optional[Iterable[str]] = None,
conflict_policy: ConflictPolicy = ConflictPolicy.do_nothing
) -> str:
id = id or str(uuid4())
@@ -113,13 +117,18 @@ class Scheduler:
task = self.data_store.get_task(func_or_task_id)
schedule = Schedule(id=id, task_id=task.id, trigger=trigger, args=args, kwargs=kwargs,
- coalesce=coalesce, misfire_grace_time=misfire_grace_time, tags=tags)
+ coalesce=coalesce, misfire_grace_time=misfire_grace_time,
+ max_jitter=max_jitter, tags=tags)
schedule.next_fire_time = trigger.next()
self.data_store.add_schedule(schedule, conflict_policy)
self.logger.info('Added new schedule (task=%r, trigger=%r); next run time at %s', task,
trigger, schedule.next_fire_time)
return schedule.id
+ def get_schedule(self, id: str) -> Schedule:
+ schedules = self.data_store.get_schedules({id})
+ return schedules[0]
+
def remove_schedule(self, schedule_id: str) -> None:
self.data_store.remove_schedules({schedule_id})
@@ -143,7 +152,7 @@ class Scheduler:
else:
task = self.data_store.get_task(func_or_task_id)
- job = Job(task_id=task.id, args=args, kwargs=kwargs, tags=tags)
+ job = Job(task_id=task.id, args=args or (), kwargs=kwargs or {}, tags=tags or frozenset())
self.data_store.add_job(job)
return job.id
@@ -247,11 +256,31 @@ class Scheduler:
fire_times[0] = fire_time
# Add one or more jobs to the job queue
- for fire_time in fire_times:
+ max_jitter = schedule.max_jitter.total_seconds() if schedule.max_jitter else 0
+ for i, fire_time in enumerate(fire_times):
+ # Calculate a jitter if max_jitter > 0
+ jitter = _zero_timedelta
+ if max_jitter:
+ if i + 1 < len(fire_times):
+ next_fire_time = fire_times[i + 1]
+ else:
+ next_fire_time = schedule.next_fire_time
+
+ if next_fire_time is not None:
+ # Jitter must never be so high that it would cause a fire time to
+ # equal or exceed the next fire time
+ jitter_s = min([
+ max_jitter,
+ (next_fire_time - fire_time
+ - _microsecond_delta).total_seconds()
+ ])
+ jitter = timedelta(seconds=random.uniform(0, jitter_s))
+ fire_time += jitter
+
schedule.last_fire_time = fire_time
job = Job(task_id=schedule.task_id, args=schedule.args,
kwargs=schedule.kwargs, schedule_id=schedule.id,
- scheduled_fire_time=fire_time,
+ scheduled_fire_time=fire_time, jitter=jitter,
start_deadline=schedule.next_deadline, tags=schedule.tags)
self.data_store.add_job(job)
diff --git a/src/apscheduler/structures.py b/src/apscheduler/structures.py
index c805ef0..a1e653b 100644
--- a/src/apscheduler/structures.py
+++ b/src/apscheduler/structures.py
@@ -6,9 +6,9 @@ from typing import Any, Callable, Optional
from uuid import UUID, uuid4
import attr
-from attr.converters import default_if_none
from . import abc
+from .converters import as_enum, as_timedelta
from .enums import CoalescePolicy, JobOutcome
from .marshalling import callable_from_ref, callable_to_ref
@@ -41,12 +41,15 @@ class Schedule:
id: str
task_id: str = attr.field(eq=False, order=False)
trigger: abc.Trigger = attr.field(eq=False, order=False)
- args: tuple = attr.field(eq=False, order=False, default=())
- kwargs: dict[str, Any] = attr.field(eq=False, order=False, factory=dict)
- coalesce: CoalescePolicy = attr.field(eq=False, order=False, default=CoalescePolicy.latest)
- misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None)
- # max_jitter: Optional[timedelta] = attr.field(eq=False, order=False, default=None)
- tags: frozenset[str] = attr.field(eq=False, order=False, factory=frozenset)
+ args: tuple = attr.field(eq=False, order=False, converter=tuple, default=())
+ kwargs: dict[str, Any] = attr.field(eq=False, order=False, converter=dict, default=())
+ coalesce: CoalescePolicy = attr.field(eq=False, order=False, default=CoalescePolicy.latest,
+ converter=as_enum(CoalescePolicy))
+ misfire_grace_time: Optional[timedelta] = attr.field(eq=False, order=False, default=None,
+ converter=as_timedelta)
+ max_jitter: Optional[timedelta] = attr.field(eq=False, order=False, converter=as_timedelta,
+ default=None)
+ tags: frozenset[str] = attr.field(eq=False, order=False, converter=frozenset, default=())
next_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None)
last_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None)
acquired_by: Optional[str] = attr.field(eq=False, order=False, default=None)
@@ -57,10 +60,6 @@ class Schedule:
marshalled['trigger'] = serializer.serialize(self.trigger)
marshalled['args'] = serializer.serialize(self.args)
marshalled['kwargs'] = serializer.serialize(self.kwargs)
- marshalled['coalesce'] = self.coalesce.name
- marshalled['tags'] = list(self.tags)
- marshalled['misfire_grace_time'] = (self.misfire_grace_time.total_seconds()
- if self.misfire_grace_time is not None else None)
if not self.acquired_by:
del marshalled['acquired_by']
del marshalled['acquired_until']
@@ -72,10 +71,6 @@ class Schedule:
marshalled['trigger'] = serializer.deserialize(marshalled['trigger'])
marshalled['args'] = serializer.deserialize(marshalled['args'])
marshalled['kwargs'] = serializer.deserialize(marshalled['kwargs'])
- marshalled['tags'] = frozenset(marshalled['tags'])
- if isinstance(marshalled['coalesce'], str):
- marshalled['coalesce'] = CoalescePolicy.__members__[marshalled['coalesce']]
-
return cls(**marshalled)
@property
@@ -90,25 +85,32 @@ class Schedule:
class Job:
id: UUID = attr.field(factory=uuid4)
task_id: str = attr.field(eq=False, order=False)
- args: tuple = attr.field(eq=False, order=False, converter=default_if_none(()))
- kwargs: dict[str, Any] = attr.field(
- eq=False, order=False, converter=default_if_none(factory=dict))
+ args: tuple = attr.field(eq=False, order=False, converter=tuple, default=())
+ kwargs: dict[str, Any] = attr.field(eq=False, order=False, converter=dict, default=())
schedule_id: Optional[str] = attr.field(eq=False, order=False, default=None)
scheduled_fire_time: Optional[datetime] = attr.field(eq=False, order=False, default=None)
+ jitter: timedelta = attr.field(eq=False, order=False, converter=as_timedelta,
+ factory=timedelta)
start_deadline: Optional[datetime] = attr.field(eq=False, order=False, default=None)
- tags: frozenset[str] = attr.field(
- eq=False, order=False, converter=default_if_none(factory=frozenset))
+ tags: frozenset[str] = attr.field(eq=False, order=False, converter=frozenset, default=())
created_at: datetime = attr.field(eq=False, order=False,
factory=partial(datetime.now, timezone.utc))
started_at: Optional[datetime] = attr.field(eq=False, order=False, default=None)
acquired_by: Optional[str] = attr.field(eq=False, order=False, default=None)
acquired_until: Optional[datetime] = attr.field(eq=False, order=False, default=None)
+ @property
+ def original_scheduled_time(self) -> Optional[datetime]:
+ """The scheduled time without any jitter included."""
+ if self.scheduled_fire_time is None:
+ return None
+
+ return self.scheduled_fire_time - self.jitter
+
def marshal(self, serializer: abc.Serializer) -> dict[str, Any]:
marshalled = attr.asdict(self)
marshalled['args'] = serializer.serialize(self.args)
marshalled['kwargs'] = serializer.serialize(self.kwargs)
- marshalled['tags'] = list(self.tags)
if not self.acquired_by:
del marshalled['acquired_by']
del marshalled['acquired_until']
@@ -117,17 +119,15 @@ class Job:
@classmethod
def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> Job:
- for key in ('args', 'kwargs'):
- marshalled[key] = serializer.deserialize(marshalled[key])
-
- marshalled['tags'] = frozenset(marshalled['tags'])
+ marshalled['args'] = serializer.deserialize(marshalled['args'])
+ marshalled['kwargs'] = serializer.deserialize(marshalled['kwargs'])
return cls(**marshalled)
@attr.define(kw_only=True, frozen=True)
class JobResult:
job_id: UUID
- outcome: JobOutcome = attr.field(eq=False, order=False)
+ outcome: JobOutcome = attr.field(eq=False, order=False, converter=as_enum(JobOutcome))
finished_at: datetime = attr.field(eq=False, order=False,
factory=partial(datetime.now, timezone.utc))
exception: Optional[BaseException] = attr.field(eq=False, order=False, default=None)
@@ -135,7 +135,6 @@ class JobResult:
def marshal(self, serializer: abc.Serializer) -> dict[str, Any]:
marshalled = attr.asdict(self)
- marshalled['outcome'] = self.outcome.name
if self.outcome is JobOutcome.error:
marshalled['exception'] = serializer.serialize(self.exception)
else:
@@ -150,9 +149,6 @@ class JobResult:
@classmethod
def unmarshal(cls, serializer: abc.Serializer, marshalled: dict[str, Any]) -> JobResult:
- if isinstance(marshalled['outcome'], str):
- marshalled['outcome'] = JobOutcome.__members__[marshalled['outcome']]
-
if marshalled.get('exception'):
marshalled['exception'] = serializer.deserialize(marshalled['exception'])
elif marshalled.get('return_value'):
diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py
index 5f1a7df..240b7bf 100644
--- a/tests/test_schedulers.py
+++ b/tests/test_schedulers.py
@@ -1,10 +1,14 @@
+import sys
import threading
import time
-from datetime import datetime, timezone
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
import anyio
import pytest
from anyio import fail_after
+from pytest_mock import MockerFixture
from apscheduler.enums import JobOutcome
from apscheduler.events import (
@@ -13,6 +17,12 @@ from apscheduler.exceptions import JobLookupError
from apscheduler.schedulers.async_ import AsyncScheduler
from apscheduler.schedulers.sync import Scheduler
from apscheduler.triggers.date import DateTrigger
+from apscheduler.triggers.interval import IntervalTrigger
+
+if sys.version_info >= (3, 9):
+ from zoneinfo import ZoneInfo
+else:
+ from backports.zoneinfo import ZoneInfo
pytestmark = pytest.mark.anyio
@@ -83,6 +93,44 @@ class TestAsyncScheduler:
# There should be no more events on the list
assert not received_events
+ @pytest.mark.parametrize('max_jitter, expected_upper_bound', [
+ pytest.param(2, 2, id='within'),
+ pytest.param(4, 2.999999, id='exceed')
+ ])
+ async def test_jitter(self, mocker: MockerFixture, timezone: ZoneInfo, max_jitter: float,
+ expected_upper_bound: float) -> None:
+ def job_added_listener(event: Event) -> None:
+ nonlocal job_id
+ assert isinstance(event, JobAdded)
+ job_id = event.job_id
+ job_added_event.set()
+
+ jitter = 1.569374
+ orig_start_time = datetime.now(timezone) - timedelta(seconds=1)
+ fake_uniform = mocker.patch('random.uniform')
+ fake_uniform.configure_mock(side_effect=lambda a, b: jitter)
+ async with AsyncScheduler(start_worker=False) as scheduler:
+ trigger = IntervalTrigger(seconds=3, start_time=orig_start_time)
+ job_added_event = anyio.Event()
+ job_id: Optional[UUID] = None
+ scheduler.events.subscribe(job_added_listener, {JobAdded})
+ schedule_id = await scheduler.add_schedule(dummy_async_job, trigger,
+ max_jitter=max_jitter)
+ schedule = await scheduler.get_schedule(schedule_id)
+ assert schedule.max_jitter == timedelta(seconds=max_jitter)
+
+ # Wait for the job to be added
+ with fail_after(3):
+ await job_added_event.wait()
+
+ fake_uniform.assert_called_once_with(0, expected_upper_bound)
+
+ # Check that the job was created with the proper amount of jitter in its scheduled time
+ jobs = await scheduler.data_store.get_jobs({job_id})
+ assert jobs[0].jitter == timedelta(seconds=jitter)
+ assert jobs[0].scheduled_fire_time == orig_start_time + timedelta(seconds=jitter)
+ assert jobs[0].original_scheduled_time == orig_start_time
+
async def test_get_job_result_success(self) -> None:
async with AsyncScheduler() as scheduler:
job_id = await scheduler.add_job(dummy_async_job, kwargs={'delay': 0.2})
@@ -165,6 +213,42 @@ class TestSyncScheduler:
# There should be no more events on the list
assert not received_events
+ @pytest.mark.parametrize('max_jitter, expected_upper_bound', [
+ pytest.param(2, 2, id='within'),
+ pytest.param(4, 2.999999, id='exceed')
+ ])
+ def test_jitter(self, mocker: MockerFixture, timezone: ZoneInfo, max_jitter: float,
+ expected_upper_bound: float) -> None:
+ def job_added_listener(event: Event) -> None:
+ nonlocal job_id
+ assert isinstance(event, JobAdded)
+ job_id = event.job_id
+ job_added_event.set()
+
+ jitter = 1.569374
+ orig_start_time = datetime.now(timezone) - timedelta(seconds=1)
+ fake_uniform = mocker.patch('random.uniform')
+ fake_uniform.configure_mock(side_effect=lambda a, b: jitter)
+ with Scheduler(start_worker=False) as scheduler:
+ trigger = IntervalTrigger(seconds=3, start_time=orig_start_time)
+ job_added_event = threading.Event()
+ job_id: Optional[UUID] = None
+ scheduler.events.subscribe(job_added_listener, {JobAdded})
+ schedule_id = scheduler.add_schedule(dummy_async_job, trigger, max_jitter=max_jitter)
+ schedule = scheduler.get_schedule(schedule_id)
+ assert schedule.max_jitter == timedelta(seconds=max_jitter)
+
+ # Wait for the job to be added
+ job_added_event.wait(3)
+
+ fake_uniform.assert_called_once_with(0, expected_upper_bound)
+
+ # Check that the job was created with the proper amount of jitter in its scheduled time
+ jobs = scheduler.data_store.get_jobs({job_id})
+ assert jobs[0].jitter == timedelta(seconds=jitter)
+ assert jobs[0].scheduled_fire_time == orig_start_time + timedelta(seconds=jitter)
+ assert jobs[0].original_scheduled_time == orig_start_time
+
def test_get_job_result(self) -> None:
with Scheduler() as scheduler:
job_id = scheduler.add_job(dummy_sync_job)
diff --git a/tests/test_workers.py b/tests/test_workers.py
index 872cf34..f1f020e 100644
--- a/tests/test_workers.py
+++ b/tests/test_workers.py
@@ -77,8 +77,7 @@ class TestAsyncWorker:
received_event = received_events.pop(0)
assert isinstance(received_event, JobAcquired)
assert received_event.job_id == job.id
- assert received_event.task_id == 'task_id'
- assert received_event.schedule_id is None
+ assert received_event.worker_id == worker.identity
received_event = received_events.pop(0)
if fail:
@@ -100,7 +99,7 @@ class TestAsyncWorker:
async def test_run_deadline_missed(self) -> None:
def listener(received_event: Event):
received_events.append(received_event)
- if len(received_events) == 4:
+ if len(received_events) == 5:
event.set()
scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc)
@@ -134,13 +133,18 @@ class TestAsyncWorker:
assert received_event.task_id == 'task_id'
assert received_event.schedule_id == 'foo'
- # Then the deadline was missed
+ # The worker acquired the job
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, JobAcquired)
+ assert received_event.job_id == job.id
+ assert received_event.worker_id == worker.identity
+
+ # The worker determined that the deadline has been missed
received_event = received_events.pop(0)
assert isinstance(received_event, JobReleased)
assert received_event.outcome is JobOutcome.missed_start_deadline
assert received_event.job_id == job.id
- assert received_event.task_id == 'task_id'
- assert received_event.schedule_id == 'foo'
+ assert received_event.worker_id == worker.identity
# Finally, the worker was stopped
received_event = received_events.pop(0)
@@ -189,8 +193,7 @@ class TestSyncWorker:
received_event = received_events.pop(0)
assert isinstance(received_event, JobAcquired)
assert received_event.job_id == job.id
- assert received_event.task_id == 'task_id'
- assert received_event.schedule_id is None
+ assert received_event.worker_id == worker.identity
received_event = received_events.pop(0)
if fail:
@@ -212,7 +215,7 @@ class TestSyncWorker:
def test_run_deadline_missed(self) -> None:
def listener(worker_event: Event):
received_events.append(worker_event)
- if len(received_events) == 4:
+ if len(received_events) == 5:
event.set()
scheduled_start_time = datetime(2020, 9, 14, tzinfo=timezone.utc)
@@ -227,7 +230,7 @@ class TestSyncWorker:
scheduled_fire_time=scheduled_start_time,
start_deadline=datetime(2020, 9, 14, 1, tzinfo=timezone.utc))
worker.data_store.add_job(job)
- event.wait(5)
+ event.wait(3)
# The worker was first started
received_event = received_events.pop(0)
@@ -245,13 +248,18 @@ class TestSyncWorker:
assert received_event.task_id == 'task_id'
assert received_event.schedule_id == 'foo'
- # Then the deadline was missed
+ # The worker acquired the job
+ received_event = received_events.pop(0)
+ assert isinstance(received_event, JobAcquired)
+ assert received_event.job_id == job.id
+ assert received_event.worker_id == worker.identity
+
+ # The worker determined that the deadline has been missed
received_event = received_events.pop(0)
assert isinstance(received_event, JobReleased)
assert received_event.outcome is JobOutcome.missed_start_deadline
assert received_event.job_id == job.id
- assert received_event.task_id == 'task_id'
- assert received_event.schedule_id == 'foo'
+ assert received_event.worker_id == worker.identity
# Finally, the worker was stopped
received_event = received_events.pop(0)