from huggingface_hub import HfApi, Repository
import gradio as gr
import json


def change_tab(query_param):
    query_param = query_param.replace("'", '"')
    query_param = json.loads(query_param)

    if (
        isinstance(query_param, dict)
        and "tab" in query_param
        and query_param["tab"] == "plot"
    ):
        return gr.Tabs.update(selected=1)
    else:
        return gr.Tabs.update(selected=0)


def restart_space(LLM_PERF_LEADERBOARD_REPO, OPTIMUM_TOKEN):
    HfApi().restart_space(repo_id=LLM_PERF_LEADERBOARD_REPO, token=OPTIMUM_TOKEN)


def load_dataset_repo(LLM_PERF_DATASET_REPO, OPTIMUM_TOKEN):
    llm_perf_dataset_repo = None
    if OPTIMUM_TOKEN:
        print("Loading LLM-Perf-Dataset from Hub...")
        llm_perf_dataset_repo = Repository(
            local_dir="./llm-perf-dataset",
            clone_from=LLM_PERF_DATASET_REPO,
            token=OPTIMUM_TOKEN,
            repo_type="dataset",
        )
        llm_perf_dataset_repo.git_pull()

    return llm_perf_dataset_repo


LLM_MODEL_TYPES = {
    "gpt_bigcode": "GPT-BigCode 🌸",
    "RefinedWebModel": "Falcon 🦅",
    "RefinedWeb": "Falcon 🦅",
    "baichuan": "Baichuan 🌊",
    "llama": "LLaMA 🦙",
    "gpt_neox": "GPT-NeoX",
    "gpt_neo": "GPT-Neo",
    "codegen": "CodeGen",
    "chatglm": "ChatGLM",
    "gpt2": "GPT-2",
    "gptj": "GPT-J",
    "xglm": "XGLM",
    "opt": "OPT",
    "mpt": "MPT",
}


def model_hyperlink(link, model_name):
    return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'


def process_model_name(model_name):
    link = f"https://huggingface.co/{model_name}"
    return model_hyperlink(link, model_name)


def process_model_type(model_type):
    if model_type in LLM_MODEL_TYPES:
        return LLM_MODEL_TYPES[model_type]
    else:
        return model_type


def process_weight_class(num):
    if num < 1000:
        return str(int(num))
    elif num < 1000000:
        return str(int(num / 1000)) + "K"
    elif num < 1000000000:
        return str(int(num / 1000000)) + "M"
    elif num < 1000000000000:
        return str(int(num / 1000000000)) + "B"
    return None