Spaces:
Runtime error
Runtime error
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments | |
from datasets import Dataset | |
import torch | |
from sklearn.metrics import accuracy_score | |
# Load the CSV data | |
df = pd.read_csv("dishTV_customer_service_with_address_and_rules_accurate_v2.csv") | |
# Clean the dataset by dropping rows with NaN values in important columns | |
df = df.dropna(subset=['Agent Utterance', 'Customer Utterance', 'Category', 'Rule Followed', 'Question Asked', 'Question Answered']) | |
# Merge Agent and Customer Utterances into a single conversation text | |
df['Conversation'] = df['Agent Utterance'] + " " + df['Customer Utterance'] | |
# Define mappings for categories and labels | |
category_mapping = { | |
'Greeting': 0, | |
'Addressing Issue': 1, | |
'Feedback': 2, | |
'Resolution': 3, | |
'Address': 4 | |
} | |
# Map categories to numeric labels | |
df['Category'] = df['Category'].map(category_mapping) | |
# Rule validation functions to check whether each rule was followed by the agent and whether the customer answered | |
def validate_rules(row): | |
missed_rules = [] | |
missed_answers = [] | |
# Rule checks for the agent | |
if 'hello' not in row['Agent Utterance'].lower() and 'hi' not in row['Agent Utterance'].lower(): | |
missed_rules.append('Greeting') | |
if 'address' not in row['Agent Utterance'].lower(): | |
missed_rules.append('Address') | |
if 'feedback' not in row['Agent Utterance'].lower(): | |
missed_rules.append('Feedback') | |
if 'resolved' not in row['Agent Utterance'].lower() and 'fix' not in row['Agent Utterance'].lower(): | |
missed_rules.append('Resolution') | |
# Check if customer answered relevant questions | |
if 'address' in row['Agent Utterance'].lower() and ('address' not in row['Customer Utterance'].lower()): | |
missed_answers.append('Customer Address Answer') | |
if 'feedback' in row['Agent Utterance'].lower() and ('yes' not in row['Customer Utterance'].lower() and 'no' not in row['Customer Utterance'].lower()): | |
missed_answers.append('Customer Feedback Answer') | |
# Returning the result as compliant or non-compliant | |
if len(missed_rules) == 0 and len(missed_answers) == 0: | |
return 1, [] # Compliant | |
else: | |
return 0, missed_rules + missed_answers # Non-Compliant | |
# Apply the rule validation to each row | |
df[['Compliant', 'Missed Rules/Answers']] = df.apply(lambda row: pd.Series(validate_rules(row)), axis=1) | |
# Splitting the data into training and validation datasets | |
train_texts, val_texts, train_labels, val_labels = train_test_split(df['Conversation'].tolist(), df['Compliant'].tolist(), test_size=0.2) | |
# Load pre-trained BERT tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# Tokenize the input data | |
def tokenize_function(examples): | |
return tokenizer(examples, padding="max_length", truncation=True, max_length=512) | |
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=512) | |
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=512) | |
# Create Dataset objects for PyTorch | |
train_dataset = Dataset.from_dict({ | |
'input_ids': train_encodings['input_ids'], | |
'attention_mask': train_encodings['attention_mask'], | |
'labels': train_labels | |
}) | |
val_dataset = Dataset.from_dict({ | |
'input_ids': val_encodings['input_ids'], | |
'attention_mask': val_encodings['attention_mask'], | |
'labels': val_labels | |
}) | |
# Load pre-trained BERT model | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) # Binary classification (Compliant vs Non-Compliant) | |
# Define compute_metrics function for evaluation | |
def compute_metrics(p): | |
predictions, labels = p | |
predictions = torch.argmax(predictions, axis=-1) | |
return {'accuracy': accuracy_score(labels, predictions)} | |
# Define training arguments for the Trainer | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy='epoch', | |
learning_rate=2e-5, | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
) | |
# Initialize Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
compute_metrics=compute_metrics | |
) | |
# Train the model | |
trainer.train() | |
# Evaluate the model | |
eval_results = trainer.evaluate() | |
print(f"Evaluation results: {eval_results}") | |
# Save the trained model | |
model.save_pretrained('./dishTV_bert_model') | |
tokenizer.save_pretrained('./dishTV_bert_model') | |
# Testing the model with an example | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
outputs = model(**inputs) | |
prediction = torch.argmax(outputs.logits, dim=-1) | |
return prediction.item() | |
# Example test | |
test_text = "Hello! I need help with my DishTV subscription." | |
prediction = predict(test_text) | |
predicted_compliance = "Compliant" if prediction == 1 else "Non-Compliant" | |
print(f"Predicted Compliance: {predicted_compliance}") | |