linxianzhong0128's picture
Upload folder using huggingface_hub
7088d16 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
from typing import Any, Callable
import torch
from common_testing import get_random_cuda_device
from fvcore.common.benchmark import benchmark
from pytorch3d.common.workaround import symeig3x3
from tests.test_symeig3x3 import TestSymEig3x3
torch.set_num_threads(1)
CUDA_DEVICE = get_random_cuda_device()
def create_traced_func(func, device, batch_size):
traced_func = torch.jit.trace(
func, (TestSymEig3x3.create_random_sym3x3(device, batch_size),)
)
return traced_func
FUNC_NAME_TO_FUNC = {
"sym3x3eig": (lambda inputs: symeig3x3(inputs, eigenvectors=True)),
"sym3x3eig_traced_cuda": create_traced_func(
(lambda inputs: symeig3x3(inputs, eigenvectors=True)), CUDA_DEVICE, 1024
),
"torch_symeig": (lambda inputs: torch.symeig(inputs, eigenvectors=True)),
"torch_linalg_eigh": (lambda inputs: torch.linalg.eigh(inputs)),
"torch_pca_lowrank": (
lambda inputs: torch.pca_lowrank(inputs, center=False, niter=1)
),
"sym3x3eig_no_vecs": (lambda inputs: symeig3x3(inputs, eigenvectors=False)),
"torch_symeig_no_vecs": (lambda inputs: torch.symeig(inputs, eigenvectors=False)),
"torch_linalg_eigvalsh_no_vecs": (lambda inputs: torch.linalg.eigvalsh(inputs)),
}
def test_symeig3x3(func_name, batch_size, device) -> Callable[[], Any]:
func = FUNC_NAME_TO_FUNC[func_name]
inputs = TestSymEig3x3.create_random_sym3x3(device, batch_size)
torch.cuda.synchronize()
def symeig3x3():
func(inputs)
torch.cuda.synchronize()
return symeig3x3
def bm_symeig3x3() -> None:
devices = ["cpu"]
if torch.cuda.is_available():
devices.append(CUDA_DEVICE)
kwargs_list = []
func_names = FUNC_NAME_TO_FUNC.keys()
batch_sizes = [16, 128, 1024, 8192, 65536, 1048576]
for func_name, batch_size, device in product(func_names, batch_sizes, devices):
# Run CUDA-only implementations only on GPU
if "cuda" in func_name and not device.startswith("cuda"):
continue
# Torch built-ins are quite slow on larger batches
if "torch" in func_name and batch_size > 8192:
continue
# Avoid running CPU implementations on larger batches as well
if device == "cpu" and batch_size > 8192:
continue
kwargs_list.append(
{"func_name": func_name, "batch_size": batch_size, "device": device}
)
benchmark(
test_symeig3x3,
"SYMEIG3X3",
kwargs_list,
warmup_iters=3,
)
if __name__ == "__main__":
bm_symeig3x3()