|
|
|
|
|
|
|
|
|
|
|
|
|
from itertools import product |
|
|
|
import torch |
|
from fvcore.common.benchmark import benchmark |
|
from pytorch3d.ops.interp_face_attrs import ( |
|
interpolate_face_attributes, |
|
interpolate_face_attributes_python, |
|
) |
|
|
|
|
|
def _generate_data(N, S, K, F, D, device, requires_grad=False): |
|
pix_to_face = torch.randint(-10, F, (N, S, S, K), device=device) |
|
barycentric_coords = torch.randn( |
|
N, S, S, K, 3, device=device, requires_grad=requires_grad |
|
) |
|
face_attrs = torch.randn(F, 3, D, device=device, requires_grad=requires_grad) |
|
grad_pix_attrs = torch.randn(N, S, S, K, D, device=device) |
|
return pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs |
|
|
|
|
|
def _bm_forward(N, S, F, K, D, impl): |
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
device = torch.device("cuda") |
|
data = _generate_data(N, S, K, F, D, device, requires_grad=False) |
|
args = data[:3] |
|
torch.cuda.synchronize() |
|
if impl == "cuda": |
|
fun = interpolate_face_attributes |
|
elif impl == "python": |
|
fun = interpolate_face_attributes_python |
|
return lambda: fun(*args) |
|
|
|
|
|
def _bm_forward_backward(N, S, F, K, D, impl): |
|
torch.manual_seed(0) |
|
device = torch.device("cuda") |
|
data = _generate_data(N, S, K, F, D, device, requires_grad=True) |
|
args, grad = data[:3], data[3] |
|
torch.cuda.synchronize() |
|
if impl == "cuda": |
|
fun = interpolate_face_attributes |
|
elif impl == "python": |
|
fun = interpolate_face_attributes_python |
|
|
|
def run(): |
|
out = fun(*args) |
|
out.backward(gradient=grad) |
|
|
|
return run |
|
|
|
|
|
def bm_interpolate_face_attribues() -> None: |
|
|
|
if not torch.cuda.is_available(): |
|
return |
|
|
|
Ns = [1, 4] |
|
Ss = [128] |
|
Ks = [1, 10, 40] |
|
Fs = [5000] |
|
Ds = [1, 3, 16] |
|
impls = ["python", "cuda"] |
|
test_cases = product(Ns, Ss, Ks, Fs, Ds, impls) |
|
kwargs_list = [] |
|
for case in test_cases: |
|
N, S, K, F, D, impl = case |
|
kwargs_list.append({"N": N, "S": S, "K": K, "F": F, "D": D, "impl": impl}) |
|
benchmark(_bm_forward, "FORWARD", kwargs_list, warmup_iters=3) |
|
benchmark(_bm_forward_backward, "FORWARD+BACKWARD", kwargs_list, warmup_iters=3) |
|
|
|
|
|
if __name__ == "__main__": |
|
bm_interpolate_face_attribues() |
|
|