cmpatino commited on
Commit
bad0412
·
1 Parent(s): 0bf6843

Add quantile as parameter in the UI

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -5,12 +5,12 @@ from sklearn.cluster import MeanShift, estimate_bandwidth
5
  from sklearn.datasets import make_blobs
6
 
7
 
8
- def get_clusters_plot(n_blobs, cluster_std):
9
  X, _, centers = make_blobs(
10
  n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True
11
  )
12
 
13
- bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
14
 
15
  ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
16
  ms.fit(X)
@@ -20,25 +20,22 @@ def get_clusters_plot(n_blobs, cluster_std):
20
  labels_unique = np.unique(labels)
21
  n_clusters_ = len(labels_unique)
22
 
23
- colors = ["#dede00", "#377eb8", "#f781bf"]
24
- markers = ["x", "o", "^"]
25
-
26
  fig = plt.figure()
27
 
28
- for k, col in zip(range(n_clusters_), colors):
29
  my_members = labels == k
30
  cluster_center = cluster_centers[k]
31
- plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col)
32
  plt.plot(
33
  cluster_center[0],
34
  cluster_center[1],
35
- markers[k],
36
- markerfacecolor=col,
37
  markeredgecolor="k",
38
  markersize=14,
39
  )
40
 
41
- return fig
 
42
 
43
 
44
  demo = gr.Interface(
@@ -47,6 +44,14 @@ demo = gr.Interface(
47
  gr.Slider(
48
  minimum=2, maximum=10, label="Number of clusters in data", step=1, value=3
49
  ),
 
 
 
 
 
 
 
 
50
  gr.Slider(
51
  minimum=0.1,
52
  maximum=1,
@@ -55,7 +60,7 @@ demo = gr.Interface(
55
  value=0.6,
56
  ),
57
  ],
58
- gr.Plot(),
59
  allow_flagging="never",
60
  )
61
 
 
5
  from sklearn.datasets import make_blobs
6
 
7
 
8
+ def get_clusters_plot(n_blobs, quantile, cluster_std):
9
  X, _, centers = make_blobs(
10
  n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True
11
  )
12
 
13
+ bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500)
14
 
15
  ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
16
  ms.fit(X)
 
20
  labels_unique = np.unique(labels)
21
  n_clusters_ = len(labels_unique)
22
 
 
 
 
23
  fig = plt.figure()
24
 
25
+ for k in range(n_clusters_):
26
  my_members = labels == k
27
  cluster_center = cluster_centers[k]
28
+ plt.scatter(X[my_members, 0], X[my_members, 1])
29
  plt.plot(
30
  cluster_center[0],
31
  cluster_center[1],
32
+ "x",
 
33
  markeredgecolor="k",
34
  markersize=14,
35
  )
36
 
37
+ message = f"## True Clusters: {len(centers)} | Detected Clusters: {n_clusters_}"
38
+ return fig, message
39
 
40
 
41
  demo = gr.Interface(
 
44
  gr.Slider(
45
  minimum=2, maximum=10, label="Number of clusters in data", step=1, value=3
46
  ),
47
+ gr.Slider(
48
+ minimum=0,
49
+ maximum=1,
50
+ step=0.05,
51
+ value=0.2,
52
+ label="Quantile",
53
+ info="Used to determine clustering's bandwidth.",
54
+ ),
55
  gr.Slider(
56
  minimum=0.1,
57
  maximum=1,
 
60
  value=0.6,
61
  ),
62
  ],
63
+ [gr.Plot(), gr.Markdown()],
64
  allow_flagging="never",
65
  )
66