"""Gradio utilities.

Note that the optional `progress` parameter can be both a `tqdm` module or a
`gr.Progress` instance.
"""

import concurrent.futures
import contextlib
import glob
import hashlib
import logging
import os
import tempfile
import time
import urllib.request

import jax
import numpy as np
from tensorflow.io import gfile


@contextlib.contextmanager
def timed(name):
  t0 = time.monotonic()
  timing = dict(dt=None)
  try:
    yield timing
  finally:
    timing['secs'] = time.monotonic() - t0
    logging.info('Timed %s: %.1f secs', name, timing['secs'])


def copy_file(
    src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
):
  """Copies a file with progress bar.

  Args:
    src: Source file (readable by `tf.io.gfile`) or URL.
    dst: Destination file. Path must be readable by `tf.io.gfile`.
    progress: An object with a `.tqdm` attribute, or `None`.
    block_size: Size of individual blocks to be read/written.
    overwrite: If `True`, overwrite `dst` if it exists.
  """
  if os.path.dirname(dst):
    os.makedirs(os.path.dirname(dst), exist_ok=True)
  if os.path.exists(dst) and not overwrite:
    return

  if src.startswith('http://') or src.startswith('https://'):
    opener = urllib.request.urlopen
    request = urllib.request.Request(src, method='HEAD')
    response = urllib.request.urlopen(request)
    content_length = response.headers.get('Content-Length')
    n = int(np.ceil(int(content_length) / block_size))
    print('content_length', content_length)
  else:
    opener = lambda path: gfile.GFile(path, 'rb')
    stats = gfile.stat(src)
    n = int(np.ceil(stats.length / block_size))

  if progress is None:
    range_or_trange = range
  else:
    range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download')

  with opener(src) as fin:
    with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout:
      for _ in range_or_trange(n):
        fout.write(fin.read(block_size))
  gfile.rename(f'{dst}-PARTIAL', dst)


_estimated_real = [(10, 10)]
_memory_cache = {}


def get_with_progress(getter, secs, progress, step=0.1):
  """Returns result from `getter` while showing a progress bar."""
  with concurrent.futures.ThreadPoolExecutor() as executor:
    future = executor.submit(getter)
    for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
      if not future.done():
        time.sleep(step)
  return future.result()


def _get_array_sizes(tree):
  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]


def get_memory_cache(
    key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
):
  """Keeps cache below specified size by removing elements not last accessed."""
  if key in _memory_cache:
    _memory_cache[key] = _memory_cache.pop(key)  # updated "last accessed" order
    return _memory_cache[key]

  est, real = zip(*_estimated_real)
  if estimated_secs is None:
    estimated_secs = sum(est) / len(est)
  with timed(f'loading {key}') as timing:
    estimated_secs *= sum(real) / sum(est)
    _memory_cache[key] = get_with_progress(getter, estimated_secs, progress)
  _estimated_real.append((estimated_secs, timing['secs']))

  sz = sum(_get_array_sizes(list(_memory_cache.values())))
  logging.info('New memory cache size=%.1f MB', sz/1e6)

  while sz > max_cache_size_bytes:
    k, v = next(iter(_memory_cache.items()))
    if k == key:
      break
    s = sum(_get_array_sizes(v))
    logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
    _memory_cache.pop(k)
    sz -= s

  return _memory_cache[key]


def get_memory_cache_info():
  """Returns number of items and total size in bytes."""
  sizes = _get_array_sizes(_memory_cache)
  return len(_memory_cache), sum(sizes)


CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache')


def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None):
  """Keeps cache below specified size by removing elements not last accessed."""
  fname = os.path.basename(path_or_url)
  path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname
  dst = os.path.join(CACHE_DIR, path_hash, fname)
  if os.path.exists(dst):
    return dst

  os.makedirs(os.path.dirname(dst), exist_ok=True)
  with timed(f'copying {path_or_url}'):
    copy_file(path_or_url, dst, progress=progress)

  atimes_sizes_paths = sorted([
      (os.path.getatime(p), os.path.getsize(p), p)
      for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
      if os.path.isfile(p)
  ])
  sz = sum(sz for _, sz, _ in atimes_sizes_paths)
  logging.info('New disk cache size=%.1f MB', sz/1e6)

  while sz > max_cache_size_bytes:
    _, s, path = atimes_sizes_paths.pop(0)
    if path == dst:
      break
    logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6)
    os.unlink(fname)
    sz -= s

  return dst


def get_disk_cache_info():
  """Returns number of items and total size in bytes."""
  sizes = [
      os.path.getsize(p)
      for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
  ]
  return len(sizes), sum(sizes)