RM / app.py
lihongze8's picture
Update app.py
bde5ca7 verified
raw
history blame
6.2 kB
import os
import sys
import subprocess
import torch
from typing import Union
from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn
# gradio相关
import gradio as gr
# transformers相关
from transformers import AutoTokenizer
##############################################################################
# 1) 若本地(或Space)没有 skywork-o1-prm-inference 目录,则 clone 下来,并将其加入 sys.path
##############################################################################
def setup_environment():
if not os.path.exists("skywork-o1-prm-inference"):
print("Cloning repository skywork-o1-prm-inference...")
subprocess.run(
["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"],
check=True
)
repo_path = os.path.abspath("skywork-o1-prm-inference")
if repo_path not in sys.path:
sys.path.append(repo_path)
print(f"Added {repo_path} to Python path")
setup_environment()
##############################################################################
# 2) 导入Skywork项目内的模块
##############################################################################
from model_utils.prm_model import PRM_MODEL
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
##############################################################################
# 3) 加载模型及其相关资源
##############################################################################
# 你可根据需求更换 model_id
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
print(f"Loading tokenizer from {model_id} ...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print(f"Loading model from {model_id} ...")
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
##############################################################################
# 4) 定义推理函数
##############################################################################
def compute_rewards(problem: str, response: str):
"""核心推理函数:将 problem + response 输入模型,输出 step_rewards。"""
data = {
"problem": problem,
"response": response
}
processed_data = [
prepare_input(
data["problem"],
data["response"],
tokenizer=tokenizer,
step_token="\n"
)
]
input_ids, steps, reward_flags = zip(*processed_data)
# 准备 batch 输入
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
input_ids,
reward_flags,
tokenizer.pad_token_id
)
input_ids = input_ids.to("cpu")
attention_mask = attention_mask.to("cpu")
if isinstance(reward_flags, torch.Tensor):
reward_flags = reward_flags.to("cpu")
# 前向推理
with torch.no_grad():
_, _, rewards = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_probs=True
)
# 计算 step_rewards
step_rewards = derive_step_rewards(rewards, reward_flags)
return step_rewards[0]
##############################################################################
# 5) 准备FastAPI应用:对外暴露 /api/predict 接口
##############################################################################
app = FastAPI(title="PRM Inference Server", description="FastAPI + Gradio App")
# 5.1 定义API输入与输出的数据结构
class InferenceData(BaseModel):
problem: str
response: str
class InferenceOutput(BaseModel):
step_rewards: list
@app.post("/api/predict", response_model=InferenceOutput)
def predict(data: InferenceData):
"""
直接使用HTTP POST /api/predict,
body JSON: { "problem": "...", "response": "..." } 即可得到 step_rewards。
"""
rewards = compute_rewards(data.problem, data.response)
return InferenceOutput(step_rewards=rewards)
##############################################################################
# 6) 构建 Gradio 界面,并将其挂载到 /gradio 路径
##############################################################################
def build_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown(
"## PRM Reward Calculation\n\n"
"输入 `Problem` 和 `Response`,点击下方按钮即可获得 step_rewards。"
)
problem_input = gr.Textbox(
label="Problem",
lines=5,
value=(
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
"and bakes muffins for her friends every day with four. She sells the remainder "
"at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
"does she make every day at the farmers' market?"
)
)
response_input = gr.Textbox(
label="Response",
lines=10,
value=(
"To determine how much money Janet makes every day at the farmers' market, "
"we need to follow these steps:\n1. ..."
)
)
output = gr.JSON(label="Step Rewards")
submit_btn = gr.Button("Compute Rewards")
submit_btn.click(
fn=compute_rewards,
inputs=[problem_input, response_input],
outputs=output
)
return demo
demo = build_gradio_interface()
# 6.1 挂载Gradio到FastAPI的 /gradio 路径
app = gr.mount_gradio_app(app, demo, path="/gradio")
##############################################################################
# 7) 在 main 中用 uvicorn 启动
##############################################################################
# 注意:在Hugging Face Spaces中,一般只需要 `python app.py` 即可开始监听。
# 若Spaces是自动检测并执行 `app.py`,则不一定会执行 `if __name__ == "__main__"` 分支。
# 因此你只要确保上面定义的 `app` 变量存在即可。
# 这里写一个可选的 main,方便本地调试。
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)