Spaces:
Runtime error
Runtime error
File size: 3,298 Bytes
e82cf8b 2cdd41c 96d7d21 2cdd41c e82cf8b 2cdd41c 7d80b1e 2cdd41c 7d80b1e 2cdd41c 7d80b1e 2cdd41c 96d7d21 2535e18 ad0c87f 96d7d21 2cdd41c 96d7d21 ad0c87f 96d7d21 ad0c87f 96d7d21 ad0c87f 96d7d21 ad0c87f 96d7d21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import streamlit as st
import torch
import numpy as np
import cv2
import wget
import os
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import get_predictor
@st.cache_data
def load_model(model_path, device):
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
predictor_params = {"brs_mode": "NoBRS"}
predictor = get_predictor(model, device=device, **predictor_params)
return predictor
# Objects in the global scope
url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
clicker = ck.Clicker()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pos_color, neg_color = "#3498DB", "#C70039"
canvas_height, canvas_width = 600, 600
err_x, err_y = 5.5, 1.0
predictor = None
image = None
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
# Objects for prediction.
with st.spinner("Wait for downloading a model..."):
if not os.path.exists(models[model]):
_ = wget.download(f"{url_prefix}/{models[model]}")
with st.spinner("Wait for loading a model..."):
predictor = load_model(models[model], device)
# Create a canvas component.
if image_path:
image = Image.open(image_path).convert("RGB")
st.title("Canvas:")
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=3,
stroke_color=pos_color if marking_type == "positive" else neg_color,
background_color="#eee",
background_image=image,
update_streamlit=True,
drawing_mode="point",
point_display_radius=3,
key="canvas",
width=canvas_width,
height=canvas_height,
)
# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
objects = canvas_result.json_data["objects"]
image_width, image_height = image.size
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
pos_clicks, neg_clicks = [], []
for click in objects:
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
is_positive = click["stroke"] == pos_color
click = ck.Click(is_positive=is_positive, coords=(y, x))
clicker.add_click(click)
# Run prediction.
pred = None
predictor.set_input_image(np.array(image))
init_mask = torch.zeros((1, 1, image_height, image_width), device=device)
with st.spinner("Wait for prediction..."):
pred = predictor.get_prediction(clicker, prev_mask=init_mask)
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
pred = np.where(pred > threshold, 1.0, 0)
# Show the prediction result.
st.image(pred, caption="")
|