# evaluation/metrics.py import re from typing import List, Dict, Tuple def extract_final_answer(prompt_name, text: str) -> str: if prompt_name == 'self_consistency': return text # 手动匹配 \boxed{} 的完整内容 start_tag = r"\boxed{" start_idx = text.find(start_tag) if start_idx != -1: start_idx += len(start_tag) # 定位到内容起始位置 brace_level = 1 end_idx = start_idx while end_idx < len(text) and brace_level > 0: if text[end_idx] == "{": brace_level += 1 elif text[end_idx] == "}": brace_level -= 1 end_idx += 1 if brace_level == 0: # 成功找到闭合 inner_content = text[start_idx:end_idx-1] # 去掉最后的闭合括号 return inner_content.strip() # 如果没找到boxed,则尝试提取数字 numbers = re.findall(r"[-+]?\d*\.?\d+", text) return numbers[-1] if numbers else text.strip() def extract_ground_truth(sample: Dict, dataset_name: str) -> str: if dataset_name.lower() == "gsm8k_test": answer = sample["answer"] match = re.search(r"####\s*([^\n]+)", answer) return match.group(1).strip() if match else answer.strip() elif dataset_name.lower() == "aime_2024": return sample["Answer"] else: return sample["answer"].strip() def compute_avg_response_length(predictions: List[str], mode: str = "char") -> float: lengths = [len(p) if mode == "char" else len(p.split()) for p in predictions] return sum(lengths) / len(lengths) def evaluate_predictions(predictions: List[str], dataset: List[Dict], dataset_name: str, prompt_name: str, field_name: str, mode: str = "char") -> Tuple[Dict, List[Dict]]: total = len(predictions) correct = 0 per_sample_results = [] for i, (pred, sample) in enumerate(zip(predictions, dataset)): gt_ans = extract_ground_truth(sample, dataset_name) pred_ans = extract_final_answer(prompt_name, pred) is_correct = int(pred_ans == gt_ans) correct += is_correct per_sample_results.append({ "qid": i, "question": sample[field_name], "correct_answer": gt_ans, "final_answer": pred_ans, "correct": is_correct, "prediction": pred }) accuracy = correct / total avg_len = compute_avg_response_length(predictions, mode) metrics = { "accuracy": round(accuracy * 100, 2), "avg_response_length": round(avg_len, 2) } return metrics, per_sample_results