File size: 2,156 Bytes
9565067 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import db
import numpy as np
import plotly.graph_objects as go
def plot_estimates_distribution():
"""Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes."""
estimates = db.load("estimates")
prompts = db.load("prompts")
if estimates.empty or prompts.empty:
fig = go.Figure()
fig.add_annotation(
text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False
)
return fig
x = np.linspace(
estimates["mu"].min() - 3 * estimates["sigma"].max(),
estimates["mu"].max() + 3 * estimates["sigma"].max(),
500,
)
fig = go.Figure()
shapes = []
# Une gaussienne par prompt
for _, row in estimates.iterrows():
mu = row["mu"]
sigma = row["sigma"]
prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"]
# Chercher le nom du prompt
name = str(prompt_id)
if "name" in prompts.columns:
match = prompts[prompts["id"] == prompt_id]
if not match.empty:
name = match.iloc[0]["name"]
y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=f"{name}",
hovertemplate=f"<b>{name}</b><br>Score (mu): {mu:.2f}<br>Sigma: {sigma:.2f}<extra></extra>",
)
)
# Ajout de la ligne verticale pointillée à mu (en gris)
shapes.append(
dict(
type="line",
x0=mu,
x1=mu,
y0=0,
y1=max(y),
line=dict(
color="gray",
width=2,
dash="dot",
),
xref="x",
yref="y",
)
)
fig.update_layout(
title="Distribution gaussienne de chaque prompt",
xaxis_title="Score (mu)",
yaxis_title="Densité",
template="plotly_white",
shapes=shapes,
)
return fig
|