haritsahm
minor fix to variable name
ef9685e
raw
history blame
10.9 kB
import streamlit as st
st.set_page_config(layout="wide")
import random
import numpy as np
import pandas as pd
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from utils import utils
PREDICTOR_MODEL, AUTOMASK_MODEL = utils.get_model(checkpoint='checkpoint/medsam_vit_b.pth')
def process_box(predictor_model, show_mask, radius_width):
bg_image = st.session_state['image']
width, height = bg_image.size[:2]
container_width = 700
scale = container_width/width
scaled_wh = (container_width, int(height * scale))
if not predictor_model.set_image:
np_image = np.asanyarray(bg_image)
with st.spinner(text="Extracing embeddings.."):
predictor_model.set_image(np_image)
if 'result_image' not in st.session_state:
st.session_state.result_image = bg_image.resize(scaled_wh)
box_canvas = st_canvas(
fill_color="rgba(255, 255, 0, 0)",
background_image = bg_image,
drawing_mode='rect',
stroke_color = "rgba(0, 255, 0, 0.6)",
stroke_width = radius_width,
width = container_width,
height = height * scale,
point_display_radius = 12,
update_streamlit=True,
key="box"
)
if not show_mask:
if 'rerun_once' in st.session_state:
if st.session_state.rerun_once:
st.session_state.rerun_once = False
else:
st.session_state.rerun_once = True
st.session_state.display_result = True
st.warning("Mask view is disabled", icon="❗")
if st.session_state.rerun_once:
st.experimental_rerun()
else:
return np.asarray(bg_image)
elif box_canvas.json_data is not None:
df = pd.json_normalize(box_canvas.json_data["objects"])
center_point,center_label,input_box = [],[],[]
center_point, center_label, input_box = [], [], []
for _, row in df.iterrows():
x, y, w,h = row["left"], row["top"], row["width"], row["height"]
x = int(x/scale)
y = int(y/scale)
w = int(w/scale)
h = int(h/scale)
center_point.append([x+w/2,y+h/2])
center_label.append([1])
input_box.append([x,y,x+w,y+h])
masks = []
if predictor_model:
masks = utils.model_predict_masks_box(predictor_model, center_point, center_label, input_box)
if len(masks) == 0:
st.warning("No Masks Found", icon="❗")
return np.asarray(bg_image)
bg_image = np.asarray(bg_image)
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
im_masked = utils.show_click(masks,color)
im_masked = Image.fromarray(im_masked).convert('RGBA')
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
result_image = result_image.resize(scaled_wh)
st.session_state.display_result = True
return result_image
else:
return np.asarray(bg_image)
return np.asarray(bg_image)
def process_click(predictor_model, show_mask, radius_width):
bg_image = st.session_state['image']
width, height = bg_image.size[:2]
container_width = 700
scale = container_width/width
scaled_wh = (container_width, int(height * scale))
if not predictor_model.set_image:
np_image = np.asanyarray(bg_image)
with st.spinner(text="Extracing embeddings.."):
predictor_model.set_image(np_image)
if 'result_image' not in st.session_state:
st.session_state.result_image = bg_image.resize(scaled_wh)
click_canvas = st_canvas(
fill_color="rgba(255, 255, 0, 0.8)",
background_image = bg_image,
drawing_mode='point',
width = container_width,
height = height * scale,
point_display_radius = radius_width,
stroke_width=2,
update_streamlit=True,
key="point",)
if not show_mask:
if 'rerun_once' in st.session_state:
if st.session_state.rerun_once:
st.session_state.rerun_once = False
else:
st.session_state.rerun_once = True
st.session_state.display_result = True
st.warning("Mask view is disabled", icon="❗")
if st.session_state.rerun_once:
st.experimental_rerun()
else:
return np.asarray(bg_image)
elif click_canvas.json_data is not None:
df = pd.json_normalize(click_canvas.json_data["objects"])
input_points = []
input_labels = []
for _, row in df.iterrows():
x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2)
x = int(x/scale)
y = int(y/scale)
input_points.append([x, y])
if row['fill'] == "rgba(0, 255, 0, 0.8)":
input_labels.append(1)
else:
input_labels.append(0)
masks = []
if predictor_model:
masks = utils.model_predict_masks_click(predictor_model, input_points, input_labels)
if len(masks) == 0:
st.warning("No Masks Found", icon="❗")
return np.asarray(bg_image)
bg_image = np.asarray(bg_image)
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
im_masked = utils.show_click(masks,color)
im_masked = Image.fromarray(im_masked).convert('RGBA')
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
result_image = result_image.resize(scaled_wh)
st.session_state.display_result = True
return result_image
else:
return np.asarray(bg_image)
return np.asarray(bg_image)
def process_everything(automask_model, show_mask, radius_width):
bg_image = st.session_state['image']
width, height = bg_image.size[:2]
container_width = 700
scale = container_width/width
scaled_wh = (container_width, int(height * scale))
if 'result_image' not in st.session_state:
st.session_state.result_image = bg_image.resize(scaled_wh)
dummy_canvas = st_canvas(
fill_color="rgba(255, 255, 0, 0.8)",
background_image = bg_image,
drawing_mode='freedraw',
width = container_width,
height = height * scale,
point_display_radius = radius_width,
stroke_width=2,
update_streamlit=False,
key="everything",)
if not show_mask:
if 'rerun_once' in st.session_state:
if st.session_state.rerun_once:
st.session_state.rerun_once = False
else:
st.session_state.rerun_once = True
st.session_state.display_result = True
st.warning("Mask view is disabled", icon="❗")
if st.session_state.rerun_once:
st.experimental_rerun()
else:
return np.asarray(bg_image)
if automask_model:
bg_image = np.asarray(bg_image)
masks = utils.model_predict_masks_everything(automask_model, bg_image)
im_masked = utils.show_everything(masks)
im_masked = Image.fromarray(im_masked).convert('RGBA')
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
result_image = result_image.resize(scaled_wh)
st.session_state.display_result = True
return result_image
else:
return np.asarray(bg_image)
def image_preprocess_callback(predictor_model, option):
if 'uploaded_image' not in st.session_state:
return
if st.session_state.uploaded_image is not None:
with st.spinner(text="Uploading image..."):
image = Image.open(st.session_state.uploaded_image).convert("RGB")
if predictor_model and option != 'Everything':
np_image = np.asanyarray(image)
with st.spinner(text="Extracing embeddings.."):
predictor_model.set_image(np_image)
else:
if predictor_model:
predictor_model.reset_image()
st.session_state.image = image
else:
with st.spinner(text="Cleaning up!"):
if 'display_result' in st.session_state:
st.session_state.display_result = False
if 'image' in st.session_state:
st.session_state.image = None
if 'result_image' in st.session_state:
del st.session_state['result_image']
if predictor_model:
predictor_model.reset_image()
def main():
with open('index.html', encoding='utf-8') as f:
html_content = f.read()
st.markdown(html_content, unsafe_allow_html=True)
st.markdown('### Model Architecture')
st.image('figures/medsam.png', caption="Segment Anything - MedSAM", width=600)
st.markdown('### Demo')
with st.container():
col1, col2, col3 = st.columns(3)
with col1:
option = st.selectbox('Segmentation mode', ('Click', 'Box', 'Everything'))
with col2:
st.write("Show or Hide Mask")
show_mask = st.checkbox('Show mask',value = True)
with col3:
radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1)
with st.container():
st.write("Upload Image")
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(PREDICTOR_MODEL, option,), label_visibility="hidden")
result_image = None
canvas_input, canvas_output = st.columns(2)
if 'image' in st.session_state:
with canvas_input:
st.write("Select Interest Area/Objects")
if st.session_state.image is not None:
with st.spinner(text="Computing masks"):
if option == 'Click':
result_image = process_click(PREDICTOR_MODEL, show_mask, radius_width)
elif option == 'Box':
result_image = process_box(PREDICTOR_MODEL, show_mask, radius_width)
else:
result_image = process_everything(AUTOMASK_MODEL, show_mask, radius_width)
if 'display_result' in st.session_state:
if st.session_state.display_result:
with canvas_output:
if result_image is not None:
st.write("Result")
st.image(result_image)
else:
st.warning("No result found, please set input prompt", icon="⚠️")
st.success('Process completed!', icon="✅")
else:
st.cache_data.clear()
if __name__ == '__main__':
main()