from dataclasses import dataclass from enum import Enum from air_benchmark.tasks.tasks import BenchmarkTable from src.envs import METRIC_LIST def get_safe_name(name: str): """Get RFC 1123 compatible safe name""" name = name.replace('-', '_') return ''.join( character.lower() for character in name if (character.isalnum() or character == '_')) @dataclass class Benchmark: name: str # [domain]_[language]_[metric], task_key in the json file, metric: str # ndcg_at_1 ,metric_key in the json file col_name: str # [domain]_[language], name to display in the leaderboard domain: str lang: str task: str # create a function return an enum class containing all the benchmarks def get_benchmarks_enum(benchmark_version): qa_benchmark_dict = {} long_doc_benchmark_dict = {} for task, domain_dict in BenchmarkTable[benchmark_version].items(): for domain, lang_dict in domain_dict.items(): for lang, dataset_list in lang_dict.items(): if task == "qa": benchmark_name = f"{domain}_{lang}" benchmark_name = get_safe_name(benchmark_name) col_name = benchmark_name for metric in dataset_list: qa_benchmark_dict[benchmark_name] = \ Benchmark( benchmark_name, metric, col_name, domain, lang, task) elif task == "long-doc": for dataset in dataset_list: benchmark_name = f"{domain}_{lang}_{dataset}" benchmark_name = get_safe_name(benchmark_name) col_name = benchmark_name for metric in METRIC_LIST: long_doc_benchmark_dict[benchmark_name] = \ Benchmark( benchmark_name, metric, col_name, domain, lang, task) return qa_benchmark_dict, long_doc_benchmark_dict _qa_benchmark_dict, _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04') QABenchmarks = Enum('QABenchmarks', _qa_benchmark_dict) LongDocBenchmarks = Enum('LongDocBenchmarks', _long_doc_benchmark_dict)