Update pipeline.py
Browse files- pipeline.py +6 -2
pipeline.py
CHANGED
@@ -839,7 +839,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
839 |
guess_mode=False,
|
840 |
):
|
841 |
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
842 |
-
image_batch_size = image.shape[0]
|
|
|
843 |
print("prepared control image_batch_size", image_batch_size)
|
844 |
|
845 |
if image_batch_size == 1:
|
@@ -848,7 +849,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
848 |
# image batch size is the same as prompt batch size
|
849 |
repeat_by = num_images_per_prompt
|
850 |
|
851 |
-
image = image.repeat_interleave(repeat_by, dim=0)
|
852 |
|
853 |
image = image.to(device=device, dtype=dtype)
|
854 |
|
@@ -1285,6 +1286,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1285 |
if isinstance(controlnet_cond_scale, list):
|
1286 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
1287 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
|
|
|
|
|
|
1288 |
|
1289 |
control_model_input = torch.transpose(control_model_input, 1, 2)
|
1290 |
control_model_input = control_model_input.reshape(
|
|
|
839 |
guess_mode=False,
|
840 |
):
|
841 |
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
842 |
+
# image_batch_size = image.shape[0]
|
843 |
+
image_batch_size = len(image)
|
844 |
print("prepared control image_batch_size", image_batch_size)
|
845 |
|
846 |
if image_batch_size == 1:
|
|
|
849 |
# image batch size is the same as prompt batch size
|
850 |
repeat_by = num_images_per_prompt
|
851 |
|
852 |
+
# image = image.repeat_interleave(repeat_by, dim=0)
|
853 |
|
854 |
image = image.to(device=device, dtype=dtype)
|
855 |
|
|
|
1286 |
if isinstance(controlnet_cond_scale, list):
|
1287 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
1288 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1289 |
+
|
1290 |
+
print("-----------------------")
|
1291 |
+
print("control_model_input.shape", control_model_input.shape)
|
1292 |
|
1293 |
control_model_input = torch.transpose(control_model_input, 1, 2)
|
1294 |
control_model_input = control_model_input.reshape(
|