import db import numpy as np import plotly.graph_objects as go def plot_estimates_distribution(): """Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes.""" estimates = db.load("estimates") prompts = db.load("prompts") if estimates.empty or prompts.empty: fig = go.Figure() fig.add_annotation( text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False ) return fig x = np.linspace( estimates["mu"].min() - 3 * estimates["sigma"].max(), estimates["mu"].max() + 3 * estimates["sigma"].max(), 500, ) fig = go.Figure() shapes = [] # Une gaussienne par prompt for _, row in estimates.iterrows(): mu = row["mu"] sigma = row["sigma"] prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"] # Chercher le nom du prompt name = str(prompt_id) if "name" in prompts.columns: match = prompts[prompts["id"] == prompt_id] if not match.empty: name = match.iloc[0]["name"] y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2) fig.add_trace( go.Scatter( x=x, y=y, mode="lines", name=f"{name}", hovertemplate=f"{name}
Score (mu): {mu:.2f}
Sigma: {sigma:.2f}", ) ) # Ajout de la ligne verticale pointillée à mu (en gris) shapes.append( dict( type="line", x0=mu, x1=mu, y0=0, y1=max(y), line=dict( color="gray", width=2, dash="dot", ), xref="x", yref="y", ) ) fig.update_layout( title="Distribution gaussienne de chaque prompt", xaxis_title="Score (mu)", yaxis_title="Densité", template="plotly_white", shapes=shapes, ) return fig