Spaces:
Sleeping
Sleeping
File size: 1,304 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import os
import subprocess
from packaging.version import parse, Version
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
bare_metal_version = parse(output[release_idx].split(",")[0])
return raw_output, bare_metal_version
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args
setup(
name='fused_dense_lib',
ext_modules=[
CUDAExtension(
name='fused_dense_lib',
sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
extra_compile_args={
'cxx': ['-O3',],
'nvcc': append_nvcc_threads(['-O3'])
}
)
],
cmdclass={
'build_ext': BuildExtension
})
|