Spaces:
Sleeping
Sleeping
import argparse | |
import collections.abc as collections | |
from pathlib import Path | |
from typing import Optional | |
import h5py | |
import numpy as np | |
import torch | |
from . import logger | |
from .utils.io import list_h5_names | |
from .utils.parsers import parse_image_lists | |
from .utils.read_write_model import read_images_binary | |
def parse_names(prefix, names, names_all): | |
if prefix is not None: | |
if not isinstance(prefix, str): | |
prefix = tuple(prefix) | |
names = [n for n in names_all if n.startswith(prefix)] | |
if len(names) == 0: | |
raise ValueError( | |
f"Could not find any image with the prefix `{prefix}`." | |
) | |
elif names is not None: | |
if isinstance(names, (str, Path)): | |
names = parse_image_lists(names) | |
elif isinstance(names, collections.Iterable): | |
names = list(names) | |
else: | |
raise ValueError( | |
f"Unknown type of image list: {names}." | |
"Provide either a list or a path to a list file." | |
) | |
else: | |
names = names_all | |
return names | |
def get_descriptors(names, path, name2idx=None, key="global_descriptor"): | |
if name2idx is None: | |
with h5py.File(str(path), "r", libver="latest") as fd: | |
desc = [fd[n][key].__array__() for n in names] | |
else: | |
desc = [] | |
for n in names: | |
with h5py.File(str(path[name2idx[n]]), "r", libver="latest") as fd: | |
desc.append(fd[n][key].__array__()) | |
return torch.from_numpy(np.stack(desc, 0)).float() | |
def pairs_from_score_matrix( | |
scores: torch.Tensor, | |
invalid: np.array, | |
num_select: int, | |
min_score: Optional[float] = None, | |
): | |
assert scores.shape == invalid.shape | |
if isinstance(scores, np.ndarray): | |
scores = torch.from_numpy(scores) | |
invalid = torch.from_numpy(invalid).to(scores.device) | |
if min_score is not None: | |
invalid |= scores < min_score | |
scores.masked_fill_(invalid, float("-inf")) | |
topk = torch.topk(scores, num_select, dim=1) | |
indices = topk.indices.cpu().numpy() | |
valid = topk.values.isfinite().cpu().numpy() | |
pairs = [] | |
for i, j in zip(*np.where(valid)): | |
pairs.append((i, indices[i, j])) | |
return pairs | |
def main( | |
descriptors, | |
output, | |
num_matched, | |
query_prefix=None, | |
query_list=None, | |
db_prefix=None, | |
db_list=None, | |
db_model=None, | |
db_descriptors=None, | |
): | |
logger.info("Extracting image pairs from a retrieval database.") | |
# We handle multiple reference feature files. | |
# We only assume that names are unique among them and map names to files. | |
if db_descriptors is None: | |
db_descriptors = descriptors | |
if isinstance(db_descriptors, (Path, str)): | |
db_descriptors = [db_descriptors] | |
name2db = { | |
n: i for i, p in enumerate(db_descriptors) for n in list_h5_names(p) | |
} | |
db_names_h5 = list(name2db.keys()) | |
query_names_h5 = list_h5_names(descriptors) | |
if db_model: | |
images = read_images_binary(db_model / "images.bin") | |
db_names = [i.name for i in images.values()] | |
else: | |
db_names = parse_names(db_prefix, db_list, db_names_h5) | |
if len(db_names) == 0: | |
raise ValueError("Could not find any database image.") | |
query_names = parse_names(query_prefix, query_list, query_names_h5) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
db_desc = get_descriptors(db_names, db_descriptors, name2db) | |
query_desc = get_descriptors(query_names, descriptors) | |
sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device)) | |
# Avoid self-matching | |
self = np.array(query_names)[:, None] == np.array(db_names)[None] | |
pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) | |
pairs = [(query_names[i], db_names[j]) for i, j in pairs] | |
logger.info(f"Found {len(pairs)} pairs.") | |
with open(output, "w") as f: | |
f.write("\n".join(" ".join([i, j]) for i, j in pairs)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--descriptors", type=Path, required=True) | |
parser.add_argument("--output", type=Path, required=True) | |
parser.add_argument("--num_matched", type=int, required=True) | |
parser.add_argument("--query_prefix", type=str, nargs="+") | |
parser.add_argument("--query_list", type=Path) | |
parser.add_argument("--db_prefix", type=str, nargs="+") | |
parser.add_argument("--db_list", type=Path) | |
parser.add_argument("--db_model", type=Path) | |
parser.add_argument("--db_descriptors", type=Path) | |
args = parser.parse_args() | |
main(**args.__dict__) | |