smoothieAI commited on
Commit
d41df3c
·
verified ·
1 Parent(s): 3c0b9b4

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +9 -21
pipeline.py CHANGED
@@ -1497,28 +1497,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1497
 
1498
  if do_classifier_free_guidance:
1499
  # Split tensor along its first dimension
 
1500
  noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
1501
-
1502
- # Reshape or expand noise_pred_uncond and noise_pred_text to match the dimensions of the sum tensors
1503
- # This step depends on the dimensions of your tensors and how they need to align for the operation
1504
- expanded_noise_pred_uncond = noise_pred_uncond.expand_as(noise_pred_uncond_sum[..., current_context_indexes, :, :])
1505
- expanded_noise_pred_text = noise_pred_text.expand_as(noise_pred_text_sum[..., current_context_indexes, :, :])
1506
-
1507
- # Perform batch addition
1508
- noise_pred_uncond_sum[..., current_context_indexes, :, :] += expanded_noise_pred_uncond
1509
- noise_pred_text_sum[..., current_context_indexes, :, :] += expanded_noise_pred_text
1510
-
1511
- # print devices and shapes for everything
1512
- print("noise_pred_uncond_sum", noise_pred_uncond_sum.device, noise_pred_uncond_sum.shape)
1513
- print("noise_pred_text_sum", noise_pred_text_sum.device, noise_pred_text_sum.shape)
1514
- print("expanded_noise_pred_uncond", expanded_noise_pred_uncond.device, expanded_noise_pred_uncond.shape)
1515
- print("expanded_noise_pred_text", expanded_noise_pred_text.device, expanded_noise_pred_text.shape)
1516
- print("current_context_latents", current_context_latents.device, current_context_latents.shape)
1517
- print("latent_counter", latent_counter.device, latent_counter.shape)
1518
-
1519
-
1520
- # Batch increment for latent_counter
1521
- latent_counter[current_context_indexes] += 1
1522
 
1523
  print("guidance time", time.time() - start_guidance_time)
1524
 
 
1497
 
1498
  if do_classifier_free_guidance:
1499
  # Split tensor along its first dimension
1500
+ time_chunk = time.time()
1501
  noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
1502
+ print("chunk time", time.time() - time_chunk)
1503
+ for i,context_index in range(len(current_context_indexes)):
1504
+ # Perform batch addition
1505
+ noise_pred_uncond_sum[..., context_index, :, :] += noise_pred_uncond[ :, :,i, :, :]
1506
+ noise_pred_text_sum[..., context_index, :, :] += noise_pred_text[ :, :,i, :, :]
1507
+
1508
+ # Batch increment for latent_counter
1509
+ latent_counter[context_index] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
1510
 
1511
  print("guidance time", time.time() - start_guidance_time)
1512