AetherV1 / aether /utils /visualize_utils.py
Wenzheng Chang
aetherv1 init
19da45c
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Adapted from https://github.com/facebookresearch/vggt/blob/main/visual_util.py
import matplotlib
import numpy as np
import trimesh
from scipy.spatial.transform import Rotation
from aether.utils.postprocess_utils import depth_edge
def predictions_to_glb(
predictions,
filter_by_frames="all",
show_cam=True,
max_depth=100.0,
rtol=0.03,
frame_rel_idx: float = 0.0,
) -> trimesh.Scene:
"""
Converts predictions to a 3D scene represented as a GLB file.
Args:
predictions (dict): Dictionary containing model predictions with keys:
- world_points: 3D point coordinates (S, H, W, 3)
- images: Input images (S, H, W, 3)
- depths: Depths (S, H, W)
- camera poses: Camera poses (S, 4, 4)
filter_by_frames (str): Frame filter specification (default: "all")
show_cam (bool): Include camera visualization (default: True)
max_depth (float): Maximum depth value (default: 100.0)
rtol (float): Relative tolerance for depth edge detection (default: 0.2)
frame_rel_idx (float): Relative index of the frame to visualize (default: 0.0)
Returns:
trimesh.Scene: Processed 3D scene containing point cloud and cameras
Raises:
ValueError: If input predictions structure is invalid
"""
if not isinstance(predictions, dict):
raise ValueError("predictions must be a dictionary")
selected_frame_idx = None
if filter_by_frames != "all" and filter_by_frames != "All":
try:
# Extract the index part before the colon
selected_frame_idx = int(filter_by_frames.split(":")[0])
except (ValueError, IndexError):
pass
pred_world_points = predictions["world_points"]
# Get images from predictions
images = predictions["images"]
# Use extrinsic matrices instead of pred_extrinsic_list
camera_poses = predictions["camera_poses"]
if selected_frame_idx is not None:
pred_world_points = pred_world_points[selected_frame_idx][None]
images = images[selected_frame_idx][None]
camera_poses = camera_poses[selected_frame_idx][None]
vertices_3d = pred_world_points.reshape(-1, 3)
# Handle different image formats - check if images need transposing
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
colors_rgb = np.transpose(images, (0, 2, 3, 1))
else: # Assume already in NHWC format
colors_rgb = images
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
depths = predictions["depths"]
masks = depths < max_depth
edge = ~depth_edge(depths, rtol=rtol, mask=masks)
masks = (masks & edge).reshape(-1)
vertices_3d = vertices_3d[masks]
colors_rgb = colors_rgb[masks]
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
vertices_3d = np.array([[1, 0, 0]])
colors_rgb = np.array([[255, 255, 255]])
scene_scale = 1
else:
# Calculate the 5th and 95th percentiles along each axis
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
# Calculate the diagonal length of the percentile bounding box
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
# Initialize a 3D scene
scene_3d = trimesh.Scene()
# Add point cloud data to the scene
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
scene_3d.add_geometry(point_cloud_data)
# Prepare 4x4 matrices for camera extrinsics
num_cameras = len(camera_poses)
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
extrinsics_matrices[:, :3, :4] = camera_poses[:, :3, :4]
extrinsics_matrices[:, 3, 3] = 1
if show_cam:
# Add camera models to the scene
for i in range(num_cameras):
camera_to_world = camera_poses[i]
rgba_color = colormap(frame_rel_idx)
current_color = tuple(int(255 * x) for x in rgba_color[:3])
integrate_camera_into_scene(
scene_3d, camera_to_world, current_color, scene_scale
)
return scene_3d
def integrate_camera_into_scene(
scene: trimesh.Scene,
transform: np.ndarray,
face_colors: tuple,
scene_scale: float,
):
"""
Integrates a fake camera mesh into the 3D scene.
Args:
scene (trimesh.Scene): The 3D scene to add the camera model.
transform (np.ndarray): Transformation matrix for camera positioning.
face_colors (tuple): Color of the camera face.
scene_scale (float): Scale of the scene.
"""
cam_width = scene_scale * 0.025
cam_height = scene_scale * 0.05
# Create cone shape for camera
rot_45_degree = np.eye(4)
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
rot_45_degree[2, 3] = -cam_height
opengl_transform = get_opengl_conversion_matrix()
# Combine transformations
complete_transform = transform @ opengl_transform @ rot_45_degree
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
# Generate mesh for the camera
slight_rotation = np.eye(4)
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
vertices_combined = np.concatenate(
[
camera_cone_shape.vertices,
0.95 * camera_cone_shape.vertices,
transform_points(slight_rotation, camera_cone_shape.vertices),
]
)
vertices_transformed = transform_points(complete_transform, vertices_combined)
mesh_faces = compute_camera_faces(camera_cone_shape)
# Add the camera mesh to the scene
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
camera_mesh.visual.face_colors[:, :3] = face_colors
scene.add_geometry(camera_mesh)
def get_opengl_conversion_matrix() -> np.ndarray:
"""
Constructs and returns the OpenGL conversion matrix.
Returns:
numpy.ndarray: A 4x4 OpenGL conversion matrix.
"""
# Create an identity matrix
matrix = np.identity(4)
# Flip the y and z axes
matrix[1, 1] = -1
matrix[2, 2] = -1
return matrix
def transform_points(
transformation: np.ndarray, points: np.ndarray, dim: int = None
) -> np.ndarray:
"""
Applies a 4x4 transformation to a set of points.
Args:
transformation (np.ndarray): Transformation matrix.
points (np.ndarray): Points to be transformed.
dim (int, optional): Dimension for reshaping the result.
Returns:
np.ndarray: Transformed points.
"""
points = np.asarray(points)
initial_shape = points.shape[:-1]
dim = dim or points.shape[-1]
# Apply transformation
transformation = transformation.swapaxes(
-1, -2
) # Transpose the transformation matrix
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
# Reshape the result
result = points[..., :dim].reshape(*initial_shape, dim)
return result
def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
"""
Computes the faces for the camera mesh.
Args:
cone_shape (trimesh.Trimesh): The shape of the camera cone.
Returns:
np.ndarray: Array of faces for the camera mesh.
"""
# Create pseudo cameras
faces_list = []
num_vertices_cone = len(cone_shape.vertices)
for face in cone_shape.faces:
if 0 in face:
continue
v1, v2, v3 = face
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
faces_list.extend(
[
(v1, v2, v2_offset),
(v1, v1_offset, v3),
(v3_offset, v2, v3),
(v1, v2, v2_offset_2),
(v1, v1_offset_2, v3),
(v3_offset_2, v2, v3),
]
)
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
return np.array(faces_list)