summaryrefslogtreecommitdiff
path: root/kazoo/tests/util.py
blob: 6351468310ca752aa6762ade8c3a962ba72c598b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
##############################################################################
#
# Copyright Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################

import logging
import os
import time

CI = os.environ.get("CI", False)
CI_ZK_VERSION = CI and os.environ.get("ZOOKEEPER_VERSION", None)
if CI_ZK_VERSION:
    if "-" in CI_ZK_VERSION:
        # Ignore pre-release markers like -alpha
        CI_ZK_VERSION = CI_ZK_VERSION.split("-")[0]
    CI_ZK_VERSION = tuple([int(n) for n in CI_ZK_VERSION.split(".")])


class Handler(logging.Handler):
    def __init__(self, *names, **kw):
        logging.Handler.__init__(self)
        self.names = names
        self.records = []
        self.setLoggerLevel(**kw)

    def setLoggerLevel(self, level=1):
        self.level = level
        self.oldlevels = {}

    def emit(self, record):
        self.records.append(record)

    def clear(self):
        del self.records[:]

    def install(self):
        for name in self.names:
            logger = logging.getLogger(name)
            self.oldlevels[name] = logger.level
            logger.setLevel(self.level)
            logger.addHandler(self)

    def uninstall(self):
        for name in self.names:
            logger = logging.getLogger(name)
            logger.setLevel(self.oldlevels[name])
            logger.removeHandler(self)

    def __str__(self):
        return "\n".join(
            [
                (
                    "%s %s\n  %s"
                    % (
                        record.name,
                        record.levelname,
                        "\n".join(
                            [
                                line
                                for line in record.getMessage().split("\n")
                                if line.strip()
                            ]
                        ),
                    )
                )
                for record in self.records
            ]
        )


class InstalledHandler(Handler):
    def __init__(self, *names, **kw):
        Handler.__init__(self, *names, **kw)
        self.install()


class Wait(object):
    class TimeOutWaitingFor(Exception):
        "A test condition timed out"

    timeout = 9
    wait = 0.01

    def __init__(
        self,
        timeout=None,
        wait=None,
        exception=None,
        getnow=(lambda: time.time),
        getsleep=(lambda: time.sleep),
    ):

        if timeout is not None:
            self.timeout = timeout

        if wait is not None:
            self.wait = wait

        if exception is not None:
            self.TimeOutWaitingFor = exception

        self.getnow = getnow
        self.getsleep = getsleep

    def __call__(self, func=None, timeout=None, wait=None, message=None):
        if func is None:
            return lambda func: self(func, timeout, wait, message)

        if func():
            return

        now = self.getnow()
        sleep = self.getsleep()
        if timeout is None:
            timeout = self.timeout
        if wait is None:
            wait = self.wait
        wait = float(wait)

        deadline = now() + timeout
        while 1:
            sleep(wait)
            if func():
                return
            if now() > deadline:
                raise self.TimeOutWaitingFor(
                    message or func.__doc__ or func.__name__
                )


wait = Wait()