diff options
-rw-r--r-- | docs/lib/passlib.utils.compat.rst | 18 | ||||
-rw-r--r-- | passlib/apache.py | 4 | ||||
-rw-r--r-- | passlib/context.py | 7 | ||||
-rw-r--r-- | passlib/handlers/cisco.py | 10 | ||||
-rw-r--r-- | passlib/handlers/des_crypt.py | 4 | ||||
-rw-r--r-- | passlib/handlers/mysql.py | 4 | ||||
-rw-r--r-- | passlib/handlers/sha2_crypt.py | 4 | ||||
-rw-r--r-- | passlib/handlers/sun_md5_crypt.py | 6 | ||||
-rw-r--r-- | passlib/tests/genconfig.py | 2 | ||||
-rw-r--r-- | passlib/tests/test_apache.py | 4 | ||||
-rw-r--r-- | passlib/tests/test_context.py | 26 | ||||
-rw-r--r-- | passlib/tests/test_ext_django.py | 8 | ||||
-rw-r--r-- | passlib/tests/test_handlers.py | 4 | ||||
-rw-r--r-- | passlib/tests/test_registry.py | 8 | ||||
-rw-r--r-- | passlib/tests/test_utils.py | 20 | ||||
-rw-r--r-- | passlib/tests/test_utils_handlers.py | 10 | ||||
-rw-r--r-- | passlib/tests/utils.py | 431 | ||||
-rw-r--r-- | passlib/utils/__init__.py | 44 | ||||
-rw-r--r-- | passlib/utils/compat.py | 178 | ||||
-rw-r--r-- | passlib/utils/des.py | 22 | ||||
-rw-r--r-- | passlib/utils/handlers.py | 6 |
21 files changed, 448 insertions, 372 deletions
diff --git a/docs/lib/passlib.utils.compat.rst b/docs/lib/passlib.utils.compat.rst index 5add85d..2536b0a 100644 --- a/docs/lib/passlib.utils.compat.rst +++ b/docs/lib/passlib.utils.compat.rst @@ -17,27 +17,27 @@ Unicode Helpers .. autofunction:: uascii_to_str .. autofunction:: str_to_uascii -.. function:: ujoin +.. function:: join_unicode Join a sequence of unicode strings, e.g. - ``ujoin([u"a",u"b",u"c"]) -> u"abc"``. + ``join_unicode([u"a",u"b",u"c"]) -> u"abc"``. Bytes Helpers ============= .. autofunction:: bascii_to_str .. autofunction:: str_to_bascii -.. function:: bjoin +.. function:: join_bytes Join a sequence of byte strings, e.g. - ``bjoin([b"a",b"b",b"c"]) -> b"abc"``. + ``join_bytes([b"a",b"b",b"c"]) -> b"abc"``. -.. function:: bjoin_ints +.. function:: join_byte_values Join a sequence of integers into a byte string, - e.g. ``bjoin_ints([97,98,99]) -> b"abc"``. + e.g. ``join_byte_values([97,98,99]) -> b"abc"``. -.. function:: bjoin_elems +.. function:: join_byte_elems Join a sequence of byte elements into a byte string. @@ -49,8 +49,8 @@ Bytes Helpers This function will join a sequence of the appropriate type for the given python version -- under Python 2, this is an alias - for :func:`bjoin`, under Python 3 this is an alias for :func:`bjoin_ints`. + for :func:`join_bytes`, under Python 3 this is an alias for :func:`join_byte_values`. -.. function:: belem_ord +.. function:: byte_elem_value Function to convert byte element to integer (a no-op under PY3) diff --git a/passlib/apache.py b/passlib/apache.py index 63fa39b..05f4b68 100644 --- a/passlib/apache.py +++ b/passlib/apache.py @@ -12,7 +12,7 @@ import sys #libs from passlib.context import CryptContext from passlib.utils import consteq, render_bytes -from passlib.utils.compat import b, bytes, bjoin, lmap, str_to_bascii, u, unicode +from passlib.utils.compat import b, bytes, join_bytes, lmap, str_to_bascii, u, unicode #pkg #local __all__ = [ @@ -148,7 +148,7 @@ class _CommonFile(object): def to_string(self): "export whole database as a byte string" - return bjoin(self._iter_lines()) + return join_bytes(self._iter_lines()) #subclass: _render_line(entry) -> line diff --git a/passlib/context.py b/passlib/context.py index e1b5de5..8522a69 100644 --- a/passlib/context.py +++ b/passlib/context.py @@ -27,7 +27,7 @@ from passlib.utils import is_crypt_handler, rng, saslprep, tick, to_bytes, \ to_unicode from passlib.utils.compat import bytes, is_mapping, iteritems, num_types, \ PY3, PY_MIN_32, unicode, SafeConfigParser, \ - StringIO, BytesIO + NativeStringIO, BytesIO #pkg #local __all__ = [ @@ -145,11 +145,10 @@ class CryptPolicy(object): """ if PY3: source = to_unicode(source, encoding, errname="source") - return cls._from_stream(StringIO(source), section, "<???>") else: source = to_bytes(source, "utf-8", source_encoding=encoding, errname="source") - return cls._from_stream(BytesIO(source), section, "<???>") + return cls._from_stream(NativeStringIO(source), section, "<???>") @classmethod def _from_stream(cls, stream, section, filename=None): @@ -698,7 +697,7 @@ class CryptPolicy(object): def to_string(self, section="passlib", encoding=None): "render to INI string; inverse of from_string() constructor" - buf = StringIO() if PY3 else BytesIO() + buf = NativeStringIO() self.to_file(buf, section) out = buf.getvalue() if not PY3: diff --git a/passlib/handlers/cisco.py b/passlib/handlers/cisco.py index 23e79b4..102d049 100644 --- a/passlib/handlers/cisco.py +++ b/passlib/handlers/cisco.py @@ -11,8 +11,8 @@ from warnings import warn #libs #pkg from passlib.utils import h64, to_bytes -from passlib.utils.compat import b, bascii_to_str, unicode, u, bjoin_ints, \ - bjoin_elems, belem_ord, biter_ints, uascii_to_str, str_to_uascii +from passlib.utils.compat import b, bascii_to_str, unicode, u, join_byte_values, \ + join_byte_elems, byte_elem_value, iter_byte_values, uascii_to_str, str_to_uascii import passlib.utils.handlers as uh #local __all__ = [ @@ -75,7 +75,7 @@ class cisco_pix(uh.HasUserContext, uh.StaticHandler): hash = md5(secret).digest() # drop every 4th byte - hash = bjoin_elems(c for i,c in enumerate(hash) if i & 3 < 3) + hash = join_byte_elems(c for i,c in enumerate(hash) if i & 3 < 3) # encode using Hash64 return h64.encode_bytes(hash).decode("ascii") @@ -185,9 +185,9 @@ class cisco_type7(uh.GenericHandler): "xor static key against data - encrypts & decrypts" key = cls._key key_size = len(key) - return bjoin_ints( + return join_byte_values( value ^ ord(key[(salt + idx) % key_size]) - for idx, value in enumerate(biter_ints(data)) + for idx, value in enumerate(iter_byte_values(data)) ) #========================================================= diff --git a/passlib/handlers/des_crypt.py b/passlib/handlers/des_crypt.py index 5e3cb4c..f663a39 100644 --- a/passlib/handlers/des_crypt.py +++ b/passlib/handlers/des_crypt.py @@ -59,7 +59,7 @@ from warnings import warn #site #libs from passlib.utils import classproperty, h64, h64big, safe_crypt, test_crypt -from passlib.utils.compat import b, bytes, belem_ord, u, uascii_to_str, unicode +from passlib.utils.compat import b, bytes, byte_elem_value, u, uascii_to_str, unicode from passlib.utils.des import mdes_encrypt_int_block import passlib.utils.handlers as uh #pkg @@ -77,7 +77,7 @@ __all__ = [ def _crypt_secret_to_key(secret): "crypt helper which converts lower 7 bits of first 8 chars of secret -> 56-bit des key, padded to 64 bits" return sum( - (belem_ord(c) & 0x7f) << (57-8*i) + (byte_elem_value(c) & 0x7f) << (57-8*i) for i, c in enumerate(secret[:8]) ) diff --git a/passlib/handlers/mysql.py b/passlib/handlers/mysql.py index cea160e..7bbaeb2 100644 --- a/passlib/handlers/mysql.py +++ b/passlib/handlers/mysql.py @@ -32,7 +32,7 @@ from warnings import warn #pkg from passlib.utils import to_native_str, to_bytes from passlib.utils.compat import b, bascii_to_str, bytes, unicode, u, \ - belem_ord, str_to_uascii + byte_elem_value, str_to_uascii import passlib.utils.handlers as uh #local __all__ = [ @@ -78,7 +78,7 @@ class mysql323(uh.StaticHandler): for c in secret: if c in WHITE: continue - tmp = belem_ord(c) + tmp = byte_elem_value(c) nr1 ^= ((((nr1 & 63)+add)*tmp) + (nr1 << 8)) & MASK_32 nr2 = (nr2+((nr2 << 8) ^ nr1)) & MASK_32 add = (add+tmp) & MASK_32 diff --git a/passlib/handlers/sha2_crypt.py b/passlib/handlers/sha2_crypt.py index bffa5c1..56ed086 100644 --- a/passlib/handlers/sha2_crypt.py +++ b/passlib/handlers/sha2_crypt.py @@ -10,7 +10,7 @@ from warnings import warn #site #libs from passlib.utils import classproperty, h64, safe_crypt, test_crypt -from passlib.utils.compat import b, bytes, belem_ord, irange, u, \ +from passlib.utils.compat import b, bytes, byte_elem_value, irange, u, \ uascii_to_str, unicode import passlib.utils.handlers as uh #pkg @@ -97,7 +97,7 @@ def _raw_sha_crypt(secret, salt, rounds, hash): dp = extend(tmp.digest(), secret) #calc DS - hash of salt, extended to size of salt - tmp = hash(salt * (16+belem_ord(a[0]))) + tmp = hash(salt * (16+byte_elem_value(a[0]))) ds = extend(tmp.digest(), salt) # diff --git a/passlib/handlers/sun_md5_crypt.py b/passlib/handlers/sun_md5_crypt.py index 7d689e7..184ad1a 100644 --- a/passlib/handlers/sun_md5_crypt.py +++ b/passlib/handlers/sun_md5_crypt.py @@ -18,7 +18,7 @@ from warnings import warn #site #libs from passlib.utils import h64 -from passlib.utils.compat import b, bytes, belem_ord, trange, u, \ +from passlib.utils.compat import b, bytes, byte_elem_value, irange, u, \ uascii_to_str, unicode, str_to_bascii import passlib.utils.handlers as uh #pkg @@ -73,7 +73,7 @@ MAGIC_HAMLET = b( ) #NOTE: these sequences are pre-calculated iteration ranges used by X & Y loops w/in rounds function below -xr = trange(7) +xr = irange(7) _XY_ROUNDS = [ tuple((i,i,i+3) for i in xr), #xrounds 0 tuple((i,i+1,i+4) for i in xr), #xrounds 1 @@ -116,7 +116,7 @@ def raw_sun_md5_crypt(secret, rounds, salt): round = 0 while round < real_rounds: #convert last result byte string to list of byte-ints for easy access - rval = [ belem_ord(c) for c in result ].__getitem__ + rval = [ byte_elem_value(c) for c in result ].__getitem__ #build up X bit by bit x = 0 diff --git a/passlib/tests/genconfig.py b/passlib/tests/genconfig.py index dc87d02..83f016b 100644 --- a/passlib/tests/genconfig.py +++ b/passlib/tests/genconfig.py @@ -77,7 +77,7 @@ class HashTimer(object): # self.samples = samples self.cache = {} - self.srange = trange(samples) + self.srange = irange(samples) def time_encrypt(self, rounds): "check how long encryption for a given number of rounds will take" diff --git a/passlib/tests/test_apache.py b/passlib/tests/test_apache.py index 052c42e..d3b4ab8 100644 --- a/passlib/tests/test_apache.py +++ b/passlib/tests/test_apache.py @@ -30,7 +30,7 @@ def backdate_file_mtime(path, offset=10): #========================================================= class HtpasswdFileTest(TestCase): "test HtpasswdFile class" - case_prefix = "HtpasswdFile" + descriptionPrefix = "HtpasswdFile" sample_01 = b('user2:2CHkkwa2AtqGs\nuser3:{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=\nuser4:pass4\nuser1:$apr1$t4tc7jTh$GPIWVUo8sQKJlUdV8V5vu0\n') sample_02 = b('user3:{SHA}3ipNV1GrBtxPmHFC21fCbVCSXIo=\nuser4:pass4\n') @@ -205,7 +205,7 @@ class HtpasswdFileTest(TestCase): #========================================================= class HtdigestFileTest(TestCase): "test HtdigestFile class" - case_prefix = "HtdigestFile" + descriptionPrefix = "HtdigestFile" sample_01 = b('user2:realm:549d2a5f4659ab39a80dac99e159ab19\nuser3:realm:a500bb8c02f6a9170ae46af10c898744\nuser4:realm:ab7b5d5f28ccc7666315f508c7358519\nuser1:realm:2a6cf53e7d8f8cf39d946dc880b14128\n') sample_02 = b('user3:realm:a500bb8c02f6a9170ae46af10c898744\nuser4:realm:ab7b5d5f28ccc7666315f508c7358519\n') diff --git a/passlib/tests/test_context.py b/passlib/tests/test_context.py index 46cd3f5..f9edafc 100644 --- a/passlib/tests/test_context.py +++ b/passlib/tests/test_context.py @@ -22,7 +22,7 @@ from passlib.exc import PasslibConfigWarning from passlib.utils import tick, to_bytes, to_unicode from passlib.utils.compat import irange, u import passlib.utils.handlers as uh -from passlib.tests.utils import TestCase, mktemp, catch_all_warnings, \ +from passlib.tests.utils import TestCase, mktemp, catch_warnings, \ gae_env, set_file from passlib.registry import register_crypt_handler_path, has_crypt_handler, \ _unload_handler_name as unload_handler_name @@ -37,7 +37,7 @@ class CryptPolicyTest(TestCase): #TODO: need to test user categories w/in all this - case_prefix = "CryptPolicy" + descriptionPrefix = "CryptPolicy" #========================================================= #sample crypt policies used for testing @@ -472,11 +472,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt def test_15_min_verify_time(self): "test get_min_verify_time() method" # silence deprecation warnings for min verify time - with catch_all_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - self._test_15() - - def _test_15(self): + warnings.filterwarnings("ignore", category=DeprecationWarning) pa = CryptPolicy() self.assertEqual(pa.get_min_verify_time(), 0) @@ -526,7 +522,7 @@ admin__context__deprecated = des_crypt, bsdi_crypt #========================================================= class CryptContextTest(TestCase): "test CryptContext object's behavior" - case_prefix = "CryptContext" + descriptionPrefix = "CryptContext" #========================================================= #constructor @@ -632,7 +628,7 @@ class CryptContextTest(TestCase): ) # min rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # set below handler min c2 = cc.replace(all__min_rounds=500, all__max_rounds=None, @@ -663,7 +659,7 @@ class CryptContextTest(TestCase): self.consumeWarningList(wlog) # max rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # set above handler max c2 = cc.replace(all__max_rounds=int(1e9)+500, all__min_rounds=None, all__default_rounds=int(1e9)+500) @@ -824,7 +820,7 @@ class CryptContextTest(TestCase): # which is much cheaper, and shares the same codebase. # min rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: self.assertEqual( cc.encrypt("password", rounds=1999, salt="nacl"), '$5$rounds=2000$nacl$9/lTZ5nrfPuz8vphznnmHuDGFuvjSNvOEDsGmGfsS97', @@ -950,7 +946,7 @@ class CryptContextTest(TestCase): return to_unicode(secret + 'x') # silence deprecation warnings for min verify time - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: cc = CryptContext([TimedHash], min_verify_time=min_verify_time) self.consumeWarningList(wlog, DeprecationWarning) @@ -977,7 +973,7 @@ class CryptContextTest(TestCase): #ensure taking longer emits a warning. TimedHash.delay = max_delay - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: elapsed, result = timecall(cc.verify, "blob", "stubx") self.assertFalse(result) self.assertAlmostEqual(elapsed, max_delay, delta=delta) @@ -1024,7 +1020,7 @@ class CryptContextTest(TestCase): GOOD1 = "$2a$12$oaQbBqq8JnSM1NHRPQGXOOm4GCUMqp7meTnkft4zgSnrbhoKdDV0C" ctx = CryptContext(["bcrypt"]) - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: self.assertTrue(ctx.hash_needs_update(BAD1)) self.assertFalse(ctx.hash_needs_update(GOOD1)) @@ -1101,7 +1097,7 @@ class dummy_2(uh.StaticHandler): name = "dummy_2" class LazyCryptContextTest(TestCase): - case_prefix = "LazyCryptContext" + descriptionPrefix = "LazyCryptContext" def setUp(self): unload_handler_name("dummy_2") diff --git a/passlib/tests/test_ext_django.py b/passlib/tests/test_ext_django.py index 4627d30..13085fc 100644 --- a/passlib/tests/test_ext_django.py +++ b/passlib/tests/test_ext_django.py @@ -13,7 +13,7 @@ from passlib.context import CryptContext, CryptPolicy from passlib.apps import django_context from passlib.ext.django import utils from passlib.hash import sha256_crypt -from passlib.tests.utils import TestCase, unittest, ut_version, catch_all_warnings +from passlib.tests.utils import TestCase, unittest, ut_version, catch_warnings import passlib.tests.test_handlers as th from passlib.utils.compat import iteritems, get_method_function, unicode from passlib.registry import get_crypt_handler @@ -130,7 +130,7 @@ def get_cc_rounds(**kwds): class PatchTest(TestCase): "test passlib.ext.django.utils:set_django_password_context" - case_prefix = "passlib.ext.django utils" + descriptionPrefix = "passlib.ext.django utils" def assert_unpatched(self): "helper to ensure django hasn't been patched" @@ -223,7 +223,7 @@ class PatchTest(TestCase): def dummy(): pass - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: #patch to use stock django context utils.set_django_password_context(django_context) self.assert_patched(context=django_context) @@ -422,7 +422,7 @@ if has_django0: class PluginTest(TestCase): "test django plugin via settings" - case_prefix = "passlib.ext.django plugin" + descriptionPrefix = "passlib.ext.django plugin" def setUp(self): #remove django patch diff --git a/passlib/tests/test_handlers.py b/passlib/tests/test_handlers.py index 580cc08..4547e56 100644 --- a/passlib/tests/test_handlers.py +++ b/passlib/tests/test_handlers.py @@ -12,7 +12,7 @@ import warnings from passlib import hash from passlib.utils.compat import irange from passlib.tests.utils import TestCase, HandlerCase, create_backend_case, \ - enable_option, b, catch_all_warnings, UserHandlerMixin, randintgauss + enable_option, b, catch_warnings, UserHandlerMixin, randintgauss from passlib.utils.compat import u #module @@ -262,7 +262,7 @@ class _bcrypt_test(HandlerCase): check_padding(bcrypt.encrypt("bob", rounds=bcrypt.min_rounds)) # some things that will raise warnings - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # # test genconfig() corrects invalid salts & issues warning. # diff --git a/passlib/tests/test_registry.py b/passlib/tests/test_registry.py index c6fa9f6..1990919 100644 --- a/passlib/tests/test_registry.py +++ b/passlib/tests/test_registry.py @@ -16,7 +16,7 @@ from passlib import hash, registry from passlib.registry import register_crypt_handler, register_crypt_handler_path, \ get_crypt_handler, list_crypt_handlers, _unload_handler_name as unload_handler_name import passlib.utils.handlers as uh -from passlib.tests.utils import TestCase, mktemp, catch_all_warnings +from passlib.tests.utils import TestCase, mktemp, catch_warnings #module log = getLogger(__name__) @@ -40,7 +40,7 @@ dummy_x = 1 #========================================================= class RegistryTest(TestCase): - case_prefix = "passlib registry" + descriptionPrefix = "passlib registry" def tearDown(self): for name in ("dummy_0", "dummy_1", "dummy_x", "dummy_bad"): @@ -112,7 +112,7 @@ class RegistryTest(TestCase): #TODO: check lazy load which calls register_crypt_handler (warning should be issued) sys.modules.pop("passlib.tests._test_bad_register", None) register_crypt_handler_path("dummy_bad", "passlib.tests._test_bad_register") - with catch_all_warnings(): + with catch_warnings(): warnings.filterwarnings("ignore", "xxxxxxxxxx", DeprecationWarning) h = get_crypt_handler("dummy_bad") from passlib.tests import _test_bad_register as tbr @@ -158,7 +158,7 @@ class RegistryTest(TestCase): register_crypt_handler(dummy_1) self.assertIs(get_crypt_handler("dummy_1"), dummy_1) - with catch_all_warnings(): + with catch_warnings(): warnings.filterwarnings("ignore", "handler names should be lower-case, and use underscores instead of hyphens:.*", UserWarning) self.assertIs(get_crypt_handler("DUMMY-1"), dummy_1) diff --git a/passlib/tests/test_utils.py b/passlib/tests/test_utils.py index 1baebd7..2d9d147 100644 --- a/passlib/tests/test_utils.py +++ b/passlib/tests/test_utils.py @@ -12,8 +12,8 @@ import warnings #pkg #module from passlib.utils.compat import b, bytes, bascii_to_str, irange, PY2, PY3, u, \ - unicode, bjoin -from passlib.tests.utils import TestCase, Params as ak, enable_option, catch_all_warnings + unicode, join_bytes +from passlib.tests.utils import TestCase, Params as ak, enable_option, catch_warnings def hb(source): return unhexlify(b(source)) @@ -546,7 +546,7 @@ class _Base64Test(TestCase): # helper to generate bytemap-specific strings def m(self, *offsets): "generate byte string from offsets" - return bjoin(self.engine.bytemap[o:o+1] for o in offsets) + return join_bytes(self.engine.bytemap[o:o+1] for o in offsets) #========================================================= # test encode_bytes @@ -842,7 +842,7 @@ from passlib.utils import h64, h64big class H64_Test(_Base64Test): "test H64 codec functions" engine = h64 - case_prefix = "h64 codec" + descriptionPrefix = "h64 codec" encoded_data = [ #test lengths 0..6 to ensure tail is encoded properly @@ -867,7 +867,7 @@ class H64_Test(_Base64Test): class H64Big_Test(_Base64Test): "test H64Big codec functions" engine = h64big - case_prefix = "h64big codec" + descriptionPrefix = "h64big codec" encoded_data = [ #test lengths 0..6 to ensure tail is encoded properly @@ -960,12 +960,12 @@ has_ssl_md4 = (md4_mod.md4 is not md4_mod._builtin_md4) if has_ssl_md4: class MD4_SSL_Test(_MD4_Test): - case_prefix = "MD4 (SSL version)" + descriptionPrefix = "MD4 (SSL version)" hash = staticmethod(md4_mod.md4) if not has_ssl_md4 or enable_option("cover"): class MD4_Builtin_Test(_MD4_Test): - case_prefix = "MD4 (builtin version)" + descriptionPrefix = "MD4 (builtin version)" hash = md4_mod._builtin_md4 #========================================================= @@ -1010,7 +1010,7 @@ class CryptoTest(TestCase): self.assertRaises(TypeError, norm_hash_name, None) # test selected results - with catch_all_warnings(): + with catch_warnings(): warnings.filterwarnings("ignore", '.*unknown hash') for row in chain(_nhn_hash_names, self.ndn_values): for idx, format in enumerate(self.ndn_formats): @@ -1227,12 +1227,12 @@ has_m2crypto = (pbkdf2._EVP is not None) if has_m2crypto: class Pbkdf2_M2Crypto_Test(_Pbkdf2BackendTest): - case_prefix = "pbkdf2 (m2crypto backend)" + descriptionPrefix = "pbkdf2 (m2crypto backend)" enable_m2crypto = True if not has_m2crypto or enable_option("cover"): class Pbkdf2_Builtin_Test(_Pbkdf2BackendTest): - case_prefix = "pbkdf2 (builtin backend)" + descriptionPrefix = "pbkdf2 (builtin backend)" enable_m2crypto = False #========================================================= diff --git a/passlib/tests/test_utils_handlers.py b/passlib/tests/test_utils_handlers.py index 7079917..4788798 100644 --- a/passlib/tests/test_utils_handlers.py +++ b/passlib/tests/test_utils_handlers.py @@ -18,7 +18,7 @@ from passlib.utils import getrandstr, JYTHON, rng, to_unicode from passlib.utils.compat import b, bytes, bascii_to_str, str_to_uascii, \ uascii_to_str, unicode, PY_MAX_25 import passlib.utils.handlers as uh -from passlib.tests.utils import HandlerCase, TestCase, catch_all_warnings +from passlib.tests.utils import HandlerCase, TestCase, catch_warnings from passlib.utils.compat import u #module log = getLogger(__name__) @@ -190,7 +190,7 @@ class SkeletonTest(TestCase): self.assertIn(norm_salt(use_defaults=True), salts3) # check explicit salts - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # check too-small salts self.assertRaises(ValueError, norm_salt, salt='') @@ -211,7 +211,7 @@ class SkeletonTest(TestCase): self.consumeWarningList(wlog, PasslibHashWarning) #check generated salts - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # check too-small salt size self.assertRaises(ValueError, gen_salt, 0) @@ -233,7 +233,7 @@ class SkeletonTest(TestCase): # test with max_salt_size=None del d1.max_salt_size - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: self.assertEqual(len(gen_salt(None)), 3) self.assertEqual(len(gen_salt(5)), 5) self.consumeWarningList(wlog) @@ -259,7 +259,7 @@ class SkeletonTest(TestCase): self.assertEqual(norm_rounds(use_defaults=True), 2) # check explicit rounds - with catch_all_warnings(record=True) as wlog: + with catch_warnings(record=True) as wlog: # too small self.assertRaises(ValueError, norm_rounds, rounds=0) self.consumeWarningList(wlog) diff --git a/passlib/tests/utils.py b/passlib/tests/utils.py index 2f36d3e..8e2883b 100644 --- a/passlib/tests/utils.py +++ b/passlib/tests/utils.py @@ -22,7 +22,10 @@ except ImportError: if PY27 or PY_MIN_32: ut_version = 2 else: - # XXX: issue warning and deprecate support sometime? + # older versions of python will need to install the unittest2 + # backport (named unittest2_3k for 3.0/3.1) + warn("please install unittest2 for python %d.%d, it will be required " + "as of passlib 1.7" % sys.version_info[:2]) ut_version = 1 import warnings @@ -136,99 +139,88 @@ def get_file(path): class TestCase(unittest.TestCase): """passlib-specific test case class - this class mainly overriddes many of the common assert methods - so to give a default message which includes the values - as well as the class-specific case_prefix string. - this latter bit makes the output of various test cases - easier to distinguish from eachother. + this class adds a number of features to the standard TestCase... + * common prefix for all test descriptions + * resets warnings filter & registry for every test + * tweaks to message formatting + * __msg__ kwd added to assertRaises() + * backport of a bunch of unittest2 features + * suite of methods for matching against warnings """ + #==================================================================== + # add various custom features + #==================================================================== - #============================================================= - #make it ease for test cases to add common prefix to all descs - #============================================================= - #: string or method returning string - prepended to all tests in TestCase - case_prefix = None + #---------------------------------------------------------------- + # make it easy for test cases to add common prefix to shortDescription + #---------------------------------------------------------------- - #: flag to disable feature - longDescription = True + # string prepended to all tests in TestCase + descriptionPrefix = None def shortDescription(self): - "wrap shortDescription() method to prepend case_prefix" + "wrap shortDescription() method to prepend descriptionPrefix" desc = super(TestCase, self).shortDescription() - if desc is None: - #would still like to add prefix, but munges things up. - return None - prefix = self.case_prefix - if prefix and self.longDescription: - if callable(prefix): - prefix = prefix() - desc = "%s: %s" % (prefix, desc) + prefix = self.descriptionPrefix + if prefix: + desc = "%s: %s" % (prefix, desc or str(self)) return desc - #============================================================ - #hack to set UT2 private skip attrs to mirror nose's __test__ attr - #============================================================ - if ut_version >= 2: + #---------------------------------------------------------------- + # hack things so nose and ut2 both skip subclasses who have + # "__unittest_skip=True" set, or whose names start with "_" + #---------------------------------------------------------------- + @classproperty + def __unittest_skip__(cls): + # NOTE: this attr is technically a unittest2 internal detail. + name = cls.__name__ + return name.startswith("_") or \ + getattr(cls, "_%s__unittest_skip" % name, False) - @classproperty - def __unittest_skip__(cls): - # make this mirror nose's '__test__' attr - return not getattr(cls, "__test__", True) + # make this mirror nose's '__test__' attr + return not getattr(cls, "__test__", True) @classproperty def __test__(cls): - # nose uses to this to skip tests. overridding this to - # skip classes with '__<cls>_unittest_skip' set - that way - # we can omit specific classes without affecting subclasses. - name = cls.__name__ - if name.startswith("_"): - return False - if getattr(cls, "_%s__unittest_skip" % name, False): - return False - return True + # make nose just proxy __unittest_skip__ + return not cls.__unittest_skip__ + # flag to skip *this* class __unittest_skip = True - #============================================================ - # tweak msg formatting for some assert methods - #============================================================ - longMessage = True #override python default (False) - - def _formatMessage(self, msg, std): - "override UT2's _formatMessage - only use longMessage if msg ends with ':'" - if not msg: - return std - if not self.longMessage or not msg.endswith(":"): - return msg.rstrip(":") - return '%s %s' % (msg, std) + #---------------------------------------------------------------- + # reset warning filters & registry before each test + #---------------------------------------------------------------- - #============================================================ - #override some unittest1 methods to support _formatMessage - #============================================================ - if ut_version < 2: + # flag to enable this feature + resetWarningState = True - def assertEqual(self, real, correct, msg=None): - if real != correct: - std = "got %r, expected would equal %r" % (real, correct) - msg = self._formatMessage(msg, std) - raise self.failureException(msg) - - def assertNotEqual(self, real, correct, msg=None): - if real == correct: - std = "got %r, expected would not equal %r" % (real, correct) - msg = self._formatMessage(msg, std) - raise self.failureException(msg) + def setUp(self): + unittest.TestCase.setUp(self) + if self.resetWarningState: + ctx = reset_warnings() + ctx.__enter__() + self.addCleanup(ctx.__exit__) + + #---------------------------------------------------------------- + # tweak message formatting so longMessage mode is only enabled + # if msg ends with ":", and turn on longMessage by default. + #---------------------------------------------------------------- + longMessage = True - assertEquals = assertEqual - assertNotEquals = assertNotEqual + def _formatMessage(self, msg, std): + if self.longMessage and msg and msg.rstrip().endswith(":"): + return '%s %s' % (msg.rstrip(), std) + else: + return msg or std - #NOTE: overriding this even under UT2. - #FIXME: this doesn't support the fancy context manager UT2 provides. - def assertRaises(self, _exc_type, _callable, *args, **kwds): - #NOTE: overriding this for format ability, - # but ALSO adding "__msg__" kwd so we can set custom msg + #---------------------------------------------------------------- + # override assertRaises() to support '__msg__' keyword + #---------------------------------------------------------------- + def assertRaises(self, _exc_type, _callable=None, *args, **kwds): msg = kwds.pop("__msg__", None) if _callable is None: + # FIXME: this ignores 'msg' return super(TestCase, self).assertRaises(_exc_type, None, *args, **kwds) try: @@ -239,11 +231,46 @@ class TestCase(unittest.TestCase): _exc_type) raise self.failureException(self._formatMessage(msg, std)) - #=============================================================== - #backport some methods from unittest2 - #=============================================================== + #---------------------------------------------------------------- + # null out a bunch of deprecated aliases so I stop using them + #---------------------------------------------------------------- + assertEquals = assertNotEquals = assertRegexpMatches = None + + #==================================================================== + # backport some methods from unittest2 + #==================================================================== if ut_version < 2: + #---------------------------------------------------------------- + # simplistic backport of addCleanup() framework + #---------------------------------------------------------------- + _cleanups = None + + def addCleanup(self, function, *args, **kwds): + queue = self._cleanups + if queue is None: + queue = self._cleanups = [] + queue.append((function, args, kwds)) + + def doCleanups(self): + queue = self._cleanups + while queue: + func, args, kwds = queue.pop() + func(*args, **kwds) + + def tearDown(self): + self.doCleanups() + unittest.TestCase.tearDown(self) + + #---------------------------------------------------------------- + # backport skipTest (requires nose to work) + #---------------------------------------------------------------- + def skipTest(self, reason): + raise SkipTest(reason) + + #---------------------------------------------------------------- + # backport various assert tests added in unittest2 + #---------------------------------------------------------------- def assertIs(self, real, correct, msg=None): if real is not correct: std = "got %r, expected would be %r" % (real, correct) @@ -262,9 +289,6 @@ class TestCase(unittest.TestCase): msg = self._formatMessage(msg, std) raise self.failureException(msg) - def skipTest(self, reason): - raise SkipTest(reason) - def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None): """Fail if the two objects are unequal as determined by their difference rounded to the given number of decimal places @@ -303,45 +327,68 @@ class TestCase(unittest.TestCase): msg = self._formatMessage(msg, standardMsg) raise self.failureException(msg) + def assertLess(self, left, right, msg=None): + if left >= right: + std = "%r not less than %r" % (left, right) + raise self.failureException(self._formatMessage(msg, std)) + + def assertGreaterEqual(self, left, right, msg=None): + if left < right: + std = "%r less than %r" % (left, right) + raise self.failureException(self._formatMessage(msg, std)) + + def assertIn(self, elem, container, msg=None): + if elem not in container: + std = "%r not found in %r" % (elem, container) + raise self.failureException(self._formatMessage(msg, std)) + + def assertNotIn(self, elem, container, msg=None): + if elem in container: + std = "%r unexpectedly in %r" % (elem, container) + raise self.failureException(self._formatMessage(msg, std)) + + #---------------------------------------------------------------- + # override some unittest1 methods to support _formatMessage + #---------------------------------------------------------------- + def assertEqual(self, real, correct, msg=None): + if real != correct: + std = "got %r, expected would equal %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + def assertNotEqual(self, real, correct, msg=None): + if real == correct: + std = "got %r, expected would not equal %r" % (real, correct) + msg = self._formatMessage(msg, std) + raise self.failureException(msg) + + #---------------------------------------------------------------- + # backport assertRegex() alias from 3.2 to 2.7/3.1 + #---------------------------------------------------------------- if not hasattr(unittest.TestCase, "assertRegex"): - # assertRegexpMatches() added in 2.7/UT2 and 3.1, renamed to - # assertRegex() in 3.2; this code ensures assertRegex() is defined. if hasattr(unittest.TestCase, "assertRegexpMatches"): + # was present in 2.7/3.1 under name assertRegexpMatches assertRegex = unittest.TestCase.assertRegexpMatches else: + # 3.0 and <= 2.6 didn't have this method at all def assertRegex(self, text, expected_regex, msg=None): """Fail the test unless the text matches the regular expression.""" if isinstance(expected_regex, sb_types): assert expected_regex, "expected_regex must not be empty." expected_regex = re.compile(expected_regex) if not expected_regex.search(text): - msg = msg or "Regex didn't match" - msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text) - raise self.failureException(msg) + msg = msg or "Regex didn't match: " + std = '%r not found in %r' % (msg, expected_regex.pattern, text) + raise self.failureException(self._formatMessage(msg, std)) #============================================================ - #add some custom methods + # custom methods for matching warnings #============================================================ - def assertFunctionResults(self, func, cases): - """helper for running through function calls. - - func should be the function to call. - cases should be list of Param instances, - where first position argument is expected return value, - and remaining args and kwds are passed to function. - """ - for elem in cases: - elem = Params.norm(elem) - correct = elem.args[0] - result = func(*elem.args[1:], **elem.kwds) - msg = "error for case %r:" % (elem.render(1),) - self.assertEqual(result, correct, msg) - def assertWarning(self, warning, message_re=None, message=None, category=None, - ##filename=None, filename_re=None, - ##lineno=None, + filename_re=None, filename=None, + lineno=None, msg=None, ): "check if WarningMessage instance (as returned by catch_warnings) matches parameters" @@ -355,7 +402,7 @@ class TestCase(unittest.TestCase): # no original WarningMessage, passed raw Warning wmsg = None - #tests that can use a warning instance or WarningMessage object + # tests that can use a warning instance or WarningMessage object if message: self.assertEqual(str(warning), message, msg) if message_re: @@ -363,29 +410,29 @@ class TestCase(unittest.TestCase): if category: self.assertIsInstance(warning, category, msg) - #commented out until needed... - ###tests that require a WarningMessage object - ##if filename or filename_re: - ## if not wmsg: - ## raise TypeError("can't read filename from warning object") - ## real = wmsg.filename - ## if real.endswith(".pyc") or real.endswith(".pyo"): - ## #FIXME: should use a stdlib call to resolve this back - ## # to original module's path - ## real = real[:-1] - ## if filename: - ## self.assertEqual(real, filename, msg) - ## if filename_re: - ## self.assertRegex(real, filename_re, msg) - ##if lineno: - ## if not wmsg: - ## raise TypeError("can't read lineno from warning object") - ## self.assertEqual(wmsg.lineno, lineno, msg) + # tests that require a WarningMessage object + if filename or filename_re: + if not wmsg: + raise TypeError("matching on filename requires a " + "WarningMessage instance") + real = wmsg.filename + if real.endswith(".pyc") or real.endswith(".pyo"): + # FIXME: should use a stdlib call to resolve this back + # to module's original filename. + real = real[:-1] + if filename: + self.assertEqual(real, filename, msg) + if filename_re: + self.assertRegex(real, filename_re, msg) + if lineno: + if not wmsg: + raise TypeError("matching on lineno requires a " + "WarningMessage instance") + self.assertEqual(wmsg.lineno, lineno, msg) def assertWarningList(self, wlist, desc=None, msg=None): """check that warning list (e.g. from catch_warnings) matches pattern""" - # TODO: make this display better diff of *which* warnings did not match, - # and make use of _formatWarning below. + # TODO: make this display better diff of *which* warnings did not match if not isinstance(desc, (list,tuple)): desc = [] if desc is None else [desc] for idx, entry in enumerate(desc): @@ -407,6 +454,11 @@ class TestCase(unittest.TestCase): (len(desc), len(wlist), self._formatWarningList(wlist), desc) raise self.failureException(self._formatMessage(msg, std)) + def consumeWarningList(self, wlist, *args, **kwds): + """assertWarningList() variant that clears list afterwards""" + self.assertWarningList(wlist, *args, **kwds) + del wlist[:] + def _formatWarning(self, entry): tail = "" if hasattr(entry, "message"): @@ -422,10 +474,23 @@ class TestCase(unittest.TestCase): def _formatWarningList(self, wlist): return "[%s]" % ", ".join(self._formatWarning(entry) for entry in wlist) - def consumeWarningList(self, wlist, *args, **kwds): - """assertWarningList() variant that clears list afterwards""" - self.assertWarningList(wlist, *args, **kwds) - del wlist[:] + #============================================================ + # misc custom methods + #============================================================ + def assertFunctionResults(self, func, cases): + """helper for running through function calls. + + func should be the function to call. + cases should be list of Param instances, + where first position argument is expected return value, + and remaining args and kwds are passed to function. + """ + for elem in cases: + elem = Params.norm(elem) + correct = elem.args[0] + result = func(*elem.args[1:], **elem.kwds) + msg = "error for case %r:" % (elem.render(1),) + self.assertEqual(result, correct, msg) #============================================================ #eoc @@ -601,10 +666,8 @@ class HandlerCase(TestCase): #========================================================= __unittest_skip = True - #optional prefix to prepend to name of test method as it's called, - #useful when multiple handler test classes being run. - #default behavior should be sufficient - def case_prefix(self): + @property + def descriptionPrefix(self): handler = self.handler name = handler.name if hasattr(handler, "get_backend"): @@ -627,11 +690,7 @@ class HandlerCase(TestCase): # setup / cleanup #========================================================= def setUp(self): - # backup warning filter state; set to display all warnings during tests; - # and restore filter state after test. - ctx = catch_all_warnings() - ctx.__enter__() - self._restore_warnings = ctx.__exit__ + TestCase.setUp(self) # if needed, select specific backend for duration of test handler = self.handler @@ -641,7 +700,7 @@ class HandlerCase(TestCase): raise RuntimeError("handler doesn't support multiple backends") if backend == "os_crypt" and not handler.has_backend("os_crypt"): self._patch_safe_crypt() - self._orig_backend = handler.get_backend() + self.addCleanup(handler.set_backend, handler.get_backend()) handler.set_backend(backend) def _patch_safe_crypt(self): @@ -660,23 +719,10 @@ class HandlerCase(TestCase): hash = handler.genhash(secret, hash) assert isinstance(hash, str) return hash - self._orig_crypt = mod._crypt + self.addCleanup(setattr, mod, "_crypt", mod._crypt) mod._crypt = crypt_stub self.using_patched_crypt = True - def tearDown(self): - # unpatch safe_crypt() - if self._orig_crypt: - import passlib.utils as mod - mod._crypt = self._orig_crypt - - # restore original backend - if self._orig_backend: - self.handler.set_backend(self._orig_backend) - - # restore warning filters - self._restore_warnings() - #========================================================= # basic tests #========================================================= @@ -1468,11 +1514,8 @@ class HandlerCase(TestCase): "config=%r hash=%r" % (name, other, secret, kwds, hash)) count +=1 - name = self.case_prefix - if not isinstance(name, str): - name = name() log.debug("fuzz test: %r checked %d passwords against %d verifiers (%s)", - name, count, len(verifiers), + self.descriptionPrefix, count, len(verifiers), ", ".join(vname(v) for v in verifiers)) def get_fuzz_verifiers(self): @@ -1745,7 +1788,7 @@ def create_backend_case(base_class, backend, module=None): "%s_%s" % (backend, handler.name), (base_class,), dict( - case_prefix = "%s (%s backend)" % (handler.name, backend), + descriptionPrefix = "%s (%s backend)" % (handler.name, backend), backend = backend, __module__= module or base_class.__module__, ) @@ -1802,9 +1845,11 @@ def mktemp(*args, **kwds): os.close(fd) return path -#========================================================= -#make sure catch_warnings() is available -#========================================================= +#============================================================================= +# warnings helpers +#============================================================================= + +# make sure catch_warnings() is available try: from warnings import catch_warnings except ImportError: @@ -1893,33 +1938,59 @@ except ImportError: self._module.filters = self._filters self._module.showwarning = self._showwarning -class catch_all_warnings(catch_warnings): - "catch_warnings() wrapper which clears filter" - def __init__(self, reset=".*", **kwds): - super(catch_all_warnings, self).__init__(**kwds) - self._reset_pat = reset +class reset_warnings(catch_warnings): + "catch_warnings() wrapper which clears warning registry & filters" + def __init__(self, reset_filter="always", reset_registry=".*", **kwds): + super(reset_warnings, self).__init__(**kwds) + self._reset_filter = reset_filter + self._reset_registry = re.compile(reset_registry) if reset_registry else None def __enter__(self): # let parent class archive filter state - ret = super(catch_all_warnings, self).__enter__() + ret = super(reset_warnings, self).__enter__() # reset the filter to list everything - warnings.resetwarnings() - warnings.simplefilter("always") + if self._reset_filter: + warnings.resetwarnings() + warnings.simplefilter(self._reset_filter) - # wipe the __warningregistry__ off the map, so warnings - # reliably get reported per-test. - # XXX: *could* restore state - pattern = self._reset_pat + # archive and clear the __warningregistry__ key for all modules + # that match the 'reset' pattern. + pattern = self._reset_registry if pattern: - import sys - key = "__warningregistry__" - for mod in sys.modules.values(): - if hasattr(mod, key) and re.match(pattern, mod.__name__): - getattr(mod, key).clear() - + orig = self._orig_registry = {} + for name, mod in sys.modules.items(): + if pattern.match(name): + reg = getattr(mod, "__warningregistry__", None) + if reg: + orig[name] = reg.copy() + reg.clear() return ret -#========================================================= -#EOF -#========================================================= + def __exit__(self, *exc_info): + # restore warning registry for all modules + pattern = self._reset_registry + if pattern: + # restore archived registry data + orig = self._orig_registry + for name, content in iteritems(orig): + mod = sys.modules.get(name) + if mod is None: + continue + reg = getattr(mod, "__warningregistry__", None) + if reg is None: + setattr(mod, "__warningregistry__", content) + else: + reg.clear() + reg.update(content) + # clear all registry entries that we didn't archive + for name, mod in sys.modules.items(): + if pattern.match(name) and name not in orig: + reg = getattr(mod, "__warningregistry__", None) + if reg: + reg.clear() + super(reset_warnings, self).__exit__(*exc_info) + +#============================================================================= +# eof +#============================================================================= diff --git a/passlib/utils/__init__.py b/passlib/utils/__init__.py index b763898..6d26d91 100644 --- a/passlib/utils/__init__.py +++ b/passlib/utils/__init__.py @@ -17,9 +17,9 @@ import unicodedata from warnings import warn #site #pkg -from passlib.utils.compat import _add_doc, b, bytes, bjoin, bjoin_ints, \ - bjoin_elems, exc_err, irange, imap, PY3, u, \ - ujoin, unicode, belem_ord +from passlib.utils.compat import add_doc, b, bytes, join_bytes, join_byte_values, \ + join_byte_elems, exc_err, irange, imap, PY3, u, \ + join_unicode, unicode, byte_elem_value #local __all__ = [ # constants @@ -353,7 +353,7 @@ def saslprep(source, errname="value"): # - strip 'commonly mapped to nothing' chars (stringprep B.1) in_table_c12 = stringprep.in_table_c12 in_table_b1 = stringprep.in_table_b1 - data = ujoin( + data = join_unicode( _USPACE if in_table_c12(c) else c for c in source if not in_table_b1(c) @@ -435,8 +435,8 @@ if PY3: return bytes(l ^ r for l, r in zip(left, right)) else: def xor_bytes(left, right): - return bjoin(chr(ord(l) ^ ord(r)) for l, r in zip(left, right)) -_add_doc(xor_bytes, "perform bitwise-xor of two byte strings") + return join_bytes(chr(ord(l) ^ ord(r)) for l, r in zip(left, right)) +add_doc(xor_bytes, "perform bitwise-xor of two byte strings") def render_bytes(source, *args): """helper for using formatting operator with bytes. @@ -463,17 +463,17 @@ def render_bytes(source, *args): @deprecated_function(deprecated="1.6", removed="1.8") def bytes_to_int(value): "decode string of bytes as single big-endian integer" - from passlib.utils.compat import belem_ord + from passlib.utils.compat import byte_elem_value out = 0 for v in value: - out = (out<<8) | belem_ord(v) + out = (out<<8) | byte_elem_value(v) return out @deprecated_function(deprecated="1.6", removed="1.8") def int_to_bytes(value, count): "encodes integer into single big-endian byte string" assert value < (1<<(8*count)), "value too large for %d bytes: %d" % (count, value) - return bjoin_ints( + return join_byte_values( ((value>>s) & 0xff) for s in irange(8*count-8,-8,-8) ) @@ -586,7 +586,7 @@ else: raise TypeError("%s must be unicode or bytes, not %s" % (errname, type(source))) -_add_doc(to_native_str, +add_doc(to_native_str, """take in unicode or bytes, return native string. python 2: encodes unicode using specified encoding, leaves bytes alone. @@ -732,7 +732,7 @@ class Base64Engine(object): else: next_value = (ord(elem) for elem in source).next gen = self._encode_bytes(next_value, chunks, tail) - out = bjoin_elems(imap(self._encode64, gen)) + out = join_byte_elems(imap(self._encode64, gen)) ##if tail: ## padding = self.padding ## if padding: @@ -841,7 +841,7 @@ class Base64Engine(object): else: next_value = imap(self._decode64, source).next try: - return bjoin_ints(self._decode_bytes(next_value, chunks, tail)) + return join_byte_values(self._decode_bytes(next_value, chunks, tail)) except KeyError: err = exc_err() raise ValueError("invalid character: %r" % (err.args[0],)) @@ -1003,19 +1003,19 @@ class Base64Engine(object): "encode byte string, first transposing source using offset list" if not isinstance(source, bytes): raise TypeError("source must be bytes, not %s" % (type(source),)) - tmp = bjoin_elems(source[off] for off in offsets) + tmp = join_byte_elems(source[off] for off in offsets) return self.encode_bytes(tmp) def decode_transposed_bytes(self, source, offsets): "decode byte string, then reverse transposition described by offset list" # NOTE: if transposition does not use all bytes of source, - # the original can't be recovered... and bjoin_elems() will throw + # the original can't be recovered... and join_byte_elems() will throw # an error because 1+ values in <buf> will be None. tmp = self.decode_bytes(source) buf = [None] * len(offsets) for off, char in zip(offsets, tmp): buf[off] = char - return bjoin_elems(buf) + return join_byte_elems(buf) #============================================================= # integer decoding helpers - mainly used by des_crypt family @@ -1137,7 +1137,7 @@ class Base64Engine(object): else: itr = irange(0, bits, 6) # padding is msb, so no change needed. - return bjoin_elems(imap(self._encode64, + return join_byte_elems(imap(self._encode64, ((value>>off) & 0x3f for off in itr))) #--------------------------------------------- @@ -1160,7 +1160,7 @@ class Base64Engine(object): raw = [value & 0x3f, (value>>6) & 0x3f] if self.big: raw = reversed(raw) - return bjoin_elems(imap(self._encode64, raw)) + return join_byte_elems(imap(self._encode64, raw)) def encode_int24(self, value): "encodes 24-bit integer -> 4 char string" @@ -1170,7 +1170,7 @@ class Base64Engine(object): (value>>12) & 0x3f, (value>>18) & 0x3f] if self.big: raw = reversed(raw) - return bjoin_elems(imap(self._encode64, raw)) + return join_byte_elems(imap(self._encode64, raw)) def encode_int64(self, value): """encode 64-bit integer -> 11 char hash64 string @@ -1319,7 +1319,7 @@ else: return None return result -_add_doc(safe_crypt, """wrapper around stdlib's crypt. +add_doc(safe_crypt, """wrapper around stdlib's crypt. This is a wrapper around stdlib's :func:`!crypt.crypt`, which attempts to provide uniform behavior across Python 2 and 3. @@ -1467,7 +1467,7 @@ def getrandbytes(rng, count): yield value & 0xff value >>= 3 i += 1 - return bjoin_ints(helper()) + return join_byte_values(helper()) def getrandstr(rng, charset, count): """return string containing *count* number of chars/bytes, whose elements are drawn from specified charset, using specified rng""" @@ -1494,9 +1494,9 @@ def getrandstr(rng, charset, count): i += 1 if isinstance(charset, unicode): - return ujoin(helper()) + return join_unicode(helper()) else: - return bjoin_elems(helper()) + return join_byte_elems(helper()) _52charset = '2346789ABCDEFGHJKMNPQRTUVWXYZabcdefghjkmnpqrstuvwxyz' diff --git a/passlib/utils/compat.py b/passlib/utils/compat.py index 6bf8c50..0715f28 100644 --- a/passlib/utils/compat.py +++ b/passlib/utils/compat.py @@ -18,8 +18,7 @@ if PY3: else: import __builtin__ as builtins - -def _add_doc(obj, doc): +def add_doc(obj, doc): """add docstring to an object""" obj.__doc__ = doc @@ -31,7 +30,7 @@ __all__ = [ 'PY2', 'PY3', 'PY_MAX_25', 'PY27', 'PY_MIN_32', # io - 'BytesIO', 'StringIO', 'SafeConfigParser', + 'BytesIO', 'StringIO', 'NativeStringIO', 'SafeConfigParser', 'print_', # type detection @@ -45,57 +44,24 @@ __all__ = [ 'unicode', 'bytes', 'sb_types', 'uascii_to_str', 'bascii_to_str', 'str_to_uascii', 'str_to_bascii', - 'ujoin', 'bjoin', 'bjoin_ints', 'bjoin_elems', 'belem_ord', + 'join_unicode', 'join_bytes', + 'join_byte_values', 'join_byte_elems', + 'byte_elem_value', + 'iter_byte_values', # iteration helpers - 'irange', 'trange', #'lrange', + 'irange', #'lrange', 'imap', 'lmap', 'iteritems', 'itervalues', + 'next', # introspection - 'exc_err', 'get_method_function', '_add_doc', + 'exc_err', 'get_method_function', 'add_doc', ] -#============================================================================= -# lazy-loaded aliases (see LazyOverlayModule at bottom) -#============================================================================= -if PY3: - _lazy_attrs = dict( - BytesIO="io.BytesIO", - StringIO="io.StringIO", - SafeConfigParser="configparser.SafeConfigParser", - ) - if PY_MIN_32: - # py32 renamed this, removing old ConfigParser - _lazy_attrs["SafeConfigParser"] = "configparser.ConfigParser" -else: - _lazy_attrs = dict( - BytesIO="cStringIO.StringIO", - StringIO="StringIO.StringIO", - SafeConfigParser="ConfigParser.SafeConfigParser", - ) - -#============================================================================= -# typing -#============================================================================= -def is_mapping(obj): - # non-exhaustive check, enough to distinguish from lists, etc - return hasattr(obj, "items") - -if (3,0) <= sys.version_info < (3,2): - # callable isn't dead, it's just resting - from collections import Callable - def callable(obj): - return isinstance(obj, Callable) -else: - callable = builtins.callable - -if PY3: - int_types = (int,) - num_types = (int, float) -else: - int_types = (int, long) - num_types = (int, long, float) +# begin accumulating mapping of lazy-loaded attrs, +# 'merged' into module at bottom +_lazy_attrs = dict() #============================================================================= # unicode & bytes types @@ -103,7 +69,6 @@ else: if PY3: unicode = str bytes = builtins.bytes -# string_types = (unicode,) def u(s): assert isinstance(s, str) @@ -116,7 +81,6 @@ if PY3: else: unicode = builtins.unicode bytes = str if PY_MAX_25 else builtins.bytes -# string_types = (unicode, bytes) def u(s): assert isinstance(s, str) @@ -132,10 +96,10 @@ sb_types = (unicode, bytes) # unicode & bytes helpers #============================================================================= # function to join list of unicode strings -ujoin = u('').join +join_unicode = u('').join # function to join list of byte strings -bjoin = b('').join +join_bytes = b('').join if PY3: def uascii_to_str(s): @@ -154,12 +118,13 @@ if PY3: assert isinstance(s, str) return s.encode("ascii") - bjoin_ints = bjoin_elems = bytes + join_byte_values = join_byte_elems = bytes - def belem_ord(elem): + def byte_elem_value(elem): + assert isinstance(elem, int) return elem - def biter_ints(s): + def iter_byte_values(s): assert isinstance(s, bytes) return s @@ -180,43 +145,53 @@ else: assert isinstance(s, str) return s - def bjoin_ints(values): - return bjoin(chr(v) for v in values) + def join_byte_values(values): + return join_bytes(chr(v) for v in values) - bjoin_elems = bjoin + join_byte_elems = join_bytes - belem_ord = ord + byte_elem_value = ord - def biter_ints(s): + def iter_byte_values(s): assert isinstance(s, bytes) return (ord(c) for c in s) -_add_doc(uascii_to_str, "helper to convert ascii unicode -> native str") -_add_doc(bascii_to_str, "helper to convert ascii bytes -> native str") -_add_doc(str_to_uascii, "helper to convert ascii native str -> unicode") -_add_doc(str_to_bascii, "helper to convert ascii native str -> bytes") +add_doc(uascii_to_str, "helper to convert ascii unicode -> native str") +add_doc(bascii_to_str, "helper to convert ascii bytes -> native str") +add_doc(str_to_uascii, "helper to convert ascii native str -> unicode") +add_doc(str_to_bascii, "helper to convert ascii native str -> bytes") -# bjoin_ints -- function to convert list of ordinal integers to byte string. +# join_byte_values -- function to convert list of ordinal integers to byte string. -# bjoin_elems -- function to convert list of byte elements to byte string; +# join_byte_elems -- function to convert list of byte elements to byte string; # i.e. what's returned by ``b('a')[0]``... # this is b('a') under PY2, but 97 under PY3. -# belem_ord -- function to convert byte element to integer -- a noop under PY3 +# byte_elem_value -- function to convert byte element to integer -- a noop under PY3 -_add_doc(biter_ints, "helper to iterate over byte values in byte string") +add_doc(iter_byte_values, "helper to iterate over byte values in byte string") + +#============================================================================= +# numeric +#============================================================================= +if PY3: + int_types = (int,) + num_types = (int, float) +else: + int_types = (int, long) + num_types = (int, long, float) #============================================================================= # iteration helpers # -# irange - range iterator -# trange - immutable range sequence (list under py2, range object under py3) -# lrange - range list +# irange - range iterable / view (xrange under py2, range under py3) +# lrange - range list (range under py2, list(range()) under py3) # +# imap - map to iterator # lmap - map to list #============================================================================= if PY3: - irange = trange = range + irange = range ##def lrange(*a,**k): ## return list(range(*a,**k)) @@ -224,34 +199,51 @@ if PY3: return list(map(*a,**k)) imap = map + def iteritems(d): + return d.items() + def itervalues(d): + return d.values() else: irange = xrange - trange = range ##lrange = range lmap = map from itertools import imap -if PY3: - def iteritems(d): - return d.items() - def itervalues(d): - return d.values() -else: def iteritems(d): return d.iteritems() def itervalues(d): return d.itervalues() if PY_MAX_25: - def next(itr): + _undef = object() + def next(itr, default=_undef): "compat wrapper for next()" - # NOTE: omits support for 'default' arg - return itr.next() + if default is _undef: + return itr.next() + try: + return itr.next() + except StopIteration: + return default else: next = builtins.next #============================================================================= +# typing +#============================================================================= +def is_mapping(obj): + # non-exhaustive check, enough to distinguish from lists, etc + return hasattr(obj, "items") + +if (3,0) <= sys.version_info < (3,2): + # callable isn't dead, it's just resting + from collections import Callable + def callable(obj): + return isinstance(obj, Callable) +else: + callable = builtins.callable + +#============================================================================= # introspection #============================================================================= def exc_err(): @@ -269,8 +261,26 @@ else: # input/output #============================================================================= if PY3: + _lazy_attrs = dict( + BytesIO="io.BytesIO", + UnicodeIO="io.StringIO", + NativeStringIO="io.StringIO", + SafeConfigParser="configparser.SafeConfigParser", + ) + if sys.version_info >= (3,2): + # py32 renamed this, removing old ConfigParser + _lazy_attrs["SafeConfigParser"] = "configparser.ConfigParser" + print_ = getattr(builtins, "print") + else: + _lazy_attrs = dict( + BytesIO="cStringIO.StringIO", + UnicodeIO="StringIO.StringIO", + NativeStringIO="cStringIO.StringIO", + SafeConfigParser="ConfigParser.SafeConfigParser", + ) + def print_(*args, **kwds): """The new-style print function.""" # extract kwd args @@ -318,13 +328,13 @@ else: #============================================================================= from types import ModuleType -def import_object(source): +def _import_object(source): "helper to import object from module; accept format `path.to.object`" modname, modattr = source.rsplit(".",1) mod = __import__(modname, fromlist=[modattr], level=0) return getattr(mod, modattr) -class LazyOverlayModule(ModuleType): +class _LazyOverlayModule(ModuleType): """proxy module which overlays original module, and lazily imports specified attributes. @@ -359,7 +369,7 @@ class LazyOverlayModule(ModuleType): if callable(source): value = source() else: - value = import_object(source) + value = _import_object(source) setattr(self, attr, value) self.__log.debug("loaded lazy attr %r: %r", attr, value) return value @@ -382,7 +392,7 @@ class LazyOverlayModule(ModuleType): return list(attrs) # replace this module with overlay that will lazily import attributes. -LazyOverlayModule.replace_module(__name__, _lazy_attrs) +_LazyOverlayModule.replace_module(__name__, _lazy_attrs) #============================================================================= # eof diff --git a/passlib/utils/des.py b/passlib/utils/des.py index 67c6d93..4172a2e 100644 --- a/passlib/utils/des.py +++ b/passlib/utils/des.py @@ -45,7 +45,7 @@ which has some nice notes on how this all works - # core import struct # pkg -from passlib.utils.compat import bytes, bjoin_ints, belem_ord, irange, trange +from passlib.utils.compat import bytes, join_byte_values, byte_elem_value, irange, irange # local __all__ = [ "expand_des_key", @@ -56,15 +56,15 @@ __all__ = [ #========================================================= #precalculated iteration ranges & constants #========================================================= -R8 = trange(8) -RR8 = trange(7, -1, -1) -RR4 = trange(3, -1, -1) -RR12_1 = trange(11, 1, -1) -RR9_1 = trange(9,-1,-1) +R8 = irange(8) +RR8 = irange(7, -1, -1) +RR4 = irange(3, -1, -1) +RR12_1 = irange(11, 1, -1) +RR9_1 = irange(9,-1,-1) -RR6_S2 = trange(6, -1, -2) -RR14_S2 = trange(14, -1, -2) -R16_S2 = trange(0, 16, 2) +RR6_S2 = irange(6, -1, -2) +RR14_S2 = irange(14, -1, -2) +R16_S2 = irange(0, 16, 2) INT_24_MAX = 0xffffff INT_64_MAX = 0xffffffff @@ -588,7 +588,7 @@ def expand_des_key(key): def iter_bits(source): for c in source: - v = belem_ord(c) + v = byte_elem_value(c) for i in irange(7,-1,-1): yield (v>>i) & 1 @@ -601,7 +601,7 @@ def expand_des_key(key): out = (out<<1) + p p = 1 - return bjoin_ints( + return join_byte_values( ((out>>s) & 0xFF) for s in irange(8*7,-8,-8) ) diff --git a/passlib/utils/handlers.py b/passlib/utils/handlers.py index 911bba2..2d283c5 100644 --- a/passlib/utils/handlers.py +++ b/passlib/utils/handlers.py @@ -19,8 +19,8 @@ from passlib.registry import get_crypt_handler from passlib.utils import classproperty, consteq, getrandstr, getrandbytes,\ BASE64_CHARS, HASH64_CHARS, rng, to_native_str, \ is_crypt_handler, deprecated_function, to_unicode -from passlib.utils.compat import b, bjoin_ints, bytes, irange, u, \ - uascii_to_str, ujoin, unicode, str_to_uascii +from passlib.utils.compat import b, join_byte_values, bytes, irange, u, \ + uascii_to_str, join_unicode, unicode, str_to_uascii # local __all__ = [ # helpers for implementing MCF handlers @@ -57,7 +57,7 @@ LOWER_HEX_CHARS = u("0123456789abcdef") #: special byte string containing all possible byte values # XXX: treated as singleton by some of the code for efficiency. -ALL_BYTE_VALUES = bjoin_ints(irange(256)) +ALL_BYTE_VALUES = join_byte_values(irange(256)) # deprecated aliases - will be removed after passlib 1.8 H64_CHARS = HASH64_CHARS |