aki-0421
F: add
a3a3ae4 unverified
raw
history blame
5.21 kB
# -*- coding: utf-8 -*-
# based on https://github.com/isl-org/MiDaS
import cv2
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from .dpt_depth import DPTDepthModel
from .midas_net import MidasNet
from .midas_net_custom import MidasNet_small
from .transforms import NormalizeImage, PrepareForNet, Resize
# ISL_PATHS = {
# "dpt_large": "dpt_large-midas-2f21e586.pt",
# "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
# "midas_v21": "",
# "midas_v21_small": "",
# }
# remote_model_path =
# "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def load_midas_transform(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == 'dpt_large': # DPT-Large
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'dpt_hybrid': # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'midas_v21':
net_w, net_h = 384, 384
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
elif model_type == 'midas_v21_small':
net_w, net_h = 256, 256
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
else:
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
transform = Compose([
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
])
return transform
def load_model(model_type, model_path):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
# model_path = ISL_PATHS[model_type]
if model_type == 'dpt_large': # DPT-Large
model = DPTDepthModel(
path=model_path,
backbone='vitl16_384',
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'dpt_hybrid': # DPT-Hybrid
model = DPTDepthModel(
path=model_path,
backbone='vitb_rn50_384',
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = 'minimal'
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
elif model_type == 'midas_v21':
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
elif model_type == 'midas_v21_small':
model = MidasNet_small(model_path,
features=64,
backbone='efficientnet_lite3',
exportable=True,
non_negative=True,
blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = 'upper_bound'
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
else:
print(
f"model_type '{model_type}' not implemented, use: --model_type large"
)
assert False
transform = Compose([
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
])
return model.eval(), transform
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
MODEL_TYPES_ISL = [
'dpt_large',
'dpt_hybrid',
'midas_v21',
'midas_v21_small',
]
def __init__(self, model_type, model_path):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type, model_path)
self.model = model
self.model.train = disabled_train
def forward(self, x):
with torch.no_grad():
prediction = self.model(x)
return prediction