smoothieAI commited on
Commit
5aaa10b
·
verified ·
1 Parent(s): a6f5535

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +68 -67
pipeline.py CHANGED
@@ -1009,76 +1009,77 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1009
 
1010
  # Denoising loop
1011
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1012
- with self.progress_bar(total=len(timesteps)) as progress_bar:
1013
- for i, t in enumerate(timesteps):
1014
- latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1015
- latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1016
-
1017
- # foreach context group seperately denoise the current timestep
1018
- for context_group in range(num_context_groups):
1019
- # calculate to current indexes, considering overlap
1020
- if context_group == 0:current_context_start = 0
1021
- else:current_context_start = context_group * (context_size - overlap)
1022
-
1023
- # select the relevent context from the latents
1024
- current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
1025
-
1026
- wrap_count = max(current_context_start + context_size - num_frames, 0)
1027
-
1028
- # if context_start + context_size > num_frames: append the remaining frames from the start of the latents
1029
- if wrap_count > 0:
1030
- current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
1031
-
1032
- # expand the latents if we are doing classifier free guidance
1033
- latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1034
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1035
-
1036
- # predict the noise residual
1037
- noise_pred = self.unet(
1038
- latent_model_input,
1039
- t,
1040
- encoder_hidden_states=prompt_embeds,
1041
- cross_attention_kwargs=cross_attention_kwargs,
1042
- added_cond_kwargs=added_cond_kwargs,
1043
- ).sample
1044
-
1045
- # perform guidance
1046
- if do_classifier_free_guidance:
1047
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1048
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1049
 
 
 
 
 
 
1050
 
1051
- # set the step index to the current batch
1052
- self.scheduler._step_index = i
1053
-
1054
- # compute the previous noisy sample x_t -> x_t-1
1055
- current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1056
-
1057
- # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1058
- if wrap_count > 0:
1059
- # add the ending frames from current_context_latents to the start of the latent_sum
1060
- latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
1061
- # increase the counter for the ending frames
1062
- latent_counter[0:wrap_count] += 1
1063
- # remove the ending frames from current_context_latents
1064
- current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
1065
-
1066
- #add the context current_context_latents back to the latent sum starting from the current context start
1067
- latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
1068
- # add one to the counter for each timestep in the context
1069
- latent_counter[current_context_start : current_context_start + context_size] += 1
1070
-
1071
- # call the callback, if provided
1072
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1073
- progress_bar.update()
1074
- if callback is not None and i % callback_steps == 0:
1075
- callback(i, t, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1076
 
1077
- latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
1078
- latents = latent_sum / latent_counter
1079
-
1080
- # shuffle rotate latent images by step places, wrapping around the last 2 to the start
1081
- latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
1082
 
1083
  print("Done denoising")
1084
 
 
1009
 
1010
  # Denoising loop
1011
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1012
+ if False:
1013
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
1014
+ for i, t in enumerate(timesteps):
1015
+ latent_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1016
+ latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1017
+
1018
+ # foreach context group seperately denoise the current timestep
1019
+ for context_group in range(num_context_groups):
1020
+ # calculate to current indexes, considering overlap
1021
+ if context_group == 0:current_context_start = 0
1022
+ else:current_context_start = context_group * (context_size - overlap)
1023
+
1024
+ # select the relevent context from the latents
1025
+ current_context_latents = latents[:, :, current_context_start : current_context_start + context_size, :, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
 
1027
+ wrap_count = max(current_context_start + context_size - num_frames, 0)
1028
+
1029
+ # if context_start + context_size > num_frames: append the remaining frames from the start of the latents
1030
+ if wrap_count > 0:
1031
+ current_context_latents = torch.cat([current_context_latents, latents[:, :, :wrap_count, :, :]], dim=2)
1032
 
1033
+ # expand the latents if we are doing classifier free guidance
1034
+ latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1035
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1036
+
1037
+ # predict the noise residual
1038
+ noise_pred = self.unet(
1039
+ latent_model_input,
1040
+ t,
1041
+ encoder_hidden_states=prompt_embeds,
1042
+ cross_attention_kwargs=cross_attention_kwargs,
1043
+ added_cond_kwargs=added_cond_kwargs,
1044
+ ).sample
1045
+
1046
+ # perform guidance
1047
+ if do_classifier_free_guidance:
1048
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1049
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1050
+
1051
+
1052
+ # set the step index to the current batch
1053
+ self.scheduler._step_index = i
1054
+
1055
+ # compute the previous noisy sample x_t -> x_t-1
1056
+ current_context_latents = self.scheduler.step(noise_pred, t, current_context_latents, **extra_step_kwargs).prev_sample
1057
+
1058
+ # if context_start + context_size > num_frames: remove the appended frames from the end of the current_context_latents
1059
+ if wrap_count > 0:
1060
+ # add the ending frames from current_context_latents to the start of the latent_sum
1061
+ latent_sum[:, :, 0:wrap_count, :, :] += current_context_latents[:, :, -wrap_count:, :, :]
1062
+ # increase the counter for the ending frames
1063
+ latent_counter[0:wrap_count] += 1
1064
+ # remove the ending frames from current_context_latents
1065
+ current_context_latents = current_context_latents[:, :, :-wrap_count, :, :]
1066
+
1067
+ #add the context current_context_latents back to the latent sum starting from the current context start
1068
+ latent_sum[:, :, current_context_start : current_context_start + context_size, :, :] += current_context_latents
1069
+ # add one to the counter for each timestep in the context
1070
+ latent_counter[current_context_start : current_context_start + context_size] += 1
1071
+
1072
+ # call the callback, if provided
1073
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1074
+ progress_bar.update()
1075
+ if callback is not None and i % callback_steps == 0:
1076
+ callback(i, t, None)
1077
+
1078
+ latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
1079
+ latents = latent_sum / latent_counter
1080
 
1081
+ # shuffle rotate latent images by step places, wrapping around the last 2 to the start
1082
+ latents = torch.cat([latents[:, :, -step:, :, :], latents[:, :, :-step, :, :]], dim=2)
 
 
 
1083
 
1084
  print("Done denoising")
1085