echung682 commited on
Commit
c280082
·
verified ·
1 Parent(s): 89ee5f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ AutoModelForSequenceClassification,
4
+ pipeline
5
+ )
6
+ from datasets import load_dataset
7
+ import json
8
+ import os
9
+ import subprocess
10
+
11
+
12
+ #importing the model
13
+ model_ckpt = "echung682/finetuned-emotion-ai-model"
14
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
15
+ pipe = pipeline(model=model_ckpt)
16
+
17
+ #importing the dataset (a whole bunch of text)
18
+ emotion_dataset = load_dataset("echung682/emotion-analysis-tweets")
19
+
20
+ #in order to keep the data persistent on HuggingFace repo
21
+ def save_to_repo():
22
+ # Add & commit the latest flagged.csv file to the Hugging Face Space repo
23
+ os.system("git pull origin main") # Pull latest changes (to avoid conflicts)
24
+ os.system("git add feedback_data/flagged.csv")
25
+ os.system('git commit -m "Update flagged data"')
26
+ os.system("git push origin main") # Push updated file to the repo
27
+
28
+
29
+ '''
30
+ in order to keep track of what the last prompt was that was given human feedback
31
+ '''
32
+ def load_state():
33
+ try:
34
+ with open("state.json", "r") as f:
35
+ return json.load(f).get("count", 0)
36
+ except FileNotFoundError:
37
+ return 0 #if the file doesn't have count variable in it, then it will return 0, which is good - that's the first index
38
+
39
+ # Save state to file
40
+ def save_state(count):
41
+ with open("state.json", "w") as f:
42
+ json.dump({"count": count}, f)
43
+
44
+ def increment():
45
+ count = load_state()
46
+ count += 1
47
+ save_state(count)
48
+ return count
49
+
50
+ def save_state_to_repo():
51
+ os.system("git pull origin main") # Pull latest changes (to avoid conflicts)
52
+ os.system("git add state.json")
53
+ os.system('git commit -m "Update state"')
54
+ os.system("git push origin main") # Push updated file to the repo
55
+
56
+
57
+ '''
58
+ keeping track of the prompt, options, and chosen option
59
+ then increasing the index number (so it doesn't ask everyone to look at the same ones)
60
+ writes the new data into the Gradio file
61
+ pushes the new data and the index number into their respective files to keep track across multiple users
62
+ '''
63
+ def updateDataset(prompt, option1, option2, flagged_option):
64
+ # This function is called when a user clicks a flagging button.
65
+ if flagged_option == option1:
66
+ chosen = option1
67
+ rejected = option2
68
+ elif flagged_option == option2:
69
+ chosen = option2
70
+ rejected = option1
71
+ else: # Handle unexpected cases (shouldn't happen with radio buttons)
72
+ chosen = ""
73
+ rejected = ""
74
+
75
+ index = increment()
76
+
77
+ with open("feedback_data/flagged.csv", "a") as f:
78
+ f.write(f"{prompt},{chosen},{rejected}\n")
79
+
80
+ # Push the updated file to the repo
81
+ save_to_repo() #all of the inputs and outputs for the Gradio interface, that will save to the feedback_data file (and then pushed to HuggingFace repo)
82
+ save_state_to_repo()
83
+
84
+ return prompt, chosen, rejected, "Submitted! Please answer another...", index
85
+
86
+
87
+ '''
88
+ finding the correct prompt based on the global index
89
+ extracting the top two scoring emotions
90
+ returning these
91
+ '''
92
+ def emotion_analysis_data_collection():
93
+ index = load_state()
94
+ result = pipe(emotion_dataset["train"]["text"][index], top_k = None)
95
+ score_list = [] #empty list to hold the scores
96
+ emotion_list = [] #empty list to hold the emotions
97
+
98
+ for emotion in result:
99
+ emotion_list.append(emotion["label"]) #extracting the emotions from the results
100
+ score_list.append(emotion["score"]) #extracing the scores from the results
101
+
102
+ emotion_dict = {}
103
+ for index, value in enumerate(emotion_list):
104
+ emotion_dict[value] = score_list[index]
105
+
106
+ dictKeys_list = list(emotion_dict.keys())
107
+ emotion_highestScore = dictKeys_list[0]
108
+ emotion_secondHighestScore = dictKeys_list[1]
109
+
110
+ #print(emotion_highestScore)
111
+ #print(emotion_secondHighestScore)
112
+ #print(" ")
113
+
114
+ return emotion_dataset["train"]["text"][index], emotion_highestScore, emotion_secondHighestScore
115
+
116
+
117
+
118
+
119
+ '''
120
+ designing the gradio interface
121
+ has the two options and a Radio object that will keep track of the chosen emotion
122
+ '''
123
+ with gr.Blocks() as survey:
124
+ gr.Markdown(
125
+ """
126
+ # Please choose the emotion that best describes the prompt
127
+ """
128
+ )
129
+
130
+ tweet, emotion_highestScore, emotion_secondHighestScore = emotion_analysis_data_collection() #calls the function that figures out what the prompt and two highest scoring emotions are
131
+ sentence = gr.Textbox(tweet, label="Prompt:", interactive=False)
132
+
133
+ #print(emotion_highestScore)
134
+ #print(emotion_secondHighestScore)
135
+
136
+ #testOutput = gr.Textbox()
137
+
138
+ with gr.Row():
139
+ emotion1 = gr.Textbox(emotion_highestScore, label="Emotion Choice 1:", interactive=False)
140
+ emotion2 = gr.Textbox(emotion_secondHighestScore, label="Emotion Choice 2", interactive=False)
141
+
142
+ options = gr.Radio([emotion_highestScore, emotion_secondHighestScore], label="Choose one:")
143
+
144
+ submit_btn = gr.Button("Submit Choice")
145
+
146
+ submit_btn.click(fn=updateDataset,
147
+ inputs=[sentence, emotion1, emotion2, options],
148
+ outputs=[gr.Textbox(label="Prompt"), gr.Textbox(label="Chosen Response"), gr.Textbox(label="Rejected Response"), gr.Textbox(label="Confirmation Message"), gr.Textbox(label="Prompt Number")],
149
+ )
150
+ survey.launch()