|
|
|
|
|
from itertools import cycle |
|
from time import time |
|
|
|
import gradio as gr |
|
import matplotlib.colors as colors |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from joblib import cpu_count |
|
from sklearn.cluster import Birch, MiniBatchKMeans |
|
from sklearn.datasets import make_blobs |
|
|
|
plt.switch_backend("agg") |
|
|
|
|
|
def do_submit(n_samples, birch_threshold, birch_n_clusters): |
|
n_samples = int(n_samples) |
|
birch_threshold = float(birch_threshold) |
|
birch_n_clusters = int(birch_n_clusters) |
|
result = "" |
|
|
|
|
|
xx = np.linspace(-22, 22, 10) |
|
yy = np.linspace(-22, 22, 10) |
|
xx, yy = np.meshgrid(xx, yy) |
|
n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis])) |
|
|
|
|
|
X, y = make_blobs(n_samples=n_samples, centers=n_centers, random_state=0) |
|
|
|
|
|
colors_ = cycle(colors.cnames.keys()) |
|
|
|
fig = plt.figure(figsize=(12, 4)) |
|
fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9) |
|
|
|
|
|
|
|
birch_models = [ |
|
Birch(threshold=1.7, n_clusters=None), |
|
Birch(threshold=1.7, n_clusters=100), |
|
] |
|
final_step = ["without global clustering", "with global clustering"] |
|
|
|
for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)): |
|
t = time() |
|
birch_model.fit(X) |
|
result += ( |
|
"BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t)) |
|
+ "\n" |
|
) |
|
|
|
|
|
labels = birch_model.labels_ |
|
centroids = birch_model.subcluster_centers_ |
|
n_clusters = np.unique(labels).size |
|
result = result + "n_clusters : %d" % n_clusters + "\n" |
|
|
|
ax = fig.add_subplot(1, 3, ind + 1) |
|
for this_centroid, k, col in zip(centroids, range(n_clusters), colors_): |
|
mask = labels == k |
|
ax.scatter( |
|
X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5 |
|
) |
|
if birch_model.n_clusters is None: |
|
ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25) |
|
ax.set_ylim([-25, 25]) |
|
ax.set_xlim([-25, 25]) |
|
ax.set_autoscaley_on(False) |
|
ax.set_title("BIRCH %s" % info) |
|
|
|
|
|
mbk = MiniBatchKMeans( |
|
init="k-means++", |
|
n_clusters=100, |
|
batch_size=256 * cpu_count(), |
|
n_init=10, |
|
max_no_improvement=10, |
|
verbose=0, |
|
random_state=0, |
|
) |
|
t0 = time() |
|
mbk.fit(X) |
|
t_mini_batch = time() - t0 |
|
result += "Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch + "\n" |
|
mbk_means_labels_unique = np.unique(mbk.labels_) |
|
|
|
ax = fig.add_subplot(1, 3, 3) |
|
for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_): |
|
mask = mbk.labels_ == k |
|
ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5) |
|
ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25) |
|
ax.set_xlim([-25, 25]) |
|
ax.set_ylim([-25, 25]) |
|
ax.set_title("MiniBatchKMeans") |
|
ax.set_autoscaley_on(False) |
|
|
|
return fig, result |
|
|
|
|
|
|
|
theme = gr.themes.Monochrome( |
|
primary_hue="indigo", |
|
secondary_hue="blue", |
|
neutral_hue="slate", |
|
radius_size=gr.themes.sizes.radius_sm, |
|
font=[ |
|
gr.themes.GoogleFont("Open Sans"), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
], |
|
) |
|
|
|
title = "Compare BIRCH and MiniBatchKMeans" |
|
with gr.Blocks(title=title, theme=theme) as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown( |
|
"[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html)" |
|
) |
|
|
|
gr.Markdown( |
|
"This example compares the timing of BIRCH (with and without the global clustering step) and \ |
|
MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.\ |
|
\n Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or \ |
|
even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our \ |
|
Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to \ |
|
rerun it with a larger value for n_samples.\ |
|
\n\n\ |
|
If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters." |
|
) |
|
|
|
n_samples = gr.Slider( |
|
minimum=20000, |
|
maximum=80000, |
|
label="Number of samples", |
|
step=500, |
|
value=25000, |
|
) |
|
birch_threshold = gr.Slider( |
|
minimum=0.5, |
|
maximum=2.0, |
|
label="Birch Threshold", |
|
step=0.1, |
|
value=1.7, |
|
) |
|
birch_n_clusters = gr.Slider( |
|
minimum=0, |
|
maximum=100, |
|
label="Birch number of clusters", |
|
step=1, |
|
value=100, |
|
) |
|
|
|
plt_out = gr.Plot() |
|
output = gr.Textbox(label="Output", multiline=True) |
|
|
|
sub_btn = gr.Button("Submit") |
|
sub_btn.click( |
|
fn=do_submit, |
|
inputs=[n_samples, birch_threshold, birch_n_clusters], |
|
outputs=[plt_out, output], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|