Spaces:
Runtime error
Runtime error
File size: 1,540 Bytes
4227e0c e9c923a 33f47d9 4227e0c e9c923a 4227e0c e9c923a 4227e0c e9c923a 4227e0c 33f47d9 b5a8c50 e9c923a 4227e0c e9c923a 4227e0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import streamlit as st
import transformers
import matplotlib.pyplot as plt
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_pipe():
model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer,
return_all_scores=True, truncation=True)
return pipe
def sort_predictions(predictions):
return sorted(predictions, key=lambda x: x['score'], reverse=True)
st.set_page_config(page_title="Emotion Prediction")
st.title("Emotion Prediction")
st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
text = st.text_area('Enter text here:')
submit = st.button('Predict')
with st.spinner("Loading model..."):
pipe = get_pipe()
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
prediction = pipe(text)[0]
prediction = sort_predictions(prediction)
fig, ax = plt.subplots()
ax.bar(x=[i for i, _ in enumerate(prediction)],
height=[p['score'] for p in prediction],
tick_label=[p['label'] for p in prediction])
ax.tick_params(rotation=90)
ax.set_ylim(0, 1)
st.header('Prediction:')
st.pyplot(fig)
prediction = dict([(p['label'], p['score']) for p in prediction])
st.header('Raw values:')
st.json(prediction)
|