teachyourselfcoding's picture
Upload 245 files
fa6856c
import random
import numpy as np
import torch
from datasets import load_dataset
from reward_model import GPTRewardModel
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
def set_seed(seed_val=42):
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"):
dataset = load_dataset(path, split=split)
if split == "test":
dataset = dataset.select(range(5000))
pairs = []
for sample in tqdm(dataset):
pair = {}
prompt = sample["prompt"]
chosen_summary = sample["chosen"]
rejected_summary = sample["rejected"]
if chosen_summary == rejected_summary:
continue
if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5:
continue
pair["chosen"] = prompt + "\n" + chosen_summary
pair["rejected"] = prompt + "\n" + rejected_summary
pairs.append(pair)
return pairs
class PairwiseDataset(Dataset):
def __init__(self, pairs, tokenizer, max_length):
self.chosen_input_ids = []
self.chosen_attn_masks = []
self.rejected_input_ids = []
self.rejected_attn_masks = []
for pair in pairs:
chosen, rejected = pair["chosen"], pair["rejected"]
chosen_encodings_dict = tokenizer(
"<|startoftext|>" + chosen + "<|endoftext|>",
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
rejected_encodings_dict = tokenizer(
"<|startoftext|>" + rejected + "<|endoftext|>",
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item():
self.chosen_input_ids.append(chosen_encodings_dict["input_ids"])
self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"])
self.rejected_input_ids.append(rejected_encodings_dict["input_ids"])
self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"])
def __len__(self):
return len(self.chosen_input_ids)
def __getitem__(self, idx):
return (
self.chosen_input_ids[idx],
self.chosen_attn_masks[idx],
self.rejected_input_ids[idx],
self.rejected_attn_masks[idx],
)
class DataCollatorReward:
def __call__(self, data):
batch = {}
batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data])
batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data])
batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data))
return batch
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0]
model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft")
model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin"))
max_length = 550
val_pairs = create_comparison_dataset("CarperAI/openai_summarize_comparisons", "test")
dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length)
from torch.utils.data import DataLoader
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward())
model.cuda()
model.eval()
model.half()
correct = 0
chosen_list = []
reject_list = []
with torch.no_grad():
for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)):
for x in batch:
batch[x] = batch[x].cuda()
outputs = model(**batch)
correct += sum(outputs["chosen_end_scores"] > outputs["rejected_end_scores"])
chosen_list.append(outputs["chosen_end_scores"].cpu())
reject_list.append(outputs["rejected_end_scores"].cpu())
print("Total accuracy: ", correct / len(dev_dataset))