haritsahm
commited on
Commit
·
f665217
1
Parent(s):
fb345ee
Add new features to segment everything
Browse files- app.py +86 -24
- utils/utils.py +17 -8
app.py
CHANGED
|
@@ -10,10 +10,10 @@ from PIL import Image
|
|
| 10 |
from streamlit_drawable_canvas import st_canvas
|
| 11 |
from utils import utils
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
|
| 16 |
-
def
|
| 17 |
bg_image = st.session_state['image']
|
| 18 |
width, height = bg_image.size[:2]
|
| 19 |
container_width = 700
|
|
@@ -44,6 +44,7 @@ def box_process(model, show_mask, radius_width):
|
|
| 44 |
st.session_state.rerun_once = True
|
| 45 |
|
| 46 |
st.session_state.display_result = True
|
|
|
|
| 47 |
if st.session_state.rerun_once:
|
| 48 |
st.experimental_rerun()
|
| 49 |
else:
|
|
@@ -64,11 +65,12 @@ def box_process(model, show_mask, radius_width):
|
|
| 64 |
input_box.append([x,y,x+w,y+h])
|
| 65 |
|
| 66 |
masks = []
|
| 67 |
-
if
|
| 68 |
-
masks = utils.model_predict_masks_box(
|
| 69 |
|
| 70 |
if len(masks) == 0:
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
bg_image = np.asarray(bg_image)
|
| 74 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
|
@@ -84,7 +86,7 @@ def box_process(model, show_mask, radius_width):
|
|
| 84 |
return np.asarray(bg_image)
|
| 85 |
|
| 86 |
|
| 87 |
-
def
|
| 88 |
|
| 89 |
bg_image = st.session_state['image']
|
| 90 |
width, height = bg_image.size[:2]
|
|
@@ -114,6 +116,7 @@ def click_process(model, show_mask, radius_width):
|
|
| 114 |
st.session_state.rerun_once = True
|
| 115 |
|
| 116 |
st.session_state.display_result = True
|
|
|
|
| 117 |
if st.session_state.rerun_once:
|
| 118 |
st.experimental_rerun()
|
| 119 |
else:
|
|
@@ -135,11 +138,12 @@ def click_process(model, show_mask, radius_width):
|
|
| 135 |
input_labels.append(0)
|
| 136 |
|
| 137 |
masks = []
|
| 138 |
-
if
|
| 139 |
-
masks = utils.model_predict_masks_click(
|
| 140 |
|
| 141 |
if len(masks) == 0:
|
| 142 |
-
|
|
|
|
| 143 |
|
| 144 |
bg_image = np.asarray(bg_image)
|
| 145 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
|
@@ -147,6 +151,7 @@ def click_process(model, show_mask, radius_width):
|
|
| 147 |
im_masked = Image.fromarray(im_masked).convert('RGBA')
|
| 148 |
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
|
| 149 |
result_image = result_image.resize(scaled_hw)
|
|
|
|
| 150 |
return result_image
|
| 151 |
else:
|
| 152 |
return np.asarray(bg_image)
|
|
@@ -154,16 +159,65 @@ def click_process(model, show_mask, radius_width):
|
|
| 154 |
return np.asarray(bg_image)
|
| 155 |
|
| 156 |
|
| 157 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
if 'uploaded_image' not in st.session_state:
|
| 159 |
return
|
| 160 |
if st.session_state.uploaded_image is not None:
|
| 161 |
with st.spinner(text="Uploading image..."):
|
| 162 |
image = Image.open(st.session_state.uploaded_image).convert("RGB")
|
| 163 |
-
if
|
| 164 |
np_image = np.asanyarray(image)
|
| 165 |
with st.spinner(text="Extracing embeddings.."):
|
| 166 |
-
|
| 167 |
st.session_state.image = image
|
| 168 |
else:
|
| 169 |
with st.spinner(text="Cleaning up!"):
|
|
@@ -173,8 +227,8 @@ def image_preprocess_callback(model):
|
|
| 173 |
st.session_state.image = None
|
| 174 |
if 'result_image' in st.session_state:
|
| 175 |
del st.session_state['result_image']
|
| 176 |
-
if
|
| 177 |
-
|
| 178 |
|
| 179 |
def main():
|
| 180 |
with open('index.html', encoding='utf-8') as f:
|
|
@@ -202,21 +256,29 @@ def main():
|
|
| 202 |
st.write("Upload Image")
|
| 203 |
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden")
|
| 204 |
|
|
|
|
| 205 |
canvas_input, canvas_output = st.columns(2)
|
|
|
|
| 206 |
if 'image' in st.session_state:
|
| 207 |
-
result_image = None
|
| 208 |
with canvas_input:
|
| 209 |
st.write("Select Interest Area/Objects")
|
| 210 |
if st.session_state.image is not None:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
result_image =
|
| 214 |
-
|
| 215 |
-
result_image =
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
else:
|
| 222 |
st.cache_data.clear()
|
|
|
|
| 10 |
from streamlit_drawable_canvas import st_canvas
|
| 11 |
from utils import utils
|
| 12 |
|
| 13 |
+
PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model('vit_b')
|
| 14 |
|
| 15 |
|
| 16 |
+
def process_box(predictor_model, show_mask, radius_width):
|
| 17 |
bg_image = st.session_state['image']
|
| 18 |
width, height = bg_image.size[:2]
|
| 19 |
container_width = 700
|
|
|
|
| 44 |
st.session_state.rerun_once = True
|
| 45 |
|
| 46 |
st.session_state.display_result = True
|
| 47 |
+
st.warning("Mask view is disabled", icon="❗")
|
| 48 |
if st.session_state.rerun_once:
|
| 49 |
st.experimental_rerun()
|
| 50 |
else:
|
|
|
|
| 65 |
input_box.append([x,y,x+w,y+h])
|
| 66 |
|
| 67 |
masks = []
|
| 68 |
+
if predictor_model:
|
| 69 |
+
masks = utils.model_predict_masks_box(predictor_model, center_point, center_label, input_box)
|
| 70 |
|
| 71 |
if len(masks) == 0:
|
| 72 |
+
st.warning("No Masks Found", icon="❗")
|
| 73 |
+
return np.asarray(bg_image)
|
| 74 |
|
| 75 |
bg_image = np.asarray(bg_image)
|
| 76 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
|
|
|
| 86 |
return np.asarray(bg_image)
|
| 87 |
|
| 88 |
|
| 89 |
+
def process_click(predictor_model, show_mask, radius_width):
|
| 90 |
|
| 91 |
bg_image = st.session_state['image']
|
| 92 |
width, height = bg_image.size[:2]
|
|
|
|
| 116 |
st.session_state.rerun_once = True
|
| 117 |
|
| 118 |
st.session_state.display_result = True
|
| 119 |
+
st.warning("Mask view is disabled", icon="❗")
|
| 120 |
if st.session_state.rerun_once:
|
| 121 |
st.experimental_rerun()
|
| 122 |
else:
|
|
|
|
| 138 |
input_labels.append(0)
|
| 139 |
|
| 140 |
masks = []
|
| 141 |
+
if predictor_model:
|
| 142 |
+
masks = utils.model_predict_masks_click(predictor_model, input_points, input_labels)
|
| 143 |
|
| 144 |
if len(masks) == 0:
|
| 145 |
+
st.warning("No Masks Found", icon="❗")
|
| 146 |
+
return np.asarray(bg_image)
|
| 147 |
|
| 148 |
bg_image = np.asarray(bg_image)
|
| 149 |
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
|
|
|
|
| 151 |
im_masked = Image.fromarray(im_masked).convert('RGBA')
|
| 152 |
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
|
| 153 |
result_image = result_image.resize(scaled_hw)
|
| 154 |
+
st.session_state.display_result = True
|
| 155 |
return result_image
|
| 156 |
else:
|
| 157 |
return np.asarray(bg_image)
|
|
|
|
| 159 |
return np.asarray(bg_image)
|
| 160 |
|
| 161 |
|
| 162 |
+
def process_everything(automask_model, show_mask, radius_width):
|
| 163 |
+
bg_image = st.session_state['image']
|
| 164 |
+
width, height = bg_image.size[:2]
|
| 165 |
+
container_width = 700
|
| 166 |
+
scale = container_width/width
|
| 167 |
+
scaled_hw = (container_width, int(height * scale))
|
| 168 |
+
|
| 169 |
+
if 'result_image' not in st.session_state:
|
| 170 |
+
st.session_state.result_image = bg_image.resize(scaled_hw)
|
| 171 |
+
|
| 172 |
+
dummy_canvas = st_canvas(
|
| 173 |
+
fill_color="rgba(255, 255, 0, 0.8)",
|
| 174 |
+
background_image = bg_image,
|
| 175 |
+
drawing_mode='freedraw',
|
| 176 |
+
width = container_width,
|
| 177 |
+
height = height * scale,
|
| 178 |
+
point_display_radius = radius_width,
|
| 179 |
+
stroke_width=2,
|
| 180 |
+
update_streamlit=False,
|
| 181 |
+
key="everything",)
|
| 182 |
+
|
| 183 |
+
if not show_mask:
|
| 184 |
+
if 'rerun_once' in st.session_state:
|
| 185 |
+
if st.session_state.rerun_once:
|
| 186 |
+
st.session_state.rerun_once = False
|
| 187 |
+
else:
|
| 188 |
+
st.session_state.rerun_once = True
|
| 189 |
+
|
| 190 |
+
st.session_state.display_result = True
|
| 191 |
+
st.warning("Mask view is disabled", icon="❗")
|
| 192 |
+
if st.session_state.rerun_once:
|
| 193 |
+
st.experimental_rerun()
|
| 194 |
+
else:
|
| 195 |
+
return np.asarray(bg_image)
|
| 196 |
+
|
| 197 |
+
if automask_model:
|
| 198 |
+
bg_image = np.asarray(bg_image)
|
| 199 |
+
masks = utils.model_predict_masks_everything(automask_model, bg_image)
|
| 200 |
+
im_masked = utils.show_everything(masks)
|
| 201 |
+
im_masked = Image.fromarray(im_masked).convert('RGBA')
|
| 202 |
+
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
|
| 203 |
+
result_image = result_image.resize(scaled_hw)
|
| 204 |
+
st.session_state.display_result = True
|
| 205 |
+
return result_image
|
| 206 |
+
|
| 207 |
+
else:
|
| 208 |
+
return np.asarray(bg_image)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def image_preprocess_callback(predictor_model, option):
|
| 212 |
if 'uploaded_image' not in st.session_state:
|
| 213 |
return
|
| 214 |
if st.session_state.uploaded_image is not None:
|
| 215 |
with st.spinner(text="Uploading image..."):
|
| 216 |
image = Image.open(st.session_state.uploaded_image).convert("RGB")
|
| 217 |
+
if predictor_model and option != 'Everything':
|
| 218 |
np_image = np.asanyarray(image)
|
| 219 |
with st.spinner(text="Extracing embeddings.."):
|
| 220 |
+
predictor_model.set_image(np_image)
|
| 221 |
st.session_state.image = image
|
| 222 |
else:
|
| 223 |
with st.spinner(text="Cleaning up!"):
|
|
|
|
| 227 |
st.session_state.image = None
|
| 228 |
if 'result_image' in st.session_state:
|
| 229 |
del st.session_state['result_image']
|
| 230 |
+
if predictor_model:
|
| 231 |
+
predictor_model.reset_image()
|
| 232 |
|
| 233 |
def main():
|
| 234 |
with open('index.html', encoding='utf-8') as f:
|
|
|
|
| 256 |
st.write("Upload Image")
|
| 257 |
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden")
|
| 258 |
|
| 259 |
+
result_image = None
|
| 260 |
canvas_input, canvas_output = st.columns(2)
|
| 261 |
+
|
| 262 |
if 'image' in st.session_state:
|
|
|
|
| 263 |
with canvas_input:
|
| 264 |
st.write("Select Interest Area/Objects")
|
| 265 |
if st.session_state.image is not None:
|
| 266 |
+
with st.spinner(text="Computing masks"):
|
| 267 |
+
if option == 'Click':
|
| 268 |
+
result_image = process_click(PREDICTOR_MODEL, show_mask, radius_width)
|
| 269 |
+
elif option == 'Box':
|
| 270 |
+
result_image = process_box(PREDICTOR_MODEL, show_mask, radius_width)
|
| 271 |
+
else:
|
| 272 |
+
result_image = process_everything(AUTOMASK_MODEL, show_mask, radius_width)
|
| 273 |
+
if 'display_result' in st.session_state:
|
| 274 |
+
if st.session_state.display_result:
|
| 275 |
+
with canvas_output:
|
| 276 |
+
if result_image is not None:
|
| 277 |
+
st.write("Result")
|
| 278 |
+
st.image(result_image)
|
| 279 |
+
else:
|
| 280 |
+
st.warning("No result found, please set input prompt", icon="⚠️")
|
| 281 |
+
st.success('Process completed!', icon="✅")
|
| 282 |
|
| 283 |
else:
|
| 284 |
st.cache_data.clear()
|
utils/utils.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
|
| 2 |
-
from segment_anything import SamPredictor, sam_model_registry
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
from distinctipy import distinctipy
|
|
@@ -19,20 +19,19 @@ def get_model(model):
|
|
| 19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
build_sam = sam_model_registry[model]
|
| 21 |
model = build_sam(checkpoint=get_checkpoint_path(model)).to(device)
|
| 22 |
-
predictor = SamPredictor(model)
|
| 23 |
if torch.cuda.is_available():
|
| 24 |
torch.cuda.empty_cache()
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
-
@st.cache_data
|
| 29 |
def show_everything(sorted_anns):
|
| 30 |
-
if len(sorted_anns) == 0:
|
| 31 |
-
return
|
| 32 |
#sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
|
| 33 |
h, w = sorted_anns[0]['segmentation'].shape[-2:]
|
| 34 |
#sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
|
| 35 |
if sorted_anns == []:
|
|
|
|
| 36 |
return np.zeros((h,w,4)).astype(np.uint8)
|
| 37 |
mask = np.zeros((h,w,4))
|
| 38 |
for ann in sorted_anns:
|
|
@@ -40,13 +39,13 @@ def show_everything(sorted_anns):
|
|
| 40 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 41 |
mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
|
| 42 |
mask = mask * 255
|
| 43 |
-
st.success('Process completed!', icon="✅")
|
| 44 |
return mask.astype(np.uint8)
|
| 45 |
|
| 46 |
|
| 47 |
def show_click(masks, colors):
|
| 48 |
h, w = masks[0].shape[-2:]
|
| 49 |
masks_total = np.zeros((h,w,4)).astype(np.uint8)
|
|
|
|
| 50 |
for mask, color in zip(masks, colors):
|
| 51 |
if np.array_equal(mask,np.array([])):continue
|
| 52 |
masks = np.zeros((h,w,4)).astype(np.uint8)
|
|
@@ -54,7 +53,7 @@ def show_click(masks, colors):
|
|
| 54 |
masks = masks.astype(bool).astype(np.uint8)
|
| 55 |
masks = masks * 255 * color.reshape(1, 1, -1)
|
| 56 |
masks_total += masks.astype(np.uint8)
|
| 57 |
-
|
| 58 |
return masks_total
|
| 59 |
|
| 60 |
def model_predict_masks_click(model,input_points,input_labels):
|
|
@@ -66,6 +65,7 @@ def model_predict_masks_click(model,input_points,input_labels):
|
|
| 66 |
point_labels=input_labels,
|
| 67 |
multimask_output=False,
|
| 68 |
)
|
|
|
|
| 69 |
if torch.cuda.is_available():
|
| 70 |
torch.cuda.empty_cache()
|
| 71 |
|
|
@@ -93,3 +93,12 @@ def model_predict_masks_box(model,center_point,center_label,input_box):
|
|
| 93 |
torch.cuda.empty_cache()
|
| 94 |
|
| 95 |
return masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
from distinctipy import distinctipy
|
|
|
|
| 19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
build_sam = sam_model_registry[model]
|
| 21 |
model = build_sam(checkpoint=get_checkpoint_path(model)).to(device)
|
|
|
|
| 22 |
if torch.cuda.is_available():
|
| 23 |
torch.cuda.empty_cache()
|
| 24 |
+
predictor = SamPredictor(model)
|
| 25 |
+
mask_generator = SamAutomaticMaskGenerator(model)
|
| 26 |
+
return predictor, mask_generator
|
| 27 |
|
| 28 |
|
|
|
|
| 29 |
def show_everything(sorted_anns):
|
|
|
|
|
|
|
| 30 |
#sorted_anns = sorted(anns, key=(lambda x: x['stability_score']), reverse=True)
|
| 31 |
h, w = sorted_anns[0]['segmentation'].shape[-2:]
|
| 32 |
#sorted_anns = sorted_anns[:int(len(sorted_anns) * stability_score/100)]
|
| 33 |
if sorted_anns == []:
|
| 34 |
+
st.warning("No Masks Found", icon="❗")
|
| 35 |
return np.zeros((h,w,4)).astype(np.uint8)
|
| 36 |
mask = np.zeros((h,w,4))
|
| 37 |
for ann in sorted_anns:
|
|
|
|
| 39 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 40 |
mask += m.reshape(h,w,1) * color.reshape(1, 1, -1)
|
| 41 |
mask = mask * 255
|
|
|
|
| 42 |
return mask.astype(np.uint8)
|
| 43 |
|
| 44 |
|
| 45 |
def show_click(masks, colors):
|
| 46 |
h, w = masks[0].shape[-2:]
|
| 47 |
masks_total = np.zeros((h,w,4)).astype(np.uint8)
|
| 48 |
+
|
| 49 |
for mask, color in zip(masks, colors):
|
| 50 |
if np.array_equal(mask,np.array([])):continue
|
| 51 |
masks = np.zeros((h,w,4)).astype(np.uint8)
|
|
|
|
| 53 |
masks = masks.astype(bool).astype(np.uint8)
|
| 54 |
masks = masks * 255 * color.reshape(1, 1, -1)
|
| 55 |
masks_total += masks.astype(np.uint8)
|
| 56 |
+
|
| 57 |
return masks_total
|
| 58 |
|
| 59 |
def model_predict_masks_click(model,input_points,input_labels):
|
|
|
|
| 65 |
point_labels=input_labels,
|
| 66 |
multimask_output=False,
|
| 67 |
)
|
| 68 |
+
|
| 69 |
if torch.cuda.is_available():
|
| 70 |
torch.cuda.empty_cache()
|
| 71 |
|
|
|
|
| 93 |
torch.cuda.empty_cache()
|
| 94 |
|
| 95 |
return masks
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def model_predict_masks_everything(mask_generator, image):
|
| 99 |
+
masks = mask_generator.generate(image)
|
| 100 |
+
|
| 101 |
+
if torch.cuda.is_available():
|
| 102 |
+
torch.cuda.empty_cache()
|
| 103 |
+
|
| 104 |
+
return masks
|