curt-park's picture
Comment most of lines for debugging
98ddf8e
raw
history blame
3.03 kB
import streamlit as st
import torch
# import numpy as np
import cv2
import wget
import os
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import get_predictor
# Model Path
prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
models = {
"RITM": "ritm_coco_lvis_h18_itermask.pth",
}
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
# Objects for prediction.
clicker = ck.Clicker()
device = torch.device("cpu")
predictor = None
with st.spinner("Wait for downloading a model..."):
if not os.path.exists(models[model]):
_ = wget.download(f"{prefix}/{models[model]}")
with st.spinner("Wait for loading a model..."):
model = utils.load_is_model(models[model], device, cpu_dist_maps=True)
predictor_params = {"brs_mode": "NoBRS"}
predictor = get_predictor(model, device=device, **predictor_params)
# Create a canvas component.
#image = None
#if image_path:
# image = Image.open(image_path)
#canvas_height, canvas_width = 600, 600
#pos_color, neg_color = "#3498DB", "#C70039"
#st.title("Canvas:")
#canvas_result = st_canvas(
# fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
# stroke_width=3,
# stroke_color=pos_color if marking_type == "positive" else neg_color,
# background_color="#eee",
# background_image=image,
# update_streamlit=True,
# drawing_mode="point",
# point_display_radius=3,
# key="canvas",
# width=canvas_width,
# height=canvas_height,
#)
# Check the user inputs ans execute predictions.
#st.title("Prediction:")
#if canvas_result.json_data and canvas_result.json_data["objects"] and image:
# objects = canvas_result.json_data["objects"]
# image_width, image_height = image.size
# ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
#
# err_x, err_y = 5.5, 1.0
# pos_clicks, neg_clicks = [], []
# for click in objects:
# x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
# x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
#
# is_positive = click["stroke"] == pos_color
# click = ck.Click(is_positive=is_positive, coords=(y, x))
# clicker.add_click(click)
#
# # prediction.
# pred = None
# predictor.set_input_image(np.array(image))
# with st.spinner("Wait for prediction..."):
# pred = predictor.get_prediction(clicker, prev_mask=None)
# pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
# pred = np.where(pred > threshold, 1.0, 0)
# st.image(pred, caption="")