UniK3D-demo / unik3d /utils /geometric.py
Luigi Piccinelli
init demo
1ea89dd
from typing import Tuple
import torch
from torch.nn import functional as F
# @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def generate_rays(
camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False
):
batch_size, device, dtype = (
camera_intrinsics.shape[0],
camera_intrinsics.device,
camera_intrinsics.dtype,
)
# print("CAMERA DTYPE", dtype)
height, width = image_shape
# Generate grid of pixel coordinates
pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
if noisy:
pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
pixel_coords = torch.stack(
[pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2
) # (H, W, 2)
pixel_coords = pixel_coords + 0.5
# Calculate ray directions
intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype)
homogeneous_coords = torch.cat(
[pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2
) # (H, W, 3)
ray_directions = torch.matmul(
intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1)
) # (3, H*W)
# unstable normalization, need float32?
ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W)
ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3)
theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1])
phi = torch.acos(ray_directions[..., 1])
# pitch = torch.asin(ray_directions[..., 1])
# roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1])
angles = torch.stack([theta, phi], dim=-1)
return ray_directions, angles
@torch.jit.script
def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
theta = spherical_tensor[..., 0] # Extract polar angle
phi = spherical_tensor[..., 1] # Extract azimuthal angle
z = spherical_tensor[..., 2] # Extract zbuffer depth
# y = r * cos(phi)
# x = r * sin(phi) * sin(theta)
# z = r * sin(phi) * cos(theta)
# =>
# r = z / sin(phi) / cos(theta)
# y = z / (sin(phi) / cos(phi)) / cos(theta)
# x = z * sin(theta) / cos(theta)
x = z * torch.tan(theta)
y = z / torch.tan(phi) / torch.cos(theta)
euclidean_tensor = torch.stack((x, y, z), dim=-1)
return euclidean_tensor
@torch.jit.script
def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
theta = spherical_tensor[..., 0] # Extract polar angle
phi = spherical_tensor[..., 1] # Extract azimuthal angle
r = spherical_tensor[..., 2] # Extract radius
# x = r * torch.sin(theta) * torch.sin(phi)
# y = r * torch.cos(theta)
# z = r * torch.cos(phi) * torch.sin(theta)
x = r * torch.sin(theta) * torch.cos(phi)
y = r * torch.sin(theta) * torch.sin(phi)
z = r * torch.cos(theta)
euclidean_tensor = torch.stack((x, y, z), dim=-1)
return euclidean_tensor
@torch.jit.script
def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
x = spherical_tensor[..., 0] # Extract polar angle
y = spherical_tensor[..., 1] # Extract azimuthal angle
z = spherical_tensor[..., 2] # Extract radius
# y = r * cos(phi)
# x = r * sin(phi) * sin(theta)
# z = r * sin(phi) * cos(theta)
r = torch.sqrt(x**2 + y**2 + z**2)
theta = torch.atan2(x / r, z / r)
phi = torch.acos(y / r)
euclidean_tensor = torch.stack((theta, phi, r), dim=-1)
return euclidean_tensor
@torch.jit.script
def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
pitch = torch.asin(euclidean_tensor[..., 1])
yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
z = euclidean_tensor[..., 2] # Extract zbuffer depth
euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1)
return euclidean_tensor
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject_points(
depth: torch.Tensor, camera_intrinsics: torch.Tensor
) -> torch.Tensor:
"""
Unprojects a batch of depth maps to 3D point clouds using camera intrinsics.
Args:
depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W).
camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3).
Returns:
torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W).
"""
batch_size, _, height, width = depth.shape
device = depth.device
# Create pixel grid
y_coords, x_coords = torch.meshgrid(
torch.arange(height, device=device),
torch.arange(width, device=device),
indexing="ij",
)
pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2)
# Get homogeneous coords (u v 1)
pixel_coords_homogeneous = torch.cat(
(pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1
)
pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten(
1
) # (3, H*W)
# Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W]
camera_intrinsics_inv = camera_intrinsics.clone()
# invert camera intrinsics
camera_intrinsics_inv[:, 0, 0] = 1 / camera_intrinsics_inv[:, 0, 0]
camera_intrinsics_inv[:, 1, 1] = 1 / camera_intrinsics_inv[:, 1, 1]
unprojected_points = camera_intrinsics_inv @ pixel_coords_homogeneous # (B, 3, H*W)
unprojected_points = unprojected_points.view(
batch_size, 3, height, width
) # (B, 3, H, W)
unprojected_points = unprojected_points * depth # (B, 3, H, W)
return unprojected_points
@torch.jit.script
def project_points(
points_3d: torch.Tensor,
intrinsic_matrix: torch.Tensor,
image_shape: Tuple[int, int],
) -> torch.Tensor:
# Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T
points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2))
# Normalize projected points: (u v w) -> (u / w, v / w, 1)
points_2d = points_2d[..., :2] / points_2d[..., 2:]
# To pixels (rounding!!!), no int as it breaks gradient
points_2d = points_2d.round()
# pointa need to be inside the image (can it diverge onto all points out???)
valid_mask = (
(points_2d[..., 0] >= 0)
& (points_2d[..., 0] < image_shape[1])
& (points_2d[..., 1] >= 0)
& (points_2d[..., 1] < image_shape[0])
)
# Calculate the flat indices of the valid pixels
flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1]
flat_indices = flat_points_2d.long()
# Create depth maps and counts using scatter_add, (B, H, W)
depth_maps = torch.zeros(
[points_3d.shape[0], *image_shape], device=points_3d.device
)
counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device)
# Loop over batches to apply masks and accumulate depth/count values
for i in range(points_3d.shape[0]):
valid_indices = flat_indices[i, valid_mask[i]]
depth_maps[i].view(-1).scatter_add_(
0, valid_indices, points_3d[i, valid_mask[i], 2]
)
counts[i].view(-1).scatter_add_(
0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2])
)
# Calculate mean depth for each pixel in each batch
mean_depth_maps = depth_maps / counts.clamp(min=1.0)
return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W)
@torch.jit.script
def downsample(data: torch.Tensor, downsample_factor: int = 2):
N, _, H, W = data.shape
data = data.view(
N,
H // downsample_factor,
downsample_factor,
W // downsample_factor,
downsample_factor,
1,
)
data = data.permute(0, 1, 3, 5, 2, 4).contiguous()
data = data.view(-1, downsample_factor * downsample_factor)
data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data)
data = torch.min(data_tmp, dim=-1).values
data = data.view(N, 1, H // downsample_factor, W // downsample_factor)
data = torch.where(data > 1000, torch.zeros_like(data), data)
return data
@torch.jit.script
def flat_interpolate(
flat_tensor: torch.Tensor,
old: Tuple[int, int],
new: Tuple[int, int],
antialias: bool = False,
mode: str = "bilinear",
) -> torch.Tensor:
if old[0] == new[0] and old[1] == new[1]:
return flat_tensor
tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
0, 3, 1, 2
) # b c h w
tensor_interp = F.interpolate(
tensor,
size=(new[0], new[1]),
mode=mode,
align_corners=False,
antialias=antialias,
)
flat_tensor_interp = tensor_interp.view(
flat_tensor.shape[0], -1, new[0] * new[1]
).permute(
0, 2, 1
) # b (h w) c
return flat_tensor_interp.contiguous()
# # @torch.jit.script
# def displacement_relative_neighbour(gt: torch.Tensor, mask: torch.Tensor = None, kernel_size: int = 7, ndim: int =4):
# pad = kernel_size // 2
# n_neighbours = int(kernel_size**2)
# # when torchscipt will support nested generators in listcomp or usage of range
# # in product(range_, range_), then use listcomp, so far speedup ~5% wrt std python
# if mask is None:
# mask = torch.ones_like(gt).bool()
# lst_gts, lst_masks = [], []
# for i in range(-kernel_size//2 + 1, kernel_size//2 + 1):
# for j in range(-kernel_size//2 + 1, kernel_size//2 + 1):
# if i != 0 or j != 0:
# lst_gts.append(torch.roll(gt, shifts=(i, j), dims=(-2, -1)))
# lst_masks.append(torch.roll(F.pad(mask, (pad,) * 4), shifts=(i, j), dims=(-2, -1)))
# gts = torch.cat(lst_gts, dim=-3)
# masks = torch.cat(lst_masks, dim=-3)
# masks = masks[..., pad:-pad, pad:-pad]
# masks[~mask.repeat(*(1,) * (ndim - 3), n_neighbours-1, 1, 1,)] = False # No displacement known if seed is missing
# log_gts = gts.clamp(min=1e-6).log() - gt.repeat(*(1,) * (ndim - 3), n_neighbours-1, 1, 1).clamp(min=1e-6).log()
# return log_gts, masks
# @torch.jit.script
# def antidisplacement_relative_neighbour(preds: torch.Tensor, kernel_size: int = 7):
# lst_preds, lst_masks = [], []
# cnt = 0
# pad = kernel_size // 2
# mask = F.pad(torch.ones((preds.shape[0], 1, preds.shape[-2], preds.shape[-1]), device=preds.device), (pad,) * 4)
# for i in range(-kernel_size//2 + 1, kernel_size//2 + 1):
# for j in range(-kernel_size//2 + 1, kernel_size//2 + 1):
# if i != 0 or j !=0:
# lst_preds.append(torch.roll(preds[:, cnt], shifts=(-i, -j), dims=(-2, -1)))
# lst_masks.append(torch.roll(mask, shifts=(-i, -j), dims=(-2, -1)))
# cnt += 1
# preds_ensamble = torch.stack(lst_preds, dim=1)
# masks = torch.cat(lst_masks, dim=1)
# masks = masks[..., pad:-pad, pad:-pad]
# return preds_ensamble, masks
# def unproject(uv, fx, fy, cx, cy, xi=0, alpha=0):
# u, v = uv.unbind(dim=1)
# mx = (u - cx) / fx
# my = (v - cy) / fy
# r_square = mx ** 2 + my ** 2
# root = 1 - (2 * alpha - 1) * r_square
# valid_mask = root >= 0
# root[~valid_mask] = 0.0
# mz = (1 - (alpha ** 2) * r_square) / (alpha * torch.sqrt(root) + (1 - alpha))
# coeff = (mz * xi + torch.sqrt(mz ** 2 + (1 - xi ** 2) * r_square)) / (mz ** 2 + r_square)
# x = coeff * mx
# y = coeff * my
# z = coeff * mz - xi
# # z = z.clamp(min=1e-7)
# x_norm = x / z
# y_norm = y / z
# z_norm = z / z
# xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1)
# # print("unproj", xnorm.shape, xnorm[:, -1].mean())
# return xnorm, valid_mask.unsqueeze(1).repeat(1, 3, 1, 1)
# def project(point3D, fx, fy, cx, cy, xi=0, alpha=0):
# B, C, H, W = point3D.shape
# x, y, z = point3D.unbind(dim=1)
# z = z.clamp(min=1e-7)
# d_1 = torch.sqrt( x ** 2 + y ** 2 + z ** 2 )
# d_2 = torch.sqrt( x ** 2 + y ** 2 + (xi * d_1 + z) ** 2 )
# div = alpha * d_2 + (1 - alpha) * (xi * d_1 + z)
# Xnorm = fx * x / div + cx
# Ynorm = fy * y / div + cy
# coords = torch.stack([Xnorm, Ynorm], dim=1)
# w1 = torch.where(alpha <= 0.5, alpha / (1 - alpha), (1 - alpha) / alpha)
# w2 = w1 + xi / ((2 * w1 * xi + xi ** 2 + 1) ** 0.5)
# valid_mask = z > - w2 * d_1
# # Return pixel coordinates
# return coords, valid_mask.unsqueeze(1).repeat(1, 2, 1, 1)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def unproject(uv, fx, fy, cx, cy, alpha=None, beta=None):
uv = uv.float()
fx = fx.float()
fy = fy.float()
cx = cx.float()
cy = cy.float()
u, v = uv.unbind(dim=1)
alpha = torch.zeros_like(fx) if alpha is None else alpha.float()
beta = torch.ones_like(fx) if beta is None else beta.float()
mx = (u - cx) / fx
my = (v - cy) / fy
r_square = mx**2 + my**2
valid_mask = r_square < torch.where(alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1)))
sqrt_val = 1 - (2 * alpha - 1) * beta * r_square
mz = (1 - beta * (alpha**2) * r_square) / (
alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha)
)
coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5)
x = coeff * mx
y = coeff * my
z = coeff * mz
valid_mask = valid_mask & (z > 1e-3)
xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1)
return xnorm, valid_mask.unsqueeze(1)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def project(point3D, fx, fy, cx, cy, alpha=None, beta=None):
H, W = point3D.shape[-2:]
alpha = torch.zeros_like(fx) if alpha is None else alpha
beta = torch.ones_like(fx) if beta is None else beta
x, y, z = point3D.unbind(dim=1)
d = torch.sqrt(beta * (x**2 + y**2) + z**2)
x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
Xnorm = fx * x + cx
Ynorm = fy * y + cy
coords = torch.stack([Xnorm, Ynorm], dim=1)
invalid = (
(coords[:, 0] < 0)
| (coords[:, 0] > W)
| (coords[:, 1] < 0)
| (coords[:, 1] > H)
| (z < 0)
)
# Return pixel coordinates
return coords, (~invalid).unsqueeze(1)
def rays2angles(rays: torch.Tensor) -> torch.Tensor:
theta = torch.atan2(rays[..., 0], rays[..., -1])
phi = torch.acos(rays[..., 1])
angles = torch.stack([theta, phi], dim=-1)
return angles
@torch.jit.script
def dilate(image, kernel_size: int | tuple[int, int]):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
device, dtype = image.device, image.dtype
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device)
dilated_image = F.conv2d(image.float(), kernel, padding=padding, stride=1)
dilated_image = torch.where(
dilated_image > 0,
torch.tensor(1.0, device=device),
torch.tensor(0.0, device=device),
)
return dilated_image.to(dtype)
@torch.jit.script
def erode(image, kernel_size: int | tuple[int, int]):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
device, dtype = image.device, image.dtype
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device)
eroded_image = F.conv2d(image.float(), kernel, padding=padding, stride=1)
eroded_image = torch.where(
eroded_image == (kernel_size[0] * kernel_size[1]),
torch.tensor(1.0, device=device),
torch.tensor(0.0, device=device),
)
return eroded_image.to(dtype)
@torch.jit.script
def iou(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor:
device = mask1.device
# Ensure the masks are binary (0 or 1)
mask1 = mask1.to(torch.bool)
mask2 = mask2.to(torch.bool)
# Compute intersection and union
intersection = torch.sum(mask1 & mask2).to(torch.float32)
union = torch.sum(mask1 | mask2).to(torch.float32)
# Compute IoU
iou = intersection / union.clip(min=1.0)
return iou
if __name__ == "__main__":
kernel_size = 3
image = torch.tensor(
[
[
[
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
]
]
],
dtype=torch.bool,
)
print("testing dilate and erode, with image:\n", image, image.shape)
# Perform dilation
dilated_image = dilate(image, kernel_size)
print("Dilated Image:\n", dilated_image)
# Perform erosion
eroded_image = erode(image, kernel_size)
print("Eroded Image:\n", eroded_image)