Update pipeline.py
Browse files- pipeline.py +21 -21
pipeline.py
CHANGED
@@ -725,30 +725,30 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
725 |
latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
726 |
latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
|
727 |
|
728 |
-
#
|
729 |
-
# for context_group in range(num_context_groups):
|
730 |
-
# # calculate to current indexes, considering overlap
|
731 |
-
# if context_group == 0:current_context_start = 0
|
732 |
-
# else:current_context_start = context_group * (context_size - overlap)
|
733 |
-
|
734 |
-
# # select the relevent context from the latents
|
735 |
-
# current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
|
736 |
-
# # if the context extends past the end of the latents, wrap around to the start
|
737 |
-
# if current_context_start + context_size > num_frames:
|
738 |
-
# current_context_latents = torch.cat([current_context_latents, latents[:, :, :current_context_start + context_size - num_frames, :, :]], dim=2)
|
739 |
-
|
740 |
for context_group in range(num_context_groups):
|
741 |
-
#
|
742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
743 |
|
744 |
-
|
745 |
-
|
746 |
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
|
753 |
|
754 |
|
|
|
725 |
latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
726 |
latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
|
727 |
|
728 |
+
# foreach context group seperately denoise the current timestep
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
for context_group in range(num_context_groups):
|
730 |
+
# calculate to current indexes, considering overlap
|
731 |
+
if context_group == 0:current_context_start = 0
|
732 |
+
else:current_context_start = context_group * (context_size - overlap)
|
733 |
+
|
734 |
+
# select the relevent context from the latents
|
735 |
+
current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
|
736 |
+
# if the context extends past the end of the latents, wrap around to the start
|
737 |
+
if current_context_start + context_size > num_frames:
|
738 |
+
current_context_latents = torch.cat([current_context_latents, latents[:, :, :current_context_start + context_size - num_frames, :, :]], dim=2)
|
739 |
+
|
740 |
+
# for context_group in range(num_context_groups):
|
741 |
+
# # Calculate the current start index, considering overlap
|
742 |
+
# current_context_start = 0 if context_group == 0 else context_group * (context_size - overlap)
|
743 |
|
744 |
+
# # Calculate the end index and adjust if it exceeds num_frames
|
745 |
+
# current_context_end = (current_context_start + context_size) % num_frames
|
746 |
|
747 |
+
# # Select the relevant context from the latents with wrap-around handling
|
748 |
+
# current_context_latents = torch.cat([
|
749 |
+
# latents[:, :, current_context_start:min(current_context_start + context_size, num_frames), :, :],
|
750 |
+
# latents[:, :, :max(current_context_end - num_frames, 0), :, :]
|
751 |
+
# ], dim=2) if current_context_start + context_size > num_frames else latents[:, :, current_context_start:current_context_start + context_size, :, :]
|
752 |
|
753 |
|
754 |
|