kkngan's picture
Update app.py
43930eb verified
import streamlit as st
from streamlit_mic_recorder import mic_recorder
from transformers import pipeline
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
import pandas as pd
import time
import altair as alt
def callback():
if st.session_state.my_recorder_output:
audio_bytes = st.session_state.my_recorder_output['bytes']
st.audio(audio_bytes)
@st.cache_resource
def load_text_to_speech_model(model="openai/whisper-base"):
pipe = pipeline("automatic-speech-recognition", model=model)
return pipe
def translate(inputs, model="openai/whisper-base"):
pipe = load_text_to_speech_model(model=model)
translate_result = pipe(inputs, generate_kwargs={'task': 'translate'})
return translate_result['text']
@st.cache_resource
def load_classification_model():
PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification"
# model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8)
# tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM)
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True)
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
num_labels=8)
return model, tokenizer
def predict(text, model, tokenizer):
lookup_key ={0: 'Hardware',
1: 'Access',
2: 'Miscellaneous',
3: 'HR Support',
4: 'Purchase',
5: 'Administrative rights',
6: 'Storage',
7: 'Internal Project'}
inputs = tokenizer(text,
padding = True,
truncation = True,
return_tensors='pt')
outputs = model(**inputs)
predicted_class_id = outputs.logits.argmax().item()
predicted_label = lookup_key.get(predicted_class_id)
probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
return predicted_label, predicted_class_id, probability
def display_result(translate_text, prediction, predicted_class_id, probability):
category = ('Hardware',
'Access',
'Miscellaneous',
'HR Support',
'Purchase',
'Administrative rights',
'Storage',
'Internal Project')
# Show translated text and prediction
st.markdown('<font color="purple"><b>Text:</b></font>', unsafe_allow_html=True)
st.write(f'{translate_text}')
st.write(f'\n')
st.write(f'\n')
st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
st.write(f'{prediction}')
st.write(f'\n')
st.write(f'\n')
# Show Probability of each Service Category
probability = np.array(probability[0])
df = pd.DataFrame({'Category': category, 'Probability (%)': probability * 100})
df['Probability (%)'] = df['Probability (%)'].apply(lambda x: round(x, 2))
base = alt.Chart(df).encode(
x='Probability (%)',
y=alt.Y('Category').sort('-x'),
tooltip=['Category',alt.Tooltip('Probability (%)', format=",.2f")],
text='Probability (%)'
).properties(title="Probability of each Service Category")
chart = base.mark_bar() + base.mark_text(align='left', dx=2)
st.altair_chart(chart, use_container_width=True)
def main():
# define parameters
image_path = 'front_page_image.jpg'
model_options = ["openai/whisper-base", "openai/whisper-large-v3"]
input_options = ["Start a recording", "Upload an audio", "Enter a transcript"]
# st.cache_resource.clear()
st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="πŸ€–",)
st.markdown('<b>πŸ€– Welcome to IT Service Classification Assistant!!! πŸ€–</b>', unsafe_allow_html=True)
st.write(f'\n')
st.write(f'\n')
with st.sidebar:
st.image(image_path , use_column_width=True)
text_to_speech_model = st.selectbox("Pick select a speech to text model", model_options)
options = st.selectbox("Pick select an input method", input_options)
# start a recording
if options == input_options[0]:
audio = mic_recorder(key='my_recorder', callback=callback)
# Upload an audio
elif options == input_options[1]:
audio = st.file_uploader("Please upload an audio", type=["wav", "mp3"])
# Enter a transcript
else:
text = st.text_area("Please input the transcript (Only support English)")
button = st.button('Submit')
if button:
with st.spinner(text="Loading... It may take a while if you are running the app for the first time."):
start_time = time.time()
# get inputs
if options == input_options[0]:
translate_text = translate(inputs=audio["bytes"], model=text_to_speech_model)
elif options == input_options[1]:
translate_text = translate(inputs=audio.getvalue(), model=text_to_speech_model)
else:
translate_text = text
model, tokenizer = load_classification_model()
prediction, predicted_class_id, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
end_time = time.time()
display_result(translate_text, prediction, predicted_class_id, probability)
st.write(f'\n')
st.write(f'\n')
st.markdown(f'*It took {(end_time-start_time):.2f} sec to process the input.', unsafe_allow_html=True)
if __name__ == '__main__':
main()