djstrong commited on
Commit
b9262b0
·
1 Parent(s): b274663

Add calc_avg.py for average score calculation and refactor task retrieval in about.py

Browse files
Files changed (2) hide show
  1. calc_avg.py +103 -0
  2. src/about.py +9 -4
calc_avg.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import argparse
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ import csv
8
+
9
+ @dataclass(frozen=True)
10
+ class Task:
11
+ benchmark: str
12
+ metric: str
13
+ col_name: str
14
+ type: str
15
+ baseline: float = 0.0
16
+
17
+ from src.about import Tasks, get_tasks
18
+
19
+ g_tasks, mc_tasks, rag_tasks, all_tasks = get_tasks()
20
+
21
+ if __name__ == '__main__':
22
+ parser = argparse.ArgumentParser(description='Calculate average scores from JSON with scores')
23
+ parser.add_argument('json', type=str, help='Path to JSON file with scores')
24
+ parser.add_argument('--header', action='store_true', help='Print header')
25
+ parser.add_argument('-d', '--delimiter', type=str, default=',', help='Delimiter for CSV output')
26
+ args = parser.parse_args()
27
+
28
+ if args.json.endswith('.json'):
29
+ paths=[args.json]
30
+ else:
31
+
32
+ paths=glob.glob(args.json + '/**/results*.json', recursive=True)
33
+
34
+ print(paths)
35
+ # paths=[args.json]
36
+
37
+ results = {}
38
+ for path in paths:
39
+ print(path)
40
+ data = json.load(open(path))
41
+
42
+
43
+ for task in Tasks:
44
+ try:
45
+ # print(task.value.benchmark, task.value.baseline)
46
+ # print(data['results'][task.value.benchmark], data['results'][task.value.benchmark][task.value.metric])
47
+ results[task.value.benchmark] = data['results'][task.value.benchmark][task.value.metric]
48
+ if 'perplexity' not in task.value.metric and 'eqbench' not in task.value.metric:
49
+ results[task.value.benchmark] *= 100
50
+
51
+ # if 'perplexity' in task.metric or 'eqbench' in task.metric:
52
+ # mean_acc = np.mean(accs)
53
+ # else:
54
+ # mean_acc = np.mean(accs) * 100.0
55
+
56
+ except KeyError:
57
+ print(f'No data for {task.value.benchmark}', file=sys.stderr)
58
+ # results=data['results']
59
+ print(results)
60
+ all_tasks_wo_polqa = [task for task in all_tasks if 'polqa' not in task]
61
+
62
+ baselines = {task.value.benchmark: task.value.baseline * 100 for task in Tasks}
63
+ print(baselines)
64
+ average_old = sum([v for task, v in results.items() if v is not None and task in all_tasks_wo_polqa]) / len(
65
+ all_tasks_wo_polqa)
66
+
67
+ average = sum(
68
+ [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
69
+ all_tasks]) / len(all_tasks)
70
+
71
+ for task in all_tasks:
72
+ print (task, results.get(task, 0), baselines.get(task, 0))
73
+
74
+ average_g = sum(
75
+ [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
76
+ g_tasks]) / len(g_tasks)
77
+ average_mc = sum(
78
+ [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
79
+ mc_tasks]) / len(mc_tasks)
80
+ average_rag = sum(
81
+ [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
82
+ rag_tasks]) / len(rag_tasks)
83
+
84
+
85
+
86
+ # for task in Tasks:
87
+ # print(task.value.benchmark, task.value.baseline)
88
+ # print(data['results'][task.value.benchmark])
89
+ # print(f'Average: {average:.2f}')
90
+ # print(f'Average generate: {average_g:.2f}')
91
+ # print(f'Average multiple choice: {average_mc:.2f}')
92
+ # print(f'Average old: {average_old:.2f}')
93
+
94
+ row = [args.json, None, average, average_old, average_g, average_mc, average_rag]
95
+ for task in Tasks:
96
+ row.append(results.get(task.value.benchmark, None))
97
+
98
+ # printe headers
99
+ if args.header:
100
+ csv.writer(sys.stdout, delimiter=args.delimiter).writerow(['file', 'name', 'average', 'average_old', 'average_g', 'average_mc'] + [task.value.benchmark for task in Tasks])
101
+ # print(row)
102
+ csv.writer(sys.stdout, delimiter=args.delimiter).writerow(row)
103
+
src/about.py CHANGED
@@ -46,10 +46,15 @@ class Tasks(Enum):
46
  task30 = Task("polish_pes", "exact_match,score-first", "pes", "other", 0.2)
47
 
48
 
49
- g_tasks = [task.value.benchmark for task in Tasks if task.value.type == "generate_until"]
50
- mc_tasks = [task.value.benchmark for task in Tasks if task.value.type == "multiple_choice"]
51
- rag_tasks = ['polish_polqa_reranking_multiple_choice', 'polish_polqa_open_book', 'polish_poquad_open_book']
52
- all_tasks = g_tasks + mc_tasks
 
 
 
 
 
53
 
54
  NUM_FEWSHOT = 0 # Change with your few shot
55
  # ---------------------------------------------------
 
46
  task30 = Task("polish_pes", "exact_match,score-first", "pes", "other", 0.2)
47
 
48
 
49
+ def get_tasks():
50
+ g_tasks = [task.value.benchmark for task in Tasks if task.value.type == "generate_until"]
51
+ mc_tasks = [task.value.benchmark for task in Tasks if task.value.type == "multiple_choice"]
52
+ rag_tasks = ['polish_polqa_reranking_multiple_choice', 'polish_polqa_open_book', 'polish_poquad_open_book']
53
+ all_tasks = g_tasks + mc_tasks
54
+ return g_tasks, mc_tasks, rag_tasks, all_tasks
55
+
56
+ g_tasks, mc_tasks, rag_tasks, all_tasks = get_tasks()
57
+
58
 
59
  NUM_FEWSHOT = 0 # Change with your few shot
60
  # ---------------------------------------------------