Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import copy | |
from evo.core.trajectory import PosePath3D, PoseTrajectory3D | |
from evo.main_ape import ape | |
from evo.tools import plot | |
from evo.core import sync | |
from evo.tools import file_interface | |
from evo.core import metrics | |
import evo | |
import torch | |
import numpy as np | |
from scipy.spatial.transform import Slerp | |
from scipy.spatial.transform import Rotation as R | |
import scipy.interpolate as si | |
def interp_poses(c2ws, N_views): | |
N_inputs = c2ws.shape[0] | |
trans = c2ws[:, :3, 3:].permute(2, 1, 0) | |
rots = c2ws[:, :3, :3] | |
render_poses = [] | |
rots = R.from_matrix(rots) | |
slerp = Slerp(np.linspace(0, 1, N_inputs), rots) | |
interp_rots = torch.tensor( | |
slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32)) | |
interp_trans = torch.nn.functional.interpolate( | |
trans, size=N_views, mode='linear').permute(2, 1, 0) | |
render_poses = torch.cat([interp_rots, interp_trans], dim=2) | |
render_poses = convert3x4_4x4(render_poses) | |
return render_poses | |
def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree): | |
target_trans = torch.tensor(scipy_bspline( | |
c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2) | |
rots = R.from_matrix(c2ws[:, :3, :3]) | |
slerp = Slerp(input_times, rots) | |
target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs) | |
target_rots = torch.tensor( | |
slerp(target_times).as_matrix().astype(np.float32)) | |
target_poses = torch.cat([target_rots, target_trans], dim=2) | |
target_poses = convert3x4_4x4(target_poses) | |
return target_poses | |
def poses_avg(poses): | |
hwf = poses[0, :3, -1:] | |
center = poses[:, :3, 3].mean(0) | |
vec2 = normalize(poses[:, :3, 2].sum(0)) | |
up = poses[:, :3, 1].sum(0) | |
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) | |
return c2w | |
def normalize(v): | |
"""Normalize a vector.""" | |
return v / np.linalg.norm(v) | |
def viewmatrix(z, up, pos): | |
vec2 = normalize(z) | |
vec1_avg = up | |
vec0 = normalize(np.cross(vec1_avg, vec2)) | |
vec1 = normalize(np.cross(vec2, vec0)) | |
m = np.stack([vec0, vec1, vec2, pos], 1) | |
return m | |
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): | |
render_poses = [] | |
rads = np.array(list(rads) + [1.]) | |
hwf = c2w[:, 4:5] | |
for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: | |
# c = np.dot(c2w[:3,:4], np.array([0.7*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.1, 1.]) * rads) | |
# c = np.dot(c2w[:3,:4], np.array([0.3*np.cos(theta) , -0.3*np.sin(theta) , -np.sin(theta*zrate) *0.01, 1.]) * rads) | |
c = np.dot(c2w[:3, :4], np.array( | |
[0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads) | |
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) | |
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) | |
return render_poses | |
def scipy_bspline(cv, n=100, degree=3, periodic=False): | |
""" Calculate n samples on a bspline | |
cv : Array ov control vertices | |
n : Number of samples to return | |
degree: Curve degree | |
periodic: True - Curve is closed | |
""" | |
cv = np.asarray(cv) | |
count = cv.shape[0] | |
# Closed curve | |
if periodic: | |
kv = np.arange(-degree, count+degree+1) | |
factor, fraction = divmod(count+degree+1, count) | |
cv = np.roll(np.concatenate( | |
(cv,) * factor + (cv[:fraction],)), -1, axis=0) | |
degree = np.clip(degree, 1, degree) | |
# Opened curve | |
else: | |
degree = np.clip(degree, 1, count-1) | |
kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree) | |
# Return samples | |
max_param = count - (degree * (1-periodic)) | |
spl = si.BSpline(kv, cv, degree) | |
return spl(np.linspace(0, max_param, n)) | |
def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf): | |
learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach( | |
).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1) | |
c2w = poses_avg(learned_poses_) | |
print('recentered', c2w.shape) | |
# Get spiral | |
# Get average pose | |
up = normalize(learned_poses_[:, :3, 1].sum(0)) | |
# Find a reasonable "focus depth" for this dataset | |
close_depth, inf_depth = bds.min()*.9, bds.max()*5. | |
dt = .75 | |
mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) | |
focal = mean_dz | |
# Get radii for spiral path | |
shrink_factor = .8 | |
zdelta = close_depth * .2 | |
tt = learned_poses_[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T | |
rads = np.percentile(np.abs(tt), 90, 0) | |
c2w_path = c2w | |
N_rots = 2 | |
c2ws = render_path_spiral( | |
c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views) | |
c2ws = torch.tensor(np.stack(c2ws).astype(np.float32)) | |
c2ws = c2ws[:, :3, :4] | |
c2ws = convert3x4_4x4(c2ws) | |
return c2ws | |
def convert3x4_4x4(input): | |
""" | |
:param input: (N, 3, 4) or (3, 4) torch or np | |
:return: (N, 4, 4) or (4, 4) torch or np | |
""" | |
if torch.is_tensor(input): | |
if len(input.shape) == 3: | |
output = torch.cat([input, torch.zeros_like( | |
input[:, 0:1])], dim=1) # (N, 4, 4) | |
output[:, 3, 3] = 1.0 | |
else: | |
output = torch.cat([input, torch.tensor( | |
[[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) # (4, 4) | |
else: | |
if len(input.shape) == 3: | |
output = np.concatenate( | |
[input, np.zeros_like(input[:, 0:1])], axis=1) # (N, 4, 4) | |
output[:, 3, 3] = 1.0 | |
else: | |
output = np.concatenate( | |
[input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) # (4, 4) | |
output[3, 3] = 1.0 | |
return output | |
plt.rc('legend', fontsize=20) # using a named size | |
def plot_pose(ref_poses, est_poses, output_path, args, vid=False): | |
ref_poses = [pose for pose in ref_poses] | |
if isinstance(est_poses, dict): | |
est_poses = [pose for k, pose in est_poses.items()] | |
else: | |
est_poses = [pose for pose in est_poses] | |
traj_ref = PosePath3D(poses_se3=ref_poses) | |
traj_est = PosePath3D(poses_se3=est_poses) | |
traj_est_aligned = copy.deepcopy(traj_est) | |
traj_est_aligned.align(traj_ref, correct_scale=True, | |
correct_only_scale=False) | |
if vid: | |
for p_idx in range(len(ref_poses)): | |
fig = plt.figure() | |
current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1] | |
current_ref = traj_ref.poses_se3[:p_idx+1] | |
current_est_aligned = PosePath3D(poses_se3=current_est_aligned) | |
current_ref = PosePath3D(poses_se3=current_ref) | |
traj_by_label = { | |
# "estimate (not aligned)": traj_est, | |
"Ours (aligned)": current_est_aligned, | |
"Ground-truth": current_ref | |
} | |
plot_mode = plot.PlotMode.xyz | |
# ax = plot.prepare_axis(fig, plot_mode, 111) | |
ax = fig.add_subplot(111, projection="3d") | |
ax.xaxis.set_tick_params(labelbottom=False) | |
ax.yaxis.set_tick_params(labelleft=False) | |
ax.zaxis.set_tick_params(labelleft=False) | |
colors = ['r', 'b'] | |
styles = ['-', '--'] | |
for idx, (label, traj) in enumerate(traj_by_label.items()): | |
plot.traj(ax, plot_mode, traj, | |
styles[idx], colors[idx], label) | |
# break | |
# plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz) | |
ax.view_init(elev=10., azim=45) | |
plt.tight_layout() | |
os.makedirs(os.path.join(os.path.dirname( | |
output_path), 'pose_vid'), exist_ok=True) | |
pose_vis_path = os.path.join(os.path.dirname( | |
output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx)) | |
print(pose_vis_path) | |
fig.savefig(pose_vis_path) | |
# else: | |
fig = plt.figure() | |
fig.patch.set_facecolor('white') # 把背景设置为纯白色 | |
traj_by_label = { | |
# "estimate (not aligned)": traj_est, | |
"Ours (aligned)": traj_est_aligned, | |
# "NoPe-NeRF (aligned)": traj_est_aligned, | |
# "CF-3DGS (aligned)": traj_est_aligned, | |
# "NeRFmm (aligned)": traj_est_aligned, | |
# args.method + " (aligned)": traj_est_aligned, | |
"COLMAP (GT)": traj_ref | |
# "Ground-truth": traj_ref | |
} | |
plot_mode = plot.PlotMode.xyz | |
# ax = plot.prepare_axis(fig, plot_mode, 111) | |
ax = fig.add_subplot(111, projection="3d") | |
ax.set_facecolor('white') # 把子图设置为纯白色 | |
ax.xaxis.set_tick_params(labelbottom=True) | |
ax.yaxis.set_tick_params(labelleft=True) | |
ax.zaxis.set_tick_params(labelleft=True) | |
colors = ['#2c9e38', '#d12920'] # | |
# colors = ['#2c9e38', '#a72126'] # | |
# colors = ['r', 'b'] | |
styles = ['-', '--'] | |
for idx, (label, traj) in enumerate(traj_by_label.items()): | |
plot.traj(ax, plot_mode, traj, | |
styles[idx], colors[idx], label) | |
# break | |
# plot.trajectories(fig, traj_by_label, plot.PlotMode.xyz) | |
ax.view_init(elev=30., azim=45) | |
# ax.view_init(elev=10., azim=45) | |
plt.tight_layout() | |
pose_vis_path = output_path / f'pose_vis.png' | |
# pose_vis_path = os.path.join(os.path.dirname(output_path), f'pose_vis_{args.method}_{args.scene}.png') | |
fig.savefig(pose_vis_path) | |
# path_parts = args.pose_path.split('/') | |
# tmp_vis_path = '/'.join(path_parts[:-1]) + '/all_vis' | |
# tmp_vis_path2 = os.path.join(tmp_vis_path, f'pose_vis_{args.method}_{args.scene}.png') | |
# fig.savefig(tmp_vis_path2) | |