smoothieAI commited on
Commit
adb42e1
·
verified ·
1 Parent(s): 5aaa10b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +67 -68
pipeline.py CHANGED
@@ -1009,77 +1009,76 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1009
 
1010
  # Denoising loop
1011
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1012
- if False:
1013
- with self.progress_bar(total=len(timesteps)) as progress_bar:
1014
- for i, t in enumerate(timesteps):
1015
- latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1016
- latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1017
-
1018
- # foreach context group seperately denoise the current timestep
1019
- for context_group in range(num_context_groups):
1020
- # calculate to current indexes, considering overlap
1021
- if context_group == 0:current_context_start = 0
1022
- else:current_context_start = context_group * (context_size - overlap)
1023
-
1024
- # select the relevent context from the latents
1025
- current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
1026
-
1027
- wrap_count = max(current_context_start + context_size - num_frames, 0)
1028
-
1029
- # if context_start + context_size > num_frames: append the remaining frames from the start of the latents
1030
- if wrap_count > 0:
1031
- current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
1032
 
1033
- # expand the latents if we are doing classifier free guidance
1034
- latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1035
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1036
-
1037
- # predict the noise residual
1038
- noise_pred = self.unet(
1039
- latent_model_input,
1040
- t,
1041
- encoder_hidden_states=prompt_embeds,
1042
- cross_attention_kwargs=cross_attention_kwargs,
1043
- added_cond_kwargs=added_cond_kwargs,
1044
- ).sample
1045
-
1046
- # perform guidance
1047
- if do_classifier_free_guidance:
1048
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1049
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1050
-
1051
-
1052
- # set the step index to the current batch
1053
- self.scheduler._step_index = i
1054
-
1055
- # compute the previous noisy sample x_t -> x_t-1
1056
- current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1057
-
1058
- # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1059
- if wrap_count > 0:
1060
- # add the ending frames from current_context_latents to the start of the latent_sum
1061
- latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
1062
- # increase the counter for the ending frames
1063
- latent_counter[0:wrap_count] += 1
1064
- # remove the ending frames from current_context_latents
1065
- current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
1066
-
1067
- #add the context current_context_latents back to the latent sum starting from the current context start
1068
- latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
1069
- # add one to the counter for each timestep in the context
1070
- latent_counter[current_context_start : current_context_start + context_size] += 1
1071
-
1072
- # call the callback, if provided
1073
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1074
- progress_bar.update()
1075
- if callback is not None and i % callback_steps == 0:
1076
- callback(i, t, None)
1077
 
1078
- latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
1079
- latents = latent_sum / latent_counter
1080
 
1081
- # shuffle rotate latent images by step places, wrapping around the last 2 to the start
1082
- latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1083
 
1084
  print("Done denoising")
1085
 
 
1009
 
1010
  # Denoising loop
1011
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1012
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
1013
+ for i, t in enumerate(timesteps):
1014
+ latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1015
+ latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1016
+
1017
+ # foreach context group seperately denoise the current timestep
1018
+ for context_group in range(num_context_groups):
1019
+ # calculate to current indexes, considering overlap
1020
+ if context_group == 0:current_context_start = 0
1021
+ else:current_context_start = context_group * (context_size - overlap)
1022
+
1023
+ # select the relevent context from the latents
1024
+ current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
 
 
 
 
 
 
 
1025
 
1026
+ wrap_count = max(current_context_start + context_size - num_frames, 0)
1027
+
1028
+ # if context_start + context_size > num_frames: append the remaining frames from the start of the latents
1029
+ if wrap_count > 0:
1030
+ current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
1031
+
1032
+ # expand the latents if we are doing classifier free guidance
1033
+ latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1034
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1035
+
1036
+ # predict the noise residual
1037
+ noise_pred = self.unet(
1038
+ latent_model_input,
1039
+ t,
1040
+ encoder_hidden_states=prompt_embeds,
1041
+ cross_attention_kwargs=cross_attention_kwargs,
1042
+ added_cond_kwargs=added_cond_kwargs,
1043
+ ).sample
1044
+
1045
+ # perform guidance
1046
+ if do_classifier_free_guidance:
1047
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1048
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
 
 
 
1050
 
1051
+ # set the step index to the current batch
1052
+ self.scheduler._step_index = i
1053
+
1054
+ # compute the previous noisy sample x_t -> x_t-1
1055
+ current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1056
+
1057
+ # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1058
+ if wrap_count > 0:
1059
+ # add the ending frames from current_context_latents to the start of the latent_sum
1060
+ latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
1061
+ # increase the counter for the ending frames
1062
+ latent_counter[0:wrap_count] += 1
1063
+ # remove the ending frames from current_context_latents
1064
+ current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
1065
+
1066
+ #add the context current_context_latents back to the latent sum starting from the current context start
1067
+ latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
1068
+ # add one to the counter for each timestep in the context
1069
+ latent_counter[current_context_start : current_context_start + context_size] += 1
1070
+
1071
+ # call the callback, if provided
1072
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1073
+ progress_bar.update()
1074
+ if callback is not None and i % callback_steps == 0:
1075
+ callback(i, t, None)
1076
+
1077
+ latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
1078
+ latents = latent_sum / latent_counter
1079
+
1080
+ # shuffle rotate latent images by step places, wrapping around the last 2 to the start
1081
+ latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
1082
 
1083
  print("Done denoising")
1084