diff options
author | Federico Caselli <cfederico87@gmail.com> | 2023-04-29 23:25:21 +0200 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-05-12 12:06:14 -0400 |
commit | 6d2df6c92c981d220bd1481ebdc7d86274ccf845 (patch) | |
tree | be009a39f769dd4cff23d1fd917f651adeccee72 | |
parent | 92e54a0e1c96cecd99397cb1aee9c3bb28f780c6 (diff) | |
download | alembic-6d2df6c92c981d220bd1481ebdc7d86274ccf845.tar.gz |
Added ``op.run_async``.
Added :meth:`.Operations.run_async` to the operation module to allow
running async functions in the ``upgrade`` or ``downgrade`` migration
function when running alembic using an async dialect.
This function will receive as first argument an
class:`~sqlalchemy.ext.asyncio.AsyncConnection` sharing the transaction
used in the migration context.
also restore the .execute() method to BatchOperations
Fixes: #1231
Change-Id: I3c3237d570be3c9bd9834e4c61bb3231bfb82765
-rw-r--r-- | alembic/op.pyi | 29 | ||||
-rw-r--r-- | alembic/operations/base.py | 45 | ||||
-rw-r--r-- | alembic/operations/ops.py | 1 | ||||
-rw-r--r-- | alembic/util/sqla_compat.py | 2 | ||||
-rw-r--r-- | docs/build/unreleased/1231.rst | 11 | ||||
-rw-r--r-- | tests/test_op.py | 44 | ||||
-rw-r--r-- | tools/write_pyi.py | 5 |
7 files changed, 134 insertions, 3 deletions
diff --git a/alembic/op.pyi b/alembic/op.pyi index aa3ad2d..4395f77 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -4,6 +4,7 @@ from __future__ import annotations from contextlib import contextmanager from typing import Any +from typing import Awaitable from typing import Callable from typing import Dict from typing import Iterator @@ -15,6 +16,7 @@ from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from sqlalchemy.sql.expression import TableClause @@ -38,6 +40,8 @@ if TYPE_CHECKING: from .operations.ops import MigrateOperation from .runtime.migration import MigrationContext from .util.sqla_compat import _literal_bindparam + +_T = TypeVar("_T") ### end imports ### def add_column( @@ -1238,3 +1242,28 @@ def rename_table( :class:`~sqlalchemy.sql.elements.quoted_name`. """ + +def run_async( + async_function: Callable[..., Awaitable[_T]], *args: Any, **kw_args: Any +) -> _T: + """Invoke the given asynchronous callable, passing an asynchronous + :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first + argument. + + This method allows calling async functions from within the + synchronous ``upgrade()`` or ``downgrade()`` alembic migration + method. + + The async connection passed to the callable shares the same + transaction as the connection running in the migration context. + + Any additional arg or kw_arg passed to this function are passed + to the provided async function. + + .. versionadded: 1.11 + + .. note:: + + This method can be called only when alembic is called using + an async dialect. + """ diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 6e45a11..b4190dc 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -4,6 +4,7 @@ from contextlib import contextmanager import re import textwrap from typing import Any +from typing import Awaitable from typing import Callable from typing import Dict from typing import Iterator @@ -14,6 +15,7 @@ from typing import Sequence # noqa from typing import Tuple from typing import Type # noqa from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from sqlalchemy.sql.elements import conv @@ -28,8 +30,6 @@ from ..util.compat import inspect_getfullargspec from ..util.sqla_compat import _literal_bindparam -NoneType = type(None) - if TYPE_CHECKING: from typing import Literal @@ -51,6 +51,7 @@ if TYPE_CHECKING: from ..ddl import DefaultImpl from ..runtime.migration import MigrationContext __all__ = ("Operations", "BatchOperations") +_T = TypeVar("_T") class AbstractOperations(util.ModuleClsProxy): @@ -483,6 +484,46 @@ class AbstractOperations(util.ModuleClsProxy): """ return self.migration_context.impl.bind # type: ignore[return-value] + def run_async( + self, + async_function: Callable[..., Awaitable[_T]], + *args: Any, + **kw_args: Any, + ) -> _T: + """Invoke the given asynchronous callable, passing an asynchronous + :class:`~sqlalchemy.ext.asyncio.AsyncConnection` as the first + argument. + + This method allows calling async functions from within the + synchronous ``upgrade()`` or ``downgrade()`` alembic migration + method. + + The async connection passed to the callable shares the same + transaction as the connection running in the migration context. + + Any additional arg or kw_arg passed to this function are passed + to the provided async function. + + .. versionadded: 1.11 + + .. note:: + + This method can be called only when alembic is called using + an async dialect. + """ + if not sqla_compat.sqla_14_18: + raise NotImplementedError("SQLAlchemy 1.4.18+ required") + sync_conn = self.get_bind() + if sync_conn is None: + raise NotImplementedError("Cannot call run_async in SQL mode") + if not sync_conn.dialect.is_async: + raise ValueError("Cannot call run_async with a sync engine") + from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.util import await_only + + async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn) + return await_only(async_function(async_conn, *args, **kw_args)) + class Operations(AbstractOperations): """Define high level migration operations. diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 99d21d9..3a002c1 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -2375,6 +2375,7 @@ class BulkInsertOp(MigrateOperation): @Operations.register_operation("execute") +@BatchOperations.register_operation("execute") class ExecuteSQLOp(MigrateOperation): """Represent an execute SQL operation.""" diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 0070337..37e1ee1 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -61,6 +61,8 @@ _vers = tuple( ) sqla_13 = _vers >= (1, 3) sqla_14 = _vers >= (1, 4) +# https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d +sqla_14_18 = _vers >= (1, 4, 18) sqla_14_26 = _vers >= (1, 4, 26) sqla_2 = _vers >= (2,) sqlalchemy_version = __version__ diff --git a/docs/build/unreleased/1231.rst b/docs/build/unreleased/1231.rst new file mode 100644 index 0000000..37678ca --- /dev/null +++ b/docs/build/unreleased/1231.rst @@ -0,0 +1,11 @@ + +.. change:: + :tags: usecase, asyncio + :tickets: 1231 + + Added :meth:`.Operations.run_async` to the operation module to allow + running async functions in the ``upgrade`` or ``downgrade`` migration + function when running alembic using an async dialect. + This function will receive as first argument an + :class:`~sqlalchemy.ext.asyncio.AsyncConnection` sharing the transaction + used in the migration context. diff --git a/tests/test_op.py b/tests/test_op.py index 8ae22a0..35adeaf 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -1,5 +1,8 @@ """Test against the builders in the op.* module.""" +from unittest.mock import MagicMock +from unittest.mock import patch + from sqlalchemy import Boolean from sqlalchemy import CheckConstraint from sqlalchemy import Column @@ -30,6 +33,7 @@ from alembic.testing import eq_ from alembic.testing import expect_warnings from alembic.testing import is_not_ from alembic.testing import mock +from alembic.testing.assertions import expect_raises_message from alembic.testing.fixtures import op_fixture from alembic.testing.fixtures import TestBase from alembic.util import sqla_compat @@ -1156,6 +1160,46 @@ class OpTest(TestBase): ("after_drop", "tb_test"), ] + @config.requirements.sqlalchemy_14 + def test_run_async_error(self): + op_fixture() + + async def go(conn): + pass + + with expect_raises_message( + NotImplementedError, "SQLAlchemy 1.4.18. required" + ): + with patch.object(sqla_compat, "sqla_14_18", False): + op.run_async(go) + with expect_raises_message( + NotImplementedError, "Cannot call run_async in SQL mode" + ): + with patch.object(op._proxy, "get_bind", lambda: None): + op.run_async(go) + with expect_raises_message( + ValueError, "Cannot call run_async with a sync engine" + ): + op.run_async(go) + + @config.requirements.sqlalchemy_14 + def test_run_async_ok(self): + from sqlalchemy.ext.asyncio import AsyncConnection + + op_fixture() + conn = op.get_bind() + mock_conn = MagicMock() + mock_fn = MagicMock() + with patch.object(conn.dialect, "is_async", True), patch.object( + AsyncConnection, "_retrieve_proxy_for_target", mock_conn + ), patch("sqlalchemy.util.await_only") as mock_await: + res = op.run_async(mock_fn, 99, foo=42) + + eq_(res, mock_await.return_value) + mock_conn.assert_called_once_with(conn) + mock_await.assert_called_once_with(mock_fn.return_value) + mock_fn.assert_called_once_with(mock_conn.return_value, 99, foo=42) + class SQLModeOpTest(TestBase): def test_auto_literals(self): diff --git a/tools/write_pyi.py b/tools/write_pyi.py index 7d24870..82ceead 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -28,7 +28,6 @@ if True: # avoid flake/zimports messing with the order from alembic.operations import ops import sqlalchemy as sa - TRIM_MODULE = [ "alembic.runtime.migration.", "alembic.operations.base.", @@ -179,9 +178,12 @@ def _generate_stub_for_meth( retval = repr(annotation).replace("typing.", "") elif isinstance(annotation, type): retval = annotation.__qualname__ + elif isinstance(annotation, typing.TypeVar): + retval = annotation.__name__ else: retval = annotation + retval = retval.replace("~", "") # typevar repr as "~T" for trim in TRIM_MODULE: retval = retval.replace(trim, "") @@ -371,6 +373,7 @@ cls_ignore = { "inline_literal", "invoke", "register_operation", + "run_async", } cases = [ |