|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpSampleBN(nn.Module):
|
|
def __init__(self, skip_input, output_features):
|
|
super(UpSampleBN, self).__init__()
|
|
|
|
self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
|
nn.BatchNorm2d(output_features),
|
|
nn.LeakyReLU(),
|
|
nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
|
|
nn.BatchNorm2d(output_features),
|
|
nn.LeakyReLU())
|
|
|
|
def forward(self, x, concat_with):
|
|
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
|
|
f = torch.cat([up_x, concat_with], dim=1)
|
|
return self._net(f)
|
|
|
|
|
|
|
|
class UpSampleGN(nn.Module):
|
|
def __init__(self, skip_input, output_features):
|
|
super(UpSampleGN, self).__init__()
|
|
|
|
self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
|
nn.GroupNorm(8, output_features),
|
|
nn.LeakyReLU(),
|
|
Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
|
|
nn.GroupNorm(8, output_features),
|
|
nn.LeakyReLU())
|
|
|
|
def forward(self, x, concat_with):
|
|
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
|
|
f = torch.cat([up_x, concat_with], dim=1)
|
|
return self._net(f)
|
|
|
|
|
|
|
|
class Conv2d(nn.Conv2d):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
padding=0, dilation=1, groups=1, bias=True):
|
|
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
|
|
padding, dilation, groups, bias)
|
|
|
|
def forward(self, x):
|
|
weight = self.weight
|
|
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
|
|
keepdim=True).mean(dim=3, keepdim=True)
|
|
weight = weight - weight_mean
|
|
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
|
weight = weight / std.expand_as(weight)
|
|
return F.conv2d(x, weight, self.bias, self.stride,
|
|
self.padding, self.dilation, self.groups)
|
|
|
|
|
|
|
|
def norm_normalize(norm_out):
|
|
min_kappa = 0.01
|
|
norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
|
|
norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
|
|
kappa = F.elu(kappa) + 1.0 + min_kappa
|
|
final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
|
|
return final_out
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
|
|
device = init_normal.device
|
|
B, _, H, W = init_normal.shape
|
|
N = int(sampling_ratio * H * W)
|
|
beta = beta
|
|
|
|
|
|
uncertainty_map = -1 * init_normal[:, 3, :, :]
|
|
|
|
|
|
if gt_norm_mask is not None:
|
|
gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
|
|
gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
|
|
uncertainty_map[gt_invalid_mask] = -1e4
|
|
|
|
|
|
_, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
|
|
|
|
|
|
if int(beta * N) > 0:
|
|
importance = idx[:, :int(beta * N)]
|
|
|
|
|
|
remaining = idx[:, int(beta * N):]
|
|
|
|
|
|
num_coverage = N - int(beta * N)
|
|
|
|
if num_coverage <= 0:
|
|
samples = importance
|
|
else:
|
|
coverage_list = []
|
|
for i in range(B):
|
|
idx_c = torch.randperm(remaining.size()[1])
|
|
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1))
|
|
coverage = torch.cat(coverage_list, dim=0)
|
|
samples = torch.cat((importance, coverage), dim=1)
|
|
|
|
else:
|
|
|
|
remaining = idx[:, :]
|
|
|
|
|
|
num_coverage = N
|
|
|
|
coverage_list = []
|
|
for i in range(B):
|
|
idx_c = torch.randperm(remaining.size()[1])
|
|
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1))
|
|
coverage = torch.cat(coverage_list, dim=0)
|
|
samples = coverage
|
|
|
|
|
|
rows_int = samples // W
|
|
rows_float = rows_int / float(H-1)
|
|
rows_float = (rows_float * 2.0) - 1.0
|
|
|
|
cols_int = samples % W
|
|
cols_float = cols_int / float(W-1)
|
|
cols_float = (cols_float * 2.0) - 1.0
|
|
|
|
point_coords = torch.zeros(B, 1, N, 2)
|
|
point_coords[:, 0, :, 0] = cols_float
|
|
point_coords[:, 0, :, 1] = rows_float
|
|
point_coords = point_coords.to(device)
|
|
return point_coords, rows_int, cols_int |