File size: 2,773 Bytes
e150a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from .base import BaseModel
from .schema import DINOConfiguration
import logging
import torch
import torch.nn as nn

import sys
import re
import os

from .dinov2.eval.depth.ops.wrappers import resize
from .dinov2.hub.backbones import dinov2_vitb14_reg

module_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(module_dir)

logger = logging.getLogger(__name__)


class FeatureExtractor(BaseModel):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    def build_encoder(self, conf: DINOConfiguration):
        BACKBONE_SIZE = "small"
        backbone_archs = {
            "small": "vits14",
            "base": "vitb14",  # this one
            "large": "vitl14",
            "giant": "vitg14",
        }
        backbone_arch = backbone_archs[BACKBONE_SIZE]
        self.crop_size = int(re.search(r"\d+", backbone_arch).group())
        backbone_name = f"dinov2_{backbone_arch}"

        self.backbone_model = dinov2_vitb14_reg(
            pretrained=conf.pretrained, drop_path_rate=0.1)

        if conf.frozen:
            for param in self.backbone_model.patch_embed.parameters():
                param.requires_grad = False

            for i in range(0, 10):
                for param in self.backbone_model.blocks[i].parameters():
                    param.requires_grad = False
                self.backbone_model.blocks[i].drop_path1 = nn.Identity()
                self.backbone_model.blocks[i].drop_path2 = nn.Identity()

        self.feat_projection = torch.nn.Conv2d(
            768, conf.output_dim, kernel_size=1)

        return self.backbone_model

    def _init(self, conf: DINOConfiguration):
        # Preprocessing
        self.register_buffer("mean_", torch.tensor(
            self.mean), persistent=False)
        self.register_buffer("std_", torch.tensor(self.std), persistent=False)

        self.build_encoder(conf)

    def _forward(self, data):
        _, _, h, w = data["image"].shape

        h_num_patches = h // self.crop_size
        w_num_patches = w // self.crop_size

        h_dino = h_num_patches * self.crop_size
        w_dino = w_num_patches * self.crop_size

        image = resize(data["image"], (h_dino, w_dino))

        image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]

        output = self.backbone_model.forward_features(
            image)['x_norm_patchtokens']
        output = output.reshape(-1, h_num_patches,
                                w_num_patches, output.shape[-1])
        output = output.permute(0, 3, 1, 2)  # channel first
        output = self.feat_projection(output)

        camera = data['camera'].to(data["image"].device, non_blocking=True)
        camera = camera.scale(output.shape[-1] / data["image"].shape[-1])

        return output, camera