Spaces:
Running
on
Zero
Running
on
Zero
import itertools | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from unik3d.utils.geometric import dilate, downsample, erode | |
from .utils import FNS, masked_mean, masked_quantile | |
class LocalNormal(nn.Module): | |
def __init__( | |
self, | |
weight: float, | |
output_fn: str = "sqrt", | |
min_samples: int = 4, | |
quantile: float = 0.2, | |
eps: float = 1e-5, | |
): | |
super(LocalNormal, self).__init__() | |
self.name: str = self.__class__.__name__ | |
self.weight = weight | |
self.output_fn = FNS[output_fn] | |
self.min_samples = min_samples | |
self.eps = eps | |
self.patch_weight = torch.ones(1, 1, 3, 3, device="cuda") | |
self.quantile = quantile | |
def bilateral_filter(self, rgb, surf, mask, patch_size=(9, 9)): | |
B, _, H, W = rgb.shape | |
sigma_surf = 0.4 | |
sigma_color = 0.3 | |
sigma_loc = 0.3 * max(H, W) | |
grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W)) | |
grid = torch.stack([grid_x, grid_y], dim=0).to(rgb.device) | |
grid = grid.unsqueeze(0).repeat(B, 1, 1, 1) | |
paddings = [patch_size[0] // 2, patch_size[1] // 2] | |
rgbd = torch.cat([rgb, grid.float(), surf], dim=1) | |
# format to B,H*W,C,H_p*W_p format | |
rgbd_neigh = F.pad(rgbd, 2 * paddings, mode="constant") | |
rgbd_neigh = F.unfold(rgbd_neigh, kernel_size=patch_size) | |
rgbd_neigh = rgbd_neigh.permute(0, 2, 1).reshape( | |
B, H * W, 8, -1 | |
) # B N 8 H_p*W_p | |
mask_neigh = F.pad(mask.float(), 2 * paddings, mode="constant") | |
mask_neigh = F.unfold(mask_neigh, kernel_size=patch_size) | |
mask_neigh = mask_neigh.permute(0, 2, 1).reshape(B, H * W, -1) | |
rgbd = rgbd.permute(0, 2, 3, 1).reshape(B, H * W, 8, 1) # B H*W 8 1 | |
rgb_neigh = rgbd_neigh[:, :, :3, :] | |
grid_neigh = rgbd_neigh[:, :, 3:5, :] | |
surf_neigh = rgbd_neigh[:, :, 5:, :] | |
rgb = rgbd[:, :, :3, :] | |
grid = rgbd[:, :, 3:5, :] | |
surf = rgbd[:, :, 5:, :] | |
# calc distance | |
rgb_dist = torch.norm(rgb - rgb_neigh, dim=-2, p=2) ** 2 | |
grid_dist = torch.norm(grid - grid_neigh, dim=-2, p=2) ** 2 | |
surf_dist = torch.norm(surf - surf_neigh, dim=-2, p=2) ** 2 | |
rgb_sim = torch.exp(-rgb_dist / 2 / sigma_color**2) | |
grid_sim = torch.exp(-grid_dist / 2 / sigma_loc**2) | |
surf_sim = torch.exp(-surf_dist / 2 / sigma_surf**2) | |
weight = mask_neigh * rgb_sim * grid_sim * surf_sim # B H*W H_p*W_p | |
weight = weight / weight.sum(dim=-1, keepdim=True).clamp(min=1e-5) | |
z = (surf_neigh * weight.unsqueeze(-2)).sum(dim=-1) | |
return z.reshape(B, H, W, 3).permute(0, 3, 1, 2) | |
def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor): | |
P0 = xyz | |
mask = mask.float() | |
normals, masks_valid_triangle = [], [] | |
combinations = list(itertools.combinations_with_replacement([-2, -1, 1, 2], 2)) | |
combinations += [c[::-1] for c in combinations] | |
# combinations = [(1, 1), (-1, -1), (1, -1), (-1, 1)] | |
for shift_0, shift_1 in set(combinations): | |
P1 = torch.roll(xyz, shifts=(0, shift_0), dims=(-1, -2)) | |
P2 = torch.roll(xyz, shifts=(shift_1, 0), dims=(-1, -2)) | |
if (shift_0 > 0) ^ (shift_1 > 0): | |
P1, P2 = P2, P1 | |
vec1, vec2 = P1 - P0, P2 - P0 | |
normal = torch.cross(vec1, vec2, dim=1) | |
vec1_norm = torch.norm(vec1, dim=1, keepdim=True).clip(min=1e-8) | |
vec2_norm = torch.norm(vec2, dim=1, keepdim=True).clip(min=1e-8) | |
normal_norm = torch.norm(normal, dim=1, keepdim=True).clip(min=1e-8) | |
normals.append(normal / normal_norm) | |
is_valid = ( | |
torch.roll(mask, shifts=(0, shift_0), dims=(-1, -2)) | |
+ torch.roll(mask, shifts=(shift_1, 0), dims=(-1, -2)) | |
+ mask | |
== 3 | |
) | |
is_valid = ( | |
(normal_norm > 1e-6) | |
& (vec1_norm > 1e-6) | |
& (vec2_norm > 1e-6) | |
& is_valid | |
) | |
masks_valid_triangle.append(is_valid) | |
normals = torch.stack(normals, dim=-1) | |
mask_valid_triangle = torch.stack(masks_valid_triangle, dim=-1).float() | |
mask_valid = mask_valid_triangle.sum(dim=-1) | |
normals = (normals * mask_valid_triangle).sum(dim=-1) / mask_valid.clamp( | |
min=1.0 | |
) | |
normals_norm = torch.norm(normals, dim=1, keepdim=True).clip(min=1e-8) | |
normals = normals / normals_norm | |
mask_valid = ( | |
(mask_valid > 0.001) | |
& (~normals.sum(dim=1, keepdim=True).isnan()) | |
& (normals_norm > 1e-6) | |
) | |
return normals, mask_valid # B 3 H W, B 1 H W | |
# def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor): | |
# x, y, z = torch.unbind(xyz, dim=1) # B 3 H W | |
# x = x.unsqueeze(1) # B 1 H W | |
# y = y.unsqueeze(1) | |
# z = z.unsqueeze(1) | |
# mask_float = mask.float() | |
# paddings = [self.patch_weight.shape[-2] // 2, self.patch_weight.shape[-1] // 2] | |
# num_samples = F.conv2d(mask_float, weight=self.patch_weight, padding=paddings).clamp(min=1.0) # B 1 H W | |
# mask_invalid = num_samples < self.min_samples | |
# xx = x * x | |
# yy = y * y | |
# zz = z * z | |
# xy = x * y | |
# xz = x * z | |
# yz = y * z | |
# xx_patch = F.conv2d(xx * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# yy_patch = F.conv2d(yy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# zz_patch = F.conv2d(zz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# xy_patch = F.conv2d(xy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# xz_patch = F.conv2d(xz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# yz_patch = F.conv2d(yz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# x_patch = F.conv2d(x * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# y_patch = F.conv2d(y * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# z_patch = F.conv2d(z * mask_float, weight=self.patch_weight, padding=paddings) / num_samples | |
# ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], dim=-1).squeeze(1) # B H W 9 | |
# ATA = torch.reshape(ATA, (ATA.shape[0], ATA.shape[1], ATA.shape[2], 3, 3)) # B H W 3 3 | |
# eps_identity = torch.eye(3, device=ATA.device, dtype=ATA.dtype).unsqueeze(0) # 1 3 3 | |
# ATA = ATA + 1e-6 * eps_identity | |
# AT1 = torch.stack([x_patch, y_patch, z_patch], dim=-1).squeeze(1).unsqueeze(-1) # B H W 3 1 | |
# det = torch.linalg.det(ATA) | |
# mask_invalid_inverse = det.abs() < 1e-12 | |
# mask_invalid = mask_invalid.squeeze(1) | mask_invalid_inverse | |
# AT1[mask_invalid, :, :] = 0 | |
# ATA[mask_invalid, :, :] = eps_identity | |
# ATA_inv = torch.linalg.inv(ATA) | |
# normals = (ATA_inv @ AT1).squeeze(dim=-1) # B H W 3 | |
# normals = normals / torch.norm(normals, dim=-1, keepdim=True).clip(min=1e-8) | |
# mask_invalid = mask_invalid | (torch.sum(normals, dim=-1) == 0.0) | |
# # flip normals, based if a * x + b * y + c * z < 0 -> change sign of normals | |
# mean_patch_xyz = AT1.squeeze(-1) | |
# orient_mask = torch.sum(normals * mean_patch_xyz, dim=-1).unsqueeze(-1) > 0 | |
# normals = (2 * orient_mask.to(ATA.dtype) - 1) * normals | |
# return normals.permute(0, 3, 1, 2), ~mask_invalid.unsqueeze(1) # B 3 H W, B H W | |
def forward(self, input: torch.Tensor, target: torch.Tensor, mask, valid): | |
if not valid.any(): | |
return 0.0 * input.mean(dim=(1, 2, 3)) | |
input = input.float() | |
target = target.float() | |
mask = erode(mask, kernel_size=3) | |
target_normal, mask_target = self.get_surface_normal(target[valid], mask[valid]) | |
input_normal, mask_input = self.get_surface_normal( | |
input[valid], torch.ones_like(mask[valid]) | |
) | |
gt_similarity = F.cosine_similarity(input_normal, target_normal, dim=1) # B H W | |
mask_target = ( | |
mask_target.squeeze(1) & (gt_similarity < 0.999) & (gt_similarity > -0.999) | |
) | |
error = F.relu((1 - gt_similarity) / 2 - 0.01) | |
error_full = torch.ones_like(mask.squeeze(1).float()) | |
error_full[valid] = error | |
mask_full = torch.ones_like(mask.squeeze(1)) | |
mask_full[valid] = mask_target | |
error_qtl = error_full.detach() | |
mask_full = mask_full & ( | |
error_qtl | |
< masked_quantile( | |
error_qtl, mask_full, dims=[1, 2], q=1 - self.quantile | |
).view(-1, 1, 1) | |
) | |
loss = masked_mean(error_full, mask=mask_full, dim=(-2, -1)).squeeze( | |
dim=(-2, -1) | |
) # B | |
loss = self.output_fn(loss) | |
return loss | |
def von_mises(self, input, target, mask, kappa): | |
score = torch.cosine_similarity(input, target, dim=1).unsqueeze(1) | |
mask_cosine = torch.logical_and( | |
mask, torch.logical_and(score.detach() < 0.999, score.detach() > -0.999) | |
) | |
nll = masked_mean( | |
kappa * (1 - score), mask=mask_cosine, dim=(-1, -2, -3) | |
).squeeze() | |
return nll | |
def build(cls, config): | |
obj = cls( | |
weight=config["weight"], | |
output_fn=config["output_fn"], | |
quantile=config.get("quantile", 0.2), | |
) | |
return obj | |