smoothieAI commited on
Commit
3f6b453
·
verified ·
1 Parent(s): a037fba

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +27 -17
pipeline.py CHANGED
@@ -1011,7 +1011,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1011
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1012
  with self.progress_bar(total=len(timesteps)) as progress_bar:
1013
  for i, t in enumerate(timesteps):
1014
- latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
 
 
1015
  latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1016
 
1017
  # foreach context group seperately denoise the current timestep
@@ -1045,37 +1047,45 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1045
  # perform guidance
1046
  if do_classifier_free_guidance:
1047
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1048
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1049
-
1050
 
1051
  # set the step index to the current batch
1052
- # self.scheduler._step_index = i
1053
-
1054
- # compute the previous noisy sample x_t -> x_t-1
1055
- current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1056
 
1057
  # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1058
- if wrap_count > 0:
1059
- # add the ending frames from current_context_latents to the start of the latent_sum
1060
- latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
1061
- # increase the counter for the ending frames
1062
- latent_counter[0:wrap_count] += 1
1063
- # remove the ending frames from current_context_latents
1064
- current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
1065
 
1066
  #add the context current_context_latents back to the latent sum starting from the current context start
1067
- latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
 
1068
  # add one to the counter for each timestep in the context
1069
  latent_counter[current_context_start : current_context_start + context_size] += 1
1070
 
 
 
 
 
 
 
 
 
 
 
1071
  # call the callback, if provided
1072
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1073
  progress_bar.update()
1074
  if callback is not None and i % callback_steps == 0:
1075
  callback(i, t, None)
1076
 
1077
- latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
1078
- latents = latent_sum / latent_counter
1079
 
1080
  # shuffle rotate latent images by step places, wrapping around the last 2 to the start
1081
  latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
 
1011
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1012
  with self.progress_bar(total=len(timesteps)) as progress_bar:
1013
  for i, t in enumerate(timesteps):
1014
+ noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1015
+ noise_pred_text_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1016
+ # latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1017
  latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1018
 
1019
  # foreach context group seperately denoise the current timestep
 
1047
  # perform guidance
1048
  if do_classifier_free_guidance:
1049
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1050
+ noise_pred_uncond_sum += noise_pred_uncond
1051
+ noise_pred_text_sum += noise_pred_text
1052
 
1053
  # set the step index to the current batch
1054
+ self.scheduler._step_index = i
 
 
 
1055
 
1056
  # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1057
+ # if wrap_count > 0:
1058
+ # # add the ending frames from current_context_latents to the start of the latent_sum
1059
+ # latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
1060
+ # # increase the counter for the ending frames
1061
+ # latent_counter[0:wrap_count] += 1
1062
+ # # remove the ending frames from current_context_latents
1063
+ # current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
1064
 
1065
  #add the context current_context_latents back to the latent sum starting from the current context start
1066
+ # latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
1067
+
1068
  # add one to the counter for each timestep in the context
1069
  latent_counter[current_context_start : current_context_start + context_size] += 1
1070
 
1071
+ # perform guidance
1072
+ if do_classifier_free_guidance:
1073
+ latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
1074
+ noise_pred_uncond = noise_pred_uncond_sum / latent_counter
1075
+ noise_pred_text = noise_pred_text_sum / latent_counter
1076
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1077
+
1078
+ # compute the previous noisy sample x_t -> x_t-1
1079
+ current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1080
+
1081
  # call the callback, if provided
1082
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1083
  progress_bar.update()
1084
  if callback is not None and i % callback_steps == 0:
1085
  callback(i, t, None)
1086
 
1087
+
1088
+ # latents = latent_sum / latent_counter
1089
 
1090
  # shuffle rotate latent images by step places, wrapping around the last 2 to the start
1091
  latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)