|
import numpy as np |
|
from scipy.spatial.transform import Rotation as R |
|
|
|
|
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
from pytorch3d.transforms import Rotate, Translate |
|
|
|
|
|
def intersect_skew_line_groups(p, r, mask): |
|
|
|
|
|
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) |
|
if p_intersect is None: |
|
return None, None, None, None |
|
_, p_line_intersect = point_line_distance( |
|
p, r, p_intersect[..., None, :].expand_as(p) |
|
) |
|
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( |
|
dim=-1 |
|
) |
|
return p_intersect, p_line_intersect, intersect_dist_squared, r |
|
|
|
|
|
def intersect_skew_lines_high_dim(p, r, mask=None): |
|
|
|
dim = p.shape[-1] |
|
|
|
if mask is None: |
|
mask = torch.ones_like(p[..., 0]) |
|
r = torch.nn.functional.normalize(r, dim=-1) |
|
|
|
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] |
|
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] |
|
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) |
|
|
|
|
|
|
|
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.any(torch.isnan(p_intersect)): |
|
print(p_intersect) |
|
return None, None |
|
ipdb.set_trace() |
|
assert False |
|
return p_intersect, r |
|
|
|
|
|
def point_line_distance(p1, r1, p2): |
|
df = p2 - p1 |
|
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) |
|
line_pt_nearest = p2 - proj_vector |
|
d = (proj_vector).norm(dim=-1) |
|
return d, line_pt_nearest |
|
|
|
|
|
def compute_optical_axis_intersection(cameras, in_ndc=True): |
|
centers = cameras.get_camera_center() |
|
principal_points = cameras.principal_point |
|
|
|
one_vec = torch.ones((len(cameras), 1), device=centers.device) |
|
optical_axis = torch.cat((principal_points, one_vec), -1) |
|
|
|
|
|
|
|
|
|
|
|
pp = cameras.unproject_points(optical_axis, from_ndc=in_ndc, world_coordinates=True) |
|
pp2 = torch.diagonal(pp, dim1=0, dim2=1).T |
|
|
|
directions = pp2 - centers |
|
centers = centers.unsqueeze(0).unsqueeze(0) |
|
directions = directions.unsqueeze(0).unsqueeze(0) |
|
|
|
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( |
|
p=centers, r=directions, mask=None |
|
) |
|
|
|
if p_intersect is None: |
|
dist = None |
|
else: |
|
p_intersect = p_intersect.squeeze().unsqueeze(0) |
|
dist = (p_intersect - centers).norm(dim=-1) |
|
|
|
return p_intersect, dist, p_line_intersect, pp2, r |
|
|
|
|
|
def normalize_cameras_with_up_axis(cameras, sequence_name, scale=1.0, in_ndc=True): |
|
""" |
|
Normalizes cameras such that the optical axes point to the origin and the average |
|
distance to the origin is 1. |
|
|
|
Args: |
|
cameras (List[camera]). |
|
""" |
|
|
|
|
|
new_cameras = cameras.clone() |
|
new_transform = new_cameras.get_world_to_view_transform() |
|
|
|
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection( |
|
cameras, |
|
in_ndc=in_ndc |
|
) |
|
t = Translate(p_intersect) |
|
|
|
|
|
scale = dist.squeeze().mean() |
|
|
|
|
|
if scale == 0: |
|
print(cameras.T) |
|
print(new_transform.get_matrix()[:, 3, :3]) |
|
return -1 |
|
assert scale != 0 |
|
|
|
new_transform = t.compose(new_transform) |
|
new_cameras.R = new_transform.get_matrix()[:, :3, :3] |
|
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale * 1.85 |
|
|
|
needs_checking = False |
|
|
|
|
|
|
|
|
|
x_vectors = new_cameras.R.transpose(1, 2)[:, 0, :].clone() |
|
x_vectors -= x_vectors.mean(dim=0, keepdim=True) |
|
U, S, Vh = torch.linalg.svd(x_vectors) |
|
V = Vh.mH |
|
|
|
|
|
if S[0] / S[1] > S[1] / S[2]: |
|
print('Warning: unexpected singular values in sequence {}: {}'.format(sequence_name, S)) |
|
needs_checking = True |
|
|
|
estimated_world_up = V[:, 2:] |
|
|
|
for camera_idx in range(len(new_cameras.T)): |
|
if torch.sign(torch.dot(estimated_world_up[:, 0], |
|
new_cameras.R[0].transpose(0,1)[1, :])) != torch.sign(torch.dot(estimated_world_up[:, 0], |
|
new_cameras.R[camera_idx].transpose(0,1)[1, :])): |
|
print("Some cameras appear to be flipped in sequence {}".format(sequence_name) ) |
|
needs_checking = True |
|
|
|
flip = torch.sign(torch.dot(estimated_world_up[:, 0], new_cameras.R[0].transpose(0,1)[1, :])) < 0 |
|
if flip: |
|
estimated_world_up = V[:, 2:] * -1 |
|
|
|
target_coordinate_basis = torch.cat([V[:, :1], |
|
estimated_world_up, |
|
torch.linalg.cross(V[:, :1], estimated_world_up, dim=0)], |
|
dim=1) |
|
new_cameras.R = torch.matmul(target_coordinate_basis.T, new_cameras.R) |
|
return new_cameras, p_intersect, p_line_intersect, pp, r, needs_checking |
|
|
|
|
|
def dot(x, y): |
|
if isinstance(x, np.ndarray): |
|
return np.sum(x * y, -1, keepdims=True) |
|
else: |
|
return torch.sum(x * y, -1, keepdim=True) |
|
|
|
|
|
def length(x, eps=1e-20): |
|
if isinstance(x, np.ndarray): |
|
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) |
|
else: |
|
return torch.sqrt(torch.clamp(dot(x, x), min=eps)) |
|
|
|
|
|
def safe_normalize(x, eps=1e-20): |
|
return x / length(x, eps) |
|
|
|
|
|
def look_at(campos, target, opengl=True): |
|
|
|
|
|
|
|
if not opengl: |
|
|
|
forward_vector = safe_normalize(target - campos) |
|
up_vector = np.array([0, 1, 0], dtype=np.float32) |
|
right_vector = safe_normalize(np.cross(forward_vector, up_vector)) |
|
up_vector = safe_normalize(np.cross(right_vector, forward_vector)) |
|
else: |
|
|
|
forward_vector = safe_normalize(campos - target) |
|
up_vector = np.array([0, 1, 0], dtype=np.float32) |
|
right_vector = safe_normalize(np.cross(up_vector, forward_vector)) |
|
up_vector = safe_normalize(np.cross(forward_vector, right_vector)) |
|
R = np.stack([right_vector, up_vector, forward_vector], axis=1) |
|
return R |
|
|
|
|
|
|
|
def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True): |
|
|
|
|
|
|
|
|
|
if is_degree: |
|
elevation = np.deg2rad(elevation) |
|
azimuth = np.deg2rad(azimuth) |
|
x = radius * np.cos(elevation) * np.sin(azimuth) |
|
y = - radius * np.sin(elevation) |
|
z = radius * np.cos(elevation) * np.cos(azimuth) |
|
if target is None: |
|
target = np.zeros([3], dtype=np.float32) |
|
campos = np.array([x, y, z]) + target |
|
T = np.eye(4, dtype=np.float32) |
|
T[:3, :3] = look_at(campos, target, opengl) |
|
T[:3, 3] = campos |
|
|
|
return T |
|
|
|
|
|
def mat2latlon(T): |
|
if not isinstance(T, np.ndarray): |
|
xyz = T.cpu().detach().numpy() |
|
else: |
|
xyz = T.copy() |
|
r = np.linalg.norm(xyz) |
|
xyz = xyz / r |
|
theta = -np.arcsin(xyz[1]) |
|
azi = np.arctan2(xyz[0], xyz[2]) |
|
return np.rad2deg(theta), np.rad2deg(azi), r |
|
|
|
|
|
def extract_camera_properties(camera_to_world_matrix): |
|
|
|
camera_position = camera_to_world_matrix[:3, 3] |
|
|
|
|
|
forward = camera_to_world_matrix[:3, 2] |
|
|
|
return camera_position, forward |
|
|
|
|
|
def compute_angular_error_batch(rotation1, rotation2): |
|
R_rel = np.einsum("Bij,Bjk ->Bik", rotation1.transpose(0, 2, 1), rotation2) |
|
t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2 |
|
theta = np.arccos(np.clip(t, -1, 1)) |
|
return theta * 180 / np.pi |
|
|
|
|
|
def find_mask_center_and_translate(image, mask): |
|
""" |
|
Calculate the center of the mask and translate the image such that |
|
the mask center is at the image center. |
|
|
|
Args: |
|
- image (torch.Tensor): Input image tensor of shape (N, C, H, W) |
|
- mask (torch.Tensor): Mask tensor of shape (N, 1, H, W) |
|
|
|
Returns: |
|
- Translated image of shape (N, C, H, W) |
|
""" |
|
_, _, h, w = image.shape |
|
|
|
|
|
|
|
y_coords, x_coords = torch.meshgrid(torch.arange(0, h), torch.arange(0, w), indexing='ij') |
|
total_mass = mask.sum(dim=[2, 3], keepdim=True) |
|
x_center = (mask * x_coords.to(image.device)).sum(dim=[2, 3], keepdim=True) / total_mass |
|
y_center = (mask * y_coords.to(image.device)).sum(dim=[2, 3], keepdim=True) / total_mass |
|
|
|
|
|
image_center_x, image_center_y = w // 2, h // 2 |
|
delta_x = x_center.squeeze() - image_center_x |
|
delta_y = y_center.squeeze() - image_center_y |
|
|
|
return torch.tensor([delta_x, delta_y]) |
|
|
|
|
|
def create_voxel_grid(length, resolution=64): |
|
""" |
|
Creates a voxel grid. |
|
xyz_range: ((min_x, max_x), (min_y, max_y), (min_z, max_z)) |
|
resolution: The number of divisions along each axis. |
|
Returns a 4D tensor representing the voxel grid, with each voxel initialized to 1 (solid). |
|
""" |
|
x = torch.linspace(-length, length, resolution) |
|
y = torch.linspace(-length, length, resolution) |
|
z = torch.linspace(-length, length, resolution) |
|
|
|
xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij') |
|
voxels = torch.stack([xx, yy, zz, torch.ones_like(xx)], dim=-1) |
|
return voxels |
|
|
|
|
|
def project_voxels_to_image(voxels, camera): |
|
""" |
|
Projects voxel centers into the camera's image plane. |
|
voxels: 4D tensor of voxel grid in homogeneous coordinates. |
|
K: Camera intrinsic matrix. |
|
R: Camera rotation matrix. |
|
t: Camera translation vector. |
|
Returns a tensor of projected 2D points in image coordinates. |
|
""" |
|
device = voxels.device |
|
|
|
|
|
|
|
N = voxels.nelement() // 4 |
|
voxels_flat = voxels.reshape(-1, 4).t() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
projected_voxels = camera.projection_matrix.transpose(0, 1) @ camera.world_view_transform.transpose(0, 1) @ voxels_flat |
|
|
|
|
|
projected_voxels_2d = (projected_voxels[:2, :] / projected_voxels[3, :]).t() |
|
projected_voxels_2d = (projected_voxels_2d.reshape(*voxels.shape[:-1], 2) + 1.) * 255 * 0.5 |
|
|
|
return projected_voxels_2d |
|
|
|
|
|
def carve_voxels(voxel_grid, projected_points, mask): |
|
""" |
|
Updates the voxel grid based on the comparison with the mask. |
|
voxel_grid: 3D tensor representing the voxel grid. |
|
projected_points: Projected 2D points in image coordinates. |
|
mask: Binary mask image. |
|
""" |
|
|
|
indices_x = torch.clamp(projected_points[..., 0], 0, mask.shape[1] - 1).long() |
|
indices_y = torch.clamp(projected_points[..., 1], 0, mask.shape[0] - 1).long() |
|
|
|
|
|
in_object = mask[indices_y, indices_x] |
|
|
|
|
|
voxel_grid[in_object == 0] = 0 |
|
|
|
|
|
def sample_points_from_voxel(cameras, masks, length=1, resolution=64, N=5000, inverse=False, device="cuda"): |
|
""" |
|
Randomly sample N points from solid regions in a voxel grid. |
|
|
|
Args: |
|
- voxel_grid (torch.Tensor): A 3D tensor representing the voxel grid after carving. |
|
Solid regions are marked with 1s. |
|
- N (int): The number of points to sample. |
|
|
|
Returns: |
|
- sampled_points (torch.Tensor): A tensor of shape (N, 3) representing the sampled 3D coordinates. |
|
""" |
|
voxel_grid = create_voxel_grid(length, resolution).to(device) |
|
voxel_grid_indicator = torch.ones(resolution, resolution, resolution) |
|
|
|
masks = torch.from_numpy(masks).to(device).squeeze() |
|
|
|
for idx, cam in enumerate(cameras): |
|
projected_points = project_voxels_to_image(voxel_grid, cam) |
|
carve_voxels(voxel_grid_indicator, projected_points, masks[idx]) |
|
|
|
voxel_grid_indicator = voxel_grid_indicator.reshape(resolution, resolution, resolution) |
|
|
|
|
|
if inverse: |
|
solid_indices = torch.nonzero(voxel_grid_indicator == 0) |
|
else: |
|
solid_indices = torch.nonzero(voxel_grid_indicator == 1) |
|
|
|
|
|
if N <= solid_indices.size(0): |
|
|
|
sampled_indices = solid_indices[torch.randperm(solid_indices.size(0))[:N]] |
|
else: |
|
|
|
sampled_indices = solid_indices[torch.randint(0, solid_indices.size(0), (N,))] |
|
|
|
|
|
|
|
|
|
sampled_points = sampled_indices.float() / (voxel_grid.size(0) - 1) * 2 * length - length |
|
|
|
return sampled_points |
|
|
|
|
|
class OrbitCamera: |
|
def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100): |
|
self.W = W |
|
self.H = H |
|
self.radius = r |
|
self.fovy = np.deg2rad(fovy) |
|
self.near = near |
|
self.far = far |
|
self.center = np.array([0, 0, 0], dtype=np.float32) |
|
self.rot = R.from_matrix(np.eye(3)) |
|
self.up = np.array([0, 1, 0], dtype=np.float32) |
|
|
|
@property |
|
def fovx(self): |
|
return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H) |
|
|
|
@property |
|
def campos(self): |
|
return self.pose[:3, 3] |
|
|
|
|
|
@property |
|
def pose(self): |
|
|
|
res = np.eye(4, dtype=np.float32) |
|
res[2, 3] = self.radius |
|
|
|
rot = np.eye(4, dtype=np.float32) |
|
rot[:3, :3] = self.rot.as_matrix() |
|
res = rot @ res |
|
|
|
res[:3, 3] -= self.center |
|
return res |
|
|
|
|
|
@property |
|
def view(self): |
|
return np.linalg.inv(self.pose) |
|
|
|
|
|
@property |
|
def perspective(self): |
|
y = np.tan(self.fovy / 2) |
|
aspect = self.W / self.H |
|
return np.array( |
|
[ |
|
[1 / (y * aspect), 0, 0, 0], |
|
[0, -1 / y, 0, 0], |
|
[ |
|
0, |
|
0, |
|
-(self.far + self.near) / (self.far - self.near), |
|
-(2 * self.far * self.near) / (self.far - self.near), |
|
], |
|
[0, 0, -1, 0], |
|
], |
|
dtype=np.float32, |
|
) |
|
|
|
|
|
@property |
|
def intrinsics(self): |
|
focal = self.H / (2 * np.tan(self.fovy / 2)) |
|
return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) |
|
|
|
@property |
|
def mvp(self): |
|
return self.perspective @ np.linalg.inv(self.pose) |
|
|
|
def orbit(self, dx, dy): |
|
|
|
side = self.rot.as_matrix()[:3, 0] |
|
rotvec_x = self.up * np.radians(-0.05 * dx) |
|
rotvec_y = side * np.radians(-0.05 * dy) |
|
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot |
|
|
|
def scale(self, delta): |
|
self.radius *= 1.1 ** (-delta) |
|
|
|
def pan(self, dx, dy, dz=0): |
|
|
|
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz]) |