Arena / app.py
FredOru's picture
(feat) db csv et admin panels
1bc7a0c
raw
history blame
9.22 kB
import gradio as gr
import pandas as pd
from arena import Arena
import plotly.graph_objs as go
import numpy as np
LABEL_A = "Proposition A"
LABEL_B = "Proposition B"
ARENA = Arena()
ARENA.init_estimates()
def select_and_display_match():
try:
prompt_a, prompt_b = ARENA.select_match()
state = {"prompt_a_id": prompt_a["id"], "prompt_b_id": prompt_b["id"]}
vote_a_btn_update = gr.update(interactive=True)
vote_b_btn_update = gr.update(interactive=True)
new_match_btn_update = gr.update(interactive=False)
return (
prompt_a["text"],
prompt_b["text"],
state,
vote_a_btn_update,
vote_b_btn_update,
new_match_btn_update,
)
except Exception as e:
return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {}
def record_winner_a(state):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
ARENA.record_result(prompt_a_id, prompt_b_id)
progress_info = ARENA.get_progress()
rankings_table = ARENA.get_rankings()
vote_a_btn_update = gr.update(interactive=False)
vote_b_btn_update = gr.update(interactive=False)
new_match_btn_update = gr.update(interactive=True)
return (
f"Vous avez choisi : {LABEL_A}",
progress_info,
rankings_table,
vote_a_btn_update,
vote_b_btn_update,
new_match_btn_update,
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def record_winner_b(state):
try:
prompt_a_id = state["prompt_a_id"]
prompt_b_id = state["prompt_b_id"]
ARENA.record_result(prompt_b_id, prompt_a_id)
progress_info = ARENA.get_progress()
rankings_table = ARENA.get_rankings()
vote_a_btn_update = gr.update(interactive=False)
vote_b_btn_update = gr.update(interactive=False)
new_match_btn_update = gr.update(interactive=True)
return (
f"Vous avez choisi : {LABEL_B}",
progress_info,
rankings_table,
vote_a_btn_update,
vote_b_btn_update,
new_match_btn_update,
)
except Exception as e:
return (
f"Erreur lors de l'enregistrement du résultat: {str(e)}",
"",
pd.DataFrame(),
)
def update_table(table_name, df):
"""Met à jour le fichier CSV de la table spécifiée à partir du DataFrame édité."""
ARENA.replace(table_name, df)
return None
def admin_visible(request: gr.Request):
is_admin = request.username == "admin"
return gr.update(visible=is_admin)
def welcome_user(request: gr.Request):
return request.username
def plot_estimates_distribution():
"""Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes."""
estimates = ARENA.load("estimates")
prompts = ARENA.load("prompts")
if estimates.empty or prompts.empty:
fig = go.Figure()
fig.add_annotation(
text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False
)
return fig
x = np.linspace(
estimates["mu"].min() - 3 * estimates["sigma"].max(),
estimates["mu"].max() + 3 * estimates["sigma"].max(),
500,
)
fig = go.Figure()
shapes = []
# Une gaussienne par prompt
for _, row in estimates.iterrows():
mu = row["mu"]
sigma = row["sigma"]
prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"]
# Chercher le nom du prompt
name = str(prompt_id)
if "name" in prompts.columns:
match = prompts[prompts["id"] == prompt_id]
if not match.empty:
name = match.iloc[0]["name"]
y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=f"{name}",
hovertemplate=f"<b>{name}</b><br>Score (mu): {mu:.2f}<br>Sigma: {sigma:.2f}<extra></extra>",
)
)
# Ajout de la ligne verticale pointillée à mu (en gris)
shapes.append(
dict(
type="line",
x0=mu,
x1=mu,
y0=0,
y1=max(y),
line=dict(
color="gray",
width=2,
dash="dot",
),
xref="x",
yref="y",
)
)
fig.update_layout(
title="Distribution gaussienne de chaque prompt",
xaxis_title="Score (mu)",
yaxis_title="Densité",
template="plotly_white",
shapes=shapes,
)
return fig
with gr.Blocks(
title="Prompt Arena",
# theme=gr.themes.Default.load("theme_schema_miku.json"),
) as demo:
state = gr.State()
with gr.Row():
username = gr.Markdown("")
gr.Button("Logout", link="/logout", scale=0, min_width=50)
gr.Markdown(
'<h1 style="text-align:center;"> Concours du meilleur Prompt Engineer </h1>'
)
progress_info = gr.Textbox(
label="Progression du concours",
value=ARENA.get_progress(),
interactive=False,
lines=2,
)
with gr.Tabs() as tabs:
# Onglet des Combats
with gr.TabItem("Combats"):
with gr.Row():
new_match_btn = gr.Button("Lancer un nouveau match", variant="primary")
with gr.Row():
with gr.Column():
proposition_a = gr.Textbox(label=LABEL_A, interactive=False)
vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False)
with gr.Column():
proposition_b = gr.Textbox(label=LABEL_B, interactive=False)
vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False)
result = gr.Textbox("Résultat", interactive=False)
# with gr.TabItem("Classement"):
rankings_table = gr.DataFrame(
label="Classement des prompts",
value=ARENA.get_rankings(),
interactive=True,
)
# Onglet des Résultats
with gr.TabItem("Admin") as admin_tab:
with gr.Accordion("Prompts", open=False):
prompts_table = gr.DataFrame(
value=ARENA.load("prompts"),
interactive=True,
)
with gr.Accordion("Estimates", open=False):
estimates_table = gr.DataFrame(
label="Estimations",
value=ARENA.load("estimates"),
interactive=True,
)
with gr.Accordion("Votes", open=False):
votes_table = gr.DataFrame(
label="Votes",
value=ARENA.load("votes"),
interactive=True,
)
gr.Plot(plot_estimates_distribution, label="Distribution des estimations")
prompts_table.change(
update_table,
inputs=[gr.Markdown("prompts", visible=False), prompts_table],
outputs=None,
)
estimates_table.change(
update_table,
inputs=[gr.Markdown("estimates", visible=False), estimates_table],
outputs=None,
)
votes_table.change(
update_table,
inputs=[gr.Markdown("votes", visible=False), votes_table],
outputs=None,
)
new_match_btn.click(
select_and_display_match,
inputs=[],
outputs=[
proposition_a,
proposition_b,
state,
vote_a_btn,
vote_b_btn,
new_match_btn,
],
)
# Callbacks pour les deux onglets
vote_a_btn.click(
record_winner_a,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
new_match_btn,
],
)
vote_b_btn.click(
record_winner_b,
inputs=[state],
outputs=[
result,
progress_info,
rankings_table,
vote_a_btn,
vote_b_btn,
new_match_btn,
],
)
demo.load(admin_visible, None, admin_tab)
demo.load(welcome_user, None, username)
def arena_auth(username, password):
if username == "admin":
return (
password == "fred"
) # todo : mettre le mot de passe en variable d'environnement
else:
return username == password
# Exemple d'utilisation
if __name__ == "__main__":
demo.launch(
auth_message="Connexion à l'arène des prompts", auth=arena_auth
) # ajouter share=True pour partager l'interface