Bert / app.py
ketanchaudhary88's picture
Update app.py
d64149e verified
raw
history blame
5.13 kB
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
import torch
from sklearn.metrics import accuracy_score
# Load the CSV data
df = pd.read_csv("dishTV_customer_service_with_address_and_rules_accurate_v2.csv")
# Clean the dataset by dropping rows with NaN values in important columns
df = df.dropna(subset=['Agent Utterance', 'Customer Utterance', 'Category', 'Rule Followed', 'Question Asked', 'Question Answered'])
# Merge Agent and Customer Utterances into a single conversation text
df['Conversation'] = df['Agent Utterance'] + " " + df['Customer Utterance']
# Define mappings for categories and labels
category_mapping = {
'Greeting': 0,
'Addressing Issue': 1,
'Feedback': 2,
'Resolution': 3,
'Address': 4
}
# Map categories to numeric labels
df['Category'] = df['Category'].map(category_mapping)
# Rule validation functions to check whether each rule was followed by the agent and whether the customer answered
def validate_rules(row):
missed_rules = []
missed_answers = []
# Rule checks for the agent
if 'hello' not in row['Agent Utterance'].lower() and 'hi' not in row['Agent Utterance'].lower():
missed_rules.append('Greeting')
if 'address' not in row['Agent Utterance'].lower():
missed_rules.append('Address')
if 'feedback' not in row['Agent Utterance'].lower():
missed_rules.append('Feedback')
if 'resolved' not in row['Agent Utterance'].lower() and 'fix' not in row['Agent Utterance'].lower():
missed_rules.append('Resolution')
# Check if customer answered relevant questions
if 'address' in row['Agent Utterance'].lower() and ('address' not in row['Customer Utterance'].lower()):
missed_answers.append('Customer Address Answer')
if 'feedback' in row['Agent Utterance'].lower() and ('yes' not in row['Customer Utterance'].lower() and 'no' not in row['Customer Utterance'].lower()):
missed_answers.append('Customer Feedback Answer')
# Returning the result as compliant or non-compliant
if len(missed_rules) == 0 and len(missed_answers) == 0:
return 1, [] # Compliant
else:
return 0, missed_rules + missed_answers # Non-Compliant
# Apply the rule validation to each row
df[['Compliant', 'Missed Rules/Answers']] = df.apply(lambda row: pd.Series(validate_rules(row)), axis=1)
# Splitting the data into training and validation datasets
train_texts, val_texts, train_labels, val_labels = train_test_split(df['Conversation'].tolist(), df['Compliant'].tolist(), test_size=0.2)
# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Tokenize the input data
def tokenize_function(examples):
return tokenizer(examples, padding="max_length", truncation=True, max_length=512)
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=512)
# Create Dataset objects for PyTorch
train_dataset = Dataset.from_dict({
'input_ids': train_encodings['input_ids'],
'attention_mask': train_encodings['attention_mask'],
'labels': train_labels
})
val_dataset = Dataset.from_dict({
'input_ids': val_encodings['input_ids'],
'attention_mask': val_encodings['attention_mask'],
'labels': val_labels
})
# Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) # Binary classification (Compliant vs Non-Compliant)
# Define compute_metrics function for evaluation
def compute_metrics(p):
predictions, labels = p
predictions = torch.argmax(predictions, axis=-1)
return {'accuracy': accuracy_score(labels, predictions)}
# Define training arguments for the Trainer
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
logging_dir='./logs',
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics
)
# Train the model
trainer.train()
# Evaluate the model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")
# Save the trained model
model.save_pretrained('./dishTV_bert_model')
tokenizer.save_pretrained('./dishTV_bert_model')
# Testing the model with an example
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = model(**inputs)
prediction = torch.argmax(outputs.logits, dim=-1)
return prediction.item()
# Example test
test_text = "Hello! I need help with my DishTV subscription."
prediction = predict(test_text)
predicted_compliance = "Compliant" if prediction == 1 else "Non-Compliant"
print(f"Predicted Compliance: {predicted_compliance}")