Update app.py
Browse files
app.py
CHANGED
@@ -7,11 +7,9 @@ import json
|
|
7 |
def setup_environment():
|
8 |
if not os.path.exists("skywork-o1-prm-inference"):
|
9 |
print("Cloning repository...")
|
10 |
-
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
|
11 |
-
repo_path = os.path.abspath("skywork-o1-prm-inference")
|
12 |
if repo_path not in sys.path:
|
13 |
-
sys.path.append(repo_path)
|
14 |
-
print(f"Added {repo_path} to Python path")
|
15 |
|
16 |
setup_environment()
|
17 |
|
@@ -27,26 +25,28 @@ model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
|
|
27 |
|
28 |
def evaluate(problem, response):
|
29 |
try:
|
|
|
30 |
processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
|
31 |
-
input_ids, steps, reward_flags = [processed_data]
|
32 |
-
|
|
|
33 |
reward_flags,
|
34 |
tokenizer.pad_token_id
|
35 |
-
)
|
36 |
-
|
37 |
input_ids = input_ids.to("cpu")
|
38 |
attention_mask = attention_mask.to("cpu")
|
39 |
if isinstance(reward_flags, torch.Tensor):
|
40 |
reward_flags = reward_flags.to("cpu")
|
41 |
|
|
|
42 |
with torch.no_grad():
|
43 |
_, _, rewards = model(
|
44 |
input_ids=input_ids,
|
45 |
attention_mask=attention_mask,
|
46 |
return_probs=True
|
47 |
)
|
48 |
-
|
49 |
-
|
50 |
return json.dumps(step_rewards[0].tolist())
|
51 |
except Exception as e:
|
52 |
return json.dumps({"error": str(e)})
|
@@ -63,11 +63,11 @@ iface = gr.Interface(
|
|
63 |
description="Enter a problem and its response to get step-wise rewards",
|
64 |
examples=[
|
65 |
[
|
66 |
-
"Janet'
|
67 |
"To determine how much money Janet makes..."
|
68 |
]
|
69 |
],
|
70 |
-
cache_examples=False# 禁用示例缓存
|
71 |
)
|
72 |
|
73 |
# 启动接口
|
|
|
7 |
def setup_environment():
|
8 |
if not os.path.exists("skywork-o1-prm-inference"):
|
9 |
print("Cloning repository...")
|
10 |
+
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)repo_path = os.path.abspath("skywork-o1-prm-inference")
|
|
|
11 |
if repo_path not in sys.path:
|
12 |
+
sys.path.append(repo_path)print(f"Added {repo_path} to Python path")
|
|
|
13 |
|
14 |
setup_environment()
|
15 |
|
|
|
25 |
|
26 |
def evaluate(problem, response):
|
27 |
try:
|
28 |
+
# 处理输入数据
|
29 |
processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
|
30 |
+
input_ids, steps, reward_flags = [processed_data]# 准备批处理输入
|
31 |
+
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
|
32 |
+
input_ids,
|
33 |
reward_flags,
|
34 |
tokenizer.pad_token_id
|
35 |
+
)# 确保在CPU上
|
|
|
36 |
input_ids = input_ids.to("cpu")
|
37 |
attention_mask = attention_mask.to("cpu")
|
38 |
if isinstance(reward_flags, torch.Tensor):
|
39 |
reward_flags = reward_flags.to("cpu")
|
40 |
|
41 |
+
# 模型推理
|
42 |
with torch.no_grad():
|
43 |
_, _, rewards = model(
|
44 |
input_ids=input_ids,
|
45 |
attention_mask=attention_mask,
|
46 |
return_probs=True
|
47 |
)
|
48 |
+
# 计算步骤奖励
|
49 |
+
step_rewards = derive_step_rewards(rewards, reward_flags)# 确保返回的是有效的JSON字符串
|
50 |
return json.dumps(step_rewards[0].tolist())
|
51 |
except Exception as e:
|
52 |
return json.dumps({"error": str(e)})
|
|
|
63 |
description="Enter a problem and its response to get step-wise rewards",
|
64 |
examples=[
|
65 |
[
|
66 |
+
"Janet's ducks lay 16 eggs per day...",
|
67 |
"To determine how much money Janet makes..."
|
68 |
]
|
69 |
],
|
70 |
+
cache_examples=False # 禁用示例缓存
|
71 |
)
|
72 |
|
73 |
# 启动接口
|