kkngan commited on
Commit
6793d9c
·
verified ·
1 Parent(s): df43922

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +156 -0
  2. front_page_image.jpg +0 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_mic_recorder import mic_recorder
3
+ from transformers import pipeline
4
+ import torch
5
+ from transformers import BertTokenizer, BertForSequenceClassification
6
+ import numpy as np
7
+ import pandas as pd
8
+ import time
9
+ import altair as alt
10
+
11
+
12
+ def callback():
13
+ if st.session_state.my_recorder_output:
14
+ audio_bytes = st.session_state.my_recorder_output['bytes']
15
+ st.audio(audio_bytes)
16
+
17
+
18
+ @st.cache_resource
19
+ def load_text_to_speech_model(model="openai/whisper-base"):
20
+ pipe = pipeline("automatic-speech-recognition", model=model)
21
+ return pipe
22
+
23
+
24
+ def translate(inputs, model="openai/whisper-base"):
25
+ pipe = pipeline("automatic-speech-recognition", model=model)
26
+ translate_result = pipe(inputs, generate_kwargs={'task': 'translate'})
27
+ return translate_result['text']
28
+
29
+
30
+ @st.cache_resource
31
+ def load_classification_model():
32
+ PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification"
33
+ # model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8)
34
+ # tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM)
35
+ tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True)
36
+ model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
37
+ num_labels=8)
38
+ return model, tokenizer
39
+
40
+
41
+ def predict(text, model, tokenizer):
42
+ lookup_key ={0: 'Hardware',
43
+ 1: 'Access',
44
+ 2: 'Miscellaneous',
45
+ 3: 'HR Support',
46
+ 4: 'Purchase',
47
+ 5: 'Administrative rights',
48
+ 6: 'Storage',
49
+ 7: 'Internal Project'}
50
+ inputs = tokenizer(text,
51
+ padding = True,
52
+ truncation = True,
53
+ return_tensors='pt')
54
+ outputs = model(**inputs)
55
+ predicted_class_id = outputs.logits.argmax().item()
56
+ predicted_label = lookup_key.get(predicted_class_id)
57
+ probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
58
+ return predicted_label, predicted_class_id, probability
59
+
60
+
61
+ def display_result(translate_text, prediction, predicted_class_id, probability):
62
+
63
+ category = ('Hardware',
64
+ 'Access',
65
+ 'Miscellaneous',
66
+ 'HR Support',
67
+ 'Purchase',
68
+ 'Administrative rights',
69
+ 'Storage',
70
+ 'Internal Project')
71
+
72
+ # Show translated text and prediction
73
+ st.markdown('<font color="purple"><b>Text:</b></font>', unsafe_allow_html=True)
74
+ st.write(f'{translate_text}')
75
+ st.write(f'\n')
76
+ st.write(f'\n')
77
+ st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
78
+ st.write(f'{prediction}')
79
+ st.write(f'\n')
80
+ st.write(f'\n')
81
+
82
+ # Show Probability of each Service Category
83
+
84
+ probability = np.array(probability[0])
85
+ df = pd.DataFrame({'Category': category, 'Probability (%)': probability * 100})
86
+ df['Probability (%)'] = df['Probability (%)'].apply(lambda x: round(x, 2))
87
+ base = alt.Chart(df).encode(
88
+ x='Probability (%)',
89
+ y=alt.Y('Category').sort('-x'),
90
+ tooltip=['Category',alt.Tooltip('Probability (%)', format=",.2f")],
91
+ text='Probability (%)'
92
+ ).properties(title="Probability of each Service Category")
93
+ chart = base.mark_bar() + base.mark_text(align='left', dx=2)
94
+ st.altair_chart(chart, use_container_width=True)
95
+
96
+
97
+ def main():
98
+ # define parameters
99
+ image_path = 'front_page_image.jpg'
100
+ model_options = ["openai/whisper-base", "openai/whisper-large-v3"]
101
+ input_options = ["Start a recording", "Upload an audio", "Enter a transcript"]
102
+
103
+ # st.cache_resource.clear()
104
+ st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="🤖",)
105
+ st.markdown('<b>🤖 Welcome to IT Service Classification Assistant!!! 🤖</b>', unsafe_allow_html=True)
106
+ st.write(f'\n')
107
+ st.write(f'\n')
108
+
109
+ with st.sidebar:
110
+ st.image(image_path , use_column_width=True)
111
+ text_to_speech_model = st.selectbox("Pick select a speech to text model", model_options)
112
+ options = st.selectbox("Pick select an input method", input_options)
113
+
114
+ # start a recording
115
+ if options == input_options[0]:
116
+ audio = mic_recorder(key='my_recorder', callback=callback)
117
+
118
+ # Upload an audio
119
+ elif options == input_options[1]:
120
+ audio = st.file_uploader("Please upload an audio", type=["wav", "mp3"])
121
+
122
+ # Enter a transcript
123
+ else:
124
+ text = st.text_area("Please input the transcript (Only support English)")
125
+
126
+ button = st.button('Submit')
127
+
128
+ if button:
129
+ with st.spinner(text="Loading... It may take a while if you are running the app for the first time."):
130
+
131
+ start_time = time.time()
132
+
133
+ # get inputs
134
+ if options == input_options[0]:
135
+ translate_text = translate(inputs=audio["bytes"], model=text_to_speech_model)
136
+
137
+ elif options == input_options[1]:
138
+ translate_text = translate(inputs=audio.getvalue(), model=text_to_speech_model)
139
+
140
+ else:
141
+ translate_text = text
142
+
143
+ model, tokenizer = load_classification_model()
144
+ prediction, predicted_class_id, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
145
+
146
+ end_time = time.time()
147
+
148
+ display_result(translate_text, prediction, predicted_class_id, probability)
149
+
150
+ st.write(f'\n')
151
+ st.write(f'\n')
152
+ st.markdown(f'*It took {(end_time-start_time):.2f} sec to process the input.', unsafe_allow_html=True)
153
+
154
+
155
+ if __name__ == '__main__':
156
+ main()
front_page_image.jpg ADDED