import os
import json
import glob
from collections import defaultdict
import pandas as pd
import gradio as gr
from content import *
from css import *
import glob

ARC = "arc"
HELLASWAG = "hellaswag"
MMLU = "mmlu"
TRUTHFULQA = "truthfulqa"
BENCHMARKS = [ARC, HELLASWAG, MMLU, TRUTHFULQA]

METRICS = ["acc_norm", "acc_norm", "acc_norm", "mc2"]

LANGS = 'ar,bn,ca,da,de,es,eu,fr,gu,hi,hr,hu,hy,id,it,kn,ml,mr,ne,nl,pt,ro,ru,sk,sr,sv,ta,te,uk,vi,zh'.split(',')

LANG_NAME = {
    'ar': 'Arabic',
    'bn': 'Bengali',
    'ca': 'Catalan',
    'da': 'Danish',
    'de': 'German',
    'es': 'Spanish',
    'eu': 'Basque',
    'fr': 'French',
    'gu': 'Gujarati',
    'hi': 'Hindi',
    'hr': 'Croatian',
    'hu': 'Hungarian',
    'hy': 'Armenian',
    'id': 'Indonesian',
    'it': 'Italian',
    'kn': 'Kannada',
    'ml': 'Malayalam',
    'mr': 'Marathi',
    'ne': 'Nepali',
    'nl': 'Dutch',
    'pt': 'Portuguese',
    'ro': 'Romanian',
    'ru': 'Russian',
    'sk': 'Slovak',
    'sr': 'Serbian',
    'sv': 'Swedish',
    'ta': 'Tamil',
    'te': 'Telugu',
    'uk': 'Ukrainian',
    'vi': 'Vietnamese',
    'zh': 'Chinese'
}


def collect_results():
    performance_dict = defaultdict(dict)
    pretrained_models = set()
    for file in glob.glob('evals/*/*.json'):
        with open(file, 'r') as f:
            data = json.load(f)
        if 'results' not in data:
            continue
        if 'config' not in data:
            continue
        results = data['results']
        config = data['config']
        if 'model_args' not in config:
            continue

        model_args = config['model_args'].split(',')
        pretrained = [x for x in model_args if x.startswith('pretrained=')]
        if len(pretrained) != 1:
            continue
        pretrained = pretrained[0].split('=')[1]
        pretrained = pretrained.split('/')[-1]
        pretrained_models.add(pretrained)

        for lang_task, perfs in results.items():
            task, lang = lang_task.split('_')
            assert task in BENCHMARKS

            if lang and task:
                metric = METRICS[BENCHMARKS.index(task)]
                p = round(perfs[metric] * 100, 1)
                performance_dict[(pretrained, lang)][task] = p
    return performance_dict, pretrained_models


def get_leaderboard_df(performance_dict, pretrained_models):
    df = list()
    for (pretrained, lang), perfs in performance_dict.items():
        lang_name = LANG_NAME[lang]
        arc_perf = perfs.get(ARC, 0.0)
        hellaswag_perf = perfs.get(HELLASWAG, 0.0)
        mmlu_perf = perfs.get(MMLU, 0.0)
        truthfulqa_perf = perfs.get(TRUTHFULQA, 0.0)

        if arc_perf * hellaswag_perf * mmlu_perf * truthfulqa_perf == 0:
            continue
        avg = round((arc_perf + hellaswag_perf + mmlu_perf + truthfulqa_perf) / 4, 1)
        notes = ' '.join([pretrained, lang_name])
        row = [pretrained, lang_name, lang, avg, arc_perf, hellaswag_perf, mmlu_perf, truthfulqa_perf, notes]
        df.append(row)

    df = pd.DataFrame.from_records(df, columns=COLS)
    df = df.sort_values(by=[LANG_COL, AVERAGE_COL], ascending=False)
    df = df[COLS]

    return df


def search_table(df, query):
    filtered_df = df[df[NOTES_COL].str.contains(query, case=False)]
    return filtered_df



MODEL_COL = "Model"
LANG_COL = "Language"
CODE_COL = "Code"
AVERAGE_COL = "Average"
ARC_COL = "ARC (25-shot)"
HELLASWAG_COL = "HellaSwag (0-shot)️"
MMLU_COL = "MMLU (25-shot)"
TRUTHFULQA_COL = "TruthfulQA (0-shot)"
NOTES_COL = "Notes"  # For search only

COLS = [MODEL_COL, LANG_COL, CODE_COL, AVERAGE_COL, ARC_COL, HELLASWAG_COL, MMLU_COL, TRUTHFULQA_COL, NOTES_COL]
TYPES = ["str", "str", "str", "number", "number", "number", "number", "number", "str"]

args = collect_results()
original_df = get_leaderboard_df(*args)

demo = gr.Blocks(css=CUSTOM_CSS)
with demo:
    gr.HTML(TITLE)
    gr.Markdown(INTRO_TEXT, elem_classes="markdown-text")
    gr.Markdown(HOW_TO, elem_classes="markdown-text")

    with gr.Box():
        search_bar = gr.Textbox(
            placeholder="Search models and languages...", show_label=False, elem_id="search-bar"
        )

        leaderboard_table = gr.components.Dataframe(
            value=original_df,
            headers=COLS,
            datatype=TYPES,
            max_rows=5,
            elem_id="leaderboard-table",
        )

        # # Dummy leaderboard for handling the case when the user uses backspace key
        hidden_leaderboard_table_for_search = gr.components.Dataframe(
            value=original_df, headers=COLS, datatype=TYPES, max_rows=5, visible=False
        )

        search_bar.change(
            search_table,
            [hidden_leaderboard_table_for_search, search_bar],
            leaderboard_table,
        )

    gr.Markdown(CREDIT, elem_classes="markdown-text")
    gr.Markdown(CITATION, elem_classes="markdown-text")

demo.launch()