drawdata-sklearn / pages /02-fancy-with-solara.py
maartenbreddels's picture
upgrade to use solara 1.29
b20b499
from typing import List
from traitlets import Dict
import solara
import solara.lab
import matplotlib.pyplot as plt
# needed for solara up to version 1.28
plt.switch_backend("module://matplotlib_inline.backend_inline")
# title = "solara"
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pylab as plt
import numpy as np
import pandas as pd
from drawdata import ScatterWidget
drawdata: solara.Reactive[List[Dict]] = solara.reactive([])
# we keep the active tab in a reactive var so the state does not get lost when we change
# the orientation of the page (vertical or horizontal)
tab = solara.reactive(0)
@solara.component
def ClassifierDraw(classifier, X, y, response_method="predict_proba", figsize=(8, 8)):
fig = plt.figure(figsize=figsize)
disp = DecisionBoundaryDisplay.from_estimator(
classifier,
X,
# not sure why this was needed, otherwise i get a blank plot
ax=fig.add_subplot(111),
response_method=response_method,
xlabel="x",
ylabel="y",
alpha=0.5,
)
disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor="k")
plt.title(f"{classifier.__class__.__name__}")
plt.close()
solara.FigureMatplotlib(fig)
@solara.component
def DecisionTreeClassifierDraw(df):
criterion = solara.use_reactive("gini")
splitter = solara.use_reactive("best")
with solara.Row():
solara.ToggleButtonsSingle(value=criterion, values=["gini", "entropy", "log_loss"])
solara.ToggleButtonsSingle(value=splitter, values=["best", "random"])
X = df[["x", "y"]].values
y = df["color"]
classifier = DecisionTreeClassifier(criterion=criterion.value, splitter=splitter.value).fit(X, y)
ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict")
@solara.component
def LogisticRegressionDraw(df):
penalty = solara.use_reactive("l2")
solver = solara.use_reactive("lbfgs")
l1_ratio = solara.use_reactive(0.5)
with solara.Row():
solara.ToggleButtonsSingle(value=penalty, values=["l1", "l2", "elasticnet", "none"])
solara.ToggleButtonsSingle(value=solver, values=["newton-cg", "lbfgs", "liblinear", "sag", "saga"])
if penalty.value == "elasticnet":
solara.FloatSlider("l1_ratio", value=l1_ratio, min=0, max=1, step=0.1)
X = df[["x", "y"]].values
y = df["color"]
try:
classifier = LogisticRegression(penalty=penalty.value, solver=solver.value, l1_ratio=l1_ratio.value).fit(X, y)
except ValueError as e:
solara.Error(str(e))
else:
ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict")
@solara.component
def Page():
vertical = solara.use_reactive(True)
solara.AppBarTitle("Draw Data with Solara demo")
df = pd.DataFrame(drawdata.value) if drawdata.value else None
with solara.AppBar():
# TODO: doesn't work, ScatterWidget does not update when data is updated (read only?)
# solara.Button(icon_name="mdi-delete", on_click=lambda: drawdata.set([]), icon=True)
# demo how solara can dynamically change the layout
solara.lab.ThemeToggle(enable_auto=False)
solara.Button(icon_name="mdi-align-vertical-top" if vertical.value else "mdi-align-horizontal-left", on_click=lambda: vertical.set(not vertical.value), icon=True)
dark_background = solara.lab.use_dark_effective()
plt.style.use('dark_background' if dark_background else 'default')
with solara.Column() if vertical.value else solara.Row():
# with solara, we don't just create the widget, but an element that describes it
# and instead of observe, we have on_<trait> callbacks
# Note: if we store the data in the reactive var (drawdata), we keep the drawing
# on hot reload.
ScatterWidget.element(data=drawdata.value, on_data=drawdata.set)
# downside of using elements and components: we cannot call method on the widget
# so we need to re-create the dataframe ourselves
with solara.lab.Tabs(value=tab):
with solara.lab.Tab("classifier"):
with solara.Column(classes=["py-4"]): # some nice y padding
if df is not None and (df["color"].nunique() > 1):
with solara.Column(style={"max-height": "500px", "padding-top": "0px"}):
with solara.lab.Tabs():
with solara.lab.Tab("DecisionTreeClassifier"):
DecisionTreeClassifierDraw(df)
with solara.lab.Tab("LogisticRegressionDraw"):
LogisticRegressionDraw(df)
else:
with solara.Column(style={"justify-content": "center"}) if not vertical.value else solara.Row():
solara.Info("Choose at least two colors to draw a decision boundary.")
with solara.lab.Tab("table view"):
with solara.Column(classes=["py-4"]): # some nice y padding
if df is not None:
with solara.FileDownload(data=lambda: df.to_csv(), filename="drawdata.csv"):
solara.Button("download as csv", icon_name="mdi-download", outlined=True, color="primary")
solara.DataFrame(df)
# in the notebook:
Page()