haritsahm commited on
Commit
fb345ee
·
1 Parent(s): f0c912f

Add box selection feature

Browse files
Files changed (2) hide show
  1. app.py +91 -11
  2. utils/utils.py +23 -0
app.py CHANGED
@@ -13,6 +13,77 @@ from utils import utils
13
  SAM_MODEL = utils.get_model('vit_b')
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def click_process(model, show_mask, radius_width):
17
 
18
  bg_image = st.session_state['image']
@@ -24,7 +95,7 @@ def click_process(model, show_mask, radius_width):
24
  if 'result_image' not in st.session_state:
25
  st.session_state.result_image = bg_image.resize(scaled_hw)
26
 
27
- canvas_result = st_canvas(
28
  fill_color="rgba(255, 255, 0, 0.8)",
29
  background_image = bg_image,
30
  drawing_mode='point',
@@ -35,13 +106,21 @@ def click_process(model, show_mask, radius_width):
35
  update_streamlit=True,
36
  key="point",)
37
 
38
- # ! Warn: Can cause infinite loop or high cpu usage
39
  if not show_mask:
40
- print("rerun no mask")
41
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
42
 
43
- elif canvas_result.json_data is not None:
44
- df = pd.json_normalize(canvas_result.json_data["objects"])
45
  input_points = []
46
  input_labels = []
47
 
@@ -88,6 +167,8 @@ def image_preprocess_callback(model):
88
  st.session_state.image = image
89
  else:
90
  with st.spinner(text="Cleaning up!"):
 
 
91
  if 'image' in st.session_state:
92
  st.session_state.image = None
93
  if 'result_image' in st.session_state:
@@ -130,16 +211,15 @@ def main():
130
  if option == 'Click':
131
  with st.spinner(text="Computing masks"):
132
  result_image = click_process(SAM_MODEL, show_mask, radius_width)
 
 
133
  with canvas_output:
134
  if result_image is not None:
135
  st.write("Result")
136
  st.image(result_image)
137
 
138
- # else:
139
- # print(f'embedding is empty - {option} - {show_mask} - {radius_width}')
140
- # if 'image' in st.session_state:
141
- # if st.session_state.image is None:
142
- # st.session_state.clear()
143
 
144
 
145
  if __name__ == '__main__':
 
13
  SAM_MODEL = utils.get_model('vit_b')
14
 
15
 
16
+ def box_process(model, show_mask, radius_width):
17
+ bg_image = st.session_state['image']
18
+ width, height = bg_image.size[:2]
19
+ container_width = 700
20
+ scale = container_width/width
21
+ scaled_hw = (container_width, int(height * scale))
22
+
23
+ if 'result_image' not in st.session_state:
24
+ st.session_state.result_image = bg_image.resize(scaled_hw)
25
+
26
+ box_canvas = st_canvas(
27
+ fill_color="rgba(255, 255, 0, 0)",
28
+ background_image = bg_image,
29
+ drawing_mode='rect',
30
+ stroke_color = "rgba(0, 255, 0, 0.6)",
31
+ stroke_width = radius_width,
32
+ width = container_width,
33
+ height = height * scale,
34
+ point_display_radius = 12,
35
+ update_streamlit=True,
36
+ key="box"
37
+ )
38
+
39
+ if not show_mask:
40
+ if 'rerun_once' in st.session_state:
41
+ if st.session_state.rerun_once:
42
+ st.session_state.rerun_once = False
43
+ else:
44
+ st.session_state.rerun_once = True
45
+
46
+ st.session_state.display_result = True
47
+ if st.session_state.rerun_once:
48
+ st.experimental_rerun()
49
+ else:
50
+ return np.asarray(bg_image)
51
+
52
+ elif box_canvas.json_data is not None:
53
+ df = pd.json_normalize(box_canvas.json_data["objects"])
54
+ center_point,center_label,input_box = [],[],[]
55
+ center_point, center_label, input_box = [], [], []
56
+ for _, row in df.iterrows():
57
+ x, y, w,h = row["left"], row["top"], row["width"], row["height"]
58
+ x = int(x/scale)
59
+ y = int(y/scale)
60
+ w = int(w/scale)
61
+ h = int(h/scale)
62
+ center_point.append([x+w/2,y+h/2])
63
+ center_label.append([1])
64
+ input_box.append([x,y,x+w,y+h])
65
+
66
+ masks = []
67
+ if model:
68
+ masks = utils.model_predict_masks_box(model, center_point, center_label, input_box)
69
+
70
+ if len(masks) == 0:
71
+ return bg_image
72
+
73
+ bg_image = np.asarray(bg_image)
74
+ color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
75
+ im_masked = utils.show_click(masks,color)
76
+ im_masked = Image.fromarray(im_masked).convert('RGBA')
77
+ result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
78
+ result_image = result_image.resize(scaled_hw)
79
+ st.session_state.display_result = True
80
+ return result_image
81
+ else:
82
+ return np.asarray(bg_image)
83
+
84
+ return np.asarray(bg_image)
85
+
86
+
87
  def click_process(model, show_mask, radius_width):
