Ryukijano commited on
Commit
c8f2370
·
verified ·
1 Parent(s): f819c94

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +50 -22
custom_pipeline.py CHANGED
@@ -130,29 +130,57 @@ class FluxWithCFGPipeline(FluxPipeline):
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # 6. Denoising loop
135
- for i, t in enumerate(timesteps):
136
- if self.interrupt:
137
- continue
138
-
139
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
-
141
- noise_pred = self.transformer(
142
- hidden_states=latents,
143
- timestep=timestep / 1000,
144
- guidance=guidance,
145
- pooled_projections=pooled_prompt_embeds,
146
- encoder_hidden_states=prompt_embeds,
147
- txt_ids=text_ids,
148
- img_ids=latent_image_ids,
149
- joint_attention_kwargs=self.joint_attention_kwargs,
150
- return_dict=False,
151
- )[0]
152
-
153
- # Yield intermediate result
154
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
- torch.cuda.empty_cache()
156
 
157
  # Final image
158
  return self._decode_latents_to_image(latents, height, width, output_type)
 
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
133
+
134
+ # static method that can be jitted
135
+ @staticmethod
136
+ @torch.jit.script
137
+ def _denoising_loop_static(latents, timesteps, pooled_prompt_embeds, prompt_embeds, text_ids, latent_image_ids, guidance, transformer, scheduler):
138
+ for i, t in enumerate(timesteps):
139
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
+
141
+ noise_pred = transformer(
142
+ hidden_states=latents,
143
+ timestep=timestep / 1000,
144
+ guidance=guidance,
145
+ pooled_projections=pooled_prompt_embeds,
146
+ encoder_hidden_states=prompt_embeds,
147
+ txt_ids=text_ids,
148
+ img_ids=latent_image_ids,
149
+ return_dict=False,
150
+ )[0]
151
+
152
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
153
+ torch.cuda.empty_cache()
154
+ return latents
155
+
156
+ # Make the core denoising loop a static method
157
+ self._denoising_loop = torch.cuda.make_graphed_callables(
158
+ _denoising_loop_static,
159
+ (
160
+ latents.clone(), # Example inputs for warmup
161
+ timesteps.clone(),
162
+ pooled_prompt_embeds.clone(),
163
+ prompt_embeds.clone(),
164
+ text_ids.clone(),
165
+ latent_image_ids.clone(),
166
+ guidance.clone(),
167
+ self.transformer,
168
+ self.scheduler
169
+ )
170
+ )
171
 
172
+ # Call the static method now
173
+ latents = self._denoising_loop(
174
+ latents,
175
+ timesteps,
176
+ pooled_prompt_embeds,
177
+ prompt_embeds,
178
+ text_ids,
179
+ latent_image_ids,
180
+ guidance,
181
+ self.transformer,
182
+ self.scheduler
183
+ )
 
 
 
 
 
 
 
 
 
 
184
 
185
  # Final image
186
  return self._decode_latents_to_image(latents, height, width, output_type)