distilgpt2-quotes / generator1.py
adldtd's picture
Uploaded project
ebe8d8e
from transformers import AutoTokenizer
from transformers import AutoModelWithLMHead
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
data = load_dataset("json", data_files = "./authors_all_CUT.json")
data = data["train"].train_test_split(test_size = 0.10)
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token #Quick fix to an issue; DistilGPT2 does not include a padding token
def tokenize_datasets(data_set):
return tokenizer(data_set["text"], padding = False, truncation = True)
BATCH_SIZE = 8
data = data.map(tokenize_datasets, batched = True, batch_size = BATCH_SIZE) #Tokenize and batch all text
FOLDER_NAME = "./distilgpt2_quotes.TRANS"
model = AutoModelWithLMHead.from_pretrained("distilgpt2")
#model = AutoModelWithLMHead.from_pretrained(FOLDER_NAME)
collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
EPOCHS = 5
training_args = TrainingArguments(FOLDER_NAME, overwrite_output_dir = True, num_train_epochs = EPOCHS, per_device_train_batch_size = BATCH_SIZE, per_device_eval_batch_size = BATCH_SIZE, eval_steps = 400, save_steps = 800)
trainer = Trainer(model, args = training_args, data_collator = collator, train_dataset = data["train"], eval_dataset = data["test"])
trainer.train()
trainer.save_model()