summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/type_scalars.rst6
-rw-r--r--lib/sqlalchemy/engine/base.py6
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py6
-rw-r--r--test/ext/mypy/plain_files/typed_results.py69
4 files changed, 81 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_20/type_scalars.rst b/doc/build/changelog/unreleased_20/type_scalars.rst
new file mode 100644
index 000000000..d983e1580
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/type_scalars.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: bug, typing
+
+ Fixed bug where the :meth:`_engine.Connection.scalars` method was not typed
+ as allowing a multiple-parameters list, which is now supported using
+ insertmanyvalues operations.
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index f6c637aa8..926a08b76 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1306,7 +1306,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[_T]:
@@ -1316,7 +1316,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
@@ -1325,7 +1325,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 86e257bdd..325c58bda 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -646,7 +646,7 @@ class AsyncConnection(
async def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[_T]:
@@ -656,7 +656,7 @@ class AsyncConnection(
async def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
@@ -665,7 +665,7 @@ class AsyncConnection(
async def scalars(
self,
statement: Executable,
- parameters: Optional[_CoreSingleExecuteParams] = None,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
diff --git a/test/ext/mypy/plain_files/typed_results.py b/test/ext/mypy/plain_files/typed_results.py
index 8fd9e5cd1..2e42bb655 100644
--- a/test/ext/mypy/plain_files/typed_results.py
+++ b/test/ext/mypy/plain_files/typed_results.py
@@ -6,6 +6,7 @@ from typing import cast
from sqlalchemy import Column
from sqlalchemy import column
from sqlalchemy import create_engine
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import table
@@ -249,6 +250,74 @@ async def t_async_result_scalar_accessors() -> None:
reveal_type(r5)
+def t_result_insertmanyvalues_scalars() -> None:
+ stmt = insert(User).returning(User.id)
+
+ uids1 = connection.scalars(
+ stmt,
+ [
+ {"name": "n1"},
+ {"name": "n2"},
+ {"name": "n3"},
+ ],
+ ).all()
+
+ # EXPECTED_TYPE: Sequence[int]
+ reveal_type(uids1)
+
+ uids2 = (
+ connection.execute(
+ stmt,
+ [
+ {"name": "n1"},
+ {"name": "n2"},
+ {"name": "n3"},
+ ],
+ )
+ .scalars()
+ .all()
+ )
+
+ # EXPECTED_TYPE: Sequence[int]
+ reveal_type(uids2)
+
+
+async def t_async_result_insertmanyvalues_scalars() -> None:
+ stmt = insert(User).returning(User.id)
+
+ uids1 = (
+ await async_connection.scalars(
+ stmt,
+ [
+ {"name": "n1"},
+ {"name": "n2"},
+ {"name": "n3"},
+ ],
+ )
+ ).all()
+
+ # EXPECTED_TYPE: Sequence[int]
+ reveal_type(uids1)
+
+ uids2 = (
+ (
+ await async_connection.execute(
+ stmt,
+ [
+ {"name": "n1"},
+ {"name": "n2"},
+ {"name": "n3"},
+ ],
+ )
+ )
+ .scalars()
+ .all()
+ )
+
+ # EXPECTED_TYPE: Sequence[int]
+ reveal_type(uids2)
+
+
def t_connection_execute_multi_row_t() -> None:
result = connection.execute(multi_stmt)