Spaces:
Running
on
Zero
Running
on
Zero
Compiled transformer
#12
by
cbensimon
HF Staff
- opened
- app.py +8 -0
- optimization.py +60 -0
- optimization_utils.py +96 -0
app.py
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import spaces
|
@@ -8,9 +13,12 @@ from PIL import Image
|
|
8 |
from diffusers import FluxKontextPipeline
|
9 |
from diffusers.utils import load_image
|
10 |
|
|
|
|
|
11 |
MAX_SEED = np.iinfo(np.int32).max
|
12 |
|
13 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
|
|
14 |
|
15 |
@spaces.GPU
|
16 |
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
|
|
|
1 |
+
# PyTorch 2.8 (temporary hack)
|
2 |
+
import os
|
3 |
+
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
|
4 |
+
|
5 |
+
# Actual demo code
|
6 |
import gradio as gr
|
7 |
import numpy as np
|
8 |
import spaces
|
|
|
13 |
from diffusers import FluxKontextPipeline
|
14 |
from diffusers.utils import load_image
|
15 |
|
16 |
+
from optimization import optimize_pipeline_
|
17 |
+
|
18 |
MAX_SEED = np.iinfo(np.int32).max
|
19 |
|
20 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
21 |
+
optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')
|
22 |
|
23 |
@spaces.GPU
|
24 |
def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
|
optimization.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
from typing import Callable
|
6 |
+
from typing import ParamSpec
|
7 |
+
|
8 |
+
import spaces
|
9 |
+
import torch
|
10 |
+
from torch.utils._pytree import tree_map_only
|
11 |
+
|
12 |
+
from optimization_utils import capture_component_call
|
13 |
+
from optimization_utils import aoti_compile
|
14 |
+
|
15 |
+
|
16 |
+
P = ParamSpec('P')
|
17 |
+
|
18 |
+
|
19 |
+
TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=4096, max=8212)
|
20 |
+
|
21 |
+
TRANSFORMER_DYNAMIC_SHAPES = {
|
22 |
+
'hidden_states': {1: TRANSFORMER_HIDDEN_DIM},
|
23 |
+
'img_ids': {0: TRANSFORMER_HIDDEN_DIM},
|
24 |
+
}
|
25 |
+
|
26 |
+
INDUCTOR_CONFIGS = {
|
27 |
+
'conv_1x1_as_mm': True,
|
28 |
+
'epilogue_fusion': False,
|
29 |
+
'coordinate_descent_tuning': True,
|
30 |
+
'coordinate_descent_check_all_directions': True,
|
31 |
+
'max_autotune': True,
|
32 |
+
'triton.cudagraphs': True,
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
|
37 |
+
|
38 |
+
@spaces.GPU(duration=1500)
|
39 |
+
def compile_transformer():
|
40 |
+
|
41 |
+
with capture_component_call(pipeline, 'transformer') as call:
|
42 |
+
pipeline(*args, **kwargs)
|
43 |
+
|
44 |
+
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
45 |
+
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
46 |
+
|
47 |
+
pipeline.transformer.fuse_qkv_projections()
|
48 |
+
|
49 |
+
exported = torch.export.export(
|
50 |
+
mod=pipeline.transformer,
|
51 |
+
args=call.args,
|
52 |
+
kwargs=call.kwargs,
|
53 |
+
dynamic_shapes=dynamic_shapes,
|
54 |
+
)
|
55 |
+
|
56 |
+
return aoti_compile(exported, INDUCTOR_CONFIGS)
|
57 |
+
|
58 |
+
transformer_config = pipeline.transformer.config
|
59 |
+
pipeline.transformer = compile_transformer()
|
60 |
+
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
optimization_utils.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
import contextlib
|
4 |
+
from contextvars import ContextVar
|
5 |
+
from io import BytesIO
|
6 |
+
from typing import Any
|
7 |
+
from typing import cast
|
8 |
+
from unittest.mock import patch
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch._inductor.package.package import package_aoti
|
12 |
+
from torch.export.pt2_archive._package import AOTICompiledModel
|
13 |
+
from torch.export.pt2_archive._package_weights import TensorProperties
|
14 |
+
from torch.export.pt2_archive._package_weights import Weights
|
15 |
+
|
16 |
+
|
17 |
+
INDUCTOR_CONFIGS_OVERRIDES = {
|
18 |
+
'aot_inductor.package_constants_in_so': False,
|
19 |
+
'aot_inductor.package_constants_on_disk': True,
|
20 |
+
'aot_inductor.package': True,
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
class ZeroGPUCompiledModel:
|
25 |
+
def __init__(self, archive_file: torch.types.FileLike, weights: Weights, cuda: bool = False):
|
26 |
+
self.archive_file = archive_file
|
27 |
+
self.weights = weights
|
28 |
+
if cuda:
|
29 |
+
self.weights_to_cuda_()
|
30 |
+
self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
|
31 |
+
def weights_to_cuda_(self):
|
32 |
+
for name in self.weights:
|
33 |
+
tensor, properties = self.weights.get_weight(name)
|
34 |
+
self.weights[name] = (tensor.to('cuda'), properties)
|
35 |
+
def __call__(self, *args, **kwargs):
|
36 |
+
if (compiled_model := self.compiled_model.get()) is None:
|
37 |
+
constants_map = {name: value[0] for name, value in self.weights.items()}
|
38 |
+
compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
|
39 |
+
compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
|
40 |
+
self.compiled_model.set(compiled_model)
|
41 |
+
return compiled_model(*args, **kwargs)
|
42 |
+
def __reduce__(self):
|
43 |
+
weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]] = {}
|
44 |
+
for name in self.weights:
|
45 |
+
tensor, properties = self.weights.get_weight(name)
|
46 |
+
tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
|
47 |
+
weight_dict[name] = (tensor_.copy_(tensor).detach().share_memory_(), properties)
|
48 |
+
return ZeroGPUCompiledModel, (self.archive_file, Weights(weight_dict), True)
|
49 |
+
|
50 |
+
|
51 |
+
def aoti_compile(
|
52 |
+
exported_program: torch.export.ExportedProgram,
|
53 |
+
inductor_configs: dict[str, Any] | None = None,
|
54 |
+
):
|
55 |
+
inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
|
56 |
+
gm = cast(torch.fx.GraphModule, exported_program.module())
|
57 |
+
assert exported_program.example_inputs is not None
|
58 |
+
args, kwargs = exported_program.example_inputs
|
59 |
+
artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
|
60 |
+
archive_file = BytesIO()
|
61 |
+
files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
|
62 |
+
package_aoti(archive_file, files)
|
63 |
+
weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
|
64 |
+
return ZeroGPUCompiledModel(archive_file, weights)
|
65 |
+
|
66 |
+
|
67 |
+
@contextlib.contextmanager
|
68 |
+
def capture_component_call(
|
69 |
+
pipeline: Any,
|
70 |
+
component_name: str,
|
71 |
+
component_method='forward',
|
72 |
+
):
|
73 |
+
|
74 |
+
class CapturedCallException(Exception):
|
75 |
+
def __init__(self, *args, **kwargs):
|
76 |
+
super().__init__()
|
77 |
+
self.args = args
|
78 |
+
self.kwargs = kwargs
|
79 |
+
|
80 |
+
class CapturedCall:
|
81 |
+
def __init__(self):
|
82 |
+
self.args: tuple[Any, ...] = ()
|
83 |
+
self.kwargs: dict[str, Any] = {}
|
84 |
+
|
85 |
+
component = getattr(pipeline, component_name)
|
86 |
+
captured_call = CapturedCall()
|
87 |
+
|
88 |
+
def capture_call(*args, **kwargs):
|
89 |
+
raise CapturedCallException(*args, **kwargs)
|
90 |
+
|
91 |
+
with patch.object(component, component_method, new=capture_call):
|
92 |
+
try:
|
93 |
+
yield captured_call
|
94 |
+
except CapturedCallException as e:
|
95 |
+
captured_call.args = e.args
|
96 |
+
captured_call.kwargs = e.kwargs
|