Arena / app.py
FredOru's picture
first working draft
330f0b8
raw
history blame contribute delete
7.69 kB
import gradio as gr
import pandas as pd
import os
import time
from threading import Thread
from arena import PromptArena
LABEL_A = "Proposition A"
LABEL_B = "Proposition B"
class PromptArenaApp:
"""
Classe pour encapsuler l'arène et gérer l'interface Gradio.
"""
def __init__(self, arena: PromptArena) -> None:
"""
Initialise l'application et charge les prompts depuis le fichier CSV.
"""
self.arena: PromptArena = arena
def select_and_display_match(self):
"""
Sélectionne un match et l'affiche.
Returns:
Tuple contenant:
- Le texte du premier prompt
- Le texte du second prompt
- Un dictionnaire d'état contenant les IDs des prompts
"""
try:
prompt_a_id, prompt_b_id = self.arena.select_match()
prompt_a_text = self.arena.prompts.get(prompt_a_id, "")
prompt_b_text = self.arena.prompts.get(prompt_b_id, "")
state = {"prompt_a_id": prompt_a_id, "prompt_b_id": prompt_b_id}
return (
prompt_a_text,
prompt_b_text,
state,
gr.update(interactive=True), # button A
gr.update(interactive=True), # button B
gr.update(interactive=False), # match button
)
except Exception as e:
return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {}
def record_winner_a(self, state: dict[str, str]):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
self.arena.record_result(
prompt_a_id, prompt_b_id
) # Mettre à jour la progression et le classement
progress_info = self.get_progress_info()
rankings_table = self.get_rankings_table()
return (
f"Vous avez choisi : {LABEL_A}",
progress_info,
rankings_table,
gr.update(interactive=False), # button A
gr.update(interactive=False), # button B
gr.update(interactive=True), # match button
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def record_winner_b(self, state: dict[str, str]):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
self.arena.record_result(
prompt_b_id, prompt_a_id
) # Mettre à jour la progression et le classement
progress_info = self.get_progress_info()
rankings_table = self.get_rankings_table()
return (
f"Vous avez choisi : {LABEL_B}",
progress_info,
rankings_table,
gr.update(interactive=False), # button A
gr.update(interactive=False), # button B
gr.update(interactive=True), # match button
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def get_progress_info(self) -> str:
"""
Obtient les informations sur la progression du tournoi.
Returns:
str: Message formaté contenant les statistiques de progression
"""
if not self.arena:
return "Aucune arène initialisée. Veuillez d'abord charger des prompts."
try:
progress = self.arena.get_progress()
info = f"Prompts: {progress['total_prompts']}\n"
info += f"Matchs joués: {progress['total_matches']}\n"
info += f"Progression: {progress['progress']:.2f}%\n"
info += (
f"Matchs restants estimés: {progress['estimated_remaining_matches']}\n"
)
info += f"Incertitude moyenne (σ): {progress['avg_sigma']:.4f}"
return info
except Exception as e:
return f"Erreur lors de la récupération de la progression: {str(e)}"
def get_rankings_table(self) -> pd.DataFrame:
"""
Obtient le classement des prompts sous forme de tableau.
Returns:
pd.DataFrame: Tableau de classement des prompts
"""
if not self.arena:
return pd.DataFrame([{"Erreur": "Aucune arène initialisée"}])
try:
rankings = self.arena.get_rankings()
df = pd.DataFrame(rankings)
df = df[["rank", "prompt_id", "score"]]
df = df.rename(
columns={
"rank": "Rang",
"prompt_id": "ID",
"score": "Score",
}
)
return df
except Exception as e:
return pd.DataFrame([{"Erreur": str(e)}])
def create_ui(self) -> gr.Blocks:
"""
Crée l'interface utilisateur Gradio.
Returns:
gr.Blocks: L'application Gradio configurée
"""
with gr.Blocks(title="Prompt Arena", theme=gr.themes.Ocean()) as app:
gr.Markdown('<h1 style="text-align:center;">🥊 Prompt Arena 🥊</h1>')
with gr.Row():
select_btn = gr.Button("Lancer un nouveau match", variant="primary")
with gr.Row():
proposition_a = gr.Textbox(label=LABEL_A, interactive=False)
proposition_b = gr.Textbox(label=LABEL_B, interactive=False)
with gr.Row():
vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False)
vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False)
result = gr.Textbox("Résultat", interactive=False)
progress_info = gr.Textbox(
label="Progression du concours", interactive=False
)
rankings_table = gr.DataFrame(label="Classement des prompts")
state = gr.State() # contient les IDs des prompts du match en cours
select_btn.click(
self.select_and_display_match,
inputs=[],
outputs=[
proposition_a,
proposition_b,
state,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
vote_a_btn.click(
self.record_winner_a,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
vote_b_btn.click(
self.record_winner_b,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
select_btn,
],
)
gr.Row([progress_info, rankings_table])
return app
# Exemple d'utilisation
if __name__ == "__main__":
# load the prompts from the CSV file
prompts = pd.read_csv("prompts.csv", header=None).iloc[:, 0].tolist()
arena = PromptArena(prompts=prompts)
app_instance = PromptArenaApp(arena=arena)
app = app_instance.create_ui()
app.launch()