import os import torch import torchvision.transforms as transforms import torchvision.transforms.functional as TF from torch import Tensor, nn from .app_utils import count_parameters from .device import device from .dpt.models import DPTDepthModel class BaseDepthModel: def __init__(self, image_size: int) -> None: self.image_size = image_size self.model: nn.Module = None def forward(self, image: Tensor) -> Tensor: """Perform forward inference for an image Input image of shape [c, h, w] Return of shape [c, h, w] """ raise NotImplementedError() def batch_forward(self, images: Tensor) -> Tensor: """Perform forward inference for a batch of images Input images of shape [b, c, h, w] Return of shape [b, c, h, w]""" raise NotImplementedError() def get_number_of_parameters(self) -> int: return count_parameters(self.model) class DPTDepth(BaseDepthModel): def __init__(self, image_size: int) -> None: super().__init__(image_size) print("DPTDepthconstructor") weights_fname = "omnidata_rgb2depth_dpt_hybrid.pth" weights_path = os.path.join("weights", weights_fname) if not os.path.isfile(weights_path): from huggingface_hub import hf_hub_download hf_hub_download( repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname, local_dir="weights", ) omnidata_ckpt = torch.load( weights_path, map_location="cpu", ) self.model = DPTDepthModel() self.model.load_state_dict(omnidata_ckpt) self.model: DPTDepthModel = self.model.to(device).eval() self.transform = transforms.Compose( [ transforms.Resize( (self.image_size, self.image_size), interpolation=TF.InterpolationMode.BICUBIC, ), transforms.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), ), ] ) def forward(self, image: Tensor) -> Tensor: depth_model_input = self.transform(image.unsqueeze(0)) return self.model.forward(depth_model_input.to(device)).squeeze(0) def batch_forward(self, images: Tensor) -> Tensor: images: Tensor = TF.resize( images, (self.image_size, self.image_size), interpolation=TF.InterpolationMode.BICUBIC, ) depth_model_input = (images - 0.5) / 0.5 return self.model(depth_model_input.to(device))