smoothieAI commited on
Commit
0e91da2
·
verified ·
1 Parent(s): d73178b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +11 -19
pipeline.py CHANGED
@@ -982,23 +982,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
982
  print(f"Appending {max(current_context_start + context_size - num_frames, 0)} frames from the start of the latents")
983
  current_context_latents = torch.cat([current_context_latents, latents[:, :, :max(current_context_start + context_size - num_frames, 0), :, :]], dim=2)
984
 
985
-
986
- # for context_group in range(num_context_groups):
987
- # # Calculate the current start index, considering overlap
988
- # current_context_start = 0 if context_group == 0 else context_group * (context_size - overlap)
989
-
990
- # # Calculate the end index and adjust if it exceeds num_frames
991
- # current_context_end = (current_context_start + context_size) % num_frames
992
-
993
- # # Select the relevant context from the latents with wrap-around handling
994
- # current_context_latents = torch.cat([
995
- # latents[:, :, current_context_start:min(current_context_start + context_size, num_frames), :, :],
996
- # latents[:, :, :max(current_context_end - num_frames, 0), :, :]
997
- # ], dim=2) if current_context_start + context_size > num_frames else latents[:, :, current_context_start:current_context_start + context_size, :, :]
998
-
999
-
1000
-
1001
-
1002
  # expand the latents if we are doing classifier free guidance
1003
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1004
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1019,9 +1002,18 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1019
 
1020
  # compute the previous noisy sample x_t -> x_t-1
1021
  current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1022
-
 
 
 
 
 
 
 
 
 
1023
  #add the context current_context_latents back to the latent sum starting from the current context start
1024
- latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
1025
  # add one to the counter for each timestep in the context
1026
  latent_counter[current_context_start : current_context_start + context_size] += 1
1027
 
 
982
  print(f"Appending {max(current_context_start + context_size - num_frames, 0)} frames from the start of the latents")
983
  current_context_latents = torch.cat([current_context_latents, latents[:, :, :max(current_context_start + context_size - num_frames, 0), :, :]], dim=2)
984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  # expand the latents if we are doing classifier free guidance
986
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
987
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
1002
 
1003
  # compute the previous noisy sample x_t -> x_t-1
1004
  current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1005
+
1006
+ # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1007
+ if current_context_start + context_size > num_frames:
1008
+ # add the ending frames from current_context_latents to the start of the latent_sum
1009
+ latent_sum[:, :, -max(current_context_start + context_size - num_frames, 0):, :, :] += current_context_latents[:, :, -max(current_context_start + context_size - num_frames, 0):, :, :]
1010
+ # increase the counter for the ending frames
1011
+ latent_counter[-max(current_context_start + context_size - num_frames, 0):] += 1
1012
+ # remove the ending frames from current_context_latents
1013
+ current_context_latents = current_context_latents[:, :, :-max(current_context_start + context_size - num_frames, 0), :, :]
1014
+
1015
  #add the context current_context_latents back to the latent sum starting from the current context start
1016
+ latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
1017
  # add one to the counter for each timestep in the context
1018
  latent_counter[current_context_start : current_context_start + context_size] += 1
1019