Ryukijano commited on
Commit
5b89af0
·
verified ·
1 Parent(s): e462ef2

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +23 -54
custom_pipeline.py CHANGED
@@ -130,60 +130,29 @@ class FluxWithCFGPipeline(FluxPipeline):
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
133
-
134
- # static method that can be jitted
135
- @staticmethod
136
- @torch.jit.script
137
- def _denoising_loop_static(latents, timesteps, pooled_prompt_embeds, prompt_embeds, text_ids, latent_image_ids, guidance, joint_attention_kwargs, transformer, scheduler):
138
- for i, t in enumerate(timesteps):
139
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
-
141
- noise_pred = transformer(
142
- hidden_states=latents,
143
- timestep=timestep / 1000,
144
- guidance=guidance,
145
- pooled_projections=pooled_prompt_embeds,
146
- encoder_hidden_states=prompt_embeds,
147
- txt_ids=text_ids,
148
- img_ids=latent_image_ids,
149
- joint_attention_kwargs=joint_attention_kwargs,
150
- return_dict=False,
151
- )[0]
152
-
153
- latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
154
- torch.cuda.empty_cache()
155
- return latents
156
-
157
- # Make the core denoising loop a static method
158
- self._denoising_loop = torch.cuda.make_graphed_callables(
159
- _denoising_loop_static,
160
- (
161
- latents.clone(), # Example inputs for warmup
162
- timesteps.clone(),
163
- pooled_prompt_embeds.clone(),
164
- prompt_embeds.clone(),
165
- text_ids.clone(),
166
- latent_image_ids.clone(),
167
- guidance.clone(),
168
- self._joint_attention_kwargs,
169
- self.transformer,
170
- self.scheduler
171
- )
172
- )
173
 
174
- # Call the static method now
175
- latents = self._denoising_loop(
176
- latents,
177
- timesteps,
178
- pooled_prompt_embeds,
179
- prompt_embeds,
180
- text_ids,
181
- latent_image_ids,
182
- guidance,
183
- self._joint_attention_kwargs,
184
- self.transformer,
185
- self.scheduler
186
- )
 
 
 
 
 
 
 
 
 
187
 
188
  # Final image
189
  return self._decode_latents_to_image(latents, height, width, output_type)
@@ -196,4 +165,4 @@ class FluxWithCFGPipeline(FluxPipeline):
196
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
197
  latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
198
  image = vae.decode(latents, return_dict=False)[0]
199
- return self.image_processor.postprocess(image, output_type=output_type)[0]
 
130
 
131
  # Handle guidance
132
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # 6. Denoising loop
135
+ for i, t in enumerate(timesteps):
136
+ if self.interrupt:
137
+ continue
138
+
139
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
+
141
+ noise_pred = self.transformer(
142
+ hidden_states=latents,
143
+ timestep=timestep / 1000,
144
+ guidance=guidance,
145
+ pooled_projections=pooled_prompt_embeds,
146
+ encoder_hidden_states=prompt_embeds,
147
+ txt_ids=text_ids,
148
+ img_ids=latent_image_ids,
149
+ joint_attention_kwargs=self.joint_attention_kwargs,
150
+ return_dict=False,
151
+ )[0]
152
+
153
+ # Yield intermediate result
154
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
+ torch.cuda.empty_cache()
156
 
157
  # Final image
158
  return self._decode_latents_to_image(latents, height, width, output_type)
 
165
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
166
  latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
167
  image = vae.decode(latents, return_dict=False)[0]
168
+ return self.image_processor.postprocess(image, output_type=output_type)[0]