llm_math_reasoning / run_all.py
MingLi
code
63c6bf0
# run_all.py
import os
import json
from tqdm import tqdm
from models.qwen2_5_math import QwenMathModel
from models.deepseek_qwen import DeepSeekQwenModel
from prompts.cot import apply_cot_answer
from prompts.self_refine import apply_self_refine_answer
from prompts.self_consistency import apply_self_consistency_answer
from prompts.zero_shot import apply_zero_shot_answer
from evaluation.metrics import evaluate_predictions
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
print("Using device:", device)
current_dir = os.path.dirname(__file__)
def load_local_dataset(path, max_samples=None):
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return data[:max_samples] if max_samples else data
def get_field_name(dataset):
return {
"GSM8K_test": "question",
"MATH-500": "problem",
"AIME_2024": "Problem"
}.get(dataset, "question")
def save_results(predictions, metrics, dataset, model, prompt_name):
data_file_path = f"results/{dataset}_{model}_{prompt_name}.json"
output_path = os.path.join(current_dir, data_file_path)
with open(output_path, "w", encoding="utf-8") as f:
json.dump({
"predictions": predictions,
"metrics": metrics
}, f, ensure_ascii=False, indent=2)
def run_all(max_samples, max_workers=4):
os.makedirs("results/logs", exist_ok=True)
datasets = ["GSM8K_test", "MATH-500", "AIME_2024"]
models = {
# "qwen2.5": QwenMathModel,
"deepseek": DeepSeekQwenModel
}
prompts = {
"qwen2.5": {
"cot": lambda q, model, dataset_name, model_name, qid: apply_cot_answer(q, model, dataset_name, model_name, qid),
"self_refine": lambda q, model, dataset_name, model_name, qid: apply_self_refine_answer(q, model, dataset_name, model_name, qid),
"self_consistency": lambda q, model, dataset_name, model_name, qid: apply_self_consistency_answer(q, model, dataset_name, model_name, qid, n_iter=5)
},
"deepseek": {
"zero_shot": lambda q, model, dataset_name, model_name, qid: apply_zero_shot_answer(q, model, dataset_name, model_name, qid)
}
}
for dataset in datasets:
data_path = os.path.join(current_dir, f"data/{dataset}.json")
data = load_local_dataset(data_path, max_samples)
field_name = get_field_name(dataset)
for model_name, model_cls in models.items():
model = model_cls()
model_prompts = prompts.get(model_name, {})
for prompt_name, prompt_fn in model_prompts.items():
print(f"\nRunning: {dataset} | {model_name} | {prompt_name}")
predictions_dict = {}
with ThreadPoolExecutor(max_workers=8) as executor:
future_to_qid = {
executor.submit(prompt_fn, sample[field_name], model, dataset, model_name, qid): qid
for qid, sample in enumerate(data)
}
for future in tqdm(as_completed(future_to_qid), total=len(future_to_qid)):
qid = future_to_qid[future]
try:
predictions_dict[qid] = future.result()
except Exception as e:
print(f"[Error] Sample {qid}: {e}")
predictions_dict[qid] = "Error"
predictions = [predictions_dict[qid] for qid in range(len(data))]
metrics, results = evaluate_predictions(predictions, data, dataset, prompt_name, field_name)
print("Result:", metrics)
save_results(results, metrics, dataset, model_name, prompt_name)
if __name__ == "__main__":
run_all(max_samples=None)