lihongze8 commited on
Commit
f71f486
·
verified ·
1 Parent(s): c442fd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -7,11 +7,9 @@ import json
7
  def setup_environment():
8
  if not os.path.exists("skywork-o1-prm-inference"):
9
  print("Cloning repository...")
10
- subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
11
- repo_path = os.path.abspath("skywork-o1-prm-inference")
12
  if repo_path not in sys.path:
13
- sys.path.append(repo_path)
14
- print(f"Added {repo_path} to Python path")
15
 
16
  setup_environment()
17
 
@@ -27,26 +25,28 @@ model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
27
 
28
  def evaluate(problem, response):
29
  try:
 
30
  processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
31
- input_ids, steps, reward_flags = [processed_data]input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
32
- input_ids,
 
33
  reward_flags,
34
  tokenizer.pad_token_id
35
- )
36
-
37
  input_ids = input_ids.to("cpu")
38
  attention_mask = attention_mask.to("cpu")
39
  if isinstance(reward_flags, torch.Tensor):
40
  reward_flags = reward_flags.to("cpu")
41
 
 
42
  with torch.no_grad():
43
  _, _, rewards = model(
44
  input_ids=input_ids,
45
  attention_mask=attention_mask,
46
  return_probs=True
47
  )
48
- step_rewards = derive_step_rewards(rewards, reward_flags)
49
- #确保返回的是有效的JSON字符串
50
  return json.dumps(step_rewards[0].tolist())
51
  except Exception as e:
52
  return json.dumps({"error": str(e)})
@@ -63,11 +63,11 @@ iface = gr.Interface(
63
  description="Enter a problem and its response to get step-wise rewards",
64
  examples=[
65
  [
66
- "Janet'sducks lay 16 eggs per day...",
67
  "To determine how much money Janet makes..."
68
  ]
69
  ],
70
- cache_examples=False# 禁用示例缓存
71
  )
72
 
73
  # 启动接口
 
7
  def setup_environment():
8
  if not os.path.exists("skywork-o1-prm-inference"):
9
  print("Cloning repository...")
10
+ subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)repo_path = os.path.abspath("skywork-o1-prm-inference")
 
11
  if repo_path not in sys.path:
12
+ sys.path.append(repo_path)print(f"Added {repo_path} to Python path")
 
13
 
14
  setup_environment()
15
 
 
25
 
26
  def evaluate(problem, response):
27
  try:
28
+ # 处理输入数据
29
  processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
30
+ input_ids, steps, reward_flags = [processed_data]# 准备批处理输入
31
+ input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
32
+ input_ids,
33
  reward_flags,
34
  tokenizer.pad_token_id
35
+ )# 确保在CPU上
 
36
  input_ids = input_ids.to("cpu")
37
  attention_mask = attention_mask.to("cpu")
38
  if isinstance(reward_flags, torch.Tensor):
39
  reward_flags = reward_flags.to("cpu")
40
 
41
+ # 模型推理
42
  with torch.no_grad():
43
  _, _, rewards = model(
44
  input_ids=input_ids,
45
  attention_mask=attention_mask,
46
  return_probs=True
47
  )
48
+ # 计算步骤奖励
49
+ step_rewards = derive_step_rewards(rewards, reward_flags)# 确保返回的是有效的JSON字符串
50
  return json.dumps(step_rewards[0].tolist())
51
  except Exception as e:
52
  return json.dumps({"error": str(e)})
 
63
  description="Enter a problem and its response to get step-wise rewards",
64
  examples=[
65
  [
66
+ "Janet's ducks lay 16 eggs per day...",
67
  "To determine how much money Janet makes..."
68
  ]
69
  ],
70
+ cache_examples=False # 禁用示例缓存
71
  )
72
 
73
  # 启动接口