summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2020-08-14 00:08:29 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2020-08-14 00:08:29 +0000
commit9ccf4a2f659d77c7f1aacaa4ff2b3b16889fea5e (patch)
treec30594326d1fb4b70b239c0865f1b7e627e63cc3 /lib/sqlalchemy/testing
parentb1c3c40e54bafe686065827d5b8d92c35bc50648 (diff)
parent5fb0138a3220161703e6ab1087319a669d14e7f4 (diff)
downloadsqlalchemy-9ccf4a2f659d77c7f1aacaa4ff2b3b16889fea5e.tar.gz
Merge "Implement rudimentary asyncio support w/ asyncpg"
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/__init__.py4
-rw-r--r--lib/sqlalchemy/testing/assertions.py30
-rw-r--r--lib/sqlalchemy/testing/asyncio.py14
-rw-r--r--lib/sqlalchemy/testing/config.py4
-rw-r--r--lib/sqlalchemy/testing/fixtures.py1
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py8
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py55
-rw-r--r--lib/sqlalchemy/testing/requirements.py6
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py81
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py42
10 files changed, 207 insertions, 38 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 79b7f9eb3..9b1164874 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -12,7 +12,6 @@ from .assertions import assert_raises # noqa
from .assertions import assert_raises_context_ok # noqa
from .assertions import assert_raises_message # noqa
from .assertions import assert_raises_message_context_ok # noqa
-from .assertions import assert_raises_return # noqa
from .assertions import AssertsCompiledSQL # noqa
from .assertions import AssertsExecutionResults # noqa
from .assertions import ComparesTables # noqa
@@ -23,6 +22,8 @@ from .assertions import eq_ignore_whitespace # noqa
from .assertions import eq_regex # noqa
from .assertions import expect_deprecated # noqa
from .assertions import expect_deprecated_20 # noqa
+from .assertions import expect_raises # noqa
+from .assertions import expect_raises_message # noqa
from .assertions import expect_warnings # noqa
from .assertions import in_ # noqa
from .assertions import is_ # noqa
@@ -35,6 +36,7 @@ from .assertions import ne_ # noqa
from .assertions import not_in_ # noqa
from .assertions import startswith_ # noqa
from .assertions import uses_deprecated # noqa
+from .config import async_test # noqa
from .config import combinations # noqa
from .config import db # noqa
from .config import fixture # noqa
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index ecc6a4ab8..fe74be823 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -298,10 +298,6 @@ def assert_raises_context_ok(except_cls, callable_, *args, **kw):
return _assert_raises(except_cls, callable_, args, kw,)
-def assert_raises_return(except_cls, callable_, *args, **kw):
- return _assert_raises(except_cls, callable_, args, kw, check_context=True)
-
-
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
return _assert_raises(
except_cls, callable_, args, kwargs, msg=msg, check_context=True
@@ -317,14 +313,26 @@ def assert_raises_message_context_ok(
def _assert_raises(
except_cls, callable_, args, kwargs, msg=None, check_context=False
):
- ret_err = None
+
+ with _expect_raises(except_cls, msg, check_context) as ec:
+ callable_(*args, **kwargs)
+ return ec.error
+
+
+class _ErrorContainer(object):
+ error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+ ec = _ErrorContainer()
if check_context:
are_we_already_in_a_traceback = sys.exc_info()[0]
try:
- callable_(*args, **kwargs)
+ yield ec
success = False
except except_cls as err:
- ret_err = err
+ ec.error = err
success = True
if msg is not None:
assert re.search(
@@ -337,7 +345,13 @@ def _assert_raises(
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
- return ret_err
+
+def expect_raises(except_cls):
+ return _expect_raises(except_cls, check_context=True)
+
+
+def expect_raises_message(except_cls, msg):
+ return _expect_raises(except_cls, msg=msg, check_context=True)
class AssertsCompiledSQL(object):
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
new file mode 100644
index 000000000..2e274de16
--- /dev/null
+++ b/lib/sqlalchemy/testing/asyncio.py
@@ -0,0 +1,14 @@
+from .assertions import assert_raises as _assert_raises
+from .assertions import assert_raises_message as _assert_raises_message
+from ..util import await_fallback as await_
+from ..util import greenlet_spawn
+
+
+async def assert_raises_async(except_cls, msg, coroutine):
+ await greenlet_spawn(_assert_raises, except_cls, await_, coroutine)
+
+
+async def assert_raises_message_async(except_cls, msg, coroutine):
+ await greenlet_spawn(
+ _assert_raises_message, except_cls, msg, await_, coroutine
+ )
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index e97821d72..8c232f319 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -178,3 +178,7 @@ class Config(object):
def skip_test(msg):
raise _fixture_functions.skip_test_exception(msg)
+
+
+def async_test(fn):
+ return _fixture_functions.async_test(fn)
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 1583147d4..85d3374de 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -61,6 +61,7 @@ class TestBase(object):
@config.fixture()
def connection(self):
eng = getattr(self, "bind", config.db)
+
conn = eng.connect()
trans = conn.begin()
try:
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index b31a4ff3e..49ff0f975 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -48,7 +48,6 @@ testing = None
util = None
file_config = None
-
logging = None
include_tags = set()
exclude_tags = set()
@@ -193,6 +192,12 @@ def setup_options(make_option):
default=False,
help="Unconditionally write/update profiling data.",
)
+ make_option(
+ "--dump-pyannotate",
+ type=str,
+ dest="dump_pyannotate",
+ help="Run pyannotate and dump json info to given file",
+ )
def configure_follower(follower_ident):
@@ -378,7 +383,6 @@ def _engine_uri(options, file_config):
cfg = provision.setup_config(
db_url, options, file_config, provision.FOLLOWER_IDENT
)
-
if not config._current:
cfg.set_as_current(cfg, testing)
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 015598952..3df239afa 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -26,6 +26,11 @@ else:
from typing import Sequence
try:
+ import asyncio
+except ImportError:
+ pass
+
+try:
import xdist # noqa
has_xdist = True
@@ -101,6 +106,24 @@ def pytest_configure(config):
plugin_base.set_fixture_functions(PytestFixtureFunctions)
+ if config.option.dump_pyannotate:
+ global DUMP_PYANNOTATE
+ DUMP_PYANNOTATE = True
+
+
+DUMP_PYANNOTATE = False
+
+
+@pytest.fixture(autouse=True)
+def collect_types_fixture():
+ if DUMP_PYANNOTATE:
+ from pyannotate_runtime import collect_types
+
+ collect_types.start()
+ yield
+ if DUMP_PYANNOTATE:
+ collect_types.stop()
+
def pytest_sessionstart(session):
plugin_base.post_begin()
@@ -109,6 +132,31 @@ def pytest_sessionstart(session):
def pytest_sessionfinish(session):
plugin_base.final_process_cleanup()
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ collect_types.dump_stats(session.config.option.dump_pyannotate)
+
+
+def pytest_collection_finish(session):
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
+
+ def _filter(filename):
+ filename = os.path.normpath(os.path.abspath(filename))
+ if "lib/sqlalchemy" not in os.path.commonpath(
+ [filename, lib_sqlalchemy]
+ ):
+ return None
+ if "testing" in filename:
+ return None
+
+ return filename
+
+ collect_types.init_types_collection(filter_filename=_filter)
+
if has_xdist:
import uuid
@@ -518,3 +566,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def get_current_test_name(self):
return os.environ.get("PYTEST_CURRENT_TEST")
+
+ def async_test(self, fn):
+ @_pytest_fn_decorator
+ def decorate(fn, *args, **kwargs):
+ asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs))
+
+ return decorate(fn)
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 25998c07b..36d0ce4c6 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -1194,6 +1194,12 @@ class SuiteRequirements(Requirements):
return False
@property
+ def async_dialect(self):
+ """dialect makes use of await_() to invoke operations on the DBAPI."""
+
+ return exclusions.closed()
+
+ @property
def computed_columns(self):
"Supports computed columns"
return exclusions.closed()
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index 2eb986c74..e6f6068c8 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -238,6 +238,8 @@ class ServerSideCursorsTest(
elif self.engine.dialect.driver == "mysqldb":
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver == "asyncpg":
+ return cursor.server_side
else:
return False
@@ -331,29 +333,74 @@ class ServerSideCursorsTest(
result.close()
@testing.provide_metadata
- def test_roundtrip(self):
+ def test_roundtrip_fetchall(self):
md = self.metadata
- self._fixture(True)
+ engine = self._fixture(True)
test_table = Table(
"test_table",
md,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
)
- test_table.create(checkfirst=True)
- test_table.insert().execute(data="data1")
- test_table.insert().execute(data="data2")
- eq_(
- test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, "data1"), (2, "data2")],
- )
- test_table.update().where(test_table.c.id == 2).values(
- data=test_table.c.data + " updated"
- ).execute()
- eq_(
- test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, "data1"), (2, "data2 updated")],
+
+ with engine.connect() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(test_table.insert(), dict(data="data1"))
+ connection.execute(test_table.insert(), dict(data="data2"))
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ connection.execute(
+ test_table.update()
+ .where(test_table.c.id == 2)
+ .values(data=test_table.c.data + " updated")
+ )
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
+ connection.execute(test_table.delete())
+ eq_(
+ connection.scalar(
+ select([func.count("*")]).select_from(test_table)
+ ),
+ 0,
+ )
+
+ @testing.provide_metadata
+ def test_roundtrip_fetchmany(self):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
)
- test_table.delete().execute()
- eq_(select([func.count("*")]).select_from(test_table).scalar(), 0)
+
+ with engine.connect() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(
+ test_table.insert(),
+ [dict(data="data%d" % i) for i in range(1, 20)],
+ )
+
+ result = connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ )
+
+ eq_(
+ result.fetchmany(5), [(i, "data%d" % i) for i in range(1, 6)],
+ )
+ eq_(
+ result.fetchmany(10),
+ [(i, "data%d" % i) for i in range(6, 16)],
+ )
+ eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 48144f885..5e6ac1eab 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -35,6 +35,7 @@ from ... import Text
from ... import Time
from ... import TIMESTAMP
from ... import type_coerce
+from ... import TypeDecorator
from ... import Unicode
from ... import UnicodeText
from ... import util
@@ -282,6 +283,9 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
@classmethod
def define_tables(cls, metadata):
+ class Decorated(TypeDecorator):
+ impl = cls.datatype
+
Table(
"date_table",
metadata,
@@ -289,6 +293,7 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
"id", Integer, primary_key=True, test_needs_autoincrement=True
),
Column("date_data", cls.datatype),
+ Column("decorated_date_data", Decorated),
)
def test_round_trip(self, connection):
@@ -302,6 +307,21 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
eq_(row, (compare,))
assert isinstance(row[0], type(compare))
+ def test_round_trip_decorated(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"decorated_date_data": self.data}
+ )
+
+ row = connection.execute(
+ select(date_table.c.decorated_date_data)
+ ).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
def test_null(self, connection):
date_table = self.tables.date_table
@@ -526,6 +546,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
Float(precision=8, asdecimal=True),
[15.7563, decimal.Decimal("15.7563"), None],
[decimal.Decimal("15.7563"), None],
+ filter_=lambda n: n is not None and round(n, 4) or None,
)
def test_float_as_float(self):
@@ -777,6 +798,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
# ("json", {"foo": "bar"}),
id_="sa",
)(fn)
+
return fn
@_index_fixtures
@@ -1139,7 +1161,15 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest):
and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6"
)
- def test_crit_against_string_coerce_type(self):
+ def test_crit_against_int_basic(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ self._test_index_criteria(
+ and_(name == "r6", cast(col["a"], String) == "5"), "r6"
+ )
+
+ def _dont_test_crit_against_string_coerce_type(self):
name = self.tables.data_table.c.name
col = self.tables.data_table.c["data"]
@@ -1152,15 +1182,7 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest):
test_literal=False,
)
- def test_crit_against_int_basic(self):
- name = self.tables.data_table.c.name
- col = self.tables.data_table.c["data"]
-
- self._test_index_criteria(
- and_(name == "r6", cast(col["a"], String) == "5"), "r6"
- )
-
- def test_crit_against_int_coerce_type(self):
+ def _dont_test_crit_against_int_coerce_type(self):
name = self.tables.data_table.c.name
col = self.tables.data_table.c["data"]