Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import sys | |
import warnings | |
from time import time | |
from typing import cast, Iterable, Sized, TextIO | |
try: | |
from tqdm import tqdm | |
except ImportError: | |
tqdm = None | |
class DisableErrorIOWrapper(object): | |
def __init__(self, wrapped: TextIO): | |
""" | |
The wrapper around a TextIO object to ignore write errors like tqdm | |
https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131 | |
""" | |
self._wrapped = wrapped | |
def __getattr__(self, name): | |
return getattr(self._wrapped, name) | |
def _wrapped_run(func, *args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except OSError as e: | |
if e.errno != 5: | |
raise | |
except ValueError as e: | |
if "closed" not in str(e): | |
raise | |
def write(self, *args, **kwargs): | |
return self._wrapped_run(self._wrapped.write, *args, **kwargs) | |
def flush(self, *args, **kwargs): | |
return self._wrapped_run(self._wrapped.flush, *args, **kwargs) | |
class SimpleProgress: | |
def __init__( | |
self, | |
iterable: Iterable = None, | |
desc: str = None, | |
total: int = None, | |
file: TextIO = None, | |
mininterval: float = 0.5, | |
): | |
""" | |
Simple progress output used when tqdm is unavailable. | |
Same as tqdm, output to stderr channel | |
""" | |
self.cur = 0 | |
self.iterable = iterable | |
self.total = total | |
if total is None and hasattr(iterable, "__len__"): | |
self.total = len(cast(Sized, iterable)) | |
self.desc = desc | |
file = DisableErrorIOWrapper(file if file else sys.stderr) | |
cast(TextIO, file) | |
self.file = file | |
self.mininterval = mininterval | |
self.last_print_t = 0.0 | |
self.closed = False | |
def __iter__(self): | |
if self.closed or not self.iterable: | |
return | |
self._refresh() | |
for it in self.iterable: | |
yield it | |
self.update() | |
self.close() | |
def _refresh(self): | |
progress_str = self.desc + ": " if self.desc else "" | |
if self.total: | |
# e.g., progress: 60% 3/5 | |
progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}" | |
else: | |
# e.g., progress: ..... | |
progress_str += "." * self.cur | |
print("\r" + progress_str, end="", file=self.file) | |
def update(self, amount: int = 1): | |
if self.closed: | |
return | |
self.cur += amount | |
cur_t = time() | |
if cur_t - self.last_print_t >= self.mininterval: | |
self._refresh() | |
self.last_print_t = cur_t | |
def close(self): | |
if not self.closed: | |
self._refresh() | |
print(file=self.file) # end with new line | |
self.closed = True | |
def progress( | |
iterable: Iterable = None, | |
desc: str = None, | |
total: int = None, | |
use_tqdm=True, | |
file: TextIO = None, | |
mininterval: float = 0.5, | |
**kwargs, | |
): | |
# Try to use tqdm is possible. Fall back to simple progress print | |
if tqdm and use_tqdm: | |
return tqdm( | |
iterable, | |
desc=desc, | |
total=total, | |
file=file, | |
mininterval=mininterval, | |
**kwargs, | |
) | |
else: | |
if not tqdm and use_tqdm: | |
warnings.warn( | |
"Tried to show progress with tqdm " | |
"but tqdm is not installed. " | |
"Fall back to simply print out the progress." | |
) | |
return SimpleProgress( | |
iterable, desc=desc, total=total, file=file, mininterval=mininterval | |
) | |