summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/mysql/aiomysql.py7
-rw-r--r--lib/sqlalchemy/dialects/sqlite/__init__.py4
-rw-r--r--lib/sqlalchemy/dialects/sqlite/aiosqlite.py331
-rw-r--r--lib/sqlalchemy/dialects/sqlite/provision.py43
-rw-r--r--lib/sqlalchemy/pool/base.py7
-rw-r--r--lib/sqlalchemy/pool/impl.py35
-rw-r--r--lib/sqlalchemy/testing/engines.py14
-rw-r--r--lib/sqlalchemy/testing/fixtures.py7
-rw-r--r--lib/sqlalchemy/testing/requirements.py10
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py14
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py13
-rw-r--r--lib/sqlalchemy/util/concurrency.py2
12 files changed, 456 insertions, 31 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py
index cab6df499..c8c7c0f97 100644
--- a/lib/sqlalchemy/dialects/mysql/aiomysql.py
+++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py
@@ -82,6 +82,13 @@ class AsyncAdapt_aiomysql_cursor:
return self._cursor.lastrowid
def close(self):
+ # note we aren't actually closing the cursor here,
+ # we are just letting GC do it. to allow this to be async
+ # we would need the Result to change how it does "Safe close cursor".
+ # MySQL "cursors" don't actually have state to be "closed" besides
+ # exhausting rows, which we already have done for sync cursor.
+ # another option would be to emulate aiosqlite dialect and assign
+ # cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py
index d12203cbd..8b24a19fd 100644
--- a/lib/sqlalchemy/dialects/sqlite/__init__.py
+++ b/lib/sqlalchemy/dialects/sqlite/__init__.py
@@ -26,6 +26,10 @@ from .base import TIMESTAMP
from .base import VARCHAR
from .dml import Insert
from .dml import insert
+from ...util import compat
+
+if compat.py3k:
+ from . import aiosqlite # noqa
# default dialect
base.dialect = dialect = pysqlite.dialect
diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
new file mode 100644
index 000000000..e4b7d1d52
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
@@ -0,0 +1,331 @@
+# sqlite/aiosqlite.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: sqlite+aiosqlite
+ :name: aiosqlite
+ :dbapi: aiosqlite
+ :connectstring: sqlite+aiosqlite:///file_path
+ :url: https://pypi.org/project/aiosqlite/
+
+The aiosqlite dialect provides support for the SQLAlchemy asyncio interface
+running on top of pysqlite.
+
+aiosqlite is a wrapper around pysqlite that uses a background thread for
+each connection. It does not actually use non-blocking IO, as SQLite
+databases are not socket-based. However it does provide a working asyncio
+interface that's useful for testing and prototyping purposes.
+
+Using a special asyncio mediation layer, the aiosqlite dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("sqlite+aiosqlite:///filename")
+
+The URL passes through all arguments to the ``pysqlite`` driver, so all
+connection arguments are the same as they are for that of :ref:`pysqlite`.
+
+
+""" # noqa
+
+from .base import SQLiteExecutionContext
+from .pysqlite import SQLiteDialect_pysqlite
+from ... import pool
+from ... import util
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiosqlite_cursor:
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "description",
+ "await_",
+ "_rows",
+ "arraysize",
+ "rowcount",
+ "lastrowid",
+ )
+
+ server_side = False
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+ self.arraysize = 1
+ self.rowcount = -1
+ self.description = None
+ self._rows = []
+
+ def close(self):
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ try:
+ _cursor = self.await_(self._connection.cursor())
+
+ if parameters is None:
+ self.await_(_cursor.execute(operation))
+ else:
+ self.await_(_cursor.execute(operation, parameters))
+
+ if _cursor.description:
+ self.description = _cursor.description
+ self.lastrowid = self.rowcount = -1
+
+ if not self.server_side:
+ self._rows = self.await_(_cursor.fetchall())
+ else:
+ self.description = None
+ self.lastrowid = _cursor.lastrowid
+ self.rowcount = _cursor.rowcount
+
+ if not self.server_side:
+ self.await_(_cursor.close())
+ else:
+ self._cursor = _cursor
+ except Exception as error:
+ self._adapt_connection._handle_exception(error)
+
+ def executemany(self, operation, seq_of_parameters):
+ try:
+ _cursor = self.await_(self._connection.cursor())
+ self.await_(_cursor.executemany(operation, seq_of_parameters))
+ self.description = None
+ self.lastrowid = _cursor.lastrowid
+ self.rowcount = _cursor.rowcount
+ self.await_(_cursor.close())
+ except Exception as error:
+ self._adapt_connection._handle_exception(error)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
+ __slots__ = "_cursor"
+
+ server_side = True
+
+ def __init__(self, *arg, **kw):
+ super().__init__(*arg, **kw)
+ self._cursor = None
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiosqlite_connection:
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+
+ @property
+ def isolation_level(self):
+ return self._connection.isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value):
+ try:
+ self._connection.isolation_level = value
+ except Exception as error:
+ self._handle_exception(error)
+
+ def create_function(self, *args, **kw):
+ try:
+ self.await_(self._connection.create_function(*args, **kw))
+ except Exception as error:
+ self._handle_exception(error)
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_aiosqlite_ss_cursor(self)
+ else:
+ return AsyncAdapt_aiosqlite_cursor(self)
+
+ def execute(self, *args, **kw):
+ return self.await_(self._connection.execute(*args, **kw))
+
+ def rollback(self):
+ try:
+ self.await_(self._connection.rollback())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def commit(self):
+ try:
+ self.await_(self._connection.commit())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def close(self):
+ # print(">close", self)
+ try:
+ self.await_(self._connection.close())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def _handle_exception(self, error):
+ if (
+ isinstance(error, ValueError)
+ and error.args[0] == "no active connection"
+ ):
+ util.raise_(
+ self.dbapi.sqlite.OperationalError("no active connection"),
+ from_=error,
+ )
+ else:
+ raise error
+
+
+class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiosqlite_dbapi:
+ def __init__(self, aiosqlite, sqlite):
+ self.aiosqlite = aiosqlite
+ self.sqlite = sqlite
+ self.paramstyle = "qmark"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "DatabaseError",
+ "Error",
+ "IntegrityError",
+ "NotSupportedError",
+ "OperationalError",
+ "ProgrammingError",
+ "sqlite_version",
+ "sqlite_version_info",
+ ):
+ setattr(self, name, getattr(self.aiosqlite, name))
+
+ for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"):
+ setattr(self, name, getattr(self.sqlite, name))
+
+ for name in ("Binary",):
+ setattr(self, name, getattr(self.sqlite, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ # Q. WHY do we need this?
+ # A. Because there is no way to set connection.isolation_level
+ # otherwise
+ # Q. BUT HOW do you know it is SAFE ?????
+ # A. The only operation that isn't safe is the isolation level set
+ # operation which aiosqlite appears to have let slip through even
+ # though pysqlite appears to do check_same_thread for this.
+ # All execute operations etc. should be safe because they all
+ # go through the single executor thread.
+
+ kw["check_same_thread"] = False
+
+ connection = self.aiosqlite.connect(*arg, **kw)
+
+ # it's a Thread. you'll thank us later
+ connection.daemon = True
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_aiosqlite_connection(
+ self,
+ await_fallback(connection),
+ )
+ else:
+ return AsyncAdapt_aiosqlite_connection(
+ self,
+ await_only(connection),
+ )
+
+
+class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
+ def create_server_side_cursor(self):
+ return self._dbapi_connection.cursor(server_side=True)
+
+
+class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
+ driver = "aiosqlite"
+
+ is_async = True
+
+ supports_server_side_cursors = True
+
+ execution_ctx_cls = SQLiteExecutionContext_aiosqlite
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_aiosqlite_dbapi(
+ __import__("aiosqlite"), __import__("sqlite3")
+ )
+
+ @classmethod
+ def get_pool_class(cls, url):
+ if cls._is_url_file_db(url):
+ return pool.NullPool
+ else:
+ return pool.StaticPool
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e, self.dbapi.OperationalError
+ ) and "no active connection" in str(e):
+ return True
+
+ return super().is_disconnect(e, connection, cursor)
+
+
+dialect = SQLiteDialect_aiosqlite
diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py
index a481be27e..d0d12695d 100644
--- a/lib/sqlalchemy/dialects/sqlite/provision.py
+++ b/lib/sqlalchemy/dialects/sqlite/provision.py
@@ -11,13 +11,22 @@ from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
+# likely needs a generate_driver_url() def here for the --dbdriver part to
+# work
+
+_drivernames = set()
+
+
@follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
if not url.database or url.database == ":memory:":
return url
else:
- return sa_url.make_url("sqlite:///%s.db" % ident)
+ _drivernames.add(url.get_driver_name())
+ return sa_url.make_url(
+ "sqlite+%s:///%s.db" % (url.get_driver_name(), ident)
+ )
@post_configure_engine.for_db("sqlite")
@@ -35,12 +44,13 @@ def _sqlite_post_configure_engine(url, engine, follower_ident):
# expected to be already present, so for now it just stays
# in a given checkout directory.
dbapi_connection.execute(
- 'ATTACH DATABASE "test_schema.db" AS test_schema'
+ 'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
+ % (engine.driver,)
)
else:
dbapi_connection.execute(
- 'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
- % follower_ident
+ 'ATTACH DATABASE "%s_%s_test_schema.db" AS test_schema'
+ % (follower_ident, engine.driver)
)
@@ -51,7 +61,10 @@ def _sqlite_create_db(cfg, eng, ident):
@drop_db.for_db("sqlite")
def _sqlite_drop_db(cfg, eng, ident):
- for path in ["%s.db" % ident, "%s_test_schema.db" % ident]:
+ for path in [
+ "%s.db" % ident,
+ "%s_%s_test_schema.db" % (ident, eng.driver),
+ ]:
if os.path.exists(path):
log.info("deleting SQLite database file: %s" % path)
os.remove(path)
@@ -71,9 +84,9 @@ def stop_test_class_outside_fixtures(config, db, cls):
# some sqlite file tests are not cleaning up well yet, so do this
# just to make things simple for now
- for file in files:
- if file:
- os.remove(file)
+ for file_ in files:
+ if file_ and os.path.exists(file_):
+ os.remove(file_)
@temp_table_keyword_args.for_db("sqlite")
@@ -89,7 +102,19 @@ def _reap_sqlite_dbs(url, idents):
for ident in idents:
# we don't have a config so we can't call _sqlite_drop_db due to the
# decorator
- for path in ["%s.db" % ident, "%s_test_schema.db" % ident]:
+ for path in (
+ [
+ "%s.db" % ident,
+ ]
+ + [
+ "%s_test_schema.db" % (drivername,)
+ for drivername in _drivernames
+ ]
+ + [
+ "%s_%s_test_schema.db" % (ident, drivername)
+ for drivername in _drivernames
+ ]
+ ):
if os.path.exists(path):
log.info("deleting SQLite database file: %s" % path)
os.remove(path)
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
index 6ec489604..d14316fdb 100644
--- a/lib/sqlalchemy/pool/base.py
+++ b/lib/sqlalchemy/pool/base.py
@@ -574,6 +574,13 @@ class _ConnectionRecord(object):
self.__connect()
return self.connection
+ def _is_hard_or_soft_invalidated(self):
+ return (
+ self.connection is None
+ or self.__pool._invalidate_time > self.starttime
+ or (self._soft_invalidate_time > self.starttime)
+ )
+
def __close(self):
self.finalize_callback.clear()
if self.__pool.dispatch.close:
diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py
index 08371a31a..730293273 100644
--- a/lib/sqlalchemy/pool/impl.py
+++ b/lib/sqlalchemy/pool/impl.py
@@ -395,15 +395,11 @@ class StaticPool(Pool):
"""A Pool of exactly one connection, used for all requests.
Reconnect-related functions such as ``recycle`` and connection
- invalidation (which is also used to support auto-reconnect) are not
- currently supported by this Pool implementation but may be implemented
- in a future release.
+ invalidation (which is also used to support auto-reconnect) are only
+ partially supported right now and may not yield good results.
- """
- @util.memoized_property
- def _conn(self):
- return self._creator()
+ """
@util.memoized_property
def connection(self):
@@ -413,9 +409,12 @@ class StaticPool(Pool):
return "StaticPool"
def dispose(self):
- if "_conn" in self.__dict__:
- self._conn.close()
- self._conn = None
+ if (
+ "connection" in self.__dict__
+ and self.connection.connection is not None
+ ):
+ self.connection.close()
+ del self.__dict__["connection"]
def recreate(self):
self.logger.info("Pool recreating")
@@ -430,14 +429,26 @@ class StaticPool(Pool):
dialect=self._dialect,
)
+ def _transfer_from(self, other_static_pool):
+ # used by the test suite to make a new engine / pool without
+ # losing the state of an existing SQLite :memory: connection
+ self._invoke_creator = (
+ lambda crec: other_static_pool.connection.connection
+ )
+
def _create_connection(self):
- return self._conn
+ raise NotImplementedError()
def _do_return_conn(self, conn):
pass
def _do_get(self):
- return self.connection
+ rec = self.connection
+ if rec._is_hard_or_soft_invalidated():
+ del self.__dict__["connection"]
+ rec = self.connection
+
+ return rec
class AssertionPool(Pool):
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index a313c298a..3faf96857 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -266,7 +266,13 @@ def reconnecting_engine(url=None, options=None):
return engine
-def testing_engine(url=None, options=None, future=None, asyncio=False):
+def testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+):
"""Produce an engine configured by --options with optional overrides."""
if asyncio:
@@ -300,6 +306,12 @@ def testing_engine(url=None, options=None, future=None, asyncio=False):
engine = create_engine(url, **options)
+ if transfer_staticpool:
+ from sqlalchemy.pool import StaticPool
+
+ if config.db is not None and isinstance(config.db.pool, StaticPool):
+ engine.pool._transfer_from(config.db.pool)
+
if scope == "global":
if asyncio:
engine.sync_engine._has_events = True
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index f47277b4a..c3eb1b363 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -51,6 +51,13 @@ class TestBase(object):
assert val, msg
@config.fixture()
+ def connection_no_trans(self):
+ eng = getattr(self, "bind", None) or config.db
+
+ with eng.connect() as conn:
+ yield conn
+
+ @config.fixture()
def connection(self):
global _connection_fixture_connection
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index de2b8f12c..208ba0091 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -18,6 +18,7 @@ to provide specific inclusion/exclusions.
import platform
import sys
+from sqlalchemy.pool.impl import QueuePool
from . import exclusions
from .. import util
@@ -117,6 +118,15 @@ class SuiteRequirements(Requirements):
)
@property
+ def queue_pool(self):
+ """target database is using QueuePool"""
+
+ def go(config):
+ return isinstance(config.db.pool, QueuePool)
+
+ return exclusions.only_if(go)
+
+ @property
def self_referential_foreign_keys(self):
"""Target database must support self-referential foreign keys."""
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
index a236b1076..c2c17d0dd 100644
--- a/lib/sqlalchemy/testing/suite/test_dialect.py
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -180,8 +180,8 @@ class AutocommitIsolationTest(fixtures.TablesTest):
with conn.begin():
conn.execute(self.tables.some_table.delete())
- def test_autocommit_on(self):
- conn = config.db.connect()
+ def test_autocommit_on(self, connection_no_trans):
+ conn = connection_no_trans
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(c2, True)
@@ -189,12 +189,14 @@ class AutocommitIsolationTest(fixtures.TablesTest):
self._test_conn_autocommits(conn, False)
- def test_autocommit_off(self):
- conn = config.db.connect()
+ def test_autocommit_off(self, connection_no_trans):
+ conn = connection_no_trans
self._test_conn_autocommits(conn, False)
- def test_turn_autocommit_off_via_default_iso_level(self):
- conn = config.db.connect()
+ def test_turn_autocommit_off_via_default_iso_level(
+ self, connection_no_trans
+ ):
+ conn = connection_no_trans
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(conn, True)
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index e8dd6cf2c..6c2880ad4 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -227,6 +227,8 @@ class ServerSideCursorsTest(
__backend__ = True
def _is_server_side(self, cursor):
+ # TODO: this is a huge issue as it prevents these tests from being
+ # usable by third party dialects.
if self.engine.dialect.driver == "psycopg2":
return bool(cursor.name)
elif self.engine.dialect.driver == "pymysql":
@@ -239,7 +241,7 @@ class ServerSideCursorsTest(
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver == "mariadbconnector":
return not cursor.buffered
- elif self.engine.dialect.driver == "asyncpg":
+ elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
return cursor.server_side
else:
return False
@@ -279,7 +281,14 @@ class ServerSideCursorsTest(
False,
),
("for_update_expr", True, select(1).with_for_update(), True),
- ("for_update_string", True, "SELECT 1 FOR UPDATE", True),
+ # TODO: need a real requirement for this, or dont use this test
+ (
+ "for_update_string",
+ True,
+ "SELECT 1 FOR UPDATE",
+ True,
+ testing.skip_if("sqlite"),
+ ),
("text_no_ss", False, text("select 42"), False),
(
"text_ss_option",
diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py
index c44efba62..e26f305d9 100644
--- a/lib/sqlalchemy/util/concurrency.py
+++ b/lib/sqlalchemy/util/concurrency.py
@@ -32,7 +32,7 @@ if not have_greenlet:
)
def await_only(thing): # noqa F811
- return thing
+ _not_implemented()
def await_fallback(thing): # noqa F81
return thing