Arena / app.py
FredOru's picture
(feat) add prompt
0fab24c
raw
history blame
8.54 kB
import gradio as gr
import pandas as pd
import os
import time
from threading import Thread
from arena import PromptArena
LABEL_A = "Proposition A"
LABEL_B = "Proposition B"
class PromptArenaApp:
"""
Classe pour encapsuler l'arène et gérer l'interface Gradio.
"""
def __init__(self, arena: PromptArena) -> None:
"""
Initialise l'application et charge les prompts depuis le fichier CSV.
"""
self.arena: PromptArena = arena
def select_and_display_match(self):
"""
Sélectionne un match et l'affiche.
Returns:
Tuple contenant:
- Le texte du premier prompt
- Le texte du second prompt
- Un dictionnaire d'état contenant les IDs des prompts
"""
try:
prompt_a_id, prompt_b_id = self.arena.select_match()
prompt_a_text = self.arena.prompts.get(prompt_a_id, "")
prompt_b_text = self.arena.prompts.get(prompt_b_id, "")
state = {"prompt_a_id": prompt_a_id, "prompt_b_id": prompt_b_id}
return (
prompt_a_text,
prompt_b_text,
state,
gr.update(interactive=True), # button A
gr.update(interactive=True), # button B
gr.update(interactive=False), # match button
)
except Exception as e:
return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {}
def record_winner_a(self, state: dict[str, str]):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
self.arena.record_result(
prompt_a_id, prompt_b_id
) # Mettre à jour la progression et le classement
progress_info = self.get_progress_info()
rankings_table = self.get_rankings_table()
return (
f"Vous avez choisi : {LABEL_A}",
progress_info,
rankings_table,
gr.update(interactive=False), # button A
gr.update(interactive=False), # button B
gr.update(interactive=True), # match button
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def record_winner_b(self, state: dict[str, str]):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
self.arena.record_result(
prompt_b_id, prompt_a_id
) # Mettre à jour la progression et le classement
progress_info = self.get_progress_info()
rankings_table = self.get_rankings_table()
return (
f"Vous avez choisi : {LABEL_B}",
progress_info,
rankings_table,
gr.update(interactive=False), # button A
gr.update(interactive=False), # button B
gr.update(interactive=True), # match button
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def get_progress_info(self) -> str:
"""
Obtient les informations sur la progression du tournoi.
Returns:
str: Message formaté contenant les statistiques de progression
"""
if not self.arena:
return "Aucune arène initialisée. Veuillez d'abord charger des prompts."
try:
progress = self.arena.get_progress()
info = f"Prompts: {progress['total_prompts']}\n"
info += f"Matchs joués: {progress['total_matches']}\n"
info += f"Progression: {progress['progress']:.2f}%\n"
info += (
f"Matchs restants estimés: {progress['estimated_remaining_matches']}\n"
)
info += f"Incertitude moyenne (σ): {progress['avg_sigma']:.4f}"
return info
except Exception as e:
return f"Erreur lors de la récupération de la progression: {str(e)}"
def get_rankings_table(self) -> pd.DataFrame:
"""
Obtient le classement des prompts sous forme de tableau.
Returns:
pd.DataFrame: Tableau de classement des prompts
"""
if not self.arena:
return pd.DataFrame([{"Erreur": "Aucune arène initialisée"}])
try:
rankings = self.arena.get_rankings()
df = pd.DataFrame(rankings)
df = df[["rank", "prompt_id", "score"]]
df = df.rename(
columns={
"rank": "Rang",
"prompt_id": "ID",
"score": "Score",
}
)
return df
except Exception as e:
return pd.DataFrame([{"Erreur": str(e)}])
def create_ui(self) -> gr.Blocks:
"""
Crée l'interface utilisateur Gradio.
Returns:
gr.Blocks: L'application Gradio configurée
"""
with gr.Blocks(title="Prompt Arena", theme=gr.themes.Ocean()) as app:
gr.Markdown('<h1 style="text-align:center;">🥊 Prompt Arena 🥊</h1>')
with gr.Row():
select_btn = gr.Button("Lancer un nouveau match", variant="primary")
with gr.Row():
proposition_a = gr.Textbox(label=LABEL_A, interactive=False)
proposition_b = gr.Textbox(label=LABEL_B, interactive=False)
with gr.Row():
vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False)
vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False)
result = gr.Textbox("Résultat", interactive=False)
progress_info = gr.Textbox(
label="Progression du concours", interactive=False
)
rankings_table = gr.DataFrame(label="Classement des prompts")
state = gr.State() # contient les IDs des prompts du match en cours
select_btn.click(
self.select_and_display_match,
inputs=[],
outputs=[
proposition_a,
proposition_b,
state,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
vote_a_btn.click(
self.record_winner_a,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
vote_b_btn.click(
self.record_winner_b,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
with gr.Row():
prompt_id_box = gr.Textbox(label="ID du prompt", interactive=True)
prompt_text_box = gr.Textbox(label="Texte du prompt", interactive=True)
save_btn = gr.Button("Ajouter le prompt", variant="secondary")
def save_prompt(prompt_id, prompt_text):
try:
self.arena.add_prompt(prompt_id, prompt_text)
return gr.update(value="Prompt ajouté avec succès !")
except Exception as e:
return gr.update(value=f"Erreur: {str(e)}")
save_result = gr.Textbox(label="Résultat de l'ajout", interactive=False)
save_btn.click(
save_prompt,
inputs=[prompt_id_box, prompt_text_box],
outputs=save_result,
)
gr.Row([progress_info, rankings_table])
return app
# Exemple d'utilisation
if __name__ == "__main__":
# load the prompts from the CSV file
prompts = pd.read_csv("prompts.csv", header=None).iloc[:, 0].tolist()
arena = PromptArena(prompts=prompts)
app_instance = PromptArenaApp(arena=arena)
app = app_instance.create_ui()
app.launch()