from typing import List, Optional, Tuple, Union import cv2 import numpy as np import torch import torchvision.transforms.functional as TF from PIL import Image from torch import Tensor, nn import torch from skimage.filters import threshold_otsu from s_multimae.da.base_da import BaseDataAugmentation from s_multimae.model_pl import ModelPL from s_multimae.visualizer import apply_vis_to_image from .base_model import BaseRGBDModel from .app_utils import get_size, normalize from .depth_model import BaseDepthModel # Environment torch.set_grad_enabled(False) from .device import device print(f"device: {device}") def post_processing_depth(depth: np.ndarray) -> np.ndarray: depth = (normalize(depth) * 255).astype(np.uint8) return cv2.applyColorMap(depth, cv2.COLORMAP_OCEAN) def base_inference( depth_model: BaseDepthModel, sod_model: BaseRGBDModel, da: BaseDataAugmentation, raw_image: Union[Image.Image, np.ndarray], raw_depth: Optional[Union[Image.Image, np.ndarray]] = None, color: np.ndarray = None, num_sets_of_salient_objects: int = 1, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Inference a pair of rgb image and depth image if depth image is not provided, the depth_model will predict a depth image based on image """ origin_size = get_size(raw_image) # Predict depth image = TF.to_tensor(raw_image) origin_shape = image.shape if raw_depth is None: depth: Tensor = depth_model.forward(image) else: depth = TF.to_tensor(raw_depth) # Preprocessing image, depth = da.forward( raw_image, depth.cpu().detach().squeeze(0).numpy(), is_transform=False ) # Inference sms = sod_model.inference(image, depth, origin_shape, num_sets_of_salient_objects) # Postprocessing sods = [] for sm in sms: binary_mask = np.array(sm) t = threshold_otsu(binary_mask) binary_mask[binary_mask < t] = 0.0 binary_mask[binary_mask >= t] = 1.0 sod = apply_vis_to_image(np.array(raw_image), binary_mask, color) sods.append(sod) depth = depth.permute(1, 2, 0).detach().cpu().numpy() depth = cv2.resize(depth, origin_size) depth = post_processing_depth(depth) return depth, sods, [e / 255.0 for e in sms] def transform_images(inputs: List[Image.Image], transform: nn.Module) -> Tensor: if len(inputs) == 1: return transform(inputs[0]).unsqueeze(0) return torch.cat([transform(input).unsqueeze(0) for input in inputs])