smoothieAI commited on
Commit
d51def2
·
verified ·
1 Parent(s): c05212b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +32 -11
pipeline.py CHANGED
@@ -812,7 +812,19 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
812
  latents = latents.to(device)
813
  return latents, init_latents
814
 
815
-
 
 
 
 
 
 
 
 
 
 
 
 
816
  @torch.no_grad()
817
  # @replace_example_docstring(EXAMPLE_DOC_STRING)
818
  def __call__(
@@ -1005,16 +1017,25 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1005
 
1006
  if self.controlnet != None:
1007
  if isinstance(controlnet, ControlNetModel):
1008
- conditioning_frames = self.prepare_image(
1009
- image=conditioning_frames,
1010
- width=width,
1011
- height=height,
1012
- batch_size=batch_size * num_videos_per_prompt * num_frames,
1013
- num_images_per_prompt=num_videos_per_prompt,
1014
- device=device,
1015
- dtype=controlnet.dtype,
1016
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1017
- guess_mode=guess_mode,
 
 
 
 
 
 
 
 
 
1018
  )
1019
  elif isinstance(controlnet, MultiControlNetModel):
1020
  cond_prepared_frames = []
 
812
  latents = latents.to(device)
813
  return latents, init_latents
814
 
815
+ def prepare_control_latents(self, batch_size, contorl_frames, num_channels_latents, num_frames, height, width, dtype, device):
816
+ shape = (
817
+ num_frames,
818
+ num_channels_latents,
819
+ height // self.vae_scale_factor,
820
+ width // self.vae_scale_factor,
821
+ )
822
+
823
+ # convert input control image array to latents tensor array
824
+ latents = torch.zeros(shape, dtype=dtype, device=device)
825
+
826
+ return latents
827
+
828
  @torch.no_grad()
829
  # @replace_example_docstring(EXAMPLE_DOC_STRING)
830
  def __call__(
 
1017
 
1018
  if self.controlnet != None:
1019
  if isinstance(controlnet, ControlNetModel):
1020
+ # conditioning_frames = self.prepare_image(
1021
+ # image=conditioning_frames,
1022
+ # width=width,
1023
+ # height=height,
1024
+ # batch_size=batch_size * num_videos_per_prompt * num_frames,
1025
+ # num_images_per_prompt=num_videos_per_prompt,
1026
+ # device=device,
1027
+ # dtype=controlnet.dtype,
1028
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
1029
+ # guess_mode=guess_mode,
1030
+ # )
1031
+ conditioning_frames = self.prepare_control_latents(
1032
+ num_frames,
1033
+ conditioning_frames,
1034
+ num_channels_latents,
1035
+ height,
1036
+ width,
1037
+ prompt_embeds.dtype,
1038
+ device,
1039
  )
1040
  elif isinstance(controlnet, MultiControlNetModel):
1041
  cond_prepared_frames = []