diff options
author | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-08 22:12:37 +0300 |
---|---|---|
committer | Alex Grönholm <alex.gronholm@nextday.fi> | 2021-09-08 22:12:37 +0300 |
commit | 7248a78e7e787b728b083aaa8199eeba3a3f3023 (patch) | |
tree | f3cd37b3809a6dd82ecc72a43c07e76d1b062257 /src/apscheduler/datastores/sqlalchemy.py | |
parent | 114e041fa434a36f27c80130b6c0667da5497047 (diff) | |
download | apscheduler-7248a78e7e787b728b083aaa8199eeba3a3f3023.tar.gz |
Deduplicated some SQLAlchemy store code
Diffstat (limited to 'src/apscheduler/datastores/sqlalchemy.py')
-rw-r--r-- | src/apscheduler/datastores/sqlalchemy.py | 83 |
1 files changed, 44 insertions, 39 deletions
diff --git a/src/apscheduler/datastores/sqlalchemy.py b/src/apscheduler/datastores/sqlalchemy.py index 11db258..d49ca61 100644 --- a/src/apscheduler/datastores/sqlalchemy.py +++ b/src/apscheduler/datastores/sqlalchemy.py @@ -51,20 +51,17 @@ class EmulatedTimestampTZ(TypeDecorator): return datetime.fromisoformat(value) if value is not None else None -@reentrant -@attr.define(eq=False) -class SQLAlchemyDataStore(DataStore): - engine: Engine - schema: Optional[str] = attr.field(default=None, kw_only=True) - serializer: Serializer = attr.field(factory=PickleSerializer, kw_only=True) - lock_expiration_delay: float = attr.field(default=30, kw_only=True) - max_poll_time: Optional[float] = attr.field(default=1, kw_only=True) - max_idle_time: float = attr.field(default=60, kw_only=True) - notify_channel: Optional[str] = attr.field(default='apscheduler', kw_only=True) - start_from_scratch: bool = attr.field(default=False, kw_only=True) +@attr.define(kw_only=True, eq=False) +class _BaseSQLAlchemyDataStore: + schema: Optional[str] = attr.field(default=None) + serializer: Serializer = attr.field(factory=PickleSerializer) + lock_expiration_delay: float = attr.field(default=30) + max_poll_time: Optional[float] = attr.field(default=1) + max_idle_time: float = attr.field(default=60) + notify_channel: Optional[str] = attr.field(default='apscheduler') + start_from_scratch: bool = attr.field(default=False) _logger: Logger = attr.field(init=False, factory=lambda: getLogger(__name__)) - _events: EventHub = attr.field(init=False, factory=EventHub) def __attrs_post_init__(self) -> None: # Generate the table definitions @@ -84,33 +81,6 @@ class SQLAlchemyDataStore(DataStore): else: self._supports_update_returning = True - @classmethod - def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: - engine = create_engine(url) - return cls(engine, **options) - - def __enter__(self): - with self.engine.begin() as conn: - if self.start_from_scratch: - for table in self._metadata.sorted_tables: - conn.execute(DropTable(table, if_exists=True)) - - self._metadata.create_all(conn) - query = select(self.t_metadata.c.schema_version) - result = conn.execute(query) - version = result.scalar() - if version is None: - conn.execute(self.t_metadata.insert(values={'schema_version': 1})) - elif version > 1: - raise RuntimeError(f'Unexpected schema version ({version}); ' - f'only version 1 is supported by this version of APScheduler') - - self._events.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._events.__exit__(exc_type, exc_val, exc_tb) - def get_table_definitions(self) -> MetaData: if self.engine.dialect.name == 'postgresql': from sqlalchemy.dialects import postgresql @@ -203,6 +173,41 @@ class SQLAlchemyDataStore(DataStore): return jobs + +@reentrant +@attr.define(eq=False) +class SQLAlchemyDataStore(_BaseSQLAlchemyDataStore, DataStore): + engine: Engine + + _events: EventHub = attr.field(init=False, factory=EventHub) + + @classmethod + def from_url(cls, url: str | URL, **options) -> SQLAlchemyDataStore: + engine = create_engine(url) + return cls(engine, **options) + + def __enter__(self): + with self.engine.begin() as conn: + if self.start_from_scratch: + for table in self._metadata.sorted_tables: + conn.execute(DropTable(table, if_exists=True)) + + self._metadata.create_all(conn) + query = select(self.t_metadata.c.schema_version) + result = conn.execute(query) + version = result.scalar() + if version is None: + conn.execute(self.t_metadata.insert(values={'schema_version': 1})) + elif version > 1: + raise RuntimeError(f'Unexpected schema version ({version}); ' + f'only version 1 is supported by this version of APScheduler') + + self._events.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._events.__exit__(exc_type, exc_val, exc_tb) + def subscribe(self, callback: Callable[[Event], Any], event_types: Optional[Iterable[type[Event]]] = None) -> SubscriptionToken: return self._events.subscribe(callback, event_types) |