smoothieAI commited on
Commit
b04efd0
·
verified ·
1 Parent(s): e3b7e33

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +56 -10
pipeline.py CHANGED
@@ -825,6 +825,37 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
825
 
826
  return latents
827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
  @torch.no_grad()
829
  # @replace_example_docstring(EXAMPLE_DOC_STRING)
830
  def __call__(
@@ -1028,19 +1059,34 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1028
  # do_classifier_free_guidance=self.do_classifier_free_guidance,
1029
  # guess_mode=guess_mode,
1030
  # )
1031
- conditioning_frames = self.prepare_control_latents(
1032
- num_frames,
1033
- conditioning_frames,
1034
- num_channels_latents,
1035
- height,
1036
- width,
1037
- prompt_embeds.dtype,
1038
- device,
 
 
1039
  )
 
1040
  elif isinstance(controlnet, MultiControlNetModel):
1041
  cond_prepared_frames = []
1042
  for frame_ in conditioning_frames:
1043
- prepared_frame = self.prepare_image(
 
 
 
 
 
 
 
 
 
 
 
 
1044
  image=frame_,
1045
  width=width,
1046
  height=height,
@@ -1051,7 +1097,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1051
  do_classifier_free_guidance=self.do_classifier_free_guidance,
1052
  guess_mode=guess_mode,
1053
  )
1054
-
1055
  cond_prepared_frames.append(prepared_frame)
1056
 
1057
  conditioning_frames = cond_prepared_frames
 
825
 
826
  return latents
827
 
828
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
829
+ def prepare_control_frames(
830
+ self,
831
+ image,
832
+ width,
833
+ height,
834
+ batch_size,
835
+ num_images_per_prompt,
836
+ device,
837
+ dtype,
838
+ do_classifier_free_guidance=False,
839
+ guess_mode=False,
840
+ ):
841
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
842
+ image_batch_size = image.shape[0]
843
+
844
+ if image_batch_size == 1:
845
+ repeat_by = batch_size
846
+ else:
847
+ # image batch size is the same as prompt batch size
848
+ repeat_by = num_images_per_prompt
849
+
850
+ image = image.repeat_interleave(repeat_by, dim=0)
851
+
852
+ image = image.to(device=device, dtype=dtype)
853
+
854
+ if do_classifier_free_guidance and not guess_mode:
855
+ image = torch.cat([image] * 2)
856
+
857
+ return image
858
+
859
  @torch.no_grad()
860
  # @replace_example_docstring(EXAMPLE_DOC_STRING)
861
  def __call__(
 
1059
  # do_classifier_free_guidance=self.do_classifier_free_guidance,
1060
  # guess_mode=guess_mode,
1061
  # )
1062
+ conditioning_frames = self.prepare_control_frames(
1063
+ image=frame_,
1064
+ width=width,
1065
+ height=height,
1066
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
1067
+ num_images_per_prompt=num_videos_per_prompt,
1068
+ device=device,
1069
+ dtype=controlnet.dtype,
1070
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1071
+ guess_mode=guess_mode,
1072
  )
1073
+
1074
  elif isinstance(controlnet, MultiControlNetModel):
1075
  cond_prepared_frames = []
1076
  for frame_ in conditioning_frames:
1077
+ # prepared_frame = self.prepare_image(
1078
+ # image=frame_,
1079
+ # width=width,
1080
+ # height=height,
1081
+ # batch_size=batch_size * num_videos_per_prompt * num_frames,
1082
+ # num_images_per_prompt=num_videos_per_prompt,
1083
+ # device=device,
1084
+ # dtype=controlnet.dtype,
1085
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
1086
+ # guess_mode=guess_mode,
1087
+ # )
1088
+
1089
+ prepared_frame = self.prepare_control_frames(
1090
  image=frame_,
1091
  width=width,
1092
  height=height,
 
1097
  do_classifier_free_guidance=self.do_classifier_free_guidance,
1098
  guess_mode=guess_mode,
1099
  )
1100
+
1101
  cond_prepared_frames.append(prepared_frame)
1102
 
1103
  conditioning_frames = cond_prepared_frames