File size: 3,731 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3

import sys
import warnings
from time import time
from typing import cast, Iterable, Sized, TextIO

try:
    from tqdm import tqdm
except ImportError:
    tqdm = None


class DisableErrorIOWrapper(object):
    def __init__(self, wrapped: TextIO):
        """
        The wrapper around a TextIO object to ignore write errors like tqdm
        https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
        """
        self._wrapped = wrapped

    def __getattr__(self, name):
        return getattr(self._wrapped, name)

    @staticmethod
    def _wrapped_run(func, *args, **kwargs):
        try:
            return func(*args, **kwargs)
        except OSError as e:
            if e.errno != 5:
                raise
        except ValueError as e:
            if "closed" not in str(e):
                raise

    def write(self, *args, **kwargs):
        return self._wrapped_run(self._wrapped.write, *args, **kwargs)

    def flush(self, *args, **kwargs):
        return self._wrapped_run(self._wrapped.flush, *args, **kwargs)


class SimpleProgress:
    def __init__(
        self,
        iterable: Iterable = None,
        desc: str = None,
        total: int = None,
        file: TextIO = None,
        mininterval: float = 0.5,
    ):
        """
        Simple progress output used when tqdm is unavailable.
        Same as tqdm, output to stderr channel
        """
        self.cur = 0

        self.iterable = iterable
        self.total = total
        if total is None and hasattr(iterable, "__len__"):
            self.total = len(cast(Sized, iterable))

        self.desc = desc

        file = DisableErrorIOWrapper(file if file else sys.stderr)
        cast(TextIO, file)
        self.file = file

        self.mininterval = mininterval
        self.last_print_t = 0.0
        self.closed = False

    def __iter__(self):
        if self.closed or not self.iterable:
            return
        self._refresh()
        for it in self.iterable:
            yield it
            self.update()
        self.close()

    def _refresh(self):
        progress_str = self.desc + ": " if self.desc else ""
        if self.total:
            # e.g., progress: 60% 3/5
            progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
        else:
            # e.g., progress: .....
            progress_str += "." * self.cur

        print("\r" + progress_str, end="", file=self.file)

    def update(self, amount: int = 1):
        if self.closed:
            return
        self.cur += amount

        cur_t = time()
        if cur_t - self.last_print_t >= self.mininterval:
            self._refresh()
            self.last_print_t = cur_t

    def close(self):
        if not self.closed:
            self._refresh()
            print(file=self.file)  # end with new line
            self.closed = True


def progress(
    iterable: Iterable = None,
    desc: str = None,
    total: int = None,
    use_tqdm=True,
    file: TextIO = None,
    mininterval: float = 0.5,
    **kwargs,
):
    # Try to use tqdm is possible. Fall back to simple progress print
    if tqdm and use_tqdm:
        return tqdm(
            iterable,
            desc=desc,
            total=total,
            file=file,
            mininterval=mininterval,
            **kwargs,
        )
    else:
        if not tqdm and use_tqdm:
            warnings.warn(
                "Tried to show progress with tqdm "
                "but tqdm is not installed. "
                "Fall back to simply print out the progress."
            )
        return SimpleProgress(
            iterable, desc=desc, total=total, file=file, mininterval=mininterval
        )