|
import os |
|
import sys |
|
import subprocess |
|
import torch |
|
from typing import Union |
|
from pydantic import BaseModel |
|
from fastapi import FastAPI |
|
import uvicorn |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
|
def setup_environment(): |
|
if not os.path.exists("skywork-o1-prm-inference"): |
|
print("Cloning repository skywork-o1-prm-inference...") |
|
subprocess.run( |
|
["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], |
|
check=True |
|
) |
|
repo_path = os.path.abspath("skywork-o1-prm-inference") |
|
if repo_path not in sys.path: |
|
sys.path.append(repo_path) |
|
print(f"Added {repo_path} to Python path") |
|
|
|
setup_environment() |
|
|
|
|
|
|
|
|
|
from model_utils.prm_model import PRM_MODEL |
|
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards |
|
|
|
|
|
|
|
|
|
|
|
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B" |
|
|
|
print(f"Loading tokenizer from {model_id} ...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
print(f"Loading model from {model_id} ...") |
|
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval() |
|
|
|
|
|
|
|
|
|
def compute_rewards(problem: str, response: str): |
|
"""核心推理函数:将 problem + response 输入模型,输出 step_rewards。""" |
|
data = { |
|
"problem": problem, |
|
"response": response |
|
} |
|
processed_data = [ |
|
prepare_input( |
|
data["problem"], |
|
data["response"], |
|
tokenizer=tokenizer, |
|
step_token="\n" |
|
) |
|
] |
|
input_ids, steps, reward_flags = zip(*processed_data) |
|
|
|
|
|
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model( |
|
input_ids, |
|
reward_flags, |
|
tokenizer.pad_token_id |
|
) |
|
|
|
input_ids = input_ids.to("cpu") |
|
attention_mask = attention_mask.to("cpu") |
|
if isinstance(reward_flags, torch.Tensor): |
|
reward_flags = reward_flags.to("cpu") |
|
|
|
|
|
with torch.no_grad(): |
|
_, _, rewards = model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_probs=True |
|
) |
|
|
|
|
|
step_rewards = derive_step_rewards(rewards, reward_flags) |
|
return step_rewards[0] |
|
|
|
|
|
|
|
|
|
app = FastAPI(title="PRM Inference Server", description="FastAPI + Gradio App") |
|
|
|
|
|
class InferenceData(BaseModel): |
|
problem: str |
|
response: str |
|
|
|
class InferenceOutput(BaseModel): |
|
step_rewards: list |
|
|
|
@app.post("/api/predict", response_model=InferenceOutput) |
|
def predict(data: InferenceData): |
|
""" |
|
直接使用HTTP POST /api/predict, |
|
body JSON: { "problem": "...", "response": "..." } 即可得到 step_rewards。 |
|
""" |
|
rewards = compute_rewards(data.problem, data.response) |
|
return InferenceOutput(step_rewards=rewards) |
|
|
|
|
|
|
|
|
|
def build_gradio_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"## PRM Reward Calculation\n\n" |
|
"输入 `Problem` 和 `Response`,点击下方按钮即可获得 step_rewards。" |
|
) |
|
|
|
problem_input = gr.Textbox( |
|
label="Problem", |
|
lines=5, |
|
value=( |
|
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning " |
|
"and bakes muffins for her friends every day with four. She sells the remainder " |
|
"at the farmers' market daily for $2 per fresh duck egg. How much in dollars " |
|
"does she make every day at the farmers' market?" |
|
) |
|
) |
|
response_input = gr.Textbox( |
|
label="Response", |
|
lines=10, |
|
value=( |
|
"To determine how much money Janet makes every day at the farmers' market, " |
|
"we need to follow these steps:\n1. ..." |
|
) |
|
) |
|
output = gr.JSON(label="Step Rewards") |
|
|
|
submit_btn = gr.Button("Compute Rewards") |
|
submit_btn.click( |
|
fn=compute_rewards, |
|
inputs=[problem_input, response_input], |
|
outputs=output |
|
) |
|
|
|
return demo |
|
|
|
demo = build_gradio_interface() |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |