Spaces:
Running
Running
Update toolkit/models/diffusion_feature_extraction.py
Browse files
toolkit/models/diffusion_feature_extraction.py
CHANGED
@@ -255,30 +255,30 @@ class DiffusionFeatureExtractor3(nn.Module):
|
|
255 |
dtype = torch.bfloat16
|
256 |
device = self.vae.device
|
257 |
|
258 |
-
# first we step the scheduler from current timestep to the very end for a full denoise
|
259 |
-
# bs = noise_pred.shape[0]
|
260 |
-
# noise_pred_chunks = torch.chunk(noise_pred, bs)
|
261 |
-
# timestep_chunks = torch.chunk(timesteps, bs)
|
262 |
-
# noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
263 |
-
# stepped_chunks = []
|
264 |
-
# for idx in range(bs):
|
265 |
-
# model_output = noise_pred_chunks[idx]
|
266 |
-
# timestep = timestep_chunks[idx]
|
267 |
-
# scheduler._step_index = None
|
268 |
-
# scheduler._init_step_index(timestep)
|
269 |
-
# sample = noisy_latent_chunks[idx].to(torch.float32)
|
270 |
-
|
271 |
-
# sigma = scheduler.sigmas[scheduler.step_index]
|
272 |
-
# sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
273 |
-
# prev_sample = sample + (sigma_next - sigma) * model_output
|
274 |
-
# stepped_chunks.append(prev_sample)
|
275 |
-
|
276 |
-
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
277 |
|
278 |
if model is not None and hasattr(model, 'get_stepped_pred'):
|
279 |
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
280 |
else:
|
281 |
-
stepped_latents = noise - noise_pred
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
284 |
|
@@ -374,4 +374,4 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
|
|
374 |
|
375 |
dfe.load_state_dict(state_dict)
|
376 |
dfe.eval()
|
377 |
-
return dfe
|
|
|
255 |
dtype = torch.bfloat16
|
256 |
device = self.vae.device
|
257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
if model is not None and hasattr(model, 'get_stepped_pred'):
|
260 |
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
261 |
else:
|
262 |
+
# stepped_latents = noise - noise_pred
|
263 |
+
# first we step the scheduler from current timestep to the very end for a full denoise
|
264 |
+
bs = noise_pred.shape[0]
|
265 |
+
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
266 |
+
timestep_chunks = torch.chunk(timesteps, bs)
|
267 |
+
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
268 |
+
stepped_chunks = []
|
269 |
+
for idx in range(bs):
|
270 |
+
model_output = noise_pred_chunks[idx]
|
271 |
+
timestep = timestep_chunks[idx]
|
272 |
+
scheduler._step_index = None
|
273 |
+
scheduler._init_step_index(timestep)
|
274 |
+
sample = noisy_latent_chunks[idx].to(torch.float32)
|
275 |
+
|
276 |
+
sigma = scheduler.sigmas[scheduler.step_index]
|
277 |
+
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
278 |
+
prev_sample = sample + (sigma_next - sigma) * model_output
|
279 |
+
stepped_chunks.append(prev_sample)
|
280 |
+
|
281 |
+
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
282 |
|
283 |
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
284 |
|
|
|
374 |
|
375 |
dfe.load_state_dict(state_dict)
|
376 |
dfe.eval()
|
377 |
+
return dfe
|