RM / app.py
lihongze8's picture
Update app.py
c442fd5 verified
raw
history blame
2.49 kB
import os
import subprocess
import sys
import json
# 设置环境
def setup_environment():
if not os.path.exists("skywork-o1-prm-inference"):
print("Cloning repository...")
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
repo_path = os.path.abspath("skywork-o1-prm-inference")
if repo_path not in sys.path:
sys.path.append(repo_path)
print(f"Added {repo_path} to Python path")
setup_environment()
import gradio as gr
from transformers import AutoTokenizer
from model_utils.prm_model import PRM_MODEL
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
import torch
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
def evaluate(problem, response):
try:
processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
input_ids, steps, reward_flags = [processed_data]input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
input_ids,
reward_flags,
tokenizer.pad_token_id
)
input_ids = input_ids.to("cpu")
attention_mask = attention_mask.to("cpu")
if isinstance(reward_flags, torch.Tensor):
reward_flags = reward_flags.to("cpu")
with torch.no_grad():
_, _, rewards = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_probs=True
)
step_rewards = derive_step_rewards(rewards, reward_flags)
#确保返回的是有效的JSON字符串
return json.dumps(step_rewards[0].tolist())
except Exception as e:
return json.dumps({"error": str(e)})
# 创建Gradio界面
iface = gr.Interface(
fn=evaluate,
inputs=[
gr.Textbox(label="Problem", lines=4),
gr.Textbox(label="Response", lines=8)
],
outputs=gr.JSON(),
title="Problem Response Evaluation",
description="Enter a problem and its response to get step-wise rewards",
examples=[
[
"Janet'sducks lay 16 eggs per day...",
"To determine how much money Janet makes..."
]
],
cache_examples=False# 禁用示例缓存
)
# 启动接口
iface.launch(server_name="0.0.0.0")