summaryrefslogtreecommitdiff
path: root/test/testlib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-06-10 21:18:24 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-06-10 21:18:24 +0000
commit45cec095b4904ba71425d2fe18c143982dd08f43 (patch)
treeaf5e540fdcbf1cb2a3337157d69d4b40be010fa8 /test/testlib
parent698a3c1ac665e7cd2ef8d5ad3ebf51b7fe6661f4 (diff)
downloadsqlalchemy-45cec095b4904ba71425d2fe18c143982dd08f43.tar.gz
- unit tests have been migrated from unittest to nose.
See README.unittests for information on how to run the tests. [ticket:970]
Diffstat (limited to 'test/testlib')
-rw-r--r--test/testlib/__init__.py38
-rw-r--r--test/testlib/assertsql.py283
-rw-r--r--test/testlib/compat.py19
-rw-r--r--test/testlib/config.py344
-rw-r--r--test/testlib/coverage.py1098
-rw-r--r--test/testlib/engines.py245
-rw-r--r--test/testlib/orm.py117
-rw-r--r--test/testlib/profiling.py207
-rw-r--r--test/testlib/requires.py127
-rw-r--r--test/testlib/sa_unittest.py787
-rw-r--r--test/testlib/schema.py79
-rw-r--r--test/testlib/testing.py919
12 files changed, 0 insertions, 4263 deletions
diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py
deleted file mode 100644
index 5b8075ddb..000000000
--- a/test/testlib/__init__.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""Enhance unittest and instrument SQLAlchemy classes for testing.
-
-Load after sqlalchemy imports to use instrumented stand-ins like Table.
-"""
-
-import sys
-import testlib.config
-from testlib.schema import Table, Column
-import testlib.testing as testing
-from testlib.testing import \
- AssertsCompiledSQL, \
- AssertsExecutionResults, \
- ComparesTables, \
- TestBase, \
- rowset
-from testlib.orm import mapper
-import testlib.profiling as profiling
-import testlib.engines as engines
-import testlib.requires as requires
-from testlib.compat import _function_named
-
-
-__all__ = ('testing',
- 'mapper',
- 'Table', 'Column',
- 'rowset',
- 'TestBase', 'AssertsExecutionResults',
- 'AssertsCompiledSQL', 'ComparesTables',
- 'profiling', 'engines',
- '_function_named')
-
-
-testing.requires = requires
-
-sys.modules['testlib.sa'] = sa = testing.CompositeModule(
- 'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule(
- 'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm'))
-sys.modules['testlib.sa.orm'] = sa.orm
diff --git a/test/testlib/assertsql.py b/test/testlib/assertsql.py
deleted file mode 100644
index dc2c6d40f..000000000
--- a/test/testlib/assertsql.py
+++ /dev/null
@@ -1,283 +0,0 @@
-
-from sqlalchemy.interfaces import ConnectionProxy
-from sqlalchemy.engine.default import DefaultDialect
-from sqlalchemy.engine.base import Connection
-from sqlalchemy import util
-import testing
-import re
-
-class AssertRule(object):
- def process_execute(self, clauseelement, *multiparams, **params):
- pass
-
- def process_cursor_execute(self, statement, parameters, context, executemany):
- pass
-
- def is_consumed(self):
- """Return True if this rule has been consumed, False if not.
-
- Should raise an AssertionError if this rule's condition has definitely failed.
-
- """
- raise NotImplementedError()
-
- def rule_passed(self):
- """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
-
- raise NotImplementedError()
-
- def consume_final(self):
- """Return True if this rule has been consumed.
-
- Should raise an AssertionError if this rule's condition has not been consumed or has failed.
-
- """
-
- if self._result is None:
- assert False, "Rule has not been consumed"
-
- return self.is_consumed()
-
-class SQLMatchRule(AssertRule):
- def __init__(self):
- self._result = None
- self._errmsg = ""
-
- def rule_passed(self):
- return self._result
-
- def is_consumed(self):
- if self._result is None:
- return False
-
- assert self._result, self._errmsg
-
- return True
-
-class ExactSQL(SQLMatchRule):
- def __init__(self, sql, params=None):
- SQLMatchRule.__init__(self)
- self.sql = sql
- self.params = params
-
- def process_cursor_execute(self, statement, parameters, context, executemany):
- if not context:
- return
-
- _received_statement = _process_engine_statement(statement, context)
- _received_parameters = context.compiled_parameters
-
- # TODO: remove this step once all unit tests
- # are migrated, as ExactSQL should really be *exact* SQL
- sql = _process_assertion_statement(self.sql, context)
-
- equivalent = _received_statement == sql
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
-
- if not isinstance(params, list):
- params = [params]
- equivalent = equivalent and params == context.compiled_parameters
- else:
- params = {}
-
-
- self._result = equivalent
- if not self._result:
- self._errmsg = "Testing for exact statement %r exact params %r, " \
- "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
-
-
-class RegexSQL(SQLMatchRule):
- def __init__(self, regex, params=None):
- SQLMatchRule.__init__(self)
- self.regex = re.compile(regex)
- self.orig_regex = regex
- self.params = params
-
- def process_cursor_execute(self, statement, parameters, context, executemany):
- if not context:
- return
-
- _received_statement = _process_engine_statement(statement, context)
- _received_parameters = context.compiled_parameters
-
- equivalent = bool(self.regex.match(_received_statement))
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
-
- if not isinstance(params, list):
- params = [params]
-
- # do a positive compare only
- for param, received in zip(params, _received_parameters):
- for k, v in param.iteritems():
- if k not in received or received[k] != v:
- equivalent = False
- break
- else:
- params = {}
-
- self._result = equivalent
- if not self._result:
- self._errmsg = "Testing for regex %r partial params %r, "\
- "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
-
-class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params):
- SQLMatchRule.__init__(self)
- self.statement = statement
- self.params = params
-
- def process_cursor_execute(self, statement, parameters, context, executemany):
- if not context:
- return
-
- _received_parameters = context.compiled_parameters
-
- # recompile from the context, using the default dialect
- compiled = context.compiled.statement.\
- compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
-
- _received_statement = re.sub(r'\n', '', str(compiled))
-
- equivalent = self.statement == _received_statement
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
-
- if not isinstance(params, list):
- params = [params]
-
- # do a positive compare only
- for param, received in zip(params, _received_parameters):
- for k, v in param.iteritems():
- if k not in received or received[k] != v:
- equivalent = False
- break
- else:
- params = {}
-
- self._result = equivalent
- if not self._result:
- self._errmsg = "Testing for compiled statement %r partial params %r, " \
- "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
-
-
-class CountStatements(AssertRule):
- def __init__(self, count):
- self.count = count
- self._statement_count = 0
-
- def process_execute(self, clauseelement, *multiparams, **params):
- self._statement_count += 1
-
- def process_cursor_execute(self, statement, parameters, context, executemany):
- pass
-
- def is_consumed(self):
- return False
-
- def consume_final(self):
- assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
- return True
-
-class AllOf(AssertRule):
- def __init__(self, *rules):
- self.rules = set(rules)
-
- def process_execute(self, clauseelement, *multiparams, **params):
- for rule in self.rules:
- rule.process_execute(clauseelement, *multiparams, **params)
-
- def process_cursor_execute(self, statement, parameters, context, executemany):
- for rule in self.rules:
- rule.process_cursor_execute(statement, parameters, context, executemany)
-
- def is_consumed(self):
- if not self.rules:
- return True
-
- for rule in list(self.rules):
- if rule.rule_passed(): # a rule passed, move on
- self.rules.remove(rule)
- return len(self.rules) == 0
-
- assert False, "No assertion rules were satisfied for statement"
-
- def consume_final(self):
- return len(self.rules) == 0
-
-def _process_engine_statement(query, context):
- if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
- query = query[:-25]
-
- query = re.sub(r'\n', '', query)
-
- return query
-
-def _process_assertion_statement(query, context):
- paramstyle = context.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle =='pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle=='qmark':
- repl = "?"
- elif paramstyle=='format':
- repl = r"%s"
- elif paramstyle=='numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
-
- return query
-
-class SQLAssert(ConnectionProxy):
- rules = None
-
- def add_rules(self, rules):
- self.rules = list(rules)
-
- def statement_complete(self):
- for rule in self.rules:
- if not rule.consume_final():
- assert False, "All statements are complete, but pending assertion rules remain"
-
- def clear_rules(self):
- del self.rules
-
- def execute(self, conn, execute, clauseelement, *multiparams, **params):
- result = execute(clauseelement, *multiparams, **params)
-
- if self.rules is not None:
- if not self.rules:
- assert False, "All rules have been exhausted, but further statements remain"
- rule = self.rules[0]
- rule.process_execute(clauseelement, *multiparams, **params)
- if rule.is_consumed():
- self.rules.pop(0)
-
- return result
-
- def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
- result = execute(cursor, statement, parameters, context)
-
- if self.rules:
- rule = self.rules[0]
- rule.process_cursor_execute(statement, parameters, context, executemany)
-
- return result
-
-asserter = SQLAssert()
-
diff --git a/test/testlib/compat.py b/test/testlib/compat.py
deleted file mode 100644
index 73eb2d651..000000000
--- a/test/testlib/compat.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import types
-import __builtin__
-
-__all__ = '_function_named', 'callable'
-
-
-def _function_named(fn, newname):
- try:
- fn.__name__ = newname
- except:
- fn = types.FunctionType(fn.func_code, fn.func_globals, newname,
- fn.func_defaults, fn.func_closure)
- return fn
-
-try:
- callable = __builtin__.callable
-except NameError:
- def callable(fn): return hasattr(fn, '__call__')
-
diff --git a/test/testlib/config.py b/test/testlib/config.py
deleted file mode 100644
index cef4c6e1d..000000000
--- a/test/testlib/config.py
+++ /dev/null
@@ -1,344 +0,0 @@
-import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
-logging, require = None, None
-
-
-__all__ = 'parser', 'configure', 'options',
-
-db = None
-db_label, db_url, db_opts = None, None, {}
-
-options = None
-file_config = None
-coverage_enabled = False
-
-base_config = """
-[db]
-sqlite=sqlite:///:memory:
-sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test
-oracle=oracle://scott:tiger@127.0.0.1:1521
-oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
-mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
-firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
-maxdb=maxdb://MONA:RED@/maxdb1
-"""
-
-parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
-
-def configure():
- global options, config
- global getopts_options, file_config
-
- file_config = ConfigParser.ConfigParser()
- file_config.readfp(StringIO.StringIO(base_config))
- file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
-
- # Opt parsing can fire immediate actions, like logging and coverage
- (options, args) = parser.parse_args()
- sys.argv[1:] = args
-
- # Lazy setup of other options (post coverage)
- for fn in post_configure:
- fn(options, file_config)
-
- return options, file_config
-
-def configure_defaults():
- global options, config
- global getopts_options, file_config
- global db
-
- file_config = ConfigParser.ConfigParser()
- file_config.readfp(StringIO.StringIO(base_config))
- file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
- (options, args) = parser.parse_args([])
-
- # make error messages raised by decorators that depend on a default
- # database clearer.
- class _engine_bomb(object):
- def __getattr__(self, key):
- raise RuntimeError('No default engine available, testlib '
- 'was configured with defaults only.')
-
- db = _engine_bomb()
- import testlib.testing
- testlib.testing.db = db
-
- return options, file_config
-
-def _log(option, opt_str, value, parser):
- global logging
- if not logging:
- import logging
- logging.basicConfig()
-
- if opt_str.endswith('-info'):
- logging.getLogger(value).setLevel(logging.INFO)
- elif opt_str.endswith('-debug'):
- logging.getLogger(value).setLevel(logging.DEBUG)
-
-def _start_cumulative_coverage(option, opt_str, value, parser):
- _start_coverage(option, opt_str, value, parser, erase=False)
-
-def _start_coverage(option, opt_str, value, parser, erase=True):
- import sys, atexit, coverage
- true_out = sys.stdout
-
- global coverage_enabled
- coverage_enabled = True
-
- def _iter_covered_files(mod, recursive=True):
-
- if recursive:
- ff = os.walk
- else:
- ff = os.listdir
-
- for rec in ff(os.path.dirname(mod.__file__)):
- for x in rec[2]:
- if x.endswith('.py'):
- yield os.path.join(rec[0], x)
-
- def _stop():
- coverage.stop()
- true_out.write("\nPreparing coverage report...\n")
-
- from sqlalchemy import sql, orm, engine, \
- ext, databases, log
-
- import sqlalchemy
-
- for modset in [
- _iter_covered_files(sqlalchemy, recursive=False),
- _iter_covered_files(databases),
- _iter_covered_files(engine),
- _iter_covered_files(ext),
- _iter_covered_files(orm),
- ]:
- coverage.report(list(modset),
- show_missing=False, ignore_errors=False,
- file=true_out)
- atexit.register(_stop)
- if erase:
- coverage.erase()
- coverage.start()
-
-def _list_dbs(*args):
- print "Available --db options (use --dburi to override)"
- for macro in sorted(file_config.options('db')):
- print "%20s\t%s" % (macro, file_config.get('db', macro))
- sys.exit(0)
-
-def _server_side_cursors(options, opt_str, value, parser):
- db_opts['server_side_cursors'] = True
-
-def _engine_strategy(options, opt_str, value, parser):
- if value:
- db_opts['strategy'] = value
-
-opt = parser.add_option
-opt("--verbose", action="store_true", dest="verbose",
- help="enable stdout echoing/printing")
-opt("--quiet", action="store_true", dest="quiet", help="suppress output")
-opt("--log-info", action="callback", type="string", callback=_log,
- help="turn on info logging for <LOG> (multiple OK)")
-opt("--log-debug", action="callback", type="string", callback=_log,
- help="turn on debug logging for <LOG> (multiple OK)")
-opt("--require", action="append", dest="require", default=[],
- help="require a particular driver or module version (multiple OK)")
-opt("--db", action="store", dest="db", default="sqlite",
- help="Use prefab database uri")
-opt('--dbs', action='callback', callback=_list_dbs,
- help="List available prefab dbs")
-opt("--dburi", action="store", dest="dburi",
- help="Database uri (overrides --db)")
-opt("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
-opt("--mockpool", action="store_true", dest="mockpool",
- help="Use mock pool (asserts only one connection used)")
-opt("--enginestrategy", action="callback", type="string",
- callback=_engine_strategy,
- help="Engine strategy (plain or threadlocal, defaults to plain)")
-opt("--reversetop", action="store_true", dest="reversetop", default=False,
- help="Reverse the collection ordering for topological sorts (helps "
- "reveal dependency issues)")
-opt("--unhashable", action="store_true", dest="unhashable", default=False,
- help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
-opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
- help="Disallow SQLAlchemy from performing == on mapped test objects.")
-opt("--truthless", action="store_true", dest="truthless", default=False,
- help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
-opt("--serverside", action="callback", callback=_server_side_cursors,
- help="Turn on server side cursors for PG")
-opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
- help="Use the specified MySQL storage engine for all tables, default is "
- "a db-default/InnoDB combo.")
-opt("--table-option", action="append", dest="tableopts", default=[],
- help="Add a dialect-specific table option, key=value")
-opt("--coverage", action="callback", callback=_start_coverage,
- help="Dump a full coverage report after running tests")
-opt("--cumulative-coverage", action="callback", callback=_start_cumulative_coverage,
- help="Like --coverage, but accumlate coverage into the current DB")
-opt("--profile", action="append", dest="profile_targets", default=[],
- help="Enable a named profile target (multiple OK.)")
-opt("--profile-sort", action="store", dest="profile_sort", default=None,
- help="Sort profile stats with this comma-separated sort order")
-opt("--profile-limit", type="int", action="store", dest="profile_limit",
- default=None,
- help="Limit function count in profile stats")
-
-class _ordered_map(object):
- def __init__(self):
- self._keys = list()
- self._data = dict()
-
- def __setitem__(self, key, value):
- if key not in self._keys:
- self._keys.append(key)
- self._data[key] = value
-
- def __iter__(self):
- for key in self._keys:
- yield self._data[key]
-
-# at one point in refactoring, modules were injecting into the config
-# process. this could probably just become a list now.
-post_configure = _ordered_map()
-
-def _engine_uri(options, file_config):
- global db_label, db_url
- db_label = 'sqlite'
- if options.dburi:
- db_url = options.dburi
- db_label = db_url[:db_url.index(':')]
- elif options.db:
- db_label = options.db
- db_url = None
-
- if db_url is None:
- if db_label not in file_config.options('db'):
- raise RuntimeError(
- "Unknown engine. Specify --dbs for known engines.")
- db_url = file_config.get('db', db_label)
-post_configure['engine_uri'] = _engine_uri
-
-def _require(options, file_config):
- if not(options.require or
- (file_config.has_section('require') and
- file_config.items('require'))):
- return
-
- try:
- import pkg_resources
- except ImportError:
- raise RuntimeError("setuptools is required for version requirements")
-
- cmdline = []
- for requirement in options.require:
- pkg_resources.require(requirement)
- cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
-
- if file_config.has_section('require'):
- for label, requirement in file_config.items('require'):
- if not label == db_label or label.startswith('%s.' % db_label):
- continue
- seen = [c for c in cmdline if requirement.startswith(c)]
- if seen:
- continue
- pkg_resources.require(requirement)
-post_configure['require'] = _require
-
-def _engine_pool(options, file_config):
- if options.mockpool:
- from sqlalchemy import pool
- db_opts['poolclass'] = pool.AssertionPool
-post_configure['engine_pool'] = _engine_pool
-
-def _create_testing_engine(options, file_config):
- from testlib import engines, testing
- global db
- db = engines.testing_engine(db_url, db_opts)
- testing.db = db
-post_configure['create_engine'] = _create_testing_engine
-
-def _prep_testing_database(options, file_config):
- from testlib import engines
- from sqlalchemy import schema
-
- try:
- # also create alt schemas etc. here?
- if options.dropfirst:
- e = engines.utf8_engine()
- existing = e.table_names()
- if existing:
- if not options.quiet:
- print "Dropping existing tables in database: " + db_url
- try:
- print "Tables: %s" % ', '.join(existing)
- except:
- pass
- print "Abort within 5 seconds..."
- time.sleep(5)
- md = schema.MetaData(e, reflect=True)
- md.drop_all()
- e.dispose()
- except (KeyboardInterrupt, SystemExit):
- raise
- except Exception, e:
- if not options.quiet:
- warnings.warn(RuntimeWarning(
- "Error checking for existing tables in testing "
- "database: %s" % e))
-post_configure['prep_db'] = _prep_testing_database
-
-def _set_table_options(options, file_config):
- import testlib.schema
-
- table_options = testlib.schema.table_options
- for spec in options.tableopts:
- key, value = spec.split('=')
- table_options[key] = value
-
- if options.mysql_engine:
- table_options['mysql_engine'] = options.mysql_engine
-post_configure['table_options'] = _set_table_options
-
-def _reverse_topological(options, file_config):
- if options.reversetop:
- from sqlalchemy.orm import unitofwork
- from sqlalchemy import topological
- class RevQueueDepSort(topological.QueueDependencySorter):
- def __init__(self, tuples, allitems):
- self.tuples = list(tuples)
- self.allitems = list(allitems)
- self.tuples.reverse()
- self.allitems.reverse()
- topological.QueueDependencySorter = RevQueueDepSort
- unitofwork.DependencySorter = RevQueueDepSort
-post_configure['topological'] = _reverse_topological
-
-def _set_profile_targets(options, file_config):
- from testlib import profiling
-
- profile_config = profiling.profile_config
-
- for target in options.profile_targets:
- profile_config['targets'].add(target)
-
- if options.profile_sort:
- profile_config['sort'] = options.profile_sort.split(',')
-
- if options.profile_limit:
- profile_config['limit'] = options.profile_limit
-
- if options.quiet:
- profile_config['report'] = False
-
- # magic "all" target
- if 'all' in profiling.all_targets:
- targets = profile_config['targets']
- if 'all' in targets and len(targets) != 1:
- targets.clear()
- targets.add('all')
-post_configure['profile_targets'] = _set_profile_targets
diff --git a/test/testlib/coverage.py b/test/testlib/coverage.py
deleted file mode 100644
index fc0f2c236..000000000
--- a/test/testlib/coverage.py
+++ /dev/null
@@ -1,1098 +0,0 @@
-#!/usr/bin/python
-#
-# Perforce Defect Tracking Integration Project
-# <http://www.ravenbrook.com/project/p4dti/>
-#
-# COVERAGE.PY -- COVERAGE TESTING
-#
-# Gareth Rees, Ravenbrook Limited, 2001-12-04
-# Ned Batchelder, 2004-12-12
-# http://nedbatchelder.com/code/modules/coverage.html
-#
-#
-# 1. INTRODUCTION
-#
-# This module provides coverage testing for Python code.
-#
-# The intended readership is all Python developers.
-#
-# This document is not confidential.
-#
-# See [GDR 2001-12-04a] for the command-line interface, programmatic
-# interface and limitations. See [GDR 2001-12-04b] for requirements and
-# design.
-
-r"""\
-Usage:
-
-coverage.py -x [-p] MODULE.py [ARG1 ARG2 ...]
- Execute module, passing the given command-line arguments, collecting
- coverage data. With the -p option, write to a temporary file containing
- the machine name and process ID.
-
-coverage.py -e
- Erase collected coverage data.
-
-coverage.py -c
- Collect data from multiple coverage files (as created by -p option above)
- and store it into a single file representing the union of the coverage.
-
-coverage.py -r [-m] [-o dir1,dir2,...] FILE1 FILE2 ...
- Report on the statement coverage for the given files. With the -m
- option, show line numbers of the statements that weren't executed.
-
-coverage.py -a [-d dir] [-o dir1,dir2,...] FILE1 FILE2 ...
- Make annotated copies of the given files, marking statements that
- are executed with > and statements that are missed with !. With
- the -d option, make the copies in that directory. Without the -d
- option, make each copy in the same directory as the original.
-
--o dir,dir2,...
- Omit reporting or annotating files when their filename path starts with
- a directory listed in the omit list.
- e.g. python coverage.py -i -r -o c:\python23,lib\enthought\traits
-
-Coverage data is saved in the file .coverage by default. Set the
-COVERAGE_FILE environment variable to save it somewhere else."""
-
-__version__ = "2.75.20070722" # see detailed history at the end of this file.
-
-import compiler
-import compiler.visitor
-import glob
-import os
-import re
-import string
-import symbol
-import sys
-import threading
-import token
-import types
-from socket import gethostname
-
-
-# 2. IMPLEMENTATION
-#
-# This uses the "singleton" pattern.
-#
-# The word "morf" means a module object (from which the source file can
-# be deduced by suitable manipulation of the __file__ attribute) or a
-# filename.
-#
-# When we generate a coverage report we have to canonicalize every
-# filename in the coverage dictionary just in case it refers to the
-# module we are reporting on. It seems a shame to throw away this
-# information so the data in the coverage dictionary is transferred to
-# the 'cexecuted' dictionary under the canonical filenames.
-#
-# The coverage dictionary is called "c" and the trace function "t". The
-# reason for these short names is that Python looks up variables by name
-# at runtime and so execution time depends on the length of variables!
-# In the bottleneck of this application it's appropriate to abbreviate
-# names to increase speed.
-
-class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
- """ A visitor for a parsed Abstract Syntax Tree which finds executable
- statements.
- """
- def __init__(self, statements, excluded, suite_spots):
- compiler.visitor.ASTVisitor.__init__(self)
- self.statements = statements
- self.excluded = excluded
- self.suite_spots = suite_spots
- self.excluding_suite = 0
-
- def doRecursive(self, node):
- for n in node.getChildNodes():
- self.dispatch(n)
-
- visitStmt = visitModule = doRecursive
-
- def doCode(self, node):
- if hasattr(node, 'decorators') and node.decorators:
- self.dispatch(node.decorators)
- self.recordAndDispatch(node.code)
- else:
- self.doSuite(node, node.code)
-
- visitFunction = visitClass = doCode
-
- def getFirstLine(self, node):
- # Find the first line in the tree node.
- lineno = node.lineno
- for n in node.getChildNodes():
- f = self.getFirstLine(n)
- if lineno and f:
- lineno = min(lineno, f)
- else:
- lineno = lineno or f
- return lineno
-
- def getLastLine(self, node):
- # Find the first line in the tree node.
- lineno = node.lineno
- for n in node.getChildNodes():
- lineno = max(lineno, self.getLastLine(n))
- return lineno
-
- def doStatement(self, node):
- self.recordLine(self.getFirstLine(node))
-
- visitAssert = visitAssign = visitAssTuple = visitPrint = \
- visitPrintnl = visitRaise = visitSubscript = visitDecorators = \
- doStatement
-
- def visitPass(self, node):
- # Pass statements have weird interactions with docstrings. If this
- # pass statement is part of one of those pairs, claim that the statement
- # is on the later of the two lines.
- l = node.lineno
- if l:
- lines = self.suite_spots.get(l, [l,l])
- self.statements[lines[1]] = 1
-
- def visitDiscard(self, node):
- # Discard nodes are statements that execute an expression, but then
- # discard the results. This includes function calls, so we can't
- # ignore them all. But if the expression is a constant, the statement
- # won't be "executed", so don't count it now.
- if node.expr.__class__.__name__ != 'Const':
- self.doStatement(node)
-
- def recordNodeLine(self, node):
- # Stmt nodes often have None, but shouldn't claim the first line of
- # their children (because the first child might be an ignorable line
- # like "global a").
- if node.__class__.__name__ != 'Stmt':
- return self.recordLine(self.getFirstLine(node))
- else:
- return 0
-
- def recordLine(self, lineno):
- # Returns a bool, whether the line is included or excluded.
- if lineno:
- # Multi-line tests introducing suites have to get charged to their
- # keyword.
- if lineno in self.suite_spots:
- lineno = self.suite_spots[lineno][0]
- # If we're inside an excluded suite, record that this line was
- # excluded.
- if self.excluding_suite:
- self.excluded[lineno] = 1
- return 0
- # If this line is excluded, or suite_spots maps this line to
- # another line that is exlcuded, then we're excluded.
- elif self.excluded.has_key(lineno) or \
- self.suite_spots.has_key(lineno) and \
- self.excluded.has_key(self.suite_spots[lineno][1]):
- return 0
- # Otherwise, this is an executable line.
- else:
- self.statements[lineno] = 1
- return 1
- return 0
-
- default = recordNodeLine
-
- def recordAndDispatch(self, node):
- self.recordNodeLine(node)
- self.dispatch(node)
-
- def doSuite(self, intro, body, exclude=0):
- exsuite = self.excluding_suite
- if exclude or (intro and not self.recordNodeLine(intro)):
- self.excluding_suite = 1
- self.recordAndDispatch(body)
- self.excluding_suite = exsuite
-
- def doPlainWordSuite(self, prevsuite, suite):
- # Finding the exclude lines for else's is tricky, because they aren't
- # present in the compiler parse tree. Look at the previous suite,
- # and find its last line. If any line between there and the else's
- # first line are excluded, then we exclude the else.
- lastprev = self.getLastLine(prevsuite)
- firstelse = self.getFirstLine(suite)
- for l in range(lastprev+1, firstelse):
- if self.suite_spots.has_key(l):
- self.doSuite(None, suite, exclude=self.excluded.has_key(l))
- break
- else:
- self.doSuite(None, suite)
-
- def doElse(self, prevsuite, node):
- if node.else_:
- self.doPlainWordSuite(prevsuite, node.else_)
-
- def visitFor(self, node):
- self.doSuite(node, node.body)
- self.doElse(node.body, node)
-
- visitWhile = visitFor
-
- def visitIf(self, node):
- # The first test has to be handled separately from the rest.
- # The first test is credited to the line with the "if", but the others
- # are credited to the line with the test for the elif.
- self.doSuite(node, node.tests[0][1])
- for t, n in node.tests[1:]:
- self.doSuite(t, n)
- self.doElse(node.tests[-1][1], node)
-
- def visitTryExcept(self, node):
- self.doSuite(node, node.body)
- for i in range(len(node.handlers)):
- a, b, h = node.handlers[i]
- if not a:
- # It's a plain "except:". Find the previous suite.
- if i > 0:
- prev = node.handlers[i-1][2]
- else:
- prev = node.body
- self.doPlainWordSuite(prev, h)
- else:
- self.doSuite(a, h)
- self.doElse(node.handlers[-1][2], node)
-
- def visitTryFinally(self, node):
- self.doSuite(node, node.body)
- self.doPlainWordSuite(node.body, node.final)
-
- def visitGlobal(self, node):
- # "global" statements don't execute like others (they don't call the
- # trace function), so don't record their line numbers.
- pass
-
-the_coverage = None
-
-class CoverageException(Exception): pass
-
-class coverage:
- # Name of the cache file (unless environment variable is set).
- cache_default = ".coverage"
-
- # Environment variable naming the cache file.
- cache_env = "COVERAGE_FILE"
-
- # A dictionary with an entry for (Python source file name, line number
- # in that file) if that line has been executed.
- c = {}
-
- # A map from canonical Python source file name to a dictionary in
- # which there's an entry for each line number that has been
- # executed.
- cexecuted = {}
-
- # Cache of results of calling the analysis2() method, so that you can
- # specify both -r and -a without doing double work.
- analysis_cache = {}
-
- # Cache of results of calling the canonical_filename() method, to
- # avoid duplicating work.
- canonical_filename_cache = {}
-
- def __init__(self):
- global the_coverage
- if the_coverage:
- raise CoverageException, "Only one coverage object allowed."
- self.usecache = 1
- self.cache = None
- self.parallel_mode = False
- self.exclude_re = ''
- self.nesting = 0
- self.cstack = []
- self.xstack = []
- self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.sep)
- self.exclude('# *pragma[: ]*[nN][oO] *[cC][oO][vV][eE][rR]')
-
- # t(f, x, y). This method is passed to sys.settrace as a trace function.
- # See [van Rossum 2001-07-20b, 9.2] for an explanation of sys.settrace and
- # the arguments and return value of the trace function.
- # See [van Rossum 2001-07-20a, 3.2] for a description of frame and code
- # objects.
-
- def t(self, f, w, unused): #pragma: no cover
- if w == 'line':
- #print "Executing %s @ %d" % (f.f_code.co_filename, f.f_lineno)
- self.c[(f.f_code.co_filename, f.f_lineno)] = 1
- for c in self.cstack:
- c[(f.f_code.co_filename, f.f_lineno)] = 1
- return self.t
-
- def help(self, error=None): #pragma: no cover
- if error:
- print error
- print
- print __doc__
- sys.exit(1)
-
- def command_line(self, argv, help_fn=None):
- import getopt
- help_fn = help_fn or self.help
- settings = {}
- optmap = {
- '-a': 'annotate',
- '-c': 'collect',
- '-d:': 'directory=',
- '-e': 'erase',
- '-h': 'help',
- '-i': 'ignore-errors',
- '-m': 'show-missing',
- '-p': 'parallel-mode',
- '-r': 'report',
- '-x': 'execute',
- '-o:': 'omit=',
- }
- short_opts = string.join(map(lambda o: o[1:], optmap.keys()), '')
- long_opts = optmap.values()
- options, args = getopt.getopt(argv, short_opts, long_opts)
- for o, a in options:
- if optmap.has_key(o):
- settings[optmap[o]] = 1
- elif optmap.has_key(o + ':'):
- settings[optmap[o + ':']] = a
- elif o[2:] in long_opts:
- settings[o[2:]] = 1
- elif o[2:] + '=' in long_opts:
- settings[o[2:]+'='] = a
- else: #pragma: no cover
- pass # Can't get here, because getopt won't return anything unknown.
-
- if settings.get('help'):
- help_fn()
-
- for i in ['erase', 'execute']:
- for j in ['annotate', 'report', 'collect']:
- if settings.get(i) and settings.get(j):
- help_fn("You can't specify the '%s' and '%s' "
- "options at the same time." % (i, j))
-
- args_needed = (settings.get('execute')
- or settings.get('annotate')
- or settings.get('report'))
- action = (settings.get('erase')
- or settings.get('collect')
- or args_needed)
- if not action:
- help_fn("You must specify at least one of -e, -x, -c, -r, or -a.")
- if not args_needed and args:
- help_fn("Unexpected arguments: %s" % " ".join(args))
-
- self.parallel_mode = settings.get('parallel-mode')
- self.get_ready()
-
- if settings.get('erase'):
- self.erase()
- if settings.get('execute'):
- if not args:
- help_fn("Nothing to do.")
- sys.argv = args
- self.start()
- import __main__
- sys.path[0] = os.path.dirname(sys.argv[0])
- execfile(sys.argv[0], __main__.__dict__)
- if settings.get('collect'):
- self.collect()
- if not args:
- args = self.cexecuted.keys()
-
- ignore_errors = settings.get('ignore-errors')
- show_missing = settings.get('show-missing')
- directory = settings.get('directory=')
-
- omit = settings.get('omit=')
- if omit is not None:
- omit = omit.split(',')
- else:
- omit = []
-
- if settings.get('report'):
- self.report(args, show_missing, ignore_errors, omit_prefixes=omit)
- if settings.get('annotate'):
- self.annotate(args, directory, ignore_errors, omit_prefixes=omit)
-
- def use_cache(self, usecache, cache_file=None):
- self.usecache = usecache
- if cache_file and not self.cache:
- self.cache_default = cache_file
-
- def get_ready(self, parallel_mode=False):
- if self.usecache and not self.cache:
- self.cache = os.environ.get(self.cache_env, self.cache_default)
- if self.parallel_mode:
- self.cache += "." + gethostname() + "." + str(os.getpid())
- self.restore()
- self.analysis_cache = {}
-
- def start(self, parallel_mode=False):
- self.get_ready()
- if self.nesting == 0: #pragma: no cover
- sys.settrace(self.t)
- if hasattr(threading, 'settrace'):
- threading.settrace(self.t)
- self.nesting += 1
-
- def stop(self):
- self.nesting -= 1
- if self.nesting == 0: #pragma: no cover
- sys.settrace(None)
- if hasattr(threading, 'settrace'):
- threading.settrace(None)
-
- def erase(self):
- self.get_ready()
- self.c = {}
- self.analysis_cache = {}
- self.cexecuted = {}
- if self.cache and os.path.exists(self.cache):
- os.remove(self.cache)
-
- def exclude(self, re):
- if self.exclude_re:
- self.exclude_re += "|"
- self.exclude_re += "(" + re + ")"
-
- def begin_recursive(self):
- self.cstack.append(self.c)
- self.xstack.append(self.exclude_re)
-
- def end_recursive(self):
- self.c = self.cstack.pop()
- self.exclude_re = self.xstack.pop()
-
- # save(). Save coverage data to the coverage cache.
-
- def save(self):
- if self.usecache and self.cache:
- self.canonicalize_filenames()
- cache = open(self.cache, 'wb')
- import marshal
- marshal.dump(self.cexecuted, cache)
- cache.close()
-
- # restore(). Restore coverage data from the coverage cache (if it exists).
-
- def restore(self):
- self.c = {}
- self.cexecuted = {}
- assert self.usecache
- if os.path.exists(self.cache):
- self.cexecuted = self.restore_file(self.cache)
-
- def restore_file(self, file_name):
- try:
- cache = open(file_name, 'rb')
- import marshal
- cexecuted = marshal.load(cache)
- cache.close()
- if isinstance(cexecuted, types.DictType):
- return cexecuted
- else:
- return {}
- except:
- return {}
-
- # collect(). Collect data in multiple files produced by parallel mode
-
- def collect(self):
- cache_dir, local = os.path.split(self.cache)
- for f in os.listdir(cache_dir or '.'):
- if not f.startswith(local):
- continue
-
- full_path = os.path.join(cache_dir, f)
- cexecuted = self.restore_file(full_path)
- self.merge_data(cexecuted)
-
- def merge_data(self, new_data):
- for file_name, file_data in new_data.items():
- if self.cexecuted.has_key(file_name):
- self.merge_file_data(self.cexecuted[file_name], file_data)
- else:
- self.cexecuted[file_name] = file_data
-
- def merge_file_data(self, cache_data, new_data):
- for line_number in new_data.keys():
- if not cache_data.has_key(line_number):
- cache_data[line_number] = new_data[line_number]
-
- # canonical_filename(filename). Return a canonical filename for the
- # file (that is, an absolute path with no redundant components and
- # normalized case). See [GDR 2001-12-04b, 3.3].
-
- def canonical_filename(self, filename):
- if not self.canonical_filename_cache.has_key(filename):
- f = filename
- if os.path.isabs(f) and not os.path.exists(f):
- f = os.path.basename(f)
- if not os.path.isabs(f):
- for path in [os.curdir] + sys.path:
- g = os.path.join(path, f)
- if os.path.exists(g):
- f = g
- break
- cf = os.path.normcase(os.path.abspath(f))
- self.canonical_filename_cache[filename] = cf
- return self.canonical_filename_cache[filename]
-
- # canonicalize_filenames(). Copy results from "c" to "cexecuted",
- # canonicalizing filenames on the way. Clear the "c" map.
-
- def canonicalize_filenames(self):
- for filename, lineno in self.c.keys():
- if filename == '<string>':
- # Can't do anything useful with exec'd strings, so skip them.
- continue
- f = self.canonical_filename(filename)
- if not self.cexecuted.has_key(f):
- self.cexecuted[f] = {}
- self.cexecuted[f][lineno] = 1
- self.c = {}
-
- # morf_filename(morf). Return the filename for a module or file.
-
- def morf_filename(self, morf):
- if isinstance(morf, types.ModuleType):
- if not hasattr(morf, '__file__'):
- raise CoverageException, "Module has no __file__ attribute."
- f = morf.__file__
- else:
- f = morf
- return self.canonical_filename(f)
-
- # analyze_morf(morf). Analyze the module or filename passed as
- # the argument. If the source code can't be found, raise an error.
- # Otherwise, return a tuple of (1) the canonical filename of the
- # source code for the module, (2) a list of lines of statements
- # in the source code, (3) a list of lines of excluded statements,
- # and (4), a map of line numbers to multi-line line number ranges, for
- # statements that cross lines.
-
- def analyze_morf(self, morf):
- if self.analysis_cache.has_key(morf):
- return self.analysis_cache[morf]
- filename = self.morf_filename(morf)
- ext = os.path.splitext(filename)[1]
- if ext == '.pyc':
- if not os.path.exists(filename[0:-1]):
- raise CoverageException, ("No source for compiled code '%s'."
- % filename)
- filename = filename[0:-1]
- elif ext != '.py':
- raise CoverageException, "File '%s' not Python source." % filename
- source = open(filename, 'r')
- lines, excluded_lines, line_map = self.find_executable_statements(
- source.read(), exclude=self.exclude_re
- )
- source.close()
- result = filename, lines, excluded_lines, line_map
- self.analysis_cache[morf] = result
- return result
-
- def first_line_of_tree(self, tree):
- while True:
- if len(tree) == 3 and type(tree[2]) == type(1):
- return tree[2]
- tree = tree[1]
-
- def last_line_of_tree(self, tree):
- while True:
- if len(tree) == 3 and type(tree[2]) == type(1):
- return tree[2]
- tree = tree[-1]
-
- def find_docstring_pass_pair(self, tree, spots):
- for i in range(1, len(tree)):
- if self.is_string_constant(tree[i]) and self.is_pass_stmt(tree[i+1]):
- first_line = self.first_line_of_tree(tree[i])
- last_line = self.last_line_of_tree(tree[i+1])
- self.record_multiline(spots, first_line, last_line)
-
- def is_string_constant(self, tree):
- try:
- return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.expr_stmt
- except:
- return False
-
- def is_pass_stmt(self, tree):
- try:
- return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.pass_stmt
- except:
- return False
-
- def record_multiline(self, spots, i, j):
- for l in range(i, j+1):
- spots[l] = (i, j)
-
- def get_suite_spots(self, tree, spots):
- """ Analyze a parse tree to find suite introducers which span a number
- of lines.
- """
- for i in range(1, len(tree)):
- if type(tree[i]) == type(()):
- if tree[i][0] == symbol.suite:
- # Found a suite, look back for the colon and keyword.
- lineno_colon = lineno_word = None
- for j in range(i-1, 0, -1):
- if tree[j][0] == token.COLON:
- # Colons are never executed themselves: we want the
- # line number of the last token before the colon.
- lineno_colon = self.last_line_of_tree(tree[j-1])
- elif tree[j][0] == token.NAME:
- if tree[j][1] == 'elif':
- # Find the line number of the first non-terminal
- # after the keyword.
- t = tree[j+1]
- while t and token.ISNONTERMINAL(t[0]):
- t = t[1]
- if t:
- lineno_word = t[2]
- else:
- lineno_word = tree[j][2]
- break
- elif tree[j][0] == symbol.except_clause:
- # "except" clauses look like:
- # ('except_clause', ('NAME', 'except', lineno), ...)
- if tree[j][1][0] == token.NAME:
- lineno_word = tree[j][1][2]
- break
- if lineno_colon and lineno_word:
- # Found colon and keyword, mark all the lines
- # between the two with the two line numbers.
- self.record_multiline(spots, lineno_word, lineno_colon)
-
- # "pass" statements are tricky: different versions of Python
- # treat them differently, especially in the common case of a
- # function with a doc string and a single pass statement.
- self.find_docstring_pass_pair(tree[i], spots)
-
- elif tree[i][0] == symbol.simple_stmt:
- first_line = self.first_line_of_tree(tree[i])
- last_line = self.last_line_of_tree(tree[i])
- if first_line != last_line:
- self.record_multiline(spots, first_line, last_line)
- self.get_suite_spots(tree[i], spots)
-
- def find_executable_statements(self, text, exclude=None):
- # Find lines which match an exclusion pattern.
- excluded = {}
- suite_spots = {}
- if exclude:
- reExclude = re.compile(exclude)
- lines = text.split('\n')
- for i in range(len(lines)):
- if reExclude.search(lines[i]):
- excluded[i+1] = 1
-
- # Parse the code and analyze the parse tree to find out which statements
- # are multiline, and where suites begin and end.
- import parser
- tree = parser.suite(text+'\n\n').totuple(1)
- self.get_suite_spots(tree, suite_spots)
- #print "Suite spots:", suite_spots
-
- # Use the compiler module to parse the text and find the executable
- # statements. We add newlines to be impervious to final partial lines.
- statements = {}
- ast = compiler.parse(text+'\n\n')
- visitor = StatementFindingAstVisitor(statements, excluded, suite_spots)
- compiler.walk(ast, visitor, walker=visitor)
-
- lines = statements.keys()
- lines.sort()
- excluded_lines = excluded.keys()
- excluded_lines.sort()
- return lines, excluded_lines, suite_spots
-
- # format_lines(statements, lines). Format a list of line numbers
- # for printing by coalescing groups of lines as long as the lines
- # represent consecutive statements. This will coalesce even if
- # there are gaps between statements, so if statements =
- # [1,2,3,4,5,10,11,12,13,14] and lines = [1,2,5,10,11,13,14] then
- # format_lines will return "1-2, 5-11, 13-14".
-
- def format_lines(self, statements, lines):
- pairs = []
- i = 0
- j = 0
- start = None
- pairs = []
- while i < len(statements) and j < len(lines):
- if statements[i] == lines[j]:
- if start == None:
- start = lines[j]
- end = lines[j]
- j = j + 1
- elif start:
- pairs.append((start, end))
- start = None
- i = i + 1
- if start:
- pairs.append((start, end))
- def stringify(pair):
- start, end = pair
- if start == end:
- return "%d" % start
- else:
- return "%d-%d" % (start, end)
- ret = string.join(map(stringify, pairs), ", ")
- return ret
-
- # Backward compatibility with version 1.
- def analysis(self, morf):
- f, s, _, m, mf = self.analysis2(morf)
- return f, s, m, mf
-
- def analysis2(self, morf):
- filename, statements, excluded, line_map = self.analyze_morf(morf)
- self.canonicalize_filenames()
- if not self.cexecuted.has_key(filename):
- self.cexecuted[filename] = {}
- missing = []
- for line in statements:
- lines = line_map.get(line, [line, line])
- for l in range(lines[0], lines[1]+1):
- if self.cexecuted[filename].has_key(l):
- break
- else:
- missing.append(line)
- return (filename, statements, excluded, missing,
- self.format_lines(statements, missing))
-
- def relative_filename(self, filename):
- """ Convert filename to relative filename from self.relative_dir.
- """
- return filename.replace(self.relative_dir, "")
-
- def morf_name(self, morf):
- """ Return the name of morf as used in report.
- """
- if isinstance(morf, types.ModuleType):
- return morf.__name__
- else:
- return self.relative_filename(os.path.splitext(morf)[0])
-
- def filter_by_prefix(self, morfs, omit_prefixes):
- """ Return list of morfs where the morf name does not begin
- with any one of the omit_prefixes.
- """
- filtered_morfs = []
- for morf in morfs:
- for prefix in omit_prefixes:
- if self.morf_name(morf).startswith(prefix):
- break
- else:
- filtered_morfs.append(morf)
-
- return filtered_morfs
-
- def morf_name_compare(self, x, y):
- return cmp(self.morf_name(x), self.morf_name(y))
-
- def report(self, morfs, show_missing=1, ignore_errors=0, file=None, omit_prefixes=[]):
- if not isinstance(morfs, types.ListType):
- morfs = [morfs]
- # On windows, the shell doesn't expand wildcards. Do it here.
- globbed = []
- for morf in morfs:
- if isinstance(morf, basestring):
- globbed.extend(glob.glob(morf))
- else:
- globbed.append(morf)
- morfs = globbed
-
- morfs = self.filter_by_prefix(morfs, omit_prefixes)
- morfs.sort(self.morf_name_compare)
-
- max_name = max([5,] + map(len, map(self.morf_name, morfs)))
- fmt_name = "%%- %ds " % max_name
- fmt_err = fmt_name + "%s: %s"
- header = fmt_name % "Name" + " Stmts Exec Cover"
- fmt_coverage = fmt_name + "% 6d % 6d % 5d%%"
- if show_missing:
- header = header + " Missing"
- fmt_coverage = fmt_coverage + " %s"
- if not file:
- file = sys.stdout
- print >>file, header
- print >>file, "-" * len(header)
- total_statements = 0
- total_executed = 0
- for morf in morfs:
- name = self.morf_name(morf)
- try:
- _, statements, _, missing, readable = self.analysis2(morf)
- n = len(statements)
- m = n - len(missing)
- if n > 0:
- pc = 100.0 * m / n
- else:
- pc = 100.0
- args = (name, n, m, pc)
- if show_missing:
- args = args + (readable,)
- print >>file, fmt_coverage % args
- total_statements = total_statements + n
- total_executed = total_executed + m
- except KeyboardInterrupt: #pragma: no cover
- raise
- except:
- if not ignore_errors:
- typ, msg = sys.exc_info()[0:2]
- print >>file, fmt_err % (name, typ, msg)
- if len(morfs) > 1:
- print >>file, "-" * len(header)
- if total_statements > 0:
- pc = 100.0 * total_executed / total_statements
- else:
- pc = 100.0
- args = ("TOTAL", total_statements, total_executed, pc)
- if show_missing:
- args = args + ("",)
- print >>file, fmt_coverage % args
-
- # annotate(morfs, ignore_errors).
-
- blank_re = re.compile(r"\s*(#|$)")
- else_re = re.compile(r"\s*else\s*:\s*(#|$)")
-
- def annotate(self, morfs, directory=None, ignore_errors=0, omit_prefixes=[]):
- morfs = self.filter_by_prefix(morfs, omit_prefixes)
- for morf in morfs:
- try:
- filename, statements, excluded, missing, _ = self.analysis2(morf)
- self.annotate_file(filename, statements, excluded, missing, directory)
- except KeyboardInterrupt:
- raise
- except:
- if not ignore_errors:
- raise
-
- def annotate_file(self, filename, statements, excluded, missing, directory=None):
- source = open(filename, 'r')
- if directory:
- dest_file = os.path.join(directory,
- os.path.basename(filename)
- + ',cover')
- else:
- dest_file = filename + ',cover'
- dest = open(dest_file, 'w')
- lineno = 0
- i = 0
- j = 0
- covered = 1
- while 1:
- line = source.readline()
- if line == '':
- break
- lineno = lineno + 1
- while i < len(statements) and statements[i] < lineno:
- i = i + 1
- while j < len(missing) and missing[j] < lineno:
- j = j + 1
- if i < len(statements) and statements[i] == lineno:
- covered = j >= len(missing) or missing[j] > lineno
- if self.blank_re.match(line):
- dest.write(' ')
- elif self.else_re.match(line):
- # Special logic for lines containing only 'else:'.
- # See [GDR 2001-12-04b, 3.2].
- if i >= len(statements) and j >= len(missing):
- dest.write('! ')
- elif i >= len(statements) or j >= len(missing):
- dest.write('> ')
- elif statements[i] == missing[j]:
- dest.write('! ')
- else:
- dest.write('> ')
- elif lineno in excluded:
- dest.write('- ')
- elif covered:
- dest.write('> ')
- else:
- dest.write('! ')
- dest.write(line)
- source.close()
- dest.close()
-
-# Singleton object.
-the_coverage = coverage()
-
-# Module functions call methods in the singleton object.
-def use_cache(*args, **kw):
- return the_coverage.use_cache(*args, **kw)
-
-def start(*args, **kw):
- return the_coverage.start(*args, **kw)
-
-def stop(*args, **kw):
- return the_coverage.stop(*args, **kw)
-
-def erase(*args, **kw):
- return the_coverage.erase(*args, **kw)
-
-def begin_recursive(*args, **kw):
- return the_coverage.begin_recursive(*args, **kw)
-
-def end_recursive(*args, **kw):
- return the_coverage.end_recursive(*args, **kw)
-
-def exclude(*args, **kw):
- return the_coverage.exclude(*args, **kw)
-
-def analysis(*args, **kw):
- return the_coverage.analysis(*args, **kw)
-
-def analysis2(*args, **kw):
- return the_coverage.analysis2(*args, **kw)
-
-def report(*args, **kw):
- return the_coverage.report(*args, **kw)
-
-def annotate(*args, **kw):
- return the_coverage.annotate(*args, **kw)
-
-def annotate_file(*args, **kw):
- return the_coverage.annotate_file(*args, **kw)
-
-# Save coverage data when Python exits. (The atexit module wasn't
-# introduced until Python 2.0, so use sys.exitfunc when it's not
-# available.)
-try:
- import atexit
- atexit.register(the_coverage.save)
-except ImportError:
- sys.exitfunc = the_coverage.save
-
-# Command-line interface.
-if __name__ == '__main__':
- the_coverage.command_line(sys.argv[1:])
-
-
-# A. REFERENCES
-#
-# [GDR 2001-12-04a] "Statement coverage for Python"; Gareth Rees;
-# Ravenbrook Limited; 2001-12-04;
-# <http://www.nedbatchelder.com/code/modules/rees-coverage.html>.
-#
-# [GDR 2001-12-04b] "Statement coverage for Python: design and
-# analysis"; Gareth Rees; Ravenbrook Limited; 2001-12-04;
-# <http://www.nedbatchelder.com/code/modules/rees-design.html>.
-#
-# [van Rossum 2001-07-20a] "Python Reference Manual (releae 2.1.1)";
-# Guide van Rossum; 2001-07-20;
-# <http://www.python.org/doc/2.1.1/ref/ref.html>.
-#
-# [van Rossum 2001-07-20b] "Python Library Reference"; Guido van Rossum;
-# 2001-07-20; <http://www.python.org/doc/2.1.1/lib/lib.html>.
-#
-#
-# B. DOCUMENT HISTORY
-#
-# 2001-12-04 GDR Created.
-#
-# 2001-12-06 GDR Added command-line interface and source code
-# annotation.
-#
-# 2001-12-09 GDR Moved design and interface to separate documents.
-#
-# 2001-12-10 GDR Open cache file as binary on Windows. Allow
-# simultaneous -e and -x, or -a and -r.
-#
-# 2001-12-12 GDR Added command-line help. Cache analysis so that it
-# only needs to be done once when you specify -a and -r.
-#
-# 2001-12-13 GDR Improved speed while recording. Portable between
-# Python 1.5.2 and 2.1.1.
-#
-# 2002-01-03 GDR Module-level functions work correctly.
-#
-# 2002-01-07 GDR Update sys.path when running a file with the -x option,
-# so that it matches the value the program would get if it were run on
-# its own.
-#
-# 2004-12-12 NMB Significant code changes.
-# - Finding executable statements has been rewritten so that docstrings and
-# other quirks of Python execution aren't mistakenly identified as missing
-# lines.
-# - Lines can be excluded from consideration, even entire suites of lines.
-# - The filesystem cache of covered lines can be disabled programmatically.
-# - Modernized the code.
-#
-# 2004-12-14 NMB Minor tweaks. Return 'analysis' to its original behavior
-# and add 'analysis2'. Add a global for 'annotate', and factor it, adding
-# 'annotate_file'.
-#
-# 2004-12-31 NMB Allow for keyword arguments in the module global functions.
-# Thanks, Allen.
-#
-# 2005-12-02 NMB Call threading.settrace so that all threads are measured.
-# Thanks Martin Fuzzey. Add a file argument to report so that reports can be
-# captured to a different destination.
-#
-# 2005-12-03 NMB coverage.py can now measure itself.
-#
-# 2005-12-04 NMB Adapted Greg Rogers' patch for using relative filenames,
-# and sorting and omitting files to report on.
-#
-# 2006-07-23 NMB Applied Joseph Tate's patch for function decorators.
-#
-# 2006-08-21 NMB Applied Sigve Tjora and Mark van der Wal's fixes for argument
-# handling.
-#
-# 2006-08-22 NMB Applied Geoff Bache's parallel mode patch.
-#
-# 2006-08-23 NMB Refactorings to improve testability. Fixes to command-line
-# logic for parallel mode and collect.
-#
-# 2006-08-25 NMB "#pragma: nocover" is excluded by default.
-#
-# 2006-09-10 NMB Properly ignore docstrings and other constant expressions that
-# appear in the middle of a function, a problem reported by Tim Leslie.
-# Minor changes to avoid lint warnings.
-#
-# 2006-09-17 NMB coverage.erase() shouldn't clobber the exclude regex.
-# Change how parallel mode is invoked, and fix erase() so that it erases the
-# cache when called programmatically.
-#
-# 2007-07-21 NMB In reports, ignore code executed from strings, since we can't
-# do anything useful with it anyway.
-# Better file handling on Linux, thanks Guillaume Chazarain.
-# Better shell support on Windows, thanks Noel O'Boyle.
-# Python 2.2 support maintained, thanks Catherine Proulx.
-#
-# 2007-07-22 NMB Python 2.5 now fully supported. The method of dealing with
-# multi-line statements is now less sensitive to the exact line that Python
-# reports during execution. Pass statements are handled specially so that their
-# disappearance during execution won't throw off the measurement.
-
-# C. COPYRIGHT AND LICENCE
-#
-# Copyright 2001 Gareth Rees. All rights reserved.
-# Copyright 2004-2007 Ned Batchelder. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# 1. Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#
-# 2. Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the
-# distribution.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# HOLDERS AND CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
-# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
-# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
-# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
-# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
-# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
-# DAMAGE.
-#
-# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $
diff --git a/test/testlib/engines.py b/test/testlib/engines.py
deleted file mode 100644
index 4068f43d0..000000000
--- a/test/testlib/engines.py
+++ /dev/null
@@ -1,245 +0,0 @@
-import sys, types, weakref
-from collections import deque
-from testlib import config
-from testlib.compat import _function_named, callable
-
-class ConnectionKiller(object):
- def __init__(self):
- self.proxy_refs = weakref.WeakKeyDictionary()
-
- def checkout(self, dbapi_con, con_record, con_proxy):
- self.proxy_refs[con_proxy] = True
-
- def _apply_all(self, methods):
- for rec in self.proxy_refs:
- if rec is not None and rec.is_valid:
- try:
- for name in methods:
- if callable(name):
- name(rec)
- else:
- getattr(rec, name)()
- except (SystemExit, KeyboardInterrupt):
- raise
- except Exception, e:
- # fixme
- sys.stderr.write("\n" + str(e) + "\n")
-
- def rollback_all(self):
- self._apply_all(('rollback',))
-
- def close_all(self):
- self._apply_all(('rollback', 'close'))
-
- def assert_all_closed(self):
- for rec in self.proxy_refs:
- if rec.is_valid:
- assert False
-
-testing_reaper = ConnectionKiller()
-
-def assert_conns_closed(fn):
- def decorated(*args, **kw):
- try:
- fn(*args, **kw)
- finally:
- testing_reaper.assert_all_closed()
- return _function_named(decorated, fn.__name__)
-
-def rollback_open_connections(fn):
- """Decorator that rolls back all open connections after fn execution."""
-
- def decorated(*args, **kw):
- try:
- fn(*args, **kw)
- finally:
- testing_reaper.rollback_all()
- return _function_named(decorated, fn.__name__)
-
-def close_open_connections(fn):
- """Decorator that closes all connections after fn execution."""
-
- def decorated(*args, **kw):
- try:
- fn(*args, **kw)
- finally:
- testing_reaper.close_all()
- return _function_named(decorated, fn.__name__)
-
-def all_dialects():
- import sqlalchemy.databases as d
- for name in d.__all__:
- mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
- yield mod.dialect()
-
-class ReconnectFixture(object):
- def __init__(self, dbapi):
- self.dbapi = dbapi
- self.connections = []
-
- def __getattr__(self, key):
- return getattr(self.dbapi, key)
-
- def connect(self, *args, **kwargs):
- conn = self.dbapi.connect(*args, **kwargs)
- self.connections.append(conn)
- return conn
-
- def shutdown(self):
- for c in list(self.connections):
- c.close()
- self.connections = []
-
-def reconnecting_engine(url=None, options=None):
- url = url or config.db_url
- dbapi = config.db.dialect.dbapi
- if not options:
- options = {}
- options['module'] = ReconnectFixture(dbapi)
- engine = testing_engine(url, options)
- engine.test_shutdown = engine.dialect.dbapi.shutdown
- return engine
-
-def testing_engine(url=None, options=None):
- """Produce an engine configured by --options with optional overrides."""
-
- from sqlalchemy import create_engine
- from testlib.assertsql import asserter
-
- url = url or config.db_url
- options = options or config.db_opts
-
- options.setdefault('proxy', asserter)
-
- listeners = options.setdefault('listeners', [])
- listeners.append(testing_reaper)
-
- engine = create_engine(url, **options)
-
- return engine
-
-def utf8_engine(url=None, options=None):
- """Hook for dialects or drivers that don't handle utf8 by default."""
-
- from sqlalchemy.engine import url as engine_url
-
- if config.db.name == 'mysql':
- dbapi_ver = config.db.dialect.dbapi.version_info
- if (dbapi_ver < (1, 2, 1) or
- dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
- (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))):
- raise RuntimeError('Character set support unavailable with this '
- 'driver version: %s' % repr(dbapi_ver))
- else:
- url = url or config.db_url
- url = engine_url.make_url(url)
- url.query['charset'] = 'utf8'
- url.query['use_unicode'] = '0'
- url = str(url)
-
- return testing_engine(url, options)
-
-def mock_engine(db=None):
- """Provides a mocking engine based on the current testing.db."""
-
- from sqlalchemy import create_engine
-
- dbi = db or config.db
- buffer = []
- def executor(sql, *a, **kw):
- buffer.append(sql)
- engine = create_engine(dbi.name + '://',
- strategy='mock', executor=executor)
- assert not hasattr(engine, 'mock')
- engine.mock = buffer
- return engine
-
-class ReplayableSession(object):
- """A simple record/playback tool.
-
- This is *not* a mock testing class. It only records a session for later
- playback and makes no assertions on call consistency whatsoever. It's
- unlikely to be suitable for anything other than DB-API recording.
-
- """
-
- Callable = object()
- NoAttribute = object()
- Natives = set([getattr(types, t)
- for t in dir(types) if not t.startswith('_')]). \
- difference([getattr(types, t)
- for t in ('FunctionType', 'BuiltinFunctionType',
- 'MethodType', 'BuiltinMethodType',
- 'LambdaType', 'UnboundMethodType',)])
- def __init__(self):
- self.buffer = deque()
-
- def recorder(self, base):
- return self.Recorder(self.buffer, base)
-
- def player(self):
- return self.Player(self.buffer)
-
- class Recorder(object):
- def __init__(self, buffer, subject):
- self._buffer = buffer
- self._subject = subject
-
- def __call__(self, *args, **kw):
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
-
- result = subject(*args, **kw)
- if type(result) not in ReplayableSession.Natives:
- buffer.append(ReplayableSession.Callable)
- return type(self)(buffer, result)
- else:
- buffer.append(result)
- return result
-
- def __getattribute__(self, key):
- try:
- return object.__getattribute__(self, key)
- except AttributeError:
- pass
-
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
- try:
- result = type(subject).__getattribute__(subject, key)
- except AttributeError:
- buffer.append(ReplayableSession.NoAttribute)
- raise
- else:
- if type(result) not in ReplayableSession.Natives:
- buffer.append(ReplayableSession.Callable)
- return type(self)(buffer, result)
- else:
- buffer.append(result)
- return result
-
- class Player(object):
- def __init__(self, buffer):
- self._buffer = buffer
-
- def __call__(self, *args, **kw):
- buffer = object.__getattribute__(self, '_buffer')
- result = buffer.popleft()
- if result is ReplayableSession.Callable:
- return self
- else:
- return result
-
- def __getattribute__(self, key):
- try:
- return object.__getattribute__(self, key)
- except AttributeError:
- pass
- buffer = object.__getattribute__(self, '_buffer')
- result = buffer.popleft()
- if result is ReplayableSession.Callable:
- return self
- elif result is ReplayableSession.NoAttribute:
- raise AttributeError(key)
- else:
- return result
diff --git a/test/testlib/orm.py b/test/testlib/orm.py
deleted file mode 100644
index 22d624601..000000000
--- a/test/testlib/orm.py
+++ /dev/null
@@ -1,117 +0,0 @@
-import inspect, re
-from testlib import config, testing
-
-sa = None
-orm = None
-
-__all__ = 'mapper',
-
-
-_whitespace = re.compile(r'^(\s+)')
-
-def _find_pragma(lines, current):
- m = _whitespace.match(lines[current])
- basis = m and m.group() or ''
-
- for line in reversed(lines[0:current]):
- if 'testlib.pragma' in line:
- return line
- m = _whitespace.match(line)
- indent = m and m.group() or ''
-
- # simplistic detection:
-
- # >> # testlib.pragma foo
- # >> center_line()
- if indent == basis:
- break
- # >> # testlib.pragma foo
- # >> if fleem:
- # >> center_line()
- if line.endswith(':'):
- break
- return None
-
-def _make_blocker(method_name, fallback):
- """Creates tripwired variant of a method, raising when called.
-
- To excempt an invocation from blockage, there are two options.
-
- 1) add a pragma in a comment::
-
- # testlib.pragma exempt:methodname
- offending_line()
-
- 2) add a magic cookie to the function's namespace::
- __sa_baremethodname_exempt__ = True
- ...
- offending_line()
- another_offending_lines()
-
- The second is useful for testing and development.
- """
-
- if method_name.startswith('__') and method_name.endswith('__'):
- frame_marker = '__sa_%s_exempt__' % method_name[2:-2]
- else:
- frame_marker = '__sa_%s_exempt__' % method_name
- pragma_marker = 'exempt:' + method_name
-
- def method(self, *args, **kw):
- frame_r = None
- try:
- frame = inspect.stack()[1][0]
- frame_r = inspect.getframeinfo(frame, 9)
-
- module = frame.f_globals.get('__name__', '')
-
- type_ = type(self)
-
- pragma = _find_pragma(*frame_r[3:5])
-
- exempt = (
- (not module.startswith('sqlalchemy')) or
- (pragma and pragma_marker in pragma) or
- (frame_marker in frame.f_locals) or
- ('self' in frame.f_locals and
- getattr(frame.f_locals['self'], frame_marker, False)))
-
- if exempt:
- supermeth = getattr(super(type_, self), method_name, None)
- if (supermeth is None or
- getattr(supermeth, 'im_func', None) is method):
- return fallback(self, *args, **kw)
- else:
- return supermeth(*args, **kw)
- else:
- raise AssertionError(
- "%s.%s called in %s, line %s in %s" % (
- type_.__name__, method_name, module, frame_r[1], frame_r[2]))
- finally:
- del frame
- method.__name__ = method_name
- return method
-
-def mapper(type_, *args, **kw):
- global orm
- if orm is None:
- from sqlalchemy import orm
-
- forbidden = [
- ('__hash__', 'unhashable', lambda s: id(s)),
- ('__eq__', 'noncomparable', lambda s, o: s is o),
- ('__ne__', 'noncomparable', lambda s, o: s is not o),
- ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)),
- ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)),
- ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)),
- ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)),
- ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)),
- ('__nonzero__', 'truthless', lambda s: 1), ]
-
- if isinstance(type_, type) and type_.__bases__ == (object,):
- for method_name, option, fallback in forbidden:
- if (getattr(config.options, option, False) and
- method_name not in type_.__dict__):
- setattr(type_, method_name, _make_blocker(method_name, fallback))
-
- return orm.mapper(type_, *args, **kw)
diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py
deleted file mode 100644
index 89db33011..000000000
--- a/test/testlib/profiling.py
+++ /dev/null
@@ -1,207 +0,0 @@
-"""Profiling support for unit and performance tests."""
-
-import os, sys
-from testlib.compat import _function_named
-import testlib.config
-
-__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
-
-all_targets = set()
-profile_config = { 'targets': set(),
- 'report': True,
- 'sort': ('time', 'calls'),
- 'limit': None }
-profiler = None
-
-def profiled(target=None, **target_opts):
- """Optional function profiling.
-
- @profiled('label')
- or
- @profiled('label', report=True, sort=('calls',), limit=20)
-
- Enables profiling for a function when 'label' is targetted for
- profiling. Report options can be supplied, and override the global
- configuration and command-line options.
- """
-
- # manual or automatic namespacing by module would remove conflict issues
- if target is None:
- target = 'anonymous_target'
- elif target in all_targets:
- print "Warning: redefining profile target '%s'" % target
- all_targets.add(target)
-
- filename = "%s.prof" % target
-
- def decorator(fn):
- def profiled(*args, **kw):
- if (target not in profile_config['targets'] and
- not target_opts.get('always', None)):
- return fn(*args, **kw)
-
- elapsed, load_stats, result = _profile(
- filename, fn, *args, **kw)
-
- if not testlib.config.options.quiet:
- print "Profiled target '%s', wall time: %.2f seconds" % (
- target, elapsed)
-
- report = target_opts.get('report', profile_config['report'])
- if report and testlib.config.options.verbose:
- sort_ = target_opts.get('sort', profile_config['sort'])
- limit = target_opts.get('limit', profile_config['limit'])
- print "Profile report for target '%s' (%s)" % (
- target, filename)
-
- stats = load_stats()
- stats.sort_stats(*sort_)
- if limit:
- stats.print_stats(limit)
- else:
- stats.print_stats()
- #stats.print_callers()
- os.unlink(filename)
- return result
- return _function_named(profiled, fn.__name__)
- return decorator
-
-def function_call_count(count=None, versions={}, variance=0.05):
- """Assert a target for a test case's function call count.
-
- count
- Optional, general target function call count.
-
- versions
- Optional, a dictionary of Python version strings to counts,
- for example::
-
- { '2.5.1': 110,
- '2.5': 100,
- '2.4': 150 }
-
- The best match for the current running python will be used.
- If none match, 'count' will be used as the fallback.
-
- variance
- An +/- deviation percentage, defaults to 5%.
- """
-
- # this could easily dump the profile report if --verbose is in effect
-
- version_info = list(sys.version_info)
- py_version = '.'.join([str(v) for v in sys.version_info])
-
- while version_info:
- version = '.'.join([str(v) for v in version_info])
- if version in versions:
- count = versions[version]
- break
- version_info.pop()
-
- if count is None:
- return lambda fn: fn
-
- def decorator(fn):
- def counted(*args, **kw):
- try:
- filename = "%s.prof" % fn.__name__
-
- elapsed, stat_loader, result = _profile(
- filename, fn, *args, **kw)
-
- stats = stat_loader()
- calls = stats.total_calls
-
- if testlib.config.options.verbose:
- stats.sort_stats('calls', 'cumulative')
- stats.print_stats()
- #stats.print_callers()
- deviance = int(count * variance)
- if (calls < (count - deviance) or
- calls > (count + deviance)):
- raise AssertionError(
- "Function call count %s not within %s%% "
- "of expected %s. (Python version %s)" % (
- calls, (variance * 100), count, py_version))
-
- return result
- finally:
- if os.path.exists(filename):
- os.unlink(filename)
- return _function_named(counted, fn.__name__)
- return decorator
-
-def conditional_call_count(discriminator, categories):
- """Apply a function call count conditionally at runtime.
-
- Takes two arguments, a callable that returns a key value, and a dict
- mapping key values to a tuple of arguments to function_call_count.
-
- The callable is not evaluated until the decorated function is actually
- invoked. If the `discriminator` returns a key not present in the
- `categories` dictionary, no call count assertion is applied.
-
- Useful for integration tests, where running a named test in isolation may
- have a function count penalty not seen in the full suite, due to lazy
- initialization in the DB-API, SA, etc.
- """
-
- def decorator(fn):
- def at_runtime(*args, **kw):
- criteria = categories.get(discriminator(), None)
- if criteria is None:
- return fn(*args, **kw)
-
- rewrapped = function_call_count(*criteria)(fn)
- return rewrapped(*args, **kw)
- return _function_named(at_runtime, fn.__name__)
- return decorator
-
-
-def _profile(filename, fn, *args, **kw):
- global profiler
- if not profiler:
- profiler = 'hotshot'
- if sys.version_info > (2, 5):
- try:
- import cProfile
- profiler = 'cProfile'
- except ImportError:
- pass
-
- if profiler == 'cProfile':
- return _profile_cProfile(filename, fn, *args, **kw)
- else:
- return _profile_hotshot(filename, fn, *args, **kw)
-
-def _profile_cProfile(filename, fn, *args, **kw):
- import cProfile, gc, pstats, time
-
- load_stats = lambda: pstats.Stats(filename)
- gc.collect()
-
- began = time.time()
- cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
- filename=filename)
- ended = time.time()
-
- return ended - began, load_stats, locals()['result']
-
-def _profile_hotshot(filename, fn, *args, **kw):
- import gc, hotshot, hotshot.stats, time
- load_stats = lambda: hotshot.stats.load(filename)
-
- gc.collect()
- prof = hotshot.Profile(filename)
- began = time.time()
- prof.start()
- try:
- result = fn(*args, **kw)
- finally:
- prof.stop()
- ended = time.time()
- prof.close()
-
- return ended - began, load_stats, result
-
diff --git a/test/testlib/requires.py b/test/testlib/requires.py
deleted file mode 100644
index b20929a83..000000000
--- a/test/testlib/requires.py
+++ /dev/null
@@ -1,127 +0,0 @@
-"""Global database feature support policy.
-
-Provides decorators to mark tests requiring specific feature support from the
-target database.
-
-"""
-
-from testlib.testing import \
- _block_unconditionally as no_support, \
- _chain_decorators_on, \
- exclude, \
- emits_warning_on
-
-
-def deferrable_constraints(fn):
- """Target database must support derferable constraints."""
- return _chain_decorators_on(
- fn,
- no_support('firebird', 'not supported by database'),
- no_support('mysql', 'not supported by database'),
- no_support('mssql', 'not supported by database'),
- )
-
-def foreign_keys(fn):
- """Target database must support foreign keys."""
- return _chain_decorators_on(
- fn,
- no_support('sqlite', 'not supported by database'),
- )
-
-def identity(fn):
- """Target database must support GENERATED AS IDENTITY or a facsimile.
-
- Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other
- column DDL feature that fills in a DB-generated identifier at INSERT-time
- without requiring pre-execution of a SEQUENCE or other artifact.
-
- """
- return _chain_decorators_on(
- fn,
- no_support('firebird', 'not supported by database'),
- no_support('oracle', 'not supported by database'),
- no_support('postgres', 'not supported by database'),
- no_support('sybase', 'not supported by database'),
- )
-
-def independent_connections(fn):
- """Target must support simultaneous, independent database connections."""
-
- # This is also true of some configurations of UnixODBC and probably win32
- # ODBC as well.
- return _chain_decorators_on(
- fn,
- no_support('sqlite', 'no driver support')
- )
-
-def row_triggers(fn):
- """Target must support standard statement-running EACH ROW triggers."""
- return _chain_decorators_on(
- fn,
- # no access to same table
- no_support('mysql', 'requires SUPER priv'),
- exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
- no_support('postgres', 'not supported by database: no statements'),
- )
-
-def savepoints(fn):
- """Target database must support savepoints."""
- return _chain_decorators_on(
- fn,
- emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'),
- no_support('access', 'not supported by database'),
- no_support('sqlite', 'not supported by database'),
- no_support('sybase', 'FIXME: guessing, needs confirmation'),
- exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
- )
-
-def sequences(fn):
- """Target database must support SEQUENCEs."""
- return _chain_decorators_on(
- fn,
- no_support('access', 'no SEQUENCE support'),
- no_support('mssql', 'no SEQUENCE support'),
- no_support('mysql', 'no SEQUENCE support'),
- no_support('sqlite', 'no SEQUENCE support'),
- no_support('sybase', 'no SEQUENCE support'),
- )
-
-def subqueries(fn):
- """Target database must support subqueries."""
- return _chain_decorators_on(
- fn,
- exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
- )
-
-def two_phase_transactions(fn):
- """Target database must support two-phase transactions."""
- return _chain_decorators_on(
- fn,
- no_support('access', 'not supported by database'),
- no_support('firebird', 'no SA implementation'),
- no_support('maxdb', 'not supported by database'),
- no_support('mssql', 'FIXME: guessing, needs confirmation'),
- no_support('oracle', 'no SA implementation'),
- no_support('sqlite', 'not supported by database'),
- no_support('sybase', 'FIXME: guessing, needs confirmation'),
- exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
- )
-
-def unicode_connections(fn):
- """Target driver must support some encoding of Unicode across the wire."""
- # TODO: expand to exclude MySQLdb versions w/ broken unicode
- return _chain_decorators_on(
- fn,
- exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
- )
-
-def unicode_ddl(fn):
- """Target driver must support some encoding of Unicode across the wire."""
- # TODO: expand to exclude MySQLdb versions w/ broken unicode
- return _chain_decorators_on(
- fn,
- no_support('maxdb', 'database support flakey'),
- no_support('oracle', 'FIXME: no support in database?'),
- no_support('sybase', 'FIXME: guessing, needs confirmation'),
- exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
- )
diff --git a/test/testlib/sa_unittest.py b/test/testlib/sa_unittest.py
deleted file mode 100644
index 7eb2c0727..000000000
--- a/test/testlib/sa_unittest.py
+++ /dev/null
@@ -1,787 +0,0 @@
-#!/usr/bin/env python
-'''
-unittest.py from Python 2.5.
-
-SQLAlchemy extends unittest internals to provide setUpAll()/tearDownAll()
-so we include a fixed version here to insulate from changes. 2.6 and
-3.0's unittest is incompatible with our changes.
-
-Approaches to removing this dependency are:
-
-* find a unittest-supported method of grouping UnitTest classes within
-a setUpAll()/tearDownAll() pair, such that all tests within a single
-UnitTest class are executed within a single execution of setUpAll()/
-tearDownAll(). It may be possible to create nested TestSuite objects
-to accomplish this but it's not clear.
-* migrate to a different system such as nose.
-
-Copyright (c) 1999-2003 Steve Purcell
-This module is free software, and you may redistribute it and/or modify
-it under the same terms as Python itself, so long as this copyright message
-and disclaimer are retained in their original form.
-
-IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
-SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
-THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
-DAMAGE.
-
-THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
-PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
-AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
-SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
-'''
-
-__author__ = "Steve Purcell"
-__email__ = "stephen_purcell at yahoo dot com"
-__version__ = "#Revision: 1.63 $"[11:-2]
-
-import time
-import sys
-import traceback
-import os
-import types
-from testlib.compat import callable
-
-##############################################################################
-# Exported classes and functions
-##############################################################################
-__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',
- 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader']
-
-# Expose obsolete functions for backwards compatibility
-__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
-
-
-##############################################################################
-# Test framework core
-##############################################################################
-
-# All classes defined herein are 'new-style' classes, allowing use of 'super()'
-__metaclass__ = type
-
-def _strclass(cls):
- return "%s.%s" % (cls.__module__, cls.__name__)
-
-__unittest = 1
-
-class TestResult:
- """Holder for test result information.
-
- Test results are automatically managed by the TestCase and TestSuite
- classes, and do not need to be explicitly manipulated by writers of tests.
-
- Each instance holds the total number of tests run, and collections of
- failures and errors that occurred among those test runs. The collections
- contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
- formatted traceback of the error that occurred.
- """
- def __init__(self):
- self.failures = []
- self.errors = []
- self.testsRun = 0
- self.shouldStop = 0
-
- def startTest(self, test):
- "Called when the given test is about to be run"
- self.testsRun = self.testsRun + 1
-
- def stopTest(self, test):
- "Called when the given test has been run"
- pass
-
- def addError(self, test, err):
- """Called when an error has occurred. 'err' is a tuple of values as
- returned by sys.exc_info().
- """
- self.errors.append((test, self._exc_info_to_string(err, test)))
-
- def addFailure(self, test, err):
- """Called when an error has occurred. 'err' is a tuple of values as
- returned by sys.exc_info()."""
- self.failures.append((test, self._exc_info_to_string(err, test)))
-
- def addSuccess(self, test):
- "Called when a test has completed successfully"
- pass
-
- def wasSuccessful(self):
- "Tells whether or not this result was a success"
- return len(self.failures) == len(self.errors) == 0
-
- def stop(self):
- "Indicates that the tests should be aborted"
- self.shouldStop = True
-
- def _exc_info_to_string(self, err, test):
- """Converts a sys.exc_info()-style tuple of values into a string."""
- exctype, value, tb = err
- # Skip test runner traceback levels
- while tb and self._is_relevant_tb_level(tb):
- tb = tb.tb_next
- if exctype is test.failureException:
- # Skip assert*() traceback levels
- length = self._count_relevant_tb_levels(tb)
- return ''.join(traceback.format_exception(exctype, value, tb, length))
- return ''.join(traceback.format_exception(exctype, value, tb))
-
- def _is_relevant_tb_level(self, tb):
- return tb.tb_frame.f_globals.has_key('__unittest')
-
- def _count_relevant_tb_levels(self, tb):
- length = 0
- while tb and not self._is_relevant_tb_level(tb):
- length += 1
- tb = tb.tb_next
- return length
-
- def __repr__(self):
- return "<%s run=%i errors=%i failures=%i>" % \
- (_strclass(self.__class__), self.testsRun, len(self.errors),
- len(self.failures))
-
-class TestCase:
- """A class whose instances are single test cases.
-
- By default, the test code itself should be placed in a method named
- 'runTest'.
-
- If the fixture may be used for many test cases, create as
- many test methods as are needed. When instantiating such a TestCase
- subclass, specify in the constructor arguments the name of the test method
- that the instance is to execute.
-
- Test authors should subclass TestCase for their own tests. Construction
- and deconstruction of the test's environment ('fixture') can be
- implemented by overriding the 'setUp' and 'tearDown' methods respectively.
-
- If it is necessary to override the __init__ method, the base class
- __init__ method must always be called. It is important that subclasses
- should not change the signature of their __init__ method, since instances
- of the classes are instantiated automatically by parts of the framework
- in order to be run.
- """
-
- # This attribute determines which exception will be raised when
- # the instance's assertion methods fail; test methods raising this
- # exception will be deemed to have 'failed' rather than 'errored'
-
- failureException = AssertionError
-
- def __init__(self, methodName='runTest'):
- """Create an instance of the class that will use the named test
- method when executed. Raises a ValueError if the instance does
- not have a method with the specified name.
- """
- try:
- self._testMethodName = methodName
- testMethod = getattr(self, methodName)
- self._testMethodDoc = testMethod.__doc__
- except AttributeError:
- raise ValueError, "no such test method in %s: %s" % \
- (self.__class__, methodName)
-
- def setUp(self):
- "Hook method for setting up the test fixture before exercising it."
- pass
-
- def tearDown(self):
- "Hook method for deconstructing the test fixture after testing it."
- pass
-
- def countTestCases(self):
- return 1
-
- def defaultTestResult(self):
- return TestResult()
-
- def shortDescription(self):
- """Returns a one-line description of the test, or None if no
- description has been provided.
-
- The default implementation of this method returns the first line of
- the specified test method's docstring.
- """
- doc = self._testMethodDoc
- return doc and doc.split("\n")[0].strip() or None
-
- def id(self):
- return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
-
- def __str__(self):
- return "%s (%s)" % (self._testMethodName, _strclass(self.__class__))
-
- def __repr__(self):
- return "<%s testMethod=%s>" % \
- (_strclass(self.__class__), self._testMethodName)
-
- def run(self, result=None):
- if result is None: result = self.defaultTestResult()
- result.startTest(self)
- testMethod = getattr(self, self._testMethodName)
- try:
- try:
- self.setUp()
- except KeyboardInterrupt:
- raise
- except:
- result.addError(self, self._exc_info())
- return
-
- ok = False
- try:
- testMethod()
- ok = True
- except self.failureException:
- result.addFailure(self, self._exc_info())
- except KeyboardInterrupt:
- raise
- except:
- result.addError(self, self._exc_info())
-
- try:
- self.tearDown()
- except KeyboardInterrupt:
- raise
- except:
- result.addError(self, self._exc_info())
- ok = False
- if ok: result.addSuccess(self)
- finally:
- result.stopTest(self)
-
- def __call__(self, *args, **kwds):
- return self.run(*args, **kwds)
-
- def debug(self):
- """Run the test without collecting errors in a TestResult"""
- self.setUp()
- getattr(self, self._testMethodName)()
- self.tearDown()
-
- def _exc_info(self):
- """Return a version of sys.exc_info() with the traceback frame
- minimised; usually the top level of the traceback frame is not
- needed.
- """
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
-
- def fail(self, msg=None):
- """Fail immediately, with the given message."""
- raise self.failureException, msg
-
- def failIf(self, expr, msg=None):
- "Fail the test if the expression is true."
- if expr: raise self.failureException, msg
-
- def failUnless(self, expr, msg=None):
- """Fail the test unless the expression is true."""
- if not expr: raise self.failureException, msg
-
- def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
- """Fail unless an exception of class excClass is thrown
- by callableObj when invoked with arguments args and keyword
- arguments kwargs. If a different type of exception is
- thrown, it will not be caught, and the test case will be
- deemed to have suffered an error, exactly as for an
- unexpected exception.
- """
- try:
- callableObj(*args, **kwargs)
- except excClass:
- return
- else:
- if hasattr(excClass,'__name__'): excName = excClass.__name__
- else: excName = str(excClass)
- raise self.failureException, "%s not raised" % excName
-
- def failUnlessEqual(self, first, second, msg=None):
- """Fail if the two objects are unequal as determined by the '=='
- operator.
- """
- if not first == second:
- raise self.failureException, \
- (msg or '%r != %r' % (first, second))
-
- def failIfEqual(self, first, second, msg=None):
- """Fail if the two objects are equal as determined by the '=='
- operator.
- """
- if first == second:
- raise self.failureException, \
- (msg or '%r == %r' % (first, second))
-
- def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
- """Fail if the two objects are unequal as determined by their
- difference rounded to the given number of decimal places
- (default 7) and comparing to zero.
-
- Note that decimal places (from zero) are usually not the same
- as significant digits (measured from the most signficant digit).
- """
- if round(second-first, places) != 0:
- raise self.failureException, \
- (msg or '%r != %r within %r places' % (first, second, places))
-
- def failIfAlmostEqual(self, first, second, places=7, msg=None):
- """Fail if the two objects are equal as determined by their
- difference rounded to the given number of decimal places
- (default 7) and comparing to zero.
-
- Note that decimal places (from zero) are usually not the same
- as significant digits (measured from the most signficant digit).
- """
- if round(second-first, places) == 0:
- raise self.failureException, \
- (msg or '%r == %r within %r places' % (first, second, places))
-
- # Synonyms for assertion methods
-
- assertEqual = assertEquals = failUnlessEqual
-
- assertNotEqual = assertNotEquals = failIfEqual
-
- assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
-
- assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
-
- assertRaises = failUnlessRaises
-
- assert_ = assertTrue = failUnless
-
- assertFalse = failIf
-
-
-
-class TestSuite:
- """A test suite is a composite test consisting of a number of TestCases.
-
- For use, create an instance of TestSuite, then add test case instances.
- When all tests have been added, the suite can be passed to a test
- runner, such as TextTestRunner. It will run the individual test cases
- in the order in which they were added, aggregating the results. When
- subclassing, do not forget to call the base class constructor.
- """
- def __init__(self, tests=()):
- self._tests = []
- self.addTests(tests)
-
- def __repr__(self):
- return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
-
- __str__ = __repr__
-
- def __iter__(self):
- return iter(self._tests)
-
- def countTestCases(self):
- cases = 0
- for test in self._tests:
- cases += test.countTestCases()
- return cases
-
- def addTest(self, test):
- # sanity checks
- if not callable(test):
- raise TypeError("the test to add must be callable")
- if (isinstance(test, (type, types.ClassType)) and
- issubclass(test, (TestCase, TestSuite))):
- raise TypeError("TestCases and TestSuites must be instantiated "
- "before passing them to addTest()")
- self._tests.append(test)
-
- def addTests(self, tests):
- if isinstance(tests, basestring):
- raise TypeError("tests must be an iterable of tests, not a string")
- for test in tests:
- self.addTest(test)
-
- def run(self, result):
- for test in self._tests:
- if result.shouldStop:
- break
- test(result)
- return result
-
- def __call__(self, *args, **kwds):
- return self.run(*args, **kwds)
-
- def debug(self):
- """Run the tests without collecting errors in a TestResult"""
- for test in self._tests: test.debug()
-
-
-class FunctionTestCase(TestCase):
- """A test case that wraps a test function.
-
- This is useful for slipping pre-existing test functions into the
- PyUnit framework. Optionally, set-up and tidy-up functions can be
- supplied. As with TestCase, the tidy-up ('tearDown') function will
- always be called if the set-up ('setUp') function ran successfully.
- """
-
- def __init__(self, testFunc, setUp=None, tearDown=None,
- description=None):
- TestCase.__init__(self)
- self.__setUpFunc = setUp
- self.__tearDownFunc = tearDown
- self.__testFunc = testFunc
- self.__description = description
-
- def setUp(self):
- if self.__setUpFunc is not None:
- self.__setUpFunc()
-
- def tearDown(self):
- if self.__tearDownFunc is not None:
- self.__tearDownFunc()
-
- def runTest(self):
- self.__testFunc()
-
- def id(self):
- return self.__testFunc.__name__
-
- def __str__(self):
- return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
-
- def __repr__(self):
- return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc)
-
- def shortDescription(self):
- if self.__description is not None: return self.__description
- doc = self.__testFunc.__doc__
- return doc and doc.split("\n")[0].strip() or None
-
-
-
-##############################################################################
-# Locating and loading tests
-##############################################################################
-
-class TestLoader:
- """This class is responsible for loading tests according to various
- criteria and returning them wrapped in a Test
- """
- testMethodPrefix = 'test'
- suiteClass = TestSuite
-
- def loadTestsFromTestCase(self, testCaseClass):
- """Return a suite of all tests cases contained in testCaseClass"""
- if issubclass(testCaseClass, TestSuite):
- raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?")
- testCaseNames = self.getTestCaseNames(testCaseClass)
- if not testCaseNames and hasattr(testCaseClass, 'runTest'):
- testCaseNames = ['runTest']
- return self.suiteClass(map(testCaseClass, testCaseNames))
-
- def loadTestsFromModule(self, module):
- """Return a suite of all tests cases contained in the given module"""
- tests = []
- for name in dir(module):
- obj = getattr(module, name)
- if (isinstance(obj, (type, types.ClassType)) and
- issubclass(obj, TestCase)):
- tests.append(self.loadTestsFromTestCase(obj))
- return self.suiteClass(tests)
-
- def loadTestsFromName(self, name, module=None):
- """Return a suite of all tests cases given a string specifier.
-
- The name may resolve either to a module, a test case class, a
- test method within a test case class, or a callable object which
- returns a TestCase or TestSuite instance.
-
- The method optionally resolves the names relative to a given module.
- """
- parts = name.split('.')
- if module is None:
- parts_copy = parts[:]
- while parts_copy:
- try:
- module = __import__('.'.join(parts_copy))
- break
- except ImportError:
- del parts_copy[-1]
- if not parts_copy: raise
- parts = parts[1:]
- obj = module
- for part in parts:
- parent, obj = obj, getattr(obj, part)
-
- if type(obj) == types.ModuleType:
- return self.loadTestsFromModule(obj)
- elif (isinstance(obj, (type, types.ClassType)) and
- issubclass(obj, TestCase)):
- return self.loadTestsFromTestCase(obj)
- elif type(obj) == types.UnboundMethodType:
- return parent(obj.__name__)
- elif isinstance(obj, TestSuite):
- return obj
- elif callable(obj):
- test = obj()
- if not isinstance(test, (TestCase, TestSuite)):
- raise ValueError, \
- "calling %s returned %s, not a test" % (obj,test)
- return test
- else:
- raise ValueError, "don't know how to make test from: %s" % obj
-
- def loadTestsFromNames(self, names, module=None):
- """Return a suite of all tests cases found using the given sequence
- of string specifiers. See 'loadTestsFromName()'.
- """
- suites = [self.loadTestsFromName(name, module) for name in names]
- return self.suiteClass(suites)
-
- def getTestCaseNames(self, testCaseClass):
- """Return a sorted sequence of method names found within testCaseClass
- """
-
- def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
- return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname))
- testFnNames = filter(isTestMethod, dir(testCaseClass))
- for baseclass in testCaseClass.__bases__:
- for testFnName in self.getTestCaseNames(baseclass):
- if testFnName not in testFnNames: # handle overridden methods
- testFnNames.append(testFnName)
- testFnNames.sort()
- return testFnNames
-
-
-
-defaultTestLoader = TestLoader()
-
-
-##############################################################################
-# Patches for old functions: these functions should be considered obsolete
-##############################################################################
-
-def _makeLoader(prefix, sortUsing, suiteClass=None):
- loader = TestLoader()
- loader.testMethodPrefix = prefix
- if suiteClass: loader.suiteClass = suiteClass
- return loader
-
-def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
- return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
-
-def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
- return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
-
-def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
- return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
-
-
-##############################################################################
-# Text UI
-##############################################################################
-
-class _WritelnDecorator:
- """Used to decorate file-like objects with a handy 'writeln' method"""
- def __init__(self,stream):
- self.stream = stream
-
- def __getattr__(self, attr):
- return getattr(self.stream,attr)
-
- def writeln(self, arg=None):
- if arg: self.write(arg)
- self.write('\n') # text-mode streams translate to \r\n if needed
-
-
-class _TextTestResult(TestResult):
- """A test result class that can print formatted text results to a stream.
-
- Used by TextTestRunner.
- """
- separator1 = '=' * 70
- separator2 = '-' * 70
-
- def __init__(self, stream, descriptions, verbosity):
- TestResult.__init__(self)
- self.stream = stream
- self.showAll = verbosity > 1
- self.dots = verbosity == 1
- self.descriptions = descriptions
-
- def getDescription(self, test):
- if self.descriptions:
- return test.shortDescription() or str(test)
- else:
- return str(test)
-
- def startTest(self, test):
- TestResult.startTest(self, test)
- if self.showAll:
- self.stream.write(self.getDescription(test))
- self.stream.write(" ... ")
-
- def addSuccess(self, test):
- TestResult.addSuccess(self, test)
- if self.showAll:
- self.stream.writeln("ok")
- elif self.dots:
- self.stream.write('.')
-
- def addError(self, test, err):
- TestResult.addError(self, test, err)
- if self.showAll:
- self.stream.writeln("ERROR")
- elif self.dots:
- self.stream.write('E')
-
- def addFailure(self, test, err):
- TestResult.addFailure(self, test, err)
- if self.showAll:
- self.stream.writeln("FAIL")
- elif self.dots:
- self.stream.write('F')
-
- def printErrors(self):
- if self.dots or self.showAll:
- self.stream.writeln()
- self.printErrorList('ERROR', self.errors)
- self.printErrorList('FAIL', self.failures)
-
- def printErrorList(self, flavour, errors):
- for test, err in errors:
- self.stream.writeln(self.separator1)
- self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
- self.stream.writeln(self.separator2)
- self.stream.writeln("%s" % err)
-
-
-class TextTestRunner:
- """A test runner class that displays results in textual form.
-
- It prints out the names of tests as they are run, errors as they
- occur, and a summary of the results at the end of the test run.
- """
- def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1):
- self.stream = _WritelnDecorator(stream)
- self.descriptions = descriptions
- self.verbosity = verbosity
-
- def _makeResult(self):
- return _TextTestResult(self.stream, self.descriptions, self.verbosity)
-
- def run(self, test):
- "Run the given test case or test suite."
- result = self._makeResult()
- startTime = time.time()
- test(result)
- stopTime = time.time()
- timeTaken = stopTime - startTime
- result.printErrors()
- self.stream.writeln(result.separator2)
- run = result.testsRun
- self.stream.writeln("Ran %d test%s in %.3fs" %
- (run, run != 1 and "s" or "", timeTaken))
- self.stream.writeln()
- if not result.wasSuccessful():
- self.stream.write("FAILED (")
- failed, errored = map(len, (result.failures, result.errors))
- if failed:
- self.stream.write("failures=%d" % failed)
- if errored:
- if failed: self.stream.write(", ")
- self.stream.write("errors=%d" % errored)
- self.stream.writeln(")")
- else:
- self.stream.writeln("OK")
- return result
-
-
-
-##############################################################################
-# Facilities for running tests from the command line
-##############################################################################
-
-class TestProgram:
- """A command-line program that runs a set of tests; this is primarily
- for making test modules conveniently executable.
- """
- USAGE = """\
-Usage: %(progName)s [options] [test] [...]
-
-Options:
- -h, --help Show this message
- -v, --verbose Verbose output
- -q, --quiet Minimal output
-
-Examples:
- %(progName)s - run default set of tests
- %(progName)s MyTestSuite - run suite 'MyTestSuite'
- %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething
- %(progName)s MyTestCase - run all 'test*' test methods
- in MyTestCase
-"""
- def __init__(self, module='__main__', defaultTest=None,
- argv=None, testRunner=None, testLoader=defaultTestLoader):
- if type(module) == type(''):
- self.module = __import__(module)
- for part in module.split('.')[1:]:
- self.module = getattr(self.module, part)
- else:
- self.module = module
- if argv is None:
- argv = sys.argv
- self.verbosity = 1
- self.defaultTest = defaultTest
- self.testRunner = testRunner
- self.testLoader = testLoader
- self.progName = os.path.basename(argv[0])
- self.parseArgs(argv)
- self.runTests()
-
- def usageExit(self, msg=None):
- if msg: print msg
- print self.USAGE % self.__dict__
- sys.exit(2)
-
- def parseArgs(self, argv):
- import getopt
- try:
- options, args = getopt.getopt(argv[1:], 'hHvq',
- ['help','verbose','quiet'])
- for opt, value in options:
- if opt in ('-h','-H','--help'):
- self.usageExit()
- if opt in ('-q','--quiet'):
- self.verbosity = 0
- if opt in ('-v','--verbose'):
- self.verbosity = 2
- if len(args) == 0 and self.defaultTest is None:
- self.test = self.testLoader.loadTestsFromModule(self.module)
- return
- if len(args) > 0:
- self.testNames = args
- else:
- self.testNames = (self.defaultTest,)
- self.createTests()
- except getopt.error, msg:
- self.usageExit(msg)
-
- def createTests(self):
- self.test = self.testLoader.loadTestsFromNames(self.testNames,
- self.module)
-
- def runTests(self):
- if self.testRunner is None:
- self.testRunner = TextTestRunner(verbosity=self.verbosity)
- result = self.testRunner.run(self.test)
- sys.exit(not result.wasSuccessful())
-
-main = TestProgram
-
-
-##############################################################################
-# Executing this module from the command line
-##############################################################################
-
-if __name__ == "__main__":
- main(module=None)
diff --git a/test/testlib/schema.py b/test/testlib/schema.py
deleted file mode 100644
index 7009fd65d..000000000
--- a/test/testlib/schema.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from testlib import testing
-
-schema = None
-
-__all__ = 'Table', 'Column',
-
-table_options = {}
-
-def Table(*args, **kw):
- """A schema.Table wrapper/hook for dialect-specific tweaks."""
-
- global schema
- if schema is None:
- from sqlalchemy import schema
-
- test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
- if k.startswith('test_')])
-
- kw.update(table_options)
-
- if testing.against('mysql'):
- if 'mysql_engine' not in kw and 'mysql_type' not in kw:
- if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
- kw['mysql_engine'] = 'InnoDB'
-
- # Apply some default cascading rules for self-referential foreign keys.
- # MySQL InnoDB has some issues around seleting self-refs too.
- if testing.against('firebird'):
- table_name = args[0]
- unpack = (testing.config.db.dialect.
- identifier_preparer.unformat_identifiers)
-
- # Only going after ForeignKeys in Columns. May need to
- # expand to ForeignKeyConstraint too.
- fks = [fk
- for col in args if isinstance(col, schema.Column)
- for fk in col.args if isinstance(fk, schema.ForeignKey)]
-
- for fk in fks:
- # root around in raw spec
- ref = fk._colspec
- if isinstance(ref, schema.Column):
- name = ref.table.name
- else:
- # take just the table name: on FB there cannot be
- # a schema, so the first element is always the
- # table name, possibly followed by the field name
- name = unpack(ref)[0]
- if name == table_name:
- if fk.ondelete is None:
- fk.ondelete = 'CASCADE'
- if fk.onupdate is None:
- fk.onupdate = 'CASCADE'
-
- if testing.against('firebird', 'oracle'):
- pk_seqs = [col for col in args
- if (isinstance(col, schema.Column)
- and col.primary_key
- and getattr(col, '_needs_autoincrement', False))]
- for c in pk_seqs:
- c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True))
- return schema.Table(*args, **kw)
-
-
-def Column(*args, **kw):
- """A schema.Column wrapper/hook for dialect-specific tweaks."""
-
- global schema
- if schema is None:
- from sqlalchemy import schema
-
- test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
- if k.startswith('test_')])
-
- c = schema.Column(*args, **kw)
- if testing.against('firebird', 'oracle'):
- if 'test_needs_autoincrement' in test_opts:
- c._needs_autoincrement = True
- return c
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
deleted file mode 100644
index 408dda79f..000000000
--- a/test/testlib/testing.py
+++ /dev/null
@@ -1,919 +0,0 @@
-"""TestCase and TestSuite artifacts and testing decorators."""
-
-# monkeypatches unittest.TestLoader.suiteClass at import time
-
-import itertools
-import operator
-import re
-import sys
-import types
-from testlib import sa_unittest as unittest
-import warnings
-from cStringIO import StringIO
-
-import testlib.config as config
-from testlib.compat import _function_named, callable
-
-# Delayed imports
-MetaData = None
-Session = None
-clear_mappers = None
-sa_exc = None
-schema = None
-sqltypes = None
-util = None
-
-
-_ops = { '<': operator.lt,
- '>': operator.gt,
- '==': operator.eq,
- '!=': operator.ne,
- '<=': operator.le,
- '>=': operator.ge,
- 'in': operator.contains,
- 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
- }
-
-# sugar ('testing.db'); set here by config() at runtime
-db = None
-
-# more sugar, installed by __init__
-requires = None
-
-def fails_if(callable_):
- """Mark a test as expected to fail if callable_ returns True.
-
- If the callable returns false, the test is run and reported as normal.
- However if the callable returns true, the test is expected to fail and the
- unit test logic is inverted: if the test fails, a success is reported. If
- the test succeeds, a failure is reported.
- """
-
- docstring = getattr(callable_, '__doc__', None) or callable_.__name__
- description = docstring.split('\n')[0]
-
- def decorate(fn):
- fn_name = fn.__name__
- def maybe(*args, **kw):
- if not callable_():
- return fn(*args, **kw)
- else:
- try:
- fn(*args, **kw)
- except Exception, ex:
- print ("'%s' failed as expected (condition: %s): %s " % (
- fn_name, description, str(ex)))
- return True
- else:
- raise AssertionError(
- "Unexpected success for '%s' (condition: %s)" %
- (fn_name, description))
- return _function_named(maybe, fn_name)
- return decorate
-
-
-def future(fn):
- """Mark a test as expected to unconditionally fail.
-
- Takes no arguments, omit parens when using as a decorator.
- """
-
- fn_name = fn.__name__
- def decorated(*args, **kw):
- try:
- fn(*args, **kw)
- except Exception, ex:
- print ("Future test '%s' failed as expected: %s " % (
- fn_name, str(ex)))
- return True
- else:
- raise AssertionError(
- "Unexpected success for future test '%s'" % fn_name)
- return _function_named(decorated, fn_name)
-
-def fails_on(dbs, reason):
- """Mark a test as expected to fail on the specified database
- implementation.
-
- Unlike ``crashes``, tests marked as ``fails_on`` will be run
- for the named databases. The test is expected to fail and the unit test
- logic is inverted: if the test fails, a success is reported. If the test
- succeeds, a failure is reported.
- """
-
- def decorate(fn):
- fn_name = fn.__name__
- def maybe(*args, **kw):
- if config.db.name != dbs:
- return fn(*args, **kw)
- else:
- try:
- fn(*args, **kw)
- except Exception, ex:
- print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, reason))
- return True
- else:
- raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
- return _function_named(maybe, fn_name)
- return decorate
-
-def fails_on_everything_except(*dbs):
- """Mark a test as expected to fail on most database implementations.
-
- Like ``fails_on``, except failure is the expected outcome on all
- databases except those listed.
- """
-
- def decorate(fn):
- fn_name = fn.__name__
- def maybe(*args, **kw):
- if config.db.name in dbs:
- return fn(*args, **kw)
- else:
- try:
- fn(*args, **kw)
- except Exception, ex:
- print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, str(ex)))
- return True
- else:
- raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
- return _function_named(maybe, fn_name)
- return decorate
-
-def crashes(db, reason):
- """Mark a test as unsupported by a database implementation.
-
- ``crashes`` tests will be skipped unconditionally. Use for feature tests
- that cause deadlocks or other fatal problems.
-
- """
- carp = _should_carp_about_exclusion(reason)
- def decorate(fn):
- fn_name = fn.__name__
- def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
- print msg
- if carp:
- print >> sys.stderr, msg
- return True
- else:
- return fn(*args, **kw)
- return _function_named(maybe, fn_name)
- return decorate
-
-def _block_unconditionally(db, reason):
- """Mark a test as unsupported by a database implementation.
-
- Will never run the test against any version of the given database, ever,
- no matter what. Use when your assumptions are infallible; past, present
- and future.
-
- """
- carp = _should_carp_about_exclusion(reason)
- def decorate(fn):
- fn_name = fn.__name__
- def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
- print msg
- if carp:
- print >> sys.stderr, msg
- return True
- else:
- return fn(*args, **kw)
- return _function_named(maybe, fn_name)
- return decorate
-
-
-def exclude(db, op, spec, reason):
- """Mark a test as unsupported by specific database server versions.
-
- Stackable, both with other excludes and other decorators. Examples::
-
- # Not supported by mydb versions less than 1, 0
- @exclude('mydb', '<', (1,0))
- # Other operators work too
- @exclude('bigdb', '==', (9,0,9))
- @exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
-
- """
- carp = _should_carp_about_exclusion(reason)
- def decorate(fn):
- fn_name = fn.__name__
- def maybe(*args, **kw):
- if _is_excluded(db, op, spec):
- msg = "'%s' unsupported on DB %s version '%s': %s" % (
- fn_name, config.db.name, _server_version(), reason)
- print msg
- if carp:
- print >> sys.stderr, msg
- return True
- else:
- return fn(*args, **kw)
- return _function_named(maybe, fn_name)
- return decorate
-
-def _should_carp_about_exclusion(reason):
- """Guard against forgotten exclusions."""
- assert reason
- for _ in ('todo', 'fixme', 'xxx'):
- if _ in reason.lower():
- return True
- else:
- if len(reason) < 4:
- return True
-
-def _is_excluded(db, op, spec):
- """Return True if the configured db matches an exclusion specification.
-
- db:
- A dialect name
- op:
- An operator or stringified operator, such as '=='
- spec:
- A value that will be compared to the dialect's server_version_info
- using the supplied operator.
-
- Examples::
- # Not supported by mydb versions less than 1, 0
- _is_excluded('mydb', '<', (1,0))
- # Other operators work too
- _is_excluded('bigdb', '==', (9,0,9))
- _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
- """
-
- if config.db.name != db:
- return False
-
- version = _server_version()
-
- oper = hasattr(op, '__call__') and op or _ops[op]
- return oper(version, spec)
-
-def _server_version(bind=None):
- """Return a server_version_info tuple."""
-
- if bind is None:
- bind = config.db
- return bind.dialect.server_version_info(bind.contextual_connect())
-
-def skip_if(predicate, reason=None):
- """Skip a test if predicate is true."""
- reason = reason or predicate.__name__
- def decorate(fn):
- fn_name = fn.__name__
- def maybe(*args, **kw):
- if predicate():
- msg = "'%s' skipped on DB %s version '%s': %s" % (
- fn_name, config.db.name, _server_version(), reason)
- print msg
- return True
- else:
- return fn(*args, **kw)
- return _function_named(maybe, fn_name)
- return decorate
-
-def emits_warning(*messages):
- """Mark a test as emitting a warning.
-
- With no arguments, squelches all SAWarning failures. Or pass one or more
- strings; these will be matched to the root of the warning description by
- warnings.filterwarnings().
- """
-
- # TODO: it would be nice to assert that a named warning was
- # emitted. should work with some monkeypatching of warnings,
- # and may work on non-CPython if they keep to the spirit of
- # warnings.showwarning's docstring.
- # - update: jython looks ok, it uses cpython's module
- def decorate(fn):
- def safe(*args, **kw):
- global sa_exc
- if sa_exc is None:
- import sqlalchemy.exc as sa_exc
-
- # todo: should probably be strict about this, too
- filters = [dict(action='ignore',
- category=sa_exc.SAPendingDeprecationWarning)]
- if not messages:
- filters.append(dict(action='ignore',
- category=sa_exc.SAWarning))
- else:
- filters.extend(dict(action='ignore',
- message=message,
- category=sa_exc.SAWarning)
- for message in messages)
- for f in filters:
- warnings.filterwarnings(**f)
- try:
- return fn(*args, **kw)
- finally:
- resetwarnings()
- return _function_named(safe, fn.__name__)
- return decorate
-
-def emits_warning_on(db, *warnings):
- """Mark a test as emitting a warning on a specific dialect.
-
- With no arguments, squelches all SAWarning failures. Or pass one or more
- strings; these will be matched to the root of the warning description by
- warnings.filterwarnings().
- """
- def decorate(fn):
- def maybe(*args, **kw):
- if isinstance(db, basestring):
- if config.db.name != db:
- return fn(*args, **kw)
- else:
- wrapped = emits_warning(*warnings)(fn)
- return wrapped(*args, **kw)
- else:
- if not _is_excluded(*db):
- return fn(*args, **kw)
- else:
- wrapped = emits_warning(*warnings)(fn)
- return wrapped(*args, **kw)
- return _function_named(maybe, fn.__name__)
- return decorate
-
-def uses_deprecated(*messages):
- """Mark a test as immune from fatal deprecation warnings.
-
- With no arguments, squelches all SADeprecationWarning failures.
- Or pass one or more strings; these will be matched to the root
- of the warning description by warnings.filterwarnings().
-
- As a special case, you may pass a function name prefixed with //
- and it will be re-written as needed to match the standard warning
- verbiage emitted by the sqlalchemy.util.deprecated decorator.
- """
-
- def decorate(fn):
- def safe(*args, **kw):
- global sa_exc
- if sa_exc is None:
- import sqlalchemy.exc as sa_exc
-
- # todo: should probably be strict about this, too
- filters = [dict(action='ignore',
- category=sa_exc.SAPendingDeprecationWarning)]
- if not messages:
- filters.append(dict(action='ignore',
- category=sa_exc.SADeprecationWarning))
- else:
- filters.extend(
- [dict(action='ignore',
- message=message,
- category=sa_exc.SADeprecationWarning)
- for message in
- [ (m.startswith('//') and
- ('Call to deprecated function ' + m[2:]) or m)
- for m in messages] ])
-
- for f in filters:
- warnings.filterwarnings(**f)
- try:
- return fn(*args, **kw)
- finally:
- resetwarnings()
- return _function_named(safe, fn.__name__)
- return decorate
-
-def resetwarnings():
- """Reset warning behavior to testing defaults."""
-
- global sa_exc
- if sa_exc is None:
- import sqlalchemy.exc as sa_exc
-
- warnings.filterwarnings('ignore',
- category=sa_exc.SAPendingDeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SAWarning)
-
-# warnings.simplefilter('error')
-
- if sys.version_info < (2, 4):
- warnings.filterwarnings('ignore', category=FutureWarning)
-
-
-def against(*queries):
- """Boolean predicate, compares to testing database configuration.
-
- Given one or more dialect names, returns True if one is the configured
- database engine.
-
- Also supports comparison to database version when provided with one or
- more 3-tuples of dialect name, operator, and version specification::
-
- testing.against('mysql', 'postgres')
- testing.against(('mysql', '>=', (5, 0, 0))
- """
-
- for query in queries:
- if isinstance(query, basestring):
- if config.db.name == query:
- return True
- else:
- name, op, spec = query
- if config.db.name != name:
- continue
-
- have = config.db.dialect.server_version_info(
- config.db.contextual_connect())
-
- oper = hasattr(op, '__call__') and op or _ops[op]
- if oper(have, spec):
- return True
- return False
-
-def _chain_decorators_on(fn, *decorators):
- """Apply a series of decorators to fn, returning a decorated function."""
- for decorator in reversed(decorators):
- fn = decorator(fn)
- return fn
-
-def rowset(results):
- """Converts the results of sql execution into a plain set of column tuples.
-
- Useful for asserting the results of an unordered query.
- """
-
- return set([tuple(row) for row in results])
-
-
-def eq_(a, b, msg=None):
- """Assert a == b, with repr messaging on failure."""
- assert a == b, msg or "%r != %r" % (a, b)
-
-def ne_(a, b, msg=None):
- """Assert a != b, with repr messaging on failure."""
- assert a != b, msg or "%r == %r" % (a, b)
-
-def is_(a, b, msg=None):
- """Assert a is b, with repr messaging on failure."""
- assert a is b, msg or "%r is not %r" % (a, b)
-
-def is_not_(a, b, msg=None):
- """Assert a is not b, with repr messaging on failure."""
- assert a is not b, msg or "%r is %r" % (a, b)
-
-def startswith_(a, fragment, msg=None):
- """Assert a.startswith(fragment), with repr messaging on failure."""
- assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
-
-
-def fixture(table, columns, *rows):
- """Insert data into table after creation."""
- def onload(event, schema_item, connection):
- insert = table.insert()
- column_names = [col.key for col in columns]
- connection.execute(insert, [dict(zip(column_names, column_values))
- for column_values in rows])
- table.append_ddl_listener('after-create', onload)
-
-def _import_by_name(name):
- submodule = name.split('.')[-1]
- return __import__(name, globals(), locals(), [submodule])
-
-class CompositeModule(types.ModuleType):
- """Merged attribute access for multiple modules."""
-
- # break the habit
- __all__ = ()
-
- def __init__(self, name, *modules, **overrides):
- """Construct a new lazy composite of modules.
-
- Modules may be string names or module-like instances. Individual
- attribute overrides may be specified as keyword arguments for
- convenience.
-
- The constructed module will resolve attribute access in reverse order:
- overrides, then each member of reversed(modules). Modules specified
- by name will be loaded lazily when encountered in attribute
- resolution.
-
- """
- types.ModuleType.__init__(self, name)
- self.__modules = list(reversed(modules))
- for key, value in overrides.iteritems():
- setattr(self, key, value)
-
- def __getattr__(self, key):
- for idx, mod in enumerate(self.__modules):
- if isinstance(mod, basestring):
- self.__modules[idx] = mod = _import_by_name(mod)
- if hasattr(mod, key):
- return getattr(mod, key)
- raise AttributeError(key)
-
-
-def resolve_artifact_names(fn):
- """Decorator, augment function globals with tables and classes.
-
- Swaps out the function's globals at execution time. The 'global' statement
- will not work as expected inside a decorated function.
-
- """
- # This could be automatically applied to framework and test_ methods in
- # the MappedTest-derived test suites but... *some* explicitness for this
- # magic is probably good. Especially as 'global' won't work- these
- # rebound functions aren't regular Python..
- #
- # Also: it's lame that CPython accepts a dict-subclass for globals, but
- # only calls dict methods. That would allow 'global' to pass through to
- # the func_globals.
- def resolved(*args, **kwargs):
- self = args[0]
- context = dict(fn.func_globals)
- for source in self._artifact_registries:
- context.update(getattr(self, source))
- # jython bug #1034
- rebound = types.FunctionType(
- fn.func_code, context, fn.func_name, fn.func_defaults,
- fn.func_closure)
- return rebound(*args, **kwargs)
- return _function_named(resolved, fn.func_name)
-
-class adict(dict):
- """Dict keys available as attributes. Shadows."""
- def __getattribute__(self, key):
- try:
- return self[key]
- except KeyError:
- return dict.__getattribute__(self, key)
-
- def get_all(self, *keys):
- return tuple([self[key] for key in keys])
-
-
-class TestBase(unittest.TestCase):
- # A sequence of database names to always run, regardless of the
- # constraints below.
- __whitelist__ = ()
-
- # A sequence of requirement names matching testing.requires decorators
- __requires__ = ()
-
- # A sequence of dialect names to exclude from the test class.
- __unsupported_on__ = ()
-
- # If present, test class is only runnable for the *single* specified
- # dialect. If you need multiple, use __unsupported_on__ and invert.
- __only_on__ = None
-
- # A sequence of no-arg callables. If any are True, the entire testcase is
- # skipped.
- __skip_if__ = None
-
-
- _artifact_registries = ()
-
- _sa_first_test = False
- _sa_last_test = False
-
- def __init__(self, *args, **params):
- unittest.TestCase.__init__(self, *args, **params)
-
- def setUpAll(self):
- pass
-
- def tearDownAll(self):
- pass
-
- def shortDescription(self):
- """overridden to not return docstrings"""
- return None
-
- def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
- try:
- callable_(*args, **kwargs)
- assert False, "Callable did not raise an exception"
- except except_cls, e:
- assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
-
- if not hasattr(unittest.TestCase, 'assertTrue'):
- assertTrue = unittest.TestCase.failUnless
- if not hasattr(unittest.TestCase, 'assertFalse'):
- assertFalse = unittest.TestCase.failIf
-
-class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None):
- if dialect is None:
- dialect = getattr(self, '__dialect__', None)
-
- if params is None:
- keys = None
- else:
- keys = params.keys()
-
- c = clause.compile(column_keys=keys, dialect=dialect)
-
- print "\nSQL String:\n" + str(c) + repr(c.params)
-
- cc = re.sub(r'\n', '', str(c))
-
- self.assertEquals(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
-
- if checkparams is not None:
- self.assertEquals(c.construct_params(params), checkparams)
-
-class ComparesTables(object):
- def assert_tables_equal(self, table, reflected_table):
- global sqltypes, schema
- if sqltypes is None:
- import sqlalchemy.types as sqltypes
- if schema is None:
- import sqlalchemy.schema as schema
- base_mro = sqltypes.TypeEngine.__mro__
- assert len(table.c) == len(reflected_table.c)
- for c, reflected_c in zip(table.c, reflected_table.c):
- self.assertEquals(c.name, reflected_c.name)
- assert reflected_c is reflected_table.c[c.name]
- self.assertEquals(c.primary_key, reflected_c.primary_key)
- self.assertEquals(c.nullable, reflected_c.nullable)
- assert len(
- set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
- set(type(c.type).__mro__).difference(base_mro)
- )
- ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
-
- if isinstance(c.type, sqltypes.String):
- self.assertEquals(c.type.length, reflected_c.type.length)
-
- self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
- if c.default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
- elif against(('mysql', '<', (5, 0))):
- # ignore reflection of bogus db-generated DefaultClause()
- pass
- elif not c.primary_key or not against('postgres'):
- print repr(c)
- assert reflected_c.default is None, reflected_c.default
-
- assert len(table.primary_key) == len(reflected_table.primary_key)
- for c in table.primary_key:
- assert reflected_table.primary_key.columns[c.name]
-
-
-class AssertsExecutionResults(object):
- def assert_result(self, result, class_, *objects):
- result = list(result)
- print repr(result)
- self.assert_list(result, class_, objects)
-
- def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
- for i in range(0, len(list)):
- self.assert_row(class_, result[i], list[i])
-
- def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
- for key, value in desc.iteritems():
- if isinstance(value, tuple):
- if isinstance(value[1], list):
- self.assert_list(getattr(rowobj, key), value[0], value[1])
- else:
- self.assert_row(value[0], getattr(rowobj, key), value[1])
- else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
-
- def assert_unordered_result(self, result, cls, *expected):
- """As assert_result, but the order of objects is not considered.
-
- The algorithm is very expensive but not a big deal for the small
- numbers of rows that the test suite manipulates.
- """
-
- global util
- if util is None:
- from sqlalchemy import util
-
- class frozendict(dict):
- def __hash__(self):
- return id(self)
-
- found = util.IdentitySet(result)
- expected = set([frozendict(e) for e in expected])
-
- for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
- self.fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
-
- if len(found) != len(expected):
- self.fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
-
- NOVALUE = object()
- def _compare_item(obj, spec):
- for key, value in spec.iteritems():
- if isinstance(value, tuple):
- try:
- self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
- except AssertionError:
- return False
- else:
- if getattr(obj, key, NOVALUE) != value:
- return False
- return True
-
- for expected_item in expected:
- for found_item in found:
- if _compare_item(found_item, expected_item):
- found.remove(found_item)
- break
- else:
- self.fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
- return True
-
- def assert_sql_execution(self, db, callable_, *rules):
- from testlib import assertsql
- assertsql.asserter.add_rules(rules)
- try:
- callable_()
- assertsql.asserter.statement_complete()
- finally:
- assertsql.asserter.clear_rules()
-
- def assert_sql(self, db, callable_, list_, with_sequences=None):
- from testlib import assertsql
-
- if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
- rules = with_sequences
- else:
- rules = list_
-
- newrules = []
- for rule in rules:
- if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
- ])
- else:
- newrule = assertsql.ExactSQL(*rule)
- newrules.append(newrule)
-
- self.assert_sql_execution(db, callable_, *newrules)
-
- def assert_sql_count(self, db, callable_, count):
- from testlib import assertsql
- self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
-
-
-
-class TTestSuite(unittest.TestSuite):
- """A TestSuite with once per TestCase setUpAll() and tearDownAll()"""
-
- def __init__(self, tests=()):
- if len(tests) > 0 and isinstance(tests[0], TestBase):
- self._initTest = tests[0]
- else:
- self._initTest = None
-
- for t in tests:
- if isinstance(t, TestBase):
- t._sa_first_test = True
- break
- for t in reversed(tests):
- if isinstance(t, TestBase):
- t._sa_last_test = True
- break
- unittest.TestSuite.__init__(self, tests)
-
- def run(self, result):
- init = getattr(self, '_initTest', None)
- if init is not None:
- if (hasattr(init, '__whitelist__') and
- config.db.name in init.__whitelist__):
- pass
- else:
- if self.__should_skip_for(init):
- return True
- try:
- resetwarnings()
- init.setUpAll()
- except:
- # skip tests if global setup fails
- ex = self.__exc_info()
- for test in self._tests:
- result.addError(test, ex)
- return False
- try:
- resetwarnings()
- for test in self._tests:
- if result.shouldStop:
- break
- test(result)
- return result
- finally:
- try:
- resetwarnings()
- if init is not None:
- init.tearDownAll()
- except:
- result.addError(init, self.__exc_info())
- pass
-
- def __should_skip_for(self, cls):
- if hasattr(cls, '__requires__'):
- global requires
- if requires is None:
- from testing import requires
- def test_suite(): return 'ok'
- for requirement in cls.__requires__:
- check = getattr(requires, requirement)
- if check(test_suite)() != 'ok':
- # The requirement will perform messaging.
- return True
- if (hasattr(cls, '__unsupported_on__') and
- config.db.name in cls.__unsupported_on__):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, config.db.name)
- return True
- if (getattr(cls, '__only_on__', None) not in (None,config.db.name)):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, config.db.name)
- return True
- if (getattr(cls, '__skip_if__', False)):
- for c in getattr(cls, '__skip_if__'):
- if c():
- print "'%s' skipped by %s" % (
- cls.__class__.__name__, c.__name__)
- return True
- for rule in getattr(cls, '__excluded_on__', ()):
- if _is_excluded(*rule):
- print "'%s' unsupported on DB %s version %s" % (
- cls.__class__.__name__, config.db.name,
- _server_version())
- return True
- return False
-
-
- def __exc_info(self):
- """Return a version of sys.exc_info() with the traceback frame
- minimised; usually the top level of the traceback frame is not
- needed.
- ripped off out of unittest module since its double __
- """
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
-
-# monkeypatch
-unittest.TestLoader.suiteClass = TTestSuite
-
-
-class DevNullWriter(object):
- def write(self, msg):
- pass
- def flush(self):
- pass
-
-def runTests(suite):
- verbose = config.options.verbose
- quiet = config.options.quiet
- orig_stdout = sys.stdout
-
- try:
- if not verbose or quiet:
- sys.stdout = DevNullWriter()
- runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
- return runner.run(suite)
- finally:
- if not verbose or quiet:
- sys.stdout = orig_stdout
-
-def main(suite=None):
- if not suite:
- if sys.argv[1:]:
- suite =unittest.TestLoader().loadTestsFromNames(
- sys.argv[1:], __import__('__main__'))
- else:
- suite = unittest.TestLoader().loadTestsFromModule(
- __import__('__main__'))
-
- result = runTests(suite)
- sys.exit(not result.wasSuccessful())