CoralLeiCN commited on
Commit
ccf895d
·
1 Parent(s): 646e8e8

Add scoring functions for evaluating model answers against ground truth

Browse files
Files changed (1) hide show
  1. agent/score.py +105 -0
agent/score.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/scorer.py
2
+ import json
3
+ import re
4
+ import string
5
+ import warnings
6
+
7
+ import numpy as np
8
+
9
+
10
+ def normalize_number_str(number_str: str) -> float:
11
+ # we replace these common units and commas to allow
12
+ # conversion to float
13
+ for char in ["$", "%", ","]:
14
+ number_str = number_str.replace(char, "")
15
+ try:
16
+ return float(number_str)
17
+ except ValueError:
18
+ print(f"String {number_str} cannot be normalized to number str.")
19
+ return float("inf")
20
+
21
+
22
+ def split_string(
23
+ s: str,
24
+ char_list: list[str] = [",", ";"],
25
+ ) -> list[str]:
26
+ pattern = f"[{''.join(char_list)}]"
27
+ return re.split(pattern, s)
28
+
29
+
30
+ def question_scorer(
31
+ model_answer: str,
32
+ ground_truth: str,
33
+ ) -> bool:
34
+ def is_float(element: any) -> bool:
35
+ try:
36
+ float(element)
37
+ return True
38
+ except ValueError:
39
+ return False
40
+
41
+ if model_answer is None:
42
+ model_answer = "None"
43
+
44
+ # if gt is a number
45
+ if is_float(ground_truth):
46
+ print(f"Evaluating {model_answer} as a number.")
47
+ normalized_answer = normalize_number_str(model_answer)
48
+ return normalized_answer == float(ground_truth)
49
+
50
+ # if gt is a list
51
+ elif any(char in ground_truth for char in [",", ";"]):
52
+ print(f"Evaluating {model_answer} as a comma separated list.")
53
+ # question with the fish: normalization removes punct
54
+
55
+ gt_elems = split_string(ground_truth)
56
+ ma_elems = split_string(model_answer)
57
+
58
+ # check length is the same
59
+ if len(gt_elems) != len(ma_elems):
60
+ warnings.warn(
61
+ "Answer lists have different lengths, returning False.", UserWarning
62
+ )
63
+ return False
64
+
65
+ # compare each element as float or str
66
+ comparisons = []
67
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
68
+ if is_float(gt_elem):
69
+ normalized_ma_elem = normalize_number_str(ma_elem)
70
+ comparisons.append(normalized_ma_elem == float(gt_elem))
71
+ else:
72
+ # we do not remove punct since comparisons can include punct
73
+ comparisons.append(
74
+ normalize_str(ma_elem, remove_punct=False)
75
+ == normalize_str(gt_elem, remove_punct=False)
76
+ )
77
+ return all(comparisons)
78
+
79
+ # if gt is a str
80
+ else:
81
+ print(f"Evaluating {model_answer} as a string.")
82
+ return normalize_str(model_answer) == normalize_str(ground_truth)
83
+
84
+
85
+ def normalize_str(input_str, remove_punct=True) -> str:
86
+ """
87
+ Normalize a string by:
88
+ - Removing all white spaces
89
+ - Optionally removing punctuation (if remove_punct is True)
90
+ - Converting to lowercase
91
+ Parameters:
92
+ - input_str: str, the string to normalize
93
+ - remove_punct: bool, whether to remove punctuation (default: True)
94
+ Returns:
95
+ - str, the normalized string
96
+ """
97
+ # Remove all white spaces. Required e.g for seagull vs. sea gull
98
+ no_spaces = re.sub(r"\s", "", input_str)
99
+
100
+ # Remove punctuation, if specified.
101
+ if remove_punct:
102
+ translator = str.maketrans("", "", string.punctuation)
103
+ return no_spaces.lower().translate(translator)
104
+ else:
105
+ return no_spaces.lower()