summaryrefslogtreecommitdiff
path: root/passlib/utils/pbkdf2.py
blob: 916b295786a3235b4340f3272bfb69019097f85f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
"""passlib.pbkdf2 - PBKDF2 support

this module is getting increasingly poorly named.
maybe rename to "kdf" since it's getting more key derivation functions added.
"""
#=================================================================================
#imports
#=================================================================================
#core
import hashlib
import logging; log = logging.getLogger(__name__)
import re
from struct import pack
from warnings import warn
#site
try:
    from M2Crypto import EVP as _EVP
except ImportError:
    _EVP = None
#pkg
from passlib.exc import PasslibRuntimeWarning, ExpectedTypeError
from passlib.utils import join_bytes, to_native_str, bytes_to_int, int_to_bytes, join_byte_values
from passlib.utils.compat import b, bytes, BytesIO, irange, callable, int_types
#local
__all__ = [
    "get_prf",
    "pbkdf1",
    "pbkdf2",
]

#=============================================================================
# hash helpers
#=============================================================================

# known hash names
_nhn_formats = dict(hashlib=0, iana=1)
_nhn_hash_names = [
    # (hashlib/ssl name, iana name or standin, ... other known aliases)

    # hashes with official IANA-assigned names
    # (as of 2012-03 - http://www.iana.org/assignments/hash-function-text-names)
    ("md2", "md2"),
    ("md5", "md5"),
    ("sha1", "sha-1"),
    ("sha224", "sha-224", "sha2-224"),
    ("sha256", "sha-256", "sha2-256"),
    ("sha384", "sha-384", "sha2-384"),
    ("sha512", "sha-512", "sha2-512"),

    # hashlib/ssl-supported hashes without official IANA names,
    # hopefully compatible stand-ins have been chosen.
    ("md4", "md4"),
    ("sha", "sha-0", "sha0"),
    ("ripemd", "ripemd"),
    ("ripemd160", "ripemd-160"),
]

# cache for norm_hash_name()
_nhn_cache = {}

def norm_hash_name(name, format="hashlib"):
    """Normalize hash function name

    :arg name:
        Original hash function name.

        This name can be a Python :mod:`~hashlib` digest name,
        a SCRAM mechanism name, IANA assigned hash name, etc.
        Case is ignored, and underscores are converted to hyphens.

    :param format:
        Naming convention to normalize to.
        Possible values are:

        * ``"hashlib"`` (the default) - normalizes name to be compatible
          with Python's :mod:`!hashlib`.

        * ``"iana"`` - normalizes name to IANA-assigned hash function name.
          for hashes which IANA hasn't assigned a name for, issues a warning,
          and then uses a heuristic to give a "best guess".

    :returns:
        Hash name, returned as native :class:`!str`.
    """
    # check cache
    try:
        idx = _nhn_formats[format]
    except KeyError:
        raise ValueError("unknown format: %r" % (format,))
    try:
        return _nhn_cache[name][idx]
    except KeyError:
        pass
    orig = name

    # normalize input
    if not isinstance(name, str):
        name = to_native_str(name, 'utf-8', 'hash name')
    name = re.sub("[_ /]", "-", name.strip().lower())
    if name.startswith("scram-"):
        name = name[6:]
        if name.endswith("-plus"):
            name = name[:-5]

    # look through standard names and known aliases
    def check_table(name):
        for row in _nhn_hash_names:
            if name in row:
                _nhn_cache[orig] = row
                return row[idx]
    result = check_table(name)
    if result:
        return result

    # try to clean name up, and recheck table
    m = re.match("^(?P<name>[a-z]+)-?(?P<rev>\d)?-?(?P<size>\d{3,4})?$", name)
    if m:
        name, rev, size = m.group("name", "rev", "size")
        if rev:
            name += rev
        if size:
            name += "-" + size
        result = check_table(name)
        if result:
            return result

    # else we've done what we can
    warn("norm_hash_name(): unknown hash: %r" % (orig,), PasslibRuntimeWarning)
    name2 = name.replace("-", "")
    row = _nhn_cache[orig] = (name2, name)
    return row[idx]

# TODO: get_hash() func which wraps norm_hash_name(), hashlib.<attr>, and hashlib.new

#=================================================================================
#general prf lookup
#=================================================================================
_BNULL = b('\x00')
_XY_DIGEST = b(',\x1cb\xe0H\xa5\x82M\xfb>\xd6\x98\xef\x8e\xf9oQ\x85\xa3i')

_trans_5C = join_byte_values((x ^ 0x5C) for x in irange(256))
_trans_36 = join_byte_values((x ^ 0x36) for x in irange(256))

