import os import sys import subprocess import torch from typing import Union from pydantic import BaseModel from fastapi import FastAPI import uvicorn # gradio相关 import gradio as gr # transformers相关 from transformers import AutoTokenizer ############################################################################## # 1) 若本地(或Space)没有 skywork-o1-prm-inference 目录,则 clone 下来,并将其加入 sys.path ############################################################################## def setup_environment(): if not os.path.exists("skywork-o1-prm-inference"): print("Cloning repository skywork-o1-prm-inference...") subprocess.run( ["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True ) repo_path = os.path.abspath("skywork-o1-prm-inference") if repo_path not in sys.path: sys.path.append(repo_path) print(f"Added {repo_path} to Python path") setup_environment() ############################################################################## # 2) 导入Skywork项目内的模块 ############################################################################## from model_utils.prm_model import PRM_MODEL from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards ############################################################################## # 3) 加载模型及其相关资源 ############################################################################## # 你可根据需求更换 model_id model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B" print(f"Loading tokenizer from {model_id} ...") tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) print(f"Loading model from {model_id} ...") model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval() ############################################################################## # 4) 定义推理函数 ############################################################################## def compute_rewards(problem: str, response: str): """核心推理函数:将 problem + response 输入模型,输出 step_rewards。""" data = { "problem": problem, "response": response } processed_data = [ prepare_input( data["problem"], data["response"], tokenizer=tokenizer, step_token="\n" ) ] input_ids, steps, reward_flags = zip(*processed_data) # 准备 batch 输入 input_ids, attention_mask, reward_flags = prepare_batch_input_for_model( input_ids, reward_flags, tokenizer.pad_token_id ) input_ids = input_ids.to("cpu") attention_mask = attention_mask.to("cpu") if isinstance(reward_flags, torch.Tensor): reward_flags = reward_flags.to("cpu") # 前向推理 with torch.no_grad(): _, _, rewards = model( input_ids=input_ids, attention_mask=attention_mask, return_probs=True ) # 计算 step_rewards step_rewards = derive_step_rewards(rewards, reward_flags) return step_rewards[0] ############################################################################## # 5) 准备FastAPI应用:对外暴露 /api/predict 接口 ############################################################################## app = FastAPI(title="PRM Inference Server", description="FastAPI + Gradio App") # 5.1 定义API输入与输出的数据结构 class InferenceData(BaseModel): problem: str response: str class InferenceOutput(BaseModel): step_rewards: list @app.post("/api/predict", response_model=InferenceOutput) def predict(data: InferenceData): """ 直接使用HTTP POST /api/predict, body JSON: { "problem": "...", "response": "..." } 即可得到 step_rewards。 """ rewards = compute_rewards(data.problem, data.response) return InferenceOutput(step_rewards=rewards) ############################################################################## # 6) 构建 Gradio 界面,并将其挂载到 /gradio 路径 ############################################################################## def build_gradio_interface(): with gr.Blocks() as demo: gr.Markdown( "## PRM Reward Calculation\n\n" "输入 `Problem` 和 `Response`,点击下方按钮即可获得 step_rewards。" ) problem_input = gr.Textbox( label="Problem", lines=5, value=( "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning " "and bakes muffins for her friends every day with four. She sells the remainder " "at the farmers' market daily for $2 per fresh duck egg. How much in dollars " "does she make every day at the farmers' market?" ) ) response_input = gr.Textbox( label="Response", lines=10, value=( "To determine how much money Janet makes every day at the farmers' market, " "we need to follow these steps:\n1. ..." ) ) output = gr.JSON(label="Step Rewards") submit_btn = gr.Button("Compute Rewards") submit_btn.click( fn=compute_rewards, inputs=[problem_input, response_input], outputs=output ) return demo demo = build_gradio_interface() # 6.1 挂载Gradio到FastAPI的 /gradio 路径 app = gr.mount_gradio_app(app, demo, path="/") ############################################################################## # 7) 在 main 中用 uvicorn 启动 ############################################################################## # 注意:在Hugging Face Spaces中,一般只需要 `python app.py` 即可开始监听。 # 若Spaces是自动检测并执行 `app.py`,则不一定会执行 `if __name__ == "__main__"` 分支。 # 因此你只要确保上面定义的 `app` 变量存在即可。 # 这里写一个可选的 main,方便本地调试。 if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)