Benjamin Bossan commited on
Commit
7c9b8f5
·
1 Parent(s): 0415b11

Change layout to show all models at once

Browse files

Using Blocks and grid to show the predictions of all models at once to
make it easier to compare them.

Files changed (1) hide show
  1. app.py +77 -49
app.py CHANGED
@@ -4,6 +4,9 @@ Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_
4
 
5
  """
6
 
 
 
 
7
  import gradio as gr
8
  import matplotlib.pyplot as plt
9
  import numpy as np
@@ -22,6 +25,12 @@ plt.style.use('seaborn')
22
  SEED = 0
23
  MAX_CLUSTERS = 10
24
  N_SAMPLES = 1000
 
 
 
 
 
 
25
  np.random.seed(SEED)
26
 
27
 
@@ -88,19 +97,29 @@ DATA_MAPPING = {
88
  'varied': get_varied,
89
  }
90
 
91
- def get_kmeans(X, n_clusters, **kwargs):
 
 
 
 
 
 
 
 
 
 
92
  model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
93
  model.set_params(**kwargs)
94
  return model.fit(X)
95
 
96
 
97
- def get_dbscan(X, n_clusters, **kwargs):
98
  model = DBSCAN(eps=0.3)
99
  model.set_params(**kwargs)
100
  return model.fit(X)
101
 
102
 
103
- def get_agglomerative(X, n_clusters, **kwargs):
104
  connectivity = kneighbors_graph(
105
  X, n_neighbors=n_clusters, include_self=False
106
  )
@@ -113,14 +132,14 @@ def get_agglomerative(X, n_clusters, **kwargs):
113
  return model.fit(X)
114
 
115
 
116
- def get_meanshift(X, n_clusters, **kwargs):
117
- bandwidth = estimate_bandwidth(X, quantile=0.3)
118
  model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
119
  model.set_params(**kwargs)
120
  return model.fit(X)
121
 
122
 
123
- def get_spectral(X, n_clusters, **kwargs):
124
  model = SpectralClustering(
125
  n_clusters=n_clusters,
126
  eigen_solver="arpack",
@@ -130,7 +149,7 @@ def get_spectral(X, n_clusters, **kwargs):
130
  return model.fit(X)
131
 
132
 
133
- def get_optics(X, n_clusters, **kwargs):
134
  model = OPTICS(
135
  min_samples=7,
136
  xi=0.05,
@@ -140,13 +159,13 @@ def get_optics(X, n_clusters, **kwargs):
140
  return model.fit(X)
141
 
142
 
143
- def get_birch(X, n_clusters, **kwargs):
144
  model = Birch(n_clusters=n_clusters)
145
  model.set_params(**kwargs)
146
  return model.fit(X)
147
 
148
 
149
- def get_gaussianmixture(X, n_clusters, **kwargs):
150
  model = GaussianMixture(
151
  n_components=n_clusters, covariance_type="full", random_state=SEED,
152
  )
@@ -155,25 +174,26 @@ def get_gaussianmixture(X, n_clusters, **kwargs):
155
 
156
 
157
  MODEL_MAPPING = {
 
158
  'KMeans': get_kmeans,
159
  'DBSCAN': get_dbscan,
160
- 'AgglomerativeClustering': get_agglomerative,
161
  'MeanShift': get_meanshift,
162
  'SpectralClustering': get_spectral,
163
  'OPTICS': get_optics,
164
  'Birch': get_birch,
165
  'GaussianMixture': get_gaussianmixture,
 
166
  }
167
 
168
 
169
  def plot_clusters(ax, X, labels):
170
  set_clusters = set(labels)
171
  set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
172
- for label in sorted(set_clusters):
173
  idx = labels == label
174
  if not sum(idx):
175
  continue
176
- ax.scatter(X[idx, 0], X[idx, 1])
177
 
178
  # show outliers (if any)
179
  idx = labels == -1
@@ -186,26 +206,23 @@ def plot_clusters(ax, X, labels):
186
  return ax
187
 
188
 
189
- def cluster(clustering_algorithm: str, dataset: str, n_clusters: int):
190
- n_clusters = int(n_clusters)
 
 
 
 
191
  X, labels = DATA_MAPPING[dataset](n_clusters)
192
- model = MODEL_MAPPING[clustering_algorithm](X, n_clusters=n_clusters)
193
  if hasattr(model, "labels_"):
194
  y_pred = model.labels_.astype(int)
195
  else:
196
  y_pred = model.predict(X)
197
 
198
- fig, axes = plt.subplots(1, 2, figsize=(16, 8))
199
-
200
- # show true labels in first panel
201
- ax = axes[0]
202
- plot_clusters(ax, X, labels)
203
- ax.set_title("True clusters")
204
 
205
- # show learned clusters in second panel
206
- ax = axes[1]
207
  plot_clusters(ax, X, y_pred)
208
- ax.set_title(clustering_algorithm)
209
 
210
  return fig
211
 
@@ -213,31 +230,42 @@ def cluster(clustering_algorithm: str, dataset: str, n_clusters: int):
213
  title = "Clustering with Scikit-learn"
214
  description = (
215
  "This example shows how different clustering algorithms work. Simply pick "
216
- "the algorithm and the dataset to see how the clustering algorithms work."
217
- )
218
- demo = gr.Interface(
219
- fn=cluster,
220
- inputs=[
221
- gr.Radio(
222
- list(MODEL_MAPPING),
223
- value="KMeans",
224
- label="clustering algorithm"
225
- ),
226
- gr.Radio(
227
- list(DATA_MAPPING),
228
- value="regular",
229
- label="dataset"
230
- ),
231
- gr.Slider(
232
- minimum=1,
233
- maximum=MAX_CLUSTERS,
234
- value=4,
235
- step=1,
236
- )
237
- ],
238
- title=title,
239
- description=description,
240
- outputs=gr.Plot(),
241
  )
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  demo.launch()
 
4
 
5
  """
