kkngan's picture
Update app.py
e5a5157 verified
raw
history blame
5.45 kB
import streamlit as st
from streamlit_mic_recorder import mic_recorder
from transformers import pipeline
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
import pandas as pd
def callback():
if st.session_state.my_recorder_output:
audio_bytes = st.session_state.my_recorder_output['bytes']
st.audio(audio_bytes)
def transcribe_and_translate(upload):
pipe = pipeline("automatic-speech-recognition", model="openai/whisper-medium")
transcribe_result = pipe(upload, generate_kwargs={'task': 'transcribe'})
translate_result = pipe(upload, generate_kwargs={'task': 'translate'})
return transcribe_result['text'], translate_result['text']
def encode_depracated(docs, tokenizer):
'''
This function takes list of texts and returns input_ids and attention_mask of texts
'''
encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length',
return_attention_mask=True, truncation=True, return_tensors='pt')
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
return input_ids, attention_masks
def load_model():
CUSTOMMODEL_PATH = "./bert-itserviceclassification"
PRETRAINED_LM = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True)
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
num_labels=8,
output_attentions=False,
output_hidden_states=False)
model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu'))
return model, tokenizer
def load_model():
PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification"
model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8)
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM)
return model, tokenizer
def predict(text, model, tokenizer):
lookup_key ={0: 'Hardware',
1: 'Access',
2: 'Miscellaneous',
3: 'HR Support',
4: 'Purchase',
5: 'Administrative rights',
6: 'Storage',
7: 'Internal Project'}
# with torch.no_grad():
# input_ids, att_mask = encode([text], tokenizer)
# logits = model(input_ids = input_ids, attention_mask=att_mask).logits
inputs = tokenizer(text,
padding = True,
truncation = True,
return_tensors='pt')
outputs = model(**inputs)
predicted_class_id = outputs.logits.argmax().item()
predicted_label = lookup_key.get(predicted_class_id)
probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
return predicted_label, probability
def main():
st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="πŸ€–",)
st.markdown('<b>πŸ€– Welcome to IT Service Classification Assistant!!! πŸ€–</b>', unsafe_allow_html=True)
st.write(f'\n')
with st.sidebar:
st.image('front_page_image.jpg' , use_column_width=True)
options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"])
if options == "Start a recording":
audio = mic_recorder(key='my_recorder', callback=callback)
elif options == "Upload an audio":
audio = st.file_uploader("Please upload an audio")
else:
text = st.text_area("Please input the transcript (Only support English)")
button = st.button('Submit')
if button:
with st.spinner(text="Loading... It may take longer for initialisation."):
model, tokenizer = load_model()
if options == "Start a recording":
transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"])
prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
elif options == "Upload an audio":
transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue)
prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
else:
transcibe_text = text
prediction, probability = predict(text=text, model=model, tokenizer=tokenizer)
st.markdown('<font color="blue"><b>Transcript:</b></font>', unsafe_allow_html=True)
st.write(f'{transcibe_text}')
st.write(f'\n')
if options != "Enter a transcript":
st.markdown('<font color="red"><b>Translation:</b></font>', unsafe_allow_html=True)
st.write(f'{translate_text}')
st.write(f'\n')
st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
st.write(f'{prediction}')
# Convert probability to bar
st.write(f'\n')
objects = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project')
df = pd.DataFrame({'Categories': objects, 'Probability': probability[0]})
st.bar_chart(data=df, x='Categories', y='Probability')
if __name__ == '__main__':
main()