Ryukijano commited on
Commit
9565796
·
verified ·
1 Parent(s): 3567671

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +4 -21
  2. custom_pipeline.py +0 -192
  3. requirements.txt +9 -9
app.py CHANGED
@@ -7,7 +7,6 @@ import time
7
  from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
10
- from huggingface_hub import login
11
 
12
  torch.backends.cuda.matmul.allow_tf32 = True
13
 
@@ -19,33 +18,17 @@ DEFAULT_HEIGHT = 1024
19
  DEFAULT_INFERENCE_STEPS = 1
20
 
21
  # Device and model setup
22
- dtype = torch.bfloat16
23
  pipe = FluxWithCFGPipeline.from_pretrained(
24
- "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype, use_safetensors=True
25
  )
26
- pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, use_safetensors=True, variant="fp16")
27
  pipe.to("cuda")
28
  pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
29
  pipe.set_adapters(["better"], adapter_weights=[1.0])
30
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
31
  pipe.unload_lora_weights()
32
 
33
- # Enable xformers
34
- pipe.enable_xformers_memory_efficient_attention()
35
-
36
- # Compile the model (Optional, needs further testing for stability)
37
- # pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
38
-
39
- # Capture CUDA Graph (Warm-up)
40
- static_inputs = {
41
- "prompt": "warmup",
42
- "width": DEFAULT_WIDTH,
43
- "height": DEFAULT_HEIGHT,
44
- "num_inference_steps": DEFAULT_INFERENCE_STEPS,
45
- "generator": torch.Generator().manual_seed(0),
46
- }
47
-
48
- pipe.capture_cuda_graph(**static_inputs)
49
  torch.cuda.empty_cache()
50
 
51
  # Inference function
@@ -180,4 +163,4 @@ with gr.Blocks() as demo:
180
  )
181
 
182
  # Launch the app
183
- demo.launch()
 
7
  from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
 
10
 
11
  torch.backends.cuda.matmul.allow_tf32 = True
12
 
 
18
  DEFAULT_INFERENCE_STEPS = 1
19
 
20
  # Device and model setup
21
+ dtype = torch.float16
22
  pipe = FluxWithCFGPipeline.from_pretrained(
23
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
24
  )
25
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
26
  pipe.to("cuda")
27
  pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
28
  pipe.set_adapters(["better"], adapter_weights=[1.0])
29
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
30
  pipe.unload_lora_weights()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  torch.cuda.empty_cache()
33
 
34
  # Inference function
 
163
  )
164
 
165
  # Launch the app
166
+ demo.launch()
custom_pipeline.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
  from typing import Any, Dict, List, Optional, Union
5
  from PIL import Image
6
- from collections import OrderedDict
7
 
8
  # Constants for shift calculation
9
  BASE_SEQ_LEN = 256
@@ -48,169 +47,6 @@ class FluxWithCFGPipeline(FluxPipeline):
48
  Extends the FluxPipeline to yield intermediate images during the denoising process
49
  with progressively increasing resolution for faster generation.
