summaryrefslogtreecommitdiff
path: root/t/unit/conftest.py
blob: 638fa9aa19f8c9a2fa1a7146116ae235f4ffc678 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import atexit
import os
import pytest
import sys

from kombu.exceptions import VersionMismatch


@pytest.fixture(scope='session')
def multiprocessing_workaround(request):
    yield
    # Workaround for multiprocessing bug where logging
    # is attempted after global already collected at shutdown.
    canceled = set()
    try:
        import multiprocessing.util
        canceled.add(multiprocessing.util._exit_function)
    except (AttributeError, ImportError):
        pass

    try:
        atexit._exithandlers[:] = [
            e for e in atexit._exithandlers if e[0] not in canceled
        ]
    except AttributeError:  # pragma: no cover
        pass  # Py3 missing _exithandlers


def zzz_reset_memory_transport_state():
    yield
    from kombu.transport import memory
    memory.Transport.state.clear()


@pytest.fixture(autouse=True)
def test_cases_has_patching(request, patching):
    if request.instance:
        request.instance.patching = patching


@pytest.fixture
def hub(request):
    from kombu.asynchronous import Hub, get_event_loop, set_event_loop
    _prev_hub = get_event_loop()
    hub = Hub()
    set_event_loop(hub)

    yield hub

    if _prev_hub is not None:
        set_event_loop(_prev_hub)


def find_distribution_modules(name=__name__, file=__file__):
    current_dist_depth = len(name.split('.')) - 1
    current_dist = os.path.join(os.path.dirname(file),
                                *([os.pardir] * current_dist_depth))
    abs = os.path.abspath(current_dist)
    dist_name = os.path.basename(abs)

    for dirpath, dirnames, filenames in os.walk(abs):
        package = (dist_name + dirpath[len(abs):]).replace('/', '.')
        if '__init__.py' in filenames:
            yield package
            for filename in filenames:
                if filename.endswith('.py') and filename != '__init__.py':
                    yield '.'.join([package, filename])[:-3]


def import_all_modules(name=__name__, file=__file__, skip=[]):
    for module in find_distribution_modules(name, file):
        if module not in skip:
            print(f'preimporting {module!r} for coverage...')
            try:
                __import__(module)
            except (ImportError, VersionMismatch, AttributeError):
                pass


def is_in_coverage():
    return (os.environ.get('COVER_ALL_MODULES') or
            any('--cov' in arg for arg in sys.argv))


@pytest.fixture(scope='session')
def cover_all_modules():
    # so coverage sees all our modules.
    if is_in_coverage():
        import_all_modules()