File size: 2,646 Bytes
63c6bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# evaluation/metrics.py

import re
from typing import List, Dict, Tuple

def extract_final_answer(prompt_name, text: str) -> str:
    if prompt_name == 'self_consistency':
        return text
    
    # 手动匹配 \boxed{} 的完整内容
    start_tag = r"\boxed{"
    start_idx = text.find(start_tag)
    
    if start_idx != -1:
        start_idx += len(start_tag)  # 定位到内容起始位置
        brace_level = 1
        end_idx = start_idx
        
        while end_idx < len(text) and brace_level > 0:
            if text[end_idx] == "{":
                brace_level += 1
            elif text[end_idx] == "}":
                brace_level -= 1
            end_idx += 1
        
        if brace_level == 0:  # 成功找到闭合
            inner_content = text[start_idx:end_idx-1]  # 去掉最后的闭合括号
            return inner_content.strip()
    
    # 如果没找到boxed,则尝试提取数字
    numbers = re.findall(r"[-+]?\d*\.?\d+", text)
    return numbers[-1] if numbers else text.strip()

def extract_ground_truth(sample: Dict, dataset_name: str) -> str:
    if dataset_name.lower() == "gsm8k_test":
        answer = sample["answer"]
        match = re.search(r"####\s*([^\n]+)", answer)
        return match.group(1).strip() if match else answer.strip()
    elif dataset_name.lower() == "aime_2024":
        return sample["Answer"]
    else:
        return sample["answer"].strip()

def compute_avg_response_length(predictions: List[str], mode: str = "char") -> float:
    lengths = [len(p) if mode == "char" else len(p.split()) for p in predictions]
    return sum(lengths) / len(lengths)

def evaluate_predictions(predictions: List[str], dataset: List[Dict], dataset_name: str, prompt_name: str, field_name: str, mode: str = "char") -> Tuple[Dict, List[Dict]]:
    total = len(predictions)
    correct = 0
    per_sample_results = []

    for i, (pred, sample) in enumerate(zip(predictions, dataset)):
        gt_ans = extract_ground_truth(sample, dataset_name)
        pred_ans = extract_final_answer(prompt_name, pred)

        is_correct = int(pred_ans == gt_ans)
        correct += is_correct

        per_sample_results.append({
            "qid": i,
            "question": sample[field_name],
            "correct_answer": gt_ans,
            "final_answer": pred_ans,
            "correct": is_correct,
            "prediction": pred
        })

    accuracy = correct / total
    avg_len = compute_avg_response_length(predictions, mode)

    metrics = {
        "accuracy": round(accuracy * 100, 2),
        "avg_response_length": round(avg_len, 2)
    }

    return metrics, per_sample_results