Spaces:
AIR-Bench
/
Running on CPU Upgrade

File size: 2,347 Bytes
8b7a945
 
 
649e0fb
8b7a945
4791ac5
8b7a945
 
649e0fb
 
 
 
 
 
 
 
 
8b7a945
 
9c49811
8b7a945
 
9c49811
 
 
8b7a945
a96f80a
3fcf957
 
 
 
 
 
 
 
 
f30cbcc
 
3fcf957
 
 
 
 
 
 
 
 
 
 
 
 
a96f80a
3fcf957
8b7a945
3fcf957
 
270c122
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from dataclasses import dataclass
from enum import Enum

from air_benchmark.tasks.tasks import BenchmarkTable

from src.envs import METRIC_LIST


def get_safe_name(name: str):
    """Get RFC 1123 compatible safe name"""
    name = name.replace('-', '_')
    return ''.join(
        character.lower()
        for character in name
        if (character.isalnum() or character == '_'))


@dataclass
class Benchmark:
    name: str  # [domain]_[language]_[metric], task_key in the json file,
    metric: str  # ndcg_at_1 ,metric_key in the json file
    col_name: str  # [domain]_[language], name to display in the leaderboard
    domain: str
    lang: str
    task: str


# create a function return an enum class containing all the benchmarks
def get_benchmarks_enum(benchmark_version):
    qa_benchmark_dict = {}
    long_doc_benchmark_dict = {}
    for task, domain_dict in BenchmarkTable[benchmark_version].items():
        for domain, lang_dict in domain_dict.items():
            for lang, dataset_list in lang_dict.items():
                if task == "qa":
                    benchmark_name = f"{domain}_{lang}"
                    benchmark_name = get_safe_name(benchmark_name)
                    col_name = benchmark_name
                    for metric in dataset_list:
                        qa_benchmark_dict[benchmark_name] = \
                            Benchmark(
                                benchmark_name, metric, col_name, domain, lang, task)
                elif task == "long-doc":
                    for dataset in dataset_list:
                        benchmark_name = f"{domain}_{lang}_{dataset}"
                        benchmark_name = get_safe_name(benchmark_name)
                        col_name = benchmark_name
                        for metric in METRIC_LIST:
                            long_doc_benchmark_dict[benchmark_name] = \
                                Benchmark(
                                    benchmark_name, metric, col_name, domain,
                                                                            lang, task)
    return qa_benchmark_dict, long_doc_benchmark_dict

_qa_benchmark_dict, _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04')

QABenchmarks = Enum('QABenchmarks', _qa_benchmark_dict)
LongDocBenchmarks = Enum('LongDocBenchmarks', _long_doc_benchmark_dict)