Spaces:
Sleeping
Sleeping
import os | |
import subprocess | |
from packaging.version import parse, Version | |
import torch | |
from setuptools import setup | |
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME | |
def get_cuda_bare_metal_version(cuda_dir): | |
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) | |
output = raw_output.split() | |
release_idx = output.index("release") + 1 | |
bare_metal_version = parse(output[release_idx].split(",")[0]) | |
return raw_output, bare_metal_version | |
def append_nvcc_threads(nvcc_extra_args): | |
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) | |
if bare_metal_version >= Version("11.2"): | |
nvcc_threads = os.getenv("NVCC_THREADS") or "4" | |
return nvcc_extra_args + ["--threads", nvcc_threads] | |
return nvcc_extra_args | |
setup( | |
name='fused_dense_lib', | |
ext_modules=[ | |
CUDAExtension( | |
name='fused_dense_lib', | |
sources=['fused_dense.cpp', 'fused_dense_cuda.cu'], | |
extra_compile_args={ | |
'cxx': ['-O3',], | |
'nvcc': append_nvcc_threads(['-O3']) | |
} | |
) | |
], | |
cmdclass={ | |
'build_ext': BuildExtension | |
}) | |