|
|
|
|
|
import re |
|
from typing import List, Dict, Tuple |
|
|
|
def extract_final_answer(prompt_name, text: str) -> str: |
|
if prompt_name == 'self_consistency': |
|
return text |
|
|
|
|
|
start_tag = r"\boxed{" |
|
start_idx = text.find(start_tag) |
|
|
|
if start_idx != -1: |
|
start_idx += len(start_tag) |
|
brace_level = 1 |
|
end_idx = start_idx |
|
|
|
while end_idx < len(text) and brace_level > 0: |
|
if text[end_idx] == "{": |
|
brace_level += 1 |
|
elif text[end_idx] == "}": |
|
brace_level -= 1 |
|
end_idx += 1 |
|
|
|
if brace_level == 0: |
|
inner_content = text[start_idx:end_idx-1] |
|
return inner_content.strip() |
|
|
|
|
|
numbers = re.findall(r"[-+]?\d*\.?\d+", text) |
|
return numbers[-1] if numbers else text.strip() |
|
|
|
def extract_ground_truth(sample: Dict, dataset_name: str) -> str: |
|
if dataset_name.lower() == "gsm8k_test": |
|
answer = sample["answer"] |
|
match = re.search(r"####\s*([^\n]+)", answer) |
|
return match.group(1).strip() if match else answer.strip() |
|
elif dataset_name.lower() == "aime_2024": |
|
return sample["Answer"] |
|
else: |
|
return sample["answer"].strip() |
|
|
|
def compute_avg_response_length(predictions: List[str], mode: str = "char") -> float: |
|
lengths = [len(p) if mode == "char" else len(p.split()) for p in predictions] |
|
return sum(lengths) / len(lengths) |
|
|
|
def evaluate_predictions(predictions: List[str], dataset: List[Dict], dataset_name: str, prompt_name: str, field_name: str, mode: str = "char") -> Tuple[Dict, List[Dict]]: |
|
total = len(predictions) |
|
correct = 0 |
|
per_sample_results = [] |
|
|
|
for i, (pred, sample) in enumerate(zip(predictions, dataset)): |
|
gt_ans = extract_ground_truth(sample, dataset_name) |
|
pred_ans = extract_final_answer(prompt_name, pred) |
|
|
|
is_correct = int(pred_ans == gt_ans) |
|
correct += is_correct |
|
|
|
per_sample_results.append({ |
|
"qid": i, |
|
"question": sample[field_name], |
|
"correct_answer": gt_ans, |
|
"final_answer": pred_ans, |
|
"correct": is_correct, |
|
"prediction": pred |
|
}) |
|
|
|
accuracy = correct / total |
|
avg_len = compute_avg_response_length(predictions, mode) |
|
|
|
metrics = { |
|
"accuracy": round(accuracy * 100, 2), |
|
"avg_response_length": round(avg_len, 2) |
|
} |
|
|
|
return metrics, per_sample_results |