Spaces:
Sleeping
Sleeping
from typing import Dict, List, Optional | |
from torch import nn, Tensor | |
from .model.multimae import generate_smultimae_model as generate_smultimae_model_v1 | |
from .configs.base_config import base_cfg | |
class RGBDModel(nn.Module): | |
def __init__(self, cfg: base_cfg): | |
super(RGBDModel, self).__init__() | |
self.inputs = cfg.inputs | |
self.outputs = cfg.outputs | |
self.is_no_depth = cfg.is_inference_with_no_depth | |
if cfg.model_version == 1: | |
self.model, self.opt_params = generate_smultimae_model_v1(cfg) | |
else: | |
raise Exception(f"Unsupported model version {cfg.model_version}") | |
def encode_decode( | |
self, | |
images: Tensor, | |
depths: Optional[Tensor], | |
gt_index_lst: Optional[List[int]] = None, | |
max_gts_lst: Optional[List[int]] = None, | |
) -> Dict[str, Tensor]: | |
"""Encode images with backbone and decode into a semantic segmentation | |
map of the same size as input. | |
Returns: | |
{ | |
"sod": Tensor, | |
"depth": Optional[Tensor], | |
"rgb": Optional[tensor], | |
} | |
""" | |
inputs = {"rgb": images} | |
if "depth" in self.inputs: | |
inputs["depth"] = depths | |
return self.model.forward(inputs, gt_index_lst, max_gts_lst) | |
def forward( | |
self, | |
images: Tensor, | |
depths: Optional[Tensor], | |
gt_index_lst: Optional[List[int]] = None, | |
max_gts_lst: Optional[List[int]] = None, | |
) -> Dict[str, Tensor]: | |
return self.encode_decode(images, depths, gt_index_lst, max_gts_lst) | |
def inference( | |
self, | |
images: Tensor, | |
depths: Optional[Tensor], | |
gt_index_lst: Optional[List[int]] = None, | |
max_gts_lst: Optional[List[int]] = None, | |
) -> Dict[str, Tensor]: | |
return self.encode_decode(images, depths, gt_index_lst, max_gts_lst) | |