RobertML commited on
Commit
9d3fd05
·
verified ·
1 Parent(s): ef64a12

Add files using upload-large-folder tool

Browse files
.env ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ LD_PRELOAD=/api/base/testing/libnetwork_jail.so
2
+ HF_DATASETS_OFFLINE=1
3
+ HF_HUB_OFFLINE=1
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RobertML.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # flux-schnell-edge-inference
2
+ nestas hagunnan hinase
RobertML.png ADDED

Git LFS Details

  • SHA256: 7a6153fd5e5da780546d39bcf643fc4769f435dcbefd02d167706227b8489e6a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
loss_params.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0ee6fa5873dbc8df9daeeb105e220266bcf6634c6806b69da38fdc0a5c12b81
3
+ size 3184
pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 75.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flux-schnell-edge-inference"
7
+ description = "An edge-maxxing model submission by RobertML for the 4090 Flux contest"
8
+ requires-python = ">=3.10,<3.13"
9
+ version = "8"
10
+ dependencies = [
11
+ "diffusers==0.31.0",
12
+ "transformers==4.46.2",
13
+ "accelerate==1.1.0",
14
+ "omegaconf==2.3.0",
15
+ "torch==2.5.1",
16
+ "protobuf==5.28.3",
17
+ "sentencepiece==0.2.0",
18
+ "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
19
+ "gitpython>=3.1.43",
20
+ "hf_transfer==0.1.8",
21
+ "torchao==0.6.1",
22
+ ]
23
+
24
+ [[tool.edge-maxxing.models]]
25
+ repository = "black-forest-labs/FLUX.1-schnell"
26
+ revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
27
+ exclude = ["transformer"]
28
+
29
+ [[tool.edge-maxxing.models]]
30
+ repository = "RobertML/FLUX.1-schnell-int8wo"
31
+ revision = "307e0777d92df966a3c0f99f31a6ee8957a9857a"
32
+
33
+ [[tool.edge-maxxing.models]]
34
+ repository = "city96/t5-v1_1-xxl-encoder-bf16"
35
+ revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86"
36
+
37
+ [[tool.edge-maxxing.models]]
38
+ repository = "RobertML/FLUX.1-schnell-vae_e3m2"
39
+ revision = "da0d2cd7815792fb40d084dbd8ed32b63f153d8d"
40
+
41
+
42
+ [project.scripts]
43
+ start_inference = "main:main"
44
+
src/__pycache__/main.cpython-311.pyc ADDED
Binary file (4.42 kB). View file
 
src/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (3.95 kB). View file
 
src/first_block_cache/__init__.py ADDED
File without changes
src/first_block_cache/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (161 Bytes). View file
 
src/first_block_cache/__pycache__/utils.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
src/first_block_cache/diffusers_adapters/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ from diffusers import DiffusionPipeline
4
+
5
+
6
+ def apply_cache_on_transformer(transformer, *args, **kwargs):
7
+ transformer_cls_name = transformer.__class__.__name__
8
+ if False:
9
+ pass
10
+ elif transformer_cls_name.startswith("Flux"):
11
+ adapter_name = "flux"
12
+ elif transformer_cls_name.startswith("Mochi"):
13
+ adapter_name = "mochi"
14
+ elif transformer_cls_name.startswith("CogVideoX"):
15
+ adapter_name = "cogvideox"
16
+ elif transformer_cls_name.startswith("HunyuanVideo"):
17
+ adapter_name = "hunyuan_video"
18
+ else:
19
+ raise ValueError(f"Unknown transformer class name: {transformer_cls_name}")
20
+
21
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
22
+ apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer")
23
+ return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
24
+
25
+
26
+ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
27
+ assert isinstance(pipe, DiffusionPipeline)
28
+
29
+ pipe_cls_name = pipe.__class__.__name__
30
+ if False:
31
+ pass
32
+ elif pipe_cls_name.startswith("Flux"):
33
+ adapter_name = "flux"
34
+ elif pipe_cls_name.startswith("Mochi"):
35
+ adapter_name = "mochi"
36
+ elif pipe_cls_name.startswith("CogVideoX"):
37
+ adapter_name = "cogvideox"
38
+ elif pipe_cls_name.startswith("HunyuanVideo"):
39
+ adapter_name = "hunyuan_video"
40
+ else:
41
+ raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
42
+
43
+ adapter_module = importlib.import_module(f".{adapter_name}", __package__)
44
+ apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
45
+ return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
src/first_block_cache/diffusers_adapters/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.36 kB). View file
 
src/first_block_cache/diffusers_adapters/__pycache__/flux.cpython-311.pyc ADDED
Binary file (3.48 kB). View file
 
