File size: 5,524 Bytes
0449aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b20b499
 
 
0449aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b20b499
0449aa7
 
b20b499
 
 
 
0449aa7
 
 
 
 
 
 
 
b20b499
0449aa7
b20b499
 
 
 
 
 
 
 
 
 
 
0449aa7
b20b499
 
 
 
 
0449aa7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from typing import List
from traitlets import Dict
import solara
import solara.lab
import matplotlib.pyplot as plt

# needed for solara up to version 1.28
plt.switch_backend("module://matplotlib_inline.backend_inline")
# title = "solara"

from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.tree import DecisionTreeClassifier

import matplotlib.pylab as plt
import numpy as np
import pandas as pd

from drawdata import ScatterWidget

drawdata: solara.Reactive[List[Dict]] = solara.reactive([])
# we keep the active tab in a reactive var so the state does not get lost when we change
# the orientation of the page (vertical or horizontal)
tab = solara.reactive(0)

@solara.component
def ClassifierDraw(classifier, X, y, response_method="predict_proba", figsize=(8, 8)):
    fig = plt.figure(figsize=figsize)
    disp = DecisionBoundaryDisplay.from_estimator(
        classifier,
        X,
        # not sure why this was needed, otherwise i get a blank plot
        ax=fig.add_subplot(111),
        response_method=response_method,
        xlabel="x",
        ylabel="y",
        alpha=0.5,
    )
    disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor="k")
    plt.title(f"{classifier.__class__.__name__}")
    plt.close()
    solara.FigureMatplotlib(fig)


@solara.component
def DecisionTreeClassifierDraw(df):
    criterion = solara.use_reactive("gini")
    splitter = solara.use_reactive("best")
    with solara.Row():
        solara.ToggleButtonsSingle(value=criterion, values=["gini", "entropy", "log_loss"])
        solara.ToggleButtonsSingle(value=splitter, values=["best", "random"])
    X = df[["x", "y"]].values
    y = df["color"]
    classifier = DecisionTreeClassifier(criterion=criterion.value, splitter=splitter.value).fit(X, y)
    ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict")


@solara.component
def LogisticRegressionDraw(df):
    penalty = solara.use_reactive("l2")
    solver = solara.use_reactive("lbfgs")
    l1_ratio = solara.use_reactive(0.5)
    with solara.Row():
        solara.ToggleButtonsSingle(value=penalty, values=["l1", "l2", "elasticnet", "none"])
        solara.ToggleButtonsSingle(value=solver, values=["newton-cg", "lbfgs", "liblinear", "sag", "saga"])
    if penalty.value == "elasticnet":
        solara.FloatSlider("l1_ratio", value=l1_ratio, min=0, max=1, step=0.1)
    X = df[["x", "y"]].values
    y = df["color"]
    try:
        classifier = LogisticRegression(penalty=penalty.value, solver=solver.value, l1_ratio=l1_ratio.value).fit(X, y)
    except ValueError as e:
        solara.Error(str(e))
    else:
        ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict")


@solara.component
def Page():
    vertical = solara.use_reactive(True)
    solara.AppBarTitle("Draw Data with Solara demo")
    df = pd.DataFrame(drawdata.value) if drawdata.value else None

    with solara.AppBar():
        # TODO: doesn't work, ScatterWidget does not update when data is updated (read only?)
        # solara.Button(icon_name="mdi-delete", on_click=lambda: drawdata.set([]), icon=True)
        # demo how solara can dynamically change the layout
        solara.lab.ThemeToggle(enable_auto=False)
        solara.Button(icon_name="mdi-align-vertical-top" if vertical.value else "mdi-align-horizontal-left", on_click=lambda: vertical.set(not vertical.value), icon=True)

    dark_background = solara.lab.use_dark_effective()
    plt.style.use('dark_background' if dark_background else 'default')


    with solara.Column() if vertical.value else solara.Row():
        # with solara, we don't just create the widget, but an element that describes it
        # and instead of observe, we have on_<trait> callbacks
        # Note: if we store the data in the reactive var (drawdata), we keep the drawing
        # on hot reload.
        ScatterWidget.element(data=drawdata.value, on_data=drawdata.set)
        # downside of using elements and components: we cannot call method on the widget
        # so we need to re-create the dataframe ourselves
        with solara.lab.Tabs(value=tab):
            with solara.lab.Tab("classifier"):
                with solara.Column(classes=["py-4"]):  # some nice y padding
                    if df is not None and (df["color"].nunique() > 1):
                        with solara.Column(style={"max-height": "500px", "padding-top": "0px"}):
                            with solara.lab.Tabs():
                                with solara.lab.Tab("DecisionTreeClassifier"):
                                    DecisionTreeClassifierDraw(df)
                                with solara.lab.Tab("LogisticRegressionDraw"):
                                    LogisticRegressionDraw(df)
                    else:
                        with solara.Column(style={"justify-content": "center"}) if not vertical.value else solara.Row():
                            solara.Info("Choose at least two colors to draw a decision boundary.")
            with solara.lab.Tab("table view"):
                with solara.Column(classes=["py-4"]):  # some nice y padding
                    if df is not None:
                        with solara.FileDownload(data=lambda: df.to_csv(), filename="drawdata.csv"):
                            solara.Button("download as csv", icon_name="mdi-download", outlined=True, color="primary")
                        solara.DataFrame(df)


# in the notebook:
Page()