Spaces:
Runtime error
Runtime error
File size: 1,993 Bytes
bc9b495 6dd6063 79e15e5 38ed118 bc9b495 6dd6063 bc9b495 6dd6063 bc9b495 6dd6063 bc9b495 f7cac21 38ed118 bc9b495 208accd f7cac21 208accd bc9b495 208accd bc9b495 6dd6063 bc9b495 6dd6063 bc9b495 6dd6063 bc9b495 6dd6063 bc9b495 0224c3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import torch
import os
import gradio as gr
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
with open("normans_wikipedia.txt", "r", encoding="utf-8") as file:
data = file.read()
output_dir = "./normans_fine-tuned"
os.makedirs(output_dir, exist_ok=True)
input_ids = tokenizer.encode(data, return_tensors="pt")
dataset = TextDataset(
tokenizer=tokenizer,
file_path="normans_wikipedia.txt",
block_size=512,
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=1,
save_steps=10_000,
save_total_limit=2,
logging_dir=output_dir,
logging_steps=100,
report_to=[],
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
try:
trainer.train()
except KeyboardInterrupt:
print("Training interrupted by user.")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
fine_tuned_model = GPT2LMHeadModel.from_pretrained(output_dir)
def generate_response(user_input):
user_input_ids = tokenizer.encode(user_input, return_tensors="pt")
generated_output = fine_tuned_model.generate(
user_input_ids,
max_length=100,
num_beams=5,
no_repeat_ngram_size=2,
top_k=50,
top_p=0.90,
temperature=0.9
)
chatbot_response = tokenizer.decode(
generated_output[0], skip_special_tokens=True)
return "Chatbot: " + chatbot_response
iface = gr.Interface(
fn=generate_response,
inputs="text",
outputs="text",
live=True
)
iface.launch() |