diff options
author | Eli Collins <elic@assurancetechnologies.com> | 2012-03-12 21:33:28 -0400 |
---|---|---|
committer | Eli Collins <elic@assurancetechnologies.com> | 2012-03-12 21:33:28 -0400 |
commit | 3913a59ad033462e6a389544ffcdf8055db7ad9c (patch) | |
tree | 63dda089e61a9d8ef4b468a323df8c2ec2ad6c70 /passlib/tests | |
parent | b970d6ee145122005f1e6808466900a94e00dfcc (diff) | |
download | passlib-3913a59ad033462e6a389544ffcdf8055db7ad9c.tar.gz |
updated test support & py3 compat code from an external library
passlib.tests
-------------
* deprecated support for unittest 1... accumulated too many backports,
planning to require unittest2 in next release.
* case_prefix renamed to shortDescription
* test case now archives & clears warning registry state in addition
to warning filter state
passlib.utils.compat
--------------------
* a bunch of the bytes-related functions were renamed for clarity
* NativeStringIO alias added
* trange alias merged into irange
Diffstat (limited to 'passlib/tests')
-rw-r--r-- | passlib/tests/genconfig.py | 2 | ||||
-rw-r--r-- | passlib/tests/test_apache.py | 4 | ||||
-rw-r--r-- | passlib/tests/test_context.py | 26 | ||||
-rw-r--r-- | passlib/tests/test_ext_django.py | 8 | ||||
-rw-r--r-- | passlib/tests/test_handlers.py | 4 | ||||
-rw-r--r-- | passlib/tests/test_registry.py | 8 | ||||
-rw-r--r-- | passlib/tests/test_utils.py | 20 | ||||
-rw-r--r-- | passlib/tests/test_utils_handlers.py | 10 | ||||
-rw-r--r-- | passlib/tests/utils.py | 431 |
9 files changed, 290 insertions, 223 deletions
diff --git a/passlib/tests/genconfig.py b/passlib/tests/genconfig.py index dc87d02..83f016b 100644 --- a/passlib/tests/genconfig.py +++ b/passlib/tests/genconfig.py @@ -77,7 +77,7 @@ class HashTimer(object): # self.samples = samples self.cache = {} - self.srange = trange(samples) + self.srange = irange(samples) def time_encrypt(self, rounds): "check how long encryption for a given number of rounds will take" diff --git a/passlib/tests/test_apache.py b/passlib/tests/test_apache.py index 052c42e..d3b4ab8 100644 --- a/passlib/tests/test_apache.py +++ b/passlib/tests/test_apache.py @@ -30,7 +30,7 @@ def backdate_file_mtime(path, offset=10): #========================================================= class HtpasswdFileTest(TestCase): "test HtpasswdFile class" - case_prefix = "HtpasswdFile" + descriptionPrefix = "HtpasswdFile" sample_01 = b('user2:2CHkkwa2AtqGs\nuser3:{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=\nuser4:pass4\nuser1:$apr1$t4tc7jTh$GPIWVUo8sQKJlUdV8V5vu0\n') sample_02 = b('user3:{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=\nuser4:pass4\n') @@ -205,7 +205,7 @@ class HtpasswdFileTest(TestCase): #========================================================= class HtdigestFileTest(TestCase): "test HtdigestFile class" - case_prefix = "HtdigestFile" + descriptionPrefix = "HtdigestFile" sample_01 = b('user2:realm:549d2a5f4659ab39a80dac99e159ab19\nuser3:realm:a500bb8c02f6a9170ae46af10c898744\nuser4:realm:ab7b5d5f28ccc7666315f508c7358519\nuser1:realm:2a6cf53e7d8f8cf39d946dc880b14128\n') sample_02 = b('user3:realm:a500bb8c02f6a9170ae46af10c898744\nuser4:realm:ab7b5d5f28ccc7666315f508c7358519\n') diff --git a/passlib/tests/test_context.py b/passlib/tests/test_context.py index 46cd3f5..f9edafc 100644 --- a/passlib/tests/test_context.py +++ b/passlib/tests/test_context.py @@ -22,7 +22,7 @@ from passlib.exc import PasslibConfigWarning from passlib.utils import tick, to_bytes, to_unicode from passlib.utils.compat import irange, u import passlib.utils.handlers as uh -from passlib.tests.utils import TestCase, mktemp, catch_all_warnings, \ +from passlib.tests.utils import TestCase, mktemp, catch_warnings, \ gae_env, set_file from passlib.registry import register_crypt_handler_path, has_crypt_handler, \ _unload_handler_name as unload_handler_name @@ -37,7 +37,7 @@ class CryptPolicyTest(TestCase): #TODO: need to test user categories w/in all this - case_prefix = "CryptPolicy" + descriptionPrefix = "CryptPolicy" #========================================================= #sample crypt policies used for testing @@ -472,11 +472,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt def test_15_min_verify_time(self): "test get_min_verify_time() method" # silence deprecation warnings for min verify time - with catch_all_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - self._test_15() - - def _test_15(self): + warnings.filterwarnings("ignore", category=DeprecationWarning) pa = CryptPolicy() self.assertEqual(pa.get_min_verify_time(), 0) @@ -526,7 +522,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt #========================================================= class CryptContextTest(TestCase): "test CryptContext object's behavior" - case_prefix = "CryptContext" + descriptionPrefix = "CryptContext" #========================================================= #constructor @@ -632,7 +628,7 @@ class CryptContextTest(TestCase): ) # min rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # set below handler min c2 = cc.replace(all__min_rounds=500, all__max_rounds=None, @@ -663,7 +659,7 @@ class CryptContextTest(TestCase): self.consumeWarningList(wlog) # max rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # set above handler max c2 = cc.replace(all__max_rounds=int(1e9)+500, all__min_rounds=None, all__default_rounds=int(1e9)+500) @@ -824,7 +820,7 @@ class CryptContextTest(TestCase): # which is much cheaper, and shares the same codebase. # min rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: self.assertEqual( cc.encrypt("password", rounds=1999, salt="nacl"), '$5$rounds=2000$nacl$9/lTZ5nrfPuz8vphznnmHuDGFuvjSNvOEDsGmGfsS97', @@ -950,7 +946,7 @@ class CryptContextTest(TestCase): return to_unicode(secret + 'x') # silence deprecation warnings for min verify time - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: cc = CryptContext([TimedHash], min_verify_time=min_verify_time) self.consumeWarningList(wlog, DeprecationWarning) @@ -977,7 +973,7 @@ class CryptContextTest(TestCase): #ensure taking longer emits a warning. TimedHash.delay = max_delay - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: elapsed, result = timecall(cc.verify, "blob", "stubx") self.assertFalse(result) self.assertAlmostEqual(elapsed, max_delay, delta=delta) @@ -1024,7 +1020,7 @@ class CryptContextTest(TestCase): GOOD1 = "$2a$12$oaQbBqq8JnSM1NHRPQGXOOm4GCUMqp7meTnkft4zgSnrbhoKdDV0C" ctx = CryptContext(["bcrypt"]) - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: self.assertTrue(ctx.hash_needs_update(BAD1)) self.assertFalse(ctx.hash_needs_update(GOOD1)) @@ -1101,7 +1097,7 @@ class dummy_2(uh.StaticHandler): name = "dummy_2" class LazyCryptContextTest(TestCase): - case_prefix = "LazyCryptContext" + descriptionPrefix = "LazyCryptContext" def setUp(self): unload_handler_name("dummy_2") diff --git a/passlib/tests/test_ext_django.py b/passlib/tests/test_ext_django.py index 4627d30..13085fc 100644 --- a/passlib/tests/test_ext_django.py +++ b/passlib/tests/test_ext_django.py @@ -13,7 +13,7 @@ from passlib.context import CryptContext, CryptPolicy from passlib.apps import django_context from passlib.ext.django import utils from passlib.hash import sha256_crypt -from passlib.tests.utils import TestCase, unittest, ut_version, catch_all_warnings +from passlib.tests.utils import TestCase, unittest, ut_version, catch_warnings import passlib.tests.test_handlers as th from passlib.utils.compat import iteritems, get_method_function, unicode from passlib.registry import get_crypt_handler @@ -130,7 +130,7 @@ def get_cc_rounds(**kwds): class PatchTest(TestCase): "test passlib.ext.django.utils:set_django_password_context" - case_prefix = "passlib.ext.django utils" + descriptionPrefix = "passlib.ext.django utils" def assert_unpatched(self): "helper to ensure django hasn't been patched" @@ -223,7 +223,7 @@ class PatchTest(TestCase): def dummy(): pass - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: #patch to use stock django context utils.set_django_password_context(django_context) self.assert_patched(context=django_context) @@ -422,7 +422,7 @@ if has_django0: class PluginTest(TestCase): "test django plugin via settings" - case_prefix = "passlib.ext.django plugin" + descriptionPrefix = "passlib.ext.django plugin" def setUp(self): #remove django patch diff --git a/passlib/tests/test_handlers.py b/passlib/tests/test_handlers.py index 580cc08..4547e56 100644 --- a/passlib/tests/test_handlers.py +++ b/passlib/tests/test_handlers.py @@ -12,7 +12,7 @@ import warnings from passlib import hash from passlib.utils.compat import irange from passlib.tests.utils import TestCase, HandlerCase, create_backend_case, \ - enable_option, b, catch_all_warnings, UserHandlerMixin, randintgauss + enable_option, b, catch_warnings, UserHandlerMixin, randintgauss from passlib.utils.compat import u #module @@ -262,7 +262,7 @@ class _bcrypt_test(HandlerCase): check_padding(bcrypt.encrypt("bob", rounds=bcrypt.min_rounds)) # some things that will raise warnings - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # # test genconfig() corrects invalid salts & issues warning. # diff --git a/passlib/tests/test_registry.py b/passlib/tests/test_registry.py index c6fa9f6..1990919 100644 --- a/passlib/tests/test_registry.py +++ b/passlib/tests/test_registry.py @@ -16,7 +16,7 @@ from passlib import hash, registry from passlib.registry import register_crypt_handler, register_crypt_handler_path, \ get_crypt_handler, list_crypt_handlers, _unload_handler_name as unload_handler_name import passlib.utils.handlers as uh -from passlib.tests.utils import TestCase, mktemp, catch_all_warnings +from passlib.tests.utils import TestCase, mktemp, catch_warnings #module log = getLogger(__name__) @@ -40,7 +40,7 @@ dummy_x = 1 #========================================================= class RegistryTest(TestCase): - case_prefix = "passlib registry" + descriptionPrefix = "passlib registry" def tearDown(self): for name in ("dummy_0", "dummy_1", "dummy_x", "dummy_bad"): @@ -112,7 +112,7 @@ class RegistryTest(TestCase): #TODO: check lazy load which calls register_crypt_handler (warning should be issued) sys.modules.pop("passlib.tests._test_bad_register", None) register_crypt_handler_path("dummy_bad", "passlib.tests._test_bad_register") - with catch_all_warnings(): + with catch_warnings(): warnings.filterwarnings("ignore", "xxxxxxxxxx", DeprecationWarning) h = get_crypt_handler("dummy_bad") from passlib.tests import _test_bad_register as tbr @@ -158,7 +158,7 @@ class RegistryTest(TestCase): register_crypt_handler(dummy_1) self.assertIs(get_crypt_handler("dummy_1"), dummy_1) - with catch_all_warnings(): + with catch_warnings(): warnings.filterwarnings("ignore", "handler names should be lower-case, and use underscores instead of hyphens:.*", UserWarning) self.assertIs(get_crypt_handler("DUMMY-1"), dummy_1) diff --git a/passlib/tests/test_utils.py b/passlib/tests/test_utils.py index 1baebd7..2d9d147 100644 --- a/passlib/tests/test_utils.py +++ b/passlib/tests/test_utils.py @@ -12,8 +12,8 @@ import warnings #pkg #module from passlib.utils.compat import b, bytes, bascii_to_str, irange, PY2, PY3, u, \ - unicode, bjoin -from passlib.tests.utils import TestCase, Params as ak, enable_option, catch_all_warnings + unicode, join_bytes +from passlib.tests.utils import TestCase, Params as ak, enable_option, catch_warnings def hb(source): return unhexlify(b(source)) @@ -546,7 +546,7 @@ class _Base64Test(TestCase): # helper to generate bytemap-specific strings def m(self, *offsets): "generate byte string from offsets" - return bjoin(self.engine.bytemap[o:o+1] for o in offsets) + return join_bytes(self.engine.bytemap[o:o+1] for o in offsets) #========================================================= # test encode_bytes @@ -842,7 +842,7 @@ from passlib.utils import h64, h64big class H64_Test(_Base64Test): "test H64 codec functions" engine = h64 - case_prefix = "h64 codec" + descriptionPrefix = "h64 codec" encoded_data = [ #test lengths 0..6 to ensure tail is encoded properly @@ -867,7 +867,7 @@ class H64_Test(_Base64Test): class H64Big_Test(_Base64Test): "test H64Big codec functions" engine = h64big - case_prefix = "h64big codec" + descriptionPrefix = "h64big codec" encoded_data = [ #test lengths 0..6 to ensure tail is encoded properly @@ -960,12 +960,12 @@ has_ssl_md4 = (md4_mod.md4 is not md4_mod._builtin_md4) if has_ssl_md4: class MD4_SSL_Test(_MD4_Test): - case_prefix = "MD4 (SSL version)" + descriptionPrefix = "MD4 (SSL version)" hash = staticmethod(md4_mod.md4) if not has_ssl_md4 or enable_option("cover"): class MD4_Builtin_Test(_MD4_Test): - case_prefix = "MD4 (builtin version)" + descriptionPrefix = "MD4 (builtin version)" hash = md4_mod._builtin_md4 #========================================================= @@ -1010,7 +1010,7 @@ class CryptoTest(TestCase): self.assertRaises(TypeError, norm_hash_name, None) # test selected results - with catch_all_warnings(): + with catch_warnings(): warnings.filterwarnings("ignore", '.*unknown hash') for row in chain(_nhn_hash_names, self.ndn_values): for idx, format in enumerate(self.ndn_formats): @@ -1227,12 +1227,12 @@ has_m2crypto = (pbkdf2._EVP is not None) if has_m2crypto: class Pbkdf2_M2Crypto_Test(_Pbkdf2BackendTest): - case_prefix = "pbkdf2 (m2crypto backend)" + descriptionPrefix = "pbkdf2 (m2crypto backend)" enable_m2crypto = True if not has_m2crypto or enable_option("cover"): class Pbkdf2_Builtin_Test(_Pbkdf2BackendTest): - case_prefix = "pbkdf2 (builtin backend)" + descriptionPrefix = "pbkdf2 (builtin backend)" enable_m2crypto = False #========================================================= diff --git a/passlib/tests/test_utils_handlers.py b/passlib/tests/test_utils_handlers.py index 7079917..4788798 100644 --- a/passlib/tests/test_utils_handlers.py +++ b/passlib/tests/test_utils_handlers.py @@ -18,7 +18,7 @@ from passlib.utils import getrandstr, JYTHON, rng, to_unicode from passlib.utils.compat import b, bytes, bascii_to_str, str_to_uascii, \ uascii_to_str, unicode, PY_MAX_25 import passlib.utils.handlers as uh -from passlib.tests.utils import HandlerCase, TestCase, catch_all_warnings +from passlib.tests.utils import HandlerCase, TestCase, catch_warnings from passlib.utils.compat import u #module log = getLogger(__name__) @@ -190,7 +190,7 @@ class SkeletonTest(TestCase): self.assertIn(norm_salt(use_defaults=True), salts3) # check explicit salts - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # check too-small salts self.assertRaises(ValueError, norm_salt, salt='') @@ -211,7 +211,7 @@ class SkeletonTest(TestCase): self.consumeWarningList(wlog, PasslibHashWarning) #check generated salts - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # check too-small salt size self.assertRaises(ValueError, gen_salt, 0) @@ -233,7 +233,7 @@ class SkeletonTest(TestCase): # test with max_salt_size=None del d1.max_salt_size - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: self.assertEqual(len(gen_salt(None)), 3) self.assertEqual(len(gen_salt(5)), 5) self.consumeWarningList(wlog) @@ -259,7 +259,7 @@ class SkeletonTest(TestCase): self.assertEqual(norm_rounds(use_defaults=True), 2) # check explicit rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # too small self.assertRaises(ValueError, norm_rounds, rounds=0) self.consumeWarningList(wlog) diff --git a/passlib/tests/utils.py b/passlib/tests/utils.py index 2f36d3e..8e2883b 100644 --- a/passlib/tests/utils.py +++ b/passlib/tests/utils.py @@ -22,7 +22,10 @@ except ImportError: if PY27 or PY_MIN_32: ut_version = 2 else: - # XXX: issue warning and deprecate support sometime? + # older versions of python will need to install the unittest2 + # backport (named unittest2_3k for 3.0/3.1) + warn("please install unittest2 for python %d.%d, it will be required " + "as of passlib 1.7" % sys.version_info[:2]) ut_version = 1 import warnings @@ -136,99 +139,88 @@ def get_file(path): class TestCase(unittest.TestCase): """passlib-specific test case class - this class mainly overriddes many of the common assert methods - so to give a default message which includes the values - as well as the class-specific case_prefix string. - this latter bit makes the output of various test cases - easier to distinguish from eachother. + this class adds a number of features to the standard TestCase... + * common prefix for all test descriptions + * resets warnings filter & registry for every test + * tweaks to message formatting + * __msg__ kwd added to assertRaises() + * backport of a bunch of unittest2 features + * suite of methods for matching against warnings """ + #==================================================================== + # add various custom features + #==================================================================== - #============================================================= - #make it ease for test cases to add common prefix to all descs - #============================================================= - #: string or method returning string - prepended to all tests in TestCase - case_prefix = None + #---------------------------------------------------------------- + # make it easy for test cases to add common prefix to shortDescription + #---------------------------------------------------------------- - #: flag to disable feature - longDescription = True + # string prepended to all tests in TestCase + descriptionPrefix = None def shortDescription(self): - "wrap shortDescription() method to prepend case_prefix" + "wrap shortDescription() method to prepend descriptionPrefix" desc = super(TestCase, self).shortDescription() - if desc is None: - #would still like to add prefix, but munges things up. - return None - prefix = self.case_prefix - if prefix and self.longDescription: - if callable(prefix): - prefix = prefix() - desc = "%s: %s" % (prefix, desc) + prefix = self.descriptionPrefix + if prefix: + desc = "%s: %s" % (prefix, desc or str(self)) return desc - #============================================================ - #hack to set UT2 private skip attrs to mirror nose's __test__ attr - #============================================================ - if ut_version >= 2: + #---------------------------------------------------------------- + # hack things so nose and ut2 both skip subclasses who have + # "__unittest_skip=True" set, or whose names start with "_" + #---------------------------------------------------------------- + @classproperty + def __unittest_skip__(cls): + # NOTE: this attr is technically a unittest2 internal detail. + name = cls.__name__ + return name.startswith("_") or \ + getattr(cls, "_%s__unittest_skip" % name, False) - @classproperty - def __unittest_skip__(cls): - # make this mirror nose's '__test__' attr - return not getattr(cls, "__test__", True) + # make this mirror nose's '__test__' attr + return not getattr(cls, "__test__", True) @classproperty def __test__(cls): - # nose uses to this to skip tests. overridding this to - # skip classes with '__<cls>_unittest_skip' set - that way - # we can omit specific classes without affecting subclasses. - name = cls.__name__ - if name.startswith("_"): - return False - if getattr(cls, "_%s__unittest_skip" % name, False): - return False - return True + # make nose just proxy __unittest_skip__ + return not cls.__unittest_skip__ + # flag to skip *this* class __unittest_skip = True - #============================================================ - # tweak msg formatting for some assert methods - #============================================================ - longMessage = True #override python default (False) - - def _formatMessage(self, msg, std): - "override UT2's _formatMessage - only use longMessage if msg ends with ':'" - if not msg: - return std - if not self.longMessage or not msg.endswith(":"): - return msg.rstrip(":") - return '%s %s' % (msg, std) + #---------------------------------------------------------------- + # reset warning filters & registry before each test + #---------------------------------------------------------------- - #============================================================ - #override some unittest1 methods to support _formatMessage - #============================================================ - if ut_version < 2: + # flag to enable this feature + resetWarningState = True - def assertEqual(self, real, correct, msg=None): - if real != correct: - std = "got %r, expected would equal %r" % (real, correct) - msg = self._formatMessage(msg, std) - raise self.failureException(msg) - - def assertNotEqual(self, real, correct, msg=None): - if real == correct: - std = "got %r, expected would not equal %r" % (real, correct) - msg = self._formatMessage(msg, std) - raise self.failureException(msg) + def setUp(self): + unittest.TestCase.setUp(self) + if self.resetWarningState: + ctx = reset_warnings() + ctx.__enter__() + self.addCleanup(ctx.__exit__) + + #---------------------------------------------------------------- + # tweak message formatting so longMessage mode is only enabled + # if msg ends with ":", and turn on longMessage by default. + #---------------------------------------------------------------- + longMessage = True - assertEquals = assertEqual - assertNotEquals = assertNotEqual + def _formatMessage(self, msg, std): + if self.longMessage and msg and msg.rstrip().endswith(":"): + return '%s %s' % (msg.rstrip(), std) + else: + return msg or std - #NOTE: overriding this even under UT2. - #FIXME: this doesn't support the fancy context manager UT2 provides. - def assertRaises(self, _exc_type, _callable, *args, **kwds): - #NOTE: overriding this for format ability, - # but ALSO adding "__msg__" kwd so we can set custom msg + #---------------------------------------------------------------- + # override assertRaises() to support '__msg__' keyword + #---------------------------------------------------------------- + def assertRaises(self, _exc_type, _callable=None, *args, **kwds): msg = kwds.pop("__msg__", None) if _callable is None: + # FIXME: this ignores 'msg' return super(TestCase, self).assertRaises(_exc_type, None, *args, **kwds) try: @@ -239,11 +231,46 @@ class TestCase(unittest.TestCase): _exc_type) raise self.failureException(self._formatMessage(msg, std)) - #=============================================================== - #backport some methods from unittest2 - #=============================================================== + #---------------------------------------------------------------- + # null out a bunch of deprecated aliases so I stop using them + #---------------------------------------------------------------- + assertEquals = assertNotEquals = assertRegexpMatches = None + + #==================================================================== + # backport some methods from unittest2 + #==================================================================== if ut_version < 2: + #---------------------------------------------------------------- + # simplistic backport of addCleanup() framework + #---------------------------------------------------------------- + _cleanups = None + + def addCleanup(self, function, *args, **kwds): + queue = self._cleanups + if queue is None: + queue = self._cleanups = [] + queue.append((function, args, kwds)) + + def doCleanups(self): + queue = self._cleanups + while queue: + func, args, kwds = queue.pop() + func(*args, **kwds) + + def tearDown(self): + self.doCleanups() + unittest.TestCase.tearDown(self) + + #---------------------------------------------------------------- + # backport skipTest (requires nose to work) + #---------------------------------------------------------------- + def skipTest(self, reason): + raise SkipTest(reason) + + #---------------------------------------------------------------- + # backport various assert tests added in unittest2 + #---------------------------------------------------------------- def assertIs(self, real, correct, msg=None): if real is not correct: std = "got %r, expected would be %r" % (real, correct) @@ -262,9 +289,6 @@ class TestCase(unittest.TestCase): msg = self._formatMessage(msg, std) raise self.failureException(msg) - def skipTest(self, reason): - raise SkipTest(reason) - def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None): """Fail if the two objects are unequal as determined by their difference rounded to the given number of decimal places @@ -303,45 +327,68 @@ class TestCase(unittest.TestCase): msg = self._formatMessage(msg, standardMsg) raise self.failureException(msg) + def assertLess(self, left, right, msg=None): + if left >= right: + std = "%r not less than %r" % (left, right) + raise self.failureException(self._formatMessage(msg, std)) + + def assertGreaterEqual(self, left, right, msg=None): + if left < right: + std = "%r less than %r" % (left, right) + raise self.failureException(self._formatMessage(msg, std)) + + def assertIn(self, elem, container, msg=None): + if elem not in container: + std = "%r not found in %r" % (elem, container) + raise self.failureException(self._formatMessage(msg, std)) + + def assertNotIn(self, elem, container, msg=None): + if elem in container: + std = "%r unexpectedly in %r" % (elem, container) + raise self.failureException(self._formatMessage(msg, std)) + + #---------------------------------------------------------------- + # override some unittest1 methods to support _formatMessage + #---------------------------------------------------------------- + def assertEqual(self, real, correct, msg=None): + if real != correct: + std = "got %r, expected would equal %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + def assertNotEqual(self, real, correct, msg=None): + if real == correct: + std = "got %r, expected would not equal %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + #---------------------------------------------------------------- + # backport assertRegex() alias from 3.2 to 2.7/3.1 + #---------------------------------------------------------------- if not hasattr(unittest.TestCase, "assertRegex"): - # assertRegexpMatches() added in 2.7/UT2 and 3.1, renamed to - # assertRegex() in 3.2; this code ensures assertRegex() is defined. if hasattr(unittest.TestCase, "assertRegexpMatches"): + # was present in 2.7/3.1 under name assertRegexpMatches assertRegex = unittest.TestCase.assertRegexpMatches else: + # 3.0 and <= 2.6 didn't have this method at all def assertRegex(self, text, expected_regex, msg=None): """Fail the test unless the text matches the regular expression.""" if isinstance(expected_regex, sb_types): assert expected_regex, "expected_regex must not be empty." expected_regex = re.compile(expected_regex) if not expected_regex.search(text): - msg = msg or "Regex didn't match" - msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text) - raise self.failureException(msg) + msg = msg or "Regex didn't match: " + std = '%r not found in %r' % (msg, expected_regex.pattern, text) + raise self.failureException(self._formatMessage(msg, std)) #============================================================ - #add some custom methods + # custom methods for matching warnings #============================================================ - def assertFunctionResults(self, func, cases): - """helper for running through function calls. - - func should be the function to call. - cases should be list of Param instances, - where first position argument is expected return value, - and remaining args and kwds are passed to function. - """ - for elem in cases: - elem = Params.norm(elem) - correct = elem.args[0] - result = func(*elem.args[1:], **elem.kwds) - msg = "error for case %r:" % (elem.render(1),) - self.assertEqual(result, correct, msg) - def assertWarning(self, warning, message_re=None, message=None, category=None, - ##filename=None, filename_re=None, - ##lineno=None, + filename_re=None, filename=None, + lineno=None, msg=None, ): "check if WarningMessage instance (as returned by catch_warnings) matches parameters" @@ -355,7 +402,7 @@ class TestCase(unittest.TestCase): # no original WarningMessage, passed raw Warning wmsg = None - #tests that can use a warning instance or WarningMessage object + # tests that can use a warning instance or WarningMessage object if message: self.assertEqual(str(warning), message, msg) if message_re: @@ -363,29 +410,29 @@ class TestCase(unittest.TestCase): if category: self.assertIsInstance(warning, category, msg) - #commented out until needed... - ###tests that require a WarningMessage object - ##if filename or filename_re: - ## if not wmsg: - ## raise TypeError("can't read filename from warning object") - ## real = wmsg.filename - ## if real.endswith(".pyc") or real.endswith(".pyo"): - ## #FIXME: should use a stdlib call to resolve this back - ## # to original module's path - ## real = real[:-1] - ## if filename: - ## self.assertEqual(real, filename, msg) - ## if filename_re: - ## self.assertRegex(real, filename_re, msg) - ##if lineno: - ## if not wmsg: - ## raise TypeError("can't read lineno from warning object") - ## self.assertEqual(wmsg.lineno, lineno, msg) + # tests that require a WarningMessage object + if filename or filename_re: + if not wmsg: + raise TypeError("matching on filename requires a " + "WarningMessage instance") + real = wmsg.filename + if real.endswith(".pyc") or real.endswith(".pyo"): + # FIXME: should use a stdlib call to resolve this back + # to module's original filename. + real = real[:-1] + if filename: + self.assertEqual(real, filename, msg) + if filename_re: + self.assertRegex(real, filename_re, msg) + if lineno: + if not wmsg: + raise TypeError("matching on lineno requires a " + "WarningMessage instance") + self.assertEqual(wmsg.lineno, lineno, msg) def assertWarningList(self, wlist, desc=None, msg=None): """check that warning list (e.g. from catch_warnings) matches pattern""" - # TODO: make this display better diff of *which* warnings did not match, - # and make use of _formatWarning below. + # TODO: make this display better diff of *which* warnings did not match if not isinstance(desc, (list,tuple)): desc = [] if desc is None else [desc] for idx, entry in enumerate(desc): @@ -407,6 +454,11 @@ class TestCase(unittest.TestCase): (len(desc), len(wlist), self._formatWarningList(wlist), desc) raise self.failureException(self._formatMessage(msg, std)) + def consumeWarningList(self, wlist, *args, **kwds): + """assertWarningList() variant that clears list afterwards""" + self.assertWarningList(wlist, *args, **kwds) + del wlist[:] + def _formatWarning(self, entry): tail = "" if hasattr(entry, "message"): @@ -422,10 +474,23 @@ class TestCase(unittest.TestCase): def _formatWarningList(self, wlist): return "[%s]" % ", ".join(self._formatWarning(entry) for entry in wlist) - def consumeWarningList(self, wlist, *args, **kwds): - """assertWarningList() variant that clears list afterwards""" - self.assertWarningList(wlist, *args, **kwds) - del wlist[:] + #============================================================ + # misc custom methods + #============================================================ + def assertFunctionResults(self, func, cases): + """helper for running through function calls. + + func should be the function to call. + cases should be list of Param instances, + where first position argument is expected return value, + and remaining args and kwds are passed to function. + """ + for elem in cases: + elem = Params.norm(elem) + correct = elem.args[0] + result = func(*elem.args[1:], **elem.kwds) + msg = "error for case %r:" % (elem.render(1),) + self.assertEqual(result, correct, msg) #============================================================ #eoc @@ -601,10 +666,8 @@ class HandlerCase(TestCase): #========================================================= __unittest_skip = True - #optional prefix to prepend to name of test method as it's called, - #useful when multiple handler test classes being run. - #default behavior should be sufficient - def case_prefix(self): + @property + def descriptionPrefix(self): handler = self.handler name = handler.name if hasattr(handler, "get_backend"): @@ -627,11 +690,7 @@ class HandlerCase(TestCase): # setup / cleanup #========================================================= def setUp(self): - # backup warning filter state; set to display all warnings during tests; - # and restore filter state after test. - ctx = catch_all_warnings() - ctx.__enter__() - self._restore_warnings = ctx.__exit__ + TestCase.setUp(self) # if needed, select specific backend for duration of test handler = self.handler @@ -641,7 +700,7 @@ class HandlerCase(TestCase): raise RuntimeError("handler doesn't support multiple backends") if backend == "os_crypt" and not handler.has_backend("os_crypt"): self._patch_safe_crypt() - self._orig_backend = handler.get_backend() + self.addCleanup(handler.set_backend, handler.get_backend()) handler.set_backend(backend) def _patch_safe_crypt(self): @@ -660,23 +719,10 @@ class HandlerCase(TestCase): hash = handler.genhash(secret, hash) assert isinstance(hash, str) return hash - self._orig_crypt = mod._crypt + self.addCleanup(setattr, mod, "_crypt", mod._crypt) mod._crypt = crypt_stub self.using_patched_crypt = True - def tearDown(self): - # unpatch safe_crypt() - if self._orig_crypt: - import passlib.utils as mod - mod._crypt = self._orig_crypt - - # restore original backend - if self._orig_backend: - self.handler.set_backend(self._orig_backend) - - # restore warning filters - self._restore_warnings() - #========================================================= # basic tests #========================================================= @@ -1468,11 +1514,8 @@ class HandlerCase(TestCase): "config=%r hash=%r" % (name, other, secret, kwds, hash)) count +=1 - name = self.case_prefix - if not isinstance(name, str): - name = name() log.debug("fuzz test: %r checked %d passwords against %d verifiers (%s)", - name, count, len(verifiers), + self.descriptionPrefix, count, len(verifiers), ", ".join(vname(v) for v in verifiers)) def get_fuzz_verifiers(self): @@ -1745,7 +1788,7 @@ def create_backend_case(base_class, backend, module=None): "%s_%s" % (backend, handler.name), (base_class,), dict( - case_prefix = "%s (%s backend)" % (handler.name, backend), + descriptionPrefix = "%s (%s backend)" % (handler.name, backend), backend = backend, __module__= module or base_class.__module__, ) @@ -1802,9 +1845,11 @@ def mktemp(*args, **kwds): os.close(fd) return path -#========================================================= -#make sure catch_warnings() is available -#========================================================= +#============================================================================= +# warnings helpers +#============================================================================= + +# make sure catch_warnings() is available try: from warnings import catch_warnings except ImportError: @@ -1893,33 +1938,59 @@ except ImportError: self._module.filters = self._filters self._module.showwarning = self._showwarning -class catch_all_warnings(catch_warnings): - "catch_warnings() wrapper which clears filter" - def __init__(self, reset=".*", **kwds): - super(catch_all_warnings, self).__init__(**kwds) - self._reset_pat = reset +class reset_warnings(catch_warnings): + "catch_warnings() wrapper which clears warning registry & filters" + def __init__(self, reset_filter="always", reset_registry=".*", **kwds): + super(reset_warnings, self).__init__(**kwds) + self._reset_filter = reset_filter + self._reset_registry = re.compile(reset_registry) if reset_registry else None def __enter__(self): # let parent class archive filter state - ret = super(catch_all_warnings, self).__enter__() + ret = super(reset_warnings, self).__enter__() # reset the filter to list everything - warnings.resetwarnings() - warnings.simplefilter("always") + if self._reset_filter: + warnings.resetwarnings() + warnings.simplefilter(self._reset_filter) - # wipe the __warningregistry__ off the map, so warnings - # reliably get reported per-test. - # XXX: *could* restore state - pattern = self._reset_pat + # archive and clear the __warningregistry__ key for all modules + # that match the 'reset' pattern. + pattern = self._reset_registry if pattern: - import sys - key = "__warningregistry__" - for mod in sys.modules.values(): - if hasattr(mod, key) and re.match(pattern, mod.__name__): - getattr(mod, key).clear() - + orig = self._orig_registry = {} + for name, mod in sys.modules.items(): + if pattern.match(name): + reg = getattr(mod, "__warningregistry__", None) + if reg: + orig[name] = reg.copy() + reg.clear() return ret -#========================================================= -#EOF -#========================================================= + def __exit__(self, *exc_info): + # restore warning registry for all modules + pattern = self._reset_registry + if pattern: + # restore archived registry data + orig = self._orig_registry + for name, content in iteritems(orig): + mod = sys.modules.get(name) + if mod is None: + continue + reg = getattr(mod, "__warningregistry__", None) + if reg is None: + setattr(mod, "__warningregistry__", content) + else: + reg.clear() + reg.update(content) + # clear all registry entries that we didn't archive + for name, mod in sys.modules.items(): + if pattern.match(name) and name not in orig: + reg = getattr(mod, "__warningregistry__", None) + if reg: + reg.clear() + super(reset_warnings, self).__exit__(*exc_info) + +#============================================================================= +# eof +#============================================================================= |