summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-06 01:26:08 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-06 01:39:07 +0300
commita3f75a8e4134cb2fd587423891b6814082dac83d (patch)
tree5586476b2c539e6d9164da002310d440bc3a00ea
parentd5fbe437a4481bdd07085bc3658392a181d2c6a6 (diff)
downloadapscheduler-a3f75a8e4134cb2fd587423891b6814082dac83d.tar.gz
Migrated some more classes to attrs
-rw-r--r--src/apscheduler/datastores/async_/sqlalchemy.py39
-rw-r--r--src/apscheduler/datastores/sync/mongodb.py44
-rw-r--r--src/apscheduler/datastores/sync/sqlalchemy.py44
3 files changed, 58 insertions, 69 deletions
diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py
index b0c8afa..619b1c5 100644
--- a/src/apscheduler/datastores/async_/sqlalchemy.py
+++ b/src/apscheduler/datastores/async_/sqlalchemy.py
@@ -1,14 +1,15 @@
from __future__ import annotations
import json
-import logging
from collections import defaultdict
from contextlib import AsyncExitStack, closing
from datetime import datetime, timedelta, timezone
from json import JSONDecodeError
+from logging import Logger, getLogger
from typing import Any, Callable, Iterable, Optional, Tuple, Type
from uuid import UUID
+import attr
import sniffio
from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
from attr import asdict
@@ -34,8 +35,6 @@ from ...serializers.pickle import PickleSerializer
from ...structures import JobResult, Task
from ...util import reentrant
-logger = logging.getLogger(__name__)
-
def default_json_handler(obj: Any) -> Any:
if isinstance(obj, datetime):
@@ -61,23 +60,22 @@ def json_object_hook(obj: dict[str, Any]) -> Any:
@reentrant
+@attr.define(eq=False)
class SQLAlchemyDataStore(AsyncDataStore):
- def __init__(self, engine: AsyncEngine, *, schema: Optional[str] = None,
- serializer: Optional[Serializer] = None,
- lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1,
- max_idle_time: float = 60, start_from_scratch: bool = False,
- notify_channel: Optional[str] = 'apscheduler'):
- self.engine = engine
- self.schema = schema
- self.serializer = serializer or PickleSerializer()
- self.lock_expiration_delay = lock_expiration_delay
- self.max_poll_time = max_poll_time
- self.max_idle_time = max_idle_time
- self.start_from_scratch = start_from_scratch
- self._logger = logging.getLogger(__name__)
- self._exit_stack = AsyncExitStack()
- self._events = AsyncEventHub()
-
+ engine: AsyncEngine
+ schema: Optional[str] = attr.field(default=None, kw_only=True)
+ serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True)
+ lock_expiration_delay: float = attr.field(default=30, kw_only=True)
+ max_poll_time: Optional[float] = attr.field(default=1, kw_only=True)
+ max_idle_time: float = attr.field(default=60, kw_only=True)
+ notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True)
+ start_from_scratch: bool = attr.field(default=False, kw_only=True)
+
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+ _exit_stack: AsyncExitStack = attr.field(init=False, factory=AsyncExitStack)
+ _events: AsyncEventHub = attr.field(init=False, factory=AsyncEventHub)
+
+ def __attrs_post_init__(self) -> None:
# Generate the table definitions
self._metadata = self.get_table_definitions()
self.t_metadata = self._metadata.tables['metadata']
@@ -95,8 +93,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
else:
self._supports_update_returning = True
- self.notify_channel = notify_channel
- if notify_channel:
+ if self.notify_channel:
if self.engine.dialect.name != 'postgresql' or self.engine.dialect.driver != 'asyncpg':
self.notify_channel = None
diff --git a/src/apscheduler/datastores/sync/mongodb.py b/src/apscheduler/datastores/sync/mongodb.py
index 50a5d15..1f1f72c 100644
--- a/src/apscheduler/datastores/sync/mongodb.py
+++ b/src/apscheduler/datastores/sync/mongodb.py
@@ -1,14 +1,15 @@
from __future__ import annotations
-import logging
from collections import defaultdict
from contextlib import ExitStack
from datetime import datetime, timezone
+from logging import Logger, getLogger
from typing import Any, Callable, ClassVar, Iterable, Optional, Tuple, Type
from uuid import UUID
import attr
import pymongo
+from attr.validators import instance_of
from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne
from pymongo.collection import Collection
from pymongo.errors import DuplicateKeyError
@@ -27,33 +28,34 @@ from ...util import reentrant
@reentrant
+@attr.define(eq=False)
class MongoDBDataStore(DataStore):
+ client: MongoClient = attr.field(validator=instance_of(MongoClient))
+ serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True)
+ database: str = attr.field(default='apscheduler', kw_only=True)
+ lock_expiration_delay: float = attr.field(default=30, kw_only=True)
+ start_from_scratch: bool = attr.field(default=False, kw_only=True)
+
_task_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Task)]
_schedule_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Schedule)]
_job_attrs: ClassVar[list[str]] = [field.name for field in attr.fields(Job)]
- def __init__(self, client: MongoClient, *, serializer: Optional[Serializer] = None,
- database: str = 'apscheduler', tasks_collection: str = 'tasks',
- schedules_collection: str = 'schedules', jobs_collection: str = 'jobs',
- job_results_collection: str = 'job_results',
- lock_expiration_delay: float = 30, start_from_scratch: bool = False):
- super().__init__()
- if not client.delegate.codec_options.tz_aware:
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+ _exit_stack: ExitStack = attr.field(init=False, factory=ExitStack)
+ _events: EventHub = attr.field(init=False, factory=EventHub)
+ _local_tasks: dict[str, Task] = attr.field(init=False, factory=dict)
+
+ @client.validator
+ def validate_client(self, attribute: attr.Attribute, value: MongoClient) -> None:
+ if not value.delegate.codec_options.tz_aware:
raise ValueError('MongoDB client must have tz_aware set to True')
- self.client = client
- self.serializer = serializer or PickleSerializer()
- self.lock_expiration_delay = lock_expiration_delay
- self.start_from_scratch = start_from_scratch
- self._local_tasks: dict[str, Task] = {}
- self._database = client[database]
- self._tasks: Collection = self._database[tasks_collection]
- self._schedules: Collection = self._database[schedules_collection]
- self._jobs: Collection = self._database[jobs_collection]
- self._jobs_results: Collection = self._database[job_results_collection]
- self._logger = logging.getLogger(__name__)
- self._exit_stack = ExitStack()
- self._events = EventHub()
+ def __attrs_post_init__(self) -> None:
+ database = self.client[self.database]
+ self._tasks: Collection = database['tasks']
+ self._schedules: Collection = database['schedules']
+ self._jobs: Collection = database['jobs']
+ self._jobs_results: Collection = database['job_results']
@classmethod
def from_url(cls, uri: str, **options) -> 'MongoDBDataStore':
diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py
index 7c30e16..d4c1d3f 100644
--- a/src/apscheduler/datastores/sync/sqlalchemy.py
+++ b/src/apscheduler/datastores/sync/sqlalchemy.py
@@ -1,11 +1,12 @@
from __future__ import annotations
-import logging
from collections import defaultdict
from datetime import datetime, timedelta, timezone
+from logging import Logger, getLogger
from typing import Any, Callable, Iterable, Optional, Tuple, Type
from uuid import UUID
+import attr
from sqlalchemy import (
Column, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, or_, select)
from sqlalchemy.engine import URL
@@ -26,25 +27,23 @@ from ...serializers.pickle import PickleSerializer
from ...structures import JobResult, Task
from ...util import reentrant
-logger = logging.getLogger(__name__)
-
@reentrant
+@attr.define(eq=False)
class SQLAlchemyDataStore(DataStore):
- def __init__(self, engine: Engine, *, schema: Optional[str] = None,
- serializer: Optional[Serializer] = None,
- lock_expiration_delay: float = 30, max_poll_time: Optional[float] = 1,
- max_idle_time: float = 60, start_from_scratch: bool = False):
- self.engine = engine
- self.schema = schema
- self.serializer = serializer or PickleSerializer()
- self.lock_expiration_delay = lock_expiration_delay
- self.max_poll_time = max_poll_time
- self.max_idle_time = max_idle_time
- self.start_from_scratch = start_from_scratch
- self._logger = logging.getLogger(__name__)
- self._events = EventHub()
-
+ engine: Engine
+ schema: Optional[str] = attr.field(default=None, kw_only=True)
+ serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True)
+ lock_expiration_delay: float = attr.field(default=30, kw_only=True)
+ max_poll_time: Optional[float] = attr.field(default=1, kw_only=True)
+ max_idle_time: float = attr.field(default=60, kw_only=True)
+ notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True)
+ start_from_scratch: bool = attr.field(default=False, kw_only=True)
+
+ _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+ _events: EventHub = attr.field(init=False, factory=EventHub)
+
+ def __attrs_post_init__(self) -> None:
# Generate the table definitions
self._metadata = self.get_table_definitions()
self.t_metadata = self._metadata.tables['metadata']
@@ -54,7 +53,7 @@ class SQLAlchemyDataStore(DataStore):
self.t_job_results = self._metadata.tables['job_results']
# Find out if the dialect supports RETURNING
- update = self.t_jobs.update().returning(self.t_schedules.c.id)
+ update = self.t_jobs.update().returning(self.t_jobs.c.id)
try:
update.compile(bind=self.engine)
except CompileError:
@@ -62,15 +61,6 @@ class SQLAlchemyDataStore(DataStore):
else:
self._supports_update_returning = True
- # Find out if the dialect supports INSERT...ON DUPLICATE KEY UPDATE
- insert = self.t_jobs.update().returning(self.t_schedules.c.id)
- try:
- insert.compile(bind=self.engine)
- except CompileError:
- self._supports_update_returning = False
- else:
- self._supports_update_returning = True
-
@classmethod
def from_url(cls, url: str | URL, **options) -> 'SQLAlchemyDataStore':
engine = create_engine(url)