88
 
89
  bg_image = st.session_state['image']
 
95
  if 'result_image' not in st.session_state:
96
  st.session_state.result_image = bg_image.resize(scaled_hw)
97
 
98
+ click_canvas = st_canvas(
99
  fill_color="rgba(255, 255, 0, 0.8)",
100
  background_image = bg_image,
101
  drawing_mode='point',
 
106
  update_streamlit=True,
107
  key="point",)
108
 
 
109
  if not show_mask:
110
+ if 'rerun_once' in st.session_state:
111
+ if st.session_state.rerun_once:
112
+ st.session_state.rerun_once = False
113
+ else:
114
+ st.session_state.rerun_once = True
115
+
116
+ st.session_state.display_result = True
117
+ if st.session_state.rerun_once:
118
+ st.experimental_rerun()
119
+ else:
120
+ return np.asarray(bg_image)
121
 
122
+ elif click_canvas.json_data is not None:
123
+ df = pd.json_normalize(click_canvas.json_data["objects"])
124
  input_points = []
125
  input_labels = []
126
 
 
167
  st.session_state.image = image
168
  else:
169
  with st.spinner(text="Cleaning up!"):
170
+ if 'display_result' in st.session_state:
171
+ st.session_state.display_result = False
172
  if 'image' in st.session_state:
173
  st.session_state.image = None
174
  if 'result_image' in st.session_state:
 
211
  if option == 'Click':
212
  with st.spinner(text="Computing masks"):
213
  result_image = click_process(SAM_MODEL, show_mask, radius_width)
214
+ elif option == 'Box':
215
+ result_image = box_process(SAM_MODEL, show_mask, radius_width)
216
  with canvas_output:
217
  if result_image is not None:
218
  st.write("Result")
219
  st.image(result_image)
220
 
221
+ else:
222
+ st.cache_data.clear()
 
 
 
223
 
224
 
225
  if __name__ == '__main__':
utils/utils.py CHANGED
@@ -70,3 +70,26 @@ def model_predict_masks_click(model,input_points,input_labels):
70
  torch.cuda.empty_cache()
71
 
72
  return masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  torch.cuda.empty_cache()
71
 
72
  return masks
73
+
74
+ def model_predict_masks_box(model,center_point,center_label,input_box):
75
+ masks = np.array([])
76
+ for i in range(len(center_label)):
77
+ if center_point[i] == []:continue
78
+ center_point_1 = np.array([center_point[i]])
79
+ center_label_1 = np.array(center_label[i])
80
+ input_box_1 = np.array(input_box[i])
81
+ mask, _, _ = model.predict(
82
+ point_coords=center_point_1,
83
+ point_labels=center_label_1,
84
+ box=input_box_1,
85
+ multimask_output=False,
86
+ )
87
+ try:
88
+ masks = masks + mask
89
+ except:
90
+ masks = mask
91
+
92
+ if torch.cuda.is_available():
93
+ torch.cuda.empty_cache()
94
+
95
+ return masks