Update app.py
Browse files
app.py
CHANGED
@@ -1,84 +1,65 @@
|
|
1 |
import os
|
2 |
import subprocess
|
3 |
-
import
|
4 |
import torch
|
|
|
5 |
from transformers import AutoTokenizer
|
6 |
|
7 |
-
#
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
if not os.path.exists(REPO_DIR):
|
13 |
-
subprocess.run(["git", "clone", REPO_URL], check=True)
|
14 |
|
15 |
-
#
|
16 |
from skywork-o1-prm-inference.model_utils.prm_model import PRM_MODEL
|
17 |
-
from skywork-o1-prm-inference.model_utils.io_utils import
|
18 |
-
prepare_input,
|
19 |
-
prepare_batch_input_for_model,
|
20 |
-
derive_step_rewards
|
21 |
-
)
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
|
26 |
# 初始化模型和tokenizer
|
27 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
28 |
-
model = PRM_MODEL.from_pretrained(
|
29 |
|
30 |
-
def
|
31 |
-
#
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
)
|
38 |
-
input_ids, steps, reward_flags = processed_input
|
39 |
|
40 |
-
#
|
41 |
-
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
|
42 |
-
[input_ids],
|
43 |
-
[reward_flags],
|
44 |
-
tokenizer.pad_token_id
|
45 |
-
)
|
46 |
input_ids = input_ids.to("cpu")
|
47 |
attention_mask = attention_mask.to("cpu")
|
48 |
if isinstance(reward_flags, torch.Tensor):
|
49 |
reward_flags = reward_flags.to("cpu")
|
50 |
|
51 |
-
# 模型推理
|
52 |
with torch.no_grad():
|
53 |
-
_, _, rewards = model(
|
54 |
-
input_ids=input_ids,
|
55 |
-
attention_mask=attention_mask,
|
56 |
-
return_probs=True
|
57 |
-
)
|
58 |
|
59 |
-
#
|
60 |
step_rewards = derive_step_rewards(rewards, reward_flags)
|
61 |
return step_rewards[0]
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
fn=inference_interface,
|
70 |
-
inputs=[
|
71 |
-
gr.Textbox(lines=4, label="Problem (题目)"),
|
72 |
-
gr.Textbox(lines=6, label="Response (回答)"),
|
73 |
-
],
|
74 |
-
outputs="json",
|
75 |
-
title="Skywork-o1-prm-inference Demo",
|
76 |
-
description=(
|
77 |
-
"输入题目和回答,点击提交查看其每个 step 对应的 reward,"
|
78 |
-
"这些 reward 值用于度量回答每一步的质量。"
|
79 |
-
),
|
80 |
-
)
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
-
|
84 |
-
|
|
|
1 |
import os
|
2 |
import subprocess
|
3 |
+
import sys
|
4 |
import torch
|
5 |
+
import gradio as gr
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
+
# 设置环境:若本地无skywork-o1-prm-inference目录,则git clone下来,并将其加入sys.path
|
9 |
+
def setup_environment():
|
10 |
+
if not os.path.exists("skywork-o1-prm-inference"):
|
11 |
+
print("Cloning repository...")
|
12 |
+
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
|
13 |
+
repo_path = os.path.abspath("skywork-o1-prm-inference")
|
14 |
+
if repo_path not in sys.path:
|
15 |
+
sys.path.append(repo_path)
|
16 |
+
print(f"Added {repo_path} to Python path")
|
17 |
|
18 |
+
setup_environment()
|
|
|
|
|
19 |
|
20 |
+
# 在环境设置后再导入项目内的模块
|
21 |
from skywork-o1-prm-inference.model_utils.prm_model import PRM_MODEL
|
22 |
+
from skywork-o1-prm-inference.model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
# 模型ID
|
25 |
+
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
|
26 |
|
27 |
# 初始化模型和tokenizer
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
29 |
+
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
|
30 |
|
31 |
+
def compute_rewards(problem, response):
|
32 |
+
# 准备数据
|
33 |
+
data = {
|
34 |
+
"problem": problem,
|
35 |
+
"response": response
|
36 |
+
}
|
37 |
+
processed_data = [prepare_input(data["problem"], data["response"], tokenizer=tokenizer, step_token="\n")]
|
38 |
+
input_ids, steps, reward_flags = zip(*processed_data)
|
|
|
39 |
|
40 |
+
# 准备 batch 输入
|
41 |
+
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id)
|
|
|
|
|
|
|
|
|
42 |
input_ids = input_ids.to("cpu")
|
43 |
attention_mask = attention_mask.to("cpu")
|
44 |
if isinstance(reward_flags, torch.Tensor):
|
45 |
reward_flags = reward_flags.to("cpu")
|
46 |
|
|
|
47 |
with torch.no_grad():
|
48 |
+
_, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True)
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
# 计算 step_rewards
|
51 |
step_rewards = derive_step_rewards(rewards, reward_flags)
|
52 |
return step_rewards[0]
|
53 |
|
54 |
+
# 搭建 Gradio 界面
|
55 |
+
with gr.Blocks() as demo:
|
56 |
+
gr.Markdown("# PRM Reward Calculation")
|
57 |
+
problem_input = gr.Textbox(label="Problem", lines=5, value="Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?")
|
58 |
+
response_input = gr.Textbox(label="Response", lines=10, value="To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. ...")
|
59 |
+
output = gr.JSON(label="Step Rewards")
|
60 |
|
61 |
+
submit_btn = gr.Button("Compute Rewards")
|
62 |
+
submit_btn.click(fn=compute_rewards, inputs=[problem_input, response_input], outputs=output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
if __name__ == "__main__":
|
65 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|