|
import db |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
|
|
|
|
def plot_estimates_distribution(): |
|
"""Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes.""" |
|
estimates = db.load("estimates") |
|
prompts = db.load("prompts") |
|
if estimates.empty or prompts.empty: |
|
fig = go.Figure() |
|
fig.add_annotation( |
|
text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False |
|
) |
|
return fig |
|
x = np.linspace( |
|
estimates["mu"].min() - 3 * estimates["sigma"].max(), |
|
estimates["mu"].max() + 3 * estimates["sigma"].max(), |
|
500, |
|
) |
|
fig = go.Figure() |
|
shapes = [] |
|
|
|
for _, row in estimates.iterrows(): |
|
mu = row["mu"] |
|
sigma = row["sigma"] |
|
prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"] |
|
|
|
name = str(prompt_id) |
|
if "name" in prompts.columns: |
|
match = prompts[prompts["id"] == prompt_id] |
|
if not match.empty: |
|
name = match.iloc[0]["name"] |
|
y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=x, |
|
y=y, |
|
mode="lines", |
|
name=f"{name}", |
|
hovertemplate=f"<b>{name}</b><br>Score (mu): {mu:.2f}<br>Sigma: {sigma:.2f}<extra></extra>", |
|
) |
|
) |
|
|
|
shapes.append( |
|
dict( |
|
type="line", |
|
x0=mu, |
|
x1=mu, |
|
y0=0, |
|
y1=max(y), |
|
line=dict( |
|
color="gray", |
|
width=2, |
|
dash="dot", |
|
), |
|
xref="x", |
|
yref="y", |
|
) |
|
) |
|
fig.update_layout( |
|
title="Distribution gaussienne de chaque prompt", |
|
xaxis_title="Score (mu)", |
|
yaxis_title="Densité", |
|
template="plotly_white", |
|
shapes=shapes, |
|
) |
|
return fig |
|
|