File size: 5,912 Bytes
35e2575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Main Retriever class
# --------------------------------------------------------
import os
import argparse
import numpy as np
import torch

from mast3r.model import AsymmetricMASt3R
from mast3r.retrieval.model import RetrievalModel, extract_local_features

try:
    import faiss
    faiss.StandardGpuResources()  # when loading the checkpoint, it will try to instanciate FaissGpuL2Index
except AttributeError as e:
    import asmk.index

    class FaissCpuL2Index(asmk.index.FaissL2Index):
        def __init__(self, gpu_id):
            super().__init__()
            self.gpu_id = gpu_id

        def _faiss_index_flat(self, dim):
            """Return initialized faiss.IndexFlatL2"""
            return faiss.IndexFlatL2(dim)

    asmk.index.FaissGpuL2Index = FaissCpuL2Index

from asmk import asmk_method  # noqa


def get_args_parser():
    parser = argparse.ArgumentParser('Retrieval scores from a set of retrieval', add_help=False, allow_abbrev=False)
    parser.add_argument('--model', type=str, required=True,
                        help="shortname of a retrieval model or path to the corresponding .pth")
    parser.add_argument('--input', type=str, required=True,
                        help="directory containing images or a file containing a list of image paths")
    parser.add_argument('--outfile', type=str, required=True, help="numpy file where to store the matrix score")
    return parser


def get_impaths(imlistfile):
    with open(imlistfile, 'r') as fid:
        impaths = [f for f in imlistfile.read().splitlines() if not f.startswith('#')
                   and len(f) > 0]  # ignore comments and empty lines
    return impaths


def get_impaths_from_imdir(imdir, extensions=['png', 'jpg', 'PNG', 'JPG']):
    assert os.path.isdir(imdir)
    impaths = [os.path.join(imdir, f) for f in sorted(os.listdir(imdir)) if any(f.endswith(ext) for ext in extensions)]
    return impaths


def get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile):
    if os.path.isfile(input_imdir_or_imlistfile):
        return get_impaths(input_imdir_or_imlistfile)
    else:
        return get_impaths_from_imdir(input_imdir_or_imlistfile)


class Retriever(object):
    def __init__(self, modelname, backbone=None, device='cuda'):
        # load the model
        assert os.path.isfile(modelname), modelname
        print(f'Loading retrieval model from {modelname}')
        ckpt = torch.load(modelname, 'cpu')  # TODO from pretrained to download it automatically
        ckpt_args = ckpt['args']
        if backbone is None:
            backbone = AsymmetricMASt3R.from_pretrained(ckpt_args.pretrained)
        self.model = RetrievalModel(
            backbone, freeze_backbone=ckpt_args.freeze_backbone, prewhiten=ckpt_args.prewhiten,
            hdims=list(map(int, ckpt_args.hdims.split('_'))) if len(ckpt_args.hdims) > 0 else "",
            residual=getattr(ckpt_args, 'residual', False), postwhiten=ckpt_args.postwhiten,
            featweights=ckpt_args.featweights, nfeat=ckpt_args.nfeat
        ).to(device)
        self.device = device
        msg = self.model.load_state_dict(ckpt['model'], strict=False)
        assert all(k.startswith('backbone') for k in msg.missing_keys)
        assert len(msg.unexpected_keys) == 0
        self.imsize = ckpt_args.imsize

        # load the asmk codebook
        dname, bname = os.path.split(modelname)  # TODO they should both be in the same file ?
        bname_splits = bname.split('_')
        cache_codebook_fname = os.path.join(dname, '_'.join(bname_splits[:-1]) + '_codebook.pkl')
        assert os.path.isfile(cache_codebook_fname), cache_codebook_fname
        asmk_params = {'index': {'gpu_id': 0}, 'train_codebook': {'codebook': {'size': '64k'}},
                       'build_ivf': {'kernel': {'binary': True}, 'ivf': {'use_idf': False},
                                     'quantize': {'multiple_assignment': 1}, 'aggregate': {}},
                       'query_ivf': {'quantize': {'multiple_assignment': 5}, 'aggregate': {},
                                     'search': {'topk': None},
                                     'similarity': {'similarity_threshold': 0.0, 'alpha': 3.0}}}
        asmk_params['train_codebook']['codebook']['size'] = ckpt_args.nclusters
        self.asmk = asmk_method.ASMKMethod.initialize_untrained(asmk_params)
        self.asmk = self.asmk.train_codebook(None, cache_path=cache_codebook_fname)

    def __call__(self, input_imdir_or_imlistfile, outfile=None):
        # get impaths
        if isinstance(input_imdir_or_imlistfile, str):
            impaths = get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile)
        else:
            impaths = input_imdir_or_imlistfile  # we're assuming a list has been passed
        print(f'Found {len(impaths)} images')

        # build the database
        feat, ids = extract_local_features(self.model, impaths, self.imsize, tocpu=True, device=self.device)
        feat = feat.cpu().numpy()
        ids = ids.cpu().numpy()
        asmk_dataset = self.asmk.build_ivf(feat, ids)

        # we actually retrieve the same set of images
        metadata, query_ids, ranks, ranked_scores = asmk_dataset.query_ivf(feat, ids)

        # well ... scores are actually reordered according to ranks ...
        # so we redo it the other way around...
        scores = np.empty_like(ranked_scores)
        scores[np.arange(ranked_scores.shape[0])[:, None], ranks] = ranked_scores

        # save
        if outfile is not None:
            if os.path.isdir(os.path.dirname(outfile)):
                os.makedirs(os.path.dirname(outfile), exist_ok=True)
            np.save(outfile, scores)
            print(f'Scores matrix saved in {outfile}')
        return scores