import streamlit as st import torch import numpy as np import cv2 import wget import os from PIL import Image from streamlit_drawable_canvas import st_canvas from isegm.inference import clicker as ck from isegm.inference import utils from isegm.inference.predictors import get_predictor @st.cache_data def load_model(model_path, device): model = utils.load_is_model(model_path, device, cpu_dist_maps=True) predictor_params = {"brs_mode": "NoBRS"} predictor = get_predictor(model, device=device, **predictor_params) return predictor # Objects in the global scope url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"} clicker = ck.Clicker() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pos_color, neg_color = "#3498DB", "#C70039" canvas_height, canvas_width = 600, 600 err_x, err_y = 5.5, 1.0 predictor = None image = None # Items in the sidebar. model = st.sidebar.selectbox("Select a Model:", tuple(models.keys())) threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative")) image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) # Objects for prediction. with st.spinner("Wait for downloading a model..."): if not os.path.exists(models[model]): _ = wget.download(f"{url_prefix}/{models[model]}") with st.spinner("Wait for loading a model..."): predictor = load_model(models[model], device) # Create a canvas component. if image_path: image = Image.open(image_path).convert("RGB") st.title("Canvas:") canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=3, stroke_color=pos_color if marking_type == "positive" else neg_color, background_color="#eee", background_image=image, update_streamlit=True, drawing_mode="point", point_display_radius=3, key="canvas", width=canvas_width, height=canvas_height, ) # Check the user inputs ans execute predictions. st.title("Prediction:") if canvas_result.json_data and canvas_result.json_data["objects"] and image: objects = canvas_result.json_data["objects"] image_width, image_height = image.size ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width pos_clicks, neg_clicks = [], [] for click in objects: x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) is_positive = click["stroke"] == pos_color click = ck.Click(is_positive=is_positive, coords=(y, x)) clicker.add_click(click) # Run prediction. pred = None predictor.set_input_image(np.array(image)) init_mask = torch.zeros((1, 1, image_height, image_width), device=device) with st.spinner("Wait for prediction..."): pred = predictor.get_prediction(clicker, prev_mask=init_mask) pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC) pred = np.where(pred > threshold, 1.0, 0) # Show the prediction result. st.image(pred, caption="")