haritsahm
commited on
Commit
·
f665217
1
Parent(s):
fb345ee
Add new features to segment everything
Browse files- app.py +86 -24
- utils/utils.py +17 -8
app.py
CHANGED
@@ -10,10 +10,10 @@ from PIL import Image
|
|
10 |
from streamlit_drawable_canvas import st_canvas
|
11 |
from utils import utils
|
12 |
|
13 |
-
|
14 |
|
15 |
|
16 |
-
def
|
17 |
bg_image = st.session_state['image']
|
18 |
width, height = bg_image.size[:2]
|
19 |
container_width = 700
|
@@ -44,6 +44,7 @@ def box_process(model, show_mask, radius_width):
|
|
44 |
st.session_state.rerun_once = True
|
45 |
|
46 |
st.session_state.display_result = True
|
|
|
47 |
if st.session_state.rerun_once:
|
48 |
st.experimental_rerun()
|
49 |
else:
|
@@ -64,11 +65,12 @@ def box_process(model, show_mask, radius_width):
|
|
64 |
input_box.append([x,y,x+w,y+h])
|
65 |
|
66 |
masks = []
|
67 |
-
if
|
68 |
-
masks = utils.model_predict_masks_box(
|
69 |
|
70 |
if len(masks) == 0:
|
71 |
-
|
|
|
72 |
|
73 |
bg_image = np.asarray(bg_image)
|
74 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
@@ -84,7 +86,7 @@ def box_process(model, show_mask, radius_width):
|
|
84 |
return np.asarray(bg_image)
|
85 |
|
86 |
|
87 |
-
def
|
88 |
|
89 |
bg_image = st.session_state['image']
|
90 |
width, height = bg_image.size[:2]
|
@@ -114,6 +116,7 @@ def click_process(model, show_mask, radius_width):
|
|
114 |
st.session_state.rerun_once = True
|
115 |
|
116 |
st.session_state.display_result = True
|
|
|
117 |
if st.session_state.rerun_once:
|
118 |
st.experimental_rerun()
|
119 |
else:
|
@@ -135,11 +138,12 @@ def click_process(model, show_mask, radius_width):
|
|
135 |
input_labels.append(0)
|
136 |
|
137 |
masks = []
|
138 |
-
if
|
139 |
-
masks = utils.model_predict_masks_click(
|
140 |
|
141 |
if len(masks) == 0:
|
142 |
-
|
|
|
143 |
|
144 |
bg_image = np.asarray(bg_image)
|
145 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
@@ -147,6 +151,7 @@ def click_process(model, show_mask, radius_width):
|
|
147 |
im_masked = Image.fromarray(im_masked).convert('RGBA')
|
148 |
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
|
149 |
result_image = result_image.resize(scaled_hw)
|
|
|
150 |
return result_image
|
151 |
else:
|
152 |
return np.asarray(bg_image)
|
@@ -154,16 +159,65 @@ def click_process(model, show_mask, radius_width):
|
|
154 |
return np.asarray(bg_image)
|
155 |
|
156 |
|
157 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
if 'uploaded_image' not in st.session_state:
|
159 |
return
|
160 |
if st.session_state.uploaded_image is not None:
|
161 |
with st.spinner(text="Uploading image..."):
|
162 |
image = Image.open(st.session_state.uploaded_image).convert("RGB")
|
163 |
-
if
|
164 |
np_image = np.asanyarray(image)
|
165 |
with st.spinner(text="Extracing embeddings.."):
|
166 |
-
|
167 |
st.session_state.image = image
|
168 |
else:
|
169 |
with st.spinner(text="Cleaning up!"):
|
@@ -173,8 +227,8 @@ def image_preprocess_callback(model):
|
|
173 |
st.session_state.image = None
|
174 |
if 'result_image' in st.session_state:
|
175 |
del st.session_state['result_image']
|
176 |
-
if
|
177 |
-
|
178 |
|
179 |
def main():
|
180 |
with open('index.html', encoding='utf-8') as f:
|
@@ -202,21 +256,29 @@ def main():
|
|
202 |
st.write("Upload Image")
|
203 |
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden")
|
204 |
|
|
|
205 |
canvas_input, canvas_output = st.columns(2)
|
|
|
206 |
if 'image' in st.session_state:
|
207 |
-
result_image = None
|
208 |
with canvas_input:
|
209 |
st.write("Select Interest Area/Objects")
|
210 |
if st.session_state.image is not None:
|
211 |
-
|
212 |
-
|
213 |
-
result_image =
|
214 |
-
|
215 |
-
result_image =
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
else:
|
222 |
st.cache_data.clear()
|
|
|
10 |
from streamlit_drawable_canvas import st_canvas
|
11 |
from utils import utils
|
12 |
|
13 |
+
PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model('vit_b')
|
14 |
|
15 |
|
16 |
+
def process_box(predictor_model, show_mask, radius_width):
|
17 |
bg_image = st.session_state['image']
|
18 |
width, height = bg_image.size[:2]
|
19 |
container_width = 700
|
|
|
44 |
st.session_state.rerun_once = True
|
45 |
|
46 |
st.session_state.display_result = True
|
47 |
+
st.warning("Mask view is disabled", icon="❗")
|
48 |
if st.session_state.rerun_once:
|
49 |
st.experimental_rerun()
|
50 |
else:
|
|
|
65 |
input_box.append([x,y,x+w,y+h])
|
66 |
|
67 |
masks = []
|
68 |
+
if predictor_model:
|
69 |
+
masks = utils.model_predict_masks_box(predictor_model, center_point, center_label, input_box)
|
70 |
|
71 |
if len(masks) == 0:
|
72 |
+
st.warning("No Masks Found", icon="❗")
|
73 |
+
return np.asarray(bg_image)
|
74 |
|
75 |
bg_image = np.asarray(bg_image)
|
76 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
|
|
86 |
return np.asarray(bg_image)
|
87 |
|
88 |
|
89 |
+
def process_click(predictor_model, show_mask, radius_width):
|
90 |
|
91 |
bg_image = st.session_state['image']
|
92 |
width, height = bg_image.size[:2]
|
|
|
116 |
st.session_state.rerun_once = True
|
117 |
|
118 |
st.session_state.display_result = True
|
119 |
+
st.warning("Mask view is disabled", icon="❗")
|
120 |
if st.session_state.rerun_once:
|
121 |
st.experimental_rerun()
|
122 |
else:
|
|
|
138 |
input_labels.append(0)
|
139 |
|
140 |
masks = []
|
141 |
+
if predictor_model:
|
142 |
+
masks = utils.model_predict_masks_click(predictor_model, input_points, input_labels)
|
143 |
|
144 |
if len(masks) == 0:
|
145 |
+
st.warning("No Masks Found", icon="❗")
|
146 |
+
return np.asarray(bg_image)
|
147 |
|
148 |
bg_image = np.asarray(bg_image)
|
149 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
|
|
151 |
im_masked = Image.fromarray(im_masked).convert('RGBA')
|
152 |
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
|
153 |
result_image = result_image.resize(scaled_hw)
|
154 |
+
st.session_state.display_result = True
|
155 |
return result_image
|
156 |
else:
|
157 |
return np.asarray(bg_image)
|
|
|
159 |
return np.asarray(bg_image)
|
160 |
|
161 |
|
162 |
+
def process_everything(automask_model, show_mask, radius_width):
|
163 |
+
bg_image = st.session_state['image']
|
164 |
+
width, height = bg_image.size[:2]
|
165 |
+
container_width = 700
|
166 |
+
scale = container_width/width
|
167 |
+
scaled_hw = (container_width, int(height * scale))
|
168 |
+
|
169 |
+
if 'result_image' not in st.session_state:
|
170 |
+
st.session_state.result_image = bg_image.resize(scaled_hw)
|
171 |
+
|
172 |
+
dummy_canvas = st_canvas(
|
173 |
+
fill_color="rgba(255, 255, 0, 0.8)",
|
174 |
+
background_image = bg_image,
|
175 |
+
drawing_mode='freedraw',
|
176 |
+
width = container_width,
|
177 |
+
height = height * scale,
|
178 |
+
point_display_radius = radius_width,
|
179 |
+
stroke_width=2,
|
180 |
+
update_streamlit=False,
|
181 |
+
key="everything",)
|
182 |
+
|
183 |
+
if not show_mask:
|
184 |
+
if 'rerun_once' in st.session_state:
|
185 |
+
if st.session_state.rerun_once:
|
186 |
+
st.session_state.rerun_once = False
|
187 |
+
else:
|
188 |
+
st.session_state.rerun_once = True
|
189 |
+
|
190 |
+
st.session_state.display_result = True
|
191 |
+
st.warning("Mask view is disabled", icon="❗")
|
192 |
+
if st.session_state.rerun_once:
|
193 |
+
st.experimental_rerun()
|
194 |
+
else:
|
195 |
+
return np.asarray(bg_image)
|
196 |
+
|
197 |
+
if automask_model:
|
198 |
+
bg_image = np.asarray(bg_image)
|
199 |
+
masks = utils.model_predict_masks_everything(automask_model, bg_image)
|
200 |
+
im_masked = utils.show_everything(masks)
|
201 |
+
im_masked = Image.fromarray(im_masked).convert('RGBA')
|
202 |
+
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
|
203 |
+
result_image = result_image.resize(scaled_hw)
|
204 |
+
st.session_state.display_result = True
|
205 |
+
return result_image
|
206 |
+
|
207 |
+
else:
|
208 |
+
return np.asarray(bg_image)
|
209 |
+
|
210 |
+
|
211 |
+
def image_preprocess_callback(predictor_model, option):
|
212 |
if 'uploaded_image' not in st.session_state:
|
213 |
return
|
214 |
if st.session_state.uploaded_image is not None:
|
215 |
with st.spinner(text="Uploading image..."):
|
216 |
image = Image.open(st.session_state.uploaded_image).convert("RGB")
|
217 |
+
if predictor_model and option != 'Everything':
|
218 |
np_image = np.asanyarray(image)
|
219 |
with st.spinner(text="Extracing embeddings.."):
|
220 |
+
predictor_model.set_image(np_image)
|
221 |
st.session_state.image = image
|
222 |
else:
|
223 |
with st.spinner(text="Cleaning up!"):
|
|
|
227 |
st.session_state.image = None
|
228 |
if 'result_image' in st.session_state:
|
229 |
del st.session_state['result_image']
|
230 |
+
if predictor_model:
|
231 |
+
predictor_model.reset_image()
|
232 |
|
233 |
def main():
|
234 |
with open('index.html', encoding='utf-8') as f:
|
|
|
256 |
st.write("Upload Image")
|
257 |
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden")
|
258 |
|
259 |
+
result_image = None
|
260 |
canvas_input, canvas_output = st.columns(2)
|
261 |
+
|
262 |
if 'image' in st.session_state:
|
|
|
263 |
with canvas_input:
|
264 |
st.write("Select Interest Area/Objects")
|
265 |
if st.session_state.image is not None:
|
266 |
+
with st.spinner(text="Computing masks"):
|
267 |
+
if option == 'Click':
|
268 |
+
result_image = process_click(PREDICTOR_MODEL, show_mask, radius_width)
|
269 |
+
elif option == 'Box':
|
270 |
+
result_image = process_box(PREDICTOR_MODEL, show_mask, radius_width)
|
271 |
+
else:
|
272 |
+
result_image = process_everything(AUTOMASK_MODEL, show_mask, radius_width)
|
273 |
+
if 'display_result' in st.session_state:
|
274 |
+
if st.session_state.display_result:
|
275 |
+
with canvas_output:
|
276 |
+
if result_image is not None:
|
277 |
+
st.write("Result")
|
278 |
+
st.image(result_image)
|
279 |
+
else:
|
280 |
+
st.warning("No result found, please set input prompt", icon="⚠️")
|
281 |
+
st.success('Process completed!', icon="✅")
|
282 |
|
283 |
else:
|
284 |
st.cache_data.clear()
|
utils/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
|
2 |
-
from segment_anything import SamPredictor, sam_model_registry
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from distinctipy import distinctipy
|
@@ -19,20 +19,19 @@ def get_model(model):
|
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
build_sam = sam_model_registry[model]
|
21 |
model = build_sam(checkpoint=get_checkpoint_path(model)).to(device)
|
22 |
-
predictor = SamPredictor(model)
|
23 |
if torch.cuda.is_available():
|
24 |
torch.cuda.empty_cache()
|
25 |
-
|
|
|
|
|
26 |
|
27 |
|
28 |
-
@st.cache_data
|
29 |
def show_everything(sorted_anns):
|
30 |
-
if len(sorted_anns) == 0:
|
31 |
-
return
|
32 |
#sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
|
33 |
h, w = sorted_anns[0]['segmentation'].shape[-2:]
|
34 |
#sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
|
35 |
if sorted_anns == []:
|
|
|
36 |
return np.zeros((h,w,4)).astype(np.uint8)
|
37 |
mask = np.zeros((h,w,4))
|
38 |
for ann in sorted_anns:
|
@@ -40,13 +39,13 @@ def show_everything(sorted_anns):
|
|
40 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
41 |
mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
|
42 |
mask = mask * 255
|
43 |
-
st.success('Process completed!', icon="✅")
|
44 |
return mask.astype(np.uint8)
|
45 |
|
46 |
|
47 |
def show_click(masks, colors):
|
48 |
h, w = masks[0].shape[-2:]
|
49 |
masks_total = np.zeros((h,w,4)).astype(np.uint8)
|
|
|
50 |
for mask, color in zip(masks, colors):
|
51 |
if np.array_equal(mask,np.array([])):continue
|
52 |
masks = np.zeros((h,w,4)).astype(np.uint8)
|
@@ -54,7 +53,7 @@ def show_click(masks, colors):
|
|
54 |
masks = masks.astype(bool).astype(np.uint8)
|
55 |
masks = masks * 255 * color.reshape(1, 1, -1)
|
56 |
masks_total += masks.astype(np.uint8)
|
57 |
-
|
58 |
return masks_total
|
59 |
|
60 |
def model_predict_masks_click(model,input_points,input_labels):
|
@@ -66,6 +65,7 @@ def model_predict_masks_click(model,input_points,input_labels):
|
|
66 |
point_labels=input_labels,
|
67 |
multimask_output=False,
|
68 |
)
|
|
|
69 |
if torch.cuda.is_available():
|
70 |
torch.cuda.empty_cache()
|
71 |
|
@@ -93,3 +93,12 @@ def model_predict_masks_box(model,center_point,center_label,input_box):
|
|
93 |
torch.cuda.empty_cache()
|
94 |
|
95 |
return masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
+
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from distinctipy import distinctipy
|
|
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
build_sam = sam_model_registry[model]
|
21 |
model = build_sam(checkpoint=get_checkpoint_path(model)).to(device)
|
|
|
22 |
if torch.cuda.is_available():
|
23 |
torch.cuda.empty_cache()
|
24 |
+
predictor = SamPredictor(model)
|
25 |
+
mask_generator = SamAutomaticMaskGenerator(model)
|
26 |
+
return predictor, mask_generator
|
27 |
|
28 |
|
|
|
29 |
def show_everything(sorted_anns):
|
|
|
|
|
30 |
#sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
|
31 |
h, w = sorted_anns[0]['segmentation'].shape[-2:]
|
32 |
#sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
|
33 |
if sorted_anns == []:
|
34 |
+
st.warning("No Masks Found", icon="❗")
|
35 |
return np.zeros((h,w,4)).astype(np.uint8)
|
36 |
mask = np.zeros((h,w,4))
|
37 |
for ann in sorted_anns:
|
|
|
39 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
40 |
mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
|
41 |
mask = mask * 255
|
|
|
42 |
return mask.astype(np.uint8)
|
43 |
|
44 |
|
45 |
def show_click(masks, colors):
|
46 |
h, w = masks[0].shape[-2:]
|
47 |
masks_total = np.zeros((h,w,4)).astype(np.uint8)
|
48 |
+
|
49 |
for mask, color in zip(masks, colors):
|
50 |
if np.array_equal(mask,np.array([])):continue
|
51 |
masks = np.zeros((h,w,4)).astype(np.uint8)
|
|
|
53 |
masks = masks.astype(bool).astype(np.uint8)
|
54 |
masks = masks * 255 * color.reshape(1, 1, -1)
|
55 |
masks_total += masks.astype(np.uint8)
|
56 |
+
|
57 |
return masks_total
|
58 |
|
59 |
def model_predict_masks_click(model,input_points,input_labels):
|
|
|
65 |
point_labels=input_labels,
|
66 |
multimask_output=False,
|
67 |
)
|
68 |
+
|
69 |
if torch.cuda.is_available():
|
70 |
torch.cuda.empty_cache()
|
71 |
|
|
|
93 |
torch.cuda.empty_cache()
|
94 |
|
95 |
return masks
|
96 |
+
|
97 |
+
|
98 |
+
def model_predict_masks_everything(mask_generator, image):
|
99 |
+
masks = mask_generator.generate(image)
|
100 |
+
|
101 |
+
if torch.cuda.is_available():
|
102 |
+
torch.cuda.empty_cache()
|
103 |
+
|
104 |
+
return masks
|