Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import gradio as gr
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
-
#
|
9 |
def setup_environment():
|
10 |
if not os.path.exists("skywork-o1-prm-inference"):
|
11 |
print("Cloning repository...")
|
@@ -17,56 +17,88 @@ def setup_environment():
|
|
17 |
|
18 |
setup_environment()
|
19 |
|
20 |
-
#
|
21 |
from model_utils.prm_model import PRM_MODEL
|
22 |
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
|
23 |
|
24 |
-
# 模型ID
|
25 |
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
|
26 |
-
|
27 |
-
# 初始化模型和tokenizer
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
29 |
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
|
30 |
|
|
|
31 |
def compute_rewards(problem, response):
|
32 |
# 准备数据
|
33 |
data = {
|
34 |
"problem": problem,
|
35 |
"response": response
|
36 |
}
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
input_ids, steps, reward_flags = zip(*processed_data)
|
39 |
|
40 |
# 准备 batch 输入
|
41 |
-
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
|
|
|
|
|
|
|
|
|
42 |
input_ids = input_ids.to("cpu")
|
43 |
attention_mask = attention_mask.to("cpu")
|
44 |
if isinstance(reward_flags, torch.Tensor):
|
45 |
reward_flags = reward_flags.to("cpu")
|
46 |
|
|
|
47 |
with torch.no_grad():
|
48 |
-
_, _, rewards = model(
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# 计算 step_rewards
|
51 |
step_rewards = derive_step_rewards(rewards, reward_flags)
|
52 |
return step_rewards[0]
|
53 |
|
54 |
-
#
|
55 |
with gr.Blocks() as demo:
|
56 |
-
gr.Markdown("
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
output = gr.JSON(label="Step Rewards")
|
60 |
|
61 |
submit_btn = gr.Button("Compute Rewards")
|
62 |
-
submit_btn.click(
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
server_name="0.0.0.0",
|
68 |
-
server_port=7860,
|
69 |
-
share=True,# 允许生成公共URL
|
70 |
-
enable_queue=True, # 启用队列
|
71 |
-
api_name="/predict" # 指定API端点名称
|
72 |
-
)
|
|
|
5 |
import gradio as gr
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
+
# 1. 若本地(或该Space)没有 skywork-o1-prm-inference 目录,则clone下来
|
9 |
def setup_environment():
|
10 |
if not os.path.exists("skywork-o1-prm-inference"):
|
11 |
print("Cloning repository...")
|
|
|
17 |
|
18 |
setup_environment()
|
19 |
|
20 |
+
# 2. 在环境准备完成后,再导入项目内的模块
|
21 |
from model_utils.prm_model import PRM_MODEL
|
22 |
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
|
23 |
|
24 |
+
# 3. 模型ID:可根据你的需求替换
|
25 |
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
|
|
|
|
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
27 |
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
|
28 |
|
29 |
+
# 4. 定义推理函数
|
30 |
def compute_rewards(problem, response):
|
31 |
# 准备数据
|
32 |
data = {
|
33 |
"problem": problem,
|
34 |
"response": response
|
35 |
}
|
36 |
+
# 进行格式化
|
37 |
+
processed_data = [
|
38 |
+
prepare_input(
|
39 |
+
data["problem"],
|
40 |
+
data["response"],
|
41 |
+
tokenizer=tokenizer,
|
42 |
+
step_token="\n"
|
43 |
+
)
|
44 |
+
]
|
45 |
input_ids, steps, reward_flags = zip(*processed_data)
|
46 |
|
47 |
# 准备 batch 输入
|
48 |
+
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
|
49 |
+
input_ids,
|
50 |
+
reward_flags,
|
51 |
+
tokenizer.pad_token_id
|
52 |
+
)
|
53 |
input_ids = input_ids.to("cpu")
|
54 |
attention_mask = attention_mask.to("cpu")
|
55 |
if isinstance(reward_flags, torch.Tensor):
|
56 |
reward_flags = reward_flags.to("cpu")
|
57 |
|
58 |
+
# 前向传播
|
59 |
with torch.no_grad():
|
60 |
+
_, _, rewards = model(
|
61 |
+
input_ids=input_ids,
|
62 |
+
attention_mask=attention_mask,
|
63 |
+
return_probs=True
|
64 |
+
)
|
65 |
|
66 |
# 计算 step_rewards
|
67 |
step_rewards = derive_step_rewards(rewards, reward_flags)
|
68 |
return step_rewards[0]
|
69 |
|
70 |
+
# 5. Gradio 可视化界面
|
71 |
with gr.Blocks() as demo:
|
72 |
+
gr.Markdown("## PRM Reward Calculation\n"
|
73 |
+
"在这里输入 problem 和 response,即可计算对每个 step 的 reward。")
|
74 |
+
|
75 |
+
problem_input = gr.Textbox(
|
76 |
+
label="Problem",
|
77 |
+
lines=5,
|
78 |
+
value=(
|
79 |
+
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
|
80 |
+
"and bakes muffins for her friends every day with four. She sells the remainder "
|
81 |
+
"at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
|
82 |
+
"does she make every day at the farmers' market?"
|
83 |
+
)
|
84 |
+
)
|
85 |
+
response_input = gr.Textbox(
|
86 |
+
label="Response",
|
87 |
+
lines=10,
|
88 |
+
value=(
|
89 |
+
"To determine how much money Janet makes every day at the farmers' market, "
|
90 |
+
"we need to follow these steps:\n1. ..."
|
91 |
+
)
|
92 |
+
)
|
93 |
output = gr.JSON(label="Step Rewards")
|
94 |
|
95 |
submit_btn = gr.Button("Compute Rewards")
|
96 |
+
submit_btn.click(
|
97 |
+
fn=compute_rewards,
|
98 |
+
inputs=[problem_input, response_input],
|
99 |
+
outputs=output
|
100 |
+
)
|
101 |
|
102 |
+
# 6. 启动 Gradio 服务
|
103 |
+
# 注意:在 Hugging Face Spaces 环境中,推荐使用 server_name="0.0.0.0",端口常用 7860/7861。
|
104 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|