|
""" |
|
Author: Luigi Piccinelli |
|
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) |
|
""" |
|
|
|
from typing import Tuple |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
|
|
def generate_rays( |
|
camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False |
|
): |
|
batch_size, device, dtype = ( |
|
camera_intrinsics.shape[0], |
|
camera_intrinsics.device, |
|
camera_intrinsics.dtype, |
|
) |
|
height, width = image_shape |
|
|
|
pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype) |
|
pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype) |
|
if noisy: |
|
pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 |
|
pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 |
|
pixel_coords = torch.stack( |
|
[pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2 |
|
) |
|
pixel_coords = pixel_coords + 0.5 |
|
|
|
|
|
intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype) |
|
homogeneous_coords = torch.cat( |
|
[pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 |
|
) |
|
ray_directions = torch.matmul( |
|
intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) |
|
) |
|
ray_directions = F.normalize(ray_directions, dim=1) |
|
ray_directions = ray_directions.permute(0, 2, 1) |
|
|
|
theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) |
|
phi = torch.acos(ray_directions[..., 1]) |
|
|
|
|
|
angles = torch.stack([theta, phi], dim=-1) |
|
return ray_directions, angles |
|
|
|
|
|
@torch.jit.script |
|
def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: |
|
theta = spherical_tensor[..., 0] |
|
phi = spherical_tensor[..., 1] |
|
z = spherical_tensor[..., 2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = z * torch.tan(theta) |
|
y = z / torch.tan(phi) / torch.cos(theta) |
|
|
|
euclidean_tensor = torch.stack((x, y, z), dim=-1) |
|
return euclidean_tensor |
|
|
|
|
|
@torch.jit.script |
|
def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: |
|
theta = spherical_tensor[..., 0] |
|
phi = spherical_tensor[..., 1] |
|
r = spherical_tensor[..., 2] |
|
|
|
|
|
|
|
x = r * torch.sin(phi) * torch.sin(theta) |
|
y = r * torch.cos(phi) |
|
z = r * torch.cos(theta) * torch.sin(phi) |
|
|
|
euclidean_tensor = torch.stack((x, y, z), dim=-1) |
|
return euclidean_tensor |
|
|
|
|
|
@torch.jit.script |
|
def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor: |
|
x = spherical_tensor[..., 0] |
|
y = spherical_tensor[..., 1] |
|
z = spherical_tensor[..., 2] |
|
|
|
|
|
|
|
r = torch.sqrt(x**2 + y**2 + z**2) |
|
theta = torch.atan2(x / r, z / r) |
|
phi = torch.acos(y / r) |
|
|
|
euclidean_tensor = torch.stack((theta, phi, r), dim=-1) |
|
return euclidean_tensor |
|
|
|
|
|
@torch.jit.script |
|
def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor: |
|
pitch = torch.asin(euclidean_tensor[..., 1]) |
|
yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1]) |
|
z = euclidean_tensor[..., 2] |
|
euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1) |
|
return euclidean_tensor |
|
|
|
|
|
@torch.jit.script |
|
def unproject_points( |
|
depth: torch.Tensor, camera_intrinsics: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Unprojects a batch of depth maps to 3D point clouds using camera intrinsics. |
|
|
|
Args: |
|
depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W). |
|
camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3). |
|
|
|
Returns: |
|
torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W). |
|
""" |
|
batch_size, _, height, width = depth.shape |
|
device = depth.device |
|
|
|
|
|
y_coords, x_coords = torch.meshgrid( |
|
torch.arange(height, device=device), |
|
torch.arange(width, device=device), |
|
indexing="ij", |
|
) |
|
pixel_coords = torch.stack((x_coords, y_coords), dim=-1) |
|
|
|
|
|
pixel_coords_homogeneous = torch.cat( |
|
(pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1 |
|
) |
|
pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten( |
|
1 |
|
) |
|
|
|
unprojected_points = torch.matmul( |
|
torch.inverse(camera_intrinsics), pixel_coords_homogeneous |
|
) |
|
unprojected_points = unprojected_points.view( |
|
batch_size, 3, height, width |
|
) |
|
unprojected_points = unprojected_points * depth |
|
return unprojected_points |
|
|
|
|
|
@torch.jit.script |
|
def project_points( |
|
points_3d: torch.Tensor, |
|
intrinsic_matrix: torch.Tensor, |
|
image_shape: Tuple[int, int], |
|
) -> torch.Tensor: |
|
|
|
points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2)) |
|
|
|
|
|
points_2d = points_2d[..., :2] / points_2d[..., 2:] |
|
|
|
|
|
points_2d = points_2d.round() |
|
|
|
|
|
valid_mask = ( |
|
(points_2d[..., 0] >= 0) |
|
& (points_2d[..., 0] < image_shape[1]) |
|
& (points_2d[..., 1] >= 0) |
|
& (points_2d[..., 1] < image_shape[0]) |
|
) |
|
|
|
|
|
flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1] |
|
flat_indices = flat_points_2d.long() |
|
|
|
|
|
depth_maps = torch.zeros( |
|
[points_3d.shape[0], *image_shape], device=points_3d.device |
|
) |
|
counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device) |
|
|
|
|
|
for i in range(points_3d.shape[0]): |
|
valid_indices = flat_indices[i, valid_mask[i]] |
|
depth_maps[i].view(-1).scatter_add_( |
|
0, valid_indices, points_3d[i, valid_mask[i], 2] |
|
) |
|
counts[i].view(-1).scatter_add_( |
|
0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2]) |
|
) |
|
|
|
|
|
mean_depth_maps = depth_maps / counts.clamp(min=1.0) |
|
return mean_depth_maps.reshape(-1, 1, *image_shape) |
|
|
|
|
|
@torch.jit.script |
|
def downsample(data: torch.Tensor, downsample_factor: int = 2): |
|
N, _, H, W = data.shape |
|
data = data.view( |
|
N, |
|
H // downsample_factor, |
|
downsample_factor, |
|
W // downsample_factor, |
|
downsample_factor, |
|
1, |
|
) |
|
data = data.permute(0, 1, 3, 5, 2, 4).contiguous() |
|
data = data.view(-1, downsample_factor * downsample_factor) |
|
data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data) |
|
data = torch.min(data_tmp, dim=-1).values |
|
data = data.view(N, 1, H // downsample_factor, W // downsample_factor) |
|
data = torch.where(data > 1000, torch.zeros_like(data), data) |
|
return data |
|
|
|
|
|
@torch.jit.script |
|
def flat_interpolate( |
|
flat_tensor: torch.Tensor, |
|
old: Tuple[int, int], |
|
new: Tuple[int, int], |
|
antialias: bool = True, |
|
mode: str = "bilinear", |
|
) -> torch.Tensor: |
|
if old[0] == new[0] and old[1] == new[1]: |
|
return flat_tensor |
|
tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute( |
|
0, 3, 1, 2 |
|
) |
|
tensor_interp = F.interpolate( |
|
tensor, |
|
size=(new[0], new[1]), |
|
mode=mode, |
|
align_corners=False, |
|
antialias=antialias, |
|
) |
|
flat_tensor_interp = tensor_interp.view( |
|
flat_tensor.shape[0], -1, new[0] * new[1] |
|
).permute( |
|
0, 2, 1 |
|
) |
|
return flat_tensor_interp.contiguous() |
|
|