Update pipeline.py
Browse files- pipeline.py +9 -21
pipeline.py
CHANGED
@@ -1497,28 +1497,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1497 |
|
1498 |
if do_classifier_free_guidance:
|
1499 |
# Split tensor along its first dimension
|
|
|
1500 |
noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
|
1501 |
-
|
1502 |
-
|
1503 |
-
|
1504 |
-
|
1505 |
-
|
1506 |
-
|
1507 |
-
|
1508 |
-
|
1509 |
-
noise_pred_text_sum[..., current_context_indexes, :, :] += expanded_noise_pred_text
|
1510 |
-
|
1511 |
-
# print devices and shapes for everything
|
1512 |
-
print("noise_pred_uncond_sum", noise_pred_uncond_sum.device, noise_pred_uncond_sum.shape)
|
1513 |
-
print("noise_pred_text_sum", noise_pred_text_sum.device, noise_pred_text_sum.shape)
|
1514 |
-
print("expanded_noise_pred_uncond", expanded_noise_pred_uncond.device, expanded_noise_pred_uncond.shape)
|
1515 |
-
print("expanded_noise_pred_text", expanded_noise_pred_text.device, expanded_noise_pred_text.shape)
|
1516 |
-
print("current_context_latents", current_context_latents.device, current_context_latents.shape)
|
1517 |
-
print("latent_counter", latent_counter.device, latent_counter.shape)
|
1518 |
-
|
1519 |
-
|
1520 |
-
# Batch increment for latent_counter
|
1521 |
-
latent_counter[current_context_indexes] += 1
|
1522 |
|
1523 |
print("guidance time", time.time() - start_guidance_time)
|
1524 |
|
|
|
1497 |
|
1498 |
if do_classifier_free_guidance:
|
1499 |
# Split tensor along its first dimension
|
1500 |
+
time_chunk = time.time()
|
1501 |
noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
|
1502 |
+
print("chunk time", time.time() - time_chunk)
|
1503 |
+
for i,context_index in range(len(current_context_indexes)):
|
1504 |
+
# Perform batch addition
|
1505 |
+
noise_pred_uncond_sum[..., context_index, :, :] += noise_pred_uncond[ :, :,i, :, :]
|
1506 |
+
noise_pred_text_sum[..., context_index, :, :] += noise_pred_text[ :, :,i, :, :]
|
1507 |
+
|
1508 |
+
# Batch increment for latent_counter
|
1509 |
+
latent_counter[context_index] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1510 |
|
1511 |
print("guidance time", time.time() - start_guidance_time)
|
1512 |
|