bentrevett commited on
Commit
4227e0c
·
1 Parent(s): 711edf8

add initial app

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import transformers
3
+ import matplotlib.pyplot as plt
4
+
5
+ @st.cache(allow_output_mutation=True)
6
+ def get_pipe():
7
+ model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
8
+ model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
9
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
10
+ pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer, return_all_scores=True, truncation=True)
11
+ return pipe
12
+
13
+ def sort_predictions(predictions):
14
+ return sorted(predictions, key=lambda x: x['score'], reverse=True)
15
+
16
+ st.set_page_config(page_title="Emotion Prediction")
17
+ st.title("Emotion Prediction")
18
+ st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
19
+
20
+ with st.spinner("Loading model..."):
21
+ pipe = get_pipe()
22
+
23
+ text = st.text_area('Enter text here:')
24
+ submit = st.button('Predict')
25
+
26
+ if submit:
27
+
28
+ prediction = pipe(text)[0]
29
+ prediction = sort_predictions(prediction)
30
+
31
+ fig, ax = plt.subplots()
32
+ ax.bar(x=[i for i, _ in enumerate(prediction)],
33
+ height=[p['score'] for p in prediction],
34
+ tick_label=[p['label'] for p in prediction])
35
+ ax.tick_params(rotation=90)
36
+ ax.set_ylim(0, 1)
37
+
38
+ st.header('Prediction:')
39
+ st.pyplot(fig)
40
+
41
+ prediction = dict([(p['label'], p['score']) for p in prediction])
42
+ st.header('Raw values:')
43
+ st.json(prediction)