owo
Browse files
wfx.py
CHANGED
@@ -21,6 +21,7 @@ def parse_args():
|
|
21 |
args.add_argument('--quantize-unet', action='store_true', default=False)
|
22 |
args.add_argument('--model', type=str, required=True)
|
23 |
args.add_argument('--custom-pipeline', type=str, default=None)
|
|
|
24 |
return args.parse_args()
|
25 |
|
26 |
def quantize_unet(m):
|
@@ -91,8 +92,14 @@ class WFX():
|
|
91 |
if args.quantize_unet:
|
92 |
self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
|
93 |
|
94 |
-
logger.info('compiling
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
self.warmup()
|
98 |
|
|
|
21 |
args.add_argument('--quantize-unet', action='store_true', default=False)
|
22 |
args.add_argument('--model', type=str, required=True)
|
23 |
args.add_argument('--custom-pipeline', type=str, default=None)
|
24 |
+
args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
|
25 |
return args.parse_args()
|
26 |
|
27 |
def quantize_unet(m):
|
|
|
92 |
if args.quantize_unet:
|
93 |
self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
|
94 |
|
95 |
+
logger.info(f'compiling pipeline in {args.compile_mode} mode...')
|
96 |
+
if args.compile_mode == 'sfast':
|
97 |
+
self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
|
98 |
+
elif args.compile_mode == 'torch':
|
99 |
+
logger.info('compiling unet...')
|
100 |
+
self.T2IPipeline.unet = torch.compile(self.T2IPipeline.unet, mode='max-autotune')
|
101 |
+
logger.info('compiling vae...')
|
102 |
+
self.T2IPipeline.vae = torch.compile(self.T2IPipeline.vae, mode='max-autotune')
|
103 |
|
104 |
self.warmup()
|
105 |
|