Spaces:
Runtime error
Runtime error
import torch | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments | |
from datasets import Dataset | |
import numpy as np | |
# Load your CSV file into a pandas DataFrame | |
df = pd.read_csv("dishTV_customer_service_with_address_and_rules_accurate_v2.csv") | |
# Print column names and first few rows to ensure data structure | |
print(df.columns) | |
print(df.head()) | |
# Create a conversation column by merging the agent's and customer's utterances | |
df['Conversation'] = df['Agent Utterance'] + " " + df['Customer Utterance'] | |
# Map labels for classification (Rule Followed, Question Asked, Question Answered) | |
df['Rule Followed'] = df['Rule Followed'].map({'Yes': 1, 'No': 0}) | |
df['Question Asked'] = df['Question Asked'].map({'Yes': 1, 'No': 0}) | |
df['Question Answered'] = df['Question Answered'].map({'Yes': 1, 'No': 0}) | |
# Split data into training and validation sets | |
train_texts, val_texts, train_labels, val_labels = train_test_split( | |
df['Conversation'].tolist(), | |
df[['Rule Followed', 'Question Asked', 'Question Answered']].values, | |
test_size=0.2 | |
) | |
# Initialize BERT tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Tokenize the conversations | |
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128) | |
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=128) | |
# Create PyTorch datasets | |
train_dataset = Dataset.from_dict({ | |
'input_ids': train_encodings['input_ids'], | |
'attention_mask': train_encodings['attention_mask'], | |
'labels': torch.tensor(train_labels, dtype=torch.float32) | |
}) | |
val_dataset = Dataset.from_dict({ | |
'input_ids': val_encodings['input_ids'], | |
'attention_mask': val_encodings['attention_mask'], | |
'labels': torch.tensor(val_labels, dtype=torch.float32) | |
}) | |
# Initialize the BERT model for multi-label classification (3 labels) | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3) | |
# Define the training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
eval_strategy='epoch', # Evaluate at the end of each epoch | |
save_strategy='epoch', # Save model at the end of each epoch | |
learning_rate=2e-5, | |
per_device_train_batch_size=32, | |
per_device_eval_batch_size=32, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
logging_steps=500, | |
save_steps=1000, # Optional, you can keep this if you want to save every N steps (only used if save_strategy is 'steps') | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
do_train=True, | |
do_eval=True | |
) | |
# Trainer setup | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
compute_metrics=lambda p: { | |
'accuracy': np.mean(np.all(p.predictions.argmax(axis=-1) == p.label_ids, axis=1)) | |
} | |
) | |
# Start training | |
trainer.train() | |
# Evaluate the model | |
eval_results = trainer.evaluate() | |
print(f"Evaluation results: {eval_results}") | |
# Define a new conversation for testing | |
new_conversation = ["Hello! How can I assist you today? I just wanted to check the status of my account."] | |
# Tokenize the new conversation | |
test_encodings = tokenizer(new_conversation, truncation=True, padding=True, max_length=512, return_tensors='pt') | |
# Make predictions | |
with torch.no_grad(): | |
model.eval() | |
outputs = model(**test_encodings) | |
predictions = torch.sigmoid(outputs.logits).cpu().numpy() # Sigmoid for multi-label classification | |
# Display predictions | |
print(f"Predictions (Rule Followed, Question Asked, Question Answered): {predictions}") | |
# Round predictions (since we are doing binary classification for each label) | |
predictions_rounded = np.round(predictions) | |
print(f"Predictions (rounded): {predictions_rounded}") | |