Qwen-Image-Edit-Relight / optimization.py
linoyts's picture
linoyts HF Staff
Update optimization.py
8cb8007 verified
raw
history blame
2.1 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
import spaces
import torch
from torch.utils._pytree import tree_map
P = ParamSpec('P')
TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
},
'encoder_hidden_states': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'encoder_hidden_states_mask': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'image_rotary_emb': ({
0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
}, {
0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
}),
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
pipeline.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
)
pipeline.fuse_lora()
pipeline.unload_lora_weights()
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
# quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
spaces.aoti_apply(compile_transformer(), pipeline.transformer)