Spaces:
AIR-Bench
/
Running on CPU Upgrade

File size: 3,195 Bytes
8b7a945
 
 
649e0fb
8b7a945
4791ac5
8b7a945
 
649e0fb
 
 
 
 
 
 
 
 
8b7a945
 
9c49811
8b7a945
 
9c49811
 
 
8b7a945
a96f80a
3fcf957
7845083
 
 
 
 
 
 
 
 
f30cbcc
3fcf957
7845083
 
 
 
 
 
 
 
 
 
3fcf957
 
 
 
bf586e3
 
3fcf957
7845083
 
 
8b7a945
3fcf957
7845083
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from dataclasses import dataclass
from enum import Enum

from air_benchmark.tasks.tasks import BenchmarkTable

from src.envs import METRIC_LIST


def get_safe_name(name: str):
    """Get RFC 1123 compatible safe name"""
    name = name.replace('-', '_')
    return ''.join(
        character.lower()
        for character in name
        if (character.isalnum() or character == '_'))


@dataclass
class Benchmark:
    name: str  # [domain]_[language]_[metric], task_key in the json file,
    metric: str  # ndcg_at_1 ,metric_key in the json file
    col_name: str  # [domain]_[language], name to display in the leaderboard
    domain: str
    lang: str
    task: str


# create a function return an enum class containing all the benchmarks
def get_benchmarks_enum(benchmark_version, task_type):
    benchmark_dict = {}
    if task_type == "qa":
        for task, domain_dict in BenchmarkTable[benchmark_version].items():
            if task != task_type:
                continue
            for domain, lang_dict in domain_dict.items():
                for lang, dataset_list in lang_dict.items():
                    benchmark_name = get_safe_name(f"{domain}_{lang}")
                    col_name = benchmark_name
                    for metric in dataset_list:
                        if "test" not in dataset_list[metric]["splits"]:
                            continue
                        benchmark_dict[benchmark_name] = \
                            Benchmark(benchmark_name, metric, col_name, domain, lang, task)
    elif task_type == "long-doc":
        for task, domain_dict in BenchmarkTable[benchmark_version].items():
            if task != task_type:
                continue
            for domain, lang_dict in domain_dict.items():
                for lang, dataset_list in lang_dict.items():
                    for dataset in dataset_list:
                        benchmark_name = f"{domain}_{lang}_{dataset}"
                        benchmark_name = get_safe_name(benchmark_name)
                        col_name = benchmark_name
                        if "test" not in dataset_list[dataset]["splits"]:
                            continue
                        for metric in METRIC_LIST:
                            benchmark_dict[benchmark_name] = \
                                Benchmark(benchmark_name, metric, col_name, domain, lang, task)
    return benchmark_dict


versions = ("AIR-Bench_24.04", "AIR-Bench_24.05")
qa_benchmark_dict = {}
for version in versions:
    safe_version_name = get_safe_name(version)[-4:]
    qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa"))

long_doc_benchmark_dict = {}
for version in versions:
    safe_version_name = get_safe_name(version)[-4:]
    long_doc_benchmark_dict[safe_version_name] = Enum(f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc"))

# _qa_benchmark_dict, = get_benchmarks_enum('AIR-Bench_24.04', "qa")
# _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04', "long-doc")

QABenchmarks = Enum('QABenchmarks', qa_benchmark_dict)
LongDocBenchmarks = Enum('LongDocBenchmarks', long_doc_benchmark_dict)