Realcat
fix: eloftr
63f3cf2
import argparse
import torch
from pathlib import Path
import h5py
import logging
from tqdm import tqdm
import pprint
import localization.matchers as matchers
from localization.base_model import dynamic_load
from colmap_utils.parsers import names_to_pair
confs = {
'gm': {
'output': 'gm',
'model': {
'name': 'gm',
'weight_path': 'weights/imp_gm.900.pth',
'sinkhorn_iterations': 20,
},
},
'gml': {
'output': 'gml',
'model': {
'name': 'gml',
'weight_path': 'weights/imp_gml.920.pth',
'sinkhorn_iterations': 20,
},
},
'adagml': {
'output': 'adagml',
'model': {
'name': 'adagml',
'weight_path': 'weights/imp_adagml.80.pth',
'sinkhorn_iterations': 20,
},
},
'superglue': {
'output': 'superglue',
'model': {
'name': 'superglue',
'weights': 'outdoor',
'sinkhorn_iterations': 20,
'weight_path': 'weights/superglue_outdoor.pth',
},
},
'NNM': {
'output': 'NNM',
'model': {
'name': 'nearest_neighbor',
'do_mutual_check': True,
'distance_threshold': None,
},
},
}
@torch.no_grad()
def main(conf, pairs, features, export_dir, exhaustive=False):
logging.info('Matching local features with configuration:'
f'\n{pprint.pformat(conf)}')
feature_path = Path(export_dir, features + '.h5')
assert feature_path.exists(), feature_path
feature_file = h5py.File(str(feature_path), 'r')
pairs_name = pairs.stem
if not exhaustive:
assert pairs.exists(), pairs
with open(pairs, 'r') as f:
pair_list = f.read().rstrip('\n').split('\n')
elif exhaustive:
logging.info(f'Writing exhaustive match pairs to {pairs}.')
assert not pairs.exists(), pairs
# get the list of images from the feature file
images = []
feature_file.visititems(
lambda name, obj: images.append(obj.parent.name.strip('/'))
if isinstance(obj, h5py.Dataset) else None)
images = list(set(images))
pair_list = [' '.join((images[i], images[j]))
for i in range(len(images)) for j in range(i)]
with open(str(pairs), 'w') as f:
f.write('\n'.join(pair_list))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(matchers, conf['model']['name'])
model = Model(conf['model']).eval().to(device)
match_name = f'{features}-{conf["output"]}-{pairs_name}'
match_path = Path(export_dir, match_name + '.h5')
match_file = h5py.File(str(match_path), 'a')
matched = set()
for pair in tqdm(pair_list, smoothing=.1):
name0, name1 = pair.split(' ')
pair = names_to_pair(name0, name1)
# Avoid to recompute duplicates to save time
if len({(name0, name1), (name1, name0)} & matched) \
or pair in match_file:
continue
data = {}
feats0, feats1 = feature_file[name0], feature_file[name1]
for k in feats1.keys():
# data[k + '0'] = feats0[k].__array__()
if k == 'descriptors':
data[k + '0'] = feats0[k][()].transpose() # [N D]
else:
data[k + '0'] = feats0[k][()]
for k in feats1.keys():
# data[k + '1'] = feats1[k].__array__()
# data[k + '1'] = feats1[k][()].transpose() # [N D]
if k == 'descriptors':
data[k + '1'] = feats1[k][()].transpose() # [N D]
else:
data[k + '1'] = feats1[k][()]
data = {k: torch.from_numpy(v)[None].float().to(device)
for k, v in data.items()}
# some matchers might expect an image but only use its size
data['image0'] = torch.empty((1, 1,) + tuple(feats0['image_size'])[::-1])
data['image1'] = torch.empty((1, 1,) + tuple(feats1['image_size'])[::-1])
pred = model(data)
grp = match_file.create_group(pair)
matches = pred['matches0'][0].cpu().short().numpy()
grp.create_dataset('matches0', data=matches)
if 'matching_scores0' in pred:
scores = pred['matching_scores0'][0].cpu().half().numpy()
grp.create_dataset('matching_scores0', data=scores)
matched |= {(name0, name1), (name1, name0)}
match_file.close()
logging.info('Finished exporting matches.')
return match_path
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--export_dir', type=Path, required=True)
parser.add_argument('--features', type=str, required=True)
parser.add_argument('--pairs', type=Path, required=True)
parser.add_argument('--conf', type=str, required=True, choices=list(confs.keys()))
parser.add_argument('--exhaustive', action='store_true')
args = parser.parse_args()
main(confs[args.conf], args.pairs, args.features, args.export_dir,
exhaustive=args.exhaustive)