import streamlit as st st.set_page_config(layout="wide") import random import numpy as np import pandas as pd from PIL import Image from streamlit_drawable_canvas import st_canvas from utils import utils SAM_MODEL = utils.get_model('vit_b') def click_process(model, show_mask, radius_width): bg_image = st.session_state['image'] width, height = bg_image.size[:2] container_width = 700 scale = container_width/width scaled_hw = (container_width, int(height * scale)) if 'result_image' not in st.session_state: st.session_state.result_image = bg_image.resize(scaled_hw) canvas_result = st_canvas( fill_color="rgba(255, 255, 0, 0.8)", background_image = bg_image, drawing_mode='point', width = container_width, height = height * scale, point_display_radius = radius_width, stroke_width=2, update_streamlit=True, key="point",) # ! Warn: Can cause infinite loop or high cpu usage if not show_mask: print("rerun no mask") st.experimental_rerun() elif canvas_result.json_data is not None: df = pd.json_normalize(canvas_result.json_data["objects"]) input_points = [] input_labels = [] for _, row in df.iterrows(): x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2) x = int(x/scale) y = int(y/scale) input_points.append([x, y]) if row['fill'] == "rgba(0, 255, 0, 0.8)": input_labels.append(1) else: input_labels.append(0) masks = [] if model: masks = utils.model_predict_masks_click(model, input_points, input_labels) if len(masks) == 0: return bg_image bg_image = np.asarray(bg_image) color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) im_masked = utils.show_click(masks,color) im_masked = Image.fromarray(im_masked).convert('RGBA') result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") result_image = result_image.resize(scaled_hw) return result_image else: return np.asarray(bg_image) return np.asarray(bg_image) def image_preprocess_callback(model): if 'uploaded_image' not in st.session_state: return if st.session_state.uploaded_image is not None: with st.spinner(text="Uploading image..."): image = Image.open(st.session_state.uploaded_image).convert("RGB") if model: np_image = np.asanyarray(image) with st.spinner(text="Extracing embeddings.."): model.set_image(np_image) st.session_state.image = image else: with st.spinner(text="Cleaning up!"): if 'image' in st.session_state: st.session_state.image = None if 'result_image' in st.session_state: del st.session_state['result_image'] if model: model.reset_image() def main(): with open('index.html', encoding='utf-8') as f: html_content = f.read() st.markdown(html_content, unsafe_allow_html=True) with st.container(): col1, col2, col3 = st.columns(3) with col1: option = st.selectbox('Segmentation mode', ('Click')) with col2: st.write("Show or Hide Mask") show_mask = st.checkbox('Show mask',value = True) with col3: radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1) with st.container(): st.write("Upload Image") st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden") result_image = None canvas_input, canvas_output = st.columns(2) if 'image' in st.session_state: with canvas_input: st.write("Select Interest Area/Objects") if st.session_state.image is not None: if option == 'Click': with st.spinner(text="Computing masks"): result_image = click_process(SAM_MODEL, show_mask, radius_width) with canvas_output: if result_image is not None: st.write("Result") st.image(result_image) else: print(f'embedding is empty - {option} - {show_mask} - {radius_width}') # if 'image' in st.session_state: # if st.session_state.image is None: # st.session_state.clear() if __name__ == '__main__': main()