kkngan commited on
Commit
e5a5157
·
verified ·
1 Parent(s): 9eecc55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -22
app.py CHANGED
@@ -2,20 +2,22 @@ import streamlit as st
2
  from streamlit_mic_recorder import mic_recorder
3
  from transformers import pipeline
4
  import torch
5
- from transformers import BertTokenizer, BertForSequenceClassification
 
 
6
 
7
  def callback():
8
  if st.session_state.my_recorder_output:
9
  audio_bytes = st.session_state.my_recorder_output['bytes']
10
  st.audio(audio_bytes)
11
 
12
- def transcribe(upload):
13
- pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")
14
- result = pipe(upload, generate_kwargs={'task': 'transcribe'})
15
- print(result['text'])
16
- return result['text']
17
 
18
- def encode(docs, tokenizer):
19
  '''
20
  This function takes list of texts and returns input_ids and attention_mask of texts
21
  '''
@@ -38,6 +40,13 @@ def load_model():
38
  return model, tokenizer
39
 
40
 
 
 
 
 
 
 
 
41
  def predict(text, model, tokenizer):
42
  lookup_key ={0: 'Hardware',
43
  1: 'Access',
@@ -47,31 +56,64 @@ def predict(text, model, tokenizer):
47
  5: 'Administrative rights',
48
  6: 'Storage',
49
  7: 'Internal Project'}
50
- with torch.no_grad():
51
- input_ids, att_mask = encode([text], tokenizer)
52
- logits = model(input_ids = input_ids, attention_mask=att_mask).logits
53
- predicted_class_id = logits.argmax().item()
 
 
 
 
 
54
  predicted_label = lookup_key.get(predicted_class_id)
55
- return predicted_label
 
56
 
57
 
58
  def main():
59
 
60
- st.set_page_config(layout="wide", page_title="IT Service NLP Classification",)
 
 
61
 
62
  with st.sidebar:
63
- audio = mic_recorder(key='my_recorder', callback=callback)
64
- button = st.button('start classification')
 
 
 
 
 
 
 
65
 
66
  if button:
67
- st.write('Loading')
68
- text = transcribe(upload=audio["bytes"])
69
- st.write(f'Speech-to-text Result:')
70
- st.write(f'{text}')
71
- model, tokenizer = load_model()
72
- prediction = predict(text=text, model=model, tokenizer=tokenizer)
73
- st.write(f'Classifcation Result:')
 
 
 
 
 
 
 
 
 
 
 
 
74
  st.write(f'{prediction}')
75
 
 
 
 
 
 
 
76
  if __name__ == '__main__':
77
  main()
 
2
  from streamlit_mic_recorder import mic_recorder
3
  from transformers import pipeline
4
  import torch
5
+ from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
6
+ import numpy as np
7
+ import pandas as pd
8
 
9
  def callback():
10
  if st.session_state.my_recorder_output:
11
  audio_bytes = st.session_state.my_recorder_output['bytes']
12
  st.audio(audio_bytes)
13
 
14
+ def transcribe_and_translate(upload):
15
+ pipe = pipeline("automatic-speech-recognition", model="openai/whisper-medium")
16
+ transcribe_result = pipe(upload, generate_kwargs={'task': 'transcribe'})
17
+ translate_result = pipe(upload, generate_kwargs={'task': 'translate'})
18
+ return transcribe_result['text'], translate_result['text']
19
 
20
+ def encode_depracated(docs, tokenizer):
21
  '''
22
  This function takes list of texts and returns input_ids and attention_mask of texts
23
  '''
 
40
  return model, tokenizer
41
 
42
 
43
+ def load_model():
44
+ PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification"
45
+ model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8)
46
+ tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM)
47
+ return model, tokenizer
48
+
49
+
50
  def predict(text, model, tokenizer):
51
  lookup_key ={0: 'Hardware',
52
  1: 'Access',
 
56
  5: 'Administrative rights',
57
  6: 'Storage',
58
  7: 'Internal Project'}
59
+ # with torch.no_grad():
60
+ # input_ids, att_mask = encode([text], tokenizer)
61
+ # logits = model(input_ids = input_ids, attention_mask=att_mask).logits
62
+ inputs = tokenizer(text,
63
+ padding = True,
64
+ truncation = True,
65
+ return_tensors='pt')
66
+ outputs = model(**inputs)
67
+ predicted_class_id = outputs.logits.argmax().item()
68
  predicted_label = lookup_key.get(predicted_class_id)
69
+ probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
70
+ return predicted_label, probability
71
 
72
 
73
  def main():
74
 
75
+ st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="🤖",)
76
+ st.markdown('<b>🤖 Welcome to IT Service Classification Assistant!!! 🤖</b>', unsafe_allow_html=True)
77
+ st.write(f'\n')
78
 
79
  with st.sidebar:
80
+ st.image('front_page_image.jpg' , use_column_width=True)
81
+ options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"])
82
+ if options == "Start a recording":
83
+ audio = mic_recorder(key='my_recorder', callback=callback)
84
+ elif options == "Upload an audio":
85
+ audio = st.file_uploader("Please upload an audio")
86
+ else:
87
+ text = st.text_area("Please input the transcript (Only support English)")
88
+ button = st.button('Submit')
89
 
90
  if button:
91
+ with st.spinner(text="Loading... It may take longer for initialisation."):
92
+ model, tokenizer = load_model()
93
+ if options == "Start a recording":
94
+ transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"])
95
+ prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
96
+ elif options == "Upload an audio":
97
+ transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue)
98
+ prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
99
+ else:
100
+ transcibe_text = text
101
+ prediction, probability = predict(text=text, model=model, tokenizer=tokenizer)
102
+ st.markdown('<font color="blue"><b>Transcript:</b></font>', unsafe_allow_html=True)
103
+ st.write(f'{transcibe_text}')
104
+ st.write(f'\n')
105
+ if options != "Enter a transcript":
106
+ st.markdown('<font color="red"><b>Translation:</b></font>', unsafe_allow_html=True)
107
+ st.write(f'{translate_text}')
108
+ st.write(f'\n')
109
+ st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
110
  st.write(f'{prediction}')
111
 
112
+ # Convert probability to bar
113
+ st.write(f'\n')
114
+ objects = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project')
115
+ df = pd.DataFrame({'Categories': objects, 'Probability': probability[0]})
116
+ st.bar_chart(data=df, x='Categories', y='Probability')
117
+
118
  if __name__ == '__main__':
119
  main()