Spaces:
Configuration error
Configuration error
File size: 2,546 Bytes
35e2575 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Building the graph based on retrieval results.
# --------------------------------------------------------
import numpy as np
def farthest_point_sampling(dist, N=None, dist_thresh=None):
"""Farthest point sampling.
Args:
dist: NxN distance matrix.
N: Number of points to sample.
dist_thresh: Distance threshold. Point sampling terminates once the
maximum distance is below this threshold.
Returns:
indices: Indices of the sampled points.
"""
assert N is not None or dist_thresh is not None, "Either N or min_dist must be provided."
if N is None:
N = dist.shape[0]
indices = []
distances = [0]
indices.append(np.random.choice(dist.shape[0]))
for i in range(1, N):
d = dist[indices].min(axis=0)
bst = d.argmax()
bst_dist = d[bst]
if dist_thresh is not None and bst_dist < dist_thresh:
break
indices.append(bst)
distances.append(bst_dist)
return np.array(indices), np.array(distances)
def make_pairs_fps(sim_mat, Na=20, tokK=1, dist_thresh=None):
dist_mat = 1 - sim_mat
pairs = set()
keyimgs_idx = np.array([])
if Na != 0:
keyimgs_idx, _ = farthest_point_sampling(dist_mat, N=Na, dist_thresh=dist_thresh)
# 1. Complete graph between key images
for i in range(len(keyimgs_idx)):
for j in range(i + 1, len(keyimgs_idx)):
idx_i, idx_j = keyimgs_idx[i], keyimgs_idx[j]
pairs.add((idx_i, idx_j))
# 2. Connect non-key images to the earest key image
keyimg_dist_mat = dist_mat[:, keyimgs_idx]
for i in range(keyimg_dist_mat.shape[0]):
if i in keyimgs_idx:
continue
j = keyimg_dist_mat[i].argmax()
i1, i2 = min(i, keyimgs_idx[j]), max(i, keyimgs_idx[j])
if i1 != i2 and (i1, i2) not in pairs:
pairs.add((i1, i2))
# 3. Add some local connections (k-NN) for each view
if tokK > 0:
for i in range(dist_mat.shape[0]):
idx = dist_mat[i].argsort()[:tokK]
for j in idx:
i1, i2 = min(i, j), max(i, j)
if i1 != i2 and (i1, i2) not in pairs:
pairs.add((i1, i2))
pairs = list(pairs)
return pairs, keyimgs_idx
|