summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-19 21:06:41 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-27 14:46:36 -0400
commitad11c482e2233f44e8747d4d5a2b17a995fff1fa (patch)
tree57f8ddd30928951519fd6ac0f418e9cbf8e65610 /lib/sqlalchemy
parent033d1a16e7a220555d7611a5b8cacb1bd83822ae (diff)
downloadsqlalchemy-ad11c482e2233f44e8747d4d5a2b17a995fff1fa.tar.gz
pep484 ORM / SQL result support
after some experimentation it seems mypy is more amenable to the generic types being fully integrated rather than having separate spin-off types. so key structures like Result, Row, Select become generic. For DML Insert, Update, Delete, these are spun into type-specific subclasses ReturningInsert, ReturningUpdate, ReturningDelete, which is fine since the "row-ness" of these constructs doesn't happen until returning() is called in any case. a Tuple based model is then integrated so that these objects can carry along information about their return types. Overloads at the .execute() level carry through the Tuple from the invoked object to the result. To suit the issue of AliasedClass generating attributes that are dynamic, experimented with a custom subclass AsAliased, but then just settled on having aliased() lie to the type checker and return `Type[_O]`, essentially. will need some type-related accessors for with_polymorphic() also. Additionally, identified an issue in Update when used "mysql style" against a join(), it basically doesn't work if asked to UPDATE two tables on the same column name. added an error message to the specific condition where it happens with a very non-specific error message that we hit a thing we can't do right now, suggest multi-table update as a possible cause. Change-Id: I5eff7eefe1d6166ee74160b2785c5e6a81fa8b95
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/__init__.py1
-rw-r--r--lib/sqlalchemy/engine/base.py89
-rw-r--r--lib/sqlalchemy/engine/cursor.py24
-rw-r--r--lib/sqlalchemy/engine/default.py14
-rw-r--r--lib/sqlalchemy/engine/events.py3
-rw-r--r--lib/sqlalchemy/engine/interfaces.py2
-rw-r--r--lib/sqlalchemy/engine/result.py422
-rw-r--r--lib/sqlalchemy/engine/row.py51
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py132
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py339
-rw-r--r--lib/sqlalchemy/ext/asyncio/scoping.py135
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py136
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py3
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py30
-rw-r--r--lib/sqlalchemy/orm/attributes.py5
-rw-r--r--lib/sqlalchemy/orm/base.py16
-rw-r--r--lib/sqlalchemy/orm/context.py7
-rw-r--r--lib/sqlalchemy/orm/interfaces.py13
-rw-r--r--lib/sqlalchemy/orm/mapper.py13
-rw-r--r--lib/sqlalchemy/orm/properties.py5
-rw-r--r--lib/sqlalchemy/orm/query.py296
-rw-r--r--lib/sqlalchemy/orm/scoping.py219
-rw-r--r--lib/sqlalchemy/orm/session.py211
-rw-r--r--lib/sqlalchemy/orm/state.py6
-rw-r--r--lib/sqlalchemy/orm/util.py53
-rw-r--r--lib/sqlalchemy/sql/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/_selectable_constructors.py166
-rw-r--r--lib/sqlalchemy/sql/_typing.py78
-rw-r--r--lib/sqlalchemy/sql/base.py15
-rw-r--r--lib/sqlalchemy/sql/coercions.py25
-rw-r--r--lib/sqlalchemy/sql/compiler.py40
-rw-r--r--lib/sqlalchemy/sql/crud.py32
-rw-r--r--lib/sqlalchemy/sql/dml.py376
-rw-r--r--lib/sqlalchemy/sql/elements.py36
-rw-r--r--lib/sqlalchemy/sql/functions.py8
-rw-r--r--lib/sqlalchemy/sql/roles.py57
-rw-r--r--lib/sqlalchemy/sql/schema.py9
-rw-r--r--lib/sqlalchemy/sql/selectable.py287
-rw-r--r--lib/sqlalchemy/sql/util.py9
-rw-r--r--lib/sqlalchemy/sql/visitors.py4
-rw-r--r--lib/sqlalchemy/util/langhelpers.py44
-rw-r--r--lib/sqlalchemy/util/typing.py2
42 files changed, 2939 insertions, 475 deletions
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 29dd6aff9..afba17075 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -46,6 +46,7 @@ from .result import MergedResult as MergedResult
from .result import Result as Result
from .result import result_tuple as result_tuple
from .result import ScalarResult as ScalarResult
+from .result import TupleResult as TupleResult
from .row import BaseRow as BaseRow
from .row import Row as Row
from .row import RowMapping as RowMapping
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index a325da929..fe3bfa1ad 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -18,8 +18,10 @@ from typing import Mapping
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Tuple
from typing import Type
+from typing import TypeVar
from typing import Union
from .interfaces import _IsolationLevel
@@ -45,12 +47,10 @@ if typing.TYPE_CHECKING:
from . import ScalarResult
from .interfaces import _AnyExecuteParams
from .interfaces import _AnyMultiExecuteParams
- from .interfaces import _AnySingleExecuteParams
from .interfaces import _CoreAnyExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
from .interfaces import _DBAPIAnyExecuteParams
- from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _ExecuteOptionsParameter
@@ -65,21 +65,21 @@ if typing.TYPE_CHECKING:
from ..pool import PoolProxiedConnection
from ..sql import Executable
from ..sql._typing import _InfoType
- from ..sql.base import SchemaVisitor
from ..sql.compiler import Compiled
from ..sql.ddl import ExecutableDDLElement
from ..sql.ddl import SchemaDropper
from ..sql.ddl import SchemaGenerator
from ..sql.functions import FunctionElement
- from ..sql.schema import ColumnDefault
from ..sql.schema import DefaultGenerator
from ..sql.schema import HasSchemaAttr
from ..sql.schema import SchemaItem
+ from ..sql.selectable import TypedReturnsRows
"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
"""
+_T = TypeVar("_T", bound=Any)
_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT
NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT
@@ -1142,10 +1142,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self._dbapi_connection = None
self.__can_reconnect = False
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> Any:
r"""Executes a SQL statement construct and returns a scalar object.
@@ -1170,10 +1191,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
execution_options or NO_OPTIONS,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
def scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[Any]:
+ ...
+
+ def scalars(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
"""Executes and returns a scalar result set, which yields scalar values
@@ -1190,14 +1232,37 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
- return self.execute(statement, parameters, execution_options).scalars()
+ return self.execute(
+ statement, parameters, execution_options=execution_options
+ ).scalars()
+
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[Any]:
+ ...
def execute(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
r"""Executes a SQL statement construct and returns a
:class:`_engine.CursorResult`.
@@ -1246,7 +1311,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
func: FunctionElement[Any],
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a sql.FunctionElement object."""
return self._execute_clauseelement(
@@ -1317,7 +1382,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
ddl: ExecutableDDLElement,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a schema.DDL object."""
execution_options = ddl._execution_options.merge_with(
@@ -1414,7 +1479,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
elem: Executable,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a sql.ClauseElement object."""
execution_options = elem._execution_options.merge_with(
@@ -1487,7 +1552,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
compiled: Compiled,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Execute a sql.Compiled object.
TODO: why do we have this? likely deprecate or remove
@@ -1537,7 +1602,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
statement: str,
parameters: Optional[_DBAPIAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
r"""Executes a SQL statement construct and returns a
:class:`_engine.CursorResult`.
@@ -1614,7 +1679,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
execution_options: _ExecuteOptions,
*args: Any,
**kw: Any,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`_engine.CursorResult`."""
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index ccf573675..ff69666b7 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -24,6 +24,7 @@ from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from .result import MergedResult
@@ -55,11 +56,12 @@ if typing.TYPE_CHECKING:
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
- from .result import _KeyMapType
from .result import _KeyType
from .result import _ProcessorsType
from ..sql.type_api import _ResultProcessorType
+_T = TypeVar("_T", bound=Any)
+
# metadata entry tuple indexes.
# using raw tuple is faster than namedtuple.
MD_INDEX: Literal[0] = 0 # integer index in cursor.description
@@ -214,7 +216,9 @@ class CursorResultMetaData(ResultMetaData):
return md
def __init__(
- self, parent: CursorResult, cursor_description: _DBAPICursorDescription
+ self,
+ parent: CursorResult[Any],
+ cursor_description: _DBAPICursorDescription,
):
context = parent.context
self._tuplefilter = None
@@ -1158,7 +1162,7 @@ class _NoResultMetaData(ResultMetaData):
_NO_RESULT_METADATA = _NoResultMetaData()
-class CursorResult(Result):
+class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
.. versionchanged:: 1.4 The :class:`.CursorResult``
@@ -1179,6 +1183,15 @@ class CursorResult(Result):
"""
+ __slots__ = (
+ "context",
+ "dialect",
+ "cursor",
+ "cursor_strategy",
+ "_echo",
+ "connection",
+ )
+
_metadata: Union[CursorResultMetaData, _NoResultMetaData]
_no_result_metadata = _NO_RESULT_METADATA
_soft_closed: bool = False
@@ -1231,7 +1244,6 @@ class CursorResult(Result):
make_row = _make_row_2
else:
make_row = _make_row
-
self._set_memoized_attribute("_row_getter", make_row)
else:
@@ -1726,12 +1738,12 @@ class CursorResult(Result):
def _raw_row_iterator(self):
return self._fetchiter_impl()
- def merge(self, *others: Result) -> MergedResult:
+ def merge(self, *others: Result[Any]) -> MergedResult[Any]:
merged_result = super().merge(*others)
setup_rowcounts = not self._metadata.returns_rows
if setup_rowcounts:
merged_result.rowcount = sum(
- cast(CursorResult, result).rowcount
+ cast("CursorResult[Any]", result).rowcount
for result in (self,) + others
)
return merged_result
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index c6571f68b..9c6ff758f 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -62,13 +62,9 @@ if typing.TYPE_CHECKING:
from .base import Connection
from .base import Engine
- from .characteristics import ConnectionCharacteristic
- from .interfaces import _AnyMultiExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
- from .interfaces import _DBAPIAnyExecuteParams
from .interfaces import _DBAPIMultiExecuteParams
- from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _IsolationLevel
from .interfaces import _MutableCoreSingleExecuteParams
@@ -83,15 +79,11 @@ if typing.TYPE_CHECKING:
from ..sql.compiler import Compiled
from ..sql.compiler import Linting
from ..sql.compiler import ResultColumnsEntry
- from ..sql.compiler import TypeCompiler
from ..sql.dml import DMLState
from ..sql.dml import UpdateBase
from ..sql.elements import BindParameter
- from ..sql.roles import ColumnsClauseRole
from ..sql.schema import Column
- from ..sql.schema import ColumnDefault
from ..sql.type_api import _BindProcessorType
- from ..sql.type_api import _ResultProcessorType
from ..sql.type_api import TypeEngine
# When we're handed literal SQL, ensure it's a SELECT query
@@ -781,7 +773,7 @@ class DefaultExecutionContext(ExecutionContext):
result_column_struct: Optional[
Tuple[List[ResultColumnsEntry], bool, bool, bool]
] = None
- returned_default_rows: Optional[List[Row]] = None
+ returned_default_rows: Optional[Sequence[Row[Any]]] = None
execution_options: _ExecuteOptions = util.EMPTY_DICT
@@ -1385,7 +1377,9 @@ class DefaultExecutionContext(ExecutionContext):
if cursor_description is None:
strategy = _cursor._NO_CURSOR_DML
- result = _cursor.CursorResult(self, strategy, cursor_description)
+ result: _cursor.CursorResult[Any] = _cursor.CursorResult(
+ self, strategy, cursor_description
+ )
if self.isinsert:
if self._is_implicit_returning:
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
index ef10946a8..4093d3e0e 100644
--- a/lib/sqlalchemy/engine/events.py
+++ b/lib/sqlalchemy/engine/events.py
@@ -28,7 +28,6 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from .base import Connection
- from .interfaces import _CoreAnyExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
from .interfaces import _DBAPIAnyExecuteParams
@@ -273,7 +272,7 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
multiparams: _CoreMultiExecuteParams,
params: _CoreSingleExecuteParams,
execution_options: _ExecuteOptions,
- result: Result,
+ result: Result[Any],
) -> None:
"""Intercept high level execute() events after execute.
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 54fe21d74..641024603 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -2422,7 +2422,7 @@ class ExecutionContext:
def _get_cache_stats(self) -> str:
raise NotImplementedError()
- def _setup_result_proxy(self) -> CursorResult:
+ def _setup_result_proxy(self) -> CursorResult[Any]:
raise NotImplementedError()
def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int:
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 71320a583..55d36a1d5 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -28,6 +28,7 @@ from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -70,6 +71,8 @@ _RawRowType = Tuple[Any, ...]
"""represents the kind of row we get from a DBAPI cursor"""
_R = TypeVar("_R", bound=_RowData)
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
_InterimRowType = Union[_R, _RawRowType]
"""a catchall "anything" kind of return type that can be applied
@@ -141,7 +144,7 @@ class ResultMetaData:
def _getter(
self, key: Any, raiseerr: bool = True
- ) -> Optional[Callable[[Row], Any]]:
+ ) -> Optional[Callable[[Row[Any]], Any]]:
index = self._index_for_key(key, raiseerr)
@@ -270,7 +273,7 @@ class SimpleResultMetaData(ResultMetaData):
_tuplefilter=_tuplefilter,
)
- def _contains(self, value: Any, row: Row) -> bool:
+ def _contains(self, value: Any, row: Row[Any]) -> bool:
return value in row._data
def _index_for_key(self, key: Any, raiseerr: bool = True) -> int:
@@ -335,7 +338,7 @@ class SimpleResultMetaData(ResultMetaData):
def result_tuple(
fields: Sequence[str], extra: Optional[Any] = None
-) -> Callable[[Iterable[Any]], Row]:
+) -> Callable[[Iterable[Any]], Row[Any]]:
parent = SimpleResultMetaData(fields, extra)
return functools.partial(
Row, parent, parent._processors, parent._keymap, Row._default_key_style
@@ -355,7 +358,9 @@ SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal[Any]")
class ResultInternal(InPlaceGenerative, Generic[_R]):
- _real_result: Optional[Result] = None
+ __slots__ = ()
+
+ _real_result: Optional[Result[Any]] = None
_generate_rows: bool = True
_row_logging_fn: Optional[Callable[[Any], Any]]
@@ -367,20 +372,20 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
_source_supports_scalars: bool
- def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]:
+ def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
def _fetchone_impl(
self, hard_close: bool = False
- ) -> Optional[_InterimRowType[Row]]:
+ ) -> Optional[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
- def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+ def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
raise NotImplementedError()
def _soft_close(self, hard: bool = False) -> None:
@@ -388,8 +393,10 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
@HasMemoized_ro_memoized_attribute
def _row_getter(self) -> Optional[Callable[..., _R]]:
- real_result: Result = (
- self._real_result if self._real_result else cast(Result, self)
+ real_result: Result[Any] = (
+ self._real_result
+ if self._real_result
+ else cast("Result[Any]", self)
)
if real_result._source_supports_scalars:
@@ -404,7 +411,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
keymap: _KeyMapType,
key_style: Any,
scalar_obj: Any,
- ) -> Row:
+ ) -> Row[Any]:
return _proc(
metadata, processors, keymap, key_style, (scalar_obj,)
)
@@ -429,7 +436,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
fixed_tf = tf
- def make_row(row: _InterimRowType[Row]) -> _R:
+ def make_row(row: _InterimRowType[Row[Any]]) -> _R:
return _make_row_orig(fixed_tf(row))
else:
@@ -447,7 +454,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
if fns:
_make_row = make_row
- def make_row(row: _InterimRowType[Row]) -> _R:
+ def make_row(row: _InterimRowType[Row[Any]]) -> _R:
interim_row = _make_row(row)
for fn in fns:
interim_row = fn(interim_row)
@@ -465,7 +472,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- def iterrows(self: Result) -> Iterator[_R]:
+ def iterrows(self: Result[Any]) -> Iterator[_R]:
for raw_row in self._fetchiter_impl():
obj: _InterimRowType[Any] = (
make_row(raw_row) if make_row else raw_row
@@ -480,7 +487,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
else:
- def iterrows(self: Result) -> Iterator[_R]:
+ def iterrows(self: Result[Any]) -> Iterator[_R]:
for raw_row in self._fetchiter_impl():
row: _InterimRowType[Any] = (
make_row(raw_row) if make_row else raw_row
@@ -546,7 +553,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- def onerow(self: Result) -> Union[_NoRow, _R]:
+ def onerow(self: Result[Any]) -> Union[_NoRow, _R]:
_onerow = self._fetchone_impl
while True:
row = _onerow()
@@ -567,7 +574,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
else:
- def onerow(self: Result) -> Union[_NoRow, _R]:
+ def onerow(self: Result[Any]) -> Union[_NoRow, _R]:
row = self._fetchone_impl()
if row is None:
return _NO_ROW
@@ -627,7 +634,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
real_result = (
self._real_result
if self._real_result
- else cast(Result, self)
+ else cast("Result[Any]", self)
)
if real_result._yield_per:
num_required = num = real_result._yield_per
@@ -667,7 +674,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
real_result = (
self._real_result
if self._real_result
- else cast(Result, self)
+ else cast("Result[Any]", self)
)
num = real_result._yield_per
@@ -799,7 +806,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
self: SelfResultInternal, indexes: Sequence[_KeyIndexType]
) -> SelfResultInternal:
real_result = (
- self._real_result if self._real_result else cast(Result, self)
+ self._real_result
+ if self._real_result
+ else cast("Result[Any]", self)
)
if not real_result._source_supports_scalars or len(indexes) != 1:
@@ -817,7 +826,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
real_result = (
self._real_result
if self._real_result is not None
- else cast(Result, self)
+ else cast("Result[Any]", self)
)
if not strategy and self._metadata._unique_filters:
@@ -836,6 +845,8 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
class _WithKeys:
+ __slots__ = ()
+
_metadata: ResultMetaData
# used mainly to share documentation on the keys method.
@@ -859,10 +870,10 @@ class _WithKeys:
return self._metadata.keys
-SelfResult = TypeVar("SelfResult", bound="Result")
+SelfResult = TypeVar("SelfResult", bound="Result[Any]")
-class Result(_WithKeys, ResultInternal[Row]):
+class Result(_WithKeys, ResultInternal[Row[_TP]]):
"""Represent a set of database results.
.. versionadded:: 1.4 The :class:`.Result` object provides a completely
@@ -887,7 +898,9 @@ class Result(_WithKeys, ResultInternal[Row]):
"""
- _row_logging_fn: Optional[Callable[[Row], Row]] = None
+ __slots__ = ("_metadata", "__dict__")
+
+ _row_logging_fn: Optional[Callable[[Row[Any]], Row[Any]]] = None
_source_supports_scalars: bool = False
@@ -1011,6 +1024,15 @@ class Result(_WithKeys, ResultInternal[Row]):
appropriate :class:`.ColumnElement` objects which correspond to
a given statement construct.
+ .. versionchanged:: 2.0 Due to a bug in 1.4, the
+ :meth:`.Result.columns` method had an incorrect behavior where
+ calling upon the method with just one index would cause the
+ :class:`.Result` object to yield scalar values rather than
+ :class:`.Row` objects. In version 2.0, this behavior has been
+ corrected such that calling upon :meth:`.Result.columns` with
+ a single index will produce a :class:`.Result` object that continues
+ to yield :class:`.Row` objects, which include only a single column.
+
E.g.::
statement = select(table.c.x, table.c.y, table.c.z)
@@ -1040,6 +1062,20 @@ class Result(_WithKeys, ResultInternal[Row]):
"""
return self._column_slices(col_expressions)
+ @overload
+ def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self: Result[Tuple[_T]], index: Literal[0]
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
+ ...
+
def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
"""Return a :class:`_result.ScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
@@ -1067,7 +1103,7 @@ class Result(_WithKeys, ResultInternal[Row]):
def _getter(
self, key: _KeyIndexType, raiseerr: bool = True
- ) -> Optional[Callable[[Row], Any]]:
+ ) -> Optional[Callable[[Row[Any]], Any]]:
"""return a callable that will retrieve the given key from a
:class:`.Row`.
@@ -1105,6 +1141,43 @@ class Result(_WithKeys, ResultInternal[Row]):
return MappingResult(self)
+ @property
+ def t(self) -> TupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ The :attr:`.Result.t` attribute is a synonym for calling the
+ :meth:`.Result.tuples` method.
+
+ .. versionadded:: 2.0
+
+ """
+ return self # type: ignore
+
+ def tuples(self) -> TupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ This method returns the same :class:`.Result` object at runtime,
+ however annotates as returning a :class:`.TupleResult` object
+ that will indicate to :pep:`484` typing tools that plain typed
+ ``Tuple`` instances are returned rather than rows. This allows
+ tuple unpacking and ``__getitem__`` access of :class:`.Row` objects
+ to by typed, for those cases where the statement invoked itself
+ included typing information.
+
+ .. versionadded:: 2.0
+
+ :return: the :class:`_result.TupleResult` type at typing time.
+
+ .. seealso::
+
+ :attr:`.Result.t` - shorter synonym
+
+ :attr:`.Row.t` - :class:`.Row` version
+
+ """
+
+ return self # type: ignore
+
def _raw_row_iterator(self) -> Iterator[_RowData]:
"""Return a safe iterator that yields raw row data.
@@ -1114,13 +1187,15 @@ class Result(_WithKeys, ResultInternal[Row]):
"""
raise NotImplementedError()
- def __iter__(self) -> Iterator[Row]:
+ def __iter__(self) -> Iterator[Row[_TP]]:
return self._iter_impl()
- def __next__(self) -> Row:
+ def __next__(self) -> Row[_TP]:
return self._next_impl()
- def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]:
+ def partitions(
+ self, size: Optional[int] = None
+ ) -> Iterator[Sequence[Row[_TP]]]:
"""Iterate through sub-lists of rows of the size given.
Each list will be of the size given, excluding the last list to
@@ -1158,12 +1233,12 @@ class Result(_WithKeys, ResultInternal[Row]):
else:
break
- def fetchall(self) -> List[Row]:
+ def fetchall(self) -> Sequence[Row[_TP]]:
"""A synonym for the :meth:`_engine.Result.all` method."""
return self._allrows()
- def fetchone(self) -> Optional[Row]:
+ def fetchone(self) -> Optional[Row[_TP]]:
"""Fetch one row.
When all rows are exhausted, returns None.
@@ -1185,7 +1260,7 @@ class Result(_WithKeys, ResultInternal[Row]):
else:
return row
- def fetchmany(self, size: Optional[int] = None) -> List[Row]:
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
@@ -1202,7 +1277,7 @@ class Result(_WithKeys, ResultInternal[Row]):
return self._manyrow_getter(self, size)
- def all(self) -> List[Row]:
+ def all(self) -> Sequence[Row[_TP]]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
@@ -1216,7 +1291,7 @@ class Result(_WithKeys, ResultInternal[Row]):
return self._allrows()
- def first(self) -> Optional[Row]:
+ def first(self) -> Optional[Row[_TP]]:
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
@@ -1252,7 +1327,7 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=False, raise_for_none=False, scalar=False
)
- def one_or_none(self) -> Optional[Row]:
+ def one_or_none(self) -> Optional[Row[_TP]]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
@@ -1276,6 +1351,14 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=False, scalar=False
)
+ @overload
+ def scalar_one(self: Result[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ def scalar_one(self) -> Any:
+ ...
+
def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
@@ -1293,6 +1376,14 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=True, scalar=True
)
+ @overload
+ def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
@@ -1310,7 +1401,7 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=False, scalar=True
)
- def one(self) -> Row:
+ def one(self) -> Row[_TP]:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
@@ -1341,6 +1432,14 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=True, raise_for_none=True, scalar=False
)
+ @overload
+ def scalar(self: Result[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(self) -> Any:
+ ...
+
def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
@@ -1359,7 +1458,7 @@ class Result(_WithKeys, ResultInternal[Row]):
raise_for_second_row=False, raise_for_none=False, scalar=True
)
- def freeze(self) -> FrozenResult:
+ def freeze(self) -> FrozenResult[_TP]:
"""Return a callable object that will produce copies of this
:class:`.Result` when invoked.
@@ -1382,7 +1481,7 @@ class Result(_WithKeys, ResultInternal[Row]):
return FrozenResult(self)
- def merge(self, *others: Result) -> MergedResult:
+ def merge(self, *others: Result[Any]) -> MergedResult[_TP]:
"""Merge this :class:`.Result` with other compatible result
objects.
@@ -1405,9 +1504,17 @@ class FilterResult(ResultInternal[_R]):
"""
- _post_creational_filter: Optional[Callable[[Any], Any]] = None
+ __slots__ = (
+ "_real_result",
+ "_post_creational_filter",
+ "_metadata",
+ "_unique_filter_state",
+ "__dict__",
+ )
+
+ _post_creational_filter: Optional[Callable[[Any], Any]]
- _real_result: Result
+ _real_result: Result[Any]
def _soft_close(self, hard: bool = False) -> None:
self._real_result._soft_close(hard=hard)
@@ -1416,20 +1523,20 @@ class FilterResult(ResultInternal[_R]):
def _attributes(self) -> Dict[Any, Any]:
return self._real_result._attributes
- def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]:
+ def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]:
return self._real_result._fetchiter_impl()
def _fetchone_impl(
self, hard_close: bool = False
- ) -> Optional[_InterimRowType[Row]]:
+ ) -> Optional[_InterimRowType[Row[Any]]]:
return self._real_result._fetchone_impl(hard_close=hard_close)
- def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+ def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
return self._real_result._fetchall_impl()
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
return self._real_result._fetchmany_impl(size=size)
@@ -1452,11 +1559,13 @@ class ScalarResult(FilterResult[_R]):
"""
+ __slots__ = ()
+
_generate_rows = False
_post_creational_filter: Optional[Callable[[Any], Any]]
- def __init__(self, real_result: Result, index: _KeyIndexType):
+ def __init__(self, real_result: Result[Any], index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
@@ -1480,7 +1589,7 @@ class ScalarResult(FilterResult[_R]):
self._unique_filter_state = (set(), strategy)
return self
- def partitions(self, size: Optional[int] = None) -> Iterator[List[_R]]:
+ def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_result.Result.partitions` except that
@@ -1498,12 +1607,12 @@ class ScalarResult(FilterResult[_R]):
else:
break
- def fetchall(self) -> List[_R]:
+ def fetchall(self) -> Sequence[_R]:
"""A synonym for the :meth:`_engine.ScalarResult.all` method."""
return self._allrows()
- def fetchmany(self, size: Optional[int] = None) -> List[_R]:
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
"""Fetch many objects.
Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1513,7 +1622,7 @@ class ScalarResult(FilterResult[_R]):
"""
return self._manyrow_getter(self, size)
- def all(self) -> List[_R]:
+ def all(self) -> Sequence[_R]:
"""Return all scalar values in a list.
Equivalent to :meth:`_result.Result.all` except that
@@ -1567,6 +1676,177 @@ class ScalarResult(FilterResult[_R]):
)
+SelfTupleResult = TypeVar("SelfTupleResult", bound="TupleResult[Any]")
+
+
+class TupleResult(FilterResult[_R], util.TypingOnly):
+ """a :class:`.Result` that's typed as returning plain Python tuples
+ instead of rows.
+
+ Since :class:`.Row` acts like a tuple in every way already,
+ this class is a typing only class, regular :class:`.Result` is still
+ used at runtime.
+
+ """
+
+ __slots__ = ()
+
+ if TYPE_CHECKING:
+
+ def partitions(
+ self, size: Optional[int] = None
+ ) -> Iterator[Sequence[_R]]:
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def fetchone(self) -> Optional[_R]:
+ """Fetch one tuple.
+
+ Equivalent to :meth:`_result.Result.fetchone` except that
+ tuple values, rather than :class:`_result.Row`
+ objects, are returned.
+
+ """
+ ...
+
+ def fetchall(self) -> Sequence[_R]:
+ """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+ ...
+
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def all(self) -> Sequence[_R]: # noqa: A001
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def __iter__(self) -> Iterator[_R]:
+ ...
+
+ def __next__(self) -> _R:
+ ...
+
+ def first(self) -> Optional[_R]:
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ ...
+
+ def one_or_none(self) -> Optional[_R]:
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ def one(self) -> _R:
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ @overload
+ def scalar_one(self: TupleResult[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ def scalar_one(self) -> Any:
+ ...
+
+ def scalar_one(self) -> Any:
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
+ def scalar_one_or_none(self) -> Optional[Any]:
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(self) -> Any:
+ ...
+
+ def scalar(self) -> Any:
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ ...
+
+
SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult")
@@ -1579,11 +1859,13 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
"""
+ __slots__ = ()
+
_generate_rows = True
_post_creational_filter = operator.attrgetter("_mapping")
- def __init__(self, result: Result):
+ def __init__(self, result: Result[Any]):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
@@ -1610,7 +1892,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
def partitions(
self, size: Optional[int] = None
- ) -> Iterator[List[RowMapping]]:
+ ) -> Iterator[Sequence[RowMapping]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_result.Result.partitions` except that
@@ -1628,7 +1910,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
else:
break
- def fetchall(self) -> List[RowMapping]:
+ def fetchall(self) -> Sequence[RowMapping]:
"""A synonym for the :meth:`_engine.MappingResult.all` method."""
return self._allrows()
@@ -1648,7 +1930,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
else:
return row
- def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+ def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]:
"""Fetch many objects.
Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1659,7 +1941,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
return self._manyrow_getter(self, size)
- def all(self) -> List[RowMapping]:
+ def all(self) -> Sequence[RowMapping]:
"""Return all scalar values in a list.
Equivalent to :meth:`_result.Result.all` except that
@@ -1714,7 +1996,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
)
-class FrozenResult:
+class FrozenResult(Generic[_TP]):
"""Represents a :class:`.Result` object in a "frozen" state suitable
for caching.
@@ -1755,7 +2037,7 @@ class FrozenResult:
data: Sequence[Any]
- def __init__(self, result: Result):
+ def __init__(self, result: Result[_TP]):
self.metadata = result._metadata._for_freeze()
self._source_supports_scalars = result._source_supports_scalars
self._attributes = result._attributes
@@ -1771,7 +2053,9 @@ class FrozenResult:
else:
return [list(row) for row in self.data]
- def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult:
+ def with_new_rows(
+ self, tuple_data: Sequence[Row[_TP]]
+ ) -> FrozenResult[_TP]:
fr = FrozenResult.__new__(FrozenResult)
fr.metadata = self.metadata
fr._attributes = self._attributes
@@ -1783,14 +2067,16 @@ class FrozenResult:
fr.data = tuple_data
return fr
- def __call__(self) -> Result:
- result = IteratorResult(self.metadata, iter(self.data))
+ def __call__(self) -> Result[_TP]:
+ result: IteratorResult[_TP] = IteratorResult(
+ self.metadata, iter(self.data)
+ )
result._attributes = self._attributes
result._source_supports_scalars = self._source_supports_scalars
return result
-class IteratorResult(Result):
+class IteratorResult(Result[_TP]):
"""A :class:`.Result` that gets data from a Python iterator of
:class:`.Row` objects or similar row-like data.
@@ -1833,7 +2119,7 @@ class IteratorResult(Result):
def _fetchone_impl(
self, hard_close: bool = False
- ) -> Optional[_InterimRowType[Row]]:
+ ) -> Optional[_InterimRowType[Row[Any]]]:
if self._hard_closed:
self._raise_hard_closed()
@@ -1844,7 +2130,7 @@ class IteratorResult(Result):
else:
return row
- def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+ def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
if self._hard_closed:
self._raise_hard_closed()
try:
@@ -1854,23 +2140,23 @@ class IteratorResult(Result):
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
if self._hard_closed:
self._raise_hard_closed()
return list(itertools.islice(self.iterator, 0, size))
-def null_result() -> IteratorResult:
+def null_result() -> IteratorResult[Any]:
return IteratorResult(SimpleResultMetaData([]), iter([]))
SelfChunkedIteratorResult = TypeVar(
- "SelfChunkedIteratorResult", bound="ChunkedIteratorResult"
+ "SelfChunkedIteratorResult", bound="ChunkedIteratorResult[Any]"
)
-class ChunkedIteratorResult(IteratorResult):
+class ChunkedIteratorResult(IteratorResult[_TP]):
"""An :class:`.IteratorResult` that works from an iterator-producing callable.
The given ``chunks`` argument is a function that is given a number of rows
@@ -1922,13 +2208,13 @@ class ChunkedIteratorResult(IteratorResult):
def _fetchmany_impl(
self, size: Optional[int] = None
- ) -> List[_InterimRowType[Row]]:
+ ) -> List[_InterimRowType[Row[Any]]]:
if self.dynamic_yield_per:
self.iterator = itertools.chain.from_iterable(self.chunks(size))
return super()._fetchmany_impl(size=size)
-class MergedResult(IteratorResult):
+class MergedResult(IteratorResult[_TP]):
"""A :class:`_engine.Result` that is merged from any number of
:class:`_engine.Result` objects.
@@ -1942,7 +2228,7 @@ class MergedResult(IteratorResult):
rowcount: Optional[int]
def __init__(
- self, cursor_metadata: ResultMetaData, results: Sequence[Result]
+ self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]]
):
self._results = results
super(MergedResult, self).__init__(
diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py
index 4ba39b55d..7c9eacb78 100644
--- a/lib/sqlalchemy/engine/row.py
+++ b/lib/sqlalchemy/engine/row.py
@@ -16,6 +16,7 @@ import typing
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generic
from typing import Iterator
from typing import List
from typing import Mapping
@@ -24,12 +25,14 @@ from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from ..sql import util as sql_util
from ..util._has_cy import HAS_CYEXTENSION
-if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
+if TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_row import BaseRow as BaseRow
from ._py_row import KEY_INTEGER_ONLY
from ._py_row import KEY_OBJECTS_ONLY
@@ -38,13 +41,16 @@ else:
from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY
from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
from .result import _KeyType
from .result import RMKeyView
from ..sql.type_api import _ResultProcessorType
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
-class Row(BaseRow, typing.Sequence[Any]):
+
+class Row(BaseRow, Sequence[Any], Generic[_TP]):
"""Represent a single result row.
The :class:`.Row` object represents a row of a database result. It is
@@ -82,6 +88,37 @@ class Row(BaseRow, typing.Sequence[Any]):
def __delattr__(self, name: str) -> NoReturn:
raise AttributeError("can't delete attribute")
+ def tuple(self) -> _TP:
+ """Return a 'tuple' form of this :class:`.Row`.
+
+ At runtime, this method returns "self"; the :class:`.Row` object is
+ already a named tuple. However, at the typing level, if this
+ :class:`.Row` is typed, the "tuple" return type will be a :pep:`484`
+ ``Tuple`` datatype that contains typing information about individual
+ elements, supporting typed unpacking and attribute access.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.Result.tuples`
+
+ """
+ return self # type: ignore
+
+ @property
+ def t(self) -> _TP:
+ """a synonym for :attr:`.Row.tuple`
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.Result.t`
+
+ """
+ return self # type: ignore
+
@property
def _mapping(self) -> RowMapping:
"""Return a :class:`.RowMapping` for this :class:`.Row`.
@@ -107,7 +144,7 @@ class Row(BaseRow, typing.Sequence[Any]):
def _filter_on_values(
self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]]
- ) -> Row:
+ ) -> Row[Any]:
return Row(
self._parent,
filters,
@@ -116,7 +153,7 @@ class Row(BaseRow, typing.Sequence[Any]):
self._data,
)
- if not typing.TYPE_CHECKING:
+ if not TYPE_CHECKING:
def _special_name_accessor(name: str) -> Any:
"""Handle ambiguous names such as "count" and "index" """
@@ -151,7 +188,7 @@ class Row(BaseRow, typing.Sequence[Any]):
__hash__ = BaseRow.__hash__
- if typing.TYPE_CHECKING:
+ if TYPE_CHECKING:
@overload
def __getitem__(self, index: int) -> Any:
@@ -299,7 +336,7 @@ class RowMapping(BaseRow, typing.Mapping[str, Any]):
_default_key_style = KEY_OBJECTS_ONLY
- if typing.TYPE_CHECKING:
+ if TYPE_CHECKING:
def __getitem__(self, key: _KeyType) -> Any:
...
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index fb05f512e..95549ada6 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -12,8 +12,10 @@ from typing import Generator
from typing import NoReturn
from typing import Optional
from typing import overload
+from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import exc as async_exc
@@ -50,6 +52,9 @@ if TYPE_CHECKING:
from ...pool import PoolProxiedConnection
from ...sql._typing import _InfoType
from ...sql.base import Executable
+ from ...sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
class _SyncConnectionCallable(Protocol):
@@ -407,7 +412,7 @@ class AsyncConnection(
statement: str,
parameters: Optional[_DBAPIAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
r"""Executes a driver-level SQL string and return buffered
:class:`_engine.Result`.
@@ -423,12 +428,33 @@ class AsyncConnection(
return await _ensure_sync_result(result, self.exec_driver_sql)
+ @overload
+ async def stream(
+ self,
+ statement: TypedReturnsRows[_T],
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncResult[_T]:
+ ...
+
+ @overload
async def stream(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> AsyncResult:
+ ) -> AsyncResult[Any]:
+ ...
+
+ async def stream(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncResult[Any]:
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object."""
@@ -436,7 +462,7 @@ class AsyncConnection(
self._proxied.execute,
statement,
parameters,
- util.EMPTY_DICT.merge_with(
+ execution_options=util.EMPTY_DICT.merge_with(
execution_options, {"stream_results": True}
),
_require_await=True,
@@ -446,12 +472,33 @@ class AsyncConnection(
assert False, "server side result expected"
return AsyncResult(result)
+ @overload
+ async def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[_T]:
+ ...
+
+ @overload
async def execute(
self,
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
+ ...
+
+ async def execute(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult[Any]:
r"""Executes a SQL statement construct and return a buffered
:class:`_engine.Result`.
@@ -487,15 +534,36 @@ class AsyncConnection(
self._proxied.execute,
statement,
parameters,
- execution_options,
+ execution_options=execution_options,
_require_await=True,
)
return await _ensure_sync_result(result, self.execute)
+ @overload
+ async def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
async def scalar(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
+ ...
+
+ async def scalar(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> Any:
r"""Executes a SQL statement construct and returns a scalar object.
@@ -508,13 +576,36 @@ class AsyncConnection(
first row returned.
"""
- result = await self.execute(statement, parameters, execution_options)
+ result = await self.execute(
+ statement, parameters, execution_options=execution_options
+ )
return result.scalar()
+ @overload
+ async def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ async def scalars(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[Any]:
+ ...
+
async def scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
r"""Executes a SQL statement construct and returns a scalar objects.
@@ -528,13 +619,36 @@ class AsyncConnection(
.. versionadded:: 1.4.24
"""
- result = await self.execute(statement, parameters, execution_options)
+ result = await self.execute(
+ statement, parameters, execution_options=execution_options
+ )
return result.scalars()
+ @overload
+ async def stream_scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
async def stream_scalars(
self,
statement: Executable,
parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncScalarResult[Any]:
+ ...
+
+ async def stream_scalars(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ *,
execution_options: Optional[_ExecuteOptionsParameter] = None,
) -> AsyncScalarResult[Any]:
r"""Executes a SQL statement and returns a streaming scalar result
@@ -549,7 +663,9 @@ class AsyncConnection(
.. versionadded:: 1.4.24
"""
- result = await self.stream(statement, parameters, execution_options)
+ result = await self.stream(
+ statement, parameters, execution_options=execution_options
+ )
return result.scalars()
async def run_sync(
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
index d0337554c..ff3dcf417 100644
--- a/lib/sqlalchemy/ext/asyncio/result.py
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -9,12 +9,15 @@ from __future__ import annotations
import operator
from typing import Any
from typing import AsyncIterator
-from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from . import exc as async_exc
+from ... import util
from ...engine.result import _NO_ROW
from ...engine.result import _R
from ...engine.result import FilterResult
@@ -24,6 +27,7 @@ from ...engine.result import ResultMetaData
from ...engine.row import Row
from ...engine.row import RowMapping
from ...util.concurrency import greenlet_spawn
+from ...util.typing import Literal
if TYPE_CHECKING:
from ...engine import CursorResult
@@ -32,9 +36,14 @@ if TYPE_CHECKING:
from ...engine.result import _UniqueFilterType
from ...engine.result import RMKeyView
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
+
class AsyncCommon(FilterResult[_R]):
- _real_result: Result
+ __slots__ = ()
+
+ _real_result: Result[Any]
_metadata: ResultMetaData
async def close(self) -> None:
@@ -43,10 +52,10 @@ class AsyncCommon(FilterResult[_R]):
await greenlet_spawn(self._real_result.close)
-SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult")
+SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]")
-class AsyncResult(AsyncCommon[Row]):
+class AsyncResult(AsyncCommon[Row[_TP]]):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -67,11 +76,16 @@ class AsyncResult(AsyncCommon[Row]):
"""
- def __init__(self, real_result: Result):
+ __slots__ = ()
+
+ _real_result: Result[_TP]
+
+ def __init__(self, real_result: Result[_TP]):
self._real_result = real_result
self._metadata = real_result._metadata
self._unique_filter_state = real_result._unique_filter_state
+ self._post_creational_filter = None
# BaseCursorResult pre-generates the "_row_getter". Use that
# if available rather than building a second one
@@ -80,6 +94,43 @@ class AsyncResult(AsyncCommon[Row]):
"_row_getter", real_result.__dict__["_row_getter"]
)
+ @property
+ def t(self) -> AsyncTupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ The :attr:`.AsyncResult.t` attribute is a synonym for calling the
+ :meth:`.AsyncResult.tuples` method.
+
+ .. versionadded:: 2.0
+
+ """
+ return self # type: ignore
+
+ def tuples(self) -> AsyncTupleResult[_TP]:
+ """Apply a "typed tuple" typing filter to returned rows.
+
+ This method returns the same :class:`.AsyncResult` object at runtime,
+ however annotates as returning a :class:`.AsyncTupleResult` object
+ that will indicate to :pep:`484` typing tools that plain typed
+ ``Tuple`` instances are returned rather than rows. This allows
+ tuple unpacking and ``__getitem__`` access of :class:`.Row` objects
+ to by typed, for those cases where the statement invoked itself
+ included typing information.
+
+ .. versionadded:: 2.0
+
+ :return: the :class:`_result.AsyncTupleResult` type at typing time.
+
+ .. seealso::
+
+ :attr:`.AsyncResult.t` - shorter synonym
+
+ :attr:`.Row.t` - :class:`.Row` version
+
+ """
+
+ return self # type: ignore
+
def keys(self) -> RMKeyView:
"""Return the :meth:`_engine.Result.keys` collection from the
underlying :class:`_engine.Result`.
@@ -115,7 +166,7 @@ class AsyncResult(AsyncCommon[Row]):
async def partitions(
self, size: Optional[int] = None
- ) -> AsyncIterator[List[Row]]:
+ ) -> AsyncIterator[Sequence[Row[_TP]]]:
"""Iterate through sub-lists of rows of the size given.
An async iterator is returned::
@@ -141,7 +192,16 @@ class AsyncResult(AsyncCommon[Row]):
else:
break
- async def fetchone(self) -> Optional[Row]:
+ async def fetchall(self) -> Sequence[Row[_TP]]:
+ """A synonym for the :meth:`.AsyncResult.all` method.
+
+ .. versionadded:: 2.0
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchone(self) -> Optional[Row[_TP]]:
"""Fetch one row.
When all rows are exhausted, returns None.
@@ -163,7 +223,9 @@ class AsyncResult(AsyncCommon[Row]):
else:
return row
- async def fetchmany(self, size: Optional[int] = None) -> List[Row]:
+ async def fetchmany(
+ self, size: Optional[int] = None
+ ) -> Sequence[Row[_TP]]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
@@ -184,7 +246,7 @@ class AsyncResult(AsyncCommon[Row]):
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self) -> List[Row]:
+ async def all(self) -> Sequence[Row[_TP]]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
@@ -196,17 +258,17 @@ class AsyncResult(AsyncCommon[Row]):
return await greenlet_spawn(self._allrows)
- def __aiter__(self) -> AsyncResult:
+ def __aiter__(self) -> AsyncResult[_TP]:
return self
- async def __anext__(self) -> Row:
+ async def __anext__(self) -> Row[_TP]:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self) -> Optional[Row]:
+ async def first(self) -> Optional[Row[_TP]]:
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
@@ -229,7 +291,7 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self) -> Optional[Row]:
+ async def one_or_none(self) -> Optional[Row[_TP]]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
@@ -251,6 +313,14 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
+ @overload
+ async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ async def scalar_one(self) -> Any:
+ ...
+
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
@@ -266,6 +336,16 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, True, True)
+ @overload
+ async def scalar_one_or_none(
+ self: AsyncResult[Tuple[_T]],
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
@@ -281,7 +361,7 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, False, True)
- async def one(self) -> Row:
+ async def one(self) -> Row[_TP]:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
@@ -312,6 +392,14 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
+ @overload
+ async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(self) -> Any:
+ ...
+
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
@@ -328,7 +416,7 @@ class AsyncResult(AsyncCommon[Row]):
"""
return await greenlet_spawn(self._only_one_row, False, False, True)
- async def freeze(self) -> FrozenResult:
+ async def freeze(self) -> FrozenResult[_TP]:
"""Return a callable object that will produce copies of this
:class:`_asyncio.AsyncResult` when invoked.
@@ -351,7 +439,7 @@ class AsyncResult(AsyncCommon[Row]):
return await greenlet_spawn(FrozenResult, self)
- def merge(self, *others: AsyncResult) -> MergedResult:
+ def merge(self, *others: AsyncResult[_TP]) -> MergedResult[_TP]:
"""Merge this :class:`_asyncio.AsyncResult` with other compatible result
objects.
@@ -370,6 +458,20 @@ class AsyncResult(AsyncCommon[Row]):
(self._real_result,) + tuple(o._real_result for o in others),
)
+ @overload
+ def scalars(
+ self: AsyncResult[Tuple[_T]], index: Literal[0]
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
+ ...
+
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
@@ -423,9 +525,11 @@ class AsyncScalarResult(AsyncCommon[_R]):
"""
+ __slots__ = ()
+
_generate_rows = False
- def __init__(self, real_result: Result, index: _KeyIndexType):
+ def __init__(self, real_result: Result[Any], index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
@@ -452,7 +556,7 @@ class AsyncScalarResult(AsyncCommon[_R]):
async def partitions(
self, size: Optional[int] = None
- ) -> AsyncIterator[List[_R]]:
+ ) -> AsyncIterator[Sequence[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -470,12 +574,12 @@ class AsyncScalarResult(AsyncCommon[_R]):
else:
break
- async def fetchall(self) -> List[_R]:
+ async def fetchall(self) -> Sequence[_R]:
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
return await greenlet_spawn(self._allrows)
- async def fetchmany(self, size: Optional[int] = None) -> List[_R]:
+ async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
"""Fetch many objects.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -485,7 +589,7 @@ class AsyncScalarResult(AsyncCommon[_R]):
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self) -> List[_R]:
+ async def all(self) -> Sequence[_R]:
"""Return all scalar values in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -555,11 +659,13 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
"""
+ __slots__ = ()
+
_generate_rows = True
_post_creational_filter = operator.attrgetter("_mapping")
- def __init__(self, result: Result):
+ def __init__(self, result: Result[Any]):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
@@ -602,7 +708,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
async def partitions(
self, size: Optional[int] = None
- ) -> AsyncIterator[List[RowMapping]]:
+ ) -> AsyncIterator[Sequence[RowMapping]]:
"""Iterate through sub-lists of elements of the size given.
@@ -621,7 +727,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
else:
break
- async def fetchall(self) -> List[RowMapping]:
+ async def fetchall(self) -> Sequence[RowMapping]:
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
return await greenlet_spawn(self._allrows)
@@ -641,7 +747,9 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
else:
return row
- async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+ async def fetchmany(
+ self, size: Optional[int] = None
+ ) -> Sequence[RowMapping]:
"""Fetch many rows.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -652,7 +760,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self) -> List[RowMapping]:
+ async def all(self) -> Sequence[RowMapping]:
"""Return all rows in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -705,11 +813,186 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
return await greenlet_spawn(self._only_one_row, True, True, False)
-_RT = TypeVar("_RT", bound="Result")
+SelfAsyncTupleResult = TypeVar(
+ "SelfAsyncTupleResult", bound="AsyncTupleResult[Any]"
+)
+
+
+class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
+ """a :class:`.AsyncResult` that's typed as returning plain Python tuples
+ instead of rows.
+
+ Since :class:`.Row` acts like a tuple in every way already,
+ this class is a typing only class, regular :class:`.AsyncResult` is
+ still used at runtime.
+
+ """
+
+ __slots__ = ()
+
+ if TYPE_CHECKING:
+
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[Sequence[_R]]:
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def fetchone(self) -> Optional[_R]:
+ """Fetch one tuple.
+
+ Equivalent to :meth:`_result.Result.fetchone` except that
+ tuple values, rather than :class:`_result.Row`
+ objects, are returned.
+
+ """
+ ...
+
+ async def fetchall(self) -> Sequence[_R]:
+ """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+ ...
+
+ async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def all(self) -> Sequence[_R]: # noqa: A001
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def __aiter__(self) -> AsyncIterator[_R]:
+ ...
+
+ async def __anext__(self) -> _R:
+ ...
+
+ async def first(self) -> Optional[_R]:
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ ...
+
+ async def one_or_none(self) -> Optional[_R]:
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ async def one(self) -> _R:
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ tuple values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ ...
+
+ @overload
+ async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T:
+ ...
+
+ @overload
+ async def scalar_one(self) -> Any:
+ ...
+
+ async def scalar_one(self) -> Any:
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ async def scalar_one_or_none(
+ self: AsyncTupleResult[Tuple[_T]],
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar_one_or_none(self) -> Optional[Any]:
+ ...
+
+ async def scalar_one_or_none(self) -> Optional[Any]:
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ ...
+
+ @overload
+ async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(self) -> Any:
+ ...
+
+ async def scalar(self) -> Any:
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ ...
+
+
+_RT = TypeVar("_RT", bound="Result[Any]")
async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
- cursor_result: CursorResult
+ cursor_result: CursorResult[Any]
try:
is_cursor = result._is_cursor
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
index c7a6e2ca0..22a060a0d 100644
--- a/lib/sqlalchemy/ext/asyncio/scoping.py
+++ b/lib/sqlalchemy/ext/asyncio/scoping.py
@@ -12,10 +12,12 @@ from typing import Callable
from typing import Iterable
from typing import Iterator
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from .session import async_sessionmaker
@@ -37,9 +39,9 @@ if TYPE_CHECKING:
from ...engine import Engine
from ...engine import Result
from ...engine import Row
+ from ...engine import RowMapping
from ...engine.interfaces import _CoreAnyExecuteParams
from ...engine.interfaces import _CoreSingleExecuteParams
- from ...engine.interfaces import _ExecuteOptions
from ...engine.interfaces import _ExecuteOptionsParameter
from ...engine.result import ScalarResult
from ...orm._typing import _IdentityKeyType
@@ -52,6 +54,9 @@ if TYPE_CHECKING:
from ...sql.base import Executable
from ...sql.elements import ClauseElement
from ...sql.selectable import ForUpdateArg
+ from ...sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
@create_proxy_methods(
@@ -480,6 +485,32 @@ class async_scoped_session:
return await self._proxied.delete(instance)
+ @overload
+ async def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ async def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
async def execute(
self,
statement: Executable,
@@ -488,7 +519,7 @@ class async_scoped_session:
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -916,6 +947,30 @@ class async_scoped_session:
return await self._proxied.rollback()
+ @overload
+ async def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
async def scalar(
self,
statement: Executable,
@@ -947,6 +1002,30 @@ class async_scoped_session:
**kw,
)
+ @overload
+ async def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ async def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
async def scalars(
self,
statement: Executable,
@@ -984,6 +1063,19 @@ class async_scoped_session:
**kw,
)
+ @overload
+ async def stream(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[_T]:
+ ...
+
+ @overload
async def stream(
self,
statement: Executable,
@@ -992,7 +1084,18 @@ class async_scoped_session:
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> AsyncResult:
+ ) -> AsyncResult[Any]:
+ ...
+
+ async def stream(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[Any]:
r"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -1012,6 +1115,30 @@ class async_scoped_session:
**kw,
)
+ @overload
+ async def stream_scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ async def stream_scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
+ ...
+
async def stream_scalars(
self,
statement: Executable,
@@ -1323,7 +1450,7 @@ class async_scoped_session:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
r"""Return an identity key.
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 1422f99a3..f2a69e9cd 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -12,10 +12,12 @@ from typing import Iterable
from typing import Iterator
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import engine
@@ -39,11 +41,10 @@ if TYPE_CHECKING:
from ...engine import Engine
from ...engine import Result
from ...engine import Row
+ from ...engine import RowMapping
from ...engine import ScalarResult
- from ...engine import Transaction
from ...engine.interfaces import _CoreAnyExecuteParams
from ...engine.interfaces import _CoreSingleExecuteParams
- from ...engine.interfaces import _ExecuteOptions
from ...engine.interfaces import _ExecuteOptionsParameter
from ...event import dispatcher
from ...orm._typing import _IdentityKeyType
@@ -59,9 +60,12 @@ if TYPE_CHECKING:
from ...sql.base import Executable
from ...sql.elements import ClauseElement
from ...sql.selectable import ForUpdateArg
+ from ...sql.selectable import TypedReturnsRows
_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
+_T = TypeVar("_T", bound=Any)
+
class _SyncSessionCallable(Protocol):
def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
@@ -257,6 +261,32 @@ class AsyncSession(ReversibleProxy[Session]):
return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+ @overload
+ async def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ async def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
async def execute(
self,
statement: Executable,
@@ -265,7 +295,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> Result:
+ ) -> Result[Any]:
"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -292,6 +322,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return await _ensure_sync_result(result, self.execute)
+ @overload
+ async def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ async def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
async def scalar(
self,
statement: Executable,
@@ -326,6 +380,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return result
+ @overload
+ async def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ async def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
async def scalars(
self,
statement: Executable,
@@ -391,6 +469,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return result_obj
+ @overload
+ async def stream(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[_T]:
+ ...
+
+ @overload
+ async def stream(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult[Any]:
+ ...
+
async def stream(
self,
statement: Executable,
@@ -399,7 +501,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
- ) -> AsyncResult:
+ ) -> AsyncResult[Any]:
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -423,6 +525,30 @@ class AsyncSession(ReversibleProxy[Session]):
)
return AsyncResult(result)
+ @overload
+ async def stream_scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[_T]:
+ ...
+
+ @overload
+ async def stream_scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
+ ...
+
async def stream_scalars(
self,
statement: Executable,
@@ -1215,7 +1341,7 @@ class AsyncSession(ReversibleProxy[Session]):
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
r"""Return an identity key.
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
index b1138a4ad..c14b466eb 100644
--- a/lib/sqlalchemy/ext/instrumentation.py
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -23,6 +23,7 @@ from ..orm import base as orm_base
from ..orm import collections
from ..orm import exc as orm_exc
from ..orm import instrumentation as orm_instrumentation
+from ..orm import util as orm_util
from ..orm.instrumentation import _default_dict_getter
from ..orm.instrumentation import _default_manager_getter
from ..orm.instrumentation import _default_opt_manager_getter
@@ -437,5 +438,7 @@ def _install_lookups(lookups):
attributes.manager_of_class
) = orm_instrumentation.manager_of_class = manager_of_class
orm_base.opt_manager_of_class = (
+ orm_util.opt_manager_of_class
+ ) = (
attributes.opt_manager_of_class
) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 457ad5c5a..48615b174 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -38,6 +38,7 @@ from ..exc import InvalidRequestError
from ..sql.base import SchemaEventTarget
from ..sql.schema import SchemaConst
from ..sql.selectable import FromClause
+from ..util.typing import Annotated
from ..util.typing import Literal
if TYPE_CHECKING:
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
from ._typing import _ORMColumnExprArgument
from .descriptor_props import _CompositeAttrType
from .interfaces import PropComparator
+ from .mapper import Mapper
from .query import Query
from .relationships import _LazyLoadArgumentType
from .relationships import _ORMBackrefArgument
@@ -1849,9 +1851,27 @@ def clear_mappers():
mapperlib._dispose_registries(mapperlib._all_registries(), False)
+# I would really like a way to get the Type[] here that shows up
+# in a different way in typing tools, however there is no current method
+# that is accepted by mypy (subclass of Type[_O] works in pylance, rejected
+# by mypy).
+AliasedType = Annotated[Type[_O], "aliased"]
+
+
+@overload
+def aliased(
+ element: Type[_O],
+ alias: Optional[Union[Alias, Subquery]] = None,
+ name: Optional[str] = None,
+ flat: bool = False,
+ adapt_on_names: bool = False,
+) -> AliasedType[_O]:
+ ...
+
+
@overload
def aliased(
- element: _EntityType[_O],
+ element: Union[AliasedClass[_O], Mapper[_O], AliasedInsp[_O]],
alias: Optional[Union[Alias, Subquery]] = None,
name: Optional[str] = None,
flat: bool = False,
@@ -1877,7 +1897,7 @@ def aliased(
name: Optional[str] = None,
flat: bool = False,
adapt_on_names: bool = False,
-) -> Union[AliasedClass[_O], FromClause]:
+) -> Union[AliasedClass[_O], FromClause, AliasedType[_O]]:
"""Produce an alias of the given element, usually an :class:`.AliasedClass`
instance.
@@ -1885,7 +1905,8 @@ def aliased(
my_alias = aliased(MyClass)
- session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+ stmt = select(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+ result = session.execute(stmt)
The :func:`.aliased` function is used to create an ad-hoc mapping of a
mapped class to a new selectable. By default, a selectable is generated
@@ -1911,6 +1932,9 @@ def aliased(
.. seealso::
+ :class:`.AsAliased` - a :pep:`484` typed version of
+ :func:`_orm.aliased`
+
:ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial`
:ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel`
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 41d944c57..619af6510 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -70,6 +70,7 @@ from .. import exc
from .. import inspection
from .. import util
from ..sql import base as sql_base
+from ..sql import cache_key
from ..sql import roles
from ..sql import traversals
from ..sql import visitors
@@ -99,10 +100,8 @@ class QueryableAttribute(
traversals.HasCopyInternals,
roles.JoinTargetRole,
roles.OnClauseRole,
- roles.ColumnsClauseRole,
- roles.ExpressionElementRole[_T],
sql_base.Immutable,
- sql_base.MemoizedHasCacheKey,
+ cache_key.MemoizedHasCacheKey,
):
"""Base class for :term:`descriptor` objects that intercept
attribute events on behalf of a :class:`.MapperProperty`
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 054d52d83..367a5332d 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -30,6 +30,7 @@ from ._typing import insp_is_mapper
from .. import exc as sa_exc
from .. import inspection
from .. import util
+from ..sql import roles
from ..sql.elements import SQLCoreOperations
from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
@@ -483,19 +484,6 @@ def _inspect_mapped_class(
return mapper
-@inspection._inspects(type)
-def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]:
- try:
- class_manager = opt_manager_of_class(class_)
- if class_manager is None or not class_manager.is_mapped:
- return None
- mapper = class_manager.mapper
- except exc.NO_STATE:
- return None
- else:
- return mapper
-
-
def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]:
insp = inspection.inspect(arg, raiseerr=False)
if insp_is_mapper(insp):
@@ -691,7 +679,7 @@ class ORMDescriptor(Generic[_T], TypingOnly):
...
-class Mapped(ORMDescriptor[_T], TypingOnly):
+class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly):
"""Represent an ORM mapped attribute on a mapped class.
This class represents the complete descriptor interface for any class
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 4fee2d383..05287cbcf 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -17,6 +17,7 @@ from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import attributes
@@ -48,14 +49,15 @@ from ..sql.base import _select_iterables
from ..sql.base import CacheableOptions
from ..sql.base import CompileState
from ..sql.base import Executable
+from ..sql.base import Generative
from ..sql.base import Options
from ..sql.dml import UpdateBase
from ..sql.elements import GroupedElement
from ..sql.elements import TextClause
+from ..sql.selectable import ExecutableReturnsRows
from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
-from ..sql.selectable import ReturnsRows
from ..sql.selectable import Select
from ..sql.selectable import SelectLabelStyle
from ..sql.selectable import SelectState
@@ -72,6 +74,7 @@ if TYPE_CHECKING:
from ..sql.selectable import SelectBase
from ..sql.type_api import TypeEngine
+_T = TypeVar("_T", bound=Any)
_path_registry = PathRegistry.root
_EMPTY_DICT = util.immutabledict()
@@ -574,7 +577,7 @@ class ORMFromStatementCompileState(ORMCompileState):
return None
-class FromStatement(GroupedElement, ReturnsRows, Executable):
+class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
"""Core construct that represents a load of ORM objects from various
:class:`.ReturnsRows` and other classes including:
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 0ca62b7e3..6a5690be2 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -61,7 +61,6 @@ from ..sql.schema import Column
from ..sql.type_api import TypeEngine
from ..util.typing import TypedDict
-
if typing.TYPE_CHECKING:
from ._typing import _EntityType
from ._typing import _IdentityKeyType
@@ -106,12 +105,12 @@ class ORMStatementRole(roles.StatementRole):
)
-class ORMColumnsClauseRole(roles.ColumnsClauseRole):
+class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]):
__slots__ = ()
_role_name = "ORM mapped entity, aliased entity, or Column expression"
-class ORMEntityColumnsClauseRole(ORMColumnsClauseRole):
+class ORMEntityColumnsClauseRole(ORMColumnsClauseRole[_T]):
__slots__ = ()
_role_name = "ORM mapped or aliased entity"
@@ -127,8 +126,8 @@ class ORMColumnDescription(TypedDict):
# into "type" is a bad idea
type: Union[Type[Any], TypeEngine[Any]]
aliased: bool
- expr: _ColumnsClauseArgument
- entity: Optional[_ColumnsClauseArgument]
+ expr: _ColumnsClauseArgument[Any]
+ entity: Optional[_ColumnsClauseArgument[Any]]
class _IntrospectsAnnotations:
@@ -282,7 +281,7 @@ class MapperProperty(
query_entity: _MapperEntity,
path: PathRegistry,
mapper: Mapper[Any],
- result: Result,
+ result: Result[Any],
adapter: Optional[ColumnAdapter],
populators: _PopulatorDict,
) -> None:
@@ -1170,7 +1169,7 @@ class LoaderStrategy:
path: AbstractEntityRegistry,
loadopt: Optional[_LoadElement],
mapper: Mapper[Any],
- result: Result,
+ result: Result[Any],
adapter: Optional[ORMAdapter],
populators: _PopulatorDict,
) -> None:
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index b37c080ea..083035093 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -96,7 +96,6 @@ if TYPE_CHECKING:
from .descriptor_props import Synonym
from .events import MapperEvents
from .instrumentation import ClassManager
- from .path_registry import AbstractEntityRegistry
from .path_registry import CachingEntityRegistry
from .properties import ColumnProperty
from .relationships import Relationship
@@ -108,10 +107,10 @@ if TYPE_CHECKING:
from ..sql.base import ReadOnlyColumnCollection
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
+ from ..sql.elements import KeyedColumnElement
from ..sql.schema import Column
from ..sql.schema import Table
from ..sql.selectable import FromClause
- from ..sql.selectable import TableClause
from ..sql.util import ColumnAdapter
from ..util import OrderedSet
@@ -161,7 +160,7 @@ _CONFIGURE_MUTEX = threading.RLock()
@log.class_logger
class Mapper(
ORMFromClauseRole,
- ORMEntityColumnsClauseRole,
+ ORMEntityColumnsClauseRole[_O],
MemoizedHasCacheKey,
InspectionAttr,
log.Identified,
@@ -1006,7 +1005,7 @@ class Mapper(
"""
- polymorphic_on: Optional[ColumnElement[Any]]
+ polymorphic_on: Optional[KeyedColumnElement[Any]]
"""The :class:`_schema.Column` or SQL expression specified as the
``polymorphic_on`` argument
for this :class:`_orm.Mapper`, within an inheritance scenario.
@@ -1699,10 +1698,10 @@ class Mapper(
instrument = True
key = getattr(col, "key", None)
if key:
- if self._should_exclude(col.key, col.key, False, col):
+ if self._should_exclude(key, key, False, col):
raise sa_exc.InvalidRequestError(
"Cannot exclude or override the "
- "discriminator column %r" % col.key
+ "discriminator column %r" % key
)
else:
self.polymorphic_on = col = col.label("_sa_polymorphic_on")
@@ -2948,7 +2947,7 @@ class Mapper(
def identity_key_from_row(
self,
- row: Optional[Union[Row, RowMapping]],
+ row: Optional[Union[Row[Any], RowMapping]],
identity_token: Optional[Any] = None,
adapter: Optional[ColumnAdapter] = None,
) -> _IdentityKeyType[_O]:
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 9f37e8457..0ca0559b4 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -54,7 +54,7 @@ from ..util.typing import NoneType
if TYPE_CHECKING:
from ._typing import _ORMColumnExprArgument
from ..sql._typing import _InfoType
- from ..sql.elements import ColumnElement
+ from ..sql.elements import KeyedColumnElement
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
@@ -85,7 +85,8 @@ class ColumnProperty(
inherit_cache = True
_links_to_entity = False
- columns: List[ColumnElement[Any]]
+ columns: List[KeyedColumnElement[Any]]
+ _orig_columns: List[KeyedColumnElement[Any]]
_is_polymorphic_discriminator: bool
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 395d01a1e..5bd302b21 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -27,6 +27,8 @@ from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
+from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -36,6 +38,7 @@ from . import exc as orm_exc
from . import interfaces
from . import loading
from . import util as orm_util
+from ._typing import _O
from .base import _assertions
from .context import _column_descriptions
from .context import _determine_last_joined_entity
@@ -56,6 +59,7 @@ from .. import log
from .. import sql
from .. import util
from ..engine import Result
+from ..engine import Row
from ..sql import coercions
from ..sql import expression
from ..sql import roles
@@ -63,10 +67,12 @@ from ..sql import Select
from ..sql import util as sql_util
from ..sql import visitors
from ..sql._typing import _FromClauseArgument
+from ..sql._typing import _TP
from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import _entity_namespace_key
from ..sql.base import _generative
from ..sql.base import Executable
+from ..sql.base import Generative
from ..sql.expression import Exists
from ..sql.selectable import _MemoizedSelectEntities
from ..sql.selectable import _SelectFromElements
@@ -75,10 +81,33 @@ from ..sql.selectable import HasHints
from ..sql.selectable import HasPrefixes
from ..sql.selectable import HasSuffixes
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util.typing import Literal
if TYPE_CHECKING:
+ from ._typing import _EntityType
+ from .session import Session
+ from ..engine.result import ScalarResult
+ from ..engine.row import Row
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _ColumnsClauseArgument
+ from ..sql._typing import _MAYBE_ENTITY
+ from ..sql._typing import _no_kw
+ from ..sql._typing import _NOT_ENTITY
+ from ..sql._typing import _PropagateAttrsType
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.selectable import _SetupJoinsElement
from ..sql.selectable import Alias
+ from ..sql.selectable import ExecutableReturnsRows
+ from ..sql.selectable import ScalarSelect
from ..sql.selectable import Subquery
__all__ = ["Query", "QueryContext"]
@@ -97,6 +126,7 @@ class Query(
HasSuffixes,
HasHints,
log.Identified,
+ Generative,
Executable,
Generic[_T],
):
@@ -159,9 +189,15 @@ class Query(
# mirrors that of ClauseElement, used to propagate the "orm"
# plugin as well as the "subject" of the plugin, e.g. the mapper
# we are querying against.
- _propagate_attrs = util.immutabledict()
+ @util.memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
+ return util.EMPTY_DICT
- def __init__(self, entities, session=None):
+ def __init__(
+ self,
+ entities: Sequence[_ColumnsClauseArgument[Any]],
+ session: Optional[Session] = None,
+ ):
"""Construct a :class:`_query.Query` directly.
E.g.::
@@ -207,6 +243,36 @@ class Query(
for ent in util.to_list(entities)
]
+ @overload
+ def tuples(self: Query[Row[_TP]]) -> Query[_TP]:
+ ...
+
+ @overload
+ def tuples(self: Query[_O]) -> Query[Tuple[_O]]:
+ ...
+
+ def tuples(self) -> Query[Any]:
+ """return a tuple-typed form of this :class:`.Query`.
+
+ This method invokes the :meth:`.Query.only_return_tuples`
+ method with a value of ``True``, which by itself ensures that this
+ :class:`.Query` will always return :class:`.Row` objects, even
+ if the query is made against a single entity. It then also
+ at the typing level will return a "typed" query, if possible,
+ that will type result rows as ``Tuple`` objects with typed
+ elements.
+
+ This method can be compared to the :meth:`.Result.tuples` method,
+ which returns "self", but from a typing perspective returns an object
+ that will yield typed ``Tuple`` objects for results. Typing
+ takes effect only if this :class:`.Query` object is a typed
+ query object already.
+
+ .. versionadded:: 2.0
+
+ """
+ return self.only_return_tuples(True)
+
def _entity_from_pre_ent_zero(self):
if not self._raw_columns:
return None
@@ -582,20 +648,52 @@ class Query(
return self.enable_eagerloads(False).statement.label(name)
+ @overload
+ def as_scalar(
+ self: Query[Tuple[_MAYBE_ENTITY]],
+ ) -> ScalarSelect[_MAYBE_ENTITY]:
+ ...
+
+ @overload
+ def as_scalar(
+ self: Query[Tuple[_NOT_ENTITY]],
+ ) -> ScalarSelect[_NOT_ENTITY]:
+ ...
+
+ @overload
+ def as_scalar(self) -> ScalarSelect[Any]:
+ ...
+
@util.deprecated(
"1.4",
"The :meth:`_query.Query.as_scalar` method is deprecated and will be "
"removed in a future release. Please refer to "
":meth:`_query.Query.scalar_subquery`.",
)
- def as_scalar(self):
+ def as_scalar(self) -> ScalarSelect[Any]:
"""Return the full SELECT statement represented by this
:class:`_query.Query`, converted to a scalar subquery.
"""
return self.scalar_subquery()
- def scalar_subquery(self):
+ @overload
+ def scalar_subquery(
+ self: Query[Tuple[_MAYBE_ENTITY]],
+ ) -> ScalarSelect[Any]:
+ ...
+
+ @overload
+ def scalar_subquery(
+ self: Query[Tuple[_NOT_ENTITY]],
+ ) -> ScalarSelect[_NOT_ENTITY]:
+ ...
+
+ @overload
+ def scalar_subquery(self) -> ScalarSelect[Any]:
+ ...
+
+ def scalar_subquery(self) -> ScalarSelect[Any]:
"""Return the full SELECT statement represented by this
:class:`_query.Query`, converted to a scalar subquery.
@@ -630,16 +728,31 @@ class Query(
.statement
)
- @_generative
- def only_return_tuples(self: SelfQuery, value) -> SelfQuery:
- """When set to True, the query results will always be a tuple.
+ @overload
+ def only_return_tuples(
+ self: Query[_O], value: Literal[True]
+ ) -> RowReturningQuery[Tuple[_O]]:
+ ...
- This is specifically for single element queries. The default is False.
+ @overload
+ def only_return_tuples(
+ self: Query[_O], value: Literal[False]
+ ) -> Query[_O]:
+ ...
- .. versionadded:: 1.2.5
+ @_generative
+ def only_return_tuples(self, value: bool) -> Query[Any]:
+ """When set to True, the query results will always be a
+ :class:`.Row` object.
+
+ This can change a query that normally returns a single entity
+ as a scalar to return a :class:`.Row` result in all cases.
.. seealso::
+ :meth:`.Query.tuples` - returns tuples, but also at the typing
+ level will type results as ``Tuple``.
+
:meth:`_query.Query.is_single_entity`
"""
@@ -1077,7 +1190,11 @@ class Query(
return self.filter(with_parent(instance, property, entity_zero.entity))
@_generative
- def add_entity(self: SelfQuery, entity, alias=None) -> SelfQuery:
+ def add_entity(
+ self,
+ entity: _EntityType[Any],
+ alias: Optional[Union[Alias, Subquery]] = None,
+ ) -> Query[Any]:
"""add a mapped entity to the list of result columns
to be returned."""
@@ -1209,8 +1326,107 @@ class Query(
except StopIteration:
return None
+ @overload
+ def with_entities(
+ self, _entity: _EntityType[_O], **kwargs: Any
+ ) -> ScalarInstanceQuery[_O]:
+ ...
+
+ @overload
+ def with_entities(
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def with_entities(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def with_entities(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def with_entities(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.with_entities
+
+ @overload
+ def with_entities(
+ self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
+ ) -> SelfQuery:
+ ...
+
@_generative
- def with_entities(self: SelfQuery, *entities) -> SelfQuery:
+ def with_entities(
+ self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> SelfQuery:
r"""Return a new :class:`_query.Query`
replacing the SELECT list with the
given entities.
@@ -1234,12 +1450,14 @@ class Query(
limit(1)
"""
+ if __kw:
+ raise _no_kw()
_MemoizedSelectEntities._generate_for_statement(self)
self._set_entities(entities)
return self
@_generative
- def add_columns(self: SelfQuery, *column) -> SelfQuery:
+ def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]:
"""Add one or more column expressions to the list
of result columns to be returned."""
@@ -1262,7 +1480,7 @@ class Query(
"is deprecated and will be removed in a "
"future release. Please use :meth:`_query.Query.add_columns`",
)
- def add_column(self, column):
+ def add_column(self, column) -> Query[Any]:
"""Add a column expression to the list of result columns to be
returned.
@@ -1472,7 +1690,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def filter(self: SelfQuery, *criterion) -> SelfQuery:
+ def filter(
+ self: SelfQuery, *criterion: _ColumnExpressionArgument[bool]
+ ) -> SelfQuery:
r"""Apply the given filtering criterion to a copy
of this :class:`_query.Query`, using SQL expressions.
@@ -1556,7 +1776,7 @@ class Query(
return self._raw_columns[0]
- def filter_by(self, **kwargs):
+ def filter_by(self: SelfQuery, **kwargs: Any) -> SelfQuery:
r"""Apply the given filtering criterion to a copy
of this :class:`_query.Query`, using keyword expressions.
@@ -1597,7 +1817,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def order_by(self: SelfQuery, *clauses) -> SelfQuery:
+ def order_by(
+ self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfQuery:
"""Apply one or more ORDER BY criteria to the query and return
the newly resulting :class:`_query.Query`.
@@ -1635,7 +1857,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def group_by(self: SelfQuery, *clauses) -> SelfQuery:
+ def group_by(
+ self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+ ) -> SelfQuery:
"""Apply one or more GROUP BY criterion to the query and return
the newly resulting :class:`_query.Query`.
@@ -1667,7 +1891,9 @@ class Query(
@_generative
@_assertions(_no_statement_condition, _no_limit_offset)
- def having(self: SelfQuery, criterion) -> SelfQuery:
+ def having(
+ self: SelfQuery, *having: _ColumnExpressionArgument[bool]
+ ) -> SelfQuery:
r"""Apply a HAVING criterion to the query and return the
newly resulting :class:`_query.Query`.
@@ -1684,17 +1910,17 @@ class Query(
"""
- self._having_criteria += (
- coercions.expect(
- roles.WhereHavingRole, criterion, apply_propagate_attrs=self
- ),
- )
+ for criterion in having:
+ having_criteria = coercions.expect(
+ roles.WhereHavingRole, criterion
+ )
+ self._having_criteria += (having_criteria,)
return self
def _set_op(self, expr_fn, *q):
return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
- def union(self, *q):
+ def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce a UNION of this Query against one or more queries.
e.g.::
@@ -1733,7 +1959,7 @@ class Query(
"""
return self._set_op(expression.union, *q)
- def union_all(self, *q):
+ def union_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce a UNION ALL of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1742,7 +1968,7 @@ class Query(
"""
return self._set_op(expression.union_all, *q)
- def intersect(self, *q):
+ def intersect(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an INTERSECT of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1751,7 +1977,7 @@ class Query(
"""
return self._set_op(expression.intersect, *q)
- def intersect_all(self, *q):
+ def intersect_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an INTERSECT ALL of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1760,7 +1986,7 @@ class Query(
"""
return self._set_op(expression.intersect_all, *q)
- def except_(self, *q):
+ def except_(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an EXCEPT of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1769,7 +1995,7 @@ class Query(
"""
return self._set_op(expression.except_, *q)
- def except_all(self, *q):
+ def except_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
"""Produce an EXCEPT ALL of this Query against one or more queries.
Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -2194,7 +2420,9 @@ class Query(
@_generative
@_assertions(_no_clauseelement_condition)
- def from_statement(self: SelfQuery, statement) -> SelfQuery:
+ def from_statement(
+ self: SelfQuery, statement: ExecutableReturnsRows
+ ) -> SelfQuery:
"""Execute the given SELECT statement and return results.
This method bypasses all internal statement compilation, and the
@@ -2283,7 +2511,7 @@ class Query(
:meth:`_query.Query.one_or_none`
"""
- return self._iter().one()
+ return self._iter().one() # type: ignore
def scalar(self) -> Any:
"""Return the first element of the first result or None
@@ -2316,7 +2544,7 @@ class Query(
def __iter__(self) -> Iterable[_T]:
return self._iter().__iter__()
- def _iter(self):
+ def _iter(self) -> Union[ScalarResult[_T], Result[_T]]:
# new style execution.
params = self._params
@@ -2837,3 +3065,7 @@ class BulkUpdate(BulkUD):
class BulkDelete(BulkUD):
"""BulkUD which handles DELETEs."""
+
+
+class RowReturningQuery(Query[Row[_TP]]):
+ pass
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index 93d18b8d7..9220c44c7 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -13,6 +13,7 @@ from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
@@ -20,8 +21,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from . import exc as orm_exc
-from .base import class_mapper
from .session import Session
from .. import exc as sa_exc
from .. import util
@@ -33,11 +32,13 @@ from ..util import warn_deprecated
from ..util.typing import Protocol
if TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _IdentityKeyType
from .identity import IdentityMap
from .interfaces import ORMOption
from .mapper import Mapper
from .query import Query
+ from .query import RowReturningQuery
from .session import _BindArguments
from .session import _EntityBindKey
from .session import _PKIdentityArgument
@@ -48,19 +49,33 @@ if TYPE_CHECKING:
from ..engine import Engine
from ..engine import Result
from ..engine import Row
+ from ..engine import RowMapping
from ..engine.interfaces import _CoreAnyExecuteParams
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptions
from ..engine.interfaces import _ExecuteOptionsParameter
from ..engine.result import ScalarResult
from ..sql._typing import _ColumnsClauseArgument
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
from ..sql.elements import ClauseElement
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.selectable import ForUpdateArg
+ from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
class _QueryDescriptorType(Protocol):
- def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]:
+ def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]:
...
@@ -236,7 +251,7 @@ class scoped_session:
self.registry.clear()
def query_property(
- self, query_cls: Optional[Type[Query[Any]]] = None
+ self, query_cls: Optional[Type[Query[_T]]] = None
) -> _QueryDescriptorType:
"""return a class property which produces a :class:`_query.Query`
object
@@ -264,20 +279,13 @@ class scoped_session:
"""
class query:
- def __get__(
- s, instance: Any, owner: Type[Any]
- ) -> Optional[Query[Any]]:
- try:
- mapper = class_mapper(owner)
- assert mapper is not None
- if query_cls:
- # custom query class
- return query_cls(mapper, session=self.registry())
- else:
- # session's configured query class
- return self.registry().query(mapper)
- except orm_exc.UnmappedClassError:
- return None
+ def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]:
+ if query_cls:
+ # custom query class
+ return query_cls(owner, session=self.registry()) # type: ignore # noqa: E501
+ else:
+ # session's configured query class
+ return self.registry().query(owner)
return query()
@@ -548,6 +556,32 @@ class scoped_session:
return self._proxied.delete(instance)
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
def execute(
self,
statement: Executable,
@@ -557,7 +591,7 @@ class scoped_session:
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a SQL expression construct.
.. container:: class_bases
@@ -1430,8 +1464,103 @@ class scoped_session:
return self._proxied.merge(instance, load=load, options=options)
+ @overload
+ def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+ ...
+
+ @overload
def query(
- self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.query
+
+ @overload
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+ ) -> Query[Any]:
+ ...
+
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
) -> Query[Any]:
r"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
@@ -1559,6 +1688,30 @@ class scoped_session:
return self._proxied.rollback()
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
@@ -1590,6 +1743,30 @@ class scoped_session:
**kw,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
def scalars(
self,
statement: Executable,
@@ -1848,7 +2025,7 @@ class scoped_session:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
r"""Return an identity key.
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 74035ec0a..263d56101 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -27,6 +27,7 @@ from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -85,12 +86,16 @@ from ..util.typing import Literal
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InstanceDict
+ from ._typing import _O
+ from .context import FromStatement
from .interfaces import ORMOption
from .interfaces import UserDefinedOption
from .mapper import Mapper
from .path_registry import PathRegistry
+ from .query import RowReturningQuery
from ..engine import Result
from ..engine import Row
from ..engine import RowMapping
@@ -104,10 +109,23 @@ if typing.TYPE_CHECKING:
from ..event import _InstanceLevelDispatch
from ..sql._typing import _ColumnsClauseArgument
from ..sql._typing import _InfoType
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
from ..sql.elements import ClauseElement
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.schema import Table
- from ..sql.selectable import TableClause
+ from ..sql.selectable import Select
+ from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
__all__ = [
"Session",
@@ -189,7 +207,7 @@ class _SessionClassMethods:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Union[Row, RowMapping]] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
"""Return an identity key.
@@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots):
params: Optional[_CoreAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
bind_arguments: Optional[_BindArguments] = None,
- ) -> Result:
+ ) -> Result[Any]:
"""Execute the statement represented by this
:class:`.ORMExecuteState`, without re-invoking events that have
already proceeded.
@@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget):
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
_scalar_result: bool = ...,
- ) -> Result:
+ ) -> Result[Any]:
...
def _execute_internal(
@@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget):
)
for idx, fn in enumerate(events_todo):
orm_exec_state._starting_event_idx = idx
- fn_result: Optional[Result] = fn(orm_exec_state)
+ fn_result: Optional[Result[Any]] = fn(orm_exec_state)
if fn_result:
if _scalar_result:
return fn_result.scalar()
@@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget):
if _scalar_result and not compile_state_cls:
if TYPE_CHECKING:
params = cast(_CoreSingleExecuteParams, params)
- return conn.scalar(statement, params or {}, execution_options)
+ return conn.scalar(
+ statement, params or {}, execution_options=execution_options
+ )
- result: Result = conn.execute(
- statement, params or {}, execution_options
+ result: Result[Any] = conn.execute(
+ statement, params or {}, execution_options=execution_options
)
if compile_state_cls:
@@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget):
else:
return result
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
def execute(
self,
statement: Executable,
@@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a SQL expression construct.
Returns a :class:`_engine.Result` object representing
@@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget):
_add_event=_add_event,
)
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
@@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget):
**kw,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
def scalars(
self,
statement: Executable,
@@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget):
f'{", ".join(context)} or this Session.'
)
+ @overload
+ def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+ ...
+
+ @overload
+ def query(
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.query
+
+ @overload
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+ ) -> Query[Any]:
+ ...
+
def query(
- self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
) -> Query[Any]:
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
@@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget):
with_for_update = ForUpdateArg._from_argument(with_for_update)
- stmt = sql.select(object_mapper(instance))
+ stmt: Select[Any] = sql.select(object_mapper(instance))
if (
loading.load_on_ident(
self,
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 58f141997..ab32a3981 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -656,13 +656,13 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
@classmethod
def _instance_level_callable_processor(
cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
- ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]:
+ ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]:
impl = manager[key].impl
if is_collection_impl(impl):
fixed_impl = impl
def _set_callable(
- state: InstanceState[_O], dict_: _InstanceDict, row: Row
+ state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
) -> None:
if "callables" not in state.__dict__:
state.callables = {}
@@ -674,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
else:
def _set_callable(
- state: InstanceState[_O], dict_: _InstanceDict, row: Row
+ state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
) -> None:
if "callables" not in state.__dict__:
state.callables = {}
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 3934de535..8148793b1 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -28,6 +28,7 @@ from typing import Union
import weakref
from . import attributes # noqa
+from . import exc
from ._typing import _O
from ._typing import insp_is_aliased_class
from ._typing import insp_is_mapper
@@ -41,6 +42,7 @@ from .base import InspectionAttr as InspectionAttr
from .base import instance_str as instance_str
from .base import object_mapper as object_mapper
from .base import object_state as object_state
+from .base import opt_manager_of_class
from .base import state_attribute_str as state_attribute_str
from .base import state_class_str as state_class_str
from .base import state_str as state_str
@@ -68,6 +70,7 @@ from ..sql.base import ColumnCollection
from ..sql.cache_key import HasCacheKey
from ..sql.cache_key import MemoizedHasCacheKey
from ..sql.elements import ColumnElement
+from ..sql.elements import KeyedColumnElement
from ..sql.selectable import FromClause
from ..util.langhelpers import MemoizedSlots
from ..util.typing import de_stringify_annotation
@@ -95,9 +98,7 @@ if typing.TYPE_CHECKING:
from ..sql.selectable import _ColumnsClauseElement
from ..sql.selectable import Alias
from ..sql.selectable import Subquery
- from ..sql.visitors import _ET
from ..sql.visitors import anon_map
- from ..sql.visitors import ExternallyTraversible
_T = TypeVar("_T", bound=Any)
@@ -341,7 +342,7 @@ def identity_key(
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[_T] = None,
- row: Optional[Union[Row, RowMapping]] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[_T]:
r"""Generate "identity key" tuples, as are used as keys in the
@@ -468,7 +469,9 @@ class ORMAdapter(sql_util.ColumnAdapter):
return not entity or entity.isa(self.mapper)
-class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
+class AliasedClass(
+ inspection.Inspectable["AliasedInsp[_O]"], ORMColumnsClauseRole[_O]
+):
r"""Represents an "aliased" form of a mapped class for usage with Query.
The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias`
@@ -663,7 +666,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
@inspection._self_inspects
class AliasedInsp(
- ORMEntityColumnsClauseRole,
+ ORMEntityColumnsClauseRole[_O],
ORMFromClauseRole,
HasCacheKey,
InspectionAttr,
@@ -1276,12 +1279,29 @@ class LoaderCriteriaOption(CriteriaOption):
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
+@inspection._inspects(type)
+def _inspect_mc(
+ class_: Type[_O],
+) -> Optional[Mapper[_O]]:
+
+ try:
+ class_manager = opt_manager_of_class(class_)
+ if class_manager is None or not class_manager.is_mapped:
+ return None
+ mapper = class_manager.mapper
+ except exc.NO_STATE:
+
+ return None
+ else:
+ return mapper
+
+
@inspection._self_inspects
class Bundle(
- ORMColumnsClauseRole,
+ ORMColumnsClauseRole[_T],
SupportsCloneAnnotations,
MemoizedHasCacheKey,
- inspection.Inspectable["Bundle"],
+ inspection.Inspectable["Bundle[_T]"],
InspectionAttr,
):
"""A grouping of SQL expressions that are returned by a :class:`.Query`
@@ -1373,10 +1393,10 @@ class Bundle(
@property
def entity_namespace(
self,
- ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
+ ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
return self.c
- columns: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+ columns: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]
"""A namespace of SQL expressions referred to by this :class:`.Bundle`.
@@ -1402,7 +1422,7 @@ class Bundle(
"""
- c: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+ c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]
"""An alias for :attr:`.Bundle.columns`."""
def _clone(self):
@@ -1908,9 +1928,10 @@ def _extract_mapped_subtype(
raw_annotation: Union[type, str],
cls: type,
key: str,
- attr_cls: type,
+ attr_cls: Type[Any],
required: bool,
is_dataclass_field: bool,
+ superclasses: Optional[Tuple[Type[Any], ...]] = None,
) -> Optional[Union[type, str]]:
if raw_annotation is None:
@@ -1930,9 +1951,13 @@ def _extract_mapped_subtype(
if is_dataclass_field:
return annotated
else:
- if (
- not hasattr(annotated, "__origin__")
- or not issubclass(annotated.__origin__, attr_cls) # type: ignore
+ # TODO: there don't seem to be tests for the failure
+ # conditions here
+ if not hasattr(annotated, "__origin__") or (
+ not issubclass(
+ annotated.__origin__, # type: ignore
+ superclasses if superclasses else attr_cls,
+ )
and not issubclass(attr_cls, annotated.__origin__) # type: ignore
):
our_annotated_str = (
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 84913225d..c3ebb4596 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -121,7 +121,6 @@ def __go(lcls: Any) -> None:
coercions.lambdas = lambdas
coercions.schema = schema
coercions.selectable = selectable
- coercions.traversals = traversals
from .annotation import _prepare_annotations
from .annotation import Annotated
diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py
index 37d44976a..f89e8f578 100644
--- a/lib/sqlalchemy/sql/_selectable_constructors.py
+++ b/lib/sqlalchemy/sql/_selectable_constructors.py
@@ -9,12 +9,16 @@ from __future__ import annotations
from typing import Any
from typing import Optional
+from typing import overload
+from typing import Tuple
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import coercions
from . import roles
from ._typing import _ColumnsClauseArgument
+from ._typing import _no_kw
from .elements import ColumnClause
from .selectable import Alias
from .selectable import CompoundSelect
@@ -34,6 +38,17 @@ if TYPE_CHECKING:
from ._typing import _FromClauseArgument
from ._typing import _OnClauseArgument
from ._typing import _SelectStatementForCompoundArgument
+ from ._typing import _T0
+ from ._typing import _T1
+ from ._typing import _T2
+ from ._typing import _T3
+ from ._typing import _T4
+ from ._typing import _T5
+ from ._typing import _T6
+ from ._typing import _T7
+ from ._typing import _T8
+ from ._typing import _T9
+ from ._typing import _TypedColumnClauseArgument as _TCCA
from .functions import Function
from .selectable import CTE
from .selectable import HasCTE
@@ -41,6 +56,9 @@ if TYPE_CHECKING:
from .selectable import SelectBase
+_T = TypeVar("_T", bound=Any)
+
+
def alias(
selectable: FromClause, name: Optional[str] = None, flat: bool = False
) -> NamedFromClause:
@@ -89,7 +107,9 @@ def cte(
)
-def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def except_(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``EXCEPT`` of multiple selectables.
The returned object is an instance of
@@ -119,7 +139,7 @@ def except_all(
def exists(
__argument: Optional[
- Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]]
] = None,
) -> Exists:
"""Construct a new :class:`_expression.Exists` construct.
@@ -162,7 +182,9 @@ def exists(
return Exists(__argument)
-def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def intersect(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return an ``INTERSECT`` of multiple selectables.
The returned object is an instance of
@@ -306,7 +328,129 @@ def outerjoin(
return Join(left, right, onclause, isouter=True, full=full)
-def select(*entities: _ColumnsClauseArgument) -> Select:
+# START OVERLOADED FUNCTIONS select Select 1-10
+
+# code within this block is **programmatically,
+# statically generated** by tools/generate_tuple_map_overloads.py
+
+
+@overload
+def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]:
+ ...
+
+
+@overload
+def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+) -> Select[Tuple[_T0, _T1, _T2]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ __ent8: _TCCA[_T8],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]:
+ ...
+
+
+@overload
+def select(
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ __ent8: _TCCA[_T8],
+ __ent9: _TCCA[_T9],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]:
+ ...
+
+
+# END OVERLOADED FUNCTIONS select
+
+
+@overload
+def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
+ ...
+
+
+def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
r"""Construct a new :class:`_expression.Select`.
@@ -343,7 +487,11 @@ def select(*entities: _ColumnsClauseArgument) -> Select:
given, as well as ORM-mapped classes.
"""
-
+ # the keyword args are a necessary element in order for the typing
+ # to work out w/ the varargs vs. having named "keyword" arguments that
+ # aren't always present.
+ if __kw:
+ raise _no_kw()
return Select(*entities)
@@ -425,7 +573,9 @@ def tablesample(
return TableSample._factory(selectable, sampling, name=name, seed=seed)
-def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def union(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return a ``UNION`` of multiple selectables.
The returned object is an instance of
@@ -445,7 +595,9 @@ def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
return CompoundSelect._create_union(*selects)
-def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def union_all(
+ *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
r"""Return a ``UNION ALL`` of multiple selectables.
The returned object is an instance of
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 53d29b628..1df530dbd 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -5,18 +5,27 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Set
+from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import roles
+from .. import exc
from .. import util
from ..inspection import Inspectable
from ..util.typing import Literal
from ..util.typing import Protocol
if TYPE_CHECKING:
+ from datetime import date
+ from datetime import datetime
+ from datetime import time
+ from datetime import timedelta
+ from decimal import Decimal
+ from uuid import UUID
+
from .base import Executable
from .compiler import Compiled
from .compiler import DDLCompiler
@@ -26,17 +35,15 @@ if TYPE_CHECKING:
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .elements import quoted_name
- from .elements import SQLCoreOperations
from .elements import TextClause
from .lambdas import LambdaElement
from .roles import ColumnsClauseRole
from .roles import FromClauseRole
from .schema import Column
- from .schema import DefaultGenerator
- from .schema import Sequence
- from .schema import Table
from .selectable import Alias
+ from .selectable import CTE
from .selectable import FromClause
from .selectable import Join
from .selectable import NamedFromClause
@@ -61,6 +68,30 @@ class _HasClauseElement(Protocol):
...
+# match column types that are not ORM entities
+_NOT_ENTITY = TypeVar(
+ "_NOT_ENTITY",
+ int,
+ str,
+ "datetime",
+ "date",
+ "time",
+ "timedelta",
+ "UUID",
+ float,
+ "Decimal",
+)
+
+_MAYBE_ENTITY = TypeVar(
+ "_MAYBE_ENTITY",
+ roles.ColumnsClauseRole,
+ Literal["*", 1],
+ Type[Any],
+ Inspectable[_HasClauseElement],
+ _HasClauseElement,
+)
+
+
# convention:
# XYZArgument - something that the end user is passing to a public API method
# XYZElement - the internal representation that we use for the thing.
@@ -76,9 +107,10 @@ _TextCoercedExpressionArgument = Union[
]
_ColumnsClauseArgument = Union[
- Literal["*", 1],
+ roles.TypedColumnsClauseRole[_T],
roles.ColumnsClauseRole,
- Type[Any],
+ Literal["*", 1],
+ Type[_T],
Inspectable[_HasClauseElement],
_HasClauseElement,
]
@@ -92,6 +124,24 @@ sets; select(...), insert().returning(...), etc.
"""
+_TypedColumnClauseArgument = Union[
+ roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T]
+]
+
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
+
+_T0 = TypeVar("_T0", bound=Any)
+_T1 = TypeVar("_T1", bound=Any)
+_T2 = TypeVar("_T2", bound=Any)
+_T3 = TypeVar("_T3", bound=Any)
+_T4 = TypeVar("_T4", bound=Any)
+_T5 = TypeVar("_T5", bound=Any)
+_T6 = TypeVar("_T6", bound=Any)
+_T7 = TypeVar("_T7", bound=Any)
+_T8 = TypeVar("_T8", bound=Any)
+_T9 = TypeVar("_T9", bound=Any)
+
+
_ColumnExpressionArgument = Union[
"ColumnElement[_T]",
_HasClauseElement,
@@ -169,6 +219,7 @@ _DMLTableArgument = Union[
"TableClause",
"Join",
"Alias",
+ "CTE",
Type[Any],
Inspectable[_HasClauseElement],
_HasClauseElement,
@@ -194,6 +245,11 @@ if TYPE_CHECKING:
def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
...
+ def is_keyed_column_element(
+ c: ClauseElement,
+ ) -> TypeGuard[KeyedColumnElement[Any]]:
+ ...
+
def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
...
@@ -216,7 +272,7 @@ if TYPE_CHECKING:
def is_select_statement(
t: Union[Executable, ReturnsRows]
- ) -> TypeGuard[Select]:
+ ) -> TypeGuard[Select[Any]]:
...
def is_table(t: FromClause) -> TypeGuard[TableClause]:
@@ -234,6 +290,7 @@ else:
is_ddl_compiler = operator.attrgetter("is_ddl")
is_named_from_clause = operator.attrgetter("named_with_column")
is_column_element = operator.attrgetter("_is_column_element")
+ is_keyed_column_element = operator.attrgetter("_is_keyed_column_element")
is_text_clause = operator.attrgetter("_is_text_clause")
is_from_clause = operator.attrgetter("_is_from_clause")
is_tuple_type = operator.attrgetter("_is_tuple_type")
@@ -260,3 +317,10 @@ def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]:
return c.is_dml and (c.is_insert or c.is_update) # type: ignore
+
+
+def _no_kw() -> exc.ArgumentError:
+ return exc.ArgumentError(
+ "Additional keyword arguments are not accepted by this "
+ "function/method. The presence of **kw is for pep-484 typing purposes"
+ )
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index f81878d55..790edefc6 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -62,10 +62,10 @@ if TYPE_CHECKING:
from . import coercions
from . import elements
from . import type_api
- from ._typing import _ColumnsClauseArgument
from .elements import BindParameter
- from .elements import ColumnClause
+ from .elements import ColumnClause # noqa
from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
from .elements import TextClause
@@ -74,7 +74,6 @@ if TYPE_CHECKING:
from .selectable import FromClause
from ..engine import Connection
from ..engine import CursorResult
- from ..engine import Result
from ..engine.base import _CompiledCacheType
from ..engine.interfaces import _CoreMultiExecuteParams
from ..engine.interfaces import _ExecuteOptions
@@ -704,8 +703,11 @@ class InPlaceGenerative(HasMemoized):
"""Provide a method-chaining pattern in conjunction with the
@_generative decorator that mutates in place."""
+ __slots__ = ()
+
def _generate(self):
skip = self._memoized_keys
+ # note __dict__ needs to be in __slots__ if this is used
for k in skip:
self.__dict__.pop(k, None)
return self
@@ -937,7 +939,7 @@ class ExecutableOption(HasCopyInternals):
SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
-class Executable(roles.StatementRole, Generative):
+class Executable(roles.StatementRole):
"""Mark a :class:`_expression.ClauseElement` as supporting execution.
:class:`.Executable` is a superclass for all "statement" types
@@ -994,7 +996,7 @@ class Executable(roles.StatementRole, Generative):
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
...
def _execute_on_scalar(
@@ -1253,7 +1255,7 @@ class SchemaVisitor(ClauseVisitor):
_COLKEY = TypeVar("_COLKEY", Union[None, str], str)
_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
-_COL = TypeVar("_COL", bound="ColumnElement[Any]")
+_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]")
class ColumnCollection(Generic[_COLKEY, _COL_co]):
@@ -1505,6 +1507,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
+
self._collection[:] = cols
self._colset.update(c for k, c in self._collection)
self._index.update(
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index 0659709ab..9b7231360 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -29,6 +29,7 @@ from typing import Union
from . import operators
from . import roles
from . import visitors
+from ._typing import is_from_clause
from .base import ExecutableOption
from .base import Options
from .cache_key import HasCacheKey
@@ -38,25 +39,18 @@ from .. import inspection
from .. import util
from ..util.typing import Literal
-if not typing.TYPE_CHECKING:
- elements = None
- lambdas = None
- schema = None
- selectable = None
- traversals = None
-
if typing.TYPE_CHECKING:
from . import elements
from . import lambdas
from . import schema
from . import selectable
- from . import traversals
from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnsClauseArgument
from ._typing import _DDLColumnArgument
from ._typing import _DMLTableArgument
from ._typing import _FromClauseArgument
from .dml import _DMLTableElement
+ from .elements import BindParameter
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
@@ -64,9 +58,7 @@ if typing.TYPE_CHECKING:
from .elements import SQLCoreOperations
from .schema import Column
from .selectable import _ColumnsClauseElement
- from .selectable import _JoinTargetElement
from .selectable import _JoinTargetProtocol
- from .selectable import _OnClauseElement
from .selectable import FromClause
from .selectable import HasCTE
from .selectable import SelectBase
@@ -170,6 +162,15 @@ def expect(
@overload
def expect(
+ role: Type[roles.LiteralValueRole],
+ element: Any,
+ **kw: Any,
+) -> BindParameter[Any]:
+ ...
+
+
+@overload
+def expect(
role: Type[roles.DDLReferredColumnRole],
element: Any,
**kw: Any,
@@ -272,7 +273,7 @@ def expect(
@overload
def expect(
role: Type[roles.ColumnsClauseRole],
- element: _ColumnsClauseArgument,
+ element: _ColumnsClauseArgument[Any],
**kw: Any,
) -> _ColumnsClauseElement:
...
@@ -933,7 +934,7 @@ class GroupByImpl(ByOfImpl, RoleImpl):
argname: Optional[str] = None,
**kw: Any,
) -> Any:
- if isinstance(resolved, roles.StrictFromClauseRole):
+ if is_from_clause(resolved):
return elements.ClauseList(*resolved.c)
else:
return resolved
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index c524a2602..a1b25b8a6 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -80,7 +80,6 @@ from ..util.typing import Protocol
from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
- from . import roles
from .annotation import _AnnotationDict
from .base import _AmbiguousTableNameMap
from .base import CompileState
@@ -95,7 +94,6 @@ if typing.TYPE_CHECKING:
from .elements import ColumnElement
from .elements import Label
from .functions import Function
- from .selectable import Alias
from .selectable import AliasedReturnsRows
from .selectable import CompoundSelectState
from .selectable import CTE
@@ -386,7 +384,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
need_result_map_for_nested: bool
need_result_map_for_compound: bool
select_0: ReturnsRows
- insert_from_select: Select
+ insert_from_select: Select[Any]
class ExpandedState(NamedTuple):
@@ -2834,15 +2832,31 @@ class SQLCompiler(Compiled):
"unique bind parameter of the same name" % name
)
elif existing._is_crud or bindparam._is_crud:
- raise exc.CompileError(
- "bindparam() name '%s' is reserved "
- "for automatic usage in the VALUES or SET "
- "clause of this "
- "insert/update statement. Please use a "
- "name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')."
- % (bindparam.key, bindparam.key)
- )
+ if existing._is_crud and bindparam._is_crud:
+ # TODO: this condition is not well understood.
+ # see tests in test/sql/test_update.py
+ raise exc.CompileError(
+ "Encountered unsupported case when compiling an "
+ "INSERT or UPDATE statement. If this is a "
+ "multi-table "
+ "UPDATE statement, please provide string-named "
+ "arguments to the "
+ "values() method with distinct names; support for "
+ "multi-table UPDATE statements that "
+ "target multiple tables for UPDATE is very "
+ "limited",
+ )
+ else:
+ raise exc.CompileError(
+ f"bindparam() name '{bindparam.key}' is reserved "
+ "for automatic usage in the VALUES or SET "
+ "clause of this "
+ "insert/update statement. Please use a "
+ "name other than column name when using "
+ "bindparam() "
+ "with insert() or update() (for example, "
+ f"'b_{bindparam.key}')."
+ )
self.binds[bindparam.key] = self.binds[name] = bindparam
@@ -3881,7 +3895,7 @@ class SQLCompiler(Compiled):
return text
def _setup_select_hints(
- self, select: Select
+ self, select: Select[Any]
) -> Tuple[str, _FromHintsType]:
byfrom = dict(
[
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index e4408cd31..29d7b45d7 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -22,6 +22,7 @@ from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
@@ -30,8 +31,10 @@ from . import coercions
from . import dml
from . import elements
from . import roles
+from .elements import ColumnClause
from .schema import default_is_clause_element
from .schema import default_is_sequence
+from .selectable import TableClause
from .. import exc
from .. import util
from ..util.typing import Literal
@@ -41,16 +44,9 @@ if TYPE_CHECKING:
from .compiler import SQLCompiler
from .dml import _DMLColumnElement
from .dml import DMLState
- from .dml import Insert
- from .dml import Update
- from .dml import UpdateDMLState
from .dml import ValuesBase
- from .elements import ClauseElement
- from .elements import ColumnClause
from .elements import ColumnElement
- from .elements import TextClause
from .schema import _SQLExprDefault
- from .schema import Column
from .selectable import TableClause
REQUIRED = util.symbol(
@@ -68,12 +64,20 @@ values present.
)
+def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]:
+ if not isinstance(c, ColumnClause):
+ raise exc.CompileError(
+ f"Can't create DML statement against column expression {c!r}"
+ )
+ return c
+
+
class _CrudParams(NamedTuple):
- single_params: List[
- Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ single_params: Sequence[
+ Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]]
]
all_multi_params: List[
- List[
+ Sequence[
Tuple[
ColumnClause[Any],
str,
@@ -274,7 +278,7 @@ def _get_crud_params(
compiler,
stmt,
compile_state,
- cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+ cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values),
cast("Callable[..., str]", _column_as_key),
kw,
)
@@ -290,7 +294,7 @@ def _get_crud_params(
# insert_executemany_returning mode :)
values = [
(
- stmt.table.columns[0],
+ _as_dml_column(stmt.table.columns[0]),
compiler.preparer.format_column(stmt.table.columns[0]),
"DEFAULT",
)
@@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams(
compiler: SQLCompiler,
stmt: ValuesBase,
compile_state: DMLState,
- initial_values: List[Tuple[ColumnClause[Any], str, str]],
+ initial_values: Sequence[Tuple[ColumnClause[Any], str, str]],
_column_as_key: Callable[..., str],
kw: Dict[str, Any],
-) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]:
values_0 = initial_values
values = [initial_values]
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 8307f6400..e0f162fc8 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -22,15 +22,19 @@ from typing import List
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from . import coercions
from . import roles
from . import util as sql_util
+from ._typing import _no_kw
+from ._typing import _TP
from ._typing import is_column_element
from ._typing import is_named_from_clause
from .base import _entity_namespace_key
@@ -42,6 +46,7 @@ from .base import ColumnCollection
from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
+from .base import Generative
from .base import HasCompileState
from .elements import BooleanClauseList
from .elements import ClauseElement
@@ -49,12 +54,13 @@ from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Null
from .selectable import Alias
+from .selectable import ExecutableReturnsRows
from .selectable import FromClause
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import Join
-from .selectable import ReturnsRows
from .selectable import TableClause
+from .selectable import TypedReturnsRows
from .sqltypes import NullType
from .visitors import InternalTraversal
from .. import exc
@@ -66,9 +72,19 @@ if TYPE_CHECKING:
from ._typing import _ColumnsClauseArgument
from ._typing import _DMLColumnArgument
from ._typing import _DMLTableArgument
- from ._typing import _FromClauseArgument
+ from ._typing import _T0 # noqa
+ from ._typing import _T1 # noqa
+ from ._typing import _T2 # noqa
+ from ._typing import _T3 # noqa
+ from ._typing import _T4 # noqa
+ from ._typing import _T5 # noqa
+ from ._typing import _T6 # noqa
+ from ._typing import _T7 # noqa
+ from ._typing import _TypedColumnClauseArgument as _TCCA # noqa
from .base import ReadOnlyColumnCollection
from .compiler import SQLCompiler
+ from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .selectable import _ColumnsClauseElement
from .selectable import _SelectIterable
from .selectable import Select
@@ -88,6 +104,8 @@ else:
isinsert = operator.attrgetter("isinsert")
+_T = TypeVar("_T", bound=Any)
+
_DMLColumnElement = Union[str, ColumnClause[Any]]
_DMLTableElement = Union[TableClause, Alias, Join]
@@ -185,6 +203,11 @@ class DMLState(CompileState):
"%s construct does not support "
"multiple parameter sets." % statement.__visit_name__.upper()
)
+ else:
+ assert isinstance(statement, Insert)
+
+ # which implies...
+ # assert isinstance(statement.table, TableClause)
for parameters in statement._multi_values:
multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [
@@ -291,7 +314,9 @@ class UpdateDMLState(DMLState):
elif statement._multi_values:
self._process_multi_values(statement)
self._extra_froms = ef = self._make_extra_froms(statement)
- self.is_multitable = mt = ef and self._dict_parameters
+
+ self.is_multitable = mt = ef
+
self.include_table_with_column_exprs = bool(
mt and compiler.render_table_with_column_in_update_from
)
@@ -317,8 +342,8 @@ class UpdateBase(
HasCompileState,
DialectKWArgs,
HasPrefixes,
- ReturnsRows,
- Executable,
+ Generative,
+ ExecutableReturnsRows,
ClauseElement,
):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
@@ -383,8 +408,8 @@ class UpdateBase(
@_generative
def returning(
- self: SelfUpdateBase, *cols: _ColumnsClauseArgument
- ) -> SelfUpdateBase:
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> UpdateBase:
r"""Add a :term:`RETURNING` or equivalent clause to this statement.
e.g.:
@@ -454,6 +479,8 @@ class UpdateBase(
:ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial`
""" # noqa: E501
+ if __kw:
+ raise _no_kw()
if self._return_defaults:
raise exc.InvalidRequestError(
"return_defaults() is already configured on this statement"
@@ -464,7 +491,7 @@ class UpdateBase(
return self
def corresponding_column(
- self, column: ColumnElement[Any], require_embedded: bool = False
+ self, column: KeyedColumnElement[Any], require_embedded: bool = False
) -> Optional[ColumnElement[Any]]:
return self.exported_columns.corresponding_column(
column, require_embedded=require_embedded
@@ -628,7 +655,7 @@ class ValuesBase(UpdateBase):
_supports_multi_parameters = False
- select: Optional[Select] = None
+ select: Optional[Select[Any]] = None
"""SELECT statement for INSERT .. FROM SELECT"""
_post_values_clause: Optional[ClauseElement] = None
@@ -804,11 +831,15 @@ class ValuesBase(UpdateBase):
)
elif isinstance(arg, collections_abc.Sequence):
-
if arg and isinstance(arg[0], (list, dict, tuple)):
self._multi_values += (arg,)
return self
+ if TYPE_CHECKING:
+ # crud.py raises during compilation if this is not the
+ # case
+ assert isinstance(self, Insert)
+
# tuple values
arg = {c.key: value for c, value in zip(self.table.c, arg)}
@@ -1010,7 +1041,7 @@ class Insert(ValuesBase):
def from_select(
self: SelfInsert,
names: List[str],
- select: Select,
+ select: Select[Any],
include_defaults: bool = True,
) -> SelfInsert:
"""Return a new :class:`_expression.Insert` construct which represents
@@ -1073,6 +1104,114 @@ class Insert(ValuesBase):
self.select = coercions.expect(roles.DMLSelectRole, select)
return self
+ if TYPE_CHECKING:
+
+ # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> ReturningInsert[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.returning
+
+ @overload
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningInsert[Any]:
+ ...
+
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningInsert[Any]:
+ ...
+
+
+class ReturningInsert(Insert, TypedReturnsRows[_TP]):
+ """Typing-only class that establishes a generic type form of
+ :class:`.Insert` which tracks returned column types.
+
+ This datatype is delivered when calling the
+ :meth:`.Insert.returning` method.
+
+ .. versionadded:: 2.0
+
+ """
+
SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
@@ -1264,6 +1403,113 @@ class Update(DMLWhereBase, ValuesBase):
self._inline = True
return self
+ if TYPE_CHECKING:
+ # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> ReturningUpdate[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.returning
+
+ @overload
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningUpdate[Any]:
+ ...
+
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningUpdate[Any]:
+ ...
+
+
+class ReturningUpdate(Update, TypedReturnsRows[_TP]):
+ """Typing-only class that establishes a generic type form of
+ :class:`.Update` which tracks returned column types.
+
+ This datatype is delivered when calling the
+ :meth:`.Update.returning` method.
+
+ .. versionadded:: 2.0
+
+ """
+
SelfDelete = typing.TypeVar("SelfDelete", bound="Delete")
@@ -1297,3 +1543,111 @@ class Delete(DMLWhereBase, UpdateBase):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
+
+ if TYPE_CHECKING:
+
+ # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> ReturningDelete[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def returning(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def returning(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.returning
+
+ @overload
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningDelete[Any]:
+ ...
+
+ def returning(
+ self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+ ) -> ReturningDelete[Any]:
+ ...
+
+
+class ReturningDelete(Update, TypedReturnsRows[_TP]):
+ """Typing-only class that establishes a generic type form of
+ :class:`.Delete` which tracks returned column types.
+
+ This datatype is delivered when calling the
+ :meth:`.Delete.returning` method.
+
+ .. versionadded:: 2.0
+
+ """
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 34d5127ab..a29561291 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -54,6 +54,7 @@ from .base import _clone
from .base import _generative
from .base import _NoArg
from .base import Executable
+from .base import Generative
from .base import HasMemoized
from .base import Immutable
from .base import NO_ARG
@@ -94,10 +95,7 @@ if typing.TYPE_CHECKING:
from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import NamedFromClause
- from .selectable import ReturnsRows
from .selectable import Select
- from .selectable import TableClause
- from .sqltypes import Boolean
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import _CloneCallableType
@@ -122,7 +120,9 @@ _NT = TypeVar("_NT", bound="_NUMERIC")
_NMT = TypeVar("_NMT", bound="_NUMBER")
-def literal(value, type_=None):
+def literal(
+ value: Any, type_: Optional[_TypeEngineArgument[_T]] = None
+) -> BindParameter[_T]:
r"""Return a literal clause, bound to a bind parameter.
Literal clauses are created automatically when non-
@@ -144,7 +144,9 @@ def literal(value, type_=None):
return coercions.expect(roles.LiteralValueRole, value, type_=type_)
-def literal_column(text, type_=None):
+def literal_column(
+ text: str, type_: Optional[_TypeEngineArgument[_T]] = None
+) -> ColumnClause[_T]:
r"""Produce a :class:`.ColumnClause` object that has the
:paramref:`_expression.column.is_literal` flag set to True.
@@ -316,6 +318,7 @@ class ClauseElement(
is_selectable = False
is_dml = False
_is_column_element = False
+ _is_keyed_column_element = False
_is_table = False
_is_textual = False
_is_from_clause = False
@@ -342,7 +345,7 @@ class ClauseElement(
if typing.TYPE_CHECKING:
def get_children(
- self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any
+ self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any
) -> Iterable[ClauseElement]:
...
@@ -455,7 +458,7 @@ class ClauseElement(
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptions,
- ) -> Result:
+ ) -> Result[Any]:
if self.supports_execution:
if TYPE_CHECKING:
assert isinstance(self, Executable)
@@ -833,13 +836,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
def in_(
self,
- other: Union[Sequence[Any], BindParameter[Any], Select],
+ other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
) -> BinaryExpression[bool]:
...
def not_in(
self,
- other: Union[Sequence[Any], BindParameter[Any], Select],
+ other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
) -> BinaryExpression[bool]:
...
@@ -1699,6 +1702,14 @@ class ColumnElement(
return self._anon_label(label, add_hash=idx)
+class KeyedColumnElement(ColumnElement[_T]):
+ """ColumnElement where ``.key`` is non-None."""
+
+ _is_keyed_column_element = True
+
+ key: str
+
+
class WrapsColumnExpression(ColumnElement[_T]):
"""Mixin that defines a :class:`_expression.ColumnElement`
as a wrapper with special
@@ -1760,7 +1771,7 @@ class WrapsColumnExpression(ColumnElement[_T]):
SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]")
-class BindParameter(roles.InElementRole, ColumnElement[_T]):
+class BindParameter(roles.InElementRole, KeyedColumnElement[_T]):
r"""Represent a "bound expression".
:class:`.BindParameter` is invoked explicitly using the
@@ -2073,6 +2084,7 @@ class TextClause(
roles.FromClauseRole,
roles.SelectStatementRole,
roles.InElementRole,
+ Generative,
Executable,
DQLDMLClauseElement,
roles.BinaryElementRole[Any],
@@ -4160,7 +4172,7 @@ class FunctionFilter(ColumnElement[_T]):
)
-class NamedColumn(ColumnElement[_T]):
+class NamedColumn(KeyedColumnElement[_T]):
is_literal = False
table: Optional[FromClause] = None
name: str
@@ -4502,7 +4514,7 @@ class ColumnClause(
self.is_literal = is_literal
- def get_children(self, column_tables=False, **kw):
+ def get_children(self, *, column_tables=False, **kw):
# override base get_children() to not return the Table
# or selectable that is parent to this column. Traversals
# expect the columns of tables and subqueries to be leaf nodes.
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 648168235..b827df3df 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -175,7 +175,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> CursorResult:
+ ) -> CursorResult[Any]:
return connection._execute_function(
self, distilled_params, execution_options
)
@@ -623,7 +623,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
joins_implicitly=joins_implicitly,
)
- def select(self) -> "Select":
+ def select(self) -> Select[Any]:
"""Produce a :func:`_expression.select` construct
against this :class:`.FunctionElement`.
@@ -632,7 +632,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
s = select(function_element)
"""
- s = Select(self)
+ s: Select[Any] = Select(self)
if self._execution_options:
s = s.execution_options(**self._execution_options)
return s
@@ -846,7 +846,7 @@ class _FunctionGenerator:
@overload
def __call__(
- self, *c: Any, type_: TypeEngine[_T], **kwargs: Any
+ self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any
) -> Function[_T]:
...
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
index 231c70a5b..09d4b35ad 100644
--- a/lib/sqlalchemy/sql/roles.py
+++ b/lib/sqlalchemy/sql/roles.py
@@ -8,8 +8,6 @@ from __future__ import annotations
from typing import Any
from typing import Generic
-from typing import Iterable
-from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -19,12 +17,7 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _PropagateAttrsType
- from .base import _EntityNamespace
- from .base import ColumnCollection
- from .base import ReadOnlyColumnCollection
- from .elements import ColumnClause
from .elements import Label
- from .elements import NamedColumn
from .selectable import _SelectIterable
from .selectable import FromClause
from .selectable import Subquery
@@ -108,13 +101,21 @@ class TruncatedLabelRole(StringRole, SQLRole):
class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
__slots__ = ()
- _role_name = "Column expression or FROM clause"
+ _role_name = (
+ "Column expression, FROM clause, or other columns clause element"
+ )
@property
def _select_iterable(self) -> _SelectIterable:
raise NotImplementedError()
+class TypedColumnsClauseRole(Generic[_T], SQLRole):
+ """element-typed form of ColumnsClauseRole"""
+
+ __slots__ = ()
+
+
class LimitOffsetRole(SQLRole):
__slots__ = ()
_role_name = "LIMIT / OFFSET expression"
@@ -161,7 +162,7 @@ class WhereHavingRole(OnClauseRole):
_role_name = "SQL expression for WHERE/HAVING role"
-class ExpressionElementRole(Generic[_T], SQLRole):
+class ExpressionElementRole(TypedColumnsClauseRole[_T]):
# note when using generics for ExpressionElementRole,
# the generic type needs to be in
# sqlalchemy.sql.coercions._impl_lookup mapping also.
@@ -212,39 +213,11 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
named_with_column: bool
- if TYPE_CHECKING:
-
- @util.ro_non_memoized_property
- def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
- ...
-
- @util.ro_non_memoized_property
- def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
- ...
-
- @util.ro_non_memoized_property
- def entity_namespace(self) -> _EntityNamespace:
- ...
-
- @util.ro_non_memoized_property
- def _hide_froms(self) -> Iterable[FromClause]:
- ...
-
- @util.ro_non_memoized_property
- def _from_objects(self) -> List[FromClause]:
- ...
-
class StrictFromClauseRole(FromClauseRole):
__slots__ = ()
# does not allow text() or select() objects
- if TYPE_CHECKING:
-
- @util.ro_non_memoized_property
- def description(self) -> str:
- ...
-
class AnonymizedFromClauseRole(StrictFromClauseRole):
__slots__ = ()
@@ -317,16 +290,6 @@ class DMLTableRole(FromClauseRole):
__slots__ = ()
_role_name = "subject table for an INSERT, UPDATE or DELETE"
- if TYPE_CHECKING:
-
- @util.ro_non_memoized_property
- def primary_key(self) -> Iterable[NamedColumn[Any]]:
- ...
-
- @util.ro_non_memoized_property
- def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
- ...
-
class DMLColumnRole(SQLRole):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 52ba60a62..27456d2be 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -86,7 +86,6 @@ if typing.TYPE_CHECKING:
from ._typing import _InfoType
from ._typing import _TextCoercedExpressionArgument
from ._typing import _TypeEngineArgument
- from .base import ColumnCollection
from .base import DedupeColumnCollection
from .base import ReadOnlyColumnCollection
from .compiler import DDLCompiler
@@ -97,9 +96,7 @@ if typing.TYPE_CHECKING:
from .visitors import anon_map
from ..engine import Connection
from ..engine import Engine
- from ..engine.cursor import CursorResult
from ..engine.interfaces import _CoreMultiExecuteParams
- from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptionsParameter
from ..engine.interfaces import ExecutionContext
from ..engine.mock import MockConnection
@@ -2609,8 +2606,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
:class:`_schema.Table`.
"""
-
- return table.columns.corresponding_column(self.column)
+ # our column is a Column, and any subquery etc. proxying us
+ # would be doing so via another Column, so that's what would
+ # be returned here
+ return table.columns.corresponding_column(self.column) # type: ignore
@util.memoized_property
def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]:
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 9d4d1d6c7..b08f13f99 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -23,6 +23,7 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
+from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
@@ -46,6 +47,8 @@ from . import traversals
from . import type_api
from . import visitors
from ._typing import _ColumnsClauseArgument
+from ._typing import _no_kw
+from ._typing import _TP
from ._typing import is_column_element
from ._typing import is_select_statement
from ._typing import is_subquery
@@ -103,9 +106,20 @@ if TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _FromClauseArgument
from ._typing import _JoinTargetArgument
+ from ._typing import _MAYBE_ENTITY
+ from ._typing import _NOT_ENTITY
from ._typing import _OnClauseArgument
from ._typing import _SelectStatementForCompoundArgument
+ from ._typing import _T0
+ from ._typing import _T1
+ from ._typing import _T2
+ from ._typing import _T3
+ from ._typing import _T4
+ from ._typing import _T5
+ from ._typing import _T6
+ from ._typing import _T7
from ._typing import _TextCoercedExpressionArgument
+ from ._typing import _TypedColumnClauseArgument as _TCCA
from ._typing import _TypeEngineArgument
from .base import _AmbiguousTableNameMap
from .base import ExecutableOption
@@ -115,14 +129,13 @@ if TYPE_CHECKING:
from .dml import Delete
from .dml import Insert
from .dml import Update
+ from .elements import KeyedColumnElement
from .elements import NamedColumn
from .elements import TextClause
from .functions import Function
- from .schema import Column
from .schema import ForeignKey
from .schema import ForeignKeyConstraint
from .type_api import TypeEngine
- from .util import ClauseAdapter
from .visitors import _CloneCallableType
@@ -245,6 +258,14 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
raise NotImplementedError()
+class ExecutableReturnsRows(Executable, ReturnsRows):
+ """base for executable statements that return rows."""
+
+
+class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]):
+ """base for executable statements that return rows."""
+
+
SelfSelectable = TypeVar("SelfSelectable", bound="Selectable")
@@ -293,8 +314,8 @@ class Selectable(ReturnsRows):
)
def corresponding_column(
- self, column: ColumnElement[Any], require_embedded: bool = False
- ) -> Optional[ColumnElement[Any]]:
+ self, column: KeyedColumnElement[Any], require_embedded: bool = False
+ ) -> Optional[KeyedColumnElement[Any]]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from the
:attr:`_expression.Selectable.exported_columns`
@@ -593,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_use_schema_map = False
- def select(self) -> Select:
+ def select(self) -> Select[Any]:
r"""Return a SELECT of this :class:`_expression.FromClause`.
@@ -795,7 +816,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
)
@util.ro_non_memoized_property
- def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
+ def exported_columns(
+ self,
+ ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`.
@@ -817,7 +840,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.c
@util.ro_non_memoized_property
- def columns(self) -> ReadOnlyColumnCollection[str, Any]:
+ def columns(
+ self,
+ ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""A named-based collection of :class:`_expression.ColumnElement`
objects maintained by this :class:`_expression.FromClause`.
@@ -833,7 +858,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.c
@util.ro_memoized_property
- def c(self) -> ReadOnlyColumnCollection[str, Any]:
+ def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""
A synonym for :attr:`.FromClause.columns`
@@ -1223,7 +1248,7 @@ class Join(roles.DMLTableRole, FromClause):
@util.preload_module("sqlalchemy.sql.util")
def _populate_column_collection(self):
sqlutil = util.preloaded.sql_util
- columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [
+ columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [
c for c in self.right.c
]
@@ -1458,7 +1483,7 @@ class Join(roles.DMLTableRole, FromClause):
"join explicitly." % (a.description, b.description)
)
- def select(self) -> "Select":
+ def select(self) -> Select[Any]:
r"""Create a :class:`_expression.Select` from this
:class:`_expression.Join`.
@@ -2764,6 +2789,7 @@ class Subquery(AliasedReturnsRows):
cls, selectable: SelectBase, name: Optional[str] = None
) -> Subquery:
"""Return a :class:`.Subquery` object."""
+
return coercions.expect(
roles.SelectStatementRole, selectable
).subquery(name=name)
@@ -3216,7 +3242,6 @@ class SelectBase(
roles.CompoundElementRole,
roles.InElementRole,
HasCTE,
- Executable,
SupportsCloneAnnotations,
Selectable,
):
@@ -3239,7 +3264,9 @@ class SelectBase(
self._reset_memoizations()
@util.ro_non_memoized_property
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set.
@@ -3284,7 +3311,9 @@ class SelectBase(
raise NotImplementedError()
@property
- def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
+ def exported_columns(
+ self,
+ ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
that represents the "exported"
columns of this :class:`_expression.Selectable`, not including
@@ -3377,7 +3406,7 @@ class SelectBase(
def as_scalar(self):
return self.scalar_subquery()
- def exists(self):
+ def exists(self) -> Exists:
"""Return an :class:`_sql.Exists` representation of this selectable,
which can be used as a column expression.
@@ -3394,7 +3423,7 @@ class SelectBase(
"""
return Exists(self)
- def scalar_subquery(self):
+ def scalar_subquery(self) -> ScalarSelect[Any]:
"""Return a 'scalar' representation of this selectable, which can be
used as a column expression.
@@ -3607,7 +3636,7 @@ SelfGenerativeSelect = typing.TypeVar(
)
-class GenerativeSelect(SelectBase):
+class GenerativeSelect(SelectBase, Generative):
"""Base class for SELECT statements where additional elements can be
added.
@@ -4128,7 +4157,7 @@ class _CompoundSelectKeyword(Enum):
INTERSECT_ALL = "INTERSECT ALL"
-class CompoundSelect(HasCompileState, GenerativeSelect):
+class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
SELECT-based set operations.
@@ -4293,7 +4322,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
return self.selects[0]._all_selected_columns
@util.ro_non_memoized_property
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -4343,7 +4374,10 @@ class SelectState(util.MemoizedSlots, CompileState):
...
def __init__(
- self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any
+ self,
+ statement: Select[Any],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
):
self.statement = statement
self.from_clauses = statement._from_obj
@@ -4369,7 +4403,7 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def get_column_descriptions(
- cls, statement: Select
+ cls, statement: Select[Any]
) -> List[Dict[str, Any]]:
return [
{
@@ -4384,12 +4418,14 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def from_statement(
- cls, statement: Select, from_statement: ReturnsRows
- ) -> Any:
+ cls, statement: Select[Any], from_statement: ExecutableReturnsRows
+ ) -> ExecutableReturnsRows:
cls._plugin_not_implemented()
@classmethod
- def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]:
+ def get_columns_clause_froms(
+ cls, statement: Select[Any]
+ ) -> List[FromClause]:
return cls._normalize_froms(
itertools.chain.from_iterable(
element._from_objects for element in statement._raw_columns
@@ -4439,7 +4475,7 @@ class SelectState(util.MemoizedSlots, CompileState):
return go
- def _get_froms(self, statement: Select) -> List[FromClause]:
+ def _get_froms(self, statement: Select[Any]) -> List[FromClause]:
ambiguous_table_name_map: _AmbiguousTableNameMap
self._ambiguous_table_name_map = ambiguous_table_name_map = {}
@@ -4467,7 +4503,7 @@ class SelectState(util.MemoizedSlots, CompileState):
def _normalize_froms(
cls,
iterable_of_froms: Iterable[FromClause],
- check_statement: Optional[Select] = None,
+ check_statement: Optional[Select[Any]] = None,
ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
) -> List[FromClause]:
"""given an iterable of things to select FROM, reduce them to what
@@ -4615,7 +4651,7 @@ class SelectState(util.MemoizedSlots, CompileState):
@classmethod
def determine_last_joined_entity(
- cls, stmt: Select
+ cls, stmt: Select[Any]
) -> Optional[_JoinTargetElement]:
if stmt._setup_joins:
return stmt._setup_joins[-1][0]
@@ -4623,7 +4659,7 @@ class SelectState(util.MemoizedSlots, CompileState):
return None
@classmethod
- def all_selected_columns(cls, statement: Select) -> _SelectIterable:
+ def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable:
return [c for c in _select_iterables(statement._raw_columns)]
def _setup_joins(
@@ -4876,7 +4912,7 @@ class _MemoizedSelectEntities(
return c # type: ignore
@classmethod
- def _generate_for_statement(cls, select_stmt: Select) -> None:
+ def _generate_for_statement(cls, select_stmt: Select[Any]) -> None:
if select_stmt._setup_joins or select_stmt._with_options:
self = _MemoizedSelectEntities()
self._raw_columns = select_stmt._raw_columns
@@ -4888,7 +4924,7 @@ class _MemoizedSelectEntities(
select_stmt._setup_joins = select_stmt._with_options = ()
-SelfSelect = typing.TypeVar("SelfSelect", bound="Select")
+SelfSelect = typing.TypeVar("SelfSelect", bound="Select[Any]")
class Select(
@@ -4898,6 +4934,7 @@ class Select(
HasCompileState,
_SelectFromElements,
GenerativeSelect,
+ TypedReturnsRows[_TP],
):
"""Represents a ``SELECT`` statement.
@@ -4973,7 +5010,7 @@ class Select(
_compile_state_factory: Type[SelectState]
@classmethod
- def _create_raw_select(cls, **kw: Any) -> Select:
+ def _create_raw_select(cls, **kw: Any) -> Select[Any]:
"""Create a :class:`.Select` using raw ``__new__`` with no coercions.
Used internally to build up :class:`.Select` constructs with
@@ -4985,7 +5022,7 @@ class Select(
stmt.__dict__.update(kw)
return stmt
- def __init__(self, *entities: _ColumnsClauseArgument):
+ def __init__(self, *entities: _ColumnsClauseArgument[Any]):
r"""Construct a new :class:`_expression.Select`.
The public constructor for :class:`_expression.Select` is the
@@ -5013,7 +5050,9 @@ class Select(
cols = list(elem._select_iterable)
return cols[0].type
- def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect:
+ def filter(
+ self: SelfSelect, *criteria: _ColumnExpressionArgument[bool]
+ ) -> SelfSelect:
"""A synonym for the :meth:`_future.Select.where` method."""
return self.where(*criteria)
@@ -5032,7 +5071,28 @@ class Select(
return self._raw_columns[0]
- def filter_by(self, **kwargs):
+ if TYPE_CHECKING:
+
+ @overload
+ def scalar_subquery(
+ self: Select[Tuple[_MAYBE_ENTITY]],
+ ) -> ScalarSelect[Any]:
+ ...
+
+ @overload
+ def scalar_subquery(
+ self: Select[Tuple[_NOT_ENTITY]],
+ ) -> ScalarSelect[_NOT_ENTITY]:
+ ...
+
+ @overload
+ def scalar_subquery(self) -> ScalarSelect[Any]:
+ ...
+
+ def scalar_subquery(self) -> ScalarSelect[Any]:
+ ...
+
+ def filter_by(self: SelfSelect, **kwargs: Any) -> SelfSelect:
r"""apply the given filtering criterion as a WHERE clause
to this select.
@@ -5046,7 +5106,7 @@ class Select(
return self.filter(*clauses)
@property
- def column_descriptions(self):
+ def column_descriptions(self) -> Any:
"""Return a :term:`plugin-enabled` 'column descriptions' structure
referring to the columns which are SELECTed by this statement.
@@ -5089,7 +5149,9 @@ class Select(
meth = SelectState.get_plugin_class(self).get_column_descriptions
return meth(self)
- def from_statement(self, statement):
+ def from_statement(
+ self, statement: ExecutableReturnsRows
+ ) -> ExecutableReturnsRows:
"""Apply the columns which this :class:`.Select` would select
onto another statement.
@@ -5410,7 +5472,7 @@ class Select(
)
@property
- def inner_columns(self):
+ def inner_columns(self) -> _SelectIterable:
"""An iterator of all :class:`_expression.ColumnElement`
expressions which would
be rendered into the columns clause of the resulting SELECT statement.
@@ -5487,18 +5549,19 @@ class Select(
self._reset_memoizations()
- def get_children(self, **kwargs):
+ def get_children(self, **kw: Any) -> Iterable[ClauseElement]:
return itertools.chain(
super(Select, self).get_children(
- omit_attrs=("_from_obj", "_correlate", "_correlate_except")
+ omit_attrs=("_from_obj", "_correlate", "_correlate_except"),
+ **kw,
),
self._iterate_from_elements(),
)
@_generative
def add_columns(
- self: SelfSelect, *columns: _ColumnsClauseArgument
- ) -> SelfSelect:
+ self, *columns: _ColumnsClauseArgument[Any]
+ ) -> Select[Any]:
"""Return a new :func:`_expression.select` construct with
the given column expressions added to its columns clause.
@@ -5523,7 +5586,7 @@ class Select(
return self
def _set_entities(
- self, entities: Iterable[_ColumnsClauseArgument]
+ self, entities: Iterable[_ColumnsClauseArgument[Any]]
) -> None:
self._raw_columns = [
coercions.expect(
@@ -5538,7 +5601,7 @@ class Select(
"be removed in a future release. Please use "
":meth:`_expression.Select.add_columns`",
)
- def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect:
+ def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]:
"""Return a new :func:`_expression.select` construct with
the given column expression added to its columns clause.
@@ -5555,9 +5618,7 @@ class Select(
return self.add_columns(column)
@util.preload_module("sqlalchemy.sql.util")
- def reduce_columns(
- self: SelfSelect, only_synonyms: bool = True
- ) -> SelfSelect:
+ def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]:
"""Return a new :func:`_expression.select` construct with redundantly
named, equivalently-valued columns removed from the columns clause.
@@ -5580,20 +5641,115 @@ class Select(
all columns that are equivalent to another are removed.
"""
- return self.with_only_columns(
+ woc: Select[Any]
+ woc = self.with_only_columns(
*util.preloaded.sql_util.reduce_columns(
self._all_selected_columns,
only_synonyms=only_synonyms,
*(self._where_criteria + self._from_obj),
)
)
+ return woc
+
+ # START OVERLOADED FUNCTIONS self.with_only_columns Select 8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_sel_v1_overloads.py
+
+ @overload
+ def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> Select[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> Select[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def with_only_columns(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.with_only_columns
+
+ @overload
+ def with_only_columns(
+ self,
+ *columns: _ColumnsClauseArgument[Any],
+ maintain_column_froms: bool = False,
+ **__kw: Any,
+ ) -> Select[Any]:
+ ...
@_generative
def with_only_columns(
- self: SelfSelect,
- *columns: _ColumnsClauseArgument,
+ self,
+ *columns: _ColumnsClauseArgument[Any],
maintain_column_froms: bool = False,
- ) -> SelfSelect:
+ **__kw: Any,
+ ) -> Select[Any]:
r"""Return a new :func:`_expression.select` construct with its columns
clause replaced with the given columns.
@@ -5647,6 +5803,9 @@ class Select(
""" # noqa: E501
+ if __kw:
+ raise _no_kw()
+
# memoizations should be cleared here as of
# I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this
# is the case for now.
@@ -5915,7 +6074,9 @@ class Select(
return self
@HasMemoized_ro_memoized_attribute
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, ColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
@@ -6215,7 +6376,7 @@ class ScalarSelect(
by this :class:`_expression.ScalarSelect`.
"""
- self.element = cast(Select, self.element).where(crit)
+ self.element = cast("Select[Any]", self.element).where(crit)
return self
@overload
@@ -6269,7 +6430,9 @@ class ScalarSelect(
"""
- self.element = cast(Select, self.element).correlate(*fromclauses)
+ self.element = cast("Select[Any]", self.element).correlate(
+ *fromclauses
+ )
return self
@_generative
@@ -6307,7 +6470,7 @@ class ScalarSelect(
"""
- self.element = cast(Select, self.element).correlate_except(
+ self.element = cast("Select[Any]", self.element).correlate_except(
*fromclauses
)
return self
@@ -6331,12 +6494,18 @@ class Exists(UnaryExpression[bool]):
def __init__(
self,
__argument: Optional[
- Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+ Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]]
] = None,
):
+ s: ScalarSelect[Any]
+
+ # TODO: this seems like we should be using coercions for this
if __argument is None:
s = Select(literal_column("*")).scalar_subquery()
- elif isinstance(__argument, (SelectBase, ScalarSelect)):
+ elif isinstance(__argument, SelectBase):
+ s = __argument.scalar_subquery()
+ s._propagate_attrs = __argument._propagate_attrs
+ elif isinstance(__argument, ScalarSelect):
s = __argument
else:
s = Select(__argument).scalar_subquery()
@@ -6358,7 +6527,7 @@ class Exists(UnaryExpression[bool]):
element = fn(element)
return element.self_group(against=operators.exists)
- def select(self) -> Select:
+ def select(self) -> Select[Any]:
r"""Return a SELECT of this :class:`_expression.Exists`.
e.g.::
@@ -6452,7 +6621,7 @@ class Exists(UnaryExpression[bool]):
SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect")
-class TextualSelect(SelectBase):
+class TextualSelect(SelectBase, Executable, Generative):
"""Wrap a :class:`_expression.TextClause` construct within a
:class:`_expression.SelectBase`
interface.
@@ -6503,7 +6672,9 @@ class TextualSelect(SelectBase):
self.positional = positional
@HasMemoized_ro_memoized_attribute
- def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+ def selected_columns(
+ self,
+ ) -> ColumnCollection[str, KeyedColumnElement[Any]]:
"""A :class:`_expression.ColumnCollection`
representing the columns that
this SELECT statement or similar construct returns in its result set,
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index d08fef60a..8c45ba410 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -50,6 +50,7 @@ from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Grouping
+from .elements import KeyedColumnElement
from .elements import Label
from .elements import Null
from .elements import UnaryExpression
@@ -72,9 +73,7 @@ if typing.TYPE_CHECKING:
from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
from .elements import TextClause
- from .roles import FromClauseRole
from .selectable import _JoinTargetElement
- from .selectable import _OnClauseElement
from .selectable import _SelectIterable
from .selectable import Selectable
from .visitors import _TraverseCallableType
@@ -569,7 +568,7 @@ class _repr_row(_repr_base):
__slots__ = ("row",)
- def __init__(self, row: "Row", max_chars: int = 300):
+ def __init__(self, row: "Row[Any]", max_chars: int = 300):
self.row = row
self.max_chars = max_chars
@@ -1068,7 +1067,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
col = col._annotations["adapt_column"]
if TYPE_CHECKING:
- assert isinstance(col, ColumnElement)
+ assert isinstance(col, KeyedColumnElement)
if self.adapt_from_selectables and col not in self.equivalents:
for adp in self.adapt_from_selectables:
@@ -1078,7 +1077,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
return None
if TYPE_CHECKING:
- assert isinstance(col, ColumnElement)
+ assert isinstance(col, KeyedColumnElement)
if self.include_fn and not self.include_fn(col):
return None
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index e0a66fbcf..88586d834 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -450,7 +450,7 @@ class HasTraverseInternals:
@util.preload_module("sqlalchemy.sql.traversals")
def get_children(
- self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[HasTraverseInternals]:
r"""Return immediate child :class:`.visitors.HasTraverseInternals`
elements of this :class:`.visitors.HasTraverseInternals`.
@@ -594,7 +594,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
if typing.TYPE_CHECKING:
def get_children(
- self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+ self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[ExternallyTraversible]:
...
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 49c5d693a..da3fbc718 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -18,6 +18,7 @@ import hashlib
import inspect
import itertools
import operator
+import os
import re
import sys
import textwrap
@@ -32,6 +33,7 @@ from typing import Generic
from typing import Iterator
from typing import List
from typing import Mapping
+from typing import no_type_check
from typing import NoReturn
from typing import Optional
from typing import overload
@@ -2106,3 +2108,45 @@ def has_compiled_ext(raise_=False):
)
else:
return False
+
+
+@no_type_check
+def console_scripts(
+ path: str, options: dict, ignore_output: bool = False
+) -> None:
+
+ import subprocess
+ import shlex
+ from pathlib import Path
+
+ is_posix = os.name == "posix"
+
+ entrypoint_name = options["entrypoint"]
+
+ for entry in compat.importlib_metadata_get("console_scripts"):
+ if entry.name == entrypoint_name:
+ impl = entry
+ break
+ else:
+ raise Exception(
+ f"Could not find entrypoint console_scripts.{entrypoint_name}"
+ )
+ cmdline_options_str = options.get("options", "")
+ cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
+ path
+ ]
+
+ kw = {}
+ if ignore_output:
+ kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
+
+ subprocess.run(
+ [
+ sys.executable,
+ "-c",
+ "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
+ ]
+ + cmdline_options_list,
+ cwd=Path(__file__).parent.parent,
+ **kw,
+ )
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index d192dc06b..2a215c4f1 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -14,7 +14,7 @@ from typing import Type
from typing import TypeVar
from typing import Union
-from typing_extensions import NotRequired as NotRequired # noqa
+from typing_extensions import NotRequired as NotRequired
from . import compat