import argparse import os from string import Template import torch from huggingface_hub import snapshot_download from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer parser = argparse.ArgumentParser() parser.add_argument("--base_model", type=str, required=True, help="Path to HF checkpoint with the base model") parser.add_argument( "--checkpoint", type=str, required=True, help="Path to either a local directory or a HF checkpoint with reward model's weights", ) parser.add_argument("--revision", type=str, required=False, help="Optional branch/commit of the HF checkpoint") parser.add_argument("--device", type=int, default=0) args = parser.parse_args() model_name = args.checkpoint.split("/")[-1] device = torch.device(args.device) class RewardModel(nn.Module): def __init__(self, checkpoint_path, eos_token_id): super().__init__() model = AutoModelForCausalLM.from_pretrained(checkpoint_path) self.transformer = model.transformer self.v_head = nn.Linear(model.config.n_embd, 1, bias=False) self.eos_token_id = eos_token_id def forward(self, input_ids): states = self.transformer(input_ids)[0] rewards = self.v_head(states).squeeze(-1) ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1) returns = torch.gather(rewards, 1, ends).squeeze(-1) return returns if os.path.isdir(args.checkpoint): directory = args.checkpoint else: directory = snapshot_download(args.checkpoint, revision=args.revision) print(f"searching through {os.listdir(directory)} in {directory}") for fpath in os.listdir(directory): if fpath.endswith(".pt") or fpath.endswith(".bin"): checkpoint = os.path.join(directory, fpath) break tokenizer = AutoTokenizer.from_pretrained(args.base_model) model = RewardModel(args.base_model, tokenizer.eos_token_id) model.load_state_dict(torch.load(checkpoint)) model.eval() model.requires_grad_(False) model = model.half().to(device) input = tokenizer("reward model's hash", return_tensors="pt").to(device) print(f"{model(input.input_ids)=}") traced_script_module = torch.jit.trace(model, input.input_ids) os.makedirs(f"model_store/{model_name}/1", exist_ok=True) traced_script_module.save(f"model_store/{model_name}/1/traced-model.pt") config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "triton_config.pbtxt") with open(config_path) as f: template = Template(f.read()) config = template.substitute({"model_name": model_name}) with open(f"model_store/{model_name}/config.pbtxt", "w") as f: f.write(config)