summaryrefslogtreecommitdiff
path: root/passlib/tests/genconfig.py
blob: dc87d02a0cc1ab5c228e10b3db028f1d3ad7b376 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""passlib config generation script

this script is a work in progress to develop a script which generates
rounds configuration parameters suitable for a particular host & deployment requirements.
right now it just consists of a function which experimentally determines
the optimal range of rounds values for a given hash, based on the desired time it should take.
"""
#=========================================================
#imports
#=========================================================
#core
from math import log as logb
import logging
import time
import sys
#site
#pkg
from passlib.utils.compat import iteritems, print_
from passlib.registry import get_crypt_handler
#local
log = logging.getLogger(__name__)
#=========================================================
#
#=========================================================
class HashTimer(object):
    """helper which determines number of rounds required for hash to take desired amount of time.

    usage::

        >>> timer = HashTimer("sha512_crypt")
        >>> timer.find_rounds(.5)

    .. note::
        This function is not very exact, and generates results
        that are only approximately the same each time (w/in about 5% usually).

        Furthermore, to generate useful values, it should
        be run when the system has an average load
        to get an accurate measurement.
    """
    log = logging.getLogger(__name__ + ".HashTimer")

    def __init__(self, name, samples=1):
        #
        #get handler, extract boundary information
        #
        self.handler = handler = get_crypt_handler(name)
        if 'rounds' not in handler.setting_kwds:
            raise ValueError("scheme does not support rounds: %r" % (handler.name,))
        self.min_rounds = getattr(handler, "min_rounds", 2)
        self.max_rounds = getattr(handler, "max_rounds", (1<<32)-1)
        rc = self.rounds_cost = getattr(handler, "rounds_cost", "linear")

        #
        #set up functions that vary based on rounds cost function
        #
        if rc == "linear":
            def get_rps(rounds, delta):
                return rounds/delta
            def guess_rounds(rps, target):
                return int(rps*target+.5)
            erradj = 2
        elif rc == "log2":
            def get_rps(rounds, delta):
                return (2**rounds)/delta
            def guess_rounds(rps, target):
                return int(logb(rps*target,2)+.5)
            erradj = 1.1
        else:
            raise NotImplementedError("unknown rounds cost function: %r" % (rc,))
        self.get_rps = get_rps
        self.guess_rounds = guess_rounds
        self.erradj = erradj

        #
        #init cache
        #
        self.samples = samples
        self.cache = {}
        self.srange = trange(samples)

    def time_encrypt(self, rounds):
        "check how long encryption for a given number of rounds will take"
        cache = self.cache
        if rounds in cache:
            return cache[rounds]
        encrypt = self.handler.encrypt
        srange = self.srange
        cur = time.time
        start = cur()
        for x in srange:
            encrypt("too many secrets", rounds=rounds)
        stop = cur()
        delta = (stop-start)/self.samples
        cache[rounds] = delta
        return delta

    def find_rounds(self, target, over=False, under=False):
        """find optimal rounds range for hash

        :arg target: time hashing a password should take
        :param over: if True, returns minimum rounds taking *at least* target seconds.
        :param under: if True, returns maximum rounds taking *at most* target seconds.

        if neither over / under is set, returns rounds taking
        closest to target seconds.

        :returns:
            returns number of rounds closest
            to taking about ``target`` seconds to hash a password.
        """
        if target <= 0:
            raise ValueError("target must be > 0")

        log = self.log
        name = self.handler.name
        get_rps = self.get_rps
        time_encrypt = self.time_encrypt
        log.info("%s: finding rounds for target time: %f", name, target)

        #
        #check if useful lower & upper bounds already exist in cache
        #
        lower = upper = None
        for rounds, delta in iteritems(self.cache.iteritems):
            if delta < target:
                if lower is None or rounds > lower:
                    lower = rounds
            else:
                if upper is None or rounds < upper:
                    upper = rounds

        #
        #if bounds not found in cache, run open-ended search for starting bounds
        #
        if lower is None:
            lower = max(1,self.min_rounds)
        if upper is None:
            guess_rounds = self.guess_rounds
            max_rounds = self.max_rounds
            target_above = target*self.erradj #NOTE: we aim a little high as hack to deal w/ measuring error
            rounds = lower
            while True:
                delta = time_encrypt(rounds)
                rps = get_rps(rounds, delta)
                log.debug("%s: ranging target: checked %r -> %fs (%f r/s)", name, rounds, delta, rps)
                if delta < target:
                    lower = rounds
                    if rounds == max_rounds:
                        log.warning("%s: target time out of range: hash would require > max_rounds (%d) in order to take %fs", name, max_rounds, target)
                        return rounds
                    rounds = min(max(guess_rounds(rps, target_above), rounds+1), max_rounds)
                else:
                    upper = rounds
                    break

        #
        #perform binary search till we find match
        #
        while lower+1<upper:
            #NOTE: weighting things in favor of upper, since per-call overhead causes curve to not be quite linear.
            next = (upper*5+lower*3)//8
            delta = time_encrypt(next)
            rps = get_rps(next, delta)
            log.debug("%s: finding target: range %r .. %r: checked %r -> %fs (%f r/s)", name, lower, upper, next, delta, rps)
            if delta < target:
                lower = next
            else:
                upper = next

        #
        #now 'lower' is largest value which takes less than target seconds,
        #and 'upper' is smallest value which takes greater than target seconds.
        #so we pick based on over/under flags, or fallback to whichever one is closest
        #
        if over:
            return upper
        elif under:
            return lower
        else:
            if target-cache[lower] < cache[upper]-target:
                return lower
            else:
                return upper

    def find_rounds_range(self, target_high, target_low=None, over=False, under=False):
        "find min/max rounds which will cause scheme to take specified range of times"
        if target_low is None:
            target_low = target_high * .75
        elif target_low > target_high:
            target_high = target_low
        rounds_high = self.find_rounds(target_high, under=not over, over=over)
        rounds_low = self.find_rounds(target_low, over=not under, under=under)
        if rounds_low > rounds_high:
            #NOTE: this happens sometimes w/ rounds_cost=log2...
            #if nothing hits w/in range, rounds_low will be 1+ rounds_high
            #we just return correctly ordered range
            rounds_low, rounds_high = rounds_high, rounds_low
        return rounds_low, rounds_high

    def estimate_rps(self):
        "return estimated rounds per second based on cached results"
        cache = self.cache
        if not cache:
            raise RuntimeError("should not be called until cache populated by find_rounds()")
        get_rps = self.get_rps
        rps = sum(r*get_rps(r,d) for r,d in iteritems(cache))/sum(cache)
        if rps > 1000: #for almost all cases, we'd return integer
            rps = int(rps)
        return rps

#=========================================================
#main
#=========================================================
def main(*args):
    from bps.logs import setup_std_logging
    setup_std_logging(level="debug", dev=True)

    timer = HashTimer("sha256_crypt")
    print_(timer.find_rounds_range(.5))
    print_(timer.estimate_rps())

    #TODO: give script ability to generate timings for a range of schemes, and minimum / maximum times.

if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
#=========================================================
#eof
#=========================================================