summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
authorSławek Ehlert <slafs@op.pl>2015-01-27 22:04:38 +0100
committerSławek Ehlert <slafs@op.pl>2015-01-27 22:04:38 +0100
commit57b2bd5dcba6140b511c898c0f682234f13d5c51 (patch)
treea0899b2a35d27e177001b163054c3c9a8f7f1c06 /lib/sqlalchemy/testing
parent6a1f16d09958e549502a0991890d64964c71b357 (diff)
parent8aaa8dd6bdfb85fa481efa3115b9080d935d344c (diff)
downloadsqlalchemy-pr/152.tar.gz
Merge branch 'master' into oracle-servicename-optionpr/152
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/__init__.py3
-rw-r--r--lib/sqlalchemy/testing/assertions.py27
-rw-r--r--lib/sqlalchemy/testing/assertsql.py530
-rw-r--r--lib/sqlalchemy/testing/engines.py6
-rw-r--r--lib/sqlalchemy/testing/exclusions.py2
-rw-r--r--lib/sqlalchemy/testing/fixtures.py5
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py15
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py3
-rw-r--r--lib/sqlalchemy/testing/profiling.py6
-rw-r--r--lib/sqlalchemy/testing/requirements.py5
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py6
-rw-r--r--lib/sqlalchemy/testing/util.py59
12 files changed, 364 insertions, 303 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 1f37b4b45..2375a13a9 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -23,7 +23,8 @@ from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
AssertsExecutionResults, expect_deprecated, expect_warnings
-from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
+from .util import run_as_contextmanager, rowset, fail, \
+ provide_metadata, adict, force_drop_names
crashes = skip
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index bf7c27a89..635f6c539 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -229,6 +229,7 @@ class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None,
checkparams=None, dialect=None,
checkpositional=None,
+ check_prefetch=None,
use_default_dialect=False,
allow_dialect_select=False,
literal_binds=False):
@@ -289,6 +290,8 @@ class AssertsCompiledSQL(object):
if checkpositional is not None:
p = c.construct_params(params)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+ if check_prefetch is not None:
+ eq_(c.prefetch, check_prefetch)
class ComparesTables(object):
@@ -405,29 +408,27 @@ class AssertsExecutionResults(object):
cls.__name__, repr(expected_item)))
return True
+ def sql_execution_asserter(self, db=None):
+ if db is None:
+ from . import db as db
+
+ return assertsql.assert_engine(db)
+
def assert_sql_execution(self, db, callable_, *rules):
- assertsql.asserter.add_rules(rules)
- try:
+ with self.sql_execution_asserter(db) as asserter:
callable_()
- assertsql.asserter.statement_complete()
- finally:
- assertsql.asserter.clear_rules()
+ asserter.assert_(*rules)
- def assert_sql(self, db, callable_, list_, with_sequences=None):
- if (with_sequences is not None and
- config.db.dialect.supports_sequences):
- rules = with_sequences
- else:
- rules = list_
+ def assert_sql(self, db, callable_, rules):
newrules = []
for rule in rules:
if isinstance(rule, dict):
newrule = assertsql.AllOf(*[
- assertsql.ExactSQL(k, v) for k, v in rule.items()
+ assertsql.CompiledSQL(k, v) for k, v in rule.items()
])
else:
- newrule = assertsql.ExactSQL(*rule)
+ newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
self.assert_sql_execution(db, callable_, *newrules)
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index bcc999fe3..5c746e8f1 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -8,84 +8,141 @@
from ..engine.default import DefaultDialect
from .. import util
import re
+import collections
+import contextlib
+from .. import event
+from sqlalchemy.schema import _DDLCompiles
+from sqlalchemy.engine.util import _distill_params
class AssertRule(object):
- def process_execute(self, clauseelement, *multiparams, **params):
- pass
+ is_consumed = False
+ errormessage = None
+ consume_statement = True
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
+ def process_statement(self, execute_observed):
pass
- def is_consumed(self):
- """Return True if this rule has been consumed, False if not.
-
- Should raise an AssertionError if this rule's condition has
- definitely failed.
-
- """
-
- raise NotImplementedError()
+ def no_more_statements(self):
+ assert False, 'All statements are complete, but pending '\
+ 'assertion rules remain'
- def rule_passed(self):
- """Return True if the last test of this rule passed, False if
- failed, None if no test was applied."""
- raise NotImplementedError()
-
- def consume_final(self):
- """Return True if this rule has been consumed.
-
- Should raise an AssertionError if this rule's condition has not
- been consumed or has failed.
+class SQLMatchRule(AssertRule):
+ pass
- """
- if self._result is None:
- assert False, 'Rule has not been consumed'
- return self.is_consumed()
+class CursorSQL(SQLMatchRule):
+ consume_statement = False
+ def __init__(self, statement, params=None):
+ self.statement = statement
+ self.params = params
-class SQLMatchRule(AssertRule):
- def __init__(self):
- self._result = None
- self._errmsg = ""
+ def process_statement(self, execute_observed):
+ stmt = execute_observed.statements[0]
+ if self.statement != stmt.statement or (
+ self.params is not None and self.params != stmt.parameters):
+ self.errormessage = \
+ "Testing for exact SQL %s parameters %s received %s %s" % (
+ self.statement, self.params,
+ stmt.statement, stmt.parameters
+ )
+ else:
+ execute_observed.statements.pop(0)
+ self.is_consumed = True
+ if not execute_observed.statements:
+ self.consume_statement = True
- def rule_passed(self):
- return self._result
- def is_consumed(self):
- if self._result is None:
- return False
+class CompiledSQL(SQLMatchRule):
- assert self._result, self._errmsg
+ def __init__(self, statement, params=None):
+ self.statement = statement
+ self.params = params
- return True
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r'[\n\t]', '', self.statement)
+ return received_statement == stmt
+ def _compile_dialect(self, execute_observed):
+ return DefaultDialect()
-class ExactSQL(SQLMatchRule):
+ def _received_statement(self, execute_observed):
+ """reconstruct the statement and params in terms
+ of a target dialect, which for CompiledSQL is just DefaultDialect."""
- def __init__(self, sql, params=None):
- SQLMatchRule.__init__(self)
- self.sql = sql
- self.params = params
+ context = execute_observed.context
+ compare_dialect = self._compile_dialect(execute_observed)
+ if isinstance(context.compiled.statement, _DDLCompiles):
+ compiled = \
+ context.compiled.statement.compile(dialect=compare_dialect)
+ else:
+ compiled = (
+ context.compiled.statement.compile(
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ inline=context.compiled.inline)
+ )
+ _received_statement = re.sub(r'[\n\t]', '', str(compiled))
+ parameters = execute_observed.parameters
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- _received_statement = \
- _process_engine_statement(context.unicode_statement,
- context)
- _received_parameters = context.compiled_parameters
+ if not parameters:
+ _received_parameters = [compiled.construct_params()]
+ else:
+ _received_parameters = [
+ compiled.construct_params(m) for m in parameters]
+
+ return _received_statement, _received_parameters
+
+ def process_statement(self, execute_observed):
+ context = execute_observed.context
+
+ _received_statement, _received_parameters = \
+ self._received_statement(execute_observed)
+ params = self._all_params(context)
+
+ equivalent = self._compare_sql(execute_observed, _received_statement)
+
+ if equivalent:
+ if params is not None:
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while all_params and all_received:
+ param = dict(all_params.pop(0))
+
+ for idx, received in enumerate(list(all_received)):
+ # do a positive compare only
+ for param_key in param:
+ # a key in param did not match current
+ # 'received'
+ if param_key not in received or \
+ received[param_key] != param[param_key]:
+ break
+ else:
+ # all keys in param matched 'received';
+ # onto next param
+ del all_received[idx]
+ break
+ else:
+ # param did not match any entry
+ # in all_received
+ equivalent = False
+ break
+ if all_params or all_received:
+ equivalent = False
- # TODO: remove this step once all unit tests are migrated, as
- # ExactSQL should really be *exact* SQL
+ if equivalent:
+ self.is_consumed = True
+ self.errormessage = None
+ else:
+ self.errormessage = self._failure_message(params) % {
+ 'received_statement': _received_statement,
+ 'received_parameters': _received_parameters
+ }
- sql = _process_assertion_statement(self.sql, context)
- equivalent = _received_statement == sql
+ def _all_params(self, context):
if self.params:
if util.callable(self.params):
params = self.params(context)
@@ -93,127 +150,77 @@ class ExactSQL(SQLMatchRule):
params = self.params
if not isinstance(params, list):
params = [params]
- equivalent = equivalent and params \
- == context.compiled_parameters
+ return params
else:
- params = {}
- self._result = equivalent
- if not self._result:
- self._errmsg = (
- 'Testing for exact statement %r exact params %r, '
- 'received %r with params %r' %
- (sql, params, _received_statement, _received_parameters))
-
+ return None
+
+ def _failure_message(self, expected_params):
+ return (
+ 'Testing for compiled statement %r partial params %r, '
+ 'received %%(received_statement)r with params '
+ '%%(received_parameters)r' % (
+ self.statement, expected_params
+ )
+ )
-class RegexSQL(SQLMatchRule):
+class RegexSQL(CompiledSQL):
def __init__(self, regex, params=None):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- _received_statement = \
- _process_engine_statement(context.unicode_statement,
- context)
- _received_parameters = context.compiled_parameters
- equivalent = bool(self.regex.match(_received_statement))
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
-
- # do a positive compare only
-
- for param, received in zip(params, _received_parameters):
- for k, v in param.items():
- if k not in received or received[k] != v:
- equivalent = False
- break
- else:
- params = {}
- self._result = equivalent
- if not self._result:
- self._errmsg = \
- 'Testing for regex %r partial params %r, received %r '\
- 'with params %r' % (self.orig_regex, params,
- _received_statement,
- _received_parameters)
-
+ def _failure_message(self, expected_params):
+ return (
+ 'Testing for compiled statement ~%r partial params %r, '
+ 'received %%(received_statement)r with params '
+ '%%(received_parameters)r' % (
+ self.orig_regex, expected_params
+ )
+ )
-class CompiledSQL(SQLMatchRule):
+ def _compare_sql(self, execute_observed, received_statement):
+ return bool(self.regex.match(received_statement))
- def __init__(self, statement, params=None):
- SQLMatchRule.__init__(self)
- self.statement = statement
- self.params = params
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- if not context:
- return
- from sqlalchemy.schema import _DDLCompiles
- _received_parameters = list(context.compiled_parameters)
-
- # recompile from the context, using the default dialect
+class DialectSQL(CompiledSQL):
+ def _compile_dialect(self, execute_observed):
+ return execute_observed.context.dialect
- if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(dialect=DefaultDialect())
+ def _received_statement(self, execute_observed):
+ received_stmt, received_params = super(DialectSQL, self).\
+ _received_statement(execute_observed)
+ for real_stmt in execute_observed.statements:
+ if real_stmt.statement == received_stmt:
+ break
else:
- compiled = (
- context.compiled.statement.compile(
- dialect=DefaultDialect(),
- column_keys=context.compiled.column_keys)
- )
- _received_statement = re.sub(r'[\n\t]', '', str(compiled))
- equivalent = self.statement == _received_statement
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
- else:
- params = list(params)
- all_params = list(params)
- all_received = list(_received_parameters)
- while params:
- param = dict(params.pop(0))
- for k, v in context.compiled.params.items():
- param.setdefault(k, v)
- if param not in _received_parameters:
- equivalent = False
- break
- else:
- _received_parameters.remove(param)
- if _received_parameters:
- equivalent = False
+ raise AssertionError(
+ "Can't locate compiled statement %r in list of "
+ "statements actually invoked" % received_stmt)
+ return received_stmt, execute_observed.context.compiled_parameters
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r'[\n\t]', '', self.statement)
+
+ # convert our comparison statement to have the
+ # paramstyle of the received
+ paramstyle = execute_observed.context.dialect.paramstyle
+ if paramstyle == 'pyformat':
+ stmt = re.sub(
+ r':([\w_]+)', r"%(\1)s", stmt)
else:
- params = {}
- all_params = {}
- all_received = []
- self._result = equivalent
- if not self._result:
- print('Testing for compiled statement %r partial params '
- '%r, received %r with params %r' %
- (self.statement, all_params,
- _received_statement, all_received))
- self._errmsg = (
- 'Testing for compiled statement %r partial params %r, '
- 'received %r with params %r' %
- (self.statement, all_params,
- _received_statement, all_received))
-
- # print self._errmsg
+ # positional params
+ repl = None
+ if paramstyle == 'qmark':
+ repl = "?"
+ elif paramstyle == 'format':
+ repl = r"%s"
+ elif paramstyle == 'numeric':
+ repl = None
+ stmt = re.sub(r':([\w_]+)', repl, stmt)
+
+ return received_statement == stmt
class CountStatements(AssertRule):
@@ -222,21 +229,13 @@ class CountStatements(AssertRule):
self.count = count
self._statement_count = 0
- def process_execute(self, clauseelement, *multiparams, **params):
+ def process_statement(self, execute_observed):
self._statement_count += 1
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- pass
-
- def is_consumed(self):
- return False
-
- def consume_final(self):
- assert self.count == self._statement_count, \
- 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
- return True
+ def no_more_statements(self):
+ if self.count != self._statement_count:
+ assert False, 'desired statement count %d does not match %d' \
+ % (self.count, self._statement_count)
class AllOf(AssertRule):
@@ -244,116 +243,113 @@ class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
- def process_execute(self, clauseelement, *multiparams, **params):
- for rule in self.rules:
- rule.process_execute(clauseelement, *multiparams, **params)
-
- def process_cursor_execute(self, statement, parameters, context,
- executemany):
- for rule in self.rules:
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
-
- def is_consumed(self):
- if not self.rules:
- return True
+ def process_statement(self, execute_observed):
for rule in list(self.rules):
- if rule.rule_passed(): # a rule passed, move on
- self.rules.remove(rule)
- return len(self.rules) == 0
- return False
+ rule.errormessage = None
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.discard(rule)
+ if not self.rules:
+ self.is_consumed = True
+ break
+ elif not rule.errormessage:
+ # rule is not done yet
+ self.errormessage = None
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
- def rule_passed(self):
- return self.is_consumed()
- def consume_final(self):
- return len(self.rules) == 0
+class Or(AllOf):
+ def process_statement(self, execute_observed):
+ for rule in self.rules:
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.is_consumed = True
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
-class Or(AllOf):
- def __init__(self, *rules):
- self.rules = set(rules)
- self._consume_final = False
- def is_consumed(self):
- if not self.rules:
- return True
- for rule in list(self.rules):
- if rule.rule_passed(): # a rule passed
- self._consume_final = True
- return True
- return False
+class SQLExecuteObserved(object):
+ def __init__(self, context, clauseelement, multiparams, params):
+ self.context = context
+ self.clauseelement = clauseelement
+ self.parameters = _distill_params(multiparams, params)
+ self.statements = []
- def consume_final(self):
- assert self._consume_final, "Unsatisified rules remain"
+class SQLCursorExecuteObserved(
+ collections.namedtuple(
+ "SQLCursorExecuteObserved",
+ ["statement", "parameters", "context", "executemany"])
+):
+ pass
-def _process_engine_statement(query, context):
- if util.jython:
- # oracle+zxjdbc passes a PyStatement when returning into
+class SQLAsserter(object):
+ def __init__(self):
+ self.accumulated = []
- query = str(query)
- if context.engine.name == 'mssql' \
- and query.endswith('; select scope_identity()'):
- query = query[:-25]
- query = re.sub(r'\n', '', query)
- return query
+ def _close(self):
+ self._final = self.accumulated
+ del self.accumulated
+ def assert_(self, *rules):
+ rules = list(rules)
+ observed = list(self._final)
-def _process_assertion_statement(query, context):
- paramstyle = context.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle == 'pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle == 'qmark':
- repl = "?"
- elif paramstyle == 'format':
- repl = r"%s"
- elif paramstyle == 'numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
+ while observed and rules:
+ rule = rules[0]
+ rule.process_statement(observed[0])
+ if rule.is_consumed:
+ rules.pop(0)
+ elif rule.errormessage:
+ assert False, rule.errormessage
- return query
+ if rule.consume_statement:
+ observed.pop(0)
+ if not observed and rules:
+ rules[0].no_more_statements()
+ elif not rules and observed:
+ assert False, "Additional SQL statements remain"
-class SQLAssert(object):
- rules = None
+@contextlib.contextmanager
+def assert_engine(engine):
+ asserter = SQLAsserter()
- def add_rules(self, rules):
- self.rules = list(rules)
+ orig = []
- def statement_complete(self):
- for rule in self.rules:
- if not rule.consume_final():
- assert False, \
- 'All statements are complete, but pending '\
- 'assertion rules remain'
-
- def clear_rules(self):
- del self.rules
-
- def execute(self, conn, clauseelement, multiparams, params, result):
- if self.rules is not None:
- if not self.rules:
- assert False, \
- 'All rules have been exhausted, but further '\
- 'statements remain'
- rule = self.rules[0]
- rule.process_execute(clauseelement, *multiparams, **params)
- if rule.is_consumed():
- self.rules.pop(0)
-
- def cursor_execute(self, conn, cursor, statement, parameters,
- context, executemany):
- if self.rules:
- rule = self.rules[0]
- rule.process_cursor_execute(statement, parameters, context,
- executemany)
+ @event.listens_for(engine, "before_execute")
+ def connection_execute(conn, clauseelement, multiparams, params):
+ # grab the original statement + params before any cursor
+ # execution
+ orig[:] = clauseelement, multiparams, params
-asserter = SQLAssert()
+ @event.listens_for(engine, "after_cursor_execute")
+ def cursor_execute(conn, cursor, statement, parameters,
+ context, executemany):
+ if not context:
+ return
+ # then grab real cursor statements and associate them all
+ # around a single context
+ if asserter.accumulated and \
+ asserter.accumulated[-1].context is context:
+ obs = asserter.accumulated[-1]
+ else:
+ obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
+ asserter.accumulated.append(obs)
+ obs.statements.append(
+ SQLCursorExecuteObserved(
+ statement, parameters, context, executemany)
+ )
+
+ try:
+ yield asserter
+ finally:
+ event.remove(engine, "after_cursor_execute", cursor_execute)
+ event.remove(engine, "before_execute", connection_execute)
+ asserter._close()
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index 0f6f59401..444a79b70 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -204,7 +204,6 @@ def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
- from .assertsql import asserter
if not options:
use_reaper = True
@@ -216,11 +215,12 @@ def testing_engine(url=None, options=None):
options = config.db_opts
engine = create_engine(url, **options)
+ engine._has_events = True # enable event blocks, helps with
+ # profiling
+
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
- event.listen(engine, 'after_execute', asserter.execute)
- event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
if use_reaper:
event.listen(engine.pool, 'connect', testing_reaper.connect)
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
index f94724608..0aff43ae1 100644
--- a/lib/sqlalchemy/testing/exclusions.py
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -425,7 +425,7 @@ def skip(db, reason=None):
def only_on(dbs, reason=None):
return only_if(
- OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
+ OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)])
)
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index d86049da7..48d4d9c9b 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -192,9 +192,8 @@ class TablesTest(TestBase):
def sql_count_(self, count, fn):
self.assert_sql_count(self.bind, fn, count)
- def sql_eq_(self, callable_, statements, with_sequences=None):
- self.assert_sql(self.bind,
- callable_, statements, with_sequences)
+ def sql_eq_(self, callable_, statements):
+ self.assert_sql(self.bind, callable_, statements)
@classmethod
def _load_fixtures(cls):
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 614a12133..b0188aa5a 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -294,7 +294,7 @@ def _setup_requirements(argument):
@post
def _prep_testing_database(options, file_config):
- from sqlalchemy.testing import config
+ from sqlalchemy.testing import config, util
from sqlalchemy.testing.exclusions import against
from sqlalchemy import schema, inspect
@@ -325,19 +325,10 @@ def _prep_testing_database(options, file_config):
schema="test_schema")
))
- for tname in reversed(inspector.get_table_names(
- order_by="foreign_key")):
- e.execute(schema.DropTable(
- schema.Table(tname, schema.MetaData())
- ))
+ util.drop_all_tables(e, inspector)
if config.requirements.schemas.enabled_for_config(cfg):
- for tname in reversed(inspector.get_table_names(
- order_by="foreign_key", schema="test_schema")):
- e.execute(schema.DropTable(
- schema.Table(tname, schema.MetaData(),
- schema="test_schema")
- ))
+ util.drop_all_tables(e, inspector, schema=cfg.test_schema)
if against(cfg, "postgresql"):
from sqlalchemy.dialects import postgresql
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index 4bbc8ed9a..fbab4966c 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -84,7 +84,8 @@ def pytest_collection_modifyitems(session, config, items):
rebuilt_items = collections.defaultdict(list)
items[:] = [
item for item in
- items if isinstance(item.parent, pytest.Instance)]
+ items if isinstance(item.parent, pytest.Instance)
+ and not item.parent.parent.name.startswith("_")]
test_classes = set(item.parent for item in items)
for test_class in test_classes:
for sub_cls in plugin_base.generate_sub_tests(
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
index 671bbe32d..57308925e 100644
--- a/lib/sqlalchemy/testing/profiling.py
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -226,6 +226,7 @@ def count_functions(variance=0.05):
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
+
if expected is None:
expected_count = None
else:
@@ -249,10 +250,11 @@ def count_functions(variance=0.05):
else:
raise AssertionError(
"Adjusted function call count %s not within %s%% "
- "of expected %s. Rerun with --write-profiles to "
+ "of expected %s, platform %s. Rerun with "
+ "--write-profiles to "
"regenerate this callcount."
% (
callcount, (variance * 100),
- expected_count))
+ expected_count, _profile_stats.platform_key))
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index da3e3128a..5744431cb 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -323,6 +323,11 @@ class SuiteRequirements(Requirements):
return exclusions.closed()
@property
+ def temporary_tables(self):
+ """target database supports temporary tables"""
+ return exclusions.open()
+
+ @property
def temporary_views(self):
"""target database supports temporary views"""
return exclusions.closed()
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 08b858b47..3edbdeb8c 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -128,6 +128,10 @@ class ComponentReflectionTest(fixtures.TablesTest):
DDL("create temporary view user_tmp_v as "
"select * from user_tmp")
)
+ event.listen(
+ user_tmp, "before_drop",
+ DDL("drop view user_tmp_v")
+ )
@classmethod
def define_index(cls, metadata, users):
@@ -511,6 +515,8 @@ class ComponentReflectionTest(fixtures.TablesTest):
def test_get_temp_table_indexes(self):
insp = inspect(self.metadata.bind)
indexes = insp.get_indexes('user_tmp')
+ for ind in indexes:
+ ind.pop('dialect_options', None)
eq_(
# TODO: we need to add better filtering for indexes/uq constraints
# that are doubled up
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index 7b3f721a6..8230f923a 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -147,6 +147,10 @@ def run_as_contextmanager(ctx, fn, *arg, **kw):
simulating the behavior of 'with' to support older
Python versions.
+ This is not necessary anymore as we have placed 2.6
+ as minimum Python version, however some tests are still using
+ this structure.
+
"""
obj = ctx.__enter__()
@@ -194,6 +198,25 @@ def provide_metadata(fn, *args, **kw):
self.metadata = prev_meta
+def force_drop_names(*names):
+ """Force the given table names to be dropped after test complete,
+ isolating for foreign key cycles
+
+ """
+ from . import config
+ from sqlalchemy import inspect
+
+ @decorator
+ def go(fn, *args, **kw):
+
+ try:
+ return fn(*args, **kw)
+ finally:
+ drop_all_tables(
+ config.db, inspect(config.db), include_names=names)
+ return go
+
+
class adict(dict):
"""Dict keys available as attributes. Shadows."""
@@ -207,3 +230,39 @@ class adict(dict):
return tuple([self[key] for key in keys])
get_all = __call__
+
+
+def drop_all_tables(engine, inspector, schema=None, include_names=None):
+ from sqlalchemy import Column, Table, Integer, MetaData, \
+ ForeignKeyConstraint
+ from sqlalchemy.schema import DropTable, DropConstraint
+
+ if include_names is not None:
+ include_names = set(include_names)
+
+ with engine.connect() as conn:
+ for tname, fkcs in reversed(
+ inspector.get_sorted_table_and_fkc_names(schema=schema)):
+ if tname:
+ if include_names is not None and tname not in include_names:
+ continue
+ conn.execute(DropTable(
+ Table(tname, MetaData())
+ ))
+ elif fkcs:
+ if not engine.dialect.supports_alter:
+ continue
+ for tname, fkc in fkcs:
+ if include_names is not None and \
+ tname not in include_names:
+ continue
+ tb = Table(
+ tname, MetaData(),
+ Column('x', Integer),
+ Column('y', Integer),
+ schema=schema
+ )
+ conn.execute(DropConstraint(
+ ForeignKeyConstraint(
+ [tb.c.x], [tb.c.y], name=fkc)
+ ))