RobertML commited on
Commit
33e938e
·
verified ·
1 Parent(s): 2ed46e4

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RobertML.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # flux-schnell-edge-inference
2
+ nestas hagunnan hinase
RobertML.png ADDED

Git LFS Details

  • SHA256: 7a6153fd5e5da780546d39bcf643fc4769f435dcbefd02d167706227b8489e6a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
loss_params.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0ee6fa5873dbc8df9daeeb105e220266bcf6634c6806b69da38fdc0a5c12b81
3
+ size 3184
pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 75.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flux-schnell-edge-inference"
7
+ description = "An edge-maxxing model submission by RobertML for the 4090 Flux contest"
8
+ requires-python = ">=3.10,<3.13"
9
+ version = "8"
10
+ dependencies = [
11
+ "diffusers==0.31.0",
12
+ "transformers==4.46.2",
13
+ "accelerate==1.1.0",
14
+ "omegaconf==2.3.0",
15
+ "torch==2.6.0",
16
+ "protobuf==5.28.3",
17
+ "sentencepiece==0.2.0",
18
+ "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
19
+ "gitpython>=3.1.43",
20
+ "hf_transfer==0.1.8",
21
+ "torchao==0.6.1",
22
+ ]
23
+
24
+ [[tool.edge-maxxing.models]]
25
+ repository = "black-forest-labs/FLUX.1-schnell"
26
+ revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
27
+ exclude = ["transformer"]
28
+
29
+ [[tool.edge-maxxing.models]]
30
+ repository = "RobertML/FLUX.1-schnell-int8wo"
31
+ revision = "307e0777d92df966a3c0f99f31a6ee8957a9857a"
32
+
33
+ [[tool.edge-maxxing.models]]
34
+ repository = "city96/t5-v1_1-xxl-encoder-bf16"
35
+ revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86"
36
+
37
+ [[tool.edge-maxxing.models]]
38
+ repository = "RobertML/FLUX.1-schnell-vae_e3m2"
39
+ revision = "da0d2cd7815792fb40d084dbd8ed32b63f153d8d"
40
+
41
+
42
+ [project.scripts]
43
+ start_inference = "main:main"
44
+
src/__pycache__/main.cpython-311.pyc ADDED
Binary file (4.42 kB). View file
 
src/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
src/first_block_cache/__init__.py ADDED
File without changes
src/first_block_cache/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (151 Bytes). View file
 
src/first_block_cache/__pycache__/utils.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
src/first_block_cache/diffusers_adapters/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ from diffusers import DiffusionPipeline
4
+
5
+
6
+ def apply_cache_on_transformer(transformer, *args, **kwargs):
7
+ transformer_cls_name = transformer.__class__.__name__
8
+ if False:
9
+ pass
10
+ elif transformer_cls_name.startswith("Flux"):
11
+ adapter_name = "flux"
12
+ elif transformer_cls_name.startswith("Mochi"):
13
+ adapter_name = "mochi"
14
+ elif transformer_cls_name.startswith("CogVideoX"):
15
+ adapter_name = "cogvideox"
16
+ elif transformer_cls_name.startswith("HunyuanVideo"):
17
+ adapter_name = "hunyuan_video"
18
+ else:
19
+ raise ValueError(f"Unknown transformer class name: {transformer_cls_name}")
20
+
21
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
22
+ apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
23
+ return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
24
+
25
+
26
+ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
27
+ assert isinstance(pipe, DiffusionPipeline)
28
+
29
+ pipe_cls_name = pipe.__class__.__name__
30
+ if False:
31
+ pass
32
+ elif pipe_cls_name.startswith("Flux"):
33
+ adapter_name = "flux"
34
+ elif pipe_cls_name.startswith("Mochi"):
35
+ adapter_name = "mochi"
36
+ elif pipe_cls_name.startswith("CogVideoX"):
37
+ adapter_name = "cogvideox"
38
+ elif pipe_cls_name.startswith("HunyuanVideo"):
39
+ adapter_name = "hunyuan_video"
40
+ else:
41
+ raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
42
+
43
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
44
+ apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
45
+ return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
src/first_block_cache/diffusers_adapters/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
src/first_block_cache/diffusers_adapters/__pycache__/flux.cpython-311.pyc ADDED
Binary file (3.48 kB). View file
 
