Benjamin Bossan commited on
Commit
6674a4f
·
1 Parent(s): bb1f0db

Initial commit

Browse files
Files changed (3) hide show
  1. README.md +5 -1
  2. app.py +206 -0
  3. requirements.txt +2 -0
README.md CHANGED
@@ -8,6 +8,10 @@ sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-3-clause
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-3-clause
11
+ tags:
12
+ - sklearn
13
  ---
14
 
15
+ # Clustering with scikit learn
16
+
17
+ Gradio demo based on this [sklearn demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html).
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio demo for different clustering techiniques
2
+
3
+ Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html
4
+
5
+ """
6
+
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from sklearn.cluster import (
11
+ AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth
12
+ )
13
+ from sklearn.datasets import make_blobs, make_circles, make_moons
14
+ from sklearn.mixture import GaussianMixture
15
+ from sklearn.neighbors import kneighbors_graph
16
+ from sklearn.preprocessing import StandardScaler
17
+
18
+
19
+ plt.style.use('seaborn')
20
+
21
+
22
+ SEED = 0
23
+ N_CLUSTERS = 4
24
+ N_SAMPLES = 1000
25
+ np.random.seed(SEED)
26
+
27
+
28
+ def normalize(X):
29
+ return StandardScaler().fit_transform(X)
30
+
31
+
32
+ def get_regular():
33
+ centers = [[1, 1], [1, -1], [-1, 1], [-1, -1]]
34
+ assert len(centers) == N_CLUSTERS
35
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.7, random_state=SEED)
36
+ return normalize(X), labels
37
+
38
+
39
+ def get_circles():
40
+ X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
41
+ return normalize(X), labels
42
+
43
+
44
+ def get_moons():
45
+ X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
46
+ return normalize(X), labels
47
+
48
+
49
+ def get_noise():
50
+ X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES)
51
+ return normalize(X), labels
52
+
53
+
54
+ def get_anisotropic():
55
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS, random_state=170)
56
+ transformation = [[0.6, -0.6], [-0.4, 0.8]]
57
+ X = np.dot(X, transformation)
58
+ return X, labels
59
+
60
+
61
+ def get_varied():
62
+ X, labels = make_blobs(
63
+ n_samples=N_SAMPLES, cluster_std=[1.0, 2.5, 0.5], random_state=SEED
64
+ )
65
+ return normalize(X), labels
66
+
67
+
68
+ DATA_MAPPING = {
69
+ 'regular': get_regular,
70
+ 'circles': get_circles,
71
+ 'moons': get_moons,
72
+ 'noise': get_noise,
73
+ 'anisotropic': get_anisotropic,
74
+ 'varied': get_varied,
75
+ }
76
+
77
+ def get_kmeans(X, **kwargs):
78
+ model = KMeans(init="k-means++", n_clusters=N_CLUSTERS, n_init=10, random_state=SEED)
79
+ model.set_params(**kwargs)
80
+ return model.fit(X)
81
+
82
+
83
+ def get_dbscan(X, **kwargs):
84
+ model = DBSCAN(eps=0.3)
85
+ model.set_params(**kwargs)
86
+ return model.fit(X)
87
+
88
+
89
+ def get_agglomerative(X, **kwargs):
90
+ connectivity = kneighbors_graph(
91
+ X, n_neighbors=N_CLUSTERS, include_self=False
92
+ )
93
+ # make connectivity symmetric
94
+ connectivity = 0.5 * (connectivity + connectivity.T)
95
+ model = AgglomerativeClustering(
96
+ n_clusters=N_CLUSTERS, linkage="ward", connectivity=connectivity
97
+ )
98
+ model.set_params(**kwargs)
99
+ return model.fit(X)
100
+
101
+
102
+ def get_meanshift(X, **kwargs):
103
+ bandwidth = estimate_bandwidth(X, quantile=0.3)
104
+ model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
105
+ model.set_params(**kwargs)
106
+ return model.fit(X)
107
+
108
+
109
+ def get_spectral(X, **kwargs):
110
+ model = SpectralClustering(
111
+ n_clusters=N_CLUSTERS,
112
+ eigen_solver="arpack",
113
+ affinity="nearest_neighbors",
114
+ )
115
+ model.set_params(**kwargs)
116
+ return model.fit(X)
117
+
118
+
119
+ def get_optics(X, **kwargs):
120
+ model = OPTICS(
121
+ min_samples=7,
122
+ xi=0.05,
123
+ min_cluster_size=0.1,
124
+ )
125
+ model.set_params(**kwargs)
126
+ return model.fit(X)
127
+
128
+
129
+ def get_birch(X, **kwargs):
130
+ model = Birch(n_clusters=3)
131
+ model.set_params(**kwargs)
132
+ return model.fit(X)
133
+
134
+
135
+ def get_gaussianmixture(X, **kwargs):
136
+ model = GaussianMixture(
137
+ n_components=N_CLUSTERS, covariance_type="full", random_state=SEED,
138
+ )
139
+ model.set_params(**kwargs)
140
+ return model.fit(X)
141
+
142
+
143
+ MODEL_MAPPING = {
144
+ 'KMeans': get_kmeans,
145
+ 'DBSCAN': get_dbscan,
146
+ 'AgglomerativeClustering': get_agglomerative,
147
+ 'MeanShift': get_meanshift,
148
+ 'SpectralClustering': get_spectral,
149
+ 'OPTICS': get_optics,
150
+ 'Birch': get_birch,
151
+ 'GaussianMixture': get_gaussianmixture,
152
+ }
153
+
154
+
155
+ def plot_clusters(ax, X, labels):
156
+ for label in range(N_CLUSTERS):
157
+ idx = labels == label
158
+ if not sum(idx):
159
+ continue
160
+ ax.scatter(X[idx, 0], X[idx, 1])
161
+
162
+ ax.grid(None)
163
+ ax.set_xticks([])
164
+ ax.set_yticks([])
165
+ return ax
166
+
167
+
168
+ def cluster(clustering_algorithm: str, dataset: str):
169
+ X, labels = DATA_MAPPING[dataset]()
170
+ model = MODEL_MAPPING[clustering_algorithm](X)
171
+ if hasattr(model, "labels_"):
172
+ y_pred = model.labels_.astype(int)
173
+ else:
174
+ y_pred = model.predict(X)
175
+
176
+ fig, axes = plt.subplots(1, 2, figsize=(16, 8))
177
+
178
+ ax = axes[0]
179
+ plot_clusters(ax, X, labels)
180
+ ax.set_title("True clusters")
181
+
182
+ ax = axes[1]
183
+ plot_clusters(ax, X, y_pred)
184
+ ax.set_title(clustering_algorithm)
185
+
186
+ return fig
187
+
188
+
189
+ demo = gr.Interface(
190
+ fn=cluster,
191
+ inputs=[
192
+ gr.Radio(
193
+ list(MODEL_MAPPING),
194
+ value="KMeans",
195
+ label="clustering algorithm"
196
+ ),
197
+ gr.Radio(
198
+ list(DATA_MAPPING),
199
+ value="regular",
200
+ label="dataset"
201
+ ),
202
+ ],
203
+ outputs=gr.Plot(),
204
+ )
205
+
206
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ matplotlib>=3.5.2
2
+ scikit-learn>=1.0.1