RamAnanth1 commited on
Commit
484451a
·
1 Parent(s): a79cd47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -2
app.py CHANGED
@@ -2,8 +2,12 @@ import gradio as gr
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from sklearn.model_selection import train_test_split
5
- from sklearn.datasets import load_breast_cancer
6
- from sklearn.tree import DecisionTreeClassifier
 
 
 
 
7
 
8
  theme = gr.themes.Monochrome(
9
  primary_hue="indigo",
@@ -16,6 +20,71 @@ description = f"""
16
  This demo can be used to evaluate the ability of k-means initializations strategies to make the algorithm convergence robust
17
  """
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  with gr.Blocks(theme=theme) as demo:
20
  gr.Markdown('''
21
  <div>
@@ -23,5 +92,9 @@ with gr.Blocks(theme=theme) as demo:
23
  </div>
24
  ''')
25
  gr.Markdown(description)
 
 
 
 
26
 
27
  demo.launch()
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from sklearn.model_selection import train_test_split
5
+
6
+ import matplotlib.cm as cm
7
+ from sklearn.utils import shuffle
8
+ from sklearn.utils import check_random_state
9
+ from sklearn.cluster import MiniBatchKMeans
10
+ from sklearn.cluster import KMeans
11
 
12
  theme = gr.themes.Monochrome(
13
  primary_hue="indigo",
 
20
  This demo can be used to evaluate the ability of k-means initializations strategies to make the algorithm convergence robust
21
  """
22
 
23
+ # TODO: Make the below parameters user passable
24
+ random_state = np.random.RandomState(0)
25
+
26
+ # k-means models can do several random inits so as to be able to trade
27
+ # CPU time for convergence robustness
28
+ n_init_range = np.array([1, 5, 10, 15, 20])
29
+
30
+ # Datasets generation parameters
31
+ n_samples_per_center = 100
32
+ grid_size = 3
33
+ scale = 0.1
34
+ n_clusters = grid_size**2
35
+
36
+ def make_data(random_state, n_samples_per_center, grid_size, scale):
37
+ random_state = check_random_state(random_state)
38
+ centers = np.array([[i, j] for i in range(grid_size) for j in range(grid_size)])
39
+ n_clusters_true, n_features = centers.shape
40
+
41
+ noise = random_state.normal(
42
+ scale=scale, size=(n_samples_per_center, centers.shape[1])
43
+ )
44
+
45
+ X = np.concatenate([c + noise for c in centers])
46
+ y = np.concatenate([[i] * n_samples_per_center for i in range(n_clusters_true)])
47
+ return shuffle(X, y, random_state=random_state)
48
+
49
+ def quant_evaluation(n_runs):
50
+ plt.figure()
51
+ plots = []
52
+ legends = []
53
+
54
+ cases = [
55
+ (KMeans, "k-means++", {}, "^-"),
56
+ (KMeans, "random", {}, "o-"),
57
+ (MiniBatchKMeans, "k-means++", {"max_no_improvement": 3}, "x-"),
58
+ (MiniBatchKMeans, "random", {"max_no_improvement": 3, "init_size": 500}, "d-"),
59
+ ]
60
+
61
+ for factory, init, params, format in cases:
62
+ print("Evaluation of %s with %s init" % (factory.__name__, init))
63
+ inertia = np.empty((len(n_init_range), n_runs))
64
+
65
+ for run_id in range(n_runs):
66
+ X, y = make_data(run_id, n_samples_per_center, grid_size, scale)
67
+ for i, n_init in enumerate(n_init_range):
68
+ km = factory(
69
+ n_clusters=n_clusters,
70
+ init=init,
71
+ random_state=run_id,
72
+ n_init=n_init,
73
+ **params,
74
+ ).fit(X)
75
+ inertia[i, run_id] = km.inertia_
76
+ p = plt.errorbar(
77
+ n_init_range, inertia.mean(axis=1), inertia.std(axis=1), fmt=format
78
+ )
79
+ plots.append(p[0])
80
+ legends.append("%s with %s init" % (factory.__name__, init))
81
+
82
+ plt.xlabel("n_init")
83
+ plt.ylabel("inertia")
84
+ plt.legend(plots, legends)
85
+ plt.title("Mean inertia for various k-means init across %d runs" % n_runs)
86
+ return plt
87
+
88
  with gr.Blocks(theme=theme) as demo:
89
  gr.Markdown('''
90
  <div>
 
92
  </div>
93
  ''')
94
  gr.Markdown(description)
95
+ n_runs = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Evaluation Runs")
96
+ run_button = gr.Button('Evaluate')
97
+ plot_inertia = gr.Plot()
98
+ run_button.click(fn=quant_evaluation, inputs=[n_runs], outputs=plot_inertia)
99
 
100
  demo.launch()