import glob
import json
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np

# clone / pull the lmeh eval data
METRICS = ["acc_norm", "acc_norm", "acc_norm", "mc2"]
BENCHMARKS = ["arc_challenge", "hellaswag", "hendrycks", "truthfulqa_mc"]
BENCH_TO_NAME = {
    "arc_challenge": "ARC (25-shot) ⬆️",
    "hellaswag": "HellaSwag (10-shot) ⬆️",
    "hendrycks": "MMLU (5-shot) ⬆️",
    "truthfulqa_mc": "TruthfulQA (0-shot) ⬆️",
}


def make_clickable_model(model_name):
    LLAMAS = [
        "huggingface/llama-7b",
        "huggingface/llama-13b",
        "huggingface/llama-30b",
        "huggingface/llama-65b",
    ]
    if model_name in LLAMAS:
        model = model_name.split("/")[1]
        return f'<a target="_blank" href="https://ai.facebook.com/blog/large-language-model-llama-meta-ai/" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model}</a>'

    if model_name == "HuggingFaceH4/stable-vicuna-13b-2904":
        link = "https://huggingface.co/" + "CarperAI/stable-vicuna-13b-delta"
        return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">stable-vicuna-13b</a>'

    if model_name == "HuggingFaceH4/llama-7b-ift-alpaca":
        link = "https://crfm.stanford.edu/2023/03/13/alpaca.html"
        return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">alpaca-13b</a>'

    # remove user from model name
    # model_name_show = ' '.join(model_name.split('/')[1:])

    link = "https://huggingface.co/" + model_name
    return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'


@dataclass
class EvalResult:
    eval_name: str
    org: str
    model: str
    revision: str
    is_8bit: bool
    results: dict

    def to_dict(self):
        if self.org is not None:
            base_model = f"{self.org}/{self.model}"
        else:
            base_model = f"{self.model}"
        data_dict = {}

        data_dict["eval_name"] = self.eval_name
        data_dict["8bit"] = self.is_8bit
        data_dict["Model"] = make_clickable_model(base_model)
        data_dict["model_name_for_query"] = base_model
        data_dict["Revision"] = self.revision
        data_dict["Average ⬆️"] = round(
            sum([v for k, v in self.results.items()]) / 4.0, 1
        )

        for benchmark in BENCHMARKS:
            if not benchmark in self.results.keys():
                self.results[benchmark] = None

        for k, v in BENCH_TO_NAME.items():
            data_dict[v] = self.results[k]

        return data_dict


def parse_eval_result(json_filepath: str) -> Tuple[str, dict]:
    with open(json_filepath) as fp:
        data = json.load(fp)

    path_split = json_filepath.split("/")
    org = None
    model = path_split[-4]
    is_8bit = path_split[-2] == "8bit"
    revision = path_split[-3]
    if len(path_split) == 7:
        # handles gpt2 type models that don't have an org
        result_key = f"{path_split[-4]}_{path_split[-3]}_{path_split[-2]}"
    else:
        result_key = (
            f"{path_split[-5]}_{path_split[-4]}_{path_split[-3]}_{path_split[-2]}"
        )
        org = path_split[-5]

    eval_result = None
    for benchmark, metric in zip(BENCHMARKS, METRICS):
        if benchmark in json_filepath:
            accs = np.array([v[metric] for k, v in data["results"].items()])
            mean_acc = round(np.mean(accs) * 100.0, 1)
            eval_result = EvalResult(
                result_key, org, model, revision, is_8bit, {benchmark: mean_acc}
            )

    return result_key, eval_result


def get_eval_results(is_public) -> List[EvalResult]:
    json_filepaths = glob.glob(
        "evals/eval_results/public/**/16bit/*.json", recursive=True
    )
    if not is_public:
        json_filepaths += glob.glob(
            "evals/eval_results/private/**/*.json", recursive=True
        )
        json_filepaths += glob.glob(
            "evals/eval_results/private/**/*.json", recursive=True
        )
        json_filepaths += glob.glob(
            "evals/eval_results/public/**/8bit/*.json", recursive=True
        )  # include the 8bit evals of public models
    eval_results = {}

    for json_filepath in json_filepaths:
        result_key, eval_result = parse_eval_result(json_filepath)
        if result_key in eval_results.keys():
            eval_results[result_key].results.update(eval_result.results)
        else:
            eval_results[result_key] = eval_result

    eval_results = [v for k, v in eval_results.items()]

    return eval_results


def get_eval_results_dicts(is_public=True) -> List[Dict]:
    eval_results = get_eval_results(is_public)

    return [e.to_dict() for e in eval_results]