|
import streamlit as st |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
import random |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from streamlit_drawable_canvas import st_canvas |
|
from utils import utils |
|
|
|
PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model(checkpoint='checkpoint/medsam_vit_b.pth') |
|
|
|
|
|
def process_box(predictor_model, show_mask, radius_width): |
|
bg_image = st.session_state['image'] |
|
width, height = bg_image.size[:2] |
|
container_width = 700 |
|
scale = container_width/width |
|
scaled_wh = (container_width, int(height * scale)) |
|
|
|
if not predictor_model.set_image: |
|
np_image = np.asanyarray(bg_image) |
|
with st.spinner(text="Extracing embeddings.."): |
|
predictor_model.set_image(np_image) |
|
|
|
if 'result_image' not in st.session_state: |
|
st.session_state.result_image = bg_image.resize(scaled_wh) |
|
|
|
box_canvas = st_canvas( |
|
fill_color="rgba(255, 255, 0, 0)", |
|
background_image = bg_image, |
|
drawing_mode='rect', |
|
stroke_color = "rgba(0, 255, 0, 0.6)", |
|
stroke_width = radius_width, |
|
width = container_width, |
|
height = height * scale, |
|
point_display_radius = 12, |
|
update_streamlit=True, |
|
key="box" |
|
) |
|
|
|
if not show_mask: |
|
if 'rerun_once' in st.session_state: |
|
if st.session_state.rerun_once: |
|
st.session_state.rerun_once = False |
|
else: |
|
st.session_state.rerun_once = True |
|
|
|
st.session_state.display_result = True |
|
st.warning("Mask view is disabled", icon="❗") |
|
if st.session_state.rerun_once: |
|
st.experimental_rerun() |
|
else: |
|
return np.asarray(bg_image) |
|
|
|
elif box_canvas.json_data is not None: |
|
df = pd.json_normalize(box_canvas.json_data["objects"]) |
|
center_point,center_label,input_box = [],[],[] |
|
center_point, center_label, input_box = [], [], [] |
|
for _, row in df.iterrows(): |
|
x, y, w,h = row["left"], row["top"], row["width"], row["height"] |
|
x = int(x/scale) |
|
y = int(y/scale) |
|
w = int(w/scale) |
|
h = int(h/scale) |
|
center_point.append([x+w/2,y+h/2]) |
|
center_label.append([1]) |
|
input_box.append([x,y,x+w,y+h]) |
|
|
|
masks = [] |
|
if predictor_model: |
|
masks = utils.model_predict_masks_box(predictor_model, center_point, center_label, input_box) |
|
|
|
if len(masks) == 0: |
|
st.warning("No Masks Found", icon="❗") |
|
return np.asarray(bg_image) |
|
|
|
bg_image = np.asarray(bg_image) |
|
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) |
|
im_masked = utils.show_click(masks,color) |
|
im_masked = Image.fromarray(im_masked).convert('RGBA') |
|
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") |
|
result_image = result_image.resize(scaled_wh) |
|
st.session_state.display_result = True |
|
return result_image |
|
else: |
|
return np.asarray(bg_image) |
|
|
|
return np.asarray(bg_image) |
|
|
|
|
|
def process_click(predictor_model, show_mask, radius_width): |
|
|
|
bg_image = st.session_state['image'] |
|
width, height = bg_image.size[:2] |
|
container_width = 700 |
|
scale = container_width/width |
|
scaled_wh = (container_width, int(height * scale)) |
|
|
|
if not predictor_model.set_image: |
|
np_image = np.asanyarray(bg_image) |
|
with st.spinner(text="Extracing embeddings.."): |
|
predictor_model.set_image(np_image) |
|
|
|
if 'result_image' not in st.session_state: |
|
st.session_state.result_image = bg_image.resize(scaled_wh) |
|
|
|
click_canvas = st_canvas( |
|
fill_color="rgba(255, 255, 0, 0.8)", |
|
background_image = bg_image, |
|
drawing_mode='point', |
|
width = container_width, |
|
height = height * scale, |
|
point_display_radius = radius_width, |
|
stroke_width=2, |
|
update_streamlit=True, |
|
key="point",) |
|
|
|
if not show_mask: |
|
if 'rerun_once' in st.session_state: |
|
if st.session_state.rerun_once: |
|
st.session_state.rerun_once = False |
|
else: |
|
st.session_state.rerun_once = True |
|
|
|
st.session_state.display_result = True |
|
st.warning("Mask view is disabled", icon="❗") |
|
if st.session_state.rerun_once: |
|
st.experimental_rerun() |
|
else: |
|
return np.asarray(bg_image) |
|
|
|
elif click_canvas.json_data is not None: |
|
df = pd.json_normalize(click_canvas.json_data["objects"]) |
|
input_points = [] |
|
input_labels = [] |
|
|
|
for _, row in df.iterrows(): |
|
x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2) |
|
x = int(x/scale) |
|
y = int(y/scale) |
|
input_points.append([x, y]) |
|
if row['fill'] == "rgba(0, 255, 0, 0.8)": |
|
input_labels.append(1) |
|
else: |
|
input_labels.append(0) |
|
|
|
masks = [] |
|
if predictor_model: |
|
masks = utils.model_predict_masks_click(predictor_model, input_points, input_labels) |
|
|
|
if len(masks) == 0: |
|
st.warning("No Masks Found", icon="❗") |
|
return np.asarray(bg_image) |
|
|
|
bg_image = np.asarray(bg_image) |
|
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) |
|
im_masked = utils.show_click(masks,color) |
|
im_masked = Image.fromarray(im_masked).convert('RGBA') |
|
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") |
|
result_image = result_image.resize(scaled_wh) |
|
st.session_state.display_result = True |
|
return result_image |
|
else: |
|
return np.asarray(bg_image) |
|
|
|
return np.asarray(bg_image) |
|
|
|
|
|
def process_everything(automask_model, show_mask, radius_width): |
|
bg_image = st.session_state['image'] |
|
width, height = bg_image.size[:2] |
|
container_width = 700 |
|
scale = container_width/width |
|
scaled_wh = (container_width, int(height * scale)) |
|
|
|
if 'result_image' not in st.session_state: |
|
st.session_state.result_image = bg_image.resize(scaled_wh) |
|
|
|
dummy_canvas = st_canvas( |
|
fill_color="rgba(255, 255, 0, 0.8)", |
|
background_image = bg_image, |
|
drawing_mode='freedraw', |
|
width = container_width, |
|
height = height * scale, |
|
point_display_radius = radius_width, |
|
stroke_width=2, |
|
update_streamlit=False, |
|
key="everything",) |
|
|
|
if not show_mask: |
|
if 'rerun_once' in st.session_state: |
|
if st.session_state.rerun_once: |
|
st.session_state.rerun_once = False |
|
else: |
|
st.session_state.rerun_once = True |
|
|
|
st.session_state.display_result = True |
|
st.warning("Mask view is disabled", icon="❗") |
|
if st.session_state.rerun_once: |
|
st.experimental_rerun() |
|
else: |
|
return np.asarray(bg_image) |
|
|
|
if automask_model: |
|
bg_image = np.asarray(bg_image) |
|
masks = utils.model_predict_masks_everything(automask_model, bg_image) |
|
im_masked = utils.show_everything(masks) |
|
im_masked = Image.fromarray(im_masked).convert('RGBA') |
|
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") |
|
result_image = result_image.resize(scaled_wh) |
|
st.session_state.display_result = True |
|
return result_image |
|
|
|
else: |
|
return np.asarray(bg_image) |
|
|
|
|
|
def image_preprocess_callback(predictor_model, option): |
|
if 'uploaded_image' not in st.session_state: |
|
return |
|
if st.session_state.uploaded_image is not None: |
|
with st.spinner(text="Uploading image..."): |
|
image = Image.open(st.session_state.uploaded_image).convert("RGB") |
|
if predictor_model and option != 'Everything': |
|
np_image = np.asanyarray(image) |
|
with st.spinner(text="Extracing embeddings.."): |
|
predictor_model.set_image(np_image) |
|
else: |
|
if predictor_model: |
|
predictor_model.reset_image() |
|
|
|
st.session_state.image = image |
|
else: |
|
with st.spinner(text="Cleaning up!"): |
|
if 'display_result' in st.session_state: |
|
st.session_state.display_result = False |
|
if 'image' in st.session_state: |
|
st.session_state.image = None |
|
if 'result_image' in st.session_state: |
|
del st.session_state['result_image'] |
|
if predictor_model: |
|
predictor_model.reset_image() |
|
|
|
def main(): |
|
with open('index.html', encoding='utf-8') as f: |
|
html_content = f.read() |
|
|
|
st.markdown(html_content, unsafe_allow_html=True) |
|
st.markdown('### Model Architecture') |
|
st.image('figures/medsam.png', caption="Segment Anything - MedSAM", width=600) |
|
st.markdown('### Demo') |
|
|
|
with st.container(): |
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
option = st.selectbox('Segmentation mode', ('Click', 'Box', 'Everything')) |
|
|
|
with col2: |
|
st.write("Show or Hide Mask") |
|
show_mask = st.checkbox('Show mask',value = True) |
|
|
|
with col3: |
|
radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1) |
|
|
|
with st.container(): |
|
st.write("Upload Image") |
|
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(PREDICTOR_MODEL, option,), label_visibility="hidden") |
|
|
|
result_image = None |
|
canvas_input, canvas_output = st.columns(2) |
|
|
|
if 'image' in st.session_state: |
|
with canvas_input: |
|
st.write("Select Interest Area/Objects") |
|
if st.session_state.image is not None: |
|
with st.spinner(text="Computing masks"): |
|
if option == 'Click': |
|
result_image = process_click(PREDICTOR_MODEL, show_mask, radius_width) |
|
elif option == 'Box': |
|
result_image = process_box(PREDICTOR_MODEL, show_mask, radius_width) |
|
else: |
|
result_image = process_everything(AUTOMASK_MODEL, show_mask, radius_width) |
|
if 'display_result' in st.session_state: |
|
if st.session_state.display_result: |
|
with canvas_output: |
|
if result_image is not None: |
|
st.write("Result") |
|
st.image(result_image) |
|
else: |
|
st.warning("No result found, please set input prompt", icon="⚠️") |
|
st.success('Process completed!', icon="✅") |
|
|
|
else: |
|
st.cache_data.clear() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |