File size: 3,298 Bytes
e82cf8b
2cdd41c
96d7d21
2cdd41c
 
 
e82cf8b
2cdd41c
 
 
 
 
 
 
7d80b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
7d80b1e
2cdd41c
 
7d80b1e
2cdd41c
 
96d7d21
2535e18
ad0c87f
96d7d21
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
96d7d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad0c87f
96d7d21
 
ad0c87f
 
96d7d21
ad0c87f
96d7d21
 
ad0c87f
 
96d7d21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
import torch
import numpy as np
import cv2
import wget
import os

from PIL import Image
from streamlit_drawable_canvas import st_canvas

from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import get_predictor

@st.cache_data
def load_model(model_path, device):
    model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
    predictor_params = {"brs_mode": "NoBRS"}
    predictor = get_predictor(model, device=device, **predictor_params)
    return predictor


# Objects in the global scope
url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
clicker = ck.Clicker()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pos_color, neg_color = "#3498DB", "#C70039"
canvas_height, canvas_width = 600, 600
err_x, err_y = 5.5, 1.0
predictor = None
image = None

# Items in the sidebar.
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])

# Objects for prediction.
with st.spinner("Wait for downloading a model..."):
    if not os.path.exists(models[model]):
        _ = wget.download(f"{url_prefix}/{models[model]}")

with st.spinner("Wait for loading a model..."):
    predictor = load_model(models[model], device)

# Create a canvas component.
if image_path:
    image = Image.open(image_path).convert("RGB")

st.title("Canvas:")
canvas_result = st_canvas(
        fill_color="rgba(255, 165, 0, 0.3)",  # Fixed fill color with some opacity
        stroke_width=3,
        stroke_color=pos_color if marking_type == "positive" else neg_color,
        background_color="#eee",
        background_image=image,
        update_streamlit=True,
        drawing_mode="point",
        point_display_radius=3,
        key="canvas",
        width=canvas_width,
        height=canvas_height,
)

# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
    objects = canvas_result.json_data["objects"]
    image_width, image_height = image.size
    ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width

    pos_clicks, neg_clicks = [], []
    for click in objects:
        x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
        x, y = min(image_width, max(0, x)), min(image_height, max(0, y))

        is_positive = click["stroke"] == pos_color
        click = ck.Click(is_positive=is_positive, coords=(y, x))
        clicker.add_click(click)

    # Run prediction.
    pred = None
    predictor.set_input_image(np.array(image))
    init_mask = torch.zeros((1, 1, image_height, image_width), device=device)

    with st.spinner("Wait for prediction..."):
        pred = predictor.get_prediction(clicker, prev_mask=init_mask)
    pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
    pred = np.where(pred > threshold, 1.0, 0)

    # Show the prediction result.
    st.image(pred, caption="")