teachyourselfcoding's picture
Upload 245 files
fa6856c
import json
import math
import os
import sys
from itertools import islice
import numpy as np
import torch
import tritonclient.grpc as client_util
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from tritonclient.utils import np_to_triton_dtype
import trlx
from trlx.data.default_configs import (
ModelConfig,
OptimizerConfig,
PPOConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)
default_config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=10000,
total_steps=10000,
batch_size=4,
checkpoint_interval=10000,
eval_interval=500,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
checkpoint_dir="checkpoints/ppo_hh",
),
model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=8e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=8e-6)),
method=PPOConfig(
name="PPOConfig",
num_rollouts=64,
chunk_size=16,
ppo_epochs=4,
init_kl_coef=0.05,
target=6,
horizon=10000,
gamma=1,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=1,
scale_reward="running",
ref_mean=None,
ref_std=None,
cliprange_reward=10,
gen_kwargs=dict(
max_new_tokens=128,
top_k=0,
top_p=1.0,
do_sample=True,
),
),
)
config_name = os.environ.get("CONFIG_NAME")
if config_name == "125M":
default_config.train.batch_size = 32
default_config.train.total_steps = 1500
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_125M"
default_config.model.model_path = "Dahoas/pythia-125M-static-sft"
default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b"
default_config.method.num_rollouts = 128
elif config_name == "1B":
default_config.train.batch_size = 8
default_config.train.total_steps = 2500
default_config.optimizer.kwargs["lr"] = 6e-6
default_config.scheduler.kwargs["eta_min"] = 6e-6
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_1B"
default_config.model.model_path = "Dahoas/pythia-1B-static-sft"
default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b"
default_config.method.chunk_size = 16
elif config_name == "6B":
default_config.train.batch_size = 4
default_config.train.seq_length = 512
default_config.train.total_steps = 6000
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_6B"
default_config.model.model_path = "Dahoas/pythia-6B-static-sft"
default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b"
default_config.method.chunk_size = 16
elif config_name == "20B":
default_config.train.seq_length = 512
default_config.train.batch_size = 1
default_config.train.total_steps = 8000
default_config.optimizer.kwargs["lr"] = 1e-6
default_config.scheduler.kwargs["eta_min"] = 1e-6
default_config.train.checkpoint_dir = "checkpoints/ppo_hh_20B"
default_config.model.model_path = "EleutherAI/gpt-neox-20b"
default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b"
default_config.method.num_rollouts = 16
default_config.method.chunk_size = 4
default_config.method.ppo_epochs = 2
def prepare_tensor(name: str, input):
t = client_util.InferInput(name, input.shape, np_to_triton_dtype(input.dtype))
t.set_data_from_numpy(input)
return t
def create_reward_fn(): # noqa: C901
reward_tokenizer = AutoTokenizer.from_pretrained("gpt2")
reward_tokenizer.pad_token = reward_tokenizer.eos_token
reward_tokenizer.truncation_side = "left"
triton_host = os.environ.get("TRITON_HOST")
if triton_host:
triton_url, triton_model = triton_host.split("/")
client = client_util.InferenceServerClient(url=triton_url, verbose=False)
def reward_fn(samples, prompts, outputs):
samples = [s + reward_tokenizer.eos_token for s in samples]
input = reward_tokenizer(samples, padding=True, max_length=1024)
mbs = 24
out = []
for i in range(math.ceil(len(samples) / mbs)):
batch_ixs = slice(i * mbs, (i + 1) * mbs)
input_ids = np.array(input.input_ids[batch_ixs], dtype=np.int32)
result = client.infer(triton_model, [prepare_tensor("input_ids", input_ids)])
rewards = result.as_numpy("rewards")
out.extend(rewards)
return out
elif os.environ.get("RANK", "0") == "0":
class RewardModel(nn.Module):
def __init__(self, checkpoint_path, eos_token_id):
super().__init__()
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
self.transformer = model.transformer
self.v_head = nn.Linear(model.config.n_embd, 1, bias=False)
self.eos_token_id = eos_token_id
def forward(self, input_ids):
states = self.transformer(input_ids)[0]
rewards = self.v_head(states).squeeze(-1)
ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1)
returns = torch.gather(rewards, 1, ends).squeeze(-1)
return returns
reward_model = RewardModel("EleutherAI/gpt-j-6B", reward_tokenizer.eos_token_id)
directory = snapshot_download("Dahoas/gptj-rm-static", revision="676bfd4d")
for fpath in os.listdir(directory):
if fpath.endswith(".pt") or fpath.endswith(".bin"):
checkpoint = os.path.join(directory, fpath)
break
reward_model.load_state_dict(torch.load(checkpoint))
reward_model.eval()
reward_model.requires_grad_(False)
reward_device = torch.cuda.device_count() - 1
reward_model = reward_model.half().to(reward_device)
reward_batch_size = 48
delta_reward = True
def get_reward(samples):
input = reward_tokenizer(
samples,
padding=True,
truncation=True,
max_length=reward_tokenizer.max_len_single_sentence,
return_tensors="pt",
).to(reward_device)
mbs = reward_batch_size
out = []
for i in range(math.ceil(len(samples) / mbs)):
batch_ixs = slice(i * mbs, (i + 1) * mbs)
input_ids = input.input_ids[batch_ixs]
rewards = reward_model(input_ids)
out.extend(rewards)
return torch.hstack(out)
def reward_fn(samples, prompts, original_output, **kwargs):
samples = [s + reward_tokenizer.eos_token for s in samples]
rewards = get_reward(samples)
if not delta_reward:
return rewards
original_samples = [p + o + reward_tokenizer.eos_token for p, o in zip(prompts, original_output)]
original_rewards = get_reward(original_samples)
return rewards - original_rewards
else:
reward_fn = True
return reward_fn
def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
dataset = load_dataset("Dahoas/rm-static")
prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in dataset["train"]]
eval_prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in islice(dataset["test"], 280)]
reward_fn = create_reward_fn()
trlx.train(
prompts=prompts,
eval_prompts=eval_prompts,
reward_fn=reward_fn,
config=config,
stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
)
if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)