import gradio as gr
import pandas as pd
from arena import Arena
import plotly.graph_objs as go
import numpy as np
LABEL_A = "Proposition A"
LABEL_B = "Proposition B"
ARENA = Arena()
ARENA.init_estimates()
def select_and_display_match():
try:
prompt_a, prompt_b = ARENA.select_match()
state = {"prompt_a_id": prompt_a["id"], "prompt_b_id": prompt_b["id"]}
vote_a_btn_update = gr.update(interactive=True)
vote_b_btn_update = gr.update(interactive=True)
new_match_btn_update = gr.update(interactive=False)
return (
prompt_a["text"],
prompt_b["text"],
state,
vote_a_btn_update,
vote_b_btn_update,
new_match_btn_update,
)
except Exception as e:
return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {}
def record_winner_a(state):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
ARENA.record_result(prompt_a_id, prompt_b_id)
progress_info = ARENA.get_progress()
rankings_table = ARENA.get_rankings()
vote_a_btn_update = gr.update(interactive=False)
vote_b_btn_update = gr.update(interactive=False)
new_match_btn_update = gr.update(interactive=True)
return (
f"Vous avez choisi : {LABEL_A}",
progress_info,
rankings_table,
vote_a_btn_update,
vote_b_btn_update,
new_match_btn_update,
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def record_winner_b(state):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
ARENA.record_result(prompt_b_id, prompt_a_id)
progress_info = ARENA.get_progress()
rankings_table = ARENA.get_rankings()
vote_a_btn_update = gr.update(interactive=False)
vote_b_btn_update = gr.update(interactive=False)
new_match_btn_update = gr.update(interactive=True)
return (
f"Vous avez choisi : {LABEL_B}",
progress_info,
rankings_table,
vote_a_btn_update,
vote_b_btn_update,
new_match_btn_update,
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def update_table(table_name, df):
"""Met à jour le fichier CSV de la table spécifiée à partir du DataFrame édité."""
ARENA.replace(table_name, df)
return None
def admin_visible(request: gr.Request):
is_admin = request.username == "admin"
return gr.update(visible=is_admin)
def welcome_user(request: gr.Request):
return request.username
def plot_estimates_distribution():
"""Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes."""
estimates = ARENA.load("estimates")
prompts = ARENA.load("prompts")
if estimates.empty or prompts.empty:
fig = go.Figure()
fig.add_annotation(
text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False
)
return fig
x = np.linspace(
estimates["mu"].min() - 3 * estimates["sigma"].max(),
estimates["mu"].max() + 3 * estimates["sigma"].max(),
500,
)
fig = go.Figure()
shapes = []
# Une gaussienne par prompt
for _, row in estimates.iterrows():
mu = row["mu"]
sigma = row["sigma"]
prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"]
# Chercher le nom du prompt
name = str(prompt_id)
if "name" in prompts.columns:
match = prompts[prompts["id"] == prompt_id]
if not match.empty:
name = match.iloc[0]["name"]
y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=f"{name}",
hovertemplate=f"{name}
Score (mu): {mu:.2f}
Sigma: {sigma:.2f}