diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-08 22:12:37 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-08 22:12:37 +0300 |
commit | 7248a78e7e787b728b083aaa8199eeba3a3f3023 (patch) | |
tree | f3cd37b3809a6dd82ecc72a43c07e76d1b062257 | |
parent | 114e041fa434a36f27c80130b6c0667da5497047 (diff) | |
download | apscheduler-7248a78e7e787b728b083aaa8199eeba3a3f3023.tar.gz |
Deduplicated some SQLAlchemy store code
-rw-r--r-- | src/apscheduler/datastores/async_sqlalchemy.py | 115 | ||||
-rw-r--r-- | src/apscheduler/datastores/sqlalchemy.py | 83 |
2 files changed, 53 insertions, 145 deletions
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py index fad2cd3..c0d21cb 100644 --- a/src/apscheduler/datastores/async_sqlalchemy.py +++ b/src/apscheduler/datastores/async_sqlalchemy.py @@ -5,7 +5,6 @@ from collections import defaultdict from contextlib import AsyncExitStack, closing from datetime import datetime, timedelta, timezone from json import JSONDecodeError -from logging import Logger, getLogger from typing import Any, Callable, Iterable, Optional from uuid import UUID @@ -13,29 +12,26 @@ import attr import sniffio from anyio import TASK_STATUS_IGNORED, create_task_group, sleep from attr import asdict -from sqlalchemy import ( - JSON, TIMESTAMP, Column, Enum, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, - func, or_, select) +from sqlalchemy import and_, bindparam, func, or_, select from sqlalchemy.engine import URL, Result -from sqlalchemy.exc import CompileError, IntegrityError +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.sql.ddl import DropTable -from sqlalchemy.sql.elements import BindParameter, literal +from sqlalchemy.sql.elements import BindParameter from .. import events as events_module -from ..abc import AsyncDataStore, Job, Schedule, Serializer -from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome +from ..abc import AsyncDataStore, Job, Schedule +from ..enums import ConflictPolicy from ..events import ( AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded, ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded, TaskRemoved, TaskUpdated) from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError from ..marshalling import callable_to_ref -from ..serializers.pickle import PickleSerializer from ..structures import JobResult, Task from ..util import reentrant -from .sqlalchemy import EmulatedTimestampTZ, EmulatedUUID +from .sqlalchemy import _BaseSQLAlchemyDataStore def default_json_handler(obj: Any) -> Any: @@ -63,37 +59,14 @@ def json_object_hook(obj: dict[str, Any]) -> Any: @reentrant @attr.define(eq=False) -class AsyncSQLAlchemyDataStore(AsyncDataStore): +class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore): engine: AsyncEngine - schema: Optional[str] = attr.field(default=None, kw_only=True) - serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True) - lock_expiration_delay: float = attr.field(default=30, kw_only=True) - max_poll_time: Optional[float] = attr.field(default=1, kw_only=True) - max_idle_time: float = attr.field(default=60, kw_only=True) - notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True) - start_from_scratch: bool = attr.field(default=False, kw_only=True) - - _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) + _exit_stack: AsyncExitStack = attr.field(init=False, factory=AsyncExitStack) _events: AsyncEventHub = attr.field(init=False, factory=AsyncEventHub) def __attrs_post_init__(self) -> None: - # Generate the table definitions - self._metadata = self.get_table_definitions() - self.t_metadata = self._metadata.tables['metadata'] - self.t_tasks = self._metadata.tables['tasks'] - self.t_schedules = self._metadata.tables['schedules'] - self.t_jobs = self._metadata.tables['jobs'] - self.t_job_results = self._metadata.tables['job_results'] - - # Find out if the dialect supports UPDATE...RETURNING - update = self.t_jobs.update().returning(self.t_jobs.c.id) - try: - update.compile(bind=self.engine) - except CompileError: - self._supports_update_returning = False - else: - self._supports_update_returning = True + super().__attrs_post_init__() if self.notify_channel: if self.engine.dialect.name != 'postgresql' or self.engine.dialect.driver != 'asyncpg': @@ -138,76 +111,6 @@ class AsyncSQLAlchemyDataStore(AsyncDataStore): async def __aexit__(self, exc_type, exc_val, exc_tb): await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - def get_table_definitions(self) -> MetaData: - if self.engine.dialect.name == 'postgresql': - from sqlalchemy.dialects import postgresql - - timestamp_type = TIMESTAMP(timezone=True) - job_id_type = postgresql.UUID(as_uuid=True) - else: - timestamp_type = EmulatedTimestampTZ - job_id_type = EmulatedUUID - - metadata = MetaData() - Table( - 'metadata', - metadata, - Column('schema_version', Integer, nullable=False) - ) - Table( - 'tasks', - metadata, - Column('id', Unicode(500), primary_key=True), - Column('func', Unicode(500), nullable=False), - Column('state', LargeBinary), - Column('max_running_jobs', Integer), - Column('misfire_grace_time', Unicode(16)), - Column('running_jobs', Integer, nullable=False, server_default=literal(0)) - ) - Table( - 'schedules', - metadata, - Column('id', Unicode(500), primary_key=True), - Column('task_id', Unicode(500), nullable=False, index=True), - Column('trigger', LargeBinary), - Column('args', LargeBinary), - Column('kwargs', LargeBinary), - Column('coalesce', Enum(CoalescePolicy), nullable=False), - Column('misfire_grace_time', Unicode(16)), - # Column('max_jitter', Unicode(16)), - Column('tags', JSON, nullable=False), - Column('next_fire_time', timestamp_type, index=True), - Column('last_fire_time', timestamp_type), - Column('acquired_by', Unicode(500)), - Column('acquired_until', timestamp_type) - ) - Table( - 'jobs', - metadata, - Column('id', job_id_type, primary_key=True), - Column('task_id', Unicode(500), nullable=False, index=True), - Column('args', LargeBinary, nullable=False), - Column('kwargs', LargeBinary, nullable=False), - Column('schedule_id', Unicode(500)), - Column('scheduled_fire_time', timestamp_type), - Column('start_deadline', timestamp_type), - Column('tags', JSON, nullable=False), - Column('created_at', timestamp_type, nullable=False), - Column('started_at', timestamp_type), - Column('acquired_by', Unicode(500)), - Column('acquired_until', timestamp_type) - ) - Table( - 'job_results', - metadata, - Column('job_id', job_id_type, primary_key=True), - Column('outcome', Enum(JobOutcome), nullable=False), - Column('finished_at', timestamp_type, index=True), - Column('exception', LargeBinary), - Column('return_value', LargeBinary) - ) - return metadata - async def _publish(self, conn: AsyncConnection, event: DataStoreEvent) -> None: if self.notify_channel: event_type = event.__class__.__name__ diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index 11db258..d49ca61 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -51,20 +51,17 @@ class EmulatedTimestampTZ(TypeDecorator): return datetime.fromisoformat(value) if value is not None else None -@reentrant -@attr.define(eq=False) -class SQLAlchemyDataStore(DataStore): - engine: Engine - schema: Optional[str] = attr.field(default=None, kw_only=True) - serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True) - lock_expiration_delay: float = attr.field(default=30, kw_only=True) - max_poll_time: Optional[float] = attr.field(default=1, kw_only=True) - max_idle_time: float = attr.field(default=60, kw_only=True) - notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True) - start_from_scratch: bool = attr.field(default=False, kw_only=True) +@attr.define(kw_only=True, eq=False) +class _BaseSQLAlchemyDataStore: + schema: Optional[str] = attr.field(default=None) + serializer: Serializer = attr.field(factory=PickleSerializer) + lock_expiration_delay: float = attr.field(default=30) + max_poll_time: Optional[float] = attr.field(default=1) + max_idle_time: float = attr.field(default=60) + notify_channel: Optional[str] = attr.field(default='apscheduler') + start_from_scratch: bool = attr.field(default=False) _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) - _events: EventHub = attr.field(init=False, factory=EventHub) def __attrs_post_init__(self) -> None: # Generate the table definitions @@ -84,33 +81,6 @@ class SQLAlchemyDataStore(DataStore): else: self._supports_update_returning = True - @classmethod - def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: - engine = create_engine(url) - return cls(engine, **options) - - def __enter__(self): - with self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - conn.execute(DropTable(table, if_exists=True)) - - self._metadata.create_all(conn) - query = select(self.t_metadata.c.schema_version) - result = conn.execute(query) - version = result.scalar() - if version is None: - conn.execute(self.t_metadata.insert(values={'schema_version': 1})) - elif version > 1: - raise RuntimeError(f'Unexpected schema version ({version}); ' - f'only version 1 is supported by this version of APScheduler') - - self._events.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._events.__exit__(exc_type, exc_val, exc_tb) - def get_table_definitions(self) -> MetaData: if self.engine.dialect.name == 'postgresql': from sqlalchemy.dialects import postgresql @@ -203,6 +173,41 @@ class SQLAlchemyDataStore(DataStore): return jobs + +@reentrant +@attr.define(eq=False) +class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): + engine: Engine + + _events: EventHub = attr.field(init=False, factory=EventHub) + + @classmethod + def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: + engine = create_engine(url) + return cls(engine, **options) + + def __enter__(self): + with self.engine.begin() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + conn.execute(DropTable(table, if_exists=True)) + + self._metadata.create_all(conn) + query = select(self.t_metadata.c.schema_version) + result = conn.execute(query) + version = result.scalar() + if version is None: + conn.execute(self.t_metadata.insert(values={'schema_version': 1})) + elif version > 1: + raise RuntimeError(f'Unexpected schema version ({version}); ' + f'only version 1 is supported by this version of APScheduler') + + self._events.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._events.__exit__(exc_type, exc_val, exc_tb) + def subscribe(self, callback: Callable[[Event], Any], event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: return self._events.subscribe(callback, event_types) |