File size: 3,899 Bytes
63c6bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# run_all.py

import os
import json
from tqdm import tqdm

from models.qwen2_5_math import QwenMathModel
from models.deepseek_qwen import DeepSeekQwenModel
from prompts.cot import apply_cot_answer
from prompts.self_refine import apply_self_refine_answer
from prompts.self_consistency import apply_self_consistency_answer
from prompts.zero_shot import apply_zero_shot_answer
from evaluation.metrics import evaluate_predictions
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch

device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
print("Using device:", device)

current_dir = os.path.dirname(__file__)

def load_local_dataset(path, max_samples=None):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data[:max_samples] if max_samples else data


def get_field_name(dataset):
    return {
        "GSM8K_test": "question",
        "MATH-500": "problem",
        "AIME_2024": "Problem"
    }.get(dataset, "question")


def save_results(predictions, metrics, dataset, model, prompt_name):
    data_file_path = f"results/{dataset}_{model}_{prompt_name}.json"
    output_path = os.path.join(current_dir, data_file_path)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump({
            "predictions": predictions,
            "metrics": metrics
        }, f, ensure_ascii=False, indent=2)


def run_all(max_samples, max_workers=4):
    os.makedirs("results/logs", exist_ok=True)

    datasets = ["GSM8K_test", "MATH-500", "AIME_2024"]
    models = {
        # "qwen2.5": QwenMathModel,
        "deepseek": DeepSeekQwenModel
    }
    prompts = {
        "qwen2.5": {
            "cot": lambda q, model, dataset_name, model_name, qid: apply_cot_answer(q, model, dataset_name, model_name, qid),
            "self_refine": lambda q, model, dataset_name, model_name, qid: apply_self_refine_answer(q, model, dataset_name, model_name, qid),
            "self_consistency": lambda q, model, dataset_name, model_name, qid: apply_self_consistency_answer(q, model, dataset_name, model_name, qid, n_iter=5)
        },
        "deepseek": {
            "zero_shot": lambda q, model, dataset_name, model_name, qid: apply_zero_shot_answer(q, model, dataset_name, model_name, qid)
        }
    }

    for dataset in datasets:
        data_path = os.path.join(current_dir, f"data/{dataset}.json")
        data = load_local_dataset(data_path, max_samples)
        field_name = get_field_name(dataset)

        for model_name, model_cls in models.items():
            model = model_cls()

            model_prompts = prompts.get(model_name, {})
            for prompt_name, prompt_fn in model_prompts.items():
                print(f"\nRunning: {dataset} | {model_name} | {prompt_name}")

                predictions_dict = {}
                with ThreadPoolExecutor(max_workers=8) as executor:
                    future_to_qid = {
                        executor.submit(prompt_fn, sample[field_name], model, dataset, model_name, qid): qid
                        for qid, sample in enumerate(data)
                    }
                    for future in tqdm(as_completed(future_to_qid), total=len(future_to_qid)):
                        qid = future_to_qid[future]
                        try:
                            predictions_dict[qid] = future.result()
                        except Exception as e:
                            print(f"[Error] Sample {qid}: {e}")
                            predictions_dict[qid] = "Error"

                predictions = [predictions_dict[qid] for qid in range(len(data))]

                metrics, results = evaluate_predictions(predictions, data, dataset, prompt_name, field_name)
                print("Result:", metrics)
                save_results(results, metrics, dataset, model_name, prompt_name)


if __name__ == "__main__":
    run_all(max_samples=None)