smoothieAI commited on
Commit
f0ba1c8
·
verified ·
1 Parent(s): 0e91da2

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +13 -7
pipeline.py CHANGED
@@ -977,10 +977,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
977
 
978
  # select the relevent context from the latents
979
  current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
 
 
 
980
  # if context_start + context_size > num_frames: append the remaining frames from the start of the latents
981
- if current_context_start + context_size > num_frames:
982
- print(f"Appending {max(current_context_start + context_size - num_frames, 0)} frames from the start of the latents")
983
- current_context_latents = torch.cat([current_context_latents, latents[:, :, :max(current_context_start + context_size - num_frames, 0), :, :]], dim=2)
 
 
 
984
 
985
  # expand the latents if we are doing classifier free guidance
986
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
@@ -1004,13 +1010,13 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1004
  current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1005
 
1006
  # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1007
- if current_context_start + context_size > num_frames:
1008
  # add the ending frames from current_context_latents to the start of the latent_sum
1009
- latent_sum[:, :, -max(current_context_start + context_size - num_frames, 0):, :, :] += current_context_latents[:, :, -max(current_context_start + context_size - num_frames, 0):, :, :]
1010
  # increase the counter for the ending frames
1011
- latent_counter[-max(current_context_start + context_size - num_frames, 0):] += 1
1012
  # remove the ending frames from current_context_latents
1013
- current_context_latents = current_context_latents[:, :, :-max(current_context_start + context_size - num_frames, 0), :, :]
1014
 
1015
  #add the context current_context_latents back to the latent sum starting from the current context start
1016
  latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
 
977
 
978
  # select the relevent context from the latents
979
  current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
980
+
981
+ wrap_count = max(current_context_start + context_size - num_frames, 0)
982
+
983
  # if context_start + context_size > num_frames: append the remaining frames from the start of the latents
984
+ if wrap_count > 0:
985
+ print(f"Appending {wrap_count} frames from the start of the latents")
986
+ current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
987
+
988
+ # print number of frames in the context
989
+ print(f"Number of frames in the context: {current_context_latents.shape[2]}")
990
 
991
  # expand the latents if we are doing classifier free guidance
992
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
 
1010
  current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1011
 
1012
  # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1013
+ if wrap_count > 0:
1014
  # add the ending frames from current_context_latents to the start of the latent_sum
1015
+ latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
1016
  # increase the counter for the ending frames
1017
+ latent_counter[0:wrap_count] += 1
1018
  # remove the ending frames from current_context_latents
1019
+ current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
1020
 
1021
  #add the context current_context_latents back to the latent sum starting from the current context start
1022
  latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents