xiaoyuxi
Cleaned history, reset to current state
c8d9d42
raw
history blame
3.46 kB
import argparse
import cv2
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose
from models.monoD.depth_anything.dpt import DPT_DINOv2
from models.monoD.depth_anything.util.transform import (
Resize, NormalizeImage, PrepareForNet
)
def build(config):
"""
Build the model from the config
NOTE: the config should contain the following
- encoder: the encoder type of the model
- load_from: the path to the pretrained model
"""
args = config
assert args.encoder in ['vits', 'vitb', 'vitl']
if args.encoder == 'vits':
depth_anything = DPT_DINOv2(encoder='vits', features=64,
out_channels=[48, 96, 192, 384],
localhub=args.localhub).cuda()
elif args.encoder == 'vitb':
depth_anything = DPT_DINOv2(encoder='vitb', features=128,
out_channels=[96, 192, 384, 768],
localhub=args.localhub).cuda()
else:
depth_anything = DPT_DINOv2(encoder='vitl', features=256,
out_channels=[256, 512, 1024, 1024],
localhub=args.localhub).cuda()
depth_anything.load_state_dict(torch.load(args.load_from,
map_location='cpu'), strict=True)
total_params = sum(param.numel() for param in depth_anything.parameters())
print('Total parameters: {:.2f}M'.format(total_params / 1e6))
depth_anything.eval()
return depth_anything
class DepthAnything(nn.Module):
def __init__(self, args):
super(DepthAnything, self).__init__()
# build the chosen model
self.dpAny = build(args)
def infer(self, rgbs):
"""
Infer the depth map from the input RGB image
Args:
rgbs: the input RGB image B x 3 x H x W (Cuda Tensor)
Asserts:
the input should be a cuda tensor
"""
assert (rgbs.is_cuda)&(len(rgbs.shape) == 4)
T, C, H, W = rgbs.shape
# prepare the input
Resizer = Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
)
#NOTE: step 1 Resize
width, height = Resizer.get_size(
rgbs.shape[2], rgbs.shape[3]
)
rgbs = F.interpolate(
rgbs, (int(height), int(width)), mode='bicubic', align_corners=False
)
#NOTE: step 2 NormalizeImage
mean_ = torch.tensor([0.485, 0.456, 0.406],
device=rgbs.device).view(1, 3, 1, 1)
std_ = torch.tensor([0.229, 0.224, 0.225],
device=rgbs.device).view(1, 3, 1, 1)
rgbs = (rgbs - mean_)/std_
#NOTE: step 3 PrepareForNet
# get the depth map
disp = self.dpAny(rgbs)
disp = F.interpolate(
disp[:,None], (H, W),
mode='bilinear', align_corners=False
)
# clamping the farthest depth to 100x of the nearest
depth_map = disp
return depth_map