summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-06 22:46:49 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-06 22:46:49 +0300
commit4e2585a6f613905135164d3f6a5c6adf752ba441 (patch)
tree772163753df70b4a540d7c7678d05effbd3110b2
parent279d9c42059a29619c89553d5468b4b6ca43dd6d (diff)
downloadapscheduler-4e2585a6f613905135164d3f6a5c6adf752ba441.tar.gz
Use the real UUID column type where supported
-rw-r--r--src/apscheduler/datastores/async_/sqlalchemy.py46
-rw-r--r--src/apscheduler/datastores/sync/sqlalchemy.py46
2 files changed, 64 insertions, 28 deletions
diff --git a/src/apscheduler/datastores/async_/sqlalchemy.py b/src/apscheduler/datastores/async_/sqlalchemy.py
index 619b1c5..beaac4f 100644
--- a/src/apscheduler/datastores/async_/sqlalchemy.py
+++ b/src/apscheduler/datastores/async_/sqlalchemy.py
@@ -14,8 +14,9 @@ import sniffio
from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
from attr import asdict
from sqlalchemy import (
- Column, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, func, or_, select)
-from sqlalchemy.engine import URL
+ TIMESTAMP, Column, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, and_,
+ bindparam, func, or_, select)
+from sqlalchemy.engine import URL, Dialect
from sqlalchemy.exc import CompileError, IntegrityError
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
@@ -59,6 +60,17 @@ def json_object_hook(obj: dict[str, Any]) -> Any:
return obj
+class EmulatedUUID(TypeDecorator):
+ impl = Unicode(32)
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect: Dialect) -> Any:
+ return value.hex
+
+ def process_result_value(self, value: Any, dialect: Dialect):
+ return UUID(value) if value else None
+
+
@reentrant
@attr.define(eq=False)
class SQLAlchemyDataStore(AsyncDataStore):
@@ -138,12 +150,18 @@ class SQLAlchemyDataStore(AsyncDataStore):
def get_table_definitions(self) -> MetaData:
if self.engine.dialect.name in ('mysql', 'mariadb'):
- from sqlalchemy.dialects.mysql import TIMESTAMP
- timestamp_type = TIMESTAMP(fsp=6)
+ from sqlalchemy.dialects import mysql
+ timestamp_type = mysql.TIMESTAMP(fsp=6)
else:
- from sqlalchemy.types import TIMESTAMP
timestamp_type = TIMESTAMP(timezone=True)
+ if self.engine.dialect.name == 'postgresql':
+ from sqlalchemy.dialects import postgresql
+
+ job_id_type = postgresql.UUID(as_uuid=True)
+ else:
+ job_id_type = EmulatedUUID
+
metadata = MetaData()
Table(
'metadata',
@@ -173,7 +191,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
Table(
'jobs',
metadata,
- Column('id', Unicode(32), primary_key=True),
+ Column('id', job_id_type, primary_key=True),
Column('task_id', Unicode(500), nullable=False, index=True),
Column('serialized_data', LargeBinary, nullable=False),
Column('created_at', timestamp_type, nullable=False),
@@ -183,7 +201,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
Table(
'job_results',
metadata,
- Column('job_id', Unicode(32), primary_key=True),
+ Column('job_id', job_id_type, primary_key=True),
Column('finished_at', timestamp_type, index=True),
Column('serialized_data', LargeBinary, nullable=False)
)
@@ -465,7 +483,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
async def add_job(self, job: Job) -> None:
now = datetime.now(timezone.utc)
serialized_data = self.serializer.serialize(job)
- insert = self.t_jobs.insert().values(id=job.id.hex, task_id=job.task_id,
+ insert = self.t_jobs.insert().values(id=job.id, task_id=job.task_id,
created_at=now, serialized_data=serialized_data)
async with self.engine.begin() as conn:
await conn.execute(insert)
@@ -478,7 +496,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\
order_by(self.t_jobs.c.id)
if ids:
- job_ids = [job_id.hex for job_id in ids]
+ job_ids = [job_id for job_id in ids]
query = query.where(self.t_jobs.c.id.in_(job_ids))
async with self.engine.begin() as conn:
@@ -529,7 +547,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
if acquired_jobs:
# Mark the acquired jobs as acquired by this worker
- acquired_job_ids = [job.id.hex for job in acquired_jobs]
+ acquired_job_ids = [job.id for job in acquired_jobs]
update = self.t_jobs.update().\
values(acquired_by=worker_id, acquired_until=acquired_until).\
where(self.t_jobs.c.id.in_(acquired_job_ids))
@@ -553,7 +571,7 @@ class SQLAlchemyDataStore(AsyncDataStore):
now = datetime.now(timezone.utc)
serialized_data = self.serializer.serialize(result)
insert = self.t_job_results.insert().\
- values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_data)
+ values(job_id=job.id, finished_at=now, serialized_data=serialized_data)
await conn.execute(insert)
# Decrement the running jobs counter
@@ -563,17 +581,17 @@ class SQLAlchemyDataStore(AsyncDataStore):
await conn.execute(update)
# Delete the job
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex)
+ delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id)
await conn.execute(delete)
async def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
async with self.engine.begin() as conn:
query = select(self.t_job_results.c.serialized_data).\
- where(self.t_job_results.c.job_id == job_id.hex)
+ where(self.t_job_results.c.job_id == job_id)
result = await conn.execute(query)
delete = self.t_job_results.delete().\
- where(self.t_job_results.c.job_id == job_id.hex)
+ where(self.t_job_results.c.job_id == job_id)
await conn.execute(delete)
serialized_data = result.scalar()
diff --git a/src/apscheduler/datastores/sync/sqlalchemy.py b/src/apscheduler/datastores/sync/sqlalchemy.py
index d4c1d3f..4914c1b 100644
--- a/src/apscheduler/datastores/sync/sqlalchemy.py
+++ b/src/apscheduler/datastores/sync/sqlalchemy.py
@@ -8,8 +8,9 @@ from uuid import UUID
import attr
from sqlalchemy import (
- Column, Integer, LargeBinary, MetaData, Table, Unicode, and_, bindparam, or_, select)
-from sqlalchemy.engine import URL
+ TIMESTAMP, Column, Integer, LargeBinary, MetaData, Table, TypeDecorator, Unicode, and_,
+ bindparam, or_, select)
+from sqlalchemy.engine import URL, Dialect
from sqlalchemy.exc import CompileError, IntegrityError
from sqlalchemy.future import Engine, create_engine
from sqlalchemy.sql.ddl import DropTable
@@ -28,6 +29,17 @@ from ...structures import JobResult, Task
from ...util import reentrant
+class EmulatedUUID(TypeDecorator):
+ impl = Unicode(32)
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect: Dialect) -> Any:
+ return value.hex
+
+ def process_result_value(self, value: Any, dialect: Dialect):
+ return UUID(value) if value else None
+
+
@reentrant
@attr.define(eq=False)
class SQLAlchemyDataStore(DataStore):
@@ -90,12 +102,18 @@ class SQLAlchemyDataStore(DataStore):
def get_table_definitions(self) -> MetaData:
if self.engine.dialect.name in ('mysql', 'mariadb'):
- from sqlalchemy.dialects.mysql import TIMESTAMP
- timestamp_type = TIMESTAMP(fsp=6)
+ from sqlalchemy.dialects import mysql
+ timestamp_type = mysql.TIMESTAMP(fsp=6)
else:
- from sqlalchemy.types import TIMESTAMP
timestamp_type = TIMESTAMP(timezone=True)
+ if self.engine.dialect.name == 'postgresql':
+ from sqlalchemy.dialects import postgresql
+
+ job_id_type = postgresql.UUID(as_uuid=True)
+ else:
+ job_id_type = EmulatedUUID
+
metadata = MetaData()
Table(
'metadata',
@@ -125,7 +143,7 @@ class SQLAlchemyDataStore(DataStore):
Table(
'jobs',
metadata,
- Column('id', Unicode(32), primary_key=True),
+ Column('id', job_id_type, primary_key=True),
Column('task_id', Unicode(500), nullable=False, index=True),
Column('serialized_data', LargeBinary, nullable=False),
Column('created_at', timestamp_type, nullable=False),
@@ -135,7 +153,7 @@ class SQLAlchemyDataStore(DataStore):
Table(
'job_results',
metadata,
- Column('job_id', Unicode(32), primary_key=True),
+ Column('job_id', job_id_type, primary_key=True),
Column('finished_at', timestamp_type, index=True),
Column('serialized_data', LargeBinary, nullable=False)
)
@@ -371,7 +389,7 @@ class SQLAlchemyDataStore(DataStore):
def add_job(self, job: Job) -> None:
now = datetime.now(timezone.utc)
serialized_data = self.serializer.serialize(job)
- insert = self.t_jobs.insert().values(id=job.id.hex, task_id=job.task_id,
+ insert = self.t_jobs.insert().values(id=job.id, task_id=job.task_id,
created_at=now, serialized_data=serialized_data)
with self.engine.begin() as conn:
conn.execute(insert)
@@ -384,7 +402,7 @@ class SQLAlchemyDataStore(DataStore):
query = select([self.t_jobs.c.id, self.t_jobs.c.serialized_data]).\
order_by(self.t_jobs.c.id)
if ids:
- job_ids = [job_id.hex for job_id in ids]
+ job_ids = [job_id for job_id in ids]
query = query.where(self.t_jobs.c.id.in_(job_ids))
with self.engine.begin() as conn:
@@ -434,7 +452,7 @@ class SQLAlchemyDataStore(DataStore):
if acquired_jobs:
# Mark the acquired jobs as acquired by this worker
- acquired_job_ids = [job.id.hex for job in acquired_jobs]
+ acquired_job_ids = [job.id for job in acquired_jobs]
update = self.t_jobs.update().\
values(acquired_by=worker_id, acquired_until=acquired_until).\
where(self.t_jobs.c.id.in_(acquired_job_ids))
@@ -458,7 +476,7 @@ class SQLAlchemyDataStore(DataStore):
now = datetime.now(timezone.utc)
serialized_result = self.serializer.serialize(result)
insert = self.t_job_results.insert().\
- values(job_id=job.id.hex, finished_at=now, serialized_data=serialized_result)
+ values(job_id=job.id, finished_at=now, serialized_data=serialized_result)
conn.execute(insert)
# Decrement the running jobs counter
@@ -468,19 +486,19 @@ class SQLAlchemyDataStore(DataStore):
conn.execute(update)
# Delete the job
- delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id.hex)
+ delete = self.t_jobs.delete().where(self.t_jobs.c.id == job.id)
conn.execute(delete)
def get_job_result(self, job_id: UUID) -> Optional[JobResult]:
with self.engine.begin() as conn:
# Retrieve the result
query = select(self.t_job_results.c.serialized_data).\
- where(self.t_job_results.c.job_id == job_id.hex)
+ where(self.t_job_results.c.job_id == job_id)
result = conn.execute(query)
# Delete the result
delete = self.t_job_results.delete().\
- where(self.t_job_results.c.job_id == job_id.hex)
+ where(self.t_job_results.c.job_id == job_id)
conn.execute(delete)
serialized_result = result.scalar()