File size: 3,418 Bytes
5cce1aa
 
 
 
 
 
 
 
 
 
c551049
 
5cce1aa
 
 
 
c93a5e6
5cce1aa
4aa24cc
5cce1aa
 
4aa24cc
c6abafe
5cce1aa
 
 
4ec43da
5cce1aa
4aa24cc
0b86248
115c5c8
0b86248
0018425
4ec43da
 
 
 
bec8546
 
 
0018425
4ec43da
 
 
5cce1aa
 
 
 
4aa24cc
 
0018425
4ec43da
 
 
5cce1aa
 
 
 
 
 
 
 
 
 
 
 
a07631a
5cce1aa
 
 
 
 
 
 
c6abafe
 
5cce1aa
 
 
 
 
 
 
 
 
 
 
 
cedf170
 
5cce1aa
 
 
 
 
 
4aa24cc
5cce1aa
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import pipeline
import plotly.express as px
import pandas as pd


st.set_page_config(layout="wide")

@st.cache(allow_output_mutation = True)
def get_classifier_model():
    return pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
    #return pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    #return pipeline("zero-shot-classification",model="sentence-transformers/paraphrase-MiniLM-L6-v2")

#st.sidebar.image("Suncorp-Bank-logo.png",width=255)

#st.image("Suncorp-Bank-logo.png",width=255)

st.title("Review Analyzer")
st.markdown("***")

text = st.text_area(label="Paste/Type the review here..")
    
st.markdown("***")

col1, col2, col3 = st.columns((1,1,1))

col1.header("Select Sentiments")
sentiments = col1.multiselect("",["Happy","Sad","Neutral"],["Happy","Sad","Neutral"])
col1.markdown(" \n")
col1.markdown(" \n")

additional_sentiments = col1.text_input("Enter comma separated sentiments.")

if additional_sentiments:
    sentiments = sentiments + additional_sentiments.split(",")

col2.header("Select Topics")
entities = col2.multiselect("",["Bank Account","Credit Card","Home Loan","Motor Loan"],
                            ["Bank Account","Credit Card","Home Loan","Motor Loan"])
additional_entities= col2.text_input("Enter comma separated entities.")

if additional_entities:
    entities = entities + additional_entities.split(",")


col3.header("Select Reasons")

reasons = col3.multiselect("",["Poor Service","No Empathy","Abuse"],
                            ["Poor Service","No Empathy","Abuse"])
additional_reasons= col3.text_input("Enter comma separated reasons.")

if additional_reasons:
    reasons = reasons + additional_reasons.split(",")

is_multi_class =  st.checkbox("Can have more than one classes",value=True)

st.markdown("***")

classify_button_clicked = st.button("Classify")

def get_classification(candidate_labels):
    classification_output = classifier(sequence_to_classify, candidate_labels, multi_class=is_multi_class)
    data = {'Class': classification_output['labels'], 'Scores': classification_output['scores']}
    df = pd.DataFrame(data)
    df = df.sort_values(by='Scores', ascending=False)
    fig = px.bar(df, x='Scores', y='Class', orientation='h', width=400, height=500)
    fig.update_layout(
        yaxis=dict(
            autorange='reversed'
        )
    )
    return fig



if classify_button_clicked:
    if text:
        st.markdown("***")
        with st.spinner("  Please wait while the text is being classified.."):
            classifier = get_classifier_model()
            sequence_to_classify = text
            # candidate_labels = sentiments + entities + reasons

            if sentiments:
                #print(classification_output)
                fig = get_classification(sentiments)
                # col5, col6= st.columns((1, 1))
                col1.markdown(" \n")
                
                col1.write(fig)

            if entities:
                #print(classification_output)
                fig = get_classification(entities)
                # col7, col8= st.columns((1, 1))
                
                col2.write(fig)

            if reasons:
                #print(classification_output)
                fig = get_classification(reasons)
                # col7, col8= st.columns((1, 1))
                col3.write(fig)