summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-03-10 11:57:00 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-03-12 11:42:50 -0500
commit4c28867f944637ef313f98d5f09da05255418c6d (patch)
treef68776450fc91df8085446d517020603b879d0f8 /lib/sqlalchemy/engine/default.py
parent03989d1dce80999bb9ea1a7d36df3285e5ce4c3b (diff)
downloadsqlalchemy-4c28867f944637ef313f98d5f09da05255418c6d.tar.gz
additional mypy strictness
enable type checking within untyped defs. This allowed some more internals to be fixed up with assertions etc. some internals that were unnecessary or not even used at all were removed. BaseCursorResult was no longer necessary since we only have one kind of CursorResult now. The different ResultProxy subclasses that had alternate "strategies" dont appear to be used at all even in 1.4.x, as there's no code that accesses the _cursor_strategy_cls attribute, which is also removed. As these were mostly private constructs that weren't even functioning correctly in any case, it's fine to remove these over the 2.0 boundary. Change-Id: Ifd536987d104b1cd8b546cefdbd5c1e5d1801082
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py109
1 files changed, 74 insertions, 35 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 0e0c76389..2579f573c 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -55,6 +55,9 @@ from ..sql.compiler import SQLCompiler
from ..sql.elements import quoted_name
if typing.TYPE_CHECKING:
+ from .base import Connection
+ from .base import Engine
+ from .characteristics import ConnectionCharacteristic
from .interfaces import _AnyMultiExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
@@ -62,6 +65,7 @@ if typing.TYPE_CHECKING:
from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
+ from .interfaces import _MutableCoreSingleExecuteParams
from .result import _ProcessorType
from .row import Row
from .url import URL
@@ -71,6 +75,7 @@ if typing.TYPE_CHECKING:
from ..sql import Executable
from ..sql.compiler import Compiled
from ..sql.compiler import ResultColumnsEntry
+ from ..sql.compiler import TypeCompiler
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
@@ -92,7 +97,11 @@ class DefaultDialect(Dialect):
statement_compiler = compiler.SQLCompiler
ddl_compiler = compiler.DDLCompiler
- type_compiler = compiler.GenericTypeCompiler # type: ignore
+ if typing.TYPE_CHECKING:
+ type_compiler: TypeCompiler
+ else:
+ type_compiler = compiler.GenericTypeCompiler
+
preparer = compiler.IdentifierPreparer
supports_alter = True
supports_comments = False
@@ -202,7 +211,7 @@ class DefaultDialect(Dialect):
server_version_info = None
- default_schema_name = None
+ default_schema_name: Optional[str] = None
# indicates symbol names are
# UPPERCASEd if they are case insensitive
@@ -290,7 +299,12 @@ class DefaultDialect(Dialect):
self.positional = self.paramstyle in ("qmark", "format", "numeric")
self.identifier_preparer = self.preparer(self)
self._on_connect_isolation_level = isolation_level
- self.type_compiler = self.type_compiler(self)
+
+ tt_callable = cast(
+ Type[compiler.GenericTypeCompiler],
+ self.type_compiler,
+ )
+ self.type_compiler = tt_callable(self)
if supports_native_boolean is not None:
self.supports_native_boolean = supports_native_boolean
@@ -490,12 +504,14 @@ class DefaultDialect(Dialect):
opts.update(url.query)
return [[], opts]
- def set_engine_execution_options(self, engine, opts):
+ def set_engine_execution_options(
+ self, engine: Engine, opts: Mapping[str, str]
+ ) -> None:
supported_names = set(self.connection_characteristics).intersection(
opts
)
if supported_names:
- characteristics = util.immutabledict(
+ characteristics: Mapping[str, str] = util.immutabledict(
(name, opts[name]) for name in supported_names
)
@@ -505,12 +521,14 @@ class DefaultDialect(Dialect):
connection, characteristics
)
- def set_connection_execution_options(self, connection, opts):
+ def set_connection_execution_options(
+ self, connection: Connection, opts: Mapping[str, str]
+ ) -> None:
supported_names = set(self.connection_characteristics).intersection(
opts
)
if supported_names:
- characteristics = util.immutabledict(
+ characteristics: Mapping[str, str] = util.immutabledict(
(name, opts[name]) for name in supported_names
)
self._set_connection_characteristics(connection, characteristics)
@@ -800,7 +818,7 @@ class DefaultExecutionContext(ExecutionContext):
dialect: Dialect
unicode_statement: str
cursor: DBAPICursor
- compiled_parameters: _CoreMultiExecuteParams
+ compiled_parameters: List[_MutableCoreSingleExecuteParams]
parameters: _DBAPIMultiExecuteParams
extracted_parameters: _CoreSingleExecuteParams
@@ -1157,7 +1175,11 @@ class DefaultExecutionContext(ExecutionContext):
parameters = {}
conn._cursor_execute(self.cursor, stmt, parameters, context=self)
- r = self.cursor.fetchone()[0]
+ row = self.cursor.fetchone()
+ if row is not None:
+ r = row[0]
+ else:
+ r = None
if type_ is not None:
# apply type post processors to the result
proc = type_._cached_result_processor(
@@ -1299,10 +1321,11 @@ class DefaultExecutionContext(ExecutionContext):
result = _cursor.CursorResult(self, strategy, cursor_description)
+ compiled = self.compiled
if (
- self.compiled
+ compiled
and not self.isddl
- and self.compiled.has_out_parameters
+ and cast(SQLCompiler, compiled).has_out_parameters
):
self._setup_out_parameters(result)
@@ -1311,10 +1334,11 @@ class DefaultExecutionContext(ExecutionContext):
return result
def _setup_out_parameters(self, result):
+ compiled = cast(SQLCompiler, self.compiled)
out_bindparams = [
(param, name)
- for param, name in self.compiled.bind_names.items()
+ for param, name in compiled.bind_names.items()
if param.isoutparam
]
out_parameters = {}
@@ -1339,9 +1363,10 @@ class DefaultExecutionContext(ExecutionContext):
result.out_parameters = out_parameters
def _setup_dml_or_text_result(self):
+ compiled = cast(SQLCompiler, self.compiled)
if self.isinsert:
- if self.compiled.postfetch_lastrowid:
+ if compiled.postfetch_lastrowid:
self.inserted_primary_key_rows = (
self._setup_ins_pk_from_lastrowid()
)
@@ -1397,7 +1422,8 @@ class DefaultExecutionContext(ExecutionContext):
result.rowcount
row = result.fetchone()
- self.returned_default_rows = [row]
+ if row is not None:
+ self.returned_default_rows = [row]
result._soft_close()
@@ -1420,13 +1446,17 @@ class DefaultExecutionContext(ExecutionContext):
return self._setup_ins_pk_from_empty()
def _setup_ins_pk_from_lastrowid(self):
- getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+ getter = cast(
+ SQLCompiler, self.compiled
+ )._inserted_primary_key_from_lastrowid_getter
lastrowid = self.get_lastrowid()
return [getter(lastrowid, self.compiled_parameters[0])]
def _setup_ins_pk_from_empty(self):
- getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+ getter = cast(
+ SQLCompiler, self.compiled
+ )._inserted_primary_key_from_lastrowid_getter
return [getter(None, param) for param in self.compiled_parameters]
def _setup_ins_pk_from_implicit_returning(self, result, rows):
@@ -1434,7 +1464,9 @@ class DefaultExecutionContext(ExecutionContext):
if not rows:
return []
- getter = self.compiled._inserted_primary_key_from_returning_getter
+ getter = cast(
+ SQLCompiler, self.compiled
+ )._inserted_primary_key_from_returning_getter
compiled_params = self.compiled_parameters
return [
@@ -1443,7 +1475,7 @@ class DefaultExecutionContext(ExecutionContext):
def lastrow_has_defaults(self):
return (self.isinsert or self.isupdate) and bool(
- self.compiled.postfetch
+ cast(SQLCompiler, self.compiled).postfetch
)
def _set_input_sizes(self):
@@ -1464,7 +1496,7 @@ class DefaultExecutionContext(ExecutionContext):
if self.isddl or self.is_text:
return
- compiled = self.compiled
+ compiled = cast(SQLCompiler, self.compiled)
inputsizes = compiled._get_set_input_sizes_lookup()
@@ -1487,7 +1519,8 @@ class DefaultExecutionContext(ExecutionContext):
if dialect.positional:
items = [
- (key, compiled.binds[key]) for key in compiled.positiontup
+ (key, compiled.binds[key])
+ for key in compiled.positiontup or ()
]
else:
items = [
@@ -1495,7 +1528,7 @@ class DefaultExecutionContext(ExecutionContext):
for bindparam, key in compiled.bind_names.items()
]
- generic_inputsizes = []
+ generic_inputsizes: List[Tuple[str, Any, TypeEngine[Any]]] = []
for key, bindparam in items:
if bindparam in compiled.literal_execute_params:
continue
@@ -1578,20 +1611,19 @@ class DefaultExecutionContext(ExecutionContext):
compiled_params = compiled.construct_params()
processors = compiled._bind_processors
if compiled.positional:
- positiontup = compiled.positiontup
parameters = self.dialect.execute_sequence_format(
[
- processors[key](compiled_params[key])
+ processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
- for key in positiontup
+ for key in compiled.positiontup or ()
]
)
else:
parameters = dict(
(
key,
- processors[key](compiled_params[key])
+ processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key],
)
@@ -1667,15 +1699,18 @@ class DefaultExecutionContext(ExecutionContext):
"get_current_parameters() can only be invoked in the "
"context of a Python side column default function"
)
-
- compile_state = self.compiled.compile_state
+ else:
+ assert column is not None
+ assert parameters is not None
+ compile_state = cast(SQLCompiler, self.compiled).compile_state
+ assert compile_state is not None
if (
isolate_multiinsert_groups
and self.isinsert
and compile_state._has_multi_parameters
):
if column._is_multiparam_column:
- index = column.index + 1
+ index = column.index + 1 # type: ignore
d = {column.original.key: parameters[column.key]}
else:
d = {column.key: parameters[column.key]}
@@ -1701,12 +1736,14 @@ class DefaultExecutionContext(ExecutionContext):
return self._exec_default(column, column.onupdate, column.type)
def _process_executemany_defaults(self):
- key_getter = self.compiled._within_exec_param_key_getter
+ compiled = cast(SQLCompiler, self.compiled)
- scalar_defaults = {}
+ key_getter = compiled._within_exec_param_key_getter
- insert_prefetch = self.compiled.insert_prefetch
- update_prefetch = self.compiled.update_prefetch
+ scalar_defaults: Dict[Column[Any], Any] = {}
+
+ insert_prefetch = compiled.insert_prefetch
+ update_prefetch = compiled.update_prefetch
# pre-determine scalar Python-side defaults
# to avoid many calls of get_insert_default()/
@@ -1739,12 +1776,14 @@ class DefaultExecutionContext(ExecutionContext):
del self.current_parameters
def _process_executesingle_defaults(self):
- key_getter = self.compiled._within_exec_param_key_getter
+ compiled = cast(SQLCompiler, self.compiled)
+
+ key_getter = compiled._within_exec_param_key_getter
self.current_parameters = (
compiled_parameters
) = self.compiled_parameters[0]
- for c in self.compiled.insert_prefetch:
+ for c in compiled.insert_prefetch:
if c.default and not c.default.is_sequence and c.default.is_scalar:
val = c.default.arg
else:
@@ -1753,7 +1792,7 @@ class DefaultExecutionContext(ExecutionContext):
if val is not None:
compiled_parameters[key_getter(c)] = val
- for c in self.compiled.update_prefetch:
+ for c in compiled.update_prefetch:
val = self.get_update_default(c)
if val is not None: