Spaces:
Build error
Build error
File size: 3,731 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/env python3
import sys
import warnings
from time import time
from typing import cast, Iterable, Sized, TextIO
try:
from tqdm import tqdm
except ImportError:
tqdm = None
class DisableErrorIOWrapper(object):
def __init__(self, wrapped: TextIO):
"""
The wrapper around a TextIO object to ignore write errors like tqdm
https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
"""
self._wrapped = wrapped
def __getattr__(self, name):
return getattr(self._wrapped, name)
@staticmethod
def _wrapped_run(func, *args, **kwargs):
try:
return func(*args, **kwargs)
except OSError as e:
if e.errno != 5:
raise
except ValueError as e:
if "closed" not in str(e):
raise
def write(self, *args, **kwargs):
return self._wrapped_run(self._wrapped.write, *args, **kwargs)
def flush(self, *args, **kwargs):
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)
class SimpleProgress:
def __init__(
self,
iterable: Iterable = None,
desc: str = None,
total: int = None,
file: TextIO = None,
mininterval: float = 0.5,
):
"""
Simple progress output used when tqdm is unavailable.
Same as tqdm, output to stderr channel
"""
self.cur = 0
self.iterable = iterable
self.total = total
if total is None and hasattr(iterable, "__len__"):
self.total = len(cast(Sized, iterable))
self.desc = desc
file = DisableErrorIOWrapper(file if file else sys.stderr)
cast(TextIO, file)
self.file = file
self.mininterval = mininterval
self.last_print_t = 0.0
self.closed = False
def __iter__(self):
if self.closed or not self.iterable:
return
self._refresh()
for it in self.iterable:
yield it
self.update()
self.close()
def _refresh(self):
progress_str = self.desc + ": " if self.desc else ""
if self.total:
# e.g., progress: 60% 3/5
progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
else:
# e.g., progress: .....
progress_str += "." * self.cur
print("\r" + progress_str, end="", file=self.file)
def update(self, amount: int = 1):
if self.closed:
return
self.cur += amount
cur_t = time()
if cur_t - self.last_print_t >= self.mininterval:
self._refresh()
self.last_print_t = cur_t
def close(self):
if not self.closed:
self._refresh()
print(file=self.file) # end with new line
self.closed = True
def progress(
iterable: Iterable = None,
desc: str = None,
total: int = None,
use_tqdm=True,
file: TextIO = None,
mininterval: float = 0.5,
**kwargs,
):
# Try to use tqdm is possible. Fall back to simple progress print
if tqdm and use_tqdm:
return tqdm(
iterable,
desc=desc,
total=total,
file=file,
mininterval=mininterval,
**kwargs,
)
else:
if not tqdm and use_tqdm:
warnings.warn(
"Tried to show progress with tqdm "
"but tqdm is not installed. "
"Fall back to simply print out the progress."
)
return SimpleProgress(
iterable, desc=desc, total=total, file=file, mininterval=mininterval
)
|