chatlawv1 / trlx /examples /summarize_rlhf /sft /train_gptj_summarize.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
import random
import evaluate
import numpy as np
import torch
from summarize_dataset import TLDRDataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
default_data_collator,
)
def set_seed(seed_val=42):
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
if __name__ == "__main__":
output_dir = "gptj-supervised-summarize-checkpoint"
train_batch_size = 16
gradient_accumulation_steps = 1
learning_rate = 1e-5
eval_batch_size = 1
eval_steps = 500
max_input_length = 550
save_steps = 1000
num_train_epochs = 5
random.seed(42)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False)
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
# Set up the datasets
data_path = "CarperAI/openai_summarize_tldr"
train_dataset = TLDRDataset(
data_path,
tokenizer,
"train",
max_length=max_input_length,
)
dev_dataset = TLDRDataset(
data_path,
tokenizer,
"valid",
max_length=max_input_length,
)
# Set up the metric
rouge = evaluate.load("rouge")
def compute_metrics(eval_preds):
labels_ids = eval_preds.label_ids
pred_ids = eval_preds.predictions
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
result = rouge.compute(predictions=pred_str, references=label_str)
return result
# Create a preprocessing function to extract out the proper logits from the model output
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
logits = logits[0]
return logits.argmax(dim=-1)
# Prepare the trainer and start training
training_args = TrainingArguments(
output_dir=output_dir,
evaluation_strategy="steps",
eval_accumulation_steps=1,
learning_rate=learning_rate,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=eval_batch_size,
gradient_checkpointing=True,
half_precision_backend=True,
fp16=True,
adam_beta1=0.9,
adam_beta2=0.95,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_train_epochs,
warmup_steps=100,
eval_steps=eval_steps,
save_steps=save_steps,
load_best_model_at_end=True,
logging_steps=50,
deepspeed="./ds_config_gptj.json",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=dev_dataset,
compute_metrics=compute_metrics,
data_collator=default_data_collator,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()
trainer.save_model(output_dir)