Spaces:
Runtime error
Runtime error
import pandas as pd | |
from datasets import load_dataset | |
from sklearn.utils import resample | |
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq | |
from torch.utils.data import Dataset | |
import gradio as gr | |
# Step 1: Load the dataset from Hugging Face (Customer Support dataset) | |
dataset = load_dataset("bitext/Bitext-customer-support-llm-chatbot-training-dataset") | |
# Step 2: Sample a subset (20% of the dataset for testing) | |
sampled_data = dataset["train"].shuffle(seed=42).select([i for i in range(int(len(dataset["train"]) * 0.2))]) | |
# Convert to DataFrame and display some rows | |
sampled_data_df = pd.DataFrame(sampled_data) | |
df_limited = sampled_data_df[['instruction', 'response']] | |
# Step 3: Handle class imbalance using oversampling | |
df_majority = df_limited[df_limited['response'] == df_limited['response'].mode()[0]] | |
df_minority = df_limited[df_limited['response'] != df_limited['response'].mode()[0]] | |
df_minority_upsampled = resample(df_minority, replace=True, n_samples=len(df_majority), random_state=42) | |
df_balanced = pd.concat([df_majority, df_minority_upsampled]) | |
# Step 4: Load the pre-trained DialoGPT model and tokenizer | |
model_name = "microsoft/DialoGPT-medium" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Check if pad_token is None, and set it to eos_token if it is | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Step 5: Preprocess the data for training | |
def preprocess_data_for_training(df, max_length=512): | |
inputs = tokenizer(df['instruction'].tolist(), padding=True, truncation=True, max_length=max_length, return_tensors="pt") | |
targets = tokenizer(df['response'].tolist(), padding=True, truncation=True, max_length=max_length, return_tensors="pt") | |
input_ids = inputs['input_ids'] | |
target_ids = targets['input_ids'] | |
if input_ids.shape[1] != target_ids.shape[1]: | |
target_ids = target_ids[:, :input_ids.shape[1]] | |
target_ids = target_ids.roll(1, dims=1) | |
target_ids[:, 0] = tokenizer.pad_token_id | |
return {'input_ids': input_ids, 'attention_mask': inputs['attention_mask'], 'labels': target_ids} | |
preprocessed_data = preprocess_data_for_training(df_balanced) | |
# Step 6: Create a custom dataset class for fine-tuning | |
class ChatbotDataset(Dataset): | |
def __init__(self, inputs, targets): | |
self.inputs = inputs | |
self.targets = targets | |
def __len__(self): | |
return len(self.inputs['input_ids']) | |
def __getitem__(self, idx): | |
return { | |
'input_ids': self.inputs['input_ids'][idx], | |
'attention_mask': self.inputs['attention_mask'][idx], | |
'labels': self.targets['input_ids'][idx] | |
} | |
train_dataset = ChatbotDataset(preprocessed_data, preprocessed_data) | |
# Step 7: Set up training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=3, | |
per_device_train_batch_size=4, | |
save_steps=10_000, | |
save_total_limit=2, | |
logging_dir='./logs', | |
logging_steps=500, | |
) | |
# Step 8: Initialize Trainer | |
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
tokenizer=tokenizer, | |
data_collator=data_collator | |
) | |
# Step 9: Fine-tune the model | |
trainer.train() | |
# Save the trained model and tokenizer | |
model.save_pretrained("./trained_model") | |
tokenizer.save_pretrained("./trained_model") | |
# Optional: Test the chatbot after training | |
def generate_response(input_text): | |
inputs = tokenizer(input_text, return_tensors="pt") | |
outputs = model.generate(inputs['input_ids'], max_length=50, pad_token_id=tokenizer.eos_token_id) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Gradio Interface | |
def chatbot_interface(input_text): | |
return generate_response(input_text) | |
iface = gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", live=True) | |
iface.launch() | |