haritsahm
Add run files
13c0669
raw
history blame
4.77 kB
import streamlit as st
st.set_page_config(layout="wide")
import random
import numpy as np
import pandas as pd
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from utils import utils
SAM_MODEL = utils.get_model('vit_b')
def click_process(model, show_mask, radius_width):
bg_image = st.session_state['image']
width, height = bg_image.size[:2]
container_width = 700
scale = container_width/width
scaled_hw = (container_width, int(height * scale))
if 'result_image' not in st.session_state:
st.session_state.result_image = bg_image.resize(scaled_hw)
canvas_result = st_canvas(
fill_color="rgba(255, 255, 0, 0.8)",
background_image = bg_image,
drawing_mode='point',
width = container_width,
height = height * scale,
point_display_radius = radius_width,
stroke_width=2,
update_streamlit=True,
key="point",)
# ! Warn: Can cause infinite loop or high cpu usage
if not show_mask:
print("rerun no mask")
st.experimental_rerun()
elif canvas_result.json_data is not None:
df = pd.json_normalize(canvas_result.json_data["objects"])
input_points = []
input_labels = []
for _, row in df.iterrows():
x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2)
x = int(x/scale)
y = int(y/scale)
input_points.append([x, y])
if row['fill'] == "rgba(0, 255, 0, 0.8)":
input_labels.append(1)
else:
input_labels.append(0)
masks = []
if model:
masks = utils.model_predict_masks_click(model, input_points, input_labels)
if len(masks) == 0:
return bg_image
bg_image = np.asarray(bg_image)
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0)
im_masked = utils.show_click(masks,color)
im_masked = Image.fromarray(im_masked).convert('RGBA')
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB")
result_image = result_image.resize(scaled_hw)
return result_image
else:
return np.asarray(bg_image)
return np.asarray(bg_image)
def image_preprocess_callback(model):
if 'uploaded_image' not in st.session_state:
return
if st.session_state.uploaded_image is not None:
with st.spinner(text="Uploading image..."):
image = Image.open(st.session_state.uploaded_image).convert("RGB")
if model:
np_image = np.asanyarray(image)
with st.spinner(text="Extracing embeddings.."):
model.set_image(np_image)
st.session_state.image = image
else:
with st.spinner(text="Cleaning up!"):
if 'image' in st.session_state:
st.session_state.image = None
if 'result_image' in st.session_state:
del st.session_state['result_image']
if model:
model.reset_image()
def main():
with open('index.html', encoding='utf-8') as f:
html_content = f.read()
st.markdown(html_content, unsafe_allow_html=True)
with st.container():
col1, col2, col3 = st.columns(3)
with col1:
option = st.selectbox('Segmentation mode', ('Click'))
with col2:
st.write("Show or Hide Mask")
show_mask = st.checkbox('Show mask',value = True)
with col3:
radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1)
with st.container():
st.write("Upload Image")
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden")
result_image = None
canvas_input, canvas_output = st.columns(2)
if 'image' in st.session_state:
with canvas_input:
st.write("Select Interest Area/Objects")
if st.session_state.image is not None:
if option == 'Click':
with st.spinner(text="Computing masks"):
result_image = click_process(SAM_MODEL, show_mask, radius_width)
with canvas_output:
if result_image is not None:
st.write("Result")
st.image(result_image)
else:
print(f'embedding is empty - {option} - {show_mask} - {radius_width}')
# if 'image' in st.session_state:
# if st.session_state.image is None:
# st.session_state.clear()
if __name__ == '__main__':
main()