diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2020-08-14 00:08:29 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-08-14 00:08:29 +0000 |
| commit | 9ccf4a2f659d77c7f1aacaa4ff2b3b16889fea5e (patch) | |
| tree | c30594326d1fb4b70b239c0865f1b7e627e63cc3 /lib/sqlalchemy/testing | |
| parent | b1c3c40e54bafe686065827d5b8d92c35bc50648 (diff) | |
| parent | 5fb0138a3220161703e6ab1087319a669d14e7f4 (diff) | |
| download | sqlalchemy-9ccf4a2f659d77c7f1aacaa4ff2b3b16889fea5e.tar.gz | |
Merge "Implement rudimentary asyncio support w/ asyncpg"
Diffstat (limited to 'lib/sqlalchemy/testing')
| -rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/asyncio.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/config.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/plugin/plugin_base.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 55 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_results.py | 81 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 42 |
10 files changed, 207 insertions, 38 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 79b7f9eb3..9b1164874 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -12,7 +12,6 @@ from .assertions import assert_raises # noqa from .assertions import assert_raises_context_ok # noqa from .assertions import assert_raises_message # noqa from .assertions import assert_raises_message_context_ok # noqa -from .assertions import assert_raises_return # noqa from .assertions import AssertsCompiledSQL # noqa from .assertions import AssertsExecutionResults # noqa from .assertions import ComparesTables # noqa @@ -23,6 +22,8 @@ from .assertions import eq_ignore_whitespace # noqa from .assertions import eq_regex # noqa from .assertions import expect_deprecated # noqa from .assertions import expect_deprecated_20 # noqa +from .assertions import expect_raises # noqa +from .assertions import expect_raises_message # noqa from .assertions import expect_warnings # noqa from .assertions import in_ # noqa from .assertions import is_ # noqa @@ -35,6 +36,7 @@ from .assertions import ne_ # noqa from .assertions import not_in_ # noqa from .assertions import startswith_ # noqa from .assertions import uses_deprecated # noqa +from .config import async_test # noqa from .config import combinations # noqa from .config import db # noqa from .config import fixture # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index ecc6a4ab8..fe74be823 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -298,10 +298,6 @@ def assert_raises_context_ok(except_cls, callable_, *args, **kw): return _assert_raises(except_cls, callable_, args, kw,) -def assert_raises_return(except_cls, callable_, *args, **kw): - return _assert_raises(except_cls, callable_, args, kw, check_context=True) - - def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): return _assert_raises( except_cls, callable_, args, kwargs, msg=msg, check_context=True @@ -317,14 +313,26 @@ def assert_raises_message_context_ok( def _assert_raises( except_cls, callable_, args, kwargs, msg=None, check_context=False ): - ret_err = None + + with _expect_raises(except_cls, msg, check_context) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer(object): + error = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False): + ec = _ErrorContainer() if check_context: are_we_already_in_a_traceback = sys.exc_info()[0] try: - callable_(*args, **kwargs) + yield ec success = False except except_cls as err: - ret_err = err + ec.error = err success = True if msg is not None: assert re.search( @@ -337,7 +345,13 @@ def _assert_raises( # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" - return ret_err + +def expect_raises(except_cls): + return _expect_raises(except_cls, check_context=True) + + +def expect_raises_message(except_cls, msg): + return _expect_raises(except_cls, msg=msg, check_context=True) class AssertsCompiledSQL(object): diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py new file mode 100644 index 000000000..2e274de16 --- /dev/null +++ b/lib/sqlalchemy/testing/asyncio.py @@ -0,0 +1,14 @@ +from .assertions import assert_raises as _assert_raises +from .assertions import assert_raises_message as _assert_raises_message +from ..util import await_fallback as await_ +from ..util import greenlet_spawn + + +async def assert_raises_async(except_cls, msg, coroutine): + await greenlet_spawn(_assert_raises, except_cls, await_, coroutine) + + +async def assert_raises_message_async(except_cls, msg, coroutine): + await greenlet_spawn( + _assert_raises_message, except_cls, msg, await_, coroutine + ) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index e97821d72..8c232f319 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -178,3 +178,7 @@ class Config(object): def skip_test(msg): raise _fixture_functions.skip_test_exception(msg) + + +def async_test(fn): + return _fixture_functions.async_test(fn) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 1583147d4..85d3374de 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -61,6 +61,7 @@ class TestBase(object): @config.fixture() def connection(self): eng = getattr(self, "bind", config.db) + conn = eng.connect() trans = conn.begin() try: diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index b31a4ff3e..49ff0f975 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -48,7 +48,6 @@ testing = None util = None file_config = None - logging = None include_tags = set() exclude_tags = set() @@ -193,6 +192,12 @@ def setup_options(make_option): default=False, help="Unconditionally write/update profiling data.", ) + make_option( + "--dump-pyannotate", + type=str, + dest="dump_pyannotate", + help="Run pyannotate and dump json info to given file", + ) def configure_follower(follower_ident): @@ -378,7 +383,6 @@ def _engine_uri(options, file_config): cfg = provision.setup_config( db_url, options, file_config, provision.FOLLOWER_IDENT ) - if not config._current: cfg.set_as_current(cfg, testing) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 015598952..3df239afa 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -26,6 +26,11 @@ else: from typing import Sequence try: + import asyncio +except ImportError: + pass + +try: import xdist # noqa has_xdist = True @@ -101,6 +106,24 @@ def pytest_configure(config): plugin_base.set_fixture_functions(PytestFixtureFunctions) + if config.option.dump_pyannotate: + global DUMP_PYANNOTATE + DUMP_PYANNOTATE = True + + +DUMP_PYANNOTATE = False + + +@pytest.fixture(autouse=True) +def collect_types_fixture(): + if DUMP_PYANNOTATE: + from pyannotate_runtime import collect_types + + collect_types.start() + yield + if DUMP_PYANNOTATE: + collect_types.stop() + def pytest_sessionstart(session): plugin_base.post_begin() @@ -109,6 +132,31 @@ def pytest_sessionstart(session): def pytest_sessionfinish(session): plugin_base.final_process_cleanup() + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + collect_types.dump_stats(session.config.option.dump_pyannotate) + + +def pytest_collection_finish(session): + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + lib_sqlalchemy = os.path.abspath("lib/sqlalchemy") + + def _filter(filename): + filename = os.path.normpath(os.path.abspath(filename)) + if "lib/sqlalchemy" not in os.path.commonpath( + [filename, lib_sqlalchemy] + ): + return None + if "testing" in filename: + return None + + return filename + + collect_types.init_types_collection(filter_filename=_filter) + if has_xdist: import uuid @@ -518,3 +566,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): def get_current_test_name(self): return os.environ.get("PYTEST_CURRENT_TEST") + + def async_test(self, fn): + @_pytest_fn_decorator + def decorate(fn, *args, **kwargs): + asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs)) + + return decorate(fn) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 25998c07b..36d0ce4c6 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1194,6 +1194,12 @@ class SuiteRequirements(Requirements): return False @property + def async_dialect(self): + """dialect makes use of await_() to invoke operations on the DBAPI.""" + + return exclusions.closed() + + @property def computed_columns(self): "Supports computed columns" return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 2eb986c74..e6f6068c8 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -238,6 +238,8 @@ class ServerSideCursorsTest( elif self.engine.dialect.driver == "mysqldb": sscursor = __import__("MySQLdb.cursors").cursors.SSCursor return isinstance(cursor, sscursor) + elif self.engine.dialect.driver == "asyncpg": + return cursor.server_side else: return False @@ -331,29 +333,74 @@ class ServerSideCursorsTest( result.close() @testing.provide_metadata - def test_roundtrip(self): + def test_roundtrip_fetchall(self): md = self.metadata - self._fixture(True) + engine = self._fixture(True) test_table = Table( "test_table", md, Column("id", Integer, primary_key=True), Column("data", String(50)), ) - test_table.create(checkfirst=True) - test_table.insert().execute(data="data1") - test_table.insert().execute(data="data2") - eq_( - test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, "data1"), (2, "data2")], - ) - test_table.update().where(test_table.c.id == 2).values( - data=test_table.c.data + " updated" - ).execute() - eq_( - test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, "data1"), (2, "data2 updated")], + + with engine.connect() as connection: + test_table.create(connection, checkfirst=True) + connection.execute(test_table.insert(), dict(data="data1")) + connection.execute(test_table.insert(), dict(data="data2")) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2")], + ) + connection.execute( + test_table.update() + .where(test_table.c.id == 2) + .values(data=test_table.c.data + " updated") + ) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) + connection.execute(test_table.delete()) + eq_( + connection.scalar( + select([func.count("*")]).select_from(test_table) + ), + 0, + ) + + @testing.provide_metadata + def test_roundtrip_fetchmany(self): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), ) - test_table.delete().execute() - eq_(select([func.count("*")]).select_from(test_table).scalar(), 0) + + with engine.connect() as connection: + test_table.create(connection, checkfirst=True) + connection.execute( + test_table.insert(), + [dict(data="data%d" % i) for i in range(1, 20)], + ) + + result = connection.execute( + test_table.select().order_by(test_table.c.id) + ) + + eq_( + result.fetchmany(5), [(i, "data%d" % i) for i in range(1, 6)], + ) + eq_( + result.fetchmany(10), + [(i, "data%d" % i) for i in range(6, 16)], + ) + eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)]) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 48144f885..5e6ac1eab 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -35,6 +35,7 @@ from ... import Text from ... import Time from ... import TIMESTAMP from ... import type_coerce +from ... import TypeDecorator from ... import Unicode from ... import UnicodeText from ... import util @@ -282,6 +283,9 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): @classmethod def define_tables(cls, metadata): + class Decorated(TypeDecorator): + impl = cls.datatype + Table( "date_table", metadata, @@ -289,6 +293,7 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): "id", Integer, primary_key=True, test_needs_autoincrement=True ), Column("date_data", cls.datatype), + Column("decorated_date_data", Decorated), ) def test_round_trip(self, connection): @@ -302,6 +307,21 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): eq_(row, (compare,)) assert isinstance(row[0], type(compare)) + def test_round_trip_decorated(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"decorated_date_data": self.data} + ) + + row = connection.execute( + select(date_table.c.decorated_date_data) + ).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + def test_null(self, connection): date_table = self.tables.date_table @@ -526,6 +546,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): Float(precision=8, asdecimal=True), [15.7563, decimal.Decimal("15.7563"), None], [decimal.Decimal("15.7563"), None], + filter_=lambda n: n is not None and round(n, 4) or None, ) def test_float_as_float(self): @@ -777,6 +798,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): # ("json", {"foo": "bar"}), id_="sa", )(fn) + return fn @_index_fixtures @@ -1139,7 +1161,15 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest): and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6" ) - def test_crit_against_string_coerce_type(self): + def test_crit_against_int_basic(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + self._test_index_criteria( + and_(name == "r6", cast(col["a"], String) == "5"), "r6" + ) + + def _dont_test_crit_against_string_coerce_type(self): name = self.tables.data_table.c.name col = self.tables.data_table.c["data"] @@ -1152,15 +1182,7 @@ class JSONStringCastIndexTest(_LiteralRoundTripFixture, fixtures.TablesTest): test_literal=False, ) - def test_crit_against_int_basic(self): - name = self.tables.data_table.c.name - col = self.tables.data_table.c["data"] - - self._test_index_criteria( - and_(name == "r6", cast(col["a"], String) == "5"), "r6" - ) - - def test_crit_against_int_coerce_type(self): + def _dont_test_crit_against_int_coerce_type(self): name = self.tables.data_table.c.name col = self.tables.data_table.c["data"] |
