Spaces:
AIR-Bench
/
Running on CPU Upgrade

nan commited on
Commit
a50e211
·
1 Parent(s): 0af261c

test: add unit tests for benchmarks

Browse files
Files changed (2) hide show
  1. src/benchmarks.py +41 -34
  2. tests/src/test_benchmarks.py +43 -9
src/benchmarks.py CHANGED
@@ -10,7 +10,7 @@ from src.models import TaskType, get_safe_name
10
  @dataclass
11
  class Benchmark:
12
  name: str # [domain]_[language]_[metric], task_key in the json file,
13
- metric: str # ndcg_at_1 ,metric_key in the json file
14
  col_name: str # [domain]_[language], name to display in the leaderboard
15
  domain: str
16
  lang: str
@@ -18,54 +18,61 @@ class Benchmark:
18
 
19
 
20
  # create a function return an enum class containing all the benchmarks
21
- def get_benchmarks_enum(benchmark_version: str, task_type: TaskType):
22
  benchmark_dict = {}
23
- if task_type == TaskType.qa:
24
- for task, domain_dict in BenchmarkTable[benchmark_version].items():
25
- if task != task_type.value:
26
- continue
27
- for domain, lang_dict in domain_dict.items():
28
- for lang, dataset_list in lang_dict.items():
29
- benchmark_name = get_safe_name(f"{domain}_{lang}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  col_name = benchmark_name
31
- for metric in dataset_list:
32
- if "test" not in dataset_list[metric]["splits"]:
33
- continue
34
  benchmark_dict[benchmark_name] = Benchmark(
35
  benchmark_name, metric, col_name, domain, lang, task
36
  )
37
- elif task_type == TaskType.long_doc:
38
- for task, domain_dict in BenchmarkTable[benchmark_version].items():
39
- if task != task_type.value:
40
- continue
41
- for domain, lang_dict in domain_dict.items():
42
- for lang, dataset_list in lang_dict.items():
43
- for dataset in dataset_list:
44
- benchmark_name = f"{domain}_{lang}_{dataset}"
45
- benchmark_name = get_safe_name(benchmark_name)
46
- col_name = benchmark_name
47
- if "test" not in dataset_list[dataset]["splits"]:
48
- continue
49
- for metric in METRIC_LIST:
50
- benchmark_dict[benchmark_name] = Benchmark(
51
- benchmark_name, metric, col_name, domain, lang, task
52
- )
53
  return benchmark_dict
54
 
55
 
56
  _qa_benchmark_dict = {}
57
  for version in BENCHMARK_VERSION_LIST:
58
  safe_version_name = get_safe_name(version)
59
- _qa_benchmark_dict[safe_version_name] = Enum(
60
- f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.qa)
61
- )
 
 
62
 
63
  _doc_benchmark_dict = {}
64
  for version in BENCHMARK_VERSION_LIST:
65
  safe_version_name = get_safe_name(version)
66
- _doc_benchmark_dict[safe_version_name] = Enum(
67
- f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, TaskType.long_doc)
68
- )
 
 
69
 
70
 
71
  QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
 
10
  @dataclass
11
  class Benchmark:
12
  name: str # [domain]_[language]_[metric], task_key in the json file,
13
+ metric: str # metric_key in the json file
14
  col_name: str # [domain]_[language], name to display in the leaderboard
15
  domain: str
16
  lang: str
 
18
 
19
 
20
  # create a function return an enum class containing all the benchmarks
21
+ def get_qa_benchmarks_dict(version: str):
22
  benchmark_dict = {}
23
+ for task, domain_dict in BenchmarkTable[version].items():
24
+ if task != TaskType.qa.value:
25
+ continue
26
+ for domain, lang_dict in domain_dict.items():
27
+ for lang, dataset_list in lang_dict.items():
28
+ benchmark_name = get_safe_name(f"{domain}_{lang}")
29
+ col_name = benchmark_name
30
+ for metric in dataset_list:
31
+ if "test" not in dataset_list[metric]["splits"]:
32
+ continue
33
+ benchmark_dict[benchmark_name] = Benchmark(
34
+ benchmark_name, metric, col_name, domain, lang, task
35
+ )
36
+ return benchmark_dict
37
+
38
+
39
+ def get_doc_benchmarks_dict(version: str):
40
+ benchmark_dict = {}
41
+ for task, domain_dict in BenchmarkTable[version].items():
42
+ if task != TaskType.long_doc.value:
43
+ continue
44
+ for domain, lang_dict in domain_dict.items():
45
+ for lang, dataset_list in lang_dict.items():
46
+ for dataset in dataset_list:
47
+ benchmark_name = f"{domain}_{lang}_{dataset}"
48
+ benchmark_name = get_safe_name(benchmark_name)
49
  col_name = benchmark_name
