Spaces:
Configuration error
Configuration error
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# Main Retriever class | |
# -------------------------------------------------------- | |
import os | |
import argparse | |
import numpy as np | |
import torch | |
from mast3r.model import AsymmetricMASt3R | |
from mast3r.retrieval.model import RetrievalModel, extract_local_features | |
try: | |
import faiss | |
faiss.StandardGpuResources() # when loading the checkpoint, it will try to instanciate FaissGpuL2Index | |
except AttributeError as e: | |
import asmk.index | |
class FaissCpuL2Index(asmk.index.FaissL2Index): | |
def __init__(self, gpu_id): | |
super().__init__() | |
self.gpu_id = gpu_id | |
def _faiss_index_flat(self, dim): | |
"""Return initialized faiss.IndexFlatL2""" | |
return faiss.IndexFlatL2(dim) | |
asmk.index.FaissGpuL2Index = FaissCpuL2Index | |
from asmk import asmk_method # noqa | |
def get_args_parser(): | |
parser = argparse.ArgumentParser('Retrieval scores from a set of retrieval', add_help=False, allow_abbrev=False) | |
parser.add_argument('--model', type=str, required=True, | |
help="shortname of a retrieval model or path to the corresponding .pth") | |
parser.add_argument('--input', type=str, required=True, | |
help="directory containing images or a file containing a list of image paths") | |
parser.add_argument('--outfile', type=str, required=True, help="numpy file where to store the matrix score") | |
return parser | |
def get_impaths(imlistfile): | |
with open(imlistfile, 'r') as fid: | |
impaths = [f for f in imlistfile.read().splitlines() if not f.startswith('#') | |
and len(f) > 0] # ignore comments and empty lines | |
return impaths | |
def get_impaths_from_imdir(imdir, extensions=['png', 'jpg', 'PNG', 'JPG']): | |
assert os.path.isdir(imdir) | |
impaths = [os.path.join(imdir, f) for f in sorted(os.listdir(imdir)) if any(f.endswith(ext) for ext in extensions)] | |
return impaths | |
def get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile): | |
if os.path.isfile(input_imdir_or_imlistfile): | |
return get_impaths(input_imdir_or_imlistfile) | |
else: | |
return get_impaths_from_imdir(input_imdir_or_imlistfile) | |
class Retriever(object): | |
def __init__(self, modelname, backbone=None, device='cuda'): | |
# load the model | |
assert os.path.isfile(modelname), modelname | |
print(f'Loading retrieval model from {modelname}') | |
ckpt = torch.load(modelname, 'cpu') # TODO from pretrained to download it automatically | |
ckpt_args = ckpt['args'] | |
if backbone is None: | |
backbone = AsymmetricMASt3R.from_pretrained(ckpt_args.pretrained) | |
self.model = RetrievalModel( | |
backbone, freeze_backbone=ckpt_args.freeze_backbone, prewhiten=ckpt_args.prewhiten, | |
hdims=list(map(int, ckpt_args.hdims.split('_'))) if len(ckpt_args.hdims) > 0 else "", | |
residual=getattr(ckpt_args, 'residual', False), postwhiten=ckpt_args.postwhiten, | |
featweights=ckpt_args.featweights, nfeat=ckpt_args.nfeat | |
).to(device) | |
self.device = device | |
msg = self.model.load_state_dict(ckpt['model'], strict=False) | |
assert all(k.startswith('backbone') for k in msg.missing_keys) | |
assert len(msg.unexpected_keys) == 0 | |
self.imsize = ckpt_args.imsize | |
# load the asmk codebook | |
dname, bname = os.path.split(modelname) # TODO they should both be in the same file ? | |
bname_splits = bname.split('_') | |
cache_codebook_fname = os.path.join(dname, '_'.join(bname_splits[:-1]) + '_codebook.pkl') | |
assert os.path.isfile(cache_codebook_fname), cache_codebook_fname | |
asmk_params = {'index': {'gpu_id': 0}, 'train_codebook': {'codebook': {'size': '64k'}}, | |
'build_ivf': {'kernel': {'binary': True}, 'ivf': {'use_idf': False}, | |
'quantize': {'multiple_assignment': 1}, 'aggregate': {}}, | |
'query_ivf': {'quantize': {'multiple_assignment': 5}, 'aggregate': {}, | |
'search': {'topk': None}, | |
'similarity': {'similarity_threshold': 0.0, 'alpha': 3.0}}} | |
asmk_params['train_codebook']['codebook']['size'] = ckpt_args.nclusters | |
self.asmk = asmk_method.ASMKMethod.initialize_untrained(asmk_params) | |
self.asmk = self.asmk.train_codebook(None, cache_path=cache_codebook_fname) | |
def __call__(self, input_imdir_or_imlistfile, outfile=None): | |
# get impaths | |
if isinstance(input_imdir_or_imlistfile, str): | |
impaths = get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile) | |
else: | |
impaths = input_imdir_or_imlistfile # we're assuming a list has been passed | |
print(f'Found {len(impaths)} images') | |
# build the database | |
feat, ids = extract_local_features(self.model, impaths, self.imsize, tocpu=True, device=self.device) | |
feat = feat.cpu().numpy() | |
ids = ids.cpu().numpy() | |
asmk_dataset = self.asmk.build_ivf(feat, ids) | |
# we actually retrieve the same set of images | |
metadata, query_ids, ranks, ranked_scores = asmk_dataset.query_ivf(feat, ids) | |
# well ... scores are actually reordered according to ranks ... | |
# so we redo it the other way around... | |
scores = np.empty_like(ranked_scores) | |
scores[np.arange(ranked_scores.shape[0])[:, None], ranks] = ranked_scores | |
# save | |
if outfile is not None: | |
if os.path.isdir(os.path.dirname(outfile)): | |
os.makedirs(os.path.dirname(outfile), exist_ok=True) | |
np.save(outfile, scores) | |
print(f'Scores matrix saved in {outfile}') | |
return scores | |