diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/aiomysql.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/__init__.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/aiosqlite.py | 331 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/provision.py | 43 | ||||
| -rw-r--r-- | lib/sqlalchemy/pool/base.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/pool/impl.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/engines.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_dialect.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_results.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/concurrency.py | 2 |
12 files changed, 456 insertions, 31 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index cab6df499..c8c7c0f97 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -82,6 +82,13 @@ class AsyncAdapt_aiomysql_cursor: return self._cursor.lastrowid def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. self._rows[:] = [] def execute(self, operation, parameters=None): diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py index d12203cbd..8b24a19fd 100644 --- a/lib/sqlalchemy/dialects/sqlite/__init__.py +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -26,6 +26,10 @@ from .base import TIMESTAMP from .base import VARCHAR from .dml import Insert from .dml import insert +from ...util import compat + +if compat.py3k: + from . import aiosqlite # noqa # default dialect base.dialect = dialect = pysqlite.dialect diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py new file mode 100644 index 000000000..e4b7d1d52 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -0,0 +1,331 @@ +# sqlite/aiosqlite.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +r""" + +.. dialect:: sqlite+aiosqlite + :name: aiosqlite + :dbapi: aiosqlite + :connectstring: sqlite+aiosqlite:///file_path + :url: https://pypi.org/project/aiosqlite/ + +The aiosqlite dialect provides support for the SQLAlchemy asyncio interface +running on top of pysqlite. + +aiosqlite is a wrapper around pysqlite that uses a background thread for +each connection. It does not actually use non-blocking IO, as SQLite +databases are not socket-based. However it does provide a working asyncio +interface that's useful for testing and prototyping purposes. + +Using a special asyncio mediation layer, the aiosqlite dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("sqlite+aiosqlite:///filename") + +The URL passes through all arguments to the ``pysqlite`` driver, so all +connection arguments are the same as they are for that of :ref:`pysqlite`. + + +""" # noqa + +from .base import SQLiteExecutionContext +from .pysqlite import SQLiteDialect_pysqlite +from ... import pool +from ... import util +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiosqlite_cursor: + __slots__ = ( + "_adapt_connection", + "_connection", + "description", + "await_", + "_rows", + "arraysize", + "rowcount", + "lastrowid", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + self.arraysize = 1 + self.rowcount = -1 + self.description = None + self._rows = [] + + def close(self): + self._rows[:] = [] + + def execute(self, operation, parameters=None): + try: + _cursor = self.await_(self._connection.cursor()) + + if parameters is None: + self.await_(_cursor.execute(operation)) + else: + self.await_(_cursor.execute(operation, parameters)) + + if _cursor.description: + self.description = _cursor.description + self.lastrowid = self.rowcount = -1 + + if not self.server_side: + self._rows = self.await_(_cursor.fetchall()) + else: + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + + if not self.server_side: + self.await_(_cursor.close()) + else: + self._cursor = _cursor + except Exception as error: + self._adapt_connection._handle_exception(error) + + def executemany(self, operation, seq_of_parameters): + try: + _cursor = self.await_(self._connection.cursor()) + self.await_(_cursor.executemany(operation, seq_of_parameters)) + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + self.await_(_cursor.close()) + except Exception as error: + self._adapt_connection._handle_exception(error) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor): + __slots__ = "_cursor" + + server_side = True + + def __init__(self, *arg, **kw): + super().__init__(*arg, **kw) + self._cursor = None + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiosqlite_connection: + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_connection") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + + @property + def isolation_level(self): + return self._connection.isolation_level + + @isolation_level.setter + def isolation_level(self, value): + try: + self._connection.isolation_level = value + except Exception as error: + self._handle_exception(error) + + def create_function(self, *args, **kw): + try: + self.await_(self._connection.create_function(*args, **kw)) + except Exception as error: + self._handle_exception(error) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiosqlite_ss_cursor(self) + else: + return AsyncAdapt_aiosqlite_cursor(self) + + def execute(self, *args, **kw): + return self.await_(self._connection.execute(*args, **kw)) + + def rollback(self): + try: + self.await_(self._connection.rollback()) + except Exception as error: + self._handle_exception(error) + + def commit(self): + try: + self.await_(self._connection.commit()) + except Exception as error: + self._handle_exception(error) + + def close(self): + # print(">close", self) + try: + self.await_(self._connection.close()) + except Exception as error: + self._handle_exception(error) + + def _handle_exception(self, error): + if ( + isinstance(error, ValueError) + and error.args[0] == "no active connection" + ): + util.raise_( + self.dbapi.sqlite.OperationalError("no active connection"), + from_=error, + ) + else: + raise error + + +class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiosqlite_dbapi: + def __init__(self, aiosqlite, sqlite): + self.aiosqlite = aiosqlite + self.sqlite = sqlite + self.paramstyle = "qmark" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "DatabaseError", + "Error", + "IntegrityError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "sqlite_version", + "sqlite_version_info", + ): + setattr(self, name, getattr(self.aiosqlite, name)) + + for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"): + setattr(self, name, getattr(self.sqlite, name)) + + for name in ("Binary",): + setattr(self, name, getattr(self.sqlite, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + # Q. WHY do we need this? + # A. Because there is no way to set connection.isolation_level + # otherwise + # Q. BUT HOW do you know it is SAFE ????? + # A. The only operation that isn't safe is the isolation level set + # operation which aiosqlite appears to have let slip through even + # though pysqlite appears to do check_same_thread for this. + # All execute operations etc. should be safe because they all + # go through the single executor thread. + + kw["check_same_thread"] = False + + connection = self.aiosqlite.connect(*arg, **kw) + + # it's a Thread. you'll thank us later + connection.daemon = True + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aiosqlite_connection( + self, + await_fallback(connection), + ) + else: + return AsyncAdapt_aiosqlite_connection( + self, + await_only(connection), + ) + + +class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): + driver = "aiosqlite" + + is_async = True + + supports_server_side_cursors = True + + execution_ctx_cls = SQLiteExecutionContext_aiosqlite + + @classmethod + def dbapi(cls): + return AsyncAdapt_aiosqlite_dbapi( + __import__("aiosqlite"), __import__("sqlite3") + ) + + @classmethod + def get_pool_class(cls, url): + if cls._is_url_file_db(url): + return pool.NullPool + else: + return pool.StaticPool + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, self.dbapi.OperationalError + ) and "no active connection" in str(e): + return True + + return super().is_disconnect(e, connection, cursor) + + +dialect = SQLiteDialect_aiosqlite diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index a481be27e..d0d12695d 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -11,13 +11,22 @@ from ...testing.provision import stop_test_class_outside_fixtures from ...testing.provision import temp_table_keyword_args +# likely needs a generate_driver_url() def here for the --dbdriver part to +# work + +_drivernames = set() + + @follower_url_from_main.for_db("sqlite") def _sqlite_follower_url_from_main(url, ident): url = sa_url.make_url(url) if not url.database or url.database == ":memory:": return url else: - return sa_url.make_url("sqlite:///%s.db" % ident) + _drivernames.add(url.get_driver_name()) + return sa_url.make_url( + "sqlite+%s:///%s.db" % (url.get_driver_name(), ident) + ) @post_configure_engine.for_db("sqlite") @@ -35,12 +44,13 @@ def _sqlite_post_configure_engine(url, engine, follower_ident): # expected to be already present, so for now it just stays # in a given checkout directory. dbapi_connection.execute( - 'ATTACH DATABASE "test_schema.db" AS test_schema' + 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' + % (engine.driver,) ) else: dbapi_connection.execute( - 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' - % follower_ident + 'ATTACH DATABASE "%s_%s_test_schema.db" AS test_schema' + % (follower_ident, engine.driver) ) @@ -51,7 +61,10 @@ def _sqlite_create_db(cfg, eng, ident): @drop_db.for_db("sqlite") def _sqlite_drop_db(cfg, eng, ident): - for path in ["%s.db" % ident, "%s_test_schema.db" % ident]: + for path in [ + "%s.db" % ident, + "%s_%s_test_schema.db" % (ident, eng.driver), + ]: if os.path.exists(path): log.info("deleting SQLite database file: %s" % path) os.remove(path) @@ -71,9 +84,9 @@ def stop_test_class_outside_fixtures(config, db, cls): # some sqlite file tests are not cleaning up well yet, so do this # just to make things simple for now - for file in files: - if file: - os.remove(file) + for file_ in files: + if file_ and os.path.exists(file_): + os.remove(file_) @temp_table_keyword_args.for_db("sqlite") @@ -89,7 +102,19 @@ def _reap_sqlite_dbs(url, idents): for ident in idents: # we don't have a config so we can't call _sqlite_drop_db due to the # decorator - for path in ["%s.db" % ident, "%s_test_schema.db" % ident]: + for path in ( + [ + "%s.db" % ident, + ] + + [ + "%s_test_schema.db" % (drivername,) + for drivername in _drivernames + ] + + [ + "%s_%s_test_schema.db" % (ident, drivername) + for drivername in _drivernames + ] + ): if os.path.exists(path): log.info("deleting SQLite database file: %s" % path) os.remove(path) diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 6ec489604..d14316fdb 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -574,6 +574,13 @@ class _ConnectionRecord(object): self.__connect() return self.connection + def _is_hard_or_soft_invalidated(self): + return ( + self.connection is None + or self.__pool._invalidate_time > self.starttime + or (self._soft_invalidate_time > self.starttime) + ) + def __close(self): self.finalize_callback.clear() if self.__pool.dispatch.close: diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 08371a31a..730293273 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -395,15 +395,11 @@ class StaticPool(Pool): """A Pool of exactly one connection, used for all requests. Reconnect-related functions such as ``recycle`` and connection - invalidation (which is also used to support auto-reconnect) are not - currently supported by this Pool implementation but may be implemented - in a future release. + invalidation (which is also used to support auto-reconnect) are only + partially supported right now and may not yield good results. - """ - @util.memoized_property - def _conn(self): - return self._creator() + """ @util.memoized_property def connection(self): @@ -413,9 +409,12 @@ class StaticPool(Pool): return "StaticPool" def dispose(self): - if "_conn" in self.__dict__: - self._conn.close() - self._conn = None + if ( + "connection" in self.__dict__ + and self.connection.connection is not None + ): + self.connection.close() + del self.__dict__["connection"] def recreate(self): self.logger.info("Pool recreating") @@ -430,14 +429,26 @@ class StaticPool(Pool): dialect=self._dialect, ) + def _transfer_from(self, other_static_pool): + # used by the test suite to make a new engine / pool without + # losing the state of an existing SQLite :memory: connection + self._invoke_creator = ( + lambda crec: other_static_pool.connection.connection + ) + def _create_connection(self): - return self._conn + raise NotImplementedError() def _do_return_conn(self, conn): pass def _do_get(self): - return self.connection + rec = self.connection + if rec._is_hard_or_soft_invalidated(): + del self.__dict__["connection"] + rec = self.connection + + return rec class AssertionPool(Pool): diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index a313c298a..3faf96857 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -266,7 +266,13 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=None, asyncio=False): +def testing_engine( + url=None, + options=None, + future=None, + asyncio=False, + transfer_staticpool=False, +): """Produce an engine configured by --options with optional overrides.""" if asyncio: @@ -300,6 +306,12 @@ def testing_engine(url=None, options=None, future=None, asyncio=False): engine = create_engine(url, **options) + if transfer_staticpool: + from sqlalchemy.pool import StaticPool + + if config.db is not None and isinstance(config.db.pool, StaticPool): + engine.pool._transfer_from(config.db.pool) + if scope == "global": if asyncio: engine.sync_engine._has_events = True diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index f47277b4a..c3eb1b363 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -51,6 +51,13 @@ class TestBase(object): assert val, msg @config.fixture() + def connection_no_trans(self): + eng = getattr(self, "bind", None) or config.db + + with eng.connect() as conn: + yield conn + + @config.fixture() def connection(self): global _connection_fixture_connection diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index de2b8f12c..208ba0091 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -18,6 +18,7 @@ to provide specific inclusion/exclusions. import platform import sys +from sqlalchemy.pool.impl import QueuePool from . import exclusions from .. import util @@ -117,6 +118,15 @@ class SuiteRequirements(Requirements): ) @property + def queue_pool(self): + """target database is using QueuePool""" + + def go(config): + return isinstance(config.db.pool, QueuePool) + + return exclusions.only_if(go) + + @property def self_referential_foreign_keys(self): """Target database must support self-referential foreign keys.""" diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index a236b1076..c2c17d0dd 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -180,8 +180,8 @@ class AutocommitIsolationTest(fixtures.TablesTest): with conn.begin(): conn.execute(self.tables.some_table.delete()) - def test_autocommit_on(self): - conn = config.db.connect() + def test_autocommit_on(self, connection_no_trans): + conn = connection_no_trans c2 = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(c2, True) @@ -189,12 +189,14 @@ class AutocommitIsolationTest(fixtures.TablesTest): self._test_conn_autocommits(conn, False) - def test_autocommit_off(self): - conn = config.db.connect() + def test_autocommit_off(self, connection_no_trans): + conn = connection_no_trans self._test_conn_autocommits(conn, False) - def test_turn_autocommit_off_via_default_iso_level(self): - conn = config.db.connect() + def test_turn_autocommit_off_via_default_iso_level( + self, connection_no_trans + ): + conn = connection_no_trans conn = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(conn, True) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index e8dd6cf2c..6c2880ad4 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -227,6 +227,8 @@ class ServerSideCursorsTest( __backend__ = True def _is_server_side(self, cursor): + # TODO: this is a huge issue as it prevents these tests from being + # usable by third party dialects. if self.engine.dialect.driver == "psycopg2": return bool(cursor.name) elif self.engine.dialect.driver == "pymysql": @@ -239,7 +241,7 @@ class ServerSideCursorsTest( return isinstance(cursor, sscursor) elif self.engine.dialect.driver == "mariadbconnector": return not cursor.buffered - elif self.engine.dialect.driver == "asyncpg": + elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"): return cursor.server_side else: return False @@ -279,7 +281,14 @@ class ServerSideCursorsTest( False, ), ("for_update_expr", True, select(1).with_for_update(), True), - ("for_update_string", True, "SELECT 1 FOR UPDATE", True), + # TODO: need a real requirement for this, or dont use this test + ( + "for_update_string", + True, + "SELECT 1 FOR UPDATE", + True, + testing.skip_if("sqlite"), + ), ("text_no_ss", False, text("select 42"), False), ( "text_ss_option", diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index c44efba62..e26f305d9 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -32,7 +32,7 @@ if not have_greenlet: ) def await_only(thing): # noqa F811 - return thing + _not_implemented() def await_fallback(thing): # noqa F81 return thing |
