Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,264 Bytes
1ea89dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from collections import namedtuple
from typing import Union
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
_KNN = namedtuple("KNN", "dists idx knn")
class _knn_points(Function):
"""
Torch autograd Function wrapper for KNN C++/CUDA implementations.
"""
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx,
p1,
p2,
lengths1,
lengths2,
K,
version,
norm: int = 2,
return_sorted: bool = True,
):
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
Returns:
p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
"""
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
# sort KNN in ascending order if K > 1
if K > 1 and return_sorted:
if lengths2.min() < K:
P1 = p1.shape[1]
mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
# mask has shape [N, K], true where dists irrelevant
mask = mask[:, None].expand(-1, P1, -1)
# mask has shape [N, P1, K], true where dists irrelevant
dists[mask] = float("inf")
dists, sort_idx = dists.sort(dim=2)
dists[mask] = 0
else:
dists, sort_idx = dists.sort(dim=2)
idx = idx.gather(2, sort_idx)
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
ctx.mark_non_differentiable(idx)
ctx.norm = norm
return dists, idx
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
norm = ctx.norm
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
if not (p1.dtype == torch.float32):
p1 = p1.float()
if not (p2.dtype == torch.float32):
p2 = p2.float()
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, norm, grad_dists
)
return grad_p1, grad_p2, None, None, None, None, None, None
def knn_points(
p1: torch.Tensor,
p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None,
norm: int = 2,
K: int = 1,
version: int = -1,
return_nn: bool = False,
return_sorted: bool = True,
) -> _KNN:
"""
K-Nearest neighbors on point clouds.
Args:
p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
containing up to P1 points of dimension D.
p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
containing up to P2 points of dimension D.
lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
length of each pointcloud in p1. Or None to indicate that every cloud has
length P1.
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
Returns:
dists: Tensor of shape (N, P1, K) giving the squared distances to
the nearest neighbors. This is padded with zeros both where a cloud in p2
has fewer than K points and where a cloud in p1 has fewer than P1 points.
idx: LongTensor of shape (N, P1, K) giving the indices of the
K nearest neighbors from points in p1 to points in p2.
Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1
points.
nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
for `p1[n, i]`. Returned if `return_nn` is True.
The nearest neighbors are collected using `knn_gather`
.. code-block::
p2_nn = knn_gather(p2, p1_idx, lengths2)
which is a helper function that allows indexing any tensor of shape (N, P2, U) with
the indices `p1_idx` returned by `knn_points`. The output is a tensor
of shape (N, P1, K, U).
"""
if p1.shape[0] != p2.shape[0]:
raise ValueError("pts1 and pts2 must have the same batch dimension.")
if p1.shape[2] != p2.shape[2]:
raise ValueError("pts1 and pts2 must have the same point dimension.")
p1 = p1.contiguous()
p2 = p2.contiguous()
P1 = p1.shape[1]
P2 = p2.shape[1]
if lengths1 is None:
lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
if lengths2 is None:
lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
p1_dists, p1_idx = _knn_points.apply(
p1, p2, lengths1, lengths2, K, version, norm, return_sorted
)
p2_nn = None
if return_nn:
p2_nn = knn_gather(p2, p1_idx, lengths2)
return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None)
def knn_gather(
x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None
):
"""
A helper function for knn that allows indexing a tensor x with the indices `idx`
returned by `knn_points`.
For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)`
where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D),
then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`.
It can also be applied for any tensor x of shape (N, M, U) where U != D.
Args:
x: Tensor of shape (N, M, U) containing U-dimensional features to
be gathered.
idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`.
lengths: LongTensor of shape (N,) of values in the range [0, M], giving the
length of each example in the batch in x. Or None to indicate that every
example has length M.
Returns:
x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x
with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`.
If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0.
"""
N, M, U = x.shape
_N, L, K = idx.shape
if N != _N:
raise ValueError("x and idx must have same batch dimension.")
if lengths is None:
lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device)
idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U)
# idx_expanded has shape [N, L, K, U]
x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)
# p2_nn has shape [N, L, K, U]
needs_mask = lengths.min() < K
if needs_mask:
# mask has shape [N, K], true where idx is irrelevant because
# there is less number of points in p2 than K
mask = lengths[:, None] <= torch.arange(K, device=x.device)[None]
# expand mask to shape [N, L, K, U]
mask = mask[:, None].expand(-1, L, -1)
mask = mask[:, :, :, None].expand(-1, -1, -1, U)
x_out[mask] = 0.0
return x_out
|