|
import argparse, torch, logging |
|
import packaging.version as pv |
|
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image |
|
from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger('wfx') |
|
|
|
|
|
if pv.parse(torch.__version__) >= pv.parse('1.12.0'): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
logger.info('matching torch version, enabling tf32') |
|
|
|
|
|
def parse_args(): |
|
args = argparse.ArgumentParser() |
|
args.add_argument('--disable-xformers', action='store_true', default=False) |
|
args.add_argument('--disable-triton', action='store_true', default=False) |
|
args.add_argument('--quantize-unet', action='store_true', default=False) |
|
args.add_argument('--model', type=str, required=True) |
|
args.add_argument('--custom-pipeline', type=str, default=None) |
|
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile']) |
|
args.add_argument('--enable-cuda-graph', action='store_true', default=False) |
|
args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False) |
|
return args.parse_args() |
|
|
|
def quantize_unet(m): |
|
from diffusers.utils import USE_PEFT_BACKEND |
|
assert USE_PEFT_BACKEND |
|
|
|
logger.info('PEFT backend detected, quantizing unet...') |
|
|
|
m = torch.quantization.quantize_dynamic( |
|
m, { torch.nn.Linear }, |
|
dtype=torch.qint8, |
|
inplace=True |
|
) |
|
|
|
logger.info('unet successfully quantized') |
|
return m |
|
|
|
|
|
class WFX(): |
|
compiler_config: CompilationConfig.Default = CompilationConfig.Default() |
|
T2IPipeline: AutoPipelineForText2Image = None |
|
I2IPipeline: AutoPipelineForImage2Image = None |
|
|
|
def __init__(self) -> None: |
|
args = parse_args() |
|
self._check_optimization(args) |
|
|
|
def _check_optimization(self, args) -> None: |
|
logger.info(f'torch version: {torch.__version__}') |
|
|
|
if not args.disable_xformers: |
|
try: |
|
import xformers |
|
self.compiler_config.enable_xformers = True |
|
logger.info(f'xformers version: {xformers.__version__}') |
|
except ImportError: |
|
logger.warning('xformers not found, disabling xformers') |
|
|
|
if not args.disable_triton: |
|
try: |
|
import triton |
|
self.compiler_config.enable_triton = True |
|
logger.info(f'triton version: {triton.__version__}') |
|
except ImportError: |
|
logger.warning('triton not found, disabling triton') |
|
|
|
self.compiler_config.enable_cuda_graph = args.enable_cuda_graph |
|
|
|
if args.disable_prefer_lowp_gemm: |
|
self.compiler_config.prefer_lowp_gemm = False |
|
|
|
for key in self.compiler_config.__dict__: |
|
logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}') |
|
|
|
def load(self) -> None: |
|
args = parse_args() |
|
extra_kwargs = { |
|
'torch_dtype': torch.float16, |
|
'use_safetensors': True, |
|
'requires_safety_checker': False, |
|
} |
|
|
|
if args.custom_pipeline is not None: |
|
logger.info(f'loading custom pipeline from "{args.custom_pipeline}"') |
|
extra_kwargs['custom_pipeline'] = args.custom_pipeline |
|
|
|
self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs) |
|
self.T2IPipeline.safety_checker = None |
|
self.T2IPipeline.to(torch.device('cuda:0')) |
|
|
|
if args.quantize_unet: |
|
self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet) |
|
|
|
logger.info(f'compiling pipeline in {args.compile_mode} mode...') |
|
if args.compile_mode == 'sfast': |
|
self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config) |
|
elif args.compile_mode == 'torch': |
|
logger.info('compiling unet...') |
|
self.T2IPipeline.unet = torch.compile(self.T2IPipeline.unet, mode='max-autotune') |
|
logger.info('compiling vae...') |
|
self.T2IPipeline.vae = torch.compile(self.T2IPipeline.vae, mode='max-autotune') |
|
|
|
self.warmup() |
|
|
|
def warmup(self) -> None: |
|
warmed = 0 |
|
warmed_total = 5 |
|
|
|
warmup_kwargs = dict( |
|
prompt='a photo of a cat', |
|
height=768, |
|
width=512, |
|
num_inference_steps=30, |
|
generator=torch.Generator(device='cuda:0').manual_seed(0), |
|
) |
|
|
|
if warmed < warmed_total: |
|
logger.info(f'warming up T2I pipeline...') |
|
for _ in range(warmed_total): |
|
begin = torch.cuda.Event(enable_timing=True) |
|
end = torch.cuda.Event(enable_timing=True) |
|
|
|
begin.record() |
|
self.T2IPipeline(**warmup_kwargs) |
|
end.record() |
|
|
|
torch.cuda.synchronize() |
|
elapsed_time = begin.elapsed_time(end) |
|
|
|
warmed += 1 |
|
logger.info(f'warmed {warmed}/{warmed_total} - {elapsed_time:.2f}ms') |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
wfx = WFX() |
|
wfx.load() |