RM / app.py
lihongze8's picture
Update app.py
acddaf1 verified
raw
history blame
4.16 kB
import os
import subprocess
import sys
import json
# 设置环境
def setup_environment():
if not os.path.exists("skywork-o1-prm-inference"):
print("Cloning repository...")
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
repo_path = os.path.abspath("skywork-o1-prm-inference")
else:
repo_path = os.path.abspath("skywork-o1-prm-inference")
if repo_path not in sys.path:
sys.path.append(repo_path)
print(f"Added {repo_path} to Python path")
setup_environment()
import gradio as gr
from transformers import AutoTokenizer
from model_utils.prm_model import PRM_MODEL
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
import torch
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
def evaluate(problem, response):
try:
# 处理输入数据
processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
input_ids = [processed_data[0]] # 第一个元素
steps = [processed_data[1]] # 第二个元素
reward_flags = [processed_data[2]] # 第三个元素
# 准备批处理输入
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
input_ids,
reward_flags,
tokenizer.pad_token_id
)
# 确保在CPU上
input_ids = input_ids.to("cpu")
attention_mask = attention_mask.to("cpu")
if isinstance(reward_flags, torch.Tensor):
reward_flags = reward_flags.to("cpu")
# 模型推理
with torch.no_grad():
_, _, rewards = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_probs=True
)
# 计算步骤奖励
step_rewards = derive_step_rewards(rewards, reward_flags)
# 确保返回的是有效的JSON字符串
if isinstance(step_rewards[0], torch.Tensor):
return json.dumps(step_rewards[0].cpu().numpy().tolist())
elif isinstance(step_rewards[0], np.ndarray):
return json.dumps(step_rewards[0].tolist())
else:
return json.dumps(list(step_rewards[0])) # 转换为列表
except Exception as e:
return json.dumps({"error": str(e)})
# 创建Gradio界面
iface = gr.Interface(
fn=evaluate,
inputs=[
gr.Textbox(label="Problem", lines=4),
gr.Textbox(label="Response", lines=8)
],
outputs=gr.JSON(),
title="Problem Response Evaluation",
description="Enter a problem and its response to get step-wise rewards",
examples=[
[
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
"To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. Calculate the total number of eggs laid by the ducks per day.\n Janet's ducks lay 16 eggs per day.\n2. Determine the number of eggs Janet uses each day.\n - She eats 3 eggs for breakfast every morning.\n - She bakes muffins for her friends every day with 4 eggs.\n So, the total number of eggs used per day is:\n 3 + 4 = 7 eggs\n3. Calculate the number of eggs Janet sells at the farmers' market each day.\n Subtract the number of eggs used from the total number of eggs laid:\n 16 - 7 = 9 eggs\n4. Determine how much money Janet makes from selling the eggs.\n She sells each egg for $2, so the total amount of money she makes is:\n 9 ×2 = 18 dollars\nTherefore, the amount of money Janet makes every day at the farmers' market is $18."
]
],
cache_examples=False # 禁用示例缓存
)
# 启动接口
iface.launch(server_name="0.0.0.0")