diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-06 01:26:08 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-06 01:39:07 +0300 |
commit | a3f75a8e4134cb2fd587423891b6814082dac83d (patch) | |
tree | 5586476b2c539e6d9164da002310d440bc3a00ea | |
parent | d5fbe437a4481bdd07085bc3658392a181d2c6a6 (diff) | |
download | apscheduler-a3f75a8e4134cb2fd587423891b6814082dac83d.tar.gz |
Migrated some more classes to attrs
-rw-r--r-- | src/apscheduler/datastores/async_/sqlalchemy.py | 39 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/mongodb.py | 44 | ||||
-rw-r--r-- | src/apscheduler/datastores/sync/sqlalchemy.py | 44 |
3 files changed, 58 insertions, 69 deletions
diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py index b0c8afa..619b1c5 100644 --- a/src/apscheduler/datastores/async_/sqlalchemy.py +++ b/src/apscheduler/datastores/async_/sqlalchemy.py @@ -1,14 +1,15 @@ from __future__ import annotations import json -import logging from collections import defaultdict from contextlib import AsyncExitStack, closing from datetime import datetime, timedelta, timezone from json import JSONDecodeError +from logging import Logger, getLogger from typing import Any, Callable, Iterable, Optional, Tuple, Type from uuid import UUID +import attr import sniffio from anyio import TASK_STATUS_IGNORED, create_task_group, sleep from attr import asdict @@ -34,8 +35,6 @@ from ...serializers.pickle import PickleSerializer from ...structures import JobResult, Task from ...util import reentrant -logger = logging.getLogger(__name__) - def default_json_handler(obj: Any) -> Any: if isinstance(obj, datetime): @@ -61,23 +60,22 @@ def json_object_hook(obj: dict[str, Any]) -> Any: @reentrant +@attr.define(eq=False) class SQLAlchemyDataStore(AsyncDataStore): - def __init__(self, engine: AsyncEngine, *, schema: Optional[str] = None, - serializer: Optional[Serializer] = None, - lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1, - max_idle_time: float = 60, start_from_scratch: bool = False, - notify_channel: Optional[str] = 'apscheduler'): - self.engine = engine - self.schema = schema - self.serializer = serializer or PickleSerializer() - self.lock_expiration_delay = lock_expiration_delay - self.max_poll_time = max_poll_time - self.max_idle_time = max_idle_time - self.start_from_scratch = start_from_scratch - self._logger = logging.getLogger(__name__) - self._exit_stack = AsyncExitStack() - self._events = AsyncEventHub() - + engine: AsyncEngine + schema: Optional[str] = attr.field(default=None, kw_only=True) + serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True) + lock_expiration_delay: float = attr.field(default=30, kw_only=True) + max_poll_time: Optional[float] = attr.field(default=1, kw_only=True) + max_idle_time: float = attr.field(default=60, kw_only=True) + notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True) + start_from_scratch: bool = attr.field(default=False, kw_only=True) + + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _exit_stack: AsyncExitStack = attr.field(init=False, factory=AsyncExitStack) + _events: AsyncEventHub = attr.field(init=False, factory=AsyncEventHub) + + def __attrs_post_init__(self) -> None: # Generate the table definitions self._metadata = self.get_table_definitions() self.t_metadata = self._metadata.tables['metadata'] @@ -95,8 +93,7 @@ class SQLAlchemyDataStore(AsyncDataStore): else: self._supports_update_returning = True - self.notify_channel = notify_channel - if notify_channel: + if self.notify_channel: if self.engine.dialect.name != 'postgresql' or self.engine.dialect.driver != 'asyncpg': self.notify_channel = None diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py index 50a5d15..1f1f72c 100644 --- a/src/apscheduler/datastores/sync/mongodb.py +++ b/src/apscheduler/datastores/sync/mongodb.py @@ -1,14 +1,15 @@ from __future__ import annotations -import logging from collections import defaultdict from contextlib import ExitStack from datetime import datetime, timezone +from logging import Logger, getLogger from typing import Any, Callable, ClassVar, Iterable, Optional, Tuple, Type from uuid import UUID import attr import pymongo +from attr.validators import instance_of from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection from pymongo.errors import DuplicateKeyError @@ -27,33 +28,34 @@ from ...util import reentrant @reentrant +@attr.define(eq=False) class MongoDBDataStore(DataStore): + client: MongoClient = attr.field(validator=instance_of(MongoClient)) + serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True) + database: str = attr.field(default='apscheduler', kw_only=True) + lock_expiration_delay: float = attr.field(default=30, kw_only=True) + start_from_scratch: bool = attr.field(default=False, kw_only=True) + _task_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Task)] _schedule_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Schedule)] _job_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Job)] - def __init__(self, client: MongoClient, *, serializer: Optional[Serializer] = None, - database: str = 'apscheduler', tasks_collection: str = 'tasks', - schedules_collection: str = 'schedules', jobs_collection: str = 'jobs', - job_results_collection: str = 'job_results', - lock_expiration_delay: float = 30, start_from_scratch: bool = False): - super().__init__() - if not client.delegate.codec_options.tz_aware: + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _exit_stack: ExitStack = attr.field(init=False, factory=ExitStack) + _events: EventHub = attr.field(init=False, factory=EventHub) + _local_tasks: dict[str, Task] = attr.field(init=False, factory=dict) + + @client.validator + def validate_client(self, attribute: attr.Attribute, value: MongoClient) -> None: + if not value.delegate.codec_options.tz_aware: raise ValueError('MongoDB client must have tz_aware set to True') - self.client = client - self.serializer = serializer or PickleSerializer() - self.lock_expiration_delay = lock_expiration_delay - self.start_from_scratch = start_from_scratch - self._local_tasks: dict[str, Task] = {} - self._database = client[database] - self._tasks: Collection = self._database[tasks_collection] - self._schedules: Collection = self._database[schedules_collection] - self._jobs: Collection = self._database[jobs_collection] - self._jobs_results: Collection = self._database[job_results_collection] - self._logger = logging.getLogger(__name__) - self._exit_stack = ExitStack() - self._events = EventHub() + def __attrs_post_init__(self) -> None: + database = self.client[self.database] + self._tasks: Collection = database['tasks'] + self._schedules: Collection = database['schedules'] + self._jobs: Collection = database['jobs'] + self._jobs_results: Collection = database['job_results'] @classmethod def from_url(cls, uri: str, **options) -> 'MongoDBDataStore': diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py index 7c30e16..d4c1d3f 100644 --- a/src/apscheduler/datastores/sync/sqlalchemy.py +++ b/src/apscheduler/datastores/sync/sqlalchemy.py @@ -1,11 +1,12 @@ from __future__ import annotations -import logging from collections import defaultdict from datetime import datetime, timedelta, timezone +from logging import Logger, getLogger from typing import Any, Callable, Iterable, Optional, Tuple, Type from uuid import UUID +import attr from sqlalchemy import ( Column, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, or_, select) from sqlalchemy.engine import URL @@ -26,25 +27,23 @@ from ...serializers.pickle import PickleSerializer from ...structures import JobResult, Task from ...util import reentrant -logger = logging.getLogger(__name__) - @reentrant +@attr.define(eq=False) class SQLAlchemyDataStore(DataStore): - def __init__(self, engine: Engine, *, schema: Optional[str] = None, - serializer: Optional[Serializer] = None, - lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1, - max_idle_time: float = 60, start_from_scratch: bool = False): - self.engine = engine - self.schema = schema - self.serializer = serializer or PickleSerializer() - self.lock_expiration_delay = lock_expiration_delay - self.max_poll_time = max_poll_time - self.max_idle_time = max_idle_time - self.start_from_scratch = start_from_scratch - self._logger = logging.getLogger(__name__) - self._events = EventHub() - + engine: Engine + schema: Optional[str] = attr.field(default=None, kw_only=True) + serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True) + lock_expiration_delay: float = attr.field(default=30, kw_only=True) + max_poll_time: Optional[float] = attr.field(default=1, kw_only=True) + max_idle_time: float = attr.field(default=60, kw_only=True) + notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True) + start_from_scratch: bool = attr.field(default=False, kw_only=True) + + _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _events: EventHub = attr.field(init=False, factory=EventHub) + + def __attrs_post_init__(self) -> None: # Generate the table definitions self._metadata = self.get_table_definitions() self.t_metadata = self._metadata.tables['metadata'] @@ -54,7 +53,7 @@ class SQLAlchemyDataStore(DataStore): self.t_job_results = self._metadata.tables['job_results'] # Find out if the dialect supports RETURNING - update = self.t_jobs.update().returning(self.t_schedules.c.id) + update = self.t_jobs.update().returning(self.t_jobs.c.id) try: update.compile(bind=self.engine) except CompileError: @@ -62,15 +61,6 @@ class SQLAlchemyDataStore(DataStore): else: self._supports_update_returning = True - # Find out if the dialect supports INSERT...ON DUPLICATE KEY UPDATE - insert = self.t_jobs.update().returning(self.t_schedules.c.id) - try: - insert.compile(bind=self.engine) - except CompileError: - self._supports_update_returning = False - else: - self._supports_update_returning = True - @classmethod def from_url(cls, url: str | URL, **options) -> 'SQLAlchemyDataStore': engine = create_engine(url) |