|
import torch |
|
import torch.nn as nn |
|
from mono.utils.comm import get_func |
|
import numpy as np |
|
import torch.nn.functional as F |
|
|
|
|
|
class BaseDepthModel(nn.Module): |
|
def __init__(self, cfg, criterions, **kwards): |
|
super(BaseDepthModel, self).__init__() |
|
model_type = cfg.model.type |
|
self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg) |
|
|
|
self.criterions_main = criterions['decoder_losses'] if criterions and 'decoder_losses' in criterions else None |
|
self.criterions_auxi = criterions['auxi_losses'] if criterions and 'auxi_losses' in criterions else None |
|
self.criterions_pose = criterions['pose_losses'] if criterions and 'pose_losses' in criterions else None |
|
self.criterions_gru = criterions['gru_losses'] if criterions and 'gru_losses' in criterions else None |
|
try: |
|
self.downsample = cfg.prediction_downsample |
|
except: |
|
self.downsample = None |
|
|
|
self.training = True |
|
|
|
def forward(self, data): |
|
if self.downsample != None: |
|
self.label_downsample(self.downsample, data) |
|
|
|
output = self.depth_model(**data) |
|
|
|
losses_dict = {} |
|
if self.training: |
|
output.update(data) |
|
losses_dict = self.get_loss(output) |
|
|
|
if self.downsample != None: |
|
self.pred_upsample(self.downsample, output) |
|
|
|
return output['prediction'], losses_dict, output['confidence'] |
|
|
|
def inference(self, data): |
|
with torch.no_grad(): |
|
output = self.depth_model(**data) |
|
output.update(data) |
|
|
|
if self.downsample != None: |
|
self.pred_upsample(self.downsample, output) |
|
|
|
output['dataset'] = 'wild' |
|
return output |
|
|
|
def get_loss(self, paras): |
|
losses_dict = {} |
|
|
|
if self.training: |
|
|
|
losses_dict.update(self.compute_decoder_loss(paras)) |
|
|
|
losses_dict.update(self.compute_auxi_loss(paras)) |
|
|
|
losses_dict.update(self.compute_pose_loss(paras)) |
|
|
|
losses_dict.update(self.compute_gru_loss(paras)) |
|
|
|
total_loss = sum(losses_dict.values()) |
|
losses_dict['total_loss'] = total_loss |
|
return losses_dict |
|
|
|
def compute_gru_loss(self, paras_): |
|
losses_dict = {} |
|
if self.criterions_gru is None or len(self.criterions_gru) == 0: |
|
return losses_dict |
|
paras = {k:v for k,v in paras_.items() if k!='prediction' and k!='prediction_normal'} |
|
n_predictions = len(paras['predictions_list']) |
|
for i, pre in enumerate(paras['predictions_list']): |
|
if i == n_predictions-1: |
|
break |
|
|
|
|
|
if 'normal_out_list' in paras.keys(): |
|
pre_normal = paras['normal_out_list'][i] |
|
else: |
|
pre_normal = None |
|
iter_dict = self.branch_loss( |
|
prediction=pre, |
|
prediction_normal=pre_normal, |
|
criterions=self.criterions_gru, |
|
branch=f'gru_{i}', |
|
**paras |
|
) |
|
|
|
adjusted_loss_gamma = 0.9**(15/(n_predictions - 1)) |
|
i_weight = adjusted_loss_gamma**(n_predictions - i - 1) |
|
iter_dict = {k:v*i_weight for k,v in iter_dict.items()} |
|
losses_dict.update(iter_dict) |
|
return losses_dict |
|
|
|
def compute_decoder_loss(self, paras): |
|
losses_dict = {} |
|
decode_losses_dict = self.branch_loss( |
|
criterions=self.criterions_main, |
|
branch='decode', |
|
**paras |
|
) |
|
return decode_losses_dict |
|
|
|
def compute_auxi_loss(self, paras): |
|
losses_dict = {} |
|
if len(self.criterions_auxi) == 0: |
|
return losses_dict |
|
args = dict( |
|
target=paras['target'], |
|
data_type=paras['data_type'], |
|
sem_mask=paras['sem_mask'], |
|
) |
|
for i, auxi_logit in enumerate(paras['auxi_logit_list']): |
|
auxi_losses_dict = self.branch_loss( |
|
prediction=paras['auxi_pred'][i], |
|
criterions=self.criterions_auxi, |
|
pred_logit=auxi_logit, |
|
branch=f'auxi_{i}', |
|
**args |
|
) |
|
losses_dict.update(auxi_losses_dict) |
|
return losses_dict |
|
|
|
def compute_pose_loss(self, paras): |
|
losses_dict = {} |
|
if self.criterions_pose is None or len(self.criterions_pose) == 0: |
|
return losses_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for loss_method in self.criterions_pose: |
|
loss_tmp = loss_method(**paras) |
|
losses_dict['pose_' + loss_method._get_name()] = loss_tmp |
|
return losses_dict |
|
|
|
def branch_loss(self, prediction, pred_logit, criterions, branch='decode', **kwargs): |
|
B, _, _, _ = prediction.shape |
|
losses_dict = {} |
|
args = dict(pred_logit=pred_logit) |
|
|
|
target = kwargs.pop('target') |
|
args.update(kwargs) |
|
|
|
|
|
batches_data_type = np.array(kwargs['data_type']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = target > 1e-8 |
|
for loss_method in criterions: |
|
|
|
new_mask = self.create_mask_as_loss(loss_method, mask, batches_data_type) |
|
|
|
loss_tmp = loss_method( |
|
prediction=prediction, |
|
target=target, |
|
mask=new_mask, |
|
**args) |
|
losses_dict[branch + '_' + loss_method._get_name()] = loss_tmp |
|
return losses_dict |
|
|
|
def create_mask_as_loss(self, loss_method, mask, batches_data_type): |
|
data_type_req = np.array(loss_method.data_type)[:, None] |
|
batch_mask = torch.tensor(np.any(data_type_req == batches_data_type, axis=0), device="cuda") |
|
new_mask = mask * batch_mask[:, None, None, None] |
|
return new_mask |
|
|
|
def label_downsample(self, downsample_factor, data_dict): |
|
scale_factor = float(1.0 / downsample_factor) |
|
downsample_target = F.interpolate(data_dict['target'], scale_factor=scale_factor) |
|
downsample_stereo_depth = F.interpolate(data_dict['stereo_depth'], scale_factor=scale_factor) |
|
|
|
data_dict['target'] = downsample_target |
|
data_dict['stereo_depth'] = downsample_stereo_depth |
|
|
|
return data_dict |
|
|
|
def pred_upsample(self, downsample_factor, data_dict): |
|
scale_factor = float(downsample_factor) |
|
upsample_prediction = F.interpolate(data_dict['prediction'], scale_factor=scale_factor).detach() |
|
upsample_confidence = F.interpolate(data_dict['confidence'], scale_factor=scale_factor).detach() |
|
|
|
data_dict['prediction'] = upsample_prediction |
|
data_dict['confidence'] = upsample_confidence |
|
|
|
return data_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def min_pool2d(tensor, kernel, stride=1): |
|
tensor = tensor * -1.0 |
|
tensor = F.max_pool2d(tensor, kernel, padding=kernel//2, stride=stride) |
|
tensor = -1.0 * tensor |
|
return tensor |