diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-03-12 19:59:29 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-03-12 19:59:29 +0000 |
| commit | 77fc8216a74e6b2d0efc6591c6c735687bd10002 (patch) | |
| tree | e338f22897ce9bcc994d625ad71aeb4b8ca7b446 /lib | |
| parent | df056af49c51dcbcd70eb13ead5c3d8588c08235 (diff) | |
| parent | 4c28867f944637ef313f98d5f09da05255418c6d (diff) | |
| download | sqlalchemy-77fc8216a74e6b2d0efc6591c6c735687bd10002.tar.gz | |
Merge "additional mypy strictness" into main
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/__init__.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/__init__.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 136 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 109 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/types.py | 174 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 49 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_py_collections.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 23 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 31 |
16 files changed, 296 insertions, 319 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 7ceb33c7c..de01a1b46 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -9,12 +9,8 @@ from __future__ import annotations from . import util as _util from .engine import AdaptedConnection as AdaptedConnection -from .engine import BaseCursorResult as BaseCursorResult from .engine import BaseRow as BaseRow from .engine import BindTyping as BindTyping -from .engine import BufferedColumnResultProxy as BufferedColumnResultProxy -from .engine import BufferedColumnRow as BufferedColumnRow -from .engine import BufferedRowResultProxy as BufferedRowResultProxy from .engine import ChunkedIteratorResult as ChunkedIteratorResult from .engine import Compiled as Compiled from .engine import Connection as Connection @@ -28,7 +24,6 @@ from .engine import engine_from_config as engine_from_config from .engine import ExceptionContext as ExceptionContext from .engine import ExecutionContext as ExecutionContext from .engine import FrozenResult as FrozenResult -from .engine import FullyBufferedResultProxy as FullyBufferedResultProxy from .engine import Inspector as Inspector from .engine import IteratorResult as IteratorResult from .engine import make_url as make_url diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 32f3f2ecc..29dd6aff9 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -25,12 +25,7 @@ from .base import Transaction as Transaction from .base import TwoPhaseTransaction as TwoPhaseTransaction from .create import create_engine as create_engine from .create import engine_from_config as engine_from_config -from .cursor import BaseCursorResult as BaseCursorResult -from .cursor import BufferedColumnResultProxy as BufferedColumnResultProxy -from .cursor import BufferedColumnRow as BufferedColumnRow -from .cursor import BufferedRowResultProxy as BufferedRowResultProxy from .cursor import CursorResult as CursorResult -from .cursor import FullyBufferedResultProxy as FullyBufferedResultProxy from .cursor import ResultProxy as ResultProxy from .interfaces import AdaptedConnection as AdaptedConnection from .interfaces import BindTyping as BindTyping diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 37faa880e..d8009e26c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -23,6 +23,7 @@ from typing import Tuple from typing import Type from typing import Union +from .interfaces import _IsolationLevel from .interfaces import BindTyping from .interfaces import ConnectionEventsTarget from .interfaces import DBAPICursor @@ -510,7 +511,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._handle_dbapi_exception(e, None, None, None, None) @property - def default_isolation_level(self) -> str: + def default_isolation_level(self) -> Optional[_IsolationLevel]: """The default isolation level assigned to this :class:`_engine.Connection`. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 78805bac1..821c0cb8e 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -23,8 +23,9 @@ from typing import List from typing import Optional from typing import Sequence from typing import Tuple -from typing import Type +from typing import Union +from .result import MergedResult from .result import Result from .result import ResultMetaData from .result import SimpleResultMetaData @@ -36,10 +37,12 @@ from ..sql import elements from ..sql import sqltypes from ..sql import util as sql_util from ..sql.base import _generative +from ..sql.compiler import ResultColumnsEntry from ..sql.compiler import RM_NAME from ..sql.compiler import RM_OBJECTS from ..sql.compiler import RM_RENDERED_NAME from ..sql.compiler import RM_TYPE +from ..sql.type_api import TypeEngine from ..util import compat from ..util.typing import Literal @@ -101,6 +104,7 @@ class CursorResultMetaData(ResultMetaData): _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]] _unpickled: bool _safe_for_cache: bool + _translated_indexes: Optional[List[int]] returns_rows: ClassVar[bool] = True @@ -123,7 +127,6 @@ class CursorResultMetaData(ResultMetaData): if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] - tup = tuplegetter(*indexes) new_metadata = self.__class__.__new__(self.__class__) @@ -526,7 +529,7 @@ class CursorResultMetaData(ResultMetaData): def _merge_textual_cols_by_position( self, context, cursor_description, result_columns ): - num_ctx_cols = len(result_columns) if result_columns else None + num_ctx_cols = len(result_columns) if num_ctx_cols > len(cursor_description): util.warn( @@ -568,6 +571,8 @@ class CursorResultMetaData(ResultMetaData): match_map = self._create_description_match_map( result_columns, loose_column_name_matching ) + mapped_type: TypeEngine[Any] + for ( idx, colname, @@ -597,15 +602,17 @@ class CursorResultMetaData(ResultMetaData): @classmethod def _create_description_match_map( cls, - result_columns, - loose_column_name_matching=False, - ): + result_columns: List[ResultColumnsEntry], + loose_column_name_matching: bool = False, + ) -> Dict[Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]]: """when matching cursor.description to a set of names that are present in a Compiled object, as is the case with TextualSelect, get all the names we expect might match those in cursor.description. """ - d = {} + d: Dict[ + Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int] + ] = {} for ridx, elem in enumerate(result_columns): key = elem[RM_RENDERED_NAME] @@ -630,7 +637,6 @@ class CursorResultMetaData(ResultMetaData): r_key, (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx), ) - return d def _merge_cols_by_none(self, context, cursor_description): @@ -739,7 +745,9 @@ class CursorResultMetaData(ResultMetaData): self._keys = state["_keys"] self._unpickled = True if state["_translated_indexes"]: - self._translated_indexes = state["_translated_indexes"] + self._translated_indexes = cast( + "List[int]", state["_translated_indexes"] + ) self._tuplefilter = tuplegetter(*self._translated_indexes) else: self._translated_indexes = self._tuplefilter = None @@ -1144,12 +1152,32 @@ class _NoResultMetaData(ResultMetaData): _NO_RESULT_METADATA = _NoResultMetaData() -class BaseCursorResult: - """Base class for database result objects.""" +class CursorResult(Result): + """A Result that is representing state from a DBAPI cursor. + + .. versionchanged:: 1.4 The :class:`.CursorResult`` + class replaces the previous :class:`.ResultProxy` interface. + This classes are based on the :class:`.Result` calling API + which provides an updated usage model and calling facade for + SQLAlchemy Core and SQLAlchemy ORM. + + Returns database rows via the :class:`.Row` class, which provides + additional API features and behaviors on top of the raw data returned by + the DBAPI. Through the use of filters such as the :meth:`.Result.scalars` + method, other kinds of objects may also be returned. + + .. seealso:: + + :ref:`coretutorial_selecting` - introductory material for accessing + :class:`_engine.CursorResult` and :class:`.Row` objects. - _metadata: ResultMetaData + """ + + _metadata: Union[CursorResultMetaData, _NoResultMetaData] + _no_result_metadata = _NO_RESULT_METADATA _soft_closed: bool = False closed: bool = False + _is_cursor = True def __init__(self, context, cursor_strategy, cursor_description): self.context = context @@ -1169,11 +1197,11 @@ class BaseCursorResult: if echo: log = self.context.connection._log_debug - def log_row(row): + def _log_row(row): log("Row %r", sql_util._repr_row(row)) return row - self._row_logging_fn = log_row + self._row_logging_fn = log_row = _log_row else: log_row = None @@ -1188,13 +1216,16 @@ class BaseCursorResult: ) if log_row: - def make_row(row): + def _make_row_2(row): made_row = _make_row(row) + assert log_row is not None log_row(made_row) return made_row + make_row = _make_row_2 else: make_row = _make_row + self._set_memoized_attribute("_row_getter", make_row) else: @@ -1208,7 +1239,7 @@ class BaseCursorResult: if compiled._cached_metadata: metadata = compiled._cached_metadata else: - metadata = self._cursor_metadata(self, cursor_description) + metadata = CursorResultMetaData(self, cursor_description) if metadata._safe_for_cache: compiled._cached_metadata = metadata @@ -1239,7 +1270,7 @@ class BaseCursorResult: self._metadata = metadata else: - self._metadata = metadata = self._cursor_metadata( + self._metadata = metadata = CursorResultMetaData( self, cursor_description ) if self._echo: @@ -1669,33 +1700,6 @@ class BaseCursorResult: """ return self.context.isinsert - -class CursorResult(BaseCursorResult, Result): - """A Result that is representing state from a DBAPI cursor. - - .. versionchanged:: 1.4 The :class:`.CursorResult`` - class replaces the previous :class:`.ResultProxy` interface. - This classes are based on the :class:`.Result` calling API - which provides an updated usage model and calling facade for - SQLAlchemy Core and SQLAlchemy ORM. - - Returns database rows via the :class:`.Row` class, which provides - additional API features and behaviors on top of the raw data returned by - the DBAPI. Through the use of filters such as the :meth:`.Result.scalars` - method, other kinds of objects may also be returned. - - .. seealso:: - - :ref:`coretutorial_selecting` - introductory material for accessing - :class:`_engine.CursorResult` and :class:`.Row` objects. - - """ - - _cursor_metadata: Type[ResultMetaData] = CursorResultMetaData - _cursor_strategy_cls = CursorFetchStrategy - _no_result_metadata = _NO_RESULT_METADATA - _is_cursor = True - def _fetchiter_impl(self): fetchone = self.cursor_strategy.fetchone @@ -1717,12 +1721,13 @@ class CursorResult(BaseCursorResult, Result): def _raw_row_iterator(self): return self._fetchiter_impl() - def merge(self, *others): - merged_result = super(CursorResult, self).merge(*others) + def merge(self, *others: Result) -> MergedResult: + merged_result = super().merge(*others) setup_rowcounts = not self._metadata.returns_rows if setup_rowcounts: merged_result.rowcount = sum( - result.rowcount for result in (self,) + others + cast(CursorResult, result).rowcount + for result in (self,) + others ) return merged_result @@ -1756,40 +1761,3 @@ class CursorResult(BaseCursorResult, Result): ResultProxy = CursorResult - - -class BufferedRowResultProxy(ResultProxy): - """A ResultProxy with row buffering behavior. - - .. deprecated:: 1.4 this class is now supplied using a strategy object. - See :class:`.BufferedRowCursorFetchStrategy`. - - """ - - _cursor_strategy_cls: Type[ - CursorFetchStrategy - ] = BufferedRowCursorFetchStrategy - - -class FullyBufferedResultProxy(ResultProxy): - """A result proxy that buffers rows fully upon creation. - - .. deprecated:: 1.4 this class is now supplied using a strategy object. - See :class:`.FullyBufferedCursorFetchStrategy`. - - """ - - _cursor_strategy_cls = FullyBufferedCursorFetchStrategy - - -class BufferedColumnRow(Row): - """Row is now BufferedColumn in all cases""" - - -class BufferedColumnResultProxy(ResultProxy): - """A ResultProxy with column buffering behavior. - - .. versionchanged:: 1.4 This is now the default behavior of the Row - and this class does not change behavior in any way. - - """ diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 0e0c76389..2579f573c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -55,6 +55,9 @@ from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name if typing.TYPE_CHECKING: + from .base import Connection + from .base import Engine + from .characteristics import ConnectionCharacteristic from .interfaces import _AnyMultiExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams @@ -62,6 +65,7 @@ if typing.TYPE_CHECKING: from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions + from .interfaces import _MutableCoreSingleExecuteParams from .result import _ProcessorType from .row import Row from .url import URL @@ -71,6 +75,7 @@ if typing.TYPE_CHECKING: from ..sql import Executable from ..sql.compiler import Compiled from ..sql.compiler import ResultColumnsEntry + from ..sql.compiler import TypeCompiler from ..sql.schema import Column from ..sql.type_api import TypeEngine @@ -92,7 +97,11 @@ class DefaultDialect(Dialect): statement_compiler = compiler.SQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.GenericTypeCompiler # type: ignore + if typing.TYPE_CHECKING: + type_compiler: TypeCompiler + else: + type_compiler = compiler.GenericTypeCompiler + preparer = compiler.IdentifierPreparer supports_alter = True supports_comments = False @@ -202,7 +211,7 @@ class DefaultDialect(Dialect): server_version_info = None - default_schema_name = None + default_schema_name: Optional[str] = None # indicates symbol names are # UPPERCASEd if they are case insensitive @@ -290,7 +299,12 @@ class DefaultDialect(Dialect): self.positional = self.paramstyle in ("qmark", "format", "numeric") self.identifier_preparer = self.preparer(self) self._on_connect_isolation_level = isolation_level - self.type_compiler = self.type_compiler(self) + + tt_callable = cast( + Type[compiler.GenericTypeCompiler], + self.type_compiler, + ) + self.type_compiler = tt_callable(self) if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean @@ -490,12 +504,14 @@ class DefaultDialect(Dialect): opts.update(url.query) return [[], opts] - def set_engine_execution_options(self, engine, opts): + def set_engine_execution_options( + self, engine: Engine, opts: Mapping[str, str] + ) -> None: supported_names = set(self.connection_characteristics).intersection( opts ) if supported_names: - characteristics = util.immutabledict( + characteristics: Mapping[str, str] = util.immutabledict( (name, opts[name]) for name in supported_names ) @@ -505,12 +521,14 @@ class DefaultDialect(Dialect): connection, characteristics ) - def set_connection_execution_options(self, connection, opts): + def set_connection_execution_options( + self, connection: Connection, opts: Mapping[str, str] + ) -> None: supported_names = set(self.connection_characteristics).intersection( opts ) if supported_names: - characteristics = util.immutabledict( + characteristics: Mapping[str, str] = util.immutabledict( (name, opts[name]) for name in supported_names ) self._set_connection_characteristics(connection, characteristics) @@ -800,7 +818,7 @@ class DefaultExecutionContext(ExecutionContext): dialect: Dialect unicode_statement: str cursor: DBAPICursor - compiled_parameters: _CoreMultiExecuteParams + compiled_parameters: List[_MutableCoreSingleExecuteParams] parameters: _DBAPIMultiExecuteParams extracted_parameters: _CoreSingleExecuteParams @@ -1157,7 +1175,11 @@ class DefaultExecutionContext(ExecutionContext): parameters = {} conn._cursor_execute(self.cursor, stmt, parameters, context=self) - r = self.cursor.fetchone()[0] + row = self.cursor.fetchone() + if row is not None: + r = row[0] + else: + r = None if type_ is not None: # apply type post processors to the result proc = type_._cached_result_processor( @@ -1299,10 +1321,11 @@ class DefaultExecutionContext(ExecutionContext): result = _cursor.CursorResult(self, strategy, cursor_description) + compiled = self.compiled if ( - self.compiled + compiled and not self.isddl - and self.compiled.has_out_parameters + and cast(SQLCompiler, compiled).has_out_parameters ): self._setup_out_parameters(result) @@ -1311,10 +1334,11 @@ class DefaultExecutionContext(ExecutionContext): return result def _setup_out_parameters(self, result): + compiled = cast(SQLCompiler, self.compiled) out_bindparams = [ (param, name) - for param, name in self.compiled.bind_names.items() + for param, name in compiled.bind_names.items() if param.isoutparam ] out_parameters = {} @@ -1339,9 +1363,10 @@ class DefaultExecutionContext(ExecutionContext): result.out_parameters = out_parameters def _setup_dml_or_text_result(self): + compiled = cast(SQLCompiler, self.compiled) if self.isinsert: - if self.compiled.postfetch_lastrowid: + if compiled.postfetch_lastrowid: self.inserted_primary_key_rows = ( self._setup_ins_pk_from_lastrowid() ) @@ -1397,7 +1422,8 @@ class DefaultExecutionContext(ExecutionContext): result.rowcount row = result.fetchone() - self.returned_default_rows = [row] + if row is not None: + self.returned_default_rows = [row] result._soft_close() @@ -1420,13 +1446,17 @@ class DefaultExecutionContext(ExecutionContext): return self._setup_ins_pk_from_empty() def _setup_ins_pk_from_lastrowid(self): - getter = self.compiled._inserted_primary_key_from_lastrowid_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter lastrowid = self.get_lastrowid() return [getter(lastrowid, self.compiled_parameters[0])] def _setup_ins_pk_from_empty(self): - getter = self.compiled._inserted_primary_key_from_lastrowid_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter return [getter(None, param) for param in self.compiled_parameters] def _setup_ins_pk_from_implicit_returning(self, result, rows): @@ -1434,7 +1464,9 @@ class DefaultExecutionContext(ExecutionContext): if not rows: return [] - getter = self.compiled._inserted_primary_key_from_returning_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_returning_getter compiled_params = self.compiled_parameters return [ @@ -1443,7 +1475,7 @@ class DefaultExecutionContext(ExecutionContext): def lastrow_has_defaults(self): return (self.isinsert or self.isupdate) and bool( - self.compiled.postfetch + cast(SQLCompiler, self.compiled).postfetch ) def _set_input_sizes(self): @@ -1464,7 +1496,7 @@ class DefaultExecutionContext(ExecutionContext): if self.isddl or self.is_text: return - compiled = self.compiled + compiled = cast(SQLCompiler, self.compiled) inputsizes = compiled._get_set_input_sizes_lookup() @@ -1487,7 +1519,8 @@ class DefaultExecutionContext(ExecutionContext): if dialect.positional: items = [ - (key, compiled.binds[key]) for key in compiled.positiontup + (key, compiled.binds[key]) + for key in compiled.positiontup or () ] else: items = [ @@ -1495,7 +1528,7 @@ class DefaultExecutionContext(ExecutionContext): for bindparam, key in compiled.bind_names.items() ] - generic_inputsizes = [] + generic_inputsizes: List[Tuple[str, Any, TypeEngine[Any]]] = [] for key, bindparam in items: if bindparam in compiled.literal_execute_params: continue @@ -1578,20 +1611,19 @@ class DefaultExecutionContext(ExecutionContext): compiled_params = compiled.construct_params() processors = compiled._bind_processors if compiled.positional: - positiontup = compiled.positiontup parameters = self.dialect.execute_sequence_format( [ - processors[key](compiled_params[key]) + processors[key](compiled_params[key]) # type: ignore if key in processors else compiled_params[key] - for key in positiontup + for key in compiled.positiontup or () ] ) else: parameters = dict( ( key, - processors[key](compiled_params[key]) + processors[key](compiled_params[key]) # type: ignore if key in processors else compiled_params[key], ) @@ -1667,15 +1699,18 @@ class DefaultExecutionContext(ExecutionContext): "get_current_parameters() can only be invoked in the " "context of a Python side column default function" ) - - compile_state = self.compiled.compile_state + else: + assert column is not None + assert parameters is not None + compile_state = cast(SQLCompiler, self.compiled).compile_state + assert compile_state is not None if ( isolate_multiinsert_groups and self.isinsert and compile_state._has_multi_parameters ): if column._is_multiparam_column: - index = column.index + 1 + index = column.index + 1 # type: ignore d = {column.original.key: parameters[column.key]} else: d = {column.key: parameters[column.key]} @@ -1701,12 +1736,14 @@ class DefaultExecutionContext(ExecutionContext): return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): - key_getter = self.compiled._within_exec_param_key_getter + compiled = cast(SQLCompiler, self.compiled) - scalar_defaults = {} + key_getter = compiled._within_exec_param_key_getter - insert_prefetch = self.compiled.insert_prefetch - update_prefetch = self.compiled.update_prefetch + scalar_defaults: Dict[Column[Any], Any] = {} + + insert_prefetch = compiled.insert_prefetch + update_prefetch = compiled.update_prefetch # pre-determine scalar Python-side defaults # to avoid many calls of get_insert_default()/ @@ -1739,12 +1776,14 @@ class DefaultExecutionContext(ExecutionContext): del self.current_parameters def _process_executesingle_defaults(self): - key_getter = self.compiled._within_exec_param_key_getter + compiled = cast(SQLCompiler, self.compiled) + + key_getter = compiled._within_exec_param_key_getter self.current_parameters = ( compiled_parameters ) = self.compiled_parameters[0] - for c in self.compiled.insert_prefetch: + for c in compiled.insert_prefetch: if c.default and not c.default.is_sequence and c.default.is_scalar: val = c.default.arg else: @@ -1753,7 +1792,7 @@ class DefaultExecutionContext(ExecutionContext): if val is not None: compiled_parameters[key_getter(c)] = val - for c in self.compiled.update_prefetch: + for c in compiled.update_prefetch: val = self.get_update_default(c) if val is not None: diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 5aefcf5b5..e65546eb7 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -36,7 +36,6 @@ from ..sql.compiler import TypeCompiler as TypeCompiler from ..sql.compiler import TypeCompiler # noqa from ..util import immutabledict from ..util.concurrency import await_only -from ..util.typing import _TypeToInstance from ..util.typing import Literal from ..util.typing import NotRequired from ..util.typing import Protocol @@ -58,6 +57,8 @@ if TYPE_CHECKING: from ..sql.elements import ClauseElement from ..sql.schema import Column from ..sql.schema import ColumnDefault + from ..sql.schema import Sequence as Sequence_SchemaItem + from ..sql.sqltypes import Integer from ..sql.type_api import TypeEngine ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]] @@ -156,6 +157,8 @@ class DBAPICursor(Protocol): arraysize: int + lastrowid: int + def close(self) -> None: ... @@ -196,6 +199,7 @@ class DBAPICursor(Protocol): _CoreSingleExecuteParams = Mapping[str, Any] +_MutableCoreSingleExecuteParams = MutableMapping[str, Any] _CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams] _CoreAnyExecuteParams = Union[ _CoreMultiExecuteParams, _CoreSingleExecuteParams @@ -605,7 +609,7 @@ class Dialect(EventTarget): ddl_compiler: Type[DDLCompiler] """a :class:`.Compiled` class used to compile DDL statements""" - type_compiler: _TypeToInstance[TypeCompiler] + type_compiler: Union[Type[TypeCompiler], TypeCompiler] """a :class:`.Compiled` class used to compile SQL type objects""" preparer: Type[IdentifierPreparer] @@ -633,7 +637,7 @@ class Dialect(EventTarget): """ - default_isolation_level: _IsolationLevel + default_isolation_level: Optional[_IsolationLevel] """the isolation that is implicitly present on new connections""" execution_ctx_cls: Type["ExecutionContext"] @@ -653,6 +657,13 @@ class Dialect(EventTarget): max_identifier_length: int """The maximum length of identifier names.""" + supports_server_side_cursors: bool + """indicates if the dialect supports server side cursors""" + + server_side_cursors: bool + """deprecated; indicates if the dialect should attempt to use server + side cursors by default""" + supports_sane_rowcount: bool """Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. @@ -2302,6 +2313,11 @@ class ExecutionContext: def _setup_result_proxy(self) -> Result: raise NotImplementedError() + def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int: + """given a :class:`.Sequence`, invoke it and return the next int + value""" + raise NotImplementedError() + def create_cursor(self) -> DBAPICursor: """Return a new cursor generated from this ExecutionContext's connection. diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 0951d5770..87d3cac1c 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -1880,6 +1880,7 @@ class MergedResult(IteratorResult): """ closed = False + rowcount: Optional[int] def __init__( self, cursor_metadata: ResultMetaData, results: Sequence[Result] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 09e38a5ab..423c3d446 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -34,6 +34,7 @@ import re from time import perf_counter import typing from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Mapping @@ -629,11 +630,11 @@ class SQLCompiler(Compiled): """list of columns that can be post-fetched after INSERT or UPDATE to receive server-updated values""" - insert_prefetch: Optional[List[Column[Any]]] + insert_prefetch: Sequence[Column[Any]] = () """list of columns for which default values should be evaluated before an INSERT takes place""" - update_prefetch: Optional[List[Column[Any]]] + update_prefetch: Sequence[Column[Any]] = () """list of columns for which onupdate default values should be evaluated before an UPDATE takes place""" @@ -739,8 +740,6 @@ class SQLCompiler(Compiled): """if True, there are bindparam() objects that have the isoutparam flag set.""" - insert_prefetch = update_prefetch = () - postfetch_lastrowid = False """if True, and this in insert, use cursor.lastrowid to populate result.inserted_primary_key. """ @@ -1340,7 +1339,7 @@ class SQLCompiler(Compiled): ) @util.memoized_property - def _within_exec_param_key_getter(self): + def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._key_getters_for_crud_column[2] if self.escaped_bind_names: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 4c38c4efa..168da17cc 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -58,12 +58,13 @@ from ..util.langhelpers import TypingOnly if typing.TYPE_CHECKING: from decimal import Decimal + from .compiler import Compiled + from .compiler import SQLCompiler from .operators import OperatorType from .selectable import FromClause from .selectable import Select from .sqltypes import Boolean # noqa from .type_api import TypeEngine - from ..engine import Compiled from ..engine import Connection from ..engine import Dialect from ..engine import Engine @@ -573,6 +574,25 @@ class ClauseElement( ) +class DQLDMLClauseElement(ClauseElement): + """represents a :class:`.ClauseElement` that compiles to a DQL or DML + expression, not DDL. + + .. versionadded:: 2.0 + + """ + + if typing.TYPE_CHECKING: + + def compile( # noqa: A001 + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> SQLCompiler: + ... + + class CompilerColumnElement( roles.DMLColumnRole, roles.DDLConstraintColumnRole, @@ -955,7 +975,7 @@ class ColumnElement( roles.DDLExpressionRole, SQLCoreOperations[_T], operators.ColumnOperators[SQLCoreOperations], - ClauseElement, + DQLDMLClauseElement, ): """Represent a column-oriented SQL expression suitable for usage in the "columns" clause, WHERE clause etc. of a statement. @@ -1820,7 +1840,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): ) -class TypeClause(ClauseElement): +class TypeClause(DQLDMLClauseElement): """Handle a type keyword in a SQL statement. Used by the ``Case`` statement. @@ -1849,7 +1869,7 @@ class TextClause( roles.BinaryElementRole, roles.InElementRole, Executable, - ClauseElement, + DQLDMLClauseElement, ): """Represent a literal SQL text fragment. @@ -2285,7 +2305,7 @@ class ClauseList( roles.OrderByRole, roles.ColumnsClauseRole, roles.DMLColumnRole, - ClauseElement, + DQLDMLClauseElement, ): """Describe a list of clauses, separated by an operator. @@ -3205,7 +3225,7 @@ class IndexExpression(BinaryExpression): inherit_cache = True -class GroupedElement(ClauseElement): +class GroupedElement(DQLDMLClauseElement): """Represent any parenthesized expression""" __visit_name__ = "grouping" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index fdae4d7b0..c270e1564 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1131,6 +1131,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): __visit_name__ = "column" inherit_cache = True + key: str @overload def __init__( diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index a5cbffb5e..e5c2bef68 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -62,6 +62,7 @@ from .elements import ClauseElement from .elements import ClauseList from .elements import ColumnClause from .elements import ColumnElement +from .elements import DQLDMLClauseElement from .elements import GroupedElement from .elements import Grouping from .elements import literal_column @@ -85,7 +86,7 @@ class _OffsetLimitParam(BindParameter): return self.effective_value -class ReturnsRows(roles.ReturnsRowsRole, ClauseElement): +class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): """The base-most class for Core constructs that have some concept of columns that can represent rows. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 45e31aaf7..b0df99c41 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -12,117 +12,63 @@ from __future__ import annotations -from .sql.sqltypes import _Binary -from .sql.sqltypes import ARRAY -from .sql.sqltypes import BIGINT -from .sql.sqltypes import BigInteger -from .sql.sqltypes import BINARY -from .sql.sqltypes import BLOB -from .sql.sqltypes import BOOLEAN -from .sql.sqltypes import Boolean -from .sql.sqltypes import CHAR -from .sql.sqltypes import CLOB -from .sql.sqltypes import Concatenable -from .sql.sqltypes import DATE -from .sql.sqltypes import Date -from .sql.sqltypes import DATETIME -from .sql.sqltypes import DateTime -from .sql.sqltypes import DECIMAL -from .sql.sqltypes import DOUBLE -from .sql.sqltypes import Double -from .sql.sqltypes import DOUBLE_PRECISION -from .sql.sqltypes import Enum -from .sql.sqltypes import FLOAT -from .sql.sqltypes import Float -from .sql.sqltypes import Indexable -from .sql.sqltypes import INT -from .sql.sqltypes import INTEGER -from .sql.sqltypes import Integer -from .sql.sqltypes import Interval -from .sql.sqltypes import JSON -from .sql.sqltypes import LargeBinary -from .sql.sqltypes import MatchType -from .sql.sqltypes import NCHAR -from .sql.sqltypes import NULLTYPE -from .sql.sqltypes import NullType -from .sql.sqltypes import NUMERIC -from .sql.sqltypes import Numeric -from .sql.sqltypes import NVARCHAR -from .sql.sqltypes import PickleType -from .sql.sqltypes import REAL -from .sql.sqltypes import SchemaType -from .sql.sqltypes import SMALLINT -from .sql.sqltypes import SmallInteger -from .sql.sqltypes import String -from .sql.sqltypes import STRINGTYPE -from .sql.sqltypes import TEXT -from .sql.sqltypes import Text -from .sql.sqltypes import TIME -from .sql.sqltypes import Time -from .sql.sqltypes import TIMESTAMP -from .sql.sqltypes import TupleType -from .sql.sqltypes import Unicode -from .sql.sqltypes import UnicodeText -from .sql.sqltypes import VARBINARY -from .sql.sqltypes import VARCHAR -from .sql.type_api import adapt_type -from .sql.type_api import ExternalType -from .sql.type_api import to_instance -from .sql.type_api import TypeDecorator -from .sql.type_api import TypeEngine -from .sql.type_api import UserDefinedType -from .sql.type_api import Variant - -__all__ = [ - "TypeEngine", - "TypeDecorator", - "UserDefinedType", - "ExternalType", - "INT", - "CHAR", - "VARCHAR", - "NCHAR", - "NVARCHAR", - "TEXT", - "Text", - "FLOAT", - "NUMERIC", - "REAL", - "DECIMAL", - "TIMESTAMP", - "DATETIME", - "CLOB", - "BLOB", - "BINARY", - "VARBINARY", - "BOOLEAN", - "BIGINT", - "SMALLINT", - "INTEGER", - "DATE", - "TIME", - "TupleType", - "String", - "Integer", - "SmallInteger", - "BigInteger", - "Numeric", - "Float", - "Double", - "DOUBLE", - "DOUBLE_PRECISION", - "DateTime", - "Date", - "Time", - "LargeBinary", - "Boolean", - "Unicode", - "Concatenable", - "UnicodeText", - "PickleType", - "Interval", - "Enum", - "Indexable", - "ARRAY", - "JSON", -] +from .sql.sqltypes import _Binary as _Binary +from .sql.sqltypes import ARRAY as ARRAY +from .sql.sqltypes import BIGINT as BIGINT +from .sql.sqltypes import BigInteger as BigInteger +from .sql.sqltypes import BINARY as BINARY +from .sql.sqltypes import BLOB as BLOB +from .sql.sqltypes import BOOLEAN as BOOLEAN +from .sql.sqltypes import Boolean as Boolean +from .sql.sqltypes import CHAR as CHAR +from .sql.sqltypes import CLOB as CLOB +from .sql.sqltypes import Concatenable as Concatenable +from .sql.sqltypes import DATE as DATE +from .sql.sqltypes import Date as Date +from .sql.sqltypes import DATETIME as DATETIME +from .sql.sqltypes import DateTime as DateTime +from .sql.sqltypes import DECIMAL as DECIMAL +from .sql.sqltypes import DOUBLE as DOUBLE +from .sql.sqltypes import Double as Double +from .sql.sqltypes import DOUBLE_PRECISION as DOUBLE_PRECISION +from .sql.sqltypes import Enum as Enum +from .sql.sqltypes import FLOAT as FLOAT +from .sql.sqltypes import Float as Float +from .sql.sqltypes import Indexable as Indexable +from .sql.sqltypes import INT as INT +from .sql.sqltypes import INTEGER as INTEGER +from .sql.sqltypes import Integer as Integer +from .sql.sqltypes import Interval as Interval +from .sql.sqltypes import JSON as JSON +from .sql.sqltypes import LargeBinary as LargeBinary +from .sql.sqltypes import MatchType as MatchType +from .sql.sqltypes import NCHAR as NCHAR +from .sql.sqltypes import NULLTYPE as NULLTYPE +from .sql.sqltypes import NullType as NullType +from .sql.sqltypes import NUMERIC as NUMERIC +from .sql.sqltypes import Numeric as Numeric +from .sql.sqltypes import NVARCHAR as NVARCHAR +from .sql.sqltypes import PickleType as PickleType +from .sql.sqltypes import REAL as REAL +from .sql.sqltypes import SchemaType as SchemaType +from .sql.sqltypes import SMALLINT as SMALLINT +from .sql.sqltypes import SmallInteger as SmallInteger +from .sql.sqltypes import String as String +from .sql.sqltypes import STRINGTYPE as STRINGTYPE +from .sql.sqltypes import TEXT as TEXT +from .sql.sqltypes import Text as Text +from .sql.sqltypes import TIME as TIME +from .sql.sqltypes import Time as Time +from .sql.sqltypes import TIMESTAMP as TIMESTAMP +from .sql.sqltypes import TupleType as TupleType +from .sql.sqltypes import Unicode as Unicode +from .sql.sqltypes import UnicodeText as UnicodeText +from .sql.sqltypes import VARBINARY as VARBINARY +from .sql.sqltypes import VARCHAR as VARCHAR +from .sql.type_api import adapt_type as adapt_type +from .sql.type_api import ExternalType as ExternalType +from .sql.type_api import to_instance as to_instance +from .sql.type_api import TypeDecorator as TypeDecorator +from .sql.type_api import TypeEngine as TypeEngine +from .sql.type_api import UserDefinedType as UserDefinedType +from .sql.type_api import Variant as Variant diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index e0b53b445..06a009c5b 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -34,6 +34,7 @@ import weakref from ._has_cy import HAS_CYEXTENSION from .typing import Literal +from .typing import Protocol if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_collections import immutabledict as immutabledict @@ -62,7 +63,7 @@ else: _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT", bound=Any) _VT = TypeVar("_VT", bound=Any) - +_T_co = TypeVar("_T_co", covariant=True) EMPTY_SET: FrozenSet[Any] = frozenset() @@ -597,7 +598,17 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): self._mutex.release() -class ScopedRegistry: +class _CreateFuncType(Protocol[_T_co]): + def __call__(self) -> _T_co: + ... + + +class _ScopeFuncType(Protocol): + def __call__(self) -> Any: + ... + + +class ScopedRegistry(Generic[_T]): """A Registry that can store one or multiple instances of a single class on the basis of a "scope" function. @@ -614,6 +625,10 @@ class ScopedRegistry: __slots__ = "createfunc", "scopefunc", "registry" + createfunc: _CreateFuncType[_T] + scopefunc: _ScopeFuncType + registry: Any + def __init__(self, createfunc, scopefunc): """Construct a new :class:`.ScopedRegistry`. @@ -629,24 +644,24 @@ class ScopedRegistry: self.scopefunc = scopefunc self.registry = {} - def __call__(self): + def __call__(self) -> _T: key = self.scopefunc() try: - return self.registry[key] + return self.registry[key] # type: ignore[no-any-return] except KeyError: - return self.registry.setdefault(key, self.createfunc()) + return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 - def has(self): + def has(self) -> bool: """Return True if an object is present in the current scope.""" return self.scopefunc() in self.registry - def set(self, obj): + def set(self, obj: _T) -> None: """Set the value for the current scope.""" self.registry[self.scopefunc()] = obj - def clear(self): + def clear(self) -> None: """Clear the current scope, if any.""" try: @@ -655,32 +670,32 @@ class ScopedRegistry: pass -class ThreadLocalRegistry(ScopedRegistry): +class ThreadLocalRegistry(ScopedRegistry[_T]): """A :class:`.ScopedRegistry` that uses a ``threading.local()`` variable for storage. """ - def __init__(self, createfunc): + def __init__(self, createfunc: Callable[[], _T]): self.createfunc = createfunc self.registry = threading.local() - def __call__(self): + def __call__(self) -> _T: try: - return self.registry.value + return self.registry.value # type: ignore[no-any-return] except AttributeError: val = self.registry.value = self.createfunc() - return val + return val # type: ignore[no-any-return] - def has(self): + def has(self) -> bool: return hasattr(self.registry, "value") - def set(self, obj): + def set(self, obj: _T) -> None: self.registry.value = obj - def clear(self): + def clear(self) -> None: try: - del self.registry.value # type: ignore + del self.registry.value except AttributeError: pass diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 771e974e9..d50352930 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -11,6 +11,7 @@ from itertools import filterfalse from typing import AbstractSet from typing import Any from typing import cast +from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator @@ -67,7 +68,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.__init__(new, *args) return new - def __init__(self, *args: Union[Mapping[_KT, _VT], Tuple[_KT, _VT]]): + def __init__( + self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]] + ): pass def __reduce__(self): @@ -369,6 +372,8 @@ class IdentitySet: def difference(self, iterable): result = self.__new__(self.__class__) + other: Collection[Any] + if isinstance(iterable, self.__class__): other = iterable._members else: @@ -394,6 +399,9 @@ class IdentitySet: def intersection(self, iterable): result = self.__new__(self.__class__) + + other: Collection[Any] + if isinstance(iterable, self.__class__): other = iterable._members else: @@ -466,7 +474,7 @@ class IdentitySet: def unique_list(seq, hashfunc=None): - seen = set() + seen: Set[Any] = set() seen_add = seen.add if not hashfunc: return [x for x in seq if x not in seen and not seen_add(x)] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 5674e19af..8cb84f73f 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -679,7 +679,7 @@ def create_proxy_methods( def decorate(cls): def instrument(name, clslevel=False): - fn = cast(Callable[..., Any], getattr(target_cls, name)) + fn = cast(types.FunctionType, getattr(target_cls, name)) spec = compat.inspect_getfullargspec(fn) env = {"__name__": fn.__module__} @@ -709,7 +709,7 @@ def create_proxy_methods( ) proxy_fn = cast( - Callable[..., Any], _exec_code_in_env(code, env, fn.__name__) + types.FunctionType, _exec_code_in_env(code, env, fn.__name__) ) proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__ proxy_fn.__doc__ = inject_docstring_text( @@ -721,9 +721,9 @@ def create_proxy_methods( ) if clslevel: - proxy_fn = classmethod(proxy_fn) - - return proxy_fn + return classmethod(proxy_fn) + else: + return proxy_fn def makeprop(name): attr = target_cls.__dict__.get(name, None) @@ -824,7 +824,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): missing = object() pos_args = [] - kw_args = _collections.OrderedDict() + kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict() vargs = None for i, insp in enumerate(to_inspect): try: @@ -855,7 +855,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): ) ] ) - output = [] + output: List[str] = [] output.extend(repr(getattr(obj, arg, None)) for arg in pos_args) @@ -1007,7 +1007,7 @@ def monkeypatch_proxied_specials( if not hasattr(maybe_fn, "__call__"): continue maybe_fn = getattr(maybe_fn, "__func__", maybe_fn) - fn = cast(Callable[..., Any], maybe_fn) + fn = cast(types.FunctionType, maybe_fn) except AttributeError: continue @@ -1024,7 +1024,9 @@ def monkeypatch_proxied_specials( "return %(name)s.%(method)s%(d_args)s" % locals() ) - env = from_instance is not None and {name: from_instance} or {} + env: Dict[str, types.FunctionType] = ( + from_instance is not None and {name: from_instance} or {} + ) exec(py, env) try: env[method].__defaults__ = fn.__defaults__ @@ -1482,6 +1484,7 @@ def dictlike_iteritems(dictlike): def iterator(): for key in dictlike.iterkeys(): + assert getter is not None yield key, getter(key) return iterator() @@ -1989,7 +1992,7 @@ def quoted_token_parser(value): # 0 = outside of quotes # 1 = inside of quotes state = 0 - result = [[]] + result: List[List[str]] = [[]] idx = 0 lv = len(value) while idx < lv: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 291061561..160eabd85 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -7,8 +7,6 @@ from typing import Callable # noqa from typing import cast from typing import Dict from typing import ForwardRef -from typing import Generic -from typing import overload from typing import Type from typing import TypeVar from typing import Union @@ -58,35 +56,6 @@ else: from typing import ParamSpec as ParamSpec # noqa F401 -class _TypeToInstance(Generic[_T]): - """describe a variable that moves between a class and an instance of - that class. - - """ - - @overload - def __get__(self, instance: None, owner: Any) -> Type[_T]: - ... - - @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... - - def __get__(self, instance: object, owner: Any) -> Union[Type[_T], _T]: - ... - - @overload - def __set__(self, instance: None, value: Type[_T]) -> None: - ... - - @overload - def __set__(self, instance: object, value: _T) -> None: - ... - - def __set__(self, instance: object, value: Union[Type[_T], _T]) -> None: - ... - - def de_stringify_annotation( cls: Type[Any], annotation: Union[str, Type[Any]] ) -> Union[str, Type[Any]]: |
