File size: 4,817 Bytes
3ef1661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
#from pytorch3d.loss import chamfer_distance

class AdabinsLoss(nn.Module):
    """
    Losses employed in Adabins.
    """
    def __init__(self, depth_normalize, variance_focus=0.85, loss_weight=1, out_channel=100, data_type=['stereo', 'lidar'],  w_ce=False, w_chamber=False, **kwargs):
        super(AdabinsLoss, self).__init__()
        self.variance_focus = variance_focus
        self.loss_weight = loss_weight
        self.data_type = data_type
        #self.bins_num = out_channel
        #self.cel = nn.CrossEntropyLoss(ignore_index=self.bins_num + 1)
        self.depth_min = depth_normalize[0]
        self.depth_max = depth_normalize[1]
        self.w_ce = w_ce
        self.eps = 1e-6
    
    def silog_loss(self, prediction, target, mask):
        d = torch.log(prediction[mask]) - torch.log(target[mask])
        d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps)
        d_mean = torch.sum(d) / (d.numel() + self.eps)
        loss = torch.sqrt(d_square_mean - self.variance_focus * (d_mean ** 2))
        return loss
    
    def chamfer_distance_loss(self, bins, target_depth_maps, mask):
        bin_centers = 0.5 * (bins[:, 1:] + bins[:, :-1])
        n, p = bin_centers.shape
        input_points = bin_centers.view(n, p, 1)  # .shape = n, p, 1
        # n, c, h, w = target_depth_maps.shape

        target_points = target_depth_maps.flatten(1)  # n, hwc
        #mask = target_points.ge(1e-3)  # only valid ground truth points
        target_points = [p[m] for p, m in zip(target_depth_maps, mask)]
        target_lengths = torch.Tensor([len(t) for t in target_points], dtype=torch.long, device="cuda")
        target_points = pad_sequence(target_points, batch_first=True).unsqueeze(2)  # .shape = n, T, 1

        loss, _ = chamfer_distance(x=input_points, y=target_points, y_lengths=target_lengths)
        return loss
    
    # def depth_to_bins(self, depth, mask, depth_edges, size_limite=(512, 960)):
    #     """
    #     Discretize depth into depth bins. Predefined bins edges are provided.
    #     Mark invalid padding area as bins_num + 1
    #     Args:
    #         @depth: 1-channel depth, [B, 1, h, w]
    #     return: depth bins [B, C, h, w]
    #     """ 
    #     def _depth_to_bins_block_(depth, mask, depth_edges):
    #         bins_id = torch.sum(depth_edges[:, None, None, None, :] < torch.abs(depth)[:, :, :, :, None], dim=-1)
    #         bins_id = bins_id - 1
    #         invalid_mask = ~mask
    #         mask_lower = (depth <= self.depth_min) 
    #         mask_higher = (depth >= self.depth_max)
            
    #         bins_id[mask_lower] = 0
    #         bins_id[mask_higher] = self.bins_num - 1
    #         bins_id[bins_id == self.bins_num] = self.bins_num - 1

    #         bins_id[invalid_mask] = self.bins_num + 1
    #         return bins_id
    #     # _, _, H, W = depth.shape
    #     # bins = mask.clone().long()
    #     # h_blocks = np.ceil(H / size_limite[0]).astype(np.int)
    #     # w_blocks = np.ceil(W/ size_limite[1]).astype(np.int)
    #     # for i in range(h_blocks):
    #     #     for j in range(w_blocks):
    #     #         h_start = i*size_limite[0]
    #     #         h_end_proposal = (i + 1) * size_limite[0]
    #     #         h_end = h_end_proposal if h_end_proposal < H else H
    #     #         w_start = j*size_limite[1]
    #     #         w_end_proposal = (j + 1) * size_limite[1]
    #     #         w_end = w_end_proposal if w_end_proposal < W else W
    #     #         bins_ij = _depth_to_bins_block_(
    #     #             depth[:, :, h_start:h_end, w_start:w_end], 
    #     #             mask[:, :, h_start:h_end, w_start:w_end],
    #     #             depth_edges
    #     #             )
    #     #         bins[:, :, h_start:h_end, w_start:w_end] = bins_ij        
    #     bins = _depth_to_bins_block_(depth, mask, depth_edges)
    #     return bins
    
    # def ce_loss(self, pred_logit, target, mask, bins_edges):
    #     target_depth_bins = self.depth_to_bins(target, mask, bins_edges)
    #     loss = self.cel(pred_logit, target_depth_bins.squeeze().long())
    #     return loss


    def forward(self, prediction, target, bins_edges, mask=None, **kwargs):
        silog_loss = self.silog_loss(prediction=prediction, target=target, mask=mask)
        #cf_loss = self.chamfer_distance_loss(bins=bins_edges, target_depth_maps=target, mask=mask)
        loss = silog_loss * 10 #+ 0.1 * cf_loss
        # if self.w_ce:
        #     loss = loss + self.ce_loss(kwargs['pred_logit'], target, mask, bins_edges)
        if torch.isnan(loss).item() | torch.isinf(loss).item():
            raise RuntimeError(f'Adabins loss error, {loss}')
        return loss * self.loss_weight