6
 
7
+ import math
8
+ from functools import partial
9
+
10
  import gradio as gr
11
  import matplotlib.pyplot as plt
12
  import numpy as np
 
25
  SEED = 0
26
  MAX_CLUSTERS = 10
27
  N_SAMPLES = 1000
28
+ N_COLS = 3
29
+ FIGSIZE = 7, 7 # does not affect size in webpage
30
+ COLORS = [
31
+ 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
32
+ ]
33
+ assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
34
  np.random.seed(SEED)
35
 
36
 
 
97
  'varied': get_varied,
98
  }
99
 
100
+
101
+ def get_groundtruth_model(X, labels, n_clusters, **kwargs):
102
+ # dummy model to show true label distribution
103
+ class Dummy:
104
+ def __init__(self, y):
105
+ self.labels_ = labels
106
+
107
+ return Dummy(labels)
108
+
109
+
110
+ def get_kmeans(X, labels, n_clusters, **kwargs):
111
  model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
112
  model.set_params(**kwargs)
113
  return model.fit(X)
114
 
115
 
116
+ def get_dbscan(X, labels, n_clusters, **kwargs):
117
  model = DBSCAN(eps=0.3)
118
  model.set_params(**kwargs)
119
  return model.fit(X)
120
 
121
 
122
+ def get_agglomerative(X, labels, n_clusters, **kwargs):
123
  connectivity = kneighbors_graph(
124
  X, n_neighbors=n_clusters, include_self=False
125
  )
 
132
  return model.fit(X)
133
 
134
 
135
+ def get_meanshift(X, labels, n_clusters, **kwargs):
136
+ bandwidth = estimate_bandwidth(X, quantile=0.25)
137
  model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
138
  model.set_params(**kwargs)
139
  return model.fit(X)
140
 
141
 
