summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2023-04-13 22:22:14 +0200
committerFederico Caselli <cfederico87@gmail.com>2023-04-13 22:23:03 +0200
commit4c5f80b0217c2d5b778ab2c5c34d431206d7a743 (patch)
tree0e28ea0b78d8fe7377fa78338847656bf0de9fcb
parent6d5ac4fdf0066fb073636233d8cd334f8f02b5a3 (diff)
downloadalembic-4c5f80b0217c2d5b778ab2c5c34d431206d7a743.tar.gz
Improve typing.
Correctly pass previously ignored arguments ``insert_before`` and ``insert_after`` in ``batch_alter_column`` Fixes: #1221 Change-Id: I79c9144f3e521fca00a0c32462ae2a69f9f7a032
-rw-r--r--alembic/autogenerate/api.py40
-rw-r--r--alembic/autogenerate/compare.py4
-rw-r--r--alembic/context.pyi26
-rw-r--r--alembic/ddl/impl.py6
-rw-r--r--alembic/op.pyi5
-rw-r--r--alembic/operations/base.py5
-rw-r--r--alembic/operations/batch.py2
-rw-r--r--alembic/operations/ops.py24
-rw-r--r--alembic/runtime/environment.py56
-rw-r--r--alembic/runtime/migration.py10
-rw-r--r--alembic/util/sqla_compat.py4
-rw-r--r--docs/build/unreleased/1221.rst6
-rw-r--r--tests/test_mysql.py2
-rw-r--r--tools/write_pyi.py3
14 files changed, 121 insertions, 72 deletions
diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py
index d7a0913..9a3b003 100644
--- a/alembic/autogenerate/api.py
+++ b/alembic/autogenerate/api.py
@@ -9,7 +9,6 @@ from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
-from typing import Union
from sqlalchemy import inspect
@@ -25,19 +24,18 @@ if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine import Inspector
- from sqlalchemy.sql.schema import Column
- from sqlalchemy.sql.schema import ForeignKeyConstraint
- from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import MetaData
- from sqlalchemy.sql.schema import Table
- from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.schema import SchemaItem
- from alembic.config import Config
- from alembic.operations.ops import MigrationScript
- from alembic.operations.ops import UpgradeOps
- from alembic.runtime.migration import MigrationContext
- from alembic.script.base import Script
- from alembic.script.base import ScriptDirectory
+ from ..config import Config
+ from ..operations.ops import MigrationScript
+ from ..operations.ops import UpgradeOps
+ from ..runtime.environment import NameFilterParentNames
+ from ..runtime.environment import NameFilterType
+ from ..runtime.environment import RenderItemFn
+ from ..runtime.migration import MigrationContext
+ from ..script.base import Script
+ from ..script.base import ScriptDirectory
def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
@@ -172,7 +170,7 @@ def render_python_code(
alembic_module_prefix: str = "op.",
render_as_batch: bool = False,
imports: Tuple[str, ...] = (),
- render_item: None = None,
+ render_item: Optional[RenderItemFn] = None,
migration_context: Optional[MigrationContext] = None,
) -> str:
"""Render Python code given an :class:`.UpgradeOps` or
@@ -359,8 +357,8 @@ class AutogenContext:
def run_name_filters(
self,
name: Optional[str],
- type_: str,
- parent_names: Dict[str, Optional[str]],
+ type_: NameFilterType,
+ parent_names: NameFilterParentNames,
) -> bool:
"""Run the context's name filters and return True if the targets
should be part of the autogenerate operation.
@@ -396,17 +394,11 @@ class AutogenContext:
def run_object_filters(
self,
- object_: Union[
- Table,
- Index,
- Column,
- UniqueConstraint,
- ForeignKeyConstraint,
- ],
+ object_: SchemaItem,
name: Optional[str],
- type_: str,
+ type_: NameFilterType,
reflected: bool,
- compare_to: Optional[Union[Table, Index, Column, UniqueConstraint]],
+ compare_to: Optional[SchemaItem],
) -> bool:
"""Run the context's object filters and return True if the targets
should be part of the autogenerate operation.
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 85cb426..595631c 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -212,7 +212,7 @@ def _compare_tables(
(inspector),
# fmt: on
)
- sqla_compat._reflect_table(inspector, t, None)
+ sqla_compat._reflect_table(inspector, t)
if autogen_context.run_object_filters(t, tname, "table", True, None):
modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
@@ -243,7 +243,7 @@ def _compare_tables(
_compat_autogen_column_reflect(inspector),
# fmt: on
)
- sqla_compat._reflect_table(inspector, t, None)
+ sqla_compat._reflect_table(inspector, t)
conn_column_info[(s, tname)] = t
for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
diff --git a/alembic/context.pyi b/alembic/context.pyi
index a9f48b2..1007a5e 100644
--- a/alembic/context.pyi
+++ b/alembic/context.pyi
@@ -3,7 +3,6 @@
from __future__ import annotations
from typing import Any
-from typing import Callable
from typing import ContextManager
from typing import Dict
from typing import List
@@ -22,7 +21,11 @@ if TYPE_CHECKING:
from sqlalchemy.sql.schema import MetaData
from .config import Config
- from .operations import MigrateOperation
+ from .runtime.environment import IncludeNameFn
+ from .runtime.environment import IncludeObjectFn
+ from .runtime.environment import OnVersionApplyFn
+ from .runtime.environment import ProcessRevisionDirectiveFn
+ from .runtime.environment import RenderItemFn
from .runtime.migration import _ProxyTransaction
from .runtime.migration import MigrationContext
from .script import ScriptDirectory
@@ -76,7 +79,7 @@ config: Config
def configure(
connection: Optional[Connection] = None,
- url: Union[str, URL, None] = None,
+ url: Optional[Union[str, URL]] = None,
dialect_name: Optional[str] = None,
dialect_opts: Optional[Dict[str, Any]] = None,
transactional_ddl: Optional[bool] = None,
@@ -87,24 +90,20 @@ def configure(
template_args: Optional[Dict[str, Any]] = None,
render_as_batch: bool = False,
target_metadata: Optional[MetaData] = None,
- include_name: Optional[Callable[..., bool]] = None,
- include_object: Optional[Callable[..., bool]] = None,
+ include_name: Optional[IncludeNameFn] = None,
+ include_object: Optional[IncludeObjectFn] = None,
include_schemas: bool = False,
- process_revision_directives: Optional[
- Callable[
- [MigrationContext, Tuple[str, str], List[MigrateOperation]], None
- ]
- ] = None,
+ process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
compare_type: bool = False,
compare_server_default: bool = False,
- render_item: Optional[Callable[..., bool]] = None,
+ render_item: Optional[RenderItemFn] = None,
literal_binds: bool = False,
upgrade_token: str = "upgrades",
downgrade_token: str = "downgrades",
alembic_module_prefix: str = "op.",
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
- on_version_apply: Optional[Callable[..., None]] = None,
+ on_version_apply: Optional[OnVersionApplyFn] = None,
**kw: Any,
) -> None:
"""Configure a :class:`.MigrationContext` within this
@@ -308,7 +307,8 @@ def configure(
``"unique_constraint"``, or ``"foreign_key_constraint"``
* ``parent_names``: a dictionary of "parent" object names, that are
relative to the name being given. Keys in this dictionary may
- include: ``"schema_name"``, ``"table_name"``.
+ include: ``"schema_name"``, ``"table_name"`` or
+ ``"schema_qualified_table_name"``.
E.g.::
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
index f11d1ed..84f5d86 100644
--- a/alembic/ddl/impl.py
+++ b/alembic/ddl/impl.py
@@ -155,9 +155,9 @@ class DefaultImpl(metaclass=ImplMeta):
def _exec(
self,
construct: Union[ClauseElement, str],
- execution_options: Optional[dict] = None,
+ execution_options: Optional[dict[str, Any]] = None,
multiparams: Sequence[dict] = (),
- params: Dict[str, int] = util.immutabledict(),
+ params: Dict[str, Any] = util.immutabledict(),
) -> Optional[CursorResult]:
if isinstance(construct, str):
construct = text(construct)
@@ -197,7 +197,7 @@ class DefaultImpl(metaclass=ImplMeta):
def execute(
self,
sql: Union[ClauseElement, str],
- execution_options: None = None,
+ execution_options: Optional[dict[str, Any]] = None,
) -> None:
self._exec(sql, execution_options)
diff --git a/alembic/op.pyi b/alembic/op.pyi
index dc94113..dab5856 100644
--- a/alembic/op.pyi
+++ b/alembic/op.pyi
@@ -951,7 +951,8 @@ def drop_table_comment(
"""
def execute(
- sqltext: Union[str, TextClause, Update], execution_options: None = None
+ sqltext: Union[str, TextClause, Update],
+ execution_options: Optional[dict[str, Any]] = None,
) -> Optional[Table]:
r"""Execute the given SQL using the current migration context.
@@ -1101,7 +1102,7 @@ def implementation_for(op_cls: Any) -> Callable[..., Any]:
"""
def inline_literal(
- value: Union[str, int], type_: None = None
+ value: Union[str, int], type_: Optional[TypeEngine] = None
) -> _literal_bindparam:
r"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
diff --git a/alembic/operations/base.py b/alembic/operations/base.py
index 04b66b5..82d9779 100644
--- a/alembic/operations/base.py
+++ b/alembic/operations/base.py
@@ -33,8 +33,9 @@ NoneType = type(None)
if TYPE_CHECKING:
from typing import Literal
- from sqlalchemy import Table # noqa
+ from sqlalchemy import Table
from sqlalchemy.engine import Connection
+ from sqlalchemy.types import TypeEngine
from .batch import BatchOperationsImpl
from .ops import MigrateOperation
@@ -439,7 +440,7 @@ class Operations(util.ModuleClsProxy):
return conv(name)
def inline_literal(
- self, value: Union[str, int], type_: None = None
+ self, value: Union[str, int], type_: Optional[TypeEngine[Any]] = None
) -> _literal_bindparam:
r"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py
index 00f13a1..f4a058b 100644
--- a/alembic/operations/batch.py
+++ b/alembic/operations/batch.py
@@ -487,7 +487,7 @@ class ApplyBatchImpl:
server_default: Optional[Union[Function[Any], str, bool]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
- autoincrement: None = None,
+ autoincrement: Optional[Union[bool, Literal["auto"]]] = None,
comment: Union[str, Literal[False]] = False,
**kw,
) -> None:
diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py
index b3ef5bb..7dd65a1 100644
--- a/alembic/operations/ops.py
+++ b/alembic/operations/ops.py
@@ -673,11 +673,11 @@ class CreateForeignKeyOp(AddConstraintOp):
local_cols: List[str],
remote_cols: List[str],
referent_schema: Optional[str] = None,
- onupdate: None = None,
- ondelete: None = None,
- deferrable: None = None,
- initially: None = None,
- match: None = None,
+ onupdate: Optional[str] = None,
+ ondelete: Optional[str] = None,
+ deferrable: Optional[bool] = None,
+ initially: Optional[str] = None,
+ match: Optional[str] = None,
**dialect_kw: Any,
) -> None:
"""Issue a "create foreign key" instruction using the
@@ -1890,10 +1890,10 @@ class AlterColumnOp(AlterTableOp):
type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
existing_server_default: bool = False,
- existing_nullable: None = None,
- existing_comment: None = None,
- insert_before: None = None,
- insert_after: None = None,
+ existing_nullable: Optional[bool] = None,
+ existing_comment: Optional[str] = None,
+ insert_before: Optional[str] = None,
+ insert_after: Optional[str] = None,
**kw: Any,
) -> Optional[Table]:
"""Issue an "alter column" instruction using the current
@@ -1935,6 +1935,8 @@ class AlterColumnOp(AlterTableOp):
modify_server_default=server_default,
modify_nullable=nullable,
modify_comment=comment,
+ insert_before=insert_before,
+ insert_after=insert_after,
**kw,
)
@@ -2314,7 +2316,7 @@ class ExecuteSQLOp(MigrateOperation):
def __init__(
self,
sqltext: Union[Update, str, Insert, TextClause],
- execution_options: None = None,
+ execution_options: Optional[dict[str, Any]] = None,
) -> None:
self.sqltext = sqltext
self.execution_options = execution_options
@@ -2324,7 +2326,7 @@ class ExecuteSQLOp(MigrateOperation):
cls,
operations: Operations,
sqltext: Union[str, TextClause, Update],
- execution_options: None = None,
+ execution_options: Optional[dict[str, Any]] = None,
) -> Optional[Table]:
r"""Execute the given SQL using the current migration context.
diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py
index c2fa11a..f5c177e 100644
--- a/alembic/runtime/environment.py
+++ b/alembic/runtime/environment.py
@@ -2,9 +2,12 @@ from __future__ import annotations
from typing import Any
from typing import Callable
+from typing import Collection
from typing import ContextManager
from typing import Dict
from typing import List
+from typing import Mapping
+from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import TextIO
@@ -12,19 +15,23 @@ from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
+from typing_extensions import Literal
+
from .migration import _ProxyTransaction
from .migration import MigrationContext
from .. import util
from ..operations import Operations
if TYPE_CHECKING:
- from typing import Literal
from sqlalchemy.engine import URL
from sqlalchemy.engine.base import Connection
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import SchemaItem
+ from .migration import MigrationInfo
+ from ..autogenerate.api import AutogenContext
from ..config import Config
from ..ddl import DefaultImpl
from ..operations.ops import MigrateOperation
@@ -36,6 +43,42 @@ ProcessRevisionDirectiveFn = Callable[
[MigrationContext, Tuple[str, str], List["MigrateOperation"]], None
]
+RenderItemFn = Callable[
+ [str, Any, "AutogenContext"], Union[str, Literal[False]]
+]
+
+NameFilterType = Literal[
+ "schema",
+ "table",
+ "column",
+ "index",
+ "unique_constraint",
+ "foreign_key_constraint",
+]
+NameFilterParentNames = MutableMapping[
+ Literal["schema_name", "table_name", "schema_qualified_table_name"],
+ Optional[str],
+]
+IncludeNameFn = Callable[
+ [Optional[str], NameFilterType, NameFilterParentNames], bool
+]
+
+IncludeObjectFn = Callable[
+ [
+ "SchemaItem",
+ Optional[str],
+ NameFilterType,
+ bool,
+ Optional["SchemaItem"],
+ ],
+ bool,
+]
+
+OnVersionApplyFn = Callable[
+ [MigrationContext, "MigrationInfo", Collection[Any], Mapping[str, Any]],
+ None,
+]
+
class EnvironmentContext(util.ModuleClsProxy):
@@ -346,22 +389,22 @@ class EnvironmentContext(util.ModuleClsProxy):
template_args: Optional[Dict[str, Any]] = None,
render_as_batch: bool = False,
target_metadata: Optional[MetaData] = None,
- include_name: Optional[Callable[..., bool]] = None,
- include_object: Optional[Callable[..., bool]] = None,
+ include_name: Optional[IncludeNameFn] = None,
+ include_object: Optional[IncludeObjectFn] = None,
include_schemas: bool = False,
process_revision_directives: Optional[
ProcessRevisionDirectiveFn
] = None,
compare_type: bool = False,
compare_server_default: bool = False,
- render_item: Optional[Callable[..., bool]] = None,
+ render_item: Optional[RenderItemFn] = None,
literal_binds: bool = False,
upgrade_token: str = "upgrades",
downgrade_token: str = "downgrades",
alembic_module_prefix: str = "op.",
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
- on_version_apply: Optional[Callable[..., None]] = None,
+ on_version_apply: Optional[OnVersionApplyFn] = None,
**kw: Any,
) -> None:
"""Configure a :class:`.MigrationContext` within this
@@ -565,7 +608,8 @@ class EnvironmentContext(util.ModuleClsProxy):
``"unique_constraint"``, or ``"foreign_key_constraint"``
* ``parent_names``: a dictionary of "parent" object names, that are
relative to the name being given. Keys in this dictionary may
- include: ``"schema_name"``, ``"table_name"``.
+ include: ``"schema_name"``, ``"table_name"`` or
+ ``"schema_qualified_table_name"``.
E.g.::
diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py
index 4e2d062..cfba0e3 100644
--- a/alembic/runtime/migration.py
+++ b/alembic/runtime/migration.py
@@ -5,10 +5,12 @@ from contextlib import nullcontext
import logging
import sys
from typing import Any
+from typing import Callable
from typing import cast
from typing import Collection
from typing import ContextManager
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
@@ -74,7 +76,7 @@ class _ProxyTransaction:
def __enter__(self) -> _ProxyTransaction:
return self
- def __exit__(self, type_: None, value: None, traceback: None) -> None:
+ def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
if self._proxied_transaction is not None:
self._proxied_transaction.__exit__(type_, value, traceback)
self.migration_context._transaction = None
@@ -158,7 +160,9 @@ class MigrationContext:
sqla_compat._get_connection_in_transaction(connection)
)
- self._migrations_fn = opts.get("fn")
+ self._migrations_fn: Optional[
+ Callable[..., Iterable[RevisionStep]]
+ ] = opts.get("fn")
self.as_sql = as_sql
self.purge = opts.get("purge", False)
@@ -1275,7 +1279,7 @@ class StampStep(MigrationStep):
self.migration_fn = self.stamp_revision
self.revision_map = revision_map
- doc: None = None
+ doc: Optional[str] = None
def stamp_revision(self, **kw: Any) -> None:
return None
diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py
index cab9949..e2725d6 100644
--- a/alembic/util/sqla_compat.py
+++ b/alembic/util/sqla_compat.py
@@ -299,9 +299,7 @@ def _columns_for_constraint(constraint):
return list(constraint.columns)
-def _reflect_table(
- inspector: Inspector, table: Table, include_cols: None
-) -> None:
+def _reflect_table(inspector: Inspector, table: Table) -> None:
if sqla_14:
return inspector.reflect_table(table, None)
else:
diff --git a/docs/build/unreleased/1221.rst b/docs/build/unreleased/1221.rst
new file mode 100644
index 0000000..de14f15
--- /dev/null
+++ b/docs/build/unreleased/1221.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: bug, batch
+ :tickets: 1221
+
+ Correctly pass previously ignored arguments ``insert_before`` and
+ ``insert_after`` in ``batch_alter_column``
diff --git a/tests/test_mysql.py b/tests/test_mysql.py
index 2145fd7..92c1819 100644
--- a/tests/test_mysql.py
+++ b/tests/test_mysql.py
@@ -627,7 +627,7 @@ class MySQLDefaultCompareTest(TestBase):
insp = inspect(self.bind)
cols = insp.get_columns(t1.name)
refl = Table(t1.name, MetaData())
- sqla_compat._reflect_table(insp, refl, None)
+ sqla_compat._reflect_table(insp, refl)
ctx = self.autogen_context["context"]
return ctx.impl.compare_server_default(
refl.c[cols[0]["name"]], col, rendered, cols[0]["default"]
diff --git a/tools/write_pyi.py b/tools/write_pyi.py
index 4fbf366..fa79c49 100644
--- a/tools/write_pyi.py
+++ b/tools/write_pyi.py
@@ -109,7 +109,8 @@ def generate_pyi_for_proxy(
# Do not generate the base implementation to avoid mypy errors
overloads = typing.get_overloads(meth)
if overloads:
- # use enumerate so we can generate docs on the last overload
+ # use enumerate so we can generate docs on the
+ # last overload
for i, ovl in enumerate(overloads, 1):
_generate_stub_for_meth(
ovl,