curt-park's picture
Uncomment all lines
96d7d21
raw
history blame
2.98 kB
import streamlit as st
import torch
import numpy as np
import cv2
import wget
import os
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import get_predictor
# Model Path
prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
models = {
"RITM": "ritm_coco_lvis_h18_itermask.pth",
}
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
# Objects for prediction.
clicker = ck.Clicker()
device = torch.device("cpu")
predictor = None
with st.spinner("Wait for downloading a model..."):
if not os.path.exists(models[model]):
_ = wget.download(f"{prefix}/{models[model]}")
with st.spinner("Wait for loading a model..."):
model = utils.load_is_model(models[model], device, cpu_dist_maps=True)
predictor_params = {"brs_mode": "NoBRS"}
predictor = get_predictor(model, device=device, **predictor_params)
# Create a canvas component.
image = None
if image_path:
image = Image.open(image_path)
canvas_height, canvas_width = 600, 600
pos_color, neg_color = "#3498DB", "#C70039"
st.title("Canvas:")
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=3,
stroke_color=pos_color if marking_type == "positive" else neg_color,
background_color="#eee",
background_image=image,
update_streamlit=True,
drawing_mode="point",
point_display_radius=3,
key="canvas",
width=canvas_width,
height=canvas_height,
)
# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
objects = canvas_result.json_data["objects"]
image_width, image_height = image.size
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
err_x, err_y = 5.5, 1.0
pos_clicks, neg_clicks = [], []
for click in objects:
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
is_positive = click["stroke"] == pos_color
click = ck.Click(is_positive=is_positive, coords=(y, x))
clicker.add_click(click)
# prediction.
pred = None
predictor.set_input_image(np.array(image))
with st.spinner("Wait for prediction..."):
pred = predictor.get_prediction(clicker, prev_mask=None)
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
pred = np.where(pred > threshold, 1.0, 0)
st.image(pred, caption="")