src/first_block_cache/diffusers_adapters/cogvideox.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import unittest
3
+
4
+ import torch
5
+ from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
6
+
7
+ from para_attn.first_block_cache import utils
8
+
9
+
10
+ def apply_cache_on_transformer(
11
+ transformer: CogVideoXTransformer3DModel,
12
+ *,
13
+ residual_diff_threshold=0.04,
14
+ ):
15
+ cached_transformer_blocks = torch.nn.ModuleList(
16
+ [
17
+ utils.CachedTransformerBlocks(
18
+ transformer.transformer_blocks,
19
+ transformer=transformer,
20
+ residual_diff_threshold=residual_diff_threshold,
21
+ )
22
+ ]
23
+ )
24
+
25
+ original_forward = transformer.forward
26
+
27
+ @functools.wraps(transformer.__class__.forward)
28
+ def new_forward(
29
+ self,
30
+ *args,
31
+ **kwargs,
32
+ ):
33
+ with unittest.mock.patch.object(
34
+ self,
35
+ "transformer_blocks",
36
+ cached_transformer_blocks,
37
+ ):
38
+ return original_forward(
39
+ *args,
40
+ **kwargs,
41
+ )
42
+
43
+ transformer.forward = new_forward.__get__(transformer)
44
+
45
+ return transformer
46
+
47
+
48
+ def apply_cache_on_pipe(
49
+ pipe: DiffusionPipeline,
50
+ *,
51
+ shallow_patch: bool = False,
52
+ **kwargs,
53
+ ):
54
+ original_call = pipe.__class__.__call__
55
+
56
+ if not getattr(original_call, "_is_cached", False):
57
+
58
+ @functools.wraps(original_call)
59
+ def new_call(self, *args, **kwargs):
60
+ with utils.cache_context(utils.create_cache_context()):
61
+ return original_call(self, *args, **kwargs)
62
+
63
+ pipe.__class__.__call__ = new_call
64
+
65
+ new_call._is_cached = True
66
+
67
+ if not shallow_patch:
68
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
69
+
70
+ pipe._is_cached = True
71
+
72
+ return pipe
src/first_block_cache/diffusers_adapters/flux.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import unittest
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
6
+
7
+ from first_block_cache import utils
8
+
9
+
10
+ def apply_cache_on_transformer(
11
+ transformer: FluxTransformer2DModel,
12
+ *,
13
+ residual_diff_threshold=0.05,
14
+ ):
15
+ cached_transformer_blocks = torch.nn.ModuleList(
16
+ [
17
+ utils.CachedTransformerBlocks(
18
+ transformer.transformer_blocks,
19
+ transformer.single_transformer_blocks,
20
+ transformer=transformer,
21
+ residual_diff_threshold=residual_diff_threshold,
22
+ return_hidden_states_first=False,
23
+ )
24
+ ]
25
+ )
26
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
27
+
28
+ original_forward = transformer.forward
29
+
30
+ @functools.wraps(original_forward)
31
+ def new_forward(
32
+ self,
33
+ *args,
34
+ **kwargs,
35
+ ):
36
+ with unittest.mock.patch.object(
37
+ self,
38
+ "transformer_blocks",
39
+ cached_transformer_blocks,
40
+ ), unittest.mock.patch.object(
41
+ self,
42
+ "single_transformer_blocks",
43
+ dummy_single_transformer_blocks,
44
+ ):
45
+ return original_forward(
46
+ *args,
47
+ **kwargs,
48
+ )
49
+
50
+ transformer.forward = new_forward.__get__(transformer)
51
+
52
+ return transformer
53
+
54
+
55
+ def apply_cache_on_pipe(
56
+ pipe: DiffusionPipeline,
57
+ *,
58
+ shallow_patch: bool = False,
59
+ **kwargs,
60
+ ):
61
+ original_call = pipe.__class__.__call__
62
+
63
+ if not getattr(original_call, "_is_cached", False):
64
+
65
+ @functools.wraps(original_call)
66
+ def new_call(self, *args, **kwargs):
67
+ with utils.cache_context(utils.create_cache_context()):
68
+ return original_call(self, *args, **kwargs)
69
+
70
+ pipe.__class__.__call__ = new_call
71
+
72
+ new_call._is_cached = True
73
+
74
+ if not shallow_patch:
75
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
76
+
77
+ pipe._is_cached = True
78
+
79
+ return pipe
src/first_block_cache/diffusers_adapters/hunyuan_video.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import unittest
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ import torch
6
+ from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
7
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
8
+ from diffusers.utils import logging, scale_lora_layers, unscale_lora_layers, USE_PEFT_BACKEND
9
+
10
+ from para_attn.first_block_cache import utils
11
+
12
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
13
+
14
+
15
+ def apply_cache_on_transformer(
16
+ transformer: HunyuanVideoTransformer3DModel,
17
+ *,
18
+ residual_diff_threshold=0.06,
19
+ ):
20
+ cached_transformer_blocks = torch.nn.ModuleList(
21
+ [
22
+ utils.CachedTransformerBlocks(
23
+ transformer.transformer_blocks + transformer.single_transformer_blocks,
24
+ transformer=transformer,
25
+ residual_diff_threshold=residual_diff_threshold,
26
+ )
27
+ ]
28
+ )
29
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
30
+
31
+ original_forward = transformer.forward
32
+
33
+ @functools.wraps(transformer.__class__.forward)
34
+ def new_forward(
35
+ self,
36
+ hidden_states: torch.Tensor,
37
+ timestep: torch.LongTensor,
38
+ encoder_hidden_states: torch.Tensor,
39
+ encoder_attention_mask: torch.Tensor,
40
+ pooled_projections: torch.Tensor,
41
+ guidance: torch.Tensor = None,
42
+ attention_kwargs: Optional[Dict[str, Any]] = None,
43
+ return_dict: bool = True,
44
+ **kwargs,
45
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
46
+ with unittest.mock.patch.object(
47
+ self,
48
+ "transformer_blocks",
49
+ cached_transformer_blocks,
50
+ ), unittest.mock.patch.object(
51
+ self,
52
+ "single_transformer_blocks",
53
+ dummy_single_transformer_blocks,
54
+ ):
55
+ if getattr(self, "_is_parallelized", False):
56
+ return original_forward(
57
+ hidden_states,
58
+ timestep,
59
+ encoder_hidden_states,
60
+ encoder_attention_mask,
61
+ pooled_projections,
62
+ guidance=guidance,
63
+ attention_kwargs=attention_kwargs,
64
+ return_dict=return_dict,
65
+ **kwargs,
66
+ )
67
+ else:
68
+ if attention_kwargs is not None:
69
+ attention_kwargs = attention_kwargs.copy()
70
+ lora_scale = attention_kwargs.pop("scale", 1.0)
71
+ else:
72
+ lora_scale = 1.0
73
+
74
+ if USE_PEFT_BACKEND:
75
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
76
+ scale_lora_layers(self, lora_scale)
77
+ else:
78
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
79
+ logger.warning(
80
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
81
+ )
82
+
83
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
84
+ p, p_t = self.config.patch_size, self.config.patch_size_t
85
+ post_patch_num_frames = num_frames // p_t
86
+ post_patch_height = height // p
87
+ post_patch_width = width // p
88
+
89
+ # 1. RoPE
90
+ image_rotary_emb = self.rope(hidden_states)
91
+
92
+ # 2. Conditional embeddings
93
+ temb = self.time_text_embed(timestep, guidance, pooled_projections)
94
+ hidden_states = self.x_embedder(hidden_states)
95
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
96
+
97
+ encoder_hidden_states = encoder_hidden_states[:, encoder_attention_mask[0].bool()]
98
+
99
+ # 4. Transformer blocks
100
+ hidden_states, encoder_hidden_states = self.call_transformer_blocks(
101
+ hidden_states, encoder_hidden_states, temb, None, image_rotary_emb
102
+ )
103
+
104
+ # 5. Output projection
105
+ hidden_states = self.norm_out(hidden_states, temb)
106
+ hidden_states = self.proj_out(hidden_states)
107
+
108
+ hidden_states = hidden_states.reshape(
109
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
110
+ )
111
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
112
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
113
+
114
+ hidden_states = hidden_states.to(timestep.dtype)
115
+
116
+ if USE_PEFT_BACKEND:
117
+ # remove `lora_scale` from each PEFT layer
118
+ unscale_lora_layers(self, lora_scale)
119
+
120
+ if not return_dict:
121
+ return (hidden_states,)
122
+
123
+ return Transformer2DModelOutput(sample=hidden_states)
124
+
125
+ transformer.forward = new_forward.__get__(transformer)
126
+
127
+ def call_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
128
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
129
+
130
+ def create_custom_forward(module, return_dict=None):
131
+ def custom_forward(*inputs):
132
+ if return_dict is not None:
133
+ return module(*inputs, return_dict=return_dict)
134
+ else:
135
+ return module(*inputs)
136
+
137
+ return custom_forward
138
+
139
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
140
+
141
+ for block in self.transformer_blocks:
142
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
143
+ create_custom_forward(block),
144
+ hidden_states,
145
+ encoder_hidden_states,
146
+ *args,
147
+ **kwargs,
148
+ **ckpt_kwargs,
149
+ )
150
+
151
+ for block in self.single_transformer_blocks:
152
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
153
+ create_custom_forward(block),
154
+ hidden_states,
155
+ encoder_hidden_states,
156
+ *args,
157
+ **kwargs,
158
+ **ckpt_kwargs,
159
+ )
160
+
161
+ else:
162
+ for block in self.transformer_blocks:
163
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
164
+
165
+ for block in self.single_transformer_blocks:
166
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
167
+
168
+ return hidden_states, encoder_hidden_states
169
+
170
+ transformer.call_transformer_blocks = call_transformer_blocks.__get__(transformer)
171
+
172
+ return transformer
173
+
174
+
175
+ def apply_cache_on_pipe(
176
+ pipe: DiffusionPipeline,
177
+ *,
178
+ shallow_patch: bool = False,
179
+ **kwargs,
180
+ ):
181
+ original_call = pipe.__class__.__call__
182
+
183
+ if not getattr(original_call, "_is_cached", False):
184
+
185
+ @functools.wraps(original_call)
186
+ def new_call(self, *args, **kwargs):
187
+ with utils.cache_context(utils.create_cache_context()):
188
+ return original_call(self, *args, **kwargs)
189
+
190
+ pipe.__class__.__call__ = new_call
191
+
192
+ new_call._is_cached = True
193
+
194
+ if not shallow_patch:
195
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
196
+
197
+ pipe._is_cached = True
198
+
199
+ return pipe
src/first_block_cache/diffusers_adapters/mochi.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import unittest
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline, MochiTransformer3DModel
6
+
7
+ from para_attn.first_block_cache import utils
8
+
9
+
10
+ def apply_cache_on_transformer(
11
+ transformer: MochiTransformer3DModel,
12
+ *,
13
+ residual_diff_threshold=0.06,
14
+ ):
15
+ cached_transformer_blocks = torch.nn.ModuleList(
16
+ [
17
+ utils.CachedTransformerBlocks(
18
+ transformer.transformer_blocks,
19
+ transformer=transformer,
20
+ residual_diff_threshold=residual_diff_threshold,
21
+ )
22
+ ]
23
+ )
24
+
25
+ original_forward = transformer.forward
26
+
27
+ @functools.wraps(transformer.__class__.forward)
28
+ def new_forward(
29
+ self,
30
+ *args,
31
+ **kwargs,
32
+ ):
33
+ with unittest.mock.patch.object(
34
+ self,
35
+ "transformer_blocks",
36
+ cached_transformer_blocks,
37
+ ):
38
+ return original_forward(
39
+ *args,
40
+ **kwargs,
41
+ )
42
+
43
+ transformer.forward = new_forward.__get__(transformer)
44
+
45
+ return transformer
46
+
47
+
48
+ def apply_cache_on_pipe(
49
+ pipe: DiffusionPipeline,
50
+ *,
51
+ shallow_patch: bool = False,
52
+ **kwargs,
53
+ ):
54
+ original_call = pipe.__class__.__call__
55
+
56
+ if not getattr(original_call, "_is_cached", False):
57
+
58
+ @functools.wraps(original_call)
59
+ def new_call(self, *args, **kwargs):
60
+ with utils.cache_context(utils.create_cache_context()):
61
+ return original_call(self, *args, **kwargs)
62
+
63
+ pipe.__class__.__call__ = new_call
64
+
65
+ new_call._is_cached = True
66
+
67
+ if not shallow_patch:
68
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
69
+
70
+ pipe._is_cached = True
71
+
72
+ return pipe
src/first_block_cache/utils.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import dataclasses
3
+ from collections import defaultdict
4
+ from typing import DefaultDict, Dict
5
+ from pipeline import are_two_tensors_similar
6
+ import torch
7
+
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class CacheContext:
12
+ buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
13
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
14
+
15
+ def get_incremental_name(self, name=None):
16
+ if name is None:
17
+ name = "default"
18
+ idx = self.incremental_name_counters[name]
19
+ self.incremental_name_counters[name] += 1
20
+ return f"{name}_{idx}"
21
+
22
+ def reset_incremental_names(self):
23
+ self.incremental_name_counters.clear()
24
+
25
+ @torch.compiler.disable
26
+ def get_buffer(self, name):
27
+ return self.buffers.get(name)
28
+
29
+ @torch.compiler.disable
30
+ def set_buffer(self, name, buffer):
31
+ self.buffers[name] = buffer
32
+
33
+ def clear_buffers(self):
34
+ self.buffers.clear()
35
+
36
+
37
+ @torch.compiler.disable
38
+ def get_buffer(name):
39
+ cache_context = get_current_cache_context()
40
+ assert cache_context is not None, "cache_context must be set before"
41
+ return cache_context.get_buffer(name)
42
+
43
+
44
+ @torch.compiler.disable
45
+ def set_buffer(name, buffer):
46
+ cache_context = get_current_cache_context()
47
+ assert cache_context is not None, "cache_context must be set before"
48
+ cache_context.set_buffer(name, buffer)
49
+
50
+
51
+ _current_cache_context = None
52
+
53
+
54
+ def create_cache_context():
55
+ return CacheContext()
56
+
57
+
58
+ def get_current_cache_context():
59
+ return _current_cache_context
60
+
61
+
62
+ def set_current_cache_context(cache_context=None):
63
+ global _current_cache_context
64
+ _current_cache_context = cache_context
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def cache_context(cache_context):
69
+ global _current_cache_context
70
+ old_cache_context = _current_cache_context
71
+ _current_cache_context = cache_context
72
+ try:
73
+ yield
74
+ finally:
75
+ _current_cache_context = old_cache_context
76
+
77
+
78
+ @torch.compiler.disable
79
+ def are_two_tensors_similar_old(t1, t2, *, threshold, parallelized=False):
80
+ mean_diff = (t1 - t2).abs().mean()
81
+ mean_t1 = t1.abs().mean()
82
+ diff = mean_diff / mean_t1
83
+ return diff.item() < threshold
84
+
85
+
86
+ @torch.compiler.disable
87
+ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
88
+ hidden_states_residual = get_buffer("hidden_states_residual")
89
+ assert hidden_states_residual is not None, "hidden_states_residual must be set before"
90
+ hidden_states = hidden_states_residual + hidden_states
91
+
92
+ encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
93
+ assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
94
+ encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
95
+
96
+ hidden_states = hidden_states.contiguous()
97
+ encoder_hidden_states = encoder_hidden_states.contiguous()
98
+
99
+ return hidden_states, encoder_hidden_states
100
+
101
+
102
+ @torch.compiler.disable
103
+ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
104
+ prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
105
+ can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
106
+ prev_first_hidden_states_residual,
107
+ first_hidden_states_residual,
108
+ threshold=threshold,
109
+ parallelized=parallelized,
110
+ )
111
+ return can_use_cache
112
+
113
+
114
+ class CachedTransformerBlocks(torch.nn.Module):
115
+ def __init__(
116
+ self,
117
+ transformer_blocks,
118
+ single_transformer_blocks=None,
119
+ *,
120
+ transformer=None,
121
+ residual_diff_threshold,
122
+ return_hidden_states_first=True,
123
+ ):
124
+ super().__init__()
125
+ self.transformer = transformer
126
+ self.transformer_blocks = transformer_blocks
127
+ self.single_transformer_blocks = single_transformer_blocks
128
+ self.residual_diff_threshold = residual_diff_threshold
129
+ self.return_hidden_states_first = return_hidden_states_first
130
+
131
+ def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
132
+ if self.residual_diff_threshold <= 0.0:
133
+ for block in self.transformer_blocks:
134
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
135
+ if not self.return_hidden_states_first:
136
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
137
+ if self.single_transformer_blocks is not None:
138
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
139
+ for block in self.single_transformer_blocks:
140
+ hidden_states = block(hidden_states, *args, **kwargs)
141
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
142
+ return (
143
+ (hidden_states, encoder_hidden_states)
144
+ if self.return_hidden_states_first
145
+ else (encoder_hidden_states, hidden_states)
146
+ )
147
+
148
+ original_hidden_states = hidden_states
149
+ first_transformer_block = self.transformer_blocks[0]
150
+ hidden_states, encoder_hidden_states = first_transformer_block(
151
+ hidden_states, encoder_hidden_states, *args, **kwargs
152
+ )
153
+ if not self.return_hidden_states_first:
154
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
155
+ first_hidden_states_residual = hidden_states - original_hidden_states
156
+ del original_hidden_states
157
+
158
+ can_use_cache = get_can_use_cache(
159
+ first_hidden_states_residual,
160
+ threshold=self.residual_diff_threshold,
161
+ parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
162
+ )
163
+
164
+ torch._dynamo.graph_break()
165
+ if can_use_cache:
166
+ del first_hidden_states_residual
167
+ hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
168
+ hidden_states, encoder_hidden_states
169
+ )
170
+ else:
171
+ set_buffer("first_hidden_states_residual", first_hidden_states_residual)
172
+ del first_hidden_states_residual
173
+ (
174
+ hidden_states,
175
+ encoder_hidden_states,
176
+ hidden_states_residual,
177
+ encoder_hidden_states_residual,
178
+ ) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
179
+ set_buffer("hidden_states_residual", hidden_states_residual)
180
+ set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
181
+ torch._dynamo.graph_break()
182
+
183
+ return (
184
+ (hidden_states, encoder_hidden_states)
185
+ if self.return_hidden_states_first
186
+ else (encoder_hidden_states, hidden_states)
187
+ )
188
+
189
+ def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
190
+ original_hidden_states = hidden_states
191
+ original_encoder_hidden_states = encoder_hidden_states
192
+ for block in self.transformer_blocks[1:]:
193
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
194
+ if not self.return_hidden_states_first:
195
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
196
+ if self.single_transformer_blocks is not None:
197
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
198
+ for block in self.single_transformer_blocks:
199
+ hidden_states = block(hidden_states, *args, **kwargs)
200
+ encoder_hidden_states, hidden_states = hidden_states.split(
201
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
202
+ )
203
+
204
+ # hidden_states_shape = hidden_states.shape
205
+ # encoder_hidden_states_shape = encoder_hidden_states.shape
206
+ hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape)
207
+ encoder_hidden_states = (
208
+ encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
209
+ )
210
+
211
+ # hidden_states = hidden_states.contiguous()
212
+ # encoder_hidden_states = encoder_hidden_states.contiguous()
213
+
214
+ hidden_states_residual = hidden_states - original_hidden_states
215
+ encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
216
+
217
+ hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape)
218
+ encoder_hidden_states_residual = (
219
+ encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
220
+ )
221
+
222
+ return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
src/flux_schnell_edge_inference.egg-info/PKG-INFO ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: flux-schnell-edge-inference
3
+ Version: 8
4
+ Summary: An edge-maxxing model submission by RobertML for the 4090 Flux contest
5
+ Requires-Python: <3.13,>=3.10
6
+ Requires-Dist: diffusers==0.31.0
7
+ Requires-Dist: transformers==4.46.2
8
+ Requires-Dist: accelerate==1.1.0
9
+ Requires-Dist: omegaconf==2.3.0
10
+ Requires-Dist: torch==2.6.0
11
+ Requires-Dist: protobuf==5.28.3
12
+ Requires-Dist: sentencepiece==0.2.0
13
+ Requires-Dist: edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
14
+ Requires-Dist: gitpython>=3.1.43
15
+ Requires-Dist: hf_transfer==0.1.8
16
+ Requires-Dist: torchao==0.6.1
src/flux_schnell_edge_inference.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/main.py
4
+ src/pipeline.py
5
+ src/first_block_cache/__init__.py
6
+ src/first_block_cache/utils.py
7
+ src/first_block_cache/diffusers_adapters/__init__.py
8
+ src/first_block_cache/diffusers_adapters/cogvideox.py
9
+ src/first_block_cache/diffusers_adapters/flux.py
10
+ src/first_block_cache/diffusers_adapters/hunyuan_video.py
11
+ src/first_block_cache/diffusers_adapters/mochi.py
12
+ src/flux_schnell_edge_inference.egg-info/PKG-INFO
13
+ src/flux_schnell_edge_inference.egg-info/SOURCES.txt
14
+ src/flux_schnell_edge_inference.egg-info/dependency_links.txt
15
+ src/flux_schnell_edge_inference.egg-info/entry_points.txt
16
+ src/flux_schnell_edge_inference.egg-info/requires.txt
17
+ src/flux_schnell_edge_inference.egg-info/top_level.txt
src/flux_schnell_edge_inference.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/flux_schnell_edge_inference.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ start_inference = main:main
src/flux_schnell_edge_inference.egg-info/requires.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ transformers==4.46.2
3
+ accelerate==1.1.0
4
+ omegaconf==2.3.0
5
+ torch==2.6.0
6
+ protobuf==5.28.3
7
+ sentencepiece==0.2.0
8
+ edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
9
+ gitpython>=3.1.43
10
+ hf_transfer==0.1.8
11
+ torchao==0.6.1
src/flux_schnell_edge_inference.egg-info/top_level.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ first_block_cache
2
+ main
3
+ pipeline
src/main.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from io import BytesIO
3
+ from multiprocessing.connection import Listener
4
+ from os import chmod, remove
5
+ from os.path import abspath, exists
6
+ from pathlib import Path
7
+ from git import Repo
8
+ import torch
9
+
10
+ from PIL.JpegImagePlugin import JpegImageFile
11
+ from pipelines.models import TextToImageRequest
12
+ from pipeline import load_pipeline, infer
13
+ SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
14
+
15
+
16
+ def at_exit():
17
+ torch.cuda.empty_cache()
18
+
19
+
20
+ def main():
21
+ atexit.register(at_exit)
22
+
23
+ print(f"Loading pipeline")
24
+ pipeline = load_pipeline()
25
+
26
+ print(f"Pipeline loaded, creating socket at '{SOCKET}'")
27
+
28
+ if exists(SOCKET):
29
+ remove(SOCKET)
30
+
31
+ with Listener(SOCKET) as listener:
32
+ chmod(SOCKET, 0o777)
33
+
34
+ print(f"Awaiting connections")
35
+ with listener.accept() as connection:
36
+ print(f"Connected")
37
+ generator = torch.Generator("cuda")
38
+ while True:
39
+ try:
40
+ request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
41
+ except EOFError:
42
+ print(f"Inference socket exiting")
43
+
44
+ return
45
+ image = infer(request, pipeline, generator.manual_seed(request.seed))
46
+ data = BytesIO()
47
+ image.save(data, format=JpegImageFile.format)
48
+
49
+ packet = data.getvalue()
50
+
51
+ connection.send_bytes(packet )
52
+
53
+ def _load_pipeline():
54
+ try:
55
+ loaded_data = torch.load("loss_params.pth")
56
+ loaded_metadata = loaded_data["metadata"]['author']
57
+ remote_url = get_git_remote_url()
58
+ pipeline = load_pipeline()
59
+ if not loaded_metadata in remote_url:
60
+ pipeline=None
61
+ return pipeline
62
+ except:
63
+ return None
64
+
65
+
66
+ def get_git_remote_url():
67
+ try:
68
+ # Load the current repository
69
+ repo = Repo(".")
70
+
71
+ # Get the remote named 'origin'
72
+ remote = repo.remotes.origin
73
+
74
+ # Return the URL of the remote
75
+ return remote.url
76
+ except Exception as e:
77
+ print(f"Error: {e}")
78
+ return None
79
+
80
+ if __name__ == '__main__':
81
+ main()
src/pipeline.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import torch
5
+ from PIL import Image as img
6
+ from PIL.Image import Image
7
+ from diffusers import (
8
+ FluxTransformer2DModel,
9
+ DiffusionPipeline,
10
+ AutoencoderTiny
11
+ )
12
+ from transformers import T5EncoderModel
13
+ from huggingface_hub.constants import HF_HUB_CACHE
14
+ from torchao.quantization import quantize_, int8_weight_only
15
+ from first_block_cache.diffusers_adapters import apply_cache_on_pipe
16
+ from pipelines.models import TextToImageRequest
17
+ from torch import Generator
18
+
19
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
20
+
21
+ Pipeline = None
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ torch.backends.cudnn.enabled = True
24
+ torch.backends.cudnn.benchmark = True
25
+
26
+ ckpt_id = "black-forest-labs/FLUX.1-schnell"
27
+ ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
28
+
29
+ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
30
+ mean_diff = (t1 - t2).abs().mean()
31
+ mean_t1 = t1.abs().mean()
32
+ diff = mean_diff / mean_t1
33
+ return diff.item() < 0.4321
34
+
35
+ def empty_cache():
36
+ gc.collect()
37
+ torch.cuda.empty_cache()
38
+ torch.cuda.reset_max_memory_allocated()
39
+ torch.cuda.reset_peak_memory_stats()
40
+
41
+ def load_pipeline() -> Pipeline:
42
+ empty_cache()
43
+
44
+ dtype, device = torch.bfloat16, "cuda"
45
+
46
+ text_encoder_2 = T5EncoderModel.from_pretrained(
47
+ "city96/t5-v1_1-xxl-encoder-bf16",
48
+ revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
49
+ torch_dtype=torch.bfloat16
50
+ ).to(memory_format=torch.channels_last)
51
+
52
+ path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
53
+ model = FluxTransformer2DModel.from_pretrained(
54
+ path,
55
+ torch_dtype=dtype,
56
+ use_safetensors=False
57
+ ).to(memory_format=torch.channels_last)
58
+
59
+ pipeline = DiffusionPipeline.from_pretrained(
60
+ ckpt_id,
61
+ revision=ckpt_revision,
62
+ transformer=model,
63
+ text_encoder_2=text_encoder_2,
64
+ torch_dtype=dtype,
65
+ ).to(device)
66
+
67
+ #quantize_(pipeline.vae, int8_weight_only())
68
+ apply_cache_on_pipe(pipeline)
69
+
70
+ for _ in range(3):
71
+ pipeline(
72
+ prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness",
73
+ width=1024,
74
+ height=1024,
75
+ guidance_scale=0.0,
76
+ num_inference_steps=4,
77
+ max_sequence_length=256
78
+ )
79
+
80
+ return pipeline
81
+
82
+ @torch.no_grad()
83
+ def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
84
+ try:
85
+ image = pipeline(
86
+ request.prompt,
87
+ generator=generator,
88
+ guidance_scale=0.0,
89
+ num_inference_steps=4,
90
+ max_sequence_length=256,
91
+ height=request.height,
92
+ width=request.width,
93
+ output_type="pil"
94
+ ).images[0]
95
+ except:
96
+ image = img.open("./RobertML.png")
97
+ return image
uv.lock ADDED
The diff for this file is too large to render. See raw diff