Spaces:
Runtime error
Runtime error
from sklearn.model_selection import train_test_split | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments | |
# Load multilingual BERT tokenizer | |
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased") | |
# Example dataset in Hindi | |
dataset = [ | |
{"customer_input": "मैंने गलत सामान प्राप्त किया है, क्या मुझे रिफंड मिल सकता है?", | |
"agent_response": "मुझे खेद है कि आपको परेशानी हो रही है। कृपया अपना ऑर्डर नंबर प्रदान करें ताकि मैं आपकी सहायता कर सकूं।", | |
"label": "compliant"}, | |
{"customer_input": "मेरा ऑर्डर देरी से आ रहा है, मुझे क्या करना चाहिए?", | |
"agent_response": "कृपया धैर्य रखें, हम आपकी समस्या को जल्द हल करेंगे।", | |
"label": "non-compliant"}, | |
# Add more examples as needed | |
] | |
# Split dataset into training and evaluation sets | |
train_data, eval_data = train_test_split(dataset, test_size=0.2) | |
# Tokenizer function that also keeps the label in the dataset | |
def tokenize_function(example): | |
tokenized_example = tokenizer(example['customer_input'], example['agent_response'], padding='max_length', truncation=True, max_length=512) | |
tokenized_example['label'] = 1 if example['label'] == 'compliant' else 0 # Convert 'compliant' to 1 and 'non-compliant' to 0 | |
return tokenized_example | |
# Apply tokenization to the entire dataset | |
train_data = [tokenize_function(x) for x in train_data] | |
eval_data = [tokenize_function(x) for x in eval_data] | |
# Dataset class | |
class DialogueDataset(torch.utils.data.Dataset): | |
def __init__(self, data): | |
self.data = data | |
self.labels = [item['label'] for item in data] | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
input_ids = torch.tensor(item['input_ids']) | |
attention_mask = torch.tensor(item['attention_mask']) | |
label = torch.tensor(item['label']) | |
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label} | |
# Create PyTorch datasets | |
train_dataset = DialogueDataset(train_data) | |
eval_dataset = DialogueDataset(eval_data) | |
# Load multilingual BERT model for sequence classification | |
model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", | |
evaluation_strategy="epoch", # Evaluate every epoch | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=8, | |
num_train_epochs=2, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
) | |
# Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Evaluate the model | |
eval_results = trainer.evaluate() | |
print("Evaluation Results:", eval_results) | |
def check_compliance(customer_input, agent_response): | |
inputs = tokenizer(customer_input, agent_response, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = torch.argmax(logits, dim=-1).item() | |
if predicted_class == 1: | |
return "Compliant" | |
else: | |
return "Non-Compliant" | |
# Test the model with new data | |
test_customer_input = "मेरे पास अकाउंट एक्सेस नहीं हो रहा है। क्या आप मेरी मदद कर सकते हैं?" | |
test_agent_response = "मुझे खेद है। कृपया अपना उपयोगकर्ता नाम साझा करें, ताकि मैं आपकी सहायता कर सकूं।" | |
result = check_compliance(test_customer_input, test_agent_response) | |
print(result) | |