Update pipeline.py
Browse files- pipeline.py +27 -17
pipeline.py
CHANGED
@@ -1011,7 +1011,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1011 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1012 |
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
1013 |
for i, t in enumerate(timesteps):
|
1014 |
-
|
|
|
|
|
1015 |
latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
|
1016 |
|
1017 |
# foreach context group seperately denoise the current timestep
|
@@ -1045,37 +1047,45 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1045 |
# perform guidance
|
1046 |
if do_classifier_free_guidance:
|
1047 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1048 |
-
|
1049 |
-
|
1050 |
|
1051 |
# set the step index to the current batch
|
1052 |
-
|
1053 |
-
|
1054 |
-
# compute the previous noisy sample x_t -> x_t-1
|
1055 |
-
current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
|
1056 |
|
1057 |
# if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
|
1058 |
-
if wrap_count > 0:
|
1059 |
-
|
1060 |
-
|
1061 |
-
|
1062 |
-
|
1063 |
-
|
1064 |
-
|
1065 |
|
1066 |
#add the context current_context_latents back to the latent sum starting from the current context start
|
1067 |
-
latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
|
|
|
1068 |
# add one to the counter for each timestep in the context
|
1069 |
latent_counter[current_context_start : current_context_start + context_size] += 1
|
1070 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1071 |
# call the callback, if provided
|
1072 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1073 |
progress_bar.update()
|
1074 |
if callback is not None and i % callback_steps == 0:
|
1075 |
callback(i, t, None)
|
1076 |
|
1077 |
-
|
1078 |
-
latents = latent_sum / latent_counter
|
1079 |
|
1080 |
# shuffle rotate latent images by step places, wrapping around the last 2 to the start
|
1081 |
latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
|
|
|
1011 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1012 |
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
1013 |
for i, t in enumerate(timesteps):
|
1014 |
+
noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1015 |
+
noise_pred_text_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1016 |
+
# latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
1017 |
latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
|
1018 |
|
1019 |
# foreach context group seperately denoise the current timestep
|
|
|
1047 |
# perform guidance
|
1048 |
if do_classifier_free_guidance:
|
1049 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1050 |
+
noise_pred_uncond_sum += noise_pred_uncond
|
1051 |
+
noise_pred_text_sum += noise_pred_text
|
1052 |
|
1053 |
# set the step index to the current batch
|
1054 |
+
self.scheduler._step_index = i
|
|
|
|
|
|
|
1055 |
|
1056 |
# if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
|
1057 |
+
# if wrap_count > 0:
|
1058 |
+
# # add the ending frames from current_context_latents to the start of the latent_sum
|
1059 |
+
# latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
|
1060 |
+
# # increase the counter for the ending frames
|
1061 |
+
# latent_counter[0:wrap_count] += 1
|
1062 |
+
# # remove the ending frames from current_context_latents
|
1063 |
+
# current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
|
1064 |
|
1065 |
#add the context current_context_latents back to the latent sum starting from the current context start
|
1066 |
+
# latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
|
1067 |
+
|
1068 |
# add one to the counter for each timestep in the context
|
1069 |
latent_counter[current_context_start : current_context_start + context_size] += 1
|
1070 |
|
1071 |
+
# perform guidance
|
1072 |
+
if do_classifier_free_guidance:
|
1073 |
+
latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
|
1074 |
+
noise_pred_uncond = noise_pred_uncond_sum / latent_counter
|
1075 |
+
noise_pred_text = noise_pred_text_sum / latent_counter
|
1076 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1077 |
+
|
1078 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1079 |
+
current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
|
1080 |
+
|
1081 |
# call the callback, if provided
|
1082 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1083 |
progress_bar.update()
|
1084 |
if callback is not None and i % callback_steps == 0:
|
1085 |
callback(i, t, None)
|
1086 |
|
1087 |
+
|
1088 |
+
# latents = latent_sum / latent_counter
|
1089 |
|
1090 |
# shuffle rotate latent images by step places, wrapping around the last 2 to the start
|
1091 |
latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
|