Spaces:
Runtime error
Runtime error
import streamlit as st | |
import transformers | |
import matplotlib.pyplot as plt | |
def get_pipe(): | |
model_name = "joeddav/distilbert-base-uncased-go-emotions-student" | |
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer, | |
return_all_scores=True, truncation=True) | |
return pipe | |
def sort_predictions(predictions): | |
return sorted(predictions, key=lambda x: x['score'], reverse=True) | |
st.set_page_config(page_title="Emotion Prediction") | |
st.title("Emotion Prediction") | |
st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.") | |
default_text = "I really love using HuggingFace Spaces!" | |
text = st.text_area('Enter text here:', value=default_text) | |
submit = st.button('Predict') | |
with st.spinner("Loading model..."): | |
pipe = get_pipe() | |
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0: | |
prediction = pipe(text)[0] | |
prediction = sort_predictions(prediction) | |
fig, ax = plt.subplots() | |
ax.bar(x=[i for i, _ in enumerate(prediction)], | |
height=[p['score'] for p in prediction], | |
tick_label=[p['label'] for p in prediction]) | |
ax.tick_params(rotation=90) | |
ax.set_ylim(0, 1) | |
st.header('Prediction:') | |
st.pyplot(fig) | |
prediction = dict([(p['label'], p['score']) for p in prediction]) | |
st.header('Raw values:') | |
st.json(prediction) | |