diff options
Diffstat (limited to 'passlib/tests/backports.py')
-rw-r--r-- | passlib/tests/backports.py | 329 |
1 files changed, 329 insertions, 0 deletions
diff --git a/passlib/tests/backports.py b/passlib/tests/backports.py new file mode 100644 index 0000000..bde41cb --- /dev/null +++ b/passlib/tests/backports.py @@ -0,0 +1,329 @@ +"""backports of needed unittest2 features""" +#========================================================= +#imports +#========================================================= +from __future__ import with_statement +#core +import logging; log = logging.getLogger(__name__) +import re +import sys +##from warnings import warn +#site +#pkg +from passlib.utils.compat import base_string_types +#local +__all__ = [ + "TestCase", + "skip", "skipIf", "skipUnless" + "catch_warnings", +] + +#========================================================= +# import latest unittest module available +#========================================================= +try: + import unittest2 as unittest + ut_version = 2 +except ImportError: + import unittest + if sys.version_info < (2,7) or (3,0) <= sys.version_info < (3,2): + # older versions of python will need to install the unittest2 + # backport (named unittest2_3k for 3.0/3.1) + ##warn("please install unittest2 for python %d.%d, it will be required " + ## "as of passlib 1.x" % sys.version_info[:2]) + ut_version = 1 + else: + ut_version = 2 + +#========================================================= +# backport SkipTest support using nose +#========================================================= +if ut_version < 2: + # used to provide replacement SkipTest() error + from nose.plugins.skip import SkipTest + + # hack up something to simulate skip() decorator + import functools + def skip(reason): + def decorator(test_item): + if isinstance(test_item, type) and issubclass(test_item, unittest.TestCase): + class skip_wrapper(test_item): + def setUp(self): + raise SkipTest(reason) + else: + @functools.wraps(test_item) + def skip_wrapper(*args, **kwargs): + raise SkipTest(reason) + return skip_wrapper + return decorator + + def skipIf(condition, reason): + if condition: + return skip(reason) + else: + return lambda item: item + + def skipUnless(condition, reason): + if condition: + return lambda item: item + else: + return skip(reason) + +else: + skip = unittest.skip + skipIf = unittest.skipIf + skipUnless = unittest.skipUnless + +#========================================================= +# custom test harness +#========================================================= +class TestCase(unittest.TestCase): + """backports a number of unittest2 features in TestCase""" + #==================================================================== + # backport some methods from unittest2 + #==================================================================== + if ut_version < 2: + + #---------------------------------------------------------------- + # simplistic backport of addCleanup() framework + #---------------------------------------------------------------- + _cleanups = None + + def addCleanup(self, function, *args, **kwds): + queue = self._cleanups + if queue is None: + queue = self._cleanups = [] + queue.append((function, args, kwds)) + + def doCleanups(self): + queue = self._cleanups + while queue: + func, args, kwds = queue.pop() + func(*args, **kwds) + + def tearDown(self): + self.doCleanups() + unittest.TestCase.tearDown(self) + + #---------------------------------------------------------------- + # backport skipTest (requires nose to work) + #---------------------------------------------------------------- + def skipTest(self, reason): + raise SkipTest(reason) + + #---------------------------------------------------------------- + # backport various assert tests added in unittest2 + #---------------------------------------------------------------- + def assertIs(self, real, correct, msg=None): + if real is not correct: + std = "got %r, expected would be %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + def assertIsNot(self, real, correct, msg=None): + if real is correct: + std = "got %r, expected would not be %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + def assertIsInstance(self, obj, klass, msg=None): + if not isinstance(obj, klass): + std = "got %r, expected instance of %r" % (obj, klass) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None): + """Fail if the two objects are unequal as determined by their + difference rounded to the given number of decimal places + (default 7) and comparing to zero, or by comparing that the + between the two objects is more than the given delta. + + Note that decimal places (from zero) are usually not the same + as significant digits (measured from the most signficant digit). + + If the two objects compare equal then they will automatically + compare almost equal. + """ + if first == second: + # shortcut + return + if delta is not None and places is not None: + raise TypeError("specify delta or places not both") + + if delta is not None: + if abs(first - second) <= delta: + return + + standardMsg = '%s != %s within %s delta' % (repr(first), + repr(second), + repr(delta)) + else: + if places is None: + places = 7 + + if round(abs(second-first), places) == 0: + return + + standardMsg = '%s != %s within %r places' % (repr(first), + repr(second), + places) + msg = self._formatMessage(msg, standardMsg) + raise self.failureException(msg) + + def assertLess(self, left, right, msg=None): + if left >= right: + std = "%r not less than %r" % (left, right) + raise self.failureException(self._formatMessage(msg, std)) + + def assertGreater(self, left, right, msg=None): + if left <= right: + std = "%r not greater than %r" % (left, right) + raise self.failureException(self._formatMessage(msg, std)) + + def assertGreaterEqual(self, left, right, msg=None): + if left < right: + std = "%r less than %r" % (left, right) + raise self.failureException(self._formatMessage(msg, std)) + + def assertIn(self, elem, container, msg=None): + if elem not in container: + std = "%r not found in %r" % (elem, container) + raise self.failureException(self._formatMessage(msg, std)) + + def assertNotIn(self, elem, container, msg=None): + if elem in container: + std = "%r unexpectedly in %r" % (elem, container) + raise self.failureException(self._formatMessage(msg, std)) + + #---------------------------------------------------------------- + # override some unittest1 methods to support _formatMessage + #---------------------------------------------------------------- + def assertEqual(self, real, correct, msg=None): + if real != correct: + std = "got %r, expected would equal %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + def assertNotEqual(self, real, correct, msg=None): + if real == correct: + std = "got %r, expected would not equal %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + #---------------------------------------------------------------- + # backport assertRegex() alias from 3.2 to 2.7/3.1 + #---------------------------------------------------------------- + if not hasattr(unittest.TestCase, "assertRegex"): + if hasattr(unittest.TestCase, "assertRegexpMatches"): + # was present in 2.7/3.1 under name assertRegexpMatches + assertRegex = unittest.TestCase.assertRegexpMatches + else: + # 3.0 and <= 2.6 didn't have this method at all + def assertRegex(self, text, expected_regex, msg=None): + """Fail the test unless the text matches the regular expression.""" + if isinstance(expected_regex, base_string_types): + assert expected_regex, "expected_regex must not be empty." + expected_regex = re.compile(expected_regex) + if not expected_regex.search(text): + msg = msg or "Regex didn't match: " + std = '%r not found in %r' % (msg, expected_regex.pattern, text) + raise self.failureException(self._formatMessage(msg, std)) + + #============================================================ + #eoc + #============================================================ + +#============================================================================= +# backport catch_warnings +#============================================================================= +try: + from warnings import catch_warnings +except ImportError: + # catch_warnings wasn't added until py26. + # this adds backported copy from py26's stdlib + # so we can use it under py25. + + class WarningMessage(object): + + """Holds the result of a single showwarning() call.""" + + _WARNING_DETAILS = ("message", "category", "filename", "lineno", "file", + "line") + + def __init__(self, message, category, filename, lineno, file=None, + line=None): + local_values = locals() + for attr in self._WARNING_DETAILS: + setattr(self, attr, local_values[attr]) + self._category_name = category.__name__ if category else None + + def __str__(self): + return ("{message : %r, category : %r, filename : %r, lineno : %s, " + "line : %r}" % (self.message, self._category_name, + self.filename, self.lineno, self.line)) + + + class catch_warnings(object): + + """A context manager that copies and restores the warnings filter upon + exiting the context. + + The 'record' argument specifies whether warnings should be captured by a + custom implementation of warnings.showwarning() and be appended to a list + returned by the context manager. Otherwise None is returned by the context + manager. The objects appended to the list are arguments whose attributes + mirror the arguments to showwarning(). + + The 'module' argument is to specify an alternative module to the module + named 'warnings' and imported under that name. This argument is only useful + when testing the warnings module itself. + + """ + + def __init__(self, record=False, module=None): + """Specify whether to record warnings and if an alternative module + should be used other than sys.modules['warnings']. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + """ + self._record = record + self._module = sys.modules['warnings'] if module is None else module + self._entered = False + + def __repr__(self): + args = [] + if self._record: + args.append("record=True") + if self._module is not sys.modules['warnings']: + args.append("module=%r" % self._module) + name = type(self).__name__ + return "%s(%s)" % (name, ", ".join(args)) + + def __enter__(self): + if self._entered: + raise RuntimeError("Cannot enter %r twice" % self) + self._entered = True + self._filters = self._module.filters + self._module.filters = self._filters[:] + self._showwarning = self._module.showwarning + if self._record: + log = [] + def showwarning(*args, **kwargs): +# self._showwarning(*args, **kwargs) + log.append(WarningMessage(*args, **kwargs)) + self._module.showwarning = showwarning + return log + else: + return None + + def __exit__(self, *exc_info): + if not self._entered: + raise RuntimeError("Cannot exit %r without entering first" % self) + self._module.filters = self._filters + self._module.showwarning = self._showwarning + +#============================================================================= +# eof +#============================================================================= |