File size: 2,900 Bytes
38733e2
 
33e0a2c
8d257f5
33e0a2c
8d257f5
611c227
33e0a2c
 
 
 
 
 
 
 
 
38733e2
33e0a2c
38733e2
33e0a2c
8d257f5
33e0a2c
7f087b8
33e0a2c
 
7f087b8
8d257f5
33e0a2c
 
611c227
33e0a2c
 
 
 
 
 
 
 
8d257f5
33e0a2c
 
8d257f5
 
 
 
611c227
8d257f5
33e0a2c
611c227
33e0a2c
8d257f5
 
611c227
33e0a2c
 
 
 
 
 
611c227
33e0a2c
 
7f087b8
8d257f5
33e0a2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import subprocess
import sys
import torch
import gradio as gr
from transformers import AutoTokenizer

# 设置环境:若本地无skywork-o1-prm-inference目录,则git clone下来,并将其加入sys.path
def setup_environment():
    if not os.path.exists("skywork-o1-prm-inference"):
        print("Cloning repository...")
        subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
    repo_path = os.path.abspath("skywork-o1-prm-inference")
    if repo_path not in sys.path:
        sys.path.append(repo_path)
        print(f"Added {repo_path} to Python path")

setup_environment()

# 在环境设置后再导入项目内的模块
from skywork-o1-prm-inference.model_utils.prm_model import PRM_MODEL
from skywork-o1-prm-inference.model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards

# 模型ID
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"

# 初始化模型和tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()

def compute_rewards(problem, response):
    # 准备数据
    data = {
        "problem": problem,
        "response": response
    }
    processed_data = [prepare_input(data["problem"], data["response"], tokenizer=tokenizer, step_token="\n")]
    input_ids, steps, reward_flags = zip(*processed_data)

    # 准备 batch 输入
    input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id)
    input_ids = input_ids.to("cpu")
    attention_mask = attention_mask.to("cpu")
    if isinstance(reward_flags, torch.Tensor):
        reward_flags = reward_flags.to("cpu")

    with torch.no_grad():
        _, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True)

    # 计算 step_rewards
    step_rewards = derive_step_rewards(rewards, reward_flags)
    return step_rewards[0]

# 搭建 Gradio 界面
with gr.Blocks() as demo:
    gr.Markdown("# PRM Reward Calculation")
    problem_input = gr.Textbox(label="Problem", lines=5, value="Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?")
    response_input = gr.Textbox(label="Response", lines=10, value="To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. ...")
    output = gr.JSON(label="Step Rewards")

    submit_btn = gr.Button("Compute Rewards")
    submit_btn.click(fn=compute_rewards, inputs=[problem_input, response_input], outputs=output)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)