Spaces:
Running
on
L4
Running
on
L4
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.nn.functional as F | |
# compute loss | |
class compute_loss(nn.Module): | |
def __init__(self, args): | |
"""args.loss_fn can be one of following: | |
- L1 - L1 loss (no uncertainty) | |
- L2 - L2 loss (no uncertainty) | |
- AL - Angular loss (no uncertainty) | |
- NLL_vMF - NLL of vonMF distribution | |
- NLL_ours - NLL of Angular vonMF distribution | |
- UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling) | |
- UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling) | |
""" | |
super(compute_loss, self).__init__() | |
self.loss_type = args.loss_fn | |
if self.loss_type in ['L1', 'L2', 'AL', 'NLL_vMF', 'NLL_ours']: | |
self.loss_fn = self.forward_R | |
elif self.loss_type in ['UG_NLL_vMF', 'UG_NLL_ours']: | |
self.loss_fn = self.forward_UG | |
else: | |
raise Exception('invalid loss type') | |
def forward(self, *args): | |
return self.loss_fn(*args) | |
def forward_R(self, norm_out, gt_norm, gt_norm_mask): | |
pred_norm, pred_kappa = norm_out[:, 0:3, :, :], norm_out[:, 3:, :, :] | |
if self.loss_type == 'L1': | |
l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True) | |
loss = torch.mean(l1[gt_norm_mask]) | |
elif self.loss_type == 'L2': | |
l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True) | |
loss = torch.mean(l2[gt_norm_mask]) | |
elif self.loss_type == 'AL': | |
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) | |
valid_mask = gt_norm_mask[:, 0, :, :].float() \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
valid_mask = valid_mask > 0.0 | |
al = torch.acos(dot[valid_mask]) | |
loss = torch.mean(al) | |
elif self.loss_type == 'NLL_vMF': | |
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) | |
valid_mask = gt_norm_mask[:, 0, :, :].float() \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
valid_mask = valid_mask > 0.0 | |
dot = dot[valid_mask] | |
kappa = pred_kappa[:, 0, :, :][valid_mask] | |
loss_pixelwise = - torch.log(kappa) \ | |
- (kappa * (dot - 1)) \ | |
+ torch.log(1 - torch.exp(- 2 * kappa)) | |
loss = torch.mean(loss_pixelwise) | |
elif self.loss_type == 'NLL_ours': | |
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) | |
valid_mask = gt_norm_mask[:, 0, :, :].float() \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
valid_mask = valid_mask > 0.0 | |
dot = dot[valid_mask] | |
kappa = pred_kappa[:, 0, :, :][valid_mask] | |
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \ | |
+ kappa * torch.acos(dot) \ | |
+ torch.log(1 + torch.exp(-kappa * np.pi)) | |
loss = torch.mean(loss_pixelwise) | |
else: | |
raise Exception('invalid loss type') | |
return loss | |
def forward_UG(self, pred_list, coord_list, gt_norm, gt_norm_mask): | |
loss = 0.0 | |
for (pred, coord) in zip(pred_list, coord_list): | |
if coord is None: | |
pred = F.interpolate(pred, size=[gt_norm.size(2), gt_norm.size(3)], mode='bilinear', align_corners=True) | |
pred_norm, pred_kappa = pred[:, 0:3, :, :], pred[:, 3:, :, :] | |
if self.loss_type == 'UG_NLL_vMF': | |
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) | |
valid_mask = gt_norm_mask[:, 0, :, :].float() \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
valid_mask = valid_mask > 0.5 | |
# mask | |
dot = dot[valid_mask] | |
kappa = pred_kappa[:, 0, :, :][valid_mask] | |
loss_pixelwise = - torch.log(kappa) \ | |
- (kappa * (dot - 1)) \ | |
+ torch.log(1 - torch.exp(- 2 * kappa)) | |
loss = loss + torch.mean(loss_pixelwise) | |
elif self.loss_type == 'UG_NLL_ours': | |
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1) | |
valid_mask = gt_norm_mask[:, 0, :, :].float() \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
valid_mask = valid_mask > 0.5 | |
dot = dot[valid_mask] | |
kappa = pred_kappa[:, 0, :, :][valid_mask] | |
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \ | |
+ kappa * torch.acos(dot) \ | |
+ torch.log(1 + torch.exp(-kappa * np.pi)) | |
loss = loss + torch.mean(loss_pixelwise) | |
else: | |
raise Exception | |
else: | |
# coord: B, 1, N, 2 | |
# pred: B, 4, N | |
gt_norm_ = F.grid_sample(gt_norm, coord, mode='nearest', align_corners=True) # (B, 3, 1, N) | |
gt_norm_mask_ = F.grid_sample(gt_norm_mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N) | |
gt_norm_ = gt_norm_[:, :, 0, :] # (B, 3, N) | |
gt_norm_mask_ = gt_norm_mask_[:, :, 0, :] > 0.5 # (B, 1, N) | |
pred_norm, pred_kappa = pred[:, 0:3, :], pred[:, 3:, :] | |
if self.loss_type == 'UG_NLL_vMF': | |
dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N) | |
valid_mask = gt_norm_mask_[:, 0, :].float() \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
valid_mask = valid_mask > 0.5 | |
dot = dot[valid_mask] | |
kappa = pred_kappa[:, 0, :][valid_mask] | |
loss_pixelwise = - torch.log(kappa) \ | |
- (kappa * (dot - 1)) \ | |
+ torch.log(1 - torch.exp(- 2 * kappa)) | |
loss = loss + torch.mean(loss_pixelwise) | |
elif self.loss_type == 'UG_NLL_ours': | |
dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N) | |
valid_mask = gt_norm_mask_[:, 0, :].float() \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
valid_mask = valid_mask > 0.5 | |
dot = dot[valid_mask] | |
kappa = pred_kappa[:, 0, :][valid_mask] | |
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \ | |
+ kappa * torch.acos(dot) \ | |
+ torch.log(1 + torch.exp(-kappa * np.pi)) | |
loss = loss + torch.mean(loss_pixelwise) | |
else: | |
raise Exception | |
return loss | |