diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-06-10 21:18:24 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-06-10 21:18:24 +0000 |
| commit | 45cec095b4904ba71425d2fe18c143982dd08f43 (patch) | |
| tree | af5e540fdcbf1cb2a3337157d69d4b40be010fa8 /test/testlib | |
| parent | 698a3c1ac665e7cd2ef8d5ad3ebf51b7fe6661f4 (diff) | |
| download | sqlalchemy-45cec095b4904ba71425d2fe18c143982dd08f43.tar.gz | |
- unit tests have been migrated from unittest to nose.
See README.unittests for information on how to run
the tests. [ticket:970]
Diffstat (limited to 'test/testlib')
| -rw-r--r-- | test/testlib/__init__.py | 38 | ||||
| -rw-r--r-- | test/testlib/assertsql.py | 283 | ||||
| -rw-r--r-- | test/testlib/compat.py | 19 | ||||
| -rw-r--r-- | test/testlib/config.py | 344 | ||||
| -rw-r--r-- | test/testlib/coverage.py | 1098 | ||||
| -rw-r--r-- | test/testlib/engines.py | 245 | ||||
| -rw-r--r-- | test/testlib/orm.py | 117 | ||||
| -rw-r--r-- | test/testlib/profiling.py | 207 | ||||
| -rw-r--r-- | test/testlib/requires.py | 127 | ||||
| -rw-r--r-- | test/testlib/sa_unittest.py | 787 | ||||
| -rw-r--r-- | test/testlib/schema.py | 79 | ||||
| -rw-r--r-- | test/testlib/testing.py | 919 |
12 files changed, 0 insertions, 4263 deletions
diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py deleted file mode 100644 index 5b8075ddb..000000000 --- a/test/testlib/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Enhance unittest and instrument SQLAlchemy classes for testing. - -Load after sqlalchemy imports to use instrumented stand-ins like Table. -""" - -import sys -import testlib.config -from testlib.schema import Table, Column -import testlib.testing as testing -from testlib.testing import \ - AssertsCompiledSQL, \ - AssertsExecutionResults, \ - ComparesTables, \ - TestBase, \ - rowset -from testlib.orm import mapper -import testlib.profiling as profiling -import testlib.engines as engines -import testlib.requires as requires -from testlib.compat import _function_named - - -__all__ = ('testing', - 'mapper', - 'Table', 'Column', - 'rowset', - 'TestBase', 'AssertsExecutionResults', - 'AssertsCompiledSQL', 'ComparesTables', - 'profiling', 'engines', - '_function_named') - - -testing.requires = requires - -sys.modules['testlib.sa'] = sa = testing.CompositeModule( - 'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule( - 'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm')) -sys.modules['testlib.sa.orm'] = sa.orm diff --git a/test/testlib/assertsql.py b/test/testlib/assertsql.py deleted file mode 100644 index dc2c6d40f..000000000 --- a/test/testlib/assertsql.py +++ /dev/null @@ -1,283 +0,0 @@ - -from sqlalchemy.interfaces import ConnectionProxy -from sqlalchemy.engine.default import DefaultDialect -from sqlalchemy.engine.base import Connection -from sqlalchemy import util -import testing -import re - -class AssertRule(object): - def process_execute(self, clauseelement, *multiparams, **params): - pass - - def process_cursor_execute(self, statement, parameters, context, executemany): - pass - - def is_consumed(self): - """Return True if this rule has been consumed, False if not. - - Should raise an AssertionError if this rule's condition has definitely failed. - - """ - raise NotImplementedError() - - def rule_passed(self): - """Return True if the last test of this rule passed, False if failed, None if no test was applied.""" - - raise NotImplementedError() - - def consume_final(self): - """Return True if this rule has been consumed. - - Should raise an AssertionError if this rule's condition has not been consumed or has failed. - - """ - - if self._result is None: - assert False, "Rule has not been consumed" - - return self.is_consumed() - -class SQLMatchRule(AssertRule): - def __init__(self): - self._result = None - self._errmsg = "" - - def rule_passed(self): - return self._result - - def is_consumed(self): - if self._result is None: - return False - - assert self._result, self._errmsg - - return True - -class ExactSQL(SQLMatchRule): - def __init__(self, sql, params=None): - SQLMatchRule.__init__(self) - self.sql = sql - self.params = params - - def process_cursor_execute(self, statement, parameters, context, executemany): - if not context: - return - - _received_statement = _process_engine_statement(statement, context) - _received_parameters = context.compiled_parameters - - # TODO: remove this step once all unit tests - # are migrated, as ExactSQL should really be *exact* SQL - sql = _process_assertion_statement(self.sql, context) - - equivalent = _received_statement == sql - if self.params: - if util.callable(self.params): - params = self.params(context) - else: - params = self.params - - if not isinstance(params, list): - params = [params] - equivalent = equivalent and params == context.compiled_parameters - else: - params = {} - - - self._result = equivalent - if not self._result: - self._errmsg = "Testing for exact statement %r exact params %r, " \ - "received %r with params %r" % (sql, params, _received_statement, _received_parameters) - - -class RegexSQL(SQLMatchRule): - def __init__(self, regex, params=None): - SQLMatchRule.__init__(self) - self.regex = re.compile(regex) - self.orig_regex = regex - self.params = params - - def process_cursor_execute(self, statement, parameters, context, executemany): - if not context: - return - - _received_statement = _process_engine_statement(statement, context) - _received_parameters = context.compiled_parameters - - equivalent = bool(self.regex.match(_received_statement)) - if self.params: - if util.callable(self.params): - params = self.params(context) - else: - params = self.params - - if not isinstance(params, list): - params = [params] - - # do a positive compare only - for param, received in zip(params, _received_parameters): - for k, v in param.iteritems(): - if k not in received or received[k] != v: - equivalent = False - break - else: - params = {} - - self._result = equivalent - if not self._result: - self._errmsg = "Testing for regex %r partial params %r, "\ - "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters) - -class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params): - SQLMatchRule.__init__(self) - self.statement = statement - self.params = params - - def process_cursor_execute(self, statement, parameters, context, executemany): - if not context: - return - - _received_parameters = context.compiled_parameters - - # recompile from the context, using the default dialect - compiled = context.compiled.statement.\ - compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys) - - _received_statement = re.sub(r'\n', '', str(compiled)) - - equivalent = self.statement == _received_statement - if self.params: - if util.callable(self.params): - params = self.params(context) - else: - params = self.params - - if not isinstance(params, list): - params = [params] - - # do a positive compare only - for param, received in zip(params, _received_parameters): - for k, v in param.iteritems(): - if k not in received or received[k] != v: - equivalent = False - break - else: - params = {} - - self._result = equivalent - if not self._result: - self._errmsg = "Testing for compiled statement %r partial params %r, " \ - "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters) - - -class CountStatements(AssertRule): - def __init__(self, count): - self.count = count - self._statement_count = 0 - - def process_execute(self, clauseelement, *multiparams, **params): - self._statement_count += 1 - - def process_cursor_execute(self, statement, parameters, context, executemany): - pass - - def is_consumed(self): - return False - - def consume_final(self): - assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count) - return True - -class AllOf(AssertRule): - def __init__(self, *rules): - self.rules = set(rules) - - def process_execute(self, clauseelement, *multiparams, **params): - for rule in self.rules: - rule.process_execute(clauseelement, *multiparams, **params) - - def process_cursor_execute(self, statement, parameters, context, executemany): - for rule in self.rules: - rule.process_cursor_execute(statement, parameters, context, executemany) - - def is_consumed(self): - if not self.rules: - return True - - for rule in list(self.rules): - if rule.rule_passed(): # a rule passed, move on - self.rules.remove(rule) - return len(self.rules) == 0 - - assert False, "No assertion rules were satisfied for statement" - - def consume_final(self): - return len(self.rules) == 0 - -def _process_engine_statement(query, context): - if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'): - query = query[:-25] - - query = re.sub(r'\n', '', query) - - return query - -def _process_assertion_statement(query, context): - paramstyle = context.dialect.paramstyle - if paramstyle == 'named': - pass - elif paramstyle =='pyformat': - query = re.sub(r':([\w_]+)', r"%(\1)s", query) - else: - # positional params - repl = None - if paramstyle=='qmark': - repl = "?" - elif paramstyle=='format': - repl = r"%s" - elif paramstyle=='numeric': - repl = None - query = re.sub(r':([\w_]+)', repl, query) - - return query - -class SQLAssert(ConnectionProxy): - rules = None - - def add_rules(self, rules): - self.rules = list(rules) - - def statement_complete(self): - for rule in self.rules: - if not rule.consume_final(): - assert False, "All statements are complete, but pending assertion rules remain" - - def clear_rules(self): - del self.rules - - def execute(self, conn, execute, clauseelement, *multiparams, **params): - result = execute(clauseelement, *multiparams, **params) - - if self.rules is not None: - if not self.rules: - assert False, "All rules have been exhausted, but further statements remain" - rule = self.rules[0] - rule.process_execute(clauseelement, *multiparams, **params) - if rule.is_consumed(): - self.rules.pop(0) - - return result - - def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): - result = execute(cursor, statement, parameters, context) - - if self.rules: - rule = self.rules[0] - rule.process_cursor_execute(statement, parameters, context, executemany) - - return result - -asserter = SQLAssert() - diff --git a/test/testlib/compat.py b/test/testlib/compat.py deleted file mode 100644 index 73eb2d651..000000000 --- a/test/testlib/compat.py +++ /dev/null @@ -1,19 +0,0 @@ -import types -import __builtin__ - -__all__ = '_function_named', 'callable' - - -def _function_named(fn, newname): - try: - fn.__name__ = newname - except: - fn = types.FunctionType(fn.func_code, fn.func_globals, newname, - fn.func_defaults, fn.func_closure) - return fn - -try: - callable = __builtin__.callable -except NameError: - def callable(fn): return hasattr(fn, '__call__') - diff --git a/test/testlib/config.py b/test/testlib/config.py deleted file mode 100644 index cef4c6e1d..000000000 --- a/test/testlib/config.py +++ /dev/null @@ -1,344 +0,0 @@ -import optparse, os, sys, re, ConfigParser, StringIO, time, warnings -logging, require = None, None - - -__all__ = 'parser', 'configure', 'options', - -db = None -db_label, db_url, db_opts = None, None, {} - -options = None -file_config = None -coverage_enabled = False - -base_config = """ -[db] -sqlite=sqlite:///:memory: -sqlite_file=sqlite:///querytest.db -postgres=postgres://scott:tiger@127.0.0.1:5432/test -mysql=mysql://scott:tiger@127.0.0.1:3306/test -oracle=oracle://scott:tiger@127.0.0.1:1521 -oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 -mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test -firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb -maxdb=maxdb://MONA:RED@/maxdb1 -""" - -parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]") - -def configure(): - global options, config - global getopts_options, file_config - - file_config = ConfigParser.ConfigParser() - file_config.readfp(StringIO.StringIO(base_config)) - file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) - - # Opt parsing can fire immediate actions, like logging and coverage - (options, args) = parser.parse_args() - sys.argv[1:] = args - - # Lazy setup of other options (post coverage) - for fn in post_configure: - fn(options, file_config) - - return options, file_config - -def configure_defaults(): - global options, config - global getopts_options, file_config - global db - - file_config = ConfigParser.ConfigParser() - file_config.readfp(StringIO.StringIO(base_config)) - file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) - (options, args) = parser.parse_args([]) - - # make error messages raised by decorators that depend on a default - # database clearer. - class _engine_bomb(object): - def __getattr__(self, key): - raise RuntimeError('No default engine available, testlib ' - 'was configured with defaults only.') - - db = _engine_bomb() - import testlib.testing - testlib.testing.db = db - - return options, file_config - -def _log(option, opt_str, value, parser): - global logging - if not logging: - import logging - logging.basicConfig() - - if opt_str.endswith('-info'): - logging.getLogger(value).setLevel(logging.INFO) - elif opt_str.endswith('-debug'): - logging.getLogger(value).setLevel(logging.DEBUG) - -def _start_cumulative_coverage(option, opt_str, value, parser): - _start_coverage(option, opt_str, value, parser, erase=False) - -def _start_coverage(option, opt_str, value, parser, erase=True): - import sys, atexit, coverage - true_out = sys.stdout - - global coverage_enabled - coverage_enabled = True - - def _iter_covered_files(mod, recursive=True): - - if recursive: - ff = os.walk - else: - ff = os.listdir - - for rec in ff(os.path.dirname(mod.__file__)): - for x in rec[2]: - if x.endswith('.py'): - yield os.path.join(rec[0], x) - - def _stop(): - coverage.stop() - true_out.write("\nPreparing coverage report...\n") - - from sqlalchemy import sql, orm, engine, \ - ext, databases, log - - import sqlalchemy - - for modset in [ - _iter_covered_files(sqlalchemy, recursive=False), - _iter_covered_files(databases), - _iter_covered_files(engine), - _iter_covered_files(ext), - _iter_covered_files(orm), - ]: - coverage.report(list(modset), - show_missing=False, ignore_errors=False, - file=true_out) - atexit.register(_stop) - if erase: - coverage.erase() - coverage.start() - -def _list_dbs(*args): - print "Available --db options (use --dburi to override)" - for macro in sorted(file_config.options('db')): - print "%20s\t%s" % (macro, file_config.get('db', macro)) - sys.exit(0) - -def _server_side_cursors(options, opt_str, value, parser): - db_opts['server_side_cursors'] = True - -def _engine_strategy(options, opt_str, value, parser): - if value: - db_opts['strategy'] = value - -opt = parser.add_option -opt("--verbose", action="store_true", dest="verbose", - help="enable stdout echoing/printing") -opt("--quiet", action="store_true", dest="quiet", help="suppress output") -opt("--log-info", action="callback", type="string", callback=_log, - help="turn on info logging for <LOG> (multiple OK)") -opt("--log-debug", action="callback", type="string", callback=_log, - help="turn on debug logging for <LOG> (multiple OK)") -opt("--require", action="append", dest="require", default=[], - help="require a particular driver or module version (multiple OK)") -opt("--db", action="store", dest="db", default="sqlite", - help="Use prefab database uri") -opt('--dbs', action='callback', callback=_list_dbs, - help="List available prefab dbs") -opt("--dburi", action="store", dest="dburi", - help="Database uri (overrides --db)") -opt("--dropfirst", action="store_true", dest="dropfirst", - help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)") -opt("--mockpool", action="store_true", dest="mockpool", - help="Use mock pool (asserts only one connection used)") -opt("--enginestrategy", action="callback", type="string", - callback=_engine_strategy, - help="Engine strategy (plain or threadlocal, defaults to plain)") -opt("--reversetop", action="store_true", dest="reversetop", default=False, - help="Reverse the collection ordering for topological sorts (helps " - "reveal dependency issues)") -opt("--unhashable", action="store_true", dest="unhashable", default=False, - help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") -opt("--noncomparable", action="store_true", dest="noncomparable", default=False, - help="Disallow SQLAlchemy from performing == on mapped test objects.") -opt("--truthless", action="store_true", dest="truthless", default=False, - help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") -opt("--serverside", action="callback", callback=_server_side_cursors, - help="Turn on server side cursors for PG") -opt("--mysql-engine", action="store", dest="mysql_engine", default=None, - help="Use the specified MySQL storage engine for all tables, default is " - "a db-default/InnoDB combo.") -opt("--table-option", action="append", dest="tableopts", default=[], - help="Add a dialect-specific table option, key=value") -opt("--coverage", action="callback", callback=_start_coverage, - help="Dump a full coverage report after running tests") -opt("--cumulative-coverage", action="callback", callback=_start_cumulative_coverage, - help="Like --coverage, but accumlate coverage into the current DB") -opt("--profile", action="append", dest="profile_targets", default=[], - help="Enable a named profile target (multiple OK.)") -opt("--profile-sort", action="store", dest="profile_sort", default=None, - help="Sort profile stats with this comma-separated sort order") -opt("--profile-limit", type="int", action="store", dest="profile_limit", - default=None, - help="Limit function count in profile stats") - -class _ordered_map(object): - def __init__(self): - self._keys = list() - self._data = dict() - - def __setitem__(self, key, value): - if key not in self._keys: - self._keys.append(key) - self._data[key] = value - - def __iter__(self): - for key in self._keys: - yield self._data[key] - -# at one point in refactoring, modules were injecting into the config -# process. this could probably just become a list now. -post_configure = _ordered_map() - -def _engine_uri(options, file_config): - global db_label, db_url - db_label = 'sqlite' - if options.dburi: - db_url = options.dburi - db_label = db_url[:db_url.index(':')] - elif options.db: - db_label = options.db - db_url = None - - if db_url is None: - if db_label not in file_config.options('db'): - raise RuntimeError( - "Unknown engine. Specify --dbs for known engines.") - db_url = file_config.get('db', db_label) -post_configure['engine_uri'] = _engine_uri - -def _require(options, file_config): - if not(options.require or - (file_config.has_section('require') and - file_config.items('require'))): - return - - try: - import pkg_resources - except ImportError: - raise RuntimeError("setuptools is required for version requirements") - - cmdline = [] - for requirement in options.require: - pkg_resources.require(requirement) - cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0]) - - if file_config.has_section('require'): - for label, requirement in file_config.items('require'): - if not label == db_label or label.startswith('%s.' % db_label): - continue - seen = [c for c in cmdline if requirement.startswith(c)] - if seen: - continue - pkg_resources.require(requirement) -post_configure['require'] = _require - -def _engine_pool(options, file_config): - if options.mockpool: - from sqlalchemy import pool - db_opts['poolclass'] = pool.AssertionPool -post_configure['engine_pool'] = _engine_pool - -def _create_testing_engine(options, file_config): - from testlib import engines, testing - global db - db = engines.testing_engine(db_url, db_opts) - testing.db = db -post_configure['create_engine'] = _create_testing_engine - -def _prep_testing_database(options, file_config): - from testlib import engines - from sqlalchemy import schema - - try: - # also create alt schemas etc. here? - if options.dropfirst: - e = engines.utf8_engine() - existing = e.table_names() - if existing: - if not options.quiet: - print "Dropping existing tables in database: " + db_url - try: - print "Tables: %s" % ', '.join(existing) - except: - pass - print "Abort within 5 seconds..." - time.sleep(5) - md = schema.MetaData(e, reflect=True) - md.drop_all() - e.dispose() - except (KeyboardInterrupt, SystemExit): - raise - except Exception, e: - if not options.quiet: - warnings.warn(RuntimeWarning( - "Error checking for existing tables in testing " - "database: %s" % e)) -post_configure['prep_db'] = _prep_testing_database - -def _set_table_options(options, file_config): - import testlib.schema - - table_options = testlib.schema.table_options - for spec in options.tableopts: - key, value = spec.split('=') - table_options[key] = value - - if options.mysql_engine: - table_options['mysql_engine'] = options.mysql_engine -post_configure['table_options'] = _set_table_options - -def _reverse_topological(options, file_config): - if options.reversetop: - from sqlalchemy.orm import unitofwork - from sqlalchemy import topological - class RevQueueDepSort(topological.QueueDependencySorter): - def __init__(self, tuples, allitems): - self.tuples = list(tuples) - self.allitems = list(allitems) - self.tuples.reverse() - self.allitems.reverse() - topological.QueueDependencySorter = RevQueueDepSort - unitofwork.DependencySorter = RevQueueDepSort -post_configure['topological'] = _reverse_topological - -def _set_profile_targets(options, file_config): - from testlib import profiling - - profile_config = profiling.profile_config - - for target in options.profile_targets: - profile_config['targets'].add(target) - - if options.profile_sort: - profile_config['sort'] = options.profile_sort.split(',') - - if options.profile_limit: - profile_config['limit'] = options.profile_limit - - if options.quiet: - profile_config['report'] = False - - # magic "all" target - if 'all' in profiling.all_targets: - targets = profile_config['targets'] - if 'all' in targets and len(targets) != 1: - targets.clear() - targets.add('all') -post_configure['profile_targets'] = _set_profile_targets diff --git a/test/testlib/coverage.py b/test/testlib/coverage.py deleted file mode 100644 index fc0f2c236..000000000 --- a/test/testlib/coverage.py +++ /dev/null @@ -1,1098 +0,0 @@ -#!/usr/bin/python -# -# Perforce Defect Tracking Integration Project -# <http://www.ravenbrook.com/project/p4dti/> -# -# COVERAGE.PY -- COVERAGE TESTING -# -# Gareth Rees, Ravenbrook Limited, 2001-12-04 -# Ned Batchelder, 2004-12-12 -# http://nedbatchelder.com/code/modules/coverage.html -# -# -# 1. INTRODUCTION -# -# This module provides coverage testing for Python code. -# -# The intended readership is all Python developers. -# -# This document is not confidential. -# -# See [GDR 2001-12-04a] for the command-line interface, programmatic -# interface and limitations. See [GDR 2001-12-04b] for requirements and -# design. - -r"""\ -Usage: - -coverage.py -x [-p] MODULE.py [ARG1 ARG2 ...] - Execute module, passing the given command-line arguments, collecting - coverage data. With the -p option, write to a temporary file containing - the machine name and process ID. - -coverage.py -e - Erase collected coverage data. - -coverage.py -c - Collect data from multiple coverage files (as created by -p option above) - and store it into a single file representing the union of the coverage. - -coverage.py -r [-m] [-o dir1,dir2,...] FILE1 FILE2 ... - Report on the statement coverage for the given files. With the -m - option, show line numbers of the statements that weren't executed. - -coverage.py -a [-d dir] [-o dir1,dir2,...] FILE1 FILE2 ... - Make annotated copies of the given files, marking statements that - are executed with > and statements that are missed with !. With - the -d option, make the copies in that directory. Without the -d - option, make each copy in the same directory as the original. - --o dir,dir2,... - Omit reporting or annotating files when their filename path starts with - a directory listed in the omit list. - e.g. python coverage.py -i -r -o c:\python23,lib\enthought\traits - -Coverage data is saved in the file .coverage by default. Set the -COVERAGE_FILE environment variable to save it somewhere else.""" - -__version__ = "2.75.20070722" # see detailed history at the end of this file. - -import compiler -import compiler.visitor -import glob -import os -import re -import string -import symbol -import sys -import threading -import token -import types -from socket import gethostname - - -# 2. IMPLEMENTATION -# -# This uses the "singleton" pattern. -# -# The word "morf" means a module object (from which the source file can -# be deduced by suitable manipulation of the __file__ attribute) or a -# filename. -# -# When we generate a coverage report we have to canonicalize every -# filename in the coverage dictionary just in case it refers to the -# module we are reporting on. It seems a shame to throw away this -# information so the data in the coverage dictionary is transferred to -# the 'cexecuted' dictionary under the canonical filenames. -# -# The coverage dictionary is called "c" and the trace function "t". The -# reason for these short names is that Python looks up variables by name -# at runtime and so execution time depends on the length of variables! -# In the bottleneck of this application it's appropriate to abbreviate -# names to increase speed. - -class StatementFindingAstVisitor(compiler.visitor.ASTVisitor): - """ A visitor for a parsed Abstract Syntax Tree which finds executable - statements. - """ - def __init__(self, statements, excluded, suite_spots): - compiler.visitor.ASTVisitor.__init__(self) - self.statements = statements - self.excluded = excluded - self.suite_spots = suite_spots - self.excluding_suite = 0 - - def doRecursive(self, node): - for n in node.getChildNodes(): - self.dispatch(n) - - visitStmt = visitModule = doRecursive - - def doCode(self, node): - if hasattr(node, 'decorators') and node.decorators: - self.dispatch(node.decorators) - self.recordAndDispatch(node.code) - else: - self.doSuite(node, node.code) - - visitFunction = visitClass = doCode - - def getFirstLine(self, node): - # Find the first line in the tree node. - lineno = node.lineno - for n in node.getChildNodes(): - f = self.getFirstLine(n) - if lineno and f: - lineno = min(lineno, f) - else: - lineno = lineno or f - return lineno - - def getLastLine(self, node): - # Find the first line in the tree node. - lineno = node.lineno - for n in node.getChildNodes(): - lineno = max(lineno, self.getLastLine(n)) - return lineno - - def doStatement(self, node): - self.recordLine(self.getFirstLine(node)) - - visitAssert = visitAssign = visitAssTuple = visitPrint = \ - visitPrintnl = visitRaise = visitSubscript = visitDecorators = \ - doStatement - - def visitPass(self, node): - # Pass statements have weird interactions with docstrings. If this - # pass statement is part of one of those pairs, claim that the statement - # is on the later of the two lines. - l = node.lineno - if l: - lines = self.suite_spots.get(l, [l,l]) - self.statements[lines[1]] = 1 - - def visitDiscard(self, node): - # Discard nodes are statements that execute an expression, but then - # discard the results. This includes function calls, so we can't - # ignore them all. But if the expression is a constant, the statement - # won't be "executed", so don't count it now. - if node.expr.__class__.__name__ != 'Const': - self.doStatement(node) - - def recordNodeLine(self, node): - # Stmt nodes often have None, but shouldn't claim the first line of - # their children (because the first child might be an ignorable line - # like "global a"). - if node.__class__.__name__ != 'Stmt': - return self.recordLine(self.getFirstLine(node)) - else: - return 0 - - def recordLine(self, lineno): - # Returns a bool, whether the line is included or excluded. - if lineno: - # Multi-line tests introducing suites have to get charged to their - # keyword. - if lineno in self.suite_spots: - lineno = self.suite_spots[lineno][0] - # If we're inside an excluded suite, record that this line was - # excluded. - if self.excluding_suite: - self.excluded[lineno] = 1 - return 0 - # If this line is excluded, or suite_spots maps this line to - # another line that is exlcuded, then we're excluded. - elif self.excluded.has_key(lineno) or \ - self.suite_spots.has_key(lineno) and \ - self.excluded.has_key(self.suite_spots[lineno][1]): - return 0 - # Otherwise, this is an executable line. - else: - self.statements[lineno] = 1 - return 1 - return 0 - - default = recordNodeLine - - def recordAndDispatch(self, node): - self.recordNodeLine(node) - self.dispatch(node) - - def doSuite(self, intro, body, exclude=0): - exsuite = self.excluding_suite - if exclude or (intro and not self.recordNodeLine(intro)): - self.excluding_suite = 1 - self.recordAndDispatch(body) - self.excluding_suite = exsuite - - def doPlainWordSuite(self, prevsuite, suite): - # Finding the exclude lines for else's is tricky, because they aren't - # present in the compiler parse tree. Look at the previous suite, - # and find its last line. If any line between there and the else's - # first line are excluded, then we exclude the else. - lastprev = self.getLastLine(prevsuite) - firstelse = self.getFirstLine(suite) - for l in range(lastprev+1, firstelse): - if self.suite_spots.has_key(l): - self.doSuite(None, suite, exclude=self.excluded.has_key(l)) - break - else: - self.doSuite(None, suite) - - def doElse(self, prevsuite, node): - if node.else_: - self.doPlainWordSuite(prevsuite, node.else_) - - def visitFor(self, node): - self.doSuite(node, node.body) - self.doElse(node.body, node) - - visitWhile = visitFor - - def visitIf(self, node): - # The first test has to be handled separately from the rest. - # The first test is credited to the line with the "if", but the others - # are credited to the line with the test for the elif. - self.doSuite(node, node.tests[0][1]) - for t, n in node.tests[1:]: - self.doSuite(t, n) - self.doElse(node.tests[-1][1], node) - - def visitTryExcept(self, node): - self.doSuite(node, node.body) - for i in range(len(node.handlers)): - a, b, h = node.handlers[i] - if not a: - # It's a plain "except:". Find the previous suite. - if i > 0: - prev = node.handlers[i-1][2] - else: - prev = node.body - self.doPlainWordSuite(prev, h) - else: - self.doSuite(a, h) - self.doElse(node.handlers[-1][2], node) - - def visitTryFinally(self, node): - self.doSuite(node, node.body) - self.doPlainWordSuite(node.body, node.final) - - def visitGlobal(self, node): - # "global" statements don't execute like others (they don't call the - # trace function), so don't record their line numbers. - pass - -the_coverage = None - -class CoverageException(Exception): pass - -class coverage: - # Name of the cache file (unless environment variable is set). - cache_default = ".coverage" - - # Environment variable naming the cache file. - cache_env = "COVERAGE_FILE" - - # A dictionary with an entry for (Python source file name, line number - # in that file) if that line has been executed. - c = {} - - # A map from canonical Python source file name to a dictionary in - # which there's an entry for each line number that has been - # executed. - cexecuted = {} - - # Cache of results of calling the analysis2() method, so that you can - # specify both -r and -a without doing double work. - analysis_cache = {} - - # Cache of results of calling the canonical_filename() method, to - # avoid duplicating work. - canonical_filename_cache = {} - - def __init__(self): - global the_coverage - if the_coverage: - raise CoverageException, "Only one coverage object allowed." - self.usecache = 1 - self.cache = None - self.parallel_mode = False - self.exclude_re = '' - self.nesting = 0 - self.cstack = [] - self.xstack = [] - self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.sep) - self.exclude('# *pragma[: ]*[nN][oO] *[cC][oO][vV][eE][rR]') - - # t(f, x, y). This method is passed to sys.settrace as a trace function. - # See [van Rossum 2001-07-20b, 9.2] for an explanation of sys.settrace and - # the arguments and return value of the trace function. - # See [van Rossum 2001-07-20a, 3.2] for a description of frame and code - # objects. - - def t(self, f, w, unused): #pragma: no cover - if w == 'line': - #print "Executing %s @ %d" % (f.f_code.co_filename, f.f_lineno) - self.c[(f.f_code.co_filename, f.f_lineno)] = 1 - for c in self.cstack: - c[(f.f_code.co_filename, f.f_lineno)] = 1 - return self.t - - def help(self, error=None): #pragma: no cover - if error: - print error - print - print __doc__ - sys.exit(1) - - def command_line(self, argv, help_fn=None): - import getopt - help_fn = help_fn or self.help - settings = {} - optmap = { - '-a': 'annotate', - '-c': 'collect', - '-d:': 'directory=', - '-e': 'erase', - '-h': 'help', - '-i': 'ignore-errors', - '-m': 'show-missing', - '-p': 'parallel-mode', - '-r': 'report', - '-x': 'execute', - '-o:': 'omit=', - } - short_opts = string.join(map(lambda o: o[1:], optmap.keys()), '') - long_opts = optmap.values() - options, args = getopt.getopt(argv, short_opts, long_opts) - for o, a in options: - if optmap.has_key(o): - settings[optmap[o]] = 1 - elif optmap.has_key(o + ':'): - settings[optmap[o + ':']] = a - elif o[2:] in long_opts: - settings[o[2:]] = 1 - elif o[2:] + '=' in long_opts: - settings[o[2:]+'='] = a - else: #pragma: no cover - pass # Can't get here, because getopt won't return anything unknown. - - if settings.get('help'): - help_fn() - - for i in ['erase', 'execute']: - for j in ['annotate', 'report', 'collect']: - if settings.get(i) and settings.get(j): - help_fn("You can't specify the '%s' and '%s' " - "options at the same time." % (i, j)) - - args_needed = (settings.get('execute') - or settings.get('annotate') - or settings.get('report')) - action = (settings.get('erase') - or settings.get('collect') - or args_needed) - if not action: - help_fn("You must specify at least one of -e, -x, -c, -r, or -a.") - if not args_needed and args: - help_fn("Unexpected arguments: %s" % " ".join(args)) - - self.parallel_mode = settings.get('parallel-mode') - self.get_ready() - - if settings.get('erase'): - self.erase() - if settings.get('execute'): - if not args: - help_fn("Nothing to do.") - sys.argv = args - self.start() - import __main__ - sys.path[0] = os.path.dirname(sys.argv[0]) - execfile(sys.argv[0], __main__.__dict__) - if settings.get('collect'): - self.collect() - if not args: - args = self.cexecuted.keys() - - ignore_errors = settings.get('ignore-errors') - show_missing = settings.get('show-missing') - directory = settings.get('directory=') - - omit = settings.get('omit=') - if omit is not None: - omit = omit.split(',') - else: - omit = [] - - if settings.get('report'): - self.report(args, show_missing, ignore_errors, omit_prefixes=omit) - if settings.get('annotate'): - self.annotate(args, directory, ignore_errors, omit_prefixes=omit) - - def use_cache(self, usecache, cache_file=None): - self.usecache = usecache - if cache_file and not self.cache: - self.cache_default = cache_file - - def get_ready(self, parallel_mode=False): - if self.usecache and not self.cache: - self.cache = os.environ.get(self.cache_env, self.cache_default) - if self.parallel_mode: - self.cache += "." + gethostname() + "." + str(os.getpid()) - self.restore() - self.analysis_cache = {} - - def start(self, parallel_mode=False): - self.get_ready() - if self.nesting == 0: #pragma: no cover - sys.settrace(self.t) - if hasattr(threading, 'settrace'): - threading.settrace(self.t) - self.nesting += 1 - - def stop(self): - self.nesting -= 1 - if self.nesting == 0: #pragma: no cover - sys.settrace(None) - if hasattr(threading, 'settrace'): - threading.settrace(None) - - def erase(self): - self.get_ready() - self.c = {} - self.analysis_cache = {} - self.cexecuted = {} - if self.cache and os.path.exists(self.cache): - os.remove(self.cache) - - def exclude(self, re): - if self.exclude_re: - self.exclude_re += "|" - self.exclude_re += "(" + re + ")" - - def begin_recursive(self): - self.cstack.append(self.c) - self.xstack.append(self.exclude_re) - - def end_recursive(self): - self.c = self.cstack.pop() - self.exclude_re = self.xstack.pop() - - # save(). Save coverage data to the coverage cache. - - def save(self): - if self.usecache and self.cache: - self.canonicalize_filenames() - cache = open(self.cache, 'wb') - import marshal - marshal.dump(self.cexecuted, cache) - cache.close() - - # restore(). Restore coverage data from the coverage cache (if it exists). - - def restore(self): - self.c = {} - self.cexecuted = {} - assert self.usecache - if os.path.exists(self.cache): - self.cexecuted = self.restore_file(self.cache) - - def restore_file(self, file_name): - try: - cache = open(file_name, 'rb') - import marshal - cexecuted = marshal.load(cache) - cache.close() - if isinstance(cexecuted, types.DictType): - return cexecuted - else: - return {} - except: - return {} - - # collect(). Collect data in multiple files produced by parallel mode - - def collect(self): - cache_dir, local = os.path.split(self.cache) - for f in os.listdir(cache_dir or '.'): - if not f.startswith(local): - continue - - full_path = os.path.join(cache_dir, f) - cexecuted = self.restore_file(full_path) - self.merge_data(cexecuted) - - def merge_data(self, new_data): - for file_name, file_data in new_data.items(): - if self.cexecuted.has_key(file_name): - self.merge_file_data(self.cexecuted[file_name], file_data) - else: - self.cexecuted[file_name] = file_data - - def merge_file_data(self, cache_data, new_data): - for line_number in new_data.keys(): - if not cache_data.has_key(line_number): - cache_data[line_number] = new_data[line_number] - - # canonical_filename(filename). Return a canonical filename for the - # file (that is, an absolute path with no redundant components and - # normalized case). See [GDR 2001-12-04b, 3.3]. - - def canonical_filename(self, filename): - if not self.canonical_filename_cache.has_key(filename): - f = filename - if os.path.isabs(f) and not os.path.exists(f): - f = os.path.basename(f) - if not os.path.isabs(f): - for path in [os.curdir] + sys.path: - g = os.path.join(path, f) - if os.path.exists(g): - f = g - break - cf = os.path.normcase(os.path.abspath(f)) - self.canonical_filename_cache[filename] = cf - return self.canonical_filename_cache[filename] - - # canonicalize_filenames(). Copy results from "c" to "cexecuted", - # canonicalizing filenames on the way. Clear the "c" map. - - def canonicalize_filenames(self): - for filename, lineno in self.c.keys(): - if filename == '<string>': - # Can't do anything useful with exec'd strings, so skip them. - continue - f = self.canonical_filename(filename) - if not self.cexecuted.has_key(f): - self.cexecuted[f] = {} - self.cexecuted[f][lineno] = 1 - self.c = {} - - # morf_filename(morf). Return the filename for a module or file. - - def morf_filename(self, morf): - if isinstance(morf, types.ModuleType): - if not hasattr(morf, '__file__'): - raise CoverageException, "Module has no __file__ attribute." - f = morf.__file__ - else: - f = morf - return self.canonical_filename(f) - - # analyze_morf(morf). Analyze the module or filename passed as - # the argument. If the source code can't be found, raise an error. - # Otherwise, return a tuple of (1) the canonical filename of the - # source code for the module, (2) a list of lines of statements - # in the source code, (3) a list of lines of excluded statements, - # and (4), a map of line numbers to multi-line line number ranges, for - # statements that cross lines. - - def analyze_morf(self, morf): - if self.analysis_cache.has_key(morf): - return self.analysis_cache[morf] - filename = self.morf_filename(morf) - ext = os.path.splitext(filename)[1] - if ext == '.pyc': - if not os.path.exists(filename[0:-1]): - raise CoverageException, ("No source for compiled code '%s'." - % filename) - filename = filename[0:-1] - elif ext != '.py': - raise CoverageException, "File '%s' not Python source." % filename - source = open(filename, 'r') - lines, excluded_lines, line_map = self.find_executable_statements( - source.read(), exclude=self.exclude_re - ) - source.close() - result = filename, lines, excluded_lines, line_map - self.analysis_cache[morf] = result - return result - - def first_line_of_tree(self, tree): - while True: - if len(tree) == 3 and type(tree[2]) == type(1): - return tree[2] - tree = tree[1] - - def last_line_of_tree(self, tree): - while True: - if len(tree) == 3 and type(tree[2]) == type(1): - return tree[2] - tree = tree[-1] - - def find_docstring_pass_pair(self, tree, spots): - for i in range(1, len(tree)): - if self.is_string_constant(tree[i]) and self.is_pass_stmt(tree[i+1]): - first_line = self.first_line_of_tree(tree[i]) - last_line = self.last_line_of_tree(tree[i+1]) - self.record_multiline(spots, first_line, last_line) - - def is_string_constant(self, tree): - try: - return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.expr_stmt - except: - return False - - def is_pass_stmt(self, tree): - try: - return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.pass_stmt - except: - return False - - def record_multiline(self, spots, i, j): - for l in range(i, j+1): - spots[l] = (i, j) - - def get_suite_spots(self, tree, spots): - """ Analyze a parse tree to find suite introducers which span a number - of lines. - """ - for i in range(1, len(tree)): - if type(tree[i]) == type(()): - if tree[i][0] == symbol.suite: - # Found a suite, look back for the colon and keyword. - lineno_colon = lineno_word = None - for j in range(i-1, 0, -1): - if tree[j][0] == token.COLON: - # Colons are never executed themselves: we want the - # line number of the last token before the colon. - lineno_colon = self.last_line_of_tree(tree[j-1]) - elif tree[j][0] == token.NAME: - if tree[j][1] == 'elif': - # Find the line number of the first non-terminal - # after the keyword. - t = tree[j+1] - while t and token.ISNONTERMINAL(t[0]): - t = t[1] - if t: - lineno_word = t[2] - else: - lineno_word = tree[j][2] - break - elif tree[j][0] == symbol.except_clause: - # "except" clauses look like: - # ('except_clause', ('NAME', 'except', lineno), ...) - if tree[j][1][0] == token.NAME: - lineno_word = tree[j][1][2] - break - if lineno_colon and lineno_word: - # Found colon and keyword, mark all the lines - # between the two with the two line numbers. - self.record_multiline(spots, lineno_word, lineno_colon) - - # "pass" statements are tricky: different versions of Python - # treat them differently, especially in the common case of a - # function with a doc string and a single pass statement. - self.find_docstring_pass_pair(tree[i], spots) - - elif tree[i][0] == symbol.simple_stmt: - first_line = self.first_line_of_tree(tree[i]) - last_line = self.last_line_of_tree(tree[i]) - if first_line != last_line: - self.record_multiline(spots, first_line, last_line) - self.get_suite_spots(tree[i], spots) - - def find_executable_statements(self, text, exclude=None): - # Find lines which match an exclusion pattern. - excluded = {} - suite_spots = {} - if exclude: - reExclude = re.compile(exclude) - lines = text.split('\n') - for i in range(len(lines)): - if reExclude.search(lines[i]): - excluded[i+1] = 1 - - # Parse the code and analyze the parse tree to find out which statements - # are multiline, and where suites begin and end. - import parser - tree = parser.suite(text+'\n\n').totuple(1) - self.get_suite_spots(tree, suite_spots) - #print "Suite spots:", suite_spots - - # Use the compiler module to parse the text and find the executable - # statements. We add newlines to be impervious to final partial lines. - statements = {} - ast = compiler.parse(text+'\n\n') - visitor = StatementFindingAstVisitor(statements, excluded, suite_spots) - compiler.walk(ast, visitor, walker=visitor) - - lines = statements.keys() - lines.sort() - excluded_lines = excluded.keys() - excluded_lines.sort() - return lines, excluded_lines, suite_spots - - # format_lines(statements, lines). Format a list of line numbers - # for printing by coalescing groups of lines as long as the lines - # represent consecutive statements. This will coalesce even if - # there are gaps between statements, so if statements = - # [1,2,3,4,5,10,11,12,13,14] and lines = [1,2,5,10,11,13,14] then - # format_lines will return "1-2, 5-11, 13-14". - - def format_lines(self, statements, lines): - pairs = [] - i = 0 - j = 0 - start = None - pairs = [] - while i < len(statements) and j < len(lines): - if statements[i] == lines[j]: - if start == None: - start = lines[j] - end = lines[j] - j = j + 1 - elif start: - pairs.append((start, end)) - start = None - i = i + 1 - if start: - pairs.append((start, end)) - def stringify(pair): - start, end = pair - if start == end: - return "%d" % start - else: - return "%d-%d" % (start, end) - ret = string.join(map(stringify, pairs), ", ") - return ret - - # Backward compatibility with version 1. - def analysis(self, morf): - f, s, _, m, mf = self.analysis2(morf) - return f, s, m, mf - - def analysis2(self, morf): - filename, statements, excluded, line_map = self.analyze_morf(morf) - self.canonicalize_filenames() - if not self.cexecuted.has_key(filename): - self.cexecuted[filename] = {} - missing = [] - for line in statements: - lines = line_map.get(line, [line, line]) - for l in range(lines[0], lines[1]+1): - if self.cexecuted[filename].has_key(l): - break - else: - missing.append(line) - return (filename, statements, excluded, missing, - self.format_lines(statements, missing)) - - def relative_filename(self, filename): - """ Convert filename to relative filename from self.relative_dir. - """ - return filename.replace(self.relative_dir, "") - - def morf_name(self, morf): - """ Return the name of morf as used in report. - """ - if isinstance(morf, types.ModuleType): - return morf.__name__ - else: - return self.relative_filename(os.path.splitext(morf)[0]) - - def filter_by_prefix(self, morfs, omit_prefixes): - """ Return list of morfs where the morf name does not begin - with any one of the omit_prefixes. - """ - filtered_morfs = [] - for morf in morfs: - for prefix in omit_prefixes: - if self.morf_name(morf).startswith(prefix): - break - else: - filtered_morfs.append(morf) - - return filtered_morfs - - def morf_name_compare(self, x, y): - return cmp(self.morf_name(x), self.morf_name(y)) - - def report(self, morfs, show_missing=1, ignore_errors=0, file=None, omit_prefixes=[]): - if not isinstance(morfs, types.ListType): - morfs = [morfs] - # On windows, the shell doesn't expand wildcards. Do it here. - globbed = [] - for morf in morfs: - if isinstance(morf, basestring): - globbed.extend(glob.glob(morf)) - else: - globbed.append(morf) - morfs = globbed - - morfs = self.filter_by_prefix(morfs, omit_prefixes) - morfs.sort(self.morf_name_compare) - - max_name = max([5,] + map(len, map(self.morf_name, morfs))) - fmt_name = "%%- %ds " % max_name - fmt_err = fmt_name + "%s: %s" - header = fmt_name % "Name" + " Stmts Exec Cover" - fmt_coverage = fmt_name + "% 6d % 6d % 5d%%" - if show_missing: - header = header + " Missing" - fmt_coverage = fmt_coverage + " %s" - if not file: - file = sys.stdout - print >>file, header - print >>file, "-" * len(header) - total_statements = 0 - total_executed = 0 - for morf in morfs: - name = self.morf_name(morf) - try: - _, statements, _, missing, readable = self.analysis2(morf) - n = len(statements) - m = n - len(missing) - if n > 0: - pc = 100.0 * m / n - else: - pc = 100.0 - args = (name, n, m, pc) - if show_missing: - args = args + (readable,) - print >>file, fmt_coverage % args - total_statements = total_statements + n - total_executed = total_executed + m - except KeyboardInterrupt: #pragma: no cover - raise - except: - if not ignore_errors: - typ, msg = sys.exc_info()[0:2] - print >>file, fmt_err % (name, typ, msg) - if len(morfs) > 1: - print >>file, "-" * len(header) - if total_statements > 0: - pc = 100.0 * total_executed / total_statements - else: - pc = 100.0 - args = ("TOTAL", total_statements, total_executed, pc) - if show_missing: - args = args + ("",) - print >>file, fmt_coverage % args - - # annotate(morfs, ignore_errors). - - blank_re = re.compile(r"\s*(#|$)") - else_re = re.compile(r"\s*else\s*:\s*(#|$)") - - def annotate(self, morfs, directory=None, ignore_errors=0, omit_prefixes=[]): - morfs = self.filter_by_prefix(morfs, omit_prefixes) - for morf in morfs: - try: - filename, statements, excluded, missing, _ = self.analysis2(morf) - self.annotate_file(filename, statements, excluded, missing, directory) - except KeyboardInterrupt: - raise - except: - if not ignore_errors: - raise - - def annotate_file(self, filename, statements, excluded, missing, directory=None): - source = open(filename, 'r') - if directory: - dest_file = os.path.join(directory, - os.path.basename(filename) - + ',cover') - else: - dest_file = filename + ',cover' - dest = open(dest_file, 'w') - lineno = 0 - i = 0 - j = 0 - covered = 1 - while 1: - line = source.readline() - if line == '': - break - lineno = lineno + 1 - while i < len(statements) and statements[i] < lineno: - i = i + 1 - while j < len(missing) and missing[j] < lineno: - j = j + 1 - if i < len(statements) and statements[i] == lineno: - covered = j >= len(missing) or missing[j] > lineno - if self.blank_re.match(line): - dest.write(' ') - elif self.else_re.match(line): - # Special logic for lines containing only 'else:'. - # See [GDR 2001-12-04b, 3.2]. - if i >= len(statements) and j >= len(missing): - dest.write('! ') - elif i >= len(statements) or j >= len(missing): - dest.write('> ') - elif statements[i] == missing[j]: - dest.write('! ') - else: - dest.write('> ') - elif lineno in excluded: - dest.write('- ') - elif covered: - dest.write('> ') - else: - dest.write('! ') - dest.write(line) - source.close() - dest.close() - -# Singleton object. -the_coverage = coverage() - -# Module functions call methods in the singleton object. -def use_cache(*args, **kw): - return the_coverage.use_cache(*args, **kw) - -def start(*args, **kw): - return the_coverage.start(*args, **kw) - -def stop(*args, **kw): - return the_coverage.stop(*args, **kw) - -def erase(*args, **kw): - return the_coverage.erase(*args, **kw) - -def begin_recursive(*args, **kw): - return the_coverage.begin_recursive(*args, **kw) - -def end_recursive(*args, **kw): - return the_coverage.end_recursive(*args, **kw) - -def exclude(*args, **kw): - return the_coverage.exclude(*args, **kw) - -def analysis(*args, **kw): - return the_coverage.analysis(*args, **kw) - -def analysis2(*args, **kw): - return the_coverage.analysis2(*args, **kw) - -def report(*args, **kw): - return the_coverage.report(*args, **kw) - -def annotate(*args, **kw): - return the_coverage.annotate(*args, **kw) - -def annotate_file(*args, **kw): - return the_coverage.annotate_file(*args, **kw) - -# Save coverage data when Python exits. (The atexit module wasn't -# introduced until Python 2.0, so use sys.exitfunc when it's not -# available.) -try: - import atexit - atexit.register(the_coverage.save) -except ImportError: - sys.exitfunc = the_coverage.save - -# Command-line interface. -if __name__ == '__main__': - the_coverage.command_line(sys.argv[1:]) - - -# A. REFERENCES -# -# [GDR 2001-12-04a] "Statement coverage for Python"; Gareth Rees; -# Ravenbrook Limited; 2001-12-04; -# <http://www.nedbatchelder.com/code/modules/rees-coverage.html>. -# -# [GDR 2001-12-04b] "Statement coverage for Python: design and -# analysis"; Gareth Rees; Ravenbrook Limited; 2001-12-04; -# <http://www.nedbatchelder.com/code/modules/rees-design.html>. -# -# [van Rossum 2001-07-20a] "Python Reference Manual (releae 2.1.1)"; -# Guide van Rossum; 2001-07-20; -# <http://www.python.org/doc/2.1.1/ref/ref.html>. -# -# [van Rossum 2001-07-20b] "Python Library Reference"; Guido van Rossum; -# 2001-07-20; <http://www.python.org/doc/2.1.1/lib/lib.html>. -# -# -# B. DOCUMENT HISTORY -# -# 2001-12-04 GDR Created. -# -# 2001-12-06 GDR Added command-line interface and source code -# annotation. -# -# 2001-12-09 GDR Moved design and interface to separate documents. -# -# 2001-12-10 GDR Open cache file as binary on Windows. Allow -# simultaneous -e and -x, or -a and -r. -# -# 2001-12-12 GDR Added command-line help. Cache analysis so that it -# only needs to be done once when you specify -a and -r. -# -# 2001-12-13 GDR Improved speed while recording. Portable between -# Python 1.5.2 and 2.1.1. -# -# 2002-01-03 GDR Module-level functions work correctly. -# -# 2002-01-07 GDR Update sys.path when running a file with the -x option, -# so that it matches the value the program would get if it were run on -# its own. -# -# 2004-12-12 NMB Significant code changes. -# - Finding executable statements has been rewritten so that docstrings and -# other quirks of Python execution aren't mistakenly identified as missing -# lines. -# - Lines can be excluded from consideration, even entire suites of lines. -# - The filesystem cache of covered lines can be disabled programmatically. -# - Modernized the code. -# -# 2004-12-14 NMB Minor tweaks. Return 'analysis' to its original behavior -# and add 'analysis2'. Add a global for 'annotate', and factor it, adding -# 'annotate_file'. -# -# 2004-12-31 NMB Allow for keyword arguments in the module global functions. -# Thanks, Allen. -# -# 2005-12-02 NMB Call threading.settrace so that all threads are measured. -# Thanks Martin Fuzzey. Add a file argument to report so that reports can be -# captured to a different destination. -# -# 2005-12-03 NMB coverage.py can now measure itself. -# -# 2005-12-04 NMB Adapted Greg Rogers' patch for using relative filenames, -# and sorting and omitting files to report on. -# -# 2006-07-23 NMB Applied Joseph Tate's patch for function decorators. -# -# 2006-08-21 NMB Applied Sigve Tjora and Mark van der Wal's fixes for argument -# handling. -# -# 2006-08-22 NMB Applied Geoff Bache's parallel mode patch. -# -# 2006-08-23 NMB Refactorings to improve testability. Fixes to command-line -# logic for parallel mode and collect. -# -# 2006-08-25 NMB "#pragma: nocover" is excluded by default. -# -# 2006-09-10 NMB Properly ignore docstrings and other constant expressions that -# appear in the middle of a function, a problem reported by Tim Leslie. -# Minor changes to avoid lint warnings. -# -# 2006-09-17 NMB coverage.erase() shouldn't clobber the exclude regex. -# Change how parallel mode is invoked, and fix erase() so that it erases the -# cache when called programmatically. -# -# 2007-07-21 NMB In reports, ignore code executed from strings, since we can't -# do anything useful with it anyway. -# Better file handling on Linux, thanks Guillaume Chazarain. -# Better shell support on Windows, thanks Noel O'Boyle. -# Python 2.2 support maintained, thanks Catherine Proulx. -# -# 2007-07-22 NMB Python 2.5 now fully supported. The method of dealing with -# multi-line statements is now less sensitive to the exact line that Python -# reports during execution. Pass statements are handled specially so that their -# disappearance during execution won't throw off the measurement. - -# C. COPYRIGHT AND LICENCE -# -# Copyright 2001 Gareth Rees. All rights reserved. -# Copyright 2004-2007 Ned Batchelder. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# 1. Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the -# distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDERS AND CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS -# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $ diff --git a/test/testlib/engines.py b/test/testlib/engines.py deleted file mode 100644 index 4068f43d0..000000000 --- a/test/testlib/engines.py +++ /dev/null @@ -1,245 +0,0 @@ -import sys, types, weakref -from collections import deque -from testlib import config -from testlib.compat import _function_named, callable - -class ConnectionKiller(object): - def __init__(self): - self.proxy_refs = weakref.WeakKeyDictionary() - - def checkout(self, dbapi_con, con_record, con_proxy): - self.proxy_refs[con_proxy] = True - - def _apply_all(self, methods): - for rec in self.proxy_refs: - if rec is not None and rec.is_valid: - try: - for name in methods: - if callable(name): - name(rec) - else: - getattr(rec, name)() - except (SystemExit, KeyboardInterrupt): - raise - except Exception, e: - # fixme - sys.stderr.write("\n" + str(e) + "\n") - - def rollback_all(self): - self._apply_all(('rollback',)) - - def close_all(self): - self._apply_all(('rollback', 'close')) - - def assert_all_closed(self): - for rec in self.proxy_refs: - if rec.is_valid: - assert False - -testing_reaper = ConnectionKiller() - -def assert_conns_closed(fn): - def decorated(*args, **kw): - try: - fn(*args, **kw) - finally: - testing_reaper.assert_all_closed() - return _function_named(decorated, fn.__name__) - -def rollback_open_connections(fn): - """Decorator that rolls back all open connections after fn execution.""" - - def decorated(*args, **kw): - try: - fn(*args, **kw) - finally: - testing_reaper.rollback_all() - return _function_named(decorated, fn.__name__) - -def close_open_connections(fn): - """Decorator that closes all connections after fn execution.""" - - def decorated(*args, **kw): - try: - fn(*args, **kw) - finally: - testing_reaper.close_all() - return _function_named(decorated, fn.__name__) - -def all_dialects(): - import sqlalchemy.databases as d - for name in d.__all__: - mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) - yield mod.dialect() - -class ReconnectFixture(object): - def __init__(self, dbapi): - self.dbapi = dbapi - self.connections = [] - - def __getattr__(self, key): - return getattr(self.dbapi, key) - - def connect(self, *args, **kwargs): - conn = self.dbapi.connect(*args, **kwargs) - self.connections.append(conn) - return conn - - def shutdown(self): - for c in list(self.connections): - c.close() - self.connections = [] - -def reconnecting_engine(url=None, options=None): - url = url or config.db_url - dbapi = config.db.dialect.dbapi - if not options: - options = {} - options['module'] = ReconnectFixture(dbapi) - engine = testing_engine(url, options) - engine.test_shutdown = engine.dialect.dbapi.shutdown - return engine - -def testing_engine(url=None, options=None): - """Produce an engine configured by --options with optional overrides.""" - - from sqlalchemy import create_engine - from testlib.assertsql import asserter - - url = url or config.db_url - options = options or config.db_opts - - options.setdefault('proxy', asserter) - - listeners = options.setdefault('listeners', []) - listeners.append(testing_reaper) - - engine = create_engine(url, **options) - - return engine - -def utf8_engine(url=None, options=None): - """Hook for dialects or drivers that don't handle utf8 by default.""" - - from sqlalchemy.engine import url as engine_url - - if config.db.name == 'mysql': - dbapi_ver = config.db.dialect.dbapi.version_info - if (dbapi_ver < (1, 2, 1) or - dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), - (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))): - raise RuntimeError('Character set support unavailable with this ' - 'driver version: %s' % repr(dbapi_ver)) - else: - url = url or config.db_url - url = engine_url.make_url(url) - url.query['charset'] = 'utf8' - url.query['use_unicode'] = '0' - url = str(url) - - return testing_engine(url, options) - -def mock_engine(db=None): - """Provides a mocking engine based on the current testing.db.""" - - from sqlalchemy import create_engine - - dbi = db or config.db - buffer = [] - def executor(sql, *a, **kw): - buffer.append(sql) - engine = create_engine(dbi.name + '://', - strategy='mock', executor=executor) - assert not hasattr(engine, 'mock') - engine.mock = buffer - return engine - -class ReplayableSession(object): - """A simple record/playback tool. - - This is *not* a mock testing class. It only records a session for later - playback and makes no assertions on call consistency whatsoever. It's - unlikely to be suitable for anything other than DB-API recording. - - """ - - Callable = object() - NoAttribute = object() - Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]). \ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', 'UnboundMethodType',)]) - def __init__(self): - self.buffer = deque() - - def recorder(self, base): - return self.Recorder(self.buffer, base) - - def player(self): - return self.Player(self.buffer) - - class Recorder(object): - def __init__(self, buffer, subject): - self._buffer = buffer - self._subject = subject - - def __call__(self, *args, **kw): - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] - - result = subject(*args, **kw) - if type(result) not in ReplayableSession.Natives: - buffer.append(ReplayableSession.Callable) - return type(self)(buffer, result) - else: - buffer.append(result) - return result - - def __getattribute__(self, key): - try: - return object.__getattribute__(self, key) - except AttributeError: - pass - - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] - try: - result = type(subject).__getattribute__(subject, key) - except AttributeError: - buffer.append(ReplayableSession.NoAttribute) - raise - else: - if type(result) not in ReplayableSession.Natives: - buffer.append(ReplayableSession.Callable) - return type(self)(buffer, result) - else: - buffer.append(result) - return result - - class Player(object): - def __init__(self, buffer): - self._buffer = buffer - - def __call__(self, *args, **kw): - buffer = object.__getattribute__(self, '_buffer') - result = buffer.popleft() - if result is ReplayableSession.Callable: - return self - else: - return result - - def __getattribute__(self, key): - try: - return object.__getattribute__(self, key) - except AttributeError: - pass - buffer = object.__getattribute__(self, '_buffer') - result = buffer.popleft() - if result is ReplayableSession.Callable: - return self - elif result is ReplayableSession.NoAttribute: - raise AttributeError(key) - else: - return result diff --git a/test/testlib/orm.py b/test/testlib/orm.py deleted file mode 100644 index 22d624601..000000000 --- a/test/testlib/orm.py +++ /dev/null @@ -1,117 +0,0 @@ -import inspect, re -from testlib import config, testing - -sa = None -orm = None - -__all__ = 'mapper', - - -_whitespace = re.compile(r'^(\s+)') - -def _find_pragma(lines, current): - m = _whitespace.match(lines[current]) - basis = m and m.group() or '' - - for line in reversed(lines[0:current]): - if 'testlib.pragma' in line: - return line - m = _whitespace.match(line) - indent = m and m.group() or '' - - # simplistic detection: - - # >> # testlib.pragma foo - # >> center_line() - if indent == basis: - break - # >> # testlib.pragma foo - # >> if fleem: - # >> center_line() - if line.endswith(':'): - break - return None - -def _make_blocker(method_name, fallback): - """Creates tripwired variant of a method, raising when called. - - To excempt an invocation from blockage, there are two options. - - 1) add a pragma in a comment:: - - # testlib.pragma exempt:methodname - offending_line() - - 2) add a magic cookie to the function's namespace:: - __sa_baremethodname_exempt__ = True - ... - offending_line() - another_offending_lines() - - The second is useful for testing and development. - """ - - if method_name.startswith('__') and method_name.endswith('__'): - frame_marker = '__sa_%s_exempt__' % method_name[2:-2] - else: - frame_marker = '__sa_%s_exempt__' % method_name - pragma_marker = 'exempt:' + method_name - - def method(self, *args, **kw): - frame_r = None - try: - frame = inspect.stack()[1][0] - frame_r = inspect.getframeinfo(frame, 9) - - module = frame.f_globals.get('__name__', '') - - type_ = type(self) - - pragma = _find_pragma(*frame_r[3:5]) - - exempt = ( - (not module.startswith('sqlalchemy')) or - (pragma and pragma_marker in pragma) or - (frame_marker in frame.f_locals) or - ('self' in frame.f_locals and - getattr(frame.f_locals['self'], frame_marker, False))) - - if exempt: - supermeth = getattr(super(type_, self), method_name, None) - if (supermeth is None or - getattr(supermeth, 'im_func', None) is method): - return fallback(self, *args, **kw) - else: - return supermeth(*args, **kw) - else: - raise AssertionError( - "%s.%s called in %s, line %s in %s" % ( - type_.__name__, method_name, module, frame_r[1], frame_r[2])) - finally: - del frame - method.__name__ = method_name - return method - -def mapper(type_, *args, **kw): - global orm - if orm is None: - from sqlalchemy import orm - - forbidden = [ - ('__hash__', 'unhashable', lambda s: id(s)), - ('__eq__', 'noncomparable', lambda s, o: s is o), - ('__ne__', 'noncomparable', lambda s, o: s is not o), - ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)), - ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)), - ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)), - ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)), - ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)), - ('__nonzero__', 'truthless', lambda s: 1), ] - - if isinstance(type_, type) and type_.__bases__ == (object,): - for method_name, option, fallback in forbidden: - if (getattr(config.options, option, False) and - method_name not in type_.__dict__): - setattr(type_, method_name, _make_blocker(method_name, fallback)) - - return orm.mapper(type_, *args, **kw) diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py deleted file mode 100644 index 89db33011..000000000 --- a/test/testlib/profiling.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Profiling support for unit and performance tests.""" - -import os, sys -from testlib.compat import _function_named -import testlib.config - -__all__ = 'profiled', 'function_call_count', 'conditional_call_count' - -all_targets = set() -profile_config = { 'targets': set(), - 'report': True, - 'sort': ('time', 'calls'), - 'limit': None } -profiler = None - -def profiled(target=None, **target_opts): - """Optional function profiling. - - @profiled('label') - or - @profiled('label', report=True, sort=('calls',), limit=20) - - Enables profiling for a function when 'label' is targetted for - profiling. Report options can be supplied, and override the global - configuration and command-line options. - """ - - # manual or automatic namespacing by module would remove conflict issues - if target is None: - target = 'anonymous_target' - elif target in all_targets: - print "Warning: redefining profile target '%s'" % target - all_targets.add(target) - - filename = "%s.prof" % target - - def decorator(fn): - def profiled(*args, **kw): - if (target not in profile_config['targets'] and - not target_opts.get('always', None)): - return fn(*args, **kw) - - elapsed, load_stats, result = _profile( - filename, fn, *args, **kw) - - if not testlib.config.options.quiet: - print "Profiled target '%s', wall time: %.2f seconds" % ( - target, elapsed) - - report = target_opts.get('report', profile_config['report']) - if report and testlib.config.options.verbose: - sort_ = target_opts.get('sort', profile_config['sort']) - limit = target_opts.get('limit', profile_config['limit']) - print "Profile report for target '%s' (%s)" % ( - target, filename) - - stats = load_stats() - stats.sort_stats(*sort_) - if limit: - stats.print_stats(limit) - else: - stats.print_stats() - #stats.print_callers() - os.unlink(filename) - return result - return _function_named(profiled, fn.__name__) - return decorator - -def function_call_count(count=None, versions={}, variance=0.05): - """Assert a target for a test case's function call count. - - count - Optional, general target function call count. - - versions - Optional, a dictionary of Python version strings to counts, - for example:: - - { '2.5.1': 110, - '2.5': 100, - '2.4': 150 } - - The best match for the current running python will be used. - If none match, 'count' will be used as the fallback. - - variance - An +/- deviation percentage, defaults to 5%. - """ - - # this could easily dump the profile report if --verbose is in effect - - version_info = list(sys.version_info) - py_version = '.'.join([str(v) for v in sys.version_info]) - - while version_info: - version = '.'.join([str(v) for v in version_info]) - if version in versions: - count = versions[version] - break - version_info.pop() - - if count is None: - return lambda fn: fn - - def decorator(fn): - def counted(*args, **kw): - try: - filename = "%s.prof" % fn.__name__ - - elapsed, stat_loader, result = _profile( - filename, fn, *args, **kw) - - stats = stat_loader() - calls = stats.total_calls - - if testlib.config.options.verbose: - stats.sort_stats('calls', 'cumulative') - stats.print_stats() - #stats.print_callers() - deviance = int(count * variance) - if (calls < (count - deviance) or - calls > (count + deviance)): - raise AssertionError( - "Function call count %s not within %s%% " - "of expected %s. (Python version %s)" % ( - calls, (variance * 100), count, py_version)) - - return result - finally: - if os.path.exists(filename): - os.unlink(filename) - return _function_named(counted, fn.__name__) - return decorator - -def conditional_call_count(discriminator, categories): - """Apply a function call count conditionally at runtime. - - Takes two arguments, a callable that returns a key value, and a dict - mapping key values to a tuple of arguments to function_call_count. - - The callable is not evaluated until the decorated function is actually - invoked. If the `discriminator` returns a key not present in the - `categories` dictionary, no call count assertion is applied. - - Useful for integration tests, where running a named test in isolation may - have a function count penalty not seen in the full suite, due to lazy - initialization in the DB-API, SA, etc. - """ - - def decorator(fn): - def at_runtime(*args, **kw): - criteria = categories.get(discriminator(), None) - if criteria is None: - return fn(*args, **kw) - - rewrapped = function_call_count(*criteria)(fn) - return rewrapped(*args, **kw) - return _function_named(at_runtime, fn.__name__) - return decorator - - -def _profile(filename, fn, *args, **kw): - global profiler - if not profiler: - profiler = 'hotshot' - if sys.version_info > (2, 5): - try: - import cProfile - profiler = 'cProfile' - except ImportError: - pass - - if profiler == 'cProfile': - return _profile_cProfile(filename, fn, *args, **kw) - else: - return _profile_hotshot(filename, fn, *args, **kw) - -def _profile_cProfile(filename, fn, *args, **kw): - import cProfile, gc, pstats, time - - load_stats = lambda: pstats.Stats(filename) - gc.collect() - - began = time.time() - cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), - filename=filename) - ended = time.time() - - return ended - began, load_stats, locals()['result'] - -def _profile_hotshot(filename, fn, *args, **kw): - import gc, hotshot, hotshot.stats, time - load_stats = lambda: hotshot.stats.load(filename) - - gc.collect() - prof = hotshot.Profile(filename) - began = time.time() - prof.start() - try: - result = fn(*args, **kw) - finally: - prof.stop() - ended = time.time() - prof.close() - - return ended - began, load_stats, result - diff --git a/test/testlib/requires.py b/test/testlib/requires.py deleted file mode 100644 index b20929a83..000000000 --- a/test/testlib/requires.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Global database feature support policy. - -Provides decorators to mark tests requiring specific feature support from the -target database. - -""" - -from testlib.testing import \ - _block_unconditionally as no_support, \ - _chain_decorators_on, \ - exclude, \ - emits_warning_on - - -def deferrable_constraints(fn): - """Target database must support derferable constraints.""" - return _chain_decorators_on( - fn, - no_support('firebird', 'not supported by database'), - no_support('mysql', 'not supported by database'), - no_support('mssql', 'not supported by database'), - ) - -def foreign_keys(fn): - """Target database must support foreign keys.""" - return _chain_decorators_on( - fn, - no_support('sqlite', 'not supported by database'), - ) - -def identity(fn): - """Target database must support GENERATED AS IDENTITY or a facsimile. - - Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other - column DDL feature that fills in a DB-generated identifier at INSERT-time - without requiring pre-execution of a SEQUENCE or other artifact. - - """ - return _chain_decorators_on( - fn, - no_support('firebird', 'not supported by database'), - no_support('oracle', 'not supported by database'), - no_support('postgres', 'not supported by database'), - no_support('sybase', 'not supported by database'), - ) - -def independent_connections(fn): - """Target must support simultaneous, independent database connections.""" - - # This is also true of some configurations of UnixODBC and probably win32 - # ODBC as well. - return _chain_decorators_on( - fn, - no_support('sqlite', 'no driver support') - ) - -def row_triggers(fn): - """Target must support standard statement-running EACH ROW triggers.""" - return _chain_decorators_on( - fn, - # no access to same table - no_support('mysql', 'requires SUPER priv'), - exclude('mysql', '<', (5, 0, 10), 'not supported by database'), - no_support('postgres', 'not supported by database: no statements'), - ) - -def savepoints(fn): - """Target database must support savepoints.""" - return _chain_decorators_on( - fn, - emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'), - no_support('access', 'not supported by database'), - no_support('sqlite', 'not supported by database'), - no_support('sybase', 'FIXME: guessing, needs confirmation'), - exclude('mysql', '<', (5, 0, 3), 'not supported by database'), - ) - -def sequences(fn): - """Target database must support SEQUENCEs.""" - return _chain_decorators_on( - fn, - no_support('access', 'no SEQUENCE support'), - no_support('mssql', 'no SEQUENCE support'), - no_support('mysql', 'no SEQUENCE support'), - no_support('sqlite', 'no SEQUENCE support'), - no_support('sybase', 'no SEQUENCE support'), - ) - -def subqueries(fn): - """Target database must support subqueries.""" - return _chain_decorators_on( - fn, - exclude('mysql', '<', (4, 1, 1), 'no subquery support'), - ) - -def two_phase_transactions(fn): - """Target database must support two-phase transactions.""" - return _chain_decorators_on( - fn, - no_support('access', 'not supported by database'), - no_support('firebird', 'no SA implementation'), - no_support('maxdb', 'not supported by database'), - no_support('mssql', 'FIXME: guessing, needs confirmation'), - no_support('oracle', 'no SA implementation'), - no_support('sqlite', 'not supported by database'), - no_support('sybase', 'FIXME: guessing, needs confirmation'), - exclude('mysql', '<', (5, 0, 3), 'not supported by database'), - ) - -def unicode_connections(fn): - """Target driver must support some encoding of Unicode across the wire.""" - # TODO: expand to exclude MySQLdb versions w/ broken unicode - return _chain_decorators_on( - fn, - exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), - ) - -def unicode_ddl(fn): - """Target driver must support some encoding of Unicode across the wire.""" - # TODO: expand to exclude MySQLdb versions w/ broken unicode - return _chain_decorators_on( - fn, - no_support('maxdb', 'database support flakey'), - no_support('oracle', 'FIXME: no support in database?'), - no_support('sybase', 'FIXME: guessing, needs confirmation'), - exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), - ) diff --git a/test/testlib/sa_unittest.py b/test/testlib/sa_unittest.py deleted file mode 100644 index 7eb2c0727..000000000 --- a/test/testlib/sa_unittest.py +++ /dev/null @@ -1,787 +0,0 @@ -#!/usr/bin/env python -''' -unittest.py from Python 2.5. - -SQLAlchemy extends unittest internals to provide setUpAll()/tearDownAll() -so we include a fixed version here to insulate from changes. 2.6 and -3.0's unittest is incompatible with our changes. - -Approaches to removing this dependency are: - -* find a unittest-supported method of grouping UnitTest classes within -a setUpAll()/tearDownAll() pair, such that all tests within a single -UnitTest class are executed within a single execution of setUpAll()/ -tearDownAll(). It may be possible to create nested TestSuite objects -to accomplish this but it's not clear. -* migrate to a different system such as nose. - -Copyright (c) 1999-2003 Steve Purcell -This module is free software, and you may redistribute it and/or modify -it under the same terms as Python itself, so long as this copyright message -and disclaimer are retained in their original form. - -IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, -SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF -THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH -DAMAGE. - -THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, -AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE, -SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. -''' - -__author__ = "Steve Purcell" -__email__ = "stephen_purcell at yahoo dot com" -__version__ = "#Revision: 1.63 $"[11:-2] - -import time -import sys -import traceback -import os -import types -from testlib.compat import callable - -############################################################################## -# Exported classes and functions -############################################################################## -__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', - 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader'] - -# Expose obsolete functions for backwards compatibility -__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) - - -############################################################################## -# Test framework core -############################################################################## - -# All classes defined herein are 'new-style' classes, allowing use of 'super()' -__metaclass__ = type - -def _strclass(cls): - return "%s.%s" % (cls.__module__, cls.__name__) - -__unittest = 1 - -class TestResult: - """Holder for test result information. - - Test results are automatically managed by the TestCase and TestSuite - classes, and do not need to be explicitly manipulated by writers of tests. - - Each instance holds the total number of tests run, and collections of - failures and errors that occurred among those test runs. The collections - contain tuples of (testcase, exceptioninfo), where exceptioninfo is the - formatted traceback of the error that occurred. - """ - def __init__(self): - self.failures = [] - self.errors = [] - self.testsRun = 0 - self.shouldStop = 0 - - def startTest(self, test): - "Called when the given test is about to be run" - self.testsRun = self.testsRun + 1 - - def stopTest(self, test): - "Called when the given test has been run" - pass - - def addError(self, test, err): - """Called when an error has occurred. 'err' is a tuple of values as - returned by sys.exc_info(). - """ - self.errors.append((test, self._exc_info_to_string(err, test))) - - def addFailure(self, test, err): - """Called when an error has occurred. 'err' is a tuple of values as - returned by sys.exc_info().""" - self.failures.append((test, self._exc_info_to_string(err, test))) - - def addSuccess(self, test): - "Called when a test has completed successfully" - pass - - def wasSuccessful(self): - "Tells whether or not this result was a success" - return len(self.failures) == len(self.errors) == 0 - - def stop(self): - "Indicates that the tests should be aborted" - self.shouldStop = True - - def _exc_info_to_string(self, err, test): - """Converts a sys.exc_info()-style tuple of values into a string.""" - exctype, value, tb = err - # Skip test runner traceback levels - while tb and self._is_relevant_tb_level(tb): - tb = tb.tb_next - if exctype is test.failureException: - # Skip assert*() traceback levels - length = self._count_relevant_tb_levels(tb) - return ''.join(traceback.format_exception(exctype, value, tb, length)) - return ''.join(traceback.format_exception(exctype, value, tb)) - - def _is_relevant_tb_level(self, tb): - return tb.tb_frame.f_globals.has_key('__unittest') - - def _count_relevant_tb_levels(self, tb): - length = 0 - while tb and not self._is_relevant_tb_level(tb): - length += 1 - tb = tb.tb_next - return length - - def __repr__(self): - return "<%s run=%i errors=%i failures=%i>" % \ - (_strclass(self.__class__), self.testsRun, len(self.errors), - len(self.failures)) - -class TestCase: - """A class whose instances are single test cases. - - By default, the test code itself should be placed in a method named - 'runTest'. - - If the fixture may be used for many test cases, create as - many test methods as are needed. When instantiating such a TestCase - subclass, specify in the constructor arguments the name of the test method - that the instance is to execute. - - Test authors should subclass TestCase for their own tests. Construction - and deconstruction of the test's environment ('fixture') can be - implemented by overriding the 'setUp' and 'tearDown' methods respectively. - - If it is necessary to override the __init__ method, the base class - __init__ method must always be called. It is important that subclasses - should not change the signature of their __init__ method, since instances - of the classes are instantiated automatically by parts of the framework - in order to be run. - """ - - # This attribute determines which exception will be raised when - # the instance's assertion methods fail; test methods raising this - # exception will be deemed to have 'failed' rather than 'errored' - - failureException = AssertionError - - def __init__(self, methodName='runTest'): - """Create an instance of the class that will use the named test - method when executed. Raises a ValueError if the instance does - not have a method with the specified name. - """ - try: - self._testMethodName = methodName - testMethod = getattr(self, methodName) - self._testMethodDoc = testMethod.__doc__ - except AttributeError: - raise ValueError, "no such test method in %s: %s" % \ - (self.__class__, methodName) - - def setUp(self): - "Hook method for setting up the test fixture before exercising it." - pass - - def tearDown(self): - "Hook method for deconstructing the test fixture after testing it." - pass - - def countTestCases(self): - return 1 - - def defaultTestResult(self): - return TestResult() - - def shortDescription(self): - """Returns a one-line description of the test, or None if no - description has been provided. - - The default implementation of this method returns the first line of - the specified test method's docstring. - """ - doc = self._testMethodDoc - return doc and doc.split("\n")[0].strip() or None - - def id(self): - return "%s.%s" % (_strclass(self.__class__), self._testMethodName) - - def __str__(self): - return "%s (%s)" % (self._testMethodName, _strclass(self.__class__)) - - def __repr__(self): - return "<%s testMethod=%s>" % \ - (_strclass(self.__class__), self._testMethodName) - - def run(self, result=None): - if result is None: result = self.defaultTestResult() - result.startTest(self) - testMethod = getattr(self, self._testMethodName) - try: - try: - self.setUp() - except KeyboardInterrupt: - raise - except: - result.addError(self, self._exc_info()) - return - - ok = False - try: - testMethod() - ok = True - except self.failureException: - result.addFailure(self, self._exc_info()) - except KeyboardInterrupt: - raise - except: - result.addError(self, self._exc_info()) - - try: - self.tearDown() - except KeyboardInterrupt: - raise - except: - result.addError(self, self._exc_info()) - ok = False - if ok: result.addSuccess(self) - finally: - result.stopTest(self) - - def __call__(self, *args, **kwds): - return self.run(*args, **kwds) - - def debug(self): - """Run the test without collecting errors in a TestResult""" - self.setUp() - getattr(self, self._testMethodName)() - self.tearDown() - - def _exc_info(self): - """Return a version of sys.exc_info() with the traceback frame - minimised; usually the top level of the traceback frame is not - needed. - """ - exctype, excvalue, tb = sys.exc_info() - if sys.platform[:4] == 'java': ## tracebacks look different in Jython - return (exctype, excvalue, tb) - return (exctype, excvalue, tb) - - def fail(self, msg=None): - """Fail immediately, with the given message.""" - raise self.failureException, msg - - def failIf(self, expr, msg=None): - "Fail the test if the expression is true." - if expr: raise self.failureException, msg - - def failUnless(self, expr, msg=None): - """Fail the test unless the expression is true.""" - if not expr: raise self.failureException, msg - - def failUnlessRaises(self, excClass, callableObj, *args, **kwargs): - """Fail unless an exception of class excClass is thrown - by callableObj when invoked with arguments args and keyword - arguments kwargs. If a different type of exception is - thrown, it will not be caught, and the test case will be - deemed to have suffered an error, exactly as for an - unexpected exception. - """ - try: - callableObj(*args, **kwargs) - except excClass: - return - else: - if hasattr(excClass,'__name__'): excName = excClass.__name__ - else: excName = str(excClass) - raise self.failureException, "%s not raised" % excName - - def failUnlessEqual(self, first, second, msg=None): - """Fail if the two objects are unequal as determined by the '==' - operator. - """ - if not first == second: - raise self.failureException, \ - (msg or '%r != %r' % (first, second)) - - def failIfEqual(self, first, second, msg=None): - """Fail if the two objects are equal as determined by the '==' - operator. - """ - if first == second: - raise self.failureException, \ - (msg or '%r == %r' % (first, second)) - - def failUnlessAlmostEqual(self, first, second, places=7, msg=None): - """Fail if the two objects are unequal as determined by their - difference rounded to the given number of decimal places - (default 7) and comparing to zero. - - Note that decimal places (from zero) are usually not the same - as significant digits (measured from the most signficant digit). - """ - if round(second-first, places) != 0: - raise self.failureException, \ - (msg or '%r != %r within %r places' % (first, second, places)) - - def failIfAlmostEqual(self, first, second, places=7, msg=None): - """Fail if the two objects are equal as determined by their - difference rounded to the given number of decimal places - (default 7) and comparing to zero. - - Note that decimal places (from zero) are usually not the same - as significant digits (measured from the most signficant digit). - """ - if round(second-first, places) == 0: - raise self.failureException, \ - (msg or '%r == %r within %r places' % (first, second, places)) - - # Synonyms for assertion methods - - assertEqual = assertEquals = failUnlessEqual - - assertNotEqual = assertNotEquals = failIfEqual - - assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual - - assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual - - assertRaises = failUnlessRaises - - assert_ = assertTrue = failUnless - - assertFalse = failIf - - - -class TestSuite: - """A test suite is a composite test consisting of a number of TestCases. - - For use, create an instance of TestSuite, then add test case instances. - When all tests have been added, the suite can be passed to a test - runner, such as TextTestRunner. It will run the individual test cases - in the order in which they were added, aggregating the results. When - subclassing, do not forget to call the base class constructor. - """ - def __init__(self, tests=()): - self._tests = [] - self.addTests(tests) - - def __repr__(self): - return "<%s tests=%s>" % (_strclass(self.__class__), self._tests) - - __str__ = __repr__ - - def __iter__(self): - return iter(self._tests) - - def countTestCases(self): - cases = 0 - for test in self._tests: - cases += test.countTestCases() - return cases - - def addTest(self, test): - # sanity checks - if not callable(test): - raise TypeError("the test to add must be callable") - if (isinstance(test, (type, types.ClassType)) and - issubclass(test, (TestCase, TestSuite))): - raise TypeError("TestCases and TestSuites must be instantiated " - "before passing them to addTest()") - self._tests.append(test) - - def addTests(self, tests): - if isinstance(tests, basestring): - raise TypeError("tests must be an iterable of tests, not a string") - for test in tests: - self.addTest(test) - - def run(self, result): - for test in self._tests: - if result.shouldStop: - break - test(result) - return result - - def __call__(self, *args, **kwds): - return self.run(*args, **kwds) - - def debug(self): - """Run the tests without collecting errors in a TestResult""" - for test in self._tests: test.debug() - - -class FunctionTestCase(TestCase): - """A test case that wraps a test function. - - This is useful for slipping pre-existing test functions into the - PyUnit framework. Optionally, set-up and tidy-up functions can be - supplied. As with TestCase, the tidy-up ('tearDown') function will - always be called if the set-up ('setUp') function ran successfully. - """ - - def __init__(self, testFunc, setUp=None, tearDown=None, - description=None): - TestCase.__init__(self) - self.__setUpFunc = setUp - self.__tearDownFunc = tearDown - self.__testFunc = testFunc - self.__description = description - - def setUp(self): - if self.__setUpFunc is not None: - self.__setUpFunc() - - def tearDown(self): - if self.__tearDownFunc is not None: - self.__tearDownFunc() - - def runTest(self): - self.__testFunc() - - def id(self): - return self.__testFunc.__name__ - - def __str__(self): - return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__) - - def __repr__(self): - return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc) - - def shortDescription(self): - if self.__description is not None: return self.__description - doc = self.__testFunc.__doc__ - return doc and doc.split("\n")[0].strip() or None - - - -############################################################################## -# Locating and loading tests -############################################################################## - -class TestLoader: - """This class is responsible for loading tests according to various - criteria and returning them wrapped in a Test - """ - testMethodPrefix = 'test' - suiteClass = TestSuite - - def loadTestsFromTestCase(self, testCaseClass): - """Return a suite of all tests cases contained in testCaseClass""" - if issubclass(testCaseClass, TestSuite): - raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?") - testCaseNames = self.getTestCaseNames(testCaseClass) - if not testCaseNames and hasattr(testCaseClass, 'runTest'): - testCaseNames = ['runTest'] - return self.suiteClass(map(testCaseClass, testCaseNames)) - - def loadTestsFromModule(self, module): - """Return a suite of all tests cases contained in the given module""" - tests = [] - for name in dir(module): - obj = getattr(module, name) - if (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, TestCase)): - tests.append(self.loadTestsFromTestCase(obj)) - return self.suiteClass(tests) - - def loadTestsFromName(self, name, module=None): - """Return a suite of all tests cases given a string specifier. - - The name may resolve either to a module, a test case class, a - test method within a test case class, or a callable object which - returns a TestCase or TestSuite instance. - - The method optionally resolves the names relative to a given module. - """ - parts = name.split('.') - if module is None: - parts_copy = parts[:] - while parts_copy: - try: - module = __import__('.'.join(parts_copy)) - break - except ImportError: - del parts_copy[-1] - if not parts_copy: raise - parts = parts[1:] - obj = module - for part in parts: - parent, obj = obj, getattr(obj, part) - - if type(obj) == types.ModuleType: - return self.loadTestsFromModule(obj) - elif (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, TestCase)): - return self.loadTestsFromTestCase(obj) - elif type(obj) == types.UnboundMethodType: - return parent(obj.__name__) - elif isinstance(obj, TestSuite): - return obj - elif callable(obj): - test = obj() - if not isinstance(test, (TestCase, TestSuite)): - raise ValueError, \ - "calling %s returned %s, not a test" % (obj,test) - return test - else: - raise ValueError, "don't know how to make test from: %s" % obj - - def loadTestsFromNames(self, names, module=None): - """Return a suite of all tests cases found using the given sequence - of string specifiers. See 'loadTestsFromName()'. - """ - suites = [self.loadTestsFromName(name, module) for name in names] - return self.suiteClass(suites) - - def getTestCaseNames(self, testCaseClass): - """Return a sorted sequence of method names found within testCaseClass - """ - - def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): - return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname)) - testFnNames = filter(isTestMethod, dir(testCaseClass)) - for baseclass in testCaseClass.__bases__: - for testFnName in self.getTestCaseNames(baseclass): - if testFnName not in testFnNames: # handle overridden methods - testFnNames.append(testFnName) - testFnNames.sort() - return testFnNames - - - -defaultTestLoader = TestLoader() - - -############################################################################## -# Patches for old functions: these functions should be considered obsolete -############################################################################## - -def _makeLoader(prefix, sortUsing, suiteClass=None): - loader = TestLoader() - loader.testMethodPrefix = prefix - if suiteClass: loader.suiteClass = suiteClass - return loader - -def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp): - return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) - -def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite): - return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass) - -def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite): - return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module) - - -############################################################################## -# Text UI -############################################################################## - -class _WritelnDecorator: - """Used to decorate file-like objects with a handy 'writeln' method""" - def __init__(self,stream): - self.stream = stream - - def __getattr__(self, attr): - return getattr(self.stream,attr) - - def writeln(self, arg=None): - if arg: self.write(arg) - self.write('\n') # text-mode streams translate to \r\n if needed - - -class _TextTestResult(TestResult): - """A test result class that can print formatted text results to a stream. - - Used by TextTestRunner. - """ - separator1 = '=' * 70 - separator2 = '-' * 70 - - def __init__(self, stream, descriptions, verbosity): - TestResult.__init__(self) - self.stream = stream - self.showAll = verbosity > 1 - self.dots = verbosity == 1 - self.descriptions = descriptions - - def getDescription(self, test): - if self.descriptions: - return test.shortDescription() or str(test) - else: - return str(test) - - def startTest(self, test): - TestResult.startTest(self, test) - if self.showAll: - self.stream.write(self.getDescription(test)) - self.stream.write(" ... ") - - def addSuccess(self, test): - TestResult.addSuccess(self, test) - if self.showAll: - self.stream.writeln("ok") - elif self.dots: - self.stream.write('.') - - def addError(self, test, err): - TestResult.addError(self, test, err) - if self.showAll: - self.stream.writeln("ERROR") - elif self.dots: - self.stream.write('E') - - def addFailure(self, test, err): - TestResult.addFailure(self, test, err) - if self.showAll: - self.stream.writeln("FAIL") - elif self.dots: - self.stream.write('F') - - def printErrors(self): - if self.dots or self.showAll: - self.stream.writeln() - self.printErrorList('ERROR', self.errors) - self.printErrorList('FAIL', self.failures) - - def printErrorList(self, flavour, errors): - for test, err in errors: - self.stream.writeln(self.separator1) - self.stream.writeln("%s: %s" % (flavour,self.getDescription(test))) - self.stream.writeln(self.separator2) - self.stream.writeln("%s" % err) - - -class TextTestRunner: - """A test runner class that displays results in textual form. - - It prints out the names of tests as they are run, errors as they - occur, and a summary of the results at the end of the test run. - """ - def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1): - self.stream = _WritelnDecorator(stream) - self.descriptions = descriptions - self.verbosity = verbosity - - def _makeResult(self): - return _TextTestResult(self.stream, self.descriptions, self.verbosity) - - def run(self, test): - "Run the given test case or test suite." - result = self._makeResult() - startTime = time.time() - test(result) - stopTime = time.time() - timeTaken = stopTime - startTime - result.printErrors() - self.stream.writeln(result.separator2) - run = result.testsRun - self.stream.writeln("Ran %d test%s in %.3fs" % - (run, run != 1 and "s" or "", timeTaken)) - self.stream.writeln() - if not result.wasSuccessful(): - self.stream.write("FAILED (") - failed, errored = map(len, (result.failures, result.errors)) - if failed: - self.stream.write("failures=%d" % failed) - if errored: - if failed: self.stream.write(", ") - self.stream.write("errors=%d" % errored) - self.stream.writeln(")") - else: - self.stream.writeln("OK") - return result - - - -############################################################################## -# Facilities for running tests from the command line -############################################################################## - -class TestProgram: - """A command-line program that runs a set of tests; this is primarily - for making test modules conveniently executable. - """ - USAGE = """\ -Usage: %(progName)s [options] [test] [...] - -Options: - -h, --help Show this message - -v, --verbose Verbose output - -q, --quiet Minimal output - -Examples: - %(progName)s - run default set of tests - %(progName)s MyTestSuite - run suite 'MyTestSuite' - %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething - %(progName)s MyTestCase - run all 'test*' test methods - in MyTestCase -""" - def __init__(self, module='__main__', defaultTest=None, - argv=None, testRunner=None, testLoader=defaultTestLoader): - if type(module) == type(''): - self.module = __import__(module) - for part in module.split('.')[1:]: - self.module = getattr(self.module, part) - else: - self.module = module - if argv is None: - argv = sys.argv - self.verbosity = 1 - self.defaultTest = defaultTest - self.testRunner = testRunner - self.testLoader = testLoader - self.progName = os.path.basename(argv[0]) - self.parseArgs(argv) - self.runTests() - - def usageExit(self, msg=None): - if msg: print msg - print self.USAGE % self.__dict__ - sys.exit(2) - - def parseArgs(self, argv): - import getopt - try: - options, args = getopt.getopt(argv[1:], 'hHvq', - ['help','verbose','quiet']) - for opt, value in options: - if opt in ('-h','-H','--help'): - self.usageExit() - if opt in ('-q','--quiet'): - self.verbosity = 0 - if opt in ('-v','--verbose'): - self.verbosity = 2 - if len(args) == 0 and self.defaultTest is None: - self.test = self.testLoader.loadTestsFromModule(self.module) - return - if len(args) > 0: - self.testNames = args - else: - self.testNames = (self.defaultTest,) - self.createTests() - except getopt.error, msg: - self.usageExit(msg) - - def createTests(self): - self.test = self.testLoader.loadTestsFromNames(self.testNames, - self.module) - - def runTests(self): - if self.testRunner is None: - self.testRunner = TextTestRunner(verbosity=self.verbosity) - result = self.testRunner.run(self.test) - sys.exit(not result.wasSuccessful()) - -main = TestProgram - - -############################################################################## -# Executing this module from the command line -############################################################################## - -if __name__ == "__main__": - main(module=None) diff --git a/test/testlib/schema.py b/test/testlib/schema.py deleted file mode 100644 index 7009fd65d..000000000 --- a/test/testlib/schema.py +++ /dev/null @@ -1,79 +0,0 @@ -from testlib import testing - -schema = None - -__all__ = 'Table', 'Column', - -table_options = {} - -def Table(*args, **kw): - """A schema.Table wrapper/hook for dialect-specific tweaks.""" - - global schema - if schema is None: - from sqlalchemy import schema - - test_opts = dict([(k,kw.pop(k)) for k in kw.keys() - if k.startswith('test_')]) - - kw.update(table_options) - - if testing.against('mysql'): - if 'mysql_engine' not in kw and 'mysql_type' not in kw: - if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: - kw['mysql_engine'] = 'InnoDB' - - # Apply some default cascading rules for self-referential foreign keys. - # MySQL InnoDB has some issues around seleting self-refs too. - if testing.against('firebird'): - table_name = args[0] - unpack = (testing.config.db.dialect. - identifier_preparer.unformat_identifiers) - - # Only going after ForeignKeys in Columns. May need to - # expand to ForeignKeyConstraint too. - fks = [fk - for col in args if isinstance(col, schema.Column) - for fk in col.args if isinstance(fk, schema.ForeignKey)] - - for fk in fks: - # root around in raw spec - ref = fk._colspec - if isinstance(ref, schema.Column): - name = ref.table.name - else: - # take just the table name: on FB there cannot be - # a schema, so the first element is always the - # table name, possibly followed by the field name - name = unpack(ref)[0] - if name == table_name: - if fk.ondelete is None: - fk.ondelete = 'CASCADE' - if fk.onupdate is None: - fk.onupdate = 'CASCADE' - - if testing.against('firebird', 'oracle'): - pk_seqs = [col for col in args - if (isinstance(col, schema.Column) - and col.primary_key - and getattr(col, '_needs_autoincrement', False))] - for c in pk_seqs: - c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True)) - return schema.Table(*args, **kw) - - -def Column(*args, **kw): - """A schema.Column wrapper/hook for dialect-specific tweaks.""" - - global schema - if schema is None: - from sqlalchemy import schema - - test_opts = dict([(k,kw.pop(k)) for k in kw.keys() - if k.startswith('test_')]) - - c = schema.Column(*args, **kw) - if testing.against('firebird', 'oracle'): - if 'test_needs_autoincrement' in test_opts: - c._needs_autoincrement = True - return c diff --git a/test/testlib/testing.py b/test/testlib/testing.py deleted file mode 100644 index 408dda79f..000000000 --- a/test/testlib/testing.py +++ /dev/null @@ -1,919 +0,0 @@ -"""TestCase and TestSuite artifacts and testing decorators.""" - -# monkeypatches unittest.TestLoader.suiteClass at import time - -import itertools -import operator -import re -import sys -import types -from testlib import sa_unittest as unittest -import warnings -from cStringIO import StringIO - -import testlib.config as config -from testlib.compat import _function_named, callable - -# Delayed imports -MetaData = None -Session = None -clear_mappers = None -sa_exc = None -schema = None -sqltypes = None -util = None - - -_ops = { '<': operator.lt, - '>': operator.gt, - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '>=': operator.ge, - 'in': operator.contains, - 'between': lambda val, pair: val >= pair[0] and val <= pair[1], - } - -# sugar ('testing.db'); set here by config() at runtime -db = None - -# more sugar, installed by __init__ -requires = None - -def fails_if(callable_): - """Mark a test as expected to fail if callable_ returns True. - - If the callable returns false, the test is run and reported as normal. - However if the callable returns true, the test is expected to fail and the - unit test logic is inverted: if the test fails, a success is reported. If - the test succeeds, a failure is reported. - """ - - docstring = getattr(callable_, '__doc__', None) or callable_.__name__ - description = docstring.split('\n')[0] - - def decorate(fn): - fn_name = fn.__name__ - def maybe(*args, **kw): - if not callable_(): - return fn(*args, **kw) - else: - try: - fn(*args, **kw) - except Exception, ex: - print ("'%s' failed as expected (condition: %s): %s " % ( - fn_name, description, str(ex))) - return True - else: - raise AssertionError( - "Unexpected success for '%s' (condition: %s)" % - (fn_name, description)) - return _function_named(maybe, fn_name) - return decorate - - -def future(fn): - """Mark a test as expected to unconditionally fail. - - Takes no arguments, omit parens when using as a decorator. - """ - - fn_name = fn.__name__ - def decorated(*args, **kw): - try: - fn(*args, **kw) - except Exception, ex: - print ("Future test '%s' failed as expected: %s " % ( - fn_name, str(ex))) - return True - else: - raise AssertionError( - "Unexpected success for future test '%s'" % fn_name) - return _function_named(decorated, fn_name) - -def fails_on(dbs, reason): - """Mark a test as expected to fail on the specified database - implementation. - - Unlike ``crashes``, tests marked as ``fails_on`` will be run - for the named databases. The test is expected to fail and the unit test - logic is inverted: if the test fails, a success is reported. If the test - succeeds, a failure is reported. - """ - - def decorate(fn): - fn_name = fn.__name__ - def maybe(*args, **kw): - if config.db.name != dbs: - return fn(*args, **kw) - else: - try: - fn(*args, **kw) - except Exception, ex: - print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, reason)) - return True - else: - raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) - return _function_named(maybe, fn_name) - return decorate - -def fails_on_everything_except(*dbs): - """Mark a test as expected to fail on most database implementations. - - Like ``fails_on``, except failure is the expected outcome on all - databases except those listed. - """ - - def decorate(fn): - fn_name = fn.__name__ - def maybe(*args, **kw): - if config.db.name in dbs: - return fn(*args, **kw) - else: - try: - fn(*args, **kw) - except Exception, ex: - print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, str(ex))) - return True - else: - raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) - return _function_named(maybe, fn_name) - return decorate - -def crashes(db, reason): - """Mark a test as unsupported by a database implementation. - - ``crashes`` tests will be skipped unconditionally. Use for feature tests - that cause deadlocks or other fatal problems. - - """ - carp = _should_carp_about_exclusion(reason) - def decorate(fn): - fn_name = fn.__name__ - def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) - print msg - if carp: - print >> sys.stderr, msg - return True - else: - return fn(*args, **kw) - return _function_named(maybe, fn_name) - return decorate - -def _block_unconditionally(db, reason): - """Mark a test as unsupported by a database implementation. - - Will never run the test against any version of the given database, ever, - no matter what. Use when your assumptions are infallible; past, present - and future. - - """ - carp = _should_carp_about_exclusion(reason) - def decorate(fn): - fn_name = fn.__name__ - def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) - print msg - if carp: - print >> sys.stderr, msg - return True - else: - return fn(*args, **kw) - return _function_named(maybe, fn_name) - return decorate - - -def exclude(db, op, spec, reason): - """Mark a test as unsupported by specific database server versions. - - Stackable, both with other excludes and other decorators. Examples:: - - # Not supported by mydb versions less than 1, 0 - @exclude('mydb', '<', (1,0)) - # Other operators work too - @exclude('bigdb', '==', (9,0,9)) - @exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) - - """ - carp = _should_carp_about_exclusion(reason) - def decorate(fn): - fn_name = fn.__name__ - def maybe(*args, **kw): - if _is_excluded(db, op, spec): - msg = "'%s' unsupported on DB %s version '%s': %s" % ( - fn_name, config.db.name, _server_version(), reason) - print msg - if carp: - print >> sys.stderr, msg - return True - else: - return fn(*args, **kw) - return _function_named(maybe, fn_name) - return decorate - -def _should_carp_about_exclusion(reason): - """Guard against forgotten exclusions.""" - assert reason - for _ in ('todo', 'fixme', 'xxx'): - if _ in reason.lower(): - return True - else: - if len(reason) < 4: - return True - -def _is_excluded(db, op, spec): - """Return True if the configured db matches an exclusion specification. - - db: - A dialect name - op: - An operator or stringified operator, such as '==' - spec: - A value that will be compared to the dialect's server_version_info - using the supplied operator. - - Examples:: - # Not supported by mydb versions less than 1, 0 - _is_excluded('mydb', '<', (1,0)) - # Other operators work too - _is_excluded('bigdb', '==', (9,0,9)) - _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) - """ - - if config.db.name != db: - return False - - version = _server_version() - - oper = hasattr(op, '__call__') and op or _ops[op] - return oper(version, spec) - -def _server_version(bind=None): - """Return a server_version_info tuple.""" - - if bind is None: - bind = config.db - return bind.dialect.server_version_info(bind.contextual_connect()) - -def skip_if(predicate, reason=None): - """Skip a test if predicate is true.""" - reason = reason or predicate.__name__ - def decorate(fn): - fn_name = fn.__name__ - def maybe(*args, **kw): - if predicate(): - msg = "'%s' skipped on DB %s version '%s': %s" % ( - fn_name, config.db.name, _server_version(), reason) - print msg - return True - else: - return fn(*args, **kw) - return _function_named(maybe, fn_name) - return decorate - -def emits_warning(*messages): - """Mark a test as emitting a warning. - - With no arguments, squelches all SAWarning failures. Or pass one or more - strings; these will be matched to the root of the warning description by - warnings.filterwarnings(). - """ - - # TODO: it would be nice to assert that a named warning was - # emitted. should work with some monkeypatching of warnings, - # and may work on non-CPython if they keep to the spirit of - # warnings.showwarning's docstring. - # - update: jython looks ok, it uses cpython's module - def decorate(fn): - def safe(*args, **kw): - global sa_exc - if sa_exc is None: - import sqlalchemy.exc as sa_exc - - # todo: should probably be strict about this, too - filters = [dict(action='ignore', - category=sa_exc.SAPendingDeprecationWarning)] - if not messages: - filters.append(dict(action='ignore', - category=sa_exc.SAWarning)) - else: - filters.extend(dict(action='ignore', - message=message, - category=sa_exc.SAWarning) - for message in messages) - for f in filters: - warnings.filterwarnings(**f) - try: - return fn(*args, **kw) - finally: - resetwarnings() - return _function_named(safe, fn.__name__) - return decorate - -def emits_warning_on(db, *warnings): - """Mark a test as emitting a warning on a specific dialect. - - With no arguments, squelches all SAWarning failures. Or pass one or more - strings; these will be matched to the root of the warning description by - warnings.filterwarnings(). - """ - def decorate(fn): - def maybe(*args, **kw): - if isinstance(db, basestring): - if config.db.name != db: - return fn(*args, **kw) - else: - wrapped = emits_warning(*warnings)(fn) - return wrapped(*args, **kw) - else: - if not _is_excluded(*db): - return fn(*args, **kw) - else: - wrapped = emits_warning(*warnings)(fn) - return wrapped(*args, **kw) - return _function_named(maybe, fn.__name__) - return decorate - -def uses_deprecated(*messages): - """Mark a test as immune from fatal deprecation warnings. - - With no arguments, squelches all SADeprecationWarning failures. - Or pass one or more strings; these will be matched to the root - of the warning description by warnings.filterwarnings(). - - As a special case, you may pass a function name prefixed with // - and it will be re-written as needed to match the standard warning - verbiage emitted by the sqlalchemy.util.deprecated decorator. - """ - - def decorate(fn): - def safe(*args, **kw): - global sa_exc - if sa_exc is None: - import sqlalchemy.exc as sa_exc - - # todo: should probably be strict about this, too - filters = [dict(action='ignore', - category=sa_exc.SAPendingDeprecationWarning)] - if not messages: - filters.append(dict(action='ignore', - category=sa_exc.SADeprecationWarning)) - else: - filters.extend( - [dict(action='ignore', - message=message, - category=sa_exc.SADeprecationWarning) - for message in - [ (m.startswith('//') and - ('Call to deprecated function ' + m[2:]) or m) - for m in messages] ]) - - for f in filters: - warnings.filterwarnings(**f) - try: - return fn(*args, **kw) - finally: - resetwarnings() - return _function_named(safe, fn.__name__) - return decorate - -def resetwarnings(): - """Reset warning behavior to testing defaults.""" - - global sa_exc - if sa_exc is None: - import sqlalchemy.exc as sa_exc - - warnings.filterwarnings('ignore', - category=sa_exc.SAPendingDeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SAWarning) - -# warnings.simplefilter('error') - - if sys.version_info < (2, 4): - warnings.filterwarnings('ignore', category=FutureWarning) - - -def against(*queries): - """Boolean predicate, compares to testing database configuration. - - Given one or more dialect names, returns True if one is the configured - database engine. - - Also supports comparison to database version when provided with one or - more 3-tuples of dialect name, operator, and version specification:: - - testing.against('mysql', 'postgres') - testing.against(('mysql', '>=', (5, 0, 0)) - """ - - for query in queries: - if isinstance(query, basestring): - if config.db.name == query: - return True - else: - name, op, spec = query - if config.db.name != name: - continue - - have = config.db.dialect.server_version_info( - config.db.contextual_connect()) - - oper = hasattr(op, '__call__') and op or _ops[op] - if oper(have, spec): - return True - return False - -def _chain_decorators_on(fn, *decorators): - """Apply a series of decorators to fn, returning a decorated function.""" - for decorator in reversed(decorators): - fn = decorator(fn) - return fn - -def rowset(results): - """Converts the results of sql execution into a plain set of column tuples. - - Useful for asserting the results of an unordered query. - """ - - return set([tuple(row) for row in results]) - - -def eq_(a, b, msg=None): - """Assert a == b, with repr messaging on failure.""" - assert a == b, msg or "%r != %r" % (a, b) - -def ne_(a, b, msg=None): - """Assert a != b, with repr messaging on failure.""" - assert a != b, msg or "%r == %r" % (a, b) - -def is_(a, b, msg=None): - """Assert a is b, with repr messaging on failure.""" - assert a is b, msg or "%r is not %r" % (a, b) - -def is_not_(a, b, msg=None): - """Assert a is not b, with repr messaging on failure.""" - assert a is not b, msg or "%r is %r" % (a, b) - -def startswith_(a, fragment, msg=None): - """Assert a.startswith(fragment), with repr messaging on failure.""" - assert a.startswith(fragment), msg or "%r does not start with %r" % ( - a, fragment) - - -def fixture(table, columns, *rows): - """Insert data into table after creation.""" - def onload(event, schema_item, connection): - insert = table.insert() - column_names = [col.key for col in columns] - connection.execute(insert, [dict(zip(column_names, column_values)) - for column_values in rows]) - table.append_ddl_listener('after-create', onload) - -def _import_by_name(name): - submodule = name.split('.')[-1] - return __import__(name, globals(), locals(), [submodule]) - -class CompositeModule(types.ModuleType): - """Merged attribute access for multiple modules.""" - - # break the habit - __all__ = () - - def __init__(self, name, *modules, **overrides): - """Construct a new lazy composite of modules. - - Modules may be string names or module-like instances. Individual - attribute overrides may be specified as keyword arguments for - convenience. - - The constructed module will resolve attribute access in reverse order: - overrides, then each member of reversed(modules). Modules specified - by name will be loaded lazily when encountered in attribute - resolution. - - """ - types.ModuleType.__init__(self, name) - self.__modules = list(reversed(modules)) - for key, value in overrides.iteritems(): - setattr(self, key, value) - - def __getattr__(self, key): - for idx, mod in enumerate(self.__modules): - if isinstance(mod, basestring): - self.__modules[idx] = mod = _import_by_name(mod) - if hasattr(mod, key): - return getattr(mod, key) - raise AttributeError(key) - - -def resolve_artifact_names(fn): - """Decorator, augment function globals with tables and classes. - - Swaps out the function's globals at execution time. The 'global' statement - will not work as expected inside a decorated function. - - """ - # This could be automatically applied to framework and test_ methods in - # the MappedTest-derived test suites but... *some* explicitness for this - # magic is probably good. Especially as 'global' won't work- these - # rebound functions aren't regular Python.. - # - # Also: it's lame that CPython accepts a dict-subclass for globals, but - # only calls dict methods. That would allow 'global' to pass through to - # the func_globals. - def resolved(*args, **kwargs): - self = args[0] - context = dict(fn.func_globals) - for source in self._artifact_registries: - context.update(getattr(self, source)) - # jython bug #1034 - rebound = types.FunctionType( - fn.func_code, context, fn.func_name, fn.func_defaults, - fn.func_closure) - return rebound(*args, **kwargs) - return _function_named(resolved, fn.func_name) - -class adict(dict): - """Dict keys available as attributes. Shadows.""" - def __getattribute__(self, key): - try: - return self[key] - except KeyError: - return dict.__getattribute__(self, key) - - def get_all(self, *keys): - return tuple([self[key] for key in keys]) - - -class TestBase(unittest.TestCase): - # A sequence of database names to always run, regardless of the - # constraints below. - __whitelist__ = () - - # A sequence of requirement names matching testing.requires decorators - __requires__ = () - - # A sequence of dialect names to exclude from the test class. - __unsupported_on__ = () - - # If present, test class is only runnable for the *single* specified - # dialect. If you need multiple, use __unsupported_on__ and invert. - __only_on__ = None - - # A sequence of no-arg callables. If any are True, the entire testcase is - # skipped. - __skip_if__ = None - - - _artifact_registries = () - - _sa_first_test = False - _sa_last_test = False - - def __init__(self, *args, **params): - unittest.TestCase.__init__(self, *args, **params) - - def setUpAll(self): - pass - - def tearDownAll(self): - pass - - def shortDescription(self): - """overridden to not return docstrings""" - return None - - def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs): - try: - callable_(*args, **kwargs) - assert False, "Callable did not raise an exception" - except except_cls, e: - assert re.search(msg, str(e)), "%r !~ %s" % (msg, e) - - if not hasattr(unittest.TestCase, 'assertTrue'): - assertTrue = unittest.TestCase.failUnless - if not hasattr(unittest.TestCase, 'assertFalse'): - assertFalse = unittest.TestCase.failIf - -class AssertsCompiledSQL(object): - def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None): - if dialect is None: - dialect = getattr(self, '__dialect__', None) - - if params is None: - keys = None - else: - keys = params.keys() - - c = clause.compile(column_keys=keys, dialect=dialect) - - print "\nSQL String:\n" + str(c) + repr(c.params) - - cc = re.sub(r'\n', '', str(c)) - - self.assertEquals(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) - - if checkparams is not None: - self.assertEquals(c.construct_params(params), checkparams) - -class ComparesTables(object): - def assert_tables_equal(self, table, reflected_table): - global sqltypes, schema - if sqltypes is None: - import sqlalchemy.types as sqltypes - if schema is None: - import sqlalchemy.schema as schema - base_mro = sqltypes.TypeEngine.__mro__ - assert len(table.c) == len(reflected_table.c) - for c, reflected_c in zip(table.c, reflected_table.c): - self.assertEquals(c.name, reflected_c.name) - assert reflected_c is reflected_table.c[c.name] - self.assertEquals(c.primary_key, reflected_c.primary_key) - self.assertEquals(c.nullable, reflected_c.nullable) - assert len( - set(type(reflected_c.type).__mro__).difference(base_mro).intersection( - set(type(c.type).__mro__).difference(base_mro) - ) - ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) - - if isinstance(c.type, sqltypes.String): - self.assertEquals(c.type.length, reflected_c.type.length) - - self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) - if c.default: - assert isinstance(reflected_c.server_default, - schema.FetchedValue) - elif against(('mysql', '<', (5, 0))): - # ignore reflection of bogus db-generated DefaultClause() - pass - elif not c.primary_key or not against('postgres'): - print repr(c) - assert reflected_c.default is None, reflected_c.default - - assert len(table.primary_key) == len(reflected_table.primary_key) - for c in table.primary_key: - assert reflected_table.primary_key.columns[c.name] - - -class AssertsExecutionResults(object): - def assert_result(self, result, class_, *objects): - result = list(result) - print repr(result) - self.assert_list(result, class_, objects) - - def assert_list(self, result, class_, list): - self.assert_(len(result) == len(list), - "result list is not the same size as test list, " + - "for class " + class_.__name__) - for i in range(0, len(list)): - self.assert_row(class_, result[i], list[i]) - - def assert_row(self, class_, rowobj, desc): - self.assert_(rowobj.__class__ is class_, - "item class is not " + repr(class_)) - for key, value in desc.iteritems(): - if isinstance(value, tuple): - if isinstance(value[1], list): - self.assert_list(getattr(rowobj, key), value[0], value[1]) - else: - self.assert_row(value[0], getattr(rowobj, key), value[1]) - else: - self.assert_(getattr(rowobj, key) == value, - "attribute %s value %s does not match %s" % ( - key, getattr(rowobj, key), value)) - - def assert_unordered_result(self, result, cls, *expected): - """As assert_result, but the order of objects is not considered. - - The algorithm is very expensive but not a big deal for the small - numbers of rows that the test suite manipulates. - """ - - global util - if util is None: - from sqlalchemy import util - - class frozendict(dict): - def __hash__(self): - return id(self) - - found = util.IdentitySet(result) - expected = set([frozendict(e) for e in expected]) - - for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): - self.fail('Unexpected type "%s", expected "%s"' % ( - type(wrong).__name__, cls.__name__)) - - if len(found) != len(expected): - self.fail('Unexpected object count "%s", expected "%s"' % ( - len(found), len(expected))) - - NOVALUE = object() - def _compare_item(obj, spec): - for key, value in spec.iteritems(): - if isinstance(value, tuple): - try: - self.assert_unordered_result( - getattr(obj, key), value[0], *value[1]) - except AssertionError: - return False - else: - if getattr(obj, key, NOVALUE) != value: - return False - return True - - for expected_item in expected: - for found_item in found: - if _compare_item(found_item, expected_item): - found.remove(found_item) - break - else: - self.fail( - "Expected %s instance with attributes %s not found." % ( - cls.__name__, repr(expected_item))) - return True - - def assert_sql_execution(self, db, callable_, *rules): - from testlib import assertsql - assertsql.asserter.add_rules(rules) - try: - callable_() - assertsql.asserter.statement_complete() - finally: - assertsql.asserter.clear_rules() - - def assert_sql(self, db, callable_, list_, with_sequences=None): - from testlib import assertsql - - if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): - rules = with_sequences - else: - rules = list_ - - newrules = [] - for rule in rules: - if isinstance(rule, dict): - newrule = assertsql.AllOf(*[ - assertsql.ExactSQL(k, v) for k, v in rule.iteritems() - ]) - else: - newrule = assertsql.ExactSQL(*rule) - newrules.append(newrule) - - self.assert_sql_execution(db, callable_, *newrules) - - def assert_sql_count(self, db, callable_, count): - from testlib import assertsql - self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) - - - -class TTestSuite(unittest.TestSuite): - """A TestSuite with once per TestCase setUpAll() and tearDownAll()""" - - def __init__(self, tests=()): - if len(tests) > 0 and isinstance(tests[0], TestBase): - self._initTest = tests[0] - else: - self._initTest = None - - for t in tests: - if isinstance(t, TestBase): - t._sa_first_test = True - break - for t in reversed(tests): - if isinstance(t, TestBase): - t._sa_last_test = True - break - unittest.TestSuite.__init__(self, tests) - - def run(self, result): - init = getattr(self, '_initTest', None) - if init is not None: - if (hasattr(init, '__whitelist__') and - config.db.name in init.__whitelist__): - pass - else: - if self.__should_skip_for(init): - return True - try: - resetwarnings() - init.setUpAll() - except: - # skip tests if global setup fails - ex = self.__exc_info() - for test in self._tests: - result.addError(test, ex) - return False - try: - resetwarnings() - for test in self._tests: - if result.shouldStop: - break - test(result) - return result - finally: - try: - resetwarnings() - if init is not None: - init.tearDownAll() - except: - result.addError(init, self.__exc_info()) - pass - - def __should_skip_for(self, cls): - if hasattr(cls, '__requires__'): - global requires - if requires is None: - from testing import requires - def test_suite(): return 'ok' - for requirement in cls.__requires__: - check = getattr(requires, requirement) - if check(test_suite)() != 'ok': - # The requirement will perform messaging. - return True - if (hasattr(cls, '__unsupported_on__') and - config.db.name in cls.__unsupported_on__): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, config.db.name) - return True - if (getattr(cls, '__only_on__', None) not in (None,config.db.name)): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, config.db.name) - return True - if (getattr(cls, '__skip_if__', False)): - for c in getattr(cls, '__skip_if__'): - if c(): - print "'%s' skipped by %s" % ( - cls.__class__.__name__, c.__name__) - return True - for rule in getattr(cls, '__excluded_on__', ()): - if _is_excluded(*rule): - print "'%s' unsupported on DB %s version %s" % ( - cls.__class__.__name__, config.db.name, - _server_version()) - return True - return False - - - def __exc_info(self): - """Return a version of sys.exc_info() with the traceback frame - minimised; usually the top level of the traceback frame is not - needed. - ripped off out of unittest module since its double __ - """ - exctype, excvalue, tb = sys.exc_info() - if sys.platform[:4] == 'java': ## tracebacks look different in Jython - return (exctype, excvalue, tb) - return (exctype, excvalue, tb) - -# monkeypatch -unittest.TestLoader.suiteClass = TTestSuite - - -class DevNullWriter(object): - def write(self, msg): - pass - def flush(self): - pass - -def runTests(suite): - verbose = config.options.verbose - quiet = config.options.quiet - orig_stdout = sys.stdout - - try: - if not verbose or quiet: - sys.stdout = DevNullWriter() - runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2) - return runner.run(suite) - finally: - if not verbose or quiet: - sys.stdout = orig_stdout - -def main(suite=None): - if not suite: - if sys.argv[1:]: - suite =unittest.TestLoader().loadTestsFromNames( - sys.argv[1:], __import__('__main__')) - else: - suite = unittest.TestLoader().loadTestsFromModule( - __import__('__main__')) - - result = runTests(suite) - sys.exit(not result.wasSuccessful()) |
