Spaces:
Sleeping
Sleeping
import time | |
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from s_multimae.da.base_da import BaseDataAugmentation | |
from .base_model import BaseRGBDModel | |
from .depth_model import BaseDepthModel | |
from .model import base_inference | |
def image_inference( | |
depth_model: BaseDepthModel, | |
sod_model: BaseRGBDModel, | |
da: BaseDataAugmentation, | |
color: np.ndarray, | |
) -> None: | |
if "depth" not in st.session_state: | |
st.session_state.depth = None | |
col1, col2 = st.columns(2) | |
image: Image = None | |
# depth: Image = None | |
def file_uploader_on_change(): | |
st.session_state.depth = None | |
with col1: | |
img_file_buffer = st.file_uploader( | |
"Upload an RGB image", | |
key="img_file_buffer", | |
type=["png", "jpg", "jpeg"], | |
on_change=file_uploader_on_change, | |
) | |
if img_file_buffer is not None: | |
image = Image.open(img_file_buffer).convert("RGB") | |
st.image(image, caption="RGB") | |
with col2: | |
depth_file_buffer = st.file_uploader( | |
"Upload a depth image (Optional)", | |
key="depth_file_buffer", | |
type=["png", "jpg", "jpeg"], | |
) | |
if depth_file_buffer is not None: | |
st.session_state.depth = Image.open(depth_file_buffer).convert("L") | |
if st.session_state.depth is not None: | |
st.image(st.session_state.depth, caption="Depth") | |
if sod_model.cfg.ground_truth_version == 6: | |
num_sets_of_salient_objects = st.number_input( | |
"Number of sets of salient objects", value=1, min_value=1, max_value=10 | |
) | |
else: | |
num_sets_of_salient_objects = 1 | |
is_predict = st.button( | |
"Predict Salient Objects", | |
key="predict_salient_objects", | |
disabled=img_file_buffer is None, | |
) | |
if is_predict: | |
with st.spinner( | |
"Processing... (It usually takes about 30s - 1 minute per a set of salient objects)" | |
): | |
start_time = time.time() | |
pred_depth, pred_sods, pred_sms = base_inference( | |
depth_model, | |
sod_model, | |
da, | |
image, | |
st.session_state.depth, | |
color, | |
num_sets_of_salient_objects, | |
) | |
if st.session_state.depth is None: | |
st.session_state.depth = Image.fromarray(pred_depth).convert("L") | |
col2.image(st.session_state.depth, "Pseudo-depth") | |
if num_sets_of_salient_objects == 1: | |
st.warning( | |
"HINT: To view a wider variety of sets of salient objects, try to increase the number of sets the model can produce." | |
) | |
elif num_sets_of_salient_objects > 1: | |
st.warning( | |
"NOTE: As single-GT accounts for 77.61% of training samples, the model may not consistently yield different sets. The best approach is to gradually increase the number of sets of salient objects until you achieve the desired result." | |
) | |
st.info(f"Inference time: {time.time() - start_time:.4f} seconds") | |
sod_cols = st.columns(len(pred_sods)) | |
for i, (pred_sod, pred_sm) in enumerate(zip(pred_sods, pred_sms)): | |
with sod_cols[i]: | |
st.image(pred_sod, "Salient Objects (Otsu threshold)") | |
st.image(pred_sm, "Salient Map") | |