Luigi Piccinelli
init demo
1ea89dd
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from unik3d.utils.geometric import dilate, downsample, erode
from .utils import FNS, masked_mean, masked_quantile
class LocalNormal(nn.Module):
def __init__(
self,
weight: float,
output_fn: str = "sqrt",
min_samples: int = 4,
quantile: float = 0.2,
eps: float = 1e-5,
):
super(LocalNormal, self).__init__()
self.name: str = self.__class__.__name__
self.weight = weight
self.output_fn = FNS[output_fn]
self.min_samples = min_samples
self.eps = eps
self.patch_weight = torch.ones(1, 1, 3, 3, device="cuda")
self.quantile = quantile
def bilateral_filter(self, rgb, surf, mask, patch_size=(9, 9)):
B, _, H, W = rgb.shape
sigma_surf = 0.4
sigma_color = 0.3
sigma_loc = 0.3 * max(H, W)
grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W))
grid = torch.stack([grid_x, grid_y], dim=0).to(rgb.device)
grid = grid.unsqueeze(0).repeat(B, 1, 1, 1)
paddings = [patch_size[0] // 2, patch_size[1] // 2]
rgbd = torch.cat([rgb, grid.float(), surf], dim=1)
# format to B,H*W,C,H_p*W_p format
rgbd_neigh = F.pad(rgbd, 2 * paddings, mode="constant")
rgbd_neigh = F.unfold(rgbd_neigh, kernel_size=patch_size)
rgbd_neigh = rgbd_neigh.permute(0, 2, 1).reshape(
B, H * W, 8, -1
) # B N 8 H_p*W_p
mask_neigh = F.pad(mask.float(), 2 * paddings, mode="constant")
mask_neigh = F.unfold(mask_neigh, kernel_size=patch_size)
mask_neigh = mask_neigh.permute(0, 2, 1).reshape(B, H * W, -1)
rgbd = rgbd.permute(0, 2, 3, 1).reshape(B, H * W, 8, 1) # B H*W 8 1
rgb_neigh = rgbd_neigh[:, :, :3, :]
grid_neigh = rgbd_neigh[:, :, 3:5, :]
surf_neigh = rgbd_neigh[:, :, 5:, :]
rgb = rgbd[:, :, :3, :]
grid = rgbd[:, :, 3:5, :]
surf = rgbd[:, :, 5:, :]
# calc distance
rgb_dist = torch.norm(rgb - rgb_neigh, dim=-2, p=2) ** 2
grid_dist = torch.norm(grid - grid_neigh, dim=-2, p=2) ** 2
surf_dist = torch.norm(surf - surf_neigh, dim=-2, p=2) ** 2
rgb_sim = torch.exp(-rgb_dist / 2 / sigma_color**2)
grid_sim = torch.exp(-grid_dist / 2 / sigma_loc**2)
surf_sim = torch.exp(-surf_dist / 2 / sigma_surf**2)
weight = mask_neigh * rgb_sim * grid_sim * surf_sim # B H*W H_p*W_p
weight = weight / weight.sum(dim=-1, keepdim=True).clamp(min=1e-5)
z = (surf_neigh * weight.unsqueeze(-2)).sum(dim=-1)
return z.reshape(B, H, W, 3).permute(0, 3, 1, 2)
def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor):
P0 = xyz
mask = mask.float()
normals, masks_valid_triangle = [], []
combinations = list(itertools.combinations_with_replacement([-2, -1, 1, 2], 2))
combinations += [c[::-1] for c in combinations]
# combinations = [(1, 1), (-1, -1), (1, -1), (-1, 1)]
for shift_0, shift_1 in set(combinations):
P1 = torch.roll(xyz, shifts=(0, shift_0), dims=(-1, -2))
P2 = torch.roll(xyz, shifts=(shift_1, 0), dims=(-1, -2))
if (shift_0 > 0) ^ (shift_1 > 0):
P1, P2 = P2, P1
vec1, vec2 = P1 - P0, P2 - P0
normal = torch.cross(vec1, vec2, dim=1)
vec1_norm = torch.norm(vec1, dim=1, keepdim=True).clip(min=1e-8)
vec2_norm = torch.norm(vec2, dim=1, keepdim=True).clip(min=1e-8)
normal_norm = torch.norm(normal, dim=1, keepdim=True).clip(min=1e-8)
normals.append(normal / normal_norm)
is_valid = (
torch.roll(mask, shifts=(0, shift_0), dims=(-1, -2))
+ torch.roll(mask, shifts=(shift_1, 0), dims=(-1, -2))
+ mask
== 3
)
is_valid = (
(normal_norm > 1e-6)
& (vec1_norm > 1e-6)
& (vec2_norm > 1e-6)
& is_valid
)
masks_valid_triangle.append(is_valid)
normals = torch.stack(normals, dim=-1)
mask_valid_triangle = torch.stack(masks_valid_triangle, dim=-1).float()
mask_valid = mask_valid_triangle.sum(dim=-1)
normals = (normals * mask_valid_triangle).sum(dim=-1) / mask_valid.clamp(
min=1.0
)
normals_norm = torch.norm(normals, dim=1, keepdim=True).clip(min=1e-8)
normals = normals / normals_norm
mask_valid = (
(mask_valid > 0.001)
& (~normals.sum(dim=1, keepdim=True).isnan())
& (normals_norm > 1e-6)
)
return normals, mask_valid # B 3 H W, B 1 H W
# def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor):
# x, y, z = torch.unbind(xyz, dim=1) # B 3 H W
# x = x.unsqueeze(1) # B 1 H W
# y = y.unsqueeze(1)
# z = z.unsqueeze(1)
# mask_float = mask.float()
# paddings = [self.patch_weight.shape[-2] // 2, self.patch_weight.shape[-1] // 2]
# num_samples = F.conv2d(mask_float, weight=self.patch_weight, padding=paddings).clamp(min=1.0) # B 1 H W
# mask_invalid = num_samples < self.min_samples
# xx = x * x
# yy = y * y
# zz = z * z
# xy = x * y
# xz = x * z
# yz = y * z
# xx_patch = F.conv2d(xx * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# yy_patch = F.conv2d(yy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# zz_patch = F.conv2d(zz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# xy_patch = F.conv2d(xy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# xz_patch = F.conv2d(xz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# yz_patch = F.conv2d(yz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# x_patch = F.conv2d(x * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# y_patch = F.conv2d(y * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# z_patch = F.conv2d(z * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
# ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], dim=-1).squeeze(1) # B H W 9
# ATA = torch.reshape(ATA, (ATA.shape[0], ATA.shape[1], ATA.shape[2], 3, 3)) # B H W 3 3
# eps_identity = torch.eye(3, device=ATA.device, dtype=ATA.dtype).unsqueeze(0) # 1 3 3
# ATA = ATA + 1e-6 * eps_identity
# AT1 = torch.stack([x_patch, y_patch, z_patch], dim=-1).squeeze(1).unsqueeze(-1) # B H W 3 1
# det = torch.linalg.det(ATA)
# mask_invalid_inverse = det.abs() < 1e-12
# mask_invalid = mask_invalid.squeeze(1) | mask_invalid_inverse
# AT1[mask_invalid, :, :] = 0
# ATA[mask_invalid, :, :] = eps_identity
# ATA_inv = torch.linalg.inv(ATA)
# normals = (ATA_inv @ AT1).squeeze(dim=-1) # B H W 3
# normals = normals / torch.norm(normals, dim=-1, keepdim=True).clip(min=1e-8)
# mask_invalid = mask_invalid | (torch.sum(normals, dim=-1) == 0.0)
# # flip normals, based if a * x + b * y + c * z < 0 -> change sign of normals
# mean_patch_xyz = AT1.squeeze(-1)
# orient_mask = torch.sum(normals * mean_patch_xyz, dim=-1).unsqueeze(-1) > 0
# normals = (2 * orient_mask.to(ATA.dtype) - 1) * normals
# return normals.permute(0, 3, 1, 2), ~mask_invalid.unsqueeze(1) # B 3 H W, B H W
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(self, input: torch.Tensor, target: torch.Tensor, mask, valid):
if not valid.any():
return 0.0 * input.mean(dim=(1, 2, 3))
input = input.float()
target = target.float()
mask = erode(mask, kernel_size=3)
target_normal, mask_target = self.get_surface_normal(target[valid], mask[valid])
input_normal, mask_input = self.get_surface_normal(
input[valid], torch.ones_like(mask[valid])
)
gt_similarity = F.cosine_similarity(input_normal, target_normal, dim=1) # B H W
mask_target = (
mask_target.squeeze(1) & (gt_similarity < 0.999) & (gt_similarity > -0.999)
)
error = F.relu((1 - gt_similarity) / 2 - 0.01)
error_full = torch.ones_like(mask.squeeze(1).float())
error_full[valid] = error
mask_full = torch.ones_like(mask.squeeze(1))
mask_full[valid] = mask_target
error_qtl = error_full.detach()
mask_full = mask_full & (
error_qtl
< masked_quantile(
error_qtl, mask_full, dims=[1, 2], q=1 - self.quantile
).view(-1, 1, 1)
)
loss = masked_mean(error_full, mask=mask_full, dim=(-2, -1)).squeeze(
dim=(-2, -1)
) # B
loss = self.output_fn(loss)
return loss
def von_mises(self, input, target, mask, kappa):
score = torch.cosine_similarity(input, target, dim=1).unsqueeze(1)
mask_cosine = torch.logical_and(
mask, torch.logical_and(score.detach() < 0.999, score.detach() > -0.999)
)
nll = masked_mean(
kappa * (1 - score), mask=mask_cosine, dim=(-1, -2, -3)
).squeeze()
return nll
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
output_fn=config["output_fn"],
quantile=config.get("quantile", 0.2),
)
return obj