RM / app.py
lihongze8's picture
Update app.py
e27c5f5 verified
import os
import sys
import subprocess
import torch
from typing import Union
from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn
# gradio相关
import gradio as gr
# transformers相关
from transformers import AutoTokenizer
##############################################################################
# 1) 若本地(或Space)没有 skywork-o1-prm-inference 目录,则 clone 下来,并将其加入 sys.path
##############################################################################
def setup_environment():
if not os.path.exists("skywork-o1-prm-inference"):
print("Cloning repository skywork-o1-prm-inference...")
subprocess.run(
["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"],
check=True
)
repo_path = os.path.abspath("skywork-o1-prm-inference")
if repo_path not in sys.path:
sys.path.append(repo_path)
print(f"Added {repo_path} to Python path")
setup_environment()
##############################################################################
# 2) 导入Skywork项目内的模块
##############################################################################
from model_utils.prm_model import PRM_MODEL
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
##############################################################################
# 3) 加载模型及其相关资源
##############################################################################
# 你可根据需求更换 model_id
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
print(f"Loading tokenizer from {model_id} ...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print(f"Loading model from {model_id} ...")
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
##############################################################################
# 4) 定义推理函数
##############################################################################
def compute_rewards(problem: str, response: str):
"""核心推理函数:将 problem + response 输入模型,输出 step_rewards。"""
data = {
"problem": problem,
"response": response
}
processed_data = [
prepare_input(
data["problem"],
data["response"],
tokenizer=tokenizer,
step_token="\n"
)
]
input_ids, steps, reward_flags = zip(*processed_data)
# 准备 batch 输入
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
input_ids,
reward_flags,
tokenizer.pad_token_id
)
input_ids = input_ids.to("cpu")
attention_mask = attention_mask.to("cpu")
if isinstance(reward_flags, torch.Tensor):
reward_flags = reward_flags.to("cpu")
# 前向推理
with torch.no_grad():
_, _, rewards = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_probs=True
)
# 计算 step_rewards
step_rewards = derive_step_rewards(rewards, reward_flags)
return step_rewards[0]
##############################################################################
# 5) 准备FastAPI应用:对外暴露 /api/predict 接口
##############################################################################
app = FastAPI(title="PRM Inference Server", description="FastAPI + Gradio App")
# 5.1 定义API输入与输出的数据结构
class InferenceData(BaseModel):
problem: str
response: str
class InferenceOutput(BaseModel):
step_rewards: list
@app.post("/api/predict", response_model=InferenceOutput)
def predict(data: InferenceData):
"""
直接使用HTTP POST /api/predict,
body JSON: { "problem": "...", "response": "..." } 即可得到 step_rewards。
"""
rewards = compute_rewards(data.problem, data.response)
return InferenceOutput(step_rewards=rewards)
##############################################################################
# 6) 构建 Gradio 界面,并将其挂载到 /gradio 路径
##############################################################################
def build_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown(
"## PRM Reward Calculation\n\n"
"输入 `Problem` 和 `Response`,点击下方按钮即可获得 step_rewards。"
)
problem_input = gr.Textbox(
label="Problem",
lines=5,
value=(
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
"and bakes muffins for her friends every day with four. She sells the remainder "
"at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
"does she make every day at the farmers' market?"
)
)
response_input = gr.Textbox(
label="Response",
lines=10,
value=(
"To determine how much money Janet makes every day at the farmers' market, "
"we need to follow these steps:\n1. ..."
)
)
output = gr.JSON(label="Step Rewards")
submit_btn = gr.Button("Compute Rewards")
submit_btn.click(
fn=compute_rewards,
inputs=[problem_input, response_input],
outputs=output
)
return demo
demo = build_gradio_interface()
# 6.1 挂载Gradio到FastAPI的 /gradio 路径
app = gr.mount_gradio_app(app, demo, path="/")
##############################################################################
# 7) 在 main 中用 uvicorn 启动
##############################################################################
# 注意:在Hugging Face Spaces中,一般只需要 `python app.py` 即可开始监听。
# 若Spaces是自动检测并执行 `app.py`,则不一定会执行 `if __name__ == "__main__"` 分支。
# 因此你只要确保上面定义的 `app` 变量存在即可。
# 这里写一个可选的 main,方便本地调试。
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)