smoothieAI commited on
Commit
a88026f
·
verified ·
1 Parent(s): 1e17894

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +17 -0
pipeline.py CHANGED
@@ -975,6 +975,23 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
975
  # select the relevent context from the latents
976
  current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
978
  # expand the latents if we are doing classifier free guidance
979
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
980
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
975
  # select the relevent context from the latents
976
  current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
977
 
978
+
979
+ # for context_group in range(num_context_groups):
980
+ # # Calculate the current start index, considering overlap
981
+ # current_context_start = 0 if context_group == 0 else context_group * (context_size - overlap)
982
+
983
+ # # Calculate the end index and adjust if it exceeds num_frames
984
+ # current_context_end = (current_context_start + context_size) % num_frames
985
+
986
+ # # Select the relevant context from the latents with wrap-around handling
987
+ # current_context_latents = torch.cat([
988
+ # latents[:, :, current_context_start:min(current_context_start + context_size, num_frames), :, :],
989
+ # latents[:, :, :max(current_context_end - num_frames, 0), :, :]
990
+ # ], dim=2) if current_context_start + context_size > num_frames else latents[:, :, current_context_start:current_context_start + context_size, :, :]
991
+
992
+
993
+
994
+
995
  # expand the latents if we are doing classifier free guidance
996
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
997
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)