smoothieAI commited on
Commit
00f8442
·
verified ·
1 Parent(s): 24453ad

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +30 -27
pipeline.py CHANGED
@@ -1005,7 +1005,27 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1005
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1006
 
1007
  # divide the initial latents into context groups
1008
- num_context_groups = num_frames // (context_size-overlap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1009
 
1010
  # Denoising loop
1011
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -1013,24 +1033,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1013
  for i, t in enumerate(timesteps):
1014
  noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1015
  noise_pred_text_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1016
- # latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1017
  latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1018
 
1019
  # foreach context group seperately denoise the current timestep
1020
- for context_group in range(num_context_groups):
1021
  # calculate to current indexes, considering overlap
1022
- if context_group == 0:current_context_start = 0
1023
- else:current_context_start = context_group * (context_size - overlap)
1024
 
1025
  # select the relevent context from the latents
1026
- current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
1027
-
1028
- wrap_count = max(current_context_start + context_size - num_frames, 0)
1029
 
1030
- # if context_start + context_size > num_frames: append the remaining frames from the start of the latents
1031
- if wrap_count > 0:
1032
- current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
1033
-
1034
  # expand the latents if we are doing classifier free guidance
1035
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1036
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1047,18 +1059,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1047
  # sum the noise predictions for the unconditional and text conditioned noise
1048
  if do_classifier_free_guidance:
1049
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1050
- if wrap_count > 0:
1051
- # add the ending frames from noise_pred_uncond to the start of the noise_pred_uncond_sum
1052
- noise_pred_uncond_sum[:, :, 0:wrap_count, :, :] += noise_pred_uncond[:, :, -wrap_count:, :, :]
1053
- noise_pred_text_sum[:, :, 0:wrap_count, :, :] += noise_pred_text[:, :, -wrap_count:, :, :]
1054
- #increase the counter for the ending frames
1055
- latent_counter[0:wrap_count] += 1
1056
- # remove the ending frames from noise_pred_uncond
1057
- noise_pred_uncond = noise_pred_uncond[:, :, :-wrap_count, :, :]
1058
- noise_pred_text = noise_pred_text[:, :, :-wrap_count, :, :]
1059
- noise_pred_uncond_sum[:, :, current_context_start : current_context_start + context_size, :, :] += noise_pred_uncond
1060
- noise_pred_text_sum[:, :, current_context_start : current_context_start + context_size, :, :] += noise_pred_text
1061
- latent_counter[current_context_start : current_context_start + context_size] += 1
1062
 
1063
  # set the step index to the current batch
1064
  self.scheduler._step_index = i
@@ -1078,9 +1084,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1078
  progress_bar.update()
1079
  if callback is not None and i % callback_steps == 0:
1080
  callback(i, t, None)
1081
-
1082
- # offset latent images by step places, wrapping around the last frames to the start
1083
- latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
1084
 
1085
  if output_type == "latent":
1086
  return AnimateDiffPipelineOutput(frames=latents)
 
1005
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1006
 
1007
  # divide the initial latents into context groups
1008
+
1009
+ def context_scheduler(context_size, overlap, offset, num_frames, num_timesteps):
1010
+ num_context_groups = (num_frames // (context_size-overlap))+1
1011
+ context_indexes = []
1012
+ for t in range(num_timesteps):
1013
+ context_groups = []
1014
+ for context_group_index in range(num_context_groups):
1015
+ context_group = []
1016
+ for i in range(context_size):
1017
+ # calculate the frame index
1018
+ frame_index = ((t+1) * context_group_index * (context_size-overlap)) + i
1019
+ # wrap around at the end
1020
+ if frame_index >= num_frames:frame_index = frame_index % num_frames
1021
+ context_group.append(frame_index)
1022
+ context_groups.append(context_groups)
1023
+ context_indexes.append(context_groups)
1024
+ return context_indexes
1025
+
1026
+ context_indexes = context_scheduler(context_size, overlap, num_frames, len(timesteps))
1027
+
1028
+ print(f"Context indexes: {context_indexes}")
1029
 
1030
  # Denoising loop
1031
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
1033
  for i, t in enumerate(timesteps):
1034
  noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1035
  noise_pred_text_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
 
1036
  latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1037
 
1038
  # foreach context group seperately denoise the current timestep
1039
+ for context_group in range(len(context_indexes[i])):
1040
  # calculate to current indexes, considering overlap
1041
+ current_context_indexes = context_indexes[i][context_group]
 
1042
 
1043
  # select the relevent context from the latents
1044
+ current_context_latents = latents[:, :, current_context_indexes, :, :]
 
 
1045
 
 
 
 
 
1046
  # expand the latents if we are doing classifier free guidance
1047
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1048
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
1059
  # sum the noise predictions for the unconditional and text conditioned noise
1060
  if do_classifier_free_guidance:
1061
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1062
+
1063
+ # add the ending frames from noise_pred_uncond to the start of the noise_pred_uncond_sum
1064
+ noise_pred_uncond_sum[:, :,current_context_indexes, :, :] += noise_pred_uncond
1065
+ noise_pred_text_sum[:, :,current_context_indexes, :, :] += noise_pred_text
1066
+ #increase the counter for the ending frames
1067
+ latent_counter[current_context_indexes] += 1
 
 
 
 
 
 
1068
 
1069
  # set the step index to the current batch
1070
  self.scheduler._step_index = i
 
1084
  progress_bar.update()
1085
  if callback is not None and i % callback_steps == 0:
1086
  callback(i, t, None)
 
 
 
1087
 
1088
  if output_type == "latent":
1089
  return AnimateDiffPipelineOutput(frames=latents)