File size: 2,156 Bytes
9565067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import db
import numpy as np
import plotly.graph_objects as go


def plot_estimates_distribution():
    """Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes."""
    estimates = db.load("estimates")
    prompts = db.load("prompts")
    if estimates.empty or prompts.empty:
        fig = go.Figure()
        fig.add_annotation(
            text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False
        )
        return fig
    x = np.linspace(
        estimates["mu"].min() - 3 * estimates["sigma"].max(),
        estimates["mu"].max() + 3 * estimates["sigma"].max(),
        500,
    )
    fig = go.Figure()
    shapes = []
    # Une gaussienne par prompt
    for _, row in estimates.iterrows():
        mu = row["mu"]
        sigma = row["sigma"]
        prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"]
        # Chercher le nom du prompt
        name = str(prompt_id)
        if "name" in prompts.columns:
            match = prompts[prompts["id"] == prompt_id]
            if not match.empty:
                name = match.iloc[0]["name"]
        y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                mode="lines",
                name=f"{name}",
                hovertemplate=f"<b>{name}</b><br>Score (mu): {mu:.2f}<br>Sigma: {sigma:.2f}<extra></extra>",
            )
        )
        # Ajout de la ligne verticale pointillée à mu (en gris)
        shapes.append(
            dict(
                type="line",
                x0=mu,
                x1=mu,
                y0=0,
                y1=max(y),
                line=dict(
                    color="gray",
                    width=2,
                    dash="dot",
                ),
                xref="x",
                yref="y",
            )
        )
    fig.update_layout(
        title="Distribution gaussienne de chaque prompt",
        xaxis_title="Score (mu)",
        yaxis_title="Densité",
        template="plotly_white",
        shapes=shapes,
    )
    return fig