strexp / captum /_utils /progress.py
markytools's picture
added strexp
d61b9c7
#!/usr/bin/env python3
import sys
import warnings
from time import time
from typing import cast, Iterable, Sized, TextIO
try:
from tqdm import tqdm
except ImportError:
tqdm = None
class DisableErrorIOWrapper(object):
def __init__(self, wrapped: TextIO):
"""
The wrapper around a TextIO object to ignore write errors like tqdm
https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
"""
self._wrapped = wrapped
def __getattr__(self, name):
return getattr(self._wrapped, name)
@staticmethod
def _wrapped_run(func, *args, **kwargs):
try:
return func(*args, **kwargs)
except OSError as e:
if e.errno != 5:
raise
except ValueError as e:
if "closed" not in str(e):
raise
def write(self, *args, **kwargs):
return self._wrapped_run(self._wrapped.write, *args, **kwargs)
def flush(self, *args, **kwargs):
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)
class SimpleProgress:
def __init__(
self,
iterable: Iterable = None,
desc: str = None,
total: int = None,
file: TextIO = None,
mininterval: float = 0.5,
):
"""
Simple progress output used when tqdm is unavailable.
Same as tqdm, output to stderr channel
"""
self.cur = 0
self.iterable = iterable
self.total = total
if total is None and hasattr(iterable, "__len__"):
self.total = len(cast(Sized, iterable))
self.desc = desc
file = DisableErrorIOWrapper(file if file else sys.stderr)
cast(TextIO, file)
self.file = file
self.mininterval = mininterval
self.last_print_t = 0.0
self.closed = False
def __iter__(self):
if self.closed or not self.iterable:
return
self._refresh()
for it in self.iterable:
yield it
self.update()
self.close()
def _refresh(self):
progress_str = self.desc + ": " if self.desc else ""
if self.total:
# e.g., progress: 60% 3/5
progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}"
else:
# e.g., progress: .....
progress_str += "." * self.cur
print("\r" + progress_str, end="", file=self.file)
def update(self, amount: int = 1):
if self.closed:
return
self.cur += amount
cur_t = time()
if cur_t - self.last_print_t >= self.mininterval:
self._refresh()
self.last_print_t = cur_t
def close(self):
if not self.closed:
self._refresh()
print(file=self.file) # end with new line
self.closed = True
def progress(
iterable: Iterable = None,
desc: str = None,
total: int = None,
use_tqdm=True,
file: TextIO = None,
mininterval: float = 0.5,
**kwargs,
):
# Try to use tqdm is possible. Fall back to simple progress print
if tqdm and use_tqdm:
return tqdm(
iterable,
desc=desc,
total=total,
file=file,
mininterval=mininterval,
**kwargs,
)
else:
if not tqdm and use_tqdm:
warnings.warn(
"Tried to show progress with tqdm "
"but tqdm is not installed. "
"Fall back to simply print out the progress."
)
return SimpleProgress(
iterable, desc=desc, total=total, file=file, mininterval=mininterval
)