Spaces:
Sleeping
Sleeping
# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron | |
# We add the case where seqlen = 4k and seqlen = 8k | |
import os | |
import subprocess | |
import torch | |
from setuptools import setup | |
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME | |
def get_cuda_bare_metal_version(cuda_dir): | |
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) | |
output = raw_output.split() | |
release_idx = output.index("release") + 1 | |
release = output[release_idx].split(".") | |
bare_metal_major = release[0] | |
bare_metal_minor = release[1][0] | |
return raw_output, bare_metal_major, bare_metal_minor | |
def append_nvcc_threads(nvcc_extra_args): | |
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) | |
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: | |
nvcc_threads = os.getenv("NVCC_THREADS") or "4" | |
return nvcc_extra_args + ["--threads", nvcc_threads] | |
return nvcc_extra_args | |
cc_flag = [] | |
cc_flag.append("-gencode") | |
cc_flag.append("arch=compute_70,code=sm_70") | |
cc_flag.append("-gencode") | |
cc_flag.append("arch=compute_80,code=sm_80") | |
setup( | |
name='fused_softmax_lib', | |
ext_modules=[ | |
CUDAExtension( | |
name='fused_softmax_lib', | |
sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'], | |
extra_compile_args={ | |
'cxx': ['-O3',], | |
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) | |
} | |
) | |
], | |
cmdclass={ | |
'build_ext': BuildExtension | |
}) | |