NavyaNayer commited on
Commit
0258268
·
verified ·
1 Parent(s): a366240

Delete predict_intent.py

Browse files
Files changed (1) hide show
  1. predict_intent.py +0 -48
predict_intent.py DELETED
@@ -1,48 +0,0 @@
1
- import torch
2
- from transformers import BertTokenizer, BertForSequenceClassification
3
- from datasets import load_dataset
4
- from collections import Counter
5
-
6
- # Check for CUDA
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
- # Load dataset and get correct label names
10
- dataset = load_dataset("clinc_oos", "plus")
11
- label_names = dataset["train"].features["intent"].names # Ensure correct order
12
-
13
- # Debugging check
14
- print(f"Total labels: {len(label_names)}") # Should print 151
15
- print("Sample labels:", label_names[:10]) # Print first 10 labels
16
-
17
- # Load the trained model
18
- num_labels = len(label_names) # Should be 151
19
- model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
20
- model.load_state_dict(torch.load("intent_classifier.pth", map_location=device))
21
- model.to(device)
22
- model.eval()
23
-
24
- # Load tokenizer
25
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
26
-
27
- def predict_intent(sentence):
28
- inputs = tokenizer(sentence, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
29
- inputs = {key: val.to(device) for key, val in inputs.items()}
30
-
31
- with torch.no_grad():
32
- outputs = model(**inputs)
33
- predicted_class = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]
34
-
35
- if predicted_class >= len(label_names): # Prevent out-of-range errors
36
- print(f"Warning: Predicted class {predicted_class} is out of range!")
37
- return predicted_class, "Unknown Label"
38
-
39
- return predicted_class, label_names[predicted_class]
40
-
41
- # Example usage
42
- sentence = "I need to attend a meeting but so tired but important"
43
- predicted_intent, predicted_label_name = predict_intent(sentence)
44
- print(f"Predicted intent for '{sentence}': {predicted_intent} ({predicted_label_name})")
45
-
46
- # # Fix: Count labels correctly from dataset["train"]
47
- # label_counts = Counter([label_names[label] for label in dataset["train"]["intent"]])
48
- # print("Label distribution:", label_counts) # Print top 10 most common labels