Ryukijano commited on
Commit
b480846
·
verified ·
1 Parent(s): 431e45c

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +53 -22
custom_pipeline.py CHANGED
@@ -130,29 +130,60 @@ class FluxWithCFGPipeline(FluxPipeline):
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # 6. Denoising loop
135
- for i, t in enumerate(timesteps):
136
- if self.interrupt:
137
- continue
138
-
139
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
-
141
- noise_pred = self.transformer(
142
- hidden_states=latents,
143
- timestep=timestep / 1000,
144
- guidance=guidance,
145
- pooled_projections=pooled_prompt_embeds,
146
- encoder_hidden_states=prompt_embeds,
147
- txt_ids=text_ids,
148
- img_ids=latent_image_ids,
149
- joint_attention_kwargs=self.joint_attention_kwargs,
150
- return_dict=False,
151
- )[0]
152
-
153
- # Yield intermediate result
154
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
- torch.cuda.empty_cache()
156
 
157
  # Final image
158
  return self._decode_latents_to_image(latents, height, width, output_type)
 
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
133
+
134
+ # static method that can be jitted
135
+ @staticmethod
136
+ @torch.jit.script
137
+ def _denoising_loop_static(latents, timesteps, pooled_prompt_embeds, prompt_embeds, text_ids, latent_image_ids, guidance, joint_attention_kwargs, transformer, scheduler):
138
+ for i, t in enumerate(timesteps):
139
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
+
141
+ noise_pred = transformer(
142
+ hidden_states=latents,
143
+ timestep=timestep / 1000,
144
+ guidance=guidance,
145
+ pooled_projections=pooled_prompt_embeds,
146
+ encoder_hidden_states=prompt_embeds,
147
+ txt_ids=text_ids,
148
+ img_ids=latent_image_ids,
149
+ joint_attention_kwargs=joint_attention_kwargs,
150
+ return_dict=False,
151
+ )[0]
152
+
153
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
154
+ torch.cuda.empty_cache()
155
+ return latents
156
+
157
+ # Make the core denoising loop a static method
158
+ self._denoising_loop = torch.cuda.make_graphed_callables(
159
+ _denoising_loop_static,
160
+ (
161
+ latents.clone(), # Example inputs for warmup
162
+ timesteps.clone(),
163
+ pooled_prompt_embeds.clone(),
164
+ prompt_embeds.clone(),
165
+ text_ids.clone(),
166
+ latent_image_ids.clone(),
167
+ guidance.clone(),
168
+ self._joint_attention_kwargs,
169
+ self.transformer,
170
+ self.scheduler
171
+ )
172
+ )
173
 
174
+ # Call the static method now
175
+ latents = self._denoising_loop(
176
+ latents,
177
+ timesteps,
178
+ pooled_prompt_embeds,
179
+ prompt_embeds,
180
+ text_ids,
181
+ latent_image_ids,
182
+ guidance,
183
+ self._joint_attention_kwargs,
184
+ self.transformer,
185
+ self.scheduler
186
+ )
 
 
 
 
 
 
 
 
 
187
 
188
  # Final image
189
  return self._decode_latents_to_image(latents, height, width, output_type)