File size: 5,447 Bytes
7e0431e
 
 
 
e5a5157
 
 
7e0431e
 
 
 
 
 
e5a5157
 
 
 
 
7e0431e
e5a5157
7e0431e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eea6be
7e0431e
 
 
e5a5157
 
 
 
 
 
 
7e0431e
 
 
 
 
 
 
 
 
e5a5157
 
 
 
 
 
 
 
 
7e0431e
e5a5157
 
7e0431e
 
 
 
e5a5157
 
 
7e0431e
 
e5a5157
 
 
 
 
 
 
 
 
7e0431e
 
e5a5157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e0431e
 
e5a5157
 
 
 
 
 
7e0431e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
from streamlit_mic_recorder import mic_recorder
from transformers import pipeline
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
import pandas as pd

def callback():
    if st.session_state.my_recorder_output:
        audio_bytes = st.session_state.my_recorder_output['bytes']
        st.audio(audio_bytes)

def transcribe_and_translate(upload):
    pipe = pipeline("automatic-speech-recognition", model="openai/whisper-medium")
    transcribe_result = pipe(upload, generate_kwargs={'task': 'transcribe'})
    translate_result = pipe(upload, generate_kwargs={'task': 'translate'})
    return transcribe_result['text'], translate_result['text']

def encode_depracated(docs, tokenizer):
    '''
    This function takes list of texts and returns input_ids and attention_mask of texts
    '''
    encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length',
                            return_attention_mask=True, truncation=True, return_tensors='pt')
    input_ids = encoded_dict['input_ids']
    attention_masks = encoded_dict['attention_mask']
    return input_ids, attention_masks


def load_model():
    CUSTOMMODEL_PATH = "./bert-itserviceclassification"
    PRETRAINED_LM = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True)
    model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
                                                        num_labels=8,
                                                        output_attentions=False,
                                                        output_hidden_states=False)
    model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu'))
    return model, tokenizer


def load_model():
    PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification"
    model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8)
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM)
    return model, tokenizer


def predict(text, model, tokenizer):
    lookup_key ={0: 'Hardware',
    1: 'Access',
    2: 'Miscellaneous',
    3: 'HR Support',
    4: 'Purchase',
    5: 'Administrative rights',
    6: 'Storage',
    7: 'Internal Project'}
    # with torch.no_grad():
    #     input_ids, att_mask = encode([text], tokenizer)
    #     logits = model(input_ids = input_ids, attention_mask=att_mask).logits
    inputs = tokenizer(text,
                   padding = True,
                   truncation = True,
                   return_tensors='pt')
    outputs = model(**inputs)
    predicted_class_id = outputs.logits.argmax().item()
    predicted_label = lookup_key.get(predicted_class_id)
    probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
    return predicted_label, probability


def main():

    st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="🤖",)
    st.markdown('<b>🤖 Welcome to IT Service Classification Assistant!!! 🤖</b>', unsafe_allow_html=True)
    st.write(f'\n')

    with st.sidebar:
        st.image('front_page_image.jpg' , use_column_width=True)
        options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"])
        if options == "Start a recording":
            audio = mic_recorder(key='my_recorder', callback=callback)
        elif options == "Upload an audio":
            audio = st.file_uploader("Please upload an audio")
        else:
            text = st.text_area("Please input the transcript (Only support English)")
        button = st.button('Submit')

    if button:
        with st.spinner(text="Loading... It may take longer for initialisation."):
            model, tokenizer = load_model()
            if options == "Start a recording":
                transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"])
                prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
            elif options == "Upload an audio":
                transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue)
                prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
            else:
                transcibe_text = text
                prediction, probability = predict(text=text, model=model, tokenizer=tokenizer)
        st.markdown('<font color="blue"><b>Transcript:</b></font>', unsafe_allow_html=True)
        st.write(f'{transcibe_text}')
        st.write(f'\n')
        if options != "Enter a transcript":
            st.markdown('<font color="red"><b>Translation:</b></font>', unsafe_allow_html=True)
            st.write(f'{translate_text}')
            st.write(f'\n')
        st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
        st.write(f'{prediction}')

        # Convert probability to bar
        st.write(f'\n')
        objects = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project')
        df = pd.DataFrame({'Categories': objects, 'Probability': probability[0]})
        st.bar_chart(data=df, x='Categories', y='Probability')

if __name__ == '__main__':
    main()