|
"""Gradio helpers for caching, downloading etc.""" |
|
|
|
import concurrent.futures |
|
import contextlib |
|
import datetime |
|
import functools |
|
import logging |
|
import os |
|
import shutil |
|
import threading |
|
import time |
|
|
|
import huggingface_hub |
|
import numpy as np |
|
import psutil |
|
|
|
|
|
def should_mock(): |
|
"""Returns `True` if `MOCK_MODEL=yes` is set in environment.""" |
|
return os.environ.get('MOCK_MODEL') == 'yes' |
|
|
|
|
|
@contextlib.contextmanager |
|
def timed(name, start_message=False): |
|
"""Emits "Timed {name}: .1f secs" message to INFO logs.""" |
|
t0 = time.monotonic() |
|
timing = dict(dt=None) |
|
try: |
|
if start_message: |
|
logging.info('Timing %s...', name) |
|
yield timing |
|
finally: |
|
timing['secs'] = time.monotonic() - t0 |
|
logging.info('Timed %s: %.1f secs', name, timing['secs']) |
|
|
|
|
|
def synced(f): |
|
"""Syncs calls to `f` with a `threading.Lock()`.""" |
|
lock = threading.Lock() |
|
@functools.wraps(f) |
|
def wrapper(*args, **kw): |
|
t0 = time.monotonic() |
|
with lock: |
|
lock_dt = time.monotonic() - t0 |
|
logging.info('synced wait: %.1f secs', lock_dt) |
|
return f(*args, **kw) |
|
return wrapper |
|
|
|
|
|
_warmed_up = set() |
|
_warmup_function = None |
|
|
|
|
|
def set_warmup_function(warmup_function): |
|
global _warmup_function |
|
_warmup_function = warmup_function |
|
|
|
|
|
_lock = threading.Lock() |
|
_scheduled = {} |
|
_download_secs = 0 |
|
_warmup_secs = 0 |
|
_loading_secs = 0 |
|
_done = {} |
|
_failed = {} |
|
|
|
|
|
def _do_download(): |
|
"""Downloading files, to be started in background thread.""" |
|
global _download_secs |
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) |
|
while True: |
|
if not _scheduled: |
|
time.sleep(1) |
|
continue |
|
|
|
name, (repo, filenames, revision) = next(iter(_scheduled.items())) |
|
logging.info('Downloading "%s" %s/%s/%s...', name, repo, filenames, revision) |
|
with timed(f'downloading {name}', True) as t: |
|
if should_mock(): |
|
logging.warning('Mocking loading') |
|
time.sleep(10.) |
|
_done[name] = None |
|
else: |
|
try: |
|
_done[name] = (huggingface_hub.hf_hub_download(repo_id=repo, filename=filename, revision=revision) for filename in filenames) |
|
except Exception as e: |
|
logging.exception('Could not download "%s" from hub!', name) |
|
_failed[name] = str(e) |
|
with _lock: |
|
_scheduled.pop(name) |
|
continue |
|
|
|
if _warmup_function: |
|
def warmup(name): |
|
global _warmup_secs |
|
with timed(f'warming up {name}', True) as t: |
|
try: |
|
_warmup_function(name) |
|
_warmed_up.add(name) |
|
except Exception: |
|
logging.exception('Could not warmup "%s"!', name) |
|
_warmup_secs += t['secs'] |
|
executor.submit(warmup, name) |
|
|
|
_download_secs += t['secs'] |
|
with _lock: |
|
_scheduled.pop(name) |
|
|
|
|
|
def register_download(name, repo, filenames, revision='main'): |
|
"""Will cause download of `filename` from HF `repo` in background thread.""" |
|
with _lock: |
|
if name not in _scheduled: |
|
_scheduled[name] = (repo, filenames, revision) |
|
|
|
|
|
def _hms(secs): |
|
"""Formats `secs=3700` to `"01:01:40"`.""" |
|
secs = int(secs) |
|
h = secs // 3600 |
|
m = (secs - h * 3600) // 60 |
|
s = secs % 60 |
|
return (f'{h}:' if h else '') + f'{m:02}:{s:02}' |
|
|
|
|
|
def downloads_status(): |
|
"""Returns string representation of download stats.""" |
|
done_t = remaining_t = '' |
|
if _done: |
|
done_t = f' in {_hms(_download_secs)}' |
|
remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}' |
|
status = f'Downloaded {len(_done)}{done_t}' |
|
if _scheduled: |
|
status += f', {len(_scheduled)}{remaining_t} remaining' |
|
if _warmup_function: |
|
status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}' |
|
if _failed: |
|
status += f', {len(_failed)} failed' |
|
return status |
|
|
|
|
|
def get_paths(): |
|
"""Returns dictionary `name` to `path` from previous `register_download()`.""" |
|
return dict(_done) |
|
|
|
|
|
_download_thread = threading.Thread(target=_do_download) |
|
_download_thread.daemon = True |
|
_download_thread.start() |
|
|
|
|
|
_estimated_real = [(10, 10)] |
|
_memory_cache = {} |
|
|
|
|
|
def get_with_progress(getter, secs, progress, step=0.1): |
|
"""Returns result from `getter` while showing a progress bar.""" |
|
if progress is None: |
|
return getter() |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
future = executor.submit(getter) |
|
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'): |
|
if not future.done(): |
|
time.sleep(step) |
|
return future.result() |
|
|
|
|
|
def _get_array_sizes(tree): |
|
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)] |
|
|
|
|
|
def get_memory_cache( |
|
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None |
|
): |
|
"""Keeps cache below specified size by removing elements not last accessed.""" |
|
if key in _memory_cache: |
|
_memory_cache[key] = _memory_cache.pop(key) |
|
return _memory_cache[key] |
|
|
|
est, real = zip(*_estimated_real) |
|
if estimated_secs is None: |
|
estimated_secs = sum(est) / len(est) |
|
with timed(f'loading {key}') as t: |
|
estimated_secs *= sum(real) / sum(est) |
|
value = get_with_progress(getter, estimated_secs, progress) |
|
_estimated_real.append((estimated_secs, t['secs'])) |
|
|
|
if not max_cache_size_bytes: |
|
return value |
|
|
|
_memory_cache[key] = value |
|
sz = sum(_get_array_sizes(list(_memory_cache.values()))) |
|
logging.info('New memory cache size=%.1f MB', sz/1e6) |
|
|
|
while sz > max_cache_size_bytes: |
|
k, v = next(iter(_memory_cache.items())) |
|
if k == key: |
|
break |
|
s = sum(_get_array_sizes(v)) |
|
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6) |
|
_memory_cache.pop(k) |
|
sz -= s |
|
|
|
return value |
|
|
|
|
|
def get_memory_cache_info(): |
|
"""Returns number of items and total size in bytes.""" |
|
sizes = _get_array_sizes(_memory_cache) |
|
return len(_memory_cache), sum(sizes) |
|
|
|
|
|
def get_system_info(): |
|
"""Returns string describing system's RAM/disk status.""" |
|
host_colocation = int(os.environ.get('HOST_COLOCATION', '1')) |
|
vm = psutil.virtual_memory() |
|
du = shutil.disk_usage('.') |
|
return ( |
|
f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, ' |
|
f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G' |
|
) |
|
|
|
|
|
def get_status(include_system_info=True): |
|
"""Returns string about download/memory/system status.""" |
|
mc_len, mc_sz = get_memory_cache_info() |
|
mc_t = _hms(sum(real for _, real in _estimated_real[1:])) |
|
return ( |
|
'Timestamp: ' |
|
+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
|
+ ' – Model stats: ' |
|
+ downloads_status() |
|
+ ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' + |
|
(' – System: ' + get_system_info() if include_system_info else '') |
|
) |
|
|