rahul7star commited on
Commit
cdc4fd7
·
verified ·
1 Parent(s): d446991

Update toolkit/models/diffusion_feature_extraction.py

Browse files
toolkit/models/diffusion_feature_extraction.py CHANGED
@@ -255,30 +255,30 @@ class DiffusionFeatureExtractor3(nn.Module):
255
  dtype = torch.bfloat16
256
  device = self.vae.device
257
 
258
- # first we step the scheduler from current timestep to the very end for a full denoise
259
- # bs = noise_pred.shape[0]
260
- # noise_pred_chunks = torch.chunk(noise_pred, bs)
261
- # timestep_chunks = torch.chunk(timesteps, bs)
262
- # noisy_latent_chunks = torch.chunk(noisy_latents, bs)
263
- # stepped_chunks = []
264
- # for idx in range(bs):
265
- # model_output = noise_pred_chunks[idx]
266
- # timestep = timestep_chunks[idx]
267
- # scheduler._step_index = None
268
- # scheduler._init_step_index(timestep)
269
- # sample = noisy_latent_chunks[idx].to(torch.float32)
270
-
271
- # sigma = scheduler.sigmas[scheduler.step_index]
272
- # sigma_next = scheduler.sigmas[-1] # use last sigma for final step
273
- # prev_sample = sample + (sigma_next - sigma) * model_output
274
- # stepped_chunks.append(prev_sample)
275
-
276
- # stepped_latents = torch.cat(stepped_chunks, dim=0)
277
 
278
  if model is not None and hasattr(model, 'get_stepped_pred'):
279
  stepped_latents = model.get_stepped_pred(noise_pred, noise)
280
  else:
281
- stepped_latents = noise - noise_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
284
 
@@ -374,4 +374,4 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor:
374
 
375
  dfe.load_state_dict(state_dict)
376
  dfe.eval()
377
- return dfe
 
255
  dtype = torch.bfloat16
256
  device = self.vae.device
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  if model is not None and hasattr(model, 'get_stepped_pred'):
260
  stepped_latents = model.get_stepped_pred(noise_pred, noise)
261
  else:
262
+ # stepped_latents = noise - noise_pred
263
+ # first we step the scheduler from current timestep to the very end for a full denoise
264
+ bs = noise_pred.shape[0]
265
+ noise_pred_chunks = torch.chunk(noise_pred, bs)
266
+ timestep_chunks = torch.chunk(timesteps, bs)
267
+ noisy_latent_chunks = torch.chunk(noisy_latents, bs)
268
+ stepped_chunks = []
269
+ for idx in range(bs):
270
+ model_output = noise_pred_chunks[idx]
271
+ timestep = timestep_chunks[idx]
272
+ scheduler._step_index = None
273
+ scheduler._init_step_index(timestep)
274
+ sample = noisy_latent_chunks[idx].to(torch.float32)
275
+
276
+ sigma = scheduler.sigmas[scheduler.step_index]
277
+ sigma_next = scheduler.sigmas[-1] # use last sigma for final step
278
+ prev_sample = sample + (sigma_next - sigma) * model_output
279
+ stepped_chunks.append(prev_sample)
280
+
281
+ stepped_latents = torch.cat(stepped_chunks, dim=0)
282
 
283
  latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
284
 
 
374
 
375
  dfe.load_state_dict(state_dict)
376
  dfe.eval()
377
+ return dfe