Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# We'll be using Torch this time around | |
import torch | |
import torch.nn.functional as F | |
def label_dictionary(model_name): | |
if model_name == "cardiffnlp/twitter-roberta-base-sentiment": | |
def twitter_roberta(label): | |
if label == "LABEL_0": | |
return "Negative" | |
elif label == "LABEL_2": | |
return "Positive" | |
else: | |
return "Neutral" | |
return twitter_roberta | |
return lambda x: x | |
def load_model(model_name): | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
classifier = pipeline(task="sentiment-analysis", model=model, tokenizer=tokenizer) | |
parser = label_dictionary(model_name) | |
return model, tokenizer, classifier, parser | |
# We first initialize a state. The state will include the following: | |
# 1) the name of the model (default: cardiffnlp/twitter-roberta-base-sentiment) | |
# 2) the model itself, and | |
# 3) the parser for the outputs, in case we actually need to parse the output to something more sensible | |
if "model" not in st.session_state: | |
st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment" | |
model, tokenizer, classifier, label_parser = load_model("cardiffnlp/twitter-roberta-base-sentiment") | |
st.session_state.model = model | |
st.session_state.tokenizer = tokenizer | |
st.session_state.classifier = classifier | |
st.session_state.label_parser = label_parser | |
def model_change(): | |
model, tokenizer, classifier, label_parser = load_model(st.session_state.model_name) | |
st.session_state.model = model | |
st.session_state.tokenizer = tokenizer | |
st.session_state.classifier = classifier | |
st.session_state.label_parser = label_parser | |
# Title | |
st.title("CSGY-6613 Sentiment Analysis") | |
# Subtitle | |
st.markdown("### Ryan Kim (rk2546)") | |
st.markdown("") | |
model_option = st.selectbox( | |
"What sentiment analysis model do you want to use?", | |
( | |
"cardiffnlp/twitter-roberta-base-sentiment", | |
"finiteautomata/beto-sentiment-analysis", | |
"bhadresh-savani/distilbert-base-uncased-emotion", | |
"siebert/sentiment-roberta-large-english" | |
), | |
on_change=model_change, | |
key="model_name" | |
) | |
placeholder="@AmericanAir just landed - 3hours Late Flight - and now we need to wait TWENTY MORE MINUTES for a gate! I have patience but none for incompetence." | |
form = st.form(key='sentiment-analysis-form') | |
text_input = form.text_area("Enter some text for sentiment analysis! If you just want to test it out without entering anything, just press the \"Submit\" button and the model will look at the placeholder.", placeholder=placeholder) | |
submit = form.form_submit_button('Submit') | |
if submit: | |
if text_input is None or len(text_input.strip()) == 0: | |
to_eval = placeholder | |
else: | |
to_eval = text_input.strip() | |
st.write("You entered:") | |
st.markdown("> {}".format(to_eval)) | |
st.write("Using the NLP model:") | |
st.markdown("> {}".format(st.session_state.model_name)) | |
result = st.session_state.classifier(to_eval) | |
label = result[0]['label'] | |
score = result[0]['score'] | |
label = st.session_state.label_parser(label) | |
st.markdown("#### Result:") | |
st.markdown("**{}**: {}".format(label,score)) | |
st.write("") | |
st.write("") | |