File size: 6,190 Bytes
38733e2
33e0a2c
bde5ca7
8d257f5
bde5ca7
 
 
 
 
 
33e0a2c
bde5ca7
 
8d257f5
611c227
bde5ca7
 
 
33e0a2c
 
bde5ca7
 
 
 
 
33e0a2c
 
 
 
38733e2
33e0a2c
38733e2
bde5ca7
 
 
2ce9270
 
7f087b8
bde5ca7
 
 
 
33e0a2c
bde5ca7
 
33e0a2c
bde5ca7
 
33e0a2c
611c227
bde5ca7
 
 
 
 
33e0a2c
 
 
 
822b0c8
 
 
 
 
 
 
 
33e0a2c
8d257f5
33e0a2c
822b0c8
 
 
 
 
bde5ca7
8d257f5
 
 
 
611c227
bde5ca7
8d257f5
822b0c8
 
 
 
 
611c227
33e0a2c
8d257f5
 
611c227
bde5ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822b0c8
bde5ca7
 
 
 
 
 
 
 
 
 
822b0c8
bde5ca7
 
 
 
 
 
 
 
 
611c227
bde5ca7
 
 
 
 
 
 
 
 
 
 
 
e27c5f5
7f087b8
bde5ca7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import sys
import subprocess
import torch
from typing import Union
from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn

# gradio相关
import gradio as gr

# transformers相关
from transformers import AutoTokenizer

##############################################################################
# 1) 若本地(或Space)没有 skywork-o1-prm-inference 目录,则 clone 下来,并将其加入 sys.path
##############################################################################
def setup_environment():
    if not os.path.exists("skywork-o1-prm-inference"):
        print("Cloning repository skywork-o1-prm-inference...")
        subprocess.run(
            ["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"],
            check=True
        )
    repo_path = os.path.abspath("skywork-o1-prm-inference")
    if repo_path not in sys.path:
        sys.path.append(repo_path)
        print(f"Added {repo_path} to Python path")

setup_environment()

##############################################################################
# 2) 导入Skywork项目内的模块
##############################################################################
from model_utils.prm_model import PRM_MODEL
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards

##############################################################################
# 3) 加载模型及其相关资源
##############################################################################
# 你可根据需求更换 model_id
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"

print(f"Loading tokenizer from {model_id} ...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

print(f"Loading model from {model_id} ...")
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()

##############################################################################
# 4) 定义推理函数
##############################################################################
def compute_rewards(problem: str, response: str):
    """核心推理函数:将 problem + response 输入模型,输出 step_rewards。"""
    data = {
        "problem": problem,
        "response": response
    }
    processed_data = [
        prepare_input(
            data["problem"],
            data["response"],
            tokenizer=tokenizer,
            step_token="\n"
        )
    ]
    input_ids, steps, reward_flags = zip(*processed_data)

    # 准备 batch 输入
    input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
        input_ids,
        reward_flags,
        tokenizer.pad_token_id
    )

    input_ids = input_ids.to("cpu")
    attention_mask = attention_mask.to("cpu")
    if isinstance(reward_flags, torch.Tensor):
        reward_flags = reward_flags.to("cpu")

    # 前向推理
    with torch.no_grad():
        _, _, rewards = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_probs=True
        )

    # 计算 step_rewards
    step_rewards = derive_step_rewards(rewards, reward_flags)
    return step_rewards[0]

##############################################################################
# 5) 准备FastAPI应用:对外暴露 /api/predict 接口
##############################################################################
app = FastAPI(title="PRM Inference Server", description="FastAPI + Gradio App")

# 5.1 定义API输入与输出的数据结构
class InferenceData(BaseModel):
    problem: str
    response: str

class InferenceOutput(BaseModel):
    step_rewards: list

@app.post("/api/predict", response_model=InferenceOutput)
def predict(data: InferenceData):
    """
    直接使用HTTP POST /api/predict,
    body JSON: { "problem": "...", "response": "..." } 即可得到 step_rewards。
    """
    rewards = compute_rewards(data.problem, data.response)
    return InferenceOutput(step_rewards=rewards)

##############################################################################
# 6) 构建 Gradio 界面,并将其挂载到 /gradio 路径
##############################################################################
def build_gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown(
            "## PRM Reward Calculation\n\n"
            "输入 `Problem` 和 `Response`,点击下方按钮即可获得 step_rewards。"
        )

        problem_input = gr.Textbox(
            label="Problem",
            lines=5,
            value=(
                "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
                "and bakes muffins for her friends every day with four. She sells the remainder "
                "at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
                "does she make every day at the farmers' market?"
            )
        )
        response_input = gr.Textbox(
            label="Response",
            lines=10,
            value=(
                "To determine how much money Janet makes every day at the farmers' market, "
                "we need to follow these steps:\n1. ..."
            )
        )
        output = gr.JSON(label="Step Rewards")

        submit_btn = gr.Button("Compute Rewards")
        submit_btn.click(
            fn=compute_rewards,
            inputs=[problem_input, response_input],
            outputs=output
        )

    return demo

demo = build_gradio_interface()

# 6.1 挂载Gradio到FastAPI的 /gradio 路径
app = gr.mount_gradio_app(app, demo, path="/")

##############################################################################
# 7) 在 main 中用 uvicorn 启动
##############################################################################
# 注意:在Hugging Face Spaces中,一般只需要 `python app.py` 即可开始监听。
# 若Spaces是自动检测并执行 `app.py`,则不一定会执行 `if __name__ == "__main__"` 分支。
# 因此你只要确保上面定义的 `app` 变量存在即可。
# 这里写一个可选的 main,方便本地调试。
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)