Chidam Gopal commited on
Commit
94de7c5
1 Parent(s): 8e4c79a

intent classifier app

Browse files
Files changed (3) hide show
  1. app.py +42 -0
  2. infer_intent.py +64 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from infer_intent import IntentClassifier
4
+ import matplotlib.pyplot as plt
5
+
6
+ st.set_page_config(layout="wide")
7
+ st.title("Intent classifier")
8
+
9
+ @st.cache_resource
10
+ def get_intent_classifier():
11
+ cls = IntentClassifier()
12
+ return cls
13
+
14
+ cls = get_intent_classifier()
15
+ query = st.text_input("Enter a query", value="What is the weather today")
16
+ pred_result, proba_result = cls.find_intent(query)
17
+
18
+ st.markdown(f"prediction = :green[{pred_result}]")
19
+ keys = list(proba_result.keys())
20
+ values = list(proba_result.values())
21
+
22
+ # Creating the bar plot
23
+ fig, ax = plt.subplots()
24
+ ax.barh(keys, values)
25
+
26
+ # Adding labels and title
27
+ ax.set_xlabel('Intent')
28
+ ax.set_ylabel('Values')
29
+ ax.set_title('Intents probability score')
30
+
31
+ col1, col2 = st.columns([2,4])
32
+
33
+ with col1:
34
+ st.pyplot(fig)
35
+
36
+ with col2:
37
+ exp = st.expander("Explore training data")
38
+ with exp:
39
+ html_file = "reports/web_search_intents.html"
40
+ with open(html_file, 'r', encoding='utf-8') as f:
41
+ plotly_html = f.read()
42
+ components.html(plotly_html, height=900, width=900)
infer_intent.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+
4
+
5
+ class IntentClassifier:
6
+ def __init__(self):
7
+ self.id2label = {0: 'information_intent',
8
+ 1: 'yelp_intent',
9
+ 2: 'navigation_intent',
10
+ 3: 'travel_intent',
11
+ 4: 'purchase_intent',
12
+ 5: 'weather_intent',
13
+ 6: 'translation_intent',
14
+ 7: 'unknown'}
15
+ self.label2id = {label:id for id,label in self.id2label.items()}
16
+
17
+ self.tokenizer = AutoTokenizer.from_pretrained("chidamnat2002/intent_classifier")
18
+ self.intent_model = AutoModelForSequenceClassification.from_pretrained('chidamnat2002/intent_classifier',
19
+ num_labels=8,
20
+ torch_dtype=torch.bfloat16,
21
+ id2label=self.id2label,
22
+ label2id=self.label2id)
23
+
24
+ def find_intent(self, sequence, verbose=False):
25
+ inputs = self.tokenizer(sequence,
26
+ return_tensors="pt", # ONNX requires inputs in NumPy format
27
+ padding="max_length", # Pad to max length
28
+ truncation=True, # Truncate if the text is too long
29
+ max_length=64)
30
+
31
+ self.intent_model.eval()
32
+ with torch.no_grad():
33
+ outputs = self.intent_model(**inputs)
34
+ logits = outputs.logits
35
+ prediction = torch.argmax(logits, dim=1).item()
36
+ probabilities = torch.softmax(logits, dim=1)
37
+ rounded_probabilities = torch.round(probabilities, decimals=3)
38
+
39
+ pred_result = self.id2label[prediction]
40
+ proba_result = dict(zip(self.label2id.keys(), rounded_probabilities.tolist()[0]))
41
+ if verbose:
42
+ print(sequence + " -> " + pred_result)
43
+ print(proba_result, "\n")
44
+ return pred_result, proba_result
45
+
46
+
47
+ def main():
48
+ text_list = [
49
+ 'floor repair cost',
50
+ 'pet store near me',
51
+ 'who is the us president',
52
+ 'italian food',
53
+ 'sandwiches for lunch',
54
+ "cheese burger cost",
55
+ "What is the weather today",
56
+ "what is the capital of usa",
57
+ "cruise trip to carribean",
58
+ ]
59
+ cls = IntentClassifier()
60
+ for sequence in text_list:
61
+ cls.find_intent(sequence)
62
+
63
+ if __name__ == '__main__':
64
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.45.1
2
+ torch==2.4.1
3
+ streamlit==1.38.0
4
+ matplotlib==3.9.2