Spaces:
AIR-Bench
/
Running on CPU Upgrade

File size: 2,866 Bytes
8b7a945
 
 
649e0fb
8b7a945
ec8e2d4
2bee5cb
649e0fb
 
8b7a945
 
9c49811
a50e211
8b7a945
9c49811
 
 
8b7a945
a96f80a
3fcf957
a50e211
7845083
a50e211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f30cbcc
a50e211
 
 
ec8e2d4
 
 
7845083
8b7a945
3fcf957
83bdd4e
1a22df4
7a743dd
a50e211
 
 
 
 
7845083
83bdd4e
1a22df4
7a743dd
a50e211
 
 
 
 
7845083
 
83bdd4e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from dataclasses import dataclass
from enum import Enum

from air_benchmark.tasks.tasks import BenchmarkTable

from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
from src.models import TaskType, get_safe_name


@dataclass
class Benchmark:
    name: str  # [domain]_[language]_[metric], task_key in the json file,
    metric: str  # metric_key in the json file
    col_name: str  # [domain]_[language], name to display in the leaderboard
    domain: str
    lang: str
    task: str


# create a function return an enum class containing all the benchmarks
def get_qa_benchmarks_dict(version: str):
    benchmark_dict = {}
    for task, domain_dict in BenchmarkTable[version].items():
        if task != TaskType.qa.value:
            continue
        for domain, lang_dict in domain_dict.items():
            for lang, dataset_list in lang_dict.items():
                benchmark_name = get_safe_name(f"{domain}_{lang}")
                col_name = benchmark_name
                for metric in dataset_list:
                    if "test" not in dataset_list[metric]["splits"]:
                        continue
                    benchmark_dict[benchmark_name] = Benchmark(
                        benchmark_name, metric, col_name, domain, lang, task
                    )
    return benchmark_dict


def get_doc_benchmarks_dict(version: str):
    benchmark_dict = {}
    for task, domain_dict in BenchmarkTable[version].items():
        if task != TaskType.long_doc.value:
            continue
        for domain, lang_dict in domain_dict.items():
            for lang, dataset_list in lang_dict.items():
                for dataset in dataset_list:
                    benchmark_name = f"{domain}_{lang}_{dataset}"
                    benchmark_name = get_safe_name(benchmark_name)
                    col_name = benchmark_name
                    if "test" not in dataset_list[dataset]["splits"]:
                        continue
                    for metric in METRIC_LIST:
                        benchmark_dict[benchmark_name] = Benchmark(
                            benchmark_name, metric, col_name, domain, lang, task
                        )
    return benchmark_dict


_qa_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
    safe_version_name = get_safe_name(version)
    _qa_benchmark_dict[safe_version_name] = \
        Enum(
            f"QABenchmarks_{safe_version_name}",
            get_qa_benchmarks_dict(version)
        )

_doc_benchmark_dict = {}
for version in BENCHMARK_VERSION_LIST:
    safe_version_name = get_safe_name(version)
    _doc_benchmark_dict[safe_version_name] = \
        Enum(
            f"LongDocBenchmarks_{safe_version_name}",
            get_doc_benchmarks_dict(version)
        )


QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
LongDocBenchmarks = Enum("LongDocBenchmarks", _doc_benchmark_dict)