|
import os |
|
import subprocess |
|
import sys |
|
import json |
|
|
|
|
|
def setup_environment(): |
|
if not os.path.exists("skywork-o1-prm-inference"): |
|
print("Cloning repository...") |
|
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True) |
|
repo_path = os.path.abspath("skywork-o1-prm-inference") |
|
if repo_path not in sys.path: |
|
sys.path.append(repo_path) |
|
print(f"Added {repo_path} to Python path") |
|
|
|
setup_environment() |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
from model_utils.prm_model import PRM_MODEL |
|
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards |
|
import torch |
|
|
|
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval() |
|
|
|
def evaluate(problem, response): |
|
try: |
|
processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n") |
|
input_ids, steps, reward_flags = [processed_data]input_ids, attention_mask, reward_flags = prepare_batch_input_for_model( |
|
input_ids, |
|
reward_flags, |
|
tokenizer.pad_token_id |
|
) |
|
|
|
input_ids = input_ids.to("cpu") |
|
attention_mask = attention_mask.to("cpu") |
|
if isinstance(reward_flags, torch.Tensor): |
|
reward_flags = reward_flags.to("cpu") |
|
|
|
with torch.no_grad(): |
|
_, _, rewards = model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_probs=True |
|
) |
|
step_rewards = derive_step_rewards(rewards, reward_flags) |
|
|
|
return json.dumps(step_rewards[0].tolist()) |
|
except Exception as e: |
|
return json.dumps({"error": str(e)}) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=evaluate, |
|
inputs=[ |
|
gr.Textbox(label="Problem", lines=4), |
|
gr.Textbox(label="Response", lines=8) |
|
], |
|
outputs=gr.JSON(), |
|
title="Problem Response Evaluation", |
|
description="Enter a problem and its response to get step-wise rewards", |
|
examples=[ |
|
[ |
|
"Janet'sducks lay 16 eggs per day...", |
|
"To determine how much money Janet makes..." |
|
] |
|
], |
|
cache_examples=False |
|
) |
|
|
|
|
|
iface.launch(server_name="0.0.0.0") |
|
|