summaryrefslogtreecommitdiff
path: root/passlib/tests
diff options
context:
space:
mode:
authorEli Collins <elic@assurancetechnologies.com>2012-03-12 21:33:28 -0400
committerEli Collins <elic@assurancetechnologies.com>2012-03-12 21:33:28 -0400
commit3913a59ad033462e6a389544ffcdf8055db7ad9c (patch)
tree63dda089e61a9d8ef4b468a323df8c2ec2ad6c70 /passlib/tests
parentb970d6ee145122005f1e6808466900a94e00dfcc (diff)
downloadpasslib-3913a59ad033462e6a389544ffcdf8055db7ad9c.tar.gz
updated test support & py3 compat code from an external library
passlib.tests ------------- * deprecated support for unittest 1... accumulated too many backports, planning to require unittest2 in next release. * case_prefix renamed to shortDescription * test case now archives & clears warning registry state in addition to warning filter state passlib.utils.compat -------------------- * a bunch of the bytes-related functions were renamed for clarity * NativeStringIO alias added * trange alias merged into irange
Diffstat (limited to 'passlib/tests')
-rw-r--r--passlib/tests/genconfig.py2
-rw-r--r--passlib/tests/test_apache.py4
-rw-r--r--passlib/tests/test_context.py26
-rw-r--r--passlib/tests/test_ext_django.py8
-rw-r--r--passlib/tests/test_handlers.py4
-rw-r--r--passlib/tests/test_registry.py8
-rw-r--r--passlib/tests/test_utils.py20
-rw-r--r--passlib/tests/test_utils_handlers.py10
-rw-r--r--passlib/tests/utils.py431
9 files changed, 290 insertions, 223 deletions
diff --git a/passlib/tests/genconfig.py b/passlib/tests/genconfig.py
index dc87d02..83f016b 100644
--- a/passlib/tests/genconfig.py
+++ b/passlib/tests/genconfig.py
@@ -77,7 +77,7 @@ class HashTimer(object):
#
self.samples = samples
self.cache = {}
- self.srange = trange(samples)
+ self.srange = irange(samples)
def time_encrypt(self, rounds):
"check how long encryption for a given number of rounds will take"
diff --git a/passlib/tests/test_apache.py b/passlib/tests/test_apache.py
index 052c42e..d3b4ab8 100644
--- a/passlib/tests/test_apache.py
+++ b/passlib/tests/test_apache.py
@@ -30,7 +30,7 @@ def backdate_file_mtime(path, offset=10):
#=========================================================
class HtpasswdFileTest(TestCase):
"test HtpasswdFile class"
- case_prefix = "HtpasswdFile"
+ descriptionPrefix = "HtpasswdFile"
sample_01 = b('user2:2CHkkwa2AtqGs\nuser3:{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=\nuser4:pass4\nuser1:$apr1$t4tc7jTh$GPIWVUo8sQKJlUdV8V5vu0\n')
sample_02 = b('user3:{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=\nuser4:pass4\n')
@@ -205,7 +205,7 @@ class HtpasswdFileTest(TestCase):
#=========================================================
class HtdigestFileTest(TestCase):
"test HtdigestFile class"
- case_prefix = "HtdigestFile"
+ descriptionPrefix = "HtdigestFile"
sample_01 = b('user2:realm:549d2a5f4659ab39a80dac99e159ab19\nuser3:realm:a500bb8c02f6a9170ae46af10c898744\nuser4:realm:ab7b5d5f28ccc7666315f508c7358519\nuser1:realm:2a6cf53e7d8f8cf39d946dc880b14128\n')
sample_02 = b('user3:realm:a500bb8c02f6a9170ae46af10c898744\nuser4:realm:ab7b5d5f28ccc7666315f508c7358519\n')
diff --git a/passlib/tests/test_context.py b/passlib/tests/test_context.py
index 46cd3f5..f9edafc 100644
--- a/passlib/tests/test_context.py
+++ b/passlib/tests/test_context.py
@@ -22,7 +22,7 @@ from passlib.exc import PasslibConfigWarning
from passlib.utils import tick, to_bytes, to_unicode
from passlib.utils.compat import irange, u
import passlib.utils.handlers as uh
-from passlib.tests.utils import TestCase, mktemp, catch_all_warnings, \
+from passlib.tests.utils import TestCase, mktemp, catch_warnings, \
gae_env, set_file
from passlib.registry import register_crypt_handler_path, has_crypt_handler, \
_unload_handler_name as unload_handler_name
@@ -37,7 +37,7 @@ class CryptPolicyTest(TestCase):
#TODO: need to test user categories w/in all this
- case_prefix = "CryptPolicy"
+ descriptionPrefix = "CryptPolicy"
#=========================================================
#sample crypt policies used for testing
@@ -472,11 +472,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt
def test_15_min_verify_time(self):
"test get_min_verify_time() method"
# silence deprecation warnings for min verify time
- with catch_all_warnings():
- warnings.filterwarnings("ignore", category=DeprecationWarning)
- self._test_15()
-
- def _test_15(self):
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
pa = CryptPolicy()
self.assertEqual(pa.get_min_verify_time(), 0)
@@ -526,7 +522,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt
#=========================================================
class CryptContextTest(TestCase):
"test CryptContext object's behavior"
- case_prefix = "CryptContext"
+ descriptionPrefix = "CryptContext"
#=========================================================
#constructor
@@ -632,7 +628,7 @@ class CryptContextTest(TestCase):
)
# min rounds
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
# set below handler min
c2 = cc.replace(all__min_rounds=500, all__max_rounds=None,
@@ -663,7 +659,7 @@ class CryptContextTest(TestCase):
self.consumeWarningList(wlog)
# max rounds
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
# set above handler max
c2 = cc.replace(all__max_rounds=int(1e9)+500, all__min_rounds=None,
all__default_rounds=int(1e9)+500)
@@ -824,7 +820,7 @@ class CryptContextTest(TestCase):
# which is much cheaper, and shares the same codebase.
# min rounds
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
self.assertEqual(
cc.encrypt("password", rounds=1999, salt="nacl"),
'$5$rounds=2000$nacl$9/lTZ5nrfPuz8vphznnmHuDGFuvjSNvOEDsGmGfsS97',
@@ -950,7 +946,7 @@ class CryptContextTest(TestCase):
return to_unicode(secret + 'x')
# silence deprecation warnings for min verify time
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
cc = CryptContext([TimedHash], min_verify_time=min_verify_time)
self.consumeWarningList(wlog, DeprecationWarning)
@@ -977,7 +973,7 @@ class CryptContextTest(TestCase):
#ensure taking longer emits a warning.
TimedHash.delay = max_delay
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
elapsed, result = timecall(cc.verify, "blob", "stubx")
self.assertFalse(result)
self.assertAlmostEqual(elapsed, max_delay, delta=delta)
@@ -1024,7 +1020,7 @@ class CryptContextTest(TestCase):
GOOD1 = "$2a$12$oaQbBqq8JnSM1NHRPQGXOOm4GCUMqp7meTnkft4zgSnrbhoKdDV0C"
ctx = CryptContext(["bcrypt"])
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
self.assertTrue(ctx.hash_needs_update(BAD1))
self.assertFalse(ctx.hash_needs_update(GOOD1))
@@ -1101,7 +1097,7 @@ class dummy_2(uh.StaticHandler):
name = "dummy_2"
class LazyCryptContextTest(TestCase):
- case_prefix = "LazyCryptContext"
+ descriptionPrefix = "LazyCryptContext"
def setUp(self):
unload_handler_name("dummy_2")
diff --git a/passlib/tests/test_ext_django.py b/passlib/tests/test_ext_django.py
index 4627d30..13085fc 100644
--- a/passlib/tests/test_ext_django.py
+++ b/passlib/tests/test_ext_django.py
@@ -13,7 +13,7 @@ from passlib.context import CryptContext, CryptPolicy
from passlib.apps import django_context
from passlib.ext.django import utils
from passlib.hash import sha256_crypt
-from passlib.tests.utils import TestCase, unittest, ut_version, catch_all_warnings
+from passlib.tests.utils import TestCase, unittest, ut_version, catch_warnings
import passlib.tests.test_handlers as th
from passlib.utils.compat import iteritems, get_method_function, unicode
from passlib.registry import get_crypt_handler
@@ -130,7 +130,7 @@ def get_cc_rounds(**kwds):
class PatchTest(TestCase):
"test passlib.ext.django.utils:set_django_password_context"
- case_prefix = "passlib.ext.django utils"
+ descriptionPrefix = "passlib.ext.django utils"
def assert_unpatched(self):
"helper to ensure django hasn't been patched"
@@ -223,7 +223,7 @@ class PatchTest(TestCase):
def dummy():
pass
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
#patch to use stock django context
utils.set_django_password_context(django_context)
self.assert_patched(context=django_context)
@@ -422,7 +422,7 @@ if has_django0:
class PluginTest(TestCase):
"test django plugin via settings"
- case_prefix = "passlib.ext.django plugin"
+ descriptionPrefix = "passlib.ext.django plugin"
def setUp(self):
#remove django patch
diff --git a/passlib/tests/test_handlers.py b/passlib/tests/test_handlers.py
index 580cc08..4547e56 100644
--- a/passlib/tests/test_handlers.py
+++ b/passlib/tests/test_handlers.py
@@ -12,7 +12,7 @@ import warnings
from passlib import hash
from passlib.utils.compat import irange
from passlib.tests.utils import TestCase, HandlerCase, create_backend_case, \
- enable_option, b, catch_all_warnings, UserHandlerMixin, randintgauss
+ enable_option, b, catch_warnings, UserHandlerMixin, randintgauss
from passlib.utils.compat import u
#module
@@ -262,7 +262,7 @@ class _bcrypt_test(HandlerCase):
check_padding(bcrypt.encrypt("bob", rounds=bcrypt.min_rounds))
# some things that will raise warnings
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
#
# test genconfig() corrects invalid salts & issues warning.
#
diff --git a/passlib/tests/test_registry.py b/passlib/tests/test_registry.py
index c6fa9f6..1990919 100644
--- a/passlib/tests/test_registry.py
+++ b/passlib/tests/test_registry.py
@@ -16,7 +16,7 @@ from passlib import hash, registry
from passlib.registry import register_crypt_handler, register_crypt_handler_path, \
get_crypt_handler, list_crypt_handlers, _unload_handler_name as unload_handler_name
import passlib.utils.handlers as uh
-from passlib.tests.utils import TestCase, mktemp, catch_all_warnings
+from passlib.tests.utils import TestCase, mktemp, catch_warnings
#module
log = getLogger(__name__)
@@ -40,7 +40,7 @@ dummy_x = 1
#=========================================================
class RegistryTest(TestCase):
- case_prefix = "passlib registry"
+ descriptionPrefix = "passlib registry"
def tearDown(self):
for name in ("dummy_0", "dummy_1", "dummy_x", "dummy_bad"):
@@ -112,7 +112,7 @@ class RegistryTest(TestCase):
#TODO: check lazy load which calls register_crypt_handler (warning should be issued)
sys.modules.pop("passlib.tests._test_bad_register", None)
register_crypt_handler_path("dummy_bad", "passlib.tests._test_bad_register")
- with catch_all_warnings():
+ with catch_warnings():
warnings.filterwarnings("ignore", "xxxxxxxxxx", DeprecationWarning)
h = get_crypt_handler("dummy_bad")
from passlib.tests import _test_bad_register as tbr
@@ -158,7 +158,7 @@ class RegistryTest(TestCase):
register_crypt_handler(dummy_1)
self.assertIs(get_crypt_handler("dummy_1"), dummy_1)
- with catch_all_warnings():
+ with catch_warnings():
warnings.filterwarnings("ignore", "handler names should be lower-case, and use underscores instead of hyphens:.*", UserWarning)
self.assertIs(get_crypt_handler("DUMMY-1"), dummy_1)
diff --git a/passlib/tests/test_utils.py b/passlib/tests/test_utils.py
index 1baebd7..2d9d147 100644
--- a/passlib/tests/test_utils.py
+++ b/passlib/tests/test_utils.py
@@ -12,8 +12,8 @@ import warnings
#pkg
#module
from passlib.utils.compat import b, bytes, bascii_to_str, irange, PY2, PY3, u, \
- unicode, bjoin
-from passlib.tests.utils import TestCase, Params as ak, enable_option, catch_all_warnings
+ unicode, join_bytes
+from passlib.tests.utils import TestCase, Params as ak, enable_option, catch_warnings
def hb(source):
return unhexlify(b(source))
@@ -546,7 +546,7 @@ class _Base64Test(TestCase):
# helper to generate bytemap-specific strings
def m(self, *offsets):
"generate byte string from offsets"
- return bjoin(self.engine.bytemap[o:o+1] for o in offsets)
+ return join_bytes(self.engine.bytemap[o:o+1] for o in offsets)
#=========================================================
# test encode_bytes
@@ -842,7 +842,7 @@ from passlib.utils import h64, h64big
class H64_Test(_Base64Test):
"test H64 codec functions"
engine = h64
- case_prefix = "h64 codec"
+ descriptionPrefix = "h64 codec"
encoded_data = [
#test lengths 0..6 to ensure tail is encoded properly
@@ -867,7 +867,7 @@ class H64_Test(_Base64Test):
class H64Big_Test(_Base64Test):
"test H64Big codec functions"
engine = h64big
- case_prefix = "h64big codec"
+ descriptionPrefix = "h64big codec"
encoded_data = [
#test lengths 0..6 to ensure tail is encoded properly
@@ -960,12 +960,12 @@ has_ssl_md4 = (md4_mod.md4 is not md4_mod._builtin_md4)
if has_ssl_md4:
class MD4_SSL_Test(_MD4_Test):
- case_prefix = "MD4 (SSL version)"
+ descriptionPrefix = "MD4 (SSL version)"
hash = staticmethod(md4_mod.md4)
if not has_ssl_md4 or enable_option("cover"):
class MD4_Builtin_Test(_MD4_Test):
- case_prefix = "MD4 (builtin version)"
+ descriptionPrefix = "MD4 (builtin version)"
hash = md4_mod._builtin_md4
#=========================================================
@@ -1010,7 +1010,7 @@ class CryptoTest(TestCase):
self.assertRaises(TypeError, norm_hash_name, None)
# test selected results
- with catch_all_warnings():
+ with catch_warnings():
warnings.filterwarnings("ignore", '.*unknown hash')
for row in chain(_nhn_hash_names, self.ndn_values):
for idx, format in enumerate(self.ndn_formats):
@@ -1227,12 +1227,12 @@ has_m2crypto = (pbkdf2._EVP is not None)
if has_m2crypto:
class Pbkdf2_M2Crypto_Test(_Pbkdf2BackendTest):
- case_prefix = "pbkdf2 (m2crypto backend)"
+ descriptionPrefix = "pbkdf2 (m2crypto backend)"
enable_m2crypto = True
if not has_m2crypto or enable_option("cover"):
class Pbkdf2_Builtin_Test(_Pbkdf2BackendTest):
- case_prefix = "pbkdf2 (builtin backend)"
+ descriptionPrefix = "pbkdf2 (builtin backend)"
enable_m2crypto = False
#=========================================================
diff --git a/passlib/tests/test_utils_handlers.py b/passlib/tests/test_utils_handlers.py
index 7079917..4788798 100644
--- a/passlib/tests/test_utils_handlers.py
+++ b/passlib/tests/test_utils_handlers.py
@@ -18,7 +18,7 @@ from passlib.utils import getrandstr, JYTHON, rng, to_unicode
from passlib.utils.compat import b, bytes, bascii_to_str, str_to_uascii, \
uascii_to_str, unicode, PY_MAX_25
import passlib.utils.handlers as uh
-from passlib.tests.utils import HandlerCase, TestCase, catch_all_warnings
+from passlib.tests.utils import HandlerCase, TestCase, catch_warnings
from passlib.utils.compat import u
#module
log = getLogger(__name__)
@@ -190,7 +190,7 @@ class SkeletonTest(TestCase):
self.assertIn(norm_salt(use_defaults=True), salts3)
# check explicit salts
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
# check too-small salts
self.assertRaises(ValueError, norm_salt, salt='')
@@ -211,7 +211,7 @@ class SkeletonTest(TestCase):
self.consumeWarningList(wlog, PasslibHashWarning)
#check generated salts
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
# check too-small salt size
self.assertRaises(ValueError, gen_salt, 0)
@@ -233,7 +233,7 @@ class SkeletonTest(TestCase):
# test with max_salt_size=None
del d1.max_salt_size
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
self.assertEqual(len(gen_salt(None)), 3)
self.assertEqual(len(gen_salt(5)), 5)
self.consumeWarningList(wlog)
@@ -259,7 +259,7 @@ class SkeletonTest(TestCase):
self.assertEqual(norm_rounds(use_defaults=True), 2)
# check explicit rounds
- with catch_all_warnings(record=True) as wlog:
+ with catch_warnings(record=True) as wlog:
# too small
self.assertRaises(ValueError, norm_rounds, rounds=0)
self.consumeWarningList(wlog)
diff --git a/passlib/tests/utils.py b/passlib/tests/utils.py
index 2f36d3e..8e2883b 100644
--- a/passlib/tests/utils.py
+++ b/passlib/tests/utils.py
@@ -22,7 +22,10 @@ except ImportError:
if PY27 or PY_MIN_32:
ut_version = 2
else:
- # XXX: issue warning and deprecate support sometime?
+ # older versions of python will need to install the unittest2
+ # backport (named unittest2_3k for 3.0/3.1)
+ warn("please install unittest2 for python %d.%d, it will be required "
+ "as of passlib 1.7" % sys.version_info[:2])
ut_version = 1
import warnings
@@ -136,99 +139,88 @@ def get_file(path):
class TestCase(unittest.TestCase):
"""passlib-specific test case class
- this class mainly overriddes many of the common assert methods
- so to give a default message which includes the values
- as well as the class-specific case_prefix string.
- this latter bit makes the output of various test cases
- easier to distinguish from eachother.
+ this class adds a number of features to the standard TestCase...
+ * common prefix for all test descriptions
+ * resets warnings filter & registry for every test
+ * tweaks to message formatting
+ * __msg__ kwd added to assertRaises()
+ * backport of a bunch of unittest2 features
+ * suite of methods for matching against warnings
"""
+ #====================================================================
+ # add various custom features
+ #====================================================================
- #=============================================================
- #make it ease for test cases to add common prefix to all descs
- #=============================================================
- #: string or method returning string - prepended to all tests in TestCase
- case_prefix = None
+ #----------------------------------------------------------------
+ # make it easy for test cases to add common prefix to shortDescription
+ #----------------------------------------------------------------
- #: flag to disable feature
- longDescription = True
+ # string prepended to all tests in TestCase
+ descriptionPrefix = None
def shortDescription(self):
- "wrap shortDescription() method to prepend case_prefix"
+ "wrap shortDescription() method to prepend descriptionPrefix"
desc = super(TestCase, self).shortDescription()
- if desc is None:
- #would still like to add prefix, but munges things up.
- return None
- prefix = self.case_prefix
- if prefix and self.longDescription:
- if callable(prefix):
- prefix = prefix()
- desc = "%s: %s" % (prefix, desc)
+ prefix = self.descriptionPrefix
+ if prefix:
+ desc = "%s: %s" % (prefix, desc or str(self))
return desc
- #============================================================
- #hack to set UT2 private skip attrs to mirror nose's __test__ attr
- #============================================================
- if ut_version >= 2:
+ #----------------------------------------------------------------
+ # hack things so nose and ut2 both skip subclasses who have
+ # "__unittest_skip=True" set, or whose names start with "_"
+ #----------------------------------------------------------------
+ @classproperty
+ def __unittest_skip__(cls):
+ # NOTE: this attr is technically a unittest2 internal detail.
+ name = cls.__name__
+ return name.startswith("_") or \
+ getattr(cls, "_%s__unittest_skip" % name, False)
- @classproperty
- def __unittest_skip__(cls):
- # make this mirror nose's '__test__' attr
- return not getattr(cls, "__test__", True)
+ # make this mirror nose's '__test__' attr
+ return not getattr(cls, "__test__", True)
@classproperty
def __test__(cls):
- # nose uses to this to skip tests. overridding this to
- # skip classes with '__<cls>_unittest_skip' set - that way
- # we can omit specific classes without affecting subclasses.
- name = cls.__name__
- if name.startswith("_"):
- return False
- if getattr(cls, "_%s__unittest_skip" % name, False):
- return False
- return True
+ # make nose just proxy __unittest_skip__
+ return not cls.__unittest_skip__
+ # flag to skip *this* class
__unittest_skip = True
- #============================================================
- # tweak msg formatting for some assert methods
- #============================================================
- longMessage = True #override python default (False)
-
- def _formatMessage(self, msg, std):
- "override UT2's _formatMessage - only use longMessage if msg ends with ':'"
- if not msg:
- return std
- if not self.longMessage or not msg.endswith(":"):
- return msg.rstrip(":")
- return '%s %s' % (msg, std)
+ #----------------------------------------------------------------
+ # reset warning filters & registry before each test
+ #----------------------------------------------------------------
- #============================================================
- #override some unittest1 methods to support _formatMessage
- #============================================================
- if ut_version < 2:
+ # flag to enable this feature
+ resetWarningState = True
- def assertEqual(self, real, correct, msg=None):
- if real != correct:
- std = "got %r, expected would equal %r" % (real, correct)
- msg = self._formatMessage(msg, std)
- raise self.failureException(msg)
-
- def assertNotEqual(self, real, correct, msg=None):
- if real == correct:
- std = "got %r, expected would not equal %r" % (real, correct)
- msg = self._formatMessage(msg, std)
- raise self.failureException(msg)
+ def setUp(self):
+ unittest.TestCase.setUp(self)
+ if self.resetWarningState:
+ ctx = reset_warnings()
+ ctx.__enter__()
+ self.addCleanup(ctx.__exit__)
+
+ #----------------------------------------------------------------
+ # tweak message formatting so longMessage mode is only enabled
+ # if msg ends with ":", and turn on longMessage by default.
+ #----------------------------------------------------------------
+ longMessage = True
- assertEquals = assertEqual
- assertNotEquals = assertNotEqual
+ def _formatMessage(self, msg, std):
+ if self.longMessage and msg and msg.rstrip().endswith(":"):
+ return '%s %s' % (msg.rstrip(), std)
+ else:
+ return msg or std
- #NOTE: overriding this even under UT2.
- #FIXME: this doesn't support the fancy context manager UT2 provides.
- def assertRaises(self, _exc_type, _callable, *args, **kwds):
- #NOTE: overriding this for format ability,
- # but ALSO adding "__msg__" kwd so we can set custom msg
+ #----------------------------------------------------------------
+ # override assertRaises() to support '__msg__' keyword
+ #----------------------------------------------------------------
+ def assertRaises(self, _exc_type, _callable=None, *args, **kwds):
msg = kwds.pop("__msg__", None)
if _callable is None:
+ # FIXME: this ignores 'msg'
return super(TestCase, self).assertRaises(_exc_type, None,
*args, **kwds)
try:
@@ -239,11 +231,46 @@ class TestCase(unittest.TestCase):
_exc_type)
raise self.failureException(self._formatMessage(msg, std))
- #===============================================================
- #backport some methods from unittest2
- #===============================================================
+ #----------------------------------------------------------------
+ # null out a bunch of deprecated aliases so I stop using them
+ #----------------------------------------------------------------
+ assertEquals = assertNotEquals = assertRegexpMatches = None
+
+ #====================================================================
+ # backport some methods from unittest2
+ #====================================================================
if ut_version < 2:
+ #----------------------------------------------------------------
+ # simplistic backport of addCleanup() framework
+ #----------------------------------------------------------------
+ _cleanups = None
+
+ def addCleanup(self, function, *args, **kwds):
+ queue = self._cleanups
+ if queue is None:
+ queue = self._cleanups = []
+ queue.append((function, args, kwds))
+
+ def doCleanups(self):
+ queue = self._cleanups
+ while queue:
+ func, args, kwds = queue.pop()
+ func(*args, **kwds)
+
+ def tearDown(self):
+ self.doCleanups()
+ unittest.TestCase.tearDown(self)
+
+ #----------------------------------------------------------------
+ # backport skipTest (requires nose to work)
+ #----------------------------------------------------------------
+ def skipTest(self, reason):
+ raise SkipTest(reason)
+
+ #----------------------------------------------------------------
+ # backport various assert tests added in unittest2
+ #----------------------------------------------------------------
def assertIs(self, real, correct, msg=None):
if real is not correct:
std = "got %r, expected would be %r" % (real, correct)
@@ -262,9 +289,6 @@ class TestCase(unittest.TestCase):
msg = self._formatMessage(msg, std)
raise self.failureException(msg)
- def skipTest(self, reason):
- raise SkipTest(reason)
-
def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
"""Fail if the two objects are unequal as determined by their
difference rounded to the given number of decimal places
@@ -303,45 +327,68 @@ class TestCase(unittest.TestCase):
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
+ def assertLess(self, left, right, msg=None):
+ if left >= right:
+ std = "%r not less than %r" % (left, right)
+ raise self.failureException(self._formatMessage(msg, std))
+
+ def assertGreaterEqual(self, left, right, msg=None):
+ if left < right:
+ std = "%r less than %r" % (left, right)
+ raise self.failureException(self._formatMessage(msg, std))
+
+ def assertIn(self, elem, container, msg=None):
+ if elem not in container:
+ std = "%r not found in %r" % (elem, container)
+ raise self.failureException(self._formatMessage(msg, std))
+
+ def assertNotIn(self, elem, container, msg=None):
+ if elem in container:
+ std = "%r unexpectedly in %r" % (elem, container)
+ raise self.failureException(self._formatMessage(msg, std))
+
+ #----------------------------------------------------------------
+ # override some unittest1 methods to support _formatMessage
+ #----------------------------------------------------------------
+ def assertEqual(self, real, correct, msg=None):
+ if real != correct:
+ std = "got %r, expected would equal %r" % (real, correct)
+ msg = self._formatMessage(msg, std)
+ raise self.failureException(msg)
+
+ def assertNotEqual(self, real, correct, msg=None):
+ if real == correct:
+ std = "got %r, expected would not equal %r" % (real, correct)
+ msg = self._formatMessage(msg, std)
+ raise self.failureException(msg)
+
+ #----------------------------------------------------------------
+ # backport assertRegex() alias from 3.2 to 2.7/3.1
+ #----------------------------------------------------------------
if not hasattr(unittest.TestCase, "assertRegex"):
- # assertRegexpMatches() added in 2.7/UT2 and 3.1, renamed to
- # assertRegex() in 3.2; this code ensures assertRegex() is defined.
if hasattr(unittest.TestCase, "assertRegexpMatches"):
+ # was present in 2.7/3.1 under name assertRegexpMatches
assertRegex = unittest.TestCase.assertRegexpMatches
else:
+ # 3.0 and <= 2.6 didn't have this method at all
def assertRegex(self, text, expected_regex, msg=None):
"""Fail the test unless the text matches the regular expression."""
if isinstance(expected_regex, sb_types):
assert expected_regex, "expected_regex must not be empty."
expected_regex = re.compile(expected_regex)
if not expected_regex.search(text):
- msg = msg or "Regex didn't match"
- msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text)
- raise self.failureException(msg)
+ msg = msg or "Regex didn't match: "
+ std = '%r not found in %r' % (msg, expected_regex.pattern, text)
+ raise self.failureException(self._formatMessage(msg, std))
#============================================================
- #add some custom methods
+ # custom methods for matching warnings
#============================================================
- def assertFunctionResults(self, func, cases):
- """helper for running through function calls.
-
- func should be the function to call.
- cases should be list of Param instances,
- where first position argument is expected return value,
- and remaining args and kwds are passed to function.
- """
- for elem in cases:
- elem = Params.norm(elem)
- correct = elem.args[0]
- result = func(*elem.args[1:], **elem.kwds)
- msg = "error for case %r:" % (elem.render(1),)
- self.assertEqual(result, correct, msg)
-
def assertWarning(self, warning,
message_re=None, message=None,
category=None,
- ##filename=None, filename_re=None,
- ##lineno=None,
+ filename_re=None, filename=None,
+ lineno=None,
msg=None,
):
"check if WarningMessage instance (as returned by catch_warnings) matches parameters"
@@ -355,7 +402,7 @@ class TestCase(unittest.TestCase):
# no original WarningMessage, passed raw Warning
wmsg = None
- #tests that can use a warning instance or WarningMessage object
+ # tests that can use a warning instance or WarningMessage object
if message:
self.assertEqual(str(warning), message, msg)
if message_re:
@@ -363,29 +410,29 @@ class TestCase(unittest.TestCase):
if category:
self.assertIsInstance(warning, category, msg)
- #commented out until needed...
- ###tests that require a WarningMessage object
- ##if filename or filename_re:
- ## if not wmsg:
- ## raise TypeError("can't read filename from warning object")
- ## real = wmsg.filename
- ## if real.endswith(".pyc") or real.endswith(".pyo"):
- ## #FIXME: should use a stdlib call to resolve this back
- ## # to original module's path
- ## real = real[:-1]
- ## if filename:
- ## self.assertEqual(real, filename, msg)
- ## if filename_re:
- ## self.assertRegex(real, filename_re, msg)
- ##if lineno:
- ## if not wmsg:
- ## raise TypeError("can't read lineno from warning object")
- ## self.assertEqual(wmsg.lineno, lineno, msg)
+ # tests that require a WarningMessage object
+ if filename or filename_re:
+ if not wmsg:
+ raise TypeError("matching on filename requires a "
+ "WarningMessage instance")
+ real = wmsg.filename
+ if real.endswith(".pyc") or real.endswith(".pyo"):
+ # FIXME: should use a stdlib call to resolve this back
+ # to module's original filename.
+ real = real[:-1]
+ if filename:
+ self.assertEqual(real, filename, msg)
+ if filename_re:
+ self.assertRegex(real, filename_re, msg)
+ if lineno:
+ if not wmsg:
+ raise TypeError("matching on lineno requires a "
+ "WarningMessage instance")
+ self.assertEqual(wmsg.lineno, lineno, msg)
def assertWarningList(self, wlist, desc=None, msg=None):
"""check that warning list (e.g. from catch_warnings) matches pattern"""
- # TODO: make this display better diff of *which* warnings did not match,
- # and make use of _formatWarning below.
+ # TODO: make this display better diff of *which* warnings did not match
if not isinstance(desc, (list,tuple)):
desc = [] if desc is None else [desc]
for idx, entry in enumerate(desc):
@@ -407,6 +454,11 @@ class TestCase(unittest.TestCase):
(len(desc), len(wlist), self._formatWarningList(wlist), desc)
raise self.failureException(self._formatMessage(msg, std))
+ def consumeWarningList(self, wlist, *args, **kwds):
+ """assertWarningList() variant that clears list afterwards"""
+ self.assertWarningList(wlist, *args, **kwds)
+ del wlist[:]
+
def _formatWarning(self, entry):
tail = ""
if hasattr(entry, "message"):
@@ -422,10 +474,23 @@ class TestCase(unittest.TestCase):
def _formatWarningList(self, wlist):
return "[%s]" % ", ".join(self._formatWarning(entry) for entry in wlist)
- def consumeWarningList(self, wlist, *args, **kwds):
- """assertWarningList() variant that clears list afterwards"""
- self.assertWarningList(wlist, *args, **kwds)
- del wlist[:]
+ #============================================================
+ # misc custom methods
+ #============================================================
+ def assertFunctionResults(self, func, cases):
+ """helper for running through function calls.
+
+ func should be the function to call.
+ cases should be list of Param instances,
+ where first position argument is expected return value,
+ and remaining args and kwds are passed to function.
+ """
+ for elem in cases:
+ elem = Params.norm(elem)
+ correct = elem.args[0]
+ result = func(*elem.args[1:], **elem.kwds)
+ msg = "error for case %r:" % (elem.render(1),)
+ self.assertEqual(result, correct, msg)
#============================================================
#eoc
@@ -601,10 +666,8 @@ class HandlerCase(TestCase):
#=========================================================
__unittest_skip = True
- #optional prefix to prepend to name of test method as it's called,
- #useful when multiple handler test classes being run.
- #default behavior should be sufficient
- def case_prefix(self):
+ @property
+ def descriptionPrefix(self):
handler = self.handler
name = handler.name
if hasattr(handler, "get_backend"):
@@ -627,11 +690,7 @@ class HandlerCase(TestCase):
# setup / cleanup
#=========================================================
def setUp(self):
- # backup warning filter state; set to display all warnings during tests;
- # and restore filter state after test.
- ctx = catch_all_warnings()
- ctx.__enter__()
- self._restore_warnings = ctx.__exit__
+ TestCase.setUp(self)
# if needed, select specific backend for duration of test
handler = self.handler
@@ -641,7 +700,7 @@ class HandlerCase(TestCase):
raise RuntimeError("handler doesn't support multiple backends")
if backend == "os_crypt" and not handler.has_backend("os_crypt"):
self._patch_safe_crypt()
- self._orig_backend = handler.get_backend()
+ self.addCleanup(handler.set_backend, handler.get_backend())
handler.set_backend(backend)
def _patch_safe_crypt(self):
@@ -660,23 +719,10 @@ class HandlerCase(TestCase):
hash = handler.genhash(secret, hash)
assert isinstance(hash, str)
return hash
- self._orig_crypt = mod._crypt
+ self.addCleanup(setattr, mod, "_crypt", mod._crypt)
mod._crypt = crypt_stub
self.using_patched_crypt = True
- def tearDown(self):
- # unpatch safe_crypt()
- if self._orig_crypt:
- import passlib.utils as mod
- mod._crypt = self._orig_crypt
-
- # restore original backend
- if self._orig_backend:
- self.handler.set_backend(self._orig_backend)
-
- # restore warning filters
- self._restore_warnings()
-
#=========================================================
# basic tests
#=========================================================
@@ -1468,11 +1514,8 @@ class HandlerCase(TestCase):
"config=%r hash=%r" % (name, other, secret, kwds, hash))
count +=1
- name = self.case_prefix
- if not isinstance(name, str):
- name = name()
log.debug("fuzz test: %r checked %d passwords against %d verifiers (%s)",
- name, count, len(verifiers),
+ self.descriptionPrefix, count, len(verifiers),
", ".join(vname(v) for v in verifiers))
def get_fuzz_verifiers(self):
@@ -1745,7 +1788,7 @@ def create_backend_case(base_class, backend, module=None):
"%s_%s" % (backend, handler.name),
(base_class,),
dict(
- case_prefix = "%s (%s backend)" % (handler.name, backend),
+ descriptionPrefix = "%s (%s backend)" % (handler.name, backend),
backend = backend,
__module__= module or base_class.__module__,
)
@@ -1802,9 +1845,11 @@ def mktemp(*args, **kwds):
os.close(fd)
return path
-#=========================================================
-#make sure catch_warnings() is available
-#=========================================================
+#=============================================================================
+# warnings helpers
+#=============================================================================
+
+# make sure catch_warnings() is available
try:
from warnings import catch_warnings
except ImportError:
@@ -1893,33 +1938,59 @@ except ImportError:
self._module.filters = self._filters
self._module.showwarning = self._showwarning
-class catch_all_warnings(catch_warnings):
- "catch_warnings() wrapper which clears filter"
- def __init__(self, reset=".*", **kwds):
- super(catch_all_warnings, self).__init__(**kwds)
- self._reset_pat = reset
+class reset_warnings(catch_warnings):
+ "catch_warnings() wrapper which clears warning registry & filters"
+ def __init__(self, reset_filter="always", reset_registry=".*", **kwds):
+ super(reset_warnings, self).__init__(**kwds)
+ self._reset_filter = reset_filter
+ self._reset_registry = re.compile(reset_registry) if reset_registry else None
def __enter__(self):
# let parent class archive filter state
- ret = super(catch_all_warnings, self).__enter__()
+ ret = super(reset_warnings, self).__enter__()
# reset the filter to list everything
- warnings.resetwarnings()
- warnings.simplefilter("always")
+ if self._reset_filter:
+ warnings.resetwarnings()
+ warnings.simplefilter(self._reset_filter)
- # wipe the __warningregistry__ off the map, so warnings
- # reliably get reported per-test.
- # XXX: *could* restore state
- pattern = self._reset_pat
+ # archive and clear the __warningregistry__ key for all modules
+ # that match the 'reset' pattern.
+ pattern = self._reset_registry
if pattern:
- import sys
- key = "__warningregistry__"
- for mod in sys.modules.values():
- if hasattr(mod, key) and re.match(pattern, mod.__name__):
- getattr(mod, key).clear()
-
+ orig = self._orig_registry = {}
+ for name, mod in sys.modules.items():
+ if pattern.match(name):
+ reg = getattr(mod, "__warningregistry__", None)
+ if reg:
+ orig[name] = reg.copy()
+ reg.clear()
return ret
-#=========================================================
-#EOF
-#=========================================================
+ def __exit__(self, *exc_info):
+ # restore warning registry for all modules
+ pattern = self._reset_registry
+ if pattern:
+ # restore archived registry data
+ orig = self._orig_registry
+ for name, content in iteritems(orig):
+ mod = sys.modules.get(name)
+ if mod is None:
+ continue
+ reg = getattr(mod, "__warningregistry__", None)
+ if reg is None:
+ setattr(mod, "__warningregistry__", content)
+ else:
+ reg.clear()
+ reg.update(content)
+ # clear all registry entries that we didn't archive
+ for name, mod in sys.modules.items():
+ if pattern.match(name) and name not in orig:
+ reg = getattr(mod, "__warningregistry__", None)
+ if reg:
+ reg.clear()
+ super(reset_warnings, self).__exit__(*exc_info)
+
+#=============================================================================
+# eof
+#=============================================================================