File size: 5,360 Bytes
877e614 9095ff2 877e614 d6b300f fdc9e50 877e614 fdc9e50 877e614 8c4b798 877e614 d6b300f 877e614 753e002 750d549 753e002 877e614 a299603 877e614 753e002 fb27579 7e20e0a 750d549 fb27579 750d549 fb27579 877e614 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse, torch, logging
import packaging.version as pv
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig)
logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('wfx')
if pv.parse(torch.__version__) >= pv.parse('1.12.0'):
torch.backends.cuda.matmul.allow_tf32 = True
#torch.backends.cudnn.allow_tf32 = True # not sure...
logger.info('matching torch version, enabling tf32')
def parse_args():
args = argparse.ArgumentParser()
args.add_argument('--disable-xformers', action='store_true', default=False)
args.add_argument('--disable-triton', action='store_true', default=False)
args.add_argument('--quantize-unet', action='store_true', default=False)
args.add_argument('--model', type=str, required=True)
args.add_argument('--custom-pipeline', type=str, default=None)
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
args.add_argument('--enable-cuda-graph', action='store_true', default=False)
args.add_argument('--disable-prefer-lowp-gemm', action='store_true', default=False)
return args.parse_args()
def quantize_unet(m):
from diffusers.utils import USE_PEFT_BACKEND
assert USE_PEFT_BACKEND
logger.info('PEFT backend detected, quantizing unet...')
m = torch.quantization.quantize_dynamic(
m, { torch.nn.Linear },
dtype=torch.qint8,
inplace=True
)
logger.info('unet successfully quantized')
return m
class WFX():
compiler_config: CompilationConfig.Default = CompilationConfig.Default()
T2IPipeline: AutoPipelineForText2Image = None
I2IPipeline: AutoPipelineForImage2Image = None
def __init__(self) -> None:
args = parse_args()
self._check_optimization(args)
def _check_optimization(self, args) -> None:
logger.info(f'torch version: {torch.__version__}')
if not args.disable_xformers:
try:
import xformers
self.compiler_config.enable_xformers = True
logger.info(f'xformers version: {xformers.__version__}')
except ImportError:
logger.warning('xformers not found, disabling xformers')
if not args.disable_triton:
try:
import triton
self.compiler_config.enable_triton = True
logger.info(f'triton version: {triton.__version__}')
except ImportError:
logger.warning('triton not found, disabling triton')
self.compiler_config.enable_cuda_graph = args.enable_cuda_graph
if args.disable_prefer_lowp_gemm:
self.compiler_config.prefer_lowp_gemm = False
for key in self.compiler_config.__dict__:
logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
def load(self) -> None:
args = parse_args()
extra_kwargs = {
'torch_dtype': torch.float16,
'use_safetensors': True,
'requires_safety_checker': False,
}
if args.custom_pipeline is not None:
logger.info(f'loading custom pipeline from "{args.custom_pipeline}"')
extra_kwargs['custom_pipeline'] = args.custom_pipeline
self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs)
self.T2IPipeline.safety_checker = None
self.T2IPipeline.to(torch.device('cuda:0'))
if args.quantize_unet:
self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
logger.info(f'compiling pipeline in {args.compile_mode} mode...')
if args.compile_mode == 'sfast':
self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
elif args.compile_mode == 'torch':
logger.info('compiling unet...')
self.T2IPipeline.unet = torch.compile(self.T2IPipeline.unet, mode='max-autotune')
logger.info('compiling vae...')
self.T2IPipeline.vae = torch.compile(self.T2IPipeline.vae, mode='max-autotune')
self.warmup()
def warmup(self) -> None:
warmed = 0
warmed_total = 5
warmup_kwargs = dict(
prompt='a photo of a cat',
height=768,
width=512,
num_inference_steps=30,
generator=torch.Generator(device='cuda:0').manual_seed(0),
)
if warmed < warmed_total:
logger.info(f'warming up T2I pipeline...')
for _ in range(warmed_total):
begin = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
begin.record()
self.T2IPipeline(**warmup_kwargs)
end.record()
torch.cuda.synchronize()
elapsed_time = begin.elapsed_time(end)
warmed += 1
logger.info(f'warmed {warmed}/{warmed_total} - {elapsed_time:.2f}ms')
if __name__ == '__main__':
wfx = WFX()
wfx.load() |