Update pipeline.py
Browse files- pipeline.py +30 -27
pipeline.py
CHANGED
@@ -1005,7 +1005,27 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1005 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1006 |
|
1007 |
# divide the initial latents into context groups
|
1008 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1009 |
|
1010 |
# Denoising loop
|
1011 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
@@ -1013,24 +1033,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1013 |
for i, t in enumerate(timesteps):
|
1014 |
noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1015 |
noise_pred_text_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1016 |
-
# latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1017 |
latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
|
1018 |
|
1019 |
# foreach context group seperately denoise the current timestep
|
1020 |
-
for context_group in range(
|
1021 |
# calculate to current indexes, considering overlap
|
1022 |
-
|
1023 |
-
else:current_context_start = context_group * (context_size - overlap)
|
1024 |
|
1025 |
# select the relevent context from the latents
|
1026 |
-
current_context_latents = latents[:, :,
|
1027 |
-
|
1028 |
-
wrap_count = max(current_context_start + context_size - num_frames, 0)
|
1029 |
|
1030 |
-
# if context_start + context_size > num_frames: append the remaining frames from the start of the latents
|
1031 |
-
if wrap_count > 0:
|
1032 |
-
current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
|
1033 |
-
|
1034 |
# expand the latents if we are doing classifier free guidance
|
1035 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1036 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
@@ -1047,18 +1059,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1047 |
# sum the noise predictions for the unconditional and text conditioned noise
|
1048 |
if do_classifier_free_guidance:
|
1049 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1050 |
-
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
1056 |
-
# remove the ending frames from noise_pred_uncond
|
1057 |
-
noise_pred_uncond = noise_pred_uncond[:, :, :-wrap_count, :, :]
|
1058 |
-
noise_pred_text = noise_pred_text[:, :, :-wrap_count, :, :]
|
1059 |
-
noise_pred_uncond_sum[:, :, current_context_start : current_context_start + context_size, :, :] += noise_pred_uncond
|
1060 |
-
noise_pred_text_sum[:, :, current_context_start : current_context_start + context_size, :, :] += noise_pred_text
|
1061 |
-
latent_counter[current_context_start : current_context_start + context_size] += 1
|
1062 |
|
1063 |
# set the step index to the current batch
|
1064 |
self.scheduler._step_index = i
|
@@ -1078,9 +1084,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1078 |
progress_bar.update()
|
1079 |
if callback is not None and i % callback_steps == 0:
|
1080 |
callback(i, t, None)
|
1081 |
-
|
1082 |
-
# offset latent images by step places, wrapping around the last frames to the start
|
1083 |
-
latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
|
1084 |
|
1085 |
if output_type == "latent":
|
1086 |
return AnimateDiffPipelineOutput(frames=latents)
|
|
|
1005 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1006 |
|
1007 |
# divide the initial latents into context groups
|
1008 |
+
|
1009 |
+
def context_scheduler(context_size, overlap, offset, num_frames, num_timesteps):
|
1010 |
+
num_context_groups = (num_frames // (context_size-overlap))+1
|
1011 |
+
context_indexes = []
|
1012 |
+
for t in range(num_timesteps):
|
1013 |
+
context_groups = []
|
1014 |
+
for context_group_index in range(num_context_groups):
|
1015 |
+
context_group = []
|
1016 |
+
for i in range(context_size):
|
1017 |
+
# calculate the frame index
|
1018 |
+
frame_index = ((t+1) * context_group_index * (context_size-overlap)) + i
|
1019 |
+
# wrap around at the end
|
1020 |
+
if frame_index >= num_frames:frame_index = frame_index % num_frames
|
1021 |
+
context_group.append(frame_index)
|
1022 |
+
context_groups.append(context_groups)
|
1023 |
+
context_indexes.append(context_groups)
|
1024 |
+
return context_indexes
|
1025 |
+
|
1026 |
+
context_indexes = context_scheduler(context_size, overlap, num_frames, len(timesteps))
|
1027 |
+
|
1028 |
+
print(f"Context indexes: {context_indexes}")
|
1029 |
|
1030 |
# Denoising loop
|
1031 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
|
1033 |
for i, t in enumerate(timesteps):
|
1034 |
noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1035 |
noise_pred_text_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
|
|
1036 |
latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
|
1037 |
|
1038 |
# foreach context group seperately denoise the current timestep
|
1039 |
+
for context_group in range(len(context_indexes[i])):
|
1040 |
# calculate to current indexes, considering overlap
|
1041 |
+
current_context_indexes = context_indexes[i][context_group]
|
|
|
1042 |
|
1043 |
# select the relevent context from the latents
|
1044 |
+
current_context_latents = latents[:, :, current_context_indexes, :, :]
|
|
|
|
|
1045 |
|
|
|
|
|
|
|
|
|
1046 |
# expand the latents if we are doing classifier free guidance
|
1047 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1048 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
1059 |
# sum the noise predictions for the unconditional and text conditioned noise
|
1060 |
if do_classifier_free_guidance:
|
1061 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1062 |
+
|
1063 |
+
# add the ending frames from noise_pred_uncond to the start of the noise_pred_uncond_sum
|
1064 |
+
noise_pred_uncond_sum[:, :,current_context_indexes, :, :] += noise_pred_uncond
|
1065 |
+
noise_pred_text_sum[:, :,current_context_indexes, :, :] += noise_pred_text
|
1066 |
+
#increase the counter for the ending frames
|
1067 |
+
latent_counter[current_context_indexes] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
1068 |
|
1069 |
# set the step index to the current batch
|
1070 |
self.scheduler._step_index = i
|
|
|
1084 |
progress_bar.update()
|
1085 |
if callback is not None and i % callback_steps == 0:
|
1086 |
callback(i, t, None)
|
|
|
|
|
|
|
1087 |
|
1088 |
if output_type == "latent":
|
1089 |
return AnimateDiffPipelineOutput(frames=latents)
|