Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Zhenyu Li | |
import itertools | |
import math | |
import copy | |
import torch | |
import random | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from mmengine import print_log | |
from mmengine.config import ConfigDict | |
from torchvision.ops import roi_align as torch_roi_align | |
from huggingface_hub import PyTorchModelHubMixin | |
from transformers import PretrainedConfig | |
from estimator.registry import MODELS | |
from estimator.models import build_model | |
from estimator.models.baseline_pretrain import BaselinePretrain | |
from estimator.models.utils import generatemask | |
from zoedepth.models.zoedepth import ZoeDepth | |
from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed | |
from zoedepth.models.layers.dist_layers import ConditionalLogBinomial | |
from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed) | |
from zoedepth.models.base_models.midas import Resize as ResizeZoe | |
from depth_anything.transform import Resize as ResizeDA | |
class PatchFusion(BaselinePretrain, PyTorchModelHubMixin): | |
def __init__( | |
self, | |
config,): | |
"""ZoeDepth model | |
""" | |
nn.Module.__init__(self) | |
if isinstance(config, ConfigDict): | |
# convert a ConfigDict to a PretrainedConfig for hf saving | |
config = PretrainedConfig.from_dict(config.to_dict()) | |
config.load_branch = True | |
else: | |
# used when loading patchfusion from hf model space | |
config = PretrainedConfig.from_dict(ConfigDict(**config).to_dict()) | |
config.load_branch = False | |
config.coarse_branch.pretrained_resource = None | |
config.fine_branch.pretrained_resource = None | |
self.config = config | |
self.min_depth = config.min_depth | |
self.max_depth = config.max_depth | |
self.patch_process_shape = config.patch_process_shape | |
self.tile_cfg = self.prepare_tile_cfg(config.image_raw_shape, config.patch_split_num) | |
self.coarse_branch_cfg = config.coarse_branch | |
if config.coarse_branch.type == 'ZoeDepth': | |
self.coarse_branch = ZoeDepth.build(**config.coarse_branch) | |
self.resizer = ResizeZoe(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") | |
elif config.coarse_branch.type == 'DA-ZoeDepth': | |
self.coarse_branch = ZoeDepth.build(**config.coarse_branch) | |
self.resizer = ResizeDA(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal") | |
else: | |
raise NotImplementedError | |
if config.fine_branch.type == 'ZoeDepth': | |
self.fine_branch = ZoeDepth.build(**config.fine_branch) | |
elif config.fine_branch.type == 'DA-ZoeDepth': | |
self.fine_branch = ZoeDepth.build(**config.fine_branch) | |
else: | |
raise NotImplementedError | |
if config.load_branch: | |
print_log("Loading coarse_branch from {}".format(config.pretrain_model[0]), logger='current') | |
print_log(self.coarse_branch.load_state_dict(torch.load(config.pretrain_model[0], map_location='cpu')['model_state_dict'], strict=True), logger='current') # coarse ckp | |
print_log("Loading fine_branch from {}".format(config.pretrain_model[1]), logger='current') | |
print_log(self.fine_branch.load_state_dict(torch.load(config.pretrain_model[1], map_location='cpu')['model_state_dict'], strict=True), logger='current') | |
# freeze all these parameters | |
for param in self.coarse_branch.parameters(): | |
param.requires_grad = False | |
for param in self.fine_branch.parameters(): | |
param.requires_grad = False | |
self.sigloss = build_model(config.sigloss) | |
N_MIDAS_OUT = 32 | |
btlnck_features = self.fine_branch.core.output_channels[0] | |
self.fusion_conv_list = nn.ModuleList() | |
for i in range(6): | |
if i == 5: | |
layer = nn.Conv2d(N_MIDAS_OUT * 2, N_MIDAS_OUT, 3, 1, 1) | |
else: | |
layer = nn.Conv2d(btlnck_features * 2, btlnck_features, 3, 1, 1) | |
self.fusion_conv_list.append(layer) | |
self.guided_fusion = build_model(config.guided_fusion) | |
# NOTE: a decoder head | |
if self.coarse_branch_cfg.bin_centers_type == "normed": | |
SeedBinRegressorLayer = SeedBinRegressor | |
Attractor = AttractorLayer | |
elif self.coarse_branch_cfg.bin_centers_type == "softplus": # default | |
SeedBinRegressorLayer = SeedBinRegressorUnnormed | |
Attractor = AttractorLayerUnnormed | |
elif self.coarse_branch_cfg.bin_centers_type == "hybrid1": | |
SeedBinRegressorLayer = SeedBinRegressor | |
Attractor = AttractorLayerUnnormed | |
elif self.coarse_branch_cfg.bin_centers_type == "hybrid2": | |
SeedBinRegressorLayer = SeedBinRegressorUnnormed | |
Attractor = AttractorLayer | |
else: | |
raise ValueError( | |
"bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") | |
N_MIDAS_OUT = 32 | |
btlnck_features = self.fine_branch.core.output_channels[0] | |
num_out_features = self.fine_branch.core.output_channels[1:] # all of them are the same | |
self.seed_bin_regressor = SeedBinRegressorLayer( | |
btlnck_features, n_bins=self.coarse_branch_cfg.n_bins, min_depth=config.min_depth, max_depth=config.max_depth) | |
self.seed_projector = Projector(btlnck_features, self.coarse_branch_cfg.bin_embedding_dim) | |
self.projectors = nn.ModuleList([ | |
Projector(num_out, self.coarse_branch_cfg.bin_embedding_dim) | |
for num_out in num_out_features | |
]) | |
# 1000, 2, inv, mean | |
self.attractors = nn.ModuleList([ | |
Attractor(self.coarse_branch_cfg.bin_embedding_dim, self.coarse_branch_cfg.n_bins, n_attractors=self.coarse_branch_cfg.n_attractors[i], min_depth=config.min_depth, max_depth=config.max_depth, | |
alpha=self.coarse_branch_cfg.attractor_alpha, gamma=self.coarse_branch_cfg.attractor_gamma, kind=self.coarse_branch_cfg.attractor_kind, attractor_type=self.coarse_branch_cfg.attractor_type) | |
for i in range(len(num_out_features)) | |
]) | |
last_in = N_MIDAS_OUT + 1 # +1 for relative depth | |
# use log binomial instead of softmax | |
self.conditional_log_binomial = ConditionalLogBinomial( | |
last_in, self.coarse_branch_cfg.bin_embedding_dim, n_classes=self.coarse_branch_cfg.n_bins, min_temp=self.coarse_branch_cfg.min_temp, max_temp=self.coarse_branch_cfg.max_temp) | |
# NOTE: consistency training | |
self.consistency_training = False | |
def load_dict(self, dict): | |
return self.load_state_dict(dict, strict=False) | |
def get_save_dict(self): | |
current_model_dict = self.state_dict() | |
save_state_dict = {} | |
for k, v in current_model_dict.items(): | |
if 'coarse_branch' in k or 'fine_branch' in k: | |
pass | |
else: | |
save_state_dict[k] = v | |
return save_state_dict | |
def coarse_forward(self, image_lr): | |
with torch.no_grad(): | |
if self.coarse_branch.training: | |
self.coarse_branch.eval() | |
deep_model_output_dict = self.coarse_branch(image_lr, return_final_centers=True) | |
deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4] | |
coarse_prediction = deep_model_output_dict['metric_depth'] | |
coarse_features = [ | |
deep_features['x_d0'], | |
deep_features['x_blocks_feat_0'], | |
deep_features['x_blocks_feat_1'], | |
deep_features['x_blocks_feat_2'], | |
deep_features['x_blocks_feat_3'], | |
deep_features['midas_final_feat']] # bs, c, h, w | |
return coarse_prediction, coarse_features | |
def fine_forward(self, image_hr_crop): | |
with torch.no_grad(): | |
if self.fine_branch.training: | |
self.fine_branch.eval() | |
deep_model_output_dict = self.fine_branch(image_hr_crop, return_final_centers=True) | |
deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4] | |
fine_prediction = deep_model_output_dict['metric_depth'] | |
fine_features = [ | |
deep_features['x_d0'], | |
deep_features['x_blocks_feat_0'], | |
deep_features['x_blocks_feat_1'], | |
deep_features['x_blocks_feat_2'], | |
deep_features['x_blocks_feat_3'], | |
deep_features['midas_final_feat']] # bs, c, h, w | |
return fine_prediction, fine_features | |
def coarse_postprocess_train(self, coarse_prediction, coarse_features, bboxs, bboxs_feat): | |
coarse_features_patch_area = [] | |
for idx, feat in enumerate(coarse_features): | |
bs, _, h, w = feat.shape | |
cur_lvl_feat = torch_roi_align(feat, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True) | |
coarse_features_patch_area.append(cur_lvl_feat) | |
coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True) | |
return coarse_prediction_roi, coarse_features_patch_area | |
def coarse_postprocess_test(self, coarse_prediction, coarse_features, bboxs, bboxs_feat): | |
patch_num = bboxs_feat.shape[0] | |
coarse_features_patch_area = [] | |
for idx, feat in enumerate(coarse_features): | |
bs, _, h, w = feat.shape | |
feat_extend = feat.repeat(patch_num, 1, 1, 1) | |
cur_lvl_feat = torch_roi_align(feat_extend, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True) | |
coarse_features_patch_area.append(cur_lvl_feat) | |
coarse_prediction = coarse_prediction.repeat(patch_num, 1, 1, 1) | |
coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True) | |
return_dict = { | |
'coarse_depth_roi': coarse_prediction_roi, | |
'coarse_feats_roi': coarse_features_patch_area} | |
return return_dict | |
def fusion_forward(self, fine_depth_pred, crop_input, coarse_model_midas_enc_feats, fine_model_midas_enc_feats, bbox_feat, coarse_depth_roi=None, coarse_feats_roi=None): | |
feat_cat_list = [] | |
feat_plus_list = [] | |
for l_i, (f_ca, f_c_roi, f_f) in enumerate(zip(coarse_model_midas_enc_feats, coarse_feats_roi, fine_model_midas_enc_feats)): | |
feat_cat = self.fusion_conv_list[l_i](torch.cat([f_c_roi, f_f], dim=1)) | |
feat_plus = f_c_roi + f_f | |
feat_cat_list.append(feat_cat) | |
feat_plus_list.append(feat_plus) | |
input_tensor = torch.cat([coarse_depth_roi, fine_depth_pred, crop_input], dim=1) | |
# HACK: hack for depth-anything | |
# if self.coarse_branch_cfg.type == 'DA-ZoeDepth': | |
# input_tensor = F.interpolate(input_tensor, size=(448, 592), mode='bilinear', align_corners=True) | |
output = self.guided_fusion( | |
input_tensor = input_tensor, | |
guide_plus = feat_plus_list, | |
guide_cat = feat_cat_list, | |
bbox = bbox_feat, | |
fine_feat_crop = fine_model_midas_enc_feats, | |
coarse_feat_whole = coarse_model_midas_enc_feats, | |
coarse_feat_crop = coarse_feats_roi, | |
coarse_feat_whole_hack=None)[::-1] # low -> high | |
x_blocks = output | |
x = x_blocks[0] | |
x_blocks = x_blocks[1:] | |
proj_feat_list = [] | |
if self.consistency_training: | |
if self.consistency_target == 'unet_feat': | |
proj_feat_list = [] | |
for idx, feat in enumerate(output): | |
proj_feat = self.consistency_projs[idx](feat) | |
proj_feat_list.append(proj_feat) | |
# NOTE: below is ZoeDepth implementation | |
last = x_blocks[-1] # have already been fused in x_blocks | |
bs, c, h, w = last.shape | |
rel_cond = torch.zeros((bs, 1, h, w), device=last.device) | |
_, seed_b_centers = self.seed_bin_regressor(x) | |
if self.coarse_branch_cfg.bin_centers_type == 'normed' or self.coarse_branch_cfg.bin_centers_type == 'hybrid2': | |
b_prev = (seed_b_centers - self.min_depth) / \ | |
(self.max_depth - self.min_depth) | |
else: | |
b_prev = seed_b_centers | |
prev_b_embedding = self.seed_projector(x) | |
# unroll this loop for better performance | |
for idx, (projector, attractor, x) in enumerate(zip(self.projectors, self.attractors, x_blocks)): | |
b_embedding = projector(x) | |
b, b_centers = attractor( | |
b_embedding, b_prev, prev_b_embedding, interpolate=True) | |
b_prev = b.clone() | |
prev_b_embedding = b_embedding.clone() | |
if self.consistency_training: | |
if self.consistency_target == 'final_feat': | |
proj_feat_1 = self.consistency_projs[0](b_centers) | |
proj_feat_2 = self.consistency_projs[1](last) | |
proj_feat_3 = self.consistency_projs[2](b_embedding) | |
proj_feat_list = [proj_feat_1, proj_feat_2, proj_feat_3] | |
rel_cond = nn.functional.interpolate( | |
rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) | |
last = torch.cat([last, rel_cond], dim=1) # + self.coarse_depth_proj(whole_depth_roi_pred) + self.fine_depth_proj(fine_depth_pred) | |
b_embedding = nn.functional.interpolate( | |
b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) | |
# till here, we have features (attached with a relative depth prediction) and embeddings | |
# post process | |
# final_pred = out * self.blur_mask + whole_depth_roi_pred * (1-self.blur_mask) | |
# out = F.interpolate(out, (540, 960), mode='bilinear', align_corners=True) | |
x = self.conditional_log_binomial(last, b_embedding) | |
b_centers = nn.functional.interpolate( | |
b_centers, x.shape[-2:], mode='bilinear', align_corners=True) | |
out = torch.sum(x * b_centers, dim=1, keepdim=True) | |
return out, proj_feat_list | |
def infer_forward(self, imgs_crop, bbox_feat_forward, tile_temp, coarse_temp_dict): | |
fine_prediction, fine_features = self.fine_forward(imgs_crop) | |
depth_prediction, consistency_target = \ | |
self.fusion_forward( | |
fine_prediction, | |
imgs_crop, | |
tile_temp['coarse_features'], | |
fine_features, | |
bbox_feat_forward, | |
**coarse_temp_dict) | |
return depth_prediction | |
def forward( | |
self, | |
mode, | |
image_lr, | |
image_hr, | |
depth_gt=None, | |
crops_image_hr=None, | |
crop_depths=None, | |
bboxs=None, | |
tile_cfg=None, | |
cai_mode='m1', | |
process_num=4): | |
if mode == 'train': | |
bboxs_feat_factor = torch.tensor([ | |
1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0], | |
1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0) | |
bboxs_feat = bboxs * bboxs_feat_factor | |
inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1) | |
bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1) | |
coarse_prediction, coarse_features = self.coarse_forward(image_lr) | |
fine_prediction, fine_features = self.fine_forward(crops_image_hr) | |
coarse_prediction_roi, coarse_features_patch_area = self.coarse_postprocess_train(coarse_prediction, coarse_features, bboxs, bboxs_feat) | |
depth_prediction, consistency_target = self.fusion_forward( | |
fine_prediction, | |
crops_image_hr, | |
coarse_features, | |
fine_features, | |
bboxs_feat, | |
coarse_depth_roi=coarse_prediction_roi, | |
coarse_feats_roi=coarse_features_patch_area,) | |
loss_dict = {} | |
loss_dict['sig_loss'] = self.sigloss(depth_prediction, crop_depths, self.min_depth, self.max_depth) | |
loss_dict['total_loss'] = loss_dict['sig_loss'] | |
return loss_dict, {'rgb': crops_image_hr, 'depth_pred': depth_prediction, 'depth_gt': crop_depths} | |
else: | |
if tile_cfg is None: | |
tile_cfg = self.tile_cfg | |
else: | |
tile_cfg = self.prepare_tile_cfg(tile_cfg['image_raw_shape'], tile_cfg['patch_split_num']) | |
assert image_hr.shape[0] == 1 | |
coarse_prediction, coarse_features = self.coarse_forward(image_lr) | |
tile_temp = { | |
'coarse_prediction': coarse_prediction, | |
'coarse_features': coarse_features,} | |
blur_mask = generatemask((self.patch_process_shape[0], self.patch_process_shape[1])) + 1e-3 | |
blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
avg_depth_map = self.regular_tile( | |
offset=[0, 0], | |
offset_process=[0, 0], | |
image_hr=image_hr[0], | |
init_flag=True, | |
tile_temp=tile_temp, | |
blur_mask=blur_mask, | |
tile_cfg=tile_cfg, | |
process_num=process_num) | |
if cai_mode == 'm2' or cai_mode[0] == 'r': | |
avg_depth_map = self.regular_tile( | |
offset=[0, tile_cfg['patch_raw_shape'][1]//2], | |
offset_process=[0, self.patch_process_shape[1]//2], | |
image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
avg_depth_map = self.regular_tile( | |
offset=[tile_cfg['patch_raw_shape'][0]//2, 0], | |
offset_process=[self.patch_process_shape[0]//2, 0], | |
image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
avg_depth_map = self.regular_tile( | |
offset=[tile_cfg['patch_raw_shape'][0]//2, tile_cfg['patch_raw_shape'][1]//2], | |
offset_process=[self.patch_process_shape[0]//2, self.patch_process_shape[1]//2], | |
init_flag=False, image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
if cai_mode[0] == 'r': | |
blur_mask = generatemask((tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1])) + 1e-3 | |
blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
avg_depth_map.resize(tile_cfg['image_raw_shape']) | |
patch_num = int(cai_mode[1:]) // process_num | |
for i in range(patch_num): | |
avg_depth_map = self.random_tile( | |
image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
depth = avg_depth_map.average_map | |
depth = depth.unsqueeze(dim=0).unsqueeze(dim=0) | |
return depth, {'rgb': image_lr, 'depth_pred': depth, 'depth_gt': depth_gt} | |