shawnljw commited on
Commit
5a41ea2
·
1 Parent(s): 2db959e

add settings to app

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import cv2
4
  from tqdm import trange
 
5
 
6
  class KMeansClustering():
7
  def __init__(self, n_clusters=8, max_iter=300):
@@ -63,8 +64,11 @@ def segment_image(image, model: KMeansClustering):
63
  labels = model.predict(image)
64
  return labels.reshape(w,b), model
65
 
66
- def generate_outputs(image):
67
- model = KMeansClustering(n_clusters=24, max_iter=10)
 
 
 
68
  label_map, model = segment_image(image, model)
69
 
70
  clustered_image = model.cluster_centers_[label_map]
@@ -87,10 +91,24 @@ with gr.Blocks() as demo:
87
  image = gr.Image()
88
  submit = gr.Button('Generate')
89
  with gr.Column():
90
- output = gr.Gallery()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  submit.click(
92
  generate_outputs,
93
- inputs=[image],
94
  outputs=[output]
95
  )
96
 
 
2
  import numpy as np
3
  import cv2
4
  from tqdm import trange
5
+ from sklearn.cluster import KMeans
6
 
7
  class KMeansClustering():
8
  def __init__(self, n_clusters=8, max_iter=300):
 
64
  labels = model.predict(image)
65
  return labels.reshape(w,b), model
66
 
67
+ def generate_outputs(image, implementation, num_colours):
68
+ if implementation == 'custom':
69
+ model = KMeansClustering(n_clusters=num_colours, max_iter=10)
70
+ elif implementation == 'sk-learn':
71
+ model = KMeans(n_clusters=num_colours, n_init='auto')
72
  label_map, model = segment_image(image, model)
73
 
74
  clustered_image = model.cluster_centers_[label_map]
 
91
  image = gr.Image()
92
  submit = gr.Button('Generate')
93
  with gr.Column():
94
+ num_colours = gr.Slider(
95
+ minimum=1,
96
+ maximum=40,
97
+ value=24,
98
+ step=1,
99
+ label='Number of colours'
100
+ )
101
+ implementation = gr.Dropdown(
102
+ choices=['sk-learn','custom'],
103
+ value='sk-learn',
104
+ label='Implementation'
105
+ )
106
+ with gr.Row():
107
+ output = gr.Gallery(preview=True)
108
+
109
  submit.click(
110
  generate_outputs,
111
+ inputs=[image, implementation, num_colours],
112
  outputs=[output]
113
  )
114