Spaces:
Runtime error
Runtime error
import random | |
import evaluate | |
import numpy as np | |
import torch | |
from summarize_dataset import TLDRDataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
Trainer, | |
TrainingArguments, | |
default_data_collator, | |
) | |
def set_seed(seed_val=42): | |
random.seed(seed_val) | |
np.random.seed(seed_val) | |
torch.manual_seed(seed_val) | |
torch.cuda.manual_seed_all(seed_val) | |
if __name__ == "__main__": | |
output_dir = "gptj-supervised-summarize-checkpoint" | |
train_batch_size = 16 | |
gradient_accumulation_steps = 1 | |
learning_rate = 1e-5 | |
eval_batch_size = 1 | |
eval_steps = 500 | |
max_input_length = 550 | |
save_steps = 1000 | |
num_train_epochs = 5 | |
random.seed(42) | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False) | |
tokenizer.pad_token = tokenizer.eos_token | |
model.resize_token_embeddings(len(tokenizer)) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model.config.end_token_id = tokenizer.eos_token_id | |
model.config.pad_token_id = model.config.eos_token_id | |
# Set up the datasets | |
data_path = "CarperAI/openai_summarize_tldr" | |
train_dataset = TLDRDataset( | |
data_path, | |
tokenizer, | |
"train", | |
max_length=max_input_length, | |
) | |
dev_dataset = TLDRDataset( | |
data_path, | |
tokenizer, | |
"valid", | |
max_length=max_input_length, | |
) | |
# Set up the metric | |
rouge = evaluate.load("rouge") | |
def compute_metrics(eval_preds): | |
labels_ids = eval_preds.label_ids | |
pred_ids = eval_preds.predictions | |
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) | |
result = rouge.compute(predictions=pred_str, references=label_str) | |
return result | |
# Create a preprocessing function to extract out the proper logits from the model output | |
def preprocess_logits_for_metrics(logits, labels): | |
if isinstance(logits, tuple): | |
logits = logits[0] | |
return logits.argmax(dim=-1) | |
# Prepare the trainer and start training | |
training_args = TrainingArguments( | |
output_dir=output_dir, | |
evaluation_strategy="steps", | |
eval_accumulation_steps=1, | |
learning_rate=learning_rate, | |
per_device_train_batch_size=train_batch_size, | |
per_device_eval_batch_size=eval_batch_size, | |
gradient_checkpointing=True, | |
half_precision_backend=True, | |
fp16=True, | |
adam_beta1=0.9, | |
adam_beta2=0.95, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
num_train_epochs=num_train_epochs, | |
warmup_steps=100, | |
eval_steps=eval_steps, | |
save_steps=save_steps, | |
load_best_model_at_end=True, | |
logging_steps=50, | |
deepspeed="./ds_config_gptj.json", | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=dev_dataset, | |
compute_metrics=compute_metrics, | |
data_collator=default_data_collator, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
) | |
trainer.train() | |
trainer.save_model(output_dir) | |