50
+ if "test" not in dataset_list[dataset]["splits"]:
51
+ continue
52
+ for metric in METRIC_LIST:
53
  benchmark_dict[benchmark_name] = Benchmark(
54
  benchmark_name, metric, col_name, domain, lang, task
55
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return benchmark_dict
57
 
58
 
59
  _qa_benchmark_dict = {}
60
  for version in BENCHMARK_VERSION_LIST:
61
  safe_version_name = get_safe_name(version)
62
+ _qa_benchmark_dict[safe_version_name] = \
63
+ Enum(
64
+ f"QABenchmarks_{safe_version_name}",
65
+ get_qa_benchmarks_dict(version)
66
+ )
67
 
68
  _doc_benchmark_dict = {}
69
  for version in BENCHMARK_VERSION_LIST:
70
  safe_version_name = get_safe_name(version)
71
+ _doc_benchmark_dict[safe_version_name] = \
72
+ Enum(
73
+ f"LongDocBenchmarks_{safe_version_name}",
74
+ get_doc_benchmarks_dict(version)
75
+ )
76
 
77
 
78
  QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
tests/src/test_benchmarks.py CHANGED
@@ -1,15 +1,49 @@
 
 
1
  from src.benchmarks import LongDocBenchmarks, QABenchmarks
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def test_qabenchmarks():
 
 
 
 
 
 
 
 
 
 
5
  for benchmark_list in list(QABenchmarks):
6
- print(benchmark_list.name)
7
- for b in list(benchmark_list.value):
8
- print(b)
9
- qa_benchmarks = QABenchmarks["2404"]
10
- l = list(frozenset([c.value.domain for c in list(qa_benchmarks.value)]))
11
- print(l)
12
 
13
 
14
- def test_longdocbenchmarks():
15
- print(list(LongDocBenchmarks))
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
  from src.benchmarks import LongDocBenchmarks, QABenchmarks
4
+ from src.envs import BENCHMARK_VERSION_LIST
5
+
6
 
7
+ # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
8
+ # 24.05
9
+ # | Task | dev | test |
10
+ # | ---- | --- | ---- |
11
+ # | Long-Doc | 4 | 11 |
12
+ # | QA | 54 | 53 |
13
+ #
14
+ # 24.04
15
+ # | Task | test |
16
+ # | ---- | ---- |
17
+ # | Long-Doc | 15 |
18
+ # | QA | 13 |
19
 
20
+ @pytest.mark.parametrize(
21
+ "num_datasets_dict",
22
+ [
23
+ {
24
+ "air_bench_2404": 13,
25
+ "air_bench_2405": 53
26
+ }
27
+ ]
28
+ )
29
+ def test_qa_benchmarks(num_datasets_dict):
30
+ assert len(QABenchmarks) == len(BENCHMARK_VERSION_LIST)
31
  for benchmark_list in list(QABenchmarks):
32
+ version_slug = benchmark_list.name
33
+ assert num_datasets_dict[version_slug] == len(benchmark_list.value)
 
 
 
 
34
 
35
 
36
+ @pytest.mark.parametrize(
37
+ "num_datasets_dict",
38
+ [
39
+ {
40
+ "air_bench_2404": 15,
41
+ "air_bench_2405": 11
42
+ }
43
+ ]
44
+ )
45
+ def test_doc_benchmarks(num_datasets_dict):
46
+ assert len(LongDocBenchmarks) == len(BENCHMARK_VERSION_LIST)
47
+ for benchmark_list in list(LongDocBenchmarks):
48
+ version_slug = benchmark_list.name
49
+ assert num_datasets_dict[version_slug] == len(benchmark_list.value)