Spaces:
Running
Running
from torch import nn | |
from torch.autograd import Function | |
import torch | |
import importlib | |
import os | |
chamfer_found = importlib.find_loader("chamfer_2D") is not None | |
if not chamfer_found: | |
## Cool trick from https://github.com/chrdiller | |
print("Jitting Chamfer 2D") | |
cur_path = os.path.dirname(os.path.abspath(__file__)) | |
build_path = cur_path.replace('chamfer2D', 'tmp') | |
os.makedirs(build_path, exist_ok=True) | |
from torch.utils.cpp_extension import load | |
chamfer_2D = load(name="chamfer_2D", | |
sources=[ | |
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), | |
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), | |
], build_directory=build_path) | |
print("Loaded JIT 2D CUDA chamfer distance") | |
else: | |
import chamfer_2D | |
print("Loaded compiled 2D CUDA chamfer distance") | |
# Chamfer's distance module @thibaultgroueix | |
# GPU tensors only | |
class chamfer_2DFunction(Function): | |
def forward(ctx, xyz1, xyz2): | |
batchsize, n, dim = xyz1.size() | |
assert dim == 2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" | |
_, m, dim = xyz2.size() | |
assert dim == 2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" | |
device = xyz1.device | |
device = xyz1.device | |
dist1 = torch.zeros(batchsize, n) | |
dist2 = torch.zeros(batchsize, m) | |
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) | |
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) | |
dist1 = dist1.to(device) | |
dist2 = dist2.to(device) | |
idx1 = idx1.to(device) | |
idx2 = idx2.to(device) | |
torch.cuda.set_device(device) | |
chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) | |
ctx.save_for_backward(xyz1, xyz2, idx1, idx2) | |
return dist1, dist2, idx1, idx2 | |
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): | |
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors | |
graddist1 = graddist1.contiguous() | |
graddist2 = graddist2.contiguous() | |
device = graddist1.device | |
gradxyz1 = torch.zeros(xyz1.size()) | |
gradxyz2 = torch.zeros(xyz2.size()) | |
gradxyz1 = gradxyz1.to(device) | |
gradxyz2 = gradxyz2.to(device) | |
chamfer_2D.backward( | |
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 | |
) | |
return gradxyz1, gradxyz2 | |
class chamfer_2DDist(nn.Module): | |
def __init__(self): | |
super(chamfer_2DDist, self).__init__() | |
def forward(self, input1, input2): | |
input1 = input1.contiguous() | |
input2 = input2.contiguous() | |
return chamfer_2DFunction.apply(input1, input2) | |