summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-08 22:12:37 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-08 22:12:37 +0300
commit7248a78e7e787b728b083aaa8199eeba3a3f3023 (patch)
treef3cd37b3809a6dd82ecc72a43c07e76d1b062257
parent114e041fa434a36f27c80130b6c0667da5497047 (diff)
downloadapscheduler-7248a78e7e787b728b083aaa8199eeba3a3f3023.tar.gz
Deduplicated some SQLAlchemy store code
-rw-r--r--src/apscheduler/datastores/async_sqlalchemy.py115
-rw-r--r--src/apscheduler/datastores/sqlalchemy.py83
2 files changed, 53 insertions, 145 deletions
diff --git a/src/apscheduler/datastores/async_sqlalchemy.py b/src/apscheduler/datastores/async_sqlalchemy.py
index fad2cd3..c0d21cb 100644
--- a/src/apscheduler/datastores/async_sqlalchemy.py
+++ b/src/apscheduler/datastores/async_sqlalchemy.py
@@ -5,7 +5,6 @@ from collections import defaultdict
from contextlib import AsyncExitStack, closing
from datetime import datetime, timedelta, timezone
from json import JSONDecodeError
-from logging import Logger, getLogger
from typing import Any, Callable, Iterable, Optional
from uuid import UUID
@@ -13,29 +12,26 @@ import attr
import sniffio
from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
from attr import asdict
-from sqlalchemy import (
- JSON, TIMESTAMP, Column, Enum, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam,
- func, or_, select)
+from sqlalchemy import and_, bindparam, func, or_, select
from sqlalchemy.engine import URL, Result
-from sqlalchemy.exc import CompileError, IntegrityError
+from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.ddl import DropTable
-from sqlalchemy.sql.elements import BindParameter, literal
+from sqlalchemy.sql.elements import BindParameter
from .. import events as events_module
-from ..abc import AsyncDataStore, Job, Schedule, Serializer
-from ..enums import CoalescePolicy, ConflictPolicy, JobOutcome
+from ..abc import AsyncDataStore, Job, Schedule
+from ..enums import ConflictPolicy
from ..events import (
AsyncEventHub, DataStoreEvent, Event, JobAdded, JobDeserializationFailed, ScheduleAdded,
ScheduleDeserializationFailed, ScheduleRemoved, ScheduleUpdated, SubscriptionToken, TaskAdded,
TaskRemoved, TaskUpdated)
from ..exceptions import ConflictingIdError, SerializationError, TaskLookupError
from ..marshalling import callable_to_ref
-from ..serializers.pickle import PickleSerializer
from ..structures import JobResult, Task
from ..util import reentrant
-from .sqlalchemy import EmulatedTimestampTZ, EmulatedUUID
+from .sqlalchemy import _BaseSQLAlchemyDataStore
def default_json_handler(obj: Any) -> Any:
@@ -63,37 +59,14 @@ def json_object_hook(obj: dict[str, Any]) -> Any:
@reentrant
@attr.define(eq=False)
-class AsyncSQLAlchemyDataStore(AsyncDataStore):
+class AsyncSQLAlchemyDataStore(_BaseSQLAlchemyDataStore, AsyncDataStore):
engine: AsyncEngine
- schema: Optional[str] = attr.field(default=None, kw_only=True)
- serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True)
- lock_expiration_delay: float = attr.field(default=30, kw_only=True)
- max_poll_time: Optional[float] = attr.field(default=1, kw_only=True)
- max_idle_time: float = attr.field(default=60, kw_only=True)
- notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True)
- start_from_scratch: bool = attr.field(default=False, kw_only=True)
-
- _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
+
_exit_stack: AsyncExitStack = attr.field(init=False, factory=AsyncExitStack)
_events: AsyncEventHub = attr.field(init=False, factory=AsyncEventHub)
def __attrs_post_init__(self) -> None:
- # Generate the table definitions
- self._metadata = self.get_table_definitions()
- self.t_metadata = self._metadata.tables['metadata']
- self.t_tasks = self._metadata.tables['tasks']
- self.t_schedules = self._metadata.tables['schedules']
- self.t_jobs = self._metadata.tables['jobs']
- self.t_job_results = self._metadata.tables['job_results']
-
- # Find out if the dialect supports UPDATE...RETURNING
- update = self.t_jobs.update().returning(self.t_jobs.c.id)
- try:
- update.compile(bind=self.engine)
- except CompileError:
- self._supports_update_returning = False
- else:
- self._supports_update_returning = True
+ super().__attrs_post_init__()
if self.notify_channel:
if self.engine.dialect.name != 'postgresql' or self.engine.dialect.driver != 'asyncpg':
@@ -138,76 +111,6 @@ class AsyncSQLAlchemyDataStore(AsyncDataStore):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
- def get_table_definitions(self) -> MetaData:
- if self.engine.dialect.name == 'postgresql':
- from sqlalchemy.dialects import postgresql
-
- timestamp_type = TIMESTAMP(timezone=True)
- job_id_type = postgresql.UUID(as_uuid=True)
- else:
- timestamp_type = EmulatedTimestampTZ
- job_id_type = EmulatedUUID
-
- metadata = MetaData()
- Table(
- 'metadata',
- metadata,
- Column('schema_version', Integer, nullable=False)
- )
- Table(
- 'tasks',
- metadata,
- Column('id', Unicode(500), primary_key=True),
- Column('func', Unicode(500), nullable=False),
- Column('state', LargeBinary),
- Column('max_running_jobs', Integer),
- Column('misfire_grace_time', Unicode(16)),
- Column('running_jobs', Integer, nullable=False, server_default=literal(0))
- )
- Table(
- 'schedules',
- metadata,
- Column('id', Unicode(500), primary_key=True),
- Column('task_id', Unicode(500), nullable=False, index=True),
- Column('trigger', LargeBinary),
- Column('args', LargeBinary),
- Column('kwargs', LargeBinary),
- Column('coalesce', Enum(CoalescePolicy), nullable=False),
- Column('misfire_grace_time', Unicode(16)),
- # Column('max_jitter', Unicode(16)),
- Column('tags', JSON, nullable=False),
- Column('next_fire_time', timestamp_type, index=True),
- Column('last_fire_time', timestamp_type),
- Column('acquired_by', Unicode(500)),
- Column('acquired_until', timestamp_type)
- )
- Table(
- 'jobs',
- metadata,
- Column('id', job_id_type, primary_key=True),
- Column('task_id', Unicode(500), nullable=False, index=True),
- Column('args', LargeBinary, nullable=False),
- Column('kwargs', LargeBinary, nullable=False),
- Column('schedule_id', Unicode(500)),
- Column('scheduled_fire_time', timestamp_type),
- Column('start_deadline', timestamp_type),
- Column('tags', JSON, nullable=False),
- Column('created_at', timestamp_type, nullable=False),
- Column('started_at', timestamp_type),
- Column('acquired_by', Unicode(500)),
- Column('acquired_until', timestamp_type)
- )
- Table(
- 'job_results',
- metadata,
- Column('job_id', job_id_type, primary_key=True),
- Column('outcome', Enum(JobOutcome), nullable=False),
- Column('finished_at', timestamp_type, index=True),
- Column('exception', LargeBinary),
- Column('return_value', LargeBinary)
- )
- return metadata
-
async def _publish(self, conn: AsyncConnection, event: DataStoreEvent) -> None:
if self.notify_channel:
event_type = event.__class__.__name__
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py
index 11db258..d49ca61 100644
--- a/src/apscheduler/datastores/sqlalchemy.py
+++ b/src/apscheduler/datastores/sqlalchemy.py
@@ -51,20 +51,17 @@ class EmulatedTimestampTZ(TypeDecorator):
return datetime.fromisoformat(value) if value is not None else None
-@reentrant
-@attr.define(eq=False)
-class SQLAlchemyDataStore(DataStore):
- engine: Engine
- schema: Optional[str] = attr.field(default=None, kw_only=True)
- serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True)
- lock_expiration_delay: float = attr.field(default=30, kw_only=True)
- max_poll_time: Optional[float] = attr.field(default=1, kw_only=True)
- max_idle_time: float = attr.field(default=60, kw_only=True)
- notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True)
- start_from_scratch: bool = attr.field(default=False, kw_only=True)
+@attr.define(kw_only=True, eq=False)
+class _BaseSQLAlchemyDataStore:
+ schema: Optional[str] = attr.field(default=None)
+ serializer: Serializer = attr.field(factory=PickleSerializer)
+ lock_expiration_delay: float = attr.field(default=30)
+ max_poll_time: Optional[float] = attr.field(default=1)
+ max_idle_time: float = attr.field(default=60)
+ notify_channel: Optional[str] = attr.field(default='apscheduler')
+ start_from_scratch: bool = attr.field(default=False)
_logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__))
- _events: EventHub = attr.field(init=False, factory=EventHub)
def __attrs_post_init__(self) -> None:
# Generate the table definitions
@@ -84,33 +81,6 @@ class SQLAlchemyDataStore(DataStore):
else:
self._supports_update_returning = True
- @classmethod
- def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore:
- engine = create_engine(url)
- return cls(engine, **options)
-
- def __enter__(self):
- with self.engine.begin() as conn:
- if self.start_from_scratch:
- for table in self._metadata.sorted_tables:
- conn.execute(DropTable(table, if_exists=True))
-
- self._metadata.create_all(conn)
- query = select(self.t_metadata.c.schema_version)
- result = conn.execute(query)
- version = result.scalar()
- if version is None:
- conn.execute(self.t_metadata.insert(values={'schema_version': 1}))
- elif version > 1:
- raise RuntimeError(f'Unexpected schema version ({version}); '
- f'only version 1 is supported by this version of APScheduler')
-
- self._events.__enter__()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._events.__exit__(exc_type, exc_val, exc_tb)
-
def get_table_definitions(self) -> MetaData:
if self.engine.dialect.name == 'postgresql':
from sqlalchemy.dialects import postgresql
@@ -203,6 +173,41 @@ class SQLAlchemyDataStore(DataStore):
return jobs
+
+@reentrant
+@attr.define(eq=False)
+class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore):
+ engine: Engine
+
+ _events: EventHub = attr.field(init=False, factory=EventHub)
+
+ @classmethod
+ def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore:
+ engine = create_engine(url)
+ return cls(engine, **options)
+
+ def __enter__(self):
+ with self.engine.begin() as conn:
+ if self.start_from_scratch:
+ for table in self._metadata.sorted_tables:
+ conn.execute(DropTable(table, if_exists=True))
+
+ self._metadata.create_all(conn)
+ query = select(self.t_metadata.c.schema_version)
+ result = conn.execute(query)
+ version = result.scalar()
+ if version is None:
+ conn.execute(self.t_metadata.insert(values={'schema_version': 1}))
+ elif version > 1:
+ raise RuntimeError(f'Unexpected schema version ({version}); '
+ f'only version 1 is supported by this version of APScheduler')
+
+ self._events.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._events.__exit__(exc_type, exc_val, exc_tb)
+
def subscribe(self, callback: Callable[[Event], Any],
event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken:
return self._events.subscribe(callback, event_types)