shawnljw commited on
Commit
52983bc
·
1 Parent(s): 05ccb0c

create web ui

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from tqdm import trange
5
+
6
+ class KMeansClustering():
7
+ def __init__(self, n_clusters=8, max_iter=300):
8
+ self.n_clusters = n_clusters
9
+ self.max_iter = max_iter
10
+
11
+ def fit(self, X):
12
+ self.inertia_ = float('inf')
13
+
14
+ # random init of clusters
15
+ idx = np.random.choice(range(X.shape[0]), self.n_clusters, replace=False)
16
+ self.cluster_centers_ = X[idx]
17
+
18
+ print(f'Training for {self.max_iter} epochs')
19
+ epochs = trange(self.max_iter)
20
+ for i in epochs:
21
+ distances = X[:, np.newaxis, :] - self.cluster_centers_[np.newaxis, :, :]
22
+ distances = np.linalg.norm(distances, axis=2)
23
+ self.labels_ = np.argmin(distances, axis=1)
24
+ new_inertia = np.sum(np.min(distances, axis=1) ** 2)
25
+
26
+ epochs.set_description(f'Epoch-{i+1}, Inertia-{new_inertia}')
27
+
28
+ if new_inertia < self.inertia_:
29
+ self.inertia_ = new_inertia
30
+ else:
31
+ epochs.close()
32
+ print('Early Stopping. Inertia has converged.')
33
+ break
34
+
35
+ self.cluster_centers_ = np.empty_like(self.cluster_centers_)
36
+ for cluster in range(self.n_clusters):
37
+ in_cluster = (self.labels_ == cluster)
38
+ if np.any(in_cluster):
39
+ self.cluster_centers_[cluster] = np.mean(X[in_cluster], axis=0)
40
+ else:
41
+ # cluster is empty, pick random point as next centroid
42
+ self.cluster_centers_[cluster] = X[np.random.randint(0, X.shape[0])]
43
+
44
+ return self
45
+
46
+ def predict(self, X):
47
+ distances = X[:, np.newaxis, :] - self.cluster_centers_[np.newaxis, :, :]
48
+ distances = np.linalg.norm(distances, axis=2)
49
+ labels = np.argmin(distances, axis=1)
50
+ return labels
51
+
52
+ def fit_predict(self, X):
53
+ return self.fit(X).labels_
54
+
55
+ def segment_image(image, model: KMeansClustering):
56
+ w, b, c = image.shape
57
+ image = image.reshape(w*b, c) / 255
58
+
59
+ idx = np.random.choice(range(image.shape[0]), image.shape[0]//5, replace=False)
60
+ image_subset = image[idx]
61
+ model.fit(image_subset) # fit model on 20% sample of image
62
+
63
+ labels = model.predict(image)
64
+ return labels.reshape(w,b), model
65
+
66
+ def generate_outputs(image):
67
+ model = KMeansClustering(n_clusters=24, max_iter=10)
68
+ label_map, model = segment_image(image, model)
69
+
70
+ clustered_image = model.cluster_centers_[label_map]
71
+ clustered_image = (clustered_image * 255).astype('uint8')
72
+ clustered_image = cv2.medianBlur(clustered_image,5)
73
+ edges = 255 - cv2.Canny(clustered_image, 0, 1)
74
+ edges = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
75
+
76
+ return [(edges, 'Coloring Page'), (clustered_image, 'Filled Picture')]
77
+
78
+ with gr.Blocks() as demo:
79
+ gr.Markdown(
80
+ """
81
+ # image2coloringbook
82
+
83
+ (image2coloringbook)[https://github.com/ShawnLJW/image2coloringbook] is a simple tool that converts an image into a coloring book.
84
+ """)
85
+ with gr.Row():
86
+ with gr.Column():
87
+ image = gr.Image()
88
+ submit = gr.Button('Generate')
89
+ with gr.Column():
90
+ output = gr.Gallery()
91
+ submit.click(
92
+ generate_outputs,
93
+ inputs=[image],
94
+ outputs=[output]
95
+ )
96
+
97
+ if __name__ == '__main__':
98
+ demo.launch()