zach
initial commit based on github repo
3ef1661
import torch
import torch.nn as nn
class ScaleInvL1Loss(nn.Module):
"""
Compute scale-invariant L1 loss.
"""
def __init__(self, loss_weight=1, data_type=['sfm', 'denselidar_nometric', 'denselidar_syn'], **kwargs):
super(ScaleInvL1Loss, self).__init__()
self.loss_weight = loss_weight
self.data_type = data_type
self.eps = 1e-6
def forward(self, prediction, target, mask=None, **kwargs):
B, _, _, _ = target.shape
target_nan = target.clone()
target_nan[~mask] = torch.nan
median_target = torch.nanmedian(target_nan.view(B, -1), dim=1)[0]
prediction_nan = prediction.clone().detach()
prediction_nan[~mask] = torch.nan
median_prediction = torch.nanmedian(prediction_nan.view(B, -1), dim=1)[0]
scale = median_target / median_prediction
scale[torch.isnan(scale)] = 0
pred_scale = prediction * scale[:, None, None, None]
target_valid = target * mask
pred_valid = pred_scale * mask
diff = torch.abs(pred_valid - target_valid)
# disp_diff = diff / (target_valid + self.eps)
loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
if torch.isnan(loss).item() | torch.isinf(loss).item():
loss = 0 * torch.sum(prediction)
print(f'Scale-invariant L1 NAN error, {loss}')
#raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
return loss * self.loss_weight