File size: 4,874 Bytes
3ef1661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn as nn
import numpy as np

class HDNRandomLoss(nn.Module):
    """
    Hieratical depth normalization loss. Replace the original hieratical depth ranges with randomly sampled ranges.
    loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d))
    """
    def __init__(self, loss_weight=1, random_num=32, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], norm_dataset=['Taskonomy', 'Matterport3D', 'Replica', 'Hypersim'], disable_dataset=['MapillaryPSD'], **kwargs):
        super(HDNRandomLoss, self).__init__()
        self.loss_weight = loss_weight
        self.random_num = random_num
        self.eps = 1e-6
        self.data_type = data_type
        self.disable_dataset = disable_dataset
    
    def get_random_masks_for_batch(self, depth_gt: torch.Tensor, mask_valid: torch.Tensor)-> torch.Tensor:
        valid_values = depth_gt[mask_valid]
        max_d = valid_values.max().item() if valid_values.numel() > 0 else 0.0 
        min_d = valid_values.min().item() if valid_values.numel() > 0 else 0.0

        sample_min_d = np.random.uniform(0, 0.75, self.random_num) * (max_d - min_d) + min_d
        sample_max_d = np.random.uniform(sample_min_d + 0.1, 1-self.eps, self.random_num) * (max_d - min_d) + min_d

        mask_new = [(depth_gt >= sample_min_d[i]) & (depth_gt < sample_max_d[i] + 1e-30) & mask_valid for i in range(self.random_num)]
        mask_new = torch.stack(mask_new, dim=0).cuda() #[N, 1, H, W]
        return mask_new

    def ssi_mae(self, prediction, target, mask_valid):
        B, C, H, W = target.shape
        prediction_nan = prediction.clone().detach()
        target_nan = target.clone()
        prediction_nan[~mask_valid] = float('nan')
        target_nan[~mask_valid] = float('nan')

        valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + self.eps
        valid_pixs = valid_pixs[:, :, :, None]

        gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
        gt_median[torch.isnan(gt_median)] = 0
        gt_diff = (torch.abs(target - gt_median) * mask_valid).reshape((B, C, -1))
        gt_s = gt_diff.sum(dim=2)[:, :, None, None] / valid_pixs
        gt_trans = (target - gt_median) / (gt_s + self.eps)

        pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
        pred_median[torch.isnan(pred_median)] = 0
        pred_diff = (torch.abs(prediction - pred_median) * mask_valid).reshape((B, C, -1))
        pred_s = pred_diff.sum(dim=2)[:, :, None, None] / valid_pixs
        pred_trans = (prediction - pred_median) / (pred_s + self.eps)

        loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid)
        return  loss_sum

    def forward(self, prediction, target, mask=None, **kwargs):
        """
        Calculate loss.
        """
        B, C, H, W = target.shape
        
        loss = 0.0
        valid_pix = 0.0
        
        device = target.device
        
        batches_dataset = kwargs['dataset']
        self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \
            for batch_dataset in batches_dataset], device=device)[:,None,None,None]
        
        batch_limit = 4
        loops = int(np.ceil(self.random_num / batch_limit))
        for i in range(B):                
            mask_i = mask[i, ...] #[1, H, W]

            if self.batch_valid[i, ...] < 0.5:
                loss += 0 * torch.sum(prediction[i, ...])
                valid_pix += 0 * torch.sum(mask_i)
                continue

            pred_i = prediction[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1)
            target_i = target[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1)
            mask_random_drange = self.get_random_masks_for_batch(target[i, ...], mask_i) # [N, 1, H, W]
            for j in range(loops):
                mask_random_loopi = mask_random_drange[j*batch_limit:(j+1)*batch_limit, ...]
                loss += self.ssi_mae(
                    prediction=pred_i[:mask_random_loopi.shape[0], ...], 
                    target=target_i[:mask_random_loopi.shape[0], ...], 
                    mask_valid=mask_random_loopi)
                valid_pix += torch.sum(mask_random_loopi)

        loss = loss / (valid_pix + self.eps)
        if torch.isnan(loss).item() | torch.isinf(loss).item():
            loss = 0 * torch.sum(prediction)
            print(f'HDNL NAN error, {loss}, valid pix: {valid_pix}')
        return loss * self.loss_weight
    
if __name__ == '__main__':
    ssil = HDNRandomLoss()
    pred = torch.rand((2, 1, 256, 256)).cuda()
    gt =  - torch.rand((2, 1, 256, 256)).cuda()#torch.zeros_like(pred).cuda() #
    gt[:, :, 100:256, 0:100] = -1
    mask = gt > 0
    out = ssil(pred, gt, mask)
    print(out)