Johannes
commited on
Commit
·
f589fdc
1
Parent(s):
2f3a29d
add working code
Browse files- README.md +2 -2
- app.py +73 -0
- requirements.txt +3 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
+
title: sklearn Spectral Clustering
|
3 |
+
emoji: 🔴🔵🔴
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.feature_extraction import image
|
3 |
+
from sklearn.cluster import spectral_clustering
|
4 |
+
import matplotlib
|
5 |
+
matplotlib.use('Agg')
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import gradio as gr
|
8 |
+
from scipy.cluster.vq import kmeans
|
9 |
+
|
10 |
+
|
11 |
+
def get_coordinates_from_mask(mask_in, number_of_centroids):
|
12 |
+
x_y = np.where(mask_in != [255, 255, 255])[:2]
|
13 |
+
x_y = np.column_stack((x_y[0], x_y[1]))
|
14 |
+
x_y = np.float32(x_y)
|
15 |
+
centroids,_ = kmeans(x_y,number_of_centroids)
|
16 |
+
centroids = np.int64(centroids)
|
17 |
+
|
18 |
+
return centroids
|
19 |
+
|
20 |
+
|
21 |
+
def infer(input_image: np.ndarray, number_of_circles: int, radius: int):
|
22 |
+
centroids = get_coordinates_from_mask(input_image, number_of_circles)
|
23 |
+
|
24 |
+
img = np.zeros((input_image.shape[1], input_image.shape[0]))
|
25 |
+
|
26 |
+
x, y = np.indices((input_image.shape[1], input_image.shape[0]))
|
27 |
+
|
28 |
+
for centroid in centroids:
|
29 |
+
circle = (x - centroid[0]) ** 2 + (y - centroid[1]) ** 2 < radius**2
|
30 |
+
img += circle
|
31 |
+
|
32 |
+
mask = img.astype(bool)
|
33 |
+
|
34 |
+
img = img.astype(float)
|
35 |
+
img += 1 + 0.2 * np.random.randn(*img.shape)
|
36 |
+
|
37 |
+
|
38 |
+
graph = image.img_to_graph(img, mask=mask)
|
39 |
+
graph.data = np.exp(-graph.data / graph.data.std())
|
40 |
+
|
41 |
+
labels = spectral_clustering(graph, n_clusters=len(centroids), eigen_solver="arpack")
|
42 |
+
label_im = np.full(mask.shape, -1.0)
|
43 |
+
label_im[mask] = labels
|
44 |
+
|
45 |
+
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
|
46 |
+
axs[0].matshow(img)
|
47 |
+
axs[1].matshow(label_im)
|
48 |
+
|
49 |
+
return fig
|
50 |
+
|
51 |
+
|
52 |
+
article = """<center>
|
53 |
+
Demo by <a href='https://huggingface.co/johko' target='_blank'>Johannes (johko) Kolbe</a>"""
|
54 |
+
|
55 |
+
|
56 |
+
description = """<p style="text-align: center;">This is an interactive demo for the <a href="https://scikit-learn.org/stable/auto_examples/cluster/plot_segmentation_toy.html#sphx-glr-auto-examples-cluster-plot-segmentation-toy-py">Spectral clustering for image segmentation tutorial</a> from scikit-learn.
|
57 |
+
</br></br>The demo lets you mark places in the input where the centers of circles should be. The circles should then be segmented from one another using Spectral Image Clustering.
|
58 |
+
</br>The circles should ideally be close together(connected), to let the algorithm work correctly.
|
59 |
+
</br>As the demo uses k-means to determine the centroids of the circles exactly, you also need to specify the number of circles you want to get.
|
60 |
+
</br></br><b>What is Spectral Image clustering?</b> From the tutorial:
|
61 |
+
</br><i>"The Spectral clustering approach solves the problem know as ‘normalized graph cuts’: the image is seen as a graph of connected voxels, and the spectral clustering algorithm amounts to choosing graph cuts defining regions while minimizing the ratio of the gradient along the cut, and the volume of the region."</i> .</p>"""
|
62 |
+
|
63 |
+
|
64 |
+
gr.Interface(
|
65 |
+
title="Spectral Clustering with scikit-learn",
|
66 |
+
description=description,
|
67 |
+
article=artsicle,
|
68 |
+
fn=infer,
|
69 |
+
inputs=[gr.Image(source="canvas", tool="sketch", label="Input Image", shape=[100, 100]),
|
70 |
+
gr.Number(label="Number of circles to draw", value=4, precision=0),
|
71 |
+
gr.Slider(label="Circle Radius", minimum=5, maximum=25, value=15, step=1)],
|
72 |
+
outputs=[gr.Plot(label="Original Image Histogram")]
|
73 |
+
).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib==3.6.3
|
2 |
+
scikit-learn==1.2.1
|
3 |
+
scipy
|