# python3.7
"""Misc utility functions."""

import os
import hashlib

from torch.hub import download_url_to_file

__all__ = [
    'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext',
    'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS',
    'parse_file_format', 'set_cache_dir', 'get_cache_dir', 'download_url'
]

REPO_NAME = 'Hammer'  # Name of the repository (project).


class Infix(object):
    """Helper class to create custom infix operators.

    When using it, make sure to put the operator between `<<` and `>>`.
    `<< INFIX_OP_NAME >>` should be considered as a whole operator.

    Examples:

    # Use `Infix` to create infix operators directly.
    add = Infix(lambda a, b: a + b)
    1 << add >> 2  # gives 3
    1 << add >> 2 << add >> 3  # gives 6

    # Use `Infix` as a decorator.
    @Infix
    def mul(a, b):
        return a * b
    2 << mul >> 4  # gives 8
    2 << mul >> 3 << mul >> 7  # gives 42
    """

    def __init__(self, function):
        self.function = function
        self.left_value = None

    def __rlshift__(self, left_value):  # override `<<` before `Infix` instance
        assert self.left_value is None  # make sure left is only called once
        self.left_value = left_value
        return self

    def __rshift__(self, right_value):  # override `>>` after `Infix` instance
        result = self.function(self.left_value, right_value)
        self.left_value = None  # reset to None
        return result


def print_and_execute(cmd):
    """Prints and executes a system command.

    Args:
        cmd: Command to be executed.
    """
    print(cmd)
    os.system(cmd)


def check_file_ext(filename, *ext_list):
    """Checks whether the given filename is with target extension(s).

    NOTE: If `ext_list` is empty, this function will always return `False`.

    Args:
        filename: Filename to check.
        *ext_list: A list of extensions.

    Returns:
        `True` if the filename is with one of extensions in `ext_list`,
        otherwise `False`.
    """
    if len(ext_list) == 0:
        return False
    ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list]
    ext_list = [ext.lower() for ext in ext_list]
    basename = os.path.basename(filename)
    ext = os.path.splitext(basename)[1].lower()
    return ext in ext_list


# File extensions regarding images (not including GIFs).
IMAGE_EXTENSIONS = (
    '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp',
    '.tiff', '.tif'
)
# File extensions regarding videos.
VIDEO_EXTENSIONS = (
    '.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm',
    '.3gp'
)
# File extensions regarding media, i.e., images, videos, GIFs.
MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS)


def parse_file_format(path):
    """Parses the file format of a given path.

    This function basically parses the file format according to its extension.
    It will also return `dir` is the given path is a directory.

    Parable file formats:

    - zip: with `.zip` extension.
    - tar: with `.tar` / `.tgz` / `.tar.gz` extension.
    - lmdb: a folder ending with `lmdb`.
    - txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE).
    - json: with `.json` extension.
    - jpg: with `.jpeg` / `jpg` / `jpe` extension.
    - png: with `.png` extension.

    Args:
        path: The path to the file to parse format from.

    Returns:
        A lower-case string, indicating the file format, or `None` if the format
            cannot be successfully parsed.
    """
    # Handle directory.
    if os.path.isdir(path) or path.endswith('/'):
        if path.rstrip('/').lower().endswith('lmdb'):
            return 'lmdb'
        return 'dir'
    # Handle file.
    if os.path.isfile(path) and os.path.splitext(path)[1] == '':
        return 'txt'
    path = path.lower()
    if path.endswith('.tar.gz'):  # Cannot parse accurate extension.
        return 'tar'
    ext = os.path.splitext(path)[1]
    if ext == '.zip':
        return 'zip'
    if ext in ['.tar', '.tgz']:
        return 'tar'
    if ext in ['.txt', '.text']:
        return 'txt'
    if ext == '.json':
        return 'json'
    if ext in ['.jpeg', '.jpg', '.jpe']:
        return 'jpg'
    if ext == '.png':
        return 'png'
    # Unparsable.
    return None


_cache_dir = None


def set_cache_dir(directory=None):
    """Sets the global cache directory.

    The cache directory can be used to save some files that will be shared
    across jobs. The default cache directory is set as `~/.cache/${REPO_NAME}/`.
    This function can be used to redirect the cache directory. Or, users can use
    `None` to reset the cache directory back to default.

    Args:
        directory: The target directory used to cache files. If set as `None`,
            the cache directory will be reset back to default. (default: None)
    """
    assert directory is None or isinstance(directory, str), 'Invalid directory!'
    global _cache_dir  # pylint: disable=global-statement
    _cache_dir = directory


def get_cache_dir():
    """Gets the global cache directory.

    The global cache directory is primarily set as `~/.cache/${REPO_NAME}/` by
    default, and can be redirected with `set_cache_dir()`.

    Returns:
        A string, representing the global cache directory.
    """
    if _cache_dir is None:
        home = os.path.expanduser('~')
        return os.path.join(home, '.cache', REPO_NAME)
    return _cache_dir


def download_url(url, path=None, filename=None, sha256=None):
    """Downloads file from URL.

    This function downloads a file from given URL, and executes Hash check if
    needed.

    Args:
        url: The URL to download file from.
        path: Path (directory) to save the downloaded file. If set as `None`,
            the cache directory will be used. Please see `get_cache_dir()` for
            more details. (default: None)
        filename: The name to save the file. If set as `None`, this name will be
            automatically parsed from the given URL. (default: None)
        sha256: The expected sha256 of the downloaded file. If set as `None`,
            the hash check will be skipped. Otherwise, this function will check
            whether the sha256 of the downloaded file matches this field.

    Returns:
        A two-element tuple, where the first term is the full path of the
            downloaded file, and the second term indicate the hash check result.
            `True` means hash check passes, `False` means hash check fails,
            while `None` means no hash check is executed.
    """
    # Handle file path.
    if path is None:
        path = get_cache_dir()
    if filename is None:
        filename = os.path.basename(url)
    save_path = os.path.join(path, filename)
    # Download file if needed.
    if not os.path.exists(save_path):
        print(f'Downloading URL `{url}` to path `{save_path}` ...')
        os.makedirs(path, exist_ok=True)
        download_url_to_file(url, save_path, hash_prefix=None, progress=True)
    # Check hash if needed.
    check_result = None
    if sha256 is not None:
        with open(save_path, 'rb') as f:
            file_hash = hashlib.sha256(f.read())
            check_result = (file_hash.hexdigest() == sha256)

    return save_path, check_result