Update pipeline.py
Browse files- pipeline.py +56 -10
pipeline.py
CHANGED
@@ -825,6 +825,37 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
825 |
|
826 |
return latents
|
827 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
@torch.no_grad()
|
829 |
# @replace_example_docstring(EXAMPLE_DOC_STRING)
|
830 |
def __call__(
|
@@ -1028,19 +1059,34 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1028 |
# do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1029 |
# guess_mode=guess_mode,
|
1030 |
# )
|
1031 |
-
conditioning_frames = self.
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
|
|
|
|
1039 |
)
|
|
|
1040 |
elif isinstance(controlnet, MultiControlNetModel):
|
1041 |
cond_prepared_frames = []
|
1042 |
for frame_ in conditioning_frames:
|
1043 |
-
prepared_frame = self.prepare_image(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1044 |
image=frame_,
|
1045 |
width=width,
|
1046 |
height=height,
|
@@ -1051,7 +1097,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1051 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1052 |
guess_mode=guess_mode,
|
1053 |
)
|
1054 |
-
|
1055 |
cond_prepared_frames.append(prepared_frame)
|
1056 |
|
1057 |
conditioning_frames = cond_prepared_frames
|
|
|
825 |
|
826 |
return latents
|
827 |
|
828 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
829 |
+
def prepare_control_frames(
|
830 |
+
self,
|
831 |
+
image,
|
832 |
+
width,
|
833 |
+
height,
|
834 |
+
batch_size,
|
835 |
+
num_images_per_prompt,
|
836 |
+
device,
|
837 |
+
dtype,
|
838 |
+
do_classifier_free_guidance=False,
|
839 |
+
guess_mode=False,
|
840 |
+
):
|
841 |
+
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
842 |
+
image_batch_size = image.shape[0]
|
843 |
+
|
844 |
+
if image_batch_size == 1:
|
845 |
+
repeat_by = batch_size
|
846 |
+
else:
|
847 |
+
# image batch size is the same as prompt batch size
|
848 |
+
repeat_by = num_images_per_prompt
|
849 |
+
|
850 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
851 |
+
|
852 |
+
image = image.to(device=device, dtype=dtype)
|
853 |
+
|
854 |
+
if do_classifier_free_guidance and not guess_mode:
|
855 |
+
image = torch.cat([image] * 2)
|
856 |
+
|
857 |
+
return image
|
858 |
+
|
859 |
@torch.no_grad()
|
860 |
# @replace_example_docstring(EXAMPLE_DOC_STRING)
|
861 |
def __call__(
|
|
|
1059 |
# do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1060 |
# guess_mode=guess_mode,
|
1061 |
# )
|
1062 |
+
conditioning_frames = self.prepare_control_frames(
|
1063 |
+
image=frame_,
|
1064 |
+
width=width,
|
1065 |
+
height=height,
|
1066 |
+
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1067 |
+
num_images_per_prompt=num_videos_per_prompt,
|
1068 |
+
device=device,
|
1069 |
+
dtype=controlnet.dtype,
|
1070 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1071 |
+
guess_mode=guess_mode,
|
1072 |
)
|
1073 |
+
|
1074 |
elif isinstance(controlnet, MultiControlNetModel):
|
1075 |
cond_prepared_frames = []
|
1076 |
for frame_ in conditioning_frames:
|
1077 |
+
# prepared_frame = self.prepare_image(
|
1078 |
+
# image=frame_,
|
1079 |
+
# width=width,
|
1080 |
+
# height=height,
|
1081 |
+
# batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1082 |
+
# num_images_per_prompt=num_videos_per_prompt,
|
1083 |
+
# device=device,
|
1084 |
+
# dtype=controlnet.dtype,
|
1085 |
+
# do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1086 |
+
# guess_mode=guess_mode,
|
1087 |
+
# )
|
1088 |
+
|
1089 |
+
prepared_frame = self.prepare_control_frames(
|
1090 |
image=frame_,
|
1091 |
width=width,
|
1092 |
height=height,
|
|
|
1097 |
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1098 |
guess_mode=guess_mode,
|
1099 |
)
|
1100 |
+
|
1101 |
cond_prepared_frames.append(prepared_frame)
|
1102 |
|
1103 |
conditioning_frames = cond_prepared_frames
|