summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/more_typing.rst11
-rw-r--r--lib/sqlalchemy/engine/result.py2
-rw-r--r--test/ext/mypy/plain_files/typed_results.py22
3 files changed, 33 insertions, 2 deletions
diff --git a/doc/build/changelog/unreleased_20/more_typing.rst b/doc/build/changelog/unreleased_20/more_typing.rst
index b958d0d91..62cd04e8b 100644
--- a/doc/build/changelog/unreleased_20/more_typing.rst
+++ b/doc/build/changelog/unreleased_20/more_typing.rst
@@ -27,4 +27,13 @@
:tickets: 9125
Fixed typing issue where iterating over a :class:`_orm.Query` object
- was not correctly typed.
+ was not correctly typed.
+
+.. change::
+ :tags: typing, bug
+ :tickets: 9136
+
+ Fixed typing issue where the object type when using :class:`_engine.Result`
+ as a context manager were not preserved, indicating :class:`_engine.Result`
+ in all cases rather than the specific :class:`_engine.Result` sub-type.
+ Pull request courtesy Martin Baláž.
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 4bf03ae69..67151913e 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -929,7 +929,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
def __init__(self, cursor_metadata: ResultMetaData):
self._metadata = cursor_metadata
- def __enter__(self) -> Result[_TP]:
+ def __enter__(self: SelfResult) -> SelfResult:
return self
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
diff --git a/test/ext/mypy/plain_files/typed_results.py b/test/ext/mypy/plain_files/typed_results.py
index 262e5b5ff..8fd9e5cd1 100644
--- a/test/ext/mypy/plain_files/typed_results.py
+++ b/test/ext/mypy/plain_files/typed_results.py
@@ -77,6 +77,28 @@ multi_stmt = select(User.id, User.name).where(User.name == "foo")
reveal_type(multi_stmt)
+def t_result_ctxmanager() -> None:
+ with connection.execute(select(column("q", Integer))) as r1:
+ # EXPECTED_TYPE: CursorResult[Tuple[int]]
+ reveal_type(r1)
+
+ with r1.mappings() as r1m:
+ # EXPECTED_TYPE: MappingResult
+ reveal_type(r1m)
+
+ with connection.scalars(select(column("q", Integer))) as r2:
+ # EXPECTED_TYPE: ScalarResult[int]
+ reveal_type(r2)
+
+ with session.execute(select(User.id)) as r3:
+ # EXPECTED_TYPE: Result[Tuple[int]]
+ reveal_type(r3)
+
+ with session.scalars(select(User.id)) as r4:
+ # EXPECTED_TYPE: ScalarResult[int]
+ reveal_type(r4)
+
+
def t_entity_varieties() -> None:
a1 = aliased(User)