|
import os |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
import copy |
|
from evo.core.trajectory import PosePath3D, PoseTrajectory3D |
|
from evo.main_ape import ape |
|
from evo.tools import plot |
|
from evo.core import sync |
|
from evo.tools import file_interface |
|
from evo.core import metrics |
|
import evo |
|
import torch |
|
import numpy as np |
|
from scipy.spatial.transform import Slerp |
|
from scipy.spatial.transform import Rotation as R |
|
import scipy.interpolate as si |
|
|
|
|
|
def interp_poses(c2ws, N_views): |
|
N_inputs = c2ws.shape[0] |
|
trans = c2ws[:, :3, 3:].permute(2, 1, 0) |
|
rots = c2ws[:, :3, :3] |
|
render_poses = [] |
|
rots = R.from_matrix(rots) |
|
slerp = Slerp(np.linspace(0, 1, N_inputs), rots) |
|
interp_rots = torch.tensor( |
|
slerp(np.linspace(0, 1, N_views)).as_matrix().astype(np.float32)) |
|
interp_trans = torch.nn.functional.interpolate( |
|
trans, size=N_views, mode='linear').permute(2, 1, 0) |
|
render_poses = torch.cat([interp_rots, interp_trans], dim=2) |
|
render_poses = convert3x4_4x4(render_poses) |
|
return render_poses |
|
|
|
|
|
def interp_poses_bspline(c2ws, N_novel_imgs, input_times, degree): |
|
target_trans = torch.tensor(scipy_bspline( |
|
c2ws[:, :3, 3], n=N_novel_imgs, degree=degree, periodic=False).astype(np.float32)).unsqueeze(2) |
|
rots = R.from_matrix(c2ws[:, :3, :3]) |
|
slerp = Slerp(input_times, rots) |
|
target_times = np.linspace(input_times[0], input_times[-1], N_novel_imgs) |
|
target_rots = torch.tensor( |
|
slerp(target_times).as_matrix().astype(np.float32)) |
|
target_poses = torch.cat([target_rots, target_trans], dim=2) |
|
target_poses = convert3x4_4x4(target_poses) |
|
return target_poses |
|
|
|
|
|
def poses_avg(poses): |
|
|
|
hwf = poses[0, :3, -1:] |
|
|
|
center = poses[:, :3, 3].mean(0) |
|
vec2 = normalize(poses[:, :3, 2].sum(0)) |
|
up = poses[:, :3, 1].sum(0) |
|
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) |
|
|
|
return c2w |
|
|
|
|
|
def normalize(v): |
|
"""Normalize a vector.""" |
|
return v / np.linalg.norm(v) |
|
|
|
|
|
def viewmatrix(z, up, pos): |
|
vec2 = normalize(z) |
|
vec1_avg = up |
|
vec0 = normalize(np.cross(vec1_avg, vec2)) |
|
vec1 = normalize(np.cross(vec2, vec0)) |
|
m = np.stack([vec0, vec1, vec2, pos], 1) |
|
return m |
|
|
|
|
|
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): |
|
render_poses = [] |
|
rads = np.array(list(rads) + [1.]) |
|
hwf = c2w[:, 4:5] |
|
|
|
for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: |
|
|
|
|
|
c = np.dot(c2w[:3, :4], np.array( |
|
[0.2*np.cos(theta), -0.2*np.sin(theta), -np.sin(theta*zrate) * 0.1, 1.]) * rads) |
|
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) |
|
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) |
|
return render_poses |
|
|
|
|
|
def scipy_bspline(cv, n=100, degree=3, periodic=False): |
|
""" Calculate n samples on a bspline |
|
|
|
cv : Array ov control vertices |
|
n : Number of samples to return |
|
degree: Curve degree |
|
periodic: True - Curve is closed |
|
""" |
|
cv = np.asarray(cv) |
|
count = cv.shape[0] |
|
|
|
|
|
if periodic: |
|
kv = np.arange(-degree, count+degree+1) |
|
factor, fraction = divmod(count+degree+1, count) |
|
cv = np.roll(np.concatenate( |
|
(cv,) * factor + (cv[:fraction],)), -1, axis=0) |
|
degree = np.clip(degree, 1, degree) |
|
|
|
|
|
else: |
|
degree = np.clip(degree, 1, count-1) |
|
kv = np.clip(np.arange(count+degree+1)-degree, 0, count-degree) |
|
|
|
|
|
max_param = count - (degree * (1-periodic)) |
|
spl = si.BSpline(kv, cv, degree) |
|
return spl(np.linspace(0, max_param, n)) |
|
|
|
|
|
def generate_spiral_nerf(learned_poses, bds, N_novel_views, hwf): |
|
learned_poses_ = np.concatenate((learned_poses[:, :3, :4].detach( |
|
).cpu().numpy(), hwf[:len(learned_poses)]), axis=-1) |
|
c2w = poses_avg(learned_poses_) |
|
print('recentered', c2w.shape) |
|
|
|
|
|
up = normalize(learned_poses_[:, :3, 1].sum(0)) |
|
|
|
|
|
close_depth, inf_depth = bds.min()*.9, bds.max()*5. |
|
dt = .75 |
|
mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) |
|
focal = mean_dz |
|
|
|
|
|
shrink_factor = .8 |
|
zdelta = close_depth * .2 |
|
tt = learned_poses_[:, :3, 3] |
|
rads = np.percentile(np.abs(tt), 90, 0) |
|
c2w_path = c2w |
|
N_rots = 2 |
|
c2ws = render_path_spiral( |
|
c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_novel_views) |
|
c2ws = torch.tensor(np.stack(c2ws).astype(np.float32)) |
|
c2ws = c2ws[:, :3, :4] |
|
c2ws = convert3x4_4x4(c2ws) |
|
return c2ws |
|
|
|
|
|
def convert3x4_4x4(input): |
|
""" |
|
:param input: (N, 3, 4) or (3, 4) torch or np |
|
:return: (N, 4, 4) or (4, 4) torch or np |
|
""" |
|
if torch.is_tensor(input): |
|
if len(input.shape) == 3: |
|
output = torch.cat([input, torch.zeros_like( |
|
input[:, 0:1])], dim=1) |
|
output[:, 3, 3] = 1.0 |
|
else: |
|
output = torch.cat([input, torch.tensor( |
|
[[0, 0, 0, 1]], dtype=input.dtype, device=input.device)], dim=0) |
|
else: |
|
if len(input.shape) == 3: |
|
output = np.concatenate( |
|
[input, np.zeros_like(input[:, 0:1])], axis=1) |
|
output[:, 3, 3] = 1.0 |
|
else: |
|
output = np.concatenate( |
|
[input, np.array([[0, 0, 0, 1]], dtype=input.dtype)], axis=0) |
|
output[3, 3] = 1.0 |
|
return output |
|
|
|
|
|
plt.rc('legend', fontsize=20) |
|
|
|
|
|
def plot_pose(ref_poses, est_poses, output_path, args, vid=False): |
|
ref_poses = [pose for pose in ref_poses] |
|
if isinstance(est_poses, dict): |
|
est_poses = [pose for k, pose in est_poses.items()] |
|
else: |
|
est_poses = [pose for pose in est_poses] |
|
traj_ref = PosePath3D(poses_se3=ref_poses) |
|
traj_est = PosePath3D(poses_se3=est_poses) |
|
traj_est_aligned = copy.deepcopy(traj_est) |
|
traj_est_aligned.align(traj_ref, correct_scale=True, |
|
correct_only_scale=False) |
|
if vid: |
|
for p_idx in range(len(ref_poses)): |
|
fig = plt.figure() |
|
current_est_aligned = traj_est_aligned.poses_se3[:p_idx+1] |
|
current_ref = traj_ref.poses_se3[:p_idx+1] |
|
current_est_aligned = PosePath3D(poses_se3=current_est_aligned) |
|
current_ref = PosePath3D(poses_se3=current_ref) |
|
traj_by_label = { |
|
|
|
"Ours (aligned)": current_est_aligned, |
|
"Ground-truth": current_ref |
|
} |
|
plot_mode = plot.PlotMode.xyz |
|
|
|
ax = fig.add_subplot(111, projection="3d") |
|
ax.xaxis.set_tick_params(labelbottom=False) |
|
ax.yaxis.set_tick_params(labelleft=False) |
|
ax.zaxis.set_tick_params(labelleft=False) |
|
colors = ['r', 'b'] |
|
styles = ['-', '--'] |
|
|
|
for idx, (label, traj) in enumerate(traj_by_label.items()): |
|
plot.traj(ax, plot_mode, traj, |
|
styles[idx], colors[idx], label) |
|
|
|
|
|
ax.view_init(elev=10., azim=45) |
|
plt.tight_layout() |
|
os.makedirs(os.path.join(os.path.dirname( |
|
output_path), 'pose_vid'), exist_ok=True) |
|
pose_vis_path = os.path.join(os.path.dirname( |
|
output_path), 'pose_vid', 'pose_vis_{:03d}.png'.format(p_idx)) |
|
print(pose_vis_path) |
|
fig.savefig(pose_vis_path) |
|
|
|
|
|
|
|
fig = plt.figure() |
|
fig.patch.set_facecolor('white') |
|
traj_by_label = { |
|
|
|
|
|
"Ours (aligned)": traj_est_aligned, |
|
|
|
|
|
|
|
|
|
"COLMAP (GT)": traj_ref |
|
|
|
} |
|
plot_mode = plot.PlotMode.xyz |
|
|
|
ax = fig.add_subplot(111, projection="3d") |
|
ax.set_facecolor('white') |
|
ax.xaxis.set_tick_params(labelbottom=True) |
|
ax.yaxis.set_tick_params(labelleft=True) |
|
ax.zaxis.set_tick_params(labelleft=True) |
|
colors = ['#2c9e38', '#d12920'] |
|
|
|
|
|
|
|
styles = ['-', '--'] |
|
|
|
for idx, (label, traj) in enumerate(traj_by_label.items()): |
|
plot.traj(ax, plot_mode, traj, |
|
styles[idx], colors[idx], label) |
|
|
|
|
|
ax.view_init(elev=30., azim=45) |
|
|
|
plt.tight_layout() |
|
pose_vis_path = output_path / f'pose_vis.png' |
|
|
|
fig.savefig(pose_vis_path) |
|
|
|
|
|
|
|
|
|
|
|
|