Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_mic_recorder import mic_recorder | |
from transformers import pipeline | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import numpy as np | |
import pandas as pd | |
import time | |
import altair as alt | |
def callback(): | |
if st.session_state.my_recorder_output: | |
audio_bytes = st.session_state.my_recorder_output['bytes'] | |
st.audio(audio_bytes) | |
def load_text_to_speech_model(model="openai/whisper-base"): | |
pipe = pipeline("automatic-speech-recognition", model=model) | |
return pipe | |
def translate(inputs, model="openai/whisper-base"): | |
pipe = load_text_to_speech_model(model=model) | |
translate_result = pipe(inputs, generate_kwargs={'task': 'translate'}) | |
return translate_result['text'] | |
def load_classification_model(): | |
PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification" | |
# model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8) | |
# tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM) | |
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True) | |
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM, | |
num_labels=8) | |
return model, tokenizer | |
def predict(text, model, tokenizer): | |
lookup_key ={0: 'Hardware', | |
1: 'Access', | |
2: 'Miscellaneous', | |
3: 'HR Support', | |
4: 'Purchase', | |
5: 'Administrative rights', | |
6: 'Storage', | |
7: 'Internal Project'} | |
inputs = tokenizer(text, | |
padding = True, | |
truncation = True, | |
return_tensors='pt') | |
outputs = model(**inputs) | |
predicted_class_id = outputs.logits.argmax().item() | |
predicted_label = lookup_key.get(predicted_class_id) | |
probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy() | |
return predicted_label, predicted_class_id, probability | |
def display_result(translate_text, prediction, predicted_class_id, probability): | |
category = ('Hardware', | |
'Access', | |
'Miscellaneous', | |
'HR Support', | |
'Purchase', | |
'Administrative rights', | |
'Storage', | |
'Internal Project') | |
# Show translated text and prediction | |
st.markdown('<font color="purple"><b>Text:</b></font>', unsafe_allow_html=True) | |
st.write(f'{translate_text}') | |
st.write(f'\n') | |
st.write(f'\n') | |
st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True) | |
st.write(f'{prediction}') | |
st.write(f'\n') | |
st.write(f'\n') | |
# Show Probability of each Service Category | |
probability = np.array(probability[0]) | |
df = pd.DataFrame({'Category': category, 'Probability (%)': probability * 100}) | |
df['Probability (%)'] = df['Probability (%)'].apply(lambda x: round(x, 2)) | |
base = alt.Chart(df).encode( | |
x='Probability (%)', | |
y=alt.Y('Category').sort('-x'), | |
tooltip=['Category',alt.Tooltip('Probability (%)', format=",.2f")], | |
text='Probability (%)' | |
).properties(title="Probability of each Service Category") | |
chart = base.mark_bar() + base.mark_text(align='left', dx=2) | |
st.altair_chart(chart, use_container_width=True) | |
def main(): | |
# define parameters | |
image_path = 'front_page_image.jpg' | |
model_options = ["openai/whisper-base", "openai/whisper-large-v3"] | |
input_options = ["Start a recording", "Upload an audio", "Enter a transcript"] | |
# st.cache_resource.clear() | |
st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="π€",) | |
st.markdown('<b>π€ Welcome to IT Service Classification Assistant!!! π€</b>', unsafe_allow_html=True) | |
st.write(f'\n') | |
st.write(f'\n') | |
with st.sidebar: | |
st.image(image_path , use_column_width=True) | |
text_to_speech_model = st.selectbox("Pick select a speech to text model", model_options) | |
options = st.selectbox("Pick select an input method", input_options) | |
# start a recording | |
if options == input_options[0]: | |
audio = mic_recorder(key='my_recorder', callback=callback) | |
# Upload an audio | |
elif options == input_options[1]: | |
audio = st.file_uploader("Please upload an audio", type=["wav", "mp3"]) | |
# Enter a transcript | |
else: | |
text = st.text_area("Please input the transcript (Only support English)") | |
button = st.button('Submit') | |
if button: | |
with st.spinner(text="Loading... It may take a while if you are running the app for the first time."): | |
start_time = time.time() | |
# get inputs | |
if options == input_options[0]: | |
translate_text = translate(inputs=audio["bytes"], model=text_to_speech_model) | |
elif options == input_options[1]: | |
translate_text = translate(inputs=audio.getvalue(), model=text_to_speech_model) | |
else: | |
translate_text = text | |
model, tokenizer = load_classification_model() | |
prediction, predicted_class_id, probability = predict(text=translate_text, model=model, tokenizer=tokenizer) | |
end_time = time.time() | |
display_result(translate_text, prediction, predicted_class_id, probability) | |
st.write(f'\n') | |
st.write(f'\n') | |
st.markdown(f'*It took {(end_time-start_time):.2f} sec to process the input.', unsafe_allow_html=True) | |
if __name__ == '__main__': | |
main() | |