from dataclasses import dataclass
from enum import Enum

from air_benchmark.tasks.tasks import BenchmarkTable

from src.envs import METRIC_LIST


def get_safe_name(name: str):
    """Get RFC 1123 compatible safe name"""
    name = name.replace('-', '_')
    return ''.join(
        character.lower()
        for character in name
        if (character.isalnum() or character == '_'))


@dataclass
class Benchmark:
    name: str  # [domain]_[language]_[metric], task_key in the json file,
    metric: str  # ndcg_at_1 ,metric_key in the json file
    col_name: str  # [domain]_[language], name to display in the leaderboard
    domain: str
    lang: str
    task: str


qa_benchmark_dict = {}
long_doc_benchmark_dict = {}
for task, domain_dict in BenchmarkTable['AIR-Bench_24.04'].items():
    for domain, lang_dict in domain_dict.items():
        for lang, dataset_list in lang_dict.items():
            if task == "qa":
                benchmark_name = f"{domain}_{lang}"
                benchmark_name = get_safe_name(benchmark_name)
                col_name = benchmark_name
                for metric in dataset_list:
                    qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
            elif task == "long-doc":
                for dataset in dataset_list:
                    benchmark_name = f"{domain}_{lang}_{dataset}"
                    benchmark_name = get_safe_name(benchmark_name)
                    col_name = benchmark_name
                    for metric in METRIC_LIST:
                        long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain,
                                                                            lang, task)

BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)