50
  """
51
- def __init__(
52
- self,
53
- vae,
54
- text_encoder,
55
- text_encoder_2,
56
- tokenizer,
57
- tokenizer_2,
58
- transformer,
59
- scheduler: FlowMatchEulerDiscreteScheduler,
60
- ):
61
- super().__init__(vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, transformer, scheduler)
62
- self.cuda_graphs = {}
63
-
64
- def capture_cuda_graph(
65
- self,
66
- prompt: Union[str, List[str]] = None,
67
- prompt_2: Optional[Union[str, List[str]]] = None,
68
- height: Optional[int] = None,
69
- width: Optional[int] = None,
70
- num_inference_steps: int = 4,
71
- guidance_scale: float = 3.5,
72
- num_images_per_prompt: Optional[int] = 1,
73
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
74
- latents: Optional[torch.FloatTensor] = None,
75
- prompt_embeds: Optional[torch.FloatTensor] = None,
76
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
77
- output_type: Optional[str] = "pil",
78
- return_dict: bool = True,
79
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80
- max_sequence_length: int = 300,
81
- **kwargs,
82
- ):
83
- """
84
- Captures a static CUDA Graph for the generation process given static inputs.
85
- """
86
- # Use a static size for all inputs
87
- static_height = height
88
- static_width = width
89
-
90
- # 1. Check inputs
91
- self.check_inputs(
92
- prompt,
93
- prompt_2,
94
- static_height,
95
- static_width,
96
- prompt_embeds=prompt_embeds,
97
- pooled_prompt_embeds=pooled_prompt_embeds,
98
- max_sequence_length=max_sequence_length,
99
- )
100
-
101
- self._guidance_scale = guidance_scale
102
- self._joint_attention_kwargs = joint_attention_kwargs
103
- self._interrupt = False
104
-
105
- # 2. Define call parameters
106
- batch_size = 1
107
- device = self._execution_device
108
-
109
- # 3. Encode prompt (with static inputs)
110
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
111
-
112
- # Use a static prompt for capture
113
- static_prompt = "static prompt" if isinstance(prompt, str) else ["static prompt"]
114
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
115
- prompt=static_prompt,
116
- prompt_2=prompt_2,
117
- prompt_embeds=None,
118
- pooled_prompt_embeds=None,
119
- device=device,
120
- num_images_per_prompt=num_images_per_prompt,
121
- max_sequence_length=max_sequence_length,
122
- lora_scale=lora_scale,
123
- )
124
-
125
- # 4. Prepare latent variables (with static inputs)
126
- num_channels_latents = self.transformer.config.in_channels // 4
127
- latents, latent_image_ids = self.prepare_latents(
128
- batch_size * num_images_per_prompt,
129
- num_channels_latents,
130
- static_height,
131
- static_width,
132
- prompt_embeds.dtype,
133
- device,
134
- generator,
135
- None,
136
- )
137
-
138
- # 5. Prepare timesteps (with static inputs)
139
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
140
- image_seq_len = latents.shape[1]
141
- mu = calculate_timestep_shift(image_seq_len)
142
- timesteps, num_inference_steps = prepare_timesteps(
143
- self.scheduler,
144
- num_inference_steps,
145
- device,
146
- None,
147
- sigmas,
148
- mu=mu,
149
- )
150
- self._num_timesteps = len(timesteps)
151
-
152
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
153
-
154
- # Capture the graph
155
- torch.cuda.synchronize()
156
- stream = torch.cuda.Stream()
157
- stream.wait_stream(torch.cuda.current_stream())
158
- with torch.cuda.stream(stream):
159
- for i, t in enumerate(timesteps):
160
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
161
- noise_pred = self.transformer(
162
- hidden_states=latents,
163
- timestep=timestep / 1000,
164
- guidance=guidance,
165
- pooled_projections=pooled_prompt_embeds,
166
- encoder_hidden_states=prompt_embeds,
167
- txt_ids=text_ids,
168
- img_ids=latent_image_ids,
169
- joint_attention_kwargs=self.joint_attention_kwargs,
170
- return_dict=False,
171
- )[0]
172
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
173
-
174
- torch.cuda.current_stream().wait_stream(stream)
175
- torch.cuda.synchronize()
176
-
177
- # Capture the CUDA graph
178
- graph = torch.cuda.CUDAGraph()
179
- with torch.cuda.graph(graph, stream=stream):
180
- # Create static inputs
181
- static_inputs = OrderedDict()
182
- static_inputs["hidden_states"] = latents.clone()
183
- static_inputs["timestep"] = timesteps[0].expand(latents.shape[0]).to(latents.dtype)
184
- static_inputs["guidance"] = guidance.clone() if guidance is not None else None
185
- static_inputs["pooled_projections"] = pooled_prompt_embeds.clone()
186
- static_inputs["encoder_hidden_states"] = prompt_embeds.clone()
187
- static_inputs["txt_ids"] = text_ids
188
- static_inputs["img_ids"] = latent_image_ids.clone()
189
- static_inputs["joint_attention_kwargs"] = self.joint_attention_kwargs
190
-
191
- # Run the static graph
192
- for i, t in enumerate(timesteps):
193
- timestep = static_inputs["timestep"].clone()
194
- noise_pred = self.transformer(
195
- hidden_states=static_inputs["hidden_states"],
196
- timestep=timestep / 1000,
197
- guidance=static_inputs["guidance"],
198
- pooled_projections=static_inputs["pooled_projections"],
199
- encoder_hidden_states=static_inputs["encoder_hidden_states"],
200
- txt_ids=static_inputs["txt_ids"],
201
- img_ids=static_inputs["img_ids"],
202
- joint_attention_kwargs=static_inputs["joint_attention_kwargs"],
203
- return_dict=False,
204
- )[0]
205
- static_inputs["hidden_states"] = self.scheduler.step(noise_pred, t, static_inputs["hidden_states"], return_dict=False)[0]
206
-
207
- # Decode the latents after the loop
208
- final_latents = static_inputs["hidden_states"]
209
- final_image = self._decode_latents_to_image(final_latents, static_height, static_width, output_type)
210
-
211
- # Store the graph and static inputs in the dictionary
212
- self.cuda_graphs[(static_height, static_width, num_inference_steps)] = (graph, static_inputs, final_image)
213
-
214
  @torch.inference_mode()
215
  def generate_images(
216
  self,
@@ -235,34 +71,6 @@ class FluxWithCFGPipeline(FluxPipeline):
235
  height = height or self.default_sample_size * self.vae_scale_factor
236
  width = width or self.default_sample_size * self.vae_scale_factor
237
 
238
- # 0. Check if a CUDA graph can be used
239
- if (height, width, num_inference_steps) in self.cuda_graphs:
240
- graph, static_inputs, final_image = self.cuda_graphs[(height, width, num_inference_steps)]
241
-
242
- # Update dynamic inputs (like prompt) in static_inputs
243
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
244
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
245
- prompt=prompt,
246
- prompt_2=prompt_2,
247
- prompt_embeds=prompt_embeds,
248
- pooled_prompt_embeds=pooled_prompt_embeds,
249
- device=self._execution_device,
250
- num_images_per_prompt=num_images_per_prompt,
251
- max_sequence_length=max_sequence_length,
252
- lora_scale=lora_scale,
253
- )
254
-
255
- # Update only the dynamic parts of static_inputs
256
- static_inputs["pooled_projections"].copy_(pooled_prompt_embeds)
257
- static_inputs["encoder_hidden_states"].copy_(prompt_embeds)
258
- static_inputs["txt_ids"] = text_ids
259
-
260
- # Replay the graph
261
- graph.replay()
262
- torch.cuda.empty_cache()
263
-
264
- return final_image
265
-
266
  # 1. Check inputs
267
  self.check_inputs(
268
  prompt,
 
3
  from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
  from typing import Any, Dict, List, Optional, Union
5
  from PIL import Image
 
6
 
7
  # Constants for shift calculation
8
  BASE_SEQ_LEN = 256
 
47
  Extends the FluxPipeline to yield intermediate images during the denoising process
48
  with progressively increasing resolution for faster generation.
49
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  @torch.inference_mode()
51
  def generate_images(
52
  self,
 
71
  height = height or self.default_sample_size * self.vae_scale_factor
72
  width = width or self.default_sample_size * self.vae_scale_factor
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # 1. Check inputs
75
  self.check_inputs(
76
  prompt,
requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
- accelerate
2
- git+https://github.com/huggingface/diffusers.git@main
3
- torch>=2.0
4
- gradio==5.8.0
5
- transformers
6
- xformers
7
- sentencepiece
8
- peft
9
- numpy
10
  pillow
 
1
+ accelerate
2
+ git+https://github.com/huggingface/diffusers.git@main
3
+ torch>=2.0
4
+ gradio==5.8.0
5
+ transformers
6
+ xformers
7
+ sentencepiece
8
+ peft
9
+ numpy
10
  pillow