Ryan Kim commited on
Commit
7883937
·
1 Parent(s): 40deb01

trying to make programmatic way to store sentiment analysis label dictionaries for each type

Browse files
Files changed (1) hide show
  1. src/main.py +16 -6
src/main.py CHANGED
@@ -11,12 +11,27 @@ st.markdown("")
11
  def load_model(model_name):
12
  return pipeline(model=model_name, task="sentiment-analysis")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  if "model" not in st.session_state:
15
  st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
16
  st.session_state.model = load_model("cardiffnlp/twitter-roberta-base-sentiment")
 
17
 
18
  def model_change():
19
  st.session_state.model = load_model(st.session_state.model_name)
 
20
 
21
  model_option = st.selectbox(
22
  "What sentiment analysis model do you want to use?",
@@ -44,12 +59,7 @@ if submit:
44
  label = result[0]['label']
45
  score = result[0]['score']
46
 
47
- if label == "LABEL_0":
48
- label = "Negative"
49
- elif label == "LABEL_2":
50
- label = "Positive"
51
- elif label == "LABEL_1":
52
- label = "Neutral"
53
 
54
  st.markdown("#### Result:")
55
  st.markdown("**{}**: {}".format(label,score))
 
11
  def load_model(model_name):
12
  return pipeline(model=model_name, task="sentiment-analysis")
13
 
14
+ @st.cache(allow_output_mutation=True)
15
+ def label_dictionary(model_name):
16
+ if model_name == "cardiffnlp/twitter-roberta-base-sentiment":
17
+ def twitter_roberta(label):
18
+ if label == "LABEL_0":
19
+ return "Negative"
20
+ elif label == "LABEL_2":
21
+ return "Positive"
22
+ else:
23
+ return "Neutral"
24
+ return twitter_roberta
25
+ return lambda x: x
26
+
27
  if "model" not in st.session_state:
28
  st.session_state.model_name = "cardiffnlp/twitter-roberta-base-sentiment"
29
  st.session_state.model = load_model("cardiffnlp/twitter-roberta-base-sentiment")
30
+ st.session_state.label_parser = label_dictionary("cardiffnlp/twitter-roberta-base-sentiment")
31
 
32
  def model_change():
33
  st.session_state.model = load_model(st.session_state.model_name)
34
+ st.session_state.label_parser = label_dictionary(st.session_state.model_name)
35
 
36
  model_option = st.selectbox(
37
  "What sentiment analysis model do you want to use?",
 
59
  label = result[0]['label']
60
  score = result[0]['score']
61
 
62
+ label = st.session_state.label_parser(label)
 
 
 
 
 
63
 
64
  st.markdown("#### Result:")
65
  st.markdown("**{}**: {}".format(label,score))