diff options
74 files changed, 2963 insertions, 1530 deletions
diff --git a/doc/build/changelog/unreleased_20/eng_ex_opt.rst b/doc/build/changelog/unreleased_20/eng_ex_opt.rst new file mode 100644 index 000000000..00947f3de --- /dev/null +++ b/doc/build/changelog/unreleased_20/eng_ex_opt.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: engine, feature + + The :meth:`.ConnectionEvents.set_connection_execution_options` + and :meth:`.ConnectionEvents.set_engine_execution_options` + event hooks now allow the given options dictionary to be modified + in-place, where the new contents will be received as the ultimate + execution options to be acted upon. Previously, in-place modifications to + the dictionary were not supported. diff --git a/doc/build/tutorial/data_insert.rst b/doc/build/tutorial/data_insert.rst index 74b0aff56..a8b1a49a2 100644 --- a/doc/build/tutorial/data_insert.rst +++ b/doc/build/tutorial/data_insert.rst @@ -127,7 +127,7 @@ illustrate this: ... conn.commit() {opensql}BEGIN (implicit) INSERT INTO user_account (name, fullname) VALUES (?, ?) - [...] (('sandy', 'Sandy Cheeks'), ('patrick', 'Patrick Star')) + [...] [('sandy', 'Sandy Cheeks'), ('patrick', 'Patrick Star')] COMMIT{stop} The execution above features "executemany" form first illustrated at @@ -185,8 +185,8 @@ construct automatically. INSERT INTO address (user_id, email_address) VALUES ((SELECT user_account.id FROM user_account WHERE user_account.name = ?), ?) - [...] (('spongebob', 'spongebob@sqlalchemy.org'), ('sandy', 'sandy@sqlalchemy.org'), - ('sandy', 'sandy@squirrelpower.org')) + [...] [('spongebob', 'spongebob@sqlalchemy.org'), ('sandy', 'sandy@sqlalchemy.org'), + ('sandy', 'sandy@squirrelpower.org')] COMMIT{stop} .. _tutorial_insert_from_select: diff --git a/doc/build/tutorial/data_update.rst b/doc/build/tutorial/data_update.rst index 8813dda98..8e88eb2f7 100644 --- a/doc/build/tutorial/data_update.rst +++ b/doc/build/tutorial/data_update.rst @@ -101,7 +101,7 @@ that literal values would normally go: ... ) {opensql}BEGIN (implicit) UPDATE user_account SET name=? WHERE user_account.name = ? - [...] (('ed', 'jack'), ('mary', 'wendy'), ('jake', 'jim')) + [...] [('ed', 'jack'), ('mary', 'wendy'), ('jake', 'jim')] <sqlalchemy.engine.cursor.CursorResult object at 0x...> COMMIT{stop} diff --git a/doc/build/tutorial/dbapi_transactions.rst b/doc/build/tutorial/dbapi_transactions.rst index 16768da2b..f4d2ad8e0 100644 --- a/doc/build/tutorial/dbapi_transactions.rst +++ b/doc/build/tutorial/dbapi_transactions.rst @@ -115,7 +115,7 @@ where we acquired the :class:`_future.Connection` object: [...] () <sqlalchemy.engine.cursor.CursorResult object at 0x...> INSERT INTO some_table (x, y) VALUES (?, ?) - [...] ((1, 1), (2, 4)) + [...] [(1, 1), (2, 4)] <sqlalchemy.engine.cursor.CursorResult object at 0x...> COMMIT @@ -149,7 +149,7 @@ may be referred towards as **begin once**: ... ) {opensql}BEGIN (implicit) INSERT INTO some_table (x, y) VALUES (?, ?) - [...] ((6, 8), (9, 10)) + [...] [(6, 8), (9, 10)] <sqlalchemy.engine.cursor.CursorResult object at 0x...> COMMIT @@ -374,7 +374,7 @@ be invoked against each parameter set individually: ... conn.commit() {opensql}BEGIN (implicit) INSERT INTO some_table (x, y) VALUES (?, ?) - [...] ((11, 12), (13, 14)) + [...] [(11, 12), (13, 14)] <sqlalchemy.engine.cursor.CursorResult object at 0x...> COMMIT @@ -508,7 +508,7 @@ our data: ... session.commit() {opensql}BEGIN (implicit) UPDATE some_table SET y=? WHERE x=? - [...] ((11, 9), (15, 13)) + [...] [(11, 9), (15, 13)] COMMIT{stop} Above, we invoked an UPDATE statement using the bound-parameter, "executemany" diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py index 132a0a4de..f4fa5b66b 100644 --- a/lib/sqlalchemy/connectors/__init__.py +++ b/lib/sqlalchemy/connectors/__init__.py @@ -6,5 +6,13 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php -class Connector: - pass +from ..engine.interfaces import Dialect + + +class Connector(Dialect): + """Base class for dialect mixins, for DBAPIs that work + across entirely different database backends. + + Currently the only such mixin is pyodbc. + + """ diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index f7d01ce43..c5f07de07 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -5,12 +5,27 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import re +from types import ModuleType +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union from urllib.parse import unquote_plus from . import Connector +from .. import ExecutionContext +from .. import pool from .. import util +from ..engine import ConnectArgsType +from ..engine import Connection from ..engine import interfaces +from ..engine import URL +from ..sql.type_api import TypeEngine class PyODBCConnector(Connector): @@ -25,18 +40,20 @@ class PyODBCConnector(Connector): # for non-DSN connections, this *may* be used to # hold the desired driver name - pyodbc_driver_name = None + pyodbc_driver_name: Optional[str] = None + + dbapi: ModuleType - def __init__(self, use_setinputsizes=False, **kw): + def __init__(self, use_setinputsizes: bool = False, **kw: Any): super(PyODBCConnector, self).__init__(**kw) if use_setinputsizes: self.bind_typing = interfaces.BindTyping.SETINPUTSIZES @classmethod - def dbapi(cls): + def import_dbapi(cls) -> ModuleType: return __import__("pyodbc") - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: opts = url.translate_connect_args(username="user") opts.update(url.query) @@ -44,7 +61,9 @@ class PyODBCConnector(Connector): query = url.query - connect_args = {} + connect_args: Dict[str, Any] = {} + connectors: List[str] + for param in ("ansi", "unicode_results", "autocommit"): if param in keys: connect_args[param] = util.asbool(keys.pop(param)) @@ -53,7 +72,7 @@ class PyODBCConnector(Connector): connectors = [unquote_plus(keys.pop("odbc_connect"))] else: - def check_quote(token): + def check_quote(token: str) -> str: if ";" in str(token): token = "{%s}" % token.replace("}", "}}") return token @@ -115,9 +134,14 @@ class PyODBCConnector(Connector): connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()]) - return [[";".join(connectors)], connect_args] + return ((";".join(connectors),), connect_args) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: Exception, + connection: Optional[pool.PoolProxiedConnection], + cursor: Optional[interfaces.DBAPICursor], + ) -> bool: if isinstance(e, self.dbapi.ProgrammingError): return "The cursor's connection has been closed." in str( e @@ -125,36 +149,44 @@ class PyODBCConnector(Connector): else: return False - def _dbapi_version(self): + def _dbapi_version(self) -> interfaces.VersionInfoType: if not self.dbapi: return () return self._parse_dbapi_version(self.dbapi.version) - def _parse_dbapi_version(self, vers): + def _parse_dbapi_version(self, vers: str) -> interfaces.VersionInfoType: m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers) if not m: return () - vers = tuple([int(x) for x in m.group(1).split(".")]) + vers_tuple: interfaces.VersionInfoType = tuple( + [int(x) for x in m.group(1).split(".")] + ) if m.group(2): - vers += (m.group(2),) - return vers + vers_tuple += (m.group(2),) + return vers_tuple - def _get_server_version_info(self, connection, allow_chars=True): + def _get_server_version_info( + self, connection: Connection + ) -> interfaces.VersionInfoType: # NOTE: this function is not reliable, particularly when # freetds is in use. Implement database-specific server version # queries. - dbapi_con = connection.connection - version = [] + dbapi_con = connection.connection.dbapi_connection + version: Tuple[Union[int, str], ...] = () r = re.compile(r"[.\-]") - for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): # type: ignore[union-attr] # noqa E501 try: - version.append(int(n)) + version += (int(n),) except ValueError: - if allow_chars: - version.append(n) + pass return tuple(version) - def do_set_input_sizes(self, cursor, list_of_tuples, context): + def do_set_input_sizes( + self, + cursor: interfaces.DBAPICursor, + list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]], + context: ExecutionContext, + ) -> None: # the rules for these types seems a little strange, as you can pass # non-tuples as well as tuples, however it seems to assume "0" # for the subsequent values if you don't pass a tuple which fails @@ -174,12 +206,16 @@ class PyODBCConnector(Connector): ] ) - def get_isolation_level_values(self, dbapi_connection): - return super().get_isolation_level_values(dbapi_connection) + [ + def get_isolation_level_values( + self, dbapi_connection: interfaces.DBAPIConnection + ) -> List[str]: + return super().get_isolation_level_values(dbapi_connection) + [ # type: ignore # noqa E501 "AUTOCOMMIT" ] - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: interfaces.DBAPIConnection, level: str + ) -> None: # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" # to work properly @@ -188,6 +224,4 @@ class PyODBCConnector(Connector): dbapi_connection.autocommit = True else: dbapi_connection.autocommit = False - super(PyODBCConnector, self).set_isolation_level( - dbapi_connection, level - ) + super().set_isolation_level(dbapi_connection, level) diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx index daf5cc940..e88c8ec0b 100644 --- a/lib/sqlalchemy/cyextension/resultproxy.pyx +++ b/lib/sqlalchemy/cyextension/resultproxy.pyx @@ -7,8 +7,6 @@ cdef int MD_INDEX = 0 # integer index in cursor.description KEY_INTEGER_ONLY = 0 KEY_OBJECTS_ONLY = 1 -sqlalchemy_engine_row = None - cdef class BaseRow: cdef readonly object _parent cdef readonly tuple _data @@ -53,19 +51,6 @@ cdef class BaseRow: self._keymap = self._parent._keymap self._key_style = state["_key_style"] - def _filter_on_values(self, filters): - global sqlalchemy_engine_row - if sqlalchemy_engine_row is None: - from sqlalchemy.engine.row import Row as sqlalchemy_engine_row - - return sqlalchemy_engine_row( - self._parent, - filters, - self._keymap, - self._key_style, - self._data, - ) - def _values_impl(self): return list(self) @@ -78,18 +63,8 @@ cdef class BaseRow: def __hash__(self): return hash(self._data) - def _get_by_int_impl(self, key): - return self._data[key] - - cpdef _get_by_key_impl(self, key): - # keep two isinstance since it's noticeably faster in the int case - if isinstance(key, int) or isinstance(key, slice): - return self._data[key] - - self._parent._raise_for_nonint(key) - - def __getitem__(self, key): - return self._get_by_key_impl(key) + def __getitem__(self, index): + return self._data[index] cpdef _get_by_key_impl_mapping(self, key): try: diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 20d2b09db..8d654b72d 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -77,7 +77,7 @@ class MSDialect_pymssql(MSDialect): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): module = __import__("pymssql") # pymmsql < 2.1.1 doesn't have a Binary method. we use string client_ver = tuple(int(x) for x in module.__version__.split(".")) @@ -106,7 +106,7 @@ class MSDialect_pymssql(MSDialect): port = opts.pop("port", None) if port and "host" in opts: opts["host"] = "%s:%s" % (opts["host"], port) - return [[], opts] + return ([], opts) def is_disconnect(self, e, connection, cursor): for msg in ( diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 3af083a6a..0951f219b 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -538,7 +538,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): # 2008. Before we had the VARCHAR cast above, pyodbc would also # fail on this query. return super(MSDialect_pyodbc, self)._get_server_version_info( - connection, allow_chars=False + connection ) else: version = [] diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index df716346e..d685b7ea1 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -276,7 +276,7 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): is_async = True @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_aiomysql_dbapi( __import__("aiomysql"), __import__("pymysql") ) diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 915b666bb..7d5b1bf86 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -288,7 +288,7 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): is_async = True @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) @classmethod diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 5c2de0911..b2ccfc90f 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2434,7 +2434,7 @@ class MySQLDialect(default.DefaultDialect): @classmethod def _is_mariadb_from_url(cls, url): - dbapi = cls.dbapi() + dbapi = cls.import_dbapi() dialect = cls(dbapi=dbapi) cargs, cparams = dialect.create_connect_args(url) diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index 9fd0b4a09..281c509b7 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -53,7 +53,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("cymysql") def _detect_charset(self, connection): diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index fca91204f..bf2b04251 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -111,7 +111,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("mariadb") def is_disconnect(self, e, connection, cursor): diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index c96b739dc..a69dac9a5 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -77,7 +77,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect): colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) @classmethod - def dbapi(cls): + def import_dbapi(cls): from mysql import connector return connector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index b4f071de0..6d66f88b4 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -159,7 +159,7 @@ class MySQLDialect_mysqldb(MySQLDialect): return False @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("MySQLdb") def on_connect(self): diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index eddb9c921..9a240da61 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -57,7 +57,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): return False @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("pymysql") def create_connect_args(self, url, _translate_args=None): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index a390099ae..98181051e 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -963,7 +963,7 @@ class OracleDialect_cx_oracle(OracleDialect): return (0, 0, 0) @classmethod - def dbapi(cls): + def import_dbapi(cls): import cx_Oracle return cx_Oracle diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 4c3c47ba6..75f6c2704 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -878,7 +878,7 @@ class PGDialect_asyncpg(PGDialect): return (99, 99, 99) @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg")) @util.memoized_property diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index c23da93bb..372b8639e 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -426,7 +426,7 @@ class PGDialect_pg8000(PGDialect): return (99, 99, 99) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("pg8000") def create_connect_args(self, url): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 3ba535d6c..33dc65afc 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -281,7 +281,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): register_hstore(info, connection.connection) @classmethod - def dbapi(cls): + def import_dbapi(cls): import psycopg return psycopg @@ -592,7 +592,7 @@ class PGDialectAsync_psycopg(PGDialect_psycopg): supports_statement_cache = True @classmethod - def dbapi(cls): + def import_dbapi(cls): import psycopg from psycopg.pq import ExecStatus diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index a08c5e5b0..dddce5a62 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -612,7 +612,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): import psycopg2 return psycopg2 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py index 5a4dcb2e6..0943613a2 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -44,7 +44,7 @@ class PGDialect_psycopg2cffi(PGDialect_psycopg2): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("psycopg2cffi") @util.memoized_property diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index e88ab1a0f..dd0499975 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -308,7 +308,7 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): execution_ctx_cls = SQLiteExecutionContext_aiosqlite @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_aiosqlite_dbapi( __import__("aiosqlite"), __import__("sqlite3") ) diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 28f795298..b67eed974 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -105,7 +105,7 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac") @classmethod - def dbapi(cls): + def import_dbapi(cls): try: import sqlcipher3 as sqlcipher except ImportError: diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 8476e6834..2aa7149a6 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -465,7 +465,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): driver = "pysqlite" @classmethod - def dbapi(cls): + def import_dbapi(cls): from sqlite3 import dbapi2 as sqlite return sqlite diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index c6bc4b6aa..32f3f2ecc 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -35,6 +35,7 @@ from .cursor import ResultProxy as ResultProxy from .interfaces import AdaptedConnection as AdaptedConnection from .interfaces import BindTyping as BindTyping from .interfaces import Compiled as Compiled +from .interfaces import ConnectArgsType as ConnectArgsType from .interfaces import CreateEnginePlugin as CreateEnginePlugin from .interfaces import Dialect as Dialect from .interfaces import ExceptionContext as ExceptionContext diff --git a/lib/sqlalchemy/engine/_py_processors.py b/lib/sqlalchemy/engine/_py_processors.py index e3024471a..27cb9e939 100644 --- a/lib/sqlalchemy/engine/_py_processors.py +++ b/lib/sqlalchemy/engine/_py_processors.py @@ -16,16 +16,30 @@ They all share one common characteristic: None is passed through unchanged. from __future__ import annotations import datetime +from decimal import Decimal import re +import typing +from typing import Any +from typing import Callable +from typing import Optional +from typing import Type +from typing import TypeVar +from typing import Union + +_DT = TypeVar( + "_DT", bound=Union[datetime.datetime, datetime.time, datetime.date] +) -def str_to_datetime_processor_factory(regexp, type_): +def str_to_datetime_processor_factory( + regexp: typing.Pattern[str], type_: Callable[..., _DT] +) -> Callable[[Optional[str]], Optional[_DT]]: rmatch = regexp.match # Even on python2.6 datetime.strptime is both slower than this code # and it does not support microseconds. has_named_groups = bool(regexp.groupindex) - def process(value): + def process(value: Optional[str]) -> Optional[_DT]: if value is None: return None else: @@ -59,10 +73,12 @@ def str_to_datetime_processor_factory(regexp, type_): return process -def to_decimal_processor_factory(target_class, scale): +def to_decimal_processor_factory( + target_class: Type[Decimal], scale: int +) -> Callable[[Optional[float]], Optional[Decimal]]: fstring = "%%.%df" % scale - def process(value): + def process(value: Optional[float]) -> Optional[Decimal]: if value is None: return None else: @@ -71,21 +87,21 @@ def to_decimal_processor_factory(target_class, scale): return process -def to_float(value): +def to_float(value: Optional[Union[int, float]]) -> Optional[float]: if value is None: return None else: return float(value) -def to_str(value): +def to_str(value: Optional[Any]) -> Optional[str]: if value is None: return None else: return str(value) -def int_to_boolean(value): +def int_to_boolean(value: Optional[int]) -> Optional[bool]: if value is None: return None else: diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py index a6d5b79d5..7cbac552f 100644 --- a/lib/sqlalchemy/engine/_py_row.py +++ b/lib/sqlalchemy/engine/_py_row.py @@ -1,26 +1,59 @@ from __future__ import annotations +import enum import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +if typing.TYPE_CHECKING: + from .result import _KeyMapType + from .result import _KeyType + from .result import _ProcessorsType + from .result import _RawRowType + from .result import _TupleGetterType + from .result import ResultMetaData MD_INDEX = 0 # integer index in cursor.description -KEY_INTEGER_ONLY = 0 -"""__getitem__ only allows integer values and slices, raises TypeError - otherwise""" -KEY_OBJECTS_ONLY = 1 -"""__getitem__ only allows string/object values, raises TypeError otherwise""" +class _KeyStyle(enum.Enum): + KEY_INTEGER_ONLY = 0 + """__getitem__ only allows integer values and slices, raises TypeError + otherwise""" -sqlalchemy_engine_row = None + KEY_OBJECTS_ONLY = 1 + """__getitem__ only allows string/object values, raises TypeError + otherwise""" + + +KEY_INTEGER_ONLY, KEY_OBJECTS_ONLY = list(_KeyStyle) class BaseRow: - Row = None __slots__ = ("_parent", "_data", "_keymap", "_key_style") - def __init__(self, parent, processors, keymap, key_style, data): + _parent: ResultMetaData + _data: _RawRowType + _keymap: _KeyMapType + _key_style: _KeyStyle + + def __init__( + self, + parent: ResultMetaData, + processors: Optional[_ProcessorsType], + keymap: _KeyMapType, + key_style: _KeyStyle, + data: _RawRowType, + ): """Row objects are constructed by CursorResult objects.""" - object.__setattr__(self, "_parent", parent) if processors: @@ -41,68 +74,45 @@ class BaseRow: object.__setattr__(self, "_key_style", key_style) - def __reduce__(self): + def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]: return ( rowproxy_reconstructor, (self.__class__, self.__getstate__()), ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "_parent": self._parent, "_data": self._data, "_key_style": self._key_style, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: parent = state["_parent"] object.__setattr__(self, "_parent", parent) object.__setattr__(self, "_data", state["_data"]) object.__setattr__(self, "_keymap", parent._keymap) object.__setattr__(self, "_key_style", state["_key_style"]) - def _filter_on_values(self, filters): - global sqlalchemy_engine_row - if sqlalchemy_engine_row is None: - from sqlalchemy.engine.row import Row as sqlalchemy_engine_row - - return sqlalchemy_engine_row( - self._parent, - filters, - self._keymap, - self._key_style, - self._data, - ) - - def _values_impl(self): + def _values_impl(self) -> List[Any]: return list(self) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self._data) - def __len__(self): + def __len__(self) -> int: return len(self._data) - def __hash__(self): + def __hash__(self) -> int: return hash(self._data) - def _get_by_int_impl(self, key): + def _get_by_int_impl(self, key: Union[int, slice]) -> Any: return self._data[key] - def _get_by_key_impl(self, key): - # keep two isinstance since it's noticeably faster in the int case - if isinstance(key, int) or isinstance(key, slice): - return self._data[key] - - self._parent._raise_for_nonint(key) - - # The original 1.4 plan was that Row would not allow row["str"] - # access, however as the C extensions were inadvertently allowing - # this coupled with the fact that orm Session sets future=True, - # this allows a softer upgrade path. see #6218 - __getitem__ = _get_by_key_impl + if not typing.TYPE_CHECKING: + __getitem__ = _get_by_int_impl - def _get_by_key_impl_mapping(self, key): + def _get_by_key_impl_mapping(self, key: _KeyType) -> Any: try: rec = self._keymap[key] except KeyError as ke: @@ -116,7 +126,7 @@ class BaseRow: return self._data[mdindex] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: try: return self._get_by_key_impl_mapping(name) except KeyError as e: @@ -125,13 +135,15 @@ class BaseRow: # This reconstructor is necessary so that pickles with the Cy extension or # without use the same Binary format. -def rowproxy_reconstructor(cls, state): +def rowproxy_reconstructor( + cls: Type[BaseRow], state: Dict[str, Any] +) -> BaseRow: obj = cls.__new__(cls) obj.__setstate__(state) return obj -def tuplegetter(*indexes): +def tuplegetter(*indexes: int) -> _TupleGetterType: it = operator.itemgetter(*indexes) if len(indexes) > 1: diff --git a/lib/sqlalchemy/engine/_py_util.py b/lib/sqlalchemy/engine/_py_util.py index ff03a4761..538c075a2 100644 --- a/lib/sqlalchemy/engine/_py_util.py +++ b/lib/sqlalchemy/engine/_py_util.py @@ -1,21 +1,32 @@ from __future__ import annotations -from collections import abc as collections_abc +import typing +from typing import Any +from typing import Mapping +from typing import Optional +from typing import Tuple from .. import exc -_no_tuple = () +if typing.TYPE_CHECKING: + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams -def _distill_params_20(params): +_no_tuple: Tuple[Any, ...] = () + + +def _distill_params_20( + params: Optional[_CoreAnyExecuteParams], +) -> _CoreMultiExecuteParams: if params is None: return _no_tuple # Assume list is more likely than tuple elif isinstance(params, list) or isinstance(params, tuple): # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance( - params[0], (tuple, collections_abc.Mapping) - ): + if params and not isinstance(params[0], (tuple, Mapping)): raise exc.ArgumentError( "List argument must consist only of tuples or dictionaries" ) @@ -25,21 +36,21 @@ def _distill_params_20(params): # only do immutabledict or abc.__instancecheck__ for Mapping after # we've checked for plain dictionaries and would otherwise raise params, - collections_abc.Mapping, + Mapping, ): return [params] else: raise exc.ArgumentError("mapping or list expected for parameters") -def _distill_raw_params(params): +def _distill_raw_params( + params: Optional[_DBAPIAnyExecuteParams], +) -> _DBAPIMultiExecuteParams: if params is None: return _no_tuple elif isinstance(params, list): # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance( - params[0], (tuple, collections_abc.Mapping) - ): + if params and not isinstance(params[0], (tuple, Mapping)): raise exc.ArgumentError( "List argument must consist only of tuples or dictionaries" ) @@ -49,8 +60,9 @@ def _distill_raw_params(params): # only do abc.__instancecheck__ for Mapping after we've checked # for plain dictionaries and would otherwise raise params, - collections_abc.Mapping, + Mapping, ): - return [params] + # cast("Union[List[Mapping[str, Any]], Tuple[Any, ...]]", [params]) + return [params] # type: ignore else: raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 8c99f6309..5ce531338 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -10,13 +10,24 @@ import contextlib import sys import typing from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List from typing import Mapping +from typing import MutableMapping +from typing import NoReturn from typing import Optional +from typing import Tuple +from typing import Type from typing import Union from .interfaces import BindTyping from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPICursor from .interfaces import ExceptionContext +from .interfaces import ExecutionContext from .util import _distill_params_20 from .util import _distill_raw_params from .util import TransactionalContext @@ -26,22 +37,48 @@ from .. import log from .. import util from ..sql import compiler from ..sql import util as sql_util -from ..sql._typing import _ExecuteOptions -from ..sql._typing import _ExecuteParams + +_CompiledCacheType = MutableMapping[Any, Any] if typing.TYPE_CHECKING: + from . import Result + from . import ScalarResult + from .interfaces import _AnyExecuteParams + from .interfaces import _AnyMultiExecuteParams + from .interfaces import _AnySingleExecuteParams + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import _ExecuteOptionsParameter + from .interfaces import _SchemaTranslateMapType from .interfaces import Dialect from .reflection import Inspector # noqa from .url import URL + from ..event import dispatcher + from ..log import _EchoFlagType + from ..pool import _ConnectionFairy from ..pool import Pool from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql.base import SchemaVisitor + from ..sql.compiler import Compiled + from ..sql.ddl import DDLElement + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.functions import FunctionElement + from ..sql.schema import ColumnDefault + from ..sql.schema import HasSchemaAttr """Defines :class:`_engine.Connection` and :class:`_engine.Engine`. """ -_EMPTY_EXECUTION_OPTS = util.immutabledict() -NO_OPTIONS = util.immutabledict() +_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.immutabledict() +NO_OPTIONS: Mapping[str, Any] = util.immutabledict() class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @@ -69,23 +106,32 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ + dispatch: dispatcher[ConnectionEventsTarget] + _sqla_logger_namespace = "sqlalchemy.engine.Connection" # used by sqlalchemy.engine.util.TransactionalContext - _trans_context_manager = None + _trans_context_manager: Optional[TransactionalContext] = None # legacy as of 2.0, should be eventually deprecated and # removed. was used in the "pre_ping" recipe that's been in the docs # a long time should_close_with_result = False + _dbapi_connection: Optional[PoolProxiedConnection] + + _execution_options: _ExecuteOptions + + _transaction: Optional[RootTransaction] + _nested_transaction: Optional[NestedTransaction] + def __init__( self, - engine, - connection=None, - _has_events=None, - _allow_revalidate=True, - _allow_autobegin=True, + engine: Engine, + connection: Optional[PoolProxiedConnection] = None, + _has_events: Optional[bool] = None, + _allow_revalidate: bool = True, + _allow_autobegin: bool = True, ): """Construct a new Connection.""" self.engine = engine @@ -125,14 +171,14 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self.dispatch.engine_connect(self) @util.memoized_property - def _message_formatter(self): + def _message_formatter(self) -> Any: if "logging_token" in self._execution_options: token = self._execution_options["logging_token"] return lambda msg: "[%s] %s" % (token, msg) else: return None - def _log_info(self, message, *arg, **kw): + def _log_info(self, message: str, *arg: Any, **kw: Any) -> None: fmt = self._message_formatter if fmt: @@ -143,7 +189,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self.engine.logger.info(message, *arg, **kw) - def _log_debug(self, message, *arg, **kw): + def _log_debug(self, message: str, *arg: Any, **kw: Any) -> None: fmt = self._message_formatter if fmt: @@ -155,19 +201,19 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self.engine.logger.debug(message, *arg, **kw) @property - def _schema_translate_map(self): + def _schema_translate_map(self) -> Optional[_SchemaTranslateMapType]: return self._execution_options.get("schema_translate_map", None) - def schema_for_object(self, obj): + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: """Return the schema name for the given schema item taking into account current schema translate map. """ name = obj.schema - schema_translate_map = self._execution_options.get( - "schema_translate_map", None - ) + schema_translate_map: Optional[ + Mapping[Optional[str], str] + ] = self._execution_options.get("schema_translate_map", None) if ( schema_translate_map @@ -178,13 +224,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): else: return name - def __enter__(self): + def __enter__(self) -> Connection: return self - def __exit__(self, type_, value, traceback): + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: self.close() - def execution_options(self, **opt): + def execution_options(self, **opt: Any) -> Connection: r"""Set non-SQL options for the connection which take effect during execution. @@ -346,13 +392,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ORM-specific execution options """ # noqa - self._execution_options = self._execution_options.union(opt) if self._has_events or self.engine._has_events: self.dispatch.set_connection_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) self.dialect.set_connection_execution_options(self, opt) return self - def get_execution_options(self): + def get_execution_options(self) -> _ExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded:: 1.3 @@ -364,14 +410,27 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return self._execution_options @property - def closed(self): + def _still_open_and_dbapi_connection_is_valid(self) -> bool: + pool_proxied_connection = self._dbapi_connection + return ( + pool_proxied_connection is not None + and pool_proxied_connection.is_valid + ) + + @property + def closed(self) -> bool: """Return True if this connection is closed.""" return self._dbapi_connection is None and not self.__can_reconnect @property - def invalidated(self): - """Return True if this connection was invalidated.""" + def invalidated(self) -> bool: + """Return True if this connection was invalidated. + + This does not indicate whether or not the connection was + invalidated at the pool level, however + + """ # prior to 1.4, "invalid" was stored as a state independent of # "closed", meaning an invalidated connection could be "closed", @@ -382,10 +441,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): # "closed" does not need to be "invalid". So the state is now # represented by the two facts alone. - return self._dbapi_connection is None and not self.closed + pool_proxied_connection = self._dbapi_connection + return pool_proxied_connection is None and self.__can_reconnect @property - def connection(self) -> "PoolProxiedConnection": + def connection(self) -> PoolProxiedConnection: """The underlying DB-API connection managed by this Connection. This is a SQLAlchemy connection-pool proxied connection @@ -410,7 +470,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): else: return self._dbapi_connection - def get_isolation_level(self): + def get_isolation_level(self) -> str: """Return the current isolation level assigned to this :class:`_engine.Connection`. @@ -442,15 +502,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): - set per :class:`_engine.Connection` isolation level """ + dbapi_connection = self.connection.dbapi_connection + assert dbapi_connection is not None try: - return self.dialect.get_isolation_level( - self.connection.dbapi_connection - ) + return self.dialect.get_isolation_level(dbapi_connection) except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @property - def default_isolation_level(self): + def default_isolation_level(self) -> str: """The default isolation level assigned to this :class:`_engine.Connection`. @@ -482,7 +542,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ return self.dialect.default_isolation_level - def _invalid_transaction(self): + def _invalid_transaction(self) -> NoReturn: raise exc.PendingRollbackError( "Can't reconnect until invalid %stransaction is rolled " "back. Please rollback() fully before proceeding" @@ -490,7 +550,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): code="8s2b", ) - def _revalidate_connection(self): + def _revalidate_connection(self) -> PoolProxiedConnection: if self.__can_reconnect and self.invalidated: if self._transaction is not None: self._invalid_transaction() @@ -499,13 +559,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): raise exc.ResourceClosedError("This Connection is closed") @property - def _still_open_and_dbapi_connection_is_valid(self): - return self._dbapi_connection is not None and getattr( - self._dbapi_connection, "is_valid", False - ) - - @property - def info(self): + def info(self) -> Dict[str, Any]: """Info dictionary associated with the underlying DBAPI connection referred to by this :class:`_engine.Connection`, allowing user-defined data to be associated with the connection. @@ -518,7 +572,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return self.connection.info - def invalidate(self, exception=None): + def invalidate(self, exception: Optional[BaseException] = None) -> None: """Invalidate the underlying DBAPI connection associated with this :class:`_engine.Connection`. @@ -567,14 +621,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self.invalidated: return + # MARKMARK if self.closed: raise exc.ResourceClosedError("This Connection is closed") if self._still_open_and_dbapi_connection_is_valid: - self._dbapi_connection.invalidate(exception) + pool_proxied_connection = self._dbapi_connection + assert pool_proxied_connection is not None + pool_proxied_connection.invalidate(exception) + self._dbapi_connection = None - def detach(self): + def detach(self) -> None: """Detach the underlying DB-API connection from its connection pool. E.g.:: @@ -600,13 +658,21 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ - self._dbapi_connection.detach() + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") - def _autobegin(self): - if self._allow_autobegin: + pool_proxied_connection = self._dbapi_connection + if pool_proxied_connection is None: + raise exc.InvalidRequestError( + "Can't detach an invalidated Connection" + ) + pool_proxied_connection.detach() + + def _autobegin(self) -> None: + if self._allow_autobegin and not self.__in_begin: self.begin() - def begin(self): + def begin(self) -> RootTransaction: """Begin a transaction prior to autobegin occurring. E.g.:: @@ -671,14 +737,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): :class:`_engine.Engine` """ - if self.__in_begin: - # for dialects that emit SQL within the process of - # dialect.do_begin() or dialect.do_begin_twophase(), this - # flag prevents "autobegin" from being emitted within that - # process, while allowing self._transaction to remain at None - # until it's complete. - return - elif self._transaction is None: + if self._transaction is None: self._transaction = RootTransaction(self) return self._transaction else: @@ -689,7 +748,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "is called first." ) - def begin_nested(self): + def begin_nested(self) -> NestedTransaction: """Begin a nested transaction (i.e. SAVEPOINT) and return a transaction handle that controls the scope of the SAVEPOINT. @@ -765,7 +824,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return NestedTransaction(self) - def begin_twophase(self, xid=None): + def begin_twophase(self, xid: Optional[Any] = None) -> TwoPhaseTransaction: """Begin a two-phase or XA transaction and return a transaction handle. @@ -794,7 +853,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): xid = self.engine.dialect.create_xid() return TwoPhaseTransaction(self, xid) - def commit(self): + def commit(self) -> None: """Commit the transaction that is currently in progress. This method commits the current transaction if one has been started. @@ -819,7 +878,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._transaction: self._transaction.commit() - def rollback(self): + def rollback(self) -> None: """Roll back the transaction that is currently in progress. This method rolls back the current transaction if one has been started. @@ -845,33 +904,33 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._transaction: self._transaction.rollback() - def recover_twophase(self): + def recover_twophase(self) -> List[Any]: return self.engine.dialect.do_recover_twophase(self) - def rollback_prepared(self, xid, recover=False): + def rollback_prepared(self, xid: Any, recover: bool = False) -> None: self.engine.dialect.do_rollback_twophase(self, xid, recover=recover) - def commit_prepared(self, xid, recover=False): + def commit_prepared(self, xid: Any, recover: bool = False) -> None: self.engine.dialect.do_commit_twophase(self, xid, recover=recover) - def in_transaction(self): + def in_transaction(self) -> bool: """Return True if a transaction is in progress.""" return self._transaction is not None and self._transaction.is_active - def in_nested_transaction(self): + def in_nested_transaction(self) -> bool: """Return True if a transaction is in progress.""" return ( self._nested_transaction is not None and self._nested_transaction.is_active ) - def _is_autocommit(self): - return ( + def _is_autocommit_isolation(self) -> bool: + return bool( self._execution_options.get("isolation_level", None) == "AUTOCOMMIT" ) - def get_transaction(self): + def get_transaction(self) -> Optional[RootTransaction]: """Return the current root transaction in progress, if any. .. versionadded:: 1.4 @@ -880,7 +939,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return self._transaction - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[NestedTransaction]: """Return the current nested transaction in progress, if any. .. versionadded:: 1.4 @@ -888,7 +947,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ return self._nested_transaction - def _begin_impl(self, transaction): + def _begin_impl(self, transaction: RootTransaction) -> None: if self._echo: self._log_info("BEGIN (implicit)") @@ -904,13 +963,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): finally: self.__in_begin = False - def _rollback_impl(self): + def _rollback_impl(self) -> None: if self._has_events or self.engine._has_events: self.dispatch.rollback(self) if self._still_open_and_dbapi_connection_is_valid: if self._echo: - if self._is_autocommit(): + if self._is_autocommit_isolation(): self._log_info( "ROLLBACK using DBAPI connection.rollback(), " "DBAPI should ignore due to autocommit mode" @@ -922,13 +981,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _commit_impl(self): + def _commit_impl(self) -> None: if self._has_events or self.engine._has_events: self.dispatch.commit(self) if self._echo: - if self._is_autocommit(): + if self._is_autocommit_isolation(): self._log_info( "COMMIT using DBAPI connection.commit(), " "DBAPI should ignore due to autocommit mode" @@ -940,58 +999,54 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _savepoint_impl(self, name=None): + def _savepoint_impl(self, name: Optional[str] = None) -> str: if self._has_events or self.engine._has_events: self.dispatch.savepoint(self, name) if name is None: self.__savepoint_seq += 1 name = "sa_savepoint_%s" % self.__savepoint_seq - if self._still_open_and_dbapi_connection_is_valid: - self.engine.dialect.do_savepoint(self, name) - return name + self.engine.dialect.do_savepoint(self, name) + return name - def _rollback_to_savepoint_impl(self, name): + def _rollback_to_savepoint_impl(self, name: str) -> None: if self._has_events or self.engine._has_events: self.dispatch.rollback_savepoint(self, name, None) if self._still_open_and_dbapi_connection_is_valid: self.engine.dialect.do_rollback_to_savepoint(self, name) - def _release_savepoint_impl(self, name): + def _release_savepoint_impl(self, name: str) -> None: if self._has_events or self.engine._has_events: self.dispatch.release_savepoint(self, name, None) - if self._still_open_and_dbapi_connection_is_valid: - self.engine.dialect.do_release_savepoint(self, name) + self.engine.dialect.do_release_savepoint(self, name) - def _begin_twophase_impl(self, transaction): + def _begin_twophase_impl(self, transaction: TwoPhaseTransaction) -> None: if self._echo: self._log_info("BEGIN TWOPHASE (implicit)") if self._has_events or self.engine._has_events: self.dispatch.begin_twophase(self, transaction.xid) - if self._still_open_and_dbapi_connection_is_valid: - self.__in_begin = True - try: - self.engine.dialect.do_begin_twophase(self, transaction.xid) - except BaseException as e: - self._handle_dbapi_exception(e, None, None, None, None) - finally: - self.__in_begin = False + self.__in_begin = True + try: + self.engine.dialect.do_begin_twophase(self, transaction.xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + self.__in_begin = False - def _prepare_twophase_impl(self, xid): + def _prepare_twophase_impl(self, xid: Any) -> None: if self._has_events or self.engine._has_events: self.dispatch.prepare_twophase(self, xid) - if self._still_open_and_dbapi_connection_is_valid: - assert isinstance(self._transaction, TwoPhaseTransaction) - try: - self.engine.dialect.do_prepare_twophase(self, xid) - except BaseException as e: - self._handle_dbapi_exception(e, None, None, None, None) + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_prepare_twophase(self, xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) - def _rollback_twophase_impl(self, xid, is_prepared): + def _rollback_twophase_impl(self, xid: Any, is_prepared: bool) -> None: if self._has_events or self.engine._has_events: self.dispatch.rollback_twophase(self, xid, is_prepared) @@ -1004,18 +1059,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _commit_twophase_impl(self, xid, is_prepared): + def _commit_twophase_impl(self, xid: Any, is_prepared: bool) -> None: if self._has_events or self.engine._has_events: self.dispatch.commit_twophase(self, xid, is_prepared) - if self._still_open_and_dbapi_connection_is_valid: - assert isinstance(self._transaction, TwoPhaseTransaction) - try: - self.engine.dialect.do_commit_twophase(self, xid, is_prepared) - except BaseException as e: - self._handle_dbapi_exception(e, None, None, None, None) + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) - def close(self): + def close(self) -> None: """Close this :class:`_engine.Connection`. This results in a release of the underlying database @@ -1050,7 +1104,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): # as we just closed the transaction, close the connection # pool connection without doing an additional reset if skip_reset: - conn._close_no_reset() + cast("_ConnectionFairy", conn)._close_no_reset() else: conn.close() @@ -1061,7 +1115,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._dbapi_connection = None self.__can_reconnect = False - def scalar(self, statement, parameters=None, execution_options=None): + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. This method is shorthand for invoking the @@ -1074,7 +1133,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ return self.execute(statement, parameters, execution_options).scalar() - def scalars(self, statement, parameters=None, execution_options=None): + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult: """Executes and returns a scalar result set, which yields scalar values from the first column of each row. @@ -1093,10 +1157,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def execute( self, - statement, - parameters: Optional[_ExecuteParams] = None, - execution_options: Optional[_ExecuteOptions] = None, - ): + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Result: r"""Executes a SQL statement construct and returns a :class:`_engine.Result`. @@ -1140,7 +1204,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options or NO_OPTIONS, ) - def _execute_function(self, func, distilled_parameters, execution_options): + def _execute_function( + self, + func: FunctionElement[Any], + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Result: """Execute a sql.FunctionElement object.""" return self._execute_clauseelement( @@ -1148,14 +1217,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) def _execute_default( - self, default, distilled_parameters, execution_options - ): + self, + default: ColumnDefault, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Any: """Execute a schema.ColumnDefault object.""" execution_options = self._execution_options.merge_with( execution_options ) + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreAnyExecuteParams] + # note for event handlers, the "distilled parameters" which is always # a list of dicts is broken out into separate "multiparams" and # "params" collections, which allows the handler to distinguish @@ -1169,6 +1244,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) = self._invoke_before_exec_event( default, distilled_parameters, execution_options ) + else: + event_multiparams = event_params = None try: conn = self._dbapi_connection @@ -1198,13 +1275,21 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return ret - def _execute_ddl(self, ddl, distilled_parameters, execution_options): + def _execute_ddl( + self, + ddl: DDLElement, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Result: """Execute a schema.DDL object.""" execution_options = ddl._execution_options.merge_with( self._execution_options, execution_options ) + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreSingleExecuteParams] + if self._has_events or self.engine._has_events: ( ddl, @@ -1214,6 +1299,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) = self._invoke_before_exec_event( ddl, distilled_parameters, execution_options ) + else: + event_multiparams = event_params = None exec_opts = self._execution_options.merge_with(execution_options) schema_translate_map = exec_opts.get("schema_translate_map", None) @@ -1243,8 +1330,19 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return ret def _invoke_before_exec_event( - self, elem, distilled_params, execution_options - ): + self, + elem: Any, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Tuple[ + Any, + _CoreMultiExecuteParams, + _CoreMultiExecuteParams, + _CoreSingleExecuteParams, + ]: + + event_multiparams: _CoreMultiExecuteParams + event_params: _CoreSingleExecuteParams if len(distilled_params) == 1: event_multiparams, event_params = [], distilled_params[0] @@ -1275,8 +1373,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return elem, distilled_params, event_multiparams, event_params def _execute_clauseelement( - self, elem, distilled_parameters, execution_options - ): + self, + elem: Executable, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Result: """Execute a sql.ClauseElement object.""" execution_options = elem._execution_options.merge_with( @@ -1309,7 +1410,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "schema_translate_map", None ) - compiled_cache = execution_options.get( + compiled_cache: _CompiledCacheType = execution_options.get( "compiled_cache", self.engine._compiled_cache ) @@ -1346,10 +1447,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _execute_compiled( self, - compiled, - distilled_parameters, - execution_options=_EMPTY_EXECUTION_OPTS, - ): + compiled: Compiled, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, + ) -> Result: """Execute a sql.Compiled object. TODO: why do we have this? likely deprecate or remove @@ -1395,8 +1496,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return ret def exec_driver_sql( - self, statement, parameters=None, execution_options=None - ): + self, + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptions] = None, + ) -> Result: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1456,7 +1560,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect, dialect.execution_ctx_cls._init_statement, statement, - distilled_parameters, + None, execution_options, statement, distilled_parameters, @@ -1466,14 +1570,14 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _execute_context( self, - dialect, - constructor, - statement, - parameters, - execution_options, - *args, - **kw, - ): + dialect: Dialect, + constructor: Callable[..., ExecutionContext], + statement: Union[str, Compiled], + parameters: Optional[_AnyMultiExecuteParams], + execution_options: _ExecuteOptions, + *args: Any, + **kw: Any, + ) -> Result: """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.CursorResult`.""" @@ -1491,7 +1595,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._handle_dbapi_exception( e, str(statement), parameters, None, None ) - return # not reached if ( self._transaction @@ -1514,29 +1617,33 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if dialect.bind_typing is BindTyping.SETINPUTSIZES: context._set_input_sizes() - cursor, statement, parameters = ( + cursor, str_statement, parameters = ( context.cursor, context.statement, context.parameters, ) + effective_parameters: Optional[_AnyExecuteParams] + if not context.executemany: - parameters = parameters[0] + effective_parameters = parameters[0] + else: + effective_parameters = parameters if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = fn( + str_statement, effective_parameters = fn( self, cursor, - statement, - parameters, + str_statement, + effective_parameters, context, context.executemany, ) if self._echo: - self._log_info(statement) + self._log_info(str_statement) stats = context._get_cache_stats() @@ -1545,7 +1652,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "[%s] %r", stats, sql_util._repr_params( - parameters, batches=10, ismulti=context.executemany + effective_parameters, + batches=10, + ismulti=context.executemany, ), ) else: @@ -1554,45 +1663,61 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): % (stats,) ) - evt_handled = False + evt_handled: bool = False try: if context.executemany: + effective_parameters = cast( + "_CoreMultiExecuteParams", effective_parameters + ) if self.dialect._has_events: for fn in self.dialect.dispatch.do_executemany: - if fn(cursor, statement, parameters, context): + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): evt_handled = True break if not evt_handled: self.dialect.do_executemany( - cursor, statement, parameters, context + cursor, str_statement, effective_parameters, context ) - elif not parameters and context.no_parameters: + elif not effective_parameters and context.no_parameters: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute_no_params: - if fn(cursor, statement, context): + if fn(cursor, str_statement, context): evt_handled = True break if not evt_handled: self.dialect.do_execute_no_params( - cursor, statement, context + cursor, str_statement, context ) else: + effective_parameters = cast( + "_CoreSingleExecuteParams", effective_parameters + ) if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute: - if fn(cursor, statement, parameters, context): + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): evt_handled = True break if not evt_handled: self.dialect.do_execute( - cursor, statement, parameters, context + cursor, str_statement, effective_parameters, context ) if self._has_events or self.engine._has_events: self.dispatch.after_cursor_execute( self, cursor, - statement, - parameters, + str_statement, + effective_parameters, context, context.executemany, ) @@ -1603,12 +1728,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception( - e, statement, parameters, cursor, context + e, str_statement, effective_parameters, cursor, context ) return result - def _cursor_execute(self, cursor, statement, parameters, context=None): + def _cursor_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: Optional[ExecutionContext] = None, + ) -> None: """Execute a statement + params on the given cursor. Adds appropriate logging and exception handling. @@ -1648,7 +1779,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, cursor, statement, parameters, context, False ) - def _safe_close_cursor(self, cursor): + def _safe_close_cursor(self, cursor: DBAPICursor) -> None: """Close the given cursor, catching exceptions and turning into log warnings. @@ -1665,8 +1796,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): _is_disconnect = False def _handle_dbapi_exception( - self, e, statement, parameters, cursor, context - ): + self, + e: BaseException, + statement: Optional[str], + parameters: Optional[_AnyExecuteParams], + cursor: Optional[DBAPICursor], + context: Optional[ExecutionContext], + ) -> NoReturn: exc_info = sys.exc_info() is_exit_exception = util.is_exit_exception(e) @@ -1708,7 +1844,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): sqlalchemy_exception = exc.DBAPIError.instance( statement, parameters, - e, + cast(Exception, e), self.dialect.dbapi.Error, hide_parameters=self.engine.hide_parameters, connection_invalidated=self._is_disconnect, @@ -1784,8 +1920,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if newraise: raise newraise.with_traceback(exc_info[2]) from e elif should_wrap: + assert sqlalchemy_exception is not None raise sqlalchemy_exception.with_traceback(exc_info[2]) from e else: + assert exc_info[1] is not None raise exc_info[1].with_traceback(exc_info[2]) finally: del self._reentrant_error @@ -1793,15 +1931,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): del self._is_disconnect if not self.invalidated: dbapi_conn_wrapper = self._dbapi_connection + assert dbapi_conn_wrapper is not None if invalidate_pool_on_disconnect: self.engine.pool._invalidate(dbapi_conn_wrapper, e) self.invalidate(e) @classmethod - def _handle_dbapi_exception_noconnection(cls, e, dialect, engine): + def _handle_dbapi_exception_noconnection( + cls, e: BaseException, dialect: Dialect, engine: Engine + ) -> NoReturn: exc_info = sys.exc_info() - is_disconnect = dialect.is_disconnect(e, None, None) + is_disconnect = isinstance( + e, dialect.dbapi.Error + ) and dialect.is_disconnect(e, None, None) should_wrap = isinstance(e, dialect.dbapi.Error) @@ -1809,7 +1952,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): sqlalchemy_exception = exc.DBAPIError.instance( None, None, - e, + cast(Exception, e), dialect.dbapi.Error, hide_parameters=engine.hide_parameters, connection_invalidated=is_disconnect, @@ -1852,11 +1995,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if newraise: raise newraise.with_traceback(exc_info[2]) from e elif should_wrap: + assert sqlalchemy_exception is not None raise sqlalchemy_exception.with_traceback(exc_info[2]) from e else: + assert exc_info[1] is not None raise exc_info[1].with_traceback(exc_info[2]) - def _run_ddl_visitor(self, visitorcallable, element, **kwargs): + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: DDLElement, + **kwargs: Any, + ) -> None: """run a DDL visitor. This method is only here so that the MockConnection can change the @@ -1871,16 +2021,16 @@ class ExceptionContextImpl(ExceptionContext): def __init__( self, - exception, - sqlalchemy_exception, - engine, - connection, - cursor, - statement, - parameters, - context, - is_disconnect, - invalidate_pool_on_disconnect, + exception: BaseException, + sqlalchemy_exception: Optional[exc.StatementError], + engine: Optional[Engine], + connection: Optional[Connection], + cursor: Optional[DBAPICursor], + statement: Optional[str], + parameters: Optional[_DBAPIAnyExecuteParams], + context: Optional[ExecutionContext], + is_disconnect: bool, + invalidate_pool_on_disconnect: bool, ): self.engine = engine self.connection = connection @@ -1932,33 +2082,35 @@ class Transaction(TransactionalContext): __slots__ = () - _is_root = False + _is_root: bool = False + is_active: bool + connection: Connection - def __init__(self, connection): + def __init__(self, connection: Connection): raise NotImplementedError() @property - def _deactivated_from_connection(self): + def _deactivated_from_connection(self) -> bool: """True if this transaction is totally deactivated from the connection and therefore can no longer affect its state. """ raise NotImplementedError() - def _do_close(self): + def _do_close(self) -> None: raise NotImplementedError() - def _do_rollback(self): + def _do_rollback(self) -> None: raise NotImplementedError() - def _do_commit(self): + def _do_commit(self) -> None: raise NotImplementedError() @property - def is_valid(self): + def is_valid(self) -> bool: return self.is_active and not self.connection.invalidated - def close(self): + def close(self) -> None: """Close this :class:`.Transaction`. If this transaction is the base transaction in a begin/commit @@ -1974,7 +2126,7 @@ class Transaction(TransactionalContext): finally: assert not self.is_active - def rollback(self): + def rollback(self) -> None: """Roll back this :class:`.Transaction`. The implementation of this may vary based on the type of transaction in @@ -1996,7 +2148,7 @@ class Transaction(TransactionalContext): finally: assert not self.is_active - def commit(self): + def commit(self) -> None: """Commit this :class:`.Transaction`. The implementation of this may vary based on the type of transaction in @@ -2017,16 +2169,16 @@ class Transaction(TransactionalContext): finally: assert not self.is_active - def _get_subject(self): + def _get_subject(self) -> Connection: return self.connection - def _transaction_is_active(self): + def _transaction_is_active(self) -> bool: return self.is_active - def _transaction_is_closed(self): + def _transaction_is_closed(self) -> bool: return not self._deactivated_from_connection - def _rollback_can_be_called(self): + def _rollback_can_be_called(self) -> bool: # for RootTransaction / NestedTransaction, it's safe to call # rollback() even if the transaction is deactive and no warnings # will be emitted. tested in @@ -2060,7 +2212,7 @@ class RootTransaction(Transaction): __slots__ = ("connection", "is_active") - def __init__(self, connection): + def __init__(self, connection: Connection): assert connection._transaction is None if connection._trans_context_manager: TransactionalContext._trans_ctx_check(connection) @@ -2070,7 +2222,7 @@ class RootTransaction(Transaction): self.is_active = True - def _deactivate_from_connection(self): + def _deactivate_from_connection(self) -> None: if self.is_active: assert self.connection._transaction is self self.is_active = False @@ -2079,19 +2231,19 @@ class RootTransaction(Transaction): util.warn("transaction already deassociated from connection") @property - def _deactivated_from_connection(self): + def _deactivated_from_connection(self) -> bool: return self.connection._transaction is not self - def _connection_begin_impl(self): + def _connection_begin_impl(self) -> None: self.connection._begin_impl(self) - def _connection_rollback_impl(self): + def _connection_rollback_impl(self) -> None: self.connection._rollback_impl() - def _connection_commit_impl(self): + def _connection_commit_impl(self) -> None: self.connection._commit_impl() - def _close_impl(self, try_deactivate=False): + def _close_impl(self, try_deactivate: bool = False) -> None: try: if self.is_active: self._connection_rollback_impl() @@ -2107,13 +2259,13 @@ class RootTransaction(Transaction): assert not self.is_active assert self.connection._transaction is not self - def _do_close(self): + def _do_close(self) -> None: self._close_impl() - def _do_rollback(self): + def _do_rollback(self) -> None: self._close_impl(try_deactivate=True) - def _do_commit(self): + def _do_commit(self) -> None: if self.is_active: assert self.connection._transaction is self @@ -2176,7 +2328,9 @@ class NestedTransaction(Transaction): __slots__ = ("connection", "is_active", "_savepoint", "_previous_nested") - def __init__(self, connection): + _savepoint: str + + def __init__(self, connection: Connection): assert connection._transaction is not None if connection._trans_context_manager: TransactionalContext._trans_ctx_check(connection) @@ -2186,7 +2340,7 @@ class NestedTransaction(Transaction): self._previous_nested = connection._nested_transaction connection._nested_transaction = self - def _deactivate_from_connection(self, warn=True): + def _deactivate_from_connection(self, warn: bool = True) -> None: if self.connection._nested_transaction is self: self.connection._nested_transaction = self._previous_nested elif warn: @@ -2195,10 +2349,10 @@ class NestedTransaction(Transaction): ) @property - def _deactivated_from_connection(self): + def _deactivated_from_connection(self) -> bool: return self.connection._nested_transaction is not self - def _cancel(self): + def _cancel(self) -> None: # called by RootTransaction when the outer transaction is # committed, rolled back, or closed to cancel all savepoints # without any action being taken @@ -2207,9 +2361,15 @@ class NestedTransaction(Transaction): if self._previous_nested: self._previous_nested._cancel() - def _close_impl(self, deactivate_from_connection, warn_already_deactive): + def _close_impl( + self, deactivate_from_connection: bool, warn_already_deactive: bool + ) -> None: try: - if self.is_active and self.connection._transaction.is_active: + if ( + self.is_active + and self.connection._transaction + and self.connection._transaction.is_active + ): self.connection._rollback_to_savepoint_impl(self._savepoint) finally: self.is_active = False @@ -2221,13 +2381,13 @@ class NestedTransaction(Transaction): if deactivate_from_connection: assert self.connection._nested_transaction is not self - def _do_close(self): + def _do_close(self) -> None: self._close_impl(True, False) - def _do_rollback(self): + def _do_rollback(self) -> None: self._close_impl(True, True) - def _do_commit(self): + def _do_commit(self) -> None: if self.is_active: try: self.connection._release_savepoint_impl(self._savepoint) @@ -2261,12 +2421,14 @@ class TwoPhaseTransaction(RootTransaction): __slots__ = ("xid", "_is_prepared") - def __init__(self, connection, xid): + xid: Any + + def __init__(self, connection: Connection, xid: Any): self._is_prepared = False self.xid = xid super(TwoPhaseTransaction, self).__init__(connection) - def prepare(self): + def prepare(self) -> None: """Prepare this :class:`.TwoPhaseTransaction`. After a PREPARE, the transaction can be committed. @@ -2277,13 +2439,13 @@ class TwoPhaseTransaction(RootTransaction): self.connection._prepare_twophase_impl(self.xid) self._is_prepared = True - def _connection_begin_impl(self): + def _connection_begin_impl(self) -> None: self.connection._begin_twophase_impl(self) - def _connection_rollback_impl(self): + def _connection_rollback_impl(self) -> None: self.connection._rollback_twophase_impl(self.xid, self._is_prepared) - def _connection_commit_impl(self): + def _connection_commit_impl(self) -> None: self.connection._commit_twophase_impl(self.xid, self._is_prepared) @@ -2310,17 +2472,23 @@ class Engine( """ - _execution_options = _EMPTY_EXECUTION_OPTS - _has_events = False - _connection_cls = Connection - _sqla_logger_namespace = "sqlalchemy.engine.Engine" - _is_future = False + dispatch: dispatcher[ConnectionEventsTarget] - _schema_translate_map = None + _compiled_cache: Optional[_CompiledCacheType] + + _execution_options: _ExecuteOptions = _EMPTY_EXECUTION_OPTS + _has_events: bool = False + _connection_cls: Type[Connection] = Connection + _sqla_logger_namespace: str = "sqlalchemy.engine.Engine" + _is_future: bool = False + + _schema_translate_map: Optional[_SchemaTranslateMapType] = None + _option_cls: Type[OptionEngine] dialect: Dialect pool: Pool url: URL + hide_parameters: bool def __init__( self, @@ -2328,7 +2496,7 @@ class Engine( dialect: Dialect, url: URL, logging_name: Optional[str] = None, - echo: Union[None, str, bool] = None, + echo: Optional[_EchoFlagType] = None, query_cache_size: int = 500, execution_options: Optional[Mapping[str, Any]] = None, hide_parameters: bool = False, @@ -2350,7 +2518,7 @@ class Engine( if execution_options: self.update_execution_options(**execution_options) - def _lru_size_alert(self, cache): + def _lru_size_alert(self, cache: util.LRUCache[Any, Any]) -> None: if self._should_log_info: self.logger.info( "Compiled cache size pruning from %d items to %d. " @@ -2360,10 +2528,10 @@ class Engine( ) @property - def engine(self): + def engine(self) -> Engine: return self - def clear_compiled_cache(self): + def clear_compiled_cache(self) -> None: """Clear the compiled cache associated with the dialect. This applies **only** to the built-in cache that is established @@ -2377,7 +2545,7 @@ class Engine( if self._compiled_cache: self._compiled_cache.clear() - def update_execution_options(self, **opt): + def update_execution_options(self, **opt: Any) -> None: r"""Update the default execution_options dictionary of this :class:`_engine.Engine`. @@ -2394,11 +2562,11 @@ class Engine( :meth:`_engine.Engine.execution_options` """ - self._execution_options = self._execution_options.union(opt) self.dispatch.set_engine_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) self.dialect.set_engine_execution_options(self, opt) - def execution_options(self, **opt): + def execution_options(self, **opt: Any) -> OptionEngine: """Return a new :class:`_engine.Engine` that will provide :class:`_engine.Connection` objects with the given execution options. @@ -2478,7 +2646,7 @@ class Engine( """ # noqa E501 return self._option_cls(self, opt) - def get_execution_options(self): + def get_execution_options(self) -> _ExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded: 1.3 @@ -2490,14 +2658,14 @@ class Engine( return self._execution_options @property - def name(self): + def name(self) -> str: """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` in use by this :class:`Engine`.""" return self.dialect.name @property - def driver(self): + def driver(self) -> str: """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` in use by this :class:`Engine`.""" @@ -2505,10 +2673,10 @@ class Engine( echo = log.echo_property() - def __repr__(self): + def __repr__(self) -> str: return "Engine(%r)" % (self.url,) - def dispose(self): + def dispose(self) -> None: """Dispose of the connection pool used by this :class:`_engine.Engine`. @@ -2538,7 +2706,9 @@ class Engine( self.dispatch.engine_disposed(self) @contextlib.contextmanager - def _optional_conn_ctx_manager(self, connection=None): + def _optional_conn_ctx_manager( + self, connection: Optional[Connection] = None + ) -> Iterator[Connection]: if connection is None: with self.connect() as conn: yield conn @@ -2546,7 +2716,7 @@ class Engine( yield connection @contextlib.contextmanager - def begin(self): + def begin(self) -> Iterator[Connection]: """Return a context manager delivering a :class:`_engine.Connection` with a :class:`.Transaction` established. @@ -2576,11 +2746,16 @@ class Engine( with conn.begin(): yield conn - def _run_ddl_visitor(self, visitorcallable, element, **kwargs): + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: DDLElement, + **kwargs: Any, + ) -> None: with self.begin() as conn: conn._run_ddl_visitor(visitorcallable, element, **kwargs) - def connect(self): + def connect(self) -> Connection: """Return a new :class:`_engine.Connection` object. The :class:`_engine.Connection` acts as a Python context manager, so @@ -2605,7 +2780,7 @@ class Engine( return self._connection_cls(self) - def raw_connection(self): + def raw_connection(self) -> PoolProxiedConnection: """Return a "raw" DBAPI connection from the connection pool. The returned object is a proxied version of the DBAPI @@ -2630,10 +2805,20 @@ class Engine( return self.pool.connect() -class OptionEngineMixin: +class OptionEngineMixin(log.Identified): _sa_propagate_class_events = False - def __init__(self, proxied, execution_options): + dispatch: dispatcher[ConnectionEventsTarget] + _compiled_cache: Optional[_CompiledCacheType] + dialect: Dialect + pool: Pool + url: URL + hide_parameters: bool + echo: log.echo_property + + def __init__( + self, proxied: Engine, execution_options: _ExecuteOptionsParameter + ): self._proxied = proxied self.url = proxied.url self.dialect = proxied.dialect @@ -2660,27 +2845,34 @@ class OptionEngineMixin: self._execution_options = proxied._execution_options self.update_execution_options(**execution_options) - def _get_pool(self): - return self._proxied.pool + def update_execution_options(self, **opt: Any) -> None: + raise NotImplementedError() - def _set_pool(self, pool): - self._proxied.pool = pool + if not typing.TYPE_CHECKING: + # https://github.com/python/typing/discussions/1095 - pool = property(_get_pool, _set_pool) + @property + def pool(self) -> Pool: + return self._proxied.pool - def _get_has_events(self): - return self._proxied._has_events or self.__dict__.get( - "_has_events", False - ) + @pool.setter + def pool(self, pool: Pool) -> None: + self._proxied.pool = pool - def _set_has_events(self, value): - self.__dict__["_has_events"] = value + @property + def _has_events(self) -> bool: + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) - _has_events = property(_get_has_events, _set_has_events) + @_has_events.setter + def _has_events(self, value: bool) -> None: + self.__dict__["_has_events"] = value class OptionEngine(OptionEngineMixin, Engine): - pass + def update_execution_options(self, **opt: Any) -> None: + Engine.update_execution_options(self, **opt) Engine._option_cls = OptionEngine diff --git a/lib/sqlalchemy/engine/characteristics.py b/lib/sqlalchemy/engine/characteristics.py index c3674c931..c0feb000b 100644 --- a/lib/sqlalchemy/engine/characteristics.py +++ b/lib/sqlalchemy/engine/characteristics.py @@ -1,6 +1,13 @@ from __future__ import annotations import abc +import typing +from typing import Any +from typing import ClassVar + +if typing.TYPE_CHECKING: + from .interfaces import DBAPIConnection + from .interfaces import Dialect class ConnectionCharacteristic(abc.ABC): @@ -25,18 +32,24 @@ class ConnectionCharacteristic(abc.ABC): __slots__ = () - transactional = False + transactional: ClassVar[bool] = False @abc.abstractmethod - def reset_characteristic(self, dialect, dbapi_conn): + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: """Reset the characteristic on the connection to its default value.""" @abc.abstractmethod - def set_characteristic(self, dialect, dbapi_conn, value): + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: """set characteristic on the connection to a given value.""" @abc.abstractmethod - def get_characteristic(self, dialect, dbapi_conn): + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: """Given a DBAPI connection, get the current value of the characteristic. @@ -44,13 +57,19 @@ class ConnectionCharacteristic(abc.ABC): class IsolationLevelCharacteristic(ConnectionCharacteristic): - transactional = True + transactional: ClassVar[bool] = True - def reset_characteristic(self, dialect, dbapi_conn): + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: dialect.reset_isolation_level(dbapi_conn) - def set_characteristic(self, dialect, dbapi_conn, value): + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: dialect._assert_and_set_isolation_level(dbapi_conn, value) - def get_characteristic(self, dialect, dbapi_conn): + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: return dialect.get_isolation_level(dbapi_conn) diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index ac3d6a2d8..cb5219396 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -7,7 +7,12 @@ from __future__ import annotations +import inspect +import typing from typing import Any +from typing import cast +from typing import Dict +from typing import Optional from typing import Union from . import base @@ -21,6 +26,9 @@ from ..pool import _AdhocProxiedConnection from ..pool import ConnectionPoolEntry from ..sql import compiler +if typing.TYPE_CHECKING: + from .base import Engine + @util.deprecated_params( strategy=( @@ -46,7 +54,7 @@ from ..sql import compiler "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": +def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> Engine: """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL <database_urls>` as the @@ -452,7 +460,8 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": if "strategy" in kwargs: strat = kwargs.pop("strategy") if strat == "mock": - return create_mock_engine(url, **kwargs) + # this case is deprecated + return create_mock_engine(url, **kwargs) # type: ignore else: raise exc.ArgumentError("unknown strategy: %r" % strat) @@ -472,14 +481,14 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": if kwargs.pop("_coerce_config", False): - def pop_kwarg(key, default=None): + def pop_kwarg(key: str, default: Optional[Any] = None) -> Any: value = kwargs.pop(key, default) if key in dialect_cls.engine_config_types: value = dialect_cls.engine_config_types[key](value) return value else: - pop_kwarg = kwargs.pop + pop_kwarg = kwargs.pop # type: ignore dialect_args = {} # consume dialect arguments from kwargs @@ -490,10 +499,29 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": dbapi = kwargs.pop("module", None) if dbapi is None: dbapi_args = {} - for k in util.get_func_kwargs(dialect_cls.dbapi): + + if "import_dbapi" in dialect_cls.__dict__: + dbapi_meth = dialect_cls.import_dbapi + + elif hasattr(dialect_cls, "dbapi") and inspect.ismethod( + dialect_cls.dbapi + ): + util.warn_deprecated( + "The dbapi() classmethod on dialect classes has been " + "renamed to import_dbapi(). Implement an import_dbapi() " + f"classmethod directly on class {dialect_cls} to remove this " + "warning; the old .dbapi() classmethod may be maintained for " + "backwards compatibility.", + "2.0", + ) + dbapi_meth = dialect_cls.dbapi + else: + dbapi_meth = dialect_cls.import_dbapi + + for k in util.get_func_kwargs(dbapi_meth): if k in kwargs: dbapi_args[k] = pop_kwarg(k) - dbapi = dialect_cls.dbapi(**dbapi_args) + dbapi = dbapi_meth(**dbapi_args) dialect_args["dbapi"] = dbapi @@ -509,18 +537,23 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": dialect = dialect_cls(**dialect_args) # assemble connection arguments - (cargs, cparams) = dialect.create_connect_args(u) + (cargs_tup, cparams) = dialect.create_connect_args(u) cparams.update(pop_kwarg("connect_args", {})) - cargs = list(cargs) # allow mutability + cargs = list(cargs_tup) # allow mutability # look for existing pool or create pool = pop_kwarg("pool", None) if pool is None: - def connect(connection_record=None): + def connect( + connection_record: Optional[ConnectionPoolEntry] = None, + ) -> DBAPIConnection: if dialect._has_events: for fn in dialect.dispatch.do_connect: - connection = fn(dialect, connection_record, cargs, cparams) + connection = cast( + DBAPIConnection, + fn(dialect, connection_record, cargs, cparams), + ) if connection is not None: return connection return dialect.connect(*cargs, **cparams) @@ -596,7 +629,11 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": do_on_connect = dialect.on_connect_url(u) if do_on_connect: - def on_connect(dbapi_connection, connection_record): + def on_connect( + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + assert do_on_connect is not None do_on_connect(dbapi_connection) event.listen(pool, "connect", on_connect) @@ -608,7 +645,7 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": def first_connect( dbapi_connection: DBAPIConnection, connection_record: ConnectionPoolEntry, - ): + ) -> None: c = base.Connection( engine, connection=_AdhocProxiedConnection( @@ -654,7 +691,9 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": return engine -def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): +def engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> Engine: """Create a new Engine instance using a configuration dictionary. The dictionary is typically produced from a config file. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 2b077056f..78805bac1 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -13,6 +13,17 @@ from __future__ import annotations import collections import functools +import typing +from typing import Any +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from .result import Result from .result import ResultMetaData @@ -30,19 +41,43 @@ from ..sql.compiler import RM_OBJECTS from ..sql.compiler import RM_RENDERED_NAME from ..sql.compiler import RM_TYPE from ..util import compat +from ..util.typing import Literal _UNPICKLED = util.symbol("unpickled") +if typing.TYPE_CHECKING: + from .interfaces import _DBAPICursorDescription + from .interfaces import ExecutionContext + from .result import _KeyIndexType + from .result import _KeyMapRecType + from .result import _KeyMapType + from .result import _KeyType + from .result import _ProcessorsType + from .result import _ProcessorType + # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. -MD_INDEX = 0 # integer index in cursor.description -MD_RESULT_MAP_INDEX = 1 # integer index in compiled._result_columns -MD_OBJECTS = 2 # other string keys and ColumnElement obj that can match -MD_LOOKUP_KEY = 3 # string key we usually expect for key-based lookup -MD_RENDERED_NAME = 4 # name that is usually in cursor.description -MD_PROCESSOR = 5 # callable to process a result value into a row -MD_UNTRANSLATED = 6 # raw name from cursor.description +MD_INDEX: Literal[0] = 0 # integer index in cursor.description +MD_RESULT_MAP_INDEX: Literal[ + 1 +] = 1 # integer index in compiled._result_columns +MD_OBJECTS: Literal[ + 2 +] = 2 # other string keys and ColumnElement obj that can match +MD_LOOKUP_KEY: Literal[ + 3 +] = 3 # string key we usually expect for key-based lookup +MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description +MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row +MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description + + +_CursorKeyMapRecType = Tuple[ + int, int, List[Any], str, str, Optional["_ProcessorType"], str +] + +_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] class CursorResultMetaData(ResultMetaData): @@ -61,22 +96,30 @@ class CursorResultMetaData(ResultMetaData): # if a need arises. ) - returns_rows = True + _keymap: _CursorKeyMapType + _processors: _ProcessorsType + _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]] + _unpickled: bool + _safe_for_cache: bool + + returns_rows: ClassVar[bool] = True - def _has_key(self, key): + def _has_key(self, key: Any) -> bool: return key in self._keymap - def _for_freeze(self): + def _for_freeze(self) -> ResultMetaData: return SimpleResultMetaData( self._keys, extra=[self._keymap[key][MD_OBJECTS] for key in self._keys], ) - def _reduce(self, keys): - recs = list(self._metadata_for_keys(keys)) + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + recs = cast( + "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys)) + ) indexes = [rec[MD_INDEX] for rec in recs] - new_keys = [rec[MD_LOOKUP_KEY] for rec in recs] + new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs] if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] @@ -104,7 +147,7 @@ class CursorResultMetaData(ResultMetaData): return new_metadata - def _adapt_to_context(self, context): + def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: """When using a cached Compiled construct that has a _result_map, for a new statement that used the cached Compiled, we need to ensure the keymap has the Column objects from our new statement as keys. @@ -112,8 +155,7 @@ class CursorResultMetaData(ResultMetaData): as matched to those of the cached statement. """ - - if not context.compiled._result_columns: + if not context.compiled or not context.compiled._result_columns: return self compiled_statement = context.compiled.statement @@ -122,6 +164,8 @@ class CursorResultMetaData(ResultMetaData): if compiled_statement is invoked_statement: return self + assert invoked_statement is not None + # this is the most common path for Core statements when # caching is used. In ORM use, this codepath is not really used # as the _result_disable_adapt_to_context execution option is @@ -162,7 +206,9 @@ class CursorResultMetaData(ResultMetaData): md._safe_for_cache = self._safe_for_cache return md - def __init__(self, parent, cursor_description): + def __init__( + self, parent: CursorResult, cursor_description: _DBAPICursorDescription + ): context = parent.context self._tuplefilter = None self._translated_indexes = None @@ -229,7 +275,7 @@ class CursorResultMetaData(ResultMetaData): # new in 1.4: get the complete set of all possible keys, # strings, objects, whatever, that are dupes across two # different records, first. - index_by_key = {} + index_by_key: Dict[Any, Any] = {} dupes = set() for metadata_entry in raw: for key in (metadata_entry[MD_RENDERED_NAME],) + ( @@ -626,7 +672,7 @@ class CursorResultMetaData(ResultMetaData): "result set column descriptions" % rec[MD_LOOKUP_KEY] ) - def _index_for_key(self, key, raiseerr=True): + def _index_for_key(self, key: Any, raiseerr: bool = True) -> Optional[int]: # TODO: can consider pre-loading ints and negative ints # into _keymap - also no coverage here if isinstance(key, int): @@ -653,7 +699,9 @@ class CursorResultMetaData(ResultMetaData): # ensure it raises CursorResultMetaData._key_fallback(self, ke.args[0], ke) - def _metadata_for_keys(self, keys): + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_CursorKeyMapRecType]: for key in keys: if int in key.__class__.__mro__: key = self._keys[key] @@ -707,7 +755,7 @@ class ResultFetchStrategy: __slots__ = () - alternate_cursor_description = None + alternate_cursor_description: Optional[_DBAPICursorDescription] = None def soft_close(self, result, dbapi_cursor): raise NotImplementedError() @@ -1099,10 +1147,9 @@ _NO_RESULT_METADATA = _NoResultMetaData() class BaseCursorResult: """Base class for database result objects.""" - out_parameters = None - _metadata = None - _soft_closed = False - closed = False + _metadata: ResultMetaData + _soft_closed: bool = False + closed: bool = False def __init__(self, context, cursor_strategy, cursor_description): self.context = context @@ -1134,7 +1181,7 @@ class BaseCursorResult: keymap = metadata._keymap processors = metadata._processors - process_row = self._process_row + process_row = Row key_style = process_row._default_key_style _make_row = functools.partial( process_row, metadata, processors, keymap, key_style @@ -1644,7 +1691,7 @@ class CursorResult(BaseCursorResult, Result): """ - _cursor_metadata = CursorResultMetaData + _cursor_metadata: Type[ResultMetaData] = CursorResultMetaData _cursor_strategy_cls = CursorFetchStrategy _no_result_metadata = _NO_RESULT_METADATA _is_cursor = True @@ -1719,7 +1766,9 @@ class BufferedRowResultProxy(ResultProxy): """ - _cursor_strategy_cls = BufferedRowCursorFetchStrategy + _cursor_strategy_cls: Type[ + CursorFetchStrategy + ] = BufferedRowCursorFetchStrategy class FullyBufferedResultProxy(ResultProxy): @@ -1744,5 +1793,3 @@ class BufferedColumnResultProxy(ResultProxy): and this class does not change behavior in any way. """ - - _process_row = BufferedColumnRow diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a4dbf2361..0e0c76389 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -19,12 +19,30 @@ import functools import random import re from time import perf_counter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type import weakref from . import characteristics from . import cursor as _cursor from . import interfaces from .base import Connection +from .interfaces import CacheStats +from .interfaces import DBAPICursor +from .interfaces import Dialect +from .interfaces import ExecutionContext from .. import event from .. import exc from .. import pool @@ -32,25 +50,49 @@ from .. import types as sqltypes from .. import util from ..sql import compiler from ..sql import expression +from ..sql.compiler import DDLCompiler +from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name +if typing.TYPE_CHECKING: + from .interfaces import _AnyMultiExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .result import _ProcessorType + from .row import Row + from .url import URL + from ..event import _ListenerFnType + from ..pool import Pool + from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql.compiler import Compiled + from ..sql.compiler import ResultColumnsEntry + from ..sql.schema import Column + from ..sql.type_api import TypeEngine + # When we're handed literal SQL, ensure it's a SELECT query SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) -CACHE_HIT = util.symbol("CACHE_HIT") -CACHE_MISS = util.symbol("CACHE_MISS") -CACHING_DISABLED = util.symbol("CACHING_DISABLED") -NO_CACHE_KEY = util.symbol("NO_CACHE_KEY") -NO_DIALECT_SUPPORT = util.symbol("NO_DIALECT_SUPPORT") +( + CACHE_HIT, + CACHE_MISS, + CACHING_DISABLED, + NO_CACHE_KEY, + NO_DIALECT_SUPPORT, +) = list(CacheStats) -class DefaultDialect(interfaces.Dialect): +class DefaultDialect(Dialect): """Default implementation of Dialect""" statement_compiler = compiler.SQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.GenericTypeCompiler + type_compiler = compiler.GenericTypeCompiler # type: ignore preparer = compiler.IdentifierPreparer supports_alter = True supports_comments = False @@ -61,8 +103,8 @@ class DefaultDialect(interfaces.Dialect): bind_typing = interfaces.BindTyping.NONE - include_set_input_sizes = None - exclude_set_input_sizes = None + include_set_input_sizes: Optional[Set[Any]] = None + exclude_set_input_sizes: Optional[Set[Any]] = None # the first value we'd get for an autoincrement # column. @@ -70,7 +112,7 @@ class DefaultDialect(interfaces.Dialect): # most DBAPIs happy with this for execute(). # not cx_oracle. - execute_sequence_format = tuple + execute_sequence_format = tuple # type: ignore supports_schemas = True supports_views = True @@ -97,16 +139,16 @@ class DefaultDialect(interfaces.Dialect): {"isolation_level": characteristics.IsolationLevelCharacteristic()} ) - engine_config_types = util.immutabledict( - [ - ("pool_timeout", util.asint), - ("echo", util.bool_or_str("debug")), - ("echo_pool", util.bool_or_str("debug")), - ("pool_recycle", util.asint), - ("pool_size", util.asint), - ("max_overflow", util.asint), - ("future", util.asbool), - ] + engine_config_types: Mapping[str, Any] = util.immutabledict( + { + "pool_timeout": util.asint, + "echo": util.bool_or_str("debug"), + "echo_pool": util.bool_or_str("debug"), + "pool_recycle": util.asint, + "pool_size": util.asint, + "max_overflow": util.asint, + "future": util.asbool, + } ) # if the NUMERIC type @@ -119,19 +161,21 @@ class DefaultDialect(interfaces.Dialect): # length at which to truncate # any identifier. max_identifier_length = 9999 - _user_defined_max_identifier_length = None + _user_defined_max_identifier_length: Optional[int] = None - isolation_level = None + isolation_level: Optional[str] = None # sub-categories of max_identifier_length. # currently these accommodate for MySQL which allows alias names # of 255 but DDL names only of 64. - max_index_name_length = None - max_constraint_name_length = None + max_index_name_length: Optional[int] = None + max_constraint_name_length: Optional[int] = None supports_sane_rowcount = True supports_sane_multi_rowcount = True - colspecs = {} + colspecs: MutableMapping[ + Type["TypeEngine[Any]"], Type["TypeEngine[Any]"] + ] = {} default_paramstyle = "named" supports_default_values = False @@ -160,43 +204,6 @@ class DefaultDialect(interfaces.Dialect): default_schema_name = None - construct_arguments = None - """Optional set of argument specifiers for various SQLAlchemy - constructs, typically schema items. - - To implement, establish as a series of tuples, as in:: - - construct_arguments = [ - (schema.Index, { - "using": False, - "where": None, - "ops": None - }) - ] - - If the above construct is established on the PostgreSQL dialect, - the :class:`.Index` construct will now accept the keyword arguments - ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. - Any other argument specified to the constructor of :class:`.Index` - which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. - - A dialect which does not include a ``construct_arguments`` member will - not participate in the argument validation system. For such a dialect, - any argument name is accepted by all participating constructs, within - the namespace of arguments prefixed with that dialect name. The rationale - here is so that third-party dialects that haven't yet implemented this - feature continue to function in the old way. - - .. versionadded:: 0.9.2 - - .. seealso:: - - :class:`.DialectKWArgs` - implementing base class which consumes - :attr:`.DefaultDialect.construct_arguments` - - - """ - # indicates symbol names are # UPPERCASEd if they are case insensitive # within the database. @@ -204,17 +211,6 @@ class DefaultDialect(interfaces.Dialect): # and denormalize_name() must be provided. requires_name_normalize = False - reflection_options = () - - dbapi_exception_translation_map = util.immutabledict() - """mapping used in the extremely unusual case that a DBAPI's - published exceptions don't actually have the __name__ that they - are linked towards. - - .. versionadded:: 1.0.5 - - """ - is_async = False CACHE_HIT = CACHE_HIT @@ -363,10 +359,10 @@ class DefaultDialect(interfaces.Dialect): return self.supports_sane_rowcount @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> Type[Pool]: return getattr(cls, "poolclass", pool.QueuePool) - def get_dialect_pool_class(self, url): + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: return self.get_pool_class(url) @classmethod @@ -377,7 +373,7 @@ class DefaultDialect(interfaces.Dialect): except ImportError: pass - def _builtin_onconnect(self): + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: if self._on_connect_isolation_level is not None: def builtin_connect(dbapi_conn, conn_rec): @@ -734,7 +730,7 @@ class StrCompileDialect(DefaultDialect): statement_compiler = compiler.StrSQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.StrSQLTypeCompiler + type_compiler = compiler.StrSQLTypeCompiler # type: ignore preparer = compiler.IdentifierPreparer supports_statement_cache = True @@ -758,24 +754,26 @@ class StrCompileDialect(DefaultDialect): } -class DefaultExecutionContext(interfaces.ExecutionContext): +class DefaultExecutionContext(ExecutionContext): isinsert = False isupdate = False isdelete = False is_crud = False is_text = False isddl = False + executemany = False - compiled = None - statement = None - result_column_struct = None - returned_default_rows = None - execution_options = util.immutabledict() + compiled: Optional[Compiled] = None + result_column_struct: Optional[ + Tuple[List[ResultColumnsEntry], bool, bool, bool] + ] = None + returned_default_rows: Optional[List[Row]] = None + + execution_options: _ExecuteOptions = util.EMPTY_DICT cursor_fetch_strategy = _cursor._DEFAULT_FETCH - cache_stats = None - invoked_statement = None + invoked_statement: Optional[Executable] = None _is_implicit_returning = False _is_explicit_returning = False @@ -786,21 +784,37 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # a hook for SQLite's translation of # result column names # NOTE: pyhive is using this hook, can't remove it :( - _translate_colname = None + _translate_colname: Optional[Callable[[str], str]] = None + + _expanded_parameters: Mapping[str, List[str]] = util.immutabledict() + """used by set_input_sizes(). + + This collection comes from ``ExpandedState.parameter_expansion``. - _expanded_parameters = util.immutabledict() + """ cache_hit = NO_CACHE_KEY + root_connection: Connection + _dbapi_connection: PoolProxiedConnection + dialect: Dialect + unicode_statement: str + cursor: DBAPICursor + compiled_parameters: _CoreMultiExecuteParams + parameters: _DBAPIMultiExecuteParams + extracted_parameters: _CoreSingleExecuteParams + + _empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT) + @classmethod def _init_ddl( cls, - dialect, - connection, - dbapi_connection, - execution_options, - compiled_ddl, - ): + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: """Initialize execution context for a DDLElement construct.""" self = cls.__new__(cls) @@ -832,23 +846,23 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if dialect.positional: self.parameters = [dialect.execute_sequence_format()] else: - self.parameters = [{}] + self.parameters = [self._empty_dict_params] return self @classmethod def _init_compiled( cls, - dialect, - connection, - dbapi_connection, - execution_options, - compiled, - parameters, - invoked_statement, - extracted_parameters, - cache_hit=CACHING_DISABLED, - ): + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: _CoreSingleExecuteParams, + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: """Initialize execution context for a Compiled construct.""" self = cls.__new__(cls) @@ -868,6 +882,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): compiled._textual_ordered_columns, compiled._loose_column_name_matching, ) + self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate self.isdelete = compiled.isdelete @@ -910,6 +925,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): processors = compiled._bind_processors + flattened_processors: Mapping[ + str, _ProcessorType + ] = processors # type: ignore[assignment] + if compiled.literal_execute_params or compiled.post_compile_params: if self.executemany: raise exc.InvalidRequestError( @@ -924,14 +943,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # re-assign self.unicode_statement self.unicode_statement = expanded_state.statement - # used by set_input_sizes() which is needed for Oracle self._expanded_parameters = expanded_state.parameter_expansion - processors = dict(processors) - processors.update(expanded_state.processors) + flattened_processors = dict(processors) # type: ignore + flattened_processors.update(expanded_state.processors) positiontup = expanded_state.positiontup elif compiled.positional: positiontup = self.compiled.positiontup + else: + positiontup = None if compiled.schema_translate_map: schema_translate_map = self.execution_options.get( @@ -949,42 +969,49 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # Convert the dictionary of bind parameter values # into a dict or list to be sent to the DBAPI's # execute() or executemany() method. - parameters = [] + if compiled.positional: + core_positional_parameters: MutableSequence[Sequence[Any]] = [] + assert positiontup is not None for compiled_params in self.compiled_parameters: - param = [ - processors[key](compiled_params[key]) - if key in processors + l_param: List[Any] = [ + flattened_processors[key](compiled_params[key]) + if key in flattened_processors else compiled_params[key] for key in positiontup ] - parameters.append(dialect.execute_sequence_format(param)) + core_positional_parameters.append( + dialect.execute_sequence_format(l_param) + ) + + self.parameters = core_positional_parameters else: + core_dict_parameters: MutableSequence[Dict[str, Any]] = [] for compiled_params in self.compiled_parameters: - param = { - key: processors[key](compiled_params[key]) - if key in processors + d_param: Dict[str, Any] = { + key: flattened_processors[key](compiled_params[key]) + if key in flattened_processors else compiled_params[key] for key in compiled_params } - parameters.append(param) + core_dict_parameters.append(d_param) - self.parameters = dialect.execute_sequence_format(parameters) + self.parameters = core_dict_parameters return self @classmethod def _init_statement( cls, - dialect, - connection, - dbapi_connection, - execution_options, - statement, - parameters, - ): + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: """Initialize execution context for a string SQL statement.""" self = cls.__new__(cls) @@ -999,7 +1026,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if self.dialect.positional: self.parameters = [dialect.execute_sequence_format()] else: - self.parameters = [{}] + self.parameters = [self._empty_dict_params] elif isinstance(parameters[0], dialect.execute_sequence_format): self.parameters = parameters elif isinstance(parameters[0], dict): @@ -1018,8 +1045,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): @classmethod def _init_default( - cls, dialect, connection, dbapi_connection, execution_options - ): + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: """Initialize execution context for a ColumnDefault construct.""" self = cls.__new__(cls) @@ -1032,7 +1063,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() return self - def _get_cache_stats(self): + def _get_cache_stats(self) -> str: if self.compiled is None: return "raw sql" @@ -1040,19 +1071,22 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ch = self.cache_hit + gen_time = self.compiled._gen_time + assert gen_time is not None + if ch is NO_CACHE_KEY: - return "no key %.5fs" % (now - self.compiled._gen_time,) + return "no key %.5fs" % (now - gen_time,) elif ch is CACHE_HIT: - return "cached since %.4gs ago" % (now - self.compiled._gen_time,) + return "cached since %.4gs ago" % (now - gen_time,) elif ch is CACHE_MISS: - return "generated in %.5fs" % (now - self.compiled._gen_time,) + return "generated in %.5fs" % (now - gen_time,) elif ch is CACHING_DISABLED: - return "caching disabled %.5fs" % (now - self.compiled._gen_time,) + return "caching disabled %.5fs" % (now - gen_time,) elif ch is NO_DIALECT_SUPPORT: return "dialect %s+%s does not support caching %.5fs" % ( self.dialect.name, self.dialect.driver, - now - self.compiled._gen_time, + now - gen_time, ) else: return "unknown" @@ -1073,11 +1107,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.root_connection.engine @util.memoized_property - def postfetch_cols(self): + def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501 + assert isinstance(self.compiled, SQLCompiler) return self.compiled.postfetch @util.memoized_property - def prefetch_cols(self): + def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501 + assert isinstance(self.compiled, SQLCompiler) if self.isinsert: return self.compiled.insert_prefetch elif self.isupdate: @@ -1086,8 +1122,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return () @util.memoized_property - def returning_cols(self): - self.compiled.returning + def returning_cols(self) -> Optional[Sequence[Column[Any]]]: + assert isinstance(self.compiled, SQLCompiler) + return self.compiled.returning @util.memoized_property def no_parameters(self): @@ -1564,7 +1601,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): str(compiled), type_, parameters=parameters ) - current_parameters = None + current_parameters: Optional[_CoreSingleExecuteParams] = None """A dictionary of parameters applied to the current row. This attribute is only available in the context of a user-defined default diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index ab462bbe1..0cbf56a6d 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -8,14 +8,41 @@ from __future__ import annotations +import typing +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + from .base import Engine from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPIConnection +from .interfaces import DBAPICursor from .interfaces import Dialect from .. import event from .. import exc - - -class ConnectionEvents(event.Events): +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .base import Connection + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import ExceptionContext + from .interfaces import ExecutionContext + from .result import Result + from ..pool import ConnectionPoolEntry + from ..sql import Executable + from ..sql.elements import BindParameter + + +class ConnectionEvents(event.Events[ConnectionEventsTarget]): """Available events for :class:`_engine.Connection` and :class:`_engine.Engine`. @@ -96,7 +123,12 @@ class ConnectionEvents(event.Events): _dispatch_target = ConnectionEventsTarget @classmethod - def _listen(cls, event_key, retval=False): + def _listen( # type: ignore[override] + cls, + event_key: event._EventKey[ConnectionEventsTarget], + retval: bool = False, + **kw: Any, + ) -> None: target, identifier, fn = ( event_key.dispatch_target, event_key.identifier, @@ -109,7 +141,7 @@ class ConnectionEvents(event.Events): if identifier == "before_execute": orig_fn = fn - def wrap_before_execute( + def wrap_before_execute( # type: ignore conn, clauseelement, multiparams, params, execution_options ): orig_fn( @@ -125,7 +157,7 @@ class ConnectionEvents(event.Events): elif identifier == "before_cursor_execute": orig_fn = fn - def wrap_before_cursor_execute( + def wrap_before_cursor_execute( # type: ignore conn, cursor, statement, parameters, context, executemany ): orig_fn( @@ -163,8 +195,15 @@ class ConnectionEvents(event.Events): ), ) def before_execute( - self, conn, clauseelement, multiparams, params, execution_options - ): + self, + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + ) -> Optional[ + Tuple[Executable, _CoreMultiExecuteParams, _CoreSingleExecuteParams] + ]: """Intercept high level execute() events, receiving uncompiled SQL constructs and other objects prior to rendering into SQL. @@ -214,13 +253,13 @@ class ConnectionEvents(event.Events): ) def after_execute( self, - conn, - clauseelement, - multiparams, - params, - execution_options, - result, - ): + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + result: Result, + ) -> None: """Intercept high level execute() events after execute. @@ -244,8 +283,14 @@ class ConnectionEvents(event.Events): """ def before_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> Optional[Tuple[str, _DBAPIAnyExecuteParams]]: """Intercept low-level cursor execute() events before execution, receiving the string SQL statement and DBAPI-specific parameter list to be invoked against a cursor. @@ -286,8 +331,14 @@ class ConnectionEvents(event.Events): """ def after_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> None: """Intercept low-level cursor execute() events after execution. :param conn: :class:`_engine.Connection` object @@ -305,7 +356,9 @@ class ConnectionEvents(event.Events): """ - def handle_error(self, exception_context): + def handle_error( + self, exception_context: ExceptionContext + ) -> Optional[BaseException]: r"""Intercept all exceptions processed by the :class:`_engine.Connection`. @@ -439,7 +492,7 @@ class ConnectionEvents(event.Events): @event._legacy_signature( "2.0", ["conn", "branch"], converter=lambda conn: (conn, False) ) - def engine_connect(self, conn): + def engine_connect(self, conn: Connection) -> None: """Intercept the creation of a new :class:`_engine.Connection`. This event is called typically as the direct result of calling @@ -475,7 +528,9 @@ class ConnectionEvents(event.Events): """ - def set_connection_execution_options(self, conn, opts): + def set_connection_execution_options( + self, conn: Connection, opts: Dict[str, Any] + ) -> None: """Intercept when the :meth:`_engine.Connection.execution_options` method is called. @@ -494,8 +549,12 @@ class ConnectionEvents(event.Events): :param opts: dictionary of options that were passed to the :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. + + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. - .. versionadded:: 0.9.0 .. seealso:: @@ -507,7 +566,9 @@ class ConnectionEvents(event.Events): """ - def set_engine_execution_options(self, engine, opts): + def set_engine_execution_options( + self, engine: Engine, opts: Dict[str, Any] + ) -> None: """Intercept when the :meth:`_engine.Engine.execution_options` method is called. @@ -526,8 +587,11 @@ class ConnectionEvents(event.Events): :param opts: dictionary of options that were passed to the :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. - .. versionadded:: 0.9.0 + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. .. seealso:: @@ -539,7 +603,7 @@ class ConnectionEvents(event.Events): """ - def engine_disposed(self, engine): + def engine_disposed(self, engine: Engine) -> None: """Intercept when the :meth:`_engine.Engine.dispose` method is called. The :meth:`_engine.Engine.dispose` method instructs the engine to @@ -559,14 +623,14 @@ class ConnectionEvents(event.Events): """ - def begin(self, conn): + def begin(self, conn: Connection) -> None: """Intercept begin() events. :param conn: :class:`_engine.Connection` object """ - def rollback(self, conn): + def rollback(self, conn: Connection) -> None: """Intercept rollback() events, as initiated by a :class:`.Transaction`. @@ -584,7 +648,7 @@ class ConnectionEvents(event.Events): """ - def commit(self, conn): + def commit(self, conn: Connection) -> None: """Intercept commit() events, as initiated by a :class:`.Transaction`. @@ -596,7 +660,7 @@ class ConnectionEvents(event.Events): :param conn: :class:`_engine.Connection` object """ - def savepoint(self, conn, name): + def savepoint(self, conn: Connection, name: str) -> None: """Intercept savepoint() events. :param conn: :class:`_engine.Connection` object @@ -604,7 +668,9 @@ class ConnectionEvents(event.Events): """ - def rollback_savepoint(self, conn, name, context): + def rollback_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: """Intercept rollback_savepoint() events. :param conn: :class:`_engine.Connection` object @@ -614,7 +680,9 @@ class ConnectionEvents(event.Events): """ # TODO: deprecate "context" - def release_savepoint(self, conn, name, context): + def release_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: """Intercept release_savepoint() events. :param conn: :class:`_engine.Connection` object @@ -624,7 +692,7 @@ class ConnectionEvents(event.Events): """ # TODO: deprecate "context" - def begin_twophase(self, conn, xid): + def begin_twophase(self, conn: Connection, xid: Any) -> None: """Intercept begin_twophase() events. :param conn: :class:`_engine.Connection` object @@ -632,14 +700,16 @@ class ConnectionEvents(event.Events): """ - def prepare_twophase(self, conn, xid): + def prepare_twophase(self, conn: Connection, xid: Any) -> None: """Intercept prepare_twophase() events. :param conn: :class:`_engine.Connection` object :param xid: two-phase XID identifier """ - def rollback_twophase(self, conn, xid, is_prepared): + def rollback_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: """Intercept rollback_twophase() events. :param conn: :class:`_engine.Connection` object @@ -649,7 +719,9 @@ class ConnectionEvents(event.Events): """ - def commit_twophase(self, conn, xid, is_prepared): + def commit_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: """Intercept commit_twophase() events. :param conn: :class:`_engine.Connection` object @@ -660,7 +732,7 @@ class ConnectionEvents(event.Events): """ -class DialectEvents(event.Events): +class DialectEvents(event.Events[Dialect]): """event interface for execution-replacement functions. These events allow direct instrumentation and replacement @@ -694,14 +766,20 @@ class DialectEvents(event.Events): _dispatch_target = Dialect @classmethod - def _listen(cls, event_key, retval=False): + def _listen( # type: ignore + cls, + event_key: event._EventKey[Dialect], + retval: bool = False, + ) -> None: target = event_key.dispatch_target target._has_events = True event_key.base_listen() @classmethod - def _accept_with(cls, target): + def _accept_with( + cls, target: Union[Engine, Type[Engine], Dialect, Type[Dialect]] + ) -> Union[Dialect, Type[Dialect]]: if isinstance(target, type): if issubclass(target, Engine): return Dialect @@ -712,7 +790,13 @@ class DialectEvents(event.Events): else: return target - def do_connect(self, dialect, conn_rec, cargs, cparams): + def do_connect( + self, + dialect: Dialect, + conn_rec: ConnectionPoolEntry, + cargs: Tuple[Any, ...], + cparams: Dict[str, Any], + ) -> Optional[DBAPIConnection]: """Receive connection arguments before a connection is made. This event is useful in that it allows the handler to manipulate the @@ -745,7 +829,13 @@ class DialectEvents(event.Events): """ - def do_executemany(self, cursor, statement, parameters, context): + def do_executemany( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: """Receive a cursor to have executemany() called. Return the value True to halt further events from invoking, @@ -754,7 +844,9 @@ class DialectEvents(event.Events): """ - def do_execute_no_params(self, cursor, statement, context): + def do_execute_no_params( + self, cursor: DBAPICursor, statement: str, context: ExecutionContext + ) -> Optional[Literal[True]]: """Receive a cursor to have execute() with no parameters called. Return the value True to halt further events from invoking, @@ -763,7 +855,13 @@ class DialectEvents(event.Events): """ - def do_execute(self, cursor, statement, parameters, context): + def do_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: """Receive a cursor to have execute() called. Return the value True to halt further events from invoking, @@ -773,8 +871,13 @@ class DialectEvents(event.Events): """ def do_setinputsizes( - self, inputsizes, cursor, statement, parameters, context - ): + self, + inputsizes: Dict[BindParameter[Any], Any], + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: ExecutionContext, + ) -> None: """Receive the setinputsizes dictionary for possible modification. This event is emitted in the case where the dialect makes use of the diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 860c1faf9..545dd0ddc 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -10,21 +10,31 @@ from __future__ import annotations from enum import Enum +from types import ModuleType from typing import Any +from typing import Awaitable from typing import Callable from typing import Dict from typing import List from typing import Mapping +from typing import MutableMapping from typing import Optional from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union +from .. import util +from ..event import EventTarget +from ..pool import Pool from ..pool import PoolProxiedConnection +from ..sql.compiler import Compiled as Compiled from ..sql.compiler import Compiled # noqa +from ..sql.compiler import TypeCompiler as TypeCompiler from ..sql.compiler import TypeCompiler # noqa +from ..util import immutabledict from ..util.concurrency import await_only from ..util.typing import _TypeToInstance from ..util.typing import NotRequired @@ -34,12 +44,33 @@ from ..util.typing import TypedDict if TYPE_CHECKING: from .base import Connection from .base import Engine + from .result import Result from .url import URL + from ..event import _ListenerFnType + from ..event import dispatcher + from ..exc import StatementError + from ..sql import Executable from ..sql.compiler import DDLCompiler from ..sql.compiler import IdentifierPreparer + from ..sql.compiler import Linting from ..sql.compiler import SQLCompiler + from ..sql.elements import ClauseElement + from ..sql.schema import Column + from ..sql.schema import ColumnDefault from ..sql.type_api import TypeEngine +ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]] + +_T = TypeVar("_T", bound="Any") + + +class CacheStats(Enum): + CACHE_HIT = 0 + CACHE_MISS = 1 + CACHING_DISABLED = 2 + NO_CACHE_KEY = 3 + NO_DIALECT_SUPPORT = 4 + class DBAPIConnection(Protocol): """protocol representing a :pep:`249` database connection. @@ -65,6 +96,8 @@ class DBAPIConnection(Protocol): def rollback(self) -> None: ... + autocommit: bool + class DBAPIType(Protocol): """protocol representing a :pep:`249` database type. @@ -128,14 +161,14 @@ class DBAPICursor(Protocol): def execute( self, operation: Any, - parameters: Optional[Union[Sequence[Any], Mapping[str, Any]]], + parameters: Optional[_DBAPISingleExecuteParams], ) -> Any: ... def executemany( self, operation: Any, - parameters: Sequence[Union[Sequence[Any], Mapping[str, Any]]], + parameters: Sequence[_DBAPIMultiExecuteParams], ) -> Any: ... @@ -161,6 +194,34 @@ class DBAPICursor(Protocol): ... +_CoreSingleExecuteParams = Mapping[str, Any] +_CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams] +_CoreAnyExecuteParams = Union[ + _CoreMultiExecuteParams, _CoreSingleExecuteParams +] + +_DBAPISingleExecuteParams = Union[Sequence[Any], _CoreSingleExecuteParams] + +_DBAPIMultiExecuteParams = Union[ + Sequence[Sequence[Any]], _CoreMultiExecuteParams +] +_DBAPIAnyExecuteParams = Union[ + _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams +] +_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any] + +_AnySingleExecuteParams = _DBAPISingleExecuteParams +_AnyMultiExecuteParams = _DBAPIMultiExecuteParams +_AnyExecuteParams = _DBAPIAnyExecuteParams + + +_ExecuteOptions = immutabledict[str, Any] +_ExecuteOptionsParameter = Mapping[str, Any] +_SchemaTranslateMapType = Mapping[str, str] + +_ImmutableExecuteOptions = immutabledict[str, Any] + + class ReflectedIdentity(TypedDict): """represent the reflected IDENTITY structure of a column, corresponding to the :class:`_schema.Identity` construct. @@ -237,7 +298,7 @@ class ReflectedColumn(TypedDict): name: str """column name""" - type: "TypeEngine" + type: TypeEngine[Any] """column type represented as a :class:`.TypeEngine` instance.""" nullable: bool @@ -465,7 +526,10 @@ class BindTyping(Enum): """ -class Dialect: +VersionInfoType = Tuple[Union[int, str], ...] + + +class Dialect(EventTarget): """Define the behavior of a specific database and DB-API combination. Any aspect of metadata definition, SQL query generation, @@ -481,6 +545,8 @@ class Dialect: """ + dispatch: dispatcher[Dialect] + name: str """identifying name for the dialect from a DBAPI-neutral point of view (i.e. 'sqlite') @@ -489,6 +555,29 @@ class Dialect: driver: str """identifying name for the dialect's DBAPI""" + dbapi: ModuleType + """A reference to the DBAPI module object itself. + + SQLAlchemy dialects import DBAPI modules using the classmethod + :meth:`.Dialect.import_dbapi`. The rationale is so that any dialect + module can be imported and used to generate SQL statements without the + need for the actual DBAPI driver to be installed. Only when an + :class:`.Engine` is constructed using :func:`.create_engine` does the + DBAPI get imported; at that point, the creation process will assign + the DBAPI module to this attribute. + + Dialects should therefore implement :meth:`.Dialect.import_dbapi` + which will import the necessary module and return it, and then refer + to ``self.dbapi`` in dialect code in order to refer to the DBAPI module + contents. + + .. versionchanged:: The :attr:`.Dialect.dbapi` attribute is exclusively + used as the per-:class:`.Dialect`-instance reference to the DBAPI + module. The previous not-fully-documented ``.Dialect.dbapi()`` + classmethod is deprecated and replaced by :meth:`.Dialect.import_dbapi`. + + """ + positional: bool """True if the paramstyle for this Dialect is positional.""" @@ -497,21 +586,23 @@ class Dialect: paramstyles). """ - statement_compiler: Type["SQLCompiler"] + compiler_linting: Linting + + statement_compiler: Type[SQLCompiler] """a :class:`.Compiled` class used to compile SQL statements""" - ddl_compiler: Type["DDLCompiler"] + ddl_compiler: Type[DDLCompiler] """a :class:`.Compiled` class used to compile DDL statements""" - type_compiler: _TypeToInstance["TypeCompiler"] + type_compiler: _TypeToInstance[TypeCompiler] """a :class:`.Compiled` class used to compile SQL type objects""" - preparer: Type["IdentifierPreparer"] + preparer: Type[IdentifierPreparer] """a :class:`.IdentifierPreparer` class used to quote identifiers. """ - identifier_preparer: "IdentifierPreparer" + identifier_preparer: IdentifierPreparer """This element will refer to an instance of :class:`.IdentifierPreparer` once a :class:`.DefaultDialect` has been constructed. @@ -531,10 +622,15 @@ class Dialect: """ + default_isolation_level: str + """the isolation that is implicitly present on new connections""" + execution_ctx_cls: Type["ExecutionContext"] """a :class:`.ExecutionContext` class used to handle statement execution""" - execute_sequence_format: Union[Type[Tuple[Any, ...]], Type[List[Any]]] + execute_sequence_format: Union[ + Type[Tuple[Any, ...]], Type[Tuple[List[Any]]] + ] """either the 'tuple' or 'list' type, depending on what cursor.execute() accepts for the second argument (they vary).""" @@ -579,7 +675,7 @@ class Dialect: """ - colspecs: Dict[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]] + colspecs: MutableMapping[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]] """A dictionary of TypeEngine classes from sqlalchemy.types mapped to subclasses that are specific to the dialect class. This dictionary is class-level only and is not accessed from the @@ -610,7 +706,55 @@ class Dialect: constraint when that type is used. """ - dbapi_exception_translation_map: Dict[str, str] + construct_arguments: Optional[ + List[Tuple[Type[ClauseElement], Mapping[str, Any]]] + ] = None + """Optional set of argument specifiers for various SQLAlchemy + constructs, typically schema items. + + To implement, establish as a series of tuples, as in:: + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) + ] + + If the above construct is established on the PostgreSQL dialect, + the :class:`.Index` construct will now accept the keyword arguments + ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. + Any other argument specified to the constructor of :class:`.Index` + which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. + + A dialect which does not include a ``construct_arguments`` member will + not participate in the argument validation system. For such a dialect, + any argument name is accepted by all participating constructs, within + the namespace of arguments prefixed with that dialect name. The rationale + here is so that third-party dialects that haven't yet implemented this + feature continue to function in the old way. + + .. versionadded:: 0.9.2 + + .. seealso:: + + :class:`.DialectKWArgs` - implementing base class which consumes + :attr:`.DefaultDialect.construct_arguments` + + + """ + + reflection_options: Sequence[str] = () + """Sequence of string names indicating keyword arguments that can be + established on a :class:`.Table` object which will be passed as + "reflection options" when using :paramref:`.Table.autoload_with`. + + Current example is "oracle_resolve_synonyms" in the Oracle dialect. + + """ + + dbapi_exception_translation_map: Mapping[str, str] = util.EMPTY_DICT """A dictionary of names that will contain as values the names of pep-249 exceptions ("IntegrityError", "OperationalError", etc) keyed to alternate class names, to support the case where a @@ -660,9 +804,16 @@ class Dialect: is_async: bool """Whether or not this dialect is intended for asyncio use.""" - def create_connect_args( - self, url: "URL" - ) -> Tuple[Tuple[str], Mapping[str, Any]]: + engine_config_types: Mapping[str, Any] + """a mapping of string keys that can be in an engine config linked to + type conversion functions. + + """ + + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: + raise NotImplementedError() + + def create_connect_args(self, url: "URL") -> ConnectArgsType: """Build DB-API compatible connection arguments. Given a :class:`.URL` object, returns a tuple @@ -696,7 +847,25 @@ class Dialect: raise NotImplementedError() @classmethod - def type_descriptor(cls, typeobj: "TypeEngine") -> "TypeEngine": + def import_dbapi(cls) -> ModuleType: + """Import the DBAPI module that is used by this dialect. + + The Python module object returned here will be assigned as an + instance variable to a constructed dialect under the name + ``.dbapi``. + + .. versionchanged:: 2.0 The :meth:`.Dialect.import_dbapi` class + method is renamed from the previous method ``.Dialect.dbapi()``, + which would be replaced at dialect instantiation time by the + DBAPI module itself, thus using the same name in two different ways. + If a ``.Dialect.dbapi()`` classmethod is present on a third-party + dialect, it will be used and a deprecation warning will be emitted. + + """ + raise NotImplementedError() + + @classmethod + def type_descriptor(cls, typeobj: "TypeEngine[_T]") -> "TypeEngine[_T]": """Transform a generic type to a dialect-specific type. Dialect classes will usually use the @@ -735,7 +904,7 @@ class Dialect: connection: "Connection", table_name: str, schema: Optional[str] = None, - **kw, + **kw: Any, ) -> List[ReflectedColumn]: """Return information about columns in ``table_name``. @@ -908,11 +1077,12 @@ class Dialect: table_name: str, schema: Optional[str] = None, **kw: Any, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: r"""Return the "options" for the table identified by ``table_name`` as a dictionary. """ + return None def get_table_comment( self, @@ -1115,7 +1285,7 @@ class Dialect: def do_set_input_sizes( self, cursor: DBAPICursor, - list_of_tuples: List[Tuple[str, Any, "TypeEngine"]], + list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]], context: "ExecutionContext", ) -> Any: """invoke the cursor.setinputsizes() method with appropriate arguments @@ -1242,7 +1412,7 @@ class Dialect: raise NotImplementedError() - def do_recover_twophase(self, connection: "Connection") -> None: + def do_recover_twophase(self, connection: "Connection") -> List[Any]: """Recover list of uncommitted prepared two phase transaction identifiers on the given connection. @@ -1256,7 +1426,7 @@ class Dialect: self, cursor: DBAPICursor, statement: str, - parameters: List[Union[Dict[str, Any], Tuple[Any]]], + parameters: _DBAPIMultiExecuteParams, context: Optional["ExecutionContext"] = None, ) -> None: """Provide an implementation of ``cursor.executemany(statement, @@ -1268,9 +1438,9 @@ class Dialect: self, cursor: DBAPICursor, statement: str, - parameters: Union[Mapping[str, Any], Tuple[Any]], - context: Optional["ExecutionContext"] = None, - ): + parameters: Optional[_DBAPISingleExecuteParams], + context: Optional[ExecutionContext] = None, + ) -> None: """Provide an implementation of ``cursor.execute(statement, parameters)``.""" @@ -1281,7 +1451,7 @@ class Dialect: cursor: DBAPICursor, statement: str, context: Optional["ExecutionContext"] = None, - ): + ) -> None: """Provide an implementation of ``cursor.execute(statement)``. The parameter collection should not be sent. @@ -1294,14 +1464,14 @@ class Dialect: self, e: Exception, connection: Optional[PoolProxiedConnection], - cursor: DBAPICursor, + cursor: Optional[DBAPICursor], ) -> bool: """Return True if the given DB-API error indicates an invalid connection""" raise NotImplementedError() - def connect(self, *cargs: Any, **cparams: Any) -> Any: + def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: r"""Establish a connection using this dialect's DBAPI. The default implementation of this method is:: @@ -1333,6 +1503,7 @@ class Dialect: :meth:`.Dialect.on_connect` """ + raise NotImplementedError() def on_connect_url(self, url: "URL") -> Optional[Callable[[Any], Any]]: """return a callable which sets up a newly created DBAPI connection. @@ -1542,7 +1713,7 @@ class Dialect: raise NotImplementedError() - def get_default_isolation_level(self, dbapi_conn: Any) -> str: + def get_default_isolation_level(self, dbapi_conn: DBAPIConnection) -> str: """Given a DBAPI connection, return its isolation level, or a default isolation level if one cannot be retrieved. @@ -1562,7 +1733,9 @@ class Dialect: """ raise NotImplementedError() - def get_isolation_level_values(self, dbapi_conn: Any) -> List[str]: + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> List[str]: """return a sequence of string isolation level names that are accepted by this dialect. @@ -1604,8 +1777,13 @@ class Dialect: """ raise NotImplementedError() + def _assert_and_set_isolation_level( + self, dbapi_conn: DBAPIConnection, level: str + ) -> None: + raise NotImplementedError() + @classmethod - def get_dialect_cls(cls, url: "URL") -> Type: + def get_dialect_cls(cls, url: URL) -> Type[Dialect]: """Given a URL, return the :class:`.Dialect` that will be used. This is a hook that allows an external plugin to provide functionality @@ -1621,7 +1799,7 @@ class Dialect: return cls @classmethod - def get_async_dialect_cls(cls, url: "URL") -> None: + def get_async_dialect_cls(cls, url: URL) -> Type[Dialect]: """Given a URL, return the :class:`.Dialect` that will be used by an async engine. @@ -1702,6 +1880,39 @@ class Dialect: """ raise NotImplementedError() + def set_engine_execution_options( + self, engine: Engine, opt: _ExecuteOptionsParameter + ) -> None: + """Establish execution options for a given engine. + + This is implemented by :class:`.DefaultDialect` to establish + event hooks for new :class:`.Connection` instances created + by the given :class:`.Engine` which will then invoke the + :meth:`.Dialect.set_connection_execution_options` method for that + connection. + + """ + raise NotImplementedError() + + def set_connection_execution_options( + self, connection: Connection, opt: _ExecuteOptionsParameter + ) -> None: + """Establish execution options for a given connection. + + This is implemented by :class:`.DefaultDialect` in order to implement + the :paramref:`_engine.Connection.execution_options.isolation_level` + execution option. Dialects can intercept various execution options + which may need to modify state on a particular DBAPI connection. + + .. versionadded:: 1.4 + + """ + raise NotImplementedError() + + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: + """return a Pool class to use for a given URL""" + raise NotImplementedError() + class CreateEnginePlugin: """A set of hooks intended to augment the construction of an @@ -1878,7 +2089,7 @@ class CreateEnginePlugin: """ # noqa: E501 - def __init__(self, url, kwargs): + def __init__(self, url: URL, kwargs: Dict[str, Any]): """Construct a new :class:`.CreateEnginePlugin`. The plugin object is instantiated individually for each call @@ -1905,7 +2116,7 @@ class CreateEnginePlugin: """ self.url = url - def update_url(self, url): + def update_url(self, url: URL) -> URL: """Update the :class:`_engine.URL`. A new :class:`_engine.URL` should be returned. This method is @@ -1920,14 +2131,19 @@ class CreateEnginePlugin: .. versionadded:: 1.4 """ + raise NotImplementedError() - def handle_dialect_kwargs(self, dialect_cls, dialect_args): + def handle_dialect_kwargs( + self, dialect_cls: Type[Dialect], dialect_args: Dict[str, Any] + ) -> None: """parse and modify dialect kwargs""" - def handle_pool_kwargs(self, pool_cls, pool_args): + def handle_pool_kwargs( + self, pool_cls: Type[Pool], pool_args: Dict[str, Any] + ) -> None: """parse and modify pool kwargs""" - def engine_created(self, engine): + def engine_created(self, engine: Engine) -> None: """Receive the :class:`_engine.Engine` object when it is fully constructed. @@ -1941,56 +2157,137 @@ class ExecutionContext: """A messenger object for a Dialect that corresponds to a single execution. - ExecutionContext should have these data members: + """ - connection - Connection object which can be freely used by default value + connection: Connection + """Connection object which can be freely used by default value generators to execute SQL. This Connection should reference the same underlying connection/transactional resources of - root_connection. + root_connection.""" - root_connection - Connection object which is the source of this ExecutionContext. + root_connection: Connection + """Connection object which is the source of this ExecutionContext.""" - dialect - dialect which created this ExecutionContext. + dialect: Dialect + """dialect which created this ExecutionContext.""" - cursor - DB-API cursor procured from the connection, + cursor: DBAPICursor + """DB-API cursor procured from the connection""" - compiled - if passed to constructor, sqlalchemy.engine.base.Compiled object - being executed, + compiled: Optional[Compiled] + """if passed to constructor, sqlalchemy.engine.base.Compiled object + being executed""" - statement - string version of the statement to be executed. Is either + statement: str + """string version of the statement to be executed. Is either passed to the constructor, or must be created from the - sql.Compiled object by the time pre_exec() has completed. + sql.Compiled object by the time pre_exec() has completed.""" - parameters - bind parameters passed to the execute() method. For compiled - statements, this is a dictionary or list of dictionaries. For - textual statements, it should be in a format suitable for the - dialect's paramstyle (i.e. dict or list of dicts for non - positional, list or list of lists/tuples for positional). + invoked_statement: Optional[Executable] + """The Executable statement object that was given in the first place. - isinsert - True if the statement is an INSERT. + This should be structurally equivalent to compiled.statement, but not + necessarily the same object as in a caching scenario the compiled form + will have been extracted from the cache. - isupdate - True if the statement is an UPDATE. + """ - prefetch_cols - a list of Column objects for which a client-side default - was fired off. Applies to inserts and updates. + parameters: _AnyMultiExecuteParams + """bind parameters passed to the execute() or exec_driver_sql() methods. + + These are always stored as a list of parameter entries. A single-element + list corresponds to a ``cursor.execute()`` call and a multiple-element + list corresponds to ``cursor.executemany()``. - postfetch_cols - a list of Column objects for which a server-side default or - inline SQL expression value was fired off. Applies to inserts - and updates. """ - def create_cursor(self): + no_parameters: bool + """True if the execution style does not use parameters""" + + isinsert: bool + """True if the statement is an INSERT.""" + + isupdate: bool + """True if the statement is an UPDATE.""" + + executemany: bool + """True if the parameters have determined this to be an executemany""" + + prefetch_cols: Optional[Sequence[Column[Any]]] + """a list of Column objects for which a client-side default + was fired off. Applies to inserts and updates.""" + + postfetch_cols: Optional[Sequence[Column[Any]]] + """a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates.""" + + @classmethod + def _init_ddl( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_compiled( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: _CoreSingleExecuteParams, + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_statement( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_default( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: + raise NotImplementedError() + + def _exec_default( + self, + column: Optional[Column[Any]], + default: ColumnDefault, + type_: Optional[TypeEngine[Any]], + ) -> Any: + raise NotImplementedError() + + def _set_input_sizes(self) -> None: + raise NotImplementedError() + + def _get_cache_stats(self) -> str: + raise NotImplementedError() + + def _setup_result_proxy(self) -> Result: + raise NotImplementedError() + + def create_cursor(self) -> DBAPICursor: """Return a new cursor generated from this ExecutionContext's connection. @@ -2001,7 +2298,7 @@ class ExecutionContext: raise NotImplementedError() - def pre_exec(self): + def pre_exec(self) -> None: """Called before an execution of a compiled statement. If a compiled statement was passed to this ExecutionContext, @@ -2011,7 +2308,9 @@ class ExecutionContext: raise NotImplementedError() - def get_out_parameter_values(self, out_param_names): + def get_out_parameter_values( + self, out_param_names: Sequence[str] + ) -> Sequence[Any]: """Return a sequence of OUT parameter values from a cursor. For dialects that support OUT parameters, this method will be called @@ -2045,7 +2344,7 @@ class ExecutionContext: """ raise NotImplementedError() - def post_exec(self): + def post_exec(self) -> None: """Called after the execution of a compiled statement. If a compiled statement was passed to this ExecutionContext, @@ -2055,20 +2354,20 @@ class ExecutionContext: raise NotImplementedError() - def handle_dbapi_exception(self, e): + def handle_dbapi_exception(self, e: BaseException) -> None: """Receive a DBAPI exception which occurred upon execute, result fetch, etc.""" raise NotImplementedError() - def lastrow_has_defaults(self): + def lastrow_has_defaults(self) -> bool: """Return True if the last INSERT or UPDATE row contained inlined or database-side defaults. """ raise NotImplementedError() - def get_rowcount(self): + def get_rowcount(self) -> Optional[int]: """Return the DBAPI ``cursor.rowcount`` value, or in some cases an interpreted value. @@ -2079,7 +2378,7 @@ class ExecutionContext: raise NotImplementedError() -class ConnectionEventsTarget: +class ConnectionEventsTarget(EventTarget): """An object which can accept events from :class:`.ConnectionEvents`. Includes :class:`_engine.Connection` and :class:`_engine.Engine`. @@ -2088,6 +2387,11 @@ class ConnectionEventsTarget: """ + dispatch: dispatcher[ConnectionEventsTarget] + + +Connectable = ConnectionEventsTarget + class ExceptionContext: """Encapsulate information about an error condition in progress. @@ -2101,7 +2405,7 @@ class ExceptionContext: """ - connection = None + connection: Optional[Connection] """The :class:`_engine.Connection` in use during the exception. This member is present, except in the case of a failure when @@ -2114,7 +2418,7 @@ class ExceptionContext: """ - engine = None + engine: Optional[Engine] """The :class:`_engine.Engine` in use during the exception. This member should always be present, even in the case of a failure @@ -2124,35 +2428,35 @@ class ExceptionContext: """ - cursor = None + cursor: Optional[DBAPICursor] """The DBAPI cursor object. May be None. """ - statement = None + statement: Optional[str] """String SQL statement that was emitted directly to the DBAPI. May be None. """ - parameters = None + parameters: Optional[_DBAPIAnyExecuteParams] """Parameter collection that was emitted directly to the DBAPI. May be None. """ - original_exception = None + original_exception: BaseException """The exception object which was caught. This member is always present. """ - sqlalchemy_exception = None + sqlalchemy_exception: Optional[StatementError] """The :class:`sqlalchemy.exc.StatementError` which wraps the original, and will be raised if exception handling is not circumvented by the event. @@ -2162,7 +2466,7 @@ class ExceptionContext: """ - chained_exception = None + chained_exception: Optional[BaseException] """The exception that was returned by the previous handler in the exception chain, if any. @@ -2173,7 +2477,7 @@ class ExceptionContext: """ - execution_context = None + execution_context: Optional[ExecutionContext] """The :class:`.ExecutionContext` corresponding to the execution operation in progress. @@ -2193,7 +2497,7 @@ class ExceptionContext: """ - is_disconnect = None + is_disconnect: bool """Represent whether the exception as occurred represents a "disconnect" condition. @@ -2218,7 +2522,7 @@ class ExceptionContext: """ - invalidate_pool_on_disconnect = True + invalidate_pool_on_disconnect: bool """Represent whether all connections in the pool should be invalidated when a "disconnect" condition is in effect. @@ -2250,12 +2554,14 @@ class AdaptedConnection: __slots__ = ("_connection",) + _connection: Any + @property - def driver_connection(self): + def driver_connection(self) -> Any: """The connection object as returned by the driver after a connect.""" return self._connection - def run_async(self, fn): + def run_async(self, fn: Callable[[Any], Awaitable[_T]]) -> _T: """Run the awaitable returned by the given function, which is passed the raw asyncio driver connection. @@ -2284,5 +2590,5 @@ class AdaptedConnection: """ return await_only(fn(self._connection)) - def __repr__(self): + def __repr__(self) -> str: return "<AdaptedConnection %s>" % self._connection diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index 76e77a3f3..a0ba96603 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -8,40 +8,69 @@ from __future__ import annotations from operator import attrgetter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import Type +from typing import Union from . import url as _url from .. import util +if typing.TYPE_CHECKING: + from .base import Connection + from .base import Engine + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _ExecuteOptionsParameter + from .interfaces import Dialect + from .url import URL + from ..sql.base import Executable + from ..sql.ddl import DDLElement + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.schema import HasSchemaAttr + + class MockConnection: - def __init__(self, dialect, execute): + def __init__(self, dialect: Dialect, execute: Callable[..., Any]): self._dialect = dialect - self.execute = execute + self._execute_impl = execute - engine = property(lambda s: s) - dialect = property(attrgetter("_dialect")) - name = property(lambda s: s._dialect.name) + engine: Engine = cast(Any, property(lambda s: s)) + dialect: Dialect = cast(Any, property(attrgetter("_dialect"))) + name: str = cast(Any, property(lambda s: s._dialect.name)) - def connect(self, **kwargs): + def connect(self, **kwargs: Any) -> MockConnection: return self - def schema_for_object(self, obj): + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: return obj.schema - def execution_options(self, **kw): + def execution_options(self, **kw: Any) -> MockConnection: return self def _run_ddl_visitor( - self, visitorcallable, element, connection=None, **kwargs - ): + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: DDLElement, + **kwargs: Any, + ) -> None: kwargs["checkfirst"] = False visitorcallable(self.dialect, self, **kwargs).traverse_single(element) - def execute(self, object_, *multiparams, **params): - raise NotImplementedError() + def execute( + self, + obj: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: + return self._execute_impl(obj, parameters) -def create_mock_engine(url, executor, **kw): +def create_mock_engine(url: URL, executor: Any, **kw: Any) -> MockConnection: """Create a "mock" engine used for echoing DDL. This is a utility function used for debugging or storing the output of DDL @@ -96,6 +125,6 @@ def create_mock_engine(url, executor, **kw): dialect_args[k] = kw.pop(k) # create dialect - dialect = dialect_cls(**dialect_args) + dialect = dialect_cls(**dialect_args) # type: ignore return MockConnection(dialect, executor) diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py index 398c1fa36..7a6a57c03 100644 --- a/lib/sqlalchemy/engine/processors.py +++ b/lib/sqlalchemy/engine/processors.py @@ -14,9 +14,20 @@ They all share one common characteristic: None is passed through unchanged. """ from __future__ import annotations +import typing + from ._py_processors import str_to_datetime_processor_factory # noqa +from ..util._has_cy import HAS_CYEXTENSION -try: +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_processors import int_to_boolean # noqa + from ._py_processors import str_to_date # noqa + from ._py_processors import str_to_datetime # noqa + from ._py_processors import str_to_time # noqa + from ._py_processors import to_decimal_processor_factory # noqa + from ._py_processors import to_float # noqa + from ._py_processors import to_str # noqa +else: from sqlalchemy.cyextension.processors import ( DecimalResultProcessor, ) # noqa @@ -34,12 +45,3 @@ try: # Decimal('5.00000') whereas the C implementation will # return Decimal('5'). These are equivalent of course. return DecimalResultProcessor(target_class, "%%.%df" % scale).process - -except ImportError: - from ._py_processors import int_to_boolean # noqa - from ._py_processors import str_to_date # noqa - from ._py_processors import str_to_datetime # noqa - from ._py_processors import str_to_time # noqa - from ._py_processors import to_decimal_processor_factory # noqa - from ._py_processors import to_float # noqa - from ._py_processors import to_str # noqa diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 3ba1ae519..0951d5770 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -9,13 +9,26 @@ from __future__ import annotations -import collections.abc as collections_abc +from enum import Enum import functools import itertools import operator import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TypeVar +from typing import Union from .row import Row +from .row import RowMapping from .. import exc from .. import util from ..sql.base import _generative @@ -25,9 +38,42 @@ from ..util._has_cy import HAS_CYEXTENSION if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_row import tuplegetter + from ._py_row import tuplegetter as tuplegetter else: - from sqlalchemy.cyextension.resultproxy import tuplegetter + from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter + +if typing.TYPE_CHECKING: + from .row import RowMapping + from ..sql.schema import Column + +_KeyType = Union[str, "Column[Any]"] +_KeyIndexType = Union[str, "Column[Any]", int] + +# is overridden in cursor using _CursorKeyMapRecType +_KeyMapRecType = Any + +_KeyMapType = Dict[_KeyType, _KeyMapRecType] + + +_RowData = Union[Row, RowMapping, Any] +"""A generic form of "row" that accommodates for the different kinds of +"rows" that different result objects return, including row, row mapping, and +scalar values""" + +_RawRowType = Tuple[Any, ...] +"""represents the kind of row we get from a DBAPI cursor""" + +_InterimRowType = Union[Row, RowMapping, Any, _RawRowType] +"""a catchall "anything" kind of return type that can be applied +across all the result types + +""" + +_ProcessorType = Callable[[Any], Any] +_ProcessorsType = Sequence[Optional[_ProcessorType]] +_TupleGetterType = Callable[[Sequence[Any]], Tuple[Any, ...]] +_UniqueFilterType = Callable[[Any], Any] +_UniqueFilterStateType = Tuple[Set[Any], Optional[_UniqueFilterType]] class ResultMetaData: @@ -35,40 +81,58 @@ class ResultMetaData: __slots__ = () - _tuplefilter = None - _translated_indexes = None - _unique_filters = None + _tuplefilter: Optional[_TupleGetterType] = None + _translated_indexes: Optional[Sequence[int]] = None + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None + _keymap: _KeyMapType + _keys: Sequence[str] + _processors: Optional[_ProcessorsType] @property - def keys(self): + def keys(self) -> RMKeyView: return RMKeyView(self) - def _has_key(self, key): + def _has_key(self, key: object) -> bool: raise NotImplementedError() - def _for_freeze(self): + def _for_freeze(self) -> ResultMetaData: raise NotImplementedError() - def _key_fallback(self, key, err, raiseerr=True): + def _key_fallback( + self, key: _KeyType, err: Exception, raiseerr: bool = True + ) -> NoReturn: assert raiseerr raise KeyError(key) from err - def _raise_for_nonint(self, key): - raise TypeError( - "TypeError: tuple indices must be integers or slices, not %s" - % type(key).__name__ + def _raise_for_ambiguous_column_name( + self, rec: _KeyMapRecType + ) -> NoReturn: + raise NotImplementedError( + "ambiguous column name logic is implemented for " + "CursorResultMetaData" ) - def _index_for_key(self, keys, raiseerr): + def _index_for_key( + self, key: _KeyIndexType, raiseerr: bool + ) -> Optional[int]: raise NotImplementedError() - def _metadata_for_keys(self, key): + def _indexes_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Sequence[int]: raise NotImplementedError() - def _reduce(self, keys): + def _metadata_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Iterator[_KeyMapRecType]: raise NotImplementedError() - def _getter(self, key, raiseerr=True): + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + raise NotImplementedError() + + def _getter( + self, key: Any, raiseerr: bool = True + ) -> Optional[Callable[[Sequence[_RowData]], _RowData]]: index = self._index_for_key(key, raiseerr) @@ -77,28 +141,33 @@ class ResultMetaData: else: return None - def _row_as_tuple_getter(self, keys): + def _row_as_tuple_getter( + self, keys: Sequence[_KeyIndexType] + ) -> _TupleGetterType: indexes = self._indexes_for_keys(keys) return tuplegetter(*indexes) -class RMKeyView(collections_abc.KeysView): +class RMKeyView(typing.KeysView[Any]): __slots__ = ("_parent", "_keys") - def __init__(self, parent): + _parent: ResultMetaData + _keys: Sequence[str] + + def __init__(self, parent: ResultMetaData): self._parent = parent self._keys = [k for k in parent._keys if k is not None] - def __len__(self): + def __len__(self) -> int: return len(self._keys) - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0._keys!r})".format(self) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._keys) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: if isinstance(item, int): return False @@ -106,10 +175,10 @@ class RMKeyView(collections_abc.KeysView): # which also don't seem to be tested in test_resultset right now return self._parent._has_key(item) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return list(other) == list(self) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return list(other) != list(self) @@ -125,20 +194,21 @@ class SimpleResultMetaData(ResultMetaData): "_unique_filters", ) + _keys: Sequence[str] + def __init__( self, - keys, - extra=None, - _processors=None, - _tuplefilter=None, - _translated_indexes=None, - _unique_filters=None, + keys: Sequence[str], + extra: Optional[Sequence[Any]] = None, + _processors: Optional[_ProcessorsType] = None, + _tuplefilter: Optional[_TupleGetterType] = None, + _translated_indexes: Optional[Sequence[int]] = None, + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None, ): self._keys = list(keys) self._tuplefilter = _tuplefilter self._translated_indexes = _translated_indexes self._unique_filters = _unique_filters - if extra: recs_names = [ ( @@ -157,10 +227,10 @@ class SimpleResultMetaData(ResultMetaData): self._processors = _processors - def _has_key(self, key): + def _has_key(self, key: object) -> bool: return key in self._keymap - def _for_freeze(self): + def _for_freeze(self) -> ResultMetaData: unique_filters = self._unique_filters if unique_filters and self._tuplefilter: unique_filters = self._tuplefilter(unique_filters) @@ -173,28 +243,28 @@ class SimpleResultMetaData(ResultMetaData): _unique_filters=unique_filters, ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "_keys": self._keys, "_translated_indexes": self._translated_indexes, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: if state["_translated_indexes"]: _translated_indexes = state["_translated_indexes"] _tuplefilter = tuplegetter(*_translated_indexes) else: _translated_indexes = _tuplefilter = None - self.__init__( + self.__init__( # type: ignore state["_keys"], _translated_indexes=_translated_indexes, _tuplefilter=_tuplefilter, ) - def _contains(self, value, row): + def _contains(self, value: Any, row: Row) -> bool: return value in row._data - def _index_for_key(self, key, raiseerr=True): + def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: if int in key.__class__.__mro__: key = self._keys[key] try: @@ -202,12 +272,14 @@ class SimpleResultMetaData(ResultMetaData): except KeyError as ke: rec = self._key_fallback(key, ke, raiseerr) - return rec[0] + return rec[0] # type: ignore[no-any-return] - def _indexes_for_keys(self, keys): + def _indexes_for_keys(self, keys: Sequence[Any]) -> Sequence[int]: return [self._keymap[key][0] for key in keys] - def _metadata_for_keys(self, keys): + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_KeyMapRecType]: for key in keys: if int in key.__class__.__mro__: key = self._keys[key] @@ -219,7 +291,7 @@ class SimpleResultMetaData(ResultMetaData): yield rec - def _reduce(self, keys): + def _reduce(self, keys: Sequence[Any]) -> ResultMetaData: try: metadata_for_keys = [ self._keymap[ @@ -230,7 +302,10 @@ class SimpleResultMetaData(ResultMetaData): except KeyError as ke: self._key_fallback(ke.args[0], ke, True) - indexes, new_keys, extra = zip(*metadata_for_keys) + indexes: Sequence[int] + new_keys: Sequence[str] + extra: Sequence[Any] + indexes, new_keys, extra = zip(*metadata_for_keys) # type: ignore if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] @@ -249,7 +324,9 @@ class SimpleResultMetaData(ResultMetaData): return new_metadata -def result_tuple(fields, extra=None): +def result_tuple( + fields: Sequence[str], extra: Optional[Any] = None +) -> Callable[[_RawRowType], Row]: parent = SimpleResultMetaData(fields, extra) return functools.partial( Row, parent, parent._processors, parent._keymap, Row._default_key_style @@ -259,31 +336,58 @@ def result_tuple(fields, extra=None): # a symbol that indicates to internal Result methods that # "no row is returned". We can't use None for those cases where a scalar # filter is applied to rows. -_NO_ROW = util.symbol("NO_ROW") +class _NoRow(Enum): + _NO_ROW = 0 -SelfResultInternal = typing.TypeVar( - "SelfResultInternal", bound="ResultInternal" -) + +_NO_ROW = _NoRow._NO_ROW + +SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal") class ResultInternal(InPlaceGenerative): - _real_result = None - _generate_rows = True - _unique_filter_state = None - _post_creational_filter = None + _real_result: Optional[Result] = None + _generate_rows: bool = True + _row_logging_fn: Optional[Callable[[Any], Any]] + + _unique_filter_state: Optional[_UniqueFilterStateType] = None + _post_creational_filter: Optional[Callable[[Any], Any]] = None _is_cursor = False + _metadata: ResultMetaData + + _source_supports_scalars: bool + + def _fetchiter_impl(self) -> Iterator[_InterimRowType]: + raise NotImplementedError() + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType]: + raise NotImplementedError() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType]: + raise NotImplementedError() + + def _fetchall_impl(self) -> List[_InterimRowType]: + raise NotImplementedError() + + def _soft_close(self, hard: bool = False) -> None: + raise NotImplementedError() + @HasMemoized.memoized_attribute - def _row_getter(self): + def _row_getter(self) -> Optional[Callable[..., _RowData]]: real_result = self._real_result if self._real_result else self if real_result._source_supports_scalars: if not self._generate_rows: return None else: - _proc = real_result._process_row + _proc = Row - def process_row( + def process_row( # type: ignore metadata, processors, keymap, key_style, scalar_obj ): return _proc( @@ -291,9 +395,9 @@ class ResultInternal(InPlaceGenerative): ) else: - process_row = real_result._process_row + process_row = Row # type: ignore - key_style = real_result._process_row._default_key_style + key_style = Row._default_key_style metadata = self._metadata keymap = metadata._keymap @@ -304,19 +408,19 @@ class ResultInternal(InPlaceGenerative): if processors: processors = tf(processors) - _make_row_orig = functools.partial( + _make_row_orig: Callable[..., Any] = functools.partial( process_row, metadata, processors, keymap, key_style ) - def make_row(row): - return _make_row_orig(tf(row)) + def make_row(row: _InterimRowType) -> _InterimRowType: + return _make_row_orig(tf(row)) # type: ignore else: - make_row = functools.partial( + make_row = functools.partial( # type: ignore process_row, metadata, processors, keymap, key_style ) - fns = () + fns: Tuple[Any, ...] = () if real_result._row_logging_fn: fns = (real_result._row_logging_fn,) @@ -326,16 +430,16 @@ class ResultInternal(InPlaceGenerative): if fns: _make_row = make_row - def make_row(row): - row = _make_row(row) + def make_row(row: _InterimRowType) -> _InterimRowType: + interim_row = _make_row(row) for fn in fns: - row = fn(row) - return row + interim_row = fn(interim_row) + return interim_row return make_row @HasMemoized.memoized_attribute - def _iterator_getter(self): + def _iterator_getter(self) -> Callable[..., Iterator[_RowData]]: make_row = self._row_getter @@ -344,9 +448,9 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def iterrows(self): - for row in self._fetchiter_impl(): - obj = make_row(row) if make_row else row + def iterrows(self: Result) -> Iterator[_RowData]: + for raw_row in self._fetchiter_impl(): + obj = make_row(raw_row) if make_row else raw_row hashed = strategy(obj) if strategy else obj if hashed in uniques: continue @@ -357,27 +461,29 @@ class ResultInternal(InPlaceGenerative): else: - def iterrows(self): + def iterrows(self: Result) -> Iterator[_RowData]: for row in self._fetchiter_impl(): - row = make_row(row) if make_row else row + row = make_row(row) if make_row else row # type: ignore if post_creational_filter: row = post_creational_filter(row) yield row return iterrows - def _raw_all_rows(self): + def _raw_all_rows(self) -> List[_RowData]: make_row = self._row_getter + assert make_row is not None rows = self._fetchall_impl() return [make_row(row) for row in rows] - def _allrows(self): + def _allrows(self) -> List[_RowData]: post_creational_filter = self._post_creational_filter make_row = self._row_getter rows = self._fetchall_impl() + made_rows: List[_InterimRowType] if make_row: made_rows = [make_row(row) for row in rows] else: @@ -386,7 +492,7 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - rows = [ + interim_rows = [ made_row for made_row, sig_row in [ ( @@ -395,17 +501,19 @@ class ResultInternal(InPlaceGenerative): ) for made_row in made_rows ] - if sig_row not in uniques and not uniques.add(sig_row) + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa E501 ] else: - rows = made_rows + interim_rows = made_rows if post_creational_filter: - rows = [post_creational_filter(row) for row in rows] - return rows + interim_rows = [ + post_creational_filter(row) for row in interim_rows + ] + return interim_rows @HasMemoized.memoized_attribute - def _onerow_getter(self): + def _onerow_getter(self) -> Callable[..., Union[_NoRow, _RowData]]: make_row = self._row_getter post_creational_filter = self._post_creational_filter @@ -413,7 +521,7 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def onerow(self): + def onerow(self: Result) -> Union[_NoRow, _RowData]: _onerow = self._fetchone_impl while True: row = _onerow() @@ -432,20 +540,22 @@ class ResultInternal(InPlaceGenerative): else: - def onerow(self): + def onerow(self: Result) -> Union[_NoRow, _RowData]: row = self._fetchone_impl() if row is None: return _NO_ROW else: - row = make_row(row) if make_row else row + interim_row: _InterimRowType = ( + make_row(row) if make_row else row + ) if post_creational_filter: - row = post_creational_filter(row) - return row + interim_row = post_creational_filter(interim_row) + return interim_row return onerow @HasMemoized.memoized_attribute - def _manyrow_getter(self): + def _manyrow_getter(self) -> Callable[..., List[_RowData]]: make_row = self._row_getter post_creational_filter = self._post_creational_filter @@ -453,7 +563,12 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def filterrows(make_row, rows, strategy, uniques): + def filterrows( + make_row: Optional[Callable[..., _RowData]], + rows: List[Any], + strategy: Optional[Callable[[Sequence[Any]], Any]], + uniques: Set[Any], + ) -> List[Row]: if make_row: rows = [make_row(row) for row in rows] @@ -466,11 +581,11 @@ class ResultInternal(InPlaceGenerative): return [ made_row for made_row, sig_row in made_rows - if sig_row not in uniques and not uniques.add(sig_row) + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa: E501 ] - def manyrows(self, num): - collect = [] + def manyrows(self: Result, num: Optional[int]) -> List[_RowData]: + collect: List[_RowData] = [] _manyrows = self._fetchmany_impl @@ -488,6 +603,7 @@ class ResultInternal(InPlaceGenerative): else: rows = _manyrows(num) num = len(rows) + assert make_row is not None collect.extend( filterrows(make_row, rows, strategy, uniques) ) @@ -495,6 +611,8 @@ class ResultInternal(InPlaceGenerative): else: num_required = num + assert num is not None + while num_required: rows = _manyrows(num_required) if not rows: @@ -511,14 +629,14 @@ class ResultInternal(InPlaceGenerative): else: - def manyrows(self, num): + def manyrows(self: Result, num: Optional[int]) -> List[_RowData]: if num is None: real_result = ( self._real_result if self._real_result else self ) num = real_result._yield_per - rows = self._fetchmany_impl(num) + rows: List[_InterimRowType] = self._fetchmany_impl(num) if make_row: rows = [make_row(row) for row in rows] if post_creational_filter: @@ -529,13 +647,13 @@ class ResultInternal(InPlaceGenerative): def _only_one_row( self, - raise_for_second_row, - raise_for_none, - scalar, - ): + raise_for_second_row: bool, + raise_for_none: bool, + scalar: bool, + ) -> Optional[_RowData]: onerow = self._fetchone_impl - row = onerow(hard_close=True) + row: _InterimRowType = onerow(hard_close=True) if row is None: if raise_for_none: raise exc.NoResultFound( @@ -565,7 +683,7 @@ class ResultInternal(InPlaceGenerative): existing_row_hash = strategy(row) if strategy else row while True: - next_row = onerow(hard_close=True) + next_row: Any = onerow(hard_close=True) if next_row is None: next_row = _NO_ROW break @@ -574,6 +692,7 @@ class ResultInternal(InPlaceGenerative): next_row = make_row(next_row) if make_row else next_row if strategy: + assert next_row is not _NO_ROW if existing_row_hash == strategy(next_row): continue elif row == next_row: @@ -608,14 +727,14 @@ class ResultInternal(InPlaceGenerative): row = post_creational_filter(row) if scalar and make_row: - return row[0] + return row[0] # type: ignore else: return row - def _iter_impl(self): + def _iter_impl(self) -> Iterator[_RowData]: return self._iterator_getter(self) - def _next_impl(self): + def _next_impl(self) -> _RowData: row = self._onerow_getter(self) if row is _NO_ROW: raise StopIteration() @@ -624,11 +743,14 @@ class ResultInternal(InPlaceGenerative): @_generative def _column_slices( - self: SelfResultInternal, indexes + self: SelfResultInternal, indexes: Sequence[_KeyIndexType] ) -> SelfResultInternal: real_result = self._real_result if self._real_result else self - if real_result._source_supports_scalars and len(indexes) == 1: + if ( + real_result._source_supports_scalars # type: ignore[attr-defined] # noqa E501 + and len(indexes) == 1 + ): self._generate_rows = False else: self._generate_rows = True @@ -637,7 +759,8 @@ class ResultInternal(InPlaceGenerative): return self @HasMemoized.memoized_attribute - def _unique_strategy(self): + def _unique_strategy(self) -> _UniqueFilterStateType: + assert self._unique_filter_state is not None uniques, strategy = self._unique_filter_state real_result = ( @@ -660,8 +783,10 @@ class ResultInternal(InPlaceGenerative): class _WithKeys: + _metadata: ResultMetaData + # used mainly to share documentation on the keys method. - def keys(self): + def keys(self) -> RMKeyView: """Return an iterable view which yields the string keys that would be represented by each :class:`.Row`. @@ -681,7 +806,7 @@ class _WithKeys: return self._metadata.keys -SelfResult = typing.TypeVar("SelfResult", bound="Result") +SelfResult = TypeVar("SelfResult", bound="Result") class Result(_WithKeys, ResultInternal): @@ -709,23 +834,18 @@ class Result(_WithKeys, ResultInternal): """ - _process_row = Row - - _row_logging_fn = None + _row_logging_fn: Optional[Callable[[Row], Row]] = None - _source_supports_scalars = False + _source_supports_scalars: bool = False - _yield_per = None + _yield_per: Optional[int] = None - _attributes = util.immutabledict() + _attributes: util.immutabledict[Any, Any] = util.immutabledict() - def __init__(self, cursor_metadata): + def __init__(self, cursor_metadata: ResultMetaData): self._metadata = cursor_metadata - def _soft_close(self, hard=False): - raise NotImplementedError() - - def close(self): + def close(self) -> None: """close this :class:`_result.Result`. The behavior of this method is implementation specific, and is @@ -748,7 +868,7 @@ class Result(_WithKeys, ResultInternal): self._soft_close(hard=True) @_generative - def yield_per(self: SelfResult, num) -> SelfResult: + def yield_per(self: SelfResult, num: int) -> SelfResult: """Configure the row-fetching strategy to fetch num rows at a time. This impacts the underlying behavior of the result when iterating over @@ -785,7 +905,9 @@ class Result(_WithKeys, ResultInternal): return self @_generative - def unique(self: SelfResult, strategy=None) -> SelfResult: + def unique( + self: SelfResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfResult: """Apply unique filtering to the objects returned by this :class:`_engine.Result`. @@ -826,7 +948,7 @@ class Result(_WithKeys, ResultInternal): return self def columns( - self: SelfResultInternal, *col_expressions + self: SelfResultInternal, *col_expressions: _KeyIndexType ) -> SelfResultInternal: r"""Establish the columns that should be returned in each row. @@ -865,7 +987,7 @@ class Result(_WithKeys, ResultInternal): """ return self._column_slices(col_expressions) - def scalars(self, index=0) -> "ScalarResult": + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult: """Return a :class:`_result.ScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -890,7 +1012,9 @@ class Result(_WithKeys, ResultInternal): """ return ScalarResult(self, index) - def _getter(self, key, raiseerr=True): + def _getter( + self, key: _KeyIndexType, raiseerr: bool = True + ) -> Optional[Callable[[Sequence[Any]], _RowData]]: """return a callable that will retrieve the given key from a :class:`.Row`. @@ -901,7 +1025,7 @@ class Result(_WithKeys, ResultInternal): ) return self._metadata._getter(key, raiseerr) - def _tuple_getter(self, keys): + def _tuple_getter(self, keys: Sequence[_KeyIndexType]) -> _TupleGetterType: """return a callable that will retrieve the given keys from a :class:`.Row`. @@ -912,7 +1036,7 @@ class Result(_WithKeys, ResultInternal): ) return self._metadata._row_as_tuple_getter(keys) - def mappings(self) -> "MappingResult": + def mappings(self) -> MappingResult: """Apply a mappings filter to returned rows, returning an instance of :class:`_result.MappingResult`. @@ -928,7 +1052,7 @@ class Result(_WithKeys, ResultInternal): return MappingResult(self) - def _raw_row_iterator(self): + def _raw_row_iterator(self) -> Iterator[_RowData]: """Return a safe iterator that yields raw row data. This is used by the :meth:`._engine.Result.merge` method @@ -937,25 +1061,13 @@ class Result(_WithKeys, ResultInternal): """ raise NotImplementedError() - def _fetchiter_impl(self): - raise NotImplementedError() - - def _fetchone_impl(self, hard_close=False): - raise NotImplementedError() - - def _fetchall_impl(self): - raise NotImplementedError() - - def _fetchmany_impl(self, size=None): - raise NotImplementedError() - - def __iter__(self): + def __iter__(self) -> Iterator[_RowData]: return self._iter_impl() - def __next__(self): + def __next__(self) -> _RowData: return self._next_impl() - def partitions(self, size=None): + def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]: """Iterate through sub-lists of rows of the size given. Each list will be of the size given, excluding the last list to @@ -989,16 +1101,16 @@ class Result(_WithKeys, ResultInternal): while True: partition = getter(self, size) if partition: - yield partition + yield partition # type: ignore else: break - def fetchall(self): + def fetchall(self) -> List[Row]: """A synonym for the :meth:`_engine.Result.all` method.""" - return self._allrows() + return self._allrows() # type: ignore[return-value] - def fetchone(self): + def fetchone(self) -> Optional[Row]: """Fetch one row. When all rows are exhausted, returns None. @@ -1018,9 +1130,9 @@ class Result(_WithKeys, ResultInternal): if row is _NO_ROW: return None else: - return row + return row # type: ignore[return-value] - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[Row]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -1035,9 +1147,9 @@ class Result(_WithKeys, ResultInternal): """ - return self._manyrow_getter(self, size) + return self._manyrow_getter(self, size) # type: ignore[return-value] - def all(self): + def all(self) -> List[Row]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -1049,9 +1161,9 @@ class Result(_WithKeys, ResultInternal): """ - return self._allrows() + return self._allrows() # type: ignore[return-value] - def first(self): + def first(self) -> Optional[Row]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -1083,11 +1195,11 @@ class Result(_WithKeys, ResultInternal): """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self): + def one_or_none(self) -> Optional[Row]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -1107,11 +1219,11 @@ class Result(_WithKeys, ResultInternal): :meth:`_result.Result.one` """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=False, scalar=False ) - def scalar_one(self): + def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`.Result.scalars` and then @@ -1128,7 +1240,7 @@ class Result(_WithKeys, ResultInternal): raise_for_second_row=True, raise_for_none=True, scalar=True ) - def scalar_one_or_none(self): + def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`.Result.scalars` and then @@ -1145,7 +1257,7 @@ class Result(_WithKeys, ResultInternal): raise_for_second_row=True, raise_for_none=False, scalar=True ) - def one(self): + def one(self) -> Row: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -1172,11 +1284,11 @@ class Result(_WithKeys, ResultInternal): :meth:`_result.Result.scalar_one` """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=True, scalar=False ) - def scalar(self): + def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. Returns None if there are no rows to fetch. @@ -1194,7 +1306,7 @@ class Result(_WithKeys, ResultInternal): raise_for_second_row=False, raise_for_none=False, scalar=True ) - def freeze(self): + def freeze(self) -> FrozenResult: """Return a callable object that will produce copies of this :class:`.Result` when invoked. @@ -1217,7 +1329,7 @@ class Result(_WithKeys, ResultInternal): return FrozenResult(self) - def merge(self, *others): + def merge(self, *others: Result) -> MergedResult: """Merge this :class:`.Result` with other compatible result objects. @@ -1240,28 +1352,37 @@ class FilterResult(ResultInternal): """ - _post_creational_filter = None + _post_creational_filter: Optional[Callable[[Any], Any]] = None - def _soft_close(self, hard=False): + _real_result: Result + + def _soft_close(self, hard: bool = False) -> None: self._real_result._soft_close(hard=hard) @property - def _attributes(self): + def _attributes(self) -> Dict[Any, Any]: return self._real_result._attributes - def _fetchiter_impl(self): + def _fetchiter_impl(self) -> Iterator[_InterimRowType]: return self._real_result._fetchiter_impl() - def _fetchone_impl(self, hard_close=False): + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType]: return self._real_result._fetchone_impl(hard_close=hard_close) - def _fetchall_impl(self): + def _fetchall_impl(self) -> List[_InterimRowType]: return self._real_result._fetchall_impl() - def _fetchmany_impl(self, size=None): + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType]: return self._real_result._fetchmany_impl(size=size) +SelfScalarResult = TypeVar("SelfScalarResult", bound="ScalarResult") + + class ScalarResult(FilterResult): """A wrapper for a :class:`_result.Result` that returns scalar values rather than :class:`_row.Row` values. @@ -1280,7 +1401,9 @@ class ScalarResult(FilterResult): _generate_rows = False - def __init__(self, real_result, index): + _post_creational_filter: Optional[Callable[[Any], Any]] + + def __init__(self, real_result: Result, index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -1292,7 +1415,9 @@ class ScalarResult(FilterResult): self._unique_filter_state = real_result._unique_filter_state - def unique(self, strategy=None): + def unique( + self: SelfScalarResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfScalarResult: """Apply unique filtering to the objects returned by this :class:`_engine.ScalarResult`. @@ -1302,7 +1427,7 @@ class ScalarResult(FilterResult): self._unique_filter_state = (set(), strategy) return self - def partitions(self, size=None): + def partitions(self, size: Optional[int] = None) -> Iterator[List[Any]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1320,12 +1445,12 @@ class ScalarResult(FilterResult): else: break - def fetchall(self): + def fetchall(self) -> List[Any]: """A synonym for the :meth:`_engine.ScalarResult.all` method.""" return self._allrows() - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[Any]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1335,7 +1460,7 @@ class ScalarResult(FilterResult): """ return self._manyrow_getter(self, size) - def all(self): + def all(self) -> List[Any]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1345,13 +1470,13 @@ class ScalarResult(FilterResult): """ return self._allrows() - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return self._iter_impl() - def __next__(self): + def __next__(self) -> Any: return self._next_impl() - def first(self): + def first(self) -> Optional[Any]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_result.Result.first` except that @@ -1364,7 +1489,7 @@ class ScalarResult(FilterResult): raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self): + def one_or_none(self) -> Optional[Any]: """Return at most one object or raise an exception. Equivalent to :meth:`_result.Result.one_or_none` except that @@ -1376,7 +1501,7 @@ class ScalarResult(FilterResult): raise_for_second_row=True, raise_for_none=False, scalar=False ) - def one(self): + def one(self) -> Any: """Return exactly one object or raise an exception. Equivalent to :meth:`_result.Result.one` except that @@ -1389,6 +1514,9 @@ class ScalarResult(FilterResult): ) +SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult") + + class MappingResult(_WithKeys, FilterResult): """A wrapper for a :class:`_engine.Result` that returns dictionary values rather than :class:`_engine.Row` values. @@ -1402,14 +1530,16 @@ class MappingResult(_WithKeys, FilterResult): _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result): + def __init__(self, result: Result): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata if result._source_supports_scalars: self._metadata = self._metadata._reduce([0]) - def unique(self, strategy=None): + def unique( + self: SelfMappingResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfMappingResult: """Apply unique filtering to the objects returned by this :class:`_engine.MappingResult`. @@ -1419,11 +1549,15 @@ class MappingResult(_WithKeys, FilterResult): self._unique_filter_state = (set(), strategy) return self - def columns(self, *col_expressions): + def columns( + self: SelfMappingResult, *col_expressions: _KeyIndexType + ) -> SelfMappingResult: r"""Establish the columns that should be returned in each row.""" return self._column_slices(col_expressions) - def partitions(self, size=None): + def partitions( + self, size: Optional[int] = None + ) -> Iterator[List[RowMapping]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1437,16 +1571,16 @@ class MappingResult(_WithKeys, FilterResult): while True: partition = getter(self, size) if partition: - yield partition + yield partition # type: ignore else: break - def fetchall(self): + def fetchall(self) -> List[RowMapping]: """A synonym for the :meth:`_engine.MappingResult.all` method.""" - return self._allrows() + return self._allrows() # type: ignore[return-value] - def fetchone(self): + def fetchone(self) -> Optional[RowMapping]: """Fetch one object. Equivalent to :meth:`_result.Result.fetchone` except that @@ -1459,9 +1593,9 @@ class MappingResult(_WithKeys, FilterResult): if row is _NO_ROW: return None else: - return row + return row # type: ignore[return-value] - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1470,9 +1604,9 @@ class MappingResult(_WithKeys, FilterResult): """ - return self._manyrow_getter(self, size) + return self._manyrow_getter(self, size) # type: ignore[return-value] - def all(self): + def all(self) -> List[RowMapping]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1481,15 +1615,15 @@ class MappingResult(_WithKeys, FilterResult): """ - return self._allrows() + return self._allrows() # type: ignore[return-value] - def __iter__(self): - return self._iter_impl() + def __iter__(self) -> Iterator[RowMapping]: + return self._iter_impl() # type: ignore[return-value] - def __next__(self): - return self._next_impl() + def __next__(self) -> RowMapping: + return self._next_impl() # type: ignore[return-value] - def first(self): + def first(self) -> Optional[RowMapping]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_result.Result.first` except that @@ -1498,11 +1632,11 @@ class MappingResult(_WithKeys, FilterResult): """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self): + def one_or_none(self) -> Optional[RowMapping]: """Return at most one object or raise an exception. Equivalent to :meth:`_result.Result.one_or_none` except that @@ -1510,11 +1644,11 @@ class MappingResult(_WithKeys, FilterResult): are returned. """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=False, scalar=False ) - def one(self): + def one(self) -> RowMapping: """Return exactly one object or raise an exception. Equivalent to :meth:`_result.Result.one` except that @@ -1522,7 +1656,7 @@ class MappingResult(_WithKeys, FilterResult): are returned. """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=True, scalar=False ) @@ -1566,7 +1700,9 @@ class FrozenResult: """ - def __init__(self, result): + data: Sequence[Any] + + def __init__(self, result: Result): self.metadata = result._metadata._for_freeze() self._source_supports_scalars = result._source_supports_scalars self._attributes = result._attributes @@ -1576,13 +1712,13 @@ class FrozenResult: else: self.data = result.fetchall() - def rewrite_rows(self): + def rewrite_rows(self) -> List[List[Any]]: if self._source_supports_scalars: return [[elem] for elem in self.data] else: return [list(row) for row in self.data] - def with_new_rows(self, tuple_data): + def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult: fr = FrozenResult.__new__(FrozenResult) fr.metadata = self.metadata fr._attributes = self._attributes @@ -1594,7 +1730,7 @@ class FrozenResult: fr.data = tuple_data return fr - def __call__(self): + def __call__(self) -> Result: result = IteratorResult(self.metadata, iter(self.data)) result._attributes = self._attributes result._source_supports_scalars = self._source_supports_scalars @@ -1603,7 +1739,7 @@ class FrozenResult: class IteratorResult(Result): """A :class:`.Result` that gets data from a Python iterator of - :class:`.Row` objects. + :class:`.Row` objects or similar row-like data. .. versionadded:: 1.4 @@ -1613,17 +1749,17 @@ class IteratorResult(Result): def __init__( self, - cursor_metadata, - iterator, - raw=None, - _source_supports_scalars=False, + cursor_metadata: ResultMetaData, + iterator: Iterator[_RowData], + raw: Optional[Any] = None, + _source_supports_scalars: bool = False, ): self._metadata = cursor_metadata self.iterator = iterator self.raw = raw self._source_supports_scalars = _source_supports_scalars - def _soft_close(self, hard=False, **kw): + def _soft_close(self, hard: bool = False, **kw: Any) -> None: if hard: self._hard_closed = True if self.raw is not None: @@ -1631,18 +1767,18 @@ class IteratorResult(Result): self.iterator = iter([]) self._reset_memoizations() - def _raise_hard_closed(self): + def _raise_hard_closed(self) -> NoReturn: raise exc.ResourceClosedError("This result object is closed.") - def _raw_row_iterator(self): + def _raw_row_iterator(self) -> Iterator[_RowData]: return self.iterator - def _fetchiter_impl(self): + def _fetchiter_impl(self) -> Iterator[_RowData]: if self._hard_closed: self._raise_hard_closed() return self.iterator - def _fetchone_impl(self, hard_close=False): + def _fetchone_impl(self, hard_close: bool = False) -> Optional[_RowData]: if self._hard_closed: self._raise_hard_closed() @@ -1653,27 +1789,26 @@ class IteratorResult(Result): else: return row - def _fetchall_impl(self): + def _fetchall_impl(self) -> List[_RowData]: if self._hard_closed: self._raise_hard_closed() - try: return list(self.iterator) finally: self._soft_close() - def _fetchmany_impl(self, size=None): + def _fetchmany_impl(self, size: Optional[int] = None) -> List[_RowData]: if self._hard_closed: self._raise_hard_closed() return list(itertools.islice(self.iterator, 0, size)) -def null_result(): +def null_result() -> IteratorResult: return IteratorResult(SimpleResultMetaData([]), iter([])) -SelfChunkedIteratorResult = typing.TypeVar( +SelfChunkedIteratorResult = TypeVar( "SelfChunkedIteratorResult", bound="ChunkedIteratorResult" ) @@ -1695,11 +1830,11 @@ class ChunkedIteratorResult(IteratorResult): def __init__( self, - cursor_metadata, - chunks, - source_supports_scalars=False, - raw=None, - dynamic_yield_per=False, + cursor_metadata: ResultMetaData, + chunks: Callable[[Optional[int]], Iterator[List[_InterimRowType]]], + source_supports_scalars: bool = False, + raw: Optional[Any] = None, + dynamic_yield_per: bool = False, ): self._metadata = cursor_metadata self.chunks = chunks @@ -1710,7 +1845,7 @@ class ChunkedIteratorResult(IteratorResult): @_generative def yield_per( - self: SelfChunkedIteratorResult, num + self: SelfChunkedIteratorResult, num: int ) -> SelfChunkedIteratorResult: # TODO: this throws away the iterator which may be holding # onto a chunk. the yield_per cannot be changed once any @@ -1722,11 +1857,13 @@ class ChunkedIteratorResult(IteratorResult): self.iterator = itertools.chain.from_iterable(self.chunks(num)) return self - def _soft_close(self, **kw): - super(ChunkedIteratorResult, self)._soft_close(**kw) - self.chunks = lambda size: [] + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + super(ChunkedIteratorResult, self)._soft_close(hard=hard, **kw) + self.chunks = lambda size: [] # type: ignore - def _fetchmany_impl(self, size=None): + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType]: if self.dynamic_yield_per: self.iterator = itertools.chain.from_iterable(self.chunks(size)) return super(ChunkedIteratorResult, self)._fetchmany_impl(size=size) @@ -1744,7 +1881,9 @@ class MergedResult(IteratorResult): closed = False - def __init__(self, cursor_metadata, results): + def __init__( + self, cursor_metadata: ResultMetaData, results: Sequence[Result] + ): self._results = results super(MergedResult, self).__init__( cursor_metadata, @@ -1763,7 +1902,7 @@ class MergedResult(IteratorResult): *[r._attributes for r in results] ) - def _soft_close(self, hard=False, **kw): + def _soft_close(self, hard: bool = False, **kw: Any) -> None: for r in self._results: r._soft_close(hard=hard, **kw) if hard: diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 29b2f338b..ff63199d4 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -9,24 +9,41 @@ from __future__ import annotations +from abc import ABC import collections.abc as collections_abc import operator import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Union from ..sql import util as sql_util from ..util._has_cy import HAS_CYEXTENSION if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_row import BaseRow + from ._py_row import BaseRow as BaseRow from ._py_row import KEY_INTEGER_ONLY from ._py_row import KEY_OBJECTS_ONLY else: - from sqlalchemy.cyextension.resultproxy import BaseRow + from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY +if typing.TYPE_CHECKING: + from .result import _KeyType + from .result import RMKeyView -class Row(BaseRow, collections_abc.Sequence): + +class Row(BaseRow, typing.Sequence[Any]): """Represent a single result row. The :class:`.Row` object represents a row of a database result. It is @@ -58,14 +75,14 @@ class Row(BaseRow, collections_abc.Sequence): _default_key_style = KEY_INTEGER_ONLY - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> NoReturn: raise AttributeError("can't set attribute") - def __delattr__(self, name): + def __delattr__(self, name: str) -> NoReturn: raise AttributeError("can't delete attribute") @property - def _mapping(self): + def _mapping(self) -> RowMapping: """Return a :class:`.RowMapping` for this :class:`.Row`. This object provides a consistent Python mapping (i.e. dictionary) @@ -87,31 +104,44 @@ class Row(BaseRow, collections_abc.Sequence): self._data, ) - def _special_name_accessor(name): - """Handle ambiguous names such as "count" and "index" """ + def _filter_on_values( + self, filters: Optional[Sequence[Optional[Callable[[Any], Any]]]] + ) -> Row: + return Row( + self._parent, + filters, + self._keymap, + self._key_style, + self._data, + ) + + if not typing.TYPE_CHECKING: + + def _special_name_accessor(name: str) -> Any: + """Handle ambiguous names such as "count" and "index" """ - @property - def go(self): - if self._parent._has_key(name): - return self.__getattr__(name) - else: + @property + def go(self: Row) -> Any: + if self._parent._has_key(name): + return self.__getattr__(name) + else: - def meth(*arg, **kw): - return getattr(collections_abc.Sequence, name)( - self, *arg, **kw - ) + def meth(*arg: Any, **kw: Any) -> Any: + return getattr(collections_abc.Sequence, name)( + self, *arg, **kw + ) - return meth + return meth - return go + return go - count = _special_name_accessor("count") - index = _special_name_accessor("index") + count = _special_name_accessor("count") + index = _special_name_accessor("index") - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: return key in self._data - def _op(self, other, op): + def _op(self, other: Any, op: Callable[[Any, Any], bool]) -> bool: return ( op(tuple(self), tuple(other)) if isinstance(other, Row) @@ -120,29 +150,44 @@ class Row(BaseRow, collections_abc.Sequence): __hash__ = BaseRow.__hash__ - def __lt__(self, other): + if typing.TYPE_CHECKING: + + @overload + def __getitem__(self, index: int) -> Any: + ... + + @overload + def __getitem__(self, index: slice) -> Sequence[Any]: + ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[Any, Sequence[Any]]: + ... + + def __lt__(self, other: Any) -> bool: return self._op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: Any) -> bool: return self._op(other, operator.le) - def __ge__(self, other): + def __ge__(self, other: Any) -> bool: return self._op(other, operator.ge) - def __gt__(self, other): + def __gt__(self, other: Any) -> bool: return self._op(other, operator.gt) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._op(other, operator.eq) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return self._op(other, operator.ne) - def __repr__(self): + def __repr__(self) -> str: return repr(sql_util._repr_row(self)) @property - def _fields(self): + def _fields(self) -> Tuple[str, ...]: """Return a tuple of string keys as represented by this :class:`.Row`. @@ -162,7 +207,7 @@ class Row(BaseRow, collections_abc.Sequence): """ return tuple([k for k in self._parent.keys if k is not None]) - def _asdict(self): + def _asdict(self) -> Dict[str, Any]: """Return a new dict which maps field names to their corresponding values. @@ -179,49 +224,51 @@ class Row(BaseRow, collections_abc.Sequence): """ return dict(self._mapping) - def _replace(self): - raise NotImplementedError() - - @property - def _field_defaults(self): - raise NotImplementedError() - BaseRowProxy = BaseRow RowProxy = Row -class ROMappingView( - collections_abc.KeysView, - collections_abc.ValuesView, - collections_abc.ItemsView, -): - __slots__ = ("_items",) +class ROMappingView(ABC): + __slots__ = () + + _items: Sequence[Any] + _mapping: Mapping[str, Any] - def __init__(self, mapping, items): + def __init__(self, mapping: Mapping[str, Any], items: Sequence[Any]): self._mapping = mapping self._items = items - def __len__(self): + def __len__(self) -> int: return len(self._items) - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0._mapping!r})".format(self) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self._items) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: return item in self._items - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return list(other) == list(self) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return list(other) != list(self) -class RowMapping(BaseRow, collections_abc.Mapping): +class ROMappingKeysValuesView( + ROMappingView, typing.KeysView[str], typing.ValuesView[Any] +): + __slots__ = ("_items",) + + +class ROMappingItemsView(ROMappingView, typing.ItemsView[str, Any]): + __slots__ = ("_items",) + + +class RowMapping(BaseRow, typing.Mapping[str, Any]): """A ``Mapping`` that maps column names and objects to :class:`.Row` values. The :class:`.RowMapping` is available from a :class:`.Row` via the @@ -251,31 +298,39 @@ class RowMapping(BaseRow, collections_abc.Mapping): _default_key_style = KEY_OBJECTS_ONLY - __getitem__ = BaseRow._get_by_key_impl_mapping + if typing.TYPE_CHECKING: - def _values_impl(self): + def __getitem__(self, key: _KeyType) -> Any: + ... + + else: + __getitem__ = BaseRow._get_by_key_impl_mapping + + def _values_impl(self) -> List[Any]: return list(self._data) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return (k for k in self._parent.keys if k is not None) - def __len__(self): + def __len__(self) -> int: return len(self._data) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return self._parent._has_key(key) - def __repr__(self): + def __repr__(self) -> str: return repr(dict(self)) - def items(self): + def items(self) -> ROMappingItemsView: """Return a view of key/value tuples for the elements in the underlying :class:`.Row`. """ - return ROMappingView(self, [(key, self[key]) for key in self.keys()]) + return ROMappingItemsView( + self, [(key, self[key]) for key in self.keys()] + ) - def keys(self): + def keys(self) -> RMKeyView: """Return a view of 'keys' for string column names represented by the underlying :class:`.Row`. @@ -283,9 +338,9 @@ class RowMapping(BaseRow, collections_abc.Mapping): return self._parent.keys - def values(self): + def values(self) -> ROMappingKeysValuesView: """Return a view of values for the values represented in the underlying :class:`.Row`. """ - return ROMappingView(self, self._values_impl()) + return ROMappingKeysValuesView(self, self._values_impl()) diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index a55233397..306989e0b 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -18,10 +18,18 @@ from __future__ import annotations import collections.abc as collections_abc import re +from typing import Any +from typing import cast from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping from typing import NamedTuple from typing import Optional +from typing import overload +from typing import Sequence from typing import Tuple +from typing import Type from typing import Union from urllib.parse import parse_qsl from urllib.parse import quote_plus @@ -86,19 +94,19 @@ class URL(NamedTuple): host: Optional[str] port: Optional[int] database: Optional[str] - query: Dict[str, Union[str, Tuple[str]]] + query: util.immutabledict[str, Union[Tuple[str, ...], str]] @classmethod def create( cls, - drivername, - username=None, - password=None, - host=None, - port=None, - database=None, - query=util.EMPTY_DICT, - ): + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = util.EMPTY_DICT, + ) -> URL: """Create a new :class:`_engine.URL` object. :param drivername: the name of the database backend. This name will @@ -146,7 +154,7 @@ class URL(NamedTuple): ) @classmethod - def _assert_port(cls, port): + def _assert_port(cls, port: Optional[int]) -> Optional[int]: if port is None: return None try: @@ -155,24 +163,48 @@ class URL(NamedTuple): raise TypeError("Port argument must be an integer or None") @classmethod - def _assert_str(cls, v, paramname): + def _assert_str(cls, v: str, paramname: str) -> str: if not isinstance(v, str): raise TypeError("%s must be a string" % paramname) return v @classmethod - def _assert_none_str(cls, v, paramname): + def _assert_none_str( + cls, v: Optional[str], paramname: str + ) -> Optional[str]: if v is None: return v return cls._assert_str(v, paramname) @classmethod - def _str_dict(cls, dict_): + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> util.immutabledict[str, Union[Tuple[str, ...], str]]: if dict_ is None: return util.EMPTY_DICT - def _assert_value(val): + @overload + def _assert_value( + val: str, + ) -> str: + ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: + ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: if isinstance(val, str): return val elif isinstance(val, collections_abc.Sequence): @@ -183,11 +215,12 @@ class URL(NamedTuple): "sequences of strings" ) - def _assert_str(v): + def _assert_str(v: str) -> str: if not isinstance(v, str): raise TypeError("Query dictionary keys must be strings") return v + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] if isinstance(dict_, collections_abc.Sequence): dict_items = dict_ else: @@ -204,14 +237,14 @@ class URL(NamedTuple): def set( self, - drivername=None, - username=None, - password=None, - host=None, - port=None, - database=None, - query=None, - ): + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> URL: """return a new :class:`_engine.URL` object with modifications. Values are used if they are non-None. To set a value to ``None`` @@ -237,7 +270,7 @@ class URL(NamedTuple): """ - kw = {} + kw: Dict[str, Any] = {} if drivername is not None: kw["drivername"] = drivername if username is not None: @@ -255,7 +288,7 @@ class URL(NamedTuple): return self._assert_replace(**kw) - def _assert_replace(self, **kw): + def _assert_replace(self, **kw: Any) -> URL: """argument checks before calling _replace()""" if "drivername" in kw: @@ -270,7 +303,9 @@ class URL(NamedTuple): return self._replace(**kw) - def update_query_string(self, query_string, append=False): + def update_query_string( + self, query_string: str, append: bool = False + ) -> URL: """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` parameter dictionary updated by the given query string. @@ -301,7 +336,11 @@ class URL(NamedTuple): """ # noqa: E501 return self.update_query_pairs(parse_qsl(query_string), append=append) - def update_query_pairs(self, key_value_pairs, append=False): + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> URL: """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` parameter dictionary updated by the given sequence of key/value pairs @@ -335,23 +374,27 @@ class URL(NamedTuple): """ # noqa: E501 existing_query = self.query - new_keys = {} + new_keys: Dict[str, Union[str, List[str]]] = {} for key, value in key_value_pairs: if key in new_keys: new_keys[key] = util.to_list(new_keys[key]) - new_keys[key].append(value) + cast("List[str]", new_keys[key]).append(cast(str, value)) else: - new_keys[key] = value + new_keys[key] = ( + list(value) if isinstance(value, (list, tuple)) else value + ) + new_query: Mapping[str, Union[str, Sequence[str]]] if append: new_query = {} for k in new_keys: if k in existing_query: - new_query[k] = util.to_list( - existing_query[k] - ) + util.to_list(new_keys[k]) + new_query[k] = tuple( + util.to_list(existing_query[k]) + + util.to_list(new_keys[k]) + ) else: new_query[k] = new_keys[k] @@ -362,10 +405,19 @@ class URL(NamedTuple): } ) else: - new_query = self.query.union(new_keys) + new_query = self.query.union( + { + k: tuple(v) if isinstance(v, list) else v + for k, v in new_keys.items() + } + ) return self.set(query=new_query) - def update_query_dict(self, query_parameters, append=False): + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> URL: """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` parameter dictionary updated by the given dictionary. @@ -410,7 +462,7 @@ class URL(NamedTuple): """ # noqa: E501 return self.update_query_pairs(query_parameters.items(), append=append) - def difference_update_query(self, names): + def difference_update_query(self, names: Iterable[str]) -> URL: """ Remove the given names from the :attr:`_engine.URL.query` dictionary, returning the new :class:`_engine.URL`. @@ -459,7 +511,7 @@ class URL(NamedTuple): ) @util.memoized_property - def normalized_query(self): + def normalized_query(self) -> Mapping[str, Sequence[str]]: """Return the :attr:`_engine.URL.query` dictionary with values normalized into sequences. @@ -494,7 +546,7 @@ class URL(NamedTuple): "be removed in a future release. Please use the " ":meth:`_engine.URL.render_as_string` method.", ) - def __to_string__(self, hide_password=True): + def __to_string__(self, hide_password: bool = True) -> str: """Render this :class:`_engine.URL` object as a string. :param hide_password: Defaults to True. The password is not shown @@ -503,7 +555,7 @@ class URL(NamedTuple): """ return self.render_as_string(hide_password=hide_password) - def render_as_string(self, hide_password=True): + def render_as_string(self, hide_password: bool = True) -> str: """Render this :class:`_engine.URL` object as a string. This method is used when the ``__str__()`` or ``__repr__()`` @@ -542,13 +594,13 @@ class URL(NamedTuple): ) return s - def __str__(self): + def __str__(self) -> str: return self.render_as_string(hide_password=False) - def __repr__(self): + def __repr__(self) -> str: return self.render_as_string() - def __copy__(self): + def __copy__(self) -> URL: return self.__class__.create( self.drivername, self.username, @@ -561,13 +613,13 @@ class URL(NamedTuple): self.query, ) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> URL: return self.__copy__() - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, URL) and self.drivername == other.drivername @@ -579,10 +631,10 @@ class URL(NamedTuple): and self.port == other.port ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def get_backend_name(self): + def get_backend_name(self) -> str: """Return the backend name. This is the name that corresponds to the database backend in @@ -595,7 +647,7 @@ class URL(NamedTuple): else: return self.drivername.split("+")[0] - def get_driver_name(self): + def get_driver_name(self) -> str: """Return the backend name. This is the name that corresponds to the DBAPI driver in @@ -613,7 +665,9 @@ class URL(NamedTuple): else: return self.drivername.split("+")[1] - def _instantiate_plugins(self, kwargs): + def _instantiate_plugins( + self, kwargs: Mapping[str, Any] + ) -> Tuple[URL, List[Any], Dict[str, Any]]: plugin_names = util.to_list(self.query.get("plugin", ())) plugin_names += kwargs.get("plugins", []) @@ -635,7 +689,7 @@ class URL(NamedTuple): return u, loaded_plugins, kwargs - def _get_entrypoint(self): + def _get_entrypoint(self) -> Type[Dialect]: """Return the "entry point" dialect class. This is normally the dialect itself except in the case when the @@ -657,9 +711,9 @@ class URL(NamedTuple): ): return cls.dialect else: - return cls + return cast("Type[Dialect]", cls) - def get_dialect(self, _is_async=False): + def get_dialect(self, _is_async: bool = False) -> Type[Dialect]: """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding to this URL's driver name. @@ -671,7 +725,9 @@ class URL(NamedTuple): dialect_cls = entrypoint.get_dialect_cls(self) return dialect_cls - def translate_connect_args(self, names=None, **kw): + def translate_connect_args( + self, names: Optional[List[str]] = None, **kw: Any + ) -> Dict[str, Any]: r"""Translate url attributes into a dictionary of connection arguments. Returns attributes of this url (`host`, `database`, `username`, @@ -711,11 +767,12 @@ class URL(NamedTuple): return translated -def make_url(name_or_url): - """Given a string or unicode instance, produce a new URL instance. +def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. The given string is parsed according to the RFC 1738 spec. If an existing URL object is passed, just returns the object. + """ if isinstance(name_or_url, str): @@ -724,7 +781,7 @@ def make_url(name_or_url): return name_or_url -def _parse_rfc1738_args(name): +def _parse_rfc1738_args(name: str) -> URL: pattern = re.compile( r""" (?P<name>[\w\+]+):// @@ -748,13 +805,14 @@ def _parse_rfc1738_args(name): m = pattern.match(name) if m is not None: components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] if components["query"] is not None: query = {} for key, value in parse_qsl(components["query"]): if key in query: query[key] = util.to_list(query[key]) - query[key].append(value) + cast("List[str]", query[key]).append(value) else: query[key] = value else: @@ -775,7 +833,7 @@ def _parse_rfc1738_args(name): if components["port"]: components["port"] = int(components["port"]) - return URL.create(name, **components) + return URL.create(name, **components) # type: ignore else: raise exc.ArgumentError( @@ -783,18 +841,8 @@ def _parse_rfc1738_args(name): ) -def _rfc_1738_quote(text): +def _rfc_1738_quote(text: str) -> str: return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) _rfc_1738_unquote = unquote - - -def _parse_keyvalue_args(name): - m = re.match(r"(\w+)://(.*)", name) - if m is not None: - (name, args) = m.group(1, 2) - opts = dict(parse_qsl(args)) - return URL(name, *opts) - else: - return None diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index f9ee65bef..213485cc9 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -7,18 +7,30 @@ from __future__ import annotations +import typing +from typing import Any +from typing import Callable +from typing import TypeVar + from .. import exc from .. import util +from ..util._has_cy import HAS_CYEXTENSION + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_util import _distill_params_20 as _distill_params_20 + from ._py_util import _distill_raw_params as _distill_raw_params +else: + from sqlalchemy.cyextension.util import ( + _distill_params_20 as _distill_params_20, + ) + from sqlalchemy.cyextension.util import ( + _distill_raw_params as _distill_raw_params, + ) -try: - from sqlalchemy.cyextension.util import _distill_params_20 # noqa - from sqlalchemy.cyextension.util import _distill_raw_params # noqa -except ImportError: - from ._py_util import _distill_params_20 # noqa - from ._py_util import _distill_raw_params # noqa +_C = TypeVar("_C", bound=Callable[[], Any]) -def connection_memoize(key): +def connection_memoize(key: str) -> Callable[[_C], _C]: """Decorator, memoize a function in a connection.info stash. Only applicable to functions which take no arguments other than a @@ -26,7 +38,7 @@ def connection_memoize(key): """ @util.decorator - def decorated(fn, self, connection): + def decorated(fn, self, connection): # type: ignore connection = connection.connect() try: return connection.info[key] @@ -34,7 +46,7 @@ def connection_memoize(key): connection.info[key] = val = fn(self, connection) return val - return decorated + return decorated # type: ignore[return-value] class TransactionalContext: @@ -47,13 +59,13 @@ class TransactionalContext: __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__") - def _transaction_is_active(self): + def _transaction_is_active(self) -> bool: raise NotImplementedError() - def _transaction_is_closed(self): + def _transaction_is_closed(self) -> bool: raise NotImplementedError() - def _rollback_can_be_called(self): + def _rollback_can_be_called(self) -> bool: """indicates the object is in a state that is known to be acceptable for rollback() to be called. @@ -70,11 +82,20 @@ class TransactionalContext: """ raise NotImplementedError() - def _get_subject(self): + def _get_subject(self) -> Any: + raise NotImplementedError() + + def commit(self) -> None: + raise NotImplementedError() + + def rollback(self) -> None: + raise NotImplementedError() + + def close(self) -> None: raise NotImplementedError() @classmethod - def _trans_ctx_check(cls, subject): + def _trans_ctx_check(cls, subject: Any) -> None: trans_context = subject._trans_context_manager if trans_context: if not trans_context._transaction_is_active(): @@ -84,7 +105,7 @@ class TransactionalContext: "before emitting further commands." ) - def __enter__(self): + def __enter__(self) -> TransactionalContext: subject = self._get_subject() # none for outer transaction, may be non-None for nested @@ -96,7 +117,7 @@ class TransactionalContext: subject._trans_context_manager = self return self - def __exit__(self, type_, value, traceback): + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: subject = getattr(self, "_trans_subject", None) # simplistically we could assume that @@ -119,6 +140,7 @@ class TransactionalContext: self.rollback() finally: if not out_of_band_exit: + assert subject is not None subject._trans_context_manager = self._outer_trans_ctx self._trans_subject = self._outer_trans_ctx = None else: @@ -131,5 +153,6 @@ class TransactionalContext: self.rollback() finally: if not out_of_band_exit: + assert subject is not None subject._trans_context_manager = self._outer_trans_ctx self._trans_subject = self._outer_trans_ctx = None diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index 0dfb39e1a..e1c949681 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -15,6 +15,7 @@ from .api import NO_RETVAL as NO_RETVAL from .api import remove as remove from .attr import RefCollection as RefCollection from .base import _Dispatch as _Dispatch +from .base import _DispatchCommon as _DispatchCommon from .base import dispatcher as dispatcher from .base import Events as Events from .legacy import _legacy_signature as _legacy_signature diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 9692894fe..afae8a59a 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -605,14 +605,14 @@ class _ListenerCollection(_CompoundListener[_ET]): class _JoinedListener(_CompoundListener[_ET]): __slots__ = "parent_dispatch", "name", "local", "parent_listeners" - parent_dispatch: _Dispatch[_ET] + parent_dispatch: _DispatchCommon[_ET] name: str local: _InstanceLevelDispatch[_ET] parent_listeners: Collection[_ListenerFnType] def __init__( self, - parent_dispatch: _Dispatch[_ET], + parent_dispatch: _DispatchCommon[_ET], name: str, local: _EmptyListener[_ET], ): diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index ef3ff9dab..4174b1dbe 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -75,6 +75,18 @@ class _UnpickleDispatch: class _DispatchCommon(Generic[_ET]): __slots__ = () + _instance_cls: Optional[Type[_ET]] + + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: + raise NotImplementedError() + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + raise NotImplementedError() + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + raise NotImplementedError() + class _Dispatch(_DispatchCommon[_ET]): """Mirror the event listening definitions of an Events class with @@ -169,7 +181,7 @@ class _Dispatch(_DispatchCommon[_ET]): instance_cls = instance.__class__ return self._for_class(instance_cls) - def _join(self, other: _Dispatch[_ET]) -> _JoinedDispatcher[_ET]: + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: """Create a 'join' of this :class:`._Dispatch` and another. This new dispatcher will dispatch events to both @@ -372,11 +384,13 @@ class _JoinedDispatcher(_DispatchCommon[_ET]): __slots__ = "local", "parent", "_instance_cls" - local: _Dispatch[_ET] - parent: _Dispatch[_ET] + local: _DispatchCommon[_ET] + parent: _DispatchCommon[_ET] _instance_cls: Optional[Type[_ET]] - def __init__(self, local: _Dispatch[_ET], parent: _Dispatch[_ET]): + def __init__( + self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] + ): self.local = local self.parent = parent self._instance_cls = self.local._instance_cls @@ -416,7 +430,7 @@ class dispatcher(Generic[_ET]): ... @overload - def __get__(self, obj: Any, cls: Type[Any]) -> _Dispatch[_ET]: + def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: ... def __get__(self, obj: Any, cls: Type[Any]) -> Any: diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 1383e024a..cc78e0971 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -27,8 +27,11 @@ from .util import _preloaded from .util import compat if typing.TYPE_CHECKING: + from .engine.interfaces import _AnyExecuteParams + from .engine.interfaces import _CoreAnyExecuteParams + from .engine.interfaces import _CoreMultiExecuteParams + from .engine.interfaces import _DBAPIAnyExecuteParams from .engine.interfaces import Dialect - from .sql._typing import _ExecuteParams from .sql.compiler import Compiled from .sql.elements import ClauseElement @@ -446,7 +449,7 @@ class StatementError(SQLAlchemyError): statement: Optional[str] = None """The string SQL statement being invoked when this exception occurred.""" - params: Optional["_ExecuteParams"] = None + params: Optional[_AnyExecuteParams] = None """The parameter list being used when this exception occurred.""" orig: Optional[BaseException] = None @@ -457,11 +460,13 @@ class StatementError(SQLAlchemyError): ismulti: Optional[bool] = None """multi parameter passed to repr_params(). None is meaningful.""" + connection_invalidated: bool = False + def __init__( self, message: str, statement: Optional[str], - params: Optional["_ExecuteParams"], + params: Optional[_AnyExecuteParams], orig: Optional[BaseException], hide_parameters: bool = False, code: Optional[str] = None, @@ -553,8 +558,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: DontWrapMixin, dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -568,8 +573,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: Exception, dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -583,8 +588,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: BaseException, dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -597,8 +602,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: Union[BaseException, DontWrapMixin], dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -684,8 +689,8 @@ class DBAPIError(StatementError): def __init__( self, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: BaseException, hide_parameters: bool = False, connection_invalidated: bool = False, diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 8da45ed0d..9a23d89d3 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -255,19 +255,19 @@ class echo_property: @overload def __get__( - self, instance: "Literal[None]", owner: "echo_property" - ) -> "echo_property": + self, instance: Literal[None], owner: Type[Identified] + ) -> echo_property: ... @overload def __get__( - self, instance: Identified, owner: "echo_property" + self, instance: Identified, owner: Type[Identified] ) -> _EchoFlagType: ... def __get__( - self, instance: Optional[Identified], owner: "echo_property" - ) -> Union["echo_property", _EchoFlagType]: + self, instance: Optional[Identified], owner: Type[Identified] + ) -> Union[echo_property, _EchoFlagType]: if instance is None: return self else: diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 2c52a7065..1fc77243a 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -18,8 +18,8 @@ SQLAlchemy connection pool. """ from . import events -from .base import _AdhocProxiedConnection -from .base import _ConnectionFairy +from .base import _AdhocProxiedConnection as _AdhocProxiedConnection +from .base import _ConnectionFairy as _ConnectionFairy from .base import _ConnectionRecord from .base import _finalize_fairy from .base import ConnectionPoolEntry as ConnectionPoolEntry diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 18d268182..c1008de5f 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from ..engine.interfaces import DBAPICursor from ..engine.interfaces import Dialect from ..event import _Dispatch + from ..event import _DispatchCommon from ..event import _ListenerFnType from ..event import dispatcher @@ -132,7 +133,7 @@ class Pool(log.Identified, event.EventTarget): events: Optional[List[Tuple[_ListenerFnType, str]]] = None, dialect: Optional[Union[_ConnDialect, Dialect]] = None, pre_ping: bool = False, - _dispatch: Optional[_Dispatch[Pool]] = None, + _dispatch: Optional[_DispatchCommon[Pool]] = None, ): """ Construct a Pool. @@ -443,78 +444,72 @@ class ManagesConnection: """ - @property - def driver_connection(self) -> Optional[Any]: - """The "driver level" connection object as used by the Python - DBAPI or database driver. + driver_connection: Optional[Any] + """The "driver level" connection object as used by the Python + DBAPI or database driver. - For traditional :pep:`249` DBAPI implementations, this object will - be the same object as that of - :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database - driver, this will be the ultimate "connection" object used by that - driver, such as the ``asyncpg.Connection`` object which will not have - standard pep-249 methods. + For traditional :pep:`249` DBAPI implementations, this object will + be the same object as that of + :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database + driver, this will be the ultimate "connection" object used by that + driver, such as the ``asyncpg.Connection`` object which will not have + standard pep-249 methods. - .. versionadded:: 1.4.24 + .. versionadded:: 1.4.24 - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.dbapi_connection` + :attr:`.ManagesConnection.dbapi_connection` - :ref:`faq_dbapi_connection` + :ref:`faq_dbapi_connection` - """ - raise NotImplementedError() + """ - @util.dynamic_property - def info(self) -> Dict[str, Any]: - """Info dictionary associated with the underlying DBAPI connection - referred to by this :class:`.ManagesConnection` instance, allowing - user-defined data to be associated with the connection. + info: Dict[str, Any] + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ManagesConnection` instance, allowing + user-defined data to be associated with the connection. - The data in this dictionary is persistent for the lifespan - of the DBAPI connection itself, including across pool checkins - and checkouts. When the connection is invalidated - and replaced with a new one, this dictionary is cleared. + The data in this dictionary is persistent for the lifespan + of the DBAPI connection itself, including across pool checkins + and checkouts. When the connection is invalidated + and replaced with a new one, this dictionary is cleared. - For a :class:`.PoolProxiedConnection` instance that's not associated - with a :class:`.ConnectionPoolEntry`, such as if it were detached, the - attribute returns a dictionary that is local to that - :class:`.ConnectionPoolEntry`. Therefore the - :attr:`.ManagesConnection.info` attribute will always provide a Python - dictionary. + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns a dictionary that is local to that + :class:`.ConnectionPoolEntry`. Therefore the + :attr:`.ManagesConnection.info` attribute will always provide a Python + dictionary. - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.record_info` + :attr:`.ManagesConnection.record_info` - """ - raise NotImplementedError() + """ - @util.dynamic_property - def record_info(self) -> Optional[Dict[str, Any]]: - """Persistent info dictionary associated with this - :class:`.ManagesConnection`. + record_info: Optional[Dict[str, Any]] + """Persistent info dictionary associated with this + :class:`.ManagesConnection`. - Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan - of this dictionary is that of the :class:`.ConnectionPoolEntry` - which owns it; therefore this dictionary will persist across - reconnects and connection invalidation for a particular entry - in the connection pool. + Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan + of this dictionary is that of the :class:`.ConnectionPoolEntry` + which owns it; therefore this dictionary will persist across + reconnects and connection invalidation for a particular entry + in the connection pool. - For a :class:`.PoolProxiedConnection` instance that's not associated - with a :class:`.ConnectionPoolEntry`, such as if it were detached, the - attribute returns None. Contrast to the :attr:`.ManagesConnection.info` - dictionary which is never None. + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns None. Contrast to the :attr:`.ManagesConnection.info` + dictionary which is never None. - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.info` + :attr:`.ManagesConnection.info` - """ - raise NotImplementedError() + """ def invalidate( self, e: Optional[BaseException] = None, soft: bool = False @@ -618,7 +613,7 @@ class _ConnectionRecord(ConnectionPoolEntry): dbapi_connection: Optional[DBAPIConnection] @property - def driver_connection(self) -> Optional[Any]: + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa E501 if self.dbapi_connection is None: return None else: @@ -637,11 +632,11 @@ class _ConnectionRecord(ConnectionPoolEntry): _soft_invalidate_time: float = 0 @util.memoized_property - def info(self) -> Dict[str, Any]: + def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125 return {} @util.memoized_property - def record_info(self) -> Optional[Dict[str, Any]]: + def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501 return {} @classmethod @@ -1048,7 +1043,7 @@ class _AdhocProxiedConnection(PoolProxiedConnection): """ - __slots__ = ("dbapi_connection", "_connection_record") + __slots__ = ("dbapi_connection", "_connection_record", "_is_valid") dbapi_connection: DBAPIConnection _connection_record: ConnectionPoolEntry @@ -1060,9 +1055,10 @@ class _AdhocProxiedConnection(PoolProxiedConnection): ): self.dbapi_connection = dbapi_connection self._connection_record = connection_record + self._is_valid = True @property - def driver_connection(self) -> Any: + def driver_connection(self) -> Any: # type: ignore[override] # mypy#4125 return self._connection_record.driver_connection @property @@ -1071,10 +1067,21 @@ class _AdhocProxiedConnection(PoolProxiedConnection): @property def is_valid(self) -> bool: - raise AttributeError("is_valid not implemented by this proxy") + """Implement is_valid state attribute. + + for the adhoc proxied connection it's assumed the connection is valid + as there is no "invalidate" routine. + + """ + return self._is_valid - @util.dynamic_property - def record_info(self) -> Optional[Dict[str, Any]]: + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + self._is_valid = False + + @property + def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501 return self._connection_record.record_info def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: @@ -1140,7 +1147,7 @@ class _ConnectionFairy(PoolProxiedConnection): _connection_record: Optional[_ConnectionRecord] @property - def driver_connection(self) -> Optional[Any]: + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa E501 if self._connection_record is None: return None return self._connection_record.driver_connection @@ -1305,17 +1312,17 @@ class _ConnectionFairy(PoolProxiedConnection): @property def is_detached(self) -> bool: - return self._connection_record is not None + return self._connection_record is None @util.memoized_property - def info(self) -> Dict[str, Any]: + def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125 if self._connection_record is None: return {} else: return self._connection_record.info - @util.dynamic_property - def record_info(self) -> Optional[Dict[str, Any]]: + @property + def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501 if self._connection_record is None: return None else: diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 7d8b9ee5c..69e4645fa 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,20 +1,11 @@ from __future__ import annotations -from typing import Any -from typing import Mapping -from typing import Sequence from typing import Type from typing import Union from . import roles from ..inspection import Inspectable -from ..util import immutabledict -_SingleExecuteParams = Mapping[str, Any] -_MultiExecuteParams = Sequence[_SingleExecuteParams] -_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams] -_ExecuteOptions = Mapping[str, Any] -_ImmutableExecuteOptions = immutabledict[str, Any] _ColumnsClauseElement = Union[ roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] ] diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 3936ed9c6..a94590da1 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -19,11 +19,12 @@ from itertools import zip_longest import operator import re import typing +from typing import Optional +from typing import Sequence from typing import TypeVar from . import roles from . import visitors -from ._typing import _ImmutableExecuteOptions from .cache_key import HasCacheKey # noqa from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa @@ -32,7 +33,7 @@ from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import exc from .. import util -from ..util import HasMemoized +from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing from ..util._has_cy import HAS_CYEXTENSION @@ -42,6 +43,16 @@ if typing.TYPE_CHECKING or not HAS_CYEXTENSION: else: from sqlalchemy.cyextension.util import prefix_anon_map # noqa +if typing.TYPE_CHECKING: + from ..engine import Connection + from ..engine import Result + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ExecuteOptionsParameter + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import CacheStats + + coercions = None elements = None type_api = None @@ -856,6 +867,32 @@ class Executable(roles.StatementRole, Generative): is_delete = False is_dml = False + if typing.TYPE_CHECKING: + + def _compile_w_cache( + self, + dialect: Dialect, + compiled_cache: Optional[_CompiledCacheType] = None, + column_keys: Optional[Sequence[str]] = None, + for_executemany: bool = False, + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]: + ... + + def _execute_on_connection( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + _force: bool = False, + ) -> Result: + ... + + @property + def _all_selected_columns(self): + raise NotImplementedError() + @property def _effective_plugin_target(self): return self.__visit_name__ diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 49f1899d5..ff659b77d 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -7,10 +7,12 @@ from __future__ import annotations -from collections import namedtuple import enum from itertools import zip_longest +import typing +from typing import Any from typing import Callable +from typing import NamedTuple from typing import Union from .visitors import anon_map @@ -22,6 +24,10 @@ from ..util import HasMemoized from ..util.typing import Literal +if typing.TYPE_CHECKING: + from .elements import BindParameter + + class CacheConst(enum.Enum): NO_CACHE = 0 @@ -345,7 +351,7 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): +class CacheKey(NamedTuple): """The key used to identify a SQL statement construct in the SQL compilation cache. @@ -355,6 +361,9 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): """ + key: Tuple[Any, ...] + bindparams: Sequence[BindParameter] + def __hash__(self): """CacheKey itself is not hashable - hash the .key portion""" diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d0f114d6c..712d31462 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -27,6 +27,7 @@ from __future__ import annotations import collections import collections.abc as collections_abc import contextlib +from enum import IntEnum import itertools import operator import re @@ -35,9 +36,13 @@ import typing from typing import Any from typing import Dict from typing import List +from typing import Mapping from typing import MutableMapping +from typing import NamedTuple from typing import Optional +from typing import Sequence from typing import Tuple +from typing import Union from . import base from . import coercions @@ -51,12 +56,17 @@ from . import sqltypes from .base import NO_ARG from .base import prefix_anon_map from .elements import quoted_name +from .schema import Column +from .type_api import TypeEngine from .. import exc from .. import util +from ..util.typing import Literal if typing.TYPE_CHECKING: from .selectable import CTE from .selectable import FromClause + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.result import _ProcessorType _FromHintsType = Dict["FromClause", str] @@ -271,42 +281,71 @@ COMPOUND_KEYWORDS = { } -RM_RENDERED_NAME = 0 -RM_NAME = 1 -RM_OBJECTS = 2 -RM_TYPE = 3 +class ResultColumnsEntry(NamedTuple): + """Tracks a column expression that is expected to be represented + in the result rows for this statement. + This normally refers to the columns clause of a SELECT statement + but may also refer to a RETURNING clause, as well as for dialect-specific + emulations. -ExpandedState = collections.namedtuple( - "ExpandedState", - [ - "statement", - "additional_parameters", - "processors", - "positiontup", - "parameter_expansion", - ], -) + """ + keyname: str + """string name that's expected in cursor.description""" -NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + name: str + """column name, may be labeled""" -COLLECT_CARTESIAN_PRODUCTS = util.symbol( - "COLLECT_CARTESIAN_PRODUCTS", - "Collect data on FROMs and cartesian products and gather " - "into 'self.from_linter'", - canonical=1, -) + objects: List[Any] + """list of objects that should be able to locate this column + in a RowMapping. This is typically string names and aliases + as well as Column objects. -WARN_LINTING = util.symbol( - "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 -) + """ + + type: TypeEngine[Any] + """Datatype to be associated with this column. This is where + the "result processing" logic directly links the compiled statement + to the rows that come back from the cursor. + + """ + + +# integer indexes into ResultColumnsEntry used by cursor.py. +# some profiling showed integer access faster than named tuple +RM_RENDERED_NAME: Literal[0] = 0 +RM_NAME: Literal[1] = 1 +RM_OBJECTS: Literal[2] = 2 +RM_TYPE: Literal[3] = 3 + + +class ExpandedState(NamedTuple): + statement: str + additional_parameters: _CoreSingleExecuteParams + processors: Mapping[str, _ProcessorType] + positiontup: Optional[Sequence[str]] + parameter_expansion: Mapping[str, List[str]] + + +class Linting(IntEnum): + NO_LINTING = 0 + "Disable all linting." + + COLLECT_CARTESIAN_PRODUCTS = 1 + """Collect data on FROMs and cartesian products and gather into + 'self.from_linter'""" + + WARN_LINTING = 2 + "Emit warnings for linters that find problems" -FROM_LINTING = util.symbol( - "FROM_LINTING", - "Warn for cartesian products; " - "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", - canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, + FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING + """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS + and WARN_LINTING""" + + +NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple( + Linting ) @@ -389,7 +428,7 @@ class Compiled: _cached_metadata = None - _result_columns = None + _result_columns: Optional[List[ResultColumnsEntry]] = None schema_translate_map = None @@ -418,7 +457,8 @@ class Compiled: """ cache_key = None - _gen_time = None + + _gen_time: float def __init__( self, @@ -573,15 +613,43 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + _result_columns: List[ResultColumnsEntry] + compound_keywords = COMPOUND_KEYWORDS - isdelete = isinsert = isupdate = False + isdelete: bool = False + isinsert: bool = False + isupdate: bool = False """class-level defaults which can be set at the instance level to define if this Compiled instance represents INSERT/UPDATE/DELETE """ - isplaintext = False + postfetch: Optional[List[Column[Any]]] + """list of columns that can be post-fetched after INSERT or UPDATE to + receive server-updated values""" + + insert_prefetch: Optional[List[Column[Any]]] + """list of columns for which default values should be evaluated before + an INSERT takes place""" + + update_prefetch: Optional[List[Column[Any]]] + """list of columns for which onupdate default values should be evaluated + before an UPDATE takes place""" + + returning: Optional[List[Column[Any]]] + """list of columns that will be delivered to cursor.description or + dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE + + """ + + isplaintext: bool = False + + result_columns: List[ResultColumnsEntry] + """relates label names in the final SQL to a tuple of local + column/label name, ColumnElement object (if any) and + TypeEngine. CursorResult uses this for type processing and + column targeting""" returning = None """holds the "returning" collection of columns if @@ -589,18 +657,18 @@ class SQLCompiler(Compiled): either implicitly or explicitly """ - returning_precedes_values = False + returning_precedes_values: bool = False """set to True classwide to generate RETURNING clauses before the VALUES or WHERE clause (i.e. MSSQL) """ - render_table_with_column_in_update_from = False + render_table_with_column_in_update_from: bool = False """set to True classwide to indicate the SET clause in a multi-table UPDATE statement should qualify columns with the table name (i.e. MySQL only) """ - ansi_bind_rules = False + ansi_bind_rules: bool = False """SQL 92 doesn't allow bind parameters to be used in the columns clause of a SELECT, nor does it allow ambiguous expressions like "? = ?". A compiler @@ -608,33 +676,33 @@ class SQLCompiler(Compiled): driver/DB enforces this """ - _textual_ordered_columns = False + _textual_ordered_columns: bool = False """tell the result object that the column names as rendered are important, but they are also "ordered" vs. what is in the compiled object here. """ - _ordered_columns = True + _ordered_columns: bool = True """ if False, means we can't be sure the list of entries in _result_columns is actually the rendered order. Usually True unless using an unordered TextualSelect. """ - _loose_column_name_matching = False + _loose_column_name_matching: bool = False """tell the result object that the SQL statement is textual, wants to match up to Column objects, and may be using the ._tq_label in the SELECT rather than the base name. """ - _numeric_binds = False + _numeric_binds: bool = False """ True if paramstyle is "numeric". This paramstyle is trickier than all the others. """ - _render_postcompile = False + _render_postcompile: bool = False """ whether to render out POSTCOMPILE params during the compile phase. @@ -684,7 +752,7 @@ class SQLCompiler(Compiled): """ - positiontup = None + positiontup: Optional[Sequence[str]] = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -699,7 +767,7 @@ class SQLCompiler(Compiled): """ - inline = False + inline: bool = False def __init__( self, @@ -760,10 +828,6 @@ class SQLCompiler(Compiled): # stack which keeps track of nested SELECT statements self.stack = [] - # relates label names in the final SQL to a tuple of local - # column/label name, ColumnElement object (if any) and - # TypeEngine. CursorResult uses this for type processing and - # column targeting self._result_columns = [] # true if the paramstyle is positional @@ -910,7 +974,9 @@ class SQLCompiler(Compiled): ) @util.memoized_property - def _bind_processors(self): + def _bind_processors( + self, + ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]: return dict( (key, value) for key, value in ( @@ -1098,8 +1164,10 @@ class SQLCompiler(Compiled): return self.construct_params(_check=False) def _process_parameters_for_postcompile( - self, parameters=None, _populate_self=False - ): + self, + parameters: Optional[_CoreSingleExecuteParams] = None, + _populate_self: bool = False, + ) -> ExpandedState: """handle special post compile parameters. These include: @@ -3070,7 +3138,13 @@ class SQLCompiler(Compiled): def get_render_as_alias_suffix(self, alias_name_text): return " AS " + alias_name_text - def _add_to_result_map(self, keyname, name, objects, type_): + def _add_to_result_map( + self, + keyname: str, + name: str, + objects: List[Any], + type_: TypeEngine[Any], + ) -> None: if keyname is None or keyname == "*": self._ordered_columns = False self._textual_ordered_columns = True @@ -3080,7 +3154,9 @@ class SQLCompiler(Compiled): "from a tuple() object. If this is an ORM query, " "consider using the Bundle object." ) - self._result_columns.append((keyname, name, objects, type_)) + self._result_columns.append( + ResultColumnsEntry(keyname, name, objects, type_) + ) def _label_returning_column(self, stmt, column, column_clause_args=None): """Render a column with necessary labels inside of a RETURNING clause. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 0c532a135..ac5dc46db 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -61,6 +61,11 @@ if typing.TYPE_CHECKING: from .selectable import Select from .sqltypes import Boolean # noqa from .type_api import TypeEngine + from ..engine import Compiled + from ..engine import Connection + from ..engine import Dialect + from ..engine import Engine + _NUMERIC = Union[complex, "Decimal"] @@ -145,7 +150,12 @@ class CompilerElement(Visitable): @util.preload_module("sqlalchemy.engine.default") @util.preload_module("sqlalchemy.engine.url") - def compile(self, bind=None, dialect=None, **kw): + def compile( + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> Compiled: """Compile this SQL expression. The return value is a :class:`~.Compiled` object. diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 528691795..fdae4d7b0 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -174,7 +174,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): _use_schema_map = True -class Table(DialectKWArgs, SchemaItem, TableClause): +class HasSchemaAttr(SchemaItem): + """schema item that includes a top-level schema name""" + + schema: Optional[str] + + +class Table(DialectKWArgs, HasSchemaAttr, TableClause): r"""Represent a table in a database. e.g.:: @@ -2850,7 +2856,7 @@ class IdentityOptions: self.order = order -class Sequence(IdentityOptions, DefaultGenerator): +class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): """Represents a named database sequence. The :class:`.Sequence` object represents the name and configurational @@ -4330,7 +4336,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"}) -class MetaData(SchemaItem): +class MetaData(HasSchemaAttr): """A collection of :class:`_schema.Table` objects and their associated schema constructs. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e3e358cdb..e0248adf0 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -21,9 +21,6 @@ from . import coercions from . import operators from . import roles from . import visitors -from ._typing import _ExecuteParams -from ._typing import _MultiExecuteParams -from ._typing import _SingleExecuteParams from .annotation import _deep_annotate # noqa from .annotation import _deep_deannotate # noqa from .annotation import _shallow_annotate # noqa @@ -54,6 +51,10 @@ from .. import exc from .. import util if typing.TYPE_CHECKING: + from ..engine.interfaces import _AnyExecuteParams + from ..engine.interfaces import _AnyMultiExecuteParams + from ..engine.interfaces import _AnySingleExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.row import Row @@ -550,12 +551,12 @@ class _repr_params(_repr_base): def __init__( self, - params: _ExecuteParams, + params: Optional[_AnyExecuteParams], batches: int, max_chars: int = 300, ismulti: Optional[bool] = None, ): - self.params: _ExecuteParams = params + self.params = params self.ismulti = ismulti self.batches = batches self.max_chars = max_chars @@ -575,7 +576,10 @@ class _repr_params(_repr_base): return self.trunc(self.params) if self.ismulti: - multi_params = cast(_MultiExecuteParams, self.params) + multi_params = cast( + "_AnyMultiExecuteParams", + self.params, + ) if len(self.params) > self.batches: msg = ( @@ -595,10 +599,18 @@ class _repr_params(_repr_base): return self._repr_multi(multi_params, typ) else: return self._repr_params( - cast(_SingleExecuteParams, self.params), typ + cast( + "_AnySingleExecuteParams", + self.params, + ), + typ, ) - def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str: + def _repr_multi( + self, + multi_params: _AnyMultiExecuteParams, + typ, + ) -> str: if multi_params: if isinstance(multi_params[0], list): elem_type = self._LIST @@ -622,13 +634,19 @@ class _repr_params(_repr_base): else: return "(%s)" % elements - def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str: + def _repr_params( + self, + params: Optional[_AnySingleExecuteParams], + typ: int, + ) -> str: trunc = self.trunc if typ is self._DICT: return "{%s}" % ( ", ".join( "%r: %s" % (key, trunc(value)) - for key, value in params.items() + for key, value in cast( + "_CoreSingleExecuteParams", params + ).items() ) ) elif typ is self._TUPLE: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 523426d09..111ecd32e 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,6 +28,8 @@ from __future__ import annotations from collections import deque import itertools import operator +import typing +from typing import Any from typing import List from typing import Tuple @@ -35,12 +37,13 @@ from .. import exc from .. import util from ..util import langhelpers from ..util import symbol +from ..util._has_cy import HAS_CYEXTENSION from ..util.langhelpers import _symbol -try: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa -except ImportError: +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import cache_anon_map as anon_map # noqa +else: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa __all__ = [ "iterate", @@ -554,7 +557,7 @@ class ExternalTraversal: __traverse_options__ = {} - def traverse_single(self, obj, **kw): + def traverse_single(self, obj: Visitable, **kw: Any) -> Any: for v in self.visitor_iterator: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index a41420504..e5cf9b92e 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -94,7 +94,6 @@ from .langhelpers import decode_slice as decode_slice from .langhelpers import decorator as decorator from .langhelpers import dictlike_iteritems as dictlike_iteritems from .langhelpers import duck_type_collection as duck_type_collection -from .langhelpers import dynamic_property as dynamic_property from .langhelpers import ellipses_string as ellipses_string from .langhelpers import EnsureKWArg as EnsureKWArg from .langhelpers import format_argspec_init as format_argspec_init diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 84735316d..e0b53b445 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -513,7 +513,12 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): threshold: float size_alert: Optional[Callable[["LRUCache[_KT, _VT]"], None]] - def __init__(self, capacity=100, threshold=0.5, size_alert=None): + def __init__( + self, + capacity: int = 100, + threshold: float = 0.5, + size_alert: Optional[Callable[..., None]] = None, + ): self.capacity = capacity self.threshold = threshold self.size_alert = size_alert diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index ee54180ac..771e974e9 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -15,9 +15,11 @@ from typing import Dict from typing import Iterable from typing import Iterator from typing import List +from typing import Mapping from typing import NoReturn from typing import Optional from typing import Set +from typing import Tuple from typing import TypeVar from typing import Union @@ -65,13 +67,15 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.__init__(new, *args) return new - def __init__(self, *args): + def __init__(self, *args: Union[Mapping[_KT, _VT], Tuple[_KT, _VT]]): pass def __reduce__(self): return immutabledict, (dict(self),) - def union(self, __d=None): + def union( + self, __d: Optional[Mapping[_KT, _VT]] = None + ) -> immutabledict[_KT, _VT]: if not __d: return self @@ -80,7 +84,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.update(new, __d) return new - def _union_w_kw(self, __d=None, **kw): + def _union_w_kw( + self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT + ) -> immutabledict[_KT, _VT]: # not sure if C version works correctly w/ this yet if not __d and not kw: return self @@ -92,7 +98,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.update(new, kw) # type: ignore return new - def merge_with(self, *dicts): + def merge_with( + self, *dicts: Optional[Mapping[_KT, _VT]] + ) -> immutabledict[_KT, _VT]: new = None for d in dicts: if d: @@ -105,7 +113,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): return new - def __repr__(self): + def __repr__(self) -> str: return "immutabledict(%s)" % dict.__repr__(self) diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 7e1d3213a..a8e58a8bf 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -28,7 +28,6 @@ from . import compat from .langhelpers import _hash_limit_string from .langhelpers import _warnings_warn from .langhelpers import decorator -from .langhelpers import dynamic_property from .langhelpers import inject_docstring_text from .langhelpers import inject_param_text from .. import exc @@ -103,7 +102,7 @@ def deprecated_property( add_deprecation_to_docstring: bool = True, warning: Optional[Type[exc.SADeprecationWarning]] = None, enable_warnings: bool = True, -) -> Callable[[Callable[..., _T]], dynamic_property[_T]]: +) -> Callable[[Callable[..., Any]], property]: """the @deprecated decorator with a @property. E.g.:: @@ -131,8 +130,8 @@ def deprecated_property( """ - def decorate(fn: Callable[..., _T]) -> dynamic_property[_T]: - return dynamic_property( + def decorate(fn: Callable[..., Any]) -> property: + return property( deprecated( version, message=message, diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 43f9d5c73..1e79fd547 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -317,15 +317,17 @@ _P = compat_typing.ParamSpec("_P") class PluginLoader: - def __init__(self, group, auto_fn=None): + def __init__( + self, group: str, auto_fn: Optional[Callable[..., Any]] = None + ): self.group = group - self.impls = {} + self.impls: Dict[str, Any] = {} self.auto_fn = auto_fn def clear(self): self.impls.clear() - def load(self, name): + def load(self, name: str) -> Any: if name in self.impls: return self.impls[name]() @@ -344,7 +346,7 @@ class PluginLoader: "Can't load plugin: %s:%s" % (self.group, name) ) - def register(self, name, modulepath, objname): + def register(self, name: str, modulepath: str, objname: str) -> None: def load(): mod = __import__(modulepath) for token in modulepath.split(".")[1:]: @@ -444,7 +446,7 @@ def get_cls_kwargs( return _set -def get_func_kwargs(func): +def get_func_kwargs(func: Callable[..., Any]) -> List[str]: """Return the set of legal kwargs for the given `func`. Uses getargspec so is safe to call for methods, functions, @@ -1125,22 +1127,13 @@ def as_interface(obj, cls=None, methods=None, required=None): ) -Selfdynamic_property = TypeVar( - "Selfdynamic_property", bound="dynamic_property[Any]" -) - Selfmemoized_property = TypeVar( "Selfmemoized_property", bound="memoized_property[Any]" ) -class dynamic_property(Generic[_T]): - """A read-only @property that is evaluated each time. - - This is mostly the same as @property except we can type it - alongside memoized_property - - """ +class memoized_property(Generic[_T]): + """A read-only @property that is only evaluated once.""" fget: Callable[..., _T] __doc__: Optional[str] @@ -1153,27 +1146,6 @@ class dynamic_property(Generic[_T]): @overload def __get__( - self: Selfdynamic_property, obj: None, cls: Any - ) -> Selfdynamic_property: - ... - - @overload - def __get__(self, obj: Any, cls: Any) -> _T: - ... - - def __get__( - self: Selfdynamic_property, obj: Any, cls: Any - ) -> Union[Selfdynamic_property, _T]: - if obj is None: - return self - return self.fget(obj) # type: ignore[no-any-return] - - -class memoized_property(dynamic_property[_T]): - """A read-only @property that is only evaluated once.""" - - @overload - def __get__( self: Selfmemoized_property, obj: None, cls: Any ) -> Selfmemoized_property: ... @@ -1231,24 +1203,27 @@ def memoized_instancemethod(fn): class HasMemoized: - """A class that maintains the names of memoized elements in a + """A mixin class that maintains the names of memoized elements in a collection for easy cache clearing, generative, etc. """ - __slots__ = () + if not typing.TYPE_CHECKING: + # support classes that want to have __slots__ with an explicit + # slot for __dict__. not sure if that requires base __slots__ here. + __slots__ = () _memoized_keys: FrozenSet[str] = frozenset() - def _reset_memoizations(self): + def _reset_memoizations(self) -> None: for elem in self._memoized_keys: self.__dict__.pop(elem, None) - def _assert_no_memoizations(self): + def _assert_no_memoizations(self) -> None: for elem in self._memoized_keys: assert elem not in self.__dict__ - def _set_memoized_attribute(self, key, value): + def _set_memoized_attribute(self, key: str, value: Any) -> None: self.__dict__[key] = value self._memoized_keys |= {key} @@ -1342,7 +1317,7 @@ class MemoizedSlots: # from paste.deploy.converters -def asbool(obj): +def asbool(obj: Any) -> bool: if isinstance(obj, str): obj = obj.strip().lower() if obj in ["true", "yes", "on", "y", "t", "1"]: @@ -1354,13 +1329,13 @@ def asbool(obj): return bool(obj) -def bool_or_str(*text): +def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]: """Return a callable that will evaluate a string as boolean, or one of a set of "alternate" string values. """ - def bool_or_value(obj): + def bool_or_value(obj: str) -> Union[str, bool]: if obj in text: return obj else: @@ -1369,7 +1344,7 @@ def bool_or_str(*text): return bool_or_value -def asint(value): +def asint(value: Any) -> Optional[int]: """Coerce to integer.""" if value is None: @@ -1377,7 +1352,13 @@ def asint(value): return int(value) -def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None): +def coerce_kw_type( + kw: Dict[str, Any], + key: str, + type_: Type[Any], + flexi_bool: bool = True, + dest: Optional[Dict[str, Any]] = None, +) -> None: r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if necessary. If 'flexi_bool' is True, the string '0' is considered false when coercing to boolean. @@ -1397,7 +1378,7 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None): dest[key] = type_(kw[key]) -def constructor_key(obj, cls): +def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]: """Produce a tuple structure that is cacheable using the __dict__ of obj to retrieve values @@ -1408,7 +1389,7 @@ def constructor_key(obj, cls): ) -def constructor_copy(obj, cls, *args, **kw): +def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T: """Instantiate cls using the __dict__ of obj as constructor arguments. Uses inspect to match the named arguments of ``cls``. @@ -1422,7 +1403,7 @@ def constructor_copy(obj, cls, *args, **kw): return cls(*args, **kw) -def counter(): +def counter() -> Callable[[], int]: """Return a threadsafe counter function.""" lock = threading.Lock() @@ -1436,47 +1417,51 @@ def counter(): return _next -def duck_type_collection(specimen, default=None): +def duck_type_collection( + specimen: Union[object, Type[Any]], default: Optional[Type[Any]] = None +) -> Type[Any]: """Given an instance or class, guess if it is or is acting as one of the basic collection types: list, set and dict. If the __emulates__ property is present, return that preferentially. """ + if typing.TYPE_CHECKING: + return object + else: + if hasattr(specimen, "__emulates__"): + # canonicalize set vs sets.Set to a standard: the builtin set + if specimen.__emulates__ is not None and issubclass( + specimen.__emulates__, set + ): + return set + else: + return specimen.__emulates__ - if hasattr(specimen, "__emulates__"): - # canonicalize set vs sets.Set to a standard: the builtin set - if specimen.__emulates__ is not None and issubclass( - specimen.__emulates__, set - ): + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): + return list + elif isa(specimen, set): return set + elif isa(specimen, dict): + return dict + + if hasattr(specimen, "append"): + return list + elif hasattr(specimen, "add"): + return set + elif hasattr(specimen, "set"): + return dict else: - return specimen.__emulates__ - - isa = isinstance(specimen, type) and issubclass or isinstance - if isa(specimen, list): - return list - elif isa(specimen, set): - return set - elif isa(specimen, dict): - return dict - - if hasattr(specimen, "append"): - return list - elif hasattr(specimen, "add"): - return set - elif hasattr(specimen, "set"): - return dict - else: - return default + return default -def assert_arg_type(arg, argtype, name): +def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any: if isinstance(arg, argtype): return arg else: if isinstance(argtype, tuple): raise exc.ArgumentError( "Argument '%s' is expected to be one of type %s, got '%s'" - % (name, " or ".join("'%s'" % a for a in argtype), type(arg)) + % (name, " or ".join("'%s'" % a for a in argtype), type(arg)) # type: ignore # noqa E501 ) else: raise exc.ArgumentError( diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index ddda420db..ad9c8e531 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -13,7 +13,7 @@ from typing import Type from typing import TypeVar from typing import Union -from typing_extensions import NotRequired # noqa +from typing_extensions import NotRequired as NotRequired # noqa from . import compat diff --git a/pyproject.toml b/pyproject.toml index e79c7292d..b2754b193 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,14 @@ markers = [ [tool.pyright] include = [ + "lib/sqlalchemy/engine/base.py", + "lib/sqlalchemy/engine/events.py", + "lib/sqlalchemy/engine/interfaces.py", + "lib/sqlalchemy/engine/_py_row.py", + "lib/sqlalchemy/engine/result.py", + "lib/sqlalchemy/engine/row.py", + "lib/sqlalchemy/engine/util.py", + "lib/sqlalchemy/engine/url.py", "lib/sqlalchemy/pool/", "lib/sqlalchemy/event/", "lib/sqlalchemy/events.py", @@ -79,9 +87,19 @@ strict = true # the whole library 100% strictly typed, so we have to tune this based on # the type of module or package we are dealing with +[[tool.mypy.overrides]] +# ad-hoc ignores +module = [ + "sqlalchemy.engine.reflection", # interim, should be strict +] + +ignore_errors = true + # strict checking [[tool.mypy.overrides]] module = [ + "sqlalchemy.connectors.*", + "sqlalchemy.engine.*", "sqlalchemy.pool.*", "sqlalchemy.event.*", "sqlalchemy.events", @@ -95,11 +113,16 @@ strict = true # partial checking, internals can be untyped [[tool.mypy.overrides]] -module="sqlalchemy.util.*" + +module = [ + "sqlalchemy.util.*", + "sqlalchemy.engine.cursor", + "sqlalchemy.engine.default", +] + + ignore_errors = false -# util is for internal use so we can get by without everything -# being typed allow_untyped_defs = true check_untyped_defs = false allow_untyped_calls = true diff --git a/test/base/test_result.py b/test/base/test_result.py index 8818ccb14..7a696d352 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -1,7 +1,6 @@ from sqlalchemy import exc from sqlalchemy import testing from sqlalchemy.engine import result -from sqlalchemy.engine.row import Row from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -728,27 +727,6 @@ class ResultTest(fixtures.TestBase): # still slices eq_(m1.fetchone(), {"b": 1, "c": 2}) - def test_alt_row_fetch(self): - class AppleRow(Row): - def apple(self): - return "apple" - - result = self._fixture(alt_row=AppleRow) - - row = result.all()[0] - eq_(row.apple(), "apple") - - def test_alt_row_transform(self): - class AppleRow(Row): - def apple(self): - return "apple" - - result = self._fixture(alt_row=AppleRow) - - row = result.columns("c", "a").all()[2] - eq_(row.apple(), "apple") - eq_(row, (2, 1)) - def test_scalar_none_iterate(self): result = self._fixture( data=[ diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index b40981a99..d54a37ceb 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -34,19 +34,19 @@ class ParseConnectTest(fixtures.TestBase): dialect = pyodbc.dialect() u = url.make_url("mssql+pyodbc://mydsn") connection = dialect.create_connect_args(u) - eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) + eq_((("dsn=mydsn;Trusted_Connection=Yes",), {}), connection) def test_pyodbc_connect_old_style_dsn_trusted(self): dialect = pyodbc.dialect() u = url.make_url("mssql+pyodbc:///?dsn=mydsn") connection = dialect.create_connect_args(u) - eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) + eq_((("dsn=mydsn;Trusted_Connection=Yes",), {}), connection) def test_pyodbc_connect_dsn_non_trusted(self): dialect = pyodbc.dialect() u = url.make_url("mssql+pyodbc://username:password@mydsn") connection = dialect.create_connect_args(u) - eq_([["dsn=mydsn;UID=username;PWD=password"], {}], connection) + eq_((("dsn=mydsn;UID=username;PWD=password",), {}), connection) def test_pyodbc_connect_dsn_extra(self): dialect = pyodbc.dialect() @@ -66,13 +66,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -99,13 +99,13 @@ class ParseConnectTest(fixtures.TestBase): ) eq_( - [ - [ + ( + ( "Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -117,13 +117,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec,12345;Database=datab" - "ase;UID=username;PWD=password" - ], + "ase;UID=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -135,13 +135,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password;port=12345" - ], + "D=username;PWD=password;port=12345", + ), {}, - ], + ), connection, ) @@ -193,13 +193,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -211,7 +211,7 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [["dsn=mydsn;Database=database;UID=username;PWD=password"], {}], + (("dsn=mydsn;Database=database;UID=username;PWD=password",), {}), connection, ) @@ -225,13 +225,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -248,14 +248,14 @@ class ParseConnectTest(fixtures.TestBase): dialect = pyodbc.dialect() connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={foob};Server=somehost%3BPORT%3D50001;" "Database=somedb%3BPORT%3D50001;UID={someuser;PORT=50001};" - "PWD={some{strange}}pw;PORT=50001}" - ], + "PWD={some{strange}}pw;PORT=50001}", + ), {}, - ], + ), connection, ) @@ -265,7 +265,7 @@ class ParseConnectTest(fixtures.TestBase): u = url.make_url("mssql+pymssql://scott:tiger@somehost/test") connection = dialect.create_connect_args(u) eq_( - [ + ( [], { "host": "somehost", @@ -273,14 +273,14 @@ class ParseConnectTest(fixtures.TestBase): "user": "scott", "database": "test", }, - ], + ), connection, ) u = url.make_url("mssql+pymssql://scott:tiger@somehost:5000/test") connection = dialect.create_connect_args(u) eq_( - [ + ( [], { "host": "somehost:5000", @@ -288,7 +288,7 @@ class ParseConnectTest(fixtures.TestBase): "user": "scott", "database": "test", }, - ], + ), connection, ) @@ -584,7 +584,9 @@ class VersionDetectionTest(fixtures.TestBase): ) ) ), - connection=Mock(getinfo=Mock(return_value=vers)), + connection=Mock( + dbapi_connection=Mock(getinfo=Mock(return_value=vers)), + ), ) eq_(dialect._get_server_version_info(conn), expected) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 613fc80a5..5a92ae6fe 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -368,11 +368,11 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, {"x": "x3", "y": "y3"}, - ), + ], **expected_kwargs, ) ], @@ -417,11 +417,11 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, {"x": "x3", "y": "y3"}, - ), + ], **expected_kwargs, ) ], @@ -470,11 +470,11 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, {"x": "x3", "y": "y3"}, - ), + ], **expected_kwargs, ) ], @@ -524,10 +524,10 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}, - ), + ], **expected_kwargs, ) ], @@ -714,11 +714,11 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): mock.call( mock.ANY, "INSERT INTO data (id, x, y, z) VALUES %s", - ( + [ {"id": 1, "y": "y1", "z": 1}, {"id": 2, "y": "y2", "z": 2}, {"id": 3, "y": "y3", "z": 3}, - ), + ], template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)", fetch=False, page_size=connection.dialect.executemany_values_page_size, diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index 8dc1f0f48..e1c610701 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -87,6 +87,93 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): class CreateEngineTest(fixtures.TestBase): + @testing.requires.sqlite + def test_dbapi_clsmethod_renamed(self): + """The dbapi() class method is renamed to import_dbapi(), + so that the .dbapi attribute can be exclusively an instance + attribute. + + """ + + from sqlalchemy.dialects.sqlite import pysqlite + from sqlalchemy.dialects import registry + + canary = mock.Mock() + + class MyDialect(pysqlite.SQLiteDialect_pysqlite): + @classmethod + def dbapi(cls): + canary() + return __import__("sqlite3") + + tokens = __name__.split(".") + + global dialect + dialect = MyDialect + + registry.register( + "mockdialect1.sqlite", ".".join(tokens[0:-1]), tokens[-1] + ) + + with expect_deprecated( + r"The dbapi\(\) classmethod on dialect classes has " + r"been renamed to import_dbapi\(\). Implement an " + r"import_dbapi\(\) classmethod directly on class " + r".*MyDialect.* to remove this warning; the old " + r".dbapi\(\) classmethod may be maintained for backwards " + r"compatibility." + ): + e = create_engine("mockdialect1+sqlite://") + + eq_(canary.mock_calls, [mock.call()]) + sqlite3 = __import__("sqlite3") + is_(e.dialect.dbapi, sqlite3) + + @testing.requires.sqlite + def test_no_warning_for_dual_dbapi_clsmethod(self): + """The dbapi() class method is renamed to import_dbapi(), + so that the .dbapi attribute can be exclusively an instance + attribute. + + Dialect classes will likely have both a dbapi() classmethod + as well as an import_dbapi() class method to maintain + cross-compatibility. Make sure these updated classes don't get a + warning and that the new method is used. + + """ + + from sqlalchemy.dialects.sqlite import pysqlite + from sqlalchemy.dialects import registry + + canary = mock.Mock() + + class MyDialect(pysqlite.SQLiteDialect_pysqlite): + @classmethod + def dbapi(cls): + canary.dbapi() + return __import__("sqlite3") + + @classmethod + def import_dbapi(cls): + canary.import_dbapi() + return __import__("sqlite3") + + tokens = __name__.split(".") + + global dialect + dialect = MyDialect + + registry.register( + "mockdialect2.sqlite", ".".join(tokens[0:-1]), tokens[-1] + ) + + # no warning + e = create_engine("mockdialect2+sqlite://") + + eq_(canary.mock_calls, [mock.call.import_dbapi()]) + sqlite3 = __import__("sqlite3") + is_(e.dialect.dbapi, sqlite3) + def test_strategy_keyword_mock(self): def executor(x, y): pass diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 59bc4863f..dbd957703 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -2285,6 +2285,40 @@ class EngineEventsTest(fixtures.TestBase): [call(c2, {"c1": "opt_c1"}), call(c4, {"c3": "opt_c3"})], ) + def test_execution_options_modify_inplace(self): + engine = engines.testing_engine() + + @event.listens_for(engine, "set_engine_execution_options") + def engine_tracker(conn, opt): + opt["engine_tracked"] = True + + @event.listens_for(engine, "set_connection_execution_options") + def conn_tracker(conn, opt): + opt["conn_tracked"] = True + + with mock.patch.object( + engine.dialect, "set_connection_execution_options" + ) as conn_opt, mock.patch.object( + engine.dialect, "set_engine_execution_options" + ) as engine_opt: + e2 = engine.execution_options(e1="opt_e1") + c1 = engine.connect() + c2 = c1.execution_options(c1="opt_c1") + + is_not(e2, engine) + is_(c1, c2) + + eq_(e2._execution_options, {"e1": "opt_e1", "engine_tracked": True}) + eq_(c2._execution_options, {"c1": "opt_c1", "conn_tracked": True}) + eq_( + engine_opt.mock_calls, + [mock.call(e2, {"e1": "opt_e1", "engine_tracked": True})], + ) + eq_( + conn_opt.mock_calls, + [mock.call(c1, {"c1": "opt_c1", "conn_tracked": True})], + ) + @testing.requires.sequences @testing.provide_metadata def test_cursor_execute(self): diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 4c378dda1..23be61aaf 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -1111,7 +1111,7 @@ class TestGetDialect(fixtures.TestBase): class MockDialect(DefaultDialect): @classmethod - def dbapi(cls, **kw): + def import_dbapi(cls, **kw): return MockDBAPI() diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index c1f0639bb..b8d9a3618 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1023,6 +1023,23 @@ class RealReconnectTest(fixtures.TestBase): eq_(conn.execute(select(1)).scalar(), 1) assert not conn.invalidated + def test_detach_invalidated(self): + with self.engine.connect() as conn: + conn.invalidate() + with expect_raises_message( + exc.InvalidRequestError, + "Can't detach an invalidated Connection", + ): + conn.detach() + + def test_detach_closed(self): + with self.engine.connect() as conn: + pass + with expect_raises_message( + exc.ResourceClosedError, "This Connection is closed" + ): + conn.detach() + @testing.requires.independent_connections def test_multiple_invalidate(self): c1 = self.engine.connect() @@ -1078,8 +1095,23 @@ class RealReconnectTest(fixtures.TestBase): conn.begin() trans2 = conn.begin_nested() conn.invalidate() + + # this passes silently, as it will often be involved + # in error catching schemes trans2.rollback() + # still invalid though + with expect_raises(exc.PendingRollbackError): + conn.begin_nested() + + def test_no_begin_on_invalid(self): + with self.engine.connect() as conn: + conn.begin() + conn.invalidate() + + with expect_raises(exc.PendingRollbackError): + conn.commit() + def test_invalidate_twice(self): with self.engine.connect() as conn: conn.invalidate() diff --git a/test/profiles.txt b/test/profiles.txt index 67f155eba..750b57780 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -98,48 +98,48 @@ test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_ # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 50235 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 62055 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 50735 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61045 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 49335 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61155 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 49435 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 59745 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53235 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 63155 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53335 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61745 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 52335 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 62255 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 52435 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 60845 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 45435 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 49255 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 45035 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 48345 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 48235 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 57555 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 48335 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 56145 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 47335 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 56655 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 47435 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 55245 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 34205 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 33805 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37005 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 33305 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37005 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 32905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 36105 # TEST: test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set @@ -163,13 +163,13 @@ test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15227 -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 34246 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15313 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26332 # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 21393 -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 28412 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 21377 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26396 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased @@ -188,23 +188,23 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpy # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 98439 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 103939 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 98506 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 104006 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96819 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102304 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96844 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102344 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520593 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522453 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520615 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522475 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 431505 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 464305 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 431905 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 450605 # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity @@ -213,18 +213,18 @@ test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_ # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 106782 -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 115789 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 106870 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 115127 # TEST: test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 19931 -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 21463 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 20030 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 21434 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_load -test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1362 -test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1458 +test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1366 +test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1455 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_no_load @@ -233,18 +233,18 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3. # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols -test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6109 -test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 7329 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6167 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 6987 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 258605 -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 281805 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 259205 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 278405 # TEST: test.aaa_profiling.test_orm.SessionTest.test_expire_lots -test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1276 -test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1268 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1252 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1260 # TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index e5b1a0a26..ff70fc184 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1612,15 +1612,15 @@ class CursorResultTest(fixtures.TablesTest): eq_(dict(row._mapping), {"a": "av", "b": "bv", "count": "cv"}) - with assertions.expect_raises_message( + with assertions.expect_raises( TypeError, - "TypeError: tuple indices must be integers or slices, not str", + "tuple indices must be integers or slices, not str", ): eq_(row["a"], "av") with assertions.expect_raises_message( TypeError, - "TypeError: tuple indices must be integers or slices, not str", + "tuple indices must be integers or slices, not str", ): eq_(row["count"], "cv") @@ -3197,8 +3197,7 @@ class GenerativeResultTest(fixtures.TablesTest): all_ = result.columns(*columns).all() eq_(all_, expected) - # ensure Row / LegacyRow comes out with .columns - assert type(all_[0]) is result._process_row + assert type(all_[0]) is Row def test_columns_twice(self, connection): users = self.tables.users @@ -3216,8 +3215,7 @@ class GenerativeResultTest(fixtures.TablesTest): ) eq_(all_, [("jack", 1)]) - # ensure Row / LegacyRow comes out with .columns - assert type(all_[0]) is result._process_row + assert type(all_[0]) is Row def test_columns_plus_getter(self, connection): users = self.tables.users |
