File size: 5,765 Bytes
c6738f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html
from itertools import cycle
from time import time
import gradio as gr
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
from joblib import cpu_count
from sklearn.cluster import Birch, MiniBatchKMeans
from sklearn.datasets import make_blobs
plt.switch_backend("agg")
def do_submit(n_samples, birch_threshold, birch_n_clusters):
n_samples = int(n_samples)
birch_threshold = float(birch_threshold)
birch_n_clusters = int(birch_n_clusters)
result = ""
# Generate centers for the blobs so that it forms a 10 X 10 grid.
xx = np.linspace(-22, 22, 10)
yy = np.linspace(-22, 22, 10)
xx, yy = np.meshgrid(xx, yy)
n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))
# Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.
X, y = make_blobs(n_samples=n_samples, centers=n_centers, random_state=0)
# Use all colors that matplotlib provides by default.
colors_ = cycle(colors.cnames.keys())
fig = plt.figure(figsize=(12, 4))
fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)
# Compute clustering with BIRCH with and without the final clustering step
# and plot.
birch_models = [
Birch(threshold=1.7, n_clusters=None),
Birch(threshold=1.7, n_clusters=100),
]
final_step = ["without global clustering", "with global clustering"]
for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
t = time()
birch_model.fit(X)
result += (
"BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t))
+ "\n"
)
# Plot result
labels = birch_model.labels_
centroids = birch_model.subcluster_centers_
n_clusters = np.unique(labels).size
result = result + "n_clusters : %d" % n_clusters + "\n"
ax = fig.add_subplot(1, 3, ind + 1)
for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
mask = labels == k
ax.scatter(
X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5
)
if birch_model.n_clusters is None:
ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
ax.set_ylim([-25, 25])
ax.set_xlim([-25, 25])
ax.set_autoscaley_on(False)
ax.set_title("BIRCH %s" % info)
# Compute clustering with MiniBatchKMeans.
mbk = MiniBatchKMeans(
init="k-means++",
n_clusters=100,
batch_size=256 * cpu_count(),
n_init=10,
max_no_improvement=10,
verbose=0,
random_state=0,
)
t0 = time()
mbk.fit(X)
t_mini_batch = time() - t0
result += "Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch + "\n"
mbk_means_labels_unique = np.unique(mbk.labels_)
ax = fig.add_subplot(1, 3, 3)
for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_):
mask = mbk.labels_ == k
ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5)
ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
ax.set_xlim([-25, 25])
ax.set_ylim([-25, 25])
ax.set_title("MiniBatchKMeans")
ax.set_autoscaley_on(False)
return fig, result
# Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
title = "Compare BIRCH and MiniBatchKMeans"
with gr.Blocks(title=title, theme=theme) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(
"[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html)"
)
gr.Markdown(
"This example compares the timing of BIRCH (with and without the global clustering step) and \
MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.\
\n Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or \
even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our \
Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to \
rerun it with a larger value for n_samples.\
\n\n\
If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters."
)
n_samples = gr.Slider(
minimum=20000,
maximum=80000,
label="Number of samples",
step=500,
value=25000,
)
birch_threshold = gr.Slider(
minimum=0.5,
maximum=2.0,
label="Birch Threshold",
step=0.1,
value=1.7,
)
birch_n_clusters = gr.Slider(
minimum=0,
maximum=100,
label="Birch number of clusters",
step=1,
value=100,
)
plt_out = gr.Plot()
output = gr.Textbox(label="Output", multiline=True)
sub_btn = gr.Button("Submit")
sub_btn.click(
fn=do_submit,
inputs=[n_samples, birch_threshold, birch_n_clusters],
outputs=[plt_out, output],
)
if __name__ == "__main__":
demo.launch()
|