Spaces:
Runtime error
Runtime error
import random | |
import numpy as np | |
import torch | |
from datasets import load_dataset | |
from reward_model import GPTRewardModel | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from transformers import AutoTokenizer | |
def set_seed(seed_val=42): | |
random.seed(seed_val) | |
np.random.seed(seed_val) | |
torch.manual_seed(seed_val) | |
torch.cuda.manual_seed_all(seed_val) | |
def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"): | |
dataset = load_dataset(path, split=split) | |
if split == "test": | |
dataset = dataset.select(range(5000)) | |
pairs = [] | |
for sample in tqdm(dataset): | |
pair = {} | |
prompt = sample["prompt"] | |
chosen_summary = sample["chosen"] | |
rejected_summary = sample["rejected"] | |
if chosen_summary == rejected_summary: | |
continue | |
if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: | |
continue | |
pair["chosen"] = prompt + "\n" + chosen_summary | |
pair["rejected"] = prompt + "\n" + rejected_summary | |
pairs.append(pair) | |
return pairs | |
class PairwiseDataset(Dataset): | |
def __init__(self, pairs, tokenizer, max_length): | |
self.chosen_input_ids = [] | |
self.chosen_attn_masks = [] | |
self.rejected_input_ids = [] | |
self.rejected_attn_masks = [] | |
for pair in pairs: | |
chosen, rejected = pair["chosen"], pair["rejected"] | |
chosen_encodings_dict = tokenizer( | |
"<|startoftext|>" + chosen + "<|endoftext|>", | |
truncation=True, | |
max_length=max_length, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
rejected_encodings_dict = tokenizer( | |
"<|startoftext|>" + rejected + "<|endoftext|>", | |
truncation=True, | |
max_length=max_length, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): | |
self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) | |
self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) | |
self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) | |
self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) | |
def __len__(self): | |
return len(self.chosen_input_ids) | |
def __getitem__(self, idx): | |
return ( | |
self.chosen_input_ids[idx], | |
self.chosen_attn_masks[idx], | |
self.rejected_input_ids[idx], | |
self.rejected_attn_masks[idx], | |
) | |
class DataCollatorReward: | |
def __call__(self, data): | |
batch = {} | |
batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) | |
batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) | |
batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) | |
return batch | |
if __name__ == "__main__": | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
tokenizer.pad_token = tokenizer.eos_token | |
PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0] | |
model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") | |
model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin")) | |
max_length = 550 | |
val_pairs = create_comparison_dataset("CarperAI/openai_summarize_comparisons", "test") | |
dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) | |
from torch.utils.data import DataLoader | |
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward()) | |
model.cuda() | |
model.eval() | |
model.half() | |
correct = 0 | |
chosen_list = [] | |
reject_list = [] | |
with torch.no_grad(): | |
for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)): | |
for x in batch: | |
batch[x] = batch[x].cuda() | |
outputs = model(**batch) | |
correct += sum(outputs["chosen_end_scores"] > outputs["rejected_end_scores"]) | |
chosen_list.append(outputs["chosen_end_scores"].cpu()) | |
reject_list.append(outputs["rejected_end_scores"].cpu()) | |
print("Total accuracy: ", correct / len(dev_dataset)) | |