rashmi's picture
Update app.py
3768708
raw
history blame
3.44 kB
# Code source: Gaël Varoquaux
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets, cluster
from sklearn.feature_extraction.image import grid_to_graph
from datasets import load_dataset
plt.switch_backend("agg")
# Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
def do_submit(n_clusters):
# Load the dataset
dataset = load_dataset("sklearn-docs/digits", header=None)
# convert dataset to pandas
df = dataset["train"].to_pandas()
X = df.iloc[:, :64]
labels = df.iloc[:, 64]
images = X.values.reshape(-1, 8, 8)
connectivity = grid_to_graph(*images[0].shape)
agglo = cluster.FeatureAgglomeration(
connectivity=connectivity, n_clusters=int(n_clusters)
)
agglo.fit(X)
X_reduced = agglo.transform(X)
X_restored = agglo.inverse_transform(X_reduced)
images_restored = np.reshape(X_restored, images.shape)
plt.figure(1, figsize=(4, 3.5))
plt.clf()
plt.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.91)
for i in range(4):
plt.subplot(3, 4, i + 1)
plt.imshow(images[i], cmap=plt.cm.gray, vmax=16, interpolation="nearest")
plt.xticks(())
plt.yticks(())
if i == 1:
plt.title("Original data")
plt.subplot(3, 4, 4 + i + 1)
plt.imshow(
images_restored[i], cmap=plt.cm.gray, vmax=16, interpolation="nearest"
)
if i == 1:
plt.title("Agglomerated data")
plt.xticks(())
plt.yticks(())
plt.subplot(3, 4, 10)
plt.imshow(
np.reshape(agglo.labels_, images[0].shape),
interpolation="nearest",
cmap=plt.cm.nipy_spectral,
)
plt.xticks(())
plt.yticks(())
plt.title("Labels")
return plt
title = "Feature Agglomeration"
with gr.Blocks(title=title, theme=theme) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(
"<b>These images show how similar features are merged together using feature agglomeration.</b>"
)
gr.Markdown(
"[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_agglomeration.html)"
)
gr.Markdown(
"The FeatureAgglomeration uses [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html#sklearn.cluster.AgglomerativeClustering)\
to group together features that look very similar, thus decreasing the number of features. It is a dimensionality reduction \
tool, see [Unsupervised dimensionality reduction](https://scikit-learn.org/stable/modules/unsupervised_reduction.html#data-reduction)."
)
with gr.Row():
n_clusters = gr.Slider(
minimum=10,
maximum=50,
label="Number of clusters",
info="Number of clusters for FeatureAgglomeration",
step=1,
value=32,
)
plt_out = gr.Plot()
n_clusters.change(do_submit, n_clusters, plt_out)
if __name__ == "__main__":
demo.launch()