diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 73 |
1 files changed, 48 insertions, 25 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 96a8bc023..61649e5e3 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -1,8 +1,14 @@ +# testing/assertions.py +# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + from __future__ import absolute_import from . import util as testutil from sqlalchemy import pool, orm, util -from sqlalchemy.engine import default, create_engine +from sqlalchemy.engine import default, create_engine, url from sqlalchemy import exc as sa_exc from sqlalchemy.util import decorator from sqlalchemy import types as sqltypes, schema @@ -92,30 +98,36 @@ def uses_deprecated(*messages): @decorator def decorate(fn, *args, **kw): - # todo: should probably be strict about this, too - filters = [dict(action='ignore', - category=sa_exc.SAPendingDeprecationWarning)] - if not messages: - filters.append(dict(action='ignore', - category=sa_exc.SADeprecationWarning)) - else: - filters.extend( - [dict(action='ignore', - message=message, - category=sa_exc.SADeprecationWarning) - for message in - [(m.startswith('//') and - ('Call to deprecated function ' + m[2:]) or m) - for m in messages]]) - - for f in filters: - warnings.filterwarnings(**f) - try: + with expect_deprecated(*messages): return fn(*args, **kw) - finally: - resetwarnings() return decorate +@contextlib.contextmanager +def expect_deprecated(*messages): + # todo: should probably be strict about this, too + filters = [dict(action='ignore', + category=sa_exc.SAPendingDeprecationWarning)] + if not messages: + filters.append(dict(action='ignore', + category=sa_exc.SADeprecationWarning)) + else: + filters.extend( + [dict(action='ignore', + message=message, + category=sa_exc.SADeprecationWarning) + for message in + [(m.startswith('//') and + ('Call to deprecated function ' + m[2:]) or m) + for m in messages]]) + + for f in filters: + warnings.filterwarnings(**f) + try: + yield + finally: + resetwarnings() + + def global_cleanup_assertions(): """Check things that have to be finalized at the end of a test suite. @@ -181,7 +193,8 @@ class AssertsCompiledSQL(object): checkparams=None, dialect=None, checkpositional=None, use_default_dialect=False, - allow_dialect_select=False): + allow_dialect_select=False, + literal_binds=False): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: @@ -195,26 +208,36 @@ class AssertsCompiledSQL(object): elif dialect == 'default': dialect = default.DefaultDialect() elif isinstance(dialect, util.string_types): - dialect = create_engine("%s://" % dialect).dialect + dialect = url.URL(dialect).get_dialect()() kw = {} + compile_kwargs = {} + if params is not None: kw['column_keys'] = list(params) + if literal_binds: + compile_kwargs['literal_binds'] = True + if isinstance(clause, orm.Query): context = clause._compile_context() context.statement.use_labels = True clause = context.statement + if compile_kwargs: + kw['compile_kwargs'] = compile_kwargs + c = clause.compile(dialect=dialect, **kw) param_str = repr(getattr(c, 'params', {})) if util.py3k: param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + print(("\nSQL String:\n" + util.text_type(c) + param_str).encode('utf-8')) + else: + print("\nSQL String:\n" + util.text_type(c).encode('utf-8') + param_str) - print("\nSQL String:\n" + util.text_type(c) + param_str) cc = re.sub(r'[\n\t]', '', util.text_type(c)) |