#!/usr/bin/env python from setuptools import find_packages, setup import os import subprocess import sys import time import torch from torch.utils.cpp_extension import (BuildExtension, CppExtension, CUDAExtension) version_file = 'basicsr/version.py' def readme(): return '' # with open('README.md', encoding='utf-8') as f: # content = f.read() # return content def get_git_hash(): def _minimal_ext_cmd(cmd): # construct minimal environment env = {} for k in ['SYSTEMROOT', 'PATH', 'HOME']: v = os.environ.get(k) if v is not None: env[k] = v # LANGUAGE is used on win32 env['LANGUAGE'] = 'C' env['LANG'] = 'C' env['LC_ALL'] = 'C' out = subprocess.Popen( cmd, stdout=subprocess.PIPE, env=env).communicate()[0] return out try: out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) sha = out.strip().decode('ascii') except OSError: sha = 'unknown' return sha def get_hash(): if os.path.exists('.git'): sha = get_git_hash()[:7] elif os.path.exists(version_file): try: from basicsr.version import __version__ sha = __version__.split('+')[-1] except ImportError: raise ImportError('Unable to get git version') else: sha = 'unknown' return sha def write_version_py(): content = """# GENERATED VERSION FILE # TIME: {} __version__ = '{}' short_version = '{}' version_info = ({}) """ sha = get_hash() with open('VERSION', 'r') as f: SHORT_VERSION = f.read().strip() VERSION_INFO = ', '.join( [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) VERSION = SHORT_VERSION + '+' + sha version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, VERSION_INFO) with open(version_file, 'w') as f: f.write(version_file_str) def get_version(): with open(version_file, 'r') as f: exec(compile(f.read(), version_file, 'exec')) return locals()['__version__'] def make_cuda_ext(name, module, sources, sources_cuda=None): if sources_cuda is None: sources_cuda = [] define_macros = [] extra_compile_args = {'cxx': []} if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': define_macros += [('WITH_CUDA', None)] extension = CUDAExtension extra_compile_args['nvcc'] = [ '-D__CUDA_NO_HALF_OPERATORS__', '-D__CUDA_NO_HALF_CONVERSIONS__', '-D__CUDA_NO_HALF2_OPERATORS__', ] sources += sources_cuda else: print(f'Compiling {name} without CUDA') extension = CppExtension return extension( name=f'{module}.{name}', sources=[os.path.join(*module.split('.'), p) for p in sources], define_macros=define_macros, extra_compile_args=extra_compile_args) def get_requirements(filename='requirements.txt'): return [] here = os.path.dirname(os.path.realpath(__file__)) with open(os.path.join(here, filename), 'r') as f: requires = [line.replace('\n', '') for line in f.readlines()] return requires if __name__ == '__main__': if '--no_cuda_ext' in sys.argv: ext_modules = [] sys.argv.remove('--no_cuda_ext') else: ext_modules = [ make_cuda_ext( name='deform_conv_ext', module='basicsr.models.ops.dcn', sources=['src/deform_conv_ext.cpp'], sources_cuda=[ 'src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu' ]), make_cuda_ext( name='fused_act_ext', module='basicsr.models.ops.fused_act', sources=['src/fused_bias_act.cpp'], sources_cuda=['src/fused_bias_act_kernel.cu']), make_cuda_ext( name='upfirdn2d_ext', module='basicsr.models.ops.upfirdn2d', sources=['src/upfirdn2d.cpp'], sources_cuda=['src/upfirdn2d_kernel.cu']), ] write_version_py() setup( name='basicsr', version=get_version(), description='Open Source Image and Video Super-Resolution Toolbox', long_description=readme(), author='Xintao Wang', author_email='xintao.wang@outlook.com', keywords='computer vision, restoration, super resolution', url='https://github.com/xinntao/BasicSR', packages=find_packages( exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), classifiers=[ 'Development Status :: 4 - Beta', 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', ], license='Apache License 2.0', setup_requires=['cython', 'numpy'], install_requires=get_requirements(), ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension}, zip_safe=False)