|
|
|
import sys |
|
import warnings |
|
import os |
|
import re |
|
import ast |
|
from pathlib import Path |
|
from packaging.version import parse, Version |
|
import platform |
|
|
|
from setuptools import setup, find_packages |
|
import subprocess |
|
|
|
import urllib.request |
|
import urllib.error |
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel |
|
|
|
import torch |
|
from torch.utils.cpp_extension import ( |
|
BuildExtension, |
|
CppExtension, |
|
CUDAExtension, |
|
CUDA_HOME, |
|
) |
|
|
|
|
|
with open("README.md", "r", encoding="utf-8") as fh: |
|
long_description = fh.read() |
|
|
|
|
|
|
|
this_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
PACKAGE_NAME = "causal_conv1d" |
|
|
|
BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" |
|
|
|
|
|
|
|
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" |
|
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" |
|
|
|
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" |
|
|
|
|
|
def get_platform(): |
|
""" |
|
Returns the platform name as used in wheel filenames. |
|
""" |
|
if sys.platform.startswith("linux"): |
|
return "linux_x86_64" |
|
elif sys.platform == "darwin": |
|
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) |
|
return f"macosx_{mac_version}_x86_64" |
|
elif sys.platform == "win32": |
|
return "win_amd64" |
|
else: |
|
raise ValueError("Unsupported platform: {}".format(sys.platform)) |
|
|
|
|
|
def get_cuda_bare_metal_version(cuda_dir): |
|
raw_output = subprocess.check_output( |
|
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True |
|
) |
|
output = raw_output.split() |
|
release_idx = output.index("release") + 1 |
|
bare_metal_version = parse(output[release_idx].split(",")[0]) |
|
|
|
return raw_output, bare_metal_version |
|
|
|
|
|
def check_if_cuda_home_none(global_option: str) -> None: |
|
if CUDA_HOME is not None: |
|
return |
|
|
|
|
|
warnings.warn( |
|
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " |
|
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " |
|
"only images whose names contain 'devel' will provide nvcc." |
|
) |
|
|
|
|
|
def append_nvcc_threads(nvcc_extra_args): |
|
return nvcc_extra_args + ["--threads", "4"] |
|
|
|
|
|
cmdclass = {} |
|
ext_modules = [] |
|
|
|
if not SKIP_CUDA_BUILD: |
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) |
|
TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
|
TORCH_MINOR = int(torch.__version__.split(".")[1]) |
|
|
|
check_if_cuda_home_none("causal_conv1d") |
|
|
|
cc_flag = [] |
|
if CUDA_HOME is not None: |
|
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) |
|
if bare_metal_version < Version("11.6"): |
|
raise RuntimeError( |
|
"causal_conv1d is only supported on CUDA 11.6 and above. " |
|
"Note: make sure nvcc has a supported version by running nvcc -V." |
|
) |
|
|
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_70,code=sm_70") |
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_80,code=sm_80") |
|
if bare_metal_version >= Version("11.8"): |
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_90,code=sm_90") |
|
|
|
|
|
|
|
|
|
if FORCE_CXX11_ABI: |
|
torch._C._GLIBCXX_USE_CXX11_ABI = True |
|
|
|
ext_modules.append( |
|
CUDAExtension( |
|
name="causal_conv1d_cuda", |
|
sources=[ |
|
"csrc/causal_conv1d.cpp", |
|
"csrc/causal_conv1d_fwd.cu", |
|
"csrc/causal_conv1d_bwd.cu", |
|
"csrc/causal_conv1d_update.cu", |
|
], |
|
extra_compile_args={ |
|
"cxx": ["-O3"], |
|
"nvcc": append_nvcc_threads( |
|
[ |
|
"-O3", |
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
"-U__CUDA_NO_BFLOAT16_OPERATORS__", |
|
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", |
|
"-U__CUDA_NO_BFLOAT162_OPERATORS__", |
|
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", |
|
"--expt-relaxed-constexpr", |
|
"--expt-extended-lambda", |
|
"--use_fast_math", |
|
"--ptxas-options=-v", |
|
"-lineinfo", |
|
] |
|
+ cc_flag |
|
), |
|
}, |
|
include_dirs=[this_dir], |
|
) |
|
) |
|
|
|
|
|
def get_package_version(): |
|
with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: |
|
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) |
|
public_version = ast.literal_eval(version_match.group(1)) |
|
local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") |
|
if local_version: |
|
return f"{public_version}+{local_version}" |
|
else: |
|
return str(public_version) |
|
|
|
|
|
def get_wheel_url(): |
|
|
|
|
|
|
|
torch_cuda_version = parse(torch.version.cuda) |
|
torch_version_raw = parse(torch.__version__) |
|
|
|
|
|
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") |
|
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" |
|
platform_name = get_platform() |
|
causal_conv1d_version = get_package_version() |
|
|
|
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" |
|
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" |
|
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() |
|
|
|
|
|
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" |
|
wheel_url = BASE_WHEEL_URL.format( |
|
tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename |
|
) |
|
return wheel_url, wheel_filename |
|
|
|
|
|
class CachedWheelsCommand(_bdist_wheel): |
|
""" |
|
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot |
|
find an existing wheel (which is currently the case for all installs). We use |
|
the environment parameters to detect whether there is already a pre-built version of a compatible |
|
wheel available and short-circuits the standard full build pipeline. |
|
""" |
|
|
|
def run(self): |
|
if FORCE_BUILD: |
|
return super().run() |
|
|
|
wheel_url, wheel_filename = get_wheel_url() |
|
print("Guessing wheel URL: ", wheel_url) |
|
try: |
|
urllib.request.urlretrieve(wheel_url, wheel_filename) |
|
|
|
|
|
|
|
|
|
if not os.path.exists(self.dist_dir): |
|
os.makedirs(self.dist_dir) |
|
|
|
impl_tag, abi_tag, plat_tag = self.get_tag() |
|
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" |
|
|
|
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") |
|
print("Raw wheel path", wheel_path) |
|
os.rename(wheel_filename, wheel_path) |
|
except urllib.error.HTTPError: |
|
print("Precompiled wheel not found. Building from source...") |
|
|
|
super().run() |
|
|
|
|
|
setup( |
|
name=PACKAGE_NAME, |
|
version=get_package_version(), |
|
packages=find_packages( |
|
exclude=( |
|
"build", |
|
"csrc", |
|
"include", |
|
"tests", |
|
"dist", |
|
"docs", |
|
"benchmarks", |
|
"causal_conv1d.egg-info", |
|
) |
|
), |
|
author="Tri Dao", |
|
author_email="[email protected]", |
|
description="Causal depthwise conv1d in CUDA, with a PyTorch interface", |
|
long_description=long_description, |
|
long_description_content_type="text/markdown", |
|
url="https://github.com/Dao-AILab/causal-conv1d", |
|
classifiers=[ |
|
"Programming Language :: Python :: 3", |
|
"License :: OSI Approved :: BSD License", |
|
"Operating System :: Unix", |
|
], |
|
ext_modules=ext_modules, |
|
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} |
|
if ext_modules |
|
else { |
|
"bdist_wheel": CachedWheelsCommand, |
|
}, |
|
python_requires=">=3.7", |
|
install_requires=[ |
|
"torch", |
|
"packaging", |
|
"ninja", |
|
], |
|
) |
|
|