ketanchaudhary88 commited on
Commit
9fe41c7
·
verified ·
1 Parent(s): 9e338f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -1,10 +1,9 @@
1
- import torch
2
  from sklearn.model_selection import train_test_split
 
3
  from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
4
 
5
- # Load multilingual BERT tokenizer and model
6
  tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
7
- model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2)
8
 
9
  # Example dataset in Hindi
10
  dataset = [
@@ -20,10 +19,13 @@ dataset = [
20
  # Split dataset into training and evaluation sets
21
  train_data, eval_data = train_test_split(dataset, test_size=0.2)
22
 
23
- # Tokenization
24
  def tokenize_function(example):
25
- return tokenizer(example['customer_input'], example['agent_response'], padding='max_length', truncation=True, max_length=512)
 
 
26
 
 
27
  train_data = [tokenize_function(x) for x in train_data]
28
  eval_data = [tokenize_function(x) for x in eval_data]
29
 
@@ -31,7 +33,7 @@ eval_data = [tokenize_function(x) for x in eval_data]
31
  class DialogueDataset(torch.utils.data.Dataset):
32
  def __init__(self, data):
33
  self.data = data
34
- self.labels = [1 if item["label"] == "compliant" else 0 for item in data]
35
 
36
  def __len__(self):
37
  return len(self.data)
@@ -40,19 +42,25 @@ class DialogueDataset(torch.utils.data.Dataset):
40
  item = self.data[idx]
41
  input_ids = torch.tensor(item['input_ids'])
42
  attention_mask = torch.tensor(item['attention_mask'])
43
- label = torch.tensor(self.labels[idx])
44
  return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label}
45
 
 
 
46
  train_dataset = DialogueDataset(train_data)
47
  eval_dataset = DialogueDataset(eval_data)
48
 
 
 
 
 
49
  # Training arguments
50
  training_args = TrainingArguments(
51
  output_dir="./results",
52
  evaluation_strategy="epoch", # Evaluate every epoch
53
  per_device_train_batch_size=8,
54
  per_device_eval_batch_size=8,
55
- num_train_epochs=2,
56
  weight_decay=0.01,
57
  logging_dir='./logs',
58
  )
@@ -72,7 +80,7 @@ trainer.train()
72
  eval_results = trainer.evaluate()
73
  print("Evaluation Results:", eval_results)
74
 
75
- # Inference function
76
  def check_compliance(customer_input, agent_response):
77
  inputs = tokenizer(customer_input, agent_response, return_tensors="pt", padding=True, truncation=True, max_length=512)
78
  with torch.no_grad():
@@ -89,4 +97,4 @@ def check_compliance(customer_input, agent_response):
89
  test_customer_input = "मेरे पास अकाउंट एक्सेस नहीं हो रहा है। क्या आप मेरी मदद कर सकते हैं?"
90
  test_agent_response = "मुझे खेद है। कृपया अपना उपयोगकर्ता नाम साझा करें, ताकि मैं आपकी सहायता कर सकूं।"
91
  result = check_compliance(test_customer_input, test_agent_response)
92
- print(result)
 
 
1
  from sklearn.model_selection import train_test_split
2
+ import torch
3
  from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
4
 
5
+ # Load multilingual BERT tokenizer
6
  tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
 
7
 
8
  # Example dataset in Hindi
9
  dataset = [
 
19
  # Split dataset into training and evaluation sets
20
  train_data, eval_data = train_test_split(dataset, test_size=0.2)
21
 
22
+ # Tokenizer function that also keeps the label in the dataset
23
  def tokenize_function(example):
24
+ tokenized_example = tokenizer(example['customer_input'], example['agent_response'], padding='max_length', truncation=True, max_length=512)
25
+ tokenized_example['label'] = 1 if example['label'] == 'compliant' else 0 # Convert 'compliant' to 1 and 'non-compliant' to 0
26
+ return tokenized_example
27
 
28
+ # Apply tokenization to the entire dataset
29
  train_data = [tokenize_function(x) for x in train_data]
30
  eval_data = [tokenize_function(x) for x in eval_data]
31
 
 
33
  class DialogueDataset(torch.utils.data.Dataset):
34
  def __init__(self, data):
35
  self.data = data
36
+ self.labels = [item['label'] for item in data]
37
 
38
  def __len__(self):
39
  return len(self.data)
 
42
  item = self.data[idx]
43
  input_ids = torch.tensor(item['input_ids'])
44
  attention_mask = torch.tensor(item['attention_mask'])
45
+ label = torch.tensor(item['label'])
46
  return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label}
47
 
48
+
49
+ # Create PyTorch datasets
50
  train_dataset = DialogueDataset(train_data)
51
  eval_dataset = DialogueDataset(eval_data)
52
 
53
+
54
+ # Load multilingual BERT model for sequence classification
55
+ model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=2)
56
+
57
  # Training arguments
58
  training_args = TrainingArguments(
59
  output_dir="./results",
60
  evaluation_strategy="epoch", # Evaluate every epoch
61
  per_device_train_batch_size=8,
62
  per_device_eval_batch_size=8,
63
+ num_train_epochs=3,
64
  weight_decay=0.01,
65
  logging_dir='./logs',
66
  )
 
80
  eval_results = trainer.evaluate()
81
  print("Evaluation Results:", eval_results)
82
 
83
+
84
  def check_compliance(customer_input, agent_response):
85
  inputs = tokenizer(customer_input, agent_response, return_tensors="pt", padding=True, truncation=True, max_length=512)
86
  with torch.no_grad():
 
97
  test_customer_input = "मेरे पास अकाउंट एक्सेस नहीं हो रहा है। क्या आप मेरी मदद कर सकते हैं?"
98
  test_agent_response = "मुझे खेद है। कृपया अपना उपयोगकर्ता नाम साझा करें, ताकि मैं आपकी सहायता कर सकूं।"
99
  result = check_compliance(test_customer_input, test_agent_response)
100
+ print(result)