|
import gradio as gr |
|
import pandas as pd |
|
import os |
|
import time |
|
from threading import Thread |
|
from arena import PromptArena |
|
|
|
LABEL_A = "Proposition A" |
|
LABEL_B = "Proposition B" |
|
|
|
|
|
class PromptArenaApp: |
|
""" |
|
Classe pour encapsuler l'arène et gérer l'interface Gradio. |
|
""" |
|
|
|
def __init__(self, arena: PromptArena) -> None: |
|
""" |
|
Initialise l'application et charge les prompts depuis le fichier CSV. |
|
""" |
|
self.arena: PromptArena = arena |
|
|
|
def select_and_display_match(self): |
|
""" |
|
Sélectionne un match et l'affiche. |
|
|
|
Returns: |
|
Tuple contenant: |
|
- Le texte du premier prompt |
|
- Le texte du second prompt |
|
- Un dictionnaire d'état contenant les IDs des prompts |
|
""" |
|
|
|
try: |
|
prompt_a_id, prompt_b_id = self.arena.select_match() |
|
prompt_a_text = self.arena.prompts.get(prompt_a_id, "") |
|
prompt_b_text = self.arena.prompts.get(prompt_b_id, "") |
|
|
|
state = {"prompt_a_id": prompt_a_id, "prompt_b_id": prompt_b_id} |
|
|
|
return ( |
|
prompt_a_text, |
|
prompt_b_text, |
|
state, |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=False), |
|
) |
|
except Exception as e: |
|
return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {} |
|
|
|
def record_winner_a(self, state: dict[str, str]): |
|
try: |
|
prompt_a_id = state["prompt_a_id"] |
|
prompt_b_id = state["prompt_b_id"] |
|
|
|
self.arena.record_result( |
|
prompt_a_id, prompt_b_id |
|
) |
|
progress_info = self.get_progress_info() |
|
rankings_table = self.get_rankings_table() |
|
|
|
return ( |
|
f"Vous avez choisi : {LABEL_A}", |
|
progress_info, |
|
rankings_table, |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=True), |
|
) |
|
except Exception as e: |
|
return ( |
|
f"Erreur lors de l'enregistrement du résultat: {str(e)}", |
|
"", |
|
pd.DataFrame(), |
|
) |
|
|
|
def record_winner_b(self, state: dict[str, str]): |
|
try: |
|
prompt_a_id = state["prompt_a_id"] |
|
prompt_b_id = state["prompt_b_id"] |
|
|
|
self.arena.record_result( |
|
prompt_b_id, prompt_a_id |
|
) |
|
progress_info = self.get_progress_info() |
|
rankings_table = self.get_rankings_table() |
|
|
|
return ( |
|
f"Vous avez choisi : {LABEL_B}", |
|
progress_info, |
|
rankings_table, |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=True), |
|
) |
|
except Exception as e: |
|
return ( |
|
f"Erreur lors de l'enregistrement du résultat: {str(e)}", |
|
"", |
|
pd.DataFrame(), |
|
) |
|
|
|
def get_progress_info(self) -> str: |
|
""" |
|
Obtient les informations sur la progression du tournoi. |
|
|
|
Returns: |
|
str: Message formaté contenant les statistiques de progression |
|
""" |
|
if not self.arena: |
|
return "Aucune arène initialisée. Veuillez d'abord charger des prompts." |
|
|
|
try: |
|
progress = self.arena.get_progress() |
|
|
|
info = f"Prompts: {progress['total_prompts']}\n" |
|
info += f"Matchs joués: {progress['total_matches']}\n" |
|
info += f"Progression: {progress['progress']:.2f}%\n" |
|
info += ( |
|
f"Matchs restants estimés: {progress['estimated_remaining_matches']}\n" |
|
) |
|
info += f"Incertitude moyenne (σ): {progress['avg_sigma']:.4f}" |
|
|
|
return info |
|
except Exception as e: |
|
return f"Erreur lors de la récupération de la progression: {str(e)}" |
|
|
|
def get_rankings_table(self) -> pd.DataFrame: |
|
""" |
|
Obtient le classement des prompts sous forme de tableau. |
|
|
|
Returns: |
|
pd.DataFrame: Tableau de classement des prompts |
|
""" |
|
if not self.arena: |
|
return pd.DataFrame([{"Erreur": "Aucune arène initialisée"}]) |
|
|
|
try: |
|
rankings = self.arena.get_rankings() |
|
|
|
df = pd.DataFrame(rankings) |
|
df = df[["rank", "prompt_id", "score"]] |
|
df = df.rename( |
|
columns={ |
|
"rank": "Rang", |
|
"prompt_id": "ID", |
|
"score": "Score", |
|
} |
|
) |
|
|
|
return df |
|
except Exception as e: |
|
return pd.DataFrame([{"Erreur": str(e)}]) |
|
|
|
def create_ui(self) -> gr.Blocks: |
|
""" |
|
Crée l'interface utilisateur Gradio. |
|
|
|
Returns: |
|
gr.Blocks: L'application Gradio configurée |
|
""" |
|
|
|
with gr.Blocks(title="Prompt Arena", theme=gr.themes.Ocean()) as app: |
|
gr.Markdown('<h1 style="text-align:center;">🥊 Prompt Arena 🥊</h1>') |
|
|
|
with gr.Row(): |
|
select_btn = gr.Button("Lancer un nouveau match", variant="primary") |
|
|
|
with gr.Row(): |
|
proposition_a = gr.Textbox(label=LABEL_A, interactive=False) |
|
proposition_b = gr.Textbox(label=LABEL_B, interactive=False) |
|
|
|
with gr.Row(): |
|
vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False) |
|
vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False) |
|
|
|
result = gr.Textbox("Résultat", interactive=False) |
|
progress_info = gr.Textbox( |
|
label="Progression du concours", interactive=False |
|
) |
|
rankings_table = gr.DataFrame(label="Classement des prompts") |
|
state = gr.State() |
|
|
|
select_btn.click( |
|
self.select_and_display_match, |
|
inputs=[], |
|
outputs=[ |
|
proposition_a, |
|
proposition_b, |
|
state, |
|
vote_a_btn, |
|
vote_b_btn, |
|
select_btn, |
|
], |
|
) |
|
vote_a_btn.click( |
|
self.record_winner_a, |
|
inputs=[state], |
|
outputs=[ |
|
result, |
|
progress_info, |
|
rankings_table, |
|
vote_a_btn, |
|
vote_b_btn, |
|
select_btn, |
|
], |
|
) |
|
vote_b_btn.click( |
|
self.record_winner_b, |
|
inputs=[state], |
|
outputs=[ |
|
result, |
|
progress_info, |
|
rankings_table, |
|
vote_a_btn, |
|
vote_b_btn, |
|
select_btn, |
|
], |
|
) |
|
|
|
gr.Row([progress_info, rankings_table]) |
|
|
|
return app |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
prompts = pd.read_csv("prompts.csv", header=None).iloc[:, 0].tolist() |
|
arena = PromptArena(prompts=prompts) |
|
app_instance = PromptArenaApp(arena=arena) |
|
app = app_instance.create_ui() |
|
app.launch() |
|
|