Bert / app.py
ketanchaudhary88's picture
Update app.py
5724fbf verified
raw
history blame
4.22 kB
from sklearn.model_selection import train_test_split
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
# Load multilingual BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
# Example dataset in Hindi
dataset = [
{"customer_input": "मैंने गलत सामान प्राप्त किया है, क्या मुझे रिफंड मिल सकता है?",
"agent_response": "मुझे खेद है कि आपको परेशानी हो रही है। कृपया अपना ऑर्डर नंबर प्रदान करें ताकि मैं आपकी सहायता कर सकूं।",
"label": "compliant"},
{"customer_input": "मेरा ऑर्डर देरी से आ रहा है, मुझे क्या करना चाहिए?",
"agent_response": "कृपया धैर्य रखें, हम आपकी समस्या को जल्द हल करेंगे।",
"label": "non-compliant"},
# Add more examples as needed
]
# Split dataset into training and evaluation sets
train_data, eval_data = train_test_split(dataset, test_size=0.2)
# Tokenizer function that also keeps the label in the dataset
def tokenize_function(example):
tokenized_example = tokenizer(example['customer_input'], example['agent_response'], padding='max_length', truncation=True, max_length=512)
tokenized_example['label'] = 1 if example['label'] == 'compliant' else 0 # Convert 'compliant' to 1 and 'non-compliant' to 0
return tokenized_example
# Apply tokenization to the entire dataset
train_data = [tokenize_function(x) for x in train_data]
eval_data = [tokenize_function(x) for x in eval_data]
# Dataset class
class DialogueDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
self.labels = [item['label'] for item in data]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
input_ids = torch.tensor(item['input_ids'])
attention_mask = torch.tensor(item['attention_mask'])
label = torch.tensor(item['label'])
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label}
# Create PyTorch datasets
train_dataset = DialogueDataset(train_data)
eval_dataset = DialogueDataset(eval_data)
# Load multilingual BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch", # Evaluate every epoch
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=2,
weight_decay=0.01,
logging_dir='./logs',
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Fine-tune the model
trainer.train()
# Evaluate the model
eval_results = trainer.evaluate()
print("Evaluation Results:", eval_results)
def check_compliance(customer_input, agent_response):
inputs = tokenizer(customer_input, agent_response, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
if predicted_class == 1:
return "Compliant"
else:
return "Non-Compliant"
# Test the model with new data
test_customer_input = "मेरे पास अकाउंट एक्सेस नहीं हो रहा है। क्या आप मेरी मदद कर सकते हैं?"
test_agent_response = "मुझे खेद है। कृपया अपना उपयोगकर्ता नाम साझा करें, ताकि मैं आपकी सहायता कर सकूं।"
result = check_compliance(test_customer_input, test_agent_response)
print(result)