File size: 3,770 Bytes
3ec88bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01e6c2b
 
 
 
 
3ec88bd
 
 
01e6c2b
 
 
 
 
3ec88bd
 
264e408
 
 
7e26e61
3ec88bd
 
 
cdcefb0
 
 
 
 
 
 
3ec88bd
cdcefb0
264e408
cdcefb0
 
 
 
 
 
7e26e61
264e408
7e26e61
264e408
 
7e26e61
 
264e408
 
7e26e61
264e408
7e26e61
 
264e408
 
 
7e26e61
 
264e408
 
3ec88bd
 
 
 
7e26e61
3ec88bd
 
 
 
 
7e26e61
3ec88bd
 
 
 
 
 
7e26e61
3ec88bd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_regression
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error


def get_plots(min_alpha, max_alpha):
    clf = Ridge()

    X, y, w = make_regression(
        n_samples=10, n_features=10, coef=True, random_state=1, bias=3.5
    )

    coefs = []
    errors = []

    alphas = np.logspace(min_alpha, max_alpha, 200)

    # Train the model with different regularisation strengths
    for a in alphas:
        clf.set_params(alpha=a)
        clf.fit(X, y)
        coefs.append(clf.coef_)
        errors.append(mean_squared_error(clf.coef_, w))

    # Display results
    fig, ax = plt.subplots(1, 2, figsize=(20, 6))

    ax[0].plot(alphas, coefs)
    ax[0].set_xscale("log")
    ax[0].set_xlabel("alpha", fontsize=16)
    ax[0].set_ylabel("weights", fontsize=16)
    ax[0].set_title(
        "Ridge coefficients as a function of the regularization", fontsize=20
    )

    ax[1].plot(alphas, errors)
    ax[1].set_xscale("log")
    ax[1].set_xlabel("alpha", fontsize=16)
    ax[1].set_ylabel("error", fontsize=16)
    ax[1].set_title(
        "Coefficient error as a function of the regularization", fontsize=20
    )
    fig.tight_layout()

    plotted_alphas_text = (
        f"**Plotted alphas between 10^({min_alpha}) and 10^({max_alpha})**"
    )
    return fig, plotted_alphas_text


with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Ridge coefficients as a function of the L2 regularization

        This space shows the effect of different alpha values in the coefficients learned by [Ridge regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html).

        The left plot shows how, as alpha tends to zero, the coefficients tend to the true coefficients. For large alpha values—i.e., strong regularization—the coefficients get smaller and eventually converge to zero.

        The right plot shows the mean squared error between the coefficients found by the model and the true coefficients. Less regularized models retrieve the exact coefficients—i.e., the error equals 0—while stronger regularised models increase the error.

        We generate the dataset using sklearn's [make_regression](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html) function to ensure we know all the coefficients' true values.

        This space is based on [sklearn’s original demo](https://scikit-learn.org/stable/auto_examples/linear_model/plot_ridge_coeffs.html#sphx-glr-auto-examples-linear-model-plot-ridge-coeffs-py).
        """
    )
    with gr.Row():
        with gr.Column(scale=5):
            with gr.Row():
                min_alpha = gr.Slider(
                    step=1,
                    value=-6,
                    minimum=-10,
                    maximum=-1,
                    label="Minimum Alpha Exponent",
                )
                max_alpha = gr.Slider(
                    step=1,
                    minimum=0,
                    maximum=10,
                    value=6,
                    label="Maximum Alpha Exponent",
                )
        with gr.Column(scale=1):
            plotted_alphas_text = gr.Markdown()

    plots = gr.Plot()

    min_alpha.change(
        get_plots,
        [min_alpha, max_alpha],
        [plots, plotted_alphas_text],
        queue=False,
    )
    max_alpha.change(
        get_plots,
        [min_alpha, max_alpha],
        [plots, plotted_alphas_text],
        queue=False,
    )

    demo.load(
        get_plots,
        [min_alpha, max_alpha],
        [plots, plotted_alphas_text],
        queue=False,
    )

if __name__ == "__main__":
    demo.launch()