import json import math import os import sys from itertools import islice import numpy as np import torch import tritonclient.grpc as client_util from datasets import load_dataset from huggingface_hub import snapshot_download from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer from tritonclient.utils import np_to_triton_dtype import trlx from trlx.data.default_configs import ( ModelConfig, OptimizerConfig, PPOConfig, SchedulerConfig, TokenizerConfig, TrainConfig, TRLConfig, ) default_config = TRLConfig( train=TrainConfig( seq_length=1024, epochs=10000, total_steps=10000, batch_size=4, checkpoint_interval=10000, eval_interval=500, pipeline="PromptPipeline", trainer="AcceleratePPOTrainer", checkpoint_dir="checkpoints/ppo_hh", ), model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2), tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=8e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=8e-6)), method=PPOConfig( name="PPOConfig", num_rollouts=64, chunk_size=16, ppo_epochs=4, init_kl_coef=0.05, target=6, horizon=10000, gamma=1, lam=0.95, cliprange=0.2, cliprange_value=0.2, vf_coef=1, scale_reward="running", ref_mean=None, ref_std=None, cliprange_reward=10, gen_kwargs=dict( max_new_tokens=128, top_k=0, top_p=1.0, do_sample=True, ), ), ) config_name = os.environ.get("CONFIG_NAME") if config_name == "125M": default_config.train.batch_size = 32 default_config.train.total_steps = 1500 default_config.train.checkpoint_dir = "checkpoints/ppo_hh_125M" default_config.model.model_path = "Dahoas/pythia-125M-static-sft" default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" default_config.method.num_rollouts = 128 elif config_name == "1B": default_config.train.batch_size = 8 default_config.train.total_steps = 2500 default_config.optimizer.kwargs["lr"] = 6e-6 default_config.scheduler.kwargs["eta_min"] = 6e-6 default_config.train.checkpoint_dir = "checkpoints/ppo_hh_1B" default_config.model.model_path = "Dahoas/pythia-1B-static-sft" default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" default_config.method.chunk_size = 16 elif config_name == "6B": default_config.train.batch_size = 4 default_config.train.seq_length = 512 default_config.train.total_steps = 6000 default_config.train.checkpoint_dir = "checkpoints/ppo_hh_6B" default_config.model.model_path = "Dahoas/pythia-6B-static-sft" default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" default_config.method.chunk_size = 16 elif config_name == "20B": default_config.train.seq_length = 512 default_config.train.batch_size = 1 default_config.train.total_steps = 8000 default_config.optimizer.kwargs["lr"] = 1e-6 default_config.scheduler.kwargs["eta_min"] = 1e-6 default_config.train.checkpoint_dir = "checkpoints/ppo_hh_20B" default_config.model.model_path = "EleutherAI/gpt-neox-20b" default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" default_config.method.num_rollouts = 16 default_config.method.chunk_size = 4 default_config.method.ppo_epochs = 2 def prepare_tensor(name: str, input): t = client_util.InferInput(name, input.shape, np_to_triton_dtype(input.dtype)) t.set_data_from_numpy(input) return t def create_reward_fn(): # noqa: C901 reward_tokenizer = AutoTokenizer.from_pretrained("gpt2") reward_tokenizer.pad_token = reward_tokenizer.eos_token reward_tokenizer.truncation_side = "left" triton_host = os.environ.get("TRITON_HOST") if triton_host: triton_url, triton_model = triton_host.split("/") client = client_util.InferenceServerClient(url=triton_url, verbose=False) def reward_fn(samples, prompts, outputs): samples = [s + reward_tokenizer.eos_token for s in samples] input = reward_tokenizer(samples, padding=True, max_length=1024) mbs = 24 out = [] for i in range(math.ceil(len(samples) / mbs)): batch_ixs = slice(i * mbs, (i + 1) * mbs) input_ids = np.array(input.input_ids[batch_ixs], dtype=np.int32) result = client.infer(triton_model, [prepare_tensor("input_ids", input_ids)]) rewards = result.as_numpy("rewards") out.extend(rewards) return out elif os.environ.get("RANK", "0") == "0": class RewardModel(nn.Module): def __init__(self, checkpoint_path, eos_token_id): super().__init__() model = AutoModelForCausalLM.from_pretrained(checkpoint_path) self.transformer = model.transformer self.v_head = nn.Linear(model.config.n_embd, 1, bias=False) self.eos_token_id = eos_token_id def forward(self, input_ids): states = self.transformer(input_ids)[0] rewards = self.v_head(states).squeeze(-1) ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1) returns = torch.gather(rewards, 1, ends).squeeze(-1) return returns reward_model = RewardModel("EleutherAI/gpt-j-6B", reward_tokenizer.eos_token_id) directory = snapshot_download("Dahoas/gptj-rm-static", revision="676bfd4d") for fpath in os.listdir(directory): if fpath.endswith(".pt") or fpath.endswith(".bin"): checkpoint = os.path.join(directory, fpath) break reward_model.load_state_dict(torch.load(checkpoint)) reward_model.eval() reward_model.requires_grad_(False) reward_device = torch.cuda.device_count() - 1 reward_model = reward_model.half().to(reward_device) reward_batch_size = 48 delta_reward = True def get_reward(samples): input = reward_tokenizer( samples, padding=True, truncation=True, max_length=reward_tokenizer.max_len_single_sentence, return_tensors="pt", ).to(reward_device) mbs = reward_batch_size out = [] for i in range(math.ceil(len(samples) / mbs)): batch_ixs = slice(i * mbs, (i + 1) * mbs) input_ids = input.input_ids[batch_ixs] rewards = reward_model(input_ids) out.extend(rewards) return torch.hstack(out) def reward_fn(samples, prompts, original_output, **kwargs): samples = [s + reward_tokenizer.eos_token for s in samples] rewards = get_reward(samples) if not delta_reward: return rewards original_samples = [p + o + reward_tokenizer.eos_token for p, o in zip(prompts, original_output)] original_rewards = get_reward(original_samples) return rewards - original_rewards else: reward_fn = True return reward_fn def main(hparams={}): config = TRLConfig.update(default_config, hparams) dataset = load_dataset("Dahoas/rm-static") prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in dataset["train"]] eval_prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in islice(dataset["test"], 280)] reward_fn = create_reward_fn() trlx.train( prompts=prompts, eval_prompts=eval_prompts, reward_fn=reward_fn, config=config, stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], ) if __name__ == "__main__": hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) main(hparams)