|
|
|
import torch |
|
|
|
from torch.utils.cpp_extension import load |
|
cd = load(name="build", |
|
sources=["pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp", |
|
"pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu"], |
|
build_directory="pyTorchChamferDistance/build") |
|
|
|
class ChamferDistanceFunction(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, xyz1, xyz2): |
|
batchsize, n, _ = xyz1.size() |
|
_, m, _ = xyz2.size() |
|
xyz1 = xyz1.contiguous() |
|
xyz2 = xyz2.contiguous() |
|
dist1 = torch.zeros(batchsize, n) |
|
dist2 = torch.zeros(batchsize, m) |
|
|
|
idx1 = torch.zeros(batchsize, n, dtype=torch.int) |
|
idx2 = torch.zeros(batchsize, m, dtype=torch.int) |
|
|
|
if not xyz1.is_cuda: |
|
cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) |
|
else: |
|
dist1 = dist1.cuda() |
|
dist2 = dist2.cuda() |
|
idx1 = idx1.cuda() |
|
idx2 = idx2.cuda() |
|
cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) |
|
|
|
ctx.save_for_backward(xyz1, xyz2, idx1, idx2) |
|
|
|
return dist1, dist2, idx1, idx2 |
|
|
|
@staticmethod |
|
def backward(ctx, graddist1, graddist2, *args): |
|
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors |
|
|
|
graddist1 = graddist1.contiguous() |
|
graddist2 = graddist2.contiguous() |
|
|
|
gradxyz1 = torch.zeros(xyz1.size()) |
|
gradxyz2 = torch.zeros(xyz2.size()) |
|
|
|
if not graddist1.is_cuda: |
|
cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) |
|
else: |
|
gradxyz1 = gradxyz1.cuda() |
|
gradxyz2 = gradxyz2.cuda() |
|
cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) |
|
|
|
return gradxyz1, gradxyz2 |
|
|
|
|
|
class ChamferDistance(torch.nn.Module): |
|
def forward(self, xyz1, xyz2): |
|
return ChamferDistanceFunction.apply(xyz1, xyz2) |
|
|