mast3r-sfm / mast3r /retrieval /processor.py
yocabon's picture
add initial version of mast3r sfm and glomap/colmap wrapper
35e2575
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Main Retriever class
# --------------------------------------------------------
import os
import argparse
import numpy as np
import torch
from mast3r.model import AsymmetricMASt3R
from mast3r.retrieval.model import RetrievalModel, extract_local_features
try:
import faiss
faiss.StandardGpuResources() # when loading the checkpoint, it will try to instanciate FaissGpuL2Index
except AttributeError as e:
import asmk.index
class FaissCpuL2Index(asmk.index.FaissL2Index):
def __init__(self, gpu_id):
super().__init__()
self.gpu_id = gpu_id
def _faiss_index_flat(self, dim):
"""Return initialized faiss.IndexFlatL2"""
return faiss.IndexFlatL2(dim)
asmk.index.FaissGpuL2Index = FaissCpuL2Index
from asmk import asmk_method # noqa
def get_args_parser():
parser = argparse.ArgumentParser('Retrieval scores from a set of retrieval', add_help=False, allow_abbrev=False)
parser.add_argument('--model', type=str, required=True,
help="shortname of a retrieval model or path to the corresponding .pth")
parser.add_argument('--input', type=str, required=True,
help="directory containing images or a file containing a list of image paths")
parser.add_argument('--outfile', type=str, required=True, help="numpy file where to store the matrix score")
return parser
def get_impaths(imlistfile):
with open(imlistfile, 'r') as fid:
impaths = [f for f in imlistfile.read().splitlines() if not f.startswith('#')
and len(f) > 0] # ignore comments and empty lines
return impaths
def get_impaths_from_imdir(imdir, extensions=['png', 'jpg', 'PNG', 'JPG']):
assert os.path.isdir(imdir)
impaths = [os.path.join(imdir, f) for f in sorted(os.listdir(imdir)) if any(f.endswith(ext) for ext in extensions)]
return impaths
def get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile):
if os.path.isfile(input_imdir_or_imlistfile):
return get_impaths(input_imdir_or_imlistfile)
else:
return get_impaths_from_imdir(input_imdir_or_imlistfile)
class Retriever(object):
def __init__(self, modelname, backbone=None, device='cuda'):
# load the model
assert os.path.isfile(modelname), modelname
print(f'Loading retrieval model from {modelname}')
ckpt = torch.load(modelname, 'cpu') # TODO from pretrained to download it automatically
ckpt_args = ckpt['args']
if backbone is None:
backbone = AsymmetricMASt3R.from_pretrained(ckpt_args.pretrained)
self.model = RetrievalModel(
backbone, freeze_backbone=ckpt_args.freeze_backbone, prewhiten=ckpt_args.prewhiten,
hdims=list(map(int, ckpt_args.hdims.split('_'))) if len(ckpt_args.hdims) > 0 else "",
residual=getattr(ckpt_args, 'residual', False), postwhiten=ckpt_args.postwhiten,
featweights=ckpt_args.featweights, nfeat=ckpt_args.nfeat
).to(device)
self.device = device
msg = self.model.load_state_dict(ckpt['model'], strict=False)
assert all(k.startswith('backbone') for k in msg.missing_keys)
assert len(msg.unexpected_keys) == 0
self.imsize = ckpt_args.imsize
# load the asmk codebook
dname, bname = os.path.split(modelname) # TODO they should both be in the same file ?
bname_splits = bname.split('_')
cache_codebook_fname = os.path.join(dname, '_'.join(bname_splits[:-1]) + '_codebook.pkl')
assert os.path.isfile(cache_codebook_fname), cache_codebook_fname
asmk_params = {'index': {'gpu_id': 0}, 'train_codebook': {'codebook': {'size': '64k'}},
'build_ivf': {'kernel': {'binary': True}, 'ivf': {'use_idf': False},
'quantize': {'multiple_assignment': 1}, 'aggregate': {}},
'query_ivf': {'quantize': {'multiple_assignment': 5}, 'aggregate': {},
'search': {'topk': None},
'similarity': {'similarity_threshold': 0.0, 'alpha': 3.0}}}
asmk_params['train_codebook']['codebook']['size'] = ckpt_args.nclusters
self.asmk = asmk_method.ASMKMethod.initialize_untrained(asmk_params)
self.asmk = self.asmk.train_codebook(None, cache_path=cache_codebook_fname)
def __call__(self, input_imdir_or_imlistfile, outfile=None):
# get impaths
if isinstance(input_imdir_or_imlistfile, str):
impaths = get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile)
else:
impaths = input_imdir_or_imlistfile # we're assuming a list has been passed
print(f'Found {len(impaths)} images')
# build the database
feat, ids = extract_local_features(self.model, impaths, self.imsize, tocpu=True, device=self.device)
feat = feat.cpu().numpy()
ids = ids.cpu().numpy()
asmk_dataset = self.asmk.build_ivf(feat, ids)
# we actually retrieve the same set of images
metadata, query_ids, ranks, ranked_scores = asmk_dataset.query_ivf(feat, ids)
# well ... scores are actually reordered according to ranks ...
# so we redo it the other way around...
scores = np.empty_like(ranked_scores)
scores[np.arange(ranked_scores.shape[0])[:, None], ranks] = ranked_scores
# save
if outfile is not None:
if os.path.isdir(os.path.dirname(outfile)):
os.makedirs(os.path.dirname(outfile), exist_ok=True)
np.save(outfile, scores)
print(f'Scores matrix saved in {outfile}')
return scores