AetherV1 / aether /utils /postprocess_utils.py
Wenzheng Chang
aetherv1 init
19da45c
from __future__ import annotations
from typing import Optional
import matplotlib
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from plyfile import PlyData, PlyElement
def signed_log1p_inverse(x):
"""
Computes the inverse of signed_log1p: x = sign(x) * (exp(abs(x)) - 1).
Args:
y (torch.Tensor): Input tensor (output of signed_log1p).
Returns:
torch.Tensor: Original tensor x.
"""
if isinstance(x, torch.Tensor):
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
elif isinstance(x, np.ndarray):
return np.sign(x) * (np.exp(np.abs(x)) - 1)
else:
raise TypeError("Input must be a torch.Tensor or numpy.ndarray")
def colorize_depth(depth, cmap="Spectral"):
min_d, max_d = (depth[depth > 0]).min(), (depth[depth > 0]).max()
depth = (max_d - depth) / (max_d - min_d)
cm = matplotlib.colormaps[cmap]
depth = depth.clip(0, 1)
depth = cm(depth, bytes=False)[..., 0:3]
return depth
def save_ply(pointmap, image, output_file, downsample=20, mask=None):
_, h, w, _ = pointmap.shape
image = image[:, :h, :w]
pointmap = pointmap[:, :h, :w]
points = pointmap.reshape(-1, 3) # (H*W, 3)
colors = image.reshape(-1, 3) # (H*W, 3)
if mask is not None:
points = points[mask.reshape(-1)]
colors = colors[mask.reshape(-1)]
indices = np.random.choice(
colors.shape[0], int(colors.shape[0] / downsample), replace=False
)
points = points[indices]
colors = colors[indices]
vertices = []
for p, c in zip(points, colors):
vertex = (p[0], p[1], p[2], int(c[0]), int(c[1]), int(c[2]))
vertices.append(vertex)
vertex_dtype = np.dtype(
[
("x", "f4"),
("y", "f4"),
("z", "f4"),
("red", "u1"),
("green", "u1"),
("blue", "u1"),
]
)
vertex_array = np.array(vertices, dtype=vertex_dtype)
ply_element = PlyElement.describe(vertex_array, "vertex")
PlyData([ply_element], text=True).write(output_file)
def fov_to_focal(fovx, fovy, h, w):
focal_x = w * 0.5 / np.tan(fovx)
focal_y = h * 0.5 / np.tan(fovy)
focal = (focal_x + focal_y) / 2
return focal
def get_rays(pose, h, w, focal=None, fovx=None, fovy=None):
import torch.nn.functional as F
pose = torch.from_numpy(pose).float()
x, y = torch.meshgrid(
torch.arange(w),
torch.arange(h),
indexing="xy",
)
x = x.flatten().unsqueeze(0).repeat(pose.shape[0], 1)
y = y.flatten().unsqueeze(0).repeat(pose.shape[0], 1)
cx = w * 0.5
cy = h * 0.5
intrinsics, focal = get_intrinsics(pose.shape[0], h, w, fovx, fovy, focal)
focal = torch.from_numpy(focal).float()
camera_dirs = F.pad(
torch.stack(
[
(x - cx + 0.5) / focal.unsqueeze(-1),
(y - cy + 0.5) / focal.unsqueeze(-1),
],
dim=-1,
),
(0, 1),
value=1.0,
) # [t, hw, 3]
pose = pose.to(dtype=camera_dirs.dtype)
rays_d = camera_dirs @ pose[:, :3, :3].transpose(1, 2) # [t, hw, 3]
rays_o = pose[:, :3, 3].unsqueeze(1).expand_as(rays_d) # [hw, 3]
rays_o = rays_o.view(pose.shape[0], h, w, 3)
rays_d = rays_d.view(pose.shape[0], h, w, 3)
return rays_o.float().numpy(), rays_d.float().numpy(), intrinsics
def get_intrinsics(batch_size, h, w, fovx=None, fovy=None, focal=None):
if focal is None:
focal_x = w * 0.5 / np.tan(fovx)
focal_y = h * 0.5 / np.tan(fovy)
focal = (focal_x + focal_y) / 2
cx = w * 0.5
cy = h * 0.5
intrinsics = np.zeros((batch_size, 3, 3))
intrinsics[:, 0, 0] = focal
intrinsics[:, 1, 1] = focal
intrinsics[:, 0, 2] = cx
intrinsics[:, 1, 2] = cy
intrinsics[:, 2, 2] = 1.0
return intrinsics, focal
def save_pointmap(
rgb,
disparity,
raymap,
save_file,
vae_downsample_scale=8,
camera_pose=None,
ray_o_scale_inv=1.0,
max_depth=1e2,
save_full_pcd_videos=False,
smooth_camera=False,
smooth_method="kalman", # or simple
**kwargs,
):
"""
Args:
rgb (numpy.ndarray): Shape of (t, h, w, 3), range [0, 1]
disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1]
raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8)
ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10.
"""
rgb = np.clip(rgb, 0, 1) * 255
pointmap_dict = postprocess_pointmap(
disparity,
raymap,
vae_downsample_scale,
camera_pose,
ray_o_scale_inv=ray_o_scale_inv,
smooth_camera=smooth_camera,
smooth_method=smooth_method,
**kwargs,
)
save_ply(
pointmap_dict["pointmap"],
rgb,
save_file,
mask=(pointmap_dict["depth"] < max_depth),
)
if save_full_pcd_videos:
pcd_dict = {
"points": pointmap_dict["pointmap"],
"colors": rgb,
"intrinsics": pointmap_dict["intrinsics"],
"poses": pointmap_dict["camera_pose"],
"depths": pointmap_dict["depth"],
}
np.save(save_file.replace(".ply", "_pcd.npy"), pcd_dict)
return pointmap_dict
def raymap_to_poses(
raymap, camera_pose=None, ray_o_scale_inv=1.0, return_intrinsics=True
):
ts = raymap.shape[0]
if (not return_intrinsics) and (camera_pose is not None):
return camera_pose, None, None
raymap[:, 3:] = signed_log1p_inverse(raymap[:, 3:])
# Extract ray origins and directions
ray_o = (
rearrange(raymap[:, 3:], "t c h w -> t h w c") * ray_o_scale_inv
) # [T, H, W, C]
ray_d = rearrange(raymap[:, :3], "t c h w -> t h w c") # [T, H, W, C]
# Compute orientation and directions
orient = ray_o.reshape(ts, -1, 3).mean(axis=1) # T, 3
image_orient = (ray_o + ray_d).reshape(ts, -1, 3).mean(axis=1) # T, 3
Focal = np.linalg.norm(image_orient - orient, axis=-1) # T,
Z_Dir = image_orient - orient # T, 3
# Compute the width (W) and field of view (FoV_x)
W_Left = ray_d[:, :, :1, :].reshape(ts, -1, 3).mean(axis=1)
W_Right = ray_d[:, :, -1:, :].reshape(ts, -1, 3).mean(axis=1)
W = W_Right - W_Left
W_real = (
np.linalg.norm(np.cross(W, Z_Dir), axis=-1)
/ (raymap.shape[-1] - 1)
* raymap.shape[-1]
)
Fov_x = np.arctan(W_real / (2 * Focal))
# Compute the height (H) and field of view (FoV_y)
H_Up = ray_d[:, :1, :, :].reshape(ts, -1, 3).mean(axis=1)
H_Down = ray_d[:, -1:, :, :].reshape(ts, -1, 3).mean(axis=1)
H = H_Up - H_Down
H_real = (
np.linalg.norm(np.cross(H, Z_Dir), axis=-1)
/ (raymap.shape[-2] - 1)
* raymap.shape[-2]
)
Fov_y = np.arctan(H_real / (2 * Focal))
# Compute X, Y, and Z directions for the camera
X_Dir = W_Right - W_Left
Y_Dir = np.cross(Z_Dir, X_Dir)
X_Dir = np.cross(Y_Dir, Z_Dir)
X_Dir /= np.linalg.norm(X_Dir, axis=-1, keepdims=True)
Y_Dir /= np.linalg.norm(Y_Dir, axis=-1, keepdims=True)
Z_Dir /= np.linalg.norm(Z_Dir, axis=-1, keepdims=True)
# Create the camera-to-world (camera_pose) transformation matrix
if camera_pose is None:
camera_pose = np.zeros((ts, 4, 4))
camera_pose[:, :3, 0] = X_Dir
camera_pose[:, :3, 1] = Y_Dir
camera_pose[:, :3, 2] = Z_Dir
camera_pose[:, :3, 3] = orient
camera_pose[:, 3, 3] = 1.0
return camera_pose, Fov_x, Fov_y
def postprocess_pointmap(
disparity,
raymap,
vae_downsample_scale=8,
camera_pose=None,
focal=None,
ray_o_scale_inv=1.0,
smooth_camera=False,
smooth_method="simple",
**kwargs,
):
"""
Args:
disparity (numpy.ndarray): Shape of (t, h, w), range [0, 1]
raymap (numpy.ndarray): Shape of (t, 6, h // 8, w // 8)
ray_o_scale_inv (float, optional): A `ray_o` scale constant. Defaults to 10.
"""
depth = np.clip(1.0 / np.clip(disparity, 1e-3, 1), 0, 1e8)
camera_pose, fov_x, fov_y = raymap_to_poses(
raymap,
camera_pose=camera_pose,
ray_o_scale_inv=ray_o_scale_inv,
return_intrinsics=(focal is not None),
)
if focal is None:
focal = fov_to_focal(
fov_x,
fov_y,
int(raymap.shape[2] * vae_downsample_scale),
int(raymap.shape[3] * vae_downsample_scale),
)
if smooth_camera:
# Check if sequence is static
is_static, trans_diff, rot_diff = detect_static_sequence(camera_pose)
if is_static:
print(
f"Detected static/near-static sequence (trans_diff={trans_diff:.6f}, rot_diff={rot_diff:.6f})"
)
# Apply stronger smoothing for static sequences
camera_pose = adaptive_pose_smoothing(camera_pose, trans_diff, rot_diff)
else:
if smooth_method == "simple":
camera_pose = smooth_poses(
camera_pose, window_size=5, method="gaussian"
)
elif smooth_method == "kalman":
camera_pose = smooth_trajectory(camera_pose, window_size=5)
ray_o, ray_d, intrinsics = get_rays(
camera_pose,
int(raymap.shape[2] * vae_downsample_scale),
int(raymap.shape[3] * vae_downsample_scale),
focal,
)
pointmap = depth[..., None] * ray_d + ray_o
return {
"pointmap": pointmap,
"camera_pose": camera_pose,
"intrinsics": intrinsics,
"ray_o": ray_o,
"ray_d": ray_d,
"depth": depth,
}
def detect_static_sequence(poses, threshold=0.01):
"""Detect if the camera sequence is static based on pose differences."""
translations = poses[:, :3, 3]
rotations = poses[:, :3, :3]
# Compute translation differences
trans_diff = np.linalg.norm(translations[1:] - translations[:-1], axis=1).mean()
# Compute rotation differences (using matrix frobenius norm)
rot_diff = np.linalg.norm(rotations[1:] - rotations[:-1], axis=(1, 2)).mean()
return trans_diff < threshold and rot_diff < threshold, trans_diff, rot_diff
def adaptive_pose_smoothing(poses, trans_diff, rot_diff, base_window=5):
"""Apply adaptive smoothing based on motion magnitude."""
# Increase window size for low motion sequences
motion_magnitude = trans_diff + rot_diff
adaptive_window = min(
41, max(base_window, int(base_window * (0.1 / max(motion_magnitude, 1e-6))))
)
# Apply stronger smoothing for low motion
poses_smooth = smooth_poses(poses, window_size=adaptive_window, method="gaussian")
return poses_smooth
def get_pixel(H, W):
# get 2D pixels (u, v) for image_a in cam_a pixel space
u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
# u_a = np.flip(u_a, axis=1)
# v_a = np.flip(v_a, axis=0)
pixels_a = np.stack(
[u_a.flatten() + 0.5, v_a.flatten() + 0.5, np.ones_like(u_a.flatten())], axis=0
)
return pixels_a
def project(depth, intrinsic, pose):
H, W = depth.shape
pixel = get_pixel(H, W).astype(np.float32)
points = (np.linalg.inv(intrinsic) @ pixel) * depth.reshape(-1)
points = pose[:3, :4] @ np.concatenate(
[points, np.ones((1, points.shape[1]))], axis=0
)
points = points.T.reshape(H, W, 3)
return points
def depth_edge(
depth: torch.Tensor,
atol: float = None,
rtol: float = None,
kernel_size: int = 3,
mask: Optional[torch.Tensor] = None,
) -> torch.BoolTensor:
"""
Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
Args:
depth (torch.Tensor): shape (..., height, width), linear depth map
atol (float): absolute tolerance
rtol (float): relative tolerance
Returns:
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
"""
is_numpy = isinstance(depth, np.ndarray)
if is_numpy:
depth = torch.from_numpy(depth)
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask)
shape = depth.shape
depth = depth.reshape(-1, 1, *shape[-2:])
if mask is not None:
mask = mask.reshape(-1, 1, *shape[-2:])
if mask is None:
diff = F.max_pool2d(
depth, kernel_size, stride=1, padding=kernel_size // 2
) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
else:
diff = F.max_pool2d(
torch.where(mask, depth, -torch.inf),
kernel_size,
stride=1,
padding=kernel_size // 2,
) + F.max_pool2d(
torch.where(mask, -depth, -torch.inf),
kernel_size,
stride=1,
padding=kernel_size // 2,
)
edge = torch.zeros_like(depth, dtype=torch.bool)
if atol is not None:
edge |= diff > atol
if rtol is not None:
edge |= (diff / depth).nan_to_num_() > rtol
edge = edge.reshape(*shape)
if is_numpy:
return edge.numpy()
return edge
@torch.jit.script
def align_rigid(
p,
q,
weights,
):
"""Compute a rigid transformation that, when applied to p, minimizes the weighted
squared distance between transformed points in p and points in q. See "Least-Squares
Rigid Motion Using SVD" by Olga Sorkine-Hornung and Michael Rabinovich for more
details (https://igl.ethz.ch/projects/ARAP/svd_rot.pdf).
"""
device = p.device
dtype = p.dtype
batch, _, _ = p.shape
# 1. Compute the centroids of both point sets.
weights_normalized = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8)
p_centroid = (weights_normalized[..., None] * p).sum(dim=-2)
q_centroid = (weights_normalized[..., None] * q).sum(dim=-2)
# 2. Compute the centered vectors.
p_centered = p - p_centroid[..., None, :]
q_centered = q - q_centroid[..., None, :]
# 3. Compute the 3x3 covariance matrix.
covariance = (q_centered * weights[..., None]).transpose(-1, -2) @ p_centered
# 4. Compute the singular value decomposition and then the rotation.
u, _, vt = torch.linalg.svd(covariance)
s = torch.eye(3, dtype=dtype, device=device)
s = s.expand((batch, 3, 3)).contiguous()
s[..., 2, 2] = (u.det() * vt.det()).sign()
rotation = u @ s @ vt
# 5. Compute the optimal scale
scale = (
(torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum(
-1
)
* weights
).sum(-1) / ((p_centered**2).sum(-1) * weights).sum(-1)
# scale = (torch.einsum("b i j, b k j -> b k i", rotation, p_centered) * q_centered).sum([-1, -2]) / (p_centered**2).sum([-1, -2])
# 6. Compute the optimal translation.
translation = q_centroid - torch.einsum(
"b i j, b j -> b i", rotation, p_centroid * scale[:, None]
)
return rotation, translation, scale
def align_camera_extrinsics(
cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t]
cameras_tgt: torch.Tensor, # Bx3x4 tensor representing [R | t]
estimate_scale: bool = True,
eps: float = 1e-9,
):
"""
Align the source camera extrinsics to the target camera extrinsics.
NOTE Assume OPENCV convention
Args:
cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras.
cameras_tgt (torch.Tensor): Bx3x4 tensor representing [R | t] for target cameras.
estimate_scale (bool, optional): Whether to estimate the scale factor. Default is True.
eps (float, optional): Small value to avoid division by zero. Default is 1e-9.
Returns:
align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment.
align_t_T (torch.Tensor): 1x3 translation vector for alignment.
align_t_s (float): Scaling factor for alignment.
"""
R_src = cameras_src[:, :, :3] # Extracting the rotation matrices from [R | t]
R_tgt = cameras_tgt[:, :, :3] # Extracting the rotation matrices from [R | t]
RRcov = torch.bmm(R_tgt.transpose(2, 1), R_src).mean(0)
U, _, V = torch.svd(RRcov)
align_t_R = V @ U.t()
T_src = cameras_src[:, :, 3] # Extracting the translation vectors from [R | t]
T_tgt = cameras_tgt[:, :, 3] # Extracting the translation vectors from [R | t]
A = torch.bmm(T_src[:, None], R_src)[:, 0]
B = torch.bmm(T_tgt[:, None], R_src)[:, 0]
Amu = A.mean(0, keepdim=True)
Bmu = B.mean(0, keepdim=True)
if estimate_scale and A.shape[0] > 1:
# get the scaling component by matching covariances
# of centered A and centered B
Ac = A - Amu
Bc = B - Bmu
align_t_s = (Ac * Bc).mean() / (Ac**2).mean().clamp(eps)
else:
# set the scale to identity
align_t_s = 1.0
# get the translation as the difference between the means of A and B
align_t_T = Bmu - align_t_s * Amu
align_t_R = align_t_R[None]
return align_t_R, align_t_T, align_t_s
def apply_transformation(
cameras_src: torch.Tensor, # Bx3x4 tensor representing [R | t]
align_t_R: torch.Tensor, # 1x3x3 rotation matrix
align_t_T: torch.Tensor, # 1x3 translation vector
align_t_s: float, # Scaling factor
return_extri: bool = True,
) -> torch.Tensor:
"""
Align and transform the source cameras using the provided rotation, translation, and scaling factors.
NOTE Assume OPENCV convention
Args:
cameras_src (torch.Tensor): Bx3x4 tensor representing [R | t] for source cameras.
align_t_R (torch.Tensor): 1x3x3 rotation matrix for alignment.
align_t_T (torch.Tensor): 1x3 translation vector for alignment.
align_t_s (float): Scaling factor for alignment.
Returns:
aligned_R (torch.Tensor): Bx3x3 tensor representing the aligned rotation matrices.
aligned_T (torch.Tensor): Bx3 tensor representing the aligned translation vectors.
"""
R_src = cameras_src[:, :, :3]
T_src = cameras_src[:, :, 3]
aligned_R = torch.bmm(R_src, align_t_R.expand(R_src.shape[0], 3, 3))
# Apply the translation alignment to the source translations
align_t_T_expanded = align_t_T[..., None].repeat(R_src.shape[0], 1, 1)
transformed_T = torch.bmm(R_src, align_t_T_expanded)[..., 0]
aligned_T = transformed_T + T_src * align_t_s
if return_extri:
extri = torch.cat([aligned_R, aligned_T.unsqueeze(-1)], dim=-1)
return extri
return aligned_R, aligned_T
def slerp(q1, q2, t):
"""Spherical Linear Interpolation between quaternions.
Args:
q1: (4,) first quaternion
q2: (4,) second quaternion
t: float between 0 and 1
Returns:
(4,) interpolated quaternion
"""
# Compute the cosine of the angle between the two vectors
dot = np.sum(q1 * q2)
# If the dot product is negative, slerp won't take the shorter path
# Fix by negating one of the input quaternions
if dot < 0.0:
q2 = -q2
dot = -dot
# Threshold for using linear interpolation instead of spherical
DOT_THRESHOLD = 0.9995
if dot > DOT_THRESHOLD:
# If the inputs are too close for comfort, linearly interpolate
# and normalize the result
result = q1 + t * (q2 - q1)
return result / np.linalg.norm(result)
# Compute the angle between the quaternions
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Compute interpolation factors
theta = theta_0 * t
sin_theta = np.sin(theta)
s0 = np.cos(theta) - dot * sin_theta / sin_theta_0
s1 = sin_theta / sin_theta_0
return (s0 * q1) + (s1 * q2)
def interpolate_poses(pose1, pose2, weight):
"""Interpolate between two camera poses with weight.
Args:
pose1: (4, 4) first camera pose
pose2: (4, 4) second camera pose
weight: float between 0 and 1, weight for pose1 (1-weight for pose2)
Returns:
(4, 4) interpolated pose
"""
from scipy.spatial.transform import Rotation as R
# Extract rotations and translations
R1 = R.from_matrix(pose1[:3, :3])
R2 = R.from_matrix(pose2[:3, :3])
t1 = pose1[:3, 3]
t2 = pose2[:3, 3]
# Get quaternions
q1 = R1.as_quat()
q2 = R2.as_quat()
# Interpolate rotation using our slerp implementation
q_interp = slerp(q1, q2, 1 - weight) # 1-weight because weight is for pose1
R_interp = R.from_quat(q_interp)
# Linear interpolation for translation
t_interp = weight * t1 + (1 - weight) * t2
# Construct interpolated pose
pose_interp = np.eye(4)
pose_interp[:3, :3] = R_interp.as_matrix()
pose_interp[:3, 3] = t_interp
return pose_interp
def smooth_poses(poses, window_size=5, method="gaussian"):
"""Smooth camera poses temporally.
Args:
poses: (N, 4, 4) camera poses
window_size: int, must be odd number
method: str, 'gaussian' or 'savgol' or 'ma'
Returns:
(N, 4, 4) smoothed poses
"""
from scipy.ndimage import gaussian_filter1d
from scipy.signal import savgol_filter
from scipy.spatial.transform import Rotation as R
assert window_size % 2 == 1, "window_size must be odd"
N = poses.shape[0]
smoothed = np.zeros_like(poses)
# Extract translations and quaternions
translations = poses[:, :3, 3]
rotations = R.from_matrix(poses[:, :3, :3])
quats = rotations.as_quat() # (N, 4)
# Ensure consistent quaternion signs to prevent interpolation artifacts
for i in range(1, N):
if np.dot(quats[i], quats[i - 1]) < 0:
quats[i] = -quats[i]
# Smooth translations
if method == "gaussian":
sigma = window_size / 6.0 # approximately 99.7% of the weight within the window
smoothed_trans = gaussian_filter1d(translations, sigma, axis=0, mode="nearest")
smoothed_quats = gaussian_filter1d(quats, sigma, axis=0, mode="nearest")
elif method == "savgol":
# Savitzky-Golay filter: polynomial fitting
poly_order = min(window_size - 1, 3)
smoothed_trans = savgol_filter(
translations, window_size, poly_order, axis=0, mode="nearest"
)
smoothed_quats = savgol_filter(
quats, window_size, poly_order, axis=0, mode="nearest"
)
elif method == "ma":
# Simple moving average
kernel = np.ones(window_size) / window_size
smoothed_trans = np.array(
[np.convolve(translations[:, i], kernel, mode="same") for i in range(3)]
).T
smoothed_quats = np.array(
[np.convolve(quats[:, i], kernel, mode="same") for i in range(4)]
).T
# Normalize quaternions
smoothed_quats /= np.linalg.norm(smoothed_quats, axis=1, keepdims=True)
# Reconstruct poses
smoothed_rots = R.from_quat(smoothed_quats).as_matrix()
for i in range(N):
smoothed[i] = np.eye(4)
smoothed[i, :3, :3] = smoothed_rots[i]
smoothed[i, :3, 3] = smoothed_trans[i]
return smoothed
def smooth_trajectory(poses, window_size=5):
"""Smooth camera trajectory using Kalman filter.
Args:
poses: (N, 4, 4) camera poses
window_size: int, window size for initial smoothing
Returns:
(N, 4, 4) smoothed poses
"""
from filterpy.kalman import KalmanFilter
from scipy.spatial.transform import Rotation as R
N = poses.shape[0]
# Initialize Kalman filter for position and velocity
kf = KalmanFilter(dim_x=6, dim_z=3) # 3D position and velocity
dt = 1.0 # assume uniform time steps
# State transition matrix
kf.F = np.array(
[
[1, 0, 0, dt, 0, 0],
[0, 1, 0, 0, dt, 0],
[0, 0, 1, 0, 0, dt],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1],
]
)
# Measurement matrix
kf.H = np.array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0]])
# Measurement noise
kf.R *= 0.1
# Process noise
kf.Q *= 0.1
# Initial state uncertainty
kf.P *= 1.0
# Extract translations and rotations
translations = poses[:, :3, 3]
rotations = R.from_matrix(poses[:, :3, :3])
quats = rotations.as_quat()
# First pass: simple smoothing for initial estimates
smoothed = smooth_poses(poses, window_size, method="gaussian")
smooth_trans = smoothed[:, :3, 3]
# Second pass: Kalman filter for trajectory
filtered_trans = np.zeros_like(translations)
kf.x = np.zeros(6)
kf.x[:3] = smooth_trans[0]
filtered_trans[0] = smooth_trans[0]
# Forward pass
for i in range(1, N):
kf.predict()
kf.update(smooth_trans[i])
filtered_trans[i] = kf.x[:3]
# Backward smoothing for rotations using SLERP
window_half = window_size // 2
smoothed_quats = np.zeros_like(quats)
for i in range(N):
start_idx = max(0, i - window_half)
end_idx = min(N, i + window_half + 1)
weights = np.exp(
-0.5 * ((np.arange(start_idx, end_idx) - i) / (window_half / 2)) ** 2
)
weights /= weights.sum()
# Weighted average of nearby quaternions
avg_quat = np.zeros(4)
for j, w in zip(range(start_idx, end_idx), weights):
if np.dot(quats[j], quats[i]) < 0:
avg_quat += w * -quats[j]
else:
avg_quat += w * quats[j]
smoothed_quats[i] = avg_quat / np.linalg.norm(avg_quat)
# Reconstruct final smoothed poses
final_smoothed = np.zeros_like(poses)
smoothed_rots = R.from_quat(smoothed_quats).as_matrix()
for i in range(N):
final_smoothed[i] = np.eye(4)
final_smoothed[i, :3, :3] = smoothed_rots[i]
final_smoothed[i, :3, 3] = filtered_trans[i]
return final_smoothed
def compute_scale(prediction, target, mask):
if isinstance(prediction, np.ndarray):
prediction = torch.from_numpy(prediction).float()
if isinstance(target, np.ndarray):
target = torch.from_numpy(target).float()
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask).bool()
numerator = torch.sum(mask * prediction * target, (1, 2))
denominator = torch.sum(mask * prediction * prediction, (1, 2))
scale = torch.zeros_like(numerator)
valid = (denominator != 0).nonzero()
scale[valid] = numerator[valid] / denominator[valid]
return scale.item()