Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
·
7c9b8f5
1
Parent(s):
0415b11
Change layout to show all models at once
Browse filesUsing Blocks and grid to show the predictions of all models at once to
make it easier to compare them.
app.py
CHANGED
@@ -4,6 +4,9 @@ Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_
|
|
4 |
|
5 |
"""
|
6 |
|
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
import matplotlib.pyplot as plt
|
9 |
import numpy as np
|
@@ -22,6 +25,12 @@ plt.style.use('seaborn')
|
|
22 |
SEED = 0
|
23 |
MAX_CLUSTERS = 10
|
24 |
N_SAMPLES = 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
np.random.seed(SEED)
|
26 |
|
27 |
|
@@ -88,19 +97,29 @@ DATA_MAPPING = {
|
|
88 |
'varied': get_varied,
|
89 |
}
|
90 |
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
|
93 |
model.set_params(**kwargs)
|
94 |
return model.fit(X)
|
95 |
|
96 |
|
97 |
-
def get_dbscan(X, n_clusters, **kwargs):
|
98 |
model = DBSCAN(eps=0.3)
|
99 |
model.set_params(**kwargs)
|
100 |
return model.fit(X)
|
101 |
|
102 |
|
103 |
-
def get_agglomerative(X, n_clusters, **kwargs):
|
104 |
connectivity = kneighbors_graph(
|
105 |
X, n_neighbors=n_clusters, include_self=False
|
106 |
)
|
@@ -113,14 +132,14 @@ def get_agglomerative(X, n_clusters, **kwargs):
|
|
113 |
return model.fit(X)
|
114 |
|
115 |
|
116 |
-
def get_meanshift(X, n_clusters, **kwargs):
|
117 |
-
bandwidth = estimate_bandwidth(X, quantile=0.
|
118 |
model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
|
119 |
model.set_params(**kwargs)
|
120 |
return model.fit(X)
|
121 |
|
122 |
|
123 |
-
def get_spectral(X, n_clusters, **kwargs):
|
124 |
model = SpectralClustering(
|
125 |
n_clusters=n_clusters,
|
126 |
eigen_solver="arpack",
|
@@ -130,7 +149,7 @@ def get_spectral(X, n_clusters, **kwargs):
|
|
130 |
return model.fit(X)
|
131 |
|
132 |
|
133 |
-
def get_optics(X, n_clusters, **kwargs):
|
134 |
model = OPTICS(
|
135 |
min_samples=7,
|
136 |
xi=0.05,
|
@@ -140,13 +159,13 @@ def get_optics(X, n_clusters, **kwargs):
|
|
140 |
return model.fit(X)
|
141 |
|
142 |
|
143 |
-
def get_birch(X, n_clusters, **kwargs):
|
144 |
model = Birch(n_clusters=n_clusters)
|
145 |
model.set_params(**kwargs)
|
146 |
return model.fit(X)
|
147 |
|
148 |
|
149 |
-
def get_gaussianmixture(X, n_clusters, **kwargs):
|
150 |
model = GaussianMixture(
|
151 |
n_components=n_clusters, covariance_type="full", random_state=SEED,
|
152 |
)
|
@@ -155,25 +174,26 @@ def get_gaussianmixture(X, n_clusters, **kwargs):
|
|
155 |
|
156 |
|
157 |
MODEL_MAPPING = {
|
|
|
158 |
'KMeans': get_kmeans,
|
159 |
'DBSCAN': get_dbscan,
|
160 |
-
'AgglomerativeClustering': get_agglomerative,
|
161 |
'MeanShift': get_meanshift,
|
162 |
'SpectralClustering': get_spectral,
|
163 |
'OPTICS': get_optics,
|
164 |
'Birch': get_birch,
|
165 |
'GaussianMixture': get_gaussianmixture,
|
|
|
166 |
}
|
167 |
|
168 |
|
169 |
def plot_clusters(ax, X, labels):
|
170 |
set_clusters = set(labels)
|
171 |
set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
|
172 |
-
for label in sorted(set_clusters):
|
173 |
idx = labels == label
|
174 |
if not sum(idx):
|
175 |
continue
|
176 |
-
ax.scatter(X[idx, 0], X[idx, 1])
|
177 |
|
178 |
# show outliers (if any)
|
179 |
idx = labels == -1
|
@@ -186,26 +206,23 @@ def plot_clusters(ax, X, labels):
|
|
186 |
return ax
|
187 |
|
188 |
|
189 |
-
def cluster(
|
190 |
-
|
|
|
|
|
|
|
|
|
191 |
X, labels = DATA_MAPPING[dataset](n_clusters)
|
192 |
-
model = MODEL_MAPPING[clustering_algorithm](X, n_clusters=n_clusters)
|
193 |
if hasattr(model, "labels_"):
|
194 |
y_pred = model.labels_.astype(int)
|
195 |
else:
|
196 |
y_pred = model.predict(X)
|
197 |
|
198 |
-
fig,
|
199 |
-
|
200 |
-
# show true labels in first panel
|
201 |
-
ax = axes[0]
|
202 |
-
plot_clusters(ax, X, labels)
|
203 |
-
ax.set_title("True clusters")
|
204 |
|
205 |
-
# show learned clusters in second panel
|
206 |
-
ax = axes[1]
|
207 |
plot_clusters(ax, X, y_pred)
|
208 |
-
ax.set_title(clustering_algorithm)
|
209 |
|
210 |
return fig
|
211 |
|
@@ -213,31 +230,42 @@ def cluster(clustering_algorithm: str, dataset: str, n_clusters: int):
|
|
213 |
title = "Clustering with Scikit-learn"
|
214 |
description = (
|
215 |
"This example shows how different clustering algorithms work. Simply pick "
|
216 |
-
"the
|
217 |
-
)
|
218 |
-
demo = gr.Interface(
|
219 |
-
fn=cluster,
|
220 |
-
inputs=[
|
221 |
-
gr.Radio(
|
222 |
-
list(MODEL_MAPPING),
|
223 |
-
value="KMeans",
|
224 |
-
label="clustering algorithm"
|
225 |
-
),
|
226 |
-
gr.Radio(
|
227 |
-
list(DATA_MAPPING),
|
228 |
-
value="regular",
|
229 |
-
label="dataset"
|
230 |
-
),
|
231 |
-
gr.Slider(
|
232 |
-
minimum=1,
|
233 |
-
maximum=MAX_CLUSTERS,
|
234 |
-
value=4,
|
235 |
-
step=1,
|
236 |
-
)
|
237 |
-
],
|
238 |
-
title=title,
|
239 |
-
description=description,
|
240 |
-
outputs=gr.Plot(),
|
241 |
)
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
demo.launch()
|
|
|
4 |
|
5 |
"""
|
6 |
|
7 |
+
import math
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
import gradio as gr
|
11 |
import matplotlib.pyplot as plt
|
12 |
import numpy as np
|
|
|
25 |
SEED = 0
|
26 |
MAX_CLUSTERS = 10
|
27 |
N_SAMPLES = 1000
|
28 |
+
N_COLS = 3
|
29 |
+
FIGSIZE = 7, 7 # does not affect size in webpage
|
30 |
+
COLORS = [
|
31 |
+
'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
|
32 |
+
]
|
33 |
+
assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
|
34 |
np.random.seed(SEED)
|
35 |
|
36 |
|
|
|
97 |
'varied': get_varied,
|
98 |
}
|
99 |
|
100 |
+
|
101 |
+
def get_groundtruth_model(X, labels, n_clusters, **kwargs):
|
102 |
+
# dummy model to show true label distribution
|
103 |
+
class Dummy:
|
104 |
+
def __init__(self, y):
|
105 |
+
self.labels_ = labels
|
106 |
+
|
107 |
+
return Dummy(labels)
|
108 |
+
|
109 |
+
|
110 |
+
def get_kmeans(X, labels, n_clusters, **kwargs):
|
111 |
model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
|
112 |
model.set_params(**kwargs)
|
113 |
return model.fit(X)
|
114 |
|
115 |
|
116 |
+
def get_dbscan(X, labels, n_clusters, **kwargs):
|
117 |
model = DBSCAN(eps=0.3)
|
118 |
model.set_params(**kwargs)
|
119 |
return model.fit(X)
|
120 |
|
121 |
|
122 |
+
def get_agglomerative(X, labels, n_clusters, **kwargs):
|
123 |
connectivity = kneighbors_graph(
|
124 |
X, n_neighbors=n_clusters, include_self=False
|
125 |
)
|
|
|
132 |
return model.fit(X)
|
133 |
|
134 |
|
135 |
+
def get_meanshift(X, labels, n_clusters, **kwargs):
|
136 |
+
bandwidth = estimate_bandwidth(X, quantile=0.25)
|
137 |
model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
|
138 |
model.set_params(**kwargs)
|
139 |
return model.fit(X)
|
140 |
|
141 |
|
142 |
+
def get_spectral(X, labels, n_clusters, **kwargs):
|
143 |
model = SpectralClustering(
|
144 |
n_clusters=n_clusters,
|
145 |
eigen_solver="arpack",
|
|
|
149 |
return model.fit(X)
|
150 |
|
151 |
|
152 |
+
def get_optics(X, labels, n_clusters, **kwargs):
|
153 |
model = OPTICS(
|
154 |
min_samples=7,
|
155 |
xi=0.05,
|
|
|
159 |
return model.fit(X)
|
160 |
|
161 |
|
162 |
+
def get_birch(X, labels, n_clusters, **kwargs):
|
163 |
model = Birch(n_clusters=n_clusters)
|
164 |
model.set_params(**kwargs)
|
165 |
return model.fit(X)
|
166 |
|
167 |
|
168 |
+
def get_gaussianmixture(X, labels, n_clusters, **kwargs):
|
169 |
model = GaussianMixture(
|
170 |
n_components=n_clusters, covariance_type="full", random_state=SEED,
|
171 |
)
|
|
|
174 |
|
175 |
|
176 |
MODEL_MAPPING = {
|
177 |
+
'True labels': get_groundtruth_model,
|
178 |
'KMeans': get_kmeans,
|
179 |
'DBSCAN': get_dbscan,
|
|
|
180 |
'MeanShift': get_meanshift,
|
181 |
'SpectralClustering': get_spectral,
|
182 |
'OPTICS': get_optics,
|
183 |
'Birch': get_birch,
|
184 |
'GaussianMixture': get_gaussianmixture,
|
185 |
+
'AgglomerativeClustering': get_agglomerative,
|
186 |
}
|
187 |
|
188 |
|
189 |
def plot_clusters(ax, X, labels):
|
190 |
set_clusters = set(labels)
|
191 |
set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
|
192 |
+
for label, color in zip(sorted(set_clusters), COLORS):
|
193 |
idx = labels == label
|
194 |
if not sum(idx):
|
195 |
continue
|
196 |
+
ax.scatter(X[idx, 0], X[idx, 1], color=color)
|
197 |
|
198 |
# show outliers (if any)
|
199 |
idx = labels == -1
|
|
|
206 |
return ax
|
207 |
|
208 |
|
209 |
+
def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):
|
210 |
+
if isinstance(n_clusters, dict):
|
211 |
+
n_clusters = n_clusters['value']
|
212 |
+
else:
|
213 |
+
n_clusters = int(n_clusters)
|
214 |
+
|
215 |
X, labels = DATA_MAPPING[dataset](n_clusters)
|
216 |
+
model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)
|
217 |
if hasattr(model, "labels_"):
|
218 |
y_pred = model.labels_.astype(int)
|
219 |
else:
|
220 |
y_pred = model.predict(X)
|
221 |
|
222 |
+
fig, ax = plt.subplots(figsize=FIGSIZE)
|
|
|
|
|
|
|
|
|
|
|
223 |
|
|
|
|
|
224 |
plot_clusters(ax, X, y_pred)
|
225 |
+
ax.set_title(clustering_algorithm, fontsize=16)
|
226 |
|
227 |
return fig
|
228 |
|
|
|
230 |
title = "Clustering with Scikit-learn"
|
231 |
description = (
|
232 |
"This example shows how different clustering algorithms work. Simply pick "
|
233 |
+
"the dataset and the number of clusters to see how the clustering algorithms work."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
)
|
235 |
|
236 |
+
with gr.Blocks(title=title) as demo:
|
237 |
+
gr.HTML(f"<b>{title}</b>")
|
238 |
+
gr.Markdown(description)
|
239 |
+
|
240 |
+
input_models = list(MODEL_MAPPING)
|
241 |
+
input_data = gr.Radio(
|
242 |
+
list(DATA_MAPPING),
|
243 |
+
value="regular",
|
244 |
+
label="dataset"
|
245 |
+
)
|
246 |
+
input_n_clusters = gr.Slider(
|
247 |
+
minimum=1,
|
248 |
+
maximum=MAX_CLUSTERS,
|
249 |
+
value=4,
|
250 |
+
step=1,
|
251 |
+
label='Number of clusters'
|
252 |
+
)
|
253 |
+
n_rows = int(math.ceil(len(input_models) / N_COLS))
|
254 |
+
counter = 0
|
255 |
+
# code below is not very elegant, maybe there is a better way?
|
256 |
+
for i in range(n_rows):
|
257 |
+
with gr.Row():
|
258 |
+
for j in range(N_COLS):
|
259 |
+
with gr.Column():
|
260 |
+
if counter >= len(input_models):
|
261 |
+
break
|
262 |
+
|
263 |
+
input_model = input_models[counter]
|
264 |
+
plot = gr.Plot(label=input_model)
|
265 |
+
fn = partial(cluster, clustering_algorithm=input_model)
|
266 |
+
input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
|
267 |
+
input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
|
268 |
+
counter += 1
|
269 |
+
|
270 |
+
|
271 |
demo.launch()
|