Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import numpy as np | |
import cv2 | |
import wget | |
import os | |
from PIL import Image | |
from streamlit_drawable_canvas import st_canvas | |
from isegm.inference import clicker as ck | |
from isegm.inference import utils | |
from isegm.inference.predictors import get_predictor | |
def load_model(model_path, device): | |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True) | |
predictor_params = {"brs_mode": "NoBRS"} | |
predictor = get_predictor(model, device=device, **predictor_params) | |
return predictor | |
# Objects in the global scope | |
url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" | |
models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"} | |
clicker = ck.Clicker() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pos_color, neg_color = "#3498DB", "#C70039" | |
canvas_height, canvas_width = 600, 600 | |
err_x, err_y = 5.5, 1.0 | |
predictor = None | |
image = None | |
# Items in the sidebar. | |
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys())) | |
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) | |
marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative")) | |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) | |
# Objects for prediction. | |
with st.spinner("Wait for downloading a model..."): | |
if not os.path.exists(models[model]): | |
_ = wget.download(f"{url_prefix}/{models[model]}") | |
with st.spinner("Wait for loading a model..."): | |
predictor = load_model(models[model], device) | |
# Create a canvas component. | |
if image_path: | |
image = Image.open(image_path).convert("RGB") | |
st.title("Canvas:") | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity | |
stroke_width=3, | |
stroke_color=pos_color if marking_type == "positive" else neg_color, | |
background_color="#eee", | |
background_image=image, | |
update_streamlit=True, | |
drawing_mode="point", | |
point_display_radius=3, | |
key="canvas", | |
width=canvas_width, | |
height=canvas_height, | |
) | |
# Check the user inputs ans execute predictions. | |
st.title("Prediction:") | |
if canvas_result.json_data and canvas_result.json_data["objects"] and image: | |
objects = canvas_result.json_data["objects"] | |
image_width, image_height = image.size | |
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width | |
pos_clicks, neg_clicks = [], [] | |
for click in objects: | |
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h | |
x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) | |
is_positive = click["stroke"] == pos_color | |
click = ck.Click(is_positive=is_positive, coords=(y, x)) | |
clicker.add_click(click) | |
# Run prediction. | |
pred = None | |
predictor.set_input_image(np.array(image)) | |
init_mask = torch.zeros((1, 1, image_height, image_width), device=device) | |
with st.spinner("Wait for prediction..."): | |
pred = predictor.get_prediction(clicker, prev_mask=init_mask) | |
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC) | |
pred = np.where(pred > threshold, 1.0, 0) | |
# Show the prediction result. | |
st.image(pred, caption="") | |