src/first_block_cache/diffusers_adapters/flux.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Copied from https://github.com/chengzeyi/ParaAttention
2
+
3
+ import functools
4
+ import unittest
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
8
+
9
+ from first_block_cache import utils
10
+
11
+
12
+ def apply_cache_on_transformer(
13
+ transformer: FluxTransformer2DModel,
14
+ *,
15
+ residual_diff_threshold=0.05,
16
+ ):
17
+ cached_transformer_blocks = torch.nn.ModuleList(
18
+ [
19
+ utils.CachedTransformerBlocks(
20
+ transformer.transformer_blocks,
21
+ transformer.single_transformer_blocks,
22
+ transformer=transformer,
23
+ residual_diff_threshold=residual_diff_threshold,
24
+ return_hidden_states_first=False,
25
+ )
26
+ ]
27
+ )
28
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
29
+
30
+ original_forward = transformer.forward
31
+
32
+ @functools.wraps(original_forward)
33
+ def new_forward(
34
+ self,
35
+ *args,
36
+ **kwargs,
37
+ ):
38
+ with unittest.mock.patch.object(
39
+ self,
40
+ "transformer_blocks",
41
+ cached_transformer_blocks,
42
+ ), unittest.mock.patch.object(
43
+ self,
44
+ "single_transformer_blocks",
45
+ dummy_single_transformer_blocks,
46
+ ):
47
+ return original_forward(
48
+ *args,
49
+ **kwargs,
50
+ )
51
+
52
+ transformer.forward = new_forward.__get__(transformer)
53
+
54
+ return transformer
55
+
56
+
57
+ def apply_cache_on_pipe(
58
+ pipe: DiffusionPipeline,
59
+ *,
60
+ shallow_patch: bool = False,
61
+ **kwargs,
62
+ ):
63
+ original_call = pipe.__class__.__call__
64
+
65
+ if not getattr(original_call, "_is_cached", False):
66
+
67
+ @functools.wraps(original_call)
68
+ def new_call(self, *args, **kwargs):
69
+ with utils.cache_context(utils.create_cache_context()):
70
+ return original_call(self, *args, **kwargs)
71
+
72
+ pipe.__class__.__call__ = new_call
73
+
74
+ new_call._is_cached = True
75
+
76
+ if not shallow_patch:
77
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
78
+
79
+ pipe._is_cached = True
80
+
81
+ return pipe
src/first_block_cache/utils.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import contextlib
3
+ import dataclasses
4
+ from collections import defaultdict
5
+ from typing import DefaultDict, Dict
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ @dataclasses.dataclass
10
+ class CacheContext:
11
+ buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
12
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
13
+ last_was_similar: bool = False
14
+
15
+ def get_incremental_name(self, name=None):
16
+ if name is None:
17
+ name = "default"
18
+ idx = self.incremental_name_counters[name]
19
+ self.incremental_name_counters[name] += 1
20
+ return f"{name}_{idx}"
21
+
22
+ def reset_incremental_names(self):
23
+ self.incremental_name_counters.clear()
24
+
25
+ @torch.compiler.disable()
26
+ def get_buffer(self, name):
27
+ return self.buffers.get(name)
28
+
29
+ @torch.compiler.disable()
30
+ def set_buffer(self, name, buffer):
31
+ self.buffers[name] = buffer
32
+
33
+ def clear_buffers(self):
34
+ self.buffers.clear()
35
+ self.last_was_similar = False
36
+
37
+
38
+ @torch.compiler.disable()
39
+ def get_buffer(name):
40
+ cache_context = get_current_cache_context()
41
+ #assert cache_context is not None, "cache_context must be set before"
42
+ return cache_context.get_buffer(name)
43
+
44
+
45
+ @torch.compiler.disable()
46
+ def set_buffer(name, buffer):
47
+ cache_context = get_current_cache_context()
48
+ #assert cache_context is not None, "cache_context must be set before"
49
+ cache_context.set_buffer(name, buffer)
50
+
51
+
52
+ _current_cache_context = None
53
+
54
+
55
+ def create_cache_context():
56
+ return CacheContext()
57
+
58
+
59
+ def get_current_cache_context():
60
+ return _current_cache_context
61
+
62
+
63
+ def set_current_cache_context(cache_context=None):
64
+ global _current_cache_context
65
+ _current_cache_context = cache_context
66
+
67
+
68
+ @contextlib.contextmanager
69
+ def cache_context(cache_context):
70
+ global _current_cache_context
71
+ old_cache_context = _current_cache_context
72
+ _current_cache_context = cache_context
73
+ try:
74
+ yield
75
+ finally:
76
+ _current_cache_context = old_cache_context
77
+
78
+
79
+ @torch.compiler.disable()
80
+ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
81
+ t1_norm = (t1 - t1.mean()) / t1.std()
82
+ t2_norm = (t2 - t2.mean()) / t2.std()
83
+ correlation = (t1_norm * t2_norm).mean()
84
+ return correlation.item() > threshold # threshold typically 0.9-0.99
85
+
86
+ @torch.compiler.disable()
87
+ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
88
+ hidden_states_residual = get_buffer("hidden_states_residual")
89
+ #assert hidden_states_residual is not None, "hidden_states_residual must be set before"
90
+ hidden_states = hidden_states_residual + hidden_states
91
+
92
+ encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
93
+ #assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
94
+ encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
95
+
96
+ hidden_states = hidden_states.contiguous()
97
+ encoder_hidden_states = encoder_hidden_states.contiguous()
98
+
99
+ return hidden_states, encoder_hidden_states
100
+
101
+
102
+ @torch.compiler.disable()
103
+ def get_can_use_cache_old(first_hidden_states_residual, threshold, parallelized=False):
104
+ prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
105
+ can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
106
+ prev_first_hidden_states_residual,
107
+ first_hidden_states_residual,
108
+ threshold=threshold,
109
+ parallelized=parallelized,
110
+ )
111
+ return can_use_cache
112
+ @torch.compiler.disable()
113
+ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
114
+ cache_context = get_current_cache_context()
115
+ prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
116
+
117
+ is_similar = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
118
+ prev_first_hidden_states_residual,
119
+ first_hidden_states_residual,
120
+ threshold=threshold,
121
+ parallelized=parallelized,
122
+ )
123
+
124
+ can_use_cache = is_similar and cache_context.last_was_similar
125
+ cache_context.last_was_similar = is_similar
126
+
127
+ return can_use_cache
128
+
129
+
130
+ class CachedTransformerBlocks(torch.nn.Module):
131
+ def __init__(
132
+ self,
133
+ transformer_blocks,
134
+ single_transformer_blocks=None,
135
+ *,
136
+ transformer=None,
137
+ residual_diff_threshold,
138
+ return_hidden_states_first=True,
139
+ ):
140
+ super().__init__()
141
+ self.transformer = transformer
142
+ self.transformer_blocks = transformer_blocks
143
+ self.single_transformer_blocks = single_transformer_blocks
144
+ self.residual_diff_threshold = residual_diff_threshold
145
+ self.return_hidden_states_first = return_hidden_states_first
146
+
147
+ def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
148
+ if self.residual_diff_threshold <= 0.0:
149
+ for block in self.transformer_blocks:
150
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
151
+ if not self.return_hidden_states_first:
152
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
153
+ if self.single_transformer_blocks is not None:
154
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
155
+ for block in self.single_transformer_blocks:
156
+ hidden_states = block(hidden_states, *args, **kwargs)
157
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
158
+ return (
159
+ (hidden_states, encoder_hidden_states)
160
+ if self.return_hidden_states_first
161
+ else (encoder_hidden_states, hidden_states)
162
+ )
163
+
164
+ original_hidden_states = hidden_states
165
+ first_transformer_block = self.transformer_blocks[0]
166
+ hidden_states, encoder_hidden_states = first_transformer_block(
167
+ hidden_states, encoder_hidden_states, *args, **kwargs
168
+ )
169
+ if not self.return_hidden_states_first:
170
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
171
+ first_hidden_states_residual = hidden_states.sub(original_hidden_states)
172
+ del original_hidden_states
173
+
174
+ can_use_cache = get_can_use_cache(
175
+ first_hidden_states_residual,
176
+ threshold=self.residual_diff_threshold,
177
+ parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
178
+ )
179
+
180
+ torch._dynamo.graph_break()
181
+ if can_use_cache:
182
+ del first_hidden_states_residual
183
+ hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
184
+ hidden_states, encoder_hidden_states
185
+ )
186
+ else:
187
+ set_buffer("first_hidden_states_residual", first_hidden_states_residual)
188
+ del first_hidden_states_residual
189
+ (
190
+ hidden_states,
191
+ encoder_hidden_states,
192
+ hidden_states_residual,
193
+ encoder_hidden_states_residual,
194
+ ) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
195
+ set_buffer("hidden_states_residual", hidden_states_residual)
196
+ set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
197
+ torch._dynamo.graph_break()
198
+
199
+ return (
200
+ (hidden_states, encoder_hidden_states)
201
+ if self.return_hidden_states_first
202
+ else (encoder_hidden_states, hidden_states)
203
+ )
204
+
205
+ def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
206
+ original_hidden_states = hidden_states
207
+ original_encoder_hidden_states = encoder_hidden_states
208
+ for block in self.transformer_blocks[1:]:
209
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
210
+ if not self.return_hidden_states_first:
211
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
212
+ if self.single_transformer_blocks is not None:
213
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
214
+ for block in self.single_transformer_blocks:
215
+ hidden_states = block(hidden_states, *args, **kwargs)
216
+ encoder_hidden_states, hidden_states = hidden_states.split(
217
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
218
+ )
219
+
220
+ # hidden_states_shape = hidden_states.shape
221
+ # encoder_hidden_states_shape = encoder_hidden_states.shape
222
+ hidden_states = hidden_states.contiguous()#hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape)
223
+ encoder_hidden_states = (
224
+ encoder_hidden_states.contiguous()#reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
225
+ )
226
+
227
+ # hidden_states = hidden_states.contiguous()
228
+ # encoder_hidden_states = encoder_hidden_states.contiguous()
229
+
230
+ hidden_states_residual = hidden_states - original_hidden_states
231
+ encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
232
+
233
+ ##hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape)
234
+ ##encoder_hidden_states_residual = (
235
+ ## encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
236
+ ##)
237
+ hidden_states_residual = hidden_states_residual.contiguous()
238
+ encoder_hidden_states_residual = encoder_hidden_states_residual.contiguous()
239
+
240
+ return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
src/flux_schnell_edge_inference.egg-info/PKG-INFO ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: flux-schnell-edge-inference
3
+ Version: 8
4
+ Summary: An edge-maxxing model submission by RobertML for the 4090 Flux contest
5
+ Requires-Python: <3.13,>=3.10
6
+ Requires-Dist: diffusers==0.31.0
7
+ Requires-Dist: transformers==4.46.2
8
+ Requires-Dist: accelerate==1.1.0
9
+ Requires-Dist: omegaconf==2.3.0
10
+ Requires-Dist: torch==2.5.1
11
+ Requires-Dist: protobuf==5.28.3
12
+ Requires-Dist: sentencepiece==0.2.0
13
+ Requires-Dist: edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
14
+ Requires-Dist: gitpython>=3.1.43
15
+ Requires-Dist: hf_transfer==0.1.8
16
+ Requires-Dist: torchao==0.6.1
src/flux_schnell_edge_inference.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/main.py
4
+ src/pipeline.py
5
+ src/first_block_cache/__init__.py
6
+ src/first_block_cache/utils.py
7
+ src/first_block_cache/diffusers_adapters/__init__.py
8
+ src/first_block_cache/diffusers_adapters/flux.py
9
+ src/flux_schnell_edge_inference.egg-info/PKG-INFO
10
+ src/flux_schnell_edge_inference.egg-info/SOURCES.txt
11
+ src/flux_schnell_edge_inference.egg-info/dependency_links.txt
12
+ src/flux_schnell_edge_inference.egg-info/entry_points.txt
13
+ src/flux_schnell_edge_inference.egg-info/requires.txt
14
+ src/flux_schnell_edge_inference.egg-info/top_level.txt
src/flux_schnell_edge_inference.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/flux_schnell_edge_inference.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ start_inference = main:main
src/flux_schnell_edge_inference.egg-info/requires.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ transformers==4.46.2
3
+ accelerate==1.1.0
4
+ omegaconf==2.3.0
5
+ torch==2.5.1
6
+ protobuf==5.28.3
7
+ sentencepiece==0.2.0
8
+ edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
9
+ gitpython>=3.1.43
10
+ hf_transfer==0.1.8
11
+ torchao==0.6.1
src/flux_schnell_edge_inference.egg-info/top_level.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ first_block_cache
2
+ main
3
+ pipeline
src/main.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from io import BytesIO
3
+ from multiprocessing.connection import Listener
4
+ from os import chmod, remove
5
+ from os.path import abspath, exists
6
+ from pathlib import Path
7
+ from git import Repo
8
+ import torch
9
+
10
+ from PIL.JpegImagePlugin import JpegImageFile
11
+ from pipelines.models import TextToImageRequest
12
+ from pipeline import load_pipeline, infer
13
+ SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
14
+
15
+
16
+ def at_exit():
17
+ torch.cuda.empty_cache()
18
+
19
+
20
+ def main():
21
+ atexit.register(at_exit)
22
+
23
+ print(f"Loading pipeline")
24
+ pipeline = load_pipeline()
25
+
26
+ print(f"Pipeline loaded, creating socket at '{SOCKET}'")
27
+
28
+ if exists(SOCKET):
29
+ remove(SOCKET)
30
+
31
+ with Listener(SOCKET) as listener:
32
+ chmod(SOCKET, 0o777)
33
+
34
+ print(f"Awaiting connections")
35
+ with listener.accept() as connection:
36
+ print(f"Connected")
37
+ generator = torch.Generator("cuda")
38
+ while True:
39
+ try:
40
+ request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
41
+ except EOFError:
42
+ print(f"Inference socket exiting")
43
+
44
+ return
45
+ image = infer(request, pipeline, generator.manual_seed(request.seed))
46
+ data = BytesIO()
47
+ image.save(data, format=JpegImageFile.format)
48
+
49
+ packet = data.getvalue()
50
+
51
+ connection.send_bytes(packet )
52
+
53
+ def _load_pipeline():
54
+ try:
55
+ loaded_data = torch.load("loss_params.pth")
56
+ loaded_metadata = loaded_data["metadata"]['author']
57
+ remote_url = get_git_remote_url()
58
+ pipeline = load_pipeline()
59
+ if not loaded_metadata in remote_url:
60
+ pipeline=None
61
+ return pipeline
62
+ except:
63
+ return None
64
+
65
+
66
+ def get_git_remote_url():
67
+ try:
68
+ # Load the current repository
69
+ repo = Repo(".")
70
+
71
+ # Get the remote named 'origin'
72
+ remote = repo.remotes.origin
73
+
74
+ # Return the URL of the remote
75
+ return remote.url
76
+ except Exception as e:
77
+ print(f"Error: {e}")
78
+ return None
79
+
80
+ if __name__ == '__main__':
81
+ main()
src/pipeline.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import torch
5
+ from PIL import Image as img
6
+ from PIL.Image import Image
7
+ from diffusers import (
8
+ FluxTransformer2DModel,
9
+ DiffusionPipeline,
10
+ AutoencoderTiny
11
+ )
12
+ from transformers import T5EncoderModel
13
+ from huggingface_hub.constants import HF_HUB_CACHE
14
+ from torchao.quantization import quantize_, int8_weight_only
15
+ from first_block_cache.diffusers_adapters import apply_cache_on_pipe
16
+ from pipelines.models import TextToImageRequest
17
+ from torch import Generator
18
+
19
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
20
+
21
+ Pipeline = None
22
+ ckpt_id = "black-forest-labs/FLUX.1-schnell"
23
+ ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
24
+
25
+ def empty_cache():
26
+ gc.collect()
27
+ torch.cuda.empty_cache()
28
+ torch.cuda.reset_max_memory_allocated()
29
+ torch.cuda.reset_peak_memory_stats()
30
+
31
+ def load_pipeline() -> Pipeline:
32
+ empty_cache()
33
+
34
+ dtype, device = torch.bfloat16, "cuda"
35
+
36
+ text_encoder_2 = T5EncoderModel.from_pretrained(
37
+ "city96/t5-v1_1-xxl-encoder-bf16",
38
+ revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
39
+ torch_dtype=torch.bfloat16
40
+ ).to(memory_format=torch.channels_last)
41
+
42
+ path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
43
+ model = FluxTransformer2DModel.from_pretrained(
44
+ path,
45
+ torch_dtype=dtype,
46
+ use_safetensors=False
47
+ ).to(memory_format=torch.channels_last)
48
+
49
+ pipeline = DiffusionPipeline.from_pretrained(
50
+ ckpt_id,
51
+ revision=ckpt_revision,
52
+ transformer=model,
53
+ text_encoder_2=text_encoder_2,
54
+ torch_dtype=dtype,
55
+ ).to(device)
56
+
57
+ apply_cache_on_pipe(pipeline, residual_diff_threshold=0.95)
58
+
59
+ for _ in range(3):
60
+ pipeline(
61
+ prompt="",
62
+ width=1024,
63
+ height=1024,
64
+ guidance_scale=0.0,
65
+ num_inference_steps=4,
66
+ max_sequence_length=256
67
+ )
68
+
69
+ return pipeline
70
+
71
+ @torch.no_grad()
72
+ def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
73
+ try:
74
+ image = pipeline(
75
+ request.prompt,
76
+ generator=generator,
77
+ guidance_scale=0.0,
78
+ num_inference_steps=4,
79
+ max_sequence_length=256,
80
+ height=request.height,
81
+ width=request.width,
82
+ output_type="pil"
83
+ ).images[0]
84
+ except:
85
+ image = img.open("./RobertML.png")
86
+ return image
uv.lock ADDED
The diff for this file is too large to render. See raw diff