Spaces:
Configuration error
Configuration error
File size: 5,912 Bytes
35e2575 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Main Retriever class
# --------------------------------------------------------
import os
import argparse
import numpy as np
import torch
from mast3r.model import AsymmetricMASt3R
from mast3r.retrieval.model import RetrievalModel, extract_local_features
try:
import faiss
faiss.StandardGpuResources() # when loading the checkpoint, it will try to instanciate FaissGpuL2Index
except AttributeError as e:
import asmk.index
class FaissCpuL2Index(asmk.index.FaissL2Index):
def __init__(self, gpu_id):
super().__init__()
self.gpu_id = gpu_id
def _faiss_index_flat(self, dim):
"""Return initialized faiss.IndexFlatL2"""
return faiss.IndexFlatL2(dim)
asmk.index.FaissGpuL2Index = FaissCpuL2Index
from asmk import asmk_method # noqa
def get_args_parser():
parser = argparse.ArgumentParser('Retrieval scores from a set of retrieval', add_help=False, allow_abbrev=False)
parser.add_argument('--model', type=str, required=True,
help="shortname of a retrieval model or path to the corresponding .pth")
parser.add_argument('--input', type=str, required=True,
help="directory containing images or a file containing a list of image paths")
parser.add_argument('--outfile', type=str, required=True, help="numpy file where to store the matrix score")
return parser
def get_impaths(imlistfile):
with open(imlistfile, 'r') as fid:
impaths = [f for f in imlistfile.read().splitlines() if not f.startswith('#')
and len(f) > 0] # ignore comments and empty lines
return impaths
def get_impaths_from_imdir(imdir, extensions=['png', 'jpg', 'PNG', 'JPG']):
assert os.path.isdir(imdir)
impaths = [os.path.join(imdir, f) for f in sorted(os.listdir(imdir)) if any(f.endswith(ext) for ext in extensions)]
return impaths
def get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile):
if os.path.isfile(input_imdir_or_imlistfile):
return get_impaths(input_imdir_or_imlistfile)
else:
return get_impaths_from_imdir(input_imdir_or_imlistfile)
class Retriever(object):
def __init__(self, modelname, backbone=None, device='cuda'):
# load the model
assert os.path.isfile(modelname), modelname
print(f'Loading retrieval model from {modelname}')
ckpt = torch.load(modelname, 'cpu') # TODO from pretrained to download it automatically
ckpt_args = ckpt['args']
if backbone is None:
backbone = AsymmetricMASt3R.from_pretrained(ckpt_args.pretrained)
self.model = RetrievalModel(
backbone, freeze_backbone=ckpt_args.freeze_backbone, prewhiten=ckpt_args.prewhiten,
hdims=list(map(int, ckpt_args.hdims.split('_'))) if len(ckpt_args.hdims) > 0 else "",
residual=getattr(ckpt_args, 'residual', False), postwhiten=ckpt_args.postwhiten,
featweights=ckpt_args.featweights, nfeat=ckpt_args.nfeat
).to(device)
self.device = device
msg = self.model.load_state_dict(ckpt['model'], strict=False)
assert all(k.startswith('backbone') for k in msg.missing_keys)
assert len(msg.unexpected_keys) == 0
self.imsize = ckpt_args.imsize
# load the asmk codebook
dname, bname = os.path.split(modelname) # TODO they should both be in the same file ?
bname_splits = bname.split('_')
cache_codebook_fname = os.path.join(dname, '_'.join(bname_splits[:-1]) + '_codebook.pkl')
assert os.path.isfile(cache_codebook_fname), cache_codebook_fname
asmk_params = {'index': {'gpu_id': 0}, 'train_codebook': {'codebook': {'size': '64k'}},
'build_ivf': {'kernel': {'binary': True}, 'ivf': {'use_idf': False},
'quantize': {'multiple_assignment': 1}, 'aggregate': {}},
'query_ivf': {'quantize': {'multiple_assignment': 5}, 'aggregate': {},
'search': {'topk': None},
'similarity': {'similarity_threshold': 0.0, 'alpha': 3.0}}}
asmk_params['train_codebook']['codebook']['size'] = ckpt_args.nclusters
self.asmk = asmk_method.ASMKMethod.initialize_untrained(asmk_params)
self.asmk = self.asmk.train_codebook(None, cache_path=cache_codebook_fname)
def __call__(self, input_imdir_or_imlistfile, outfile=None):
# get impaths
if isinstance(input_imdir_or_imlistfile, str):
impaths = get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile)
else:
impaths = input_imdir_or_imlistfile # we're assuming a list has been passed
print(f'Found {len(impaths)} images')
# build the database
feat, ids = extract_local_features(self.model, impaths, self.imsize, tocpu=True, device=self.device)
feat = feat.cpu().numpy()
ids = ids.cpu().numpy()
asmk_dataset = self.asmk.build_ivf(feat, ids)
# we actually retrieve the same set of images
metadata, query_ids, ranks, ranked_scores = asmk_dataset.query_ivf(feat, ids)
# well ... scores are actually reordered according to ranks ...
# so we redo it the other way around...
scores = np.empty_like(ranked_scores)
scores[np.arange(ranked_scores.shape[0])[:, None], ranks] = ranked_scores
# save
if outfile is not None:
if os.path.isdir(os.path.dirname(outfile)):
os.makedirs(os.path.dirname(outfile), exist_ok=True)
np.save(outfile, scores)
print(f'Scores matrix saved in {outfile}')
return scores
|