|
|
|
|
|
import os |
|
import json |
|
from tqdm import tqdm |
|
|
|
from models.qwen2_5_math import QwenMathModel |
|
from models.deepseek_qwen import DeepSeekQwenModel |
|
from prompts.cot import apply_cot_answer |
|
from prompts.self_refine import apply_self_refine_answer |
|
from prompts.self_consistency import apply_self_consistency_answer |
|
from prompts.zero_shot import apply_zero_shot_answer |
|
from evaluation.metrics import evaluate_predictions |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
import torch |
|
|
|
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu" |
|
print("Using device:", device) |
|
|
|
current_dir = os.path.dirname(__file__) |
|
|
|
def load_local_dataset(path, max_samples=None): |
|
with open(path, "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
return data[:max_samples] if max_samples else data |
|
|
|
|
|
def get_field_name(dataset): |
|
return { |
|
"GSM8K_test": "question", |
|
"MATH-500": "problem", |
|
"AIME_2024": "Problem" |
|
}.get(dataset, "question") |
|
|
|
|
|
def save_results(predictions, metrics, dataset, model, prompt_name): |
|
data_file_path = f"results/{dataset}_{model}_{prompt_name}.json" |
|
output_path = os.path.join(current_dir, data_file_path) |
|
with open(output_path, "w", encoding="utf-8") as f: |
|
json.dump({ |
|
"predictions": predictions, |
|
"metrics": metrics |
|
}, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def run_all(max_samples, max_workers=4): |
|
os.makedirs("results/logs", exist_ok=True) |
|
|
|
datasets = ["GSM8K_test", "MATH-500", "AIME_2024"] |
|
models = { |
|
|
|
"deepseek": DeepSeekQwenModel |
|
} |
|
prompts = { |
|
"qwen2.5": { |
|
"cot": lambda q, model, dataset_name, model_name, qid: apply_cot_answer(q, model, dataset_name, model_name, qid), |
|
"self_refine": lambda q, model, dataset_name, model_name, qid: apply_self_refine_answer(q, model, dataset_name, model_name, qid), |
|
"self_consistency": lambda q, model, dataset_name, model_name, qid: apply_self_consistency_answer(q, model, dataset_name, model_name, qid, n_iter=5) |
|
}, |
|
"deepseek": { |
|
"zero_shot": lambda q, model, dataset_name, model_name, qid: apply_zero_shot_answer(q, model, dataset_name, model_name, qid) |
|
} |
|
} |
|
|
|
for dataset in datasets: |
|
data_path = os.path.join(current_dir, f"data/{dataset}.json") |
|
data = load_local_dataset(data_path, max_samples) |
|
field_name = get_field_name(dataset) |
|
|
|
for model_name, model_cls in models.items(): |
|
model = model_cls() |
|
|
|
model_prompts = prompts.get(model_name, {}) |
|
for prompt_name, prompt_fn in model_prompts.items(): |
|
print(f"\nRunning: {dataset} | {model_name} | {prompt_name}") |
|
|
|
predictions_dict = {} |
|
with ThreadPoolExecutor(max_workers=8) as executor: |
|
future_to_qid = { |
|
executor.submit(prompt_fn, sample[field_name], model, dataset, model_name, qid): qid |
|
for qid, sample in enumerate(data) |
|
} |
|
for future in tqdm(as_completed(future_to_qid), total=len(future_to_qid)): |
|
qid = future_to_qid[future] |
|
try: |
|
predictions_dict[qid] = future.result() |
|
except Exception as e: |
|
print(f"[Error] Sample {qid}: {e}") |
|
predictions_dict[qid] = "Error" |
|
|
|
predictions = [predictions_dict[qid] for qid in range(len(data))] |
|
|
|
metrics, results = evaluate_predictions(predictions, data, dataset, prompt_name, field_name) |
|
print("Result:", metrics) |
|
save_results(results, metrics, dataset, model_name, prompt_name) |
|
|
|
|
|
if __name__ == "__main__": |
|
run_all(max_samples=None) |
|
|