summaryrefslogtreecommitdiff
path: root/kazoo/tests/util.py
blob: dbe3b48a0db21dcf0dc204b4aa5c946119a87633 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
##############################################################################
#
# Copyright Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################

import logging
import os
import time

CI = os.environ.get('CI', False)
CI_ZK_VERSION = CI and os.environ.get('ZOOKEEPER_VERSION', None)
if CI_ZK_VERSION:
    if '-' in CI_ZK_VERSION:
        # Ignore pre-release markers like -alpha
        CI_ZK_VERSION = CI_ZK_VERSION.split('-')[0]
    CI_ZK_VERSION = tuple([int(n) for n in CI_ZK_VERSION.split('.')])


class Handler(logging.Handler):

    def __init__(self, *names, **kw):
        logging.Handler.__init__(self)
        self.names = names
        self.records = []
        self.setLoggerLevel(**kw)

    def setLoggerLevel(self, level=1):
        self.level = level
        self.oldlevels = {}

    def emit(self, record):
        self.records.append(record)

    def clear(self):
        del self.records[:]

    def install(self):
        for name in self.names:
            logger = logging.getLogger(name)
            self.oldlevels[name] = logger.level
            logger.setLevel(self.level)
            logger.addHandler(self)

    def uninstall(self):
        for name in self.names:
            logger = logging.getLogger(name)
            logger.setLevel(self.oldlevels[name])
            logger.removeHandler(self)

    def __str__(self):
        return '\n'.join(
            [("%s %s\n  %s" %
              (record.name, record.levelname,
               '\n'.join([line
                          for line in record.getMessage().split('\n')
                          if line.strip()])
               )
              )
             for record in self.records])


class InstalledHandler(Handler):

    def __init__(self, *names, **kw):
        Handler.__init__(self, *names, **kw)
        self.install()


class Wait(object):

    class TimeOutWaitingFor(Exception):
        "A test condition timed out"

    timeout = 9
    wait = .01

    def __init__(self, timeout=None, wait=None, exception=None,
                 getnow=(lambda: time.time), getsleep=(lambda: time.sleep)):

        if timeout is not None:
            self.timeout = timeout

        if wait is not None:
            self.wait = wait

        if exception is not None:
            self.TimeOutWaitingFor = exception

        self.getnow = getnow
        self.getsleep = getsleep

    def __call__(self, func=None, timeout=None, wait=None, message=None):
        if func is None:
            return lambda func: self(func, timeout, wait, message)

        if func():
            return

        now = self.getnow()
        sleep = self.getsleep()
        if timeout is None:
            timeout = self.timeout
        if wait is None:
            wait = self.wait
        wait = float(wait)

        deadline = now() + timeout
        while 1:
            sleep(wait)
            if func():
                return
            if now() > deadline:
                raise self.TimeOutWaitingFor(
                    message or
                    getattr(func, '__doc__') or
                    getattr(func, '__name__')
                    )

wait = Wait()