diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-10 11:57:00 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-12 11:42:50 -0500 |
| commit | 4c28867f944637ef313f98d5f09da05255418c6d (patch) | |
| tree | f68776450fc91df8085446d517020603b879d0f8 /lib/sqlalchemy/engine/default.py | |
| parent | 03989d1dce80999bb9ea1a7d36df3285e5ce4c3b (diff) | |
| download | sqlalchemy-4c28867f944637ef313f98d5f09da05255418c6d.tar.gz | |
additional mypy strictness
enable type checking within untyped defs. This allowed
some more internals to be fixed up with assertions etc.
some internals that were unnecessary or not even used
at all were removed. BaseCursorResult was no longer
necessary since we only have one kind of CursorResult
now. The different ResultProxy subclasses that had
alternate "strategies" dont appear to be used at all
even in 1.4.x, as there's no code that accesses the
_cursor_strategy_cls attribute, which is also removed.
As these were mostly private constructs that weren't
even functioning correctly in any case,
it's fine to remove these over the 2.0 boundary.
Change-Id: Ifd536987d104b1cd8b546cefdbd5c1e5d1801082
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 109 |
1 files changed, 74 insertions, 35 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 0e0c76389..2579f573c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -55,6 +55,9 @@ from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name if typing.TYPE_CHECKING: + from .base import Connection + from .base import Engine + from .characteristics import ConnectionCharacteristic from .interfaces import _AnyMultiExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams @@ -62,6 +65,7 @@ if typing.TYPE_CHECKING: from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions + from .interfaces import _MutableCoreSingleExecuteParams from .result import _ProcessorType from .row import Row from .url import URL @@ -71,6 +75,7 @@ if typing.TYPE_CHECKING: from ..sql import Executable from ..sql.compiler import Compiled from ..sql.compiler import ResultColumnsEntry + from ..sql.compiler import TypeCompiler from ..sql.schema import Column from ..sql.type_api import TypeEngine @@ -92,7 +97,11 @@ class DefaultDialect(Dialect): statement_compiler = compiler.SQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.GenericTypeCompiler # type: ignore + if typing.TYPE_CHECKING: + type_compiler: TypeCompiler + else: + type_compiler = compiler.GenericTypeCompiler + preparer = compiler.IdentifierPreparer supports_alter = True supports_comments = False @@ -202,7 +211,7 @@ class DefaultDialect(Dialect): server_version_info = None - default_schema_name = None + default_schema_name: Optional[str] = None # indicates symbol names are # UPPERCASEd if they are case insensitive @@ -290,7 +299,12 @@ class DefaultDialect(Dialect): self.positional = self.paramstyle in ("qmark", "format", "numeric") self.identifier_preparer = self.preparer(self) self._on_connect_isolation_level = isolation_level - self.type_compiler = self.type_compiler(self) + + tt_callable = cast( + Type[compiler.GenericTypeCompiler], + self.type_compiler, + ) + self.type_compiler = tt_callable(self) if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean @@ -490,12 +504,14 @@ class DefaultDialect(Dialect): opts.update(url.query) return [[], opts] - def set_engine_execution_options(self, engine, opts): + def set_engine_execution_options( + self, engine: Engine, opts: Mapping[str, str] + ) -> None: supported_names = set(self.connection_characteristics).intersection( opts ) if supported_names: - characteristics = util.immutabledict( + characteristics: Mapping[str, str] = util.immutabledict( (name, opts[name]) for name in supported_names ) @@ -505,12 +521,14 @@ class DefaultDialect(Dialect): connection, characteristics ) - def set_connection_execution_options(self, connection, opts): + def set_connection_execution_options( + self, connection: Connection, opts: Mapping[str, str] + ) -> None: supported_names = set(self.connection_characteristics).intersection( opts ) if supported_names: - characteristics = util.immutabledict( + characteristics: Mapping[str, str] = util.immutabledict( (name, opts[name]) for name in supported_names ) self._set_connection_characteristics(connection, characteristics) @@ -800,7 +818,7 @@ class DefaultExecutionContext(ExecutionContext): dialect: Dialect unicode_statement: str cursor: DBAPICursor - compiled_parameters: _CoreMultiExecuteParams + compiled_parameters: List[_MutableCoreSingleExecuteParams] parameters: _DBAPIMultiExecuteParams extracted_parameters: _CoreSingleExecuteParams @@ -1157,7 +1175,11 @@ class DefaultExecutionContext(ExecutionContext): parameters = {} conn._cursor_execute(self.cursor, stmt, parameters, context=self) - r = self.cursor.fetchone()[0] + row = self.cursor.fetchone() + if row is not None: + r = row[0] + else: + r = None if type_ is not None: # apply type post processors to the result proc = type_._cached_result_processor( @@ -1299,10 +1321,11 @@ class DefaultExecutionContext(ExecutionContext): result = _cursor.CursorResult(self, strategy, cursor_description) + compiled = self.compiled if ( - self.compiled + compiled and not self.isddl - and self.compiled.has_out_parameters + and cast(SQLCompiler, compiled).has_out_parameters ): self._setup_out_parameters(result) @@ -1311,10 +1334,11 @@ class DefaultExecutionContext(ExecutionContext): return result def _setup_out_parameters(self, result): + compiled = cast(SQLCompiler, self.compiled) out_bindparams = [ (param, name) - for param, name in self.compiled.bind_names.items() + for param, name in compiled.bind_names.items() if param.isoutparam ] out_parameters = {} @@ -1339,9 +1363,10 @@ class DefaultExecutionContext(ExecutionContext): result.out_parameters = out_parameters def _setup_dml_or_text_result(self): + compiled = cast(SQLCompiler, self.compiled) if self.isinsert: - if self.compiled.postfetch_lastrowid: + if compiled.postfetch_lastrowid: self.inserted_primary_key_rows = ( self._setup_ins_pk_from_lastrowid() ) @@ -1397,7 +1422,8 @@ class DefaultExecutionContext(ExecutionContext): result.rowcount row = result.fetchone() - self.returned_default_rows = [row] + if row is not None: + self.returned_default_rows = [row] result._soft_close() @@ -1420,13 +1446,17 @@ class DefaultExecutionContext(ExecutionContext): return self._setup_ins_pk_from_empty() def _setup_ins_pk_from_lastrowid(self): - getter = self.compiled._inserted_primary_key_from_lastrowid_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter lastrowid = self.get_lastrowid() return [getter(lastrowid, self.compiled_parameters[0])] def _setup_ins_pk_from_empty(self): - getter = self.compiled._inserted_primary_key_from_lastrowid_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter return [getter(None, param) for param in self.compiled_parameters] def _setup_ins_pk_from_implicit_returning(self, result, rows): @@ -1434,7 +1464,9 @@ class DefaultExecutionContext(ExecutionContext): if not rows: return [] - getter = self.compiled._inserted_primary_key_from_returning_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_returning_getter compiled_params = self.compiled_parameters return [ @@ -1443,7 +1475,7 @@ class DefaultExecutionContext(ExecutionContext): def lastrow_has_defaults(self): return (self.isinsert or self.isupdate) and bool( - self.compiled.postfetch + cast(SQLCompiler, self.compiled).postfetch ) def _set_input_sizes(self): @@ -1464,7 +1496,7 @@ class DefaultExecutionContext(ExecutionContext): if self.isddl or self.is_text: return - compiled = self.compiled + compiled = cast(SQLCompiler, self.compiled) inputsizes = compiled._get_set_input_sizes_lookup() @@ -1487,7 +1519,8 @@ class DefaultExecutionContext(ExecutionContext): if dialect.positional: items = [ - (key, compiled.binds[key]) for key in compiled.positiontup + (key, compiled.binds[key]) + for key in compiled.positiontup or () ] else: items = [ @@ -1495,7 +1528,7 @@ class DefaultExecutionContext(ExecutionContext): for bindparam, key in compiled.bind_names.items() ] - generic_inputsizes = [] + generic_inputsizes: List[Tuple[str, Any, TypeEngine[Any]]] = [] for key, bindparam in items: if bindparam in compiled.literal_execute_params: continue @@ -1578,20 +1611,19 @@ class DefaultExecutionContext(ExecutionContext): compiled_params = compiled.construct_params() processors = compiled._bind_processors if compiled.positional: - positiontup = compiled.positiontup parameters = self.dialect.execute_sequence_format( [ - processors[key](compiled_params[key]) + processors[key](compiled_params[key]) # type: ignore if key in processors else compiled_params[key] - for key in positiontup + for key in compiled.positiontup or () ] ) else: parameters = dict( ( key, - processors[key](compiled_params[key]) + processors[key](compiled_params[key]) # type: ignore if key in processors else compiled_params[key], ) @@ -1667,15 +1699,18 @@ class DefaultExecutionContext(ExecutionContext): "get_current_parameters() can only be invoked in the " "context of a Python side column default function" ) - - compile_state = self.compiled.compile_state + else: + assert column is not None + assert parameters is not None + compile_state = cast(SQLCompiler, self.compiled).compile_state + assert compile_state is not None if ( isolate_multiinsert_groups and self.isinsert and compile_state._has_multi_parameters ): if column._is_multiparam_column: - index = column.index + 1 + index = column.index + 1 # type: ignore d = {column.original.key: parameters[column.key]} else: d = {column.key: parameters[column.key]} @@ -1701,12 +1736,14 @@ class DefaultExecutionContext(ExecutionContext): return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): - key_getter = self.compiled._within_exec_param_key_getter + compiled = cast(SQLCompiler, self.compiled) - scalar_defaults = {} + key_getter = compiled._within_exec_param_key_getter - insert_prefetch = self.compiled.insert_prefetch - update_prefetch = self.compiled.update_prefetch + scalar_defaults: Dict[Column[Any], Any] = {} + + insert_prefetch = compiled.insert_prefetch + update_prefetch = compiled.update_prefetch # pre-determine scalar Python-side defaults # to avoid many calls of get_insert_default()/ @@ -1739,12 +1776,14 @@ class DefaultExecutionContext(ExecutionContext): del self.current_parameters def _process_executesingle_defaults(self): - key_getter = self.compiled._within_exec_param_key_getter + compiled = cast(SQLCompiler, self.compiled) + + key_getter = compiled._within_exec_param_key_getter self.current_parameters = ( compiled_parameters ) = self.compiled_parameters[0] - for c in self.compiled.insert_prefetch: + for c in compiled.insert_prefetch: if c.default and not c.default.is_sequence and c.default.is_scalar: val = c.default.arg else: @@ -1753,7 +1792,7 @@ class DefaultExecutionContext(ExecutionContext): if val is not None: compiled_parameters[key_getter(c)] = val - for c in self.compiled.update_prefetch: + for c in compiled.update_prefetch: val = self.get_update_default(c) if val is not None: |
