File size: 5,449 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once
#include <torch/extension.h>
#include <tuple>
#include "utils/pytorch3d_cutils.h"

// Compute indices of K nearest neighbors in pointcloud p2 to points
// in pointcloud p1.
//
// Args:
//    p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
//        containing P1 points of dimension D.
//    p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
//        containing P2 points of dimension D.
//    lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
//    lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
//    norm: int specifying the norm for the distance (1 for L1, 2 for L2)
//    K: int giving the number of nearest points to return.
//    version: Integer telling which implementation to use.
//
// Returns:
//    p1_neighbor_idx: LongTensor of shape (N, P1, K), where
//        p1_neighbor_idx[n, i, k] = j means that the kth nearest
//        neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
//        It is padded with zeros so that it can be used easily in a later
//        gather() operation.
//
//    p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
//        distance from each point p1[n, p, :] to its K neighbors
//        p2[n, p1_neighbor_idx[n, p, k], :].

// CPU implementation.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const int norm,
    const int K);

// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const int norm,
    const int K,
    const int version);

// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const int norm,
    const int K,
    const int version) {
  if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
    CHECK_CUDA(p1);
    CHECK_CUDA(p2);
    return KNearestNeighborIdxCuda(
        p1, p2, lengths1, lengths2, norm, K, version);
#else
    AT_ERROR("Not compiled with GPU support.");
#endif
  }
  return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
}

// Compute gradients with respect to p1 and p2
//
// Args:
//    p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
//        containing P1 points of dimension D.
//    p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
//        containing P2 points of dimension D.
//    lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
//    lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
//    p1_neighbor_idx: LongTensor of shape (N, P1, K), where
//        p1_neighbor_idx[n, i, k] = j means that the kth nearest
//        neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
//        It is padded with zeros so that it can be used easily in a later
//        gather() operation. This is computed from the forward pass.
//    norm: int specifying the norm for the distance (1 for L1, 2 for L2)
//    grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
//        gradients.
//
// Returns:
//    grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
//        wrt p1.
//    grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
//        wrt p2.

// CPU implementation.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const at::Tensor& idxs,
    const int norm,
    const at::Tensor& grad_dists);

// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const at::Tensor& idxs,
    const int norm,
    const at::Tensor& grad_dists);

// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
    const at::Tensor& p1,
    const at::Tensor& p2,
    const at::Tensor& lengths1,
    const at::Tensor& lengths2,
    const at::Tensor& idxs,
    const int norm,
    const at::Tensor& grad_dists) {
  if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
    CHECK_CUDA(p1);
    CHECK_CUDA(p2);
    return KNearestNeighborBackwardCuda(
        p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
#else
    AT_ERROR("Not compiled with GPU support.");
#endif
  }
  return KNearestNeighborBackwardCpu(
      p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
}

// Utility to check whether a KNN version can be used.
//
// Args:
//    version: Integer in the range 0 <= version <= 3 indicating one of our
//        KNN implementations.
//    D: Number of dimensions for the input and query point clouds
//    K: Number of neighbors to be found
//
// Returns:
//    Whether the indicated KNN version can be used.
bool KnnCheckVersion(int version, const int64_t D, const int64_t K);