Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from dataclasses import dataclass | |
from enum import Enum | |
from air_benchmark.tasks.tasks import BenchmarkTable | |
def get_safe_name(name: str): | |
"""Get RFC 1123 compatible safe name""" | |
name = name.replace('-', '_') | |
return ''.join( | |
character.lower() | |
for character in name | |
if (character.isalnum() or character == '_')) | |
METRIC_LIST = [ | |
"ndcg_at_1", | |
"ndcg_at_3", | |
"ndcg_at_5", | |
"ndcg_at_10", | |
"ndcg_at_100", | |
"ndcg_at_1000", | |
"map_at_1", | |
"map_at_3", | |
"map_at_5", | |
"map_at_10", | |
"map_at_100", | |
"map_at_1000", | |
"recall_at_1", | |
"recall_at_3", | |
"recall_at_5", | |
"recall_at_10", | |
"recall_at_100", | |
"recall_at_1000", | |
"precision_at_1", | |
"precision_at_3", | |
"precision_at_5", | |
"precision_at_10", | |
"precision_at_100", | |
"precision_at_1000", | |
"mrr_at_1", | |
"mrr_at_3", | |
"mrr_at_5", | |
"mrr_at_10", | |
"mrr_at_100", | |
"mrr_at_1000" | |
] | |
class Benchmark: | |
name: str # [domain]_[language]_[metric], task_key in the json file, | |
metric: str # ndcg_at_1 ,metric_key in the json file | |
col_name: str # [domain]_[language], name to display in the leaderboard | |
domain: str | |
lang: str | |
task: str | |
qa_benchmark_dict = {} | |
long_doc_benchmark_dict = {} | |
for task, domain_dict in BenchmarkTable['AIR-Bench_24.04'].items(): | |
for domain, lang_dict in domain_dict.items(): | |
for lang, dataset_list in lang_dict.items(): | |
if task == "qa": | |
benchmark_name = f"{domain}_{lang}" | |
benchmark_name = get_safe_name(benchmark_name) | |
col_name = benchmark_name | |
for metric in dataset_list: | |
qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task) | |
elif task == "long-doc": | |
for dataset in dataset_list: | |
benchmark_name = f"{domain}_{lang}_{dataset}" | |
benchmark_name = get_safe_name(benchmark_name) | |
col_name = benchmark_name | |
for metric in METRIC_LIST: | |
long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, | |
lang, task) | |
BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict) | |
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict) | |
BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()] | |
BENCHMARK_COLS_LONG_DOC = [c.col_name for c in long_doc_benchmark_dict.values()] | |
DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()])) | |
LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()])) | |
DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()])) | |
LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()])) | |
DEFAULT_METRIC_QA = "ndcg_at_10" | |
DEFAULT_METRIC_LONG_DOC = "recall_at_10" | |