rashmi commited on
Commit
2f297cd
·
1 Parent(s): 91007bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code source: Gaël Varoquaux
2
+ # Modified for documentation by Jaques Grobler
3
+ # License: BSD 3 clause
4
+
5
+ import gradio as gr
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+
10
+ from sklearn import datasets, cluster
11
+ from sklearn.feature_extraction.image import grid_to_graph
12
+ from datasets import load_dataset
13
+
14
+ plt.switch_backend("agg")
15
+
16
+
17
+ # Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
18
+ theme = gr.themes.Monochrome(
19
+ primary_hue="indigo",
20
+ secondary_hue="blue",
21
+ neutral_hue="slate",
22
+ radius_size=gr.themes.sizes.radius_sm,
23
+ font=[
24
+ gr.themes.GoogleFont("Open Sans"),
25
+ "ui-sans-serif",
26
+ "system-ui",
27
+ "sans-serif",
28
+ ],
29
+ )
30
+
31
+
32
+ def do_submit(n_clusters):
33
+ # Load the dataset
34
+ dataset = load_dataset("sklearn-docs/digits", header=None)
35
+ # convert dataset to pandas
36
+ df = dataset["train"].to_pandas()
37
+ X = df.iloc[:, :64]
38
+ labels = df.iloc[:, 64]
39
+ images = X.values.reshape(-1, 8, 8)
40
+ connectivity = grid_to_graph(*images[0].shape)
41
+
42
+ agglo = cluster.FeatureAgglomeration(
43
+ connectivity=connectivity, n_clusters=int(n_clusters)
44
+ )
45
+ agglo.fit(X)
46
+ X_reduced = agglo.transform(X)
47
+
48
+ X_restored = agglo.inverse_transform(X_reduced)
49
+ images_restored = np.reshape(X_restored, images.shape)
50
+ plt.figure(1, figsize=(4, 3.5))
51
+ plt.clf()
52
+ plt.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.91)
53
+ for i in range(4):
54
+ plt.subplot(3, 4, i + 1)
55
+ plt.imshow(images[i], cmap=plt.cm.gray, vmax=16, interpolation="nearest")
56
+ plt.xticks(())
57
+ plt.yticks(())
58
+ if i == 1:
59
+ plt.title("Original data")
60
+ plt.subplot(3, 4, 4 + i + 1)
61
+ plt.imshow(
62
+ images_restored[i], cmap=plt.cm.gray, vmax=16, interpolation="nearest"
63
+ )
64
+ if i == 1:
65
+ plt.title("Agglomerated data")
66
+ plt.xticks(())
67
+ plt.yticks(())
68
+
69
+ plt.subplot(3, 4, 10)
70
+ plt.imshow(
71
+ np.reshape(agglo.labels_, images[0].shape),
72
+ interpolation="nearest",
73
+ cmap=plt.cm.nipy_spectral,
74
+ )
75
+ plt.xticks(())
76
+ plt.yticks(())
77
+ plt.title("Labels")
78
+ return plt
79
+
80
+
81
+ title = "Feature Agglomeration"
82
+ with gr.Blocks(title=title, theme=theme) as demo:
83
+ gr.Markdown(f"## {title}")
84
+ gr.Markdown(
85
+ "<b>These images show how similar features are merged together using feature agglomeration.</b>"
86
+ )
87
+ gr.Markdown(
88
+ "[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_digits_agglomeration.html)"
89
+ )
90
+
91
+ gr.Markdown(
92
+ "The FeatureAgglomeration uses [agglomerative clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html#sklearn.cluster.AgglomerativeClustering)\
93
+ to group together features that look very similar, thus decreasing the number of features. It is a dimensionality reduction \
94
+ tool, see [Unsupervised dimensionality reduction](https://scikit-learn.org/stable/modules/unsupervised_reduction.html#data-reduction)."
95
+ )
96
+ n_clusters = gr.Slider(
97
+ minimum=10,
98
+ maximum=50,
99
+ label="Number of clusters",
100
+ info="Number of clusters for FeatureAgglomeration",
101
+ step=1,
102
+ value=32,
103
+ )
104
+
105
+ sub_btn = gr.Button("Submit")
106
+ plt_out = gr.Plot()
107
+
108
+ sub_btn.click(fn=do_submit, inputs=[n_clusters], outputs=[plt_out])
109
+
110
+
111
+ if __name__ == "__main__":
112
+ demo.launch()