summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/eng_ex_opt.rst9
-rw-r--r--doc/build/tutorial/data_insert.rst6
-rw-r--r--doc/build/tutorial/data_update.rst2
-rw-r--r--doc/build/tutorial/dbapi_transactions.rst8
-rw-r--r--lib/sqlalchemy/connectors/__init__.py12
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py88
-rw-r--r--lib/sqlalchemy/cyextension/resultproxy.pyx29
-rw-r--r--lib/sqlalchemy/dialects/mssql/pymssql.py4
-rw-r--r--lib/sqlalchemy/dialects/mssql/pyodbc.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/aiomysql.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/asyncmy.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/cymysql.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/mariadbconnector.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqlconnector.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqldb.py2
-rw-r--r--lib/sqlalchemy/dialects/mysql/pymysql.py2
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg.py4
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py2
-rw-r--r--lib/sqlalchemy/dialects/sqlite/aiosqlite.py2
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlcipher.py2
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py2
-rw-r--r--lib/sqlalchemy/engine/__init__.py1
-rw-r--r--lib/sqlalchemy/engine/_py_processors.py30
-rw-r--r--lib/sqlalchemy/engine/_py_row.py104
-rw-r--r--lib/sqlalchemy/engine/_py_util.py38
-rw-r--r--lib/sqlalchemy/engine/base.py698
-rw-r--r--lib/sqlalchemy/engine/characteristics.py35
-rw-r--r--lib/sqlalchemy/engine/create.py65
-rw-r--r--lib/sqlalchemy/engine/cursor.py107
-rw-r--r--lib/sqlalchemy/engine/default.py313
-rw-r--r--lib/sqlalchemy/engine/events.py193
-rw-r--r--lib/sqlalchemy/engine/interfaces.py488
-rw-r--r--lib/sqlalchemy/engine/mock.py57
-rw-r--r--lib/sqlalchemy/engine/processors.py22
-rw-r--r--lib/sqlalchemy/engine/result.py609
-rw-r--r--lib/sqlalchemy/engine/row.py183
-rw-r--r--lib/sqlalchemy/engine/url.py186
-rw-r--r--lib/sqlalchemy/engine/util.py55
-rw-r--r--lib/sqlalchemy/event/__init__.py1
-rw-r--r--lib/sqlalchemy/event/attr.py4
-rw-r--r--lib/sqlalchemy/event/base.py24
-rw-r--r--lib/sqlalchemy/exc.py31
-rw-r--r--lib/sqlalchemy/log.py10
-rw-r--r--lib/sqlalchemy/pool/__init__.py4
-rw-r--r--lib/sqlalchemy/pool/base.py139
-rw-r--r--lib/sqlalchemy/sql/_typing.py9
-rw-r--r--lib/sqlalchemy/sql/base.py41
-rw-r--r--lib/sqlalchemy/sql/cache_key.py13
-rw-r--r--lib/sqlalchemy/sql/compiler.py180
-rw-r--r--lib/sqlalchemy/sql/elements.py12
-rw-r--r--lib/sqlalchemy/sql/schema.py12
-rw-r--r--lib/sqlalchemy/sql/util.py38
-rw-r--r--lib/sqlalchemy/sql/visitors.py11
-rw-r--r--lib/sqlalchemy/util/__init__.py1
-rw-r--r--lib/sqlalchemy/util/_collections.py7
-rw-r--r--lib/sqlalchemy/util/_py_collections.py18
-rw-r--r--lib/sqlalchemy/util/deprecations.py7
-rw-r--r--lib/sqlalchemy/util/langhelpers.py137
-rw-r--r--lib/sqlalchemy/util/typing.py2
-rw-r--r--pyproject.toml29
-rw-r--r--test/base/test_result.py22
-rw-r--r--test/dialect/mssql/test_engine.py90
-rw-r--r--test/dialect/postgresql/test_dialect.py20
-rw-r--r--test/engine/test_deprecations.py87
-rw-r--r--test/engine/test_execute.py34
-rw-r--r--test/engine/test_parseconnect.py2
-rw-r--r--test/engine/test_reconnect.py32
-rw-r--r--test/profiles.txt84
-rw-r--r--test/sql/test_resultset.py12
74 files changed, 2963 insertions, 1530 deletions
diff --git a/doc/build/changelog/unreleased_20/eng_ex_opt.rst b/doc/build/changelog/unreleased_20/eng_ex_opt.rst
new file mode 100644
index 000000000..00947f3de
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/eng_ex_opt.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: engine, feature
+
+ The :meth:`.ConnectionEvents.set_connection_execution_options`
+ and :meth:`.ConnectionEvents.set_engine_execution_options`
+ event hooks now allow the given options dictionary to be modified
+ in-place, where the new contents will be received as the ultimate
+ execution options to be acted upon. Previously, in-place modifications to
+ the dictionary were not supported.
diff --git a/doc/build/tutorial/data_insert.rst b/doc/build/tutorial/data_insert.rst
index 74b0aff56..a8b1a49a2 100644
--- a/doc/build/tutorial/data_insert.rst
+++ b/doc/build/tutorial/data_insert.rst
@@ -127,7 +127,7 @@ illustrate this:
... conn.commit()
{opensql}BEGIN (implicit)
INSERT INTO user_account (name, fullname) VALUES (?, ?)
- [...] (('sandy', 'Sandy Cheeks'), ('patrick', 'Patrick Star'))
+ [...] [('sandy', 'Sandy Cheeks'), ('patrick', 'Patrick Star')]
COMMIT{stop}
The execution above features "executemany" form first illustrated at
@@ -185,8 +185,8 @@ construct automatically.
INSERT INTO address (user_id, email_address) VALUES ((SELECT user_account.id
FROM user_account
WHERE user_account.name = ?), ?)
- [...] (('spongebob', 'spongebob@sqlalchemy.org'), ('sandy', 'sandy@sqlalchemy.org'),
- ('sandy', 'sandy@squirrelpower.org'))
+ [...] [('spongebob', 'spongebob@sqlalchemy.org'), ('sandy', 'sandy@sqlalchemy.org'),
+ ('sandy', 'sandy@squirrelpower.org')]
COMMIT{stop}
.. _tutorial_insert_from_select:
diff --git a/doc/build/tutorial/data_update.rst b/doc/build/tutorial/data_update.rst
index 8813dda98..8e88eb2f7 100644
--- a/doc/build/tutorial/data_update.rst
+++ b/doc/build/tutorial/data_update.rst
@@ -101,7 +101,7 @@ that literal values would normally go:
... )
{opensql}BEGIN (implicit)
UPDATE user_account SET name=? WHERE user_account.name = ?
- [...] (('ed', 'jack'), ('mary', 'wendy'), ('jake', 'jim'))
+ [...] [('ed', 'jack'), ('mary', 'wendy'), ('jake', 'jim')]
<sqlalchemy.engine.cursor.CursorResult object at 0x...>
COMMIT{stop}
diff --git a/doc/build/tutorial/dbapi_transactions.rst b/doc/build/tutorial/dbapi_transactions.rst
index 16768da2b..f4d2ad8e0 100644
--- a/doc/build/tutorial/dbapi_transactions.rst
+++ b/doc/build/tutorial/dbapi_transactions.rst
@@ -115,7 +115,7 @@ where we acquired the :class:`_future.Connection` object:
[...] ()
<sqlalchemy.engine.cursor.CursorResult object at 0x...>
INSERT INTO some_table (x, y) VALUES (?, ?)
- [...] ((1, 1), (2, 4))
+ [...] [(1, 1), (2, 4)]
<sqlalchemy.engine.cursor.CursorResult object at 0x...>
COMMIT
@@ -149,7 +149,7 @@ may be referred towards as **begin once**:
... )
{opensql}BEGIN (implicit)
INSERT INTO some_table (x, y) VALUES (?, ?)
- [...] ((6, 8), (9, 10))
+ [...] [(6, 8), (9, 10)]
<sqlalchemy.engine.cursor.CursorResult object at 0x...>
COMMIT
@@ -374,7 +374,7 @@ be invoked against each parameter set individually:
... conn.commit()
{opensql}BEGIN (implicit)
INSERT INTO some_table (x, y) VALUES (?, ?)
- [...] ((11, 12), (13, 14))
+ [...] [(11, 12), (13, 14)]
<sqlalchemy.engine.cursor.CursorResult object at 0x...>
COMMIT
@@ -508,7 +508,7 @@ our data:
... session.commit()
{opensql}BEGIN (implicit)
UPDATE some_table SET y=? WHERE x=?
- [...] ((11, 9), (15, 13))
+ [...] [(11, 9), (15, 13)]
COMMIT{stop}
Above, we invoked an UPDATE statement using the bound-parameter, "executemany"
diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py
index 132a0a4de..f4fa5b66b 100644
--- a/lib/sqlalchemy/connectors/__init__.py
+++ b/lib/sqlalchemy/connectors/__init__.py
@@ -6,5 +6,13 @@
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-class Connector:
- pass
+from ..engine.interfaces import Dialect
+
+
+class Connector(Dialect):
+ """Base class for dialect mixins, for DBAPIs that work
+ across entirely different database backends.
+
+ Currently the only such mixin is pyodbc.
+
+ """
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index f7d01ce43..c5f07de07 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -5,12 +5,27 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
+
import re
+from types import ModuleType
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
from urllib.parse import unquote_plus
from . import Connector
+from .. import ExecutionContext
+from .. import pool
from .. import util
+from ..engine import ConnectArgsType
+from ..engine import Connection
from ..engine import interfaces
+from ..engine import URL
+from ..sql.type_api import TypeEngine
class PyODBCConnector(Connector):
@@ -25,18 +40,20 @@ class PyODBCConnector(Connector):
# for non-DSN connections, this *may* be used to
# hold the desired driver name
- pyodbc_driver_name = None
+ pyodbc_driver_name: Optional[str] = None
+
+ dbapi: ModuleType
- def __init__(self, use_setinputsizes=False, **kw):
+ def __init__(self, use_setinputsizes: bool = False, **kw: Any):
super(PyODBCConnector, self).__init__(**kw)
if use_setinputsizes:
self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls) -> ModuleType:
return __import__("pyodbc")
- def create_connect_args(self, url):
+ def create_connect_args(self, url: URL) -> ConnectArgsType:
opts = url.translate_connect_args(username="user")
opts.update(url.query)
@@ -44,7 +61,9 @@ class PyODBCConnector(Connector):
query = url.query
- connect_args = {}
+ connect_args: Dict[str, Any] = {}
+ connectors: List[str]
+
for param in ("ansi", "unicode_results", "autocommit"):
if param in keys:
connect_args[param] = util.asbool(keys.pop(param))
@@ -53,7 +72,7 @@ class PyODBCConnector(Connector):
connectors = [unquote_plus(keys.pop("odbc_connect"))]
else:
- def check_quote(token):
+ def check_quote(token: str) -> str:
if ";" in str(token):
token = "{%s}" % token.replace("}", "}}")
return token
@@ -115,9 +134,14 @@ class PyODBCConnector(Connector):
connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()])
- return [[";".join(connectors)], connect_args]
+ return ((";".join(connectors),), connect_args)
- def is_disconnect(self, e, connection, cursor):
+ def is_disconnect(
+ self,
+ e: Exception,
+ connection: Optional[pool.PoolProxiedConnection],
+ cursor: Optional[interfaces.DBAPICursor],
+ ) -> bool:
if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(
e
@@ -125,36 +149,44 @@ class PyODBCConnector(Connector):
else:
return False
- def _dbapi_version(self):
+ def _dbapi_version(self) -> interfaces.VersionInfoType:
if not self.dbapi:
return ()
return self._parse_dbapi_version(self.dbapi.version)
- def _parse_dbapi_version(self, vers):
+ def _parse_dbapi_version(self, vers: str) -> interfaces.VersionInfoType:
m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers)
if not m:
return ()
- vers = tuple([int(x) for x in m.group(1).split(".")])
+ vers_tuple: interfaces.VersionInfoType = tuple(
+ [int(x) for x in m.group(1).split(".")]
+ )
if m.group(2):
- vers += (m.group(2),)
- return vers
+ vers_tuple += (m.group(2),)
+ return vers_tuple
- def _get_server_version_info(self, connection, allow_chars=True):
+ def _get_server_version_info(
+ self, connection: Connection
+ ) -> interfaces.VersionInfoType:
# NOTE: this function is not reliable, particularly when
# freetds is in use. Implement database-specific server version
# queries.
- dbapi_con = connection.connection
- version = []
+ dbapi_con = connection.connection.dbapi_connection
+ version: Tuple[Union[int, str], ...] = ()
r = re.compile(r"[.\-]")
- for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
+ for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): # type: ignore[union-attr] # noqa E501
try:
- version.append(int(n))
+ version += (int(n),)
except ValueError:
- if allow_chars:
- version.append(n)
+ pass
return tuple(version)
- def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ def do_set_input_sizes(
+ self,
+ cursor: interfaces.DBAPICursor,
+ list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]],
+ context: ExecutionContext,
+ ) -> None:
# the rules for these types seems a little strange, as you can pass
# non-tuples as well as tuples, however it seems to assume "0"
# for the subsequent values if you don't pass a tuple which fails
@@ -174,12 +206,16 @@ class PyODBCConnector(Connector):
]
)
- def get_isolation_level_values(self, dbapi_connection):
- return super().get_isolation_level_values(dbapi_connection) + [
+ def get_isolation_level_values(
+ self, dbapi_connection: interfaces.DBAPIConnection
+ ) -> List[str]:
+ return super().get_isolation_level_values(dbapi_connection) + [ # type: ignore # noqa E501
"AUTOCOMMIT"
]
- def set_isolation_level(self, dbapi_connection, level):
+ def set_isolation_level(
+ self, dbapi_connection: interfaces.DBAPIConnection, level: str
+ ) -> None:
# adjust for ConnectionFairy being present
# allows attribute set e.g. "connection.autocommit = True"
# to work properly
@@ -188,6 +224,4 @@ class PyODBCConnector(Connector):
dbapi_connection.autocommit = True
else:
dbapi_connection.autocommit = False
- super(PyODBCConnector, self).set_isolation_level(
- dbapi_connection, level
- )
+ super().set_isolation_level(dbapi_connection, level)
diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx
index daf5cc940..e88c8ec0b 100644
--- a/lib/sqlalchemy/cyextension/resultproxy.pyx
+++ b/lib/sqlalchemy/cyextension/resultproxy.pyx
@@ -7,8 +7,6 @@ cdef int MD_INDEX = 0 # integer index in cursor.description
KEY_INTEGER_ONLY = 0
KEY_OBJECTS_ONLY = 1
-sqlalchemy_engine_row = None
-
cdef class BaseRow:
cdef readonly object _parent
cdef readonly tuple _data
@@ -53,19 +51,6 @@ cdef class BaseRow:
self._keymap = self._parent._keymap
self._key_style = state["_key_style"]
- def _filter_on_values(self, filters):
- global sqlalchemy_engine_row
- if sqlalchemy_engine_row is None:
- from sqlalchemy.engine.row import Row as sqlalchemy_engine_row
-
- return sqlalchemy_engine_row(
- self._parent,
- filters,
- self._keymap,
- self._key_style,
- self._data,
- )
-
def _values_impl(self):
return list(self)
@@ -78,18 +63,8 @@ cdef class BaseRow:
def __hash__(self):
return hash(self._data)
- def _get_by_int_impl(self, key):
- return self._data[key]
-
- cpdef _get_by_key_impl(self, key):
- # keep two isinstance since it's noticeably faster in the int case
- if isinstance(key, int) or isinstance(key, slice):
- return self._data[key]
-
- self._parent._raise_for_nonint(key)
-
- def __getitem__(self, key):
- return self._get_by_key_impl(key)
+ def __getitem__(self, index):
+ return self._data[index]
cpdef _get_by_key_impl_mapping(self, key):
try:
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
index 20d2b09db..8d654b72d 100644
--- a/lib/sqlalchemy/dialects/mssql/pymssql.py
+++ b/lib/sqlalchemy/dialects/mssql/pymssql.py
@@ -77,7 +77,7 @@ class MSDialect_pymssql(MSDialect):
)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
module = __import__("pymssql")
# pymmsql < 2.1.1 doesn't have a Binary method. we use string
client_ver = tuple(int(x) for x in module.__version__.split("."))
@@ -106,7 +106,7 @@ class MSDialect_pymssql(MSDialect):
port = opts.pop("port", None)
if port and "host" in opts:
opts["host"] = "%s:%s" % (opts["host"], port)
- return [[], opts]
+ return ([], opts)
def is_disconnect(self, e, connection, cursor):
for msg in (
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
index 3af083a6a..0951f219b 100644
--- a/lib/sqlalchemy/dialects/mssql/pyodbc.py
+++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py
@@ -538,7 +538,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
# 2008. Before we had the VARCHAR cast above, pyodbc would also
# fail on this query.
return super(MSDialect_pyodbc, self)._get_server_version_info(
- connection, allow_chars=False
+ connection
)
else:
version = []
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py
index df716346e..d685b7ea1 100644
--- a/lib/sqlalchemy/dialects/mysql/aiomysql.py
+++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py
@@ -276,7 +276,7 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
is_async = True
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return AsyncAdapt_aiomysql_dbapi(
__import__("aiomysql"), __import__("pymysql")
)
diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py
index 915b666bb..7d5b1bf86 100644
--- a/lib/sqlalchemy/dialects/mysql/asyncmy.py
+++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py
@@ -288,7 +288,7 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
is_async = True
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
@classmethod
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 5c2de0911..b2ccfc90f 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -2434,7 +2434,7 @@ class MySQLDialect(default.DefaultDialect):
@classmethod
def _is_mariadb_from_url(cls, url):
- dbapi = cls.dbapi()
+ dbapi = cls.import_dbapi()
dialect = cls(dbapi=dbapi)
cargs, cparams = dialect.create_connect_args(url)
diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py
index 9fd0b4a09..281c509b7 100644
--- a/lib/sqlalchemy/dialects/mysql/cymysql.py
+++ b/lib/sqlalchemy/dialects/mysql/cymysql.py
@@ -53,7 +53,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return __import__("cymysql")
def _detect_charset(self, connection):
diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py
index fca91204f..bf2b04251 100644
--- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py
+++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py
@@ -111,7 +111,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return __import__("mariadb")
def is_disconnect(self, e, connection, cursor):
diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
index c96b739dc..a69dac9a5 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
@@ -77,7 +77,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
from mysql import connector
return connector
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py
index b4f071de0..6d66f88b4 100644
--- a/lib/sqlalchemy/dialects/mysql/mysqldb.py
+++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py
@@ -159,7 +159,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
return False
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return __import__("MySQLdb")
def on_connect(self):
diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py
index eddb9c921..9a240da61 100644
--- a/lib/sqlalchemy/dialects/mysql/pymysql.py
+++ b/lib/sqlalchemy/dialects/mysql/pymysql.py
@@ -57,7 +57,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
return False
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return __import__("pymysql")
def create_connect_args(self, url, _translate_args=None):
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index a390099ae..98181051e 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -963,7 +963,7 @@ class OracleDialect_cx_oracle(OracleDialect):
return (0, 0, 0)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
import cx_Oracle
return cx_Oracle
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 4c3c47ba6..75f6c2704 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -878,7 +878,7 @@ class PGDialect_asyncpg(PGDialect):
return (99, 99, 99)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
@util.memoized_property
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index c23da93bb..372b8639e 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -426,7 +426,7 @@ class PGDialect_pg8000(PGDialect):
return (99, 99, 99)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return __import__("pg8000")
def create_connect_args(self, url):
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py
index 3ba535d6c..33dc65afc 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py
@@ -281,7 +281,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
register_hstore(info, connection.connection)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
import psycopg
return psycopg
@@ -592,7 +592,7 @@ class PGDialectAsync_psycopg(PGDialect_psycopg):
supports_statement_cache = True
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
import psycopg
from psycopg.pq import ExecStatus
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index a08c5e5b0..dddce5a62 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -612,7 +612,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
import psycopg2
return psycopg2
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
index 5a4dcb2e6..0943613a2 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
@@ -44,7 +44,7 @@ class PGDialect_psycopg2cffi(PGDialect_psycopg2):
)
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return __import__("psycopg2cffi")
@util.memoized_property
diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
index e88ab1a0f..dd0499975 100644
--- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
@@ -308,7 +308,7 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
execution_ctx_cls = SQLiteExecutionContext_aiosqlite
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
return AsyncAdapt_aiosqlite_dbapi(
__import__("aiosqlite"), __import__("sqlite3")
)
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
index 28f795298..b67eed974 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
@@ -105,7 +105,7 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac")
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
try:
import sqlcipher3 as sqlcipher
except ImportError:
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
index 8476e6834..2aa7149a6 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -465,7 +465,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
driver = "pysqlite"
@classmethod
- def dbapi(cls):
+ def import_dbapi(cls):
from sqlite3 import dbapi2 as sqlite
return sqlite
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index c6bc4b6aa..32f3f2ecc 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -35,6 +35,7 @@ from .cursor import ResultProxy as ResultProxy
from .interfaces import AdaptedConnection as AdaptedConnection
from .interfaces import BindTyping as BindTyping
from .interfaces import Compiled as Compiled
+from .interfaces import ConnectArgsType as ConnectArgsType
from .interfaces import CreateEnginePlugin as CreateEnginePlugin
from .interfaces import Dialect as Dialect
from .interfaces import ExceptionContext as ExceptionContext
diff --git a/lib/sqlalchemy/engine/_py_processors.py b/lib/sqlalchemy/engine/_py_processors.py
index e3024471a..27cb9e939 100644
--- a/lib/sqlalchemy/engine/_py_processors.py
+++ b/lib/sqlalchemy/engine/_py_processors.py
@@ -16,16 +16,30 @@ They all share one common characteristic: None is passed through unchanged.
from __future__ import annotations
import datetime
+from decimal import Decimal
import re
+import typing
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import Type
+from typing import TypeVar
+from typing import Union
+
+_DT = TypeVar(
+ "_DT", bound=Union[datetime.datetime, datetime.time, datetime.date]
+)
-def str_to_datetime_processor_factory(regexp, type_):
+def str_to_datetime_processor_factory(
+ regexp: typing.Pattern[str], type_: Callable[..., _DT]
+) -> Callable[[Optional[str]], Optional[_DT]]:
rmatch = regexp.match
# Even on python2.6 datetime.strptime is both slower than this code
# and it does not support microseconds.
has_named_groups = bool(regexp.groupindex)
- def process(value):
+ def process(value: Optional[str]) -> Optional[_DT]:
if value is None:
return None
else:
@@ -59,10 +73,12 @@ def str_to_datetime_processor_factory(regexp, type_):
return process
-def to_decimal_processor_factory(target_class, scale):
+def to_decimal_processor_factory(
+ target_class: Type[Decimal], scale: int
+) -> Callable[[Optional[float]], Optional[Decimal]]:
fstring = "%%.%df" % scale
- def process(value):
+ def process(value: Optional[float]) -> Optional[Decimal]:
if value is None:
return None
else:
@@ -71,21 +87,21 @@ def to_decimal_processor_factory(target_class, scale):
return process
-def to_float(value):
+def to_float(value: Optional[Union[int, float]]) -> Optional[float]:
if value is None:
return None
else:
return float(value)
-def to_str(value):
+def to_str(value: Optional[Any]) -> Optional[str]:
if value is None:
return None
else:
return str(value)
-def int_to_boolean(value):
+def int_to_boolean(value: Optional[int]) -> Optional[bool]:
if value is None:
return None
else:
diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py
index a6d5b79d5..7cbac552f 100644
--- a/lib/sqlalchemy/engine/_py_row.py
+++ b/lib/sqlalchemy/engine/_py_row.py
@@ -1,26 +1,59 @@
from __future__ import annotations
+import enum
import operator
+import typing
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import Union
+
+if typing.TYPE_CHECKING:
+ from .result import _KeyMapType
+ from .result import _KeyType
+ from .result import _ProcessorsType
+ from .result import _RawRowType
+ from .result import _TupleGetterType
+ from .result import ResultMetaData
MD_INDEX = 0 # integer index in cursor.description
-KEY_INTEGER_ONLY = 0
-"""__getitem__ only allows integer values and slices, raises TypeError
- otherwise"""
-KEY_OBJECTS_ONLY = 1
-"""__getitem__ only allows string/object values, raises TypeError otherwise"""
+class _KeyStyle(enum.Enum):
+ KEY_INTEGER_ONLY = 0
+ """__getitem__ only allows integer values and slices, raises TypeError
+ otherwise"""
-sqlalchemy_engine_row = None
+ KEY_OBJECTS_ONLY = 1
+ """__getitem__ only allows string/object values, raises TypeError
+ otherwise"""
+
+
+KEY_INTEGER_ONLY, KEY_OBJECTS_ONLY = list(_KeyStyle)
class BaseRow:
- Row = None
__slots__ = ("_parent", "_data", "_keymap", "_key_style")
- def __init__(self, parent, processors, keymap, key_style, data):
+ _parent: ResultMetaData
+ _data: _RawRowType
+ _keymap: _KeyMapType
+ _key_style: _KeyStyle
+
+ def __init__(
+ self,
+ parent: ResultMetaData,
+ processors: Optional[_ProcessorsType],
+ keymap: _KeyMapType,
+ key_style: _KeyStyle,
+ data: _RawRowType,
+ ):
"""Row objects are constructed by CursorResult objects."""
-
object.__setattr__(self, "_parent", parent)
if processors:
@@ -41,68 +74,45 @@ class BaseRow:
object.__setattr__(self, "_key_style", key_style)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]:
return (
rowproxy_reconstructor,
(self.__class__, self.__getstate__()),
)
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {
"_parent": self._parent,
"_data": self._data,
"_key_style": self._key_style,
}
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
parent = state["_parent"]
object.__setattr__(self, "_parent", parent)
object.__setattr__(self, "_data", state["_data"])
object.__setattr__(self, "_keymap", parent._keymap)
object.__setattr__(self, "_key_style", state["_key_style"])
- def _filter_on_values(self, filters):
- global sqlalchemy_engine_row
- if sqlalchemy_engine_row is None:
- from sqlalchemy.engine.row import Row as sqlalchemy_engine_row
-
- return sqlalchemy_engine_row(
- self._parent,
- filters,
- self._keymap,
- self._key_style,
- self._data,
- )
-
- def _values_impl(self):
+ def _values_impl(self) -> List[Any]:
return list(self)
- def __iter__(self):
+ def __iter__(self) -> Iterator[Any]:
return iter(self._data)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._data)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self._data)
- def _get_by_int_impl(self, key):
+ def _get_by_int_impl(self, key: Union[int, slice]) -> Any:
return self._data[key]
- def _get_by_key_impl(self, key):
- # keep two isinstance since it's noticeably faster in the int case
- if isinstance(key, int) or isinstance(key, slice):
- return self._data[key]
-
- self._parent._raise_for_nonint(key)
-
- # The original 1.4 plan was that Row would not allow row["str"]
- # access, however as the C extensions were inadvertently allowing
- # this coupled with the fact that orm Session sets future=True,
- # this allows a softer upgrade path. see #6218
- __getitem__ = _get_by_key_impl
+ if not typing.TYPE_CHECKING:
+ __getitem__ = _get_by_int_impl
- def _get_by_key_impl_mapping(self, key):
+ def _get_by_key_impl_mapping(self, key: _KeyType) -> Any:
try:
rec = self._keymap[key]
except KeyError as ke:
@@ -116,7 +126,7 @@ class BaseRow:
return self._data[mdindex]
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
try:
return self._get_by_key_impl_mapping(name)
except KeyError as e:
@@ -125,13 +135,15 @@ class BaseRow:
# This reconstructor is necessary so that pickles with the Cy extension or
# without use the same Binary format.
-def rowproxy_reconstructor(cls, state):
+def rowproxy_reconstructor(
+ cls: Type[BaseRow], state: Dict[str, Any]
+) -> BaseRow:
obj = cls.__new__(cls)
obj.__setstate__(state)
return obj
-def tuplegetter(*indexes):
+def tuplegetter(*indexes: int) -> _TupleGetterType:
it = operator.itemgetter(*indexes)
if len(indexes) > 1:
diff --git a/lib/sqlalchemy/engine/_py_util.py b/lib/sqlalchemy/engine/_py_util.py
index ff03a4761..538c075a2 100644
--- a/lib/sqlalchemy/engine/_py_util.py
+++ b/lib/sqlalchemy/engine/_py_util.py
@@ -1,21 +1,32 @@
from __future__ import annotations
-from collections import abc as collections_abc
+import typing
+from typing import Any
+from typing import Mapping
+from typing import Optional
+from typing import Tuple
from .. import exc
-_no_tuple = ()
+if typing.TYPE_CHECKING:
+ from .interfaces import _CoreAnyExecuteParams
+ from .interfaces import _CoreMultiExecuteParams
+ from .interfaces import _DBAPIAnyExecuteParams
+ from .interfaces import _DBAPIMultiExecuteParams
-def _distill_params_20(params):
+_no_tuple: Tuple[Any, ...] = ()
+
+
+def _distill_params_20(
+ params: Optional[_CoreAnyExecuteParams],
+) -> _CoreMultiExecuteParams:
if params is None:
return _no_tuple
# Assume list is more likely than tuple
elif isinstance(params, list) or isinstance(params, tuple):
# collections_abc.MutableSequence): # avoid abc.__instancecheck__
- if params and not isinstance(
- params[0], (tuple, collections_abc.Mapping)
- ):
+ if params and not isinstance(params[0], (tuple, Mapping)):
raise exc.ArgumentError(
"List argument must consist only of tuples or dictionaries"
)
@@ -25,21 +36,21 @@ def _distill_params_20(params):
# only do immutabledict or abc.__instancecheck__ for Mapping after
# we've checked for plain dictionaries and would otherwise raise
params,
- collections_abc.Mapping,
+ Mapping,
):
return [params]
else:
raise exc.ArgumentError("mapping or list expected for parameters")
-def _distill_raw_params(params):
+def _distill_raw_params(
+ params: Optional[_DBAPIAnyExecuteParams],
+) -> _DBAPIMultiExecuteParams:
if params is None:
return _no_tuple
elif isinstance(params, list):
# collections_abc.MutableSequence): # avoid abc.__instancecheck__
- if params and not isinstance(
- params[0], (tuple, collections_abc.Mapping)
- ):
+ if params and not isinstance(params[0], (tuple, Mapping)):
raise exc.ArgumentError(
"List argument must consist only of tuples or dictionaries"
)
@@ -49,8 +60,9 @@ def _distill_raw_params(params):
# only do abc.__instancecheck__ for Mapping after we've checked
# for plain dictionaries and would otherwise raise
params,
- collections_abc.Mapping,
+ Mapping,
):
- return [params]
+ # cast("Union[List[Mapping[str, Any]], Tuple[Any, ...]]", [params])
+ return [params] # type: ignore
else:
raise exc.ArgumentError("mapping or sequence expected for parameters")
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 8c99f6309..5ce531338 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -10,13 +10,24 @@ import contextlib
import sys
import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
from typing import Mapping
+from typing import MutableMapping
+from typing import NoReturn
from typing import Optional
+from typing import Tuple
+from typing import Type
from typing import Union
from .interfaces import BindTyping
from .interfaces import ConnectionEventsTarget
+from .interfaces import DBAPICursor
from .interfaces import ExceptionContext
+from .interfaces import ExecutionContext
from .util import _distill_params_20
from .util import _distill_raw_params
from .util import TransactionalContext
@@ -26,22 +37,48 @@ from .. import log
from .. import util
from ..sql import compiler
from ..sql import util as sql_util
-from ..sql._typing import _ExecuteOptions
-from ..sql._typing import _ExecuteParams
+
+_CompiledCacheType = MutableMapping[Any, Any]
if typing.TYPE_CHECKING:
+ from . import Result
+ from . import ScalarResult
+ from .interfaces import _AnyExecuteParams
+ from .interfaces import _AnyMultiExecuteParams
+ from .interfaces import _AnySingleExecuteParams
+ from .interfaces import _CoreAnyExecuteParams
+ from .interfaces import _CoreMultiExecuteParams
+ from .interfaces import _CoreSingleExecuteParams
+ from .interfaces import _DBAPIAnyExecuteParams
+ from .interfaces import _DBAPIMultiExecuteParams
+ from .interfaces import _DBAPISingleExecuteParams
+ from .interfaces import _ExecuteOptions
+ from .interfaces import _ExecuteOptionsParameter
+ from .interfaces import _SchemaTranslateMapType
from .interfaces import Dialect
from .reflection import Inspector # noqa
from .url import URL
+ from ..event import dispatcher
+ from ..log import _EchoFlagType
+ from ..pool import _ConnectionFairy
from ..pool import Pool
from ..pool import PoolProxiedConnection
+ from ..sql import Executable
+ from ..sql.base import SchemaVisitor
+ from ..sql.compiler import Compiled
+ from ..sql.ddl import DDLElement
+ from ..sql.ddl import SchemaDropper
+ from ..sql.ddl import SchemaGenerator
+ from ..sql.functions import FunctionElement
+ from ..sql.schema import ColumnDefault
+ from ..sql.schema import HasSchemaAttr
"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
"""
-_EMPTY_EXECUTION_OPTS = util.immutabledict()
-NO_OPTIONS = util.immutabledict()
+_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.immutabledict()
+NO_OPTIONS: Mapping[str, Any] = util.immutabledict()
class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
@@ -69,23 +106,32 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
+ dispatch: dispatcher[ConnectionEventsTarget]
+
_sqla_logger_namespace = "sqlalchemy.engine.Connection"
# used by sqlalchemy.engine.util.TransactionalContext
- _trans_context_manager = None
+ _trans_context_manager: Optional[TransactionalContext] = None
# legacy as of 2.0, should be eventually deprecated and
# removed. was used in the "pre_ping" recipe that's been in the docs
# a long time
should_close_with_result = False
+ _dbapi_connection: Optional[PoolProxiedConnection]
+
+ _execution_options: _ExecuteOptions
+
+ _transaction: Optional[RootTransaction]
+ _nested_transaction: Optional[NestedTransaction]
+
def __init__(
self,
- engine,
- connection=None,
- _has_events=None,
- _allow_revalidate=True,
- _allow_autobegin=True,
+ engine: Engine,
+ connection: Optional[PoolProxiedConnection] = None,
+ _has_events: Optional[bool] = None,
+ _allow_revalidate: bool = True,
+ _allow_autobegin: bool = True,
):
"""Construct a new Connection."""
self.engine = engine
@@ -125,14 +171,14 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self.dispatch.engine_connect(self)
@util.memoized_property
- def _message_formatter(self):
+ def _message_formatter(self) -> Any:
if "logging_token" in self._execution_options:
token = self._execution_options["logging_token"]
return lambda msg: "[%s] %s" % (token, msg)
else:
return None
- def _log_info(self, message, *arg, **kw):
+ def _log_info(self, message: str, *arg: Any, **kw: Any) -> None:
fmt = self._message_formatter
if fmt:
@@ -143,7 +189,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self.engine.logger.info(message, *arg, **kw)
- def _log_debug(self, message, *arg, **kw):
+ def _log_debug(self, message: str, *arg: Any, **kw: Any) -> None:
fmt = self._message_formatter
if fmt:
@@ -155,19 +201,19 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self.engine.logger.debug(message, *arg, **kw)
@property
- def _schema_translate_map(self):
+ def _schema_translate_map(self) -> Optional[_SchemaTranslateMapType]:
return self._execution_options.get("schema_translate_map", None)
- def schema_for_object(self, obj):
+ def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]:
"""Return the schema name for the given schema item taking into
account current schema translate map.
"""
name = obj.schema
- schema_translate_map = self._execution_options.get(
- "schema_translate_map", None
- )
+ schema_translate_map: Optional[
+ Mapping[Optional[str], str]
+ ] = self._execution_options.get("schema_translate_map", None)
if (
schema_translate_map
@@ -178,13 +224,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
else:
return name
- def __enter__(self):
+ def __enter__(self) -> Connection:
return self
- def __exit__(self, type_, value, traceback):
+ def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
self.close()
- def execution_options(self, **opt):
+ def execution_options(self, **opt: Any) -> Connection:
r"""Set non-SQL options for the connection which take effect
during execution.
@@ -346,13 +392,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
ORM-specific execution options
""" # noqa
- self._execution_options = self._execution_options.union(opt)
if self._has_events or self.engine._has_events:
self.dispatch.set_connection_execution_options(self, opt)
+ self._execution_options = self._execution_options.union(opt)
self.dialect.set_connection_execution_options(self, opt)
return self
- def get_execution_options(self):
+ def get_execution_options(self) -> _ExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
@@ -364,14 +410,27 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return self._execution_options
@property
- def closed(self):
+ def _still_open_and_dbapi_connection_is_valid(self) -> bool:
+ pool_proxied_connection = self._dbapi_connection
+ return (
+ pool_proxied_connection is not None
+ and pool_proxied_connection.is_valid
+ )
+
+ @property
+ def closed(self) -> bool:
"""Return True if this connection is closed."""
return self._dbapi_connection is None and not self.__can_reconnect
@property
- def invalidated(self):
- """Return True if this connection was invalidated."""
+ def invalidated(self) -> bool:
+ """Return True if this connection was invalidated.
+
+ This does not indicate whether or not the connection was
+ invalidated at the pool level, however
+
+ """
# prior to 1.4, "invalid" was stored as a state independent of
# "closed", meaning an invalidated connection could be "closed",
@@ -382,10 +441,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
# "closed" does not need to be "invalid". So the state is now
# represented by the two facts alone.
- return self._dbapi_connection is None and not self.closed
+ pool_proxied_connection = self._dbapi_connection
+ return pool_proxied_connection is None and self.__can_reconnect
@property
- def connection(self) -> "PoolProxiedConnection":
+ def connection(self) -> PoolProxiedConnection:
"""The underlying DB-API connection managed by this Connection.
This is a SQLAlchemy connection-pool proxied connection
@@ -410,7 +470,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
else:
return self._dbapi_connection
- def get_isolation_level(self):
+ def get_isolation_level(self) -> str:
"""Return the current isolation level assigned to this
:class:`_engine.Connection`.
@@ -442,15 +502,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
- set per :class:`_engine.Connection` isolation level
"""
+ dbapi_connection = self.connection.dbapi_connection
+ assert dbapi_connection is not None
try:
- return self.dialect.get_isolation_level(
- self.connection.dbapi_connection
- )
+ return self.dialect.get_isolation_level(dbapi_connection)
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
@property
- def default_isolation_level(self):
+ def default_isolation_level(self) -> str:
"""The default isolation level assigned to this
:class:`_engine.Connection`.
@@ -482,7 +542,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
return self.dialect.default_isolation_level
- def _invalid_transaction(self):
+ def _invalid_transaction(self) -> NoReturn:
raise exc.PendingRollbackError(
"Can't reconnect until invalid %stransaction is rolled "
"back. Please rollback() fully before proceeding"
@@ -490,7 +550,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
code="8s2b",
)
- def _revalidate_connection(self):
+ def _revalidate_connection(self) -> PoolProxiedConnection:
if self.__can_reconnect and self.invalidated:
if self._transaction is not None:
self._invalid_transaction()
@@ -499,13 +559,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
raise exc.ResourceClosedError("This Connection is closed")
@property
- def _still_open_and_dbapi_connection_is_valid(self):
- return self._dbapi_connection is not None and getattr(
- self._dbapi_connection, "is_valid", False
- )
-
- @property
- def info(self):
+ def info(self) -> Dict[str, Any]:
"""Info dictionary associated with the underlying DBAPI connection
referred to by this :class:`_engine.Connection`, allowing user-defined
data to be associated with the connection.
@@ -518,7 +572,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return self.connection.info
- def invalidate(self, exception=None):
+ def invalidate(self, exception: Optional[BaseException] = None) -> None:
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
@@ -567,14 +621,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if self.invalidated:
return
+ # MARKMARK
if self.closed:
raise exc.ResourceClosedError("This Connection is closed")
if self._still_open_and_dbapi_connection_is_valid:
- self._dbapi_connection.invalidate(exception)
+ pool_proxied_connection = self._dbapi_connection
+ assert pool_proxied_connection is not None
+ pool_proxied_connection.invalidate(exception)
+
self._dbapi_connection = None
- def detach(self):
+ def detach(self) -> None:
"""Detach the underlying DB-API connection from its connection pool.
E.g.::
@@ -600,13 +658,21 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
- self._dbapi_connection.detach()
+ if self.closed:
+ raise exc.ResourceClosedError("This Connection is closed")
- def _autobegin(self):
- if self._allow_autobegin:
+ pool_proxied_connection = self._dbapi_connection
+ if pool_proxied_connection is None:
+ raise exc.InvalidRequestError(
+ "Can't detach an invalidated Connection"
+ )
+ pool_proxied_connection.detach()
+
+ def _autobegin(self) -> None:
+ if self._allow_autobegin and not self.__in_begin:
self.begin()
- def begin(self):
+ def begin(self) -> RootTransaction:
"""Begin a transaction prior to autobegin occurring.
E.g.::
@@ -671,14 +737,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
:class:`_engine.Engine`
"""
- if self.__in_begin:
- # for dialects that emit SQL within the process of
- # dialect.do_begin() or dialect.do_begin_twophase(), this
- # flag prevents "autobegin" from being emitted within that
- # process, while allowing self._transaction to remain at None
- # until it's complete.
- return
- elif self._transaction is None:
+ if self._transaction is None:
self._transaction = RootTransaction(self)
return self._transaction
else:
@@ -689,7 +748,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"is called first."
)
- def begin_nested(self):
+ def begin_nested(self) -> NestedTransaction:
"""Begin a nested transaction (i.e. SAVEPOINT) and return a transaction
handle that controls the scope of the SAVEPOINT.
@@ -765,7 +824,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return NestedTransaction(self)
- def begin_twophase(self, xid=None):
+ def begin_twophase(self, xid: Optional[Any] = None) -> TwoPhaseTransaction:
"""Begin a two-phase or XA transaction and return a transaction
handle.
@@ -794,7 +853,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
xid = self.engine.dialect.create_xid()
return TwoPhaseTransaction(self, xid)
- def commit(self):
+ def commit(self) -> None:
"""Commit the transaction that is currently in progress.
This method commits the current transaction if one has been started.
@@ -819,7 +878,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if self._transaction:
self._transaction.commit()
- def rollback(self):
+ def rollback(self) -> None:
"""Roll back the transaction that is currently in progress.
This method rolls back the current transaction if one has been started.
@@ -845,33 +904,33 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if self._transaction:
self._transaction.rollback()
- def recover_twophase(self):
+ def recover_twophase(self) -> List[Any]:
return self.engine.dialect.do_recover_twophase(self)
- def rollback_prepared(self, xid, recover=False):
+ def rollback_prepared(self, xid: Any, recover: bool = False) -> None:
self.engine.dialect.do_rollback_twophase(self, xid, recover=recover)
- def commit_prepared(self, xid, recover=False):
+ def commit_prepared(self, xid: Any, recover: bool = False) -> None:
self.engine.dialect.do_commit_twophase(self, xid, recover=recover)
- def in_transaction(self):
+ def in_transaction(self) -> bool:
"""Return True if a transaction is in progress."""
return self._transaction is not None and self._transaction.is_active
- def in_nested_transaction(self):
+ def in_nested_transaction(self) -> bool:
"""Return True if a transaction is in progress."""
return (
self._nested_transaction is not None
and self._nested_transaction.is_active
)
- def _is_autocommit(self):
- return (
+ def _is_autocommit_isolation(self) -> bool:
+ return bool(
self._execution_options.get("isolation_level", None)
== "AUTOCOMMIT"
)
- def get_transaction(self):
+ def get_transaction(self) -> Optional[RootTransaction]:
"""Return the current root transaction in progress, if any.
.. versionadded:: 1.4
@@ -880,7 +939,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return self._transaction
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[NestedTransaction]:
"""Return the current nested transaction in progress, if any.
.. versionadded:: 1.4
@@ -888,7 +947,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
return self._nested_transaction
- def _begin_impl(self, transaction):
+ def _begin_impl(self, transaction: RootTransaction) -> None:
if self._echo:
self._log_info("BEGIN (implicit)")
@@ -904,13 +963,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
finally:
self.__in_begin = False
- def _rollback_impl(self):
+ def _rollback_impl(self) -> None:
if self._has_events or self.engine._has_events:
self.dispatch.rollback(self)
if self._still_open_and_dbapi_connection_is_valid:
if self._echo:
- if self._is_autocommit():
+ if self._is_autocommit_isolation():
self._log_info(
"ROLLBACK using DBAPI connection.rollback(), "
"DBAPI should ignore due to autocommit mode"
@@ -922,13 +981,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
- def _commit_impl(self):
+ def _commit_impl(self) -> None:
if self._has_events or self.engine._has_events:
self.dispatch.commit(self)
if self._echo:
- if self._is_autocommit():
+ if self._is_autocommit_isolation():
self._log_info(
"COMMIT using DBAPI connection.commit(), "
"DBAPI should ignore due to autocommit mode"
@@ -940,58 +999,54 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
- def _savepoint_impl(self, name=None):
+ def _savepoint_impl(self, name: Optional[str] = None) -> str:
if self._has_events or self.engine._has_events:
self.dispatch.savepoint(self, name)
if name is None:
self.__savepoint_seq += 1
name = "sa_savepoint_%s" % self.__savepoint_seq
- if self._still_open_and_dbapi_connection_is_valid:
- self.engine.dialect.do_savepoint(self, name)
- return name
+ self.engine.dialect.do_savepoint(self, name)
+ return name
- def _rollback_to_savepoint_impl(self, name):
+ def _rollback_to_savepoint_impl(self, name: str) -> None:
if self._has_events or self.engine._has_events:
self.dispatch.rollback_savepoint(self, name, None)
if self._still_open_and_dbapi_connection_is_valid:
self.engine.dialect.do_rollback_to_savepoint(self, name)
- def _release_savepoint_impl(self, name):
+ def _release_savepoint_impl(self, name: str) -> None:
if self._has_events or self.engine._has_events:
self.dispatch.release_savepoint(self, name, None)
- if self._still_open_and_dbapi_connection_is_valid:
- self.engine.dialect.do_release_savepoint(self, name)
+ self.engine.dialect.do_release_savepoint(self, name)
- def _begin_twophase_impl(self, transaction):
+ def _begin_twophase_impl(self, transaction: TwoPhaseTransaction) -> None:
if self._echo:
self._log_info("BEGIN TWOPHASE (implicit)")
if self._has_events or self.engine._has_events:
self.dispatch.begin_twophase(self, transaction.xid)
- if self._still_open_and_dbapi_connection_is_valid:
- self.__in_begin = True
- try:
- self.engine.dialect.do_begin_twophase(self, transaction.xid)
- except BaseException as e:
- self._handle_dbapi_exception(e, None, None, None, None)
- finally:
- self.__in_begin = False
+ self.__in_begin = True
+ try:
+ self.engine.dialect.do_begin_twophase(self, transaction.xid)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+ finally:
+ self.__in_begin = False
- def _prepare_twophase_impl(self, xid):
+ def _prepare_twophase_impl(self, xid: Any) -> None:
if self._has_events or self.engine._has_events:
self.dispatch.prepare_twophase(self, xid)
- if self._still_open_and_dbapi_connection_is_valid:
- assert isinstance(self._transaction, TwoPhaseTransaction)
- try:
- self.engine.dialect.do_prepare_twophase(self, xid)
- except BaseException as e:
- self._handle_dbapi_exception(e, None, None, None, None)
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_prepare_twophase(self, xid)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
- def _rollback_twophase_impl(self, xid, is_prepared):
+ def _rollback_twophase_impl(self, xid: Any, is_prepared: bool) -> None:
if self._has_events or self.engine._has_events:
self.dispatch.rollback_twophase(self, xid, is_prepared)
@@ -1004,18 +1059,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
- def _commit_twophase_impl(self, xid, is_prepared):
+ def _commit_twophase_impl(self, xid: Any, is_prepared: bool) -> None:
if self._has_events or self.engine._has_events:
self.dispatch.commit_twophase(self, xid, is_prepared)
- if self._still_open_and_dbapi_connection_is_valid:
- assert isinstance(self._transaction, TwoPhaseTransaction)
- try:
- self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
- except BaseException as e:
- self._handle_dbapi_exception(e, None, None, None, None)
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
- def close(self):
+ def close(self) -> None:
"""Close this :class:`_engine.Connection`.
This results in a release of the underlying database
@@ -1050,7 +1104,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
# as we just closed the transaction, close the connection
# pool connection without doing an additional reset
if skip_reset:
- conn._close_no_reset()
+ cast("_ConnectionFairy", conn)._close_no_reset()
else:
conn.close()
@@ -1061,7 +1115,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self._dbapi_connection = None
self.__can_reconnect = False
- def scalar(self, statement, parameters=None, execution_options=None):
+ def scalar(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
r"""Executes a SQL statement construct and returns a scalar object.
This method is shorthand for invoking the
@@ -1074,7 +1133,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
return self.execute(statement, parameters, execution_options).scalar()
- def scalars(self, statement, parameters=None, execution_options=None):
+ def scalars(
+ self,
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult:
"""Executes and returns a scalar result set, which yields scalar values
from the first column of each row.
@@ -1093,10 +1157,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def execute(
self,
- statement,
- parameters: Optional[_ExecuteParams] = None,
- execution_options: Optional[_ExecuteOptions] = None,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Result:
r"""Executes a SQL statement construct and returns a
:class:`_engine.Result`.
@@ -1140,7 +1204,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
execution_options or NO_OPTIONS,
)
- def _execute_function(self, func, distilled_parameters, execution_options):
+ def _execute_function(
+ self,
+ func: FunctionElement[Any],
+ distilled_parameters: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptions,
+ ) -> Result:
"""Execute a sql.FunctionElement object."""
return self._execute_clauseelement(
@@ -1148,14 +1217,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
)
def _execute_default(
- self, default, distilled_parameters, execution_options
- ):
+ self,
+ default: ColumnDefault,
+ distilled_parameters: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptions,
+ ) -> Any:
"""Execute a schema.ColumnDefault object."""
execution_options = self._execution_options.merge_with(
execution_options
)
+ event_multiparams: Optional[_CoreMultiExecuteParams]
+ event_params: Optional[_CoreAnyExecuteParams]
+
# note for event handlers, the "distilled parameters" which is always
# a list of dicts is broken out into separate "multiparams" and
# "params" collections, which allows the handler to distinguish
@@ -1169,6 +1244,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
) = self._invoke_before_exec_event(
default, distilled_parameters, execution_options
)
+ else:
+ event_multiparams = event_params = None
try:
conn = self._dbapi_connection
@@ -1198,13 +1275,21 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return ret
- def _execute_ddl(self, ddl, distilled_parameters, execution_options):
+ def _execute_ddl(
+ self,
+ ddl: DDLElement,
+ distilled_parameters: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptions,
+ ) -> Result:
"""Execute a schema.DDL object."""
execution_options = ddl._execution_options.merge_with(
self._execution_options, execution_options
)
+ event_multiparams: Optional[_CoreMultiExecuteParams]
+ event_params: Optional[_CoreSingleExecuteParams]
+
if self._has_events or self.engine._has_events:
(
ddl,
@@ -1214,6 +1299,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
) = self._invoke_before_exec_event(
ddl, distilled_parameters, execution_options
)
+ else:
+ event_multiparams = event_params = None
exec_opts = self._execution_options.merge_with(execution_options)
schema_translate_map = exec_opts.get("schema_translate_map", None)
@@ -1243,8 +1330,19 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return ret
def _invoke_before_exec_event(
- self, elem, distilled_params, execution_options
- ):
+ self,
+ elem: Any,
+ distilled_params: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptions,
+ ) -> Tuple[
+ Any,
+ _CoreMultiExecuteParams,
+ _CoreMultiExecuteParams,
+ _CoreSingleExecuteParams,
+ ]:
+
+ event_multiparams: _CoreMultiExecuteParams
+ event_params: _CoreSingleExecuteParams
if len(distilled_params) == 1:
event_multiparams, event_params = [], distilled_params[0]
@@ -1275,8 +1373,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return elem, distilled_params, event_multiparams, event_params
def _execute_clauseelement(
- self, elem, distilled_parameters, execution_options
- ):
+ self,
+ elem: Executable,
+ distilled_parameters: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptions,
+ ) -> Result:
"""Execute a sql.ClauseElement object."""
execution_options = elem._execution_options.merge_with(
@@ -1309,7 +1410,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"schema_translate_map", None
)
- compiled_cache = execution_options.get(
+ compiled_cache: _CompiledCacheType = execution_options.get(
"compiled_cache", self.engine._compiled_cache
)
@@ -1346,10 +1447,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def _execute_compiled(
self,
- compiled,
- distilled_parameters,
- execution_options=_EMPTY_EXECUTION_OPTS,
- ):
+ compiled: Compiled,
+ distilled_parameters: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS,
+ ) -> Result:
"""Execute a sql.Compiled object.
TODO: why do we have this? likely deprecate or remove
@@ -1395,8 +1496,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
return ret
def exec_driver_sql(
- self, statement, parameters=None, execution_options=None
- ):
+ self,
+ statement: str,
+ parameters: Optional[_DBAPIAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptions] = None,
+ ) -> Result:
r"""Executes a SQL statement construct and returns a
:class:`_engine.CursorResult`.
@@ -1456,7 +1560,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
dialect,
dialect.execution_ctx_cls._init_statement,
statement,
- distilled_parameters,
+ None,
execution_options,
statement,
distilled_parameters,
@@ -1466,14 +1570,14 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def _execute_context(
self,
- dialect,
- constructor,
- statement,
- parameters,
- execution_options,
- *args,
- **kw,
- ):
+ dialect: Dialect,
+ constructor: Callable[..., ExecutionContext],
+ statement: Union[str, Compiled],
+ parameters: Optional[_AnyMultiExecuteParams],
+ execution_options: _ExecuteOptions,
+ *args: Any,
+ **kw: Any,
+ ) -> Result:
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`_engine.CursorResult`."""
@@ -1491,7 +1595,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self._handle_dbapi_exception(
e, str(statement), parameters, None, None
)
- return # not reached
if (
self._transaction
@@ -1514,29 +1617,33 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if dialect.bind_typing is BindTyping.SETINPUTSIZES:
context._set_input_sizes()
- cursor, statement, parameters = (
+ cursor, str_statement, parameters = (
context.cursor,
context.statement,
context.parameters,
)
+ effective_parameters: Optional[_AnyExecuteParams]
+
if not context.executemany:
- parameters = parameters[0]
+ effective_parameters = parameters[0]
+ else:
+ effective_parameters = parameters
if self._has_events or self.engine._has_events:
for fn in self.dispatch.before_cursor_execute:
- statement, parameters = fn(
+ str_statement, effective_parameters = fn(
self,
cursor,
- statement,
- parameters,
+ str_statement,
+ effective_parameters,
context,
context.executemany,
)
if self._echo:
- self._log_info(statement)
+ self._log_info(str_statement)
stats = context._get_cache_stats()
@@ -1545,7 +1652,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"[%s] %r",
stats,
sql_util._repr_params(
- parameters, batches=10, ismulti=context.executemany
+ effective_parameters,
+ batches=10,
+ ismulti=context.executemany,
),
)
else:
@@ -1554,45 +1663,61 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
% (stats,)
)
- evt_handled = False
+ evt_handled: bool = False
try:
if context.executemany:
+ effective_parameters = cast(
+ "_CoreMultiExecuteParams", effective_parameters
+ )
if self.dialect._has_events:
for fn in self.dialect.dispatch.do_executemany:
- if fn(cursor, statement, parameters, context):
+ if fn(
+ cursor,
+ str_statement,
+ effective_parameters,
+ context,
+ ):
evt_handled = True
break
if not evt_handled:
self.dialect.do_executemany(
- cursor, statement, parameters, context
+ cursor, str_statement, effective_parameters, context
)
- elif not parameters and context.no_parameters:
+ elif not effective_parameters and context.no_parameters:
if self.dialect._has_events:
for fn in self.dialect.dispatch.do_execute_no_params:
- if fn(cursor, statement, context):
+ if fn(cursor, str_statement, context):
evt_handled = True
break
if not evt_handled:
self.dialect.do_execute_no_params(
- cursor, statement, context
+ cursor, str_statement, context
)
else:
+ effective_parameters = cast(
+ "_CoreSingleExecuteParams", effective_parameters
+ )
if self.dialect._has_events:
for fn in self.dialect.dispatch.do_execute:
- if fn(cursor, statement, parameters, context):
+ if fn(
+ cursor,
+ str_statement,
+ effective_parameters,
+ context,
+ ):
evt_handled = True
break
if not evt_handled:
self.dialect.do_execute(
- cursor, statement, parameters, context
+ cursor, str_statement, effective_parameters, context
)
if self._has_events or self.engine._has_events:
self.dispatch.after_cursor_execute(
self,
cursor,
- statement,
- parameters,
+ str_statement,
+ effective_parameters,
context,
context.executemany,
)
@@ -1603,12 +1728,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
except BaseException as e:
self._handle_dbapi_exception(
- e, statement, parameters, cursor, context
+ e, str_statement, effective_parameters, cursor, context
)
return result
- def _cursor_execute(self, cursor, statement, parameters, context=None):
+ def _cursor_execute(
+ self,
+ cursor: DBAPICursor,
+ statement: str,
+ parameters: _DBAPISingleExecuteParams,
+ context: Optional[ExecutionContext] = None,
+ ) -> None:
"""Execute a statement + params on the given cursor.
Adds appropriate logging and exception handling.
@@ -1648,7 +1779,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
self, cursor, statement, parameters, context, False
)
- def _safe_close_cursor(self, cursor):
+ def _safe_close_cursor(self, cursor: DBAPICursor) -> None:
"""Close the given cursor, catching exceptions
and turning into log warnings.
@@ -1665,8 +1796,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
_is_disconnect = False
def _handle_dbapi_exception(
- self, e, statement, parameters, cursor, context
- ):
+ self,
+ e: BaseException,
+ statement: Optional[str],
+ parameters: Optional[_AnyExecuteParams],
+ cursor: Optional[DBAPICursor],
+ context: Optional[ExecutionContext],
+ ) -> NoReturn:
exc_info = sys.exc_info()
is_exit_exception = util.is_exit_exception(e)
@@ -1708,7 +1844,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
sqlalchemy_exception = exc.DBAPIError.instance(
statement,
parameters,
- e,
+ cast(Exception, e),
self.dialect.dbapi.Error,
hide_parameters=self.engine.hide_parameters,
connection_invalidated=self._is_disconnect,
@@ -1784,8 +1920,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if newraise:
raise newraise.with_traceback(exc_info[2]) from e
elif should_wrap:
+ assert sqlalchemy_exception is not None
raise sqlalchemy_exception.with_traceback(exc_info[2]) from e
else:
+ assert exc_info[1] is not None
raise exc_info[1].with_traceback(exc_info[2])
finally:
del self._reentrant_error
@@ -1793,15 +1931,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
del self._is_disconnect
if not self.invalidated:
dbapi_conn_wrapper = self._dbapi_connection
+ assert dbapi_conn_wrapper is not None
if invalidate_pool_on_disconnect:
self.engine.pool._invalidate(dbapi_conn_wrapper, e)
self.invalidate(e)
@classmethod
- def _handle_dbapi_exception_noconnection(cls, e, dialect, engine):
+ def _handle_dbapi_exception_noconnection(
+ cls, e: BaseException, dialect: Dialect, engine: Engine
+ ) -> NoReturn:
exc_info = sys.exc_info()
- is_disconnect = dialect.is_disconnect(e, None, None)
+ is_disconnect = isinstance(
+ e, dialect.dbapi.Error
+ ) and dialect.is_disconnect(e, None, None)
should_wrap = isinstance(e, dialect.dbapi.Error)
@@ -1809,7 +1952,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
sqlalchemy_exception = exc.DBAPIError.instance(
None,
None,
- e,
+ cast(Exception, e),
dialect.dbapi.Error,
hide_parameters=engine.hide_parameters,
connection_invalidated=is_disconnect,
@@ -1852,11 +1995,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if newraise:
raise newraise.with_traceback(exc_info[2]) from e
elif should_wrap:
+ assert sqlalchemy_exception is not None
raise sqlalchemy_exception.with_traceback(exc_info[2]) from e
else:
+ assert exc_info[1] is not None
raise exc_info[1].with_traceback(exc_info[2])
- def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ def _run_ddl_visitor(
+ self,
+ visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
+ element: DDLElement,
+ **kwargs: Any,
+ ) -> None:
"""run a DDL visitor.
This method is only here so that the MockConnection can change the
@@ -1871,16 +2021,16 @@ class ExceptionContextImpl(ExceptionContext):
def __init__(
self,
- exception,
- sqlalchemy_exception,
- engine,
- connection,
- cursor,
- statement,
- parameters,
- context,
- is_disconnect,
- invalidate_pool_on_disconnect,
+ exception: BaseException,
+ sqlalchemy_exception: Optional[exc.StatementError],
+ engine: Optional[Engine],
+ connection: Optional[Connection],
+ cursor: Optional[DBAPICursor],
+ statement: Optional[str],
+ parameters: Optional[_DBAPIAnyExecuteParams],
+ context: Optional[ExecutionContext],
+ is_disconnect: bool,
+ invalidate_pool_on_disconnect: bool,
):
self.engine = engine
self.connection = connection
@@ -1932,33 +2082,35 @@ class Transaction(TransactionalContext):
__slots__ = ()
- _is_root = False
+ _is_root: bool = False
+ is_active: bool
+ connection: Connection
- def __init__(self, connection):
+ def __init__(self, connection: Connection):
raise NotImplementedError()
@property
- def _deactivated_from_connection(self):
+ def _deactivated_from_connection(self) -> bool:
"""True if this transaction is totally deactivated from the connection
and therefore can no longer affect its state.
"""
raise NotImplementedError()
- def _do_close(self):
+ def _do_close(self) -> None:
raise NotImplementedError()
- def _do_rollback(self):
+ def _do_rollback(self) -> None:
raise NotImplementedError()
- def _do_commit(self):
+ def _do_commit(self) -> None:
raise NotImplementedError()
@property
- def is_valid(self):
+ def is_valid(self) -> bool:
return self.is_active and not self.connection.invalidated
- def close(self):
+ def close(self) -> None:
"""Close this :class:`.Transaction`.
If this transaction is the base transaction in a begin/commit
@@ -1974,7 +2126,7 @@ class Transaction(TransactionalContext):
finally:
assert not self.is_active
- def rollback(self):
+ def rollback(self) -> None:
"""Roll back this :class:`.Transaction`.
The implementation of this may vary based on the type of transaction in
@@ -1996,7 +2148,7 @@ class Transaction(TransactionalContext):
finally:
assert not self.is_active
- def commit(self):
+ def commit(self) -> None:
"""Commit this :class:`.Transaction`.
The implementation of this may vary based on the type of transaction in
@@ -2017,16 +2169,16 @@ class Transaction(TransactionalContext):
finally:
assert not self.is_active
- def _get_subject(self):
+ def _get_subject(self) -> Connection:
return self.connection
- def _transaction_is_active(self):
+ def _transaction_is_active(self) -> bool:
return self.is_active
- def _transaction_is_closed(self):
+ def _transaction_is_closed(self) -> bool:
return not self._deactivated_from_connection
- def _rollback_can_be_called(self):
+ def _rollback_can_be_called(self) -> bool:
# for RootTransaction / NestedTransaction, it's safe to call
# rollback() even if the transaction is deactive and no warnings
# will be emitted. tested in
@@ -2060,7 +2212,7 @@ class RootTransaction(Transaction):
__slots__ = ("connection", "is_active")
- def __init__(self, connection):
+ def __init__(self, connection: Connection):
assert connection._transaction is None
if connection._trans_context_manager:
TransactionalContext._trans_ctx_check(connection)
@@ -2070,7 +2222,7 @@ class RootTransaction(Transaction):
self.is_active = True
- def _deactivate_from_connection(self):
+ def _deactivate_from_connection(self) -> None:
if self.is_active:
assert self.connection._transaction is self
self.is_active = False
@@ -2079,19 +2231,19 @@ class RootTransaction(Transaction):
util.warn("transaction already deassociated from connection")
@property
- def _deactivated_from_connection(self):
+ def _deactivated_from_connection(self) -> bool:
return self.connection._transaction is not self
- def _connection_begin_impl(self):
+ def _connection_begin_impl(self) -> None:
self.connection._begin_impl(self)
- def _connection_rollback_impl(self):
+ def _connection_rollback_impl(self) -> None:
self.connection._rollback_impl()
- def _connection_commit_impl(self):
+ def _connection_commit_impl(self) -> None:
self.connection._commit_impl()
- def _close_impl(self, try_deactivate=False):
+ def _close_impl(self, try_deactivate: bool = False) -> None:
try:
if self.is_active:
self._connection_rollback_impl()
@@ -2107,13 +2259,13 @@ class RootTransaction(Transaction):
assert not self.is_active
assert self.connection._transaction is not self
- def _do_close(self):
+ def _do_close(self) -> None:
self._close_impl()
- def _do_rollback(self):
+ def _do_rollback(self) -> None:
self._close_impl(try_deactivate=True)
- def _do_commit(self):
+ def _do_commit(self) -> None:
if self.is_active:
assert self.connection._transaction is self
@@ -2176,7 +2328,9 @@ class NestedTransaction(Transaction):
__slots__ = ("connection", "is_active", "_savepoint", "_previous_nested")
- def __init__(self, connection):
+ _savepoint: str
+
+ def __init__(self, connection: Connection):
assert connection._transaction is not None
if connection._trans_context_manager:
TransactionalContext._trans_ctx_check(connection)
@@ -2186,7 +2340,7 @@ class NestedTransaction(Transaction):
self._previous_nested = connection._nested_transaction
connection._nested_transaction = self
- def _deactivate_from_connection(self, warn=True):
+ def _deactivate_from_connection(self, warn: bool = True) -> None:
if self.connection._nested_transaction is self:
self.connection._nested_transaction = self._previous_nested
elif warn:
@@ -2195,10 +2349,10 @@ class NestedTransaction(Transaction):
)
@property
- def _deactivated_from_connection(self):
+ def _deactivated_from_connection(self) -> bool:
return self.connection._nested_transaction is not self
- def _cancel(self):
+ def _cancel(self) -> None:
# called by RootTransaction when the outer transaction is
# committed, rolled back, or closed to cancel all savepoints
# without any action being taken
@@ -2207,9 +2361,15 @@ class NestedTransaction(Transaction):
if self._previous_nested:
self._previous_nested._cancel()
- def _close_impl(self, deactivate_from_connection, warn_already_deactive):
+ def _close_impl(
+ self, deactivate_from_connection: bool, warn_already_deactive: bool
+ ) -> None:
try:
- if self.is_active and self.connection._transaction.is_active:
+ if (
+ self.is_active
+ and self.connection._transaction
+ and self.connection._transaction.is_active
+ ):
self.connection._rollback_to_savepoint_impl(self._savepoint)
finally:
self.is_active = False
@@ -2221,13 +2381,13 @@ class NestedTransaction(Transaction):
if deactivate_from_connection:
assert self.connection._nested_transaction is not self
- def _do_close(self):
+ def _do_close(self) -> None:
self._close_impl(True, False)
- def _do_rollback(self):
+ def _do_rollback(self) -> None:
self._close_impl(True, True)
- def _do_commit(self):
+ def _do_commit(self) -> None:
if self.is_active:
try:
self.connection._release_savepoint_impl(self._savepoint)
@@ -2261,12 +2421,14 @@ class TwoPhaseTransaction(RootTransaction):
__slots__ = ("xid", "_is_prepared")
- def __init__(self, connection, xid):
+ xid: Any
+
+ def __init__(self, connection: Connection, xid: Any):
self._is_prepared = False
self.xid = xid
super(TwoPhaseTransaction, self).__init__(connection)
- def prepare(self):
+ def prepare(self) -> None:
"""Prepare this :class:`.TwoPhaseTransaction`.
After a PREPARE, the transaction can be committed.
@@ -2277,13 +2439,13 @@ class TwoPhaseTransaction(RootTransaction):
self.connection._prepare_twophase_impl(self.xid)
self._is_prepared = True
- def _connection_begin_impl(self):
+ def _connection_begin_impl(self) -> None:
self.connection._begin_twophase_impl(self)
- def _connection_rollback_impl(self):
+ def _connection_rollback_impl(self) -> None:
self.connection._rollback_twophase_impl(self.xid, self._is_prepared)
- def _connection_commit_impl(self):
+ def _connection_commit_impl(self) -> None:
self.connection._commit_twophase_impl(self.xid, self._is_prepared)
@@ -2310,17 +2472,23 @@ class Engine(
"""
- _execution_options = _EMPTY_EXECUTION_OPTS
- _has_events = False
- _connection_cls = Connection
- _sqla_logger_namespace = "sqlalchemy.engine.Engine"
- _is_future = False
+ dispatch: dispatcher[ConnectionEventsTarget]
- _schema_translate_map = None
+ _compiled_cache: Optional[_CompiledCacheType]
+
+ _execution_options: _ExecuteOptions = _EMPTY_EXECUTION_OPTS
+ _has_events: bool = False
+ _connection_cls: Type[Connection] = Connection
+ _sqla_logger_namespace: str = "sqlalchemy.engine.Engine"
+ _is_future: bool = False
+
+ _schema_translate_map: Optional[_SchemaTranslateMapType] = None
+ _option_cls: Type[OptionEngine]
dialect: Dialect
pool: Pool
url: URL
+ hide_parameters: bool
def __init__(
self,
@@ -2328,7 +2496,7 @@ class Engine(
dialect: Dialect,
url: URL,
logging_name: Optional[str] = None,
- echo: Union[None, str, bool] = None,
+ echo: Optional[_EchoFlagType] = None,
query_cache_size: int = 500,
execution_options: Optional[Mapping[str, Any]] = None,
hide_parameters: bool = False,
@@ -2350,7 +2518,7 @@ class Engine(
if execution_options:
self.update_execution_options(**execution_options)
- def _lru_size_alert(self, cache):
+ def _lru_size_alert(self, cache: util.LRUCache[Any, Any]) -> None:
if self._should_log_info:
self.logger.info(
"Compiled cache size pruning from %d items to %d. "
@@ -2360,10 +2528,10 @@ class Engine(
)
@property
- def engine(self):
+ def engine(self) -> Engine:
return self
- def clear_compiled_cache(self):
+ def clear_compiled_cache(self) -> None:
"""Clear the compiled cache associated with the dialect.
This applies **only** to the built-in cache that is established
@@ -2377,7 +2545,7 @@ class Engine(
if self._compiled_cache:
self._compiled_cache.clear()
- def update_execution_options(self, **opt):
+ def update_execution_options(self, **opt: Any) -> None:
r"""Update the default execution_options dictionary
of this :class:`_engine.Engine`.
@@ -2394,11 +2562,11 @@ class Engine(
:meth:`_engine.Engine.execution_options`
"""
- self._execution_options = self._execution_options.union(opt)
self.dispatch.set_engine_execution_options(self, opt)
+ self._execution_options = self._execution_options.union(opt)
self.dialect.set_engine_execution_options(self, opt)
- def execution_options(self, **opt):
+ def execution_options(self, **opt: Any) -> OptionEngine:
"""Return a new :class:`_engine.Engine` that will provide
:class:`_engine.Connection` objects with the given execution options.
@@ -2478,7 +2646,7 @@ class Engine(
""" # noqa E501
return self._option_cls(self, opt)
- def get_execution_options(self):
+ def get_execution_options(self) -> _ExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded: 1.3
@@ -2490,14 +2658,14 @@ class Engine(
return self._execution_options
@property
- def name(self):
+ def name(self) -> str:
"""String name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
in use by this :class:`Engine`."""
return self.dialect.name
@property
- def driver(self):
+ def driver(self) -> str:
"""Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
in use by this :class:`Engine`."""
@@ -2505,10 +2673,10 @@ class Engine(
echo = log.echo_property()
- def __repr__(self):
+ def __repr__(self) -> str:
return "Engine(%r)" % (self.url,)
- def dispose(self):
+ def dispose(self) -> None:
"""Dispose of the connection pool used by this
:class:`_engine.Engine`.
@@ -2538,7 +2706,9 @@ class Engine(
self.dispatch.engine_disposed(self)
@contextlib.contextmanager
- def _optional_conn_ctx_manager(self, connection=None):
+ def _optional_conn_ctx_manager(
+ self, connection: Optional[Connection] = None
+ ) -> Iterator[Connection]:
if connection is None:
with self.connect() as conn:
yield conn
@@ -2546,7 +2716,7 @@ class Engine(
yield connection
@contextlib.contextmanager
- def begin(self):
+ def begin(self) -> Iterator[Connection]:
"""Return a context manager delivering a :class:`_engine.Connection`
with a :class:`.Transaction` established.
@@ -2576,11 +2746,16 @@ class Engine(
with conn.begin():
yield conn
- def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ def _run_ddl_visitor(
+ self,
+ visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
+ element: DDLElement,
+ **kwargs: Any,
+ ) -> None:
with self.begin() as conn:
conn._run_ddl_visitor(visitorcallable, element, **kwargs)
- def connect(self):
+ def connect(self) -> Connection:
"""Return a new :class:`_engine.Connection` object.
The :class:`_engine.Connection` acts as a Python context manager, so
@@ -2605,7 +2780,7 @@ class Engine(
return self._connection_cls(self)
- def raw_connection(self):
+ def raw_connection(self) -> PoolProxiedConnection:
"""Return a "raw" DBAPI connection from the connection pool.
The returned object is a proxied version of the DBAPI
@@ -2630,10 +2805,20 @@ class Engine(
return self.pool.connect()
-class OptionEngineMixin:
+class OptionEngineMixin(log.Identified):
_sa_propagate_class_events = False
- def __init__(self, proxied, execution_options):
+ dispatch: dispatcher[ConnectionEventsTarget]
+ _compiled_cache: Optional[_CompiledCacheType]
+ dialect: Dialect
+ pool: Pool
+ url: URL
+ hide_parameters: bool
+ echo: log.echo_property
+
+ def __init__(
+ self, proxied: Engine, execution_options: _ExecuteOptionsParameter
+ ):
self._proxied = proxied
self.url = proxied.url
self.dialect = proxied.dialect
@@ -2660,27 +2845,34 @@ class OptionEngineMixin:
self._execution_options = proxied._execution_options
self.update_execution_options(**execution_options)
- def _get_pool(self):
- return self._proxied.pool
+ def update_execution_options(self, **opt: Any) -> None:
+ raise NotImplementedError()
- def _set_pool(self, pool):
- self._proxied.pool = pool
+ if not typing.TYPE_CHECKING:
+ # https://github.com/python/typing/discussions/1095
- pool = property(_get_pool, _set_pool)
+ @property
+ def pool(self) -> Pool:
+ return self._proxied.pool
- def _get_has_events(self):
- return self._proxied._has_events or self.__dict__.get(
- "_has_events", False
- )
+ @pool.setter
+ def pool(self, pool: Pool) -> None:
+ self._proxied.pool = pool
- def _set_has_events(self, value):
- self.__dict__["_has_events"] = value
+ @property
+ def _has_events(self) -> bool:
+ return self._proxied._has_events or self.__dict__.get(
+ "_has_events", False
+ )
- _has_events = property(_get_has_events, _set_has_events)
+ @_has_events.setter
+ def _has_events(self, value: bool) -> None:
+ self.__dict__["_has_events"] = value
class OptionEngine(OptionEngineMixin, Engine):
- pass
+ def update_execution_options(self, **opt: Any) -> None:
+ Engine.update_execution_options(self, **opt)
Engine._option_cls = OptionEngine
diff --git a/lib/sqlalchemy/engine/characteristics.py b/lib/sqlalchemy/engine/characteristics.py
index c3674c931..c0feb000b 100644
--- a/lib/sqlalchemy/engine/characteristics.py
+++ b/lib/sqlalchemy/engine/characteristics.py
@@ -1,6 +1,13 @@
from __future__ import annotations
import abc
+import typing
+from typing import Any
+from typing import ClassVar
+
+if typing.TYPE_CHECKING:
+ from .interfaces import DBAPIConnection
+ from .interfaces import Dialect
class ConnectionCharacteristic(abc.ABC):
@@ -25,18 +32,24 @@ class ConnectionCharacteristic(abc.ABC):
__slots__ = ()
- transactional = False
+ transactional: ClassVar[bool] = False
@abc.abstractmethod
- def reset_characteristic(self, dialect, dbapi_conn):
+ def reset_characteristic(
+ self, dialect: Dialect, dbapi_conn: DBAPIConnection
+ ) -> None:
"""Reset the characteristic on the connection to its default value."""
@abc.abstractmethod
- def set_characteristic(self, dialect, dbapi_conn, value):
+ def set_characteristic(
+ self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any
+ ) -> None:
"""set characteristic on the connection to a given value."""
@abc.abstractmethod
- def get_characteristic(self, dialect, dbapi_conn):
+ def get_characteristic(
+ self, dialect: Dialect, dbapi_conn: DBAPIConnection
+ ) -> Any:
"""Given a DBAPI connection, get the current value of the
characteristic.
@@ -44,13 +57,19 @@ class ConnectionCharacteristic(abc.ABC):
class IsolationLevelCharacteristic(ConnectionCharacteristic):
- transactional = True
+ transactional: ClassVar[bool] = True
- def reset_characteristic(self, dialect, dbapi_conn):
+ def reset_characteristic(
+ self, dialect: Dialect, dbapi_conn: DBAPIConnection
+ ) -> None:
dialect.reset_isolation_level(dbapi_conn)
- def set_characteristic(self, dialect, dbapi_conn, value):
+ def set_characteristic(
+ self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any
+ ) -> None:
dialect._assert_and_set_isolation_level(dbapi_conn, value)
- def get_characteristic(self, dialect, dbapi_conn):
+ def get_characteristic(
+ self, dialect: Dialect, dbapi_conn: DBAPIConnection
+ ) -> Any:
return dialect.get_isolation_level(dbapi_conn)
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
index ac3d6a2d8..cb5219396 100644
--- a/lib/sqlalchemy/engine/create.py
+++ b/lib/sqlalchemy/engine/create.py
@@ -7,7 +7,12 @@
from __future__ import annotations
+import inspect
+import typing
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Optional
from typing import Union
from . import base
@@ -21,6 +26,9 @@ from ..pool import _AdhocProxiedConnection
from ..pool import ConnectionPoolEntry
from ..sql import compiler
+if typing.TYPE_CHECKING:
+ from .base import Engine
+
@util.deprecated_params(
strategy=(
@@ -46,7 +54,7 @@ from ..sql import compiler
"is deprecated and will be removed in a future release. ",
),
)
-def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
+def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> Engine:
"""Create a new :class:`_engine.Engine` instance.
The standard calling form is to send the :ref:`URL <database_urls>` as the
@@ -452,7 +460,8 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
if "strategy" in kwargs:
strat = kwargs.pop("strategy")
if strat == "mock":
- return create_mock_engine(url, **kwargs)
+ # this case is deprecated
+ return create_mock_engine(url, **kwargs) # type: ignore
else:
raise exc.ArgumentError("unknown strategy: %r" % strat)
@@ -472,14 +481,14 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
if kwargs.pop("_coerce_config", False):
- def pop_kwarg(key, default=None):
+ def pop_kwarg(key: str, default: Optional[Any] = None) -> Any:
value = kwargs.pop(key, default)
if key in dialect_cls.engine_config_types:
value = dialect_cls.engine_config_types[key](value)
return value
else:
- pop_kwarg = kwargs.pop
+ pop_kwarg = kwargs.pop # type: ignore
dialect_args = {}
# consume dialect arguments from kwargs
@@ -490,10 +499,29 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
dbapi = kwargs.pop("module", None)
if dbapi is None:
dbapi_args = {}
- for k in util.get_func_kwargs(dialect_cls.dbapi):
+
+ if "import_dbapi" in dialect_cls.__dict__:
+ dbapi_meth = dialect_cls.import_dbapi
+
+ elif hasattr(dialect_cls, "dbapi") and inspect.ismethod(
+ dialect_cls.dbapi
+ ):
+ util.warn_deprecated(
+ "The dbapi() classmethod on dialect classes has been "
+ "renamed to import_dbapi(). Implement an import_dbapi() "
+ f"classmethod directly on class {dialect_cls} to remove this "
+ "warning; the old .dbapi() classmethod may be maintained for "
+ "backwards compatibility.",
+ "2.0",
+ )
+ dbapi_meth = dialect_cls.dbapi
+ else:
+ dbapi_meth = dialect_cls.import_dbapi
+
+ for k in util.get_func_kwargs(dbapi_meth):
if k in kwargs:
dbapi_args[k] = pop_kwarg(k)
- dbapi = dialect_cls.dbapi(**dbapi_args)
+ dbapi = dbapi_meth(**dbapi_args)
dialect_args["dbapi"] = dbapi
@@ -509,18 +537,23 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
dialect = dialect_cls(**dialect_args)
# assemble connection arguments
- (cargs, cparams) = dialect.create_connect_args(u)
+ (cargs_tup, cparams) = dialect.create_connect_args(u)
cparams.update(pop_kwarg("connect_args", {}))
- cargs = list(cargs) # allow mutability
+ cargs = list(cargs_tup) # allow mutability
# look for existing pool or create
pool = pop_kwarg("pool", None)
if pool is None:
- def connect(connection_record=None):
+ def connect(
+ connection_record: Optional[ConnectionPoolEntry] = None,
+ ) -> DBAPIConnection:
if dialect._has_events:
for fn in dialect.dispatch.do_connect:
- connection = fn(dialect, connection_record, cargs, cparams)
+ connection = cast(
+ DBAPIConnection,
+ fn(dialect, connection_record, cargs, cparams),
+ )
if connection is not None:
return connection
return dialect.connect(*cargs, **cparams)
@@ -596,7 +629,11 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
do_on_connect = dialect.on_connect_url(u)
if do_on_connect:
- def on_connect(dbapi_connection, connection_record):
+ def on_connect(
+ dbapi_connection: DBAPIConnection,
+ connection_record: ConnectionPoolEntry,
+ ) -> None:
+ assert do_on_connect is not None
do_on_connect(dbapi_connection)
event.listen(pool, "connect", on_connect)
@@ -608,7 +645,7 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
def first_connect(
dbapi_connection: DBAPIConnection,
connection_record: ConnectionPoolEntry,
- ):
+ ) -> None:
c = base.Connection(
engine,
connection=_AdhocProxiedConnection(
@@ -654,7 +691,9 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine":
return engine
-def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+def engine_from_config(
+ configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any
+) -> Engine:
"""Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file.
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
index 2b077056f..78805bac1 100644
--- a/lib/sqlalchemy/engine/cursor.py
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -13,6 +13,17 @@ from __future__ import annotations
import collections
import functools
+import typing
+from typing import Any
+from typing import cast
+from typing import ClassVar
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from .result import Result
from .result import ResultMetaData
@@ -30,19 +41,43 @@ from ..sql.compiler import RM_OBJECTS
from ..sql.compiler import RM_RENDERED_NAME
from ..sql.compiler import RM_TYPE
from ..util import compat
+from ..util.typing import Literal
_UNPICKLED = util.symbol("unpickled")
+if typing.TYPE_CHECKING:
+ from .interfaces import _DBAPICursorDescription
+ from .interfaces import ExecutionContext
+ from .result import _KeyIndexType
+ from .result import _KeyMapRecType
+ from .result import _KeyMapType
+ from .result import _KeyType
+ from .result import _ProcessorsType
+ from .result import _ProcessorType
+
# metadata entry tuple indexes.
# using raw tuple is faster than namedtuple.
-MD_INDEX = 0 # integer index in cursor.description
-MD_RESULT_MAP_INDEX = 1 # integer index in compiled._result_columns
-MD_OBJECTS = 2 # other string keys and ColumnElement obj that can match
-MD_LOOKUP_KEY = 3 # string key we usually expect for key-based lookup
-MD_RENDERED_NAME = 4 # name that is usually in cursor.description
-MD_PROCESSOR = 5 # callable to process a result value into a row
-MD_UNTRANSLATED = 6 # raw name from cursor.description
+MD_INDEX: Literal[0] = 0 # integer index in cursor.description
+MD_RESULT_MAP_INDEX: Literal[
+ 1
+] = 1 # integer index in compiled._result_columns
+MD_OBJECTS: Literal[
+ 2
+] = 2 # other string keys and ColumnElement obj that can match
+MD_LOOKUP_KEY: Literal[
+ 3
+] = 3 # string key we usually expect for key-based lookup
+MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description
+MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row
+MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description
+
+
+_CursorKeyMapRecType = Tuple[
+ int, int, List[Any], str, str, Optional["_ProcessorType"], str
+]
+
+_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType]
class CursorResultMetaData(ResultMetaData):
@@ -61,22 +96,30 @@ class CursorResultMetaData(ResultMetaData):
# if a need arises.
)
- returns_rows = True
+ _keymap: _CursorKeyMapType
+ _processors: _ProcessorsType
+ _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]]
+ _unpickled: bool
+ _safe_for_cache: bool
+
+ returns_rows: ClassVar[bool] = True
- def _has_key(self, key):
+ def _has_key(self, key: Any) -> bool:
return key in self._keymap
- def _for_freeze(self):
+ def _for_freeze(self) -> ResultMetaData:
return SimpleResultMetaData(
self._keys,
extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
)
- def _reduce(self, keys):
- recs = list(self._metadata_for_keys(keys))
+ def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
+ recs = cast(
+ "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys))
+ )
indexes = [rec[MD_INDEX] for rec in recs]
- new_keys = [rec[MD_LOOKUP_KEY] for rec in recs]
+ new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs]
if self._translated_indexes:
indexes = [self._translated_indexes[idx] for idx in indexes]
@@ -104,7 +147,7 @@ class CursorResultMetaData(ResultMetaData):
return new_metadata
- def _adapt_to_context(self, context):
+ def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData:
"""When using a cached Compiled construct that has a _result_map,
for a new statement that used the cached Compiled, we need to ensure
the keymap has the Column objects from our new statement as keys.
@@ -112,8 +155,7 @@ class CursorResultMetaData(ResultMetaData):
as matched to those of the cached statement.
"""
-
- if not context.compiled._result_columns:
+ if not context.compiled or not context.compiled._result_columns:
return self
compiled_statement = context.compiled.statement
@@ -122,6 +164,8 @@ class CursorResultMetaData(ResultMetaData):
if compiled_statement is invoked_statement:
return self
+ assert invoked_statement is not None
+
# this is the most common path for Core statements when
# caching is used. In ORM use, this codepath is not really used
# as the _result_disable_adapt_to_context execution option is
@@ -162,7 +206,9 @@ class CursorResultMetaData(ResultMetaData):
md._safe_for_cache = self._safe_for_cache
return md
- def __init__(self, parent, cursor_description):
+ def __init__(
+ self, parent: CursorResult, cursor_description: _DBAPICursorDescription
+ ):
context = parent.context
self._tuplefilter = None
self._translated_indexes = None
@@ -229,7 +275,7 @@ class CursorResultMetaData(ResultMetaData):
# new in 1.4: get the complete set of all possible keys,
# strings, objects, whatever, that are dupes across two
# different records, first.
- index_by_key = {}
+ index_by_key: Dict[Any, Any] = {}
dupes = set()
for metadata_entry in raw:
for key in (metadata_entry[MD_RENDERED_NAME],) + (
@@ -626,7 +672,7 @@ class CursorResultMetaData(ResultMetaData):
"result set column descriptions" % rec[MD_LOOKUP_KEY]
)
- def _index_for_key(self, key, raiseerr=True):
+ def _index_for_key(self, key: Any, raiseerr: bool = True) -> Optional[int]:
# TODO: can consider pre-loading ints and negative ints
# into _keymap - also no coverage here
if isinstance(key, int):
@@ -653,7 +699,9 @@ class CursorResultMetaData(ResultMetaData):
# ensure it raises
CursorResultMetaData._key_fallback(self, ke.args[0], ke)
- def _metadata_for_keys(self, keys):
+ def _metadata_for_keys(
+ self, keys: Sequence[Any]
+ ) -> Iterator[_CursorKeyMapRecType]:
for key in keys:
if int in key.__class__.__mro__:
key = self._keys[key]
@@ -707,7 +755,7 @@ class ResultFetchStrategy:
__slots__ = ()
- alternate_cursor_description = None
+ alternate_cursor_description: Optional[_DBAPICursorDescription] = None
def soft_close(self, result, dbapi_cursor):
raise NotImplementedError()
@@ -1099,10 +1147,9 @@ _NO_RESULT_METADATA = _NoResultMetaData()
class BaseCursorResult:
"""Base class for database result objects."""
- out_parameters = None
- _metadata = None
- _soft_closed = False
- closed = False
+ _metadata: ResultMetaData
+ _soft_closed: bool = False
+ closed: bool = False
def __init__(self, context, cursor_strategy, cursor_description):
self.context = context
@@ -1134,7 +1181,7 @@ class BaseCursorResult:
keymap = metadata._keymap
processors = metadata._processors
- process_row = self._process_row
+ process_row = Row
key_style = process_row._default_key_style
_make_row = functools.partial(
process_row, metadata, processors, keymap, key_style
@@ -1644,7 +1691,7 @@ class CursorResult(BaseCursorResult, Result):
"""
- _cursor_metadata = CursorResultMetaData
+ _cursor_metadata: Type[ResultMetaData] = CursorResultMetaData
_cursor_strategy_cls = CursorFetchStrategy
_no_result_metadata = _NO_RESULT_METADATA
_is_cursor = True
@@ -1719,7 +1766,9 @@ class BufferedRowResultProxy(ResultProxy):
"""
- _cursor_strategy_cls = BufferedRowCursorFetchStrategy
+ _cursor_strategy_cls: Type[
+ CursorFetchStrategy
+ ] = BufferedRowCursorFetchStrategy
class FullyBufferedResultProxy(ResultProxy):
@@ -1744,5 +1793,3 @@ class BufferedColumnResultProxy(ResultProxy):
and this class does not change behavior in any way.
"""
-
- _process_row = BufferedColumnRow
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index a4dbf2361..0e0c76389 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -19,12 +19,30 @@ import functools
import random
import re
from time import perf_counter
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import MutableMapping
+from typing import MutableSequence
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
import weakref
from . import characteristics
from . import cursor as _cursor
from . import interfaces
from .base import Connection
+from .interfaces import CacheStats
+from .interfaces import DBAPICursor
+from .interfaces import Dialect
+from .interfaces import ExecutionContext
from .. import event
from .. import exc
from .. import pool
@@ -32,25 +50,49 @@ from .. import types as sqltypes
from .. import util
from ..sql import compiler
from ..sql import expression
+from ..sql.compiler import DDLCompiler
+from ..sql.compiler import SQLCompiler
from ..sql.elements import quoted_name
+if typing.TYPE_CHECKING:
+ from .interfaces import _AnyMultiExecuteParams
+ from .interfaces import _CoreMultiExecuteParams
+ from .interfaces import _CoreSingleExecuteParams
+ from .interfaces import _DBAPIAnyExecuteParams
+ from .interfaces import _DBAPIMultiExecuteParams
+ from .interfaces import _DBAPISingleExecuteParams
+ from .interfaces import _ExecuteOptions
+ from .result import _ProcessorType
+ from .row import Row
+ from .url import URL
+ from ..event import _ListenerFnType
+ from ..pool import Pool
+ from ..pool import PoolProxiedConnection
+ from ..sql import Executable
+ from ..sql.compiler import Compiled
+ from ..sql.compiler import ResultColumnsEntry
+ from ..sql.schema import Column
+ from ..sql.type_api import TypeEngine
+
# When we're handed literal SQL, ensure it's a SELECT query
SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
-CACHE_HIT = util.symbol("CACHE_HIT")
-CACHE_MISS = util.symbol("CACHE_MISS")
-CACHING_DISABLED = util.symbol("CACHING_DISABLED")
-NO_CACHE_KEY = util.symbol("NO_CACHE_KEY")
-NO_DIALECT_SUPPORT = util.symbol("NO_DIALECT_SUPPORT")
+(
+ CACHE_HIT,
+ CACHE_MISS,
+ CACHING_DISABLED,
+ NO_CACHE_KEY,
+ NO_DIALECT_SUPPORT,
+) = list(CacheStats)
-class DefaultDialect(interfaces.Dialect):
+class DefaultDialect(Dialect):
"""Default implementation of Dialect"""
statement_compiler = compiler.SQLCompiler
ddl_compiler = compiler.DDLCompiler
- type_compiler = compiler.GenericTypeCompiler
+ type_compiler = compiler.GenericTypeCompiler # type: ignore
preparer = compiler.IdentifierPreparer
supports_alter = True
supports_comments = False
@@ -61,8 +103,8 @@ class DefaultDialect(interfaces.Dialect):
bind_typing = interfaces.BindTyping.NONE
- include_set_input_sizes = None
- exclude_set_input_sizes = None
+ include_set_input_sizes: Optional[Set[Any]] = None
+ exclude_set_input_sizes: Optional[Set[Any]] = None
# the first value we'd get for an autoincrement
# column.
@@ -70,7 +112,7 @@ class DefaultDialect(interfaces.Dialect):
# most DBAPIs happy with this for execute().
# not cx_oracle.
- execute_sequence_format = tuple
+ execute_sequence_format = tuple # type: ignore
supports_schemas = True
supports_views = True
@@ -97,16 +139,16 @@ class DefaultDialect(interfaces.Dialect):
{"isolation_level": characteristics.IsolationLevelCharacteristic()}
)
- engine_config_types = util.immutabledict(
- [
- ("pool_timeout", util.asint),
- ("echo", util.bool_or_str("debug")),
- ("echo_pool", util.bool_or_str("debug")),
- ("pool_recycle", util.asint),
- ("pool_size", util.asint),
- ("max_overflow", util.asint),
- ("future", util.asbool),
- ]
+ engine_config_types: Mapping[str, Any] = util.immutabledict(
+ {
+ "pool_timeout": util.asint,
+ "echo": util.bool_or_str("debug"),
+ "echo_pool": util.bool_or_str("debug"),
+ "pool_recycle": util.asint,
+ "pool_size": util.asint,
+ "max_overflow": util.asint,
+ "future": util.asbool,
+ }
)
# if the NUMERIC type
@@ -119,19 +161,21 @@ class DefaultDialect(interfaces.Dialect):
# length at which to truncate
# any identifier.
max_identifier_length = 9999
- _user_defined_max_identifier_length = None
+ _user_defined_max_identifier_length: Optional[int] = None
- isolation_level = None
+ isolation_level: Optional[str] = None
# sub-categories of max_identifier_length.
# currently these accommodate for MySQL which allows alias names
# of 255 but DDL names only of 64.
- max_index_name_length = None
- max_constraint_name_length = None
+ max_index_name_length: Optional[int] = None
+ max_constraint_name_length: Optional[int] = None
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
- colspecs = {}
+ colspecs: MutableMapping[
+ Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]
+ ] = {}
default_paramstyle = "named"
supports_default_values = False
@@ -160,43 +204,6 @@ class DefaultDialect(interfaces.Dialect):
default_schema_name = None
- construct_arguments = None
- """Optional set of argument specifiers for various SQLAlchemy
- constructs, typically schema items.
-
- To implement, establish as a series of tuples, as in::
-
- construct_arguments = [
- (schema.Index, {
- "using": False,
- "where": None,
- "ops": None
- })
- ]
-
- If the above construct is established on the PostgreSQL dialect,
- the :class:`.Index` construct will now accept the keyword arguments
- ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``.
- Any other argument specified to the constructor of :class:`.Index`
- which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`.
-
- A dialect which does not include a ``construct_arguments`` member will
- not participate in the argument validation system. For such a dialect,
- any argument name is accepted by all participating constructs, within
- the namespace of arguments prefixed with that dialect name. The rationale
- here is so that third-party dialects that haven't yet implemented this
- feature continue to function in the old way.
-
- .. versionadded:: 0.9.2
-
- .. seealso::
-
- :class:`.DialectKWArgs` - implementing base class which consumes
- :attr:`.DefaultDialect.construct_arguments`
-
-
- """
-
# indicates symbol names are
# UPPERCASEd if they are case insensitive
# within the database.
@@ -204,17 +211,6 @@ class DefaultDialect(interfaces.Dialect):
# and denormalize_name() must be provided.
requires_name_normalize = False
- reflection_options = ()
-
- dbapi_exception_translation_map = util.immutabledict()
- """mapping used in the extremely unusual case that a DBAPI's
- published exceptions don't actually have the __name__ that they
- are linked towards.
-
- .. versionadded:: 1.0.5
-
- """
-
is_async = False
CACHE_HIT = CACHE_HIT
@@ -363,10 +359,10 @@ class DefaultDialect(interfaces.Dialect):
return self.supports_sane_rowcount
@classmethod
- def get_pool_class(cls, url):
+ def get_pool_class(cls, url: URL) -> Type[Pool]:
return getattr(cls, "poolclass", pool.QueuePool)
- def get_dialect_pool_class(self, url):
+ def get_dialect_pool_class(self, url: URL) -> Type[Pool]:
return self.get_pool_class(url)
@classmethod
@@ -377,7 +373,7 @@ class DefaultDialect(interfaces.Dialect):
except ImportError:
pass
- def _builtin_onconnect(self):
+ def _builtin_onconnect(self) -> Optional[_ListenerFnType]:
if self._on_connect_isolation_level is not None:
def builtin_connect(dbapi_conn, conn_rec):
@@ -734,7 +730,7 @@ class StrCompileDialect(DefaultDialect):
statement_compiler = compiler.StrSQLCompiler
ddl_compiler = compiler.DDLCompiler
- type_compiler = compiler.StrSQLTypeCompiler
+ type_compiler = compiler.StrSQLTypeCompiler # type: ignore
preparer = compiler.IdentifierPreparer
supports_statement_cache = True
@@ -758,24 +754,26 @@ class StrCompileDialect(DefaultDialect):
}
-class DefaultExecutionContext(interfaces.ExecutionContext):
+class DefaultExecutionContext(ExecutionContext):
isinsert = False
isupdate = False
isdelete = False
is_crud = False
is_text = False
isddl = False
+
executemany = False
- compiled = None
- statement = None
- result_column_struct = None
- returned_default_rows = None
- execution_options = util.immutabledict()
+ compiled: Optional[Compiled] = None
+ result_column_struct: Optional[
+ Tuple[List[ResultColumnsEntry], bool, bool, bool]
+ ] = None
+ returned_default_rows: Optional[List[Row]] = None
+
+ execution_options: _ExecuteOptions = util.EMPTY_DICT
cursor_fetch_strategy = _cursor._DEFAULT_FETCH
- cache_stats = None
- invoked_statement = None
+ invoked_statement: Optional[Executable] = None
_is_implicit_returning = False
_is_explicit_returning = False
@@ -786,21 +784,37 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# a hook for SQLite's translation of
# result column names
# NOTE: pyhive is using this hook, can't remove it :(
- _translate_colname = None
+ _translate_colname: Optional[Callable[[str], str]] = None
+
+ _expanded_parameters: Mapping[str, List[str]] = util.immutabledict()
+ """used by set_input_sizes().
+
+ This collection comes from ``ExpandedState.parameter_expansion``.
- _expanded_parameters = util.immutabledict()
+ """
cache_hit = NO_CACHE_KEY
+ root_connection: Connection
+ _dbapi_connection: PoolProxiedConnection
+ dialect: Dialect
+ unicode_statement: str
+ cursor: DBAPICursor
+ compiled_parameters: _CoreMultiExecuteParams
+ parameters: _DBAPIMultiExecuteParams
+ extracted_parameters: _CoreSingleExecuteParams
+
+ _empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT)
+
@classmethod
def _init_ddl(
cls,
- dialect,
- connection,
- dbapi_connection,
- execution_options,
- compiled_ddl,
- ):
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ compiled_ddl: DDLCompiler,
+ ) -> ExecutionContext:
"""Initialize execution context for a DDLElement construct."""
self = cls.__new__(cls)
@@ -832,23 +846,23 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if dialect.positional:
self.parameters = [dialect.execute_sequence_format()]
else:
- self.parameters = [{}]
+ self.parameters = [self._empty_dict_params]
return self
@classmethod
def _init_compiled(
cls,
- dialect,
- connection,
- dbapi_connection,
- execution_options,
- compiled,
- parameters,
- invoked_statement,
- extracted_parameters,
- cache_hit=CACHING_DISABLED,
- ):
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ compiled: SQLCompiler,
+ parameters: _CoreMultiExecuteParams,
+ invoked_statement: Executable,
+ extracted_parameters: _CoreSingleExecuteParams,
+ cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
+ ) -> ExecutionContext:
"""Initialize execution context for a Compiled construct."""
self = cls.__new__(cls)
@@ -868,6 +882,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
compiled._textual_ordered_columns,
compiled._loose_column_name_matching,
)
+
self.isinsert = compiled.isinsert
self.isupdate = compiled.isupdate
self.isdelete = compiled.isdelete
@@ -910,6 +925,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
processors = compiled._bind_processors
+ flattened_processors: Mapping[
+ str, _ProcessorType
+ ] = processors # type: ignore[assignment]
+
if compiled.literal_execute_params or compiled.post_compile_params:
if self.executemany:
raise exc.InvalidRequestError(
@@ -924,14 +943,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# re-assign self.unicode_statement
self.unicode_statement = expanded_state.statement
- # used by set_input_sizes() which is needed for Oracle
self._expanded_parameters = expanded_state.parameter_expansion
- processors = dict(processors)
- processors.update(expanded_state.processors)
+ flattened_processors = dict(processors) # type: ignore
+ flattened_processors.update(expanded_state.processors)
positiontup = expanded_state.positiontup
elif compiled.positional:
positiontup = self.compiled.positiontup
+ else:
+ positiontup = None
if compiled.schema_translate_map:
schema_translate_map = self.execution_options.get(
@@ -949,42 +969,49 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# Convert the dictionary of bind parameter values
# into a dict or list to be sent to the DBAPI's
# execute() or executemany() method.
- parameters = []
+
if compiled.positional:
+ core_positional_parameters: MutableSequence[Sequence[Any]] = []
+ assert positiontup is not None
for compiled_params in self.compiled_parameters:
- param = [
- processors[key](compiled_params[key])
- if key in processors
+ l_param: List[Any] = [
+ flattened_processors[key](compiled_params[key])
+ if key in flattened_processors
else compiled_params[key]
for key in positiontup
]
- parameters.append(dialect.execute_sequence_format(param))
+ core_positional_parameters.append(
+ dialect.execute_sequence_format(l_param)
+ )
+
+ self.parameters = core_positional_parameters
else:
+ core_dict_parameters: MutableSequence[Dict[str, Any]] = []
for compiled_params in self.compiled_parameters:
- param = {
- key: processors[key](compiled_params[key])
- if key in processors
+ d_param: Dict[str, Any] = {
+ key: flattened_processors[key](compiled_params[key])
+ if key in flattened_processors
else compiled_params[key]
for key in compiled_params
}
- parameters.append(param)
+ core_dict_parameters.append(d_param)
- self.parameters = dialect.execute_sequence_format(parameters)
+ self.parameters = core_dict_parameters
return self
@classmethod
def _init_statement(
cls,
- dialect,
- connection,
- dbapi_connection,
- execution_options,
- statement,
- parameters,
- ):
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ statement: str,
+ parameters: _DBAPIMultiExecuteParams,
+ ) -> ExecutionContext:
"""Initialize execution context for a string SQL statement."""
self = cls.__new__(cls)
@@ -999,7 +1026,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if self.dialect.positional:
self.parameters = [dialect.execute_sequence_format()]
else:
- self.parameters = [{}]
+ self.parameters = [self._empty_dict_params]
elif isinstance(parameters[0], dialect.execute_sequence_format):
self.parameters = parameters
elif isinstance(parameters[0], dict):
@@ -1018,8 +1045,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
@classmethod
def _init_default(
- cls, dialect, connection, dbapi_connection, execution_options
- ):
+ cls,
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ ) -> ExecutionContext:
"""Initialize execution context for a ColumnDefault construct."""
self = cls.__new__(cls)
@@ -1032,7 +1063,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.cursor = self.create_cursor()
return self
- def _get_cache_stats(self):
+ def _get_cache_stats(self) -> str:
if self.compiled is None:
return "raw sql"
@@ -1040,19 +1071,22 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
ch = self.cache_hit
+ gen_time = self.compiled._gen_time
+ assert gen_time is not None
+
if ch is NO_CACHE_KEY:
- return "no key %.5fs" % (now - self.compiled._gen_time,)
+ return "no key %.5fs" % (now - gen_time,)
elif ch is CACHE_HIT:
- return "cached since %.4gs ago" % (now - self.compiled._gen_time,)
+ return "cached since %.4gs ago" % (now - gen_time,)
elif ch is CACHE_MISS:
- return "generated in %.5fs" % (now - self.compiled._gen_time,)
+ return "generated in %.5fs" % (now - gen_time,)
elif ch is CACHING_DISABLED:
- return "caching disabled %.5fs" % (now - self.compiled._gen_time,)
+ return "caching disabled %.5fs" % (now - gen_time,)
elif ch is NO_DIALECT_SUPPORT:
return "dialect %s+%s does not support caching %.5fs" % (
self.dialect.name,
self.dialect.driver,
- now - self.compiled._gen_time,
+ now - gen_time,
)
else:
return "unknown"
@@ -1073,11 +1107,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return self.root_connection.engine
@util.memoized_property
- def postfetch_cols(self):
+ def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501
+ assert isinstance(self.compiled, SQLCompiler)
return self.compiled.postfetch
@util.memoized_property
- def prefetch_cols(self):
+ def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501
+ assert isinstance(self.compiled, SQLCompiler)
if self.isinsert:
return self.compiled.insert_prefetch
elif self.isupdate:
@@ -1086,8 +1122,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
return ()
@util.memoized_property
- def returning_cols(self):
- self.compiled.returning
+ def returning_cols(self) -> Optional[Sequence[Column[Any]]]:
+ assert isinstance(self.compiled, SQLCompiler)
+ return self.compiled.returning
@util.memoized_property
def no_parameters(self):
@@ -1564,7 +1601,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
str(compiled), type_, parameters=parameters
)
- current_parameters = None
+ current_parameters: Optional[_CoreSingleExecuteParams] = None
"""A dictionary of parameters applied to the current row.
This attribute is only available in the context of a user-defined default
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
index ab462bbe1..0cbf56a6d 100644
--- a/lib/sqlalchemy/engine/events.py
+++ b/lib/sqlalchemy/engine/events.py
@@ -8,14 +8,41 @@
from __future__ import annotations
+import typing
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import Union
+
from .base import Engine
from .interfaces import ConnectionEventsTarget
+from .interfaces import DBAPIConnection
+from .interfaces import DBAPICursor
from .interfaces import Dialect
from .. import event
from .. import exc
-
-
-class ConnectionEvents(event.Events):
+from ..util.typing import Literal
+
+if typing.TYPE_CHECKING:
+ from .base import Connection
+ from .interfaces import _CoreAnyExecuteParams
+ from .interfaces import _CoreMultiExecuteParams
+ from .interfaces import _CoreSingleExecuteParams
+ from .interfaces import _DBAPIAnyExecuteParams
+ from .interfaces import _DBAPIMultiExecuteParams
+ from .interfaces import _DBAPISingleExecuteParams
+ from .interfaces import _ExecuteOptions
+ from .interfaces import ExceptionContext
+ from .interfaces import ExecutionContext
+ from .result import Result
+ from ..pool import ConnectionPoolEntry
+ from ..sql import Executable
+ from ..sql.elements import BindParameter
+
+
+class ConnectionEvents(event.Events[ConnectionEventsTarget]):
"""Available events for
:class:`_engine.Connection` and :class:`_engine.Engine`.
@@ -96,7 +123,12 @@ class ConnectionEvents(event.Events):
_dispatch_target = ConnectionEventsTarget
@classmethod
- def _listen(cls, event_key, retval=False):
+ def _listen( # type: ignore[override]
+ cls,
+ event_key: event._EventKey[ConnectionEventsTarget],
+ retval: bool = False,
+ **kw: Any,
+ ) -> None:
target, identifier, fn = (
event_key.dispatch_target,
event_key.identifier,
@@ -109,7 +141,7 @@ class ConnectionEvents(event.Events):
if identifier == "before_execute":
orig_fn = fn
- def wrap_before_execute(
+ def wrap_before_execute( # type: ignore
conn, clauseelement, multiparams, params, execution_options
):
orig_fn(
@@ -125,7 +157,7 @@ class ConnectionEvents(event.Events):
elif identifier == "before_cursor_execute":
orig_fn = fn
- def wrap_before_cursor_execute(
+ def wrap_before_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
orig_fn(
@@ -163,8 +195,15 @@ class ConnectionEvents(event.Events):
),
)
def before_execute(
- self, conn, clauseelement, multiparams, params, execution_options
- ):
+ self,
+ conn: Connection,
+ clauseelement: Executable,
+ multiparams: _CoreMultiExecuteParams,
+ params: _CoreSingleExecuteParams,
+ execution_options: _ExecuteOptions,
+ ) -> Optional[
+ Tuple[Executable, _CoreMultiExecuteParams, _CoreSingleExecuteParams]
+ ]:
"""Intercept high level execute() events, receiving uncompiled
SQL constructs and other objects prior to rendering into SQL.
@@ -214,13 +253,13 @@ class ConnectionEvents(event.Events):
)
def after_execute(
self,
- conn,
- clauseelement,
- multiparams,
- params,
- execution_options,
- result,
- ):
+ conn: Connection,
+ clauseelement: Executable,
+ multiparams: _CoreMultiExecuteParams,
+ params: _CoreSingleExecuteParams,
+ execution_options: _ExecuteOptions,
+ result: Result,
+ ) -> None:
"""Intercept high level execute() events after execute.
@@ -244,8 +283,14 @@ class ConnectionEvents(event.Events):
"""
def before_cursor_execute(
- self, conn, cursor, statement, parameters, context, executemany
- ):
+ self,
+ conn: Connection,
+ cursor: DBAPICursor,
+ statement: str,
+ parameters: _DBAPIAnyExecuteParams,
+ context: Optional[ExecutionContext],
+ executemany: bool,
+ ) -> Optional[Tuple[str, _DBAPIAnyExecuteParams]]:
"""Intercept low-level cursor execute() events before execution,
receiving the string SQL statement and DBAPI-specific parameter list to
be invoked against a cursor.
@@ -286,8 +331,14 @@ class ConnectionEvents(event.Events):
"""
def after_cursor_execute(
- self, conn, cursor, statement, parameters, context, executemany
- ):
+ self,
+ conn: Connection,
+ cursor: DBAPICursor,
+ statement: str,
+ parameters: _DBAPIAnyExecuteParams,
+ context: Optional[ExecutionContext],
+ executemany: bool,
+ ) -> None:
"""Intercept low-level cursor execute() events after execution.
:param conn: :class:`_engine.Connection` object
@@ -305,7 +356,9 @@ class ConnectionEvents(event.Events):
"""
- def handle_error(self, exception_context):
+ def handle_error(
+ self, exception_context: ExceptionContext
+ ) -> Optional[BaseException]:
r"""Intercept all exceptions processed by the
:class:`_engine.Connection`.
@@ -439,7 +492,7 @@ class ConnectionEvents(event.Events):
@event._legacy_signature(
"2.0", ["conn", "branch"], converter=lambda conn: (conn, False)
)
- def engine_connect(self, conn):
+ def engine_connect(self, conn: Connection) -> None:
"""Intercept the creation of a new :class:`_engine.Connection`.
This event is called typically as the direct result of calling
@@ -475,7 +528,9 @@ class ConnectionEvents(event.Events):
"""
- def set_connection_execution_options(self, conn, opts):
+ def set_connection_execution_options(
+ self, conn: Connection, opts: Dict[str, Any]
+ ) -> None:
"""Intercept when the :meth:`_engine.Connection.execution_options`
method is called.
@@ -494,8 +549,12 @@ class ConnectionEvents(event.Events):
:param opts: dictionary of options that were passed to the
:meth:`_engine.Connection.execution_options` method.
+ This dictionary may be modified in place to affect the ultimate
+ options which take effect.
+
+ .. versionadded:: 2.0 the ``opts`` dictionary may be modified
+ in place.
- .. versionadded:: 0.9.0
.. seealso::
@@ -507,7 +566,9 @@ class ConnectionEvents(event.Events):
"""
- def set_engine_execution_options(self, engine, opts):
+ def set_engine_execution_options(
+ self, engine: Engine, opts: Dict[str, Any]
+ ) -> None:
"""Intercept when the :meth:`_engine.Engine.execution_options`
method is called.
@@ -526,8 +587,11 @@ class ConnectionEvents(event.Events):
:param opts: dictionary of options that were passed to the
:meth:`_engine.Connection.execution_options` method.
+ This dictionary may be modified in place to affect the ultimate
+ options which take effect.
- .. versionadded:: 0.9.0
+ .. versionadded:: 2.0 the ``opts`` dictionary may be modified
+ in place.
.. seealso::
@@ -539,7 +603,7 @@ class ConnectionEvents(event.Events):
"""
- def engine_disposed(self, engine):
+ def engine_disposed(self, engine: Engine) -> None:
"""Intercept when the :meth:`_engine.Engine.dispose` method is called.
The :meth:`_engine.Engine.dispose` method instructs the engine to
@@ -559,14 +623,14 @@ class ConnectionEvents(event.Events):
"""
- def begin(self, conn):
+ def begin(self, conn: Connection) -> None:
"""Intercept begin() events.
:param conn: :class:`_engine.Connection` object
"""
- def rollback(self, conn):
+ def rollback(self, conn: Connection) -> None:
"""Intercept rollback() events, as initiated by a
:class:`.Transaction`.
@@ -584,7 +648,7 @@ class ConnectionEvents(event.Events):
"""
- def commit(self, conn):
+ def commit(self, conn: Connection) -> None:
"""Intercept commit() events, as initiated by a
:class:`.Transaction`.
@@ -596,7 +660,7 @@ class ConnectionEvents(event.Events):
:param conn: :class:`_engine.Connection` object
"""
- def savepoint(self, conn, name):
+ def savepoint(self, conn: Connection, name: str) -> None:
"""Intercept savepoint() events.
:param conn: :class:`_engine.Connection` object
@@ -604,7 +668,9 @@ class ConnectionEvents(event.Events):
"""
- def rollback_savepoint(self, conn, name, context):
+ def rollback_savepoint(
+ self, conn: Connection, name: str, context: None
+ ) -> None:
"""Intercept rollback_savepoint() events.
:param conn: :class:`_engine.Connection` object
@@ -614,7 +680,9 @@ class ConnectionEvents(event.Events):
"""
# TODO: deprecate "context"
- def release_savepoint(self, conn, name, context):
+ def release_savepoint(
+ self, conn: Connection, name: str, context: None
+ ) -> None:
"""Intercept release_savepoint() events.
:param conn: :class:`_engine.Connection` object
@@ -624,7 +692,7 @@ class ConnectionEvents(event.Events):
"""
# TODO: deprecate "context"
- def begin_twophase(self, conn, xid):
+ def begin_twophase(self, conn: Connection, xid: Any) -> None:
"""Intercept begin_twophase() events.
:param conn: :class:`_engine.Connection` object
@@ -632,14 +700,16 @@ class ConnectionEvents(event.Events):
"""
- def prepare_twophase(self, conn, xid):
+ def prepare_twophase(self, conn: Connection, xid: Any) -> None:
"""Intercept prepare_twophase() events.
:param conn: :class:`_engine.Connection` object
:param xid: two-phase XID identifier
"""
- def rollback_twophase(self, conn, xid, is_prepared):
+ def rollback_twophase(
+ self, conn: Connection, xid: Any, is_prepared: bool
+ ) -> None:
"""Intercept rollback_twophase() events.
:param conn: :class:`_engine.Connection` object
@@ -649,7 +719,9 @@ class ConnectionEvents(event.Events):
"""
- def commit_twophase(self, conn, xid, is_prepared):
+ def commit_twophase(
+ self, conn: Connection, xid: Any, is_prepared: bool
+ ) -> None:
"""Intercept commit_twophase() events.
:param conn: :class:`_engine.Connection` object
@@ -660,7 +732,7 @@ class ConnectionEvents(event.Events):
"""
-class DialectEvents(event.Events):
+class DialectEvents(event.Events[Dialect]):
"""event interface for execution-replacement functions.
These events allow direct instrumentation and replacement
@@ -694,14 +766,20 @@ class DialectEvents(event.Events):
_dispatch_target = Dialect
@classmethod
- def _listen(cls, event_key, retval=False):
+ def _listen( # type: ignore
+ cls,
+ event_key: event._EventKey[Dialect],
+ retval: bool = False,
+ ) -> None:
target = event_key.dispatch_target
target._has_events = True
event_key.base_listen()
@classmethod
- def _accept_with(cls, target):
+ def _accept_with(
+ cls, target: Union[Engine, Type[Engine], Dialect, Type[Dialect]]
+ ) -> Union[Dialect, Type[Dialect]]:
if isinstance(target, type):
if issubclass(target, Engine):
return Dialect
@@ -712,7 +790,13 @@ class DialectEvents(event.Events):
else:
return target
- def do_connect(self, dialect, conn_rec, cargs, cparams):
+ def do_connect(
+ self,
+ dialect: Dialect,
+ conn_rec: ConnectionPoolEntry,
+ cargs: Tuple[Any, ...],
+ cparams: Dict[str, Any],
+ ) -> Optional[DBAPIConnection]:
"""Receive connection arguments before a connection is made.
This event is useful in that it allows the handler to manipulate the
@@ -745,7 +829,13 @@ class DialectEvents(event.Events):
"""
- def do_executemany(self, cursor, statement, parameters, context):
+ def do_executemany(
+ self,
+ cursor: DBAPICursor,
+ statement: str,
+ parameters: _DBAPIMultiExecuteParams,
+ context: ExecutionContext,
+ ) -> Optional[Literal[True]]:
"""Receive a cursor to have executemany() called.
Return the value True to halt further events from invoking,
@@ -754,7 +844,9 @@ class DialectEvents(event.Events):
"""
- def do_execute_no_params(self, cursor, statement, context):
+ def do_execute_no_params(
+ self, cursor: DBAPICursor, statement: str, context: ExecutionContext
+ ) -> Optional[Literal[True]]:
"""Receive a cursor to have execute() with no parameters called.
Return the value True to halt further events from invoking,
@@ -763,7 +855,13 @@ class DialectEvents(event.Events):
"""
- def do_execute(self, cursor, statement, parameters, context):
+ def do_execute(
+ self,
+ cursor: DBAPICursor,
+ statement: str,
+ parameters: _DBAPISingleExecuteParams,
+ context: ExecutionContext,
+ ) -> Optional[Literal[True]]:
"""Receive a cursor to have execute() called.
Return the value True to halt further events from invoking,
@@ -773,8 +871,13 @@ class DialectEvents(event.Events):
"""
def do_setinputsizes(
- self, inputsizes, cursor, statement, parameters, context
- ):
+ self,
+ inputsizes: Dict[BindParameter[Any], Any],
+ cursor: DBAPICursor,
+ statement: str,
+ parameters: _DBAPIAnyExecuteParams,
+ context: ExecutionContext,
+ ) -> None:
"""Receive the setinputsizes dictionary for possible modification.
This event is emitted in the case where the dialect makes use of the
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 860c1faf9..545dd0ddc 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -10,21 +10,31 @@
from __future__ import annotations
from enum import Enum
+from types import ModuleType
from typing import Any
+from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import List
from typing import Mapping
+from typing import MutableMapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
+from .. import util
+from ..event import EventTarget
+from ..pool import Pool
from ..pool import PoolProxiedConnection
+from ..sql.compiler import Compiled as Compiled
from ..sql.compiler import Compiled # noqa
+from ..sql.compiler import TypeCompiler as TypeCompiler
from ..sql.compiler import TypeCompiler # noqa
+from ..util import immutabledict
from ..util.concurrency import await_only
from ..util.typing import _TypeToInstance
from ..util.typing import NotRequired
@@ -34,12 +44,33 @@ from ..util.typing import TypedDict
if TYPE_CHECKING:
from .base import Connection
from .base import Engine
+ from .result import Result
from .url import URL
+ from ..event import _ListenerFnType
+ from ..event import dispatcher
+ from ..exc import StatementError
+ from ..sql import Executable
from ..sql.compiler import DDLCompiler
from ..sql.compiler import IdentifierPreparer
+ from ..sql.compiler import Linting
from ..sql.compiler import SQLCompiler
+ from ..sql.elements import ClauseElement
+ from ..sql.schema import Column
+ from ..sql.schema import ColumnDefault
from ..sql.type_api import TypeEngine
+ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]]
+
+_T = TypeVar("_T", bound="Any")
+
+
+class CacheStats(Enum):
+ CACHE_HIT = 0
+ CACHE_MISS = 1
+ CACHING_DISABLED = 2
+ NO_CACHE_KEY = 3
+ NO_DIALECT_SUPPORT = 4
+
class DBAPIConnection(Protocol):
"""protocol representing a :pep:`249` database connection.
@@ -65,6 +96,8 @@ class DBAPIConnection(Protocol):
def rollback(self) -> None:
...
+ autocommit: bool
+
class DBAPIType(Protocol):
"""protocol representing a :pep:`249` database type.
@@ -128,14 +161,14 @@ class DBAPICursor(Protocol):
def execute(
self,
operation: Any,
- parameters: Optional[Union[Sequence[Any], Mapping[str, Any]]],
+ parameters: Optional[_DBAPISingleExecuteParams],
) -> Any:
...
def executemany(
self,
operation: Any,
- parameters: Sequence[Union[Sequence[Any], Mapping[str, Any]]],
+ parameters: Sequence[_DBAPIMultiExecuteParams],
) -> Any:
...
@@ -161,6 +194,34 @@ class DBAPICursor(Protocol):
...
+_CoreSingleExecuteParams = Mapping[str, Any]
+_CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams]
+_CoreAnyExecuteParams = Union[
+ _CoreMultiExecuteParams, _CoreSingleExecuteParams
+]
+
+_DBAPISingleExecuteParams = Union[Sequence[Any], _CoreSingleExecuteParams]
+
+_DBAPIMultiExecuteParams = Union[
+ Sequence[Sequence[Any]], _CoreMultiExecuteParams
+]
+_DBAPIAnyExecuteParams = Union[
+ _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams
+]
+_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any]
+
+_AnySingleExecuteParams = _DBAPISingleExecuteParams
+_AnyMultiExecuteParams = _DBAPIMultiExecuteParams
+_AnyExecuteParams = _DBAPIAnyExecuteParams
+
+
+_ExecuteOptions = immutabledict[str, Any]
+_ExecuteOptionsParameter = Mapping[str, Any]
+_SchemaTranslateMapType = Mapping[str, str]
+
+_ImmutableExecuteOptions = immutabledict[str, Any]
+
+
class ReflectedIdentity(TypedDict):
"""represent the reflected IDENTITY structure of a column, corresponding
to the :class:`_schema.Identity` construct.
@@ -237,7 +298,7 @@ class ReflectedColumn(TypedDict):
name: str
"""column name"""
- type: "TypeEngine"
+ type: TypeEngine[Any]
"""column type represented as a :class:`.TypeEngine` instance."""
nullable: bool
@@ -465,7 +526,10 @@ class BindTyping(Enum):
"""
-class Dialect:
+VersionInfoType = Tuple[Union[int, str], ...]
+
+
+class Dialect(EventTarget):
"""Define the behavior of a specific database and DB-API combination.
Any aspect of metadata definition, SQL query generation,
@@ -481,6 +545,8 @@ class Dialect:
"""
+ dispatch: dispatcher[Dialect]
+
name: str
"""identifying name for the dialect from a DBAPI-neutral point of view
(i.e. 'sqlite')
@@ -489,6 +555,29 @@ class Dialect:
driver: str
"""identifying name for the dialect's DBAPI"""
+ dbapi: ModuleType
+ """A reference to the DBAPI module object itself.
+
+ SQLAlchemy dialects import DBAPI modules using the classmethod
+ :meth:`.Dialect.import_dbapi`. The rationale is so that any dialect
+ module can be imported and used to generate SQL statements without the
+ need for the actual DBAPI driver to be installed. Only when an
+ :class:`.Engine` is constructed using :func:`.create_engine` does the
+ DBAPI get imported; at that point, the creation process will assign
+ the DBAPI module to this attribute.
+
+ Dialects should therefore implement :meth:`.Dialect.import_dbapi`
+ which will import the necessary module and return it, and then refer
+ to ``self.dbapi`` in dialect code in order to refer to the DBAPI module
+ contents.
+
+ .. versionchanged:: The :attr:`.Dialect.dbapi` attribute is exclusively
+ used as the per-:class:`.Dialect`-instance reference to the DBAPI
+ module. The previous not-fully-documented ``.Dialect.dbapi()``
+ classmethod is deprecated and replaced by :meth:`.Dialect.import_dbapi`.
+
+ """
+
positional: bool
"""True if the paramstyle for this Dialect is positional."""
@@ -497,21 +586,23 @@ class Dialect:
paramstyles).
"""
- statement_compiler: Type["SQLCompiler"]
+ compiler_linting: Linting
+
+ statement_compiler: Type[SQLCompiler]
"""a :class:`.Compiled` class used to compile SQL statements"""
- ddl_compiler: Type["DDLCompiler"]
+ ddl_compiler: Type[DDLCompiler]
"""a :class:`.Compiled` class used to compile DDL statements"""
- type_compiler: _TypeToInstance["TypeCompiler"]
+ type_compiler: _TypeToInstance[TypeCompiler]
"""a :class:`.Compiled` class used to compile SQL type objects"""
- preparer: Type["IdentifierPreparer"]
+ preparer: Type[IdentifierPreparer]
"""a :class:`.IdentifierPreparer` class used to
quote identifiers.
"""
- identifier_preparer: "IdentifierPreparer"
+ identifier_preparer: IdentifierPreparer
"""This element will refer to an instance of :class:`.IdentifierPreparer`
once a :class:`.DefaultDialect` has been constructed.
@@ -531,10 +622,15 @@ class Dialect:
"""
+ default_isolation_level: str
+ """the isolation that is implicitly present on new connections"""
+
execution_ctx_cls: Type["ExecutionContext"]
"""a :class:`.ExecutionContext` class used to handle statement execution"""
- execute_sequence_format: Union[Type[Tuple[Any, ...]], Type[List[Any]]]
+ execute_sequence_format: Union[
+ Type[Tuple[Any, ...]], Type[Tuple[List[Any]]]
+ ]
"""either the 'tuple' or 'list' type, depending on what cursor.execute()
accepts for the second argument (they vary)."""
@@ -579,7 +675,7 @@ class Dialect:
"""
- colspecs: Dict[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]]
+ colspecs: MutableMapping[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]]
"""A dictionary of TypeEngine classes from sqlalchemy.types mapped
to subclasses that are specific to the dialect class. This
dictionary is class-level only and is not accessed from the
@@ -610,7 +706,55 @@ class Dialect:
constraint when that type is used.
"""
- dbapi_exception_translation_map: Dict[str, str]
+ construct_arguments: Optional[
+ List[Tuple[Type[ClauseElement], Mapping[str, Any]]]
+ ] = None
+ """Optional set of argument specifiers for various SQLAlchemy
+ constructs, typically schema items.
+
+ To implement, establish as a series of tuples, as in::
+
+ construct_arguments = [
+ (schema.Index, {
+ "using": False,
+ "where": None,
+ "ops": None
+ })
+ ]
+
+ If the above construct is established on the PostgreSQL dialect,
+ the :class:`.Index` construct will now accept the keyword arguments
+ ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``.
+ Any other argument specified to the constructor of :class:`.Index`
+ which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`.
+
+ A dialect which does not include a ``construct_arguments`` member will
+ not participate in the argument validation system. For such a dialect,
+ any argument name is accepted by all participating constructs, within
+ the namespace of arguments prefixed with that dialect name. The rationale
+ here is so that third-party dialects that haven't yet implemented this
+ feature continue to function in the old way.
+
+ .. versionadded:: 0.9.2
+
+ .. seealso::
+
+ :class:`.DialectKWArgs` - implementing base class which consumes
+ :attr:`.DefaultDialect.construct_arguments`
+
+
+ """
+
+ reflection_options: Sequence[str] = ()
+ """Sequence of string names indicating keyword arguments that can be
+ established on a :class:`.Table` object which will be passed as
+ "reflection options" when using :paramref:`.Table.autoload_with`.
+
+ Current example is "oracle_resolve_synonyms" in the Oracle dialect.
+
+ """
+
+ dbapi_exception_translation_map: Mapping[str, str] = util.EMPTY_DICT
"""A dictionary of names that will contain as values the names of
pep-249 exceptions ("IntegrityError", "OperationalError", etc)
keyed to alternate class names, to support the case where a
@@ -660,9 +804,16 @@ class Dialect:
is_async: bool
"""Whether or not this dialect is intended for asyncio use."""
- def create_connect_args(
- self, url: "URL"
- ) -> Tuple[Tuple[str], Mapping[str, Any]]:
+ engine_config_types: Mapping[str, Any]
+ """a mapping of string keys that can be in an engine config linked to
+ type conversion functions.
+
+ """
+
+ def _builtin_onconnect(self) -> Optional[_ListenerFnType]:
+ raise NotImplementedError()
+
+ def create_connect_args(self, url: "URL") -> ConnectArgsType:
"""Build DB-API compatible connection arguments.
Given a :class:`.URL` object, returns a tuple
@@ -696,7 +847,25 @@ class Dialect:
raise NotImplementedError()
@classmethod
- def type_descriptor(cls, typeobj: "TypeEngine") -> "TypeEngine":
+ def import_dbapi(cls) -> ModuleType:
+ """Import the DBAPI module that is used by this dialect.
+
+ The Python module object returned here will be assigned as an
+ instance variable to a constructed dialect under the name
+ ``.dbapi``.
+
+ .. versionchanged:: 2.0 The :meth:`.Dialect.import_dbapi` class
+ method is renamed from the previous method ``.Dialect.dbapi()``,
+ which would be replaced at dialect instantiation time by the
+ DBAPI module itself, thus using the same name in two different ways.
+ If a ``.Dialect.dbapi()`` classmethod is present on a third-party
+ dialect, it will be used and a deprecation warning will be emitted.
+
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def type_descriptor(cls, typeobj: "TypeEngine[_T]") -> "TypeEngine[_T]":
"""Transform a generic type to a dialect-specific type.
Dialect classes will usually use the
@@ -735,7 +904,7 @@ class Dialect:
connection: "Connection",
table_name: str,
schema: Optional[str] = None,
- **kw,
+ **kw: Any,
) -> List[ReflectedColumn]:
"""Return information about columns in ``table_name``.
@@ -908,11 +1077,12 @@ class Dialect:
table_name: str,
schema: Optional[str] = None,
**kw: Any,
- ) -> Dict[str, Any]:
+ ) -> Optional[Dict[str, Any]]:
r"""Return the "options" for the table identified by ``table_name``
as a dictionary.
"""
+ return None
def get_table_comment(
self,
@@ -1115,7 +1285,7 @@ class Dialect:
def do_set_input_sizes(
self,
cursor: DBAPICursor,
- list_of_tuples: List[Tuple[str, Any, "TypeEngine"]],
+ list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]],
context: "ExecutionContext",
) -> Any:
"""invoke the cursor.setinputsizes() method with appropriate arguments
@@ -1242,7 +1412,7 @@ class Dialect:
raise NotImplementedError()
- def do_recover_twophase(self, connection: "Connection") -> None:
+ def do_recover_twophase(self, connection: "Connection") -> List[Any]:
"""Recover list of uncommitted prepared two phase transaction
identifiers on the given connection.
@@ -1256,7 +1426,7 @@ class Dialect:
self,
cursor: DBAPICursor,
statement: str,
- parameters: List[Union[Dict[str, Any], Tuple[Any]]],
+ parameters: _DBAPIMultiExecuteParams,
context: Optional["ExecutionContext"] = None,
) -> None:
"""Provide an implementation of ``cursor.executemany(statement,
@@ -1268,9 +1438,9 @@ class Dialect:
self,
cursor: DBAPICursor,
statement: str,
- parameters: Union[Mapping[str, Any], Tuple[Any]],
- context: Optional["ExecutionContext"] = None,
- ):
+ parameters: Optional[_DBAPISingleExecuteParams],
+ context: Optional[ExecutionContext] = None,
+ ) -> None:
"""Provide an implementation of ``cursor.execute(statement,
parameters)``."""
@@ -1281,7 +1451,7 @@ class Dialect:
cursor: DBAPICursor,
statement: str,
context: Optional["ExecutionContext"] = None,
- ):
+ ) -> None:
"""Provide an implementation of ``cursor.execute(statement)``.
The parameter collection should not be sent.
@@ -1294,14 +1464,14 @@ class Dialect:
self,
e: Exception,
connection: Optional[PoolProxiedConnection],
- cursor: DBAPICursor,
+ cursor: Optional[DBAPICursor],
) -> bool:
"""Return True if the given DB-API error indicates an invalid
connection"""
raise NotImplementedError()
- def connect(self, *cargs: Any, **cparams: Any) -> Any:
+ def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection:
r"""Establish a connection using this dialect's DBAPI.
The default implementation of this method is::
@@ -1333,6 +1503,7 @@ class Dialect:
:meth:`.Dialect.on_connect`
"""
+ raise NotImplementedError()
def on_connect_url(self, url: "URL") -> Optional[Callable[[Any], Any]]:
"""return a callable which sets up a newly created DBAPI connection.
@@ -1542,7 +1713,7 @@ class Dialect:
raise NotImplementedError()
- def get_default_isolation_level(self, dbapi_conn: Any) -> str:
+ def get_default_isolation_level(self, dbapi_conn: DBAPIConnection) -> str:
"""Given a DBAPI connection, return its isolation level, or
a default isolation level if one cannot be retrieved.
@@ -1562,7 +1733,9 @@ class Dialect:
"""
raise NotImplementedError()
- def get_isolation_level_values(self, dbapi_conn: Any) -> List[str]:
+ def get_isolation_level_values(
+ self, dbapi_conn: DBAPIConnection
+ ) -> List[str]:
"""return a sequence of string isolation level names that are accepted
by this dialect.
@@ -1604,8 +1777,13 @@ class Dialect:
"""
raise NotImplementedError()
+ def _assert_and_set_isolation_level(
+ self, dbapi_conn: DBAPIConnection, level: str
+ ) -> None:
+ raise NotImplementedError()
+
@classmethod
- def get_dialect_cls(cls, url: "URL") -> Type:
+ def get_dialect_cls(cls, url: URL) -> Type[Dialect]:
"""Given a URL, return the :class:`.Dialect` that will be used.
This is a hook that allows an external plugin to provide functionality
@@ -1621,7 +1799,7 @@ class Dialect:
return cls
@classmethod
- def get_async_dialect_cls(cls, url: "URL") -> None:
+ def get_async_dialect_cls(cls, url: URL) -> Type[Dialect]:
"""Given a URL, return the :class:`.Dialect` that will be used by
an async engine.
@@ -1702,6 +1880,39 @@ class Dialect:
"""
raise NotImplementedError()
+ def set_engine_execution_options(
+ self, engine: Engine, opt: _ExecuteOptionsParameter
+ ) -> None:
+ """Establish execution options for a given engine.
+
+ This is implemented by :class:`.DefaultDialect` to establish
+ event hooks for new :class:`.Connection` instances created
+ by the given :class:`.Engine` which will then invoke the
+ :meth:`.Dialect.set_connection_execution_options` method for that
+ connection.
+
+ """
+ raise NotImplementedError()
+
+ def set_connection_execution_options(
+ self, connection: Connection, opt: _ExecuteOptionsParameter
+ ) -> None:
+ """Establish execution options for a given connection.
+
+ This is implemented by :class:`.DefaultDialect` in order to implement
+ the :paramref:`_engine.Connection.execution_options.isolation_level`
+ execution option. Dialects can intercept various execution options
+ which may need to modify state on a particular DBAPI connection.
+
+ .. versionadded:: 1.4
+
+ """
+ raise NotImplementedError()
+
+ def get_dialect_pool_class(self, url: URL) -> Type[Pool]:
+ """return a Pool class to use for a given URL"""
+ raise NotImplementedError()
+
class CreateEnginePlugin:
"""A set of hooks intended to augment the construction of an
@@ -1878,7 +2089,7 @@ class CreateEnginePlugin:
""" # noqa: E501
- def __init__(self, url, kwargs):
+ def __init__(self, url: URL, kwargs: Dict[str, Any]):
"""Construct a new :class:`.CreateEnginePlugin`.
The plugin object is instantiated individually for each call
@@ -1905,7 +2116,7 @@ class CreateEnginePlugin:
"""
self.url = url
- def update_url(self, url):
+ def update_url(self, url: URL) -> URL:
"""Update the :class:`_engine.URL`.
A new :class:`_engine.URL` should be returned. This method is
@@ -1920,14 +2131,19 @@ class CreateEnginePlugin:
.. versionadded:: 1.4
"""
+ raise NotImplementedError()
- def handle_dialect_kwargs(self, dialect_cls, dialect_args):
+ def handle_dialect_kwargs(
+ self, dialect_cls: Type[Dialect], dialect_args: Dict[str, Any]
+ ) -> None:
"""parse and modify dialect kwargs"""
- def handle_pool_kwargs(self, pool_cls, pool_args):
+ def handle_pool_kwargs(
+ self, pool_cls: Type[Pool], pool_args: Dict[str, Any]
+ ) -> None:
"""parse and modify pool kwargs"""
- def engine_created(self, engine):
+ def engine_created(self, engine: Engine) -> None:
"""Receive the :class:`_engine.Engine`
object when it is fully constructed.
@@ -1941,56 +2157,137 @@ class ExecutionContext:
"""A messenger object for a Dialect that corresponds to a single
execution.
- ExecutionContext should have these data members:
+ """
- connection
- Connection object which can be freely used by default value
+ connection: Connection
+ """Connection object which can be freely used by default value
generators to execute SQL. This Connection should reference the
same underlying connection/transactional resources of
- root_connection.
+ root_connection."""
- root_connection
- Connection object which is the source of this ExecutionContext.
+ root_connection: Connection
+ """Connection object which is the source of this ExecutionContext."""
- dialect
- dialect which created this ExecutionContext.
+ dialect: Dialect
+ """dialect which created this ExecutionContext."""
- cursor
- DB-API cursor procured from the connection,
+ cursor: DBAPICursor
+ """DB-API cursor procured from the connection"""
- compiled
- if passed to constructor, sqlalchemy.engine.base.Compiled object
- being executed,
+ compiled: Optional[Compiled]
+ """if passed to constructor, sqlalchemy.engine.base.Compiled object
+ being executed"""
- statement
- string version of the statement to be executed. Is either
+ statement: str
+ """string version of the statement to be executed. Is either
passed to the constructor, or must be created from the
- sql.Compiled object by the time pre_exec() has completed.
+ sql.Compiled object by the time pre_exec() has completed."""
- parameters
- bind parameters passed to the execute() method. For compiled
- statements, this is a dictionary or list of dictionaries. For
- textual statements, it should be in a format suitable for the
- dialect's paramstyle (i.e. dict or list of dicts for non
- positional, list or list of lists/tuples for positional).
+ invoked_statement: Optional[Executable]
+ """The Executable statement object that was given in the first place.
- isinsert
- True if the statement is an INSERT.
+ This should be structurally equivalent to compiled.statement, but not
+ necessarily the same object as in a caching scenario the compiled form
+ will have been extracted from the cache.
- isupdate
- True if the statement is an UPDATE.
+ """
- prefetch_cols
- a list of Column objects for which a client-side default
- was fired off. Applies to inserts and updates.
+ parameters: _AnyMultiExecuteParams
+ """bind parameters passed to the execute() or exec_driver_sql() methods.
+
+ These are always stored as a list of parameter entries. A single-element
+ list corresponds to a ``cursor.execute()`` call and a multiple-element
+ list corresponds to ``cursor.executemany()``.
- postfetch_cols
- a list of Column objects for which a server-side default or
- inline SQL expression value was fired off. Applies to inserts
- and updates.
"""
- def create_cursor(self):
+ no_parameters: bool
+ """True if the execution style does not use parameters"""
+
+ isinsert: bool
+ """True if the statement is an INSERT."""
+
+ isupdate: bool
+ """True if the statement is an UPDATE."""
+
+ executemany: bool
+ """True if the parameters have determined this to be an executemany"""
+
+ prefetch_cols: Optional[Sequence[Column[Any]]]
+ """a list of Column objects for which a client-side default
+ was fired off. Applies to inserts and updates."""
+
+ postfetch_cols: Optional[Sequence[Column[Any]]]
+ """a list of Column objects for which a server-side default or
+ inline SQL expression value was fired off. Applies to inserts
+ and updates."""
+
+ @classmethod
+ def _init_ddl(
+ cls,
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ compiled_ddl: DDLCompiler,
+ ) -> ExecutionContext:
+ raise NotImplementedError()
+
+ @classmethod
+ def _init_compiled(
+ cls,
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ compiled: SQLCompiler,
+ parameters: _CoreMultiExecuteParams,
+ invoked_statement: Executable,
+ extracted_parameters: _CoreSingleExecuteParams,
+ cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
+ ) -> ExecutionContext:
+ raise NotImplementedError()
+
+ @classmethod
+ def _init_statement(
+ cls,
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ statement: str,
+ parameters: _DBAPIMultiExecuteParams,
+ ) -> ExecutionContext:
+ raise NotImplementedError()
+
+ @classmethod
+ def _init_default(
+ cls,
+ dialect: Dialect,
+ connection: Connection,
+ dbapi_connection: PoolProxiedConnection,
+ execution_options: _ExecuteOptions,
+ ) -> ExecutionContext:
+ raise NotImplementedError()
+
+ def _exec_default(
+ self,
+ column: Optional[Column[Any]],
+ default: ColumnDefault,
+ type_: Optional[TypeEngine[Any]],
+ ) -> Any:
+ raise NotImplementedError()
+
+ def _set_input_sizes(self) -> None:
+ raise NotImplementedError()
+
+ def _get_cache_stats(self) -> str:
+ raise NotImplementedError()
+
+ def _setup_result_proxy(self) -> Result:
+ raise NotImplementedError()
+
+ def create_cursor(self) -> DBAPICursor:
"""Return a new cursor generated from this ExecutionContext's
connection.
@@ -2001,7 +2298,7 @@ class ExecutionContext:
raise NotImplementedError()
- def pre_exec(self):
+ def pre_exec(self) -> None:
"""Called before an execution of a compiled statement.
If a compiled statement was passed to this ExecutionContext,
@@ -2011,7 +2308,9 @@ class ExecutionContext:
raise NotImplementedError()
- def get_out_parameter_values(self, out_param_names):
+ def get_out_parameter_values(
+ self, out_param_names: Sequence[str]
+ ) -> Sequence[Any]:
"""Return a sequence of OUT parameter values from a cursor.
For dialects that support OUT parameters, this method will be called
@@ -2045,7 +2344,7 @@ class ExecutionContext:
"""
raise NotImplementedError()
- def post_exec(self):
+ def post_exec(self) -> None:
"""Called after the execution of a compiled statement.
If a compiled statement was passed to this ExecutionContext,
@@ -2055,20 +2354,20 @@ class ExecutionContext:
raise NotImplementedError()
- def handle_dbapi_exception(self, e):
+ def handle_dbapi_exception(self, e: BaseException) -> None:
"""Receive a DBAPI exception which occurred upon execute, result
fetch, etc."""
raise NotImplementedError()
- def lastrow_has_defaults(self):
+ def lastrow_has_defaults(self) -> bool:
"""Return True if the last INSERT or UPDATE row contained
inlined or database-side defaults.
"""
raise NotImplementedError()
- def get_rowcount(self):
+ def get_rowcount(self) -> Optional[int]:
"""Return the DBAPI ``cursor.rowcount`` value, or in some
cases an interpreted value.
@@ -2079,7 +2378,7 @@ class ExecutionContext:
raise NotImplementedError()
-class ConnectionEventsTarget:
+class ConnectionEventsTarget(EventTarget):
"""An object which can accept events from :class:`.ConnectionEvents`.
Includes :class:`_engine.Connection` and :class:`_engine.Engine`.
@@ -2088,6 +2387,11 @@ class ConnectionEventsTarget:
"""
+ dispatch: dispatcher[ConnectionEventsTarget]
+
+
+Connectable = ConnectionEventsTarget
+
class ExceptionContext:
"""Encapsulate information about an error condition in progress.
@@ -2101,7 +2405,7 @@ class ExceptionContext:
"""
- connection = None
+ connection: Optional[Connection]
"""The :class:`_engine.Connection` in use during the exception.
This member is present, except in the case of a failure when
@@ -2114,7 +2418,7 @@ class ExceptionContext:
"""
- engine = None
+ engine: Optional[Engine]
"""The :class:`_engine.Engine` in use during the exception.
This member should always be present, even in the case of a failure
@@ -2124,35 +2428,35 @@ class ExceptionContext:
"""
- cursor = None
+ cursor: Optional[DBAPICursor]
"""The DBAPI cursor object.
May be None.
"""
- statement = None
+ statement: Optional[str]
"""String SQL statement that was emitted directly to the DBAPI.
May be None.
"""
- parameters = None
+ parameters: Optional[_DBAPIAnyExecuteParams]
"""Parameter collection that was emitted directly to the DBAPI.
May be None.
"""
- original_exception = None
+ original_exception: BaseException
"""The exception object which was caught.
This member is always present.
"""
- sqlalchemy_exception = None
+ sqlalchemy_exception: Optional[StatementError]
"""The :class:`sqlalchemy.exc.StatementError` which wraps the original,
and will be raised if exception handling is not circumvented by the event.
@@ -2162,7 +2466,7 @@ class ExceptionContext:
"""
- chained_exception = None
+ chained_exception: Optional[BaseException]
"""The exception that was returned by the previous handler in the
exception chain, if any.
@@ -2173,7 +2477,7 @@ class ExceptionContext:
"""
- execution_context = None
+ execution_context: Optional[ExecutionContext]
"""The :class:`.ExecutionContext` corresponding to the execution
operation in progress.
@@ -2193,7 +2497,7 @@ class ExceptionContext:
"""
- is_disconnect = None
+ is_disconnect: bool
"""Represent whether the exception as occurred represents a "disconnect"
condition.
@@ -2218,7 +2522,7 @@ class ExceptionContext:
"""
- invalidate_pool_on_disconnect = True
+ invalidate_pool_on_disconnect: bool
"""Represent whether all connections in the pool should be invalidated
when a "disconnect" condition is in effect.
@@ -2250,12 +2554,14 @@ class AdaptedConnection:
__slots__ = ("_connection",)
+ _connection: Any
+
@property
- def driver_connection(self):
+ def driver_connection(self) -> Any:
"""The connection object as returned by the driver after a connect."""
return self._connection
- def run_async(self, fn):
+ def run_async(self, fn: Callable[[Any], Awaitable[_T]]) -> _T:
"""Run the awaitable returned by the given function, which is passed
the raw asyncio driver connection.
@@ -2284,5 +2590,5 @@ class AdaptedConnection:
"""
return await_only(fn(self._connection))
- def __repr__(self):
+ def __repr__(self) -> str:
return "<AdaptedConnection %s>" % self._connection
diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py
index 76e77a3f3..a0ba96603 100644
--- a/lib/sqlalchemy/engine/mock.py
+++ b/lib/sqlalchemy/engine/mock.py
@@ -8,40 +8,69 @@
from __future__ import annotations
from operator import attrgetter
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Optional
+from typing import Type
+from typing import Union
from . import url as _url
from .. import util
+if typing.TYPE_CHECKING:
+ from .base import Connection
+ from .base import Engine
+ from .interfaces import _CoreAnyExecuteParams
+ from .interfaces import _ExecuteOptionsParameter
+ from .interfaces import Dialect
+ from .url import URL
+ from ..sql.base import Executable
+ from ..sql.ddl import DDLElement
+ from ..sql.ddl import SchemaDropper
+ from ..sql.ddl import SchemaGenerator
+ from ..sql.schema import HasSchemaAttr
+
+
class MockConnection:
- def __init__(self, dialect, execute):
+ def __init__(self, dialect: Dialect, execute: Callable[..., Any]):
self._dialect = dialect
- self.execute = execute
+ self._execute_impl = execute
- engine = property(lambda s: s)
- dialect = property(attrgetter("_dialect"))
- name = property(lambda s: s._dialect.name)
+ engine: Engine = cast(Any, property(lambda s: s))
+ dialect: Dialect = cast(Any, property(attrgetter("_dialect")))
+ name: str = cast(Any, property(lambda s: s._dialect.name))
- def connect(self, **kwargs):
+ def connect(self, **kwargs: Any) -> MockConnection:
return self
- def schema_for_object(self, obj):
+ def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]:
return obj.schema
- def execution_options(self, **kw):
+ def execution_options(self, **kw: Any) -> MockConnection:
return self
def _run_ddl_visitor(
- self, visitorcallable, element, connection=None, **kwargs
- ):
+ self,
+ visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
+ element: DDLElement,
+ **kwargs: Any,
+ ) -> None:
kwargs["checkfirst"] = False
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
- def execute(self, object_, *multiparams, **params):
- raise NotImplementedError()
+ def execute(
+ self,
+ obj: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
+ return self._execute_impl(obj, parameters)
-def create_mock_engine(url, executor, **kw):
+def create_mock_engine(url: URL, executor: Any, **kw: Any) -> MockConnection:
"""Create a "mock" engine used for echoing DDL.
This is a utility function used for debugging or storing the output of DDL
@@ -96,6 +125,6 @@ def create_mock_engine(url, executor, **kw):
dialect_args[k] = kw.pop(k)
# create dialect
- dialect = dialect_cls(**dialect_args)
+ dialect = dialect_cls(**dialect_args) # type: ignore
return MockConnection(dialect, executor)
diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py
index 398c1fa36..7a6a57c03 100644
--- a/lib/sqlalchemy/engine/processors.py
+++ b/lib/sqlalchemy/engine/processors.py
@@ -14,9 +14,20 @@ They all share one common characteristic: None is passed through unchanged.
"""
from __future__ import annotations
+import typing
+
from ._py_processors import str_to_datetime_processor_factory # noqa
+from ..util._has_cy import HAS_CYEXTENSION
-try:
+if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
+ from ._py_processors import int_to_boolean # noqa
+ from ._py_processors import str_to_date # noqa
+ from ._py_processors import str_to_datetime # noqa
+ from ._py_processors import str_to_time # noqa
+ from ._py_processors import to_decimal_processor_factory # noqa
+ from ._py_processors import to_float # noqa
+ from ._py_processors import to_str # noqa
+else:
from sqlalchemy.cyextension.processors import (
DecimalResultProcessor,
) # noqa
@@ -34,12 +45,3 @@ try:
# Decimal('5.00000') whereas the C implementation will
# return Decimal('5'). These are equivalent of course.
return DecimalResultProcessor(target_class, "%%.%df" % scale).process
-
-except ImportError:
- from ._py_processors import int_to_boolean # noqa
- from ._py_processors import str_to_date # noqa
- from ._py_processors import str_to_datetime # noqa
- from ._py_processors import str_to_time # noqa
- from ._py_processors import to_decimal_processor_factory # noqa
- from ._py_processors import to_float # noqa
- from ._py_processors import to_str # noqa
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 3ba1ae519..0951d5770 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -9,13 +9,26 @@
from __future__ import annotations
-import collections.abc as collections_abc
+from enum import Enum
import functools
import itertools
import operator
import typing
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
from .row import Row
+from .row import RowMapping
from .. import exc
from .. import util
from ..sql.base import _generative
@@ -25,9 +38,42 @@ from ..util._has_cy import HAS_CYEXTENSION
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_row import tuplegetter
+ from ._py_row import tuplegetter as tuplegetter
else:
- from sqlalchemy.cyextension.resultproxy import tuplegetter
+ from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter
+
+if typing.TYPE_CHECKING:
+ from .row import RowMapping
+ from ..sql.schema import Column
+
+_KeyType = Union[str, "Column[Any]"]
+_KeyIndexType = Union[str, "Column[Any]", int]
+
+# is overridden in cursor using _CursorKeyMapRecType
+_KeyMapRecType = Any
+
+_KeyMapType = Dict[_KeyType, _KeyMapRecType]
+
+
+_RowData = Union[Row, RowMapping, Any]
+"""A generic form of "row" that accommodates for the different kinds of
+"rows" that different result objects return, including row, row mapping, and
+scalar values"""
+
+_RawRowType = Tuple[Any, ...]
+"""represents the kind of row we get from a DBAPI cursor"""
+
+_InterimRowType = Union[Row, RowMapping, Any, _RawRowType]
+"""a catchall "anything" kind of return type that can be applied
+across all the result types
+
+"""
+
+_ProcessorType = Callable[[Any], Any]
+_ProcessorsType = Sequence[Optional[_ProcessorType]]
+_TupleGetterType = Callable[[Sequence[Any]], Tuple[Any, ...]]
+_UniqueFilterType = Callable[[Any], Any]
+_UniqueFilterStateType = Tuple[Set[Any], Optional[_UniqueFilterType]]
class ResultMetaData:
@@ -35,40 +81,58 @@ class ResultMetaData:
__slots__ = ()
- _tuplefilter = None
- _translated_indexes = None
- _unique_filters = None
+ _tuplefilter: Optional[_TupleGetterType] = None
+ _translated_indexes: Optional[Sequence[int]] = None
+ _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None
+ _keymap: _KeyMapType
+ _keys: Sequence[str]
+ _processors: Optional[_ProcessorsType]
@property
- def keys(self):
+ def keys(self) -> RMKeyView:
return RMKeyView(self)
- def _has_key(self, key):
+ def _has_key(self, key: object) -> bool:
raise NotImplementedError()
- def _for_freeze(self):
+ def _for_freeze(self) -> ResultMetaData:
raise NotImplementedError()
- def _key_fallback(self, key, err, raiseerr=True):
+ def _key_fallback(
+ self, key: _KeyType, err: Exception, raiseerr: bool = True
+ ) -> NoReturn:
assert raiseerr
raise KeyError(key) from err
- def _raise_for_nonint(self, key):
- raise TypeError(
- "TypeError: tuple indices must be integers or slices, not %s"
- % type(key).__name__
+ def _raise_for_ambiguous_column_name(
+ self, rec: _KeyMapRecType
+ ) -> NoReturn:
+ raise NotImplementedError(
+ "ambiguous column name logic is implemented for "
+ "CursorResultMetaData"
)
- def _index_for_key(self, keys, raiseerr):
+ def _index_for_key(
+ self, key: _KeyIndexType, raiseerr: bool
+ ) -> Optional[int]:
raise NotImplementedError()
- def _metadata_for_keys(self, key):
+ def _indexes_for_keys(
+ self, keys: Sequence[_KeyIndexType]
+ ) -> Sequence[int]:
raise NotImplementedError()
- def _reduce(self, keys):
+ def _metadata_for_keys(
+ self, keys: Sequence[_KeyIndexType]
+ ) -> Iterator[_KeyMapRecType]:
raise NotImplementedError()
- def _getter(self, key, raiseerr=True):
+ def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
+ raise NotImplementedError()
+
+ def _getter(
+ self, key: Any, raiseerr: bool = True
+ ) -> Optional[Callable[[Sequence[_RowData]], _RowData]]:
index = self._index_for_key(key, raiseerr)
@@ -77,28 +141,33 @@ class ResultMetaData:
else:
return None
- def _row_as_tuple_getter(self, keys):
+ def _row_as_tuple_getter(
+ self, keys: Sequence[_KeyIndexType]
+ ) -> _TupleGetterType:
indexes = self._indexes_for_keys(keys)
return tuplegetter(*indexes)
-class RMKeyView(collections_abc.KeysView):
+class RMKeyView(typing.KeysView[Any]):
__slots__ = ("_parent", "_keys")
- def __init__(self, parent):
+ _parent: ResultMetaData
+ _keys: Sequence[str]
+
+ def __init__(self, parent: ResultMetaData):
self._parent = parent
self._keys = [k for k in parent._keys if k is not None]
- def __len__(self):
+ def __len__(self) -> int:
return len(self._keys)
- def __repr__(self):
+ def __repr__(self) -> str:
return "{0.__class__.__name__}({0._keys!r})".format(self)
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return iter(self._keys)
- def __contains__(self, item):
+ def __contains__(self, item: Any) -> bool:
if isinstance(item, int):
return False
@@ -106,10 +175,10 @@ class RMKeyView(collections_abc.KeysView):
# which also don't seem to be tested in test_resultset right now
return self._parent._has_key(item)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return list(other) == list(self)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return list(other) != list(self)
@@ -125,20 +194,21 @@ class SimpleResultMetaData(ResultMetaData):
"_unique_filters",
)
+ _keys: Sequence[str]
+
def __init__(
self,
- keys,
- extra=None,
- _processors=None,
- _tuplefilter=None,
- _translated_indexes=None,
- _unique_filters=None,
+ keys: Sequence[str],
+ extra: Optional[Sequence[Any]] = None,
+ _processors: Optional[_ProcessorsType] = None,
+ _tuplefilter: Optional[_TupleGetterType] = None,
+ _translated_indexes: Optional[Sequence[int]] = None,
+ _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None,
):
self._keys = list(keys)
self._tuplefilter = _tuplefilter
self._translated_indexes = _translated_indexes
self._unique_filters = _unique_filters
-
if extra:
recs_names = [
(
@@ -157,10 +227,10 @@ class SimpleResultMetaData(ResultMetaData):
self._processors = _processors
- def _has_key(self, key):
+ def _has_key(self, key: object) -> bool:
return key in self._keymap
- def _for_freeze(self):
+ def _for_freeze(self) -> ResultMetaData:
unique_filters = self._unique_filters
if unique_filters and self._tuplefilter:
unique_filters = self._tuplefilter(unique_filters)
@@ -173,28 +243,28 @@ class SimpleResultMetaData(ResultMetaData):
_unique_filters=unique_filters,
)
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {
"_keys": self._keys,
"_translated_indexes": self._translated_indexes,
}
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
if state["_translated_indexes"]:
_translated_indexes = state["_translated_indexes"]
_tuplefilter = tuplegetter(*_translated_indexes)
else:
_translated_indexes = _tuplefilter = None
- self.__init__(
+ self.__init__( # type: ignore
state["_keys"],
_translated_indexes=_translated_indexes,
_tuplefilter=_tuplefilter,
)
- def _contains(self, value, row):
+ def _contains(self, value: Any, row: Row) -> bool:
return value in row._data
- def _index_for_key(self, key, raiseerr=True):
+ def _index_for_key(self, key: Any, raiseerr: bool = True) -> int:
if int in key.__class__.__mro__:
key = self._keys[key]
try:
@@ -202,12 +272,14 @@ class SimpleResultMetaData(ResultMetaData):
except KeyError as ke:
rec = self._key_fallback(key, ke, raiseerr)
- return rec[0]
+ return rec[0] # type: ignore[no-any-return]
- def _indexes_for_keys(self, keys):
+ def _indexes_for_keys(self, keys: Sequence[Any]) -> Sequence[int]:
return [self._keymap[key][0] for key in keys]
- def _metadata_for_keys(self, keys):
+ def _metadata_for_keys(
+ self, keys: Sequence[Any]
+ ) -> Iterator[_KeyMapRecType]:
for key in keys:
if int in key.__class__.__mro__:
key = self._keys[key]
@@ -219,7 +291,7 @@ class SimpleResultMetaData(ResultMetaData):
yield rec
- def _reduce(self, keys):
+ def _reduce(self, keys: Sequence[Any]) -> ResultMetaData:
try:
metadata_for_keys = [
self._keymap[
@@ -230,7 +302,10 @@ class SimpleResultMetaData(ResultMetaData):
except KeyError as ke:
self._key_fallback(ke.args[0], ke, True)
- indexes, new_keys, extra = zip(*metadata_for_keys)
+ indexes: Sequence[int]
+ new_keys: Sequence[str]
+ extra: Sequence[Any]
+ indexes, new_keys, extra = zip(*metadata_for_keys) # type: ignore
if self._translated_indexes:
indexes = [self._translated_indexes[idx] for idx in indexes]
@@ -249,7 +324,9 @@ class SimpleResultMetaData(ResultMetaData):
return new_metadata
-def result_tuple(fields, extra=None):
+def result_tuple(
+ fields: Sequence[str], extra: Optional[Any] = None
+) -> Callable[[_RawRowType], Row]:
parent = SimpleResultMetaData(fields, extra)
return functools.partial(
Row, parent, parent._processors, parent._keymap, Row._default_key_style
@@ -259,31 +336,58 @@ def result_tuple(fields, extra=None):
# a symbol that indicates to internal Result methods that
# "no row is returned". We can't use None for those cases where a scalar
# filter is applied to rows.
-_NO_ROW = util.symbol("NO_ROW")
+class _NoRow(Enum):
+ _NO_ROW = 0
-SelfResultInternal = typing.TypeVar(
- "SelfResultInternal", bound="ResultInternal"
-)
+
+_NO_ROW = _NoRow._NO_ROW
+
+SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal")
class ResultInternal(InPlaceGenerative):
- _real_result = None
- _generate_rows = True
- _unique_filter_state = None
- _post_creational_filter = None
+ _real_result: Optional[Result] = None
+ _generate_rows: bool = True
+ _row_logging_fn: Optional[Callable[[Any], Any]]
+
+ _unique_filter_state: Optional[_UniqueFilterStateType] = None
+ _post_creational_filter: Optional[Callable[[Any], Any]] = None
_is_cursor = False
+ _metadata: ResultMetaData
+
+ _source_supports_scalars: bool
+
+ def _fetchiter_impl(self) -> Iterator[_InterimRowType]:
+ raise NotImplementedError()
+
+ def _fetchone_impl(
+ self, hard_close: bool = False
+ ) -> Optional[_InterimRowType]:
+ raise NotImplementedError()
+
+ def _fetchmany_impl(
+ self, size: Optional[int] = None
+ ) -> List[_InterimRowType]:
+ raise NotImplementedError()
+
+ def _fetchall_impl(self) -> List[_InterimRowType]:
+ raise NotImplementedError()
+
+ def _soft_close(self, hard: bool = False) -> None:
+ raise NotImplementedError()
+
@HasMemoized.memoized_attribute
- def _row_getter(self):
+ def _row_getter(self) -> Optional[Callable[..., _RowData]]:
real_result = self._real_result if self._real_result else self
if real_result._source_supports_scalars:
if not self._generate_rows:
return None
else:
- _proc = real_result._process_row
+ _proc = Row
- def process_row(
+ def process_row( # type: ignore
metadata, processors, keymap, key_style, scalar_obj
):
return _proc(
@@ -291,9 +395,9 @@ class ResultInternal(InPlaceGenerative):
)
else:
- process_row = real_result._process_row
+ process_row = Row # type: ignore
- key_style = real_result._process_row._default_key_style
+ key_style = Row._default_key_style
metadata = self._metadata
keymap = metadata._keymap
@@ -304,19 +408,19 @@ class ResultInternal(InPlaceGenerative):
if processors:
processors = tf(processors)
- _make_row_orig = functools.partial(
+ _make_row_orig: Callable[..., Any] = functools.partial(
process_row, metadata, processors, keymap, key_style
)
- def make_row(row):
- return _make_row_orig(tf(row))
+ def make_row(row: _InterimRowType) -> _InterimRowType:
+ return _make_row_orig(tf(row)) # type: ignore
else:
- make_row = functools.partial(
+ make_row = functools.partial( # type: ignore
process_row, metadata, processors, keymap, key_style
)
- fns = ()
+ fns: Tuple[Any, ...] = ()
if real_result._row_logging_fn:
fns = (real_result._row_logging_fn,)
@@ -326,16 +430,16 @@ class ResultInternal(InPlaceGenerative):
if fns:
_make_row = make_row
- def make_row(row):
- row = _make_row(row)
+ def make_row(row: _InterimRowType) -> _InterimRowType:
+ interim_row = _make_row(row)
for fn in fns:
- row = fn(row)
- return row
+ interim_row = fn(interim_row)
+ return interim_row
return make_row
@HasMemoized.memoized_attribute
- def _iterator_getter(self):
+ def _iterator_getter(self) -> Callable[..., Iterator[_RowData]]:
make_row = self._row_getter
@@ -344,9 +448,9 @@ class ResultInternal(InPlaceGenerative):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- def iterrows(self):
- for row in self._fetchiter_impl():
- obj = make_row(row) if make_row else row
+ def iterrows(self: Result) -> Iterator[_RowData]:
+ for raw_row in self._fetchiter_impl():
+ obj = make_row(raw_row) if make_row else raw_row
hashed = strategy(obj) if strategy else obj
if hashed in uniques:
continue
@@ -357,27 +461,29 @@ class ResultInternal(InPlaceGenerative):
else:
- def iterrows(self):
+ def iterrows(self: Result) -> Iterator[_RowData]:
for row in self._fetchiter_impl():
- row = make_row(row) if make_row else row
+ row = make_row(row) if make_row else row # type: ignore
if post_creational_filter:
row = post_creational_filter(row)
yield row
return iterrows
- def _raw_all_rows(self):
+ def _raw_all_rows(self) -> List[_RowData]:
make_row = self._row_getter
+ assert make_row is not None
rows = self._fetchall_impl()
return [make_row(row) for row in rows]
- def _allrows(self):
+ def _allrows(self) -> List[_RowData]:
post_creational_filter = self._post_creational_filter
make_row = self._row_getter
rows = self._fetchall_impl()
+ made_rows: List[_InterimRowType]
if make_row:
made_rows = [make_row(row) for row in rows]
else:
@@ -386,7 +492,7 @@ class ResultInternal(InPlaceGenerative):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- rows = [
+ interim_rows = [
made_row
for made_row, sig_row in [
(
@@ -395,17 +501,19 @@ class ResultInternal(InPlaceGenerative):
)
for made_row in made_rows
]
- if sig_row not in uniques and not uniques.add(sig_row)
+ if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa E501
]
else:
- rows = made_rows
+ interim_rows = made_rows
if post_creational_filter:
- rows = [post_creational_filter(row) for row in rows]
- return rows
+ interim_rows = [
+ post_creational_filter(row) for row in interim_rows
+ ]
+ return interim_rows
@HasMemoized.memoized_attribute
- def _onerow_getter(self):
+ def _onerow_getter(self) -> Callable[..., Union[_NoRow, _RowData]]:
make_row = self._row_getter
post_creational_filter = self._post_creational_filter
@@ -413,7 +521,7 @@ class ResultInternal(InPlaceGenerative):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- def onerow(self):
+ def onerow(self: Result) -> Union[_NoRow, _RowData]:
_onerow = self._fetchone_impl
while True:
row = _onerow()
@@ -432,20 +540,22 @@ class ResultInternal(InPlaceGenerative):
else:
- def onerow(self):
+ def onerow(self: Result) -> Union[_NoRow, _RowData]:
row = self._fetchone_impl()
if row is None:
return _NO_ROW
else:
- row = make_row(row) if make_row else row
+ interim_row: _InterimRowType = (
+ make_row(row) if make_row else row
+ )
if post_creational_filter:
- row = post_creational_filter(row)
- return row
+ interim_row = post_creational_filter(interim_row)
+ return interim_row
return onerow
@HasMemoized.memoized_attribute
- def _manyrow_getter(self):
+ def _manyrow_getter(self) -> Callable[..., List[_RowData]]:
make_row = self._row_getter
post_creational_filter = self._post_creational_filter
@@ -453,7 +563,12 @@ class ResultInternal(InPlaceGenerative):
if self._unique_filter_state:
uniques, strategy = self._unique_strategy
- def filterrows(make_row, rows, strategy, uniques):
+ def filterrows(
+ make_row: Optional[Callable[..., _RowData]],
+ rows: List[Any],
+ strategy: Optional[Callable[[Sequence[Any]], Any]],
+ uniques: Set[Any],
+ ) -> List[Row]:
if make_row:
rows = [make_row(row) for row in rows]
@@ -466,11 +581,11 @@ class ResultInternal(InPlaceGenerative):
return [
made_row
for made_row, sig_row in made_rows
- if sig_row not in uniques and not uniques.add(sig_row)
+ if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa: E501
]
- def manyrows(self, num):
- collect = []
+ def manyrows(self: Result, num: Optional[int]) -> List[_RowData]:
+ collect: List[_RowData] = []
_manyrows = self._fetchmany_impl
@@ -488,6 +603,7 @@ class ResultInternal(InPlaceGenerative):
else:
rows = _manyrows(num)
num = len(rows)
+ assert make_row is not None
collect.extend(
filterrows(make_row, rows, strategy, uniques)
)
@@ -495,6 +611,8 @@ class ResultInternal(InPlaceGenerative):
else:
num_required = num
+ assert num is not None
+
while num_required:
rows = _manyrows(num_required)
if not rows:
@@ -511,14 +629,14 @@ class ResultInternal(InPlaceGenerative):
else:
- def manyrows(self, num):
+ def manyrows(self: Result, num: Optional[int]) -> List[_RowData]:
if num is None:
real_result = (
self._real_result if self._real_result else self
)
num = real_result._yield_per
- rows = self._fetchmany_impl(num)
+ rows: List[_InterimRowType] = self._fetchmany_impl(num)
if make_row:
rows = [make_row(row) for row in rows]
if post_creational_filter:
@@ -529,13 +647,13 @@ class ResultInternal(InPlaceGenerative):
def _only_one_row(
self,
- raise_for_second_row,
- raise_for_none,
- scalar,
- ):
+ raise_for_second_row: bool,
+ raise_for_none: bool,
+ scalar: bool,
+ ) -> Optional[_RowData]:
onerow = self._fetchone_impl
- row = onerow(hard_close=True)
+ row: _InterimRowType = onerow(hard_close=True)
if row is None:
if raise_for_none:
raise exc.NoResultFound(
@@ -565,7 +683,7 @@ class ResultInternal(InPlaceGenerative):
existing_row_hash = strategy(row) if strategy else row
while True:
- next_row = onerow(hard_close=True)
+ next_row: Any = onerow(hard_close=True)
if next_row is None:
next_row = _NO_ROW
break
@@ -574,6 +692,7 @@ class ResultInternal(InPlaceGenerative):
next_row = make_row(next_row) if make_row else next_row
if strategy:
+ assert next_row is not _NO_ROW
if existing_row_hash == strategy(next_row):
continue
elif row == next_row:
@@ -608,14 +727,14 @@ class ResultInternal(InPlaceGenerative):
row = post_creational_filter(row)
if scalar and make_row:
- return row[0]
+ return row[0] # type: ignore
else:
return row
- def _iter_impl(self):
+ def _iter_impl(self) -> Iterator[_RowData]:
return self._iterator_getter(self)
- def _next_impl(self):
+ def _next_impl(self) -> _RowData:
row = self._onerow_getter(self)
if row is _NO_ROW:
raise StopIteration()
@@ -624,11 +743,14 @@ class ResultInternal(InPlaceGenerative):
@_generative
def _column_slices(
- self: SelfResultInternal, indexes
+ self: SelfResultInternal, indexes: Sequence[_KeyIndexType]
) -> SelfResultInternal:
real_result = self._real_result if self._real_result else self
- if real_result._source_supports_scalars and len(indexes) == 1:
+ if (
+ real_result._source_supports_scalars # type: ignore[attr-defined] # noqa E501
+ and len(indexes) == 1
+ ):
self._generate_rows = False
else:
self._generate_rows = True
@@ -637,7 +759,8 @@ class ResultInternal(InPlaceGenerative):
return self
@HasMemoized.memoized_attribute
- def _unique_strategy(self):
+ def _unique_strategy(self) -> _UniqueFilterStateType:
+ assert self._unique_filter_state is not None
uniques, strategy = self._unique_filter_state
real_result = (
@@ -660,8 +783,10 @@ class ResultInternal(InPlaceGenerative):
class _WithKeys:
+ _metadata: ResultMetaData
+
# used mainly to share documentation on the keys method.
- def keys(self):
+ def keys(self) -> RMKeyView:
"""Return an iterable view which yields the string keys that would
be represented by each :class:`.Row`.
@@ -681,7 +806,7 @@ class _WithKeys:
return self._metadata.keys
-SelfResult = typing.TypeVar("SelfResult", bound="Result")
+SelfResult = TypeVar("SelfResult", bound="Result")
class Result(_WithKeys, ResultInternal):
@@ -709,23 +834,18 @@ class Result(_WithKeys, ResultInternal):
"""
- _process_row = Row
-
- _row_logging_fn = None
+ _row_logging_fn: Optional[Callable[[Row], Row]] = None
- _source_supports_scalars = False
+ _source_supports_scalars: bool = False
- _yield_per = None
+ _yield_per: Optional[int] = None
- _attributes = util.immutabledict()
+ _attributes: util.immutabledict[Any, Any] = util.immutabledict()
- def __init__(self, cursor_metadata):
+ def __init__(self, cursor_metadata: ResultMetaData):
self._metadata = cursor_metadata
- def _soft_close(self, hard=False):
- raise NotImplementedError()
-
- def close(self):
+ def close(self) -> None:
"""close this :class:`_result.Result`.
The behavior of this method is implementation specific, and is
@@ -748,7 +868,7 @@ class Result(_WithKeys, ResultInternal):
self._soft_close(hard=True)
@_generative
- def yield_per(self: SelfResult, num) -> SelfResult:
+ def yield_per(self: SelfResult, num: int) -> SelfResult:
"""Configure the row-fetching strategy to fetch num rows at a time.
This impacts the underlying behavior of the result when iterating over
@@ -785,7 +905,9 @@ class Result(_WithKeys, ResultInternal):
return self
@_generative
- def unique(self: SelfResult, strategy=None) -> SelfResult:
+ def unique(
+ self: SelfResult, strategy: Optional[_UniqueFilterType] = None
+ ) -> SelfResult:
"""Apply unique filtering to the objects returned by this
:class:`_engine.Result`.
@@ -826,7 +948,7 @@ class Result(_WithKeys, ResultInternal):
return self
def columns(
- self: SelfResultInternal, *col_expressions
+ self: SelfResultInternal, *col_expressions: _KeyIndexType
) -> SelfResultInternal:
r"""Establish the columns that should be returned in each row.
@@ -865,7 +987,7 @@ class Result(_WithKeys, ResultInternal):
"""
return self._column_slices(col_expressions)
- def scalars(self, index=0) -> "ScalarResult":
+ def scalars(self, index: _KeyIndexType = 0) -> ScalarResult:
"""Return a :class:`_result.ScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
@@ -890,7 +1012,9 @@ class Result(_WithKeys, ResultInternal):
"""
return ScalarResult(self, index)
- def _getter(self, key, raiseerr=True):
+ def _getter(
+ self, key: _KeyIndexType, raiseerr: bool = True
+ ) -> Optional[Callable[[Sequence[Any]], _RowData]]:
"""return a callable that will retrieve the given key from a
:class:`.Row`.
@@ -901,7 +1025,7 @@ class Result(_WithKeys, ResultInternal):
)
return self._metadata._getter(key, raiseerr)
- def _tuple_getter(self, keys):
+ def _tuple_getter(self, keys: Sequence[_KeyIndexType]) -> _TupleGetterType:
"""return a callable that will retrieve the given keys from a
:class:`.Row`.
@@ -912,7 +1036,7 @@ class Result(_WithKeys, ResultInternal):
)
return self._metadata._row_as_tuple_getter(keys)
- def mappings(self) -> "MappingResult":
+ def mappings(self) -> MappingResult:
"""Apply a mappings filter to returned rows, returning an instance of
:class:`_result.MappingResult`.
@@ -928,7 +1052,7 @@ class Result(_WithKeys, ResultInternal):
return MappingResult(self)
- def _raw_row_iterator(self):
+ def _raw_row_iterator(self) -> Iterator[_RowData]:
"""Return a safe iterator that yields raw row data.
This is used by the :meth:`._engine.Result.merge` method
@@ -937,25 +1061,13 @@ class Result(_WithKeys, ResultInternal):
"""
raise NotImplementedError()
- def _fetchiter_impl(self):
- raise NotImplementedError()
-
- def _fetchone_impl(self, hard_close=False):
- raise NotImplementedError()
-
- def _fetchall_impl(self):
- raise NotImplementedError()
-
- def _fetchmany_impl(self, size=None):
- raise NotImplementedError()
-
- def __iter__(self):
+ def __iter__(self) -> Iterator[_RowData]:
return self._iter_impl()
- def __next__(self):
+ def __next__(self) -> _RowData:
return self._next_impl()
- def partitions(self, size=None):
+ def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]:
"""Iterate through sub-lists of rows of the size given.
Each list will be of the size given, excluding the last list to
@@ -989,16 +1101,16 @@ class Result(_WithKeys, ResultInternal):
while True:
partition = getter(self, size)
if partition:
- yield partition
+ yield partition # type: ignore
else:
break
- def fetchall(self):
+ def fetchall(self) -> List[Row]:
"""A synonym for the :meth:`_engine.Result.all` method."""
- return self._allrows()
+ return self._allrows() # type: ignore[return-value]
- def fetchone(self):
+ def fetchone(self) -> Optional[Row]:
"""Fetch one row.
When all rows are exhausted, returns None.
@@ -1018,9 +1130,9 @@ class Result(_WithKeys, ResultInternal):
if row is _NO_ROW:
return None
else:
- return row
+ return row # type: ignore[return-value]
- def fetchmany(self, size=None):
+ def fetchmany(self, size: Optional[int] = None) -> List[Row]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
@@ -1035,9 +1147,9 @@ class Result(_WithKeys, ResultInternal):
"""
- return self._manyrow_getter(self, size)
+ return self._manyrow_getter(self, size) # type: ignore[return-value]
- def all(self):
+ def all(self) -> List[Row]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
@@ -1049,9 +1161,9 @@ class Result(_WithKeys, ResultInternal):
"""
- return self._allrows()
+ return self._allrows() # type: ignore[return-value]
- def first(self):
+ def first(self) -> Optional[Row]:
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
@@ -1083,11 +1195,11 @@ class Result(_WithKeys, ResultInternal):
"""
- return self._only_one_row(
+ return self._only_one_row( # type: ignore[return-value]
raise_for_second_row=False, raise_for_none=False, scalar=False
)
- def one_or_none(self):
+ def one_or_none(self) -> Optional[Row]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
@@ -1107,11 +1219,11 @@ class Result(_WithKeys, ResultInternal):
:meth:`_result.Result.one`
"""
- return self._only_one_row(
+ return self._only_one_row( # type: ignore[return-value]
raise_for_second_row=True, raise_for_none=False, scalar=False
)
- def scalar_one(self):
+ def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`.Result.scalars` and then
@@ -1128,7 +1240,7 @@ class Result(_WithKeys, ResultInternal):
raise_for_second_row=True, raise_for_none=True, scalar=True
)
- def scalar_one_or_none(self):
+ def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`.Result.scalars` and then
@@ -1145,7 +1257,7 @@ class Result(_WithKeys, ResultInternal):
raise_for_second_row=True, raise_for_none=False, scalar=True
)
- def one(self):
+ def one(self) -> Row:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
@@ -1172,11 +1284,11 @@ class Result(_WithKeys, ResultInternal):
:meth:`_result.Result.scalar_one`
"""
- return self._only_one_row(
+ return self._only_one_row( # type: ignore[return-value]
raise_for_second_row=True, raise_for_none=True, scalar=False
)
- def scalar(self):
+ def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
Returns None if there are no rows to fetch.
@@ -1194,7 +1306,7 @@ class Result(_WithKeys, ResultInternal):
raise_for_second_row=False, raise_for_none=False, scalar=True
)
- def freeze(self):
+ def freeze(self) -> FrozenResult:
"""Return a callable object that will produce copies of this
:class:`.Result` when invoked.
@@ -1217,7 +1329,7 @@ class Result(_WithKeys, ResultInternal):
return FrozenResult(self)
- def merge(self, *others):
+ def merge(self, *others: Result) -> MergedResult:
"""Merge this :class:`.Result` with other compatible result
objects.
@@ -1240,28 +1352,37 @@ class FilterResult(ResultInternal):
"""
- _post_creational_filter = None
+ _post_creational_filter: Optional[Callable[[Any], Any]] = None
- def _soft_close(self, hard=False):
+ _real_result: Result
+
+ def _soft_close(self, hard: bool = False) -> None:
self._real_result._soft_close(hard=hard)
@property
- def _attributes(self):
+ def _attributes(self) -> Dict[Any, Any]:
return self._real_result._attributes
- def _fetchiter_impl(self):
+ def _fetchiter_impl(self) -> Iterator[_InterimRowType]:
return self._real_result._fetchiter_impl()
- def _fetchone_impl(self, hard_close=False):
+ def _fetchone_impl(
+ self, hard_close: bool = False
+ ) -> Optional[_InterimRowType]:
return self._real_result._fetchone_impl(hard_close=hard_close)
- def _fetchall_impl(self):
+ def _fetchall_impl(self) -> List[_InterimRowType]:
return self._real_result._fetchall_impl()
- def _fetchmany_impl(self, size=None):
+ def _fetchmany_impl(
+ self, size: Optional[int] = None
+ ) -> List[_InterimRowType]:
return self._real_result._fetchmany_impl(size=size)
+SelfScalarResult = TypeVar("SelfScalarResult", bound="ScalarResult")
+
+
class ScalarResult(FilterResult):
"""A wrapper for a :class:`_result.Result` that returns scalar values
rather than :class:`_row.Row` values.
@@ -1280,7 +1401,9 @@ class ScalarResult(FilterResult):
_generate_rows = False
- def __init__(self, real_result, index):
+ _post_creational_filter: Optional[Callable[[Any], Any]]
+
+ def __init__(self, real_result: Result, index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
@@ -1292,7 +1415,9 @@ class ScalarResult(FilterResult):
self._unique_filter_state = real_result._unique_filter_state
- def unique(self, strategy=None):
+ def unique(
+ self: SelfScalarResult, strategy: Optional[_UniqueFilterType] = None
+ ) -> SelfScalarResult:
"""Apply unique filtering to the objects returned by this
:class:`_engine.ScalarResult`.
@@ -1302,7 +1427,7 @@ class ScalarResult(FilterResult):
self._unique_filter_state = (set(), strategy)
return self
- def partitions(self, size=None):
+ def partitions(self, size: Optional[int] = None) -> Iterator[List[Any]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_result.Result.partitions` except that
@@ -1320,12 +1445,12 @@ class ScalarResult(FilterResult):
else:
break
- def fetchall(self):
+ def fetchall(self) -> List[Any]:
"""A synonym for the :meth:`_engine.ScalarResult.all` method."""
return self._allrows()
- def fetchmany(self, size=None):
+ def fetchmany(self, size: Optional[int] = None) -> List[Any]:
"""Fetch many objects.
Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1335,7 +1460,7 @@ class ScalarResult(FilterResult):
"""
return self._manyrow_getter(self, size)
- def all(self):
+ def all(self) -> List[Any]:
"""Return all scalar values in a list.
Equivalent to :meth:`_result.Result.all` except that
@@ -1345,13 +1470,13 @@ class ScalarResult(FilterResult):
"""
return self._allrows()
- def __iter__(self):
+ def __iter__(self) -> Iterator[Any]:
return self._iter_impl()
- def __next__(self):
+ def __next__(self) -> Any:
return self._next_impl()
- def first(self):
+ def first(self) -> Optional[Any]:
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_result.Result.first` except that
@@ -1364,7 +1489,7 @@ class ScalarResult(FilterResult):
raise_for_second_row=False, raise_for_none=False, scalar=False
)
- def one_or_none(self):
+ def one_or_none(self) -> Optional[Any]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_result.Result.one_or_none` except that
@@ -1376,7 +1501,7 @@ class ScalarResult(FilterResult):
raise_for_second_row=True, raise_for_none=False, scalar=False
)
- def one(self):
+ def one(self) -> Any:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_result.Result.one` except that
@@ -1389,6 +1514,9 @@ class ScalarResult(FilterResult):
)
+SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult")
+
+
class MappingResult(_WithKeys, FilterResult):
"""A wrapper for a :class:`_engine.Result` that returns dictionary values
rather than :class:`_engine.Row` values.
@@ -1402,14 +1530,16 @@ class MappingResult(_WithKeys, FilterResult):
_post_creational_filter = operator.attrgetter("_mapping")
- def __init__(self, result):
+ def __init__(self, result: Result):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
if result._source_supports_scalars:
self._metadata = self._metadata._reduce([0])
- def unique(self, strategy=None):
+ def unique(
+ self: SelfMappingResult, strategy: Optional[_UniqueFilterType] = None
+ ) -> SelfMappingResult:
"""Apply unique filtering to the objects returned by this
:class:`_engine.MappingResult`.
@@ -1419,11 +1549,15 @@ class MappingResult(_WithKeys, FilterResult):
self._unique_filter_state = (set(), strategy)
return self
- def columns(self, *col_expressions):
+ def columns(
+ self: SelfMappingResult, *col_expressions: _KeyIndexType
+ ) -> SelfMappingResult:
r"""Establish the columns that should be returned in each row."""
return self._column_slices(col_expressions)
- def partitions(self, size=None):
+ def partitions(
+ self, size: Optional[int] = None
+ ) -> Iterator[List[RowMapping]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_result.Result.partitions` except that
@@ -1437,16 +1571,16 @@ class MappingResult(_WithKeys, FilterResult):
while True:
partition = getter(self, size)
if partition:
- yield partition
+ yield partition # type: ignore
else:
break
- def fetchall(self):
+ def fetchall(self) -> List[RowMapping]:
"""A synonym for the :meth:`_engine.MappingResult.all` method."""
- return self._allrows()
+ return self._allrows() # type: ignore[return-value]
- def fetchone(self):
+ def fetchone(self) -> Optional[RowMapping]:
"""Fetch one object.
Equivalent to :meth:`_result.Result.fetchone` except that
@@ -1459,9 +1593,9 @@ class MappingResult(_WithKeys, FilterResult):
if row is _NO_ROW:
return None
else:
- return row
+ return row # type: ignore[return-value]
- def fetchmany(self, size=None):
+ def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
"""Fetch many objects.
Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1470,9 +1604,9 @@ class MappingResult(_WithKeys, FilterResult):
"""
- return self._manyrow_getter(self, size)
+ return self._manyrow_getter(self, size) # type: ignore[return-value]
- def all(self):
+ def all(self) -> List[RowMapping]:
"""Return all scalar values in a list.
Equivalent to :meth:`_result.Result.all` except that
@@ -1481,15 +1615,15 @@ class MappingResult(_WithKeys, FilterResult):
"""
- return self._allrows()
+ return self._allrows() # type: ignore[return-value]
- def __iter__(self):
- return self._iter_impl()
+ def __iter__(self) -> Iterator[RowMapping]:
+ return self._iter_impl() # type: ignore[return-value]
- def __next__(self):
- return self._next_impl()
+ def __next__(self) -> RowMapping:
+ return self._next_impl() # type: ignore[return-value]
- def first(self):
+ def first(self) -> Optional[RowMapping]:
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_result.Result.first` except that
@@ -1498,11 +1632,11 @@ class MappingResult(_WithKeys, FilterResult):
"""
- return self._only_one_row(
+ return self._only_one_row( # type: ignore[return-value]
raise_for_second_row=False, raise_for_none=False, scalar=False
)
- def one_or_none(self):
+ def one_or_none(self) -> Optional[RowMapping]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_result.Result.one_or_none` except that
@@ -1510,11 +1644,11 @@ class MappingResult(_WithKeys, FilterResult):
are returned.
"""
- return self._only_one_row(
+ return self._only_one_row( # type: ignore[return-value]
raise_for_second_row=True, raise_for_none=False, scalar=False
)
- def one(self):
+ def one(self) -> RowMapping:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_result.Result.one` except that
@@ -1522,7 +1656,7 @@ class MappingResult(_WithKeys, FilterResult):
are returned.
"""
- return self._only_one_row(
+ return self._only_one_row( # type: ignore[return-value]
raise_for_second_row=True, raise_for_none=True, scalar=False
)
@@ -1566,7 +1700,9 @@ class FrozenResult:
"""
- def __init__(self, result):
+ data: Sequence[Any]
+
+ def __init__(self, result: Result):
self.metadata = result._metadata._for_freeze()
self._source_supports_scalars = result._source_supports_scalars
self._attributes = result._attributes
@@ -1576,13 +1712,13 @@ class FrozenResult:
else:
self.data = result.fetchall()
- def rewrite_rows(self):
+ def rewrite_rows(self) -> List[List[Any]]:
if self._source_supports_scalars:
return [[elem] for elem in self.data]
else:
return [list(row) for row in self.data]
- def with_new_rows(self, tuple_data):
+ def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult:
fr = FrozenResult.__new__(FrozenResult)
fr.metadata = self.metadata
fr._attributes = self._attributes
@@ -1594,7 +1730,7 @@ class FrozenResult:
fr.data = tuple_data
return fr
- def __call__(self):
+ def __call__(self) -> Result:
result = IteratorResult(self.metadata, iter(self.data))
result._attributes = self._attributes
result._source_supports_scalars = self._source_supports_scalars
@@ -1603,7 +1739,7 @@ class FrozenResult:
class IteratorResult(Result):
"""A :class:`.Result` that gets data from a Python iterator of
- :class:`.Row` objects.
+ :class:`.Row` objects or similar row-like data.
.. versionadded:: 1.4
@@ -1613,17 +1749,17 @@ class IteratorResult(Result):
def __init__(
self,
- cursor_metadata,
- iterator,
- raw=None,
- _source_supports_scalars=False,
+ cursor_metadata: ResultMetaData,
+ iterator: Iterator[_RowData],
+ raw: Optional[Any] = None,
+ _source_supports_scalars: bool = False,
):
self._metadata = cursor_metadata
self.iterator = iterator
self.raw = raw
self._source_supports_scalars = _source_supports_scalars
- def _soft_close(self, hard=False, **kw):
+ def _soft_close(self, hard: bool = False, **kw: Any) -> None:
if hard:
self._hard_closed = True
if self.raw is not None:
@@ -1631,18 +1767,18 @@ class IteratorResult(Result):
self.iterator = iter([])
self._reset_memoizations()
- def _raise_hard_closed(self):
+ def _raise_hard_closed(self) -> NoReturn:
raise exc.ResourceClosedError("This result object is closed.")
- def _raw_row_iterator(self):
+ def _raw_row_iterator(self) -> Iterator[_RowData]:
return self.iterator
- def _fetchiter_impl(self):
+ def _fetchiter_impl(self) -> Iterator[_RowData]:
if self._hard_closed:
self._raise_hard_closed()
return self.iterator
- def _fetchone_impl(self, hard_close=False):
+ def _fetchone_impl(self, hard_close: bool = False) -> Optional[_RowData]:
if self._hard_closed:
self._raise_hard_closed()
@@ -1653,27 +1789,26 @@ class IteratorResult(Result):
else:
return row
- def _fetchall_impl(self):
+ def _fetchall_impl(self) -> List[_RowData]:
if self._hard_closed:
self._raise_hard_closed()
-
try:
return list(self.iterator)
finally:
self._soft_close()
- def _fetchmany_impl(self, size=None):
+ def _fetchmany_impl(self, size: Optional[int] = None) -> List[_RowData]:
if self._hard_closed:
self._raise_hard_closed()
return list(itertools.islice(self.iterator, 0, size))
-def null_result():
+def null_result() -> IteratorResult:
return IteratorResult(SimpleResultMetaData([]), iter([]))
-SelfChunkedIteratorResult = typing.TypeVar(
+SelfChunkedIteratorResult = TypeVar(
"SelfChunkedIteratorResult", bound="ChunkedIteratorResult"
)
@@ -1695,11 +1830,11 @@ class ChunkedIteratorResult(IteratorResult):
def __init__(
self,
- cursor_metadata,
- chunks,
- source_supports_scalars=False,
- raw=None,
- dynamic_yield_per=False,
+ cursor_metadata: ResultMetaData,
+ chunks: Callable[[Optional[int]], Iterator[List[_InterimRowType]]],
+ source_supports_scalars: bool = False,
+ raw: Optional[Any] = None,
+ dynamic_yield_per: bool = False,
):
self._metadata = cursor_metadata
self.chunks = chunks
@@ -1710,7 +1845,7 @@ class ChunkedIteratorResult(IteratorResult):
@_generative
def yield_per(
- self: SelfChunkedIteratorResult, num
+ self: SelfChunkedIteratorResult, num: int
) -> SelfChunkedIteratorResult:
# TODO: this throws away the iterator which may be holding
# onto a chunk. the yield_per cannot be changed once any
@@ -1722,11 +1857,13 @@ class ChunkedIteratorResult(IteratorResult):
self.iterator = itertools.chain.from_iterable(self.chunks(num))
return self
- def _soft_close(self, **kw):
- super(ChunkedIteratorResult, self)._soft_close(**kw)
- self.chunks = lambda size: []
+ def _soft_close(self, hard: bool = False, **kw: Any) -> None:
+ super(ChunkedIteratorResult, self)._soft_close(hard=hard, **kw)
+ self.chunks = lambda size: [] # type: ignore
- def _fetchmany_impl(self, size=None):
+ def _fetchmany_impl(
+ self, size: Optional[int] = None
+ ) -> List[_InterimRowType]:
if self.dynamic_yield_per:
self.iterator = itertools.chain.from_iterable(self.chunks(size))
return super(ChunkedIteratorResult, self)._fetchmany_impl(size=size)
@@ -1744,7 +1881,9 @@ class MergedResult(IteratorResult):
closed = False
- def __init__(self, cursor_metadata, results):
+ def __init__(
+ self, cursor_metadata: ResultMetaData, results: Sequence[Result]
+ ):
self._results = results
super(MergedResult, self).__init__(
cursor_metadata,
@@ -1763,7 +1902,7 @@ class MergedResult(IteratorResult):
*[r._attributes for r in results]
)
- def _soft_close(self, hard=False, **kw):
+ def _soft_close(self, hard: bool = False, **kw: Any) -> None:
for r in self._results:
r._soft_close(hard=hard, **kw)
if hard:
diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py
index 29b2f338b..ff63199d4 100644
--- a/lib/sqlalchemy/engine/row.py
+++ b/lib/sqlalchemy/engine/row.py
@@ -9,24 +9,41 @@
from __future__ import annotations
+from abc import ABC
import collections.abc as collections_abc
import operator
import typing
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Mapping
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import Union
from ..sql import util as sql_util
from ..util._has_cy import HAS_CYEXTENSION
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
- from ._py_row import BaseRow
+ from ._py_row import BaseRow as BaseRow
from ._py_row import KEY_INTEGER_ONLY
from ._py_row import KEY_OBJECTS_ONLY
else:
- from sqlalchemy.cyextension.resultproxy import BaseRow
+ from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow
from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY
from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY
+if typing.TYPE_CHECKING:
+ from .result import _KeyType
+ from .result import RMKeyView
-class Row(BaseRow, collections_abc.Sequence):
+
+class Row(BaseRow, typing.Sequence[Any]):
"""Represent a single result row.
The :class:`.Row` object represents a row of a database result. It is
@@ -58,14 +75,14 @@ class Row(BaseRow, collections_abc.Sequence):
_default_key_style = KEY_INTEGER_ONLY
- def __setattr__(self, name, value):
+ def __setattr__(self, name: str, value: Any) -> NoReturn:
raise AttributeError("can't set attribute")
- def __delattr__(self, name):
+ def __delattr__(self, name: str) -> NoReturn:
raise AttributeError("can't delete attribute")
@property
- def _mapping(self):
+ def _mapping(self) -> RowMapping:
"""Return a :class:`.RowMapping` for this :class:`.Row`.
This object provides a consistent Python mapping (i.e. dictionary)
@@ -87,31 +104,44 @@ class Row(BaseRow, collections_abc.Sequence):
self._data,
)
- def _special_name_accessor(name):
- """Handle ambiguous names such as "count" and "index" """
+ def _filter_on_values(
+ self, filters: Optional[Sequence[Optional[Callable[[Any], Any]]]]
+ ) -> Row:
+ return Row(
+ self._parent,
+ filters,
+ self._keymap,
+ self._key_style,
+ self._data,
+ )
+
+ if not typing.TYPE_CHECKING:
+
+ def _special_name_accessor(name: str) -> Any:
+ """Handle ambiguous names such as "count" and "index" """
- @property
- def go(self):
- if self._parent._has_key(name):
- return self.__getattr__(name)
- else:
+ @property
+ def go(self: Row) -> Any:
+ if self._parent._has_key(name):
+ return self.__getattr__(name)
+ else:
- def meth(*arg, **kw):
- return getattr(collections_abc.Sequence, name)(
- self, *arg, **kw
- )
+ def meth(*arg: Any, **kw: Any) -> Any:
+ return getattr(collections_abc.Sequence, name)(
+ self, *arg, **kw
+ )
- return meth
+ return meth
- return go
+ return go
- count = _special_name_accessor("count")
- index = _special_name_accessor("index")
+ count = _special_name_accessor("count")
+ index = _special_name_accessor("index")
- def __contains__(self, key):
+ def __contains__(self, key: Any) -> bool:
return key in self._data
- def _op(self, other, op):
+ def _op(self, other: Any, op: Callable[[Any, Any], bool]) -> bool:
return (
op(tuple(self), tuple(other))
if isinstance(other, Row)
@@ -120,29 +150,44 @@ class Row(BaseRow, collections_abc.Sequence):
__hash__ = BaseRow.__hash__
- def __lt__(self, other):
+ if typing.TYPE_CHECKING:
+
+ @overload
+ def __getitem__(self, index: int) -> Any:
+ ...
+
+ @overload
+ def __getitem__(self, index: slice) -> Sequence[Any]:
+ ...
+
+ def __getitem__(
+ self, index: Union[int, slice]
+ ) -> Union[Any, Sequence[Any]]:
+ ...
+
+ def __lt__(self, other: Any) -> bool:
return self._op(other, operator.lt)
- def __le__(self, other):
+ def __le__(self, other: Any) -> bool:
return self._op(other, operator.le)
- def __ge__(self, other):
+ def __ge__(self, other: Any) -> bool:
return self._op(other, operator.ge)
- def __gt__(self, other):
+ def __gt__(self, other: Any) -> bool:
return self._op(other, operator.gt)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return self._op(other, operator.eq)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return self._op(other, operator.ne)
- def __repr__(self):
+ def __repr__(self) -> str:
return repr(sql_util._repr_row(self))
@property
- def _fields(self):
+ def _fields(self) -> Tuple[str, ...]:
"""Return a tuple of string keys as represented by this
:class:`.Row`.
@@ -162,7 +207,7 @@ class Row(BaseRow, collections_abc.Sequence):
"""
return tuple([k for k in self._parent.keys if k is not None])
- def _asdict(self):
+ def _asdict(self) -> Dict[str, Any]:
"""Return a new dict which maps field names to their corresponding
values.
@@ -179,49 +224,51 @@ class Row(BaseRow, collections_abc.Sequence):
"""
return dict(self._mapping)
- def _replace(self):
- raise NotImplementedError()
-
- @property
- def _field_defaults(self):
- raise NotImplementedError()
-
BaseRowProxy = BaseRow
RowProxy = Row
-class ROMappingView(
- collections_abc.KeysView,
- collections_abc.ValuesView,
- collections_abc.ItemsView,
-):
- __slots__ = ("_items",)
+class ROMappingView(ABC):
+ __slots__ = ()
+
+ _items: Sequence[Any]
+ _mapping: Mapping[str, Any]
- def __init__(self, mapping, items):
+ def __init__(self, mapping: Mapping[str, Any], items: Sequence[Any]):
self._mapping = mapping
self._items = items
- def __len__(self):
+ def __len__(self) -> int:
return len(self._items)
- def __repr__(self):
+ def __repr__(self) -> str:
return "{0.__class__.__name__}({0._mapping!r})".format(self)
- def __iter__(self):
+ def __iter__(self) -> Iterator[Any]:
return iter(self._items)
- def __contains__(self, item):
+ def __contains__(self, item: Any) -> bool:
return item in self._items
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return list(other) == list(self)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return list(other) != list(self)
-class RowMapping(BaseRow, collections_abc.Mapping):
+class ROMappingKeysValuesView(
+ ROMappingView, typing.KeysView[str], typing.ValuesView[Any]
+):
+ __slots__ = ("_items",)
+
+
+class ROMappingItemsView(ROMappingView, typing.ItemsView[str, Any]):
+ __slots__ = ("_items",)
+
+
+class RowMapping(BaseRow, typing.Mapping[str, Any]):
"""A ``Mapping`` that maps column names and objects to :class:`.Row` values.
The :class:`.RowMapping` is available from a :class:`.Row` via the
@@ -251,31 +298,39 @@ class RowMapping(BaseRow, collections_abc.Mapping):
_default_key_style = KEY_OBJECTS_ONLY
- __getitem__ = BaseRow._get_by_key_impl_mapping
+ if typing.TYPE_CHECKING:
- def _values_impl(self):
+ def __getitem__(self, key: _KeyType) -> Any:
+ ...
+
+ else:
+ __getitem__ = BaseRow._get_by_key_impl_mapping
+
+ def _values_impl(self) -> List[Any]:
return list(self._data)
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
return (k for k in self._parent.keys if k is not None)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._data)
- def __contains__(self, key):
+ def __contains__(self, key: object) -> bool:
return self._parent._has_key(key)
- def __repr__(self):
+ def __repr__(self) -> str:
return repr(dict(self))
- def items(self):
+ def items(self) -> ROMappingItemsView:
"""Return a view of key/value tuples for the elements in the
underlying :class:`.Row`.
"""
- return ROMappingView(self, [(key, self[key]) for key in self.keys()])
+ return ROMappingItemsView(
+ self, [(key, self[key]) for key in self.keys()]
+ )
- def keys(self):
+ def keys(self) -> RMKeyView:
"""Return a view of 'keys' for string column names represented
by the underlying :class:`.Row`.
@@ -283,9 +338,9 @@ class RowMapping(BaseRow, collections_abc.Mapping):
return self._parent.keys
- def values(self):
+ def values(self) -> ROMappingKeysValuesView:
"""Return a view of values for the values represented in the
underlying :class:`.Row`.
"""
- return ROMappingView(self, self._values_impl())
+ return ROMappingKeysValuesView(self, self._values_impl())
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index a55233397..306989e0b 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -18,10 +18,18 @@ from __future__ import annotations
import collections.abc as collections_abc
import re
+from typing import Any
+from typing import cast
from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Mapping
from typing import NamedTuple
from typing import Optional
+from typing import overload
+from typing import Sequence
from typing import Tuple
+from typing import Type
from typing import Union
from urllib.parse import parse_qsl
from urllib.parse import quote_plus
@@ -86,19 +94,19 @@ class URL(NamedTuple):
host: Optional[str]
port: Optional[int]
database: Optional[str]
- query: Dict[str, Union[str, Tuple[str]]]
+ query: util.immutabledict[str, Union[Tuple[str, ...], str]]
@classmethod
def create(
cls,
- drivername,
- username=None,
- password=None,
- host=None,
- port=None,
- database=None,
- query=util.EMPTY_DICT,
- ):
+ drivername: str,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ database: Optional[str] = None,
+ query: Mapping[str, Union[Sequence[str], str]] = util.EMPTY_DICT,
+ ) -> URL:
"""Create a new :class:`_engine.URL` object.
:param drivername: the name of the database backend. This name will
@@ -146,7 +154,7 @@ class URL(NamedTuple):
)
@classmethod
- def _assert_port(cls, port):
+ def _assert_port(cls, port: Optional[int]) -> Optional[int]:
if port is None:
return None
try:
@@ -155,24 +163,48 @@ class URL(NamedTuple):
raise TypeError("Port argument must be an integer or None")
@classmethod
- def _assert_str(cls, v, paramname):
+ def _assert_str(cls, v: str, paramname: str) -> str:
if not isinstance(v, str):
raise TypeError("%s must be a string" % paramname)
return v
@classmethod
- def _assert_none_str(cls, v, paramname):
+ def _assert_none_str(
+ cls, v: Optional[str], paramname: str
+ ) -> Optional[str]:
if v is None:
return v
return cls._assert_str(v, paramname)
@classmethod
- def _str_dict(cls, dict_):
+ def _str_dict(
+ cls,
+ dict_: Optional[
+ Union[
+ Sequence[Tuple[str, Union[Sequence[str], str]]],
+ Mapping[str, Union[Sequence[str], str]],
+ ]
+ ],
+ ) -> util.immutabledict[str, Union[Tuple[str, ...], str]]:
if dict_ is None:
return util.EMPTY_DICT
- def _assert_value(val):
+ @overload
+ def _assert_value(
+ val: str,
+ ) -> str:
+ ...
+
+ @overload
+ def _assert_value(
+ val: Sequence[str],
+ ) -> Union[str, Tuple[str, ...]]:
+ ...
+
+ def _assert_value(
+ val: Union[str, Sequence[str]],
+ ) -> Union[str, Tuple[str, ...]]:
if isinstance(val, str):
return val
elif isinstance(val, collections_abc.Sequence):
@@ -183,11 +215,12 @@ class URL(NamedTuple):
"sequences of strings"
)
- def _assert_str(v):
+ def _assert_str(v: str) -> str:
if not isinstance(v, str):
raise TypeError("Query dictionary keys must be strings")
return v
+ dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]]
if isinstance(dict_, collections_abc.Sequence):
dict_items = dict_
else:
@@ -204,14 +237,14 @@ class URL(NamedTuple):
def set(
self,
- drivername=None,
- username=None,
- password=None,
- host=None,
- port=None,
- database=None,
- query=None,
- ):
+ drivername: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ database: Optional[str] = None,
+ query: Optional[Mapping[str, Union[Sequence[str], str]]] = None,
+ ) -> URL:
"""return a new :class:`_engine.URL` object with modifications.
Values are used if they are non-None. To set a value to ``None``
@@ -237,7 +270,7 @@ class URL(NamedTuple):
"""
- kw = {}
+ kw: Dict[str, Any] = {}
if drivername is not None:
kw["drivername"] = drivername
if username is not None:
@@ -255,7 +288,7 @@ class URL(NamedTuple):
return self._assert_replace(**kw)
- def _assert_replace(self, **kw):
+ def _assert_replace(self, **kw: Any) -> URL:
"""argument checks before calling _replace()"""
if "drivername" in kw:
@@ -270,7 +303,9 @@ class URL(NamedTuple):
return self._replace(**kw)
- def update_query_string(self, query_string, append=False):
+ def update_query_string(
+ self, query_string: str, append: bool = False
+ ) -> URL:
"""Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query`
parameter dictionary updated by the given query string.
@@ -301,7 +336,11 @@ class URL(NamedTuple):
""" # noqa: E501
return self.update_query_pairs(parse_qsl(query_string), append=append)
- def update_query_pairs(self, key_value_pairs, append=False):
+ def update_query_pairs(
+ self,
+ key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]],
+ append: bool = False,
+ ) -> URL:
"""Return a new :class:`_engine.URL` object with the
:attr:`_engine.URL.query`
parameter dictionary updated by the given sequence of key/value pairs
@@ -335,23 +374,27 @@ class URL(NamedTuple):
""" # noqa: E501
existing_query = self.query
- new_keys = {}
+ new_keys: Dict[str, Union[str, List[str]]] = {}
for key, value in key_value_pairs:
if key in new_keys:
new_keys[key] = util.to_list(new_keys[key])
- new_keys[key].append(value)
+ cast("List[str]", new_keys[key]).append(cast(str, value))
else:
- new_keys[key] = value
+ new_keys[key] = (
+ list(value) if isinstance(value, (list, tuple)) else value
+ )
+ new_query: Mapping[str, Union[str, Sequence[str]]]
if append:
new_query = {}
for k in new_keys:
if k in existing_query:
- new_query[k] = util.to_list(
- existing_query[k]
- ) + util.to_list(new_keys[k])
+ new_query[k] = tuple(
+ util.to_list(existing_query[k])
+ + util.to_list(new_keys[k])
+ )
else:
new_query[k] = new_keys[k]
@@ -362,10 +405,19 @@ class URL(NamedTuple):
}
)
else:
- new_query = self.query.union(new_keys)
+ new_query = self.query.union(
+ {
+ k: tuple(v) if isinstance(v, list) else v
+ for k, v in new_keys.items()
+ }
+ )
return self.set(query=new_query)
- def update_query_dict(self, query_parameters, append=False):
+ def update_query_dict(
+ self,
+ query_parameters: Mapping[str, Union[str, List[str]]],
+ append: bool = False,
+ ) -> URL:
"""Return a new :class:`_engine.URL` object with the
:attr:`_engine.URL.query` parameter dictionary updated by the given
dictionary.
@@ -410,7 +462,7 @@ class URL(NamedTuple):
""" # noqa: E501
return self.update_query_pairs(query_parameters.items(), append=append)
- def difference_update_query(self, names):
+ def difference_update_query(self, names: Iterable[str]) -> URL:
"""
Remove the given names from the :attr:`_engine.URL.query` dictionary,
returning the new :class:`_engine.URL`.
@@ -459,7 +511,7 @@ class URL(NamedTuple):
)
@util.memoized_property
- def normalized_query(self):
+ def normalized_query(self) -> Mapping[str, Sequence[str]]:
"""Return the :attr:`_engine.URL.query` dictionary with values normalized
into sequences.
@@ -494,7 +546,7 @@ class URL(NamedTuple):
"be removed in a future release. Please use the "
":meth:`_engine.URL.render_as_string` method.",
)
- def __to_string__(self, hide_password=True):
+ def __to_string__(self, hide_password: bool = True) -> str:
"""Render this :class:`_engine.URL` object as a string.
:param hide_password: Defaults to True. The password is not shown
@@ -503,7 +555,7 @@ class URL(NamedTuple):
"""
return self.render_as_string(hide_password=hide_password)
- def render_as_string(self, hide_password=True):
+ def render_as_string(self, hide_password: bool = True) -> str:
"""Render this :class:`_engine.URL` object as a string.
This method is used when the ``__str__()`` or ``__repr__()``
@@ -542,13 +594,13 @@ class URL(NamedTuple):
)
return s
- def __str__(self):
+ def __str__(self) -> str:
return self.render_as_string(hide_password=False)
- def __repr__(self):
+ def __repr__(self) -> str:
return self.render_as_string()
- def __copy__(self):
+ def __copy__(self) -> URL:
return self.__class__.create(
self.drivername,
self.username,
@@ -561,13 +613,13 @@ class URL(NamedTuple):
self.query,
)
- def __deepcopy__(self, memo):
+ def __deepcopy__(self, memo: Any) -> URL:
return self.__copy__()
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(str(self))
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, URL)
and self.drivername == other.drivername
@@ -579,10 +631,10 @@ class URL(NamedTuple):
and self.port == other.port
)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return not self == other
- def get_backend_name(self):
+ def get_backend_name(self) -> str:
"""Return the backend name.
This is the name that corresponds to the database backend in
@@ -595,7 +647,7 @@ class URL(NamedTuple):
else:
return self.drivername.split("+")[0]
- def get_driver_name(self):
+ def get_driver_name(self) -> str:
"""Return the backend name.
This is the name that corresponds to the DBAPI driver in
@@ -613,7 +665,9 @@ class URL(NamedTuple):
else:
return self.drivername.split("+")[1]
- def _instantiate_plugins(self, kwargs):
+ def _instantiate_plugins(
+ self, kwargs: Mapping[str, Any]
+ ) -> Tuple[URL, List[Any], Dict[str, Any]]:
plugin_names = util.to_list(self.query.get("plugin", ()))
plugin_names += kwargs.get("plugins", [])
@@ -635,7 +689,7 @@ class URL(NamedTuple):
return u, loaded_plugins, kwargs
- def _get_entrypoint(self):
+ def _get_entrypoint(self) -> Type[Dialect]:
"""Return the "entry point" dialect class.
This is normally the dialect itself except in the case when the
@@ -657,9 +711,9 @@ class URL(NamedTuple):
):
return cls.dialect
else:
- return cls
+ return cast("Type[Dialect]", cls)
- def get_dialect(self, _is_async=False):
+ def get_dialect(self, _is_async: bool = False) -> Type[Dialect]:
"""Return the SQLAlchemy :class:`_engine.Dialect` class corresponding
to this URL's driver name.
@@ -671,7 +725,9 @@ class URL(NamedTuple):
dialect_cls = entrypoint.get_dialect_cls(self)
return dialect_cls
- def translate_connect_args(self, names=None, **kw):
+ def translate_connect_args(
+ self, names: Optional[List[str]] = None, **kw: Any
+ ) -> Dict[str, Any]:
r"""Translate url attributes into a dictionary of connection arguments.
Returns attributes of this url (`host`, `database`, `username`,
@@ -711,11 +767,12 @@ class URL(NamedTuple):
return translated
-def make_url(name_or_url):
- """Given a string or unicode instance, produce a new URL instance.
+def make_url(name_or_url: Union[str, URL]) -> URL:
+ """Given a string, produce a new URL instance.
The given string is parsed according to the RFC 1738 spec. If an
existing URL object is passed, just returns the object.
+
"""
if isinstance(name_or_url, str):
@@ -724,7 +781,7 @@ def make_url(name_or_url):
return name_or_url
-def _parse_rfc1738_args(name):
+def _parse_rfc1738_args(name: str) -> URL:
pattern = re.compile(
r"""
(?P<name>[\w\+]+)://
@@ -748,13 +805,14 @@ def _parse_rfc1738_args(name):
m = pattern.match(name)
if m is not None:
components = m.groupdict()
+ query: Optional[Dict[str, Union[str, List[str]]]]
if components["query"] is not None:
query = {}
for key, value in parse_qsl(components["query"]):
if key in query:
query[key] = util.to_list(query[key])
- query[key].append(value)
+ cast("List[str]", query[key]).append(value)
else:
query[key] = value
else:
@@ -775,7 +833,7 @@ def _parse_rfc1738_args(name):
if components["port"]:
components["port"] = int(components["port"])
- return URL.create(name, **components)
+ return URL.create(name, **components) # type: ignore
else:
raise exc.ArgumentError(
@@ -783,18 +841,8 @@ def _parse_rfc1738_args(name):
)
-def _rfc_1738_quote(text):
+def _rfc_1738_quote(text: str) -> str:
return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text)
_rfc_1738_unquote = unquote
-
-
-def _parse_keyvalue_args(name):
- m = re.match(r"(\w+)://(.*)", name)
- if m is not None:
- (name, args) = m.group(1, 2)
- opts = dict(parse_qsl(args))
- return URL(name, *opts)
- else:
- return None
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py
index f9ee65bef..213485cc9 100644
--- a/lib/sqlalchemy/engine/util.py
+++ b/lib/sqlalchemy/engine/util.py
@@ -7,18 +7,30 @@
from __future__ import annotations
+import typing
+from typing import Any
+from typing import Callable
+from typing import TypeVar
+
from .. import exc
from .. import util
+from ..util._has_cy import HAS_CYEXTENSION
+
+if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
+ from ._py_util import _distill_params_20 as _distill_params_20
+ from ._py_util import _distill_raw_params as _distill_raw_params
+else:
+ from sqlalchemy.cyextension.util import (
+ _distill_params_20 as _distill_params_20,
+ )
+ from sqlalchemy.cyextension.util import (
+ _distill_raw_params as _distill_raw_params,
+ )
-try:
- from sqlalchemy.cyextension.util import _distill_params_20 # noqa
- from sqlalchemy.cyextension.util import _distill_raw_params # noqa
-except ImportError:
- from ._py_util import _distill_params_20 # noqa
- from ._py_util import _distill_raw_params # noqa
+_C = TypeVar("_C", bound=Callable[[], Any])
-def connection_memoize(key):
+def connection_memoize(key: str) -> Callable[[_C], _C]:
"""Decorator, memoize a function in a connection.info stash.
Only applicable to functions which take no arguments other than a
@@ -26,7 +38,7 @@ def connection_memoize(key):
"""
@util.decorator
- def decorated(fn, self, connection):
+ def decorated(fn, self, connection): # type: ignore
connection = connection.connect()
try:
return connection.info[key]
@@ -34,7 +46,7 @@ def connection_memoize(key):
connection.info[key] = val = fn(self, connection)
return val
- return decorated
+ return decorated # type: ignore[return-value]
class TransactionalContext:
@@ -47,13 +59,13 @@ class TransactionalContext:
__slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__")
- def _transaction_is_active(self):
+ def _transaction_is_active(self) -> bool:
raise NotImplementedError()
- def _transaction_is_closed(self):
+ def _transaction_is_closed(self) -> bool:
raise NotImplementedError()
- def _rollback_can_be_called(self):
+ def _rollback_can_be_called(self) -> bool:
"""indicates the object is in a state that is known to be acceptable
for rollback() to be called.
@@ -70,11 +82,20 @@ class TransactionalContext:
"""
raise NotImplementedError()
- def _get_subject(self):
+ def _get_subject(self) -> Any:
+ raise NotImplementedError()
+
+ def commit(self) -> None:
+ raise NotImplementedError()
+
+ def rollback(self) -> None:
+ raise NotImplementedError()
+
+ def close(self) -> None:
raise NotImplementedError()
@classmethod
- def _trans_ctx_check(cls, subject):
+ def _trans_ctx_check(cls, subject: Any) -> None:
trans_context = subject._trans_context_manager
if trans_context:
if not trans_context._transaction_is_active():
@@ -84,7 +105,7 @@ class TransactionalContext:
"before emitting further commands."
)
- def __enter__(self):
+ def __enter__(self) -> TransactionalContext:
subject = self._get_subject()
# none for outer transaction, may be non-None for nested
@@ -96,7 +117,7 @@ class TransactionalContext:
subject._trans_context_manager = self
return self
- def __exit__(self, type_, value, traceback):
+ def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
subject = getattr(self, "_trans_subject", None)
# simplistically we could assume that
@@ -119,6 +140,7 @@ class TransactionalContext:
self.rollback()
finally:
if not out_of_band_exit:
+ assert subject is not None
subject._trans_context_manager = self._outer_trans_ctx
self._trans_subject = self._outer_trans_ctx = None
else:
@@ -131,5 +153,6 @@ class TransactionalContext:
self.rollback()
finally:
if not out_of_band_exit:
+ assert subject is not None
subject._trans_context_manager = self._outer_trans_ctx
self._trans_subject = self._outer_trans_ctx = None
diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py
index 0dfb39e1a..e1c949681 100644
--- a/lib/sqlalchemy/event/__init__.py
+++ b/lib/sqlalchemy/event/__init__.py
@@ -15,6 +15,7 @@ from .api import NO_RETVAL as NO_RETVAL
from .api import remove as remove
from .attr import RefCollection as RefCollection
from .base import _Dispatch as _Dispatch
+from .base import _DispatchCommon as _DispatchCommon
from .base import dispatcher as dispatcher
from .base import Events as Events
from .legacy import _legacy_signature as _legacy_signature
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
index 9692894fe..afae8a59a 100644
--- a/lib/sqlalchemy/event/attr.py
+++ b/lib/sqlalchemy/event/attr.py
@@ -605,14 +605,14 @@ class _ListenerCollection(_CompoundListener[_ET]):
class _JoinedListener(_CompoundListener[_ET]):
__slots__ = "parent_dispatch", "name", "local", "parent_listeners"
- parent_dispatch: _Dispatch[_ET]
+ parent_dispatch: _DispatchCommon[_ET]
name: str
local: _InstanceLevelDispatch[_ET]
parent_listeners: Collection[_ListenerFnType]
def __init__(
self,
- parent_dispatch: _Dispatch[_ET],
+ parent_dispatch: _DispatchCommon[_ET],
name: str,
local: _EmptyListener[_ET],
):
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
index ef3ff9dab..4174b1dbe 100644
--- a/lib/sqlalchemy/event/base.py
+++ b/lib/sqlalchemy/event/base.py
@@ -75,6 +75,18 @@ class _UnpickleDispatch:
class _DispatchCommon(Generic[_ET]):
__slots__ = ()
+ _instance_cls: Optional[Type[_ET]]
+
+ def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]:
+ raise NotImplementedError()
+
+ def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]:
+ raise NotImplementedError()
+
+ @property
+ def _events(self) -> Type[_HasEventsDispatch[_ET]]:
+ raise NotImplementedError()
+
class _Dispatch(_DispatchCommon[_ET]):
"""Mirror the event listening definitions of an Events class with
@@ -169,7 +181,7 @@ class _Dispatch(_DispatchCommon[_ET]):
instance_cls = instance.__class__
return self._for_class(instance_cls)
- def _join(self, other: _Dispatch[_ET]) -> _JoinedDispatcher[_ET]:
+ def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]:
"""Create a 'join' of this :class:`._Dispatch` and another.
This new dispatcher will dispatch events to both
@@ -372,11 +384,13 @@ class _JoinedDispatcher(_DispatchCommon[_ET]):
__slots__ = "local", "parent", "_instance_cls"
- local: _Dispatch[_ET]
- parent: _Dispatch[_ET]
+ local: _DispatchCommon[_ET]
+ parent: _DispatchCommon[_ET]
_instance_cls: Optional[Type[_ET]]
- def __init__(self, local: _Dispatch[_ET], parent: _Dispatch[_ET]):
+ def __init__(
+ self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET]
+ ):
self.local = local
self.parent = parent
self._instance_cls = self.local._instance_cls
@@ -416,7 +430,7 @@ class dispatcher(Generic[_ET]):
...
@overload
- def __get__(self, obj: Any, cls: Type[Any]) -> _Dispatch[_ET]:
+ def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]:
...
def __get__(self, obj: Any, cls: Type[Any]) -> Any:
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
index 1383e024a..cc78e0971 100644
--- a/lib/sqlalchemy/exc.py
+++ b/lib/sqlalchemy/exc.py
@@ -27,8 +27,11 @@ from .util import _preloaded
from .util import compat
if typing.TYPE_CHECKING:
+ from .engine.interfaces import _AnyExecuteParams
+ from .engine.interfaces import _CoreAnyExecuteParams
+ from .engine.interfaces import _CoreMultiExecuteParams
+ from .engine.interfaces import _DBAPIAnyExecuteParams
from .engine.interfaces import Dialect
- from .sql._typing import _ExecuteParams
from .sql.compiler import Compiled
from .sql.elements import ClauseElement
@@ -446,7 +449,7 @@ class StatementError(SQLAlchemyError):
statement: Optional[str] = None
"""The string SQL statement being invoked when this exception occurred."""
- params: Optional["_ExecuteParams"] = None
+ params: Optional[_AnyExecuteParams] = None
"""The parameter list being used when this exception occurred."""
orig: Optional[BaseException] = None
@@ -457,11 +460,13 @@ class StatementError(SQLAlchemyError):
ismulti: Optional[bool] = None
"""multi parameter passed to repr_params(). None is meaningful."""
+ connection_invalidated: bool = False
+
def __init__(
self,
message: str,
statement: Optional[str],
- params: Optional["_ExecuteParams"],
+ params: Optional[_AnyExecuteParams],
orig: Optional[BaseException],
hide_parameters: bool = False,
code: Optional[str] = None,
@@ -553,8 +558,8 @@ class DBAPIError(StatementError):
@classmethod
def instance(
cls,
- statement: str,
- params: "_ExecuteParams",
+ statement: Optional[str],
+ params: Optional[_AnyExecuteParams],
orig: DontWrapMixin,
dbapi_base_err: Type[Exception],
hide_parameters: bool = False,
@@ -568,8 +573,8 @@ class DBAPIError(StatementError):
@classmethod
def instance(
cls,
- statement: str,
- params: "_ExecuteParams",
+ statement: Optional[str],
+ params: Optional[_AnyExecuteParams],
orig: Exception,
dbapi_base_err: Type[Exception],
hide_parameters: bool = False,
@@ -583,8 +588,8 @@ class DBAPIError(StatementError):
@classmethod
def instance(
cls,
- statement: str,
- params: "_ExecuteParams",
+ statement: Optional[str],
+ params: Optional[_AnyExecuteParams],
orig: BaseException,
dbapi_base_err: Type[Exception],
hide_parameters: bool = False,
@@ -597,8 +602,8 @@ class DBAPIError(StatementError):
@classmethod
def instance(
cls,
- statement: str,
- params: "_ExecuteParams",
+ statement: Optional[str],
+ params: Optional[_AnyExecuteParams],
orig: Union[BaseException, DontWrapMixin],
dbapi_base_err: Type[Exception],
hide_parameters: bool = False,
@@ -684,8 +689,8 @@ class DBAPIError(StatementError):
def __init__(
self,
- statement: str,
- params: "_ExecuteParams",
+ statement: Optional[str],
+ params: Optional[_AnyExecuteParams],
orig: BaseException,
hide_parameters: bool = False,
connection_invalidated: bool = False,
diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py
index 8da45ed0d..9a23d89d3 100644
--- a/lib/sqlalchemy/log.py
+++ b/lib/sqlalchemy/log.py
@@ -255,19 +255,19 @@ class echo_property:
@overload
def __get__(
- self, instance: "Literal[None]", owner: "echo_property"
- ) -> "echo_property":
+ self, instance: Literal[None], owner: Type[Identified]
+ ) -> echo_property:
...
@overload
def __get__(
- self, instance: Identified, owner: "echo_property"
+ self, instance: Identified, owner: Type[Identified]
) -> _EchoFlagType:
...
def __get__(
- self, instance: Optional[Identified], owner: "echo_property"
- ) -> Union["echo_property", _EchoFlagType]:
+ self, instance: Optional[Identified], owner: Type[Identified]
+ ) -> Union[echo_property, _EchoFlagType]:
if instance is None:
return self
else:
diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py
index 2c52a7065..1fc77243a 100644
--- a/lib/sqlalchemy/pool/__init__.py
+++ b/lib/sqlalchemy/pool/__init__.py
@@ -18,8 +18,8 @@ SQLAlchemy connection pool.
"""
from . import events
-from .base import _AdhocProxiedConnection
-from .base import _ConnectionFairy
+from .base import _AdhocProxiedConnection as _AdhocProxiedConnection
+from .base import _ConnectionFairy as _ConnectionFairy
from .base import _ConnectionRecord
from .base import _finalize_fairy
from .base import ConnectionPoolEntry as ConnectionPoolEntry
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
index 18d268182..c1008de5f 100644
--- a/lib/sqlalchemy/pool/base.py
+++ b/lib/sqlalchemy/pool/base.py
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
from ..engine.interfaces import DBAPICursor
from ..engine.interfaces import Dialect
from ..event import _Dispatch
+ from ..event import _DispatchCommon
from ..event import _ListenerFnType
from ..event import dispatcher
@@ -132,7 +133,7 @@ class Pool(log.Identified, event.EventTarget):
events: Optional[List[Tuple[_ListenerFnType, str]]] = None,
dialect: Optional[Union[_ConnDialect, Dialect]] = None,
pre_ping: bool = False,
- _dispatch: Optional[_Dispatch[Pool]] = None,
+ _dispatch: Optional[_DispatchCommon[Pool]] = None,
):
"""
Construct a Pool.
@@ -443,78 +444,72 @@ class ManagesConnection:
"""
- @property
- def driver_connection(self) -> Optional[Any]:
- """The "driver level" connection object as used by the Python
- DBAPI or database driver.
+ driver_connection: Optional[Any]
+ """The "driver level" connection object as used by the Python
+ DBAPI or database driver.
- For traditional :pep:`249` DBAPI implementations, this object will
- be the same object as that of
- :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database
- driver, this will be the ultimate "connection" object used by that
- driver, such as the ``asyncpg.Connection`` object which will not have
- standard pep-249 methods.
+ For traditional :pep:`249` DBAPI implementations, this object will
+ be the same object as that of
+ :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database
+ driver, this will be the ultimate "connection" object used by that
+ driver, such as the ``asyncpg.Connection`` object which will not have
+ standard pep-249 methods.
- .. versionadded:: 1.4.24
+ .. versionadded:: 1.4.24
- .. seealso::
+ .. seealso::
- :attr:`.ManagesConnection.dbapi_connection`
+ :attr:`.ManagesConnection.dbapi_connection`
- :ref:`faq_dbapi_connection`
+ :ref:`faq_dbapi_connection`
- """
- raise NotImplementedError()
+ """
- @util.dynamic_property
- def info(self) -> Dict[str, Any]:
- """Info dictionary associated with the underlying DBAPI connection
- referred to by this :class:`.ManagesConnection` instance, allowing
- user-defined data to be associated with the connection.
+ info: Dict[str, Any]
+ """Info dictionary associated with the underlying DBAPI connection
+ referred to by this :class:`.ManagesConnection` instance, allowing
+ user-defined data to be associated with the connection.
- The data in this dictionary is persistent for the lifespan
- of the DBAPI connection itself, including across pool checkins
- and checkouts. When the connection is invalidated
- and replaced with a new one, this dictionary is cleared.
+ The data in this dictionary is persistent for the lifespan
+ of the DBAPI connection itself, including across pool checkins
+ and checkouts. When the connection is invalidated
+ and replaced with a new one, this dictionary is cleared.
- For a :class:`.PoolProxiedConnection` instance that's not associated
- with a :class:`.ConnectionPoolEntry`, such as if it were detached, the
- attribute returns a dictionary that is local to that
- :class:`.ConnectionPoolEntry`. Therefore the
- :attr:`.ManagesConnection.info` attribute will always provide a Python
- dictionary.
+ For a :class:`.PoolProxiedConnection` instance that's not associated
+ with a :class:`.ConnectionPoolEntry`, such as if it were detached, the
+ attribute returns a dictionary that is local to that
+ :class:`.ConnectionPoolEntry`. Therefore the
+ :attr:`.ManagesConnection.info` attribute will always provide a Python
+ dictionary.
- .. seealso::
+ .. seealso::
- :attr:`.ManagesConnection.record_info`
+ :attr:`.ManagesConnection.record_info`
- """
- raise NotImplementedError()
+ """
- @util.dynamic_property
- def record_info(self) -> Optional[Dict[str, Any]]:
- """Persistent info dictionary associated with this
- :class:`.ManagesConnection`.
+ record_info: Optional[Dict[str, Any]]
+ """Persistent info dictionary associated with this
+ :class:`.ManagesConnection`.
- Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan
- of this dictionary is that of the :class:`.ConnectionPoolEntry`
- which owns it; therefore this dictionary will persist across
- reconnects and connection invalidation for a particular entry
- in the connection pool.
+ Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan
+ of this dictionary is that of the :class:`.ConnectionPoolEntry`
+ which owns it; therefore this dictionary will persist across
+ reconnects and connection invalidation for a particular entry
+ in the connection pool.
- For a :class:`.PoolProxiedConnection` instance that's not associated
- with a :class:`.ConnectionPoolEntry`, such as if it were detached, the
- attribute returns None. Contrast to the :attr:`.ManagesConnection.info`
- dictionary which is never None.
+ For a :class:`.PoolProxiedConnection` instance that's not associated
+ with a :class:`.ConnectionPoolEntry`, such as if it were detached, the
+ attribute returns None. Contrast to the :attr:`.ManagesConnection.info`
+ dictionary which is never None.
- .. seealso::
+ .. seealso::
- :attr:`.ManagesConnection.info`
+ :attr:`.ManagesConnection.info`
- """
- raise NotImplementedError()
+ """
def invalidate(
self, e: Optional[BaseException] = None, soft: bool = False
@@ -618,7 +613,7 @@ class _ConnectionRecord(ConnectionPoolEntry):
dbapi_connection: Optional[DBAPIConnection]
@property
- def driver_connection(self) -> Optional[Any]:
+ def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa E501
if self.dbapi_connection is None:
return None
else:
@@ -637,11 +632,11 @@ class _ConnectionRecord(ConnectionPoolEntry):
_soft_invalidate_time: float = 0
@util.memoized_property
- def info(self) -> Dict[str, Any]:
+ def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125
return {}
@util.memoized_property
- def record_info(self) -> Optional[Dict[str, Any]]:
+ def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501
return {}
@classmethod
@@ -1048,7 +1043,7 @@ class _AdhocProxiedConnection(PoolProxiedConnection):
"""
- __slots__ = ("dbapi_connection", "_connection_record")
+ __slots__ = ("dbapi_connection", "_connection_record", "_is_valid")
dbapi_connection: DBAPIConnection
_connection_record: ConnectionPoolEntry
@@ -1060,9 +1055,10 @@ class _AdhocProxiedConnection(PoolProxiedConnection):
):
self.dbapi_connection = dbapi_connection
self._connection_record = connection_record
+ self._is_valid = True
@property
- def driver_connection(self) -> Any:
+ def driver_connection(self) -> Any: # type: ignore[override] # mypy#4125
return self._connection_record.driver_connection
@property
@@ -1071,10 +1067,21 @@ class _AdhocProxiedConnection(PoolProxiedConnection):
@property
def is_valid(self) -> bool:
- raise AttributeError("is_valid not implemented by this proxy")
+ """Implement is_valid state attribute.
+
+ for the adhoc proxied connection it's assumed the connection is valid
+ as there is no "invalidate" routine.
+
+ """
+ return self._is_valid
- @util.dynamic_property
- def record_info(self) -> Optional[Dict[str, Any]]:
+ def invalidate(
+ self, e: Optional[BaseException] = None, soft: bool = False
+ ) -> None:
+ self._is_valid = False
+
+ @property
+ def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501
return self._connection_record.record_info
def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor:
@@ -1140,7 +1147,7 @@ class _ConnectionFairy(PoolProxiedConnection):
_connection_record: Optional[_ConnectionRecord]
@property
- def driver_connection(self) -> Optional[Any]:
+ def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa E501
if self._connection_record is None:
return None
return self._connection_record.driver_connection
@@ -1305,17 +1312,17 @@ class _ConnectionFairy(PoolProxiedConnection):
@property
def is_detached(self) -> bool:
- return self._connection_record is not None
+ return self._connection_record is None
@util.memoized_property
- def info(self) -> Dict[str, Any]:
+ def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125
if self._connection_record is None:
return {}
else:
return self._connection_record.info
- @util.dynamic_property
- def record_info(self) -> Optional[Dict[str, Any]]:
+ @property
+ def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501
if self._connection_record is None:
return None
else:
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index 7d8b9ee5c..69e4645fa 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -1,20 +1,11 @@
from __future__ import annotations
-from typing import Any
-from typing import Mapping
-from typing import Sequence
from typing import Type
from typing import Union
from . import roles
from ..inspection import Inspectable
-from ..util import immutabledict
-_SingleExecuteParams = Mapping[str, Any]
-_MultiExecuteParams = Sequence[_SingleExecuteParams]
-_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams]
-_ExecuteOptions = Mapping[str, Any]
-_ImmutableExecuteOptions = immutabledict[str, Any]
_ColumnsClauseElement = Union[
roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement]
]
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 3936ed9c6..a94590da1 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -19,11 +19,12 @@ from itertools import zip_longest
import operator
import re
import typing
+from typing import Optional
+from typing import Sequence
from typing import TypeVar
from . import roles
from . import visitors
-from ._typing import _ImmutableExecuteOptions
from .cache_key import HasCacheKey # noqa
from .cache_key import MemoizedHasCacheKey # noqa
from .traversals import HasCopyInternals # noqa
@@ -32,7 +33,7 @@ from .visitors import ExtendedInternalTraversal
from .visitors import InternalTraversal
from .. import exc
from .. import util
-from ..util import HasMemoized
+from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
from ..util._has_cy import HAS_CYEXTENSION
@@ -42,6 +43,16 @@ if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
else:
from sqlalchemy.cyextension.util import prefix_anon_map # noqa
+if typing.TYPE_CHECKING:
+ from ..engine import Connection
+ from ..engine import Result
+ from ..engine.interfaces import _CoreMultiExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..engine.interfaces import _ImmutableExecuteOptions
+ from ..engine.interfaces import CacheStats
+
+
coercions = None
elements = None
type_api = None
@@ -856,6 +867,32 @@ class Executable(roles.StatementRole, Generative):
is_delete = False
is_dml = False
+ if typing.TYPE_CHECKING:
+
+ def _compile_w_cache(
+ self,
+ dialect: Dialect,
+ compiled_cache: Optional[_CompiledCacheType] = None,
+ column_keys: Optional[Sequence[str]] = None,
+ for_executemany: bool = False,
+ schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+ **kw: Any,
+ ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]:
+ ...
+
+ def _execute_on_connection(
+ self,
+ connection: Connection,
+ distilled_params: _CoreMultiExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ _force: bool = False,
+ ) -> Result:
+ ...
+
+ @property
+ def _all_selected_columns(self):
+ raise NotImplementedError()
+
@property
def _effective_plugin_target(self):
return self.__visit_name__
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index 49f1899d5..ff659b77d 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -7,10 +7,12 @@
from __future__ import annotations
-from collections import namedtuple
import enum
from itertools import zip_longest
+import typing
+from typing import Any
from typing import Callable
+from typing import NamedTuple
from typing import Union
from .visitors import anon_map
@@ -22,6 +24,10 @@ from ..util import HasMemoized
from ..util.typing import Literal
+if typing.TYPE_CHECKING:
+ from .elements import BindParameter
+
+
class CacheConst(enum.Enum):
NO_CACHE = 0
@@ -345,7 +351,7 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
return HasCacheKey._generate_cache_key(self)
-class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+class CacheKey(NamedTuple):
"""The key used to identify a SQL statement construct in the
SQL compilation cache.
@@ -355,6 +361,9 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
"""
+ key: Tuple[Any, ...]
+ bindparams: Sequence[BindParameter]
+
def __hash__(self):
"""CacheKey itself is not hashable - hash the .key portion"""
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d0f114d6c..712d31462 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -27,6 +27,7 @@ from __future__ import annotations
import collections
import collections.abc as collections_abc
import contextlib
+from enum import IntEnum
import itertools
import operator
import re
@@ -35,9 +36,13 @@ import typing
from typing import Any
from typing import Dict
from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NamedTuple
from typing import Optional
+from typing import Sequence
from typing import Tuple
+from typing import Union
from . import base
from . import coercions
@@ -51,12 +56,17 @@ from . import sqltypes
from .base import NO_ARG
from .base import prefix_anon_map
from .elements import quoted_name
+from .schema import Column
+from .type_api import TypeEngine
from .. import exc
from .. import util
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
from .selectable import CTE
from .selectable import FromClause
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.result import _ProcessorType
_FromHintsType = Dict["FromClause", str]
@@ -271,42 +281,71 @@ COMPOUND_KEYWORDS = {
}
-RM_RENDERED_NAME = 0
-RM_NAME = 1
-RM_OBJECTS = 2
-RM_TYPE = 3
+class ResultColumnsEntry(NamedTuple):
+ """Tracks a column expression that is expected to be represented
+ in the result rows for this statement.
+ This normally refers to the columns clause of a SELECT statement
+ but may also refer to a RETURNING clause, as well as for dialect-specific
+ emulations.
-ExpandedState = collections.namedtuple(
- "ExpandedState",
- [
- "statement",
- "additional_parameters",
- "processors",
- "positiontup",
- "parameter_expansion",
- ],
-)
+ """
+ keyname: str
+ """string name that's expected in cursor.description"""
-NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
+ name: str
+ """column name, may be labeled"""
-COLLECT_CARTESIAN_PRODUCTS = util.symbol(
- "COLLECT_CARTESIAN_PRODUCTS",
- "Collect data on FROMs and cartesian products and gather "
- "into 'self.from_linter'",
- canonical=1,
-)
+ objects: List[Any]
+ """list of objects that should be able to locate this column
+ in a RowMapping. This is typically string names and aliases
+ as well as Column objects.
-WARN_LINTING = util.symbol(
- "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
-)
+ """
+
+ type: TypeEngine[Any]
+ """Datatype to be associated with this column. This is where
+ the "result processing" logic directly links the compiled statement
+ to the rows that come back from the cursor.
+
+ """
+
+
+# integer indexes into ResultColumnsEntry used by cursor.py.
+# some profiling showed integer access faster than named tuple
+RM_RENDERED_NAME: Literal[0] = 0
+RM_NAME: Literal[1] = 1
+RM_OBJECTS: Literal[2] = 2
+RM_TYPE: Literal[3] = 3
+
+
+class ExpandedState(NamedTuple):
+ statement: str
+ additional_parameters: _CoreSingleExecuteParams
+ processors: Mapping[str, _ProcessorType]
+ positiontup: Optional[Sequence[str]]
+ parameter_expansion: Mapping[str, List[str]]
+
+
+class Linting(IntEnum):
+ NO_LINTING = 0
+ "Disable all linting."
+
+ COLLECT_CARTESIAN_PRODUCTS = 1
+ """Collect data on FROMs and cartesian products and gather into
+ 'self.from_linter'"""
+
+ WARN_LINTING = 2
+ "Emit warnings for linters that find problems"
-FROM_LINTING = util.symbol(
- "FROM_LINTING",
- "Warn for cartesian products; "
- "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
- canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
+ FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING
+ """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS
+ and WARN_LINTING"""
+
+
+NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
+ Linting
)
@@ -389,7 +428,7 @@ class Compiled:
_cached_metadata = None
- _result_columns = None
+ _result_columns: Optional[List[ResultColumnsEntry]] = None
schema_translate_map = None
@@ -418,7 +457,8 @@ class Compiled:
"""
cache_key = None
- _gen_time = None
+
+ _gen_time: float
def __init__(
self,
@@ -573,15 +613,43 @@ class SQLCompiler(Compiled):
extract_map = EXTRACT_MAP
+ _result_columns: List[ResultColumnsEntry]
+
compound_keywords = COMPOUND_KEYWORDS
- isdelete = isinsert = isupdate = False
+ isdelete: bool = False
+ isinsert: bool = False
+ isupdate: bool = False
"""class-level defaults which can be set at the instance
level to define if this Compiled instance represents
INSERT/UPDATE/DELETE
"""
- isplaintext = False
+ postfetch: Optional[List[Column[Any]]]
+ """list of columns that can be post-fetched after INSERT or UPDATE to
+ receive server-updated values"""
+
+ insert_prefetch: Optional[List[Column[Any]]]
+ """list of columns for which default values should be evaluated before
+ an INSERT takes place"""
+
+ update_prefetch: Optional[List[Column[Any]]]
+ """list of columns for which onupdate default values should be evaluated
+ before an UPDATE takes place"""
+
+ returning: Optional[List[Column[Any]]]
+ """list of columns that will be delivered to cursor.description or
+ dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
+
+ """
+
+ isplaintext: bool = False
+
+ result_columns: List[ResultColumnsEntry]
+ """relates label names in the final SQL to a tuple of local
+ column/label name, ColumnElement object (if any) and
+ TypeEngine. CursorResult uses this for type processing and
+ column targeting"""
returning = None
"""holds the "returning" collection of columns if
@@ -589,18 +657,18 @@ class SQLCompiler(Compiled):
either implicitly or explicitly
"""
- returning_precedes_values = False
+ returning_precedes_values: bool = False
"""set to True classwide to generate RETURNING
clauses before the VALUES or WHERE clause (i.e. MSSQL)
"""
- render_table_with_column_in_update_from = False
+ render_table_with_column_in_update_from: bool = False
"""set to True classwide to indicate the SET clause
in a multi-table UPDATE statement should qualify
columns with the table name (i.e. MySQL only)
"""
- ansi_bind_rules = False
+ ansi_bind_rules: bool = False
"""SQL 92 doesn't allow bind parameters to be used
in the columns clause of a SELECT, nor does it allow
ambiguous expressions like "? = ?". A compiler
@@ -608,33 +676,33 @@ class SQLCompiler(Compiled):
driver/DB enforces this
"""
- _textual_ordered_columns = False
+ _textual_ordered_columns: bool = False
"""tell the result object that the column names as rendered are important,
but they are also "ordered" vs. what is in the compiled object here.
"""
- _ordered_columns = True
+ _ordered_columns: bool = True
"""
if False, means we can't be sure the list of entries
in _result_columns is actually the rendered order. Usually
True unless using an unordered TextualSelect.
"""
- _loose_column_name_matching = False
+ _loose_column_name_matching: bool = False
"""tell the result object that the SQL statement is textual, wants to match
up to Column objects, and may be using the ._tq_label in the SELECT rather
than the base name.
"""
- _numeric_binds = False
+ _numeric_binds: bool = False
"""
True if paramstyle is "numeric". This paramstyle is trickier than
all the others.
"""
- _render_postcompile = False
+ _render_postcompile: bool = False
"""
whether to render out POSTCOMPILE params during the compile phase.
@@ -684,7 +752,7 @@ class SQLCompiler(Compiled):
"""
- positiontup = None
+ positiontup: Optional[Sequence[str]] = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
@@ -699,7 +767,7 @@ class SQLCompiler(Compiled):
"""
- inline = False
+ inline: bool = False
def __init__(
self,
@@ -760,10 +828,6 @@ class SQLCompiler(Compiled):
# stack which keeps track of nested SELECT statements
self.stack = []
- # relates label names in the final SQL to a tuple of local
- # column/label name, ColumnElement object (if any) and
- # TypeEngine. CursorResult uses this for type processing and
- # column targeting
self._result_columns = []
# true if the paramstyle is positional
@@ -910,7 +974,9 @@ class SQLCompiler(Compiled):
)
@util.memoized_property
- def _bind_processors(self):
+ def _bind_processors(
+ self,
+ ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]:
return dict(
(key, value)
for key, value in (
@@ -1098,8 +1164,10 @@ class SQLCompiler(Compiled):
return self.construct_params(_check=False)
def _process_parameters_for_postcompile(
- self, parameters=None, _populate_self=False
- ):
+ self,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ _populate_self: bool = False,
+ ) -> ExpandedState:
"""handle special post compile parameters.
These include:
@@ -3070,7 +3138,13 @@ class SQLCompiler(Compiled):
def get_render_as_alias_suffix(self, alias_name_text):
return " AS " + alias_name_text
- def _add_to_result_map(self, keyname, name, objects, type_):
+ def _add_to_result_map(
+ self,
+ keyname: str,
+ name: str,
+ objects: List[Any],
+ type_: TypeEngine[Any],
+ ) -> None:
if keyname is None or keyname == "*":
self._ordered_columns = False
self._textual_ordered_columns = True
@@ -3080,7 +3154,9 @@ class SQLCompiler(Compiled):
"from a tuple() object. If this is an ORM query, "
"consider using the Bundle object."
)
- self._result_columns.append((keyname, name, objects, type_))
+ self._result_columns.append(
+ ResultColumnsEntry(keyname, name, objects, type_)
+ )
def _label_returning_column(self, stmt, column, column_clause_args=None):
"""Render a column with necessary labels inside of a RETURNING clause.
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 0c532a135..ac5dc46db 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -61,6 +61,11 @@ if typing.TYPE_CHECKING:
from .selectable import Select
from .sqltypes import Boolean # noqa
from .type_api import TypeEngine
+ from ..engine import Compiled
+ from ..engine import Connection
+ from ..engine import Dialect
+ from ..engine import Engine
+
_NUMERIC = Union[complex, "Decimal"]
@@ -145,7 +150,12 @@ class CompilerElement(Visitable):
@util.preload_module("sqlalchemy.engine.default")
@util.preload_module("sqlalchemy.engine.url")
- def compile(self, bind=None, dialect=None, **kw):
+ def compile(
+ self,
+ bind: Optional[Union[Engine, Connection]] = None,
+ dialect: Optional[Dialect] = None,
+ **kw: Any,
+ ) -> Compiled:
"""Compile this SQL expression.
The return value is a :class:`~.Compiled` object.
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 528691795..fdae4d7b0 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -174,7 +174,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
_use_schema_map = True
-class Table(DialectKWArgs, SchemaItem, TableClause):
+class HasSchemaAttr(SchemaItem):
+ """schema item that includes a top-level schema name"""
+
+ schema: Optional[str]
+
+
+class Table(DialectKWArgs, HasSchemaAttr, TableClause):
r"""Represent a table in a database.
e.g.::
@@ -2850,7 +2856,7 @@ class IdentityOptions:
self.order = order
-class Sequence(IdentityOptions, DefaultGenerator):
+class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
"""Represents a named database sequence.
The :class:`.Sequence` object represents the name and configurational
@@ -4330,7 +4336,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"})
-class MetaData(SchemaItem):
+class MetaData(HasSchemaAttr):
"""A collection of :class:`_schema.Table`
objects and their associated schema
constructs.
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index e3e358cdb..e0248adf0 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -21,9 +21,6 @@ from . import coercions
from . import operators
from . import roles
from . import visitors
-from ._typing import _ExecuteParams
-from ._typing import _MultiExecuteParams
-from ._typing import _SingleExecuteParams
from .annotation import _deep_annotate # noqa
from .annotation import _deep_deannotate # noqa
from .annotation import _shallow_annotate # noqa
@@ -54,6 +51,10 @@ from .. import exc
from .. import util
if typing.TYPE_CHECKING:
+ from ..engine.interfaces import _AnyExecuteParams
+ from ..engine.interfaces import _AnyMultiExecuteParams
+ from ..engine.interfaces import _AnySingleExecuteParams
+ from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.row import Row
@@ -550,12 +551,12 @@ class _repr_params(_repr_base):
def __init__(
self,
- params: _ExecuteParams,
+ params: Optional[_AnyExecuteParams],
batches: int,
max_chars: int = 300,
ismulti: Optional[bool] = None,
):
- self.params: _ExecuteParams = params
+ self.params = params
self.ismulti = ismulti
self.batches = batches
self.max_chars = max_chars
@@ -575,7 +576,10 @@ class _repr_params(_repr_base):
return self.trunc(self.params)
if self.ismulti:
- multi_params = cast(_MultiExecuteParams, self.params)
+ multi_params = cast(
+ "_AnyMultiExecuteParams",
+ self.params,
+ )
if len(self.params) > self.batches:
msg = (
@@ -595,10 +599,18 @@ class _repr_params(_repr_base):
return self._repr_multi(multi_params, typ)
else:
return self._repr_params(
- cast(_SingleExecuteParams, self.params), typ
+ cast(
+ "_AnySingleExecuteParams",
+ self.params,
+ ),
+ typ,
)
- def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str:
+ def _repr_multi(
+ self,
+ multi_params: _AnyMultiExecuteParams,
+ typ,
+ ) -> str:
if multi_params:
if isinstance(multi_params[0], list):
elem_type = self._LIST
@@ -622,13 +634,19 @@ class _repr_params(_repr_base):
else:
return "(%s)" % elements
- def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str:
+ def _repr_params(
+ self,
+ params: Optional[_AnySingleExecuteParams],
+ typ: int,
+ ) -> str:
trunc = self.trunc
if typ is self._DICT:
return "{%s}" % (
", ".join(
"%r: %s" % (key, trunc(value))
- for key, value in params.items()
+ for key, value in cast(
+ "_CoreSingleExecuteParams", params
+ ).items()
)
)
elif typ is self._TUPLE:
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 523426d09..111ecd32e 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -28,6 +28,8 @@ from __future__ import annotations
from collections import deque
import itertools
import operator
+import typing
+from typing import Any
from typing import List
from typing import Tuple
@@ -35,12 +37,13 @@ from .. import exc
from .. import util
from ..util import langhelpers
from ..util import symbol
+from ..util._has_cy import HAS_CYEXTENSION
from ..util.langhelpers import _symbol
-try:
- from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
-except ImportError:
+if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import cache_anon_map as anon_map # noqa
+else:
+ from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa
__all__ = [
"iterate",
@@ -554,7 +557,7 @@ class ExternalTraversal:
__traverse_options__ = {}
- def traverse_single(self, obj, **kw):
+ def traverse_single(self, obj: Visitable, **kw: Any) -> Any:
for v in self.visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index a41420504..e5cf9b92e 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -94,7 +94,6 @@ from .langhelpers import decode_slice as decode_slice
from .langhelpers import decorator as decorator
from .langhelpers import dictlike_iteritems as dictlike_iteritems
from .langhelpers import duck_type_collection as duck_type_collection
-from .langhelpers import dynamic_property as dynamic_property
from .langhelpers import ellipses_string as ellipses_string
from .langhelpers import EnsureKWArg as EnsureKWArg
from .langhelpers import format_argspec_init as format_argspec_init
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 84735316d..e0b53b445 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -513,7 +513,12 @@ class LRUCache(typing.MutableMapping[_KT, _VT]):
threshold: float
size_alert: Optional[Callable[["LRUCache[_KT, _VT]"], None]]
- def __init__(self, capacity=100, threshold=0.5, size_alert=None):
+ def __init__(
+ self,
+ capacity: int = 100,
+ threshold: float = 0.5,
+ size_alert: Optional[Callable[..., None]] = None,
+ ):
self.capacity = capacity
self.threshold = threshold
self.size_alert = size_alert
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index ee54180ac..771e974e9 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -15,9 +15,11 @@ from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
+from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Set
+from typing import Tuple
from typing import TypeVar
from typing import Union
@@ -65,13 +67,15 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
dict.__init__(new, *args)
return new
- def __init__(self, *args):
+ def __init__(self, *args: Union[Mapping[_KT, _VT], Tuple[_KT, _VT]]):
pass
def __reduce__(self):
return immutabledict, (dict(self),)
- def union(self, __d=None):
+ def union(
+ self, __d: Optional[Mapping[_KT, _VT]] = None
+ ) -> immutabledict[_KT, _VT]:
if not __d:
return self
@@ -80,7 +84,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
dict.update(new, __d)
return new
- def _union_w_kw(self, __d=None, **kw):
+ def _union_w_kw(
+ self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT
+ ) -> immutabledict[_KT, _VT]:
# not sure if C version works correctly w/ this yet
if not __d and not kw:
return self
@@ -92,7 +98,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
dict.update(new, kw) # type: ignore
return new
- def merge_with(self, *dicts):
+ def merge_with(
+ self, *dicts: Optional[Mapping[_KT, _VT]]
+ ) -> immutabledict[_KT, _VT]:
new = None
for d in dicts:
if d:
@@ -105,7 +113,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
return new
- def __repr__(self):
+ def __repr__(self) -> str:
return "immutabledict(%s)" % dict.__repr__(self)
diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py
index 7e1d3213a..a8e58a8bf 100644
--- a/lib/sqlalchemy/util/deprecations.py
+++ b/lib/sqlalchemy/util/deprecations.py
@@ -28,7 +28,6 @@ from . import compat
from .langhelpers import _hash_limit_string
from .langhelpers import _warnings_warn
from .langhelpers import decorator
-from .langhelpers import dynamic_property
from .langhelpers import inject_docstring_text
from .langhelpers import inject_param_text
from .. import exc
@@ -103,7 +102,7 @@ def deprecated_property(
add_deprecation_to_docstring: bool = True,
warning: Optional[Type[exc.SADeprecationWarning]] = None,
enable_warnings: bool = True,
-) -> Callable[[Callable[..., _T]], dynamic_property[_T]]:
+) -> Callable[[Callable[..., Any]], property]:
"""the @deprecated decorator with a @property.
E.g.::
@@ -131,8 +130,8 @@ def deprecated_property(
"""
- def decorate(fn: Callable[..., _T]) -> dynamic_property[_T]:
- return dynamic_property(
+ def decorate(fn: Callable[..., Any]) -> property:
+ return property(
deprecated(
version,
message=message,
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index 43f9d5c73..1e79fd547 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -317,15 +317,17 @@ _P = compat_typing.ParamSpec("_P")
class PluginLoader:
- def __init__(self, group, auto_fn=None):
+ def __init__(
+ self, group: str, auto_fn: Optional[Callable[..., Any]] = None
+ ):
self.group = group
- self.impls = {}
+ self.impls: Dict[str, Any] = {}
self.auto_fn = auto_fn
def clear(self):
self.impls.clear()
- def load(self, name):
+ def load(self, name: str) -> Any:
if name in self.impls:
return self.impls[name]()
@@ -344,7 +346,7 @@ class PluginLoader:
"Can't load plugin: %s:%s" % (self.group, name)
)
- def register(self, name, modulepath, objname):
+ def register(self, name: str, modulepath: str, objname: str) -> None:
def load():
mod = __import__(modulepath)
for token in modulepath.split(".")[1:]:
@@ -444,7 +446,7 @@ def get_cls_kwargs(
return _set
-def get_func_kwargs(func):
+def get_func_kwargs(func: Callable[..., Any]) -> List[str]:
"""Return the set of legal kwargs for the given `func`.
Uses getargspec so is safe to call for methods, functions,
@@ -1125,22 +1127,13 @@ def as_interface(obj, cls=None, methods=None, required=None):
)
-Selfdynamic_property = TypeVar(
- "Selfdynamic_property", bound="dynamic_property[Any]"
-)
-
Selfmemoized_property = TypeVar(
"Selfmemoized_property", bound="memoized_property[Any]"
)
-class dynamic_property(Generic[_T]):
- """A read-only @property that is evaluated each time.
-
- This is mostly the same as @property except we can type it
- alongside memoized_property
-
- """
+class memoized_property(Generic[_T]):
+ """A read-only @property that is only evaluated once."""
fget: Callable[..., _T]
__doc__: Optional[str]
@@ -1153,27 +1146,6 @@ class dynamic_property(Generic[_T]):
@overload
def __get__(
- self: Selfdynamic_property, obj: None, cls: Any
- ) -> Selfdynamic_property:
- ...
-
- @overload
- def __get__(self, obj: Any, cls: Any) -> _T:
- ...
-
- def __get__(
- self: Selfdynamic_property, obj: Any, cls: Any
- ) -> Union[Selfdynamic_property, _T]:
- if obj is None:
- return self
- return self.fget(obj) # type: ignore[no-any-return]
-
-
-class memoized_property(dynamic_property[_T]):
- """A read-only @property that is only evaluated once."""
-
- @overload
- def __get__(
self: Selfmemoized_property, obj: None, cls: Any
) -> Selfmemoized_property:
...
@@ -1231,24 +1203,27 @@ def memoized_instancemethod(fn):
class HasMemoized:
- """A class that maintains the names of memoized elements in a
+ """A mixin class that maintains the names of memoized elements in a
collection for easy cache clearing, generative, etc.
"""
- __slots__ = ()
+ if not typing.TYPE_CHECKING:
+ # support classes that want to have __slots__ with an explicit
+ # slot for __dict__. not sure if that requires base __slots__ here.
+ __slots__ = ()
_memoized_keys: FrozenSet[str] = frozenset()
- def _reset_memoizations(self):
+ def _reset_memoizations(self) -> None:
for elem in self._memoized_keys:
self.__dict__.pop(elem, None)
- def _assert_no_memoizations(self):
+ def _assert_no_memoizations(self) -> None:
for elem in self._memoized_keys:
assert elem not in self.__dict__
- def _set_memoized_attribute(self, key, value):
+ def _set_memoized_attribute(self, key: str, value: Any) -> None:
self.__dict__[key] = value
self._memoized_keys |= {key}
@@ -1342,7 +1317,7 @@ class MemoizedSlots:
# from paste.deploy.converters
-def asbool(obj):
+def asbool(obj: Any) -> bool:
if isinstance(obj, str):
obj = obj.strip().lower()
if obj in ["true", "yes", "on", "y", "t", "1"]:
@@ -1354,13 +1329,13 @@ def asbool(obj):
return bool(obj)
-def bool_or_str(*text):
+def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]:
"""Return a callable that will evaluate a string as
boolean, or one of a set of "alternate" string values.
"""
- def bool_or_value(obj):
+ def bool_or_value(obj: str) -> Union[str, bool]:
if obj in text:
return obj
else:
@@ -1369,7 +1344,7 @@ def bool_or_str(*text):
return bool_or_value
-def asint(value):
+def asint(value: Any) -> Optional[int]:
"""Coerce to integer."""
if value is None:
@@ -1377,7 +1352,13 @@ def asint(value):
return int(value)
-def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None):
+def coerce_kw_type(
+ kw: Dict[str, Any],
+ key: str,
+ type_: Type[Any],
+ flexi_bool: bool = True,
+ dest: Optional[Dict[str, Any]] = None,
+) -> None:
r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
necessary. If 'flexi_bool' is True, the string '0' is considered false
when coercing to boolean.
@@ -1397,7 +1378,7 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None):
dest[key] = type_(kw[key])
-def constructor_key(obj, cls):
+def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]:
"""Produce a tuple structure that is cacheable using the __dict__ of
obj to retrieve values
@@ -1408,7 +1389,7 @@ def constructor_key(obj, cls):
)
-def constructor_copy(obj, cls, *args, **kw):
+def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T:
"""Instantiate cls using the __dict__ of obj as constructor arguments.
Uses inspect to match the named arguments of ``cls``.
@@ -1422,7 +1403,7 @@ def constructor_copy(obj, cls, *args, **kw):
return cls(*args, **kw)
-def counter():
+def counter() -> Callable[[], int]:
"""Return a threadsafe counter function."""
lock = threading.Lock()
@@ -1436,47 +1417,51 @@ def counter():
return _next
-def duck_type_collection(specimen, default=None):
+def duck_type_collection(
+ specimen: Union[object, Type[Any]], default: Optional[Type[Any]] = None
+) -> Type[Any]:
"""Given an instance or class, guess if it is or is acting as one of
the basic collection types: list, set and dict. If the __emulates__
property is present, return that preferentially.
"""
+ if typing.TYPE_CHECKING:
+ return object
+ else:
+ if hasattr(specimen, "__emulates__"):
+ # canonicalize set vs sets.Set to a standard: the builtin set
+ if specimen.__emulates__ is not None and issubclass(
+ specimen.__emulates__, set
+ ):
+ return set
+ else:
+ return specimen.__emulates__
- if hasattr(specimen, "__emulates__"):
- # canonicalize set vs sets.Set to a standard: the builtin set
- if specimen.__emulates__ is not None and issubclass(
- specimen.__emulates__, set
- ):
+ isa = isinstance(specimen, type) and issubclass or isinstance
+ if isa(specimen, list):
+ return list
+ elif isa(specimen, set):
return set
+ elif isa(specimen, dict):
+ return dict
+
+ if hasattr(specimen, "append"):
+ return list
+ elif hasattr(specimen, "add"):
+ return set
+ elif hasattr(specimen, "set"):
+ return dict
else:
- return specimen.__emulates__
-
- isa = isinstance(specimen, type) and issubclass or isinstance
- if isa(specimen, list):
- return list
- elif isa(specimen, set):
- return set
- elif isa(specimen, dict):
- return dict
-
- if hasattr(specimen, "append"):
- return list
- elif hasattr(specimen, "add"):
- return set
- elif hasattr(specimen, "set"):
- return dict
- else:
- return default
+ return default
-def assert_arg_type(arg, argtype, name):
+def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any:
if isinstance(arg, argtype):
return arg
else:
if isinstance(argtype, tuple):
raise exc.ArgumentError(
"Argument '%s' is expected to be one of type %s, got '%s'"
- % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
+ % (name, " or ".join("'%s'" % a for a in argtype), type(arg)) # type: ignore # noqa E501
)
else:
raise exc.ArgumentError(
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index ddda420db..ad9c8e531 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -13,7 +13,7 @@ from typing import Type
from typing import TypeVar
from typing import Union
-from typing_extensions import NotRequired # noqa
+from typing_extensions import NotRequired as NotRequired # noqa
from . import compat
diff --git a/pyproject.toml b/pyproject.toml
index e79c7292d..b2754b193 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,14 @@ markers = [
[tool.pyright]
include = [
+ "lib/sqlalchemy/engine/base.py",
+ "lib/sqlalchemy/engine/events.py",
+ "lib/sqlalchemy/engine/interfaces.py",
+ "lib/sqlalchemy/engine/_py_row.py",
+ "lib/sqlalchemy/engine/result.py",
+ "lib/sqlalchemy/engine/row.py",
+ "lib/sqlalchemy/engine/util.py",
+ "lib/sqlalchemy/engine/url.py",
"lib/sqlalchemy/pool/",
"lib/sqlalchemy/event/",
"lib/sqlalchemy/events.py",
@@ -79,9 +87,19 @@ strict = true
# the whole library 100% strictly typed, so we have to tune this based on
# the type of module or package we are dealing with
+[[tool.mypy.overrides]]
+# ad-hoc ignores
+module = [
+ "sqlalchemy.engine.reflection", # interim, should be strict
+]
+
+ignore_errors = true
+
# strict checking
[[tool.mypy.overrides]]
module = [
+ "sqlalchemy.connectors.*",
+ "sqlalchemy.engine.*",
"sqlalchemy.pool.*",
"sqlalchemy.event.*",
"sqlalchemy.events",
@@ -95,11 +113,16 @@ strict = true
# partial checking, internals can be untyped
[[tool.mypy.overrides]]
-module="sqlalchemy.util.*"
+
+module = [
+ "sqlalchemy.util.*",
+ "sqlalchemy.engine.cursor",
+ "sqlalchemy.engine.default",
+]
+
+
ignore_errors = false
-# util is for internal use so we can get by without everything
-# being typed
allow_untyped_defs = true
check_untyped_defs = false
allow_untyped_calls = true
diff --git a/test/base/test_result.py b/test/base/test_result.py
index 8818ccb14..7a696d352 100644
--- a/test/base/test_result.py
+++ b/test/base/test_result.py
@@ -1,7 +1,6 @@
from sqlalchemy import exc
from sqlalchemy import testing
from sqlalchemy.engine import result
-from sqlalchemy.engine.row import Row
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
@@ -728,27 +727,6 @@ class ResultTest(fixtures.TestBase):
# still slices
eq_(m1.fetchone(), {"b": 1, "c": 2})
- def test_alt_row_fetch(self):
- class AppleRow(Row):
- def apple(self):
- return "apple"
-
- result = self._fixture(alt_row=AppleRow)
-
- row = result.all()[0]
- eq_(row.apple(), "apple")
-
- def test_alt_row_transform(self):
- class AppleRow(Row):
- def apple(self):
- return "apple"
-
- result = self._fixture(alt_row=AppleRow)
-
- row = result.columns("c", "a").all()[2]
- eq_(row.apple(), "apple")
- eq_(row, (2, 1))
-
def test_scalar_none_iterate(self):
result = self._fixture(
data=[
diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py
index b40981a99..d54a37ceb 100644
--- a/test/dialect/mssql/test_engine.py
+++ b/test/dialect/mssql/test_engine.py
@@ -34,19 +34,19 @@ class ParseConnectTest(fixtures.TestBase):
dialect = pyodbc.dialect()
u = url.make_url("mssql+pyodbc://mydsn")
connection = dialect.create_connect_args(u)
- eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection)
+ eq_((("dsn=mydsn;Trusted_Connection=Yes",), {}), connection)
def test_pyodbc_connect_old_style_dsn_trusted(self):
dialect = pyodbc.dialect()
u = url.make_url("mssql+pyodbc:///?dsn=mydsn")
connection = dialect.create_connect_args(u)
- eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection)
+ eq_((("dsn=mydsn;Trusted_Connection=Yes",), {}), connection)
def test_pyodbc_connect_dsn_non_trusted(self):
dialect = pyodbc.dialect()
u = url.make_url("mssql+pyodbc://username:password@mydsn")
connection = dialect.create_connect_args(u)
- eq_([["dsn=mydsn;UID=username;PWD=password"], {}], connection)
+ eq_((("dsn=mydsn;UID=username;PWD=password",), {}), connection)
def test_pyodbc_connect_dsn_extra(self):
dialect = pyodbc.dialect()
@@ -66,13 +66,13 @@ class ParseConnectTest(fixtures.TestBase):
)
connection = dialect.create_connect_args(u)
eq_(
- [
- [
+ (
+ (
"DRIVER={SQL Server};Server=hostspec;Database=database;UI"
- "D=username;PWD=password"
- ],
+ "D=username;PWD=password",
+ ),
{},
- ],
+ ),
connection,
)
@@ -99,13 +99,13 @@ class ParseConnectTest(fixtures.TestBase):
)
eq_(
- [
- [
+ (
+ (
"Server=hostspec;Database=database;UI"
- "D=username;PWD=password"
- ],
+ "D=username;PWD=password",
+ ),
{},
- ],
+ ),
connection,
)
@@ -117,13 +117,13 @@ class ParseConnectTest(fixtures.TestBase):
)
connection = dialect.create_connect_args(u)
eq_(
- [
- [
+ (
+ (
"DRIVER={SQL Server};Server=hostspec,12345;Database=datab"
- "ase;UID=username;PWD=password"
- ],
+ "ase;UID=username;PWD=password",
+ ),
{},
- ],
+ ),
connection,
)
@@ -135,13 +135,13 @@ class ParseConnectTest(fixtures.TestBase):
)
connection = dialect.create_connect_args(u)
eq_(
- [
- [
+ (
+ (
"DRIVER={SQL Server};Server=hostspec;Database=database;UI"
- "D=username;PWD=password;port=12345"
- ],
+ "D=username;PWD=password;port=12345",
+ ),
{},
- ],
+ ),
connection,
)
@@ -193,13 +193,13 @@ class ParseConnectTest(fixtures.TestBase):
)
connection = dialect.create_connect_args(u)
eq_(
- [
- [
+ (
+ (
"DRIVER={SQL Server};Server=hostspec;Database=database;UI"
- "D=username;PWD=password"
- ],
+ "D=username;PWD=password",
+ ),
{},
- ],
+ ),
connection,
)
@@ -211,7 +211,7 @@ class ParseConnectTest(fixtures.TestBase):
)
connection = dialect.create_connect_args(u)
eq_(
- [["dsn=mydsn;Database=database;UID=username;PWD=password"], {}],
+ (("dsn=mydsn;Database=database;UID=username;PWD=password",), {}),
connection,
)
@@ -225,13 +225,13 @@ class ParseConnectTest(fixtures.TestBase):
)
connection = dialect.create_connect_args(u)
eq_(
- [
- [
+ (
+ (
"DRIVER={SQL Server};Server=hostspec;Database=database;UI"
- "D=username;PWD=password"
- ],
+ "D=username;PWD=password",
+ ),
{},
- ],
+ ),
connection,
)
@@ -248,14 +248,14 @@ class ParseConnectTest(fixtures.TestBase):
dialect = pyodbc.dialect()
connection = dialect.create_connect_args(u)
eq_(
- [
- [
+ (
+ (
"DRIVER={foob};Server=somehost%3BPORT%3D50001;"
"Database=somedb%3BPORT%3D50001;UID={someuser;PORT=50001};"
- "PWD={some{strange}}pw;PORT=50001}"
- ],
+ "PWD={some{strange}}pw;PORT=50001}",
+ ),
{},
- ],
+ ),
connection,
)
@@ -265,7 +265,7 @@ class ParseConnectTest(fixtures.TestBase):
u = url.make_url("mssql+pymssql://scott:tiger@somehost/test")
connection = dialect.create_connect_args(u)
eq_(
- [
+ (
[],
{
"host": "somehost",
@@ -273,14 +273,14 @@ class ParseConnectTest(fixtures.TestBase):
"user": "scott",
"database": "test",
},
- ],
+ ),
connection,
)
u = url.make_url("mssql+pymssql://scott:tiger@somehost:5000/test")
connection = dialect.create_connect_args(u)
eq_(
- [
+ (
[],
{
"host": "somehost:5000",
@@ -288,7 +288,7 @@ class ParseConnectTest(fixtures.TestBase):
"user": "scott",
"database": "test",
},
- ],
+ ),
connection,
)
@@ -584,7 +584,9 @@ class VersionDetectionTest(fixtures.TestBase):
)
)
),
- connection=Mock(getinfo=Mock(return_value=vers)),
+ connection=Mock(
+ dbapi_connection=Mock(getinfo=Mock(return_value=vers)),
+ ),
)
eq_(dialect._get_server_version_info(conn), expected)
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index 613fc80a5..5a92ae6fe 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -368,11 +368,11 @@ class ExecuteManyMode:
mock.call(
mock.ANY,
stmt,
- (
+ [
{"x": "x1", "y": "y1"},
{"x": "x2", "y": "y2"},
{"x": "x3", "y": "y3"},
- ),
+ ],
**expected_kwargs,
)
],
@@ -417,11 +417,11 @@ class ExecuteManyMode:
mock.call(
mock.ANY,
stmt,
- (
+ [
{"x": "x1", "y": "y1"},
{"x": "x2", "y": "y2"},
{"x": "x3", "y": "y3"},
- ),
+ ],
**expected_kwargs,
)
],
@@ -470,11 +470,11 @@ class ExecuteManyMode:
mock.call(
mock.ANY,
stmt,
- (
+ [
{"x": "x1", "y": "y1"},
{"x": "x2", "y": "y2"},
{"x": "x3", "y": "y3"},
- ),
+ ],
**expected_kwargs,
)
],
@@ -524,10 +524,10 @@ class ExecuteManyMode:
mock.call(
mock.ANY,
stmt,
- (
+ [
{"xval": "x1", "yval": "y5"},
{"xval": "x3", "yval": "y6"},
- ),
+ ],
**expected_kwargs,
)
],
@@ -714,11 +714,11 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
mock.call(
mock.ANY,
"INSERT INTO data (id, x, y, z) VALUES %s",
- (
+ [
{"id": 1, "y": "y1", "z": 1},
{"id": 2, "y": "y2", "z": 2},
{"id": 3, "y": "y3", "z": 3},
- ),
+ ],
template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)",
fetch=False,
page_size=connection.dialect.executemany_values_page_size,
diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py
index 8dc1f0f48..e1c610701 100644
--- a/test/engine/test_deprecations.py
+++ b/test/engine/test_deprecations.py
@@ -87,6 +87,93 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
class CreateEngineTest(fixtures.TestBase):
+ @testing.requires.sqlite
+ def test_dbapi_clsmethod_renamed(self):
+ """The dbapi() class method is renamed to import_dbapi(),
+ so that the .dbapi attribute can be exclusively an instance
+ attribute.
+
+ """
+
+ from sqlalchemy.dialects.sqlite import pysqlite
+ from sqlalchemy.dialects import registry
+
+ canary = mock.Mock()
+
+ class MyDialect(pysqlite.SQLiteDialect_pysqlite):
+ @classmethod
+ def dbapi(cls):
+ canary()
+ return __import__("sqlite3")
+
+ tokens = __name__.split(".")
+
+ global dialect
+ dialect = MyDialect
+
+ registry.register(
+ "mockdialect1.sqlite", ".".join(tokens[0:-1]), tokens[-1]
+ )
+
+ with expect_deprecated(
+ r"The dbapi\(\) classmethod on dialect classes has "
+ r"been renamed to import_dbapi\(\). Implement an "
+ r"import_dbapi\(\) classmethod directly on class "
+ r".*MyDialect.* to remove this warning; the old "
+ r".dbapi\(\) classmethod may be maintained for backwards "
+ r"compatibility."
+ ):
+ e = create_engine("mockdialect1+sqlite://")
+
+ eq_(canary.mock_calls, [mock.call()])
+ sqlite3 = __import__("sqlite3")
+ is_(e.dialect.dbapi, sqlite3)
+
+ @testing.requires.sqlite
+ def test_no_warning_for_dual_dbapi_clsmethod(self):
+ """The dbapi() class method is renamed to import_dbapi(),
+ so that the .dbapi attribute can be exclusively an instance
+ attribute.
+
+ Dialect classes will likely have both a dbapi() classmethod
+ as well as an import_dbapi() class method to maintain
+ cross-compatibility. Make sure these updated classes don't get a
+ warning and that the new method is used.
+
+ """
+
+ from sqlalchemy.dialects.sqlite import pysqlite
+ from sqlalchemy.dialects import registry
+
+ canary = mock.Mock()
+
+ class MyDialect(pysqlite.SQLiteDialect_pysqlite):
+ @classmethod
+ def dbapi(cls):
+ canary.dbapi()
+ return __import__("sqlite3")
+
+ @classmethod
+ def import_dbapi(cls):
+ canary.import_dbapi()
+ return __import__("sqlite3")
+
+ tokens = __name__.split(".")
+
+ global dialect
+ dialect = MyDialect
+
+ registry.register(
+ "mockdialect2.sqlite", ".".join(tokens[0:-1]), tokens[-1]
+ )
+
+ # no warning
+ e = create_engine("mockdialect2+sqlite://")
+
+ eq_(canary.mock_calls, [mock.call.import_dbapi()])
+ sqlite3 = __import__("sqlite3")
+ is_(e.dialect.dbapi, sqlite3)
+
def test_strategy_keyword_mock(self):
def executor(x, y):
pass
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 59bc4863f..dbd957703 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -2285,6 +2285,40 @@ class EngineEventsTest(fixtures.TestBase):
[call(c2, {"c1": "opt_c1"}), call(c4, {"c3": "opt_c3"})],
)
+ def test_execution_options_modify_inplace(self):
+ engine = engines.testing_engine()
+
+ @event.listens_for(engine, "set_engine_execution_options")
+ def engine_tracker(conn, opt):
+ opt["engine_tracked"] = True
+
+ @event.listens_for(engine, "set_connection_execution_options")
+ def conn_tracker(conn, opt):
+ opt["conn_tracked"] = True
+
+ with mock.patch.object(
+ engine.dialect, "set_connection_execution_options"
+ ) as conn_opt, mock.patch.object(
+ engine.dialect, "set_engine_execution_options"
+ ) as engine_opt:
+ e2 = engine.execution_options(e1="opt_e1")
+ c1 = engine.connect()
+ c2 = c1.execution_options(c1="opt_c1")
+
+ is_not(e2, engine)
+ is_(c1, c2)
+
+ eq_(e2._execution_options, {"e1": "opt_e1", "engine_tracked": True})
+ eq_(c2._execution_options, {"c1": "opt_c1", "conn_tracked": True})
+ eq_(
+ engine_opt.mock_calls,
+ [mock.call(e2, {"e1": "opt_e1", "engine_tracked": True})],
+ )
+ eq_(
+ conn_opt.mock_calls,
+ [mock.call(c1, {"c1": "opt_c1", "conn_tracked": True})],
+ )
+
@testing.requires.sequences
@testing.provide_metadata
def test_cursor_execute(self):
diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py
index 4c378dda1..23be61aaf 100644
--- a/test/engine/test_parseconnect.py
+++ b/test/engine/test_parseconnect.py
@@ -1111,7 +1111,7 @@ class TestGetDialect(fixtures.TestBase):
class MockDialect(DefaultDialect):
@classmethod
- def dbapi(cls, **kw):
+ def import_dbapi(cls, **kw):
return MockDBAPI()
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index c1f0639bb..b8d9a3618 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -1023,6 +1023,23 @@ class RealReconnectTest(fixtures.TestBase):
eq_(conn.execute(select(1)).scalar(), 1)
assert not conn.invalidated
+ def test_detach_invalidated(self):
+ with self.engine.connect() as conn:
+ conn.invalidate()
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "Can't detach an invalidated Connection",
+ ):
+ conn.detach()
+
+ def test_detach_closed(self):
+ with self.engine.connect() as conn:
+ pass
+ with expect_raises_message(
+ exc.ResourceClosedError, "This Connection is closed"
+ ):
+ conn.detach()
+
@testing.requires.independent_connections
def test_multiple_invalidate(self):
c1 = self.engine.connect()
@@ -1078,8 +1095,23 @@ class RealReconnectTest(fixtures.TestBase):
conn.begin()
trans2 = conn.begin_nested()
conn.invalidate()
+
+ # this passes silently, as it will often be involved
+ # in error catching schemes
trans2.rollback()
+ # still invalid though
+ with expect_raises(exc.PendingRollbackError):
+ conn.begin_nested()
+
+ def test_no_begin_on_invalid(self):
+ with self.engine.connect() as conn:
+ conn.begin()
+ conn.invalidate()
+
+ with expect_raises(exc.PendingRollbackError):
+ conn.commit()
+
def test_invalidate_twice(self):
with self.engine.connect() as conn:
conn.invalidate()
diff --git a/test/profiles.txt b/test/profiles.txt
index 67f155eba..750b57780 100644
--- a/test/profiles.txt
+++ b/test/profiles.txt
@@ -98,48 +98,48 @@ test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 50235
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 62055
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 50735
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61045
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 49335
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61155
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 49435
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 59745
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53235
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 63155
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53335
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61745
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 52335
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 62255
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 52435
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 60845
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 45435
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 49255
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 45035
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 48345
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 48235
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 57555
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 48335
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 56145
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 47335
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 56655
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 47435
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 55245
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 34205
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37905
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 33805
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37005
# TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 33305
-test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37005
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 32905
+test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 36105
# TEST: test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set
@@ -163,13 +163,13 @@ test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching
# TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline
-test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15227
-test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 34246
+test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15313
+test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26332
# TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols
-test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 21393
-test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 28412
+test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 21377
+test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26396
# TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased
@@ -188,23 +188,23 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpy
# TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d
-test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 98439
-test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 103939
+test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 98506
+test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 104006
# TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased
-test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96819
-test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102304
+test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96844
+test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102344
# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520593
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522453
+test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520615
+test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522475
# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 431505
-test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 464305
+test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 431905
+test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 450605
# TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity
@@ -213,18 +213,18 @@ test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_
# TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity
-test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 106782
-test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 115789
+test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 106870
+test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 115127
# TEST: test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks
-test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 19931
-test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 21463
+test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 20030
+test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 21434
# TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_load
-test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1362
-test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1458
+test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1366
+test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1455
# TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_no_load
@@ -233,18 +233,18 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3.
# TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols
-test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6109
-test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 7329
+test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6167
+test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 6987
# TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results
-test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 258605
-test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 281805
+test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 259205
+test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 278405
# TEST: test.aaa_profiling.test_orm.SessionTest.test_expire_lots
-test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1276
-test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1268
+test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1252
+test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1260
# TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect
diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py
index e5b1a0a26..ff70fc184 100644
--- a/test/sql/test_resultset.py
+++ b/test/sql/test_resultset.py
@@ -1612,15 +1612,15 @@ class CursorResultTest(fixtures.TablesTest):
eq_(dict(row._mapping), {"a": "av", "b": "bv", "count": "cv"})
- with assertions.expect_raises_message(
+ with assertions.expect_raises(
TypeError,
- "TypeError: tuple indices must be integers or slices, not str",
+ "tuple indices must be integers or slices, not str",
):
eq_(row["a"], "av")
with assertions.expect_raises_message(
TypeError,
- "TypeError: tuple indices must be integers or slices, not str",
+ "tuple indices must be integers or slices, not str",
):
eq_(row["count"], "cv")
@@ -3197,8 +3197,7 @@ class GenerativeResultTest(fixtures.TablesTest):
all_ = result.columns(*columns).all()
eq_(all_, expected)
- # ensure Row / LegacyRow comes out with .columns
- assert type(all_[0]) is result._process_row
+ assert type(all_[0]) is Row
def test_columns_twice(self, connection):
users = self.tables.users
@@ -3216,8 +3215,7 @@ class GenerativeResultTest(fixtures.TablesTest):
)
eq_(all_, [("jack", 1)])
- # ensure Row / LegacyRow comes out with .columns
- assert type(all_[0]) is result._process_row
+ assert type(all_[0]) is Row
def test_columns_plus_getter(self, connection):
users = self.tables.users