Spaces:
Sleeping
Sleeping
import datetime | |
import json | |
import logging | |
import os | |
import random | |
import threading | |
import time | |
import uuid | |
from functools import partial | |
import gradio as gr | |
import huggingface_hub | |
import requests | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# API configuration | |
API_URL = os.getenv("API_URL") | |
API_KEY = os.getenv("API_KEY") | |
auth_token = os.environ.get("TOKEN") or True | |
hf_repo = "Iker/Feedback_FactCheking_Mar4" | |
huggingface_hub.create_repo( | |
repo_id=hf_repo, | |
repo_type="dataset", | |
token=auth_token, | |
exist_ok=True, | |
private=True, | |
) | |
headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"} | |
def update_models(): | |
models = ["pro", "pro-gpt", "pro-rag", "pro-rag-gpt"] | |
# Randomly select if we want to use pro or turbo models | |
# if random.random() < 0.5: | |
# models = models_pro | |
# else: | |
# models = models_turbo | |
# Randomly select two models | |
models = random.sample(models, 2) | |
print(f"Models updated: {models}") | |
return models | |
# Function to submit a fact-checking request | |
def submit_fact_check(article_topic, config, language, location): | |
endpoint = f"{API_URL}/fact-check" | |
payload = { | |
"article_topic": article_topic, | |
"config": config, | |
"language": language, | |
"location": location, | |
} | |
response = requests.post(endpoint, json=payload, headers=headers) | |
response.raise_for_status() # Raise an exception for HTTP errors | |
return response.json()["job_id"] | |
# Function to get the result of a fact-checking job | |
def get_fact_check_result(job_id): | |
endpoint = f"{API_URL}/result/{job_id}" | |
response = requests.get(endpoint, headers=headers) | |
response.raise_for_status() # Raise an exception for HTTP errors | |
return response.json() | |
def fact_checking(article_topic, config): | |
language = "es" | |
location = "es" | |
logger.info(f"Submitting fact-checking request for article: {article_topic}") | |
try: | |
job_id = submit_fact_check(article_topic, config, language, location) | |
logger.info(f"Fact-checking job submitted. Job ID: {job_id}") | |
# Poll for results | |
start_time = time.time() | |
while True: | |
try: | |
result = get_fact_check_result(job_id) | |
if result["status"] == "completed": | |
logger.info("Fact-checking completed:") | |
logger.info(f"Response object: {result}") | |
logger.info( | |
f"Result: {json.dumps(result['result'], indent=4, ensure_ascii=False)}" | |
) | |
return result["result"] | |
elif result["status"] == "failed": | |
logger.error("Fact-checking failed:") | |
logger.error(f"Response object: {result}") | |
logger.error(f"Error message: {result['error']}") | |
return None | |
else: | |
elapsed_time = time.time() - start_time | |
logger.info( | |
f"Fact-checking in progress. Elapsed time: {elapsed_time:.2f} seconds" | |
) | |
time.sleep(2) # Wait for 2 seconds before checking again | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error while polling for results: {e}") | |
time.sleep(2) # Wait before retrying | |
except requests.exceptions.RequestException as e: | |
logger.error(f"An error occurred while submitting the request: {e}") | |
def format_response(response): | |
title = response["metadata"]["title"] | |
main_claim = response["metadata"]["main_claim"] | |
fc = response["answer"] | |
rq = response["related_questions"] | |
rq_block = [] | |
for q, a in rq.items(): | |
rq_block.append("**{}**\n{}".format(q, a)) | |
return "## {}\n\n### {}\n\n{}\n\n{}".format(title, main_claim, fc, "\n".join(rq_block)) | |
def do_both_fact_checking(msg): | |
models = update_models() | |
results = [None, None] | |
threads = [] | |
def fact_checking_1_thread(): | |
results[0] = fact_checking(msg, config=models[0]) | |
def fact_checking_2_thread(): | |
results[1] = fact_checking(msg, config=models[1]) | |
# Start the threads | |
thread1 = threading.Thread(target=fact_checking_1_thread) | |
thread2 = threading.Thread(target=fact_checking_2_thread) | |
threads.append(thread1) | |
threads.append(thread2) | |
thread1.start() | |
thread2.start() | |
# Wait for the threads to complete | |
for thread in threads: | |
thread.join() | |
# Format the responses | |
response_1 = format_response(results[0]) if results[0] else "Error" | |
response_2 = format_response(results[1]) if results[1] else "Error" | |
history_a = [(msg, response_1)] | |
history_b = [(msg, response_2)] | |
return ("", history_a, history_b, models) | |
def save_history( | |
models, | |
history_0, | |
history_1, | |
max_new_tokens=None, | |
temperature=None, | |
top_p=None, | |
repetition_penalty=None, | |
winner=None, | |
): | |
path = f"history_{uuid.uuid4()}.json" | |
path = os.path.join("data", path) | |
os.makedirs("data", exist_ok=True) | |
data = { | |
"timestamp": datetime.datetime.now().isoformat(), | |
# "models": models, | |
"model_a": models[0], | |
"model_b": models[1], | |
"hyperparameters": { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
}, | |
"message": history_0[0][0], | |
"fc_a": history_0[0][1], | |
"fc_b": history_1[0][1], | |
"winner": winner, | |
} | |
with open(path, "w") as f: | |
json.dump(data, ensure_ascii=False, indent=4, fp=f) | |
huggingface_hub.upload_file( | |
repo_id=hf_repo, | |
repo_type="dataset", | |
token=os.environ.get("TOKEN") or True, | |
path_in_repo=path, | |
path_or_fileobj=path, | |
) | |
gr.Info("Feedback sent successfully! Thank you for your help.") | |
return [], [] | |
with gr.Blocks( | |
theme="gradio/soft", | |
fill_height=True, | |
fill_width=True, | |
analytics_enabled=False, | |
title="Fact Cheking Demo", | |
css=".center-text { text-align: center; } footer {visibility: hidden;} .avatar-container {width: 50px; height: 50px; border: none;}", | |
) as demo: | |
gr.Markdown("# Fact Checking Arena", elem_classes="center-text") | |
models = gr.State([]) | |
with gr.Row(): | |
with gr.Column(): | |
chatbot_a = gr.Chatbot( | |
height=800, | |
show_copy_all_button=True, | |
avatar_images=[ | |
None, | |
"https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg", | |
], | |
) | |
with gr.Column(): | |
chatbot_b = gr.Chatbot( | |
show_copy_all_button=True, | |
height=800, | |
avatar_images=[ | |
None, | |
"https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg", | |
], | |
) | |
msg = gr.Textbox( | |
label="Introduce que quieres verificar", | |
placeholder="Los coches electricos contaminan más que los coches de gasolina", | |
autofocus=True, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
left = gr.Button("👈 Izquierda mejor") | |
with gr.Column(): | |
tie = gr.Button("🤝 Iguald de buenos") | |
with gr.Column(): | |
fail = gr.Button("👎 Igual de malos") | |
with gr.Column(): | |
right = gr.Button("👉 Derecha mejor") | |
msg.submit( | |
do_both_fact_checking, | |
inputs=[ | |
msg, | |
], | |
outputs=[msg, chatbot_a, chatbot_b, models], | |
) | |
left.click( | |
partial( | |
save_history, | |
winner="model_a", | |
), | |
inputs=[ | |
models, | |
chatbot_a, | |
chatbot_b, | |
], | |
outputs=[chatbot_a, chatbot_b], | |
) | |
tie.click( | |
partial( | |
save_history, | |
winner="tie", | |
), | |
inputs=[ | |
models, | |
chatbot_a, | |
chatbot_b, | |
], | |
outputs=[chatbot_a, chatbot_b], | |
) | |
fail.click( | |
partial( | |
save_history, | |
winner="tie (both bad)", | |
), | |
inputs=[ | |
models, | |
chatbot_a, | |
chatbot_b, | |
], | |
outputs=[chatbot_a, chatbot_b], | |
) | |
right.click( | |
partial( | |
save_history, | |
winner="model_b", | |
), | |
inputs=[ | |
models, | |
chatbot_a, | |
chatbot_b, | |
], | |
outputs=[chatbot_a, chatbot_b], | |
) | |
demo.load(update_models, inputs=[], outputs=[models]) | |
demo.launch( | |
# auth=(os.getenv("GRADIO_USERNAME"), os.getenv("GRADIO_PASSWORD")),share=True | |
) | |