File size: 2,514 Bytes
0f9b91a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import time
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
class GptHumorTrainer:
def __init__(self, silent=False) -> None:
start_time = time.perf_counter()
self.tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
self.model = GPT2LMHeadModel.from_pretrained(self.local_file_path("SaveState"))
self.model.eval()
if not silent:
print(f"Model Loading Took {time.perf_counter()-start_time} Seconds")
def local_file_path(self, path):
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
def train(self, train_file, epochs=3):
device = torch.device("cpu")
self.model.to(device)
# Prepare the dataset
train_dataset = TextDataset(
tokenizer=self.tokenizer,
file_path=train_file,
block_size=128,
)
# We use a special data collator for language modeling tasks
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False,
)
for epoch in range(epochs):
# Define the training arguments for each epoch
training_args = TrainingArguments(
output_dir=f"./results/epoch_{epoch+1}", # The output directory for this epoch
overwrite_output_dir=True, # Overwrite the content of the output directory
num_train_epochs=3, # Train for 1 epoch at a time
per_device_train_batch_size=3, # Batch size for training
save_steps=-1, # Save model after each epoch
save_total_limit=None, # No limit on the total amount of checkpoints
prediction_loss_only=True, # Focus on the prediction loss only
)
# Initialize the Trainer
trainer = Trainer(
model=self.model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
# Train the model for one epoch
trainer.train()
# Save the model after each epoch
self.model.save_pretrained(self.local_file_path("SaveState"))
if __name__ == "__main__":
humor_trainer = GptHumorTrainer()
humor_trainer.train(humor_trainer.local_file_path("TrainData.txt"), epochs=5) # Replace with the path to your training file
|