File size: 3,528 Bytes
35e2575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# make pairs using mast3r scene_graph, including retrieval
# --------------------------------------------------------
import argparse
import torch
import os
import os.path as path
import PIL
from PIL import Image
import pathlib
from kapture.io.csv import table_to_file

from mast3r.model import AsymmetricMASt3R
from mast3r.retrieval.processor import Retriever
from mast3r.image_pairs import make_pairs


def get_argparser():
    parser = argparse.ArgumentParser(description='point triangulator with mast3r from kapture data')
    parser.add_argument('--dir', required=True, help='image dir')
    parser.add_argument('--scene_graph', default='retrieval-20-1-10-1')
    parser.add_argument('--output', required=True, help='txt file')

    parser_weights = parser.add_mutually_exclusive_group(required=False)
    parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
    parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
                                choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])
    parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")

    return parser


def get_image_list(images_path):
    file_list = [path.relpath(path.join(dirpath, filename), images_path)
                 for dirpath, dirs, filenames in os.walk(images_path)
                 for filename in filenames]
    file_list = sorted(file_list)
    image_list = []
    for filename in file_list:
        # test if file is a valid image
        try:
            # lazy load
            with Image.open(path.join(images_path, filename)) as im:
                width, height = im.size
                image_list.append(filename)
        except (OSError, PIL.UnidentifiedImageError):
            # It is not a valid image: skip it
            print(f'Skipping invalid image file {filename}')
            continue
    return image_list


def main(dir, scene_graph, output, backbone=None, retrieval_model=None):
    imgs = get_image_list(dir)

    sim_matrix = None
    if 'retrieval' in scene_graph:
        retriever = Retriever(retrieval_model, backbone=backbone)
        imgs_fp = [path.join(dir, filename) for filename in imgs]
        with torch.no_grad():
            sim_matrix = retriever(imgs_fp)

        # Cleanup
        del retriever
        torch.cuda.empty_cache()

    pairs = make_pairs(imgs, scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)

    os.umask(0o002)
    p = pathlib.Path(output)
    os.makedirs(str(p.parent.resolve()), exist_ok=True)

    with open(output, 'w') as fid:
        table_to_file(fid, pairs, header='# query_image, map_image, score')


if __name__ == '__main__':
    parser = get_argparser()
    args = parser.parse_args()

    if "retrieval" in args.scene_graph:
        assert args.retrieval_model is not None
        if args.weights is not None:
            weights_path = args.weights
        else:
            weights_path = "naver/" + args.model_name
        backbone = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
        retrieval_model = args.retrieval_model
    else:
        backbone = None
        retrieval_model = None
    main(args.dir, args.scene_graph, args.output, backbone, retrieval_model)