import os from typing import List import torch from datasets import load_dataset from reward_model.reward_model import GPTRewardModel from tqdm import tqdm from transformers import AutoTokenizer import trlx from trlx.data.configs import ( ModelConfig, OptimizerConfig, SchedulerConfig, TokenizerConfig, TrainConfig, TRLConfig, ) from trlx.models.modeling_ppo import PPOConfig REWARD_CHECKPOINT_PATH = "reward_model/rm_checkpoint/pytorch_model.bin" if not os.path.exists(REWARD_CHECKPOINT_PATH): os.makedirs("reward_model/rm_checkpoint", exist_ok=True) os.system( f"wget -O {REWARD_CHECKPOINT_PATH} \ https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin" ) SFT_MODEL_PATH = "CarperAI/openai_summarize_tldr_sft" config = TRLConfig( train=TrainConfig( seq_length=550, epochs=50, total_steps=100000, batch_size=4, checkpoint_interval=10000, eval_interval=200, pipeline="PromptPipeline", trainer="AcceleratePPOTrainer", ), model=ModelConfig( model_path="CarperAI/openai_summarize_tldr_sft", num_layers_unfrozen=8, ), tokenizer=TokenizerConfig( tokenizer_path="gpt2", truncation_side="right", ), optimizer=OptimizerConfig( name="adamw", kwargs={ "lr": 5.0e-6, "betas": [0.9, 0.999], "eps": 1.0e-8, "weight_decay": 0.01, }, ), scheduler=SchedulerConfig( name="cosine_annealing", kwargs={ "T_max": 100000, "eta_min": 5.0e-6, }, ), method=PPOConfig( name="PPOConfig", num_rollouts=128, chunk_size=16, ppo_epochs=4, init_kl_coef=0.1, target=6, horizon=10000, gamma=1, lam=0.95, cliprange=0.2, cliprange_value=0.2, vf_coef=0.2, scale_reward=None, ref_mean=None, ref_std=None, cliprange_reward=10, gen_kwargs={ "max_new_tokens": 50, }, ), ) if __name__ == "__main__": # Load the pre-trained reward model rw_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") rw_tokenizer.pad_token = rw_tokenizer.eos_token rw_model = GPTRewardModel(SFT_MODEL_PATH) rw_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH), strict=False) rw_model.half() rw_model.eval() rw_device = torch.device("cuda:{}".format(1)) # set reward model device rw_model.to(rw_device) def get_scores(samples: List[str]): scores_list = [] batch_size = 2 for i in range(0, len(samples), batch_size): sub_samples = samples[i : i + batch_size] sub_samples = ["<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples] encodings_dict = rw_tokenizer( sub_samples, truncation=True, max_length=config.train.seq_length, padding="max_length", return_tensors="pt", ) input_ids = encodings_dict["input_ids"].to(rw_device) attn_masks = encodings_dict["attention_mask"].to(rw_device) input_ids = input_ids.repeat(2, 1) attn_masks = attn_masks.repeat(2, 1) with torch.no_grad(): sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks) scores_list.append(sub_scores["chosen_end_scores"]) scores = torch.cat(scores_list, dim=0) return scores def get_prompt_dataset(prompts, max_length): """ Get the prompt after T5 decoding to make sure dictionary of prompts and summaries is consistent decode prompt from trlX pipeline """ formatted_prompts = [] for i in tqdm(range(len(prompts))): tmp = tokenizer.decode( tokenizer( prompts[i].split("TL;DR:")[0], truncation=True, max_length=max_length - 5, # to make sure "TL;DR" dont get truncated add_special_tokens=False, )["input_ids"], skip_special_tokens=True, ).strip() tmp = tmp + "\nTL;DR:" tmp = tokenizer.decode( tokenizer(tmp, truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], skip_special_tokens=True, ).strip() formatted_prompts.append(tmp) return formatted_prompts def reward_fn(samples: List[str], **kwargs): original_samples = [text.split("TL;DR:")[0] + "TL;DR: " for text in samples] original_samples = [text + post_summary_dict[text.strip()] for text in original_samples] original_scores = get_scores(original_samples) scores = get_scores(samples) norms_scores = scores - original_scores return norms_scores tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" max_length_input = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] dataset = load_dataset("CarperAI/openai_summarize_tldr") # Store data into prompt and label pairs train_set = [(sample["prompt"], sample["label"]) for sample in dataset["train"]] val_set = [(sample["prompt"], sample["label"]) for sample in dataset["valid"]] # Split contents into summaries and labels train_posts, train_summaries = zip(*train_set) val_posts, val_summaries = zip(*val_set) # Get the OpenAI summaries post_summary_dict = {} train_prompts = get_prompt_dataset(train_posts, max_length_input) for i in range(len(train_prompts)): post_summary_dict[train_prompts[i]] = train_summaries[i] val_prompts = get_prompt_dataset(val_posts, max_length_input) for i in range(len(val_prompts)): post_summary_dict[val_prompts[i]] = val_summaries[i] trainer = trlx.train( reward_fn=reward_fn, prompts=train_prompts, eval_prompts=val_prompts[0:1000], # sampling 1000 validation prompts for evaluation speed in training config=config, )