|
import streamlit as st |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
import random |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from streamlit_drawable_canvas import st_canvas |
|
from utils import utils |
|
|
|
SAM_MODEL = utils.get_model('vit_b') |
|
|
|
|
|
def click_process(model, show_mask, radius_width): |
|
|
|
bg_image = st.session_state['image'] |
|
width, height = bg_image.size[:2] |
|
container_width = 700 |
|
scale = container_width/width |
|
scaled_hw = (container_width, int(height * scale)) |
|
|
|
if 'result_image' not in st.session_state: |
|
st.session_state.result_image = bg_image.resize(scaled_hw) |
|
|
|
canvas_result = st_canvas( |
|
fill_color="rgba(255, 255, 0, 0.8)", |
|
background_image = bg_image, |
|
drawing_mode='point', |
|
width = container_width, |
|
height = height * scale, |
|
point_display_radius = radius_width, |
|
stroke_width=2, |
|
update_streamlit=True, |
|
key="point",) |
|
|
|
|
|
if not show_mask: |
|
print("rerun no mask") |
|
st.experimental_rerun() |
|
|
|
elif canvas_result.json_data is not None: |
|
df = pd.json_normalize(canvas_result.json_data["objects"]) |
|
input_points = [] |
|
input_labels = [] |
|
|
|
for _, row in df.iterrows(): |
|
x, y = int(row["left"] + row["width"]/2), int(row["top"] + row["height"]/2) |
|
x = int(x/scale) |
|
y = int(y/scale) |
|
input_points.append([x, y]) |
|
if row['fill'] == "rgba(0, 255, 0, 0.8)": |
|
input_labels.append(1) |
|
else: |
|
input_labels.append(0) |
|
|
|
masks = [] |
|
if model: |
|
masks = utils.model_predict_masks_click(model, input_points, input_labels) |
|
|
|
if len(masks) == 0: |
|
return bg_image |
|
|
|
bg_image = np.asarray(bg_image) |
|
color = np.concatenate([random.choice(utils.get_color()), np.array([0.6])], axis=0) |
|
im_masked = utils.show_click(masks,color) |
|
im_masked = Image.fromarray(im_masked).convert('RGBA') |
|
result_image = Image.alpha_composite(Image.fromarray(bg_image).convert('RGBA'),im_masked).convert("RGB") |
|
result_image = result_image.resize(scaled_hw) |
|
return result_image |
|
else: |
|
return np.asarray(bg_image) |
|
|
|
return np.asarray(bg_image) |
|
|
|
|
|
def image_preprocess_callback(model): |
|
if 'uploaded_image' not in st.session_state: |
|
return |
|
if st.session_state.uploaded_image is not None: |
|
with st.spinner(text="Uploading image..."): |
|
image = Image.open(st.session_state.uploaded_image).convert("RGB") |
|
if model: |
|
np_image = np.asanyarray(image) |
|
with st.spinner(text="Extracing embeddings.."): |
|
model.set_image(np_image) |
|
st.session_state.image = image |
|
else: |
|
with st.spinner(text="Cleaning up!"): |
|
if 'image' in st.session_state: |
|
st.session_state.image = None |
|
if 'result_image' in st.session_state: |
|
del st.session_state['result_image'] |
|
if model: |
|
model.reset_image() |
|
|
|
def main(): |
|
with open('index.html', encoding='utf-8') as f: |
|
html_content = f.read() |
|
|
|
st.markdown(html_content, unsafe_allow_html=True) |
|
|
|
with st.container(): |
|
col1, col2, col3 = st.columns(3) |
|
|
|
with col1: |
|
option = st.selectbox('Segmentation mode', ('Click')) |
|
|
|
with col2: |
|
st.write("Show or Hide Mask") |
|
show_mask = st.checkbox('Show mask',value = True) |
|
|
|
with col3: |
|
radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1) |
|
|
|
with st.container(): |
|
st.write("Upload Image") |
|
st.file_uploader(label='Upload image',type=['png','jpg','tif'], key='uploaded_image', on_change=image_preprocess_callback, args=(SAM_MODEL,), label_visibility="hidden") |
|
|
|
result_image = None |
|
canvas_input, canvas_output = st.columns(2) |
|
if 'image' in st.session_state: |
|
with canvas_input: |
|
st.write("Select Interest Area/Objects") |
|
if st.session_state.image is not None: |
|
if option == 'Click': |
|
with st.spinner(text="Computing masks"): |
|
result_image = click_process(SAM_MODEL, show_mask, radius_width) |
|
with canvas_output: |
|
if result_image is not None: |
|
st.write("Result") |
|
st.image(result_image) |
|
|
|
else: |
|
print(f'embedding is empty - {option} - {show_mask} - {radius_width}') |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |