|
import gradio as gr |
|
import pandas as pd |
|
from arena import Arena |
|
import plotly.graph_objs as go |
|
import numpy as np |
|
|
|
LABEL_A = "Proposition A" |
|
LABEL_B = "Proposition B" |
|
|
|
ARENA = Arena() |
|
ARENA.init_estimates() |
|
|
|
|
|
def select_and_display_match(): |
|
try: |
|
prompt_a, prompt_b = ARENA.select_match() |
|
state = {"prompt_a_id": prompt_a["id"], "prompt_b_id": prompt_b["id"]} |
|
vote_a_btn_update = gr.update(interactive=True) |
|
vote_b_btn_update = gr.update(interactive=True) |
|
new_match_btn_update = gr.update(interactive=False) |
|
return ( |
|
prompt_a["text"], |
|
prompt_b["text"], |
|
state, |
|
vote_a_btn_update, |
|
vote_b_btn_update, |
|
new_match_btn_update, |
|
) |
|
except Exception as e: |
|
return f"Erreur lors de la sélection d'un match: {str(e)}", "", "", {} |
|
|
|
|
|
def record_winner_a(state): |
|
try: |
|
prompt_a_id = state["prompt_a_id"] |
|
prompt_b_id = state["prompt_b_id"] |
|
ARENA.record_result(prompt_a_id, prompt_b_id) |
|
progress_info = ARENA.get_progress() |
|
rankings_table = ARENA.get_rankings() |
|
vote_a_btn_update = gr.update(interactive=False) |
|
vote_b_btn_update = gr.update(interactive=False) |
|
new_match_btn_update = gr.update(interactive=True) |
|
return ( |
|
f"Vous avez choisi : {LABEL_A}", |
|
progress_info, |
|
rankings_table, |
|
vote_a_btn_update, |
|
vote_b_btn_update, |
|
new_match_btn_update, |
|
) |
|
except Exception as e: |
|
return ( |
|
f"Erreur lors de l'enregistrement du résultat: {str(e)}", |
|
"", |
|
pd.DataFrame(), |
|
) |
|
|
|
|
|
def record_winner_b(state): |
|
try: |
|
prompt_a_id = state["prompt_a_id"] |
|
prompt_b_id = state["prompt_b_id"] |
|
ARENA.record_result(prompt_b_id, prompt_a_id) |
|
progress_info = ARENA.get_progress() |
|
rankings_table = ARENA.get_rankings() |
|
vote_a_btn_update = gr.update(interactive=False) |
|
vote_b_btn_update = gr.update(interactive=False) |
|
new_match_btn_update = gr.update(interactive=True) |
|
return ( |
|
f"Vous avez choisi : {LABEL_B}", |
|
progress_info, |
|
rankings_table, |
|
vote_a_btn_update, |
|
vote_b_btn_update, |
|
new_match_btn_update, |
|
) |
|
except Exception as e: |
|
return ( |
|
f"Erreur lors de l'enregistrement du résultat: {str(e)}", |
|
"", |
|
pd.DataFrame(), |
|
) |
|
|
|
|
|
def update_table(table_name, df): |
|
"""Met à jour le fichier CSV de la table spécifiée à partir du DataFrame édité.""" |
|
ARENA.replace(table_name, df) |
|
return None |
|
|
|
|
|
def admin_visible(request: gr.Request): |
|
is_admin = request.username == "admin" |
|
return gr.update(visible=is_admin) |
|
|
|
|
|
def welcome_user(request: gr.Request): |
|
return request.username |
|
|
|
|
|
def plot_estimates_distribution(): |
|
"""Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes.""" |
|
estimates = ARENA.load("estimates") |
|
prompts = ARENA.load("prompts") |
|
if estimates.empty or prompts.empty: |
|
fig = go.Figure() |
|
fig.add_annotation( |
|
text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False |
|
) |
|
return fig |
|
x = np.linspace( |
|
estimates["mu"].min() - 3 * estimates["sigma"].max(), |
|
estimates["mu"].max() + 3 * estimates["sigma"].max(), |
|
500, |
|
) |
|
fig = go.Figure() |
|
shapes = [] |
|
|
|
for _, row in estimates.iterrows(): |
|
mu = row["mu"] |
|
sigma = row["sigma"] |
|
prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"] |
|
|
|
name = str(prompt_id) |
|
if "name" in prompts.columns: |
|
match = prompts[prompts["id"] == prompt_id] |
|
if not match.empty: |
|
name = match.iloc[0]["name"] |
|
y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=x, |
|
y=y, |
|
mode="lines", |
|
name=f"{name}", |
|
hovertemplate=f"<b>{name}</b><br>Score (mu): {mu:.2f}<br>Sigma: {sigma:.2f}<extra></extra>", |
|
) |
|
) |
|
|
|
shapes.append( |
|
dict( |
|
type="line", |
|
x0=mu, |
|
x1=mu, |
|
y0=0, |
|
y1=max(y), |
|
line=dict( |
|
color="gray", |
|
width=2, |
|
dash="dot", |
|
), |
|
xref="x", |
|
yref="y", |
|
) |
|
) |
|
fig.update_layout( |
|
title="Distribution gaussienne de chaque prompt", |
|
xaxis_title="Score (mu)", |
|
yaxis_title="Densité", |
|
template="plotly_white", |
|
shapes=shapes, |
|
) |
|
return fig |
|
|
|
|
|
with gr.Blocks( |
|
title="Prompt Arena", |
|
|
|
) as demo: |
|
state = gr.State() |
|
|
|
with gr.Row(): |
|
username = gr.Markdown("") |
|
gr.Button("Logout", link="/logout", scale=0, min_width=50) |
|
|
|
gr.Markdown( |
|
'<h1 style="text-align:center;"> Concours du meilleur Prompt Engineer </h1>' |
|
) |
|
|
|
progress_info = gr.Textbox( |
|
label="Progression du concours", |
|
value=ARENA.get_progress(), |
|
interactive=False, |
|
lines=2, |
|
) |
|
|
|
with gr.Tabs() as tabs: |
|
|
|
with gr.TabItem("Combats"): |
|
with gr.Row(): |
|
new_match_btn = gr.Button("Lancer un nouveau match", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
proposition_a = gr.Textbox(label=LABEL_A, interactive=False) |
|
vote_a_btn = gr.Button("Choisir " + LABEL_A, interactive=False) |
|
with gr.Column(): |
|
proposition_b = gr.Textbox(label=LABEL_B, interactive=False) |
|
vote_b_btn = gr.Button("Choisir " + LABEL_B, interactive=False) |
|
result = gr.Textbox("Résultat", interactive=False) |
|
|
|
|
|
rankings_table = gr.DataFrame( |
|
label="Classement des prompts", |
|
value=ARENA.get_rankings(), |
|
interactive=True, |
|
) |
|
|
|
|
|
with gr.TabItem("Admin") as admin_tab: |
|
with gr.Accordion("Prompts", open=False): |
|
prompts_table = gr.DataFrame( |
|
value=ARENA.load("prompts"), |
|
interactive=True, |
|
) |
|
with gr.Accordion("Estimates", open=False): |
|
estimates_table = gr.DataFrame( |
|
label="Estimations", |
|
value=ARENA.load("estimates"), |
|
interactive=True, |
|
) |
|
with gr.Accordion("Votes", open=False): |
|
votes_table = gr.DataFrame( |
|
label="Votes", |
|
value=ARENA.load("votes"), |
|
interactive=True, |
|
) |
|
gr.Plot(plot_estimates_distribution, label="Distribution des estimations") |
|
prompts_table.change( |
|
update_table, |
|
inputs=[gr.Markdown("prompts", visible=False), prompts_table], |
|
outputs=None, |
|
) |
|
estimates_table.change( |
|
update_table, |
|
inputs=[gr.Markdown("estimates", visible=False), estimates_table], |
|
outputs=None, |
|
) |
|
votes_table.change( |
|
update_table, |
|
inputs=[gr.Markdown("votes", visible=False), votes_table], |
|
outputs=None, |
|
) |
|
|
|
new_match_btn.click( |
|
select_and_display_match, |
|
inputs=[], |
|
outputs=[ |
|
proposition_a, |
|
proposition_b, |
|
state, |
|
vote_a_btn, |
|
vote_b_btn, |
|
new_match_btn, |
|
], |
|
) |
|
|
|
|
|
vote_a_btn.click( |
|
record_winner_a, |
|
inputs=[state], |
|
outputs=[ |
|
result, |
|
progress_info, |
|
rankings_table, |
|
vote_a_btn, |
|
vote_b_btn, |
|
new_match_btn, |
|
], |
|
) |
|
vote_b_btn.click( |
|
record_winner_b, |
|
inputs=[state], |
|
outputs=[ |
|
result, |
|
progress_info, |
|
rankings_table, |
|
vote_a_btn, |
|
vote_b_btn, |
|
new_match_btn, |
|
], |
|
) |
|
|
|
demo.load(admin_visible, None, admin_tab) |
|
demo.load(welcome_user, None, username) |
|
|
|
|
|
def arena_auth(username, password): |
|
if username == "admin": |
|
return ( |
|
password == "fred" |
|
) |
|
else: |
|
return username == password |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
auth_message="Connexion à l'arène des prompts", auth=arena_auth |
|
) |
|
|