142
+ def get_spectral(X, labels, n_clusters, **kwargs):
143
  model = SpectralClustering(
144
  n_clusters=n_clusters,
145
  eigen_solver="arpack",
 
149
  return model.fit(X)
150
 
151
 
152
+ def get_optics(X, labels, n_clusters, **kwargs):
153
  model = OPTICS(
154
  min_samples=7,
155
  xi=0.05,
 
159
  return model.fit(X)
160
 
161
 
162
+ def get_birch(X, labels, n_clusters, **kwargs):
163
  model = Birch(n_clusters=n_clusters)
164
  model.set_params(**kwargs)
165
  return model.fit(X)
166
 
167
 
168
+ def get_gaussianmixture(X, labels, n_clusters, **kwargs):
169
  model = GaussianMixture(
170
  n_components=n_clusters, covariance_type="full", random_state=SEED,
171
  )
 
174
 
175
 
176
  MODEL_MAPPING = {
177
+ 'True labels': get_groundtruth_model,
178
  'KMeans': get_kmeans,
179
  'DBSCAN': get_dbscan,
 
180
  'MeanShift': get_meanshift,
181
  'SpectralClustering': get_spectral,
182
  'OPTICS': get_optics,
183
  'Birch': get_birch,
184
  'GaussianMixture': get_gaussianmixture,
185
+ 'AgglomerativeClustering': get_agglomerative,
186
  }
187
 
188
 
189
  def plot_clusters(ax, X, labels):
190
  set_clusters = set(labels)
191
  set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
192
+ for label, color in zip(sorted(set_clusters), COLORS):
193
  idx = labels == label
194
  if not sum(idx):
195
  continue
196
+ ax.scatter(X[idx, 0], X[idx, 1], color=color)
197
 
198
  # show outliers (if any)
199
  idx = labels == -1
 
206
  return ax
207
 
208
 
209
+ def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):
210
+ if isinstance(n_clusters, dict):
211
+ n_clusters = n_clusters['value']
212
+ else:
213
+ n_clusters = int(n_clusters)
214
+
215
  X, labels = DATA_MAPPING[dataset](n_clusters)
216
+ model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)
217
  if hasattr(model, "labels_"):
218
  y_pred = model.labels_.astype(int)
219
  else:
220
  y_pred = model.predict(X)
221
 
222
+ fig, ax = plt.subplots(figsize=FIGSIZE)
 
 
 
 
 
223
 
 
 
224
  plot_clusters(ax, X, y_pred)
225
+ ax.set_title(clustering_algorithm, fontsize=16)
226
 
227
  return fig
228
 
 
230
  title = "Clustering with Scikit-learn"
231
  description = (
232
  "This example shows how different clustering algorithms work. Simply pick "
233
+ "the dataset and the number of clusters to see how the clustering algorithms work."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
 
236
+ with gr.Blocks(title=title) as demo:
237
+ gr.HTML(f"<b>{title}</b>")
238
+ gr.Markdown(description)
239
+
240
+ input_models = list(MODEL_MAPPING)
241
+ input_data = gr.Radio(
242
+ list(DATA_MAPPING),
243
+ value="regular",
244
+ label="dataset"
245
+ )
246
+ input_n_clusters = gr.Slider(
247
+ minimum=1,
248
+ maximum=MAX_CLUSTERS,
249
+ value=4,
250
+ step=1,
251
+ label='Number of clusters'
252
+ )
253
+ n_rows = int(math.ceil(len(input_models) / N_COLS))
254
+ counter = 0
255
+ # code below is not very elegant, maybe there is a better way?
256
+ for i in range(n_rows):
257
+ with gr.Row():
258
+ for j in range(N_COLS):
259
+ with gr.Column():
260
+ if counter >= len(input_models):
261
+ break
262
+
263
+ input_model = input_models[counter]
264
+ plot = gr.Plot(label=input_model)
265
+ fn = partial(cluster, clustering_algorithm=input_model)
266
+ input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
267
+ input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
268
+ counter += 1
269
+
270
+
271
  demo.launch()