#!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # make pairs using mast3r scene_graph, including retrieval # -------------------------------------------------------- import argparse import torch import os import os.path as path import PIL from PIL import Image import pathlib from kapture.io.csv import table_to_file from mast3r.model import AsymmetricMASt3R from mast3r.retrieval.processor import Retriever from mast3r.image_pairs import make_pairs def get_argparser(): parser = argparse.ArgumentParser(description='point triangulator with mast3r from kapture data') parser.add_argument('--dir', required=True, help='image dir') parser.add_argument('--scene_graph', default='retrieval-20-1-10-1') parser.add_argument('--output', required=True, help='txt file') parser_weights = parser.add_mutually_exclusive_group(required=False) parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) parser_weights.add_argument("--model_name", type=str, help="name of the model weights", choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]) parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded") return parser def get_image_list(images_path): file_list = [path.relpath(path.join(dirpath, filename), images_path) for dirpath, dirs, filenames in os.walk(images_path) for filename in filenames] file_list = sorted(file_list) image_list = [] for filename in file_list: # test if file is a valid image try: # lazy load with Image.open(path.join(images_path, filename)) as im: width, height = im.size image_list.append(filename) except (OSError, PIL.UnidentifiedImageError): # It is not a valid image: skip it print(f'Skipping invalid image file {filename}') continue return image_list def main(dir, scene_graph, output, backbone=None, retrieval_model=None): imgs = get_image_list(dir) sim_matrix = None if 'retrieval' in scene_graph: retriever = Retriever(retrieval_model, backbone=backbone) imgs_fp = [path.join(dir, filename) for filename in imgs] with torch.no_grad(): sim_matrix = retriever(imgs_fp) # Cleanup del retriever torch.cuda.empty_cache() pairs = make_pairs(imgs, scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix) os.umask(0o002) p = pathlib.Path(output) os.makedirs(str(p.parent.resolve()), exist_ok=True) with open(output, 'w') as fid: table_to_file(fid, pairs, header='# query_image, map_image, score') if __name__ == '__main__': parser = get_argparser() args = parser.parse_args() if "retrieval" in args.scene_graph: assert args.retrieval_model is not None if args.weights is not None: weights_path = args.weights else: weights_path = "naver/" + args.model_name backbone = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) retrieval_model = args.retrieval_model else: backbone = None retrieval_model = None main(args.dir, args.scene_graph, args.output, backbone, retrieval_model)