ketanchaudhary88 commited on
Commit
b409f35
·
verified ·
1 Parent(s): b0b0870

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sklearn.model_selection import train_test_split
3
+ from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
4
+
5
+ # Load multilingual BERT tokenizer and model
6
+ tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
7
+ model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2)
8
+
9
+ # Example dataset in Hindi
10
+ dataset = [
11
+ {"customer_input": "मैंने गलत सामान प्राप्त किया है, क्या मुझे रिफंड मिल सकता है?",
12
+ "agent_response": "मुझे खेद है कि आपको परेशानी हो रही है। कृपया अपना ऑर्डर नंबर प्रदान करें ताकि मैं आपकी सहायता कर सकूं।",
13
+ "label": "compliant"},
14
+ {"customer_input": "मेरा ऑर्डर देरी से आ रहा है, मुझे क्या करना चाहिए?",
15
+ "agent_response": "कृपया धैर्य रखें, हम आपकी समस्या को जल्द हल करेंगे।",
16
+ "label": "non-compliant"},
17
+ # Add more examples as needed
18
+ ]
19
+
20
+ # Split dataset into training and evaluation sets
21
+ train_data, eval_data = train_test_split(dataset, test_size=0.2)
22
+
23
+ # Tokenization
24
+ def tokenize_function(example):
25
+ return tokenizer(example['customer_input'], example['agent_response'], padding='max_length', truncation=True, max_length=512)
26
+
27
+ train_data = [tokenize_function(x) for x in train_data]
28
+ eval_data = [tokenize_function(x) for x in eval_data]
29
+
30
+ # Dataset class
31
+ class DialogueDataset(torch.utils.data.Dataset):
32
+ def __init__(self, data):
33
+ self.data = data
34
+ self.labels = [1 if item["label"] == "compliant" else 0 for item in data]
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+ item = self.data[idx]
41
+ input_ids = torch.tensor(item['input_ids'])
42
+ attention_mask = torch.tensor(item['attention_mask'])
43
+ label = torch.tensor(self.labels[idx])
44
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label}
45
+
46
+ train_dataset = DialogueDataset(train_data)
47
+ eval_dataset = DialogueDataset(eval_data)
48
+
49
+ # Training arguments
50
+ training_args = TrainingArguments(
51
+ output_dir="./results",
52
+ evaluation_strategy="epoch", # Evaluate every epoch
53
+ per_device_train_batch_size=8,
54
+ per_device_eval_batch_size=8,
55
+ num_train_epochs=2,
56
+ weight_decay=0.01,
57
+ logging_dir='./logs',
58
+ )
59
+
60
+ # Trainer
61
+ trainer = Trainer(
62
+ model=model,
63
+ args=training_args,
64
+ train_dataset=train_dataset,
65
+ eval_dataset=eval_dataset,
66
+ )
67
+
68
+ # Fine-tune the model
69
+ trainer.train()
70
+
71
+ # Evaluate the model
72
+ eval_results = trainer.evaluate()
73
+ print("Evaluation Results:", eval_results)
74
+
75
+ # Inference function
76
+ def check_compliance(customer_input, agent_response):
77
+ inputs = tokenizer(customer_input, agent_response, return_tensors="pt", padding=True, truncation=True, max_length=512)
78
+ with torch.no_grad():
79
+ outputs = model(**inputs)
80
+ logits = outputs.logits
81
+ predicted_class = torch.argmax(logits, dim=-1).item()
82
+
83
+ if predicted_class == 1:
84
+ return "Compliant"
85
+ else:
86
+ return "Non-Compliant"
87
+
88
+ # Test the model with new data
89
+ test_customer_input = "मेरे पास अकाउंट एक्सेस नहीं हो रहा है। क्या आप मेरी मदद कर सकते हैं?"
90
+ test_agent_response = "मुझे खेद है। कृपया अपना उपयोगकर्ता नाम साझा करें, ताकि मैं आपकी सहायता कर सकूं।"
91
+ result = check_compliance(test_customer_input, test_agent_response)
92
+ print(result)