"""Gradio demo for different clustering techiniques Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html """ import gradio as gr import matplotlib.pyplot as plt import numpy as np from sklearn.cluster import ( AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth ) from sklearn.datasets import make_blobs, make_circles, make_moons from sklearn.mixture import GaussianMixture from sklearn.neighbors import kneighbors_graph from sklearn.preprocessing import StandardScaler plt.style.use('seaborn') SEED = 0 MAX_CLUSTERS = 10 N_SAMPLES = 1000 np.random.seed(SEED) def normalize(X): return StandardScaler().fit_transform(X) def get_regular(n_clusters): # spiral pattern centers = [ [0, 0], [1, 0], [1, 1], [0, 1], [-1, 1], [-1, 0], [-1, -1], [0, -1], [1, -1], [2, -1], ][:n_clusters] assert len(centers) == n_clusters X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED) return normalize(X), labels def get_circles(n_clusters): X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED) return normalize(X), labels def get_moons(n_clusters): X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED) return normalize(X), labels def get_noise(n_clusters): X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES) return normalize(X), labels def get_anisotropic(n_clusters): X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170) transformation = [[0.6, -0.6], [-0.4, 0.8]] X = np.dot(X, transformation) return X, labels def get_varied(n_clusters): cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters] assert len(cluster_std) == n_clusters X, labels = make_blobs( n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED ) return normalize(X), labels DATA_MAPPING = { 'regular': get_regular, 'circles': get_circles, 'moons': get_moons, 'noise': get_noise, 'anisotropic': get_anisotropic, 'varied': get_varied, } def get_kmeans(X, n_clusters, **kwargs): model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED) model.set_params(**kwargs) return model.fit(X) def get_dbscan(X, n_clusters, **kwargs): model = DBSCAN(eps=0.3) model.set_params(**kwargs) return model.fit(X) def get_agglomerative(X, n_clusters, **kwargs): connectivity = kneighbors_graph( X, n_neighbors=n_clusters, include_self=False ) # make connectivity symmetric connectivity = 0.5 * (connectivity + connectivity.T) model = AgglomerativeClustering( n_clusters=n_clusters, linkage="ward", connectivity=connectivity ) model.set_params(**kwargs) return model.fit(X) def get_meanshift(X, n_clusters, **kwargs): bandwidth = estimate_bandwidth(X, quantile=0.3) model = MeanShift(bandwidth=bandwidth, bin_seeding=True) model.set_params(**kwargs) return model.fit(X) def get_spectral(X, n_clusters, **kwargs): model = SpectralClustering( n_clusters=n_clusters, eigen_solver="arpack", affinity="nearest_neighbors", ) model.set_params(**kwargs) return model.fit(X) def get_optics(X, n_clusters, **kwargs): model = OPTICS( min_samples=7, xi=0.05, min_cluster_size=0.1, ) model.set_params(**kwargs) return model.fit(X) def get_birch(X, n_clusters, **kwargs): model = Birch(n_clusters=n_clusters) model.set_params(**kwargs) return model.fit(X) def get_gaussianmixture(X, n_clusters, **kwargs): model = GaussianMixture( n_components=n_clusters, covariance_type="full", random_state=SEED, ) model.set_params(**kwargs) return model.fit(X) MODEL_MAPPING = { 'KMeans': get_kmeans, 'DBSCAN': get_dbscan, 'AgglomerativeClustering': get_agglomerative, 'MeanShift': get_meanshift, 'SpectralClustering': get_spectral, 'OPTICS': get_optics, 'Birch': get_birch, 'GaussianMixture': get_gaussianmixture, } def plot_clusters(ax, X, labels): set_clusters = set(labels) set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately for label in sorted(set_clusters): idx = labels == label if not sum(idx): continue ax.scatter(X[idx, 0], X[idx, 1]) # show outliers (if any) idx = labels == -1 if sum(idx): ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x') ax.grid(None) ax.set_xticks([]) ax.set_yticks([]) return ax def cluster(clustering_algorithm: str, dataset: str, n_clusters: int): n_clusters = int(n_clusters) X, labels = DATA_MAPPING[dataset](n_clusters) model = MODEL_MAPPING[clustering_algorithm](X, n_clusters=n_clusters) if hasattr(model, "labels_"): y_pred = model.labels_.astype(int) else: y_pred = model.predict(X) fig, axes = plt.subplots(1, 2, figsize=(16, 8)) # show true labels in first panel ax = axes[0] plot_clusters(ax, X, labels) ax.set_title("True clusters") # show learned clusters in second panel ax = axes[1] plot_clusters(ax, X, y_pred) ax.set_title(clustering_algorithm) return fig title = "Clustering with Scikit-learn" description = ( "This example shows how different clustering algorithms work. Simply pick " "the algorithm and the dataset to see how the clustering algorithms work." ) demo = gr.Interface( fn=cluster, inputs=[ gr.Radio( list(MODEL_MAPPING), value="KMeans", label="clustering algorithm" ), gr.Radio( list(DATA_MAPPING), value="regular", label="dataset" ), gr.Slider( minimum=1, maximum=MAX_CLUSTERS, value=4, step=1, ) ], title=title, description=description, outputs=gr.Plot(), ) demo.launch()