Sasidhar commited on
Commit
5cce1aa
·
1 Parent(s): 4202285

Create new file

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ import plotly.express as px
4
+ import pandas as pd
5
+
6
+
7
+ st.set_page_config(layout="wide")
8
+
9
+ @st.cache(allow_output_mutation = True)
10
+ def get_classifier_model():
11
+ return pipeline("zero-shot-classification", model="models/bart-large-mnli")
12
+ #return pipeline("zero-shot-classification",model="sentence-transformers/paraphrase-MiniLM-L6-v2")
13
+
14
+
15
+ #st.sidebar.image("Suncorp-Bank-logo.png",width=255)
16
+
17
+ st.image("Suncorp-Bank-logo.png",width=255)
18
+
19
+ st.title("Detecting Barriers from Conversations")
20
+ st.markdown("***")
21
+
22
+ text = st.text_area(label="Enter text to classify")
23
+ st.markdown("***")
24
+
25
+ col1, col2, col3 = st.columns((1,1,1))
26
+ col1.header("Select Sentiments")
27
+ sentiments = col1.multiselect("",["Happy","Sad","Anxious","Depressed","Empathetic"],["Happy","Sad","Anxious","Depressed","Empathetic"])
28
+ col2.header("Select Entities")
29
+ entities = col2.multiselect("",["Employee","Doctor","Family","Friends"],
30
+ ["Employee","Doctor","Family","Friends"])
31
+
32
+
33
+ col3.header("Select Reasons")
34
+
35
+ reasons = col3.multiselect("",["Bullying","Alchohol","Abuse","Domestic_Violence",'Chronic_Pain','Driving','Hobbies','Treatment'],
36
+ ["Bullying","Alchohol","Abuse","Domestic_Violence",'Chronic_Pain','Driving','Hobbies','Treatment'])
37
+
38
+ is_multi_class = st.checkbox("Can have more than one classes",value=True)
39
+
40
+ st.markdown("***")
41
+
42
+ classify_button_clicked = st.button("Classify")
43
+
44
+ def get_classification(candidate_labels):
45
+ classification_output = classifier(sequence_to_classify, candidate_labels, multi_class=is_multi_class)
46
+ data = {'Class': classification_output['labels'], 'Scores': classification_output['scores']}
47
+ df = pd.DataFrame(data)
48
+ df = df.sort_values(by='Scores', ascending=False)
49
+ fig = px.bar(df, x='Scores', y='Class', orientation='h', width=800, height=800)
50
+ fig.update_layout(
51
+ yaxis=dict(
52
+ autorange='reversed'
53
+ )
54
+ )
55
+ return fig
56
+
57
+ if classify_button_clicked:
58
+ if text:
59
+ st.markdown("***")
60
+ with st.spinner(" Please wait while the text is being classified.."):
61
+ classifier = get_classifier_model()
62
+ sequence_to_classify = text
63
+ # candidate_labels = sentiments + entities + reasons
64
+
65
+ if sentiments:
66
+ #print(classification_output)
67
+ fig = get_classification(sentiments)
68
+ # col5, col6= st.columns((1, 1))
69
+ col1.write(fig)
70
+
71
+ if entities:
72
+ #print(classification_output)
73
+ fig = get_classification(entities)
74
+ # col7, col8= st.columns((1, 1))
75
+ col2.write(fig)
76
+
77
+ if reasons:
78
+ #print(classification_output)
79
+ fig = get_classification(reasons)
80
+ # col7, col8= st.columns((1, 1))
81
+ col3.write(fig)