lihongze8 commited on
Commit
822b0c8
·
verified ·
1 Parent(s): d41d753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -22
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import gradio as gr
6
  from transformers import AutoTokenizer
7
 
8
- # 设置环境:若本地无skywork-o1-prm-inference目录,则git clone下来,并将其加入sys.path
9
  def setup_environment():
10
  if not os.path.exists("skywork-o1-prm-inference"):
11
  print("Cloning repository...")
@@ -17,56 +17,88 @@ def setup_environment():
17
 
18
  setup_environment()
19
 
20
- # 在环境设置后再导入项目内的模块
21
  from model_utils.prm_model import PRM_MODEL
22
  from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
23
 
24
- # 模型ID
25
  model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
26
-
27
- # 初始化模型和tokenizer
28
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
29
  model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
30
 
 
31
  def compute_rewards(problem, response):
32
  # 准备数据
33
  data = {
34
  "problem": problem,
35
  "response": response
36
  }
37
- processed_data = [prepare_input(data["problem"], data["response"], tokenizer=tokenizer, step_token="\n")]
 
 
 
 
 
 
 
 
38
  input_ids, steps, reward_flags = zip(*processed_data)
39
 
40
  # 准备 batch 输入
41
- input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id)
 
 
 
 
42
  input_ids = input_ids.to("cpu")
43
  attention_mask = attention_mask.to("cpu")
44
  if isinstance(reward_flags, torch.Tensor):
45
  reward_flags = reward_flags.to("cpu")
46
 
 
47
  with torch.no_grad():
48
- _, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True)
 
 
 
 
49
 
50
  # 计算 step_rewards
51
  step_rewards = derive_step_rewards(rewards, reward_flags)
52
  return step_rewards[0]
53
 
54
- # 搭建 Gradio 界面
55
  with gr.Blocks() as demo:
56
- gr.Markdown("# PRM Reward Calculation")
57
- problem_input = gr.Textbox(label="Problem", lines=5, value="Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?")
58
- response_input = gr.Textbox(label="Response", lines=10, value="To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  output = gr.JSON(label="Step Rewards")
60
 
61
  submit_btn = gr.Button("Compute Rewards")
62
- submit_btn.click(fn=compute_rewards, inputs=[problem_input, response_input], outputs=output)
 
 
 
 
63
 
64
- if __name__ == "__main__":
65
- # 修改launch参数,启用API
66
- demo.launch(
67
- server_name="0.0.0.0",
68
- server_port=7860,
69
- share=True,# 允许生成公共URL
70
- enable_queue=True, # 启用队列
71
- api_name="/predict" # 指定API端点名称
72
- )
 
5
  import gradio as gr
6
  from transformers import AutoTokenizer
7
 
8
+ # 1. 若本地(或该Space)没有 skywork-o1-prm-inference 目录,则clone下来
9
  def setup_environment():
10
  if not os.path.exists("skywork-o1-prm-inference"):
11
  print("Cloning repository...")
 
17
 
18
  setup_environment()
19
 
20
+ # 2. 在环境准备完成后,再导入项目内的模块
21
  from model_utils.prm_model import PRM_MODEL
22
  from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
23
 
24
+ # 3. 模型ID:可根据你的需求替换
25
  model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
 
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
  model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
28
 
29
+ # 4. 定义推理函数
30
  def compute_rewards(problem, response):
31
  # 准备数据
32
  data = {
33
  "problem": problem,
34
  "response": response
35
  }
36
+ # 进行格式化
37
+ processed_data = [
38
+ prepare_input(
39
+ data["problem"],
40
+ data["response"],
41
+ tokenizer=tokenizer,
42
+ step_token="\n"
43
+ )
44
+ ]
45
  input_ids, steps, reward_flags = zip(*processed_data)
46
 
47
  # 准备 batch 输入
48
+ input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
49
+ input_ids,
50
+ reward_flags,
51
+ tokenizer.pad_token_id
52
+ )
53
  input_ids = input_ids.to("cpu")
54
  attention_mask = attention_mask.to("cpu")
55
  if isinstance(reward_flags, torch.Tensor):
56
  reward_flags = reward_flags.to("cpu")
57
 
58
+ # 前向传播
59
  with torch.no_grad():
60
+ _, _, rewards = model(
61
+ input_ids=input_ids,
62
+ attention_mask=attention_mask,
63
+ return_probs=True
64
+ )
65
 
66
  # 计算 step_rewards
67
  step_rewards = derive_step_rewards(rewards, reward_flags)
68
  return step_rewards[0]
69
 
70
+ # 5. Gradio 可视化界面
71
  with gr.Blocks() as demo:
72
+ gr.Markdown("## PRM Reward Calculation\n"
73
+ "在这里输入 problem response,即可计算对每个 step reward。")
74
+
75
+ problem_input = gr.Textbox(
76
+ label="Problem",
77
+ lines=5,
78
+ value=(
79
+ "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
80
+ "and bakes muffins for her friends every day with four. She sells the remainder "
81
+ "at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
82
+ "does she make every day at the farmers' market?"
83
+ )
84
+ )
85
+ response_input = gr.Textbox(
86
+ label="Response",
87
+ lines=10,
88
+ value=(
89
+ "To determine how much money Janet makes every day at the farmers' market, "
90
+ "we need to follow these steps:\n1. ..."
91
+ )
92
+ )
93
  output = gr.JSON(label="Step Rewards")
94
 
95
  submit_btn = gr.Button("Compute Rewards")
96
+ submit_btn.click(
97
+ fn=compute_rewards,
98
+ inputs=[problem_input, response_input],
99
+ outputs=output
100
+ )
101
 
102
+ # 6. 启动 Gradio 服务
103
+ # 注意:在 Hugging Face Spaces 环境中,推荐使用 server_name="0.0.0.0",端口常用 7860/7861。
104
+ demo.launch(server_name="0.0.0.0", server_port=7860)