Iker commited on
Commit
551c0f3
·
verified ·
1 Parent(s): c82a4d0

Upload compare_gradio.py

Browse files
Files changed (1) hide show
  1. compare_gradio.py +305 -0
compare_gradio.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import threading
5
+ import time
6
+ import uuid
7
+ import datetime
8
+ import gradio as gr
9
+ import huggingface_hub
10
+ import requests
11
+ import random
12
+ from functools import partial
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ # API configuration
19
+ API_URL = os.getenv("API_URL")
20
+ API_KEY = os.getenv("API_KEY")
21
+
22
+ auth_token = os.environ.get("TOKEN") or True
23
+ hf_repo = "Iker/Feedback_FactCheking"
24
+ huggingface_hub.create_repo(
25
+ repo_id=hf_repo,
26
+ repo_type="dataset",
27
+ token=auth_token,
28
+ exist_ok=True,
29
+ private=True,
30
+ )
31
+
32
+
33
+ headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
34
+
35
+
36
+ def update_models():
37
+ models = random.sample(["pro", "pro2", "turbo", "turbo2"], 2)
38
+ print(f"Models updated: {models}")
39
+ return models
40
+
41
+
42
+ # Function to submit a fact-checking request
43
+ def submit_fact_check(article_topic, config, language, location):
44
+ endpoint = f"{API_URL}/fact-check"
45
+ payload = {
46
+ "article_topic": article_topic,
47
+ "config": config,
48
+ "language": language,
49
+ "location": location,
50
+ }
51
+
52
+ response = requests.post(endpoint, json=payload, headers=headers)
53
+ response.raise_for_status() # Raise an exception for HTTP errors
54
+ return response.json()["job_id"]
55
+
56
+
57
+ # Function to get the result of a fact-checking job
58
+ def get_fact_check_result(job_id):
59
+ endpoint = f"{API_URL}/result/{job_id}"
60
+
61
+ response = requests.get(endpoint, headers=headers)
62
+ response.raise_for_status() # Raise an exception for HTTP errors
63
+ return response.json()
64
+
65
+
66
+ def fact_checking(article_topic, config):
67
+ language = "es"
68
+ location = "es"
69
+ logger.info(f"Submitting fact-checking request for article: {article_topic}")
70
+ try:
71
+ job_id = submit_fact_check(article_topic, config, language, location)
72
+ logger.info(f"Fact-checking job submitted. Job ID: {job_id}")
73
+
74
+ # Poll for results
75
+ start_time = time.time()
76
+ while True:
77
+ try:
78
+ result = get_fact_check_result(job_id)
79
+ if result["status"] == "completed":
80
+ logger.info("Fact-checking completed:")
81
+ logger.info(f"Response object: {result}")
82
+ logger.info(
83
+ f"Result: {json.dumps(result['result'], indent=4, ensure_ascii=False)}"
84
+ )
85
+ return result["result"]
86
+ elif result["status"] == "failed":
87
+ logger.error("Fact-checking failed:")
88
+ logger.error(f"Response object: {result}")
89
+ logger.error(f"Error message: {result['error']}")
90
+ return None
91
+ else:
92
+ elapsed_time = time.time() - start_time
93
+ logger.info(
94
+ f"Fact-checking in progress. Elapsed time: {elapsed_time:.2f} seconds"
95
+ )
96
+ time.sleep(2) # Wait for 2 seconds before checking again
97
+ except requests.exceptions.RequestException as e:
98
+ logger.error(f"Error while polling for results: {e}")
99
+ time.sleep(2) # Wait before retrying
100
+
101
+ except requests.exceptions.RequestException as e:
102
+ logger.error(f"An error occurred while submitting the request: {e}")
103
+
104
+
105
+ def format_response(response):
106
+ title = response["metadata"]["title"]
107
+ main_claim = response["metadata"]["main_claim"]
108
+
109
+ fc = response["answer"]
110
+
111
+ rq = response["related_questions"]
112
+
113
+ rq_block = []
114
+ for q, a in rq.items():
115
+ rq_block.append(f"**{q}**\n{a}")
116
+
117
+ return f"## {title}\n\n### {main_claim}\n\n{fc}\n\n{'\n'.join(rq_block)}"
118
+
119
+
120
+ def do_both_fact_checking(msg):
121
+ models = update_models()
122
+
123
+ results = [None, None]
124
+ threads = []
125
+
126
+ def fact_checking_1_thread():
127
+ results[0] = fact_checking(msg, config=models[0])
128
+
129
+ def fact_checking_2_thread():
130
+ results[1] = fact_checking(msg, config=models[1])
131
+
132
+ # Start the threads
133
+ thread1 = threading.Thread(target=fact_checking_1_thread)
134
+ thread2 = threading.Thread(target=fact_checking_2_thread)
135
+ threads.append(thread1)
136
+ threads.append(thread2)
137
+ thread1.start()
138
+ thread2.start()
139
+
140
+ # Wait for the threads to complete
141
+ for thread in threads:
142
+ thread.join()
143
+
144
+ # Format the responses
145
+ response_1 = format_response(results[0]) if results[0] else None
146
+ response_2 = format_response(results[1]) if results[1] else None
147
+ history_a = [(msg, response_1)]
148
+ history_b = [(msg, response_2)]
149
+ return ("", history_a, history_b, models)
150
+
151
+
152
+ def save_history(
153
+ models,
154
+ history_0,
155
+ history_1,
156
+ max_new_tokens=None,
157
+ temperature=None,
158
+ top_p=None,
159
+ repetition_penalty=None,
160
+ winner=None,
161
+ ):
162
+ path = f"history_{uuid.uuid4()}.json"
163
+ path = os.path.join("data", path)
164
+ os.makedirs("data", exist_ok=True)
165
+ data = {
166
+ "timestamp": datetime.datetime.now().isoformat(),
167
+ # "models": models,
168
+ "model_a": models[0],
169
+ "model_b": models[1],
170
+ "hyperparameters": {
171
+ "max_new_tokens": max_new_tokens,
172
+ "temperature": temperature,
173
+ "top_p": top_p,
174
+ "repetition_penalty": repetition_penalty,
175
+ },
176
+ "message": history_0[0][0],
177
+ "fc_a": history_0[0][1],
178
+ "fc_b": history_1[0][1],
179
+ "winner": winner,
180
+ }
181
+
182
+ with open(path, "w") as f:
183
+ json.dump(data, ensure_ascii=False, indent=4, fp=f)
184
+
185
+ huggingface_hub.upload_file(
186
+ repo_id=hf_repo,
187
+ repo_type="dataset",
188
+ token=os.environ.get("TOKEN") or True,
189
+ path_in_repo=path,
190
+ path_or_fileobj=path,
191
+ )
192
+
193
+ gr.Info("Feedback sent successfully! Thank you for your help.")
194
+
195
+
196
+ with gr.Blocks(
197
+ theme="gradio/soft",
198
+ fill_height=True,
199
+ fill_width=True,
200
+ analytics_enabled=False,
201
+ title="Fact Cheking Demo",
202
+ css=".center-text { text-align: center; } footer {visibility: hidden;} .avatar-container {width: 50px; height: 50px; border: none;}",
203
+ ) as demo:
204
+ gr.Markdown("# Fact Checking Arena", elem_classes="center-text")
205
+ models = gr.State([])
206
+ with gr.Row():
207
+ with gr.Column():
208
+ chatbot_a = gr.Chatbot(
209
+ height=800,
210
+ show_copy_all_button=True,
211
+ avatar_images=[
212
+ None,
213
+ "https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg",
214
+ ],
215
+ )
216
+
217
+ with gr.Column():
218
+ chatbot_b = gr.Chatbot(
219
+ show_copy_all_button=True,
220
+ height=800,
221
+ avatar_images=[
222
+ None,
223
+ "https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg",
224
+ ],
225
+ )
226
+
227
+ msg = gr.Textbox(
228
+ label="Introduce que quieres verificar",
229
+ placeholder="Los coches electricos contaminan más que los coches de gasolina",
230
+ autofocus=True,
231
+ )
232
+
233
+ with gr.Row():
234
+ with gr.Column():
235
+ left = gr.Button("👈 Derecha mejor")
236
+ with gr.Column():
237
+ tie = gr.Button("🤝 Iguald de buenos")
238
+ with gr.Column():
239
+ fail = gr.Button("👎 Igual de malos")
240
+ with gr.Column():
241
+ right = gr.Button("👉 Iquierda mejor")
242
+
243
+ msg.submit(
244
+ do_both_fact_checking,
245
+ inputs=[
246
+ msg,
247
+ ],
248
+ outputs=[msg, chatbot_a, chatbot_b, models],
249
+ )
250
+
251
+ left.click(
252
+ partial(
253
+ save_history,
254
+ winner="model_a",
255
+ ),
256
+ inputs=[
257
+ models,
258
+ chatbot_a,
259
+ chatbot_b,
260
+ ],
261
+ )
262
+
263
+ tie.click(
264
+ partial(
265
+ save_history,
266
+ winner="tie",
267
+ ),
268
+ inputs=[
269
+ models,
270
+ chatbot_a,
271
+ chatbot_b,
272
+ ],
273
+ )
274
+
275
+ fail.click(
276
+ partial(
277
+ save_history,
278
+ winner="tie (both bad)",
279
+ ),
280
+ inputs=[
281
+ models,
282
+ chatbot_a,
283
+ chatbot_b,
284
+ ],
285
+ )
286
+
287
+ right.click(
288
+ partial(
289
+ save_history,
290
+ winner="model_b",
291
+ ),
292
+ inputs=[
293
+ models,
294
+ chatbot_a,
295
+ chatbot_b,
296
+ ],
297
+ )
298
+
299
+ demo.load(update_models, inputs=[], outputs=[models])
300
+
301
+ demo.launch(
302
+ server_name="0.0.0.0",
303
+ server_port=7860,
304
+ auth=(os.getenv("GRADIO_USERNAME"), os.getenv("GRADIO_PASSWORD")),
305
+ )