File size: 2,442 Bytes
4f6b78d
 
 
 
 
 
 
 
 
 
 
 
60fd7ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6b78d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60fd7ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# --------------------------------------------------------
# training executable for DUSt3R
# --------------------------------------------------------
from dust3r.training import get_args_parser, train, load_model
from dust3r.pose_eval import eval_pose_estimation, pose_estimation_custom
from dust3r.depth_eval import eval_mono_depth_estimation
import croco.utils.misc as misc  # noqa
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import os

def main(pretrained, dir_path, output_dir, use_pred_mask, n_iter):
    
    args = get_args_parser()
    args = args.parse_args()
    args.pretrained = pretrained
    args.dir_path = dir_path
    args.output_dir = output_dir
    args.use_pred_mask = use_pred_mask
    args.n_iter = n_iter

    misc.init_distributed_mode(args)
    global_rank = misc.get_rank()
    world_size = misc.get_world_size()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # fix the seed
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = args.cudnn_benchmark
    model, _ = load_model(args, device)
    os.makedirs(args.output_dir, exist_ok=True)


    pose_estimation_custom(args, model, device, save_dir=args.output_dir)



if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.mode.startswith('eval'):
        misc.init_distributed_mode(args)
        global_rank = misc.get_rank()
        world_size = misc.get_world_size()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        device = torch.device(device)

        # fix the seed
        seed = args.seed + misc.get_rank()
        torch.manual_seed(seed)
        np.random.seed(seed)
        cudnn.benchmark = args.cudnn_benchmark
        model, _ = load_model(args, device)
        os.makedirs(args.output_dir, exist_ok=True)

        if args.mode == 'eval_pose':
            ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation(args, model, device, save_dir=args.output_dir)
            print(f'ATE mean: {ate_mean}, RPE trans mean: {rpe_trans_mean}, RPE rot mean: {rpe_rot_mean}')
        if args.mode == 'eval_pose_custom':
            pose_estimation_custom(args, model, device, save_dir=args.output_dir)

        if args.mode == 'eval_depth':
            eval_mono_depth_estimation(args, model, device)

        exit(0)
    train(args)