Update pipeline.py
Browse files- pipeline.py +67 -68
pipeline.py
CHANGED
@@ -1009,77 +1009,76 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1009 |
|
1010 |
# Denoising loop
|
1011 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
|
1026 |
-
|
1027 |
-
wrap_count = max(current_context_start + context_size - num_frames, 0)
|
1028 |
-
|
1029 |
-
# if context_start + context_size > num_frames: append the remaining frames from the start of the latents
|
1030 |
-
if wrap_count > 0:
|
1031 |
-
current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
|
1032 |
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
|
1044 |
-
|
1045 |
-
|
1046 |
-
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
1056 |
-
current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
|
1057 |
-
|
1058 |
-
# if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
|
1059 |
-
if wrap_count > 0:
|
1060 |
-
# add the ending frames from current_context_latents to the start of the latent_sum
|
1061 |
-
latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
|
1062 |
-
# increase the counter for the ending frames
|
1063 |
-
latent_counter[0:wrap_count] += 1
|
1064 |
-
# remove the ending frames from current_context_latents
|
1065 |
-
current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
|
1066 |
-
|
1067 |
-
#add the context current_context_latents back to the latent sum starting from the current context start
|
1068 |
-
latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
|
1069 |
-
# add one to the counter for each timestep in the context
|
1070 |
-
latent_counter[current_context_start : current_context_start + context_size] += 1
|
1071 |
-
|
1072 |
-
# call the callback, if provided
|
1073 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1074 |
-
progress_bar.update()
|
1075 |
-
if callback is not None and i % callback_steps == 0:
|
1076 |
-
callback(i, t, None)
|
1077 |
|
1078 |
-
latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
|
1079 |
-
latents = latent_sum / latent_counter
|
1080 |
|
1081 |
-
#
|
1082 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1083 |
|
1084 |
print("Done denoising")
|
1085 |
|
|
|
1009 |
|
1010 |
# Denoising loop
|
1011 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1012 |
+
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
1013 |
+
for i, t in enumerate(timesteps):
|
1014 |
+
latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1015 |
+
latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
|
1016 |
+
|
1017 |
+
# foreach context group seperately denoise the current timestep
|
1018 |
+
for context_group in range(num_context_groups):
|
1019 |
+
# calculate to current indexes, considering overlap
|
1020 |
+
if context_group == 0:current_context_start = 0
|
1021 |
+
else:current_context_start = context_group * (context_size - overlap)
|
1022 |
+
|
1023 |
+
# select the relevent context from the latents
|
1024 |
+
current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1025 |
|
1026 |
+
wrap_count = max(current_context_start + context_size - num_frames, 0)
|
1027 |
+
|
1028 |
+
# if context_start + context_size > num_frames: append the remaining frames from the start of the latents
|
1029 |
+
if wrap_count > 0:
|
1030 |
+
current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
|
1031 |
+
|
1032 |
+
# expand the latents if we are doing classifier free guidance
|
1033 |
+
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1034 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1035 |
+
|
1036 |
+
# predict the noise residual
|
1037 |
+
noise_pred = self.unet(
|
1038 |
+
latent_model_input,
|
1039 |
+
t,
|
1040 |
+
encoder_hidden_states=prompt_embeds,
|
1041 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1042 |
+
added_cond_kwargs=added_cond_kwargs,
|
1043 |
+
).sample
|
1044 |
+
|
1045 |
+
# perform guidance
|
1046 |
+
if do_classifier_free_guidance:
|
1047 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1048 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1049 |
|
|
|
|
|
1050 |
|
1051 |
+
# set the step index to the current batch
|
1052 |
+
self.scheduler._step_index = i
|
1053 |
+
|
1054 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1055 |
+
current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
|
1056 |
+
|
1057 |
+
# if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
|
1058 |
+
if wrap_count > 0:
|
1059 |
+
# add the ending frames from current_context_latents to the start of the latent_sum
|
1060 |
+
latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
|
1061 |
+
# increase the counter for the ending frames
|
1062 |
+
latent_counter[0:wrap_count] += 1
|
1063 |
+
# remove the ending frames from current_context_latents
|
1064 |
+
current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
|
1065 |
+
|
1066 |
+
#add the context current_context_latents back to the latent sum starting from the current context start
|
1067 |
+
latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
|
1068 |
+
# add one to the counter for each timestep in the context
|
1069 |
+
latent_counter[current_context_start : current_context_start + context_size] += 1
|
1070 |
+
|
1071 |
+
# call the callback, if provided
|
1072 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1073 |
+
progress_bar.update()
|
1074 |
+
if callback is not None and i % callback_steps == 0:
|
1075 |
+
callback(i, t, None)
|
1076 |
+
|
1077 |
+
latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
|
1078 |
+
latents = latent_sum / latent_counter
|
1079 |
+
|
1080 |
+
# shuffle rotate latent images by step places, wrapping around the last 2 to the start
|
1081 |
+
latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
|
1082 |
|
1083 |
print("Done denoising")
|
1084 |
|