chatlawv1 / trlx /examples /summarize_rlhf /sft /summarize_dataset.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
import json
import pandas as pd
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
def get_dataset_from_jsonl(jsonl_file, return_summary=True):
# if return_summary is True, return a list of posts with summary concatenated
# if return_summary is False, return a list of posts and a list of summaries
with open(jsonl_file, "r") as f:
dataset = [json.loads(line) for line in f]
post_list = []
summary_list = []
for d in dataset:
if return_summary:
post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: {d['summary']}"
else:
post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: "
summary_list.append(d["summary"])
post_list.append(post)
if not return_summary:
return post_list, summary_list
return post_list
class TLDRDataset(Dataset):
def __init__(self, train_path, tokenizer, split, max_length=550):
self.post_list = []
dataset = load_dataset(train_path, split=split)
for sample in dataset:
self.post_list.append(sample["prompt"] + sample["label"])
if "valid" in split:
self.post_list = self.post_list[0:2000]
self.tokenizer = tokenizer
self.max_length = max_length
self.input_ids = []
self.attn_masks = []
def __len__(self):
return len(self.post_list)
def __getitem__(self, idx):
txt = self.post_list[idx]
encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
input_ids = torch.tensor(encodings_dict["input_ids"])
attn_masks = torch.tensor(encodings_dict["attention_mask"])
return {
"input_ids": input_ids,
"attention_mask": attn_masks,
"labels": input_ids,
}
class ComparisonDataset(Dataset):
def __init__(self, comparison_path, tokenizer, max_length=550):
with open(comparison_path, "r") as f:
dataset = [json.loads(line) for line in f]
self.tokenizer = tokenizer
self.post_list = []
self.summaries_0 = []
self.summaries_1 = []
self.labels = []
self.max_length = max_length
def make_text(post, summarize):
return f"SUBREDDIT: r/{post['subreddit']}\nTITLE: {post['title']}\nPOST: {post['post']}\nTL;DR: {summarize}"
for sample in dataset: # chosen summary is always the first one
self.post_list.append(sample["info"]["post"])
# NOTE: The chosen summary is always the first one, i.e. `sample["summaries"][0]`
if sample["choice"] == 0:
self.summaries_0.append(make_text(sample["info"], sample["summaries"][0]["text"]))
self.summaries_1.append(make_text(sample["info"], sample["summaries"][1]["text"]))
else:
self.summaries_0.append(make_text(sample["info"], sample["summaries"][1]["text"]))
self.summaries_1.append(make_text(sample["info"], sample["summaries"][0]["text"]))
self.labels.append(0)
def __len__(self):
return len(self.post_list)
def __getitem__(self, idx):
summ0 = self.summaries_0[idx]
summ1 = self.summaries_1[idx]
encodings_dict = self.tokenizer(
[summ0, summ1],
truncation=True,
max_length=self.max_length,
padding="max_length",
)
input_ids = torch.tensor(encodings_dict["input_ids"])
attention_mask = torch.tensor(encodings_dict["attention_mask"])
return {"input_ids": input_ids, "attention_mask": attention_mask}
class AllSummDataset(Dataset):
def __init__(self, train_path, tokenizer, split, max_length=1024):
df = pd.read_parquet(train_path)
if split == "valid":
df = df.sample(n=5000)
self.summarizes = []
for i, row in df.iterrows():
self.summarizes.append(f"Summarize: {row['text']}. TL;DR: {row['summary']}")
self.tokenizer = tokenizer
self.max_length = max_length
self.input_ids = []
self.attn_masks = []
def __len__(self):
return len(self.summarizes)
def __getitem__(self, idx):
txt = self.summarizes[idx]
encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
input_ids = torch.tensor(encodings_dict["input_ids"])
attn_masks = torch.tensor(encodings_dict["attention_mask"])
return {
"input_ids": input_ids,
"attention_mask": attn_masks,
"labels": input_ids,
}