File size: 2,546 Bytes
35e2575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Building the graph based on retrieval results.
# --------------------------------------------------------
import numpy as np


def farthest_point_sampling(dist, N=None, dist_thresh=None):
    """Farthest point sampling.

    Args:
        dist: NxN distance matrix.
        N: Number of points to sample.
        dist_thresh: Distance threshold. Point sampling terminates once the
                     maximum distance is below this threshold.

    Returns:
        indices: Indices of the sampled points.
    """

    assert N is not None or dist_thresh is not None, "Either N or min_dist must be provided."

    if N is None:
        N = dist.shape[0]

    indices = []
    distances = [0]
    indices.append(np.random.choice(dist.shape[0]))
    for i in range(1, N):
        d = dist[indices].min(axis=0)
        bst = d.argmax()
        bst_dist = d[bst]
        if dist_thresh is not None and bst_dist < dist_thresh:
            break
        indices.append(bst)
        distances.append(bst_dist)
    return np.array(indices), np.array(distances)


def make_pairs_fps(sim_mat, Na=20, tokK=1, dist_thresh=None):
    dist_mat = 1 - sim_mat

    pairs = set()
    keyimgs_idx = np.array([])
    if Na != 0:
        keyimgs_idx, _ = farthest_point_sampling(dist_mat, N=Na, dist_thresh=dist_thresh)

        # 1. Complete graph between key images
        for i in range(len(keyimgs_idx)):
            for j in range(i + 1, len(keyimgs_idx)):
                idx_i, idx_j = keyimgs_idx[i], keyimgs_idx[j]
                pairs.add((idx_i, idx_j))

        # 2. Connect non-key images to the earest key image
        keyimg_dist_mat = dist_mat[:, keyimgs_idx]
        for i in range(keyimg_dist_mat.shape[0]):
            if i in keyimgs_idx:
                continue
            j = keyimg_dist_mat[i].argmax()
            i1, i2 = min(i, keyimgs_idx[j]), max(i, keyimgs_idx[j])
            if i1 != i2 and (i1, i2) not in pairs:
                pairs.add((i1, i2))

    # 3. Add some local connections (k-NN) for each view
    if tokK > 0:
        for i in range(dist_mat.shape[0]):
            idx = dist_mat[i].argsort()[:tokK]
            for j in idx:
                i1, i2 = min(i, j), max(i, j)
                if i1 != i2 and (i1, i2) not in pairs:
                    pairs.add((i1, i2))

    pairs = list(pairs)

    return pairs, keyimgs_idx