def _get_hmac_prf(digest):
    "helper to return HMAC prf for specific digest"
    def tag_wrapper(prf):
        prf.__name__ = "hmac_" + digest
        prf.__doc__ = ("hmac_%s(key, msg) -> digest;"
                       " generated by passlib.utils.pbkdf2.get_prf()" %
                       digest)

    if _EVP and digest == "sha1":
        # use m2crypto function directly for sha1, since that's it's default digest
        try:
            result = _EVP.hmac(b('x'),b('y'))
        except ValueError: #pragma: no cover
            pass
        else:
            if result == _XY_DIGEST:
                return _EVP.hmac, 20
        # don't expect to ever get here, but will fall back to pure-python if we do.
        warn("M2Crypto.EVP.HMAC() returned unexpected result " # pragma: no cover -- sanity check
             "during Passlib self-test!", PasslibRuntimeWarning)
    elif _EVP:
        # use m2crypto if it's present and supports requested digest
        try:
            result = _EVP.hmac(b('x'), b('y'), digest)
        except ValueError:
            pass
        else:
            #it does. so use M2Crypto's hmac & digest code
            hmac_const = _EVP.hmac
            def prf(key, msg):
                return hmac_const(key, msg, digest)
            digest_size = len(result)
            tag_wrapper(prf)
            return prf, digest_size

    #fall back to hashlib-based implementation
    digest_const = getattr(hashlib, digest, None)
    if not digest_const:
        raise ValueError("unknown hash algorithm: %r" % (digest,))
    tmp = digest_const()
    block_size = tmp.block_size
    assert block_size >= 16, "unacceptably low block size"
    digest_size = tmp.digest_size
    del tmp
    def prf(key, msg):
        # simplified version of stdlib's hmac module
        if len(key) > block_size:
            key = digest_const(key).digest()
        key += _BNULL * (block_size - len(key))
        tmp = digest_const(key.translate(_trans_36) + msg).digest()
        return digest_const(key.translate(_trans_5C) + tmp).digest()
    tag_wrapper(prf)
    return prf, digest_size

#cache mapping prf name/func -> (func, digest_size)
_prf_cache = {}

def _clear_prf_cache():
    "helper for unit tests"
    _prf_cache.clear()

def get_prf(name):
    """lookup pseudo-random family (prf) by name.

    :arg name:
        this must be the name of a recognized prf.
        currently this only recognizes names with the format
        :samp:`hmac-{digest}`, where :samp:`{digest}`
        is the name of a hash function such as
        ``md5``, ``sha256``, etc.

        this can also be a callable with the signature
        ``prf(secret, message) -> digest``,
        in which case it will be returned unchanged.

    :raises ValueError: if the name is not known
    :raises TypeError: if the name is not a callable or string

    :returns:
        a tuple of :samp:`({func}, {digest_size})`.

        * :samp:`{func}` is a function implementing
          the specified prf, and has the signature
          ``func(secret, message) -> digest``.

        * :samp:`{digest_size}` is an integer indicating
          the number of bytes the function returns.

    usage example::

        >>> from passlib.utils.pbkdf2 import get_prf
        >>> hmac_sha256, dsize = get_prf("hmac-sha256")
        >>> hmac_sha256
        <function hmac_sha256 at 0x1e37c80>
        >>> dsize
        32
        >>> digest = hmac_sha256('password', 'message')

    this function will attempt to return the fastest implementation
    it can find; if M2Crypto is present, and supports the specified prf,
    :func:`M2Crypto.EVP.hmac` will be used behind the scenes.
    """
    global _prf_cache
    if name in _prf_cache:
        return _prf_cache[name]
    if isinstance(name, str):
        if name.startswith("hmac-") or name.startswith("hmac_"):
            retval = _get_hmac_prf(name[5:])
        else:
            raise ValueError("unknown prf algorithm: %r" % (name,))
    elif callable(name):
        #assume it's a callable, use it directly
        digest_size = len(name(b('x'),b('y')))
        retval = (name, digest_size)
    else:
        raise ExpectedTypeError(name, "str or callable", "prf name")
    _prf_cache[name] = retval
    return retval

