|
|
|
|
|
|
|
|
|
|
|
from enum import Enum
|
|
from functools import partial
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from .backbones import _make_dinov2_model
|
|
from .depth import BNHead, DepthEncoderDecoder, DPTHead
|
|
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
|
|
|
|
|
|
class Weights(Enum):
|
|
NYU = "NYU"
|
|
KITTI = "KITTI"
|
|
|
|
|
|
def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
|
|
if not pretrained:
|
|
return (0.001, 10.0)
|
|
|
|
|
|
if weights == Weights.KITTI:
|
|
return (0.001, 80.0)
|
|
|
|
if weights == Weights.NYU:
|
|
return (0.001, 10.0)
|
|
|
|
return (0.001, 10.0)
|
|
|
|
|
|
def _make_dinov2_linear_depth_head(
|
|
*,
|
|
embed_dim: int,
|
|
layers: int,
|
|
min_depth: float,
|
|
max_depth: float,
|
|
**kwargs,
|
|
):
|
|
if layers not in (1, 4):
|
|
raise AssertionError(f"Unsupported number of layers: {layers}")
|
|
|
|
if layers == 1:
|
|
in_index = [0]
|
|
else:
|
|
assert layers == 4
|
|
in_index = [0, 1, 2, 3]
|
|
|
|
return BNHead(
|
|
classify=True,
|
|
n_bins=256,
|
|
bins_strategy="UD",
|
|
norm_strategy="linear",
|
|
upsample=4,
|
|
in_channels=[embed_dim] * len(in_index),
|
|
in_index=in_index,
|
|
input_transform="resize_concat",
|
|
channels=embed_dim * len(in_index) * 2,
|
|
align_corners=False,
|
|
min_depth=0.001,
|
|
max_depth=80,
|
|
loss_decode=(),
|
|
)
|
|
|
|
|
|
def _make_dinov2_linear_depther(
|
|
*,
|
|
arch_name: str = "vit_large",
|
|
layers: int = 4,
|
|
pretrained: bool = True,
|
|
weights: Union[Weights, str] = Weights.NYU,
|
|
depth_range: Optional[Tuple[float, float]] = None,
|
|
**kwargs,
|
|
):
|
|
if layers not in (1, 4):
|
|
raise AssertionError(f"Unsupported number of layers: {layers}")
|
|
if isinstance(weights, str):
|
|
try:
|
|
weights = Weights[weights]
|
|
except KeyError:
|
|
raise AssertionError(f"Unsupported weights: {weights}")
|
|
|
|
if depth_range is None:
|
|
depth_range = _get_depth_range(pretrained, weights)
|
|
min_depth, max_depth = depth_range
|
|
|
|
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
|
|
|
embed_dim = backbone.embed_dim
|
|
patch_size = backbone.patch_size
|
|
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
|
linear_depth_head = _make_dinov2_linear_depth_head(
|
|
embed_dim=embed_dim,
|
|
layers=layers,
|
|
min_depth=min_depth,
|
|
max_depth=max_depth,
|
|
)
|
|
|
|
layer_count = {
|
|
"vit_small": 12,
|
|
"vit_base": 12,
|
|
"vit_large": 24,
|
|
"vit_giant2": 40,
|
|
}[arch_name]
|
|
|
|
if layers == 4:
|
|
out_index = {
|
|
"vit_small": [2, 5, 8, 11],
|
|
"vit_base": [2, 5, 8, 11],
|
|
"vit_large": [4, 11, 17, 23],
|
|
"vit_giant2": [9, 19, 29, 39],
|
|
}[arch_name]
|
|
else:
|
|
assert layers == 1
|
|
out_index = [layer_count - 1]
|
|
|
|
model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
|
|
model.backbone.forward = partial(
|
|
backbone.get_intermediate_layers,
|
|
n=out_index,
|
|
reshape=True,
|
|
return_class_token=True,
|
|
norm=False,
|
|
)
|
|
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
|
|
|
|
if pretrained:
|
|
layers_str = str(layers) if layers == 4 else ""
|
|
weights_str = weights.value.lower()
|
|
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
|
|
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
|
if "state_dict" in checkpoint:
|
|
state_dict = checkpoint["state_dict"]
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
|
return model
|
|
|
|
|
|
def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_linear_depther(
|
|
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
|
)
|
|
|
|
|
|
def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_linear_depther(
|
|
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
|
)
|
|
|
|
|
|
def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_linear_depther(
|
|
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
|
)
|
|
|
|
|
|
def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_linear_depther(
|
|
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
|
)
|
|
|
|
|
|
def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
|
|
return DPTHead(
|
|
in_channels=[embed_dim] * 4,
|
|
channels=256,
|
|
embed_dims=embed_dim,
|
|
post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
|
|
readout_type="project",
|
|
min_depth=min_depth,
|
|
max_depth=max_depth,
|
|
loss_decode=(),
|
|
)
|
|
|
|
|
|
def _make_dinov2_dpt_depther(
|
|
*,
|
|
arch_name: str = "vit_large",
|
|
pretrained: bool = True,
|
|
weights: Union[Weights, str] = Weights.NYU,
|
|
depth_range: Optional[Tuple[float, float]] = None,
|
|
**kwargs,
|
|
):
|
|
if isinstance(weights, str):
|
|
try:
|
|
weights = Weights[weights]
|
|
except KeyError:
|
|
raise AssertionError(f"Unsupported weights: {weights}")
|
|
|
|
if depth_range is None:
|
|
depth_range = _get_depth_range(pretrained, weights)
|
|
min_depth, max_depth = depth_range
|
|
|
|
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
|
|
|
model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
|
|
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
|
|
|
|
out_index = {
|
|
"vit_small": [2, 5, 8, 11],
|
|
"vit_base": [2, 5, 8, 11],
|
|
"vit_large": [4, 11, 17, 23],
|
|
"vit_giant2": [9, 19, 29, 39],
|
|
}[arch_name]
|
|
|
|
model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
|
|
model.backbone.forward = partial(
|
|
backbone.get_intermediate_layers,
|
|
n=out_index,
|
|
reshape=True,
|
|
return_class_token=True,
|
|
norm=False,
|
|
)
|
|
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
|
|
|
|
if pretrained:
|
|
weights_str = weights.value.lower()
|
|
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
|
|
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
|
if "state_dict" in checkpoint:
|
|
state_dict = checkpoint["state_dict"]
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
|
return model
|
|
|
|
|
|
def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
|
|
|
|
|
def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
|
|
|
|
|
def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
|
|
|
|
|
def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
|
return _make_dinov2_dpt_depther(
|
|
arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
|
)
|
|
|