PatchFusion / estimator /models /patchfusion.py
Zhyever
refactor
1f418ff
raw
history blame
21.7 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Zhenyu Li
import itertools
import math
import copy
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mmengine import print_log
from mmengine.config import ConfigDict
from torchvision.ops import roi_align as torch_roi_align
from huggingface_hub import PyTorchModelHubMixin
from transformers import PretrainedConfig
from estimator.registry import MODELS
from estimator.models import build_model
from estimator.models.baseline_pretrain import BaselinePretrain
from estimator.models.utils import generatemask
from zoedepth.models.zoedepth import ZoeDepth
from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed
from zoedepth.models.layers.dist_layers import ConditionalLogBinomial
from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed)
from zoedepth.models.base_models.midas import Resize as ResizeZoe
from depth_anything.transform import Resize as ResizeDA
@MODELS.register_module()
class PatchFusion(BaselinePretrain, PyTorchModelHubMixin):
def __init__(
self,
config,):
"""ZoeDepth model
"""
nn.Module.__init__(self)
if isinstance(config, ConfigDict):
# convert a ConfigDict to a PretrainedConfig for hf saving
config = PretrainedConfig.from_dict(config.to_dict())
config.load_branch = True
else:
# used when loading patchfusion from hf model space
config = PretrainedConfig.from_dict(ConfigDict(**config).to_dict())
config.load_branch = False
config.coarse_branch.pretrained_resource = None
config.fine_branch.pretrained_resource = None
self.config = config
self.min_depth = config.min_depth
self.max_depth = config.max_depth
self.patch_process_shape = config.patch_process_shape
self.tile_cfg = self.prepare_tile_cfg(config.image_raw_shape, config.patch_split_num)
self.coarse_branch_cfg = config.coarse_branch
if config.coarse_branch.type == 'ZoeDepth':
self.coarse_branch = ZoeDepth.build(**config.coarse_branch)
self.resizer = ResizeZoe(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal")
elif config.coarse_branch.type == 'DA-ZoeDepth':
self.coarse_branch = ZoeDepth.build(**config.coarse_branch)
self.resizer = ResizeDA(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal")
else:
raise NotImplementedError
if config.fine_branch.type == 'ZoeDepth':
self.fine_branch = ZoeDepth.build(**config.fine_branch)
elif config.fine_branch.type == 'DA-ZoeDepth':
self.fine_branch = ZoeDepth.build(**config.fine_branch)
else:
raise NotImplementedError
if config.load_branch:
print_log("Loading coarse_branch from {}".format(config.pretrain_model[0]), logger='current')
print_log(self.coarse_branch.load_state_dict(torch.load(config.pretrain_model[0], map_location='cpu')['model_state_dict'], strict=True), logger='current') # coarse ckp
print_log("Loading fine_branch from {}".format(config.pretrain_model[1]), logger='current')
print_log(self.fine_branch.load_state_dict(torch.load(config.pretrain_model[1], map_location='cpu')['model_state_dict'], strict=True), logger='current')
# freeze all these parameters
for param in self.coarse_branch.parameters():
param.requires_grad = False
for param in self.fine_branch.parameters():
param.requires_grad = False
self.sigloss = build_model(config.sigloss)
N_MIDAS_OUT = 32
btlnck_features = self.fine_branch.core.output_channels[0]
self.fusion_conv_list = nn.ModuleList()
for i in range(6):
if i == 5:
layer = nn.Conv2d(N_MIDAS_OUT * 2, N_MIDAS_OUT, 3, 1, 1)
else:
layer = nn.Conv2d(btlnck_features * 2, btlnck_features, 3, 1, 1)
self.fusion_conv_list.append(layer)
self.guided_fusion = build_model(config.guided_fusion)
# NOTE: a decoder head
if self.coarse_branch_cfg.bin_centers_type == "normed":
SeedBinRegressorLayer = SeedBinRegressor
Attractor = AttractorLayer
elif self.coarse_branch_cfg.bin_centers_type == "softplus": # default
SeedBinRegressorLayer = SeedBinRegressorUnnormed
Attractor = AttractorLayerUnnormed
elif self.coarse_branch_cfg.bin_centers_type == "hybrid1":
SeedBinRegressorLayer = SeedBinRegressor
Attractor = AttractorLayerUnnormed
elif self.coarse_branch_cfg.bin_centers_type == "hybrid2":
SeedBinRegressorLayer = SeedBinRegressorUnnormed
Attractor = AttractorLayer
else:
raise ValueError(
"bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'")
N_MIDAS_OUT = 32
btlnck_features = self.fine_branch.core.output_channels[0]
num_out_features = self.fine_branch.core.output_channels[1:] # all of them are the same
self.seed_bin_regressor = SeedBinRegressorLayer(
btlnck_features, n_bins=self.coarse_branch_cfg.n_bins, min_depth=config.min_depth, max_depth=config.max_depth)
self.seed_projector = Projector(btlnck_features, self.coarse_branch_cfg.bin_embedding_dim)
self.projectors = nn.ModuleList([
Projector(num_out, self.coarse_branch_cfg.bin_embedding_dim)
for num_out in num_out_features
])
# 1000, 2, inv, mean
self.attractors = nn.ModuleList([
Attractor(self.coarse_branch_cfg.bin_embedding_dim, self.coarse_branch_cfg.n_bins, n_attractors=self.coarse_branch_cfg.n_attractors[i], min_depth=config.min_depth, max_depth=config.max_depth,
alpha=self.coarse_branch_cfg.attractor_alpha, gamma=self.coarse_branch_cfg.attractor_gamma, kind=self.coarse_branch_cfg.attractor_kind, attractor_type=self.coarse_branch_cfg.attractor_type)
for i in range(len(num_out_features))
])
last_in = N_MIDAS_OUT + 1 # +1 for relative depth
# use log binomial instead of softmax
self.conditional_log_binomial = ConditionalLogBinomial(
last_in, self.coarse_branch_cfg.bin_embedding_dim, n_classes=self.coarse_branch_cfg.n_bins, min_temp=self.coarse_branch_cfg.min_temp, max_temp=self.coarse_branch_cfg.max_temp)
# NOTE: consistency training
self.consistency_training = False
def load_dict(self, dict):
return self.load_state_dict(dict, strict=False)
def get_save_dict(self):
current_model_dict = self.state_dict()
save_state_dict = {}
for k, v in current_model_dict.items():
if 'coarse_branch' in k or 'fine_branch' in k:
pass
else:
save_state_dict[k] = v
return save_state_dict
def coarse_forward(self, image_lr):
with torch.no_grad():
if self.coarse_branch.training:
self.coarse_branch.eval()
deep_model_output_dict = self.coarse_branch(image_lr, return_final_centers=True)
deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4]
coarse_prediction = deep_model_output_dict['metric_depth']
coarse_features = [
deep_features['x_d0'],
deep_features['x_blocks_feat_0'],
deep_features['x_blocks_feat_1'],
deep_features['x_blocks_feat_2'],
deep_features['x_blocks_feat_3'],
deep_features['midas_final_feat']] # bs, c, h, w
return coarse_prediction, coarse_features
def fine_forward(self, image_hr_crop):
with torch.no_grad():
if self.fine_branch.training:
self.fine_branch.eval()
deep_model_output_dict = self.fine_branch(image_hr_crop, return_final_centers=True)
deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4]
fine_prediction = deep_model_output_dict['metric_depth']
fine_features = [
deep_features['x_d0'],
deep_features['x_blocks_feat_0'],
deep_features['x_blocks_feat_1'],
deep_features['x_blocks_feat_2'],
deep_features['x_blocks_feat_3'],
deep_features['midas_final_feat']] # bs, c, h, w
return fine_prediction, fine_features
def coarse_postprocess_train(self, coarse_prediction, coarse_features, bboxs, bboxs_feat):
coarse_features_patch_area = []
for idx, feat in enumerate(coarse_features):
bs, _, h, w = feat.shape
cur_lvl_feat = torch_roi_align(feat, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True)
coarse_features_patch_area.append(cur_lvl_feat)
coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True)
return coarse_prediction_roi, coarse_features_patch_area
def coarse_postprocess_test(self, coarse_prediction, coarse_features, bboxs, bboxs_feat):
patch_num = bboxs_feat.shape[0]
coarse_features_patch_area = []
for idx, feat in enumerate(coarse_features):
bs, _, h, w = feat.shape
feat_extend = feat.repeat(patch_num, 1, 1, 1)
cur_lvl_feat = torch_roi_align(feat_extend, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True)
coarse_features_patch_area.append(cur_lvl_feat)
coarse_prediction = coarse_prediction.repeat(patch_num, 1, 1, 1)
coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True)
return_dict = {
'coarse_depth_roi': coarse_prediction_roi,
'coarse_feats_roi': coarse_features_patch_area}
return return_dict
def fusion_forward(self, fine_depth_pred, crop_input, coarse_model_midas_enc_feats, fine_model_midas_enc_feats, bbox_feat, coarse_depth_roi=None, coarse_feats_roi=None):
feat_cat_list = []
feat_plus_list = []
for l_i, (f_ca, f_c_roi, f_f) in enumerate(zip(coarse_model_midas_enc_feats, coarse_feats_roi, fine_model_midas_enc_feats)):
feat_cat = self.fusion_conv_list[l_i](torch.cat([f_c_roi, f_f], dim=1))
feat_plus = f_c_roi + f_f
feat_cat_list.append(feat_cat)
feat_plus_list.append(feat_plus)
input_tensor = torch.cat([coarse_depth_roi, fine_depth_pred, crop_input], dim=1)
# HACK: hack for depth-anything
# if self.coarse_branch_cfg.type == 'DA-ZoeDepth':
# input_tensor = F.interpolate(input_tensor, size=(448, 592), mode='bilinear', align_corners=True)
output = self.guided_fusion(
input_tensor = input_tensor,
guide_plus = feat_plus_list,
guide_cat = feat_cat_list,
bbox = bbox_feat,
fine_feat_crop = fine_model_midas_enc_feats,
coarse_feat_whole = coarse_model_midas_enc_feats,
coarse_feat_crop = coarse_feats_roi,
coarse_feat_whole_hack=None)[::-1] # low -> high
x_blocks = output
x = x_blocks[0]
x_blocks = x_blocks[1:]
proj_feat_list = []
if self.consistency_training:
if self.consistency_target == 'unet_feat':
proj_feat_list = []
for idx, feat in enumerate(output):
proj_feat = self.consistency_projs[idx](feat)
proj_feat_list.append(proj_feat)
# NOTE: below is ZoeDepth implementation
last = x_blocks[-1] # have already been fused in x_blocks
bs, c, h, w = last.shape
rel_cond = torch.zeros((bs, 1, h, w), device=last.device)
_, seed_b_centers = self.seed_bin_regressor(x)
if self.coarse_branch_cfg.bin_centers_type == 'normed' or self.coarse_branch_cfg.bin_centers_type == 'hybrid2':
b_prev = (seed_b_centers - self.min_depth) / \
(self.max_depth - self.min_depth)
else:
b_prev = seed_b_centers
prev_b_embedding = self.seed_projector(x)
# unroll this loop for better performance
for idx, (projector, attractor, x) in enumerate(zip(self.projectors, self.attractors, x_blocks)):
b_embedding = projector(x)
b, b_centers = attractor(
b_embedding, b_prev, prev_b_embedding, interpolate=True)
b_prev = b.clone()
prev_b_embedding = b_embedding.clone()
if self.consistency_training:
if self.consistency_target == 'final_feat':
proj_feat_1 = self.consistency_projs[0](b_centers)
proj_feat_2 = self.consistency_projs[1](last)
proj_feat_3 = self.consistency_projs[2](b_embedding)
proj_feat_list = [proj_feat_1, proj_feat_2, proj_feat_3]
rel_cond = nn.functional.interpolate(
rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True)
last = torch.cat([last, rel_cond], dim=1) # + self.coarse_depth_proj(whole_depth_roi_pred) + self.fine_depth_proj(fine_depth_pred)
b_embedding = nn.functional.interpolate(
b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
# till here, we have features (attached with a relative depth prediction) and embeddings
# post process
# final_pred = out * self.blur_mask + whole_depth_roi_pred * (1-self.blur_mask)
# out = F.interpolate(out, (540, 960), mode='bilinear', align_corners=True)
x = self.conditional_log_binomial(last, b_embedding)
b_centers = nn.functional.interpolate(
b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
out = torch.sum(x * b_centers, dim=1, keepdim=True)
return out, proj_feat_list
def infer_forward(self, imgs_crop, bbox_feat_forward, tile_temp, coarse_temp_dict):
fine_prediction, fine_features = self.fine_forward(imgs_crop)
depth_prediction, consistency_target = \
self.fusion_forward(
fine_prediction,
imgs_crop,
tile_temp['coarse_features'],
fine_features,
bbox_feat_forward,
**coarse_temp_dict)
return depth_prediction
def forward(
self,
mode,
image_lr,
image_hr,
depth_gt=None,
crops_image_hr=None,
crop_depths=None,
bboxs=None,
tile_cfg=None,
cai_mode='m1',
process_num=4):
if mode == 'train':
bboxs_feat_factor = torch.tensor([
1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1],
1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0],
1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1],
1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0)
bboxs_feat = bboxs * bboxs_feat_factor
inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1)
bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1)
coarse_prediction, coarse_features = self.coarse_forward(image_lr)
fine_prediction, fine_features = self.fine_forward(crops_image_hr)
coarse_prediction_roi, coarse_features_patch_area = self.coarse_postprocess_train(coarse_prediction, coarse_features, bboxs, bboxs_feat)
depth_prediction, consistency_target = self.fusion_forward(
fine_prediction,
crops_image_hr,
coarse_features,
fine_features,
bboxs_feat,
coarse_depth_roi=coarse_prediction_roi,
coarse_feats_roi=coarse_features_patch_area,)
loss_dict = {}
loss_dict['sig_loss'] = self.sigloss(depth_prediction, crop_depths, self.min_depth, self.max_depth)
loss_dict['total_loss'] = loss_dict['sig_loss']
return loss_dict, {'rgb': crops_image_hr, 'depth_pred': depth_prediction, 'depth_gt': crop_depths}
else:
if tile_cfg is None:
tile_cfg = self.tile_cfg
else:
tile_cfg = self.prepare_tile_cfg(tile_cfg['image_raw_shape'], tile_cfg['patch_split_num'])
assert image_hr.shape[0] == 1
coarse_prediction, coarse_features = self.coarse_forward(image_lr)
tile_temp = {
'coarse_prediction': coarse_prediction,
'coarse_features': coarse_features,}
blur_mask = generatemask((self.patch_process_shape[0], self.patch_process_shape[1])) + 1e-3
blur_mask = torch.tensor(blur_mask, device=image_hr.device)
avg_depth_map = self.regular_tile(
offset=[0, 0],
offset_process=[0, 0],
image_hr=image_hr[0],
init_flag=True,
tile_temp=tile_temp,
blur_mask=blur_mask,
tile_cfg=tile_cfg,
process_num=process_num)
if cai_mode == 'm2' or cai_mode[0] == 'r':
avg_depth_map = self.regular_tile(
offset=[0, tile_cfg['patch_raw_shape'][1]//2],
offset_process=[0, self.patch_process_shape[1]//2],
image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
avg_depth_map = self.regular_tile(
offset=[tile_cfg['patch_raw_shape'][0]//2, 0],
offset_process=[self.patch_process_shape[0]//2, 0],
image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
avg_depth_map = self.regular_tile(
offset=[tile_cfg['patch_raw_shape'][0]//2, tile_cfg['patch_raw_shape'][1]//2],
offset_process=[self.patch_process_shape[0]//2, self.patch_process_shape[1]//2],
init_flag=False, image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
if cai_mode[0] == 'r':
blur_mask = generatemask((tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1])) + 1e-3
blur_mask = torch.tensor(blur_mask, device=image_hr.device)
avg_depth_map.resize(tile_cfg['image_raw_shape'])
patch_num = int(cai_mode[1:]) // process_num
for i in range(patch_num):
avg_depth_map = self.random_tile(
image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num)
depth = avg_depth_map.average_map
depth = depth.unsqueeze(dim=0).unsqueeze(dim=0)
return depth, {'rgb': image_lr, 'depth_pred': depth, 'depth_gt': depth_gt}