Spaces:
Running
Running
from typing import List | |
from traitlets import Dict | |
import solara | |
import solara.lab | |
import matplotlib.pyplot as plt | |
# needed for solara up to version 1.28 | |
plt.switch_backend("module://matplotlib_inline.backend_inline") | |
# title = "solara" | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.inspection import DecisionBoundaryDisplay | |
from sklearn.tree import DecisionTreeClassifier | |
import matplotlib.pylab as plt | |
import numpy as np | |
import pandas as pd | |
from drawdata import ScatterWidget | |
drawdata: solara.Reactive[List[Dict]] = solara.reactive([]) | |
# we keep the active tab in a reactive var so the state does not get lost when we change | |
# the orientation of the page (vertical or horizontal) | |
tab = solara.reactive(0) | |
def ClassifierDraw(classifier, X, y, response_method="predict_proba", figsize=(8, 8)): | |
fig = plt.figure(figsize=figsize) | |
disp = DecisionBoundaryDisplay.from_estimator( | |
classifier, | |
X, | |
# not sure why this was needed, otherwise i get a blank plot | |
ax=fig.add_subplot(111), | |
response_method=response_method, | |
xlabel="x", | |
ylabel="y", | |
alpha=0.5, | |
) | |
disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor="k") | |
plt.title(f"{classifier.__class__.__name__}") | |
plt.close() | |
solara.FigureMatplotlib(fig) | |
def DecisionTreeClassifierDraw(df): | |
criterion = solara.use_reactive("gini") | |
splitter = solara.use_reactive("best") | |
with solara.Row(): | |
solara.ToggleButtonsSingle(value=criterion, values=["gini", "entropy", "log_loss"]) | |
solara.ToggleButtonsSingle(value=splitter, values=["best", "random"]) | |
X = df[["x", "y"]].values | |
y = df["color"] | |
classifier = DecisionTreeClassifier(criterion=criterion.value, splitter=splitter.value).fit(X, y) | |
ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict") | |
def LogisticRegressionDraw(df): | |
penalty = solara.use_reactive("l2") | |
solver = solara.use_reactive("lbfgs") | |
l1_ratio = solara.use_reactive(0.5) | |
with solara.Row(): | |
solara.ToggleButtonsSingle(value=penalty, values=["l1", "l2", "elasticnet", "none"]) | |
solara.ToggleButtonsSingle(value=solver, values=["newton-cg", "lbfgs", "liblinear", "sag", "saga"]) | |
if penalty.value == "elasticnet": | |
solara.FloatSlider("l1_ratio", value=l1_ratio, min=0, max=1, step=0.1) | |
X = df[["x", "y"]].values | |
y = df["color"] | |
try: | |
classifier = LogisticRegression(penalty=penalty.value, solver=solver.value, l1_ratio=l1_ratio.value).fit(X, y) | |
except ValueError as e: | |
solara.Error(str(e)) | |
else: | |
ClassifierDraw(classifier, X, y, "predict_proba" if len(np.unique(df["color"])) == 2 else "predict") | |
def Page(): | |
vertical = solara.use_reactive(True) | |
solara.AppBarTitle("Draw Data with Solara demo") | |
df = pd.DataFrame(drawdata.value) if drawdata.value else None | |
with solara.AppBar(): | |
# TODO: doesn't work, ScatterWidget does not update when data is updated (read only?) | |
# solara.Button(icon_name="mdi-delete", on_click=lambda: drawdata.set([]), icon=True) | |
# demo how solara can dynamically change the layout | |
solara.lab.ThemeToggle(enable_auto=False) | |
solara.Button(icon_name="mdi-align-vertical-top" if vertical.value else "mdi-align-horizontal-left", on_click=lambda: vertical.set(not vertical.value), icon=True) | |
dark_background = solara.lab.use_dark_effective() | |
plt.style.use('dark_background' if dark_background else 'default') | |
with solara.Column() if vertical.value else solara.Row(): | |
# with solara, we don't just create the widget, but an element that describes it | |
# and instead of observe, we have on_<trait> callbacks | |
# Note: if we store the data in the reactive var (drawdata), we keep the drawing | |
# on hot reload. | |
ScatterWidget.element(data=drawdata.value, on_data=drawdata.set) | |
# downside of using elements and components: we cannot call method on the widget | |
# so we need to re-create the dataframe ourselves | |
with solara.lab.Tabs(value=tab): | |
with solara.lab.Tab("classifier"): | |
with solara.Column(classes=["py-4"]): # some nice y padding | |
if df is not None and (df["color"].nunique() > 1): | |
with solara.Column(style={"max-height": "500px", "padding-top": "0px"}): | |
with solara.lab.Tabs(): | |
with solara.lab.Tab("DecisionTreeClassifier"): | |
DecisionTreeClassifierDraw(df) | |
with solara.lab.Tab("LogisticRegressionDraw"): | |
LogisticRegressionDraw(df) | |
else: | |
with solara.Column(style={"justify-content": "center"}) if not vertical.value else solara.Row(): | |
solara.Info("Choose at least two colors to draw a decision boundary.") | |
with solara.lab.Tab("table view"): | |
with solara.Column(classes=["py-4"]): # some nice y padding | |
if df is not None: | |
with solara.FileDownload(data=lambda: df.to_csv(), filename="drawdata.csv"): | |
solara.Button("download as csv", icon_name="mdi-download", outlined=True, color="primary") | |
solara.DataFrame(df) | |
# in the notebook: | |
Page() | |