Update pipeline.py
Browse files- pipeline.py +13 -7
pipeline.py
CHANGED
@@ -977,10 +977,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
977 |
|
978 |
# select the relevent context from the latents
|
979 |
current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
|
|
|
|
|
|
|
980 |
# if context_start + context_size > num_frames: append the remaining frames from the start of the latents
|
981 |
-
if
|
982 |
-
print(f"Appending {
|
983 |
-
current_context_latents = torch.cat([current_context_latents, latents[:, :, :
|
|
|
|
|
|
|
984 |
|
985 |
# expand the latents if we are doing classifier free guidance
|
986 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
@@ -1004,13 +1010,13 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1004 |
current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
|
1005 |
|
1006 |
# if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
|
1007 |
-
if
|
1008 |
# add the ending frames from current_context_latents to the start of the latent_sum
|
1009 |
-
latent_sum[:, :,
|
1010 |
# increase the counter for the ending frames
|
1011 |
-
latent_counter[
|
1012 |
# remove the ending frames from current_context_latents
|
1013 |
-
current_context_latents = current_context_latents[:, :, :-
|
1014 |
|
1015 |
#add the context current_context_latents back to the latent sum starting from the current context start
|
1016 |
latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
|
|
|
977 |
|
978 |
# select the relevent context from the latents
|
979 |
current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
|
980 |
+
|
981 |
+
wrap_count = max(current_context_start + context_size - num_frames, 0)
|
982 |
+
|
983 |
# if context_start + context_size > num_frames: append the remaining frames from the start of the latents
|
984 |
+
if wrap_count > 0:
|
985 |
+
print(f"Appending {wrap_count} frames from the start of the latents")
|
986 |
+
current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
|
987 |
+
|
988 |
+
# print number of frames in the context
|
989 |
+
print(f"Number of frames in the context: {current_context_latents.shape[2]}")
|
990 |
|
991 |
# expand the latents if we are doing classifier free guidance
|
992 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
|
|
1010 |
current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
|
1011 |
|
1012 |
# if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
|
1013 |
+
if wrap_count > 0:
|
1014 |
# add the ending frames from current_context_latents to the start of the latent_sum
|
1015 |
+
latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
|
1016 |
# increase the counter for the ending frames
|
1017 |
+
latent_counter[0:wrap_count] += 1
|
1018 |
# remove the ending frames from current_context_latents
|
1019 |
+
current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
|
1020 |
|
1021 |
#add the context current_context_latents back to the latent sum starting from the current context start
|
1022 |
latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
|