Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_mic_recorder import mic_recorder | |
from transformers import pipeline | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
def callback(): | |
if st.session_state.my_recorder_output: | |
audio_bytes = st.session_state.my_recorder_output['bytes'] | |
st.audio(audio_bytes) | |
def transcribe(upload): | |
pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large") | |
result = pipe(upload, generate_kwargs={'task': 'transcribe'}) | |
print(result['text']) | |
return result['text'] | |
def encode(docs, tokenizer): | |
''' | |
This function takes list of texts and returns input_ids and attention_mask of texts | |
''' | |
encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length', | |
return_attention_mask=True, truncation=True, return_tensors='pt') | |
input_ids = encoded_dict['input_ids'] | |
attention_masks = encoded_dict['attention_mask'] | |
return input_ids, attention_masks | |
def load_model(): | |
CUSTOMMODEL_PATH = "./bert-itserviceclassification" | |
PRETRAINED_LM = "bert-base-uncased" | |
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True) | |
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM, | |
num_labels=8, | |
output_attentions=False, | |
output_hidden_states=False) | |
model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu')) | |
return model, tokenizer | |
def predict(text, model, tokenizer): | |
lookup_key ={0: 'Hardware', | |
1: 'Access', | |
2: 'Miscellaneous', | |
3: 'HR Support', | |
4: 'Purchase', | |
5: 'Administrative rights', | |
6: 'Storage', | |
7: 'Internal Project'} | |
with torch.no_grad(): | |
input_ids, att_mask = encode([text], tokenizer) | |
logits = model(input_ids = input_ids, attention_mask=att_mask).logits | |
predicted_class_id = logits.argmax().item() | |
predicted_label = lookup_key.get(predicted_class_id) | |
return predicted_label | |
def main(): | |
st.set_page_config(layout="wide", page_title="IT Service NLP Classification",) | |
with st.sidebar: | |
audio = mic_recorder(key='my_recorder', callback=callback) | |
button = st.button('start classification') | |
if button: | |
st.write('Loading') | |
text = transcribe(upload=audio["bytes"]) | |
st.write(f'Speech-to-text Result:') | |
st.write(f'{text}') | |
model, tokenizer = load_model() | |
prediction = predict(text=text, model=model, tokenizer=tokenizer) | |
st.write(f'Classifcation Result:') | |
st.write(f'{prediction}') | |
if __name__ == '__main__': | |
main() |