Spaces:
Runtime error
Runtime error
Ryan Kim
commited on
Commit
·
7883937
1
Parent(s):
40deb01
trying to make programmatic way to store sentiment analysis label dictionaries for each type
Browse files- src/main.py +16 -6
src/main.py
CHANGED
@@ -11,12 +11,27 @@ st.markdown("")
|
|
11 |
def load_model(model_name):
|
12 |
return pipeline(model=model_name, task="sentiment-analysis")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
if "model" not in st.session_state:
|
15 |
st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
16 |
st.session_state.model = load_model("cardiffnlp/twitter-roberta-base-sentiment")
|
|
|
17 |
|
18 |
def model_change():
|
19 |
st.session_state.model = load_model(st.session_state.model_name)
|
|
|
20 |
|
21 |
model_option = st.selectbox(
|
22 |
"What sentiment analysis model do you want to use?",
|
@@ -44,12 +59,7 @@ if submit:
|
|
44 |
label = result[0]['label']
|
45 |
score = result[0]['score']
|
46 |
|
47 |
-
|
48 |
-
label = "Negative"
|
49 |
-
elif label == "LABEL_2":
|
50 |
-
label = "Positive"
|
51 |
-
elif label == "LABEL_1":
|
52 |
-
label = "Neutral"
|
53 |
|
54 |
st.markdown("#### Result:")
|
55 |
st.markdown("**{}**: {}".format(label,score))
|
|
|
11 |
def load_model(model_name):
|
12 |
return pipeline(model=model_name, task="sentiment-analysis")
|
13 |
|
14 |
+
@st.cache(allow_output_mutation=True)
|
15 |
+
def label_dictionary(model_name):
|
16 |
+
if model_name == "cardiffnlp/twitter-roberta-base-sentiment":
|
17 |
+
def twitter_roberta(label):
|
18 |
+
if label == "LABEL_0":
|
19 |
+
return "Negative"
|
20 |
+
elif label == "LABEL_2":
|
21 |
+
return "Positive"
|
22 |
+
else:
|
23 |
+
return "Neutral"
|
24 |
+
return twitter_roberta
|
25 |
+
return lambda x: x
|
26 |
+
|
27 |
if "model" not in st.session_state:
|
28 |
st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
29 |
st.session_state.model = load_model("cardiffnlp/twitter-roberta-base-sentiment")
|
30 |
+
st.session_state.label_parser = label_dictionary("cardiffnlp/twitter-roberta-base-sentiment")
|
31 |
|
32 |
def model_change():
|
33 |
st.session_state.model = load_model(st.session_state.model_name)
|
34 |
+
st.session_state.label_parser = label_dictionary(st.session_state.model_name)
|
35 |
|
36 |
model_option = st.selectbox(
|
37 |
"What sentiment analysis model do you want to use?",
|
|
|
59 |
label = result[0]['label']
|
60 |
score = result[0]['score']
|
61 |
|
62 |
+
label = st.session_state.label_parser(label)
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
st.markdown("#### Result:")
|
65 |
st.markdown("**{}**: {}".format(label,score))
|