Spaces:
Sleeping
Sleeping
Update custom_pipeline.py
Browse files- custom_pipeline.py +53 -22
custom_pipeline.py
CHANGED
@@ -130,29 +130,60 @@ class FluxWithCFGPipeline(FluxPipeline):
|
|
130 |
|
131 |
# Handle guidance
|
132 |
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
#
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
txt_ids=text_ids,
|
148 |
-
img_ids=latent_image_ids,
|
149 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
150 |
-
return_dict=False,
|
151 |
-
)[0]
|
152 |
-
|
153 |
-
# Yield intermediate result
|
154 |
-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
155 |
-
torch.cuda.empty_cache()
|
156 |
|
157 |
# Final image
|
158 |
return self._decode_latents_to_image(latents, height, width, output_type)
|
|
|
130 |
|
131 |
# Handle guidance
|
132 |
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
|
133 |
+
|
134 |
+
# static method that can be jitted
|
135 |
+
@staticmethod
|
136 |
+
@torch.jit.script
|
137 |
+
def _denoising_loop_static(latents, timesteps, pooled_prompt_embeds, prompt_embeds, text_ids, latent_image_ids, guidance, joint_attention_kwargs, transformer, scheduler):
|
138 |
+
for i, t in enumerate(timesteps):
|
139 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
140 |
+
|
141 |
+
noise_pred = transformer(
|
142 |
+
hidden_states=latents,
|
143 |
+
timestep=timestep / 1000,
|
144 |
+
guidance=guidance,
|
145 |
+
pooled_projections=pooled_prompt_embeds,
|
146 |
+
encoder_hidden_states=prompt_embeds,
|
147 |
+
txt_ids=text_ids,
|
148 |
+
img_ids=latent_image_ids,
|
149 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
150 |
+
return_dict=False,
|
151 |
+
)[0]
|
152 |
+
|
153 |
+
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
154 |
+
torch.cuda.empty_cache()
|
155 |
+
return latents
|
156 |
+
|
157 |
+
# Make the core denoising loop a static method
|
158 |
+
self._denoising_loop = torch.cuda.make_graphed_callables(
|
159 |
+
_denoising_loop_static,
|
160 |
+
(
|
161 |
+
latents.clone(), # Example inputs for warmup
|
162 |
+
timesteps.clone(),
|
163 |
+
pooled_prompt_embeds.clone(),
|
164 |
+
prompt_embeds.clone(),
|
165 |
+
text_ids.clone(),
|
166 |
+
latent_image_ids.clone(),
|
167 |
+
guidance.clone(),
|
168 |
+
self._joint_attention_kwargs,
|
169 |
+
self.transformer,
|
170 |
+
self.scheduler
|
171 |
+
)
|
172 |
+
)
|
173 |
|
174 |
+
# Call the static method now
|
175 |
+
latents = self._denoising_loop(
|
176 |
+
latents,
|
177 |
+
timesteps,
|
178 |
+
pooled_prompt_embeds,
|
179 |
+
prompt_embeds,
|
180 |
+
text_ids,
|
181 |
+
latent_image_ids,
|
182 |
+
guidance,
|
183 |
+
self._joint_attention_kwargs,
|
184 |
+
self.transformer,
|
185 |
+
self.scheduler
|
186 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
# Final image
|
189 |
return self._decode_latents_to_image(latents, height, width, output_type)
|