Spaces:
Sleeping
Sleeping
File size: 1,896 Bytes
6e9c433 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from typing import Dict, List, Optional
from torch import nn, Tensor
from .model.multimae import generate_smultimae_model as generate_smultimae_model_v1
from .configs.base_config import base_cfg
class RGBDModel(nn.Module):
def __init__(self, cfg: base_cfg):
super(RGBDModel, self).__init__()
self.inputs = cfg.inputs
self.outputs = cfg.outputs
self.is_no_depth = cfg.is_inference_with_no_depth
if cfg.model_version == 1:
self.model, self.opt_params = generate_smultimae_model_v1(cfg)
else:
raise Exception(f"Unsupported model version {cfg.model_version}")
def encode_decode(
self,
images: Tensor,
depths: Optional[Tensor],
gt_index_lst: Optional[List[int]] = None,
max_gts_lst: Optional[List[int]] = None,
) -> Dict[str, Tensor]:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input.
Returns:
{
"sod": Tensor,
"depth": Optional[Tensor],
"rgb": Optional[tensor],
}
"""
inputs = {"rgb": images}
if "depth" in self.inputs:
inputs["depth"] = depths
return self.model.forward(inputs, gt_index_lst, max_gts_lst)
def forward(
self,
images: Tensor,
depths: Optional[Tensor],
gt_index_lst: Optional[List[int]] = None,
max_gts_lst: Optional[List[int]] = None,
) -> Dict[str, Tensor]:
return self.encode_decode(images, depths, gt_index_lst, max_gts_lst)
def inference(
self,
images: Tensor,
depths: Optional[Tensor],
gt_index_lst: Optional[List[int]] = None,
max_gts_lst: Optional[List[int]] = None,
) -> Dict[str, Tensor]:
return self.encode_decode(images, depths, gt_index_lst, max_gts_lst)
|