Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_mic_recorder import mic_recorder | |
from transformers import pipeline | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer | |
import numpy as np | |
import pandas as pd | |
def callback(): | |
if st.session_state.my_recorder_output: | |
audio_bytes = st.session_state.my_recorder_output['bytes'] | |
st.audio(audio_bytes) | |
def transcribe_and_translate(upload): | |
pipe = pipeline("automatic-speech-recognition", model="openai/whisper-medium") | |
transcribe_result = pipe(upload, generate_kwargs={'task': 'transcribe'}) | |
translate_result = pipe(upload, generate_kwargs={'task': 'translate'}) | |
return transcribe_result['text'], translate_result['text'] | |
def encode_depracated(docs, tokenizer): | |
''' | |
This function takes list of texts and returns input_ids and attention_mask of texts | |
''' | |
encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length', | |
return_attention_mask=True, truncation=True, return_tensors='pt') | |
input_ids = encoded_dict['input_ids'] | |
attention_masks = encoded_dict['attention_mask'] | |
return input_ids, attention_masks | |
def load_model(): | |
CUSTOMMODEL_PATH = "./bert-itserviceclassification" | |
PRETRAINED_LM = "bert-base-uncased" | |
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True) | |
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM, | |
num_labels=8, | |
output_attentions=False, | |
output_hidden_states=False) | |
model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu')) | |
return model, tokenizer | |
def load_model(): | |
PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification" | |
model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8) | |
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM) | |
return model, tokenizer | |
def predict(text, model, tokenizer): | |
lookup_key ={0: 'Hardware', | |
1: 'Access', | |
2: 'Miscellaneous', | |
3: 'HR Support', | |
4: 'Purchase', | |
5: 'Administrative rights', | |
6: 'Storage', | |
7: 'Internal Project'} | |
# with torch.no_grad(): | |
# input_ids, att_mask = encode([text], tokenizer) | |
# logits = model(input_ids = input_ids, attention_mask=att_mask).logits | |
inputs = tokenizer(text, | |
padding = True, | |
truncation = True, | |
return_tensors='pt') | |
outputs = model(**inputs) | |
predicted_class_id = outputs.logits.argmax().item() | |
predicted_label = lookup_key.get(predicted_class_id) | |
probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy() | |
return predicted_label, probability | |
def main(): | |
st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="π€",) | |
st.markdown('<b>π€ Welcome to IT Service Classification Assistant!!! π€</b>', unsafe_allow_html=True) | |
st.write(f'\n') | |
with st.sidebar: | |
st.image('front_page_image.jpg' , use_column_width=True) | |
options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"]) | |
if options == "Start a recording": | |
audio = mic_recorder(key='my_recorder', callback=callback) | |
elif options == "Upload an audio": | |
audio = st.file_uploader("Please upload an audio") | |
else: | |
text = st.text_area("Please input the transcript (Only support English)") | |
button = st.button('Submit') | |
if button: | |
with st.spinner(text="Loading... It may take longer for initialisation."): | |
model, tokenizer = load_model() | |
if options == "Start a recording": | |
transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"]) | |
prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer) | |
elif options == "Upload an audio": | |
transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue) | |
prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer) | |
else: | |
transcibe_text = text | |
prediction, probability = predict(text=text, model=model, tokenizer=tokenizer) | |
st.markdown('<font color="blue"><b>Transcript:</b></font>', unsafe_allow_html=True) | |
st.write(f'{transcibe_text}') | |
st.write(f'\n') | |
if options != "Enter a transcript": | |
st.markdown('<font color="red"><b>Translation:</b></font>', unsafe_allow_html=True) | |
st.write(f'{translate_text}') | |
st.write(f'\n') | |
st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True) | |
st.write(f'{prediction}') | |
# Convert probability to bar | |
st.write(f'\n') | |
objects = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project') | |
df = pd.DataFrame({'Categories': objects, 'Probability': probability[0]}) | |
st.bar_chart(data=df, x='Categories', y='Probability') | |
if __name__ == '__main__': | |
main() |