curt-park's picture
Cache loaded model
7d80b1e
raw
history blame
3.3 kB
import streamlit as st
import torch
import numpy as np
import cv2
import wget
import os
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import get_predictor
@st.cache_data
def load_model(model_path, device):
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
predictor_params = {"brs_mode": "NoBRS"}
predictor = get_predictor(model, device=device, **predictor_params)
return predictor
# Objects in the global scope
url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
clicker = ck.Clicker()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pos_color, neg_color = "#3498DB", "#C70039"
canvas_height, canvas_width = 600, 600
err_x, err_y = 5.5, 1.0
predictor = None
image = None
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
# Objects for prediction.
with st.spinner("Wait for downloading a model..."):
if not os.path.exists(models[model]):
_ = wget.download(f"{url_prefix}/{models[model]}")
with st.spinner("Wait for loading a model..."):
predictor = load_model(models[model], device)
# Create a canvas component.
if image_path:
image = Image.open(image_path).convert("RGB")
st.title("Canvas:")
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=3,
stroke_color=pos_color if marking_type == "positive" else neg_color,
background_color="#eee",
background_image=image,
update_streamlit=True,
drawing_mode="point",
point_display_radius=3,
key="canvas",
width=canvas_width,
height=canvas_height,
)
# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
objects = canvas_result.json_data["objects"]
image_width, image_height = image.size
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
pos_clicks, neg_clicks = [], []
for click in objects:
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
is_positive = click["stroke"] == pos_color
click = ck.Click(is_positive=is_positive, coords=(y, x))
clicker.add_click(click)
# Run prediction.
pred = None
predictor.set_input_image(np.array(image))
init_mask = torch.zeros((1, 1, image_height, image_width), device=device)
with st.spinner("Wait for prediction..."):
pred = predictor.get_prediction(clicker, prev_mask=init_mask)
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
pred = np.where(pred > threshold, 1.0, 0)
# Show the prediction result.
st.image(pred, caption="")