|
import os |
|
import subprocess |
|
import sys |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
|
|
|
|
def setup_environment(): |
|
if not os.path.exists("skywork-o1-prm-inference"): |
|
print("Cloning repository...") |
|
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True) |
|
repo_path = os.path.abspath("skywork-o1-prm-inference") |
|
if repo_path not in sys.path: |
|
sys.path.append(repo_path) |
|
print(f"Added {repo_path} to Python path") |
|
|
|
setup_environment() |
|
|
|
|
|
from model_utils.prm_model import PRM_MODEL |
|
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards |
|
|
|
|
|
model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval() |
|
|
|
def compute_rewards(problem, response): |
|
|
|
data = { |
|
"problem": problem, |
|
"response": response |
|
} |
|
processed_data = [prepare_input(data["problem"], data["response"], tokenizer=tokenizer, step_token="\n")] |
|
input_ids, steps, reward_flags = zip(*processed_data) |
|
|
|
|
|
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id) |
|
input_ids = input_ids.to("cpu") |
|
attention_mask = attention_mask.to("cpu") |
|
if isinstance(reward_flags, torch.Tensor): |
|
reward_flags = reward_flags.to("cpu") |
|
|
|
with torch.no_grad(): |
|
_, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True) |
|
|
|
|
|
step_rewards = derive_step_rewards(rewards, reward_flags) |
|
return step_rewards[0] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# PRM Reward Calculation") |
|
problem_input = gr.Textbox(label="Problem", lines=5, value="Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?") |
|
response_input = gr.Textbox(label="Response", lines=10, value="To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. ...") |
|
output = gr.JSON(label="Step Rewards") |
|
|
|
submit_btn = gr.Button("Compute Rewards") |
|
submit_btn.click(fn=compute_rewards, inputs=[problem_input, response_input], outputs=output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|