Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
import time | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Function | |
import torch.nn.functional as F | |
try: | |
from . import mvpraymarchlib | |
except: | |
import mvpraymarchlib | |
def build_accel(primtransfin, algo, fixedorder=False): | |
"""build bvh structure given primitive centers and sizes | |
Parameters: | |
---------- | |
primtransfin : tuple[tensor, tensor, tensor] | |
primitive transform tensors | |
algo : int | |
raymarching algorithm | |
fixedorder : optional[str] | |
True means the bvh builder will not reorder primitives and will | |
use a trivial tree structure. Likely to be slow for arbitrary | |
configurations of primitives. | |
""" | |
primpos, primrot, primscale = primtransfin | |
N = primpos.size(0) | |
K = primpos.size(1) | |
dev = primpos.device | |
# compute and sort morton codes | |
if fixedorder: | |
sortedobjid = (torch.arange(N*K, dtype=torch.int32, device=dev) % K).view(N, K) | |
else: | |
cmax = primpos.max(dim=1, keepdim=True)[0] | |
cmin = primpos.min(dim=1, keepdim=True)[0] | |
centers_norm = (primpos - cmin) / (cmax - cmin).clamp(min=1e-8) | |
mortoncode = torch.empty((N, K), dtype=torch.int32, device=dev) | |
mvpraymarchlib.compute_morton(centers_norm, mortoncode, algo) | |
sortedcode, sortedobjid_long = torch.sort(mortoncode, dim=-1) | |
sortedobjid = sortedobjid_long.int() | |
if fixedorder: | |
nodechildren = torch.cat([ | |
torch.arange(1, (K - 1) * 2 + 1, dtype=torch.int32, device=dev), | |
torch.div(torch.arange(-2, -(K * 2 + 1) - 1, -1, dtype=torch.int32, device=dev), 2, rounding_mode="floor")], | |
dim=0).view(1, K + K - 1, 2).repeat(N, 1, 1) | |
nodeparent = ( | |
torch.div(torch.arange(-1, K * 2 - 2, dtype=torch.int32, device=dev), 2, rounding_mode="floor") | |
.view(1, -1).repeat(N, 1)) | |
else: | |
nodechildren = torch.empty((N, K + K - 1, 2), dtype=torch.int32, device=dev) | |
nodeparent = torch.full((N, K + K - 1), -1, dtype=torch.int32, device=dev) | |
mvpraymarchlib.build_tree(sortedcode, nodechildren, nodeparent) | |
nodeaabb = torch.empty((N, K + K - 1, 2, 3), dtype=torch.float32, device=dev) | |
mvpraymarchlib.compute_aabb(*primtransfin, sortedobjid, nodechildren, nodeparent, nodeaabb, algo) | |
return sortedobjid, nodechildren, nodeaabb | |
class MVPRaymarch(Function): | |
"""Custom Function for raymarching Mixture of Volumetric Primitives.""" | |
def forward(self, raypos, raydir, stepsize, tminmax, | |
primpos, primrot, primscale, | |
template, warp, | |
rayterm, gradmode, options): | |
algo = options["algo"] | |
usebvh = options["usebvh"] | |
sortprims = options["sortprims"] | |
randomorder = options["randomorder"] | |
maxhitboxes = options["maxhitboxes"] | |
synchitboxes = options["synchitboxes"] | |
chlast = options["chlast"] | |
fadescale = options["fadescale"] | |
fadeexp = options["fadeexp"] | |
accum = options["accum"] | |
termthresh = options["termthresh"] | |
griddim = options["griddim"] | |
if isinstance(options["blocksize"], tuple): | |
blocksizex, blocksizey = options["blocksize"] | |
else: | |
blocksizex = options["blocksize"] | |
blocksizey = 1 | |
assert raypos.is_contiguous() and raypos.size(3) == 3 | |
assert raydir.is_contiguous() and raydir.size(3) == 3 | |
assert tminmax.is_contiguous() and tminmax.size(3) == 2 | |
assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3 | |
assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3 | |
assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3 | |
if chlast: | |
assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4 | |
assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3) | |
else: | |
assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4 | |
assert warp is None or (warp.is_contiguous() and warp.size(2) == 3) | |
primtransfin = (primpos, primrot, primscale) | |
# Build bvh | |
if usebvh is not False: | |
# compute radius of primitives | |
sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin, | |
algo, fixedorder=usebvh=="fixedorder") | |
assert sortedobjid.is_contiguous() | |
assert nodechildren.is_contiguous() | |
assert nodeaabb.is_contiguous() | |
if randomorder: | |
sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))] | |
else: | |
_, sortedobjid, nodechildren, nodeaabb = None, None, None, None | |
# march through boxes | |
N, H, W = raypos.size(0), raypos.size(1), raypos.size(2) | |
rayrgba = torch.empty((N, H, W, 4), device=raypos.device) | |
if gradmode: | |
raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device) | |
rayterm = None | |
else: | |
raysat = None | |
rayterm = None | |
mvpraymarchlib.raymarch_forward( | |
raypos, raydir, stepsize, tminmax, | |
sortedobjid, nodechildren, nodeaabb, | |
*primtransfin, | |
template, warp, | |
rayrgba, raysat, rayterm, | |
algo, sortprims, maxhitboxes, synchitboxes, chlast, | |
fadescale, fadeexp, | |
accum, termthresh, | |
griddim, blocksizex, blocksizey) | |
self.save_for_backward( | |
raypos, raydir, tminmax, | |
sortedobjid, nodechildren, nodeaabb, | |
primpos, primrot, primscale, | |
template, warp, | |
rayrgba, raysat, rayterm) | |
self.options = options | |
self.stepsize = stepsize | |
return rayrgba | |
def backward(self, grad_rayrgba): | |
(raypos, raydir, tminmax, | |
sortedobjid, nodechildren, nodeaabb, | |
primpos, primrot, primscale, | |
template, warp, | |
rayrgba, raysat, rayterm) = self.saved_tensors | |
algo = self.options["algo"] | |
usebvh = self.options["usebvh"] | |
sortprims = self.options["sortprims"] | |
maxhitboxes = self.options["maxhitboxes"] | |
synchitboxes = self.options["synchitboxes"] | |
chlast = self.options["chlast"] | |
fadescale = self.options["fadescale"] | |
fadeexp = self.options["fadeexp"] | |
accum = self.options["accum"] | |
termthresh = self.options["termthresh"] | |
griddim = self.options["griddim"] | |
if isinstance(self.options["bwdblocksize"], tuple): | |
blocksizex, blocksizey = self.options["bwdblocksize"] | |
else: | |
blocksizex = self.options["bwdblocksize"] | |
blocksizey = 1 | |
stepsize = self.stepsize | |
grad_primpos = torch.zeros_like(primpos) | |
grad_primrot = torch.zeros_like(primrot) | |
grad_primscale = torch.zeros_like(primscale) | |
primtransfin = (primpos, grad_primpos, primrot, grad_primrot, primscale, grad_primscale) | |
grad_template = torch.zeros_like(template) | |
grad_warp = torch.zeros_like(warp) if warp is not None else None | |
mvpraymarchlib.raymarch_backward(raypos, raydir, stepsize, tminmax, | |
sortedobjid, nodechildren, nodeaabb, | |
*primtransfin, | |
template, grad_template, warp, grad_warp, | |
rayrgba, grad_rayrgba.contiguous(), raysat, rayterm, | |
algo, sortprims, maxhitboxes, synchitboxes, chlast, | |
fadescale, fadeexp, | |
accum, termthresh, | |
griddim, blocksizex, blocksizey) | |
return (None, None, None, None, | |
grad_primpos, grad_primrot, grad_primscale, | |
grad_template, grad_warp, | |
None, None, None) | |
def mvpraymarch(raypos, raydir, stepsize, tminmax, | |
primtransf, | |
template, warp, | |
rayterm=None, | |
algo=0, usebvh="fixedorder", | |
sortprims=False, randomorder=False, | |
maxhitboxes=512, synchitboxes=True, | |
chlast=True, fadescale=8., fadeexp=8., | |
accum=0, termthresh=0., | |
griddim=3, blocksize=(8, 16), bwdblocksize=(8, 16)): | |
"""Main entry point for raymarching MVP. | |
Parameters: | |
---------- | |
raypos: N x H x W x 3 tensor of ray origins | |
raydir: N x H x W x 3 tensor of ray directions | |
stepsize: raymarching step size | |
tminmax: N x H x W x 2 tensor of raymarching min/max bounds | |
template: N x K x 4 x TD x TH x TW tensor of K RGBA primitives | |
warp: N x K x 3 x TD x TH x TW tensor of K warp fields (optional) | |
primpos: N x K x 3 tensor of primitive centers | |
primrot: N x K x 3 x 3 tensor of primitive orientations | |
primscale: N x K x 3 tensor of primitive inverse dimension lengths | |
algo: algorithm for raymarching (valid values: 0, 1). algo=0 is the fastest. | |
Currently algo=0 has a limit of 512 primitives per ray, so problems can | |
occur if there are many more boxes. all sortprims=True options have | |
this limitation, but you can use (algo=1, sortprims=False, | |
usebvh="fixedorder") which works correctly and has no primitive number | |
limitation (but is slightly slower). | |
usebvh: True to use bvh, "fixedorder" for a simple BVH, False for no bvh | |
sortprims: True to sort overlapping primitives at a sample point. Must | |
be True for gradients to match the PyTorch gradients. Seems unstable | |
if False but also not a big performance bottleneck. | |
chlast: whether template is provided as channels last or not. True tends | |
to be faster. | |
fadescale: Opacity is faded at the borders of the primitives by the equation | |
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of | |
the primitive. | |
fadeexp: Opacity is faded at the borders of the primitives by the equation | |
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of | |
the primitive. | |
griddim: CUDA grid dimensionality. | |
blocksize: blocksize of CUDA kernels. Should be 2-element tuple if | |
griddim>1, or integer if griddim==1.""" | |
if isinstance(primtransf, tuple): | |
primpos, primrot, primscale = primtransf | |
else: | |
primpos, primrot, primscale = ( | |
primtransf[:, :, 0, :].contiguous(), | |
primtransf[:, :, 1:4, :].contiguous(), | |
primtransf[:, :, 4, :].contiguous()) | |
primtransfin = (primpos, primrot, primscale) | |
out = MVPRaymarch.apply(raypos, raydir, stepsize, tminmax, | |
*primtransfin, | |
template, warp, | |
rayterm, torch.is_grad_enabled(), | |
{"algo": algo, "usebvh": usebvh, "sortprims": sortprims, "randomorder": randomorder, | |
"maxhitboxes": maxhitboxes, "synchitboxes": synchitboxes, | |
"chlast": chlast, "fadescale": fadescale, "fadeexp": fadeexp, | |
"accum": accum, "termthresh": termthresh, | |
"griddim": griddim, "blocksize": blocksize, "bwdblocksize": bwdblocksize}) | |
return out | |
class Rodrigues(nn.Module): | |
def __init__(self): | |
super(Rodrigues, self).__init__() | |
def forward(self, rvec): | |
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) | |
rvec = rvec / theta[:, None] | |
costh = torch.cos(theta) | |
sinth = torch.sin(theta) | |
return torch.stack(( | |
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, | |
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, | |
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, | |
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, | |
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, | |
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, | |
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, | |
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, | |
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) | |
def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False, | |
dowarp=False, chlast=False, fadescale=8., fadeexp=8., | |
accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)): | |
N = 2 | |
H = 65 | |
W = 65 | |
k3 = 4 | |
K = k3*k3*k3 | |
M = 32 | |
print("=================================================================") | |
print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, " | |
"fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, " | |
"blocksize={}, bwdblocksize={}".format( | |
usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast, | |
fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize, | |
bwdblocksize)) | |
# generate random inputs | |
torch.manual_seed(1112) | |
coherent_rays = True | |
if not coherent_rays: | |
_raypos = torch.randn(N, H, W, 3).to("cuda") | |
_raydir = torch.randn(N, H, W, 3).to("cuda") | |
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) | |
else: | |
focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)]) | |
princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)]) | |
pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) | |
pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) | |
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] | |
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) | |
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) | |
_raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda") | |
_raydir = raydir.to("cuda") | |
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True)) | |
max_len = 6.0 | |
_stepsize = max_len / 15.386928 | |
_tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \ | |
torch.rand(N, H, W, 2, device="cuda") * 1. | |
_template = torch.randn(N, K, 4, M, M, M, requires_grad=True) | |
_template.data[:, :, -1, :, :, :] -= 3.5 | |
_template = _template.contiguous().detach().clone() | |
_template.requires_grad = True | |
gridxyz = torch.stack(torch.meshgrid( | |
torch.linspace(-1., 1., M//2), | |
torch.linspace(-1., 1., M//2), | |
torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous() | |
_warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone() | |
_warp.requires_grad = True | |
_primpos = torch.randn(N, K, 3, requires_grad=True) | |
_primpos = torch.randn(N, K, 3, requires_grad=True) | |
coherent_centers = True | |
if coherent_centers: | |
ns = k3 | |
#assert ns*ns*ns==K | |
grid3d = torch.stack(torch.meshgrid( | |
torch.linspace(-1., 1., ns), | |
torch.linspace(-1., 1., ns), | |
torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None] | |
_primpos = (( | |
grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) + | |
0.1 * torch.randn(N, K, 3, requires_grad=True) | |
)).contiguous().detach().clone() | |
_primpos.requires_grad = True | |
scale_ws = 1. | |
_primrot = torch.randn(N, K, 3) | |
rodrigues = Rodrigues() | |
_primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone() | |
_primrot.requires_grad = True | |
_primscale = torch.randn(N, K, 3, requires_grad=True) | |
_primscale.data *= 0.0 | |
if dowarp: | |
params = [_template, _warp, _primscale, _primrot, _primpos] | |
paramnames = ["template", "warp", "primscale", "primrot", "primpos"] | |
else: | |
params = [_template, _primscale, _primrot, _primpos] | |
paramnames = ["template", "primscale", "primrot", "primpos"] | |
termthreshorig = termthresh | |
########################### run pytorch version ########################### | |
raypos = _raypos | |
raydir = _raydir | |
stepsize = _stepsize | |
tminmax = _tminmax | |
#template = F.softplus(_template.to("cuda") * 1.5) | |
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 | |
warp = _warp.to("cuda") | |
primpos = _primpos.to("cuda") * 0.3 | |
primrot = _primrot.to("cuda") | |
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) | |
# python raymarching implementation | |
rayrgba = torch.zeros((N, H, W, 4)).to("cuda") | |
raypos = raypos + raydir * tminmax[:, :, :, 0, None] | |
t = tminmax[:, :, :, 0] | |
step = 0 | |
t0 = t.detach().clone() | |
raypos0 = raypos.detach().clone() | |
torch.cuda.synchronize() | |
time0 = time.time() | |
while (t < tminmax[:, :, :, 1]).any(): | |
valid2 = torch.ones_like(rayrgba[:, :, :, 3:4]) | |
for k in range(K): | |
y0 = torch.bmm( | |
(raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)), | |
primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :] | |
fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True)) | |
if dowarp: | |
y1 = F.grid_sample( | |
warp[:, k, :, :, :, :], | |
y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) | |
else: | |
y1 = y0 | |
sample = F.grid_sample( | |
template[:, k, :, :, :, :], | |
y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1) | |
valid1 = ( | |
torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) * | |
torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True)) | |
valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None] | |
alpha0 = sample[:, :, :, 3:4] | |
rgb = sample[:, :, :, 0:3] * valid * valid1 | |
alpha = alpha0 * fade * stepsize * valid * valid1 | |
if accum == 0: | |
newalpha = rayrgba[:, :, :, 3:4] + alpha | |
contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1 | |
rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1) | |
else: | |
raise | |
step += 1 | |
t = t0 + stepsize * step | |
raypos = raypos0 + raydir * stepsize * step | |
print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item()) | |
sample0 = rayrgba | |
torch.cuda.synchronize() | |
time1 = time.time() | |
sample0.backward(torch.ones_like(sample0)) | |
torch.cuda.synchronize() | |
time2 = time.time() | |
print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total")) | |
print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0)) | |
grads0 = [p.grad.detach().clone() for p in params] | |
for p in params: | |
p.grad.detach_() | |
p.grad.zero_() | |
############################## run cuda version ########################### | |
raypos = _raypos | |
raydir = _raydir | |
stepsize = _stepsize | |
tminmax = _tminmax | |
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5 | |
warp = _warp.to("cuda") | |
if chlast: | |
template = template.permute(0, 1, 3, 4, 5, 2).contiguous() | |
warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous() | |
primpos = _primpos.to("cuda") * 0.3 | |
primrot = _primrot.to("cuda") | |
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda")) | |
niter = 1 | |
tf, tb = 0., 0. | |
for i in range(niter): | |
for p in params: | |
try: | |
p.grad.detach_() | |
p.grad.zero_() | |
except: | |
pass | |
t0 = time.time() | |
torch.cuda.synchronize() | |
sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax, | |
(primpos, primrot, primscale), | |
template, warp if dowarp else None, | |
algo=algo, usebvh=usebvh, sortprims=sortprims, | |
maxhitboxes=maxhitboxes, synchitboxes=synchitboxes, | |
chlast=chlast, fadescale=fadescale, fadeexp=fadeexp, | |
accum=accum, termthresh=termthreshorig, | |
griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize) | |
t1 = time.time() | |
torch.cuda.synchronize() | |
sample1.backward(torch.ones_like(sample1), retain_graph=True) | |
torch.cuda.synchronize() | |
t2 = time.time() | |
tf += t1 - t0 | |
tb += t2 - t1 | |
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) | |
grads1 = [p.grad.detach().clone() for p in params] | |
############# compare results ############# | |
print("-----------------------------------------------------------------") | |
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda")) | |
ind = torch.argmax(torch.abs(sample0 - sample1)) | |
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( | |
"fwd", | |
torch.max(torch.abs(sample0 - sample1)).item(), | |
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), | |
torch.sqrt(torch.sum(sample0 * sample0)).item(), | |
torch.sqrt(torch.sum(sample1 * sample1)).item(), | |
ind.item(), | |
sample0.view(-1)[ind].item(), | |
sample1.view(-1)[ind].item())) | |
for p, g0, g1 in zip(paramnames, grads0, grads1): | |
ind = torch.argmax(torch.abs(g0 - g1)) | |
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( | |
p, | |
torch.max(torch.abs(g0 - g1)).item(), | |
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), | |
torch.sqrt(torch.sum(g0 * g0)).item(), | |
torch.sqrt(torch.sum(g1 * g1)).item(), | |
ind.item(), | |
g0.view(-1)[ind].item(), | |
g1.view(-1)[ind].item())) | |
if __name__ == "__main__": | |
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, | |
dowarp=False, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=0, griddim=3) | |
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True, | |
dowarp=True, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=1, griddim=3) | |