Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Zhenyu Li | |
import itertools | |
import math | |
import copy | |
import random | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmengine import print_log | |
from estimator.registry import MODELS | |
from estimator.models import build_model | |
from estimator.models.utils import get_activation | |
from zoedepth.models.zoedepth import ZoeDepth | |
import matplotlib.pyplot as plt | |
from estimator.models.utils import get_activation, generatemask, RunningAverageMap | |
from zoedepth.models.base_models.midas import Resize as ResizeZoe | |
from depth_anything.transform import Resize as ResizeDA | |
class BaselinePretrain(nn.Module): | |
def __init__(self, | |
coarse_branch, | |
fine_branch, | |
sigloss, | |
min_depth, | |
max_depth, | |
image_raw_shape=(2160, 3840), | |
patch_process_shape=(384, 512), | |
patch_split_num=(4, 4), | |
target='coarse', | |
coarse_branch_zoe=None): | |
"""ZoeDepth model | |
""" | |
super().__init__() | |
self.patch_process_shape = patch_process_shape | |
self.tile_cfg = self.prepare_tile_cfg(image_raw_shape, patch_split_num) | |
self.min_depth = min_depth | |
self.max_depth = max_depth | |
self.coarse_branch_cfg = coarse_branch | |
self.fine_branch_cfg = fine_branch | |
if target == 'coarse': | |
if self.coarse_branch_cfg.type == 'ZoeDepth': | |
self.coarse_branch = ZoeDepth.build(**coarse_branch) | |
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.coarse_branch.core.prep.resizer)), logger='current') | |
self.resizer = ResizeZoe(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") | |
elif self.coarse_branch_cfg.type == 'DA-ZoeDepth': | |
self.coarse_branch = ZoeDepth.build(**coarse_branch) | |
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.coarse_branch.core.prep.resizer)), logger='current') | |
self.resizer = ResizeDA(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal") | |
if target == 'fine': | |
if self.fine_branch_cfg.type == 'ZoeDepth': | |
self.fine_branch = ZoeDepth.build(**fine_branch) | |
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.fine_branch.core.prep.resizer)), logger='current') | |
self.resizer = ResizeZoe(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") | |
elif self.fine_branch_cfg.type == 'DA-ZoeDepth': | |
self.fine_branch = ZoeDepth.build(**fine_branch) | |
print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.fine_branch.core.prep.resizer)), logger='current') | |
self.resizer = ResizeDA(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal") | |
self.sigloss = build_model(sigloss) | |
self.target = target | |
def prepare_tile_cfg(self, image_raw_shape, patch_split_num): | |
# information for process | |
patch_split_num = patch_split_num | |
patch_reensemble_shape = (self.patch_process_shape[0] * patch_split_num[0], self.patch_process_shape[1] * patch_split_num[1]) | |
patch_raw_shape = (image_raw_shape[0] // patch_split_num[0], image_raw_shape[1] // patch_split_num[1]) | |
image_raw_shape = image_raw_shape | |
raw_h_split_point = [] | |
raw_w_split_point = [] | |
for i in range(patch_split_num[0]): | |
raw_h_split_point.append(int(patch_raw_shape[0] * i)) | |
for i in range(patch_split_num[1]): | |
raw_w_split_point.append(int(patch_raw_shape[1] * i)) | |
tile_cfg = { | |
'patch_split_num': patch_split_num, | |
'patch_reensemble_shape': patch_reensemble_shape, | |
'patch_raw_shape': patch_raw_shape, | |
'image_raw_shape': image_raw_shape, | |
'raw_h_split_point': raw_h_split_point, | |
'raw_w_split_point': raw_w_split_point} | |
return tile_cfg | |
def load_dict(self, dict): | |
if hasattr(self, 'coarse_branch') and hasattr(self, 'fine_branch') == False: | |
return self.coarse_branch.load_state_dict(dict, strict=True) | |
elif hasattr(self, 'fine_branch') and hasattr(self, 'coarse_branch') == False: | |
return self.fine_branch.load_state_dict(dict, strict=True) | |
else: | |
raise NotImplementedError('Not support loading coarse and fine together') | |
def get_save_dict(self): | |
model_state_dict = {} | |
if hasattr(self, 'coarse_branch') and hasattr(self, 'fine_branch') == False: | |
model_state_dict.update(self.coarse_branch.state_dict()) | |
elif hasattr(self, 'fine_branch') and hasattr(self, 'coarse_branch') == False: | |
model_state_dict.update(self.fine_branch.state_dict()) | |
else: | |
raise NotImplementedError('Not support training coarse and fine together') | |
return model_state_dict | |
def infer_forward(self, imgs_crop): | |
output_dict = self.fine_branch(imgs_crop) | |
return output_dict['metric_depth'] | |
def random_tile( | |
self, | |
image_hr, | |
tile_temp=None, | |
blur_mask=None, | |
avg_depth_map=None, | |
tile_cfg=None, | |
process_num=4,): | |
## setting | |
height, width = tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1] | |
h_start_list = [random.randint(0, tile_cfg['image_raw_shape'][0] - height - 1) for _ in range(process_num)] | |
w_start_list = [random.randint(0, tile_cfg['image_raw_shape'][1] - width - 1)] | |
## prepare data | |
imgs_crop = [] | |
bboxs = [] | |
for h_start in h_start_list: | |
for w_start in w_start_list: | |
crop_image = image_hr[:, h_start: h_start+height, w_start: w_start+width] | |
crop_image_resized = self.resizer(crop_image.unsqueeze(dim=0)).squeeze(dim=0) # resize to patch_process_shape | |
bbox = torch.tensor([w_start, h_start, w_start+width, h_start+height]) | |
imgs_crop.append(crop_image_resized) | |
bboxs.append(bbox) | |
imgs_crop = torch.stack(imgs_crop, dim=0) | |
bboxs = torch.stack(bboxs, dim=0) | |
imgs_crop = imgs_crop.to(image_hr.device) | |
bboxs = bboxs.to(image_hr.device).int() | |
bboxs_feat_factor = torch.tensor([ | |
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0], | |
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0) | |
bboxs_feat = bboxs * bboxs_feat_factor | |
inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1) | |
bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1) | |
if tile_temp is not None: | |
coarse_postprocess_dict = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp) | |
prediction_list = [] | |
if tile_temp is not None: | |
coarse_temp_dict = {} | |
for k, v in coarse_postprocess_dict.items(): | |
if k == 'coarse_feats_roi': | |
coarse_temp_dict[k] = [f for f in v] | |
else: | |
coarse_temp_dict[k] = v | |
bbox_feat_forward = bboxs_feat | |
bbox_feat_forward[:, 0] = 0 | |
prediction = self.infer_forward(imgs_crop, bbox_feat_forward, tile_temp, coarse_temp_dict) | |
else: | |
prediction = self.infer_forward(imgs_crop) | |
prediction_list.append(prediction) | |
predictions = torch.cat(prediction_list, dim=0) | |
predictions = F.interpolate(predictions, tile_cfg['patch_raw_shape']) | |
patch_select_idx = 0 | |
for h_start in h_start_list: | |
for w_start in w_start_list: | |
temp_depth = predictions[patch_select_idx] | |
count_map = torch.zeros(tile_cfg['image_raw_shape'], device=temp_depth.device) | |
pred_depth = torch.zeros(tile_cfg['image_raw_shape'], device=temp_depth.device) | |
count_map[h_start: h_start+tile_cfg['patch_raw_shape'][0], w_start: w_start+tile_cfg['patch_raw_shape'][1]] = blur_mask | |
pred_depth[h_start: h_start+tile_cfg['patch_raw_shape'][0], w_start: w_start+tile_cfg['patch_raw_shape'][1]] = temp_depth * blur_mask | |
avg_depth_map.update(pred_depth, count_map) | |
patch_select_idx += 1 | |
return avg_depth_map | |
def regular_tile( | |
self, | |
offset, | |
offset_process, | |
image_hr, | |
init_flag=False, | |
tile_temp=None, | |
blur_mask=None, | |
avg_depth_map=None, | |
tile_cfg=None, | |
process_num=4,): | |
## setting | |
height, width = tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1] | |
offset_h, offset_w = offset[0], offset[1] | |
assert offset_w >= 0 and offset_h >= 0 | |
tile_num_h = (tile_cfg['image_raw_shape'][0] - offset_h) // height | |
tile_num_w = (tile_cfg['image_raw_shape'][1] - offset_w) // width | |
h_start_list = [height * h + offset_h for h in range(tile_num_h)] | |
w_start_list = [width * w + offset_w for w in range(tile_num_w)] | |
height_process, width_process = self.patch_process_shape[0], self.patch_process_shape[1] | |
offset_h_process, offset_w_process = offset_process[0], offset_process[1] | |
assert offset_h_process >= 0 and offset_w_process >= 0 | |
tile_num_h_process = (tile_cfg['patch_reensemble_shape'][0] - offset_h_process) // height_process | |
tile_num_w_process = (tile_cfg['patch_reensemble_shape'][1] - offset_w_process) // width_process | |
h_start_list_process = [height_process * h + offset_h_process for h in range(tile_num_h_process)] | |
w_start_list_process = [width_process * w + offset_w_process for w in range(tile_num_w_process)] | |
## prepare data | |
imgs_crop = [] | |
bboxs = [] | |
iter_priors = [] | |
for h_start in h_start_list: | |
for w_start in w_start_list: | |
crop_image = image_hr[:, h_start: h_start+height, w_start: w_start+width] | |
crop_image_resized = self.resizer(crop_image.unsqueeze(dim=0)).squeeze(dim=0) # resize to patch_process_shape | |
bbox = torch.tensor([w_start, h_start, w_start+width, h_start+height]) | |
imgs_crop.append(crop_image_resized) | |
bboxs.append(bbox) | |
imgs_crop = torch.stack(imgs_crop, dim=0) | |
bboxs = torch.stack(bboxs, dim=0) | |
imgs_crop = imgs_crop.to(image_hr.device) | |
bboxs = bboxs.to(image_hr.device).int() | |
bboxs = bboxs.squeeze() # HACK: during inference, 1, 16, 4 -> 16, 4 | |
if len(bboxs.shape) == 1: | |
bboxs = bboxs.unsqueeze(dim=0) | |
bboxs_feat_factor = torch.tensor([ | |
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0], | |
1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0) | |
bboxs_feat = bboxs * bboxs_feat_factor | |
inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1) | |
bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1) | |
# post_process | |
if tile_temp is not None: | |
# coarse_prediction_roi, coarse_features_patch_area, crop_coarse_prediction_collection = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp) | |
coarse_postprocess_dict = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp) | |
count_map = torch.zeros(tile_cfg['patch_reensemble_shape'], device=image_hr.device) | |
pred_depth = torch.zeros(tile_cfg['patch_reensemble_shape'], device=image_hr.device) | |
prediction_list = [] | |
split_rebatch_image = torch.split(imgs_crop, process_num, dim=0) | |
for idx, rebatch_image in enumerate(split_rebatch_image): | |
if tile_temp is not None: | |
coarse_temp_dict = {} | |
for k, v in coarse_postprocess_dict.items(): | |
if k == 'coarse_feats_roi': | |
coarse_temp_dict[k] = [f[idx*process_num:(idx+1)*process_num, :, :, :] for f in v] | |
else: | |
coarse_temp_dict[k] = v[idx*process_num:(idx+1)*process_num, :, :, :] | |
bbox_feat_forward = bboxs_feat[idx*process_num:(idx+1)*process_num, :] | |
bbox_feat_forward[:, 0] = 0 | |
prediction = self.infer_forward(rebatch_image, bbox_feat_forward, tile_temp, coarse_temp_dict) | |
else: | |
prediction = self.infer_forward(rebatch_image) | |
prediction_list.append(prediction) | |
predictions = torch.cat(prediction_list, dim=0) | |
patch_select_idx = 0 | |
for h_start in h_start_list_process: | |
for w_start in w_start_list_process: | |
temp_depth = predictions[patch_select_idx] | |
if init_flag: | |
count_map[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = blur_mask | |
pred_depth[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = temp_depth * blur_mask | |
else: | |
count_map = torch.zeros(tile_cfg['patch_reensemble_shape'], device=temp_depth.device) | |
pred_depth = torch.zeros(tile_cfg['patch_reensemble_shape'], device=temp_depth.device) | |
count_map[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = blur_mask | |
pred_depth[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = temp_depth * blur_mask | |
avg_depth_map.update(pred_depth, count_map) | |
patch_select_idx += 1 | |
if init_flag: | |
avg_depth_map = RunningAverageMap(pred_depth, count_map) | |
return avg_depth_map | |
def forward( | |
self, | |
mode, | |
image_lr, | |
image_hr, | |
depth_gt, | |
crop_depths=None, | |
crops_image_hr=None, | |
bboxs=None, | |
tile_cfg=None, | |
cai_mode='m1', | |
process_num=4, | |
**kwargs): | |
if mode == 'train': | |
loss_dict = {} | |
if self.target == 'coarse': | |
model_output_dict = self.coarse_branch(image_lr) | |
depth_prediction = model_output_dict['metric_depth'] | |
loss_dict['coarse_loss'] = self.sigloss(depth_prediction, depth_gt, self.min_depth, self.max_depth) | |
loss_dict['total_loss'] = loss_dict['coarse_loss'] | |
return loss_dict, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': depth_gt} | |
elif self.target == 'fine': | |
model_output_dict = self.fine_branch(crops_image_hr) # 1/2 res, 1/4 res, 1/8 res, 1/16 res | |
depth_prediction = model_output_dict['metric_depth'] | |
loss_dict['fine_loss'] = self.sigloss(depth_prediction, crop_depths, self.min_depth, self.max_depth) | |
loss_dict['total_loss'] = loss_dict['fine_loss'] | |
return loss_dict, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': crop_depths} | |
else: | |
raise NotImplementedError | |
else: | |
if self.target == 'coarse': | |
model_output_dict = self.coarse_branch(image_lr) | |
depth_prediction = model_output_dict['metric_depth'] | |
elif self.target == 'fine': | |
if tile_cfg is None: | |
tile_cfg = self.tile_cfg | |
else: | |
tile_cfg = self.prepare_tile_cfg(tile_cfg['image_raw_shape'], tile_cfg['patch_split_num']) | |
assert image_hr.shape[0] == 1 | |
blur_mask = generatemask((self.patch_process_shape[0], self.patch_process_shape[1])) + 1e-3 | |
blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
avg_depth_map = self.regular_tile( | |
offset=[0, 0], | |
offset_process=[0, 0], | |
image_hr=image_hr[0], | |
init_flag=True, | |
tile_temp=None, | |
blur_mask=blur_mask, | |
tile_cfg=tile_cfg, | |
process_num=process_num) | |
if cai_mode == 'm2' or cai_mode[0] == 'r': | |
avg_depth_map = self.regular_tile( | |
offset=[0, tile_cfg['patch_raw_shape'][1]//2], | |
offset_process=[0, self.patch_process_shape[1]//2], | |
image_hr=image_hr[0], init_flag=False, tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
avg_depth_map = self.regular_tile( | |
offset=[tile_cfg['patch_raw_shape'][0]//2, 0], | |
offset_process=[self.patch_process_shape[0]//2, 0], | |
image_hr=image_hr[0], init_flag=False, tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
avg_depth_map = self.regular_tile( | |
offset=[tile_cfg['patch_raw_shape'][0]//2, tile_cfg['patch_raw_shape'][1]//2], | |
offset_process=[self.patch_process_shape[0]//2, self.patch_process_shape[1]//2], | |
init_flag=False, image_hr=image_hr[0], tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
if cai_mode[0] == 'r': | |
blur_mask = generatemask((tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1])) + 1e-3 | |
blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
avg_depth_map.resize(tile_cfg['image_raw_shape']) | |
patch_num = int(cai_mode[1:]) | |
for i in range(patch_num): | |
avg_depth_map = self.random_tile( | |
image_hr=image_hr[0], tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
depth = avg_depth_map.average_map | |
depth = depth.unsqueeze(dim=0).unsqueeze(dim=0) | |
return depth, {} | |
else: | |
raise NotImplementedError | |
return depth_prediction, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': depth_gt} | |