Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Shariq Farooq Bhat | |
import itertools | |
import math | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from zoedepth.models.depth_model import DepthModel | |
from zoedepth.models.base_models.midas import MidasCore | |
from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed | |
from zoedepth.models.layers.dist_layers import ConditionalLogBinomial, ConditionalLogBinomialV2 | |
from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, | |
SeedBinRegressorUnnormed) | |
from zoedepth.models.model_io import load_state_from_resource | |
from torchvision.transforms import Normalize | |
def get_activation(name, bank): | |
# input of forward_hook will be a function of model/inp/oup | |
def hook(module, input, output): | |
bank[name] = output | |
return hook | |
class ZoeDepthCustom(DepthModel): | |
def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, | |
n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, | |
midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, sr_ratio=1, | |
raw_depth_shape=(2160, 3840), transform_sample_gt_size=(2160, 3840), representation='', | |
fetch_features=True, sample_feat_level=3, use_hr=False, deform=False, wo_bins=False, baseline=False, **kwargs): | |
"""ZoeDepth model. This is the version of ZoeDepth that has a single metric head | |
Args: | |
core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features | |
n_bins (int, optional): Number of bin centers. Defaults to 64. | |
bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. | |
For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". | |
bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. | |
min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. | |
max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. | |
n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. | |
attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. | |
attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. | |
attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. | |
attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. | |
min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. | |
max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. | |
train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. | |
midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. | |
encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. | |
pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. | |
sr_ratio: sr ratio during infer | |
raw_depth_shape: raw depth shape during infer. times sr_ratio will be the target resolution. Used to sample points during training | |
transform_sample_gt_size: training depth shape # influenced by crop shape which is not included in this pipeline right now | |
representation: I use it to test the "bilap head" and a discarded idea | |
fetch_features: if fetch feats. Default=True | |
""" | |
super().__init__() | |
self.core = core | |
self.max_depth = max_depth | |
self.min_depth = min_depth | |
self.min_temp = min_temp | |
self.bin_centers_type = bin_centers_type | |
self.midas_lr_factor = midas_lr_factor | |
self.encoder_lr_factor = encoder_lr_factor | |
self.pos_enc_lr_factor = pos_enc_lr_factor | |
self.train_midas = train_midas | |
self.inverse_midas = inverse_midas | |
if self.encoder_lr_factor <= 0: | |
self.core.freeze_encoder( | |
freeze_rel_pos=self.pos_enc_lr_factor <= 0) | |
N_MIDAS_OUT = 32 | |
btlnck_features = self.core.output_channels[0] | |
num_out_features = self.core.output_channels[1:] | |
self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, | |
kernel_size=1, stride=1, padding=0) # btlnck conv | |
if bin_centers_type == "normed": | |
SeedBinRegressorLayer = SeedBinRegressor | |
Attractor = AttractorLayer | |
elif bin_centers_type == "softplus": # default | |
SeedBinRegressorLayer = SeedBinRegressorUnnormed | |
Attractor = AttractorLayerUnnormed | |
elif bin_centers_type == "hybrid1": | |
SeedBinRegressorLayer = SeedBinRegressor | |
Attractor = AttractorLayerUnnormed | |
elif bin_centers_type == "hybrid2": | |
SeedBinRegressorLayer = SeedBinRegressorUnnormed | |
Attractor = AttractorLayer | |
else: | |
raise ValueError( | |
"bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") | |
self.seed_bin_regressor = SeedBinRegressorLayer( | |
btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) | |
self.seed_projector = Projector(btlnck_features, bin_embedding_dim) | |
self.projectors = nn.ModuleList([ | |
Projector(num_out, bin_embedding_dim) | |
for num_out in num_out_features | |
]) | |
# 1000, 2, inv, mean | |
self.attractors = nn.ModuleList([ | |
Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, | |
alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) | |
for i in range(len(num_out_features)) | |
]) | |
last_in = N_MIDAS_OUT + 1 # +1 for relative depth | |
# use log binomial instead of softmax | |
self.conditional_log_binomial = ConditionalLogBinomial( | |
last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) | |
self.handles = [] | |
self.hook_feats = {} | |
self.set_fetch_features(fetch_features) | |
self.baseline = baseline | |
def init_weight(self): | |
if self.deform: | |
for m in self.mlp_feat_offset.modules(): | |
if isinstance(m, nn.Conv1d): | |
torch.nn.init.constant_(m.weight, 0) | |
if self.representation == 'biLaplacian': | |
for m in self.proj_x_block: | |
torch.nn.init.constant_(m.weight, 0) | |
def forward(self, x, sampled_depth=None, mode='train', return_final_centers=False, denorm=False, return_probs=False, image_raw=None, **kwargs): | |
""" | |
Args: | |
x (torch.Tensor): Input image tensor of shape (B, C, H, W) | |
return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. | |
denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. | |
return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. | |
Returns: | |
dict: Dictionary containing the following keys: | |
- rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) | |
- metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) | |
- bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True | |
- probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True | |
""" | |
b, c, h, w = x.shape | |
# print("input shape ", x.shape) | |
self.orig_input_width = w | |
self.orig_input_height = h | |
rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) | |
# print("output shapes", rel_depth.shape, out.shape) | |
outconv_activation = out[0] | |
btlnck = out[1] | |
x_blocks = out[2:] | |
x_d0 = self.conv2(btlnck) | |
self.hook_feats['x_d0'] = x_d0 | |
x = x_d0 | |
last = outconv_activation | |
self.hook_feats['midas_final_feat'] = last | |
if self.inverse_midas: | |
# invert depth followed by normalization | |
rel_depth = 1.0 / (rel_depth + 1e-6) | |
rel_depth = (rel_depth - rel_depth.min()) / \ | |
(rel_depth.max() - rel_depth.min()) | |
# concat rel depth with last. First interpolate rel depth to last size | |
rel_cond = rel_depth.unsqueeze(1) | |
self.hook_feats['rel_depth'] = rel_cond | |
_, seed_b_centers = self.seed_bin_regressor(x) | |
if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': | |
b_prev = (seed_b_centers - self.min_depth) / \ | |
(self.max_depth - self.min_depth) | |
else: | |
b_prev = seed_b_centers | |
prev_b_embedding = self.seed_projector(x) | |
# unroll this loop for better performance | |
for idx, (projector, attractor, x) in enumerate(zip(self.projectors, self.attractors, x_blocks)): | |
b_embedding = projector(x) | |
self.hook_feats['x_blocks_feat_{}'.format(idx)] = x | |
# self.hook_feats['attractor_{}_0'.format(idx)] = b_embedding | |
# self.hook_feats['attractor_{}_1'.format(idx)] = b_prev | |
# self.hook_feats['attractor_{}_2'.format(idx)] = prev_b_embedding | |
b, b_centers = attractor( | |
b_embedding, b_prev, prev_b_embedding, interpolate=True) | |
b_prev = b.clone() | |
prev_b_embedding = b_embedding.clone() | |
self.hook_feats['b_centers'] = b_centers | |
if self.baseline: | |
rel_cond = nn.functional.interpolate( | |
rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) | |
last = torch.cat([last, rel_cond], dim=1) | |
b_embedding = nn.functional.interpolate( | |
b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) | |
# till here, we have features (attached with a relative depth prediction) and embeddings | |
x = self.conditional_log_binomial(last, b_embedding) | |
b_centers = nn.functional.interpolate( | |
b_centers, x.shape[-2:], mode='bilinear', align_corners=True) | |
self.hook_feats['b_centers'] = b_centers | |
out = torch.sum(x * b_centers, dim=1, keepdim=True) | |
output = dict(metric_depth=out) | |
return output | |
# if mode == 'train': | |
# if self.wo_bins and self.representation == 'biLaplacian': | |
# pred_depth_dict = self.implicit_function_train_process(sampled_depth, mode=mode) | |
# output = pred_depth_dict | |
# else: | |
# pred_depth = self.implicit_function_train_process(sampled_depth, mode=mode) | |
# output = dict(metric_depth=pred_depth) | |
# elif mode == 'eval': | |
# pred_depth = self.implicit_function_infer_process(mode=mode) | |
# output = dict(metric_depth=pred_depth) | |
# else: | |
# raise NotImplementedError | |
return output | |
# change param lr later | |
def get_lr_params(self, lr): | |
""" | |
Learning rate configuration for different layers of the model | |
Args: | |
lr (float) : Base learning rate | |
Returns: | |
list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. | |
""" | |
param_conf = [] | |
if self.train_midas: | |
if self.encoder_lr_factor > 0: | |
param_conf.append({'params': self.core.get_enc_params_except_rel_pos( | |
), 'lr': lr / self.encoder_lr_factor}) | |
if self.pos_enc_lr_factor > 0: | |
param_conf.append( | |
{'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) | |
midas_params = self.core.core.scratch.parameters() | |
midas_lr_factor = self.midas_lr_factor | |
param_conf.append( | |
{'params': midas_params, 'lr': lr / midas_lr_factor}) | |
remaining_modules = [] | |
for name, child in self.named_children(): | |
if name != 'core': | |
remaining_modules.append(child) | |
remaining_params = itertools.chain( | |
*[child.parameters() for child in remaining_modules]) | |
param_conf.append({'params': remaining_params, 'lr': lr}) | |
return param_conf | |
def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): | |
core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, | |
train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) | |
model = ZoeDepthCustom(core, **kwargs) | |
if pretrained_resource: | |
assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" | |
model = load_state_from_resource(model, pretrained_resource) | |
# a = torch.ones((2, 3, 100, 200)) | |
# b = torch.ones((2, 500, 3)) | |
# model(a, b) | |
return model | |
def build_from_config(config): | |
return ZoeDepthCustom.build(**config) | |
def remove_hooks(self): | |
for h in self.handles: | |
h.remove() | |
return self | |
def set_fetch_features(self, fetch_features): | |
self.fetch_features = fetch_features | |
if fetch_features: | |
if len(self.handles) == 0: | |
self.attach_hooks() | |
else: | |
self.remove_hooks() | |
return self | |
def attach_hooks(self): | |
self.handles.append(self.seed_projector.register_forward_hook(get_activation("seed_projector", self.hook_feats))) | |
for idx, proj in enumerate(self.projectors): | |
self.handles.append(proj.register_forward_hook(get_activation("projector_{}".format(idx), self.hook_feats))) | |
for idx, proj in enumerate(self.attractors): | |
self.handles.append(proj.register_forward_hook(get_activation("attractor_{}".format(idx), self.hook_feats))) | |
return self | |