Update pipeline.py
Browse files- pipeline.py +32 -11
pipeline.py
CHANGED
@@ -812,7 +812,19 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
812 |
latents = latents.to(device)
|
813 |
return latents, init_latents
|
814 |
|
815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
816 |
@torch.no_grad()
|
817 |
# @replace_example_docstring(EXAMPLE_DOC_STRING)
|
818 |
def __call__(
|
@@ -1005,16 +1017,25 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1005 |
|
1006 |
if self.controlnet != None:
|
1007 |
if isinstance(controlnet, ControlNetModel):
|
1008 |
-
conditioning_frames = self.prepare_image(
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1018 |
)
|
1019 |
elif isinstance(controlnet, MultiControlNetModel):
|
1020 |
cond_prepared_frames = []
|
|
|
812 |
latents = latents.to(device)
|
813 |
return latents, init_latents
|
814 |
|
815 |
+
def prepare_control_latents(self, batch_size, contorl_frames, num_channels_latents, num_frames, height, width, dtype, device):
|
816 |
+
shape = (
|
817 |
+
num_frames,
|
818 |
+
num_channels_latents,
|
819 |
+
height // self.vae_scale_factor,
|
820 |
+
width // self.vae_scale_factor,
|
821 |
+
)
|
822 |
+
|
823 |
+
# convert input control image array to latents tensor array
|
824 |
+
latents = torch.zeros(shape, dtype=dtype, device=device)
|
825 |
+
|
826 |
+
return latents
|
827 |
+
|
828 |
@torch.no_grad()
|
829 |
# @replace_example_docstring(EXAMPLE_DOC_STRING)
|
830 |
def __call__(
|
|
|
1017 |
|
1018 |
if self.controlnet != None:
|
1019 |
if isinstance(controlnet, ControlNetModel):
|
1020 |
+
# conditioning_frames = self.prepare_image(
|
1021 |
+
# image=conditioning_frames,
|
1022 |
+
# width=width,
|
1023 |
+
# height=height,
|
1024 |
+
# batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1025 |
+
# num_images_per_prompt=num_videos_per_prompt,
|
1026 |
+
# device=device,
|
1027 |
+
# dtype=controlnet.dtype,
|
1028 |
+
# do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1029 |
+
# guess_mode=guess_mode,
|
1030 |
+
# )
|
1031 |
+
conditioning_frames = self.prepare_control_latents(
|
1032 |
+
num_frames,
|
1033 |
+
conditioning_frames,
|
1034 |
+
num_channels_latents,
|
1035 |
+
height,
|
1036 |
+
width,
|
1037 |
+
prompt_embeds.dtype,
|
1038 |
+
device,
|
1039 |
)
|
1040 |
elif isinstance(controlnet, MultiControlNetModel):
|
1041 |
cond_prepared_frames = []
|