Arena / plot.py
FredOru's picture
v0.2
9565067
raw
history blame
2.16 kB
import db
import numpy as np
import plotly.graph_objects as go
def plot_estimates_distribution():
"""Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes."""
estimates = db.load("estimates")
prompts = db.load("prompts")
if estimates.empty or prompts.empty:
fig = go.Figure()
fig.add_annotation(
text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False
)
return fig
x = np.linspace(
estimates["mu"].min() - 3 * estimates["sigma"].max(),
estimates["mu"].max() + 3 * estimates["sigma"].max(),
500,
)
fig = go.Figure()
shapes = []
# Une gaussienne par prompt
for _, row in estimates.iterrows():
mu = row["mu"]
sigma = row["sigma"]
prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"]
# Chercher le nom du prompt
name = str(prompt_id)
if "name" in prompts.columns:
match = prompts[prompts["id"] == prompt_id]
if not match.empty:
name = match.iloc[0]["name"]
y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=f"{name}",
hovertemplate=f"<b>{name}</b><br>Score (mu): {mu:.2f}<br>Sigma: {sigma:.2f}<extra></extra>",
)
)
# Ajout de la ligne verticale pointillée à mu (en gris)
shapes.append(
dict(
type="line",
x0=mu,
x1=mu,
y0=0,
y1=max(y),
line=dict(
color="gray",
width=2,
dash="dot",
),
xref="x",
yref="y",
)
)
fig.update_layout(
title="Distribution gaussienne de chaque prompt",
xaxis_title="Score (mu)",
yaxis_title="Densité",
template="plotly_white",
shapes=shapes,
)
return fig