RM / app.py
lihongze8's picture
Update app.py
33e0a2c verified
raw
history blame
2.9 kB
import os
import subprocess
import sys
import torch
import gradio as gr
from transformers import AutoTokenizer
# 设置环境:若本地无skywork-o1-prm-inference目录,则git clone下来,并将其加入sys.path
def setup_environment():
if not os.path.exists("skywork-o1-prm-inference"):
print("Cloning repository...")
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
repo_path = os.path.abspath("skywork-o1-prm-inference")
if repo_path not in sys.path:
sys.path.append(repo_path)
print(f"Added {repo_path} to Python path")
setup_environment()
# 在环境设置后再导入项目内的模块
from skywork-o1-prm-inference.model_utils.prm_model import PRM_MODEL
from skywork-o1-prm-inference.model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
# 模型ID
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
# 初始化模型和tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
def compute_rewards(problem, response):
# 准备数据
data = {
"problem": problem,
"response": response
}
processed_data = [prepare_input(data["problem"], data["response"], tokenizer=tokenizer, step_token="\n")]
input_ids, steps, reward_flags = zip(*processed_data)
# 准备 batch 输入
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id)
input_ids = input_ids.to("cpu")
attention_mask = attention_mask.to("cpu")
if isinstance(reward_flags, torch.Tensor):
reward_flags = reward_flags.to("cpu")
with torch.no_grad():
_, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True)
# 计算 step_rewards
step_rewards = derive_step_rewards(rewards, reward_flags)
return step_rewards[0]
# 搭建 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("# PRM Reward Calculation")
problem_input = gr.Textbox(label="Problem", lines=5, value="Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?")
response_input = gr.Textbox(label="Response", lines=10, value="To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. ...")
output = gr.JSON(label="Step Rewards")
submit_btn = gr.Button("Compute Rewards")
submit_btn.click(fn=compute_rewards, inputs=[problem_input, response_input], outputs=output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)