File size: 4,192 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <torch/extension.h>
#include <queue>
#include <tuple>

std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const int norm,
    const int K) {
  const int N = p1.size(0);
  const int P1 = p1.size(1);
  const int D = p1.size(2);

  auto long_opts = lengths1.options().dtype(torch::kInt64);
  torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
  torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());

  auto p1_a = p1.accessor<float, 3>();
  auto p2_a = p2.accessor<float, 3>();
  auto lengths1_a = lengths1.accessor<int64_t, 1>();
  auto lengths2_a = lengths2.accessor<int64_t, 1>();
  auto idxs_a = idxs.accessor<int64_t, 3>();
  auto dists_a = dists.accessor<float, 3>();

  for (int n = 0; n < N; ++n) {
    const int64_t length1 = lengths1_a[n];
    const int64_t length2 = lengths2_a[n];
    for (int64_t i1 = 0; i1 < length1; ++i1) {
      // Use a priority queue to store (distance, index) tuples.
      std::priority_queue<std::tuple<float, int>> q;
      for (int64_t i2 = 0; i2 < length2; ++i2) {
        float dist = 0;
        for (int d = 0; d < D; ++d) {
          float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
          if (norm == 1) {
            dist += abs(diff);
          } else { // norm is 2 (default)
            dist += diff * diff;
          }
        }
        int size = static_cast<int>(q.size());
        if (size < K || dist < std::get<0>(q.top())) {
          q.emplace(dist, i2);
          if (size >= K) {
            q.pop();
          }
        }
      }
      while (!q.empty()) {
        auto t = q.top();
        q.pop();
        const int k = q.size();
        dists_a[n][i1][k] = std::get<0>(t);
        idxs_a[n][i1][k] = std::get<1>(t);
      }
    }
  }
  return std::make_tuple(idxs, dists);
}

// ------------------------------------------------------------- //
//                   Backward Operators                          //
// ------------------------------------------------------------- //

std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const at::Tensor& idxs,
    const int norm,
    const at::Tensor& grad_dists) {
  const int N = p1.size(0);
  const int P1 = p1.size(1);
  const int D = p1.size(2);
  const int P2 = p2.size(1);
  const int K = idxs.size(2);

  torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
  torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());

  auto p1_a = p1.accessor<float, 3>();
  auto p2_a = p2.accessor<float, 3>();
  auto lengths1_a = lengths1.accessor<int64_t, 1>();
  auto lengths2_a = lengths2.accessor<int64_t, 1>();
  auto idxs_a = idxs.accessor<int64_t, 3>();
  auto grad_dists_a = grad_dists.accessor<float, 3>();
  auto grad_p1_a = grad_p1.accessor<float, 3>();
  auto grad_p2_a = grad_p2.accessor<float, 3>();

  for (int n = 0; n < N; ++n) {
    const int64_t length1 = lengths1_a[n];
    int64_t length2 = lengths2_a[n];
    length2 = (length2 < K) ? length2 : K;
    for (int64_t i1 = 0; i1 < length1; ++i1) {
      for (int64_t k = 0; k < length2; ++k) {
        const int64_t i2 = idxs_a[n][i1][k];
        // If the index is the pad value of -1 then ignore it
        if (i2 == -1) {
          continue;
        }
        for (int64_t d = 0; d < D; ++d) {
          float diff = 0.0;
          if (norm == 1) {
            float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
            diff = grad_dists_a[n][i1][k] * sign;
          } else { // norm is 2 (default)
            diff = 2.0f * grad_dists_a[n][i1][k] *
                (p1_a[n][i1][d] - p2_a[n][i2][d]);
          }
          grad_p1_a[n][i1][d] += diff;
          grad_p2_a[n][i2][d] += -1.0f * diff;
        }
      }
    }
  }
  return std::make_tuple(grad_p1, grad_p2);
}