MingLi
code
63c6bf0
# evaluation/metrics.py
import re
from typing import List, Dict, Tuple
def extract_final_answer(prompt_name, text: str) -> str:
if prompt_name == 'self_consistency':
return text
# 手动匹配 \boxed{} 的完整内容
start_tag = r"\boxed{"
start_idx = text.find(start_tag)
if start_idx != -1:
start_idx += len(start_tag) # 定位到内容起始位置
brace_level = 1
end_idx = start_idx
while end_idx < len(text) and brace_level > 0:
if text[end_idx] == "{":
brace_level += 1
elif text[end_idx] == "}":
brace_level -= 1
end_idx += 1
if brace_level == 0: # 成功找到闭合
inner_content = text[start_idx:end_idx-1] # 去掉最后的闭合括号
return inner_content.strip()
# 如果没找到boxed,则尝试提取数字
numbers = re.findall(r"[-+]?\d*\.?\d+", text)
return numbers[-1] if numbers else text.strip()
def extract_ground_truth(sample: Dict, dataset_name: str) -> str:
if dataset_name.lower() == "gsm8k_test":
answer = sample["answer"]
match = re.search(r"####\s*([^\n]+)", answer)
return match.group(1).strip() if match else answer.strip()
elif dataset_name.lower() == "aime_2024":
return sample["Answer"]
else:
return sample["answer"].strip()
def compute_avg_response_length(predictions: List[str], mode: str = "char") -> float:
lengths = [len(p) if mode == "char" else len(p.split()) for p in predictions]
return sum(lengths) / len(lengths)
def evaluate_predictions(predictions: List[str], dataset: List[Dict], dataset_name: str, prompt_name: str, field_name: str, mode: str = "char") -> Tuple[Dict, List[Dict]]:
total = len(predictions)
correct = 0
per_sample_results = []
for i, (pred, sample) in enumerate(zip(predictions, dataset)):
gt_ans = extract_ground_truth(sample, dataset_name)
pred_ans = extract_final_answer(prompt_name, pred)
is_correct = int(pred_ans == gt_ans)
correct += is_correct
per_sample_results.append({
"qid": i,
"question": sample[field_name],
"correct_answer": gt_ans,
"final_answer": pred_ans,
"correct": is_correct,
"prediction": pred
})
accuracy = correct / total
avg_len = compute_avg_response_length(predictions, mode)
metrics = {
"accuracy": round(accuracy * 100, 2),
"avg_response_length": round(avg_len, 2)
}
return metrics, per_sample_results