Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class SoftWeight(nn.Module): | |
""" | |
Transfer n-channel discrete depth bins to a depth map. | |
Args: | |
@depth_bin: n-channel output of the network, [b, c, h, w] | |
Return: 1-channel depth, [b, 1, h, w] | |
""" | |
def __init__(self, depth_bins_border): | |
super(SoftWeight, self).__init__() | |
self.register_buffer("depth_bins_border", torch.tensor(depth_bins_border), persistent=False) | |
def forward(self, pred_logit): | |
if type(pred_logit).__module__ != torch.__name__: | |
pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda") | |
pred_score = nn.functional.softmax(pred_logit, dim=1) | |
pred_score_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c] | |
pred_score_weight = pred_score_ch * self.depth_bins_border | |
depth_log = torch.sum(pred_score_weight, dim=3, dtype=torch.float32, keepdim=True) | |
depth = 10 ** depth_log | |
depth = depth.permute(0, 3, 1, 2) # [b, 1, h, w] | |
confidence, _ = torch.max(pred_logit, dim=1, keepdim=True) | |
return depth, confidence | |
def soft_weight(pred_logit, depth_bins_border): | |
""" | |
Transfer n-channel discrete depth bins to depth map. | |
Args: | |
@depth_bin: n-channel output of the network, [b, c, h, w] | |
Return: 1-channel depth, [b, 1, h, w] | |
""" | |
if type(pred_logit).__module__ != torch.__name__: | |
pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda") | |
if type(depth_bins_border).__module__ != torch.__name__: | |
depth_bins_border = torch.tensor(depth_bins_border, dtype=torch.float32, device="cuda") | |
pred_score = nn.functional.softmax(pred_logit, dim=1) | |
depth_bins_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c] depth = torch.sum(depth, dim=3, dtype=torch.float32, keepdim=True) | |
depth = 10 ** depth | |
depth = depth.permute(0, 3, 1, 2) # [b, 1, h, w] | |
confidence, _ = torch.max(pred_logit, dim=1, keepdim=True) | |
return depth, confidence | |
if __name__ == '__main__': | |
import numpy as np | |
depth_max = 100 | |
depth_min = 0.5 | |
depth_bin_interval = (np.log10(depth_max) - np.log10(depth_min)) / 200 | |
depth_bins_border = [np.log10(depth_min) + depth_bin_interval * (i + 0.5) | |
for i in range(200)] | |
sw = SoftWeight(depth_bins_border) |