Ryukijano commited on
Commit
431e45c
·
verified ·
1 Parent(s): c8f2370

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +9 -17
  2. custom_pipeline.py +22 -50
app.py CHANGED
@@ -8,10 +8,7 @@ from diffusers import DiffusionPipeline, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
10
 
11
- # Enable TF32 and set Tensor Core precision
12
  torch.backends.cuda.matmul.allow_tf32 = True
13
- torch.backends.cudnn.allow_tf32 = True
14
- torch.set_float32_matmul_precision('high')
15
 
16
  # Constants
17
  MAX_SEED = np.iinfo(np.int32).max
@@ -32,10 +29,6 @@ pipe.set_adapters(["better"], adapter_weights=[1.0])
32
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
33
  pipe.unload_lora_weights()
34
 
35
- # Memory optimizations (optional, uncomment if needed)
36
- # pipe.enable_model_cpu_offload()
37
- # pipe.enable_sequential_cpu_offload()
38
-
39
  torch.cuda.empty_cache()
40
 
41
  # Inference function
@@ -47,15 +40,14 @@ def generate_image(prompt, seed=24, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
47
 
48
  start_time = time.time()
49
 
50
- with torch.autocast(device_type="cuda", dtype=torch.float16):
51
- # Only generate the last image in the sequence
52
- img = pipe.generate_images(
53
- prompt=prompt,
54
- width=width,
55
- height=height,
56
- num_inference_steps=num_inference_steps,
57
- generator=generator
58
- )
59
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
60
  return img, seed, latency
61
 
@@ -171,4 +163,4 @@ with gr.Blocks() as demo:
171
  )
172
 
173
  # Launch the app
174
- demo.launch()
 
8
  from diffusers.models.attention_processor import AttnProcessor2_0
9
  from custom_pipeline import FluxWithCFGPipeline
10
 
 
11
  torch.backends.cuda.matmul.allow_tf32 = True
 
 
12
 
13
  # Constants
14
  MAX_SEED = np.iinfo(np.int32).max
 
29
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0)
30
  pipe.unload_lora_weights()
31
 
 
 
 
 
32
  torch.cuda.empty_cache()
33
 
34
  # Inference function
 
40
 
41
  start_time = time.time()
42
 
43
+ # Only generate the last image in the sequence
44
+ img = pipe.generate_images(
45
+ prompt=prompt,
46
+ width=width,
47
+ height=height,
48
+ num_inference_steps=num_inference_steps,
49
+ generator=generator
50
+ )
 
51
  latency = f"Latency: {(time.time()-start_time):.2f} seconds"
52
  return img, seed, latency
53
 
 
163
  )
164
 
165
  # Launch the app
166
+ demo.launch()
custom_pipeline.py CHANGED
@@ -130,57 +130,29 @@ class FluxWithCFGPipeline(FluxPipeline):
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
133
-
134
- # static method that can be jitted
135
- @staticmethod
136
- @torch.jit.script
137
- def _denoising_loop_static(latents, timesteps, pooled_prompt_embeds, prompt_embeds, text_ids, latent_image_ids, guidance, transformer, scheduler):
138
- for i, t in enumerate(timesteps):
139
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
-
141
- noise_pred = transformer(
142
- hidden_states=latents,
143
- timestep=timestep / 1000,
144
- guidance=guidance,
145
- pooled_projections=pooled_prompt_embeds,
146
- encoder_hidden_states=prompt_embeds,
147
- txt_ids=text_ids,
148
- img_ids=latent_image_ids,
149
- return_dict=False,
150
- )[0]
151
-
152
- latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
153
- torch.cuda.empty_cache()
154
- return latents
155
-
156
- # Make the core denoising loop a static method
157
- self._denoising_loop = torch.cuda.make_graphed_callables(
158
- _denoising_loop_static,
159
- (
160
- latents.clone(), # Example inputs for warmup
161
- timesteps.clone(),
162
- pooled_prompt_embeds.clone(),
163
- prompt_embeds.clone(),
164
- text_ids.clone(),
165
- latent_image_ids.clone(),
166
- guidance.clone(),
167
- self.transformer,
168
- self.scheduler
169
- )
170
- )
171
 
172
- # Call the static method now
173
- latents = self._denoising_loop(
174
- latents,
175
- timesteps,
176
- pooled_prompt_embeds,
177
- prompt_embeds,
178
- text_ids,
179
- latent_image_ids,
180
- guidance,
181
- self.transformer,
182
- self.scheduler
183
- )
 
 
 
 
 
 
 
 
 
 
184
 
185
  # Final image
186
  return self._decode_latents_to_image(latents, height, width, output_type)
 
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # 6. Denoising loop
135
+ for i, t in enumerate(timesteps):
136
+ if self.interrupt:
137
+ continue
138
+
139
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
+
141
+ noise_pred = self.transformer(
142
+ hidden_states=latents,
143
+ timestep=timestep / 1000,
144
+ guidance=guidance,
145
+ pooled_projections=pooled_prompt_embeds,
146
+ encoder_hidden_states=prompt_embeds,
147
+ txt_ids=text_ids,
148
+ img_ids=latent_image_ids,
149
+ joint_attention_kwargs=self.joint_attention_kwargs,
150
+ return_dict=False,
151
+ )[0]
152
+
153
+ # Yield intermediate result
154
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
+ torch.cuda.empty_cache()
156
 
157
  # Final image
158
  return self._decode_latents_to_image(latents, height, width, output_type)