|
import os |
|
import math |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from dust3r.utils.vo_eval import load_traj, eval_metrics, plot_trajectory, save_trajectory_tum_format, process_directory, calculate_averages |
|
import croco.utils.misc as misc |
|
import torch.distributed as dist |
|
from tqdm import tqdm |
|
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode |
|
from dust3r.utils.image import load_images, rgb, enlarge_seg_masks |
|
from dust3r.image_pairs import make_pairs |
|
from dust3r.inference import inference |
|
|
|
import dust3r.eval_metadata |
|
from dust3r.eval_metadata import dataset_metadata |
|
|
|
|
|
def eval_pose_estimation(args, model, device, save_dir=None): |
|
metadata = dataset_metadata.get(args.eval_dataset, dataset_metadata['sintel']) |
|
img_path = metadata['img_path'] |
|
mask_path = metadata['mask_path'] |
|
|
|
ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation_dist( |
|
args, model, device, save_dir=save_dir, img_path=img_path, mask_path=mask_path |
|
) |
|
return ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug |
|
|
|
|
|
|
|
|
|
|
|
def eval_pose_estimation_dist(args, model, device, img_path, save_dir=None, mask_path=None): |
|
|
|
metadata = dataset_metadata.get(args.eval_dataset, dataset_metadata['sintel']) |
|
anno_path = metadata.get('anno_path', None) |
|
|
|
silent = args.silent |
|
seq_list = args.seq_list |
|
if seq_list is None: |
|
if metadata.get('full_seq', False): |
|
args.full_seq = True |
|
else: |
|
seq_list = metadata.get('seq_list', []) |
|
if args.full_seq: |
|
seq_list = os.listdir(img_path) |
|
seq_list = [seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq))] |
|
seq_list = sorted(seq_list) |
|
|
|
if save_dir is None: |
|
save_dir = args.output_dir |
|
|
|
|
|
if misc.is_dist_avail_and_initialized(): |
|
rank = dist.get_rank() |
|
world_size = dist.get_world_size() |
|
else: |
|
rank = 0 |
|
world_size = 1 |
|
|
|
total_seqs = len(seq_list) |
|
seqs_per_proc = (total_seqs + world_size - 1) // world_size |
|
|
|
start_idx = rank * seqs_per_proc |
|
end_idx = min(start_idx + seqs_per_proc, total_seqs) |
|
|
|
seq_list = seq_list[start_idx:end_idx] |
|
|
|
ate_list = [] |
|
rpe_trans_list = [] |
|
rpe_rot_list = [] |
|
outfile_list = [] |
|
load_img_size = 512 |
|
|
|
error_log_path = f'{save_dir}/_error_log_{rank}.txt' |
|
bug = False |
|
|
|
for seq in tqdm(seq_list): |
|
try: |
|
dir_path = metadata['dir_path_func'](img_path, seq) |
|
|
|
|
|
skip_condition = metadata.get('skip_condition', None) |
|
if skip_condition is not None and skip_condition(save_dir, seq): |
|
continue |
|
|
|
mask_path_seq_func = metadata.get('mask_path_seq_func', lambda mask_path, seq: None) |
|
mask_path_seq = mask_path_seq_func(mask_path, seq) |
|
|
|
filelist = [os.path.join(dir_path, name) for name in os.listdir(dir_path)] |
|
filelist.sort() |
|
if args.evaluate_davis: |
|
filelist = filelist[:50] |
|
filelist = filelist[::args.pose_eval_stride] |
|
max_winsize = max(1, math.ceil((len(filelist)-1)/2)) |
|
scene_graph_type = args.scene_graph_type |
|
if int(scene_graph_type.split('-')[1]) > max_winsize: |
|
scene_graph_type = f'{args.scene_graph_type.split("-")[0]}-{max_winsize}' |
|
if len(scene_graph_type.split("-")) > 2: |
|
scene_graph_type += f'-{args.scene_graph_type.split("-")[2]}' |
|
imgs = load_images( |
|
filelist, size=load_img_size, verbose=False, |
|
dynamic_mask_root=mask_path_seq, crop=not args.no_crop |
|
) |
|
print(f'Processing {seq} with {len(imgs)} images') |
|
if args.eval_dataset == 'davis' and len(imgs) > 95: |
|
|
|
scene_graph_type = scene_graph_type.replace('5', '4') |
|
pairs = make_pairs( |
|
imgs, scene_graph=scene_graph_type, prefilter=None, symmetrize=True |
|
) |
|
|
|
|
|
output = inference(pairs, model, device, batch_size=1, verbose=not silent) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
with torch.enable_grad(): |
|
if len(imgs) > 2: |
|
mode = GlobalAlignerMode.PointCloudOptimizer |
|
scene = global_aligner( |
|
output, device=device, mode=mode, verbose=not silent, |
|
shared_focal=not args.not_shared_focal and not args.use_gt_focal, |
|
flow_loss_weight=args.flow_loss_weight, flow_loss_fn=args.flow_loss_fn, |
|
depth_regularize_weight=args.depth_regularize_weight, |
|
num_total_iter=args.n_iter, temporal_smoothing_weight=args.temporal_smoothing_weight, |
|
motion_mask_thre=args.motion_mask_thre, |
|
flow_loss_start_epoch=args.flow_loss_start_epoch, flow_loss_thre=args.flow_loss_thre, translation_weight=args.translation_weight, |
|
sintel_ckpt=args.eval_dataset == 'sintel', use_gt_mask = args.use_gt_mask, use_pred_mask = args.use_pred_mask, sam2_mask_refine=args.sam2_mask_refine, |
|
empty_cache=len(filelist) > 72, pxl_thre=args.pxl_thresh, batchify=not args.not_batchify |
|
) |
|
if args.use_gt_focal: |
|
focal_path = os.path.join( |
|
img_path.replace('final', 'camdata_left'), seq, 'focal.txt' |
|
) |
|
focals = np.loadtxt(focal_path) |
|
focals = focals[::args.pose_eval_stride] |
|
original_img_size = cv2.imread(filelist[0]).shape[:2] |
|
resized_img_size = tuple(imgs[0]['img'].shape[-2:]) |
|
focals = focals * max( |
|
(resized_img_size[0] / original_img_size[0]), |
|
(resized_img_size[1] / original_img_size[1]) |
|
) |
|
scene.preset_focal(focals, requires_grad=False) |
|
lr = 0.01 |
|
loss = scene.compute_global_alignment( |
|
init='mst', niter=args.n_iter, schedule=args.pose_schedule, lr=lr, |
|
) |
|
else: |
|
mode = GlobalAlignerMode.PairViewer |
|
scene = global_aligner(output, device=device, mode=mode, verbose=not silent) |
|
|
|
if args.save_pose_qualitative: |
|
outfile = get_3D_model_from_scene( |
|
outdir=save_dir, silent=silent, scene=scene, min_conf_thr=2, as_pointcloud=True, mask_sky=False, |
|
clean_depth=True, transparent_cams=False, cam_size=0.01, save_name=seq |
|
) |
|
else: |
|
outfile = None |
|
pred_traj = scene.get_tum_poses() |
|
|
|
os.makedirs(f'{save_dir}/{seq}', exist_ok=True) |
|
scene.clean_pointcloud() |
|
scene.save_tum_poses(f'{save_dir}/{seq}/pred_traj.txt') |
|
scene.save_focals(f'{save_dir}/{seq}/pred_focal.txt') |
|
scene.save_intrinsics(f'{save_dir}/{seq}/pred_intrinsics.txt') |
|
scene.save_depth_maps(f'{save_dir}/{seq}') |
|
scene.save_dynamic_masks(f'{save_dir}/{seq}') |
|
scene.save_dyna_maps(f'{save_dir}/{seq}') |
|
scene.save_conf_maps(f'{save_dir}/{seq}') |
|
scene.save_init_conf_maps(f'{save_dir}/{seq}') |
|
scene.save_rgb_imgs(f'{save_dir}/{seq}') |
|
enlarge_seg_masks(f'{save_dir}/{seq}', kernel_size=5 if args.use_gt_mask else 3) |
|
|
|
gt_traj_file = metadata['gt_traj_func'](img_path, anno_path, seq) |
|
traj_format = metadata.get('traj_format', None) |
|
|
|
if args.eval_dataset == 'sintel': |
|
gt_traj = load_traj(gt_traj_file=gt_traj_file, stride=args.pose_eval_stride) |
|
elif traj_format is not None: |
|
gt_traj = load_traj(gt_traj_file=gt_traj_file, traj_format=traj_format) |
|
else: |
|
gt_traj = None |
|
|
|
if gt_traj is not None: |
|
ate, rpe_trans, rpe_rot = eval_metrics( |
|
pred_traj, gt_traj, seq=seq, filename=f'{save_dir}/{seq}_eval_metric.txt' |
|
) |
|
plot_trajectory( |
|
pred_traj, gt_traj, title=seq, filename=f'{save_dir}/{seq}.png' |
|
) |
|
else: |
|
ate, rpe_trans, rpe_rot = 0, 0, 0 |
|
outfile = None |
|
bug = True |
|
|
|
ate_list.append(ate) |
|
rpe_trans_list.append(rpe_trans) |
|
rpe_rot_list.append(rpe_rot) |
|
outfile_list.append(outfile) |
|
|
|
|
|
with open(error_log_path, 'a') as f: |
|
f.write(f'{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n') |
|
f.write(f'{ate:.5f}\n') |
|
f.write(f'{rpe_trans:.5f}\n') |
|
f.write(f'{rpe_rot:.5f}\n') |
|
|
|
except Exception as e: |
|
if 'out of memory' in str(e): |
|
|
|
torch.cuda.empty_cache() |
|
with open(error_log_path, 'a') as f: |
|
f.write(f'OOM error in sequence {seq}, skipping this sequence.\n') |
|
print(f'OOM error in sequence {seq}, skipping...') |
|
elif 'Degenerate covariance rank' in str(e) or 'Eigenvalues did not converge' in str(e): |
|
|
|
with open(error_log_path, 'a') as f: |
|
f.write(f'Exception in sequence {seq}: {str(e)}\n') |
|
print(f'Traj evaluation error in sequence {seq}, skipping.') |
|
else: |
|
raise e |
|
|
|
|
|
if misc.is_dist_avail_and_initialized(): |
|
torch.distributed.barrier() |
|
|
|
bug_tensor = torch.tensor(int(bug), device=device) |
|
|
|
bug = bool(bug_tensor.item()) |
|
|
|
|
|
outfile_list_all = [None for _ in range(world_size)] |
|
|
|
outfile_list_combined = [] |
|
for sublist in outfile_list_all: |
|
if sublist is not None: |
|
outfile_list_combined.extend(sublist) |
|
|
|
results = process_directory(save_dir) |
|
avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results) |
|
|
|
|
|
if rank == 0: |
|
with open(f'{save_dir}/_error_log.txt', 'a') as f: |
|
|
|
for i in range(world_size): |
|
with open(f'{save_dir}/_error_log_{i}.txt', 'r') as f_sub: |
|
f.write(f_sub.read()) |
|
f.write(f'Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n') |
|
|
|
return avg_ate, avg_rpe_trans, avg_rpe_rot, outfile_list_combined, bug |
|
|
|
|
|
def pose_estimation_custom(args, model, device, save_dir=None): |
|
load_img_size = 512 |
|
dir_path = args.dir_path |
|
silent = args.silent |
|
|
|
|
|
filelist = [os.path.join(dir_path, name) for name in os.listdir(dir_path)] |
|
filelist.sort() |
|
filelist = filelist[::args.pose_eval_stride] |
|
max_winsize = max(1, math.ceil((len(filelist)-1)/2)) |
|
scene_graph_type = args.scene_graph_type |
|
if int(scene_graph_type.split('-')[1]) > max_winsize: |
|
scene_graph_type = f'{args.scene_graph_type.split("-")[0]}-{max_winsize}' |
|
if len(scene_graph_type.split("-")) > 2: |
|
scene_graph_type += f'-{args.scene_graph_type.split("-")[2]}' |
|
imgs = load_images( |
|
filelist, size=load_img_size, verbose=False, crop=not args.no_crop |
|
) |
|
print(f'Processing {args.dir_path} with {len(imgs)} images') |
|
if len(imgs) > 95: |
|
|
|
scene_graph_type = scene_graph_type.replace('5', '4') |
|
pairs = make_pairs( |
|
imgs, scene_graph=scene_graph_type, prefilter=None, symmetrize=True |
|
) |
|
|
|
output = inference(pairs, model, device, batch_size=1, verbose=not silent) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
with torch.enable_grad(): |
|
if len(imgs) > 2: |
|
mode = GlobalAlignerMode.PointCloudOptimizer |
|
scene = global_aligner( |
|
output, device=device, mode=mode, verbose=not silent, |
|
shared_focal=not args.not_shared_focal and not args.use_gt_focal, |
|
flow_loss_weight=args.flow_loss_weight, flow_loss_fn=args.flow_loss_fn, |
|
depth_regularize_weight=args.depth_regularize_weight, |
|
num_total_iter=args.n_iter, temporal_smoothing_weight=args.temporal_smoothing_weight, |
|
motion_mask_thre=args.motion_mask_thre, |
|
flow_loss_start_epoch=args.flow_loss_start_epoch, flow_loss_thre=args.flow_loss_thre, translation_weight=args.translation_weight, |
|
sintel_ckpt=args.eval_dataset == 'sintel', use_gt_mask = args.use_gt_mask, use_pred_mask = args.use_pred_mask, sam2_mask_refine=args.sam2_mask_refine, |
|
empty_cache=len(filelist) > 72, pxl_thre=args.pxl_thresh, batchify=not args.not_batchify |
|
) |
|
if args.use_gt_focal: |
|
focal_path = args.focal_path |
|
focals = np.loadtxt(focal_path) |
|
focals = focals[::args.pose_eval_stride] |
|
original_img_size = cv2.imread(filelist[0]).shape[:2] |
|
resized_img_size = tuple(imgs[0]['img'].shape[-2:]) |
|
focals = focals * max( |
|
(resized_img_size[0] / original_img_size[0]), |
|
(resized_img_size[1] / original_img_size[1]) |
|
) |
|
scene.preset_focal(focals, requires_grad=False) |
|
lr = 0.01 |
|
loss = scene.compute_global_alignment( |
|
init='mst', niter=args.n_iter, schedule=args.pose_schedule, lr=lr, |
|
) |
|
else: |
|
mode = GlobalAlignerMode.PairViewer |
|
scene = global_aligner(output, device=device, mode=mode, verbose=not silent) |
|
|
|
|
|
os.makedirs(f'{save_dir}', exist_ok=True) |
|
scene.clean_pointcloud() |
|
scene.save_tum_poses(f'{save_dir}/pred_traj.txt') |
|
scene.save_focals(f'{save_dir}/pred_focal.txt') |
|
scene.save_intrinsics(f'{save_dir}/pred_intrinsics.txt') |
|
scene.save_depth_maps(f'{save_dir}') |
|
scene.save_dynamic_masks(f'{save_dir}') |
|
scene.save_dyna_maps(f'{save_dir}') |
|
scene.save_conf_maps(f'{save_dir}') |
|
scene.save_init_conf_maps(f'{save_dir}') |
|
scene.save_rgb_imgs(f'{save_dir}') |
|
|
|
|