InstantSplat / utils /utils_poses /vis_pose_utils.py
kairunwen's picture
add code
35e2073
import os
import matplotlib
import matplotlib.pyplot as plt
import copy
from evo.core.trajectory import PosePath3D, PoseTrajectory3D
from evo.main_ape import ape
from evo.tools import plot
from evo.core import sync
from evo.tools import file_interface
from evo.core import metrics
import evo
import torch
import numpy as np
from scipy.spatial.transform import Slerp
from scipy.spatial.transform import Rotation as R
import scipy.interpolate as si
def interp_poses(c2ws, N_views):
N_inputs = c2ws.shape[0]
trans = c2ws[:, :3, 3:].permute(2, 1, 0)
rots = c2ws[:, :3, :3]
render_poses = []
rots = R.from_matrix(rots)
slerp = Slerp(np.linspace(0, 1, N_inputs), rots)
interp_rots = torch.tensor(
slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32))
interp_trans = torch.nn.functional.interpolate(
trans, size=N_views, mode='linear').permute(2, 1, 0)
render_poses = torch.cat([interp_rots, interp_trans], dim=2)
render_poses = convert3x4_4x4(render_poses)
return render_poses
def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree):
target_trans = torch.tensor(scipy_bspline(
c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2)
rots = R.from_matrix(c2ws[:, :3, :3])
slerp = Slerp(input_times, rots)
target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs)
target_rots = torch.tensor(
slerp(target_times).as_matrix().astype(np.float32))
target_poses = torch.cat([target_rots, target_trans], dim=2)
target_poses = convert3x4_4x4(target_poses)
return target_poses
def poses_avg(poses):
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
vec2 = normalize(poses[:, :3, 2].sum(0))
up = poses[:, :3, 1].sum(0)
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
return c2w
def normalize(v):
"""Normalize a vector."""
return v / np.linalg.norm(v)
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
render_poses = []
rads = np.array(list(rads) + [1.])
hwf = c2w[:, 4:5]
for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
# c = np.dot(c2w[:3,:4], np.array([0.7*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.1, 1.]) * rads)
# c = np.dot(c2w[:3,:4], np.array([0.3*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.01, 1.]) * rads)
c = np.dot(c2w[:3, :4], np.array(
[0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads)
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
return render_poses
def scipy_bspline(cv, n=100, degree=3, periodic=False):
""" Calculate n samples on a bspline
cv : Array ov control vertices
n : Number of samples to return
degree: Curve degree
periodic: True - Curve is closed
"""
cv = np.asarray(cv)
count = cv.shape[0]
# Closed curve
if periodic:
kv = np.arange(-degree, count+degree+1)
factor, fraction = divmod(count+degree+1, count)
cv = np.roll(np.concatenate(
(cv,) * factor + (cv[:fraction],)), -1, axis=0)
degree = np.clip(degree, 1, degree)
# Opened curve
else:
degree = np.clip(degree, 1, count-1)
kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree)
# Return samples
max_param = count - (degree * (1-periodic))
spl = si.BSpline(kv, cv, degree)
return spl(np.linspace(0, max_param, n))
def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf):
learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach(
).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1)
c2w = poses_avg(learned_poses_)
print('recentered', c2w.shape)
# Get spiral
# Get average pose
up = normalize(learned_poses_[:, :3, 1].sum(0))
# Find a reasonable "focus depth" for this dataset
close_depth, inf_depth = bds.min()*.9, bds.max()*5.
dt = .75
mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
focal = mean_dz
# Get radii for spiral path
shrink_factor = .8
zdelta = close_depth * .2
tt = learned_poses_[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
rads = np.percentile(np.abs(tt), 90, 0)
c2w_path = c2w
N_rots = 2
c2ws = render_path_spiral(
c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views)
c2ws = torch.tensor(np.stack(c2ws).astype(np.float32))
c2ws = c2ws[:, :3, :4]
c2ws = convert3x4_4x4(c2ws)
return c2ws
def convert3x4_4x4(input):
"""
:param input: (N, 3, 4) or (3, 4) torch or np
:return: (N, 4, 4) or (4, 4) torch or np
"""
if torch.is_tensor(input):
if len(input.shape) == 3:
output = torch.cat([input, torch.zeros_like(
input[:, 0:1])], dim=1) # (N, 4, 4)
output[:, 3, 3] = 1.0
else:
output = torch.cat([input, torch.tensor(
[[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4)
else:
if len(input.shape) == 3:
output = np.concatenate(
[input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4)
output[:, 3, 3] = 1.0
else:
output = np.concatenate(
[input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4)
output[3, 3] = 1.0
return output
plt.rc('legend', fontsize=20) # using a named size
def plot_pose(ref_poses, est_poses, output_path, args, vid=False):
ref_poses = [pose for pose in ref_poses]
if isinstance(est_poses, dict):
est_poses = [pose for k, pose in est_poses.items()]
else:
est_poses = [pose for pose in est_poses]
traj_ref = PosePath3D(poses_se3=ref_poses)
traj_est = PosePath3D(poses_se3=est_poses)
traj_est_aligned = copy.deepcopy(traj_est)
traj_est_aligned.align(traj_ref, correct_scale=True,
correct_only_scale=False)
if vid:
for p_idx in range(len(ref_poses)):
fig = plt.figure()
current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1]
current_ref = traj_ref.poses_se3[:p_idx+1]
current_est_aligned = PosePath3D(poses_se3=current_est_aligned)
current_ref = PosePath3D(poses_se3=current_ref)
traj_by_label = {
# "estimate (not aligned)": traj_est,
"Ours (aligned)": current_est_aligned,
"Ground-truth": current_ref
}
plot_mode = plot.PlotMode.xyz
# ax = plot.prepare_axis(fig, plot_mode, 111)
ax = fig.add_subplot(111, projection="3d")
ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.zaxis.set_tick_params(labelleft=False)
colors = ['r', 'b']
styles = ['-', '--']
for idx, (label, traj) in enumerate(traj_by_label.items()):
plot.traj(ax, plot_mode, traj,
styles[idx], colors[idx], label)
# break
# plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
ax.view_init(elev=10., azim=45)
plt.tight_layout()
os.makedirs(os.path.join(os.path.dirname(
output_path), 'pose_vid'), exist_ok=True)
pose_vis_path = os.path.join(os.path.dirname(
output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx))
print(pose_vis_path)
fig.savefig(pose_vis_path)
# else:
fig = plt.figure()
fig.patch.set_facecolor('white') # 把背景设置为纯白色
traj_by_label = {
# "estimate (not aligned)": traj_est,
"Ours (aligned)": traj_est_aligned,
# "NoPe-NeRF (aligned)": traj_est_aligned,
# "CF-3DGS (aligned)": traj_est_aligned,
# "NeRFmm (aligned)": traj_est_aligned,
# args.method + " (aligned)": traj_est_aligned,
"COLMAP (GT)": traj_ref
# "Ground-truth": traj_ref
}
plot_mode = plot.PlotMode.xyz
# ax = plot.prepare_axis(fig, plot_mode, 111)
ax = fig.add_subplot(111, projection="3d")
ax.set_facecolor('white') # 把子图设置为纯白色
ax.xaxis.set_tick_params(labelbottom=True)
ax.yaxis.set_tick_params(labelleft=True)
ax.zaxis.set_tick_params(labelleft=True)
colors = ['#2c9e38', '#d12920'] #
# colors = ['#2c9e38', '#a72126'] #
# colors = ['r', 'b']
styles = ['-', '--']
for idx, (label, traj) in enumerate(traj_by_label.items()):
plot.traj(ax, plot_mode, traj,
styles[idx], colors[idx], label)
# break
# plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz)
ax.view_init(elev=30., azim=45)
# ax.view_init(elev=10., azim=45)
plt.tight_layout()
pose_vis_path = output_path / f'pose_vis.png'
# pose_vis_path = os.path.join(os.path.dirname(output_path), f'pose_vis_{args.method}_{args.scene}.png')
fig.savefig(pose_vis_path)
# path_parts = args.pose_path.split('/')
# tmp_vis_path = '/'.join(path_parts[:-1]) + '/all_vis'
# tmp_vis_path2 = os.path.join(tmp_vis_path, f'pose_vis_{args.method}_{args.scene}.png')
# fig.savefig(tmp_vis_path2)