import os import subprocess import sys import torch import gradio as gr from transformers import AutoTokenizer # 设置环境:若本地无skywork-o1-prm-inference目录,则git clone下来,并将其加入sys.path def setup_environment(): if not os.path.exists("skywork-o1-prm-inference"): print("Cloning repository...") subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True) repo_path = os.path.abspath("skywork-o1-prm-inference") if repo_path not in sys.path: sys.path.append(repo_path) print(f"Added {repo_path} to Python path") setup_environment() # 在环境设置后再导入项目内的模块 from model_utils.prm_model import PRM_MODEL from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards # 模型ID model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B" # 初始化模型和tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval() def compute_rewards(problem, response): # 准备数据 data = { "problem": problem, "response": response } processed_data = [prepare_input(data["problem"], data["response"], tokenizer=tokenizer, step_token="\n")] input_ids, steps, reward_flags = zip(*processed_data) # 准备 batch 输入 input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id) input_ids = input_ids.to("cpu") attention_mask = attention_mask.to("cpu") if isinstance(reward_flags, torch.Tensor): reward_flags = reward_flags.to("cpu") with torch.no_grad(): _, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True) # 计算 step_rewards step_rewards = derive_step_rewards(rewards, reward_flags) return step_rewards[0] # 搭建 Gradio 界面 with gr.Blocks() as demo: gr.Markdown("# PRM Reward Calculation") problem_input = gr.Textbox(label="Problem", lines=5, value="Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?") response_input = gr.Textbox(label="Response", lines=10, value="To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. ...") output = gr.JSON(label="Step Rewards") submit_btn = gr.Button("Compute Rewards") submit_btn.click(fn=compute_rewards, inputs=[problem_input, response_input], outputs=output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)