summaryrefslogtreecommitdiff
path: root/src/tox/execute/stream.py
blob: 980c97ffdebad5174d3dc425361ea280ac564728 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations

from contextlib import contextmanager
from threading import Event, Lock, Timer
from types import TracebackType
from typing import IO, Iterator

from colorama import Fore


class SyncWrite:
    """
    Make sure data collected is synced in-memory and to the target stream on every newline and time period.

    Used to propagate executed commands output to the standard output/error streams visible to the user.
    """

    REFRESH_RATE = 0.1

    def __init__(self, name: str, target: IO[bytes] | None, color: str | None = None) -> None:
        self._content = bytearray()
        self._target: IO[bytes] | None = target
        self._target_enabled: bool = target is not None
        self._keep_printing: Event = Event()
        self._content_lock: Lock = Lock()
        self._lock: Lock = Lock()
        self._at: int = 0
        self._color: str | None = color
        self.name = name

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(name={self.name!r}, target={self._target!r}, color={self._color!r})"

    def __enter__(self) -> SyncWrite:
        if self._target_enabled:
            self._start()
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,  # noqa: U100
        exc_val: BaseException | None,  # noqa: U100
        exc_tb: TracebackType | None,  # noqa: U100
    ) -> None:
        if self._target_enabled:
            self._cancel()
            self._write(len(self._content))

    def handler(self, content: bytes) -> None:
        """A callback called whenever content is written"""
        with self._content_lock:
            self._content.extend(content)
            if self._target_enabled is False:
                return
            at = content.rfind(b"\n")
            if at != -1:  # pragma: no branch
                at = len(self._content) - len(content) + at + 1
        self._cancel()
        try:
            if at != -1:
                self._write(at)
        finally:
            self._start()

    def _start(self) -> None:
        self.timer = Timer(self.REFRESH_RATE, self._trigger_timer)
        self.timer.name = f"{self.name}-sync-timer"
        self.timer.start()

    def _cancel(self) -> None:
        self.timer.cancel()

    def _trigger_timer(self) -> None:
        with self._content_lock:
            at = len(self._content)
        self._write(at)

    def _write(self, at: int) -> None:
        assert self._target is not None  # because _do_print is guarding the call of this method
        with self._lock:
            if at > self._at:  # pragma: no branch
                try:
                    with self.colored():
                        self._target.write(self._content[self._at : at])
                    self._target.flush()
                finally:
                    self._at = at

    @contextmanager
    def colored(self) -> Iterator[None]:
        if self._color is None or self._target is None:
            yield
        else:
            self._target.write(self._color.encode("utf-8"))
            try:
                yield
            finally:
                self._target.write(Fore.RESET.encode("utf-8"))

    @property
    def text(self) -> str:
        with self._content_lock:
            return self._content.decode("utf-8")

    @property
    def content(self) -> bytearray:
        with self._content_lock:
            return self._content