import streamlit as st import torch import numpy as np import cv2 import wget import os from PIL import Image from streamlit_drawable_canvas import st_canvas from isegm.inference import clicker as ck from isegm.inference import utils from isegm.inference.predictors import get_predictor # Model Path prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" models = { "RITM": "ritm_coco_lvis_h18_itermask.pth", } # Items in the sidebar. model = st.sidebar.selectbox("Select a Model:", tuple(models.keys())) threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative")) image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) # Objects for prediction. clicker = ck.Clicker() device = torch.device("cpu") predictor = None with st.spinner("Wait for downloading a model..."): if not os.path.exists(models[model]): _ = wget.download(f"{prefix}/{models[model]}") with st.spinner("Wait for loading a model..."): model = utils.load_is_model(models[model], device, cpu_dist_maps=True) predictor_params = {"brs_mode": "NoBRS"} predictor = get_predictor(model, device=device, **predictor_params) # Create a canvas component. image = None if image_path: image = Image.open(image_path) canvas_height, canvas_width = 600, 600 pos_color, neg_color = "#3498DB", "#C70039" st.title("Canvas:") canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=3, stroke_color=pos_color if marking_type == "positive" else neg_color, background_color="#eee", background_image=image, update_streamlit=True, drawing_mode="point", point_display_radius=3, key="canvas", width=canvas_width, height=canvas_height, ) # Check the user inputs ans execute predictions. st.title("Prediction:") if canvas_result.json_data and canvas_result.json_data["objects"] and image: objects = canvas_result.json_data["objects"] image_width, image_height = image.size ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width err_x, err_y = 5.5, 1.0 pos_clicks, neg_clicks = [], [] for click in objects: x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) is_positive = click["stroke"] == pos_color click = ck.Click(is_positive=is_positive, coords=(y, x)) clicker.add_click(click) # prediction. pred = None predictor.set_input_image(np.array(image)) with st.spinner("Wait for prediction..."): pred = predictor.get_prediction(clicker, prev_mask=None) pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC) pred = np.where(pred > threshold, 1.0, 0) st.image(pred, caption="")