import streamlit as st import torch # import numpy as np import cv2 import wget import os from PIL import Image from streamlit_drawable_canvas import st_canvas from isegm.inference import clicker as ck from isegm.inference import utils from isegm.inference.predictors import get_predictor # Model Path prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" models = { "RITM": "ritm_coco_lvis_h18_itermask.pth", } # Items in the sidebar. model = st.sidebar.selectbox("Select a Model:", tuple(models.keys())) threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative")) image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) # Objects for prediction. clicker = ck.Clicker() device = torch.device("cpu") predictor = None with st.spinner("Wait for downloading a model..."): if not os.path.exists(models[model]): _ = wget.download(f"{prefix}/{models[model]}") with st.spinner("Wait for loading a model..."): model = utils.load_is_model(models[model], device, cpu_dist_maps=True) predictor_params = {"brs_mode": "NoBRS"} predictor = get_predictor(model, device=device, **predictor_params) # Create a canvas component. #image = None #if image_path: # image = Image.open(image_path) #canvas_height, canvas_width = 600, 600 #pos_color, neg_color = "#3498DB", "#C70039" #st.title("Canvas:") #canvas_result = st_canvas( # fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity # stroke_width=3, # stroke_color=pos_color if marking_type == "positive" else neg_color, # background_color="#eee", # background_image=image, # update_streamlit=True, # drawing_mode="point", # point_display_radius=3, # key="canvas", # width=canvas_width, # height=canvas_height, #) # Check the user inputs ans execute predictions. #st.title("Prediction:") #if canvas_result.json_data and canvas_result.json_data["objects"] and image: # objects = canvas_result.json_data["objects"] # image_width, image_height = image.size # ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width # # err_x, err_y = 5.5, 1.0 # pos_clicks, neg_clicks = [], [] # for click in objects: # x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h # x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) # # is_positive = click["stroke"] == pos_color # click = ck.Click(is_positive=is_positive, coords=(y, x)) # clicker.add_click(click) # # # prediction. # pred = None # predictor.set_input_image(np.array(image)) # with st.spinner("Wait for prediction..."): # pred = predictor.get_prediction(clicker, prev_mask=None) # pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC) # pred = np.where(pred > threshold, 1.0, 0) # st.image(pred, caption="")