File size: 2,654 Bytes
6e9c433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b8ec7b
9a0bf16
 
 
 
9b8ec7b
6e9c433
 
 
 
9b8ec7b
6e9c433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch import Tensor, nn

from .app_utils import count_parameters
from .device import device
from .dpt.models import DPTDepthModel


class BaseDepthModel:
    def __init__(self, image_size: int) -> None:
        self.image_size = image_size
        self.model: nn.Module = None

    def forward(self, image: Tensor) -> Tensor:
        """Perform forward inference for an image
        Input image of shape [c, h, w]
        Return of shape [c, h, w]
        """
        raise NotImplementedError()

    def batch_forward(self, images: Tensor) -> Tensor:
        """Perform forward inference for a batch of images
        Input images of shape [b, c, h, w]
        Return of shape [b, c, h, w]"""
        raise NotImplementedError()

    def get_number_of_parameters(self) -> int:
        return count_parameters(self.model)


class DPTDepth(BaseDepthModel):
    def __init__(self, image_size: int) -> None:
        super().__init__(image_size)
        print("DPTDepthconstructor")
        weights_fname = "omnidata_rgb2depth_dpt_hybrid.pth"
        weights_path = os.path.join("weights", weights_fname)
        if not os.path.isfile(weights_path):
            from huggingface_hub import hf_hub_download

            hf_hub_download(
                repo_id="RGBD-SOD/S-MultiMAE",
                filename=weights_fname,
                local_dir="weights",
            )
        omnidata_ckpt = torch.load(
            weights_path,
            map_location="cpu",
        )

        self.model = DPTDepthModel()
        self.model.load_state_dict(omnidata_ckpt)
        self.model: DPTDepthModel = self.model.to(device).eval()

        self.transform = transforms.Compose(
            [
                transforms.Resize(
                    (self.image_size, self.image_size),
                    interpolation=TF.InterpolationMode.BICUBIC,
                ),
                transforms.Normalize(
                    (0.5, 0.5, 0.5),
                    (0.5, 0.5, 0.5),
                ),
            ]
        )

    def forward(self, image: Tensor) -> Tensor:
        depth_model_input = self.transform(image.unsqueeze(0))
        return self.model.forward(depth_model_input.to(device)).squeeze(0)

    def batch_forward(self, images: Tensor) -> Tensor:
        images: Tensor = TF.resize(
            images,
            (self.image_size, self.image_size),
            interpolation=TF.InterpolationMode.BICUBIC,
        )
        depth_model_input = (images - 0.5) / 0.5
        return self.model(depth_model_input.to(device))