cmpatino's picture
Add quantile as parameter in the UI
bad0412
raw
history blame
1.78 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
def get_clusters_plot(n_blobs, quantile, cluster_std):
X, _, centers = make_blobs(
n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True
)
bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
fig = plt.figure()
for k in range(n_clusters_):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.scatter(X[my_members, 0], X[my_members, 1])
plt.plot(
cluster_center[0],
cluster_center[1],
"x",
markeredgecolor="k",
markersize=14,
)
message = f"## True Clusters: {len(centers)} | Detected Clusters: {n_clusters_}"
return fig, message
demo = gr.Interface(
get_clusters_plot,
[
gr.Slider(
minimum=2, maximum=10, label="Number of clusters in data", step=1, value=3
),
gr.Slider(
minimum=0,
maximum=1,
step=0.05,
value=0.2,
label="Quantile",
info="Used to determine clustering's bandwidth.",
),
gr.Slider(
minimum=0.1,
maximum=1,
label="Cluster standard deviation",
step=0.1,
value=0.6,
),
],
[gr.Plot(), gr.Markdown()],
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()