digitalWestie's picture
fix missing
75641e2
raw
history blame
1.58 kB
import streamlit as st
from transformers import PerceiverTokenizer, PerceiverForMaskedLM
import transformers
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_pipe():
model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver")
tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")
pipe = transformers.pipeline('text-classification', model=model, tokenizer=tokenizer,
return_all_scores=True, truncation=True)
return pipe
def sort_predictions(predictions):
return sorted(predictions, key=lambda x: x['score'], reverse=True)
st.set_page_config(page_title="Emotion Prediction")
st.title("Emotion Prediction")
st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
default_text = "I really love using HuggingFace Spaces!"
text = st.text_area('Enter text here:', value=default_text)
submit = st.button('Predict')
with st.spinner("Loading model..."):
pipe = get_pipe()
if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
prediction = pipe(text)[0]
prediction = sort_predictions(prediction)
fig, ax = plt.subplots()
ax.bar(x=[i for i, _ in enumerate(prediction)],
height=[p['score'] for p in prediction],
tick_label=[p['label'] for p in prediction])
ax.tick_params(rotation=90)
ax.set_ylim(0, 1)
st.header('Prediction:')
st.pyplot(fig)
prediction = dict([(p['label'], p['score']) for p in prediction])
st.header('Raw values:')
st.json(prediction)