#=================================================================================
#pbkdf1 support
#=================================================================================
def pbkdf1(secret, salt, rounds, keylen=None, hash="sha1"):
    """pkcs#5 password-based key derivation v1.5

    :arg secret: passphrase to use to generate key
    :arg salt: salt string to use when generating key
    :param rounds: number of rounds to use to generate key
    :arg keylen: number of bytes to generate (if ``None``, uses digest's native size)
    :param hash:
        hash function to use. must be name of a hash recognized by hashlib.

    :returns:
        raw bytes of generated key

    .. note::

        This algorithm has been deprecated, new code should use PBKDF2.
        Among other limitations, ``keylen`` cannot be larger
        than the digest size of the specified hash.

    """
    # validate secret & salt
    if not isinstance(secret, bytes):
        raise ExpectedTypeError(secret, "bytes", "secret")
    if not isinstance(salt, bytes):
        raise ExpectedTypeError(salt, "bytes", "salt")

    # validate rounds
    if not isinstance(rounds, int_types):
        raise ExpectedTypeError(rounds, "int", "rounds")
    if rounds < 1:
        raise ValueError("rounds must be at least 1")

    # resolve hash
    try:
        hash_const = getattr(hashlib, hash)
    except AttributeError:
        # check for ssl hash
        # NOTE: if hash unknown, new() will throw ValueError, which we'd just
        #       reraise anyways; so instead of checking, we just let it get
        #       thrown during first use, below
        # TODO: use builtin md4 class if hashlib doesn't have it.
        def hash_const(msg):
            return hashlib.new(hash, msg)

    # prime pbkdf1 loop, get block size
    block = hash_const(secret + salt).digest()

    # validate keylen
    if keylen is None:
        keylen = len(block)
    elif not isinstance(keylen, int_types):
        raise ExpectedTypeError(keylen, "int or None", "keylen")
    elif keylen < 0:
        raise ValueError("keylen must be at least 0")
    elif keylen > len(block):
        raise ValueError("keylength too large for digest: %r > %r" %
                         (keylen, len(block)))

    # main pbkdf1 loop
    for _ in irange(rounds-1):
        block = hash_const(block).digest()
    return block[:keylen]

#=================================================================================
#pbkdf2
#=================================================================================
MAX_BLOCKS = 0xffffffff #2**32-1
MAX_HMAC_SHA1_KEYLEN = MAX_BLOCKS*20
# NOTE: the pbkdf2 spec does not specify a maximum number of rounds.
#       however, many of the hashes in passlib are currently clamped
#       at the 32-bit limit, just for sanity. once realistic pbkdf2 rounds
#       start approaching 24 bits, this limit will be raised.

def pbkdf2(secret, salt, rounds, keylen=None, prf="hmac-sha1"):
    """pkcs#5 password-based key derivation v2.0

    :arg secret: passphrase to use to generate key
    :arg salt: salt string to use when generating key
    :param rounds: number of rounds to use to generate key
    :arg keylen:
        number of bytes to generate.
        if set to ``None``, will use digest size of selected prf.
    :param prf:
        psuedo-random family to use for key strengthening.
        this can be any string or callable accepted by :func:`get_prf`.
        this defaults to ``"hmac-sha1"`` (the only prf explicitly listed in
        the PBKDF2 specification)

    :returns:
        raw bytes of generated key
    """
    # validate secret & salt
    if not isinstance(secret, bytes):
        raise ExpectedTypeError(secret, "bytes", "secret")
    if not isinstance(salt, bytes):
        raise ExpectedTypeError(salt, "bytes", "salt")

    # validate rounds
    if not isinstance(rounds, int_types):
        raise ExpectedTypeError(rounds, "int", "rounds")
    if rounds < 1:
        raise ValueError("rounds must be at least 1")

    # validate keylen
    if keylen is not None:
        if not isinstance(keylen, int_types):
            raise ExpectedTypeError(keylen, "int or None", "keylen")
        elif keylen < 0:
            raise ValueError("keylen must be at least 0")

    # special case for m2crypto + hmac-sha1
    if prf == "hmac-sha1" and _EVP:
        if keylen is None:
            keylen = 20
        # NOTE: doing check here, because M2crypto won't take 'long' instances
        # (which this is when running under 32bit)
        if keylen > MAX_HMAC_SHA1_KEYLEN:
            raise ValueError("key length too long for digest")

        # NOTE: as of 2012-4-4, m2crypto has buffer overflow issue
        # which may cause segfaults if keylen > 32 (EVP_MAX_KEY_LENGTH).
        # therefore we're avoiding m2crypto for large keys until that's fixed.
        # see https://bugzilla.osafoundation.org/show_bug.cgi?id=13052
        if keylen < 32:
            return _EVP.pbkdf2(secret, salt, rounds, keylen)

    # resolve prf
    prf_func, digest_size = get_prf(prf)
    if keylen is None:
        keylen = digest_size

    # figure out how many blocks we'll need
    block_count = (keylen+digest_size-1)//digest_size
    if block_count >= MAX_BLOCKS:
        raise ValueError("key length too long for digest")

    #build up result from blocks
    def gen():
        for i in irange(block_count):
            digest = prf_func(secret, salt + pack(">L", i+1))
            accum = bytes_to_int(digest)
            for _ in irange(rounds-1):
                digest = prf_func(secret, digest)
                accum ^= bytes_to_int(digest)
            yield int_to_bytes(accum, digest_size)
    return join_bytes(gen())[:keylen]

#=================================================================================
#eof
#=================================================================================