Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Author: Luigi Piccinelli | |
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) | |
""" | |
from copy import deepcopy | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from .coordinate import coords_grid | |
from .misc import recursive_apply, squeeze_list | |
def invert_pinhole(K): | |
fx = K[..., 0, 0] | |
fy = K[..., 1, 1] | |
cx = K[..., 0, 2] | |
cy = K[..., 1, 2] | |
K_inv = torch.zeros_like(K) | |
K_inv[..., 0, 0] = 1.0 / fx | |
K_inv[..., 1, 1] = 1.0 / fy | |
K_inv[..., 0, 2] = -cx / fx | |
K_inv[..., 1, 2] = -cy / fy | |
K_inv[..., 2, 2] = 1.0 | |
return K_inv | |
def unproject_pinhole(depth, K): | |
b, _, h, w = depth.shape | |
K_inv = invert_pinhole(K) | |
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] | |
grid_flat = grid.reshape(b, -1, h * w) # [B, 3, H*W] | |
cam_coords = K_inv @ grid_flat | |
pcd = cam_coords.reshape(b, -1, h, w) * depth | |
return pcd | |
def project_pinhole(pcd, K): | |
b, _, h, w = pcd.shape | |
pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W] | |
cam_coords = K @ pcd_flat | |
pcd_proj = cam_coords[:, :2] / cam_coords[:, 2:].clamp(min=0.01) | |
pcd_proj = pcd_proj.reshape(b, 2, h, w) | |
return pcd_proj | |
class Camera: | |
def __init__(self, params=None, K=None): | |
if params.ndim == 1: | |
params = params.unsqueeze(0) | |
if K is None: | |
K = ( | |
torch.eye(3, device=params.device, dtype=params.dtype) | |
.unsqueeze(0) | |
.repeat(params.shape[0], 1, 1) | |
) | |
K[..., 0, 0] = params[..., 0] | |
K[..., 1, 1] = params[..., 1] | |
K[..., 0, 2] = params[..., 2] | |
K[..., 1, 2] = params[..., 3] | |
self.params = params | |
self.K = K | |
self.overlap_mask = None | |
self.projection_mask = None | |
def project(self, xyz): | |
raise NotImplementedError | |
def unproject(self, uv): | |
raise NotImplementedError | |
def get_projection_mask(self): | |
return self.projection_mask | |
def get_overlap_mask(self): | |
return self.overlap_mask | |
def reconstruct(self, depth): | |
id_coords = coords_grid( | |
1, depth.shape[-2], depth.shape[-1], device=depth.device | |
) | |
rays = self.unproject(id_coords) | |
return ( | |
rays / rays[:, -1:].clamp(min=1e-4) * depth.clamp(min=1e-4) | |
) # assumption z>0!!! | |
def resize(self, factor): | |
self.K[..., :2, :] *= factor | |
self.params[..., :4] *= factor | |
return self | |
def to(self, device, non_blocking=False): | |
self.params = self.params.to(device, non_blocking=non_blocking) | |
self.K = self.K.to(device, non_blocking=non_blocking) | |
return self | |
def get_rays(self, shapes, noisy=False): | |
b, h, w = shapes | |
uv = coords_grid(1, h, w, device=self.K.device, noisy=noisy) | |
rays = self.unproject(uv) | |
return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) | |
def get_pinhole_rays(self, shapes, noisy=False): | |
b, h, w = shapes | |
uv = coords_grid(b, h, w, device=self.K.device, homogeneous=True, noisy=noisy) | |
rays = (invert_pinhole(self.K) @ uv.reshape(b, 3, -1)).reshape(b, 3, h, w) | |
return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) | |
def flip(self, H, W, direction="horizontal"): | |
new_cx = ( | |
W - self.params[:, 2] if direction == "horizontal" else self.params[:, 2] | |
) | |
new_cy = H - self.params[:, 3] if direction == "vertical" else self.params[:, 3] | |
self.params = torch.stack( | |
[self.params[:, 0], self.params[:, 1], new_cx, new_cy], dim=1 | |
) | |
self.K[..., 0, 2] = new_cx | |
self.K[..., 1, 2] = new_cy | |
return self | |
def clone(self): | |
return deepcopy(self) | |
def crop(self, left, top, right=None, bottom=None): | |
self.K[..., 0, 2] -= left | |
self.K[..., 1, 2] -= top | |
self.params[..., 2] -= left | |
self.params[..., 3] -= top | |
return self | |
# helper function to get how fov changes based on new original size and new size | |
def get_new_fov(self, new_shape, original_shape): | |
new_hfov = 2 * torch.atan( | |
self.params[..., 2] / self.params[..., 0] * new_shape[1] / original_shape[1] | |
) | |
new_vfov = 2 * torch.atan( | |
self.params[..., 3] / self.params[..., 1] * new_shape[0] / original_shape[0] | |
) | |
return new_hfov, new_vfov | |
def mask_overlap_projection(self, projected): | |
B, _, H, W = projected.shape | |
id_coords = coords_grid(B, H, W, device=projected.device) | |
# check for mask where flow would overlap with other part of the image | |
# eleemtns coming from the border are then masked out | |
flow = projected - id_coords | |
gamma = 0.1 | |
sample_grid = gamma * flow + id_coords # sample along the flow | |
sample_grid[:, 0] = sample_grid[:, 0] / (W - 1) * 2 - 1 | |
sample_grid[:, 1] = sample_grid[:, 1] / (H - 1) * 2 - 1 | |
sampled_flow = F.grid_sample( | |
flow, | |
sample_grid.permute(0, 2, 3, 1), | |
mode="bilinear", | |
align_corners=False, | |
padding_mode="border", | |
) | |
mask = ( | |
(1 - gamma) * torch.norm(flow, dim=1, keepdim=True) | |
< torch.norm(sampled_flow, dim=1, keepdim=True) | |
) | (torch.norm(flow, dim=1, keepdim=True) < 1) | |
return mask | |
def _pad_params(self): | |
# Ensure params are padded to length 16 | |
if self.params.shape[1] < 16: | |
padding = torch.zeros( | |
16 - self.params.shape[1], | |
device=self.params.device, | |
dtype=self.params.dtype, | |
) | |
padding = padding.view(*[(self.params.ndim - 1) * [1] + [-1]]) | |
padding = padding.repeat(self.params.shape[:-1] + (1,)) | |
return torch.cat([self.params, padding], dim=-1) | |
return self.params | |
def flatten_cameras(cameras): # -> list[Camera]: | |
# Recursively flatten BatchCamera into primitive cameras | |
flattened_cameras = [] | |
for camera in cameras: | |
if isinstance(camera, BatchCamera): | |
flattened_cameras.extend(BatchCamera.flatten_cameras(camera.cameras)) | |
elif isinstance(camera, list): | |
flattened_cameras.extend(camera) | |
else: | |
flattened_cameras.append(camera) | |
return flattened_cameras | |
def _stack_or_cat_cameras(cameras, func, **kwargs): | |
# Generalized method to handle stacking or concatenation | |
flat_cameras = BatchCamera.flatten_cameras(cameras) | |
K_matrices = [camera.K for camera in flat_cameras] | |
padded_params = [camera._pad_params() for camera in flat_cameras] | |
stacked_K = func(K_matrices, **kwargs) | |
stacked_params = func(padded_params, **kwargs) | |
# Keep track of the original classes | |
original_class = [x.__class__.__name__ for x in flat_cameras] | |
return BatchCamera(stacked_params, stacked_K, original_class, flat_cameras) | |
def __torch_function__(cls, func, types, args=(), kwargs=None): | |
if kwargs is None: | |
kwargs = {} | |
if func is torch.cat: | |
return Camera._stack_or_cat_cameras(args[0], func, **kwargs) | |
if func is torch.stack: | |
return Camera._stack_or_cat_cameras(args[0], func, **kwargs) | |
if func is torch.flatten: | |
return Camera._stack_or_cat_cameras(args[0], torch.cat, **kwargs) | |
return super().__torch_function__(func, types, args, kwargs) | |
def device(self): | |
return self.K.device | |
# here we assume that cx,cy are more or less H/2 and W/2 | |
def hfov(self): | |
return 2 * torch.atan(self.params[..., 2] / self.params[..., 0]) | |
def vfov(self): | |
return 2 * torch.atan(self.params[..., 3] / self.params[..., 1]) | |
def max_fov(self): | |
return 150.0 / 180.0 * np.pi, 150.0 / 180.0 * np.pi | |
class Pinhole(Camera): | |
def __init__(self, params=None, K=None): | |
assert params is not None or K is not None | |
# params = [fx, fy, cx, cy] | |
if params is None: | |
params = torch.stack( | |
[K[..., 0, 0], K[..., 1, 1], K[..., 0, 2], K[..., 1, 2]], dim=-1 | |
) | |
super().__init__(params=params, K=K) | |
def project(self, pcd): | |
b, _, h, w = pcd.shape | |
pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W] | |
cam_coords = self.K @ pcd_flat | |
pcd_proj = cam_coords[:, :2] / cam_coords[:, -1:].clamp(min=0.01) | |
pcd_proj = pcd_proj.reshape(b, 2, h, w) | |
invalid = ( | |
(pcd_proj[:, 0] < 0) | |
& (pcd_proj[:, 0] >= w) | |
& (pcd_proj[:, 1] < 0) | |
& (pcd_proj[:, 1] >= h) | |
) | |
self.projection_mask = (~invalid).unsqueeze(1) | |
return pcd_proj | |
def unproject(self, uv): | |
b, _, h, w = uv.shape | |
uv_flat = uv.reshape(b, 2, -1) # [B, 2, H*W] | |
uv_homogeneous = torch.cat( | |
[uv_flat, torch.ones(b, 1, h * w, device=uv.device)], dim=1 | |
) # [B, 3, H*W] | |
K_inv = torch.inverse(self.K.float()) | |
xyz = K_inv @ uv_homogeneous | |
xyz = xyz / xyz[:, -1:].clip(min=1e-4) | |
xyz = xyz.reshape(b, 3, h, w) | |
self.unprojection_mask = xyz[:, -1:] > 1e-4 | |
return xyz | |
def reconstruct(self, depth): | |
b, _, h, w = depth.shape | |
uv = coords_grid(b, h, w, device=depth.device) | |
xyz = self.unproject(uv) * depth.clip(min=0.0) | |
return xyz | |
class EUCM(Camera): | |
def __init__(self, params): | |
# params = [fx, fy, cx, cy, alpha, beta] | |
super().__init__(params=params, K=None) | |
def project(self, xyz): | |
H, W = xyz.shape[-2:] | |
fx, fy, cx, cy, alpha, beta = self.params[:6].unbind(dim=1) | |
x, y, z = xyz.unbind(dim=1) | |
d = torch.sqrt(beta * (x**2 + y**2) + z**2) | |
x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3) | |
y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3) | |
Xnorm = fx * x + cx | |
Ynorm = fy * y + cy | |
coords = torch.stack([Xnorm, Ynorm], dim=1) | |
invalid = ( | |
(coords[:, 0] < 0) | |
| (coords[:, 0] > W) | |
| (coords[:, 1] < 0) | |
| (coords[:, 1] > H) | |
| (z < 0) | |
) | |
self.projection_mask = (~invalid).unsqueeze(1) | |
return coords | |
def unproject(self, uv): | |
u, v = uv.unbind(dim=1) | |
fx, fy, cx, cy, alpha, beta = self.params.unbind(dim=1) | |
mx = (u - cx) / fx | |
my = (v - cy) / fy | |
r_square = mx**2 + my**2 | |
valid_mask = r_square < torch.where( | |
alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1)) | |
) | |
sqrt_val = 1 - (2 * alpha - 1) * beta * r_square | |
mz = (1 - beta * (alpha**2) * r_square) / ( | |
alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha) | |
) | |
coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5) | |
x = coeff * mx | |
y = coeff * my | |
z = coeff * mz | |
self.unprojection_mask = valid_mask & (z > 1e-3) | |
xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1) | |
return xnorm | |
class Spherical(Camera): | |
def __init__(self, params): | |
# Hfov and Vofv are in radians and halved! | |
# params: [fx, fy, cx, cy, W, H, HFoV/2, VFoV/2] | |
# fx,fy,cx,cy = dummy values | |
super().__init__(params=params, K=None) | |
def resize(self, factor): | |
self.K[..., :2, :] *= factor | |
self.params[..., :6] *= factor | |
return self | |
def crop(self, left, top, right, bottom): | |
self.K[..., 0, 2] -= left | |
self.K[..., 1, 2] -= top | |
self.params[..., 2] -= left | |
self.params[..., 3] -= top | |
W, H = self.params[..., 4], self.params[..., 5] | |
angle_ratio_W = (W - left - right) / W | |
angle_ratio_H = (H - top - bottom) / H | |
self.params[..., 4] -= left + right | |
self.params[..., 5] -= top + bottom | |
# rescale hfov and vfov | |
self.params[..., 6] *= angle_ratio_W | |
self.params[..., 7] *= angle_ratio_H | |
return self | |
def project(self, xyz): | |
width, height = self.params[..., 4], self.params[..., 5] | |
hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] | |
longitude = torch.atan2(xyz[:, 0], xyz[:, 2]) | |
latitude = torch.asin(xyz[:, 1] / torch.norm(xyz, dim=1).clamp(min=1e-5)) | |
u = longitude / hfov * (width - 1) + (width - 1) / 2 | |
v = latitude / vfov * (height - 1) + (height - 1) / 2 | |
return torch.stack([u, v], dim=1) | |
def unproject(self, uv): | |
u, v = uv.unbind(dim=1) | |
width, height = self.params[..., 4], self.params[..., 5] | |
hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] | |
longitude = (u - (width - 1) / 2) / (width - 1) * hfov | |
latitude = (v - (height - 1) / 2) / (height - 1) * vfov | |
x = torch.cos(latitude) * torch.sin(longitude) | |
z = torch.cos(latitude) * torch.cos(longitude) | |
y = torch.sin(latitude) | |
unit_sphere = torch.stack([x, y, z], dim=1) | |
unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=1, keepdim=True).clip( | |
min=1e-5 | |
) | |
return unit_sphere | |
def reconstruct(self, depth): | |
id_coords = coords_grid( | |
1, depth.shape[-2], depth.shape[-1], device=depth.device | |
) | |
return self.unproject(id_coords) * depth | |
def get_new_fov(self, new_shape, original_shape): | |
new_hfov = 2 * self.params[..., 6] * new_shape[1] / original_shape[1] | |
new_vfov = 2 * self.params[..., 7] * new_shape[0] / original_shape[0] | |
return new_hfov, new_vfov | |
def hfov(self): | |
return 2 * self.params[..., 6] | |
def vfov(self): | |
return 2 * self.params[..., 7] | |
def max_fov(self): | |
return 2 * np.pi, 0.9 * np.pi # avoid strong distortion on tops | |
class OPENCV(Camera): | |
def __init__(self, params): | |
super().__init__(params=params, K=None) | |
# params: [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4] | |
self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6 | |
assert ( | |
self.params[..., 7:10].abs().sum() == 0.0 | |
), "Do not support poly division model" | |
self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6 | |
self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6 | |
def project(self, xyz): | |
eps = 1e-9 | |
B, _, H, W = xyz.shape | |
N = H * W | |
xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) | |
# Radial correction. | |
z = xyz[:, :, 2].reshape(B, N, 1) | |
z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) | |
ab = xyz[:, :, :2] / z | |
r = torch.norm(ab, dim=-1, p=2, keepdim=True) | |
th = r | |
th_pow = torch.cat( | |
[torch.pow(th, 2 + i * 2) for i in range(3)], dim=-1 | |
) # Create powers of th (th^3, th^5, ...) | |
distortion_coeffs_num = self.params[:, 4:7].reshape(B, 1, 3) | |
distortion_coeffs_den = self.params[:, 7:10].reshape(B, 1, 3) | |
th_num = 1 + torch.sum(th_pow * distortion_coeffs_num, dim=-1, keepdim=True) | |
th_den = 1 + torch.sum(th_pow * distortion_coeffs_den, dim=-1, keepdim=True) | |
xr_yr = ab * th_num / th_den | |
uv_dist = xr_yr | |
# Tangential correction. | |
p0 = self.params[..., -6].reshape(B, 1) | |
p1 = self.params[..., -5].reshape(B, 1) | |
xr = xr_yr[:, :, 0].reshape(B, N) | |
yr = xr_yr[:, :, 1].reshape(B, N) | |
xr_yr_sq = torch.square(xr_yr) | |
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) | |
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) | |
rd_sq = xr_sq + yr_sq | |
uv_dist_tu = uv_dist[:, :, 0] + ( | |
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 | |
) | |
uv_dist_tv = uv_dist[:, :, 1] + ( | |
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 | |
) | |
uv_dist = torch.stack( | |
[uv_dist_tu, uv_dist_tv], dim=-1 | |
) # Avoids in-place complaint. | |
# Thin Prism correction. | |
s0 = self.params[..., -4].reshape(B, 1) | |
s1 = self.params[..., -3].reshape(B, 1) | |
s2 = self.params[..., -2].reshape(B, 1) | |
s3 = self.params[..., -1].reshape(B, 1) | |
rd_4 = torch.square(rd_sq) | |
uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) | |
uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) | |
# Finally, apply standard terms: focal length and camera centers. | |
if self.params.shape[-1] == 15: | |
fx_fy = self.params[..., 0].reshape(B, 1, 1) | |
cx_cy = self.params[..., 1:3].reshape(B, 1, 2) | |
else: | |
fx_fy = self.params[..., 0:2].reshape(B, 1, 2) | |
cx_cy = self.params[..., 2:4].reshape(B, 1, 2) | |
result = uv_dist * fx_fy + cx_cy | |
result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) | |
invalid = ( | |
(result[:, 0] < 0) | |
| (result[:, 0] > W) | |
| (result[:, 1] < 0) | |
| (result[:, 1] > H) | |
) | |
self.projection_mask = (~invalid).unsqueeze(1) | |
self.overlap_mask = self.mask_overlap_projection(result) | |
return result | |
def unproject(self, uv, max_iters: int = 10): | |
eps = 1e-3 | |
B, _, H, W = uv.shape | |
N = H * W | |
uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) | |
if self.params.shape[-1] == 15: | |
fx_fy = self.params[..., 0].reshape(B, 1, 1) | |
cx_cy = self.params[..., 1:3].reshape(B, 1, 2) | |
else: | |
fx_fy = self.params[..., 0:2].reshape(B, 1, 2) | |
cx_cy = self.params[..., 2:4].reshape(B, 1, 2) | |
uv_dist = (uv - cx_cy) / fx_fy | |
# Compute xr_yr using Newton's method. | |
xr_yr = uv_dist.clone() # Initial guess. | |
max_iters_tanprism = ( | |
max_iters if self.use_thin_prism or self.use_tangential else 0 | |
) | |
for _ in range(max_iters_tanprism): | |
uv_dist_est = xr_yr.clone() | |
xr = xr_yr[..., 0].reshape(B, N) | |
yr = xr_yr[..., 1].reshape(B, N) | |
xr_yr_sq = torch.square(xr_yr) | |
xr_sq = xr_yr_sq[..., 0].reshape(B, N) | |
yr_sq = xr_yr_sq[..., 1].reshape(B, N) | |
rd_sq = xr_sq + yr_sq | |
if self.use_tangential: | |
# Tangential terms. | |
p0 = self.params[..., -6].reshape(B, 1) | |
p1 = self.params[..., -5].reshape(B, 1) | |
uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( | |
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 | |
) | |
uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( | |
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 | |
) | |
if self.use_thin_prism: | |
# Thin Prism terms. | |
s0 = self.params[..., -4].reshape(B, 1) | |
s1 = self.params[..., -3].reshape(B, 1) | |
s2 = self.params[..., -2].reshape(B, 1) | |
s3 = self.params[..., -1].reshape(B, 1) | |
rd_4 = torch.square(rd_sq) | |
uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) | |
uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) | |
# Compute the derivative of uv_dist w.r.t. xr_yr. | |
duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) | |
if self.use_tangential: | |
duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 | |
offdiag = 2.0 * (xr * p1 + yr * p0) | |
duv_dist_dxr_yr[..., 0, 1] = offdiag | |
duv_dist_dxr_yr[..., 1, 0] = offdiag | |
duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 | |
if self.use_thin_prism: | |
xr_yr_sq_norm = xr_sq + yr_sq | |
temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) | |
duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1) | |
duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1) | |
temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) | |
duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2) | |
duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2) | |
mat = duv_dist_dxr_yr.reshape(-1, 2, 2) | |
a = mat[:, 0, 0].reshape(-1, 1, 1) | |
b = mat[:, 0, 1].reshape(-1, 1, 1) | |
c = mat[:, 1, 0].reshape(-1, 1, 1) | |
d = mat[:, 1, 1].reshape(-1, 1, 1) | |
det = 1.0 / ((a * d) - (b * c)) | |
top = torch.cat([d, -b], dim=-1) | |
bot = torch.cat([-c, a], dim=-1) | |
inv = det * torch.cat([top, bot], dim=-2) | |
inv = inv.reshape(B, N, 2, 2) | |
diff = uv_dist - uv_dist_est | |
a = inv[..., 0, 0] | |
b = inv[..., 0, 1] | |
c = inv[..., 1, 0] | |
d = inv[..., 1, 1] | |
e = diff[..., 0] | |
f = diff[..., 1] | |
step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) | |
# Newton step. | |
xr_yr = xr_yr + step | |
# Compute theta using Newton's method. | |
xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) | |
th = xr_yr_norm.clone() | |
max_iters_radial = max_iters if self.use_radial else 0 | |
c = ( | |
torch.tensor([2.0 * i + 3 for i in range(3)], device=self.device) | |
.reshape(1, 1, 3) | |
.repeat(B, 1, 1) | |
) | |
radial_params_num = self.params[..., 4:7].reshape(B, 1, 3) | |
# Trust region parameters | |
delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius | |
delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius | |
eta = 0.1 # Acceptable reduction threshold | |
for i in range(max_iters_radial): | |
th_sq = th * th # th^2 | |
# Compute powers of th^2 up to th^(12) | |
theta_powers = torch.cat( | |
[th_sq ** (i + 1) for i in range(3)], dim=-1 | |
) # Shape: (B, N, 6) | |
# Compute th_radial: radial distortion model applied to th | |
th_radial = 1.0 + torch.sum( | |
theta_powers * radial_params_num, dim=-1, keepdim=True | |
) | |
th_radial = th_radial * th # Multiply by th at the end | |
# Compute derivative dthd_th | |
dthd_th = 1.0 + torch.sum( | |
c * radial_params_num * theta_powers, dim=-1, keepdim=True | |
) | |
dthd_th = dthd_th # Already includes derivative terms | |
# Compute residual | |
residual = th_radial - xr_yr_norm # Shape: (B, N, 1) | |
residual_norm = torch.norm(residual, dim=2, keepdim=True) # For each pixel | |
# Check for convergence | |
if torch.max(torch.abs(residual)) < eps: | |
break | |
# Avoid division by zero by adding a small epsilon | |
safe_dthd_th = dthd_th.clone() | |
zero_derivative_mask = dthd_th.abs() < eps | |
safe_dthd_th[zero_derivative_mask] = eps | |
# Compute Newton's step | |
step = -residual / safe_dthd_th | |
# Compute predicted reduction | |
predicted_reduction = -(residual * step).sum(dim=2, keepdim=True) | |
# Adjust step based on trust region | |
step_norm = torch.norm(step, dim=2, keepdim=True) | |
over_trust_mask = step_norm > delta | |
# Scale step if it exceeds trust radius | |
step_scaled = step.clone() | |
step_scaled[over_trust_mask] = step[over_trust_mask] * ( | |
delta[over_trust_mask] / step_norm[over_trust_mask] | |
) | |
# Update theta | |
th_new = th + step_scaled | |
# Compute new residual | |
th_sq_new = th_new * th_new | |
theta_powers_new = torch.cat( | |
[th_sq_new ** (j + 1) for j in range(3)], dim=-1 | |
) | |
th_radial_new = 1.0 + torch.sum( | |
theta_powers_new * radial_params_num, dim=-1, keepdim=True | |
) | |
th_radial_new = th_radial_new * th_new | |
residual_new = th_radial_new - xr_yr_norm | |
residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True) | |
# Compute actual reduction | |
actual_reduction = residual_norm - residual_new_norm | |
# Compute ratio of actual to predicted reduction | |
# predicted_reduction[predicted_reduction.abs() < eps] = eps #* torch.sign(predicted_reduction[predicted_reduction.abs() < eps]) | |
rho = actual_reduction / predicted_reduction | |
rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0 | |
# Update trust radius delta | |
delta_update_mask = rho > 0.5 | |
delta[delta_update_mask] = torch.min( | |
2.0 * delta[delta_update_mask], delta_max | |
) | |
delta_decrease_mask = rho < 0.2 | |
delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask] | |
# Accept or reject the step | |
accept_step_mask = rho > eta | |
th = torch.where(accept_step_mask, th_new, th) | |
# Compute the ray direction using theta and xr_yr. | |
close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) | |
ray_dir = torch.where(close_to_zero, xr_yr, th / xr_yr_norm * xr_yr) | |
ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) | |
ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) | |
return ray | |
class Fisheye624(Camera): | |
def __init__(self, params): | |
super().__init__(params=params, K=None) | |
# params: [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4] | |
self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6 | |
self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6 | |
self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6 | |
def project(self, xyz): | |
eps = 1e-9 | |
B, _, H, W = xyz.shape | |
N = H * W | |
xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) | |
# Radial correction. | |
z = xyz[:, :, 2].reshape(B, N, 1) | |
z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) | |
ab = xyz[:, :, :2] / z | |
r = torch.norm(ab, dim=-1, p=2, keepdim=True) | |
th = torch.atan(r) | |
th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) | |
th_pow = torch.cat( | |
[torch.pow(th, 3 + i * 2) for i in range(6)], dim=-1 | |
) # Create powers of th (th^3, th^5, ...) | |
distortion_coeffs = self.params[:, 4:10].reshape(B, 1, 6) | |
th_k = th + torch.sum(th_pow * distortion_coeffs, dim=-1, keepdim=True) | |
xr_yr = th_k * th_divr | |
uv_dist = xr_yr | |
# Tangential correction. | |
p0 = self.params[..., -6].reshape(B, 1) | |
p1 = self.params[..., -5].reshape(B, 1) | |
xr = xr_yr[:, :, 0].reshape(B, N) | |
yr = xr_yr[:, :, 1].reshape(B, N) | |
xr_yr_sq = torch.square(xr_yr) | |
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) | |
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) | |
rd_sq = xr_sq + yr_sq | |
uv_dist_tu = uv_dist[:, :, 0] + ( | |
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 | |
) | |
uv_dist_tv = uv_dist[:, :, 1] + ( | |
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 | |
) | |
uv_dist = torch.stack( | |
[uv_dist_tu, uv_dist_tv], dim=-1 | |
) # Avoids in-place complaint. | |
# Thin Prism correction. | |
s0 = self.params[..., -4].reshape(B, 1) | |
s1 = self.params[..., -3].reshape(B, 1) | |
s2 = self.params[..., -2].reshape(B, 1) | |
s3 = self.params[..., -1].reshape(B, 1) | |
rd_4 = torch.square(rd_sq) | |
uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) | |
uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) | |
# Finally, apply standard terms: focal length and camera centers. | |
if self.params.shape[-1] == 15: | |
fx_fy = self.params[..., 0].reshape(B, 1, 1) | |
cx_cy = self.params[..., 1:3].reshape(B, 1, 2) | |
else: | |
fx_fy = self.params[..., 0:2].reshape(B, 1, 2) | |
cx_cy = self.params[..., 2:4].reshape(B, 1, 2) | |
result = uv_dist * fx_fy + cx_cy | |
result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) | |
invalid = ( | |
(result[:, 0] < 0) | |
| (result[:, 0] > W) | |
| (result[:, 1] < 0) | |
| (result[:, 1] > H) | |
) | |
self.projection_mask = (~invalid).unsqueeze(1) | |
self.overlap_mask = self.mask_overlap_projection(result) | |
return result | |
def unproject(self, uv, max_iters: int = 10): | |
eps = 1e-3 | |
B, _, H, W = uv.shape | |
N = H * W | |
uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) | |
if self.params.shape[-1] == 15: | |
fx_fy = self.params[..., 0].reshape(B, 1, 1) | |
cx_cy = self.params[..., 1:3].reshape(B, 1, 2) | |
else: | |
fx_fy = self.params[..., 0:2].reshape(B, 1, 2) | |
cx_cy = self.params[..., 2:4].reshape(B, 1, 2) | |
uv_dist = (uv - cx_cy) / fx_fy | |
# Compute xr_yr using Newton's method. | |
xr_yr = uv_dist.clone() # Initial guess. | |
max_iters_tanprism = ( | |
max_iters if self.use_thin_prism or self.use_tangential else 0 | |
) | |
for _ in range(max_iters_tanprism): | |
uv_dist_est = xr_yr.clone() | |
xr = xr_yr[..., 0].reshape(B, N) | |
yr = xr_yr[..., 1].reshape(B, N) | |
xr_yr_sq = torch.square(xr_yr) | |
xr_sq = xr_yr_sq[..., 0].reshape(B, N) | |
yr_sq = xr_yr_sq[..., 1].reshape(B, N) | |
rd_sq = xr_sq + yr_sq | |
if self.use_tangential: | |
# Tangential terms. | |
p0 = self.params[..., -6].reshape(B, 1) | |
p1 = self.params[..., -5].reshape(B, 1) | |
uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( | |
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 | |
) | |
uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( | |
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 | |
) | |
if self.use_thin_prism: | |
# Thin Prism terms. | |
s0 = self.params[..., -4].reshape(B, 1) | |
s1 = self.params[..., -3].reshape(B, 1) | |
s2 = self.params[..., -2].reshape(B, 1) | |
s3 = self.params[..., -1].reshape(B, 1) | |
rd_4 = torch.square(rd_sq) | |
uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) | |
uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) | |
# Compute the derivative of uv_dist w.r.t. xr_yr. | |
duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) | |
if self.use_tangential: | |
duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 | |
offdiag = 2.0 * (xr * p1 + yr * p0) | |
duv_dist_dxr_yr[..., 0, 1] = offdiag | |
duv_dist_dxr_yr[..., 1, 0] = offdiag | |
duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 | |
if self.use_thin_prism: | |
xr_yr_sq_norm = xr_sq + yr_sq | |
temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) | |
duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1) | |
duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1) | |
temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) | |
duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2) | |
duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2) | |
# Compute 2x2 inverse manually here since torch.inverse() is very slow. | |
# Because this is slow: inv = duv_dist_dxr_yr.inverse() | |
# About a 10x reduction in speed with above line. | |
mat = duv_dist_dxr_yr.reshape(-1, 2, 2) | |
a = mat[:, 0, 0].reshape(-1, 1, 1) | |
b = mat[:, 0, 1].reshape(-1, 1, 1) | |
c = mat[:, 1, 0].reshape(-1, 1, 1) | |
d = mat[:, 1, 1].reshape(-1, 1, 1) | |
det = 1.0 / ((a * d) - (b * c)) | |
top = torch.cat([d, -b], dim=-1) | |
bot = torch.cat([-c, a], dim=-1) | |
inv = det * torch.cat([top, bot], dim=-2) | |
inv = inv.reshape(B, N, 2, 2) | |
diff = uv_dist - uv_dist_est | |
a = inv[..., 0, 0] | |
b = inv[..., 0, 1] | |
c = inv[..., 1, 0] | |
d = inv[..., 1, 1] | |
e = diff[..., 0] | |
f = diff[..., 1] | |
step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) | |
# Newton step. | |
xr_yr = xr_yr + step | |
# Compute theta using Newton's method. | |
xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) | |
th = xr_yr_norm.clone() | |
max_iters_radial = max_iters if self.use_radial else 0 | |
c = ( | |
torch.tensor([2.0 * i + 3 for i in range(6)], device=self.device) | |
.reshape(1, 1, 6) | |
.repeat(B, 1, 1) | |
) | |
radial_params = self.params[..., 4:10].reshape(B, 1, 6) | |
# Trust region parameters | |
delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius | |
delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius | |
eta = 0.1 # Acceptable reduction threshold | |
for i in range(max_iters_radial): | |
th_sq = th * th | |
# Compute powers of th^2 up to th^(12) | |
theta_powers = torch.cat( | |
[th_sq ** (i + 1) for i in range(6)], dim=-1 | |
) # Shape: (B, N, 6) | |
# Compute th_radial: radial distortion model applied to th | |
th_radial = 1.0 + torch.sum( | |
theta_powers * radial_params, dim=-1, keepdim=True | |
) | |
th_radial = th_radial * th | |
# Compute derivative dthd_th | |
dthd_th = 1.0 + torch.sum( | |
c * radial_params * theta_powers, dim=-1, keepdim=True | |
) | |
dthd_th = dthd_th | |
# Compute residual | |
residual = th_radial - xr_yr_norm # Shape: (B, N, 1) | |
residual_norm = torch.norm(residual, dim=2, keepdim=True) | |
# Check for convergence | |
if torch.max(torch.abs(residual)) < eps: | |
break | |
# Avoid division by zero by adding a small epsilon | |
safe_dthd_th = dthd_th.clone() | |
zero_derivative_mask = dthd_th.abs() < eps | |
safe_dthd_th[zero_derivative_mask] = eps | |
# Compute Newton's step | |
step = -residual / safe_dthd_th | |
# Compute predicted reduction | |
predicted_reduction = -(residual * step).sum(dim=2, keepdim=True) | |
# Adjust step based on trust region | |
step_norm = torch.norm(step, dim=2, keepdim=True) | |
over_trust_mask = step_norm > delta | |
# Scale step if it exceeds trust radius | |
step_scaled = step.clone() | |
step_scaled[over_trust_mask] = step[over_trust_mask] * ( | |
delta[over_trust_mask] / step_norm[over_trust_mask] | |
) | |
# Update theta | |
th_new = th + step_scaled | |
# Compute new residual | |
th_sq_new = th_new * th_new | |
theta_powers_new = torch.cat( | |
[th_sq_new ** (j + 1) for j in range(6)], dim=-1 | |
) | |
th_radial_new = 1.0 + torch.sum( | |
theta_powers_new * radial_params, dim=-1, keepdim=True | |
) | |
th_radial_new = th_radial_new * th_new | |
residual_new = th_radial_new - xr_yr_norm | |
residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True) | |
# Compute actual reduction | |
actual_reduction = residual_norm - residual_new_norm | |
# Compute ratio of actual to predicted reduction | |
rho = actual_reduction / predicted_reduction | |
rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0 | |
# Update trust radius delta | |
delta_update_mask = rho > 0.5 | |
delta[delta_update_mask] = torch.min( | |
2.0 * delta[delta_update_mask], delta_max | |
) | |
delta_decrease_mask = rho < 0.2 | |
delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask] | |
# Accept or reject the step | |
accept_step_mask = rho > eta | |
th = torch.where(accept_step_mask, th_new, th) | |
# Compute the ray direction using theta and xr_yr. | |
close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) | |
ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) | |
ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) | |
ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) | |
return ray | |
class MEI(Camera): | |
def __init__(self, params): | |
super().__init__(params=params, K=None) | |
# fx fy cx cy k1 k2 p1 p2 xi | |
self.use_radial = self.params[..., 4:6].abs().sum() > 1e-6 | |
self.use_tangential = self.params[..., 6:8].abs().sum() > 1e-6 | |
def unproject(self, uv, max_iters: int = 20): | |
eps = 1e-6 | |
B, _, H, W = uv.shape | |
N = H * W | |
uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) | |
k1, k2, p0, p1, xi = self.params[..., 4:9].unbind(dim=1) | |
fx_fy = self.params[..., 0:2].reshape(B, 1, 2) | |
cx_cy = self.params[..., 2:4].reshape(B, 1, 2) | |
uv_dist = (uv - cx_cy) / fx_fy | |
# Compute xr_yr using Newton's method. | |
xr_yr = uv_dist.clone() # Initial guess. | |
max_iters_tangential = max_iters if self.use_tangential else 0 | |
for _ in range(max_iters_tangential): | |
uv_dist_est = xr_yr.clone() | |
# Tangential terms. | |
xr = xr_yr[..., 0] | |
yr = xr_yr[..., 1] | |
xr_yr_sq = xr_yr**2 | |
xr_sq = xr_yr_sq[..., 0] | |
yr_sq = xr_yr_sq[..., 1] | |
rd_sq = xr_sq + yr_sq | |
uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( | |
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 | |
) | |
uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( | |
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 | |
) | |
# Compute the derivative of uv_dist w.r.t. xr_yr. | |
duv_dist_dxr_yr = torch.ones((B, N, 2, 2), device=uv.device) | |
duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 | |
offdiag = 2.0 * (xr * p1 + yr * p0) | |
duv_dist_dxr_yr[..., 0, 1] = offdiag | |
duv_dist_dxr_yr[..., 1, 0] = offdiag | |
duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 | |
mat = duv_dist_dxr_yr.reshape(-1, 2, 2) | |
a = mat[:, 0, 0].reshape(-1, 1, 1) | |
b = mat[:, 0, 1].reshape(-1, 1, 1) | |
c = mat[:, 1, 0].reshape(-1, 1, 1) | |
d = mat[:, 1, 1].reshape(-1, 1, 1) | |
det = 1.0 / ((a * d) - (b * c)) | |
top = torch.cat([d, -b], dim=-1) | |
bot = torch.cat([-c, a], dim=-1) | |
inv = det * torch.cat([top, bot], dim=-2) | |
inv = inv.reshape(B, N, 2, 2) | |
diff = uv_dist - uv_dist_est | |
a = inv[..., 0, 0] | |
b = inv[..., 0, 1] | |
c = inv[..., 1, 0] | |
d = inv[..., 1, 1] | |
e = diff[..., 0] | |
f = diff[..., 1] | |
step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) | |
# Newton step. | |
xr_yr = xr_yr + step | |
# Compute theta using Newton's method. | |
xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) | |
th = xr_yr_norm.clone() | |
max_iters_radial = max_iters if self.use_radial else 0 | |
for _ in range(max_iters_radial): | |
th_radial = 1.0 + k1 * torch.pow(th, 2) + k2 * torch.pow(th, 4) | |
dthd_th = 1.0 + 3.0 * k1 * torch.pow(th, 2) + 5.0 * k2 * torch.pow(th, 4) | |
th_radial = th_radial * th | |
step = (xr_yr_norm - th_radial) / dthd_th | |
# handle dthd_th close to 0. | |
step = torch.where( | |
torch.abs(dthd_th) > eps, step, torch.sign(step) * eps * 10.0 | |
) | |
th = th + step | |
# Compute the ray direction using theta and xr_yr. | |
close_to_zero = (torch.abs(th) < eps) & (torch.abs(xr_yr_norm) < eps) | |
ray_dir = torch.where(close_to_zero, xr_yr, th * xr_yr / xr_yr_norm) | |
# Compute the 3D projective ray | |
rho2_u = ( | |
ray_dir.norm(p=2, dim=2, keepdim=True) ** 2 | |
) # B N 1 # x_c * x_c + y_c * y_c | |
xi = xi.reshape(B, 1, 1) | |
sqrt_term = torch.sqrt(1.0 + (1.0 - xi * xi) * rho2_u) | |
P_z = 1.0 - xi * (rho2_u + 1.0) / (xi + sqrt_term) | |
# Special case when xi is 1.0 (unit sphere projection ??) | |
P_z = torch.where(xi == 1.0, (1.0 - rho2_u) / 2.0, P_z) | |
ray = torch.cat([ray_dir, P_z], dim=-1) | |
ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) | |
# remove nans | |
where_nan = ray.isnan().any(dim=1, keepdim=True).repeat(1, 3, 1, 1) | |
ray = torch.where(where_nan, torch.zeros_like(ray), ray) | |
return ray | |
def project(self, xyz): | |
is_flat = xyz.ndim == 3 | |
B, N = xyz.shape[:2] | |
if not is_flat: | |
B, _, H, W = xyz.shape | |
N = H * W | |
xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) | |
k1, k2, p0, p1, xi = self.params[..., 4:].unbind(dim=1) | |
fx_fy = self.params[..., 0:2].reshape(B, 1, 2) | |
cx_cy = self.params[..., 2:4].reshape(B, 1, 2) | |
norm = xyz.norm(p=2, dim=-1, keepdim=True) | |
ab = xyz[..., :-1] / (xyz[..., -1:] + xi.reshape(B, 1, 1) * norm) | |
# radial correction | |
r = ab.norm(dim=-1, p=2, keepdim=True) | |
k1 = self.params[..., 4].reshape(B, 1, 1) | |
k2 = self.params[..., 5].reshape(B, 1, 1) | |
# ab / r * th * (1 + k1 * (th ** 2) + k2 * (th**4)) | |
# but here r = th, no spherical distortion | |
xr_yr = ab * (1 + k1 * (r**2) + k2 * (r**4)) | |
# Tangential correction. | |
uv_dist = xr_yr | |
p0 = self.params[:, -3].reshape(B, 1) | |
p1 = self.params[:, -2].reshape(B, 1) | |
xr = xr_yr[..., 0].reshape(B, N) | |
yr = xr_yr[..., 1].reshape(B, N) | |
xr_yr_sq = torch.square(xr_yr) | |
xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) | |
yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) | |
rd_sq = xr_sq + yr_sq | |
uv_dist_tu = uv_dist[:, :, 0] + ( | |
(2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 | |
) | |
uv_dist_tv = uv_dist[:, :, 1] + ( | |
(2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 | |
) | |
uv_dist = torch.stack( | |
[uv_dist_tu, uv_dist_tv], dim=-1 | |
) # Avoids in-place complaint. | |
result = uv_dist * fx_fy + cx_cy | |
if not is_flat: | |
result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) | |
invalid = ( | |
(result[:, 0] < 0) | |
| (result[:, 0] > W) | |
| (result[:, 1] < 0) | |
| (result[:, 1] > H) | |
) | |
self.projection_mask = (~invalid).unsqueeze(1) | |
# creates hole in the middle... ?? | |
# self.overlap_mask = self.mask_overlap_projection(result) | |
return result | |
class BatchCamera(Camera): | |
def __init__(self, params, K, original_class, cameras): | |
super().__init__(params, K) | |
self.original_class = original_class | |
self.cameras = cameras | |
# Delegate these methods to original camera | |
def project(self, points_3d): | |
return torch.cat( | |
[ | |
camera.project(points_3d[i : i + 1]) | |
for i, camera in enumerate(self.cameras) | |
] | |
) | |
def unproject(self, points_2d): | |
def recursive_unproject(cameras): | |
if isinstance(cameras, list): | |
return [recursive_unproject(camera) for camera in cameras] | |
else: | |
return cameras.unproject(points_2d) | |
def flatten_and_cat(nested_list): | |
if isinstance(nested_list[0], list): | |
return torch.cat( | |
[flatten_and_cat(sublist) for sublist in nested_list], dim=0 | |
) | |
else: | |
return torch.cat(nested_list, dim=0) | |
unprojected = recursive_unproject(self.cameras) | |
return flatten_and_cat(unprojected) | |
def crop(self, left, top, right=None, bottom=None): | |
val = torch.cat( | |
[ | |
camera.crop(left, top, right, bottom) | |
for i, camera in enumerate(self.cameras) | |
] | |
) | |
return val | |
def resize(self, ratio): | |
val = torch.cat([camera.resize(ratio) for i, camera in enumerate(self.cameras)]) | |
return val | |
def reconstruct(self, depth): | |
val = torch.cat( | |
[ | |
camera.reconstruct(depth[i : i + 1]) | |
for i, camera in enumerate(self.cameras) | |
] | |
) | |
return val | |
def get_projection_mask(self): | |
return torch.cat( | |
[camera.projection_mask for i, camera in enumerate(self.cameras)] | |
) | |
def to(self, device, non_blocking=False): | |
self = super().to(device, non_blocking=non_blocking) | |
self.cameras = recursive_apply( | |
self.cameras, lambda camera: camera.to(device, non_blocking=non_blocking) | |
) | |
return self | |
def reshape(self, *shape): | |
# Reshape the intrinsic matrix (K) and params | |
# we know that the shape of K is (..., 3, 3) and params is (..., 16) | |
reshaped_K = self.K.reshape(*shape, 3, 3) | |
reshaped_params = self.params.reshape(*shape, self.params.shape[-1]) | |
self.cameras = np.array(self.cameras, dtype=object).reshape(shape).tolist() | |
self.original_class = ( | |
np.array(self.original_class, dtype=object).reshape(shape).tolist() | |
) | |
# Create a new BatchCamera with reshaped K and params | |
return BatchCamera( | |
reshaped_params, reshaped_K, self.original_class, self.cameras | |
) | |
def get_new_fov(self, new_shape, original_shape): | |
return [ | |
camera.get_new_fov(new_shape, original_shape) | |
for i, camera in enumerate(self.cameras) | |
] | |
def squeeze(self, dim): | |
return BatchCamera( | |
self.params.squeeze(dim), | |
self.K.squeeze(dim), | |
squeeze_list(self.original_class, dim=dim), | |
squeeze_list(self.cameras, dim=dim), | |
) | |
def __getitem__(self, idx): | |
# If it's an integer index, return a single camera | |
if isinstance(idx, int): | |
return self.cameras[idx] | |
# If it's a slice, return a new BatchCamera with sliced cameras | |
elif isinstance(idx, slice): | |
return BatchCamera( | |
self.params[idx], | |
self.K[idx], | |
self.original_class[idx], | |
self.cameras[idx], | |
) | |
raise TypeError(f"Invalid index type: {type(idx)}") | |
def __setitem__(self, idx, value): | |
# If it's an integer index, return a single camera | |
if isinstance(idx, int): | |
self.cameras[idx] = value | |
self.params[idx, :] = 0.0 | |
self.params[idx, : value.params.shape[1]] = value.params[0] | |
self.K[idx] = value.K[0] | |
self.original_class[idx] = getattr( | |
value, "original_class", value.__class__.__name__ | |
) | |
# If it's a slice, return a new BatchCamera with sliced cameras | |
elif isinstance(idx, slice): | |
# Update each internal attribute using the slice | |
self.params[idx] = value.params | |
self.K[idx] = value.K | |
self.original_class[idx] = value.original_class | |
self.cameras[idx] = value.cameras | |
def __len__(self): | |
return len(self.cameras) | |
def from_camera(cls, camera): | |
return cls(camera.params, camera.K, [camera.__class__.__name__], [camera]) | |
def is_perspective(self): | |
return recursive_apply(self.cameras, lambda camera: isinstance(camera, Pinhole)) | |
def is_spherical(self): | |
return recursive_apply( | |
self.cameras, lambda camera: isinstance(camera, Spherical) | |
) | |
def is_eucm(self): | |
return recursive_apply(self.cameras, lambda camera: isinstance(camera, EUCM)) | |
def is_fisheye(self): | |
return recursive_apply( | |
self.cameras, lambda camera: isinstance(camera, Fisheye624) | |
) | |
def is_pinhole(self): | |
return recursive_apply(self.cameras, lambda camera: isinstance(camera, Pinhole)) | |
def hfov(self): | |
return recursive_apply(self.cameras, lambda camera: camera.hfov) | |
def vfov(self): | |
return recursive_apply(self.cameras, lambda camera: camera.vfov) | |
def max_fov(self): | |
return recursive_apply(self.cameras, lambda camera: camera.max_fov) | |
import json | |
import random | |
# sampler helpers | |
from math import log | |
import torch.nn as nn | |
def eucm(boundaries, mult, batch, device, dtype): | |
alpha_min, alpha_max = boundaries[0][0] * mult, boundaries[0][1] * mult | |
beta_mean, beta_std = boundaries[1][0] * mult, boundaries[1][1] * mult | |
alpha = ( | |
torch.rand(batch, device=device, dtype=dtype) * (alpha_max - alpha_min) | |
+ alpha_min | |
) | |
beta = F.softplus( | |
torch.randn(batch, device=device, dtype=dtype) * beta_std + beta_mean, | |
beta=log(2), | |
) | |
return alpha, beta | |
def free_fisheye(boundaries, mult, batch, device, dtype): | |
k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult | |
k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult | |
k3_min, k3_max = boundaries[2][0] * mult, boundaries[2][1] * mult | |
k4_min, k4_max = boundaries[3][0] * mult, boundaries[3][1] * mult | |
k5_min, k5_max = boundaries[4][0] * mult, boundaries[4][1] * mult | |
k6_min, k6_max = boundaries[5][0] * mult, boundaries[5][1] * mult | |
p1_min, p1_max = boundaries[6][0] * mult, boundaries[6][1] * mult | |
p2_min, p2_max = boundaries[7][0] * mult, boundaries[7][1] * mult | |
s1_min, s1_max = boundaries[8][0] * mult, boundaries[8][1] * mult | |
s2_min, s2_max = boundaries[9][0] * mult, boundaries[9][1] * mult | |
s3_min, s3_max = boundaries[10][0] * mult, boundaries[10][1] * mult | |
s4_min, s4_max = boundaries[11][0] * mult, boundaries[11][1] * mult | |
k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min | |
k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min | |
k3 = torch.rand(batch, device=device, dtype=dtype) * (k3_max - k3_min) + k3_min | |
k4 = torch.rand(batch, device=device, dtype=dtype) * (k4_max - k4_min) + k4_min | |
k5 = torch.rand(batch, device=device, dtype=dtype) * (k5_max - k5_min) + k5_min | |
k6 = torch.rand(batch, device=device, dtype=dtype) * (k6_max - k6_min) + k6_min | |
p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min | |
p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min | |
s1 = torch.rand(batch, device=device, dtype=dtype) * (s1_max - s1_min) + s1_min | |
s2 = torch.rand(batch, device=device, dtype=dtype) * (s2_max - s2_min) + s2_min | |
s3 = torch.rand(batch, device=device, dtype=dtype) * (s3_max - s3_min) + s3_min | |
s4 = torch.rand(batch, device=device, dtype=dtype) * (s4_max - s4_min) + s4_min | |
return k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4 | |
def mei(boundaries, mult, batch, device, dtype): | |
k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult | |
k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult | |
p1_min, p1_max = boundaries[2][0] * mult, boundaries[2][1] * mult | |
p2_min, p2_max = boundaries[3][0] * mult, boundaries[3][1] * mult | |
xi_min, xi_max = boundaries[4][0] * mult, boundaries[4][1] * mult | |
k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min | |
k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min | |
p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min | |
p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min | |
xi = torch.rand(batch, device=device, dtype=dtype) * (xi_max - xi_min) + xi_min | |
return k1, k2, p1, p2, xi | |
def consistent_fisheye(boundaries, mult, batch, device, dtype): | |
sign = random.choice([-1, 1]) | |
return free_fisheye(boundaries, sign * mult, batch, device, dtype) | |
def invert_fisheye(boundaries, mult, batch, device, dtype): | |
k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult | |
k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult | |
k3_min, k3_max = boundaries[2][0] * mult, boundaries[2][1] * mult | |
k4_min, k4_max = boundaries[3][0] * mult, boundaries[3][1] * mult | |
k5_min, k5_max = boundaries[4][0] * mult, boundaries[4][1] * mult | |
k6_min, k6_max = boundaries[5][0] * mult, boundaries[5][1] * mult | |
p1_min, p1_max = boundaries[6][0] * mult, boundaries[6][1] * mult | |
p2_min, p2_max = boundaries[7][0] * mult, boundaries[7][1] * mult | |
s1_min, s1_max = boundaries[8][0] * mult, boundaries[8][1] * mult | |
s2_min, s2_max = boundaries[9][0] * mult, boundaries[9][1] * mult | |
s3_min, s3_max = boundaries[10][0] * mult, boundaries[10][1] * mult | |
s4_min, s4_max = boundaries[11][0] * mult, boundaries[11][1] * mult | |
sign = random.choice([-1, 1]) | |
k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min | |
k1 = sign * k1 | |
k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min | |
k2 = -1 * sign * k2 | |
k3 = torch.rand(batch, device=device, dtype=dtype) * (k3_max - k3_min) + k3_min | |
k3 = sign * k3 | |
k4 = torch.rand(batch, device=device, dtype=dtype) * (k4_max - k4_min) + k4_min | |
k4 = -1 * sign * k4 | |
k5 = torch.rand(batch, device=device, dtype=dtype) * (k5_max - k5_min) + k5_min | |
k5 = sign * k5 | |
k6 = torch.rand(batch, device=device, dtype=dtype) * (k6_max - k6_min) + k6_min | |
k6 = -1 * sign * k6 | |
p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min | |
p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min | |
s1 = torch.rand(batch, device=device, dtype=dtype) * (s1_max - s1_min) + s1_min | |
s2 = torch.rand(batch, device=device, dtype=dtype) * (s2_max - s2_min) + s2_min | |
s3 = torch.rand(batch, device=device, dtype=dtype) * (s3_max - s3_min) + s3_min | |
s4 = torch.rand(batch, device=device, dtype=dtype) * (s4_max - s4_min) + s4_min | |
return k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4 | |
class CameraSampler(nn.Module): | |
def __init__(self): | |
super().__init__() | |
with open("camera_sampler.json", "r") as f: | |
config = json.load(f) | |
self.camera_type = config["type"] | |
self.sampling_fn = config["fn"] | |
self.boundaries = nn.ParameterList( | |
[ | |
nn.Parameter(torch.tensor(x), requires_grad=False) | |
for x in config["boundaries"] | |
] | |
) | |
self.probs = nn.Parameter(torch.tensor(config["probs"]), requires_grad=False) | |
def forward(self, fx, fy, cx, cy, mult, ratio, H): | |
selected_idx = torch.multinomial(self.probs, num_samples=1) | |
device, dtype = fx.device, fx.dtype | |
selected_camera = self.camera_type[selected_idx] | |
selected_sampling_fn = self.sampling_fn[selected_idx] | |
selected_boundaries = self.boundaries[selected_idx] | |
if "Fisheye" in selected_camera or "OPENCV" in selected_camera: | |
mult = mult * ratio | |
params = eval(selected_sampling_fn)( | |
selected_boundaries, mult, len(fx), device, dtype | |
) | |
params = torch.stack([fx, fy, cx, cy, *params], dim=1) | |
camera = eval(selected_camera)(params=params) | |
return camera | |