"""Gradio helpers for caching, downloading etc."""

import concurrent.futures
import contextlib
import datetime
import functools
import logging
import os
import shutil
import threading
import time

import huggingface_hub
import numpy as np
import psutil


def should_mock():
  """Returns `True` if `MOCK_MODEL=yes` is set in environment."""
  return os.environ.get('MOCK_MODEL') == 'yes'


@contextlib.contextmanager
def timed(name, start_message=False):
  """Emits "Timed {name}: .1f secs" message to INFO logs."""
  t0 = time.monotonic()
  timing = dict(dt=None)
  try:
    if start_message:
      logging.info('Timing %s...', name)
    yield timing
  finally:
    timing['secs'] = time.monotonic() - t0
    logging.info('Timed %s: %.1f secs', name, timing['secs'])


def synced(f):
  """Syncs calls to `f` with a `threading.Lock()`."""
  lock = threading.Lock()
  @functools.wraps(f)
  def wrapper(*args, **kw):
    t0 = time.monotonic()
    with lock:
      lock_dt = time.monotonic() - t0
      logging.info('synced wait: %.1f secs', lock_dt)
      return f(*args, **kw)
  return wrapper


_warmed_up = set()
_warmup_function = None


def set_warmup_function(warmup_function):
  global _warmup_function
  _warmup_function = warmup_function


_lock = threading.Lock()
_scheduled = {}
_download_secs = 0
_warmup_secs = 0
_loading_secs = 0
_done = {}
_failed = {}


def _do_download():
  """Downloading files, to be started in background thread."""
  global _download_secs
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
  while True:
    if not _scheduled:
      time.sleep(1)
      continue

    name, (repo, filename, revision) = next(iter(_scheduled.items()))
    logging.info('Downloading "%s" %s/%s/%s...', name, repo, filename, revision)
    with timed(f'downloading {name}', True) as t:
      if should_mock():
        logging.warning('Mocking loading')
        time.sleep(10.)
        _done[name] = None
      else:
        try:
          _done[name] = huggingface_hub.hf_hub_download(
              repo_id=repo, filename=filename, revision=revision)
        except Exception as e:  # pylint: disable=broad-exception-caught
          logging.exception('Could not download "%s" from hub!', name)
          _failed[name] = str(e)
          with _lock:
            _scheduled.pop(name)
          continue

    if _warmup_function:
      def warmup(name):
        global _warmup_secs
        with timed(f'warming up {name}', True) as t:
          try:
            _warmup_function(name)
            _warmed_up.add(name)
          except Exception:  # pylint: disable=broad-exception-caught
            logging.exception('Could not warmup "%s"!', name)
        _warmup_secs += t['secs']
      executor.submit(warmup, name)

    _download_secs += t['secs']
    with _lock:
      _scheduled.pop(name)


def register_download(name, repo, filename, revision='main'):
  """Will cause download of `filename` from HF `repo` in background thread."""
  with _lock:
    if name not in _scheduled:
      _scheduled[name] = (repo, filename, revision)


def _hms(secs):
  """Formats `secs=3700` to `"01:01:40"`."""
  secs = int(secs)
  h = secs // 3600
  m = (secs - h * 3600) // 60
  s = secs % 60
  return (f'{h}:' if h else '') + f'{m:02}:{s:02}'


def downloads_status():
  """Returns string representation of download stats."""
  done_t = remaining_t = ''
  if _done:
    done_t = f' in {_hms(_download_secs)}'
    remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}'
  status = f'Downloaded {len(_done)}{done_t}'
  if _scheduled:
    status += f', {len(_scheduled)}{remaining_t} remaining'
  if _warmup_function:
    status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}'
  if _failed:
    status += f', {len(_failed)} failed'
  return status


def get_paths():
  """Returns dictionary `name` to `path` from previous `register_download()`."""
  return dict(_done)


_download_thread = threading.Thread(target=_do_download)
_download_thread.daemon = True
_download_thread.start()


_estimated_real = [(10, 10)]
_memory_cache = {}


def get_with_progress(getter, secs, progress, step=0.1):
  """Returns result from `getter` while showing a progress bar."""
  if progress is None:
    return getter()
  with concurrent.futures.ThreadPoolExecutor() as executor:
    future = executor.submit(getter)
    for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
      if not future.done():
        time.sleep(step)
  return future.result()


def _get_array_sizes(tree):
  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]


def get_memory_cache(
    key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
):
  """Keeps cache below specified size by removing elements not last accessed."""
  if key in _memory_cache:
    _memory_cache[key] = _memory_cache.pop(key)  # Updates "last accessed" order
    return _memory_cache[key]

  est, real = zip(*_estimated_real)
  if estimated_secs is None:
    estimated_secs = sum(est) / len(est)
  with timed(f'loading {key}') as t:
    estimated_secs *= sum(real) / sum(est)
    value = get_with_progress(getter, estimated_secs, progress)
  _estimated_real.append((estimated_secs, t['secs']))

  if not max_cache_size_bytes:
    return value

  _memory_cache[key] = value
  sz = sum(_get_array_sizes(list(_memory_cache.values())))
  logging.info('New memory cache size=%.1f MB', sz/1e6)

  while sz > max_cache_size_bytes:
    k, v = next(iter(_memory_cache.items()))
    if k == key:
      break
    s = sum(_get_array_sizes(v))
    logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
    _memory_cache.pop(k)
    sz -= s

  return value


def get_memory_cache_info():
  """Returns number of items and total size in bytes."""
  sizes = _get_array_sizes(_memory_cache)
  return len(_memory_cache), sum(sizes)


def get_system_info():
  """Returns string describing system's RAM/disk status."""
  host_colocation = int(os.environ.get('HOST_COLOCATION', '1'))
  vm = psutil.virtual_memory()
  du = shutil.disk_usage('.')
  return (
      f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, '
      f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G'
  )


def get_status(include_system_info=True):
  """Returns string about download/memory/system status."""
  mc_len, mc_sz = get_memory_cache_info()
  mc_t = _hms(sum(real for _, real in _estimated_real[1:]))
  return (
      'Timestamp: '
      + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
      + ' – Model stats: '
      + downloads_status()
      + ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' +
      (' – System: ' + get_system_info() if include_system_info else '')
  )