S-MultiMAE / streamlit_apps /app_utils /image_inference.py
thinh-researcher's picture
- Update docs
c4e40c4
import time
import numpy as np
import streamlit as st
from PIL import Image
from s_multimae.da.base_da import BaseDataAugmentation
from .base_model import BaseRGBDModel
from .depth_model import BaseDepthModel
from .model import base_inference
def image_inference(
depth_model: BaseDepthModel,
sod_model: BaseRGBDModel,
da: BaseDataAugmentation,
color: np.ndarray,
) -> None:
if "depth" not in st.session_state:
st.session_state.depth = None
col1, col2 = st.columns(2)
image: Image = None
# depth: Image = None
def file_uploader_on_change():
st.session_state.depth = None
with col1:
img_file_buffer = st.file_uploader(
"Upload an RGB image",
key="img_file_buffer",
type=["png", "jpg", "jpeg"],
on_change=file_uploader_on_change,
)
if img_file_buffer is not None:
image = Image.open(img_file_buffer).convert("RGB")
st.image(image, caption="RGB")
with col2:
depth_file_buffer = st.file_uploader(
"Upload a depth image (Optional)",
key="depth_file_buffer",
type=["png", "jpg", "jpeg"],
)
if depth_file_buffer is not None:
st.session_state.depth = Image.open(depth_file_buffer).convert("L")
if st.session_state.depth is not None:
st.image(st.session_state.depth, caption="Depth")
if sod_model.cfg.ground_truth_version == 6:
num_sets_of_salient_objects = st.number_input(
"Number of sets of salient objects", value=1, min_value=1, max_value=10
)
else:
num_sets_of_salient_objects = 1
is_predict = st.button(
"Predict Salient Objects",
key="predict_salient_objects",
disabled=img_file_buffer is None,
)
if is_predict:
with st.spinner(
"Processing... (It usually takes about 30s - 1 minute per a set of salient objects)"
):
start_time = time.time()
pred_depth, pred_sods, pred_sms = base_inference(
depth_model,
sod_model,
da,
image,
st.session_state.depth,
color,
num_sets_of_salient_objects,
)
if st.session_state.depth is None:
st.session_state.depth = Image.fromarray(pred_depth).convert("L")
col2.image(st.session_state.depth, "Pseudo-depth")
if num_sets_of_salient_objects == 1:
st.warning(
"HINT: To view a wider variety of sets of salient objects, try to increase the number of sets the model can produce."
)
elif num_sets_of_salient_objects > 1:
st.warning(
"NOTE: As single-GT accounts for 77.61% of training samples, the model may not consistently yield different sets. The best approach is to gradually increase the number of sets of salient objects until you achieve the desired result."
)
st.info(f"Inference time: {time.time() - start_time:.4f} seconds")
sod_cols = st.columns(len(pred_sods))
for i, (pred_sod, pred_sm) in enumerate(zip(pred_sods, pred_sms)):
with sod_cols[i]:
st.image(pred_sod, "Salient Objects (Otsu threshold)")
st.image(pred_sm, "Salient Map")