File size: 1,566 Bytes
4b532c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Tuple

import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])


class ThreeNN(Function):
    """Find the top-3 nearest neighbors of the target set from the source set.



    Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_

    for more details.

    """

    @staticmethod
    def forward(ctx, target: torch.Tensor,

                source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:

            target (Tensor): shape (B, N, 3), points set that needs to

                find the nearest neighbors.

            source (Tensor): shape (B, M, 3), points set that is used

                to find the nearest neighbors of points in target set.



        Returns:

            Tensor: shape (B, N, 3), L2 distance of each point in target

                set to their corresponding nearest neighbors.

        """
        target = target.contiguous()
        source = source.contiguous()

        B, N, _ = target.size()
        m = source.size(1)
        dist2 = torch.cuda.FloatTensor(B, N, 3)
        idx = torch.cuda.IntTensor(B, N, 3)

        ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
        if torch.__version__ != 'parrots':
            ctx.mark_non_differentiable(idx)

        return torch.sqrt(dist2), idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None


three_nn = ThreeNN.apply