computerscience-person's picture
Showcase Decision Tree Classifier using Iris dataset.
8fa82fa
raw
history blame
7.53 kB
import marimo
__generated_with = "0.10.16"
app = marimo.App()
@app.cell
def _():
import marimo as mo
import polars as pl
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
mo.md("# Iris Dataset Showcase")
return (
DecisionTreeClassifier,
accuracy_score,
classification_report,
confusion_matrix,
datasets,
mo,
pl,
train_test_split,
)
@app.cell(hide_code=True)
def _(datasets):
iris = datasets.load_iris()
X = iris.data
y = iris.target
return X, iris, y
@app.cell(hide_code=True)
def _(X, train_test_split, y):
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
return X_test, X_train, y_test, y_train
@app.cell
def _(DecisionTreeClassifier, X_train, mo, y_train):
classifier = DecisionTreeClassifier()
classifier.fit(X_train, y_train)
mo.md(f"""
## Decision Tree Classifier
""")
return (classifier,)
@app.cell
def _(X_test, classifier):
y_pred = classifier.predict(X_test)
return (y_pred,)
@app.cell
def _(
accuracy_score,
classification_report,
confusion_matrix,
mo,
y_pred,
y_test,
):
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
# Confusion matrix
conf_matrix = confusion_matrix(y_test, y_pred)
# Classification report
class_report = classification_report(y_test, y_pred)
mo.md(f"""
Accuracy: {accuracy}
Confusion Matrix:
```
{conf_matrix}
```
Classification Report:
```
{class_report}
```
""")
return accuracy, class_report, conf_matrix
@app.cell
def _(X_test, pl, y_pred, y_test):
import seaborn as sns
import matplotlib.pyplot as plt
df = pl.DataFrame({
"sepal length (cm)": X_test[:, 0],
"sepal width (cm)": X_test[:, 1],
"Predicted": y_pred,
"Actual": y_test
})
return df, plt, sns
@app.cell
def _(df, mo, plt, sns):
plt.figure(figsize=(10, 6))
sns.scatterplot(data=df, x='sepal length (cm)', y='sepal width (cm)', hue='Predicted', style='Actual', palette='Set1', markers=['o', 's', 'D'])
plt.title('Iris Dataset: Sepal Length vs Sepal Width')
plt.xlabel('Sepal Length (cm)')
plt.ylabel('Sepal Width (cm)')
plt.legend(title='Class')
mo.vstack(
[
mo.md("## Iris Dataset"),
plt.gcf()
]
)
return
@app.cell
def _(conf_matrix, iris, mo, plt, sns):
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
mo.vstack([
mo.md("## Confusion Matrix"),
plt.gcf()
])
return
@app.cell(hide_code=True)
def _(iris, pl):
iris_df = pl.DataFrame(data=iris.data, schema=iris.feature_names)
iris_df = iris_df.with_columns(pl.Series("species", iris.target))
return (iris_df,)
@app.cell
def _(iris_df, mo, plt, sns):
sns.pairplot(iris_df.to_pandas(), hue='species', palette='Set1', markers=["o", "s", "D"])
mo.vstack([
mo.md("## Pair Plot"),
plt.gcf()
])
return
@app.cell
def _(classifier, iris, mo, plt):
from sklearn.tree import plot_tree
plt.figure(figsize=(12, 8))
plot_tree(classifier, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
mo.vstack([
mo.md("## Classifier Decision Tree Visualization"),
plt.gcf()
])
return (plot_tree,)
@app.cell(hide_code=True)
def _():
tips = {
"Saving": (
"""
**Saving**
- _Name_ your app using the box at the top of the screen, or
with `Ctrl/Cmd+s`. You can also create a named app at the
command line, e.g., `marimo edit app_name.py`.
- _Save_ by clicking the save icon on the bottom right, or by
inputting `Ctrl/Cmd+s`. By default marimo is configured
to autosave.
"""
),
"Running": (
"""
1. _Run a cell_ by clicking the play ( ▷ ) button on the top
right of a cell, or by inputting `Ctrl/Cmd+Enter`.
2. _Run a stale cell_ by clicking the yellow run button on the
right of the cell, or by inputting `Ctrl/Cmd+Enter`. A cell is
stale when its code has been modified but not run.
3. _Run all stale cells_ by clicking the play ( ▷ ) button on
the bottom right of the screen, or input `Ctrl/Cmd+Shift+r`.
"""
),
"Console Output": (
"""
Console output (e.g., `print()` statements) is shown below a
cell.
"""
),
"Creating, Moving, and Deleting Cells": (
"""
1. _Create_ a new cell above or below a given one by clicking
the plus button to the left of the cell, which appears on
mouse hover.
2. _Move_ a cell up or down by dragging on the handle to the
right of the cell, which appears on mouse hover.
3. _Delete_ a cell by clicking the trash bin icon. Bring it
back by clicking the undo button on the bottom right of the
screen, or with `Ctrl/Cmd+Shift+z`.
"""
),
"Disabling Automatic Execution": (
"""
Via the notebook settings (gear icon) or footer panel, you
can disable automatic execution. This is helpful when
working with expensive notebooks or notebooks that have
side-effects like database transactions.
"""
),
"Disabling Cells": (
"""
You can disable a cell via the cell context menu.
marimo will never run a disabled cell or any cells that depend on it.
This can help prevent accidental execution of expensive computations
when editing a notebook.
"""
),
"Code Folding": (
"""
You can collapse or fold the code in a cell by clicking the arrow
icons in the line number column to the left, or by using keyboard
shortcuts.
Use the command palette (`Ctrl/Cmd+k`) or a keyboard shortcut to
quickly fold or unfold all cells.
"""
),
"Code Formatting": (
"""
If you have [ruff](https://github.com/astral-sh/ruff) installed,
you can format a cell with the keyboard shortcut `Ctrl/Cmd+b`.
"""
),
"Command Palette": (
"""
Use `Ctrl/Cmd+k` to open the command palette.
"""
),
"Keyboard Shortcuts": (
"""
Open the notebook menu (top-right) or input `Ctrl/Cmd+Shift+h` to
view a list of all keyboard shortcuts.
"""
),
"Configuration": (
"""
Configure the editor by clicking the gears icon near the top-right
of the screen.
"""
),
}
return (tips,)
if __name__ == "__main__":
app.run()