Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.cluster import MeanShift, estimate_bandwidth | |
from sklearn.datasets import make_blobs | |
def get_clusters_plot(n_blobs, quantile, cluster_std): | |
X, _, centers = make_blobs( | |
n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True | |
) | |
bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500) | |
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) | |
ms.fit(X) | |
labels = ms.labels_ | |
cluster_centers = ms.cluster_centers_ | |
labels_unique = np.unique(labels) | |
n_clusters_ = len(labels_unique) | |
fig = plt.figure() | |
for k in range(n_clusters_): | |
my_members = labels == k | |
cluster_center = cluster_centers[k] | |
plt.scatter(X[my_members, 0], X[my_members, 1]) | |
plt.plot( | |
cluster_center[0], | |
cluster_center[1], | |
"x", | |
markeredgecolor="k", | |
markersize=14, | |
) | |
if len(centers) != n_clusters_: | |
message = ( | |
'<p style="text-align: center;">' | |
+ f"The number of estimated clusters ({n_clusters_})" | |
+ f" differs from the true number of clusters ({n_blobs})." | |
+ " Try changing the `Quantile` parameter.</p>" | |
) | |
else: | |
message = ( | |
'<p style="text-align: center;">' | |
+ f"The number of estimated clusters ({n_clusters_})" | |
+ f" matches the true number of clusters ({n_blobs})!</p>" | |
) | |
return fig, message | |
demo = gr.Interface( | |
get_clusters_plot, | |
[ | |
gr.Slider( | |
minimum=2, maximum=10, label="Number of clusters in data", step=1, value=3 | |
), | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.2, | |
label="Quantile", | |
info="Used to determine clustering's bandwidth.", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
label="Clusters standard deviation", | |
step=0.1, | |
value=0.6, | |
), | |
], | |
[gr.Plot(label="Clusters' Plot"), gr.HTML()], | |
allow_flagging="never", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |