|
""" |
|
General helpers required for `tqdm.std`. |
|
""" |
|
import os |
|
import re |
|
import sys |
|
from functools import partial, partialmethod, wraps |
|
from inspect import signature |
|
|
|
from unicodedata import east_asian_width |
|
from warnings import warn |
|
from weakref import proxy |
|
|
|
_range, _unich, _unicode, _basestring = range, chr, str, str |
|
CUR_OS = sys.platform |
|
IS_WIN = any(CUR_OS.startswith(i) for i in ['win32', 'cygwin']) |
|
IS_NIX = any(CUR_OS.startswith(i) for i in ['aix', 'linux', 'darwin', 'freebsd']) |
|
RE_ANSI = re.compile(r"\x1b\[[;\d]*[A-Za-z]") |
|
|
|
try: |
|
if IS_WIN: |
|
import colorama |
|
else: |
|
raise ImportError |
|
except ImportError: |
|
colorama = None |
|
else: |
|
try: |
|
colorama.init(strip=False) |
|
except TypeError: |
|
colorama.init() |
|
|
|
|
|
def envwrap(prefix, types=None, is_method=False): |
|
""" |
|
Override parameter defaults via `os.environ[prefix + param_name]`. |
|
Maps UPPER_CASE env vars map to lower_case param names. |
|
camelCase isn't supported (because Windows ignores case). |
|
|
|
Precedence (highest first): |
|
|
|
- call (`foo(a=3)`) |
|
- environ (`FOO_A=2`) |
|
- signature (`def foo(a=1)`) |
|
|
|
Parameters |
|
---------- |
|
prefix : str |
|
Env var prefix, e.g. "FOO_" |
|
types : dict, optional |
|
Fallback mappings `{'param_name': type, ...}` if types cannot be |
|
inferred from function signature. |
|
Consider using `types=collections.defaultdict(lambda: ast.literal_eval)`. |
|
is_method : bool, optional |
|
Whether to use `functools.partialmethod`. If (default: False) use `functools.partial`. |
|
|
|
Examples |
|
-------- |
|
``` |
|
$ cat foo.py |
|
from tqdm.utils import envwrap |
|
@envwrap("FOO_") |
|
def test(a=1, b=2, c=3): |
|
print(f"received: a={a}, b={b}, c={c}") |
|
|
|
$ FOO_A=42 FOO_C=1337 python -c 'import foo; foo.test(c=99)' |
|
received: a=42, b=2, c=99 |
|
``` |
|
""" |
|
if types is None: |
|
types = {} |
|
i = len(prefix) |
|
env_overrides = {k[i:].lower(): v for k, v in os.environ.items() if k.startswith(prefix)} |
|
part = partialmethod if is_method else partial |
|
|
|
def wrap(func): |
|
params = signature(func).parameters |
|
|
|
overrides = {k: v for k, v in env_overrides.items() if k in params} |
|
|
|
for k in overrides: |
|
param = params[k] |
|
if param.annotation is not param.empty: |
|
for typ in getattr(param.annotation, '__args__', (param.annotation,)): |
|
try: |
|
overrides[k] = typ(overrides[k]) |
|
except Exception: |
|
pass |
|
else: |
|
break |
|
elif param.default is not None: |
|
overrides[k] = type(param.default)(overrides[k]) |
|
else: |
|
try: |
|
overrides[k] = types[k](overrides[k]) |
|
except KeyError: |
|
pass |
|
return part(func, **overrides) |
|
return wrap |
|
|
|
|
|
class FormatReplace(object): |
|
""" |
|
>>> a = FormatReplace('something') |
|
>>> f"{a:5d}" |
|
'something' |
|
""" |
|
def __init__(self, replace=''): |
|
self.replace = replace |
|
self.format_called = 0 |
|
|
|
def __format__(self, _): |
|
self.format_called += 1 |
|
return self.replace |
|
|
|
|
|
class Comparable(object): |
|
"""Assumes child has self._comparable attr/@property""" |
|
def __lt__(self, other): |
|
return self._comparable < other._comparable |
|
|
|
def __le__(self, other): |
|
return (self < other) or (self == other) |
|
|
|
def __eq__(self, other): |
|
return self._comparable == other._comparable |
|
|
|
def __ne__(self, other): |
|
return not self == other |
|
|
|
def __gt__(self, other): |
|
return not self <= other |
|
|
|
def __ge__(self, other): |
|
return not self < other |
|
|
|
|
|
class ObjectWrapper(object): |
|
def __getattr__(self, name): |
|
return getattr(self._wrapped, name) |
|
|
|
def __setattr__(self, name, value): |
|
return setattr(self._wrapped, name, value) |
|
|
|
def wrapper_getattr(self, name): |
|
"""Actual `self.getattr` rather than self._wrapped.getattr""" |
|
try: |
|
return object.__getattr__(self, name) |
|
except AttributeError: |
|
return getattr(self, name) |
|
|
|
def wrapper_setattr(self, name, value): |
|
"""Actual `self.setattr` rather than self._wrapped.setattr""" |
|
return object.__setattr__(self, name, value) |
|
|
|
def __init__(self, wrapped): |
|
""" |
|
Thin wrapper around a given object |
|
""" |
|
self.wrapper_setattr('_wrapped', wrapped) |
|
|
|
|
|
class SimpleTextIOWrapper(ObjectWrapper): |
|
""" |
|
Change only `.write()` of the wrapped object by encoding the passed |
|
value and passing the result to the wrapped object's `.write()` method. |
|
""" |
|
|
|
def __init__(self, wrapped, encoding): |
|
super().__init__(wrapped) |
|
self.wrapper_setattr('encoding', encoding) |
|
|
|
def write(self, s): |
|
""" |
|
Encode `s` and pass to the wrapped object's `.write()` method. |
|
""" |
|
return self._wrapped.write(s.encode(self.wrapper_getattr('encoding'))) |
|
|
|
def __eq__(self, other): |
|
return self._wrapped == getattr(other, '_wrapped', other) |
|
|
|
|
|
class DisableOnWriteError(ObjectWrapper): |
|
""" |
|
Disable the given `tqdm_instance` upon `write()` or `flush()` errors. |
|
""" |
|
@staticmethod |
|
def disable_on_exception(tqdm_instance, func): |
|
""" |
|
Quietly set `tqdm_instance.miniters=inf` if `func` raises `errno=5`. |
|
""" |
|
tqdm_instance = proxy(tqdm_instance) |
|
|
|
def inner(*args, **kwargs): |
|
try: |
|
return func(*args, **kwargs) |
|
except OSError as e: |
|
if e.errno != 5: |
|
raise |
|
try: |
|
tqdm_instance.miniters = float('inf') |
|
except ReferenceError: |
|
pass |
|
except ValueError as e: |
|
if 'closed' not in str(e): |
|
raise |
|
try: |
|
tqdm_instance.miniters = float('inf') |
|
except ReferenceError: |
|
pass |
|
return inner |
|
|
|
def __init__(self, wrapped, tqdm_instance): |
|
super().__init__(wrapped) |
|
if hasattr(wrapped, 'write'): |
|
self.wrapper_setattr( |
|
'write', self.disable_on_exception(tqdm_instance, wrapped.write)) |
|
if hasattr(wrapped, 'flush'): |
|
self.wrapper_setattr( |
|
'flush', self.disable_on_exception(tqdm_instance, wrapped.flush)) |
|
|
|
def __eq__(self, other): |
|
return self._wrapped == getattr(other, '_wrapped', other) |
|
|
|
|
|
class CallbackIOWrapper(ObjectWrapper): |
|
def __init__(self, callback, stream, method="read"): |
|
""" |
|
Wrap a given `file`-like object's `read()` or `write()` to report |
|
lengths to the given `callback` |
|
""" |
|
super().__init__(stream) |
|
func = getattr(stream, method) |
|
if method == "write": |
|
@wraps(func) |
|
def write(data, *args, **kwargs): |
|
res = func(data, *args, **kwargs) |
|
callback(len(data)) |
|
return res |
|
self.wrapper_setattr('write', write) |
|
elif method == "read": |
|
@wraps(func) |
|
def read(*args, **kwargs): |
|
data = func(*args, **kwargs) |
|
callback(len(data)) |
|
return data |
|
self.wrapper_setattr('read', read) |
|
else: |
|
raise KeyError("Can only wrap read/write methods") |
|
|
|
|
|
def _is_utf(encoding): |
|
try: |
|
u'\u2588\u2589'.encode(encoding) |
|
except UnicodeEncodeError: |
|
return False |
|
except Exception: |
|
try: |
|
return encoding.lower().startswith('utf-') or ('U8' == encoding) |
|
except Exception: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def _supports_unicode(fp): |
|
try: |
|
return _is_utf(fp.encoding) |
|
except AttributeError: |
|
return False |
|
|
|
|
|
def _is_ascii(s): |
|
if isinstance(s, str): |
|
for c in s: |
|
if ord(c) > 255: |
|
return False |
|
return True |
|
return _supports_unicode(s) |
|
|
|
|
|
def _screen_shape_wrapper(): |
|
""" |
|
Return a function which returns console dimensions (width, height). |
|
Supported: linux, osx, windows, cygwin. |
|
""" |
|
_screen_shape = None |
|
if IS_WIN: |
|
_screen_shape = _screen_shape_windows |
|
if _screen_shape is None: |
|
_screen_shape = _screen_shape_tput |
|
if IS_NIX: |
|
_screen_shape = _screen_shape_linux |
|
return _screen_shape |
|
|
|
|
|
def _screen_shape_windows(fp): |
|
try: |
|
import struct |
|
from ctypes import create_string_buffer, windll |
|
from sys import stdin, stdout |
|
|
|
io_handle = -12 |
|
if fp == stdin: |
|
io_handle = -10 |
|
elif fp == stdout: |
|
io_handle = -11 |
|
|
|
h = windll.kernel32.GetStdHandle(io_handle) |
|
csbi = create_string_buffer(22) |
|
res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi) |
|
if res: |
|
(_bufx, _bufy, _curx, _cury, _wattr, left, top, right, bottom, |
|
_maxx, _maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw) |
|
return right - left, bottom - top |
|
except Exception: |
|
pass |
|
return None, None |
|
|
|
|
|
def _screen_shape_tput(*_): |
|
"""cygwin xterm (windows)""" |
|
try: |
|
import shlex |
|
from subprocess import check_call |
|
return [int(check_call(shlex.split('tput ' + i))) - 1 |
|
for i in ('cols', 'lines')] |
|
except Exception: |
|
pass |
|
return None, None |
|
|
|
|
|
def _screen_shape_linux(fp): |
|
|
|
try: |
|
from array import array |
|
from fcntl import ioctl |
|
from termios import TIOCGWINSZ |
|
except ImportError: |
|
return None, None |
|
else: |
|
try: |
|
rows, cols = array('h', ioctl(fp, TIOCGWINSZ, '\0' * 8))[:2] |
|
return cols, rows |
|
except Exception: |
|
try: |
|
return [int(os.environ[i]) - 1 for i in ("COLUMNS", "LINES")] |
|
except (KeyError, ValueError): |
|
return None, None |
|
|
|
|
|
def _environ_cols_wrapper(): |
|
""" |
|
Return a function which returns console width. |
|
Supported: linux, osx, windows, cygwin. |
|
""" |
|
warn("Use `_screen_shape_wrapper()(file)[0]` instead of" |
|
" `_environ_cols_wrapper()(file)`", DeprecationWarning, stacklevel=2) |
|
shape = _screen_shape_wrapper() |
|
if not shape: |
|
return None |
|
|
|
@wraps(shape) |
|
def inner(fp): |
|
return shape(fp)[0] |
|
|
|
return inner |
|
|
|
|
|
def _term_move_up(): |
|
return '' if (os.name == 'nt') and (colorama is None) else '\x1b[A' |
|
|
|
|
|
def _text_width(s): |
|
return sum(2 if east_asian_width(ch) in 'FW' else 1 for ch in str(s)) |
|
|
|
|
|
def disp_len(data): |
|
""" |
|
Returns the real on-screen length of a string which may contain |
|
ANSI control codes and wide chars. |
|
""" |
|
return _text_width(RE_ANSI.sub('', data)) |
|
|
|
|
|
def disp_trim(data, length): |
|
""" |
|
Trim a string which may contain ANSI control characters. |
|
""" |
|
if len(data) == disp_len(data): |
|
return data[:length] |
|
|
|
ansi_present = bool(RE_ANSI.search(data)) |
|
while disp_len(data) > length: |
|
data = data[:-1] |
|
if ansi_present and bool(RE_ANSI.search(data)): |
|
|
|
return data if data.endswith("\033[0m") else data + "\033[0m" |
|
return data |
|
|