|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from itertools import product |
|
from typing import Any, Callable |
|
|
|
import torch |
|
from common_testing import get_random_cuda_device |
|
from fvcore.common.benchmark import benchmark |
|
from pytorch3d.common.workaround import symeig3x3 |
|
from tests.test_symeig3x3 import TestSymEig3x3 |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
|
CUDA_DEVICE = get_random_cuda_device() |
|
|
|
|
|
def create_traced_func(func, device, batch_size): |
|
traced_func = torch.jit.trace( |
|
func, (TestSymEig3x3.create_random_sym3x3(device, batch_size),) |
|
) |
|
|
|
return traced_func |
|
|
|
|
|
FUNC_NAME_TO_FUNC = { |
|
"sym3x3eig": (lambda inputs: symeig3x3(inputs, eigenvectors=True)), |
|
"sym3x3eig_traced_cuda": create_traced_func( |
|
(lambda inputs: symeig3x3(inputs, eigenvectors=True)), CUDA_DEVICE, 1024 |
|
), |
|
"torch_symeig": (lambda inputs: torch.symeig(inputs, eigenvectors=True)), |
|
"torch_linalg_eigh": (lambda inputs: torch.linalg.eigh(inputs)), |
|
"torch_pca_lowrank": ( |
|
lambda inputs: torch.pca_lowrank(inputs, center=False, niter=1) |
|
), |
|
"sym3x3eig_no_vecs": (lambda inputs: symeig3x3(inputs, eigenvectors=False)), |
|
"torch_symeig_no_vecs": (lambda inputs: torch.symeig(inputs, eigenvectors=False)), |
|
"torch_linalg_eigvalsh_no_vecs": (lambda inputs: torch.linalg.eigvalsh(inputs)), |
|
} |
|
|
|
|
|
def test_symeig3x3(func_name, batch_size, device) -> Callable[[], Any]: |
|
func = FUNC_NAME_TO_FUNC[func_name] |
|
inputs = TestSymEig3x3.create_random_sym3x3(device, batch_size) |
|
torch.cuda.synchronize() |
|
|
|
def symeig3x3(): |
|
func(inputs) |
|
torch.cuda.synchronize() |
|
|
|
return symeig3x3 |
|
|
|
|
|
def bm_symeig3x3() -> None: |
|
devices = ["cpu"] |
|
if torch.cuda.is_available(): |
|
devices.append(CUDA_DEVICE) |
|
|
|
kwargs_list = [] |
|
func_names = FUNC_NAME_TO_FUNC.keys() |
|
batch_sizes = [16, 128, 1024, 8192, 65536, 1048576] |
|
|
|
for func_name, batch_size, device in product(func_names, batch_sizes, devices): |
|
|
|
if "cuda" in func_name and not device.startswith("cuda"): |
|
continue |
|
|
|
|
|
if "torch" in func_name and batch_size > 8192: |
|
continue |
|
|
|
|
|
if device == "cpu" and batch_size > 8192: |
|
continue |
|
|
|
kwargs_list.append( |
|
{"func_name": func_name, "batch_size": batch_size, "device": device} |
|
) |
|
|
|
benchmark( |
|
test_symeig3x3, |
|
"SYMEIG3X3", |
|
kwargs_list, |
|
warmup_iters=3, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
bm_symeig3x3() |
|
|