mast3r-sfm / make_pairs.py
yocabon's picture
add initial version of mast3r sfm and glomap/colmap wrapper
35e2575
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# make pairs using mast3r scene_graph, including retrieval
# --------------------------------------------------------
import argparse
import torch
import os
import os.path as path
import PIL
from PIL import Image
import pathlib
from kapture.io.csv import table_to_file
from mast3r.model import AsymmetricMASt3R
from mast3r.retrieval.processor import Retriever
from mast3r.image_pairs import make_pairs
def get_argparser():
parser = argparse.ArgumentParser(description='point triangulator with mast3r from kapture data')
parser.add_argument('--dir', required=True, help='image dir')
parser.add_argument('--scene_graph', default='retrieval-20-1-10-1')
parser.add_argument('--output', required=True, help='txt file')
parser_weights = parser.add_mutually_exclusive_group(required=False)
parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])
parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
return parser
def get_image_list(images_path):
file_list = [path.relpath(path.join(dirpath, filename), images_path)
for dirpath, dirs, filenames in os.walk(images_path)
for filename in filenames]
file_list = sorted(file_list)
image_list = []
for filename in file_list:
# test if file is a valid image
try:
# lazy load
with Image.open(path.join(images_path, filename)) as im:
width, height = im.size
image_list.append(filename)
except (OSError, PIL.UnidentifiedImageError):
# It is not a valid image: skip it
print(f'Skipping invalid image file {filename}')
continue
return image_list
def main(dir, scene_graph, output, backbone=None, retrieval_model=None):
imgs = get_image_list(dir)
sim_matrix = None
if 'retrieval' in scene_graph:
retriever = Retriever(retrieval_model, backbone=backbone)
imgs_fp = [path.join(dir, filename) for filename in imgs]
with torch.no_grad():
sim_matrix = retriever(imgs_fp)
# Cleanup
del retriever
torch.cuda.empty_cache()
pairs = make_pairs(imgs, scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
os.umask(0o002)
p = pathlib.Path(output)
os.makedirs(str(p.parent.resolve()), exist_ok=True)
with open(output, 'w') as fid:
table_to_file(fid, pairs, header='# query_image, map_image, score')
if __name__ == '__main__':
parser = get_argparser()
args = parser.parse_args()
if "retrieval" in args.scene_graph:
assert args.retrieval_model is not None
if args.weights is not None:
weights_path = args.weights
else:
weights_path = "naver/" + args.model_name
backbone = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
retrieval_model = args.retrieval_model
else:
backbone = None
retrieval_model = None
main(args.dir, args.scene_graph, args.output, backbone, retrieval_model)