Spaces:
Running
Running
File size: 2,773 Bytes
e150a4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from .base import BaseModel
from .schema import DINOConfiguration
import logging
import torch
import torch.nn as nn
import sys
import re
import os
from .dinov2.eval.depth.ops.wrappers import resize
from .dinov2.hub.backbones import dinov2_vitb14_reg
module_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(module_dir)
logger = logging.getLogger(__name__)
class FeatureExtractor(BaseModel):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def build_encoder(self, conf: DINOConfiguration):
BACKBONE_SIZE = "small"
backbone_archs = {
"small": "vits14",
"base": "vitb14", # this one
"large": "vitl14",
"giant": "vitg14",
}
backbone_arch = backbone_archs[BACKBONE_SIZE]
self.crop_size = int(re.search(r"\d+", backbone_arch).group())
backbone_name = f"dinov2_{backbone_arch}"
self.backbone_model = dinov2_vitb14_reg(
pretrained=conf.pretrained, drop_path_rate=0.1)
if conf.frozen:
for param in self.backbone_model.patch_embed.parameters():
param.requires_grad = False
for i in range(0, 10):
for param in self.backbone_model.blocks[i].parameters():
param.requires_grad = False
self.backbone_model.blocks[i].drop_path1 = nn.Identity()
self.backbone_model.blocks[i].drop_path2 = nn.Identity()
self.feat_projection = torch.nn.Conv2d(
768, conf.output_dim, kernel_size=1)
return self.backbone_model
def _init(self, conf: DINOConfiguration):
# Preprocessing
self.register_buffer("mean_", torch.tensor(
self.mean), persistent=False)
self.register_buffer("std_", torch.tensor(self.std), persistent=False)
self.build_encoder(conf)
def _forward(self, data):
_, _, h, w = data["image"].shape
h_num_patches = h // self.crop_size
w_num_patches = w // self.crop_size
h_dino = h_num_patches * self.crop_size
w_dino = w_num_patches * self.crop_size
image = resize(data["image"], (h_dino, w_dino))
image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
output = self.backbone_model.forward_features(
image)['x_norm_patchtokens']
output = output.reshape(-1, h_num_patches,
w_num_patches, output.shape[-1])
output = output.permute(0, 3, 1, 2) # channel first
output = self.feat_projection(output)
camera = data['camera'].to(data["image"].device, non_blocking=True)
camera = camera.scale(output.shape[-1] / data["image"].shape[-1])
return output, camera
|