File size: 4,217 Bytes
b409f35
9fe41c7
b409f35
 
9fe41c7
b409f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fe41c7
b409f35
9fe41c7
 
 
b409f35
9fe41c7
b409f35
 
 
 
 
 
 
9fe41c7
b409f35
 
 
 
 
 
 
 
9fe41c7
b409f35
 
9fe41c7
 
b409f35
 
 
9fe41c7
 
 
 
b409f35
 
 
 
 
 
5724fbf
b409f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fe41c7
b409f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fe41c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from sklearn.model_selection import train_test_split
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

# Load multilingual BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

# Example dataset in Hindi
dataset = [
    {"customer_input": "मैंने गलत सामान प्राप्त किया है, क्या मुझे रिफंड मिल सकता है?", 
     "agent_response": "मुझे खेद है कि आपको परेशानी हो रही है। कृपया अपना ऑर्डर नंबर प्रदान करें ताकि मैं आपकी सहायता कर सकूं।", 
     "label": "compliant"},
    {"customer_input": "मेरा ऑर्डर देरी से आ रहा है, मुझे क्या करना चाहिए?", 
     "agent_response": "कृपया धैर्य रखें, हम आपकी समस्या को जल्द हल करेंगे।", 
     "label": "non-compliant"},
    # Add more examples as needed
]

# Split dataset into training and evaluation sets
train_data, eval_data = train_test_split(dataset, test_size=0.2)

# Tokenizer function that also keeps the label in the dataset
def tokenize_function(example):
    tokenized_example = tokenizer(example['customer_input'], example['agent_response'], padding='max_length', truncation=True, max_length=512)
    tokenized_example['label'] = 1 if example['label'] == 'compliant' else 0  # Convert 'compliant' to 1 and 'non-compliant' to 0
    return tokenized_example

# Apply tokenization to the entire dataset
train_data = [tokenize_function(x) for x in train_data]
eval_data = [tokenize_function(x) for x in eval_data]

# Dataset class
class DialogueDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        self.labels = [item['label'] for item in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_ids = torch.tensor(item['input_ids'])
        attention_mask = torch.tensor(item['attention_mask'])
        label = torch.tensor(item['label'])
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label}


# Create PyTorch datasets
train_dataset = DialogueDataset(train_data)
eval_dataset = DialogueDataset(eval_data)


# Load multilingual BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2)

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",  # Evaluate every epoch
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir='./logs',
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Fine-tune the model
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print("Evaluation Results:", eval_results)


def check_compliance(customer_input, agent_response):
    inputs = tokenizer(customer_input, agent_response, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=-1).item()

    if predicted_class == 1:
        return "Compliant"
    else:
        return "Non-Compliant"

# Test the model with new data
test_customer_input = "मेरे पास अकाउंट एक्सेस नहीं हो रहा है। क्या आप मेरी मदद कर सकते हैं?"
test_agent_response = "मुझे खेद है। कृपया अपना उपयोगकर्ता नाम साझा करें, ताकि मैं आपकी सहायता कर सकूं।"
result = check_compliance(test_customer_input, test_agent_response)
print(result)