FactCheking-Arena / compare_gradio.py
Iker's picture
Upload compare_gradio.py
c32dad6 verified
import datetime
import json
import logging
import os
import random
import threading
import time
import uuid
from functools import partial
import gradio as gr
import huggingface_hub
import requests
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# API configuration
API_URL = os.getenv("API_URL")
API_KEY = os.getenv("API_KEY")
auth_token = os.environ.get("TOKEN") or True
hf_repo = "Iker/Feedback_FactCheking_Mar4"
huggingface_hub.create_repo(
repo_id=hf_repo,
repo_type="dataset",
token=auth_token,
exist_ok=True,
private=True,
)
headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"}
def update_models():
models = ["pro", "pro-gpt", "pro-rag", "pro-rag-gpt"]
# Randomly select if we want to use pro or turbo models
# if random.random() < 0.5:
# models = models_pro
# else:
# models = models_turbo
# Randomly select two models
models = random.sample(models, 2)
print(f"Models updated: {models}")
return models
# Function to submit a fact-checking request
def submit_fact_check(article_topic, config, language, location):
endpoint = f"{API_URL}/fact-check"
payload = {
"article_topic": article_topic,
"config": config,
"language": language,
"location": location,
}
response = requests.post(endpoint, json=payload, headers=headers)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()["job_id"]
# Function to get the result of a fact-checking job
def get_fact_check_result(job_id):
endpoint = f"{API_URL}/result/{job_id}"
response = requests.get(endpoint, headers=headers)
response.raise_for_status() # Raise an exception for HTTP errors
return response.json()
def fact_checking(article_topic, config):
language = "es"
location = "es"
logger.info(f"Submitting fact-checking request for article: {article_topic}")
try:
job_id = submit_fact_check(article_topic, config, language, location)
logger.info(f"Fact-checking job submitted. Job ID: {job_id}")
# Poll for results
start_time = time.time()
while True:
try:
result = get_fact_check_result(job_id)
if result["status"] == "completed":
logger.info("Fact-checking completed:")
logger.info(f"Response object: {result}")
logger.info(
f"Result: {json.dumps(result['result'], indent=4, ensure_ascii=False)}"
)
return result["result"]
elif result["status"] == "failed":
logger.error("Fact-checking failed:")
logger.error(f"Response object: {result}")
logger.error(f"Error message: {result['error']}")
return None
else:
elapsed_time = time.time() - start_time
logger.info(
f"Fact-checking in progress. Elapsed time: {elapsed_time:.2f} seconds"
)
time.sleep(2) # Wait for 2 seconds before checking again
except requests.exceptions.RequestException as e:
logger.error(f"Error while polling for results: {e}")
time.sleep(2) # Wait before retrying
except requests.exceptions.RequestException as e:
logger.error(f"An error occurred while submitting the request: {e}")
def format_response(response):
title = response["metadata"]["title"]
main_claim = response["metadata"]["main_claim"]
fc = response["answer"]
rq = response["related_questions"]
rq_block = []
for q, a in rq.items():
rq_block.append("**{}**\n{}".format(q, a))
return "## {}\n\n### {}\n\n{}\n\n{}".format(title, main_claim, fc, "\n".join(rq_block))
def do_both_fact_checking(msg):
models = update_models()
results = [None, None]
threads = []
def fact_checking_1_thread():
results[0] = fact_checking(msg, config=models[0])
def fact_checking_2_thread():
results[1] = fact_checking(msg, config=models[1])
# Start the threads
thread1 = threading.Thread(target=fact_checking_1_thread)
thread2 = threading.Thread(target=fact_checking_2_thread)
threads.append(thread1)
threads.append(thread2)
thread1.start()
thread2.start()
# Wait for the threads to complete
for thread in threads:
thread.join()
# Format the responses
response_1 = format_response(results[0]) if results[0] else "Error"
response_2 = format_response(results[1]) if results[1] else "Error"
history_a = [(msg, response_1)]
history_b = [(msg, response_2)]
return ("", history_a, history_b, models)
def save_history(
models,
history_0,
history_1,
max_new_tokens=None,
temperature=None,
top_p=None,
repetition_penalty=None,
winner=None,
):
path = f"history_{uuid.uuid4()}.json"
path = os.path.join("data", path)
os.makedirs("data", exist_ok=True)
data = {
"timestamp": datetime.datetime.now().isoformat(),
# "models": models,
"model_a": models[0],
"model_b": models[1],
"hyperparameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"message": history_0[0][0],
"fc_a": history_0[0][1],
"fc_b": history_1[0][1],
"winner": winner,
}
with open(path, "w") as f:
json.dump(data, ensure_ascii=False, indent=4, fp=f)
huggingface_hub.upload_file(
repo_id=hf_repo,
repo_type="dataset",
token=os.environ.get("TOKEN") or True,
path_in_repo=path,
path_or_fileobj=path,
)
gr.Info("Feedback sent successfully! Thank you for your help.")
return [], []
with gr.Blocks(
theme="gradio/soft",
fill_height=True,
fill_width=True,
analytics_enabled=False,
title="Fact Cheking Demo",
css=".center-text { text-align: center; } footer {visibility: hidden;} .avatar-container {width: 50px; height: 50px; border: none;}",
) as demo:
gr.Markdown("# Fact Checking Arena", elem_classes="center-text")
models = gr.State([])
with gr.Row():
with gr.Column():
chatbot_a = gr.Chatbot(
height=800,
show_copy_all_button=True,
avatar_images=[
None,
"https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg",
],
)
with gr.Column():
chatbot_b = gr.Chatbot(
show_copy_all_button=True,
height=800,
avatar_images=[
None,
"https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg",
],
)
msg = gr.Textbox(
label="Introduce que quieres verificar",
placeholder="Los coches electricos contaminan más que los coches de gasolina",
autofocus=True,
)
with gr.Row():
with gr.Column():
left = gr.Button("👈 Izquierda mejor")
with gr.Column():
tie = gr.Button("🤝 Iguald de buenos")
with gr.Column():
fail = gr.Button("👎 Igual de malos")
with gr.Column():
right = gr.Button("👉 Derecha mejor")
msg.submit(
do_both_fact_checking,
inputs=[
msg,
],
outputs=[msg, chatbot_a, chatbot_b, models],
)
left.click(
partial(
save_history,
winner="model_a",
),
inputs=[
models,
chatbot_a,
chatbot_b,
],
outputs=[chatbot_a, chatbot_b],
)
tie.click(
partial(
save_history,
winner="tie",
),
inputs=[
models,
chatbot_a,
chatbot_b,
],
outputs=[chatbot_a, chatbot_b],
)
fail.click(
partial(
save_history,
winner="tie (both bad)",
),
inputs=[
models,
chatbot_a,
chatbot_b,
],
outputs=[chatbot_a, chatbot_b],
)
right.click(
partial(
save_history,
winner="model_b",
),
inputs=[
models,
chatbot_a,
chatbot_b,
],
outputs=[chatbot_a, chatbot_b],
)
demo.load(update_models, inputs=[], outputs=[models])
demo.launch(
# auth=(os.getenv("GRADIO_USERNAME"), os.getenv("GRADIO_PASSWORD")),share=True
)