# MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Zhenyu Li import itertools import math import copy import torch import random import numpy as np import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt from mmengine import print_log from mmengine.config import ConfigDict from torchvision.ops import roi_align as torch_roi_align from huggingface_hub import PyTorchModelHubMixin from transformers import PretrainedConfig from estimator.registry import MODELS from estimator.models import build_model from estimator.models.baseline_pretrain import BaselinePretrain from estimator.models.utils import generatemask from zoedepth.models.zoedepth import ZoeDepth from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed from zoedepth.models.layers.dist_layers import ConditionalLogBinomial from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed) from zoedepth.models.base_models.midas import Resize as ResizeZoe from depth_anything.transform import Resize as ResizeDA @MODELS.register_module() class PatchFusion(BaselinePretrain, PyTorchModelHubMixin): def __init__( self, config,): """ZoeDepth model """ nn.Module.__init__(self) if isinstance(config, ConfigDict): # convert a ConfigDict to a PretrainedConfig for hf saving config = PretrainedConfig.from_dict(config.to_dict()) config.load_branch = True else: # used when loading patchfusion from hf model space config = PretrainedConfig.from_dict(ConfigDict(**config).to_dict()) config.load_branch = False config.coarse_branch.pretrained_resource = None config.fine_branch.pretrained_resource = None self.config = config self.min_depth = config.min_depth self.max_depth = config.max_depth self.patch_process_shape = config.patch_process_shape self.tile_cfg = self.prepare_tile_cfg(config.image_raw_shape, config.patch_split_num) self.coarse_branch_cfg = config.coarse_branch if config.coarse_branch.type == 'ZoeDepth': self.coarse_branch = ZoeDepth.build(**config.coarse_branch) self.resizer = ResizeZoe(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") elif config.coarse_branch.type == 'DA-ZoeDepth': self.coarse_branch = ZoeDepth.build(**config.coarse_branch) self.resizer = ResizeDA(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal") else: raise NotImplementedError if config.fine_branch.type == 'ZoeDepth': self.fine_branch = ZoeDepth.build(**config.fine_branch) elif config.fine_branch.type == 'DA-ZoeDepth': self.fine_branch = ZoeDepth.build(**config.fine_branch) else: raise NotImplementedError if config.load_branch: print_log("Loading coarse_branch from {}".format(config.pretrain_model[0]), logger='current') print_log(self.coarse_branch.load_state_dict(torch.load(config.pretrain_model[0], map_location='cpu')['model_state_dict'], strict=True), logger='current') # coarse ckp print_log("Loading fine_branch from {}".format(config.pretrain_model[1]), logger='current') print_log(self.fine_branch.load_state_dict(torch.load(config.pretrain_model[1], map_location='cpu')['model_state_dict'], strict=True), logger='current') # freeze all these parameters for param in self.coarse_branch.parameters(): param.requires_grad = False for param in self.fine_branch.parameters(): param.requires_grad = False self.sigloss = build_model(config.sigloss) N_MIDAS_OUT = 32 btlnck_features = self.fine_branch.core.output_channels[0] self.fusion_conv_list = nn.ModuleList() for i in range(6): if i == 5: layer = nn.Conv2d(N_MIDAS_OUT * 2, N_MIDAS_OUT, 3, 1, 1) else: layer = nn.Conv2d(btlnck_features * 2, btlnck_features, 3, 1, 1) self.fusion_conv_list.append(layer) self.guided_fusion = build_model(config.guided_fusion) # NOTE: a decoder head if self.coarse_branch_cfg.bin_centers_type == "normed": SeedBinRegressorLayer = SeedBinRegressor Attractor = AttractorLayer elif self.coarse_branch_cfg.bin_centers_type == "softplus": # default SeedBinRegressorLayer = SeedBinRegressorUnnormed Attractor = AttractorLayerUnnormed elif self.coarse_branch_cfg.bin_centers_type == "hybrid1": SeedBinRegressorLayer = SeedBinRegressor Attractor = AttractorLayerUnnormed elif self.coarse_branch_cfg.bin_centers_type == "hybrid2": SeedBinRegressorLayer = SeedBinRegressorUnnormed Attractor = AttractorLayer else: raise ValueError( "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") N_MIDAS_OUT = 32 btlnck_features = self.fine_branch.core.output_channels[0] num_out_features = self.fine_branch.core.output_channels[1:] # all of them are the same self.seed_bin_regressor = SeedBinRegressorLayer( btlnck_features, n_bins=self.coarse_branch_cfg.n_bins, min_depth=config.min_depth, max_depth=config.max_depth) self.seed_projector = Projector(btlnck_features, self.coarse_branch_cfg.bin_embedding_dim) self.projectors = nn.ModuleList([ Projector(num_out, self.coarse_branch_cfg.bin_embedding_dim) for num_out in num_out_features ]) # 1000, 2, inv, mean self.attractors = nn.ModuleList([ Attractor(self.coarse_branch_cfg.bin_embedding_dim, self.coarse_branch_cfg.n_bins, n_attractors=self.coarse_branch_cfg.n_attractors[i], min_depth=config.min_depth, max_depth=config.max_depth, alpha=self.coarse_branch_cfg.attractor_alpha, gamma=self.coarse_branch_cfg.attractor_gamma, kind=self.coarse_branch_cfg.attractor_kind, attractor_type=self.coarse_branch_cfg.attractor_type) for i in range(len(num_out_features)) ]) last_in = N_MIDAS_OUT + 1 # +1 for relative depth # use log binomial instead of softmax self.conditional_log_binomial = ConditionalLogBinomial( last_in, self.coarse_branch_cfg.bin_embedding_dim, n_classes=self.coarse_branch_cfg.n_bins, min_temp=self.coarse_branch_cfg.min_temp, max_temp=self.coarse_branch_cfg.max_temp) # NOTE: consistency training self.consistency_training = False def load_dict(self, dict): return self.load_state_dict(dict, strict=False) def get_save_dict(self): current_model_dict = self.state_dict() save_state_dict = {} for k, v in current_model_dict.items(): if 'coarse_branch' in k or 'fine_branch' in k: pass else: save_state_dict[k] = v return save_state_dict def coarse_forward(self, image_lr): with torch.no_grad(): if self.coarse_branch.training: self.coarse_branch.eval() deep_model_output_dict = self.coarse_branch(image_lr, return_final_centers=True) deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4] coarse_prediction = deep_model_output_dict['metric_depth'] coarse_features = [ deep_features['x_d0'], deep_features['x_blocks_feat_0'], deep_features['x_blocks_feat_1'], deep_features['x_blocks_feat_2'], deep_features['x_blocks_feat_3'], deep_features['midas_final_feat']] # bs, c, h, w return coarse_prediction, coarse_features def fine_forward(self, image_hr_crop): with torch.no_grad(): if self.fine_branch.training: self.fine_branch.eval() deep_model_output_dict = self.fine_branch(image_hr_crop, return_final_centers=True) deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4] fine_prediction = deep_model_output_dict['metric_depth'] fine_features = [ deep_features['x_d0'], deep_features['x_blocks_feat_0'], deep_features['x_blocks_feat_1'], deep_features['x_blocks_feat_2'], deep_features['x_blocks_feat_3'], deep_features['midas_final_feat']] # bs, c, h, w return fine_prediction, fine_features def coarse_postprocess_train(self, coarse_prediction, coarse_features, bboxs, bboxs_feat): coarse_features_patch_area = [] for idx, feat in enumerate(coarse_features): bs, _, h, w = feat.shape cur_lvl_feat = torch_roi_align(feat, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True) coarse_features_patch_area.append(cur_lvl_feat) coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True) return coarse_prediction_roi, coarse_features_patch_area def coarse_postprocess_test(self, coarse_prediction, coarse_features, bboxs, bboxs_feat): patch_num = bboxs_feat.shape[0] coarse_features_patch_area = [] for idx, feat in enumerate(coarse_features): bs, _, h, w = feat.shape feat_extend = feat.repeat(patch_num, 1, 1, 1) cur_lvl_feat = torch_roi_align(feat_extend, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True) coarse_features_patch_area.append(cur_lvl_feat) coarse_prediction = coarse_prediction.repeat(patch_num, 1, 1, 1) coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True) return_dict = { 'coarse_depth_roi': coarse_prediction_roi, 'coarse_feats_roi': coarse_features_patch_area} return return_dict def fusion_forward(self, fine_depth_pred, crop_input, coarse_model_midas_enc_feats, fine_model_midas_enc_feats, bbox_feat, coarse_depth_roi=None, coarse_feats_roi=None): feat_cat_list = [] feat_plus_list = [] for l_i, (f_ca, f_c_roi, f_f) in enumerate(zip(coarse_model_midas_enc_feats, coarse_feats_roi, fine_model_midas_enc_feats)): feat_cat = self.fusion_conv_list[l_i](torch.cat([f_c_roi, f_f], dim=1)) feat_plus = f_c_roi + f_f feat_cat_list.append(feat_cat) feat_plus_list.append(feat_plus) input_tensor = torch.cat([coarse_depth_roi, fine_depth_pred, crop_input], dim=1) # HACK: hack for depth-anything # if self.coarse_branch_cfg.type == 'DA-ZoeDepth': # input_tensor = F.interpolate(input_tensor, size=(448, 592), mode='bilinear', align_corners=True) output = self.guided_fusion( input_tensor = input_tensor, guide_plus = feat_plus_list, guide_cat = feat_cat_list, bbox = bbox_feat, fine_feat_crop = fine_model_midas_enc_feats, coarse_feat_whole = coarse_model_midas_enc_feats, coarse_feat_crop = coarse_feats_roi, coarse_feat_whole_hack=None)[::-1] # low -> high x_blocks = output x = x_blocks[0] x_blocks = x_blocks[1:] proj_feat_list = [] if self.consistency_training: if self.consistency_target == 'unet_feat': proj_feat_list = [] for idx, feat in enumerate(output): proj_feat = self.consistency_projs[idx](feat) proj_feat_list.append(proj_feat) # NOTE: below is ZoeDepth implementation last = x_blocks[-1] # have already been fused in x_blocks bs, c, h, w = last.shape rel_cond = torch.zeros((bs, 1, h, w), device=last.device) _, seed_b_centers = self.seed_bin_regressor(x) if self.coarse_branch_cfg.bin_centers_type == 'normed' or self.coarse_branch_cfg.bin_centers_type == 'hybrid2': b_prev = (seed_b_centers - self.min_depth) / \ (self.max_depth - self.min_depth) else: b_prev = seed_b_centers prev_b_embedding = self.seed_projector(x) # unroll this loop for better performance for idx, (projector, attractor, x) in enumerate(zip(self.projectors, self.attractors, x_blocks)): b_embedding = projector(x) b, b_centers = attractor( b_embedding, b_prev, prev_b_embedding, interpolate=True) b_prev = b.clone() prev_b_embedding = b_embedding.clone() if self.consistency_training: if self.consistency_target == 'final_feat': proj_feat_1 = self.consistency_projs[0](b_centers) proj_feat_2 = self.consistency_projs[1](last) proj_feat_3 = self.consistency_projs[2](b_embedding) proj_feat_list = [proj_feat_1, proj_feat_2, proj_feat_3] rel_cond = nn.functional.interpolate( rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) last = torch.cat([last, rel_cond], dim=1) # + self.coarse_depth_proj(whole_depth_roi_pred) + self.fine_depth_proj(fine_depth_pred) b_embedding = nn.functional.interpolate( b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) # till here, we have features (attached with a relative depth prediction) and embeddings # post process # final_pred = out * self.blur_mask + whole_depth_roi_pred * (1-self.blur_mask) # out = F.interpolate(out, (540, 960), mode='bilinear', align_corners=True) x = self.conditional_log_binomial(last, b_embedding) b_centers = nn.functional.interpolate( b_centers, x.shape[-2:], mode='bilinear', align_corners=True) out = torch.sum(x * b_centers, dim=1, keepdim=True) return out, proj_feat_list def infer_forward(self, imgs_crop, bbox_feat_forward, tile_temp, coarse_temp_dict): fine_prediction, fine_features = self.fine_forward(imgs_crop) depth_prediction, consistency_target = \ self.fusion_forward( fine_prediction, imgs_crop, tile_temp['coarse_features'], fine_features, bbox_feat_forward, **coarse_temp_dict) return depth_prediction def forward( self, mode, image_lr, image_hr, depth_gt=None, crops_image_hr=None, crop_depths=None, bboxs=None, tile_cfg=None, cai_mode='m1', process_num=4): if mode == 'train': bboxs_feat_factor = torch.tensor([ 1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], 1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0], 1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], 1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0) bboxs_feat = bboxs * bboxs_feat_factor inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1) bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1) coarse_prediction, coarse_features = self.coarse_forward(image_lr) fine_prediction, fine_features = self.fine_forward(crops_image_hr) coarse_prediction_roi, coarse_features_patch_area = self.coarse_postprocess_train(coarse_prediction, coarse_features, bboxs, bboxs_feat) depth_prediction, consistency_target = self.fusion_forward( fine_prediction, crops_image_hr, coarse_features, fine_features, bboxs_feat, coarse_depth_roi=coarse_prediction_roi, coarse_feats_roi=coarse_features_patch_area,) loss_dict = {} loss_dict['sig_loss'] = self.sigloss(depth_prediction, crop_depths, self.min_depth, self.max_depth) loss_dict['total_loss'] = loss_dict['sig_loss'] return loss_dict, {'rgb': crops_image_hr, 'depth_pred': depth_prediction, 'depth_gt': crop_depths} else: if tile_cfg is None: tile_cfg = self.tile_cfg else: tile_cfg = self.prepare_tile_cfg(tile_cfg['image_raw_shape'], tile_cfg['patch_split_num']) assert image_hr.shape[0] == 1 coarse_prediction, coarse_features = self.coarse_forward(image_lr) tile_temp = { 'coarse_prediction': coarse_prediction, 'coarse_features': coarse_features,} blur_mask = generatemask((self.patch_process_shape[0], self.patch_process_shape[1])) + 1e-3 blur_mask = torch.tensor(blur_mask, device=image_hr.device) avg_depth_map = self.regular_tile( offset=[0, 0], offset_process=[0, 0], image_hr=image_hr[0], init_flag=True, tile_temp=tile_temp, blur_mask=blur_mask, tile_cfg=tile_cfg, process_num=process_num) if cai_mode == 'm2' or cai_mode[0] == 'r': avg_depth_map = self.regular_tile( offset=[0, tile_cfg['patch_raw_shape'][1]//2], offset_process=[0, self.patch_process_shape[1]//2], image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) avg_depth_map = self.regular_tile( offset=[tile_cfg['patch_raw_shape'][0]//2, 0], offset_process=[self.patch_process_shape[0]//2, 0], image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) avg_depth_map = self.regular_tile( offset=[tile_cfg['patch_raw_shape'][0]//2, tile_cfg['patch_raw_shape'][1]//2], offset_process=[self.patch_process_shape[0]//2, self.patch_process_shape[1]//2], init_flag=False, image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) if cai_mode[0] == 'r': blur_mask = generatemask((tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1])) + 1e-3 blur_mask = torch.tensor(blur_mask, device=image_hr.device) avg_depth_map.resize(tile_cfg['image_raw_shape']) patch_num = int(cai_mode[1:]) // process_num for i in range(patch_num): avg_depth_map = self.random_tile( image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) depth = avg_depth_map.average_map depth = depth.unsqueeze(dim=0).unsqueeze(dim=0) return depth, {'rgb': image_lr, 'depth_pred': depth, 'depth_gt': depth_gt}