File size: 2,486 Bytes
38733e2
 
 
c442fd5
38733e2
c442fd5
38733e2
 
 
 
 
 
 
 
 
 
 
7f087b8
 
 
 
 
 
 
 
 
 
 
38733e2
 
c442fd5
38733e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c442fd5
 
38733e2
c442fd5
7f087b8
 
 
 
 
38733e2
 
7f087b8
c442fd5
7f087b8
38733e2
 
 
c442fd5
 
38733e2
c442fd5
 
7f087b8
 
 
c442fd5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import subprocess
import sys
import json

# 设置环境
def setup_environment():
    if not os.path.exists("skywork-o1-prm-inference"):
        print("Cloning repository...")
        subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
    repo_path = os.path.abspath("skywork-o1-prm-inference")
    if repo_path not in sys.path:
        sys.path.append(repo_path)
        print(f"Added {repo_path} to Python path")

setup_environment()

import gradio as gr
from transformers import AutoTokenizer
from model_utils.prm_model import PRM_MODEL
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
import torch

model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()

def evaluate(problem, response):
    try:
        processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
        input_ids, steps, reward_flags = [processed_data]input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
            input_ids, 
            reward_flags, 
            tokenizer.pad_token_id
        )
        
        input_ids = input_ids.to("cpu")
        attention_mask = attention_mask.to("cpu")
        if isinstance(reward_flags, torch.Tensor):
            reward_flags = reward_flags.to("cpu")
        
        with torch.no_grad():
            _, _, rewards = model(
                input_ids=input_ids, 
                attention_mask=attention_mask, 
                return_probs=True
            )
        step_rewards = derive_step_rewards(rewards, reward_flags)
        #确保返回的是有效的JSON字符串
        return json.dumps(step_rewards[0].tolist())
    except Exception as e:
        return json.dumps({"error": str(e)})

# 创建Gradio界面
iface = gr.Interface(
    fn=evaluate,
    inputs=[
        gr.Textbox(label="Problem", lines=4),
        gr.Textbox(label="Response", lines=8)
    ],
    outputs=gr.JSON(),
    title="Problem Response Evaluation",
    description="Enter a problem and its response to get step-wise rewards",
    examples=[
        [
            "Janet'sducks lay 16 eggs per day...",
            "To determine how much money Janet makes..."
        ]
    ],
    cache_examples=False# 禁用示例缓存
)

# 启动接口
iface.launch(server_name="0.0.0.0")