File size: 6,190 Bytes
38733e2 33e0a2c bde5ca7 8d257f5 bde5ca7 33e0a2c bde5ca7 8d257f5 611c227 bde5ca7 33e0a2c bde5ca7 33e0a2c 38733e2 33e0a2c 38733e2 bde5ca7 2ce9270 7f087b8 bde5ca7 33e0a2c bde5ca7 33e0a2c bde5ca7 33e0a2c 611c227 bde5ca7 33e0a2c 822b0c8 33e0a2c 8d257f5 33e0a2c 822b0c8 bde5ca7 8d257f5 611c227 bde5ca7 8d257f5 822b0c8 611c227 33e0a2c 8d257f5 611c227 bde5ca7 822b0c8 bde5ca7 822b0c8 bde5ca7 611c227 bde5ca7 e27c5f5 7f087b8 bde5ca7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import sys
import subprocess
import torch
from typing import Union
from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn
# gradio相关
import gradio as gr
# transformers相关
from transformers import AutoTokenizer
##############################################################################
# 1) 若本地(或Space)没有 skywork-o1-prm-inference 目录,则 clone 下来,并将其加入 sys.path
##############################################################################
def setup_environment():
if not os.path.exists("skywork-o1-prm-inference"):
print("Cloning repository skywork-o1-prm-inference...")
subprocess.run(
["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"],
check=True
)
repo_path = os.path.abspath("skywork-o1-prm-inference")
if repo_path not in sys.path:
sys.path.append(repo_path)
print(f"Added {repo_path} to Python path")
setup_environment()
##############################################################################
# 2) 导入Skywork项目内的模块
##############################################################################
from model_utils.prm_model import PRM_MODEL
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
##############################################################################
# 3) 加载模型及其相关资源
##############################################################################
# 你可根据需求更换 model_id
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
print(f"Loading tokenizer from {model_id} ...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print(f"Loading model from {model_id} ...")
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
##############################################################################
# 4) 定义推理函数
##############################################################################
def compute_rewards(problem: str, response: str):
"""核心推理函数:将 problem + response 输入模型,输出 step_rewards。"""
data = {
"problem": problem,
"response": response
}
processed_data = [
prepare_input(
data["problem"],
data["response"],
tokenizer=tokenizer,
step_token="\n"
)
]
input_ids, steps, reward_flags = zip(*processed_data)
# 准备 batch 输入
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
input_ids,
reward_flags,
tokenizer.pad_token_id
)
input_ids = input_ids.to("cpu")
attention_mask = attention_mask.to("cpu")
if isinstance(reward_flags, torch.Tensor):
reward_flags = reward_flags.to("cpu")
# 前向推理
with torch.no_grad():
_, _, rewards = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_probs=True
)
# 计算 step_rewards
step_rewards = derive_step_rewards(rewards, reward_flags)
return step_rewards[0]
##############################################################################
# 5) 准备FastAPI应用:对外暴露 /api/predict 接口
##############################################################################
app = FastAPI(title="PRM Inference Server", description="FastAPI + Gradio App")
# 5.1 定义API输入与输出的数据结构
class InferenceData(BaseModel):
problem: str
response: str
class InferenceOutput(BaseModel):
step_rewards: list
@app.post("/api/predict", response_model=InferenceOutput)
def predict(data: InferenceData):
"""
直接使用HTTP POST /api/predict,
body JSON: { "problem": "...", "response": "..." } 即可得到 step_rewards。
"""
rewards = compute_rewards(data.problem, data.response)
return InferenceOutput(step_rewards=rewards)
##############################################################################
# 6) 构建 Gradio 界面,并将其挂载到 /gradio 路径
##############################################################################
def build_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown(
"## PRM Reward Calculation\n\n"
"输入 `Problem` 和 `Response`,点击下方按钮即可获得 step_rewards。"
)
problem_input = gr.Textbox(
label="Problem",
lines=5,
value=(
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
"and bakes muffins for her friends every day with four. She sells the remainder "
"at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
"does she make every day at the farmers' market?"
)
)
response_input = gr.Textbox(
label="Response",
lines=10,
value=(
"To determine how much money Janet makes every day at the farmers' market, "
"we need to follow these steps:\n1. ..."
)
)
output = gr.JSON(label="Step Rewards")
submit_btn = gr.Button("Compute Rewards")
submit_btn.click(
fn=compute_rewards,
inputs=[problem_input, response_input],
outputs=output
)
return demo
demo = build_gradio_interface()
# 6.1 挂载Gradio到FastAPI的 /gradio 路径
app = gr.mount_gradio_app(app, demo, path="/")
##############################################################################
# 7) 在 main 中用 uvicorn 启动
##############################################################################
# 注意:在Hugging Face Spaces中,一般只需要 `python app.py` 即可开始监听。
# 若Spaces是自动检测并执行 `app.py`,则不一定会执行 `if __name__ == "__main__"` 分支。
# 因此你只要确保上面定义的 `app` 变量存在即可。
# 这里写一个可选的 main,方便本地调试。
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |