haritsahm commited on
Commit
f665217
·
1 Parent(s): fb345ee

Add new features to segment everything

Browse files
Files changed (2) hide show
  1. app.py +86 -24
  2. utils/utils.py +17 -8
app.py CHANGED
@@ -10,10 +10,10 @@ from PIL import Image
10
  from streamlit_drawable_canvas import st_canvas
11
  from utils import utils
12
 
13
- SAM_MODEL = utils.get_model('vit_b')
14
 
15
 
16
- def box_process(model, show_mask, radius_width):
17
  bg_image = st.session_state['image']
18
  width, height = bg_image.size[:2]
19
  container_width = 700
@@ -44,6 +44,7 @@ def box_process(model, show_mask, radius_width):
44
  st.session_state.rerun_once = True
45
 
46
  st.session_state.display_result = True
 
47
  if st.session_state.rerun_once:
48
  st.experimental_rerun()
49
  else:
@@ -64,11 +65,12 @@ def box_process(model, show_mask, radius_width):
64
  input_box.append([x,y,x+w,y+h])
65
 
66
  masks = []
67
- if model:
68
- masks = utils.model_predict_masks_box(model, center_point, center_label, input_box)
69
 
70
  if len(masks) == 0:
71
- return bg_image
 
72
 
73
  bg_image = np.asarray(bg_image)
74
  color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
@@ -84,7 +86,7 @@ def box_process(model, show_mask, radius_width):
84
  return np.asarray(bg_image)
85
 
86
 
87
- def click_process(model, show_mask, radius_width):
88
 
89
  bg_image = st.session_state['image']
90
  width, height = bg_image.size[:2]
@@ -114,6 +116,7 @@ def click_process(model, show_mask, radius_width):
114
  st.session_state.rerun_once = True
115
 
116
  st.session_state.display_result = True
 
117
  if st.session_state.rerun_once:
118
  st.experimental_rerun()
119
  else:
@@ -135,11 +138,12 @@ def click_process(model, show_mask, radius_width):
135
  input_labels.append(0)
136
 
137
  masks = []
138
- if model:
139
- masks = utils.model_predict_masks_click(model, input_points, input_labels)
140
 
141
  if len(masks) == 0:
142
- return bg_image
 
143
 
144
  bg_image = np.asarray(bg_image)
145
  color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
@@ -147,6 +151,7 @@ def click_process(model, show_mask, radius_width):
147
  im_masked = Image.fromarray(im_masked).convert('RGBA')
148
  result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
149
  result_image = result_image.resize(scaled_hw)
 
150
  return result_image
151
  else:
152
  return np.asarray(bg_image)
@@ -154,16 +159,65 @@ def click_process(model, show_mask, radius_width):
154
  return np.asarray(bg_image)
155
 
156
 
157
- def image_preprocess_callback(model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  if 'uploaded_image' not in st.session_state:
159
  return
160
  if st.session_state.uploaded_image is not None:
161
  with st.spinner(text="Uploading image..."):
162
  image = Image.open(st.session_state.uploaded_image).convert("RGB")
163
- if model:
164
  np_image = np.asanyarray(image)
165
  with st.spinner(text="Extracing embeddings.."):
166
- model.set_image(np_image)
167
  st.session_state.image = image
168
  else:
169
  with st.spinner(text="Cleaning up!"):
@@ -173,8 +227,8 @@ def image_preprocess_callback(model):
173
  st.session_state.image = None
174
  if 'result_image' in st.session_state:
175
  del st.session_state['result_image']
176
- if model:
177
- model.reset_image()
178
 
179
  def main():
180
  with open('index.html', encoding='utf-8') as f:
@@ -202,21 +256,29 @@ def main():
202
  st.write("Upload Image")
203
  st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden")
204
 
 
205
  canvas_input, canvas_output = st.columns(2)
 
206
  if 'image' in st.session_state:
207
- result_image = None
208
  with canvas_input:
209
  st.write("Select Interest Area/Objects")
210
  if st.session_state.image is not None:
211
- if option == 'Click':
212
- with st.spinner(text="Computing masks"):
213
- result_image = click_process(SAM_MODEL, show_mask, radius_width)
214
- elif option == 'Box':
215
- result_image = box_process(SAM_MODEL, show_mask, radius_width)
216
- with canvas_output:
217
- if result_image is not None:
218
- st.write("Result")
219
- st.image(result_image)
 
 
 
 
 
 
 
220
 
221
  else:
222
  st.cache_data.clear()
 
10
  from streamlit_drawable_canvas import st_canvas
11
  from utils import utils
12
 
13
+ PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model('vit_b')
14
 
15
 
16
+ def process_box(predictor_model, show_mask, radius_width):
17
  bg_image = st.session_state['image']
18
  width, height = bg_image.size[:2]
19
  container_width = 700
 
44
  st.session_state.rerun_once = True
45
 
46
  st.session_state.display_result = True
47
+ st.warning("Mask view is disabled", icon="❗")
48
  if st.session_state.rerun_once:
49
  st.experimental_rerun()
50
  else:
 
65
  input_box.append([x,y,x+w,y+h])
66
 
67
  masks = []
68
+ if predictor_model:
69
+ masks = utils.model_predict_masks_box(predictor_model, center_point, center_label, input_box)
70
 
71
  if len(masks) == 0:
72
+ st.warning("No Masks Found", icon="❗")
73
+ return np.asarray(bg_image)
74
 
75
  bg_image = np.asarray(bg_image)
76
  color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
 
86
  return np.asarray(bg_image)
87
 
88
 
89
+ def process_click(predictor_model, show_mask, radius_width):
90
 
91
  bg_image = st.session_state['image']
92
  width, height = bg_image.size[:2]
 
116
  st.session_state.rerun_once = True
117
 
118
  st.session_state.display_result = True
119
+ st.warning("Mask view is disabled", icon="❗")
120
  if st.session_state.rerun_once:
121
  st.experimental_rerun()
122
  else:
 
138
  input_labels.append(0)
139
 
140
  masks = []
141
+ if predictor_model:
142
+ masks = utils.model_predict_masks_click(predictor_model, input_points, input_labels)
143
 
144
  if len(masks) == 0:
145
+ st.warning("No Masks Found", icon="❗")
146
+ return np.asarray(bg_image)
147
 
148
  bg_image = np.asarray(bg_image)
149
  color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
 
151
  im_masked = Image.fromarray(im_masked).convert('RGBA')
152
  result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
153
  result_image = result_image.resize(scaled_hw)
154
+ st.session_state.display_result = True
155
  return result_image
156
  else:
157
  return np.asarray(bg_image)
 
159
  return np.asarray(bg_image)
160
 
161
 
162
+ def process_everything(automask_model, show_mask, radius_width):
163
+ bg_image = st.session_state['image']
164
+ width, height = bg_image.size[:2]
165
+ container_width = 700
166
+ scale = container_width/width
167
+ scaled_hw = (container_width, int(height * scale))
168
+
169
+ if 'result_image' not in st.session_state:
170
+ st.session_state.result_image = bg_image.resize(scaled_hw)
171
+
172
+ dummy_canvas = st_canvas(
173
+ fill_color="rgba(255, 255, 0, 0.8)",
174
+ background_image = bg_image,
175
+ drawing_mode='freedraw',
176
+ width = container_width,
177
+ height = height * scale,
178
+ point_display_radius = radius_width,
179
+ stroke_width=2,
180
+ update_streamlit=False,
181
+ key="everything",)
182
+
183
+ if not show_mask:
184
+ if 'rerun_once' in st.session_state:
185
+ if st.session_state.rerun_once:
186
+ st.session_state.rerun_once = False
187
+ else:
188
+ st.session_state.rerun_once = True
189
+
190
+ st.session_state.display_result = True
191
+ st.warning("Mask view is disabled", icon="❗")
192
+ if st.session_state.rerun_once:
193
+ st.experimental_rerun()
194
+ else:
195
+ return np.asarray(bg_image)
196
+
197
+ if automask_model:
198
+ bg_image = np.asarray(bg_image)
199
+ masks = utils.model_predict_masks_everything(automask_model, bg_image)
200
+ im_masked = utils.show_everything(masks)
201
+ im_masked = Image.fromarray(im_masked).convert('RGBA')
202
+ result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
203
+ result_image = result_image.resize(scaled_hw)
204
+ st.session_state.display_result = True
205
+ return result_image
206
+
207
+ else:
208
+ return np.asarray(bg_image)
209
+
210
+
211
+ def image_preprocess_callback(predictor_model, option):
212
  if 'uploaded_image' not in st.session_state:
213
  return
214
  if st.session_state.uploaded_image is not None:
215
  with st.spinner(text="Uploading image..."):
216
  image = Image.open(st.session_state.uploaded_image).convert("RGB")
217
+ if predictor_model and option != 'Everything':
218
  np_image = np.asanyarray(image)
219
  with st.spinner(text="Extracing embeddings.."):
220
+ predictor_model.set_image(np_image)
221
  st.session_state.image = image
222
  else:
223
  with st.spinner(text="Cleaning up!"):
 
227
  st.session_state.image = None
228
  if 'result_image' in st.session_state:
229
  del st.session_state['result_image']
230
+ if predictor_model:
231
+ predictor_model.reset_image()
232
 
233
  def main():
234
  with open('index.html', encoding='utf-8') as f:
 
256
  st.write("Upload Image")
257
  st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden")
258
 
259
+ result_image = None
260
  canvas_input, canvas_output = st.columns(2)
261
+
262
  if 'image' in st.session_state:
 
263
  with canvas_input:
264
  st.write("Select Interest Area/Objects")
265
  if st.session_state.image is not None:
266
+ with st.spinner(text="Computing masks"):
267
+ if option == 'Click':
268
+ result_image = process_click(PREDICTOR_MODEL, show_mask, radius_width)
269
+ elif option == 'Box':
270
+ result_image = process_box(PREDICTOR_MODEL, show_mask, radius_width)
271
+ else:
272
+ result_image = process_everything(AUTOMASK_MODEL, show_mask, radius_width)
273
+ if 'display_result' in st.session_state:
274
+ if st.session_state.display_result:
275
+ with canvas_output:
276
+ if result_image is not None:
277
+ st.write("Result")
278
+ st.image(result_image)
279
+ else:
280
+ st.warning("No result found, please set input prompt", icon="⚠️")
281
+ st.success('Process completed!', icon="✅")
282
 
283
  else:
284
  st.cache_data.clear()
utils/utils.py CHANGED
@@ -1,5 +1,5 @@
1
 
2
- from segment_anything import SamPredictor, sam_model_registry
3
  import torch
4
  import numpy as np
5
  from distinctipy import distinctipy
@@ -19,20 +19,19 @@ def get_model(model):
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  build_sam = sam_model_registry[model]
21
  model = build_sam(checkpoint=get_checkpoint_path(model)).to(device)
22
- predictor = SamPredictor(model)
23
  if torch.cuda.is_available():
24
  torch.cuda.empty_cache()
25
- return predictor
 
 
26
 
27
 
28
- @st.cache_data
29
  def show_everything(sorted_anns):
30
- if len(sorted_anns) == 0:
31
- return
32
  #sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
33
  h, w = sorted_anns[0]['segmentation'].shape[-2:]
34
  #sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
35
  if sorted_anns == []:
 
36
  return np.zeros((h,w,4)).astype(np.uint8)
37
  mask = np.zeros((h,w,4))
38
  for ann in sorted_anns:
@@ -40,13 +39,13 @@ def show_everything(sorted_anns):
40
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
41
  mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
42
  mask = mask * 255
43
- st.success('Process completed!', icon="✅")
44
  return mask.astype(np.uint8)
45
 
46
 
47
  def show_click(masks, colors):
48
  h, w = masks[0].shape[-2:]
49
  masks_total = np.zeros((h,w,4)).astype(np.uint8)
 
50
  for mask, color in zip(masks, colors):
51
  if np.array_equal(mask,np.array([])):continue
52
  masks = np.zeros((h,w,4)).astype(np.uint8)
@@ -54,7 +53,7 @@ def show_click(masks, colors):
54
  masks = masks.astype(bool).astype(np.uint8)
55
  masks = masks * 255 * color.reshape(1, 1, -1)
56
  masks_total += masks.astype(np.uint8)
57
- st.success('Process completed!', icon="✅")
58
  return masks_total
59
 
60
  def model_predict_masks_click(model,input_points,input_labels):
@@ -66,6 +65,7 @@ def model_predict_masks_click(model,input_points,input_labels):
66
  point_labels=input_labels,
67
  multimask_output=False,
68
  )
 
69
  if torch.cuda.is_available():
70
  torch.cuda.empty_cache()
71
 
@@ -93,3 +93,12 @@ def model_predict_masks_box(model,center_point,center_label,input_box):
93
  torch.cuda.empty_cache()
94
 
95
  return masks
 
 
 
 
 
 
 
 
 
 
1
 
2
+ from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
3
  import torch
4
  import numpy as np
5
  from distinctipy import distinctipy
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  build_sam = sam_model_registry[model]
21
  model = build_sam(checkpoint=get_checkpoint_path(model)).to(device)
 
22
  if torch.cuda.is_available():
23
  torch.cuda.empty_cache()
24
+ predictor = SamPredictor(model)
25
+ mask_generator = SamAutomaticMaskGenerator(model)
26
+ return predictor, mask_generator
27
 
28
 
 
29
  def show_everything(sorted_anns):
 
 
30
  #sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
31
  h, w = sorted_anns[0]['segmentation'].shape[-2:]
32
  #sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
33
  if sorted_anns == []:
34
+ st.warning("No Masks Found", icon="❗")
35
  return np.zeros((h,w,4)).astype(np.uint8)
36
  mask = np.zeros((h,w,4))
37
  for ann in sorted_anns:
 
39
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
40
  mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
41
  mask = mask * 255
 
42
  return mask.astype(np.uint8)
43
 
44
 
45
  def show_click(masks, colors):
46
  h, w = masks[0].shape[-2:]
47
  masks_total = np.zeros((h,w,4)).astype(np.uint8)
48
+
49
  for mask, color in zip(masks, colors):
50
  if np.array_equal(mask,np.array([])):continue
51
  masks = np.zeros((h,w,4)).astype(np.uint8)
 
53
  masks = masks.astype(bool).astype(np.uint8)
54
  masks = masks * 255 * color.reshape(1, 1, -1)
55
  masks_total += masks.astype(np.uint8)
56
+
57
  return masks_total
58
 
59
  def model_predict_masks_click(model,input_points,input_labels):
 
65
  point_labels=input_labels,
66
  multimask_output=False,
67
  )
68
+
69
  if torch.cuda.is_available():
70
  torch.cuda.empty_cache()
71
 
 
93
  torch.cuda.empty_cache()
94
 
95
  return masks
96
+
97
+
98
+ def model_predict_masks_everything(mask_generator, image):
99
+ masks = mask_generator.generate(image)
100
+
101
+ if torch.cuda.is_available():
102
+ torch.cuda.empty_cache()
103
+
104
+ return masks