smoothieAI commited on
Commit
b9d8a02
·
verified ·
1 Parent(s): a3f7e00

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -2
pipeline.py CHANGED
@@ -839,7 +839,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
839
  guess_mode=False,
840
  ):
841
  image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
842
- image_batch_size = image.shape[0]
 
843
  print("prepared control image_batch_size", image_batch_size)
844
 
845
  if image_batch_size == 1:
@@ -848,7 +849,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
848
  # image batch size is the same as prompt batch size
849
  repeat_by = num_images_per_prompt
850
 
851
- image = image.repeat_interleave(repeat_by, dim=0)
852
 
853
  image = image.to(device=device, dtype=dtype)
854
 
@@ -1285,6 +1286,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1285
  if isinstance(controlnet_cond_scale, list):
1286
  controlnet_cond_scale = controlnet_cond_scale[0]
1287
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
 
 
 
1288
 
1289
  control_model_input = torch.transpose(control_model_input, 1, 2)
1290
  control_model_input = control_model_input.reshape(
 
839
  guess_mode=False,
840
  ):
841
  image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
842
+ # image_batch_size = image.shape[0]
843
+ image_batch_size = len(image)
844
  print("prepared control image_batch_size", image_batch_size)
845
 
846
  if image_batch_size == 1:
 
849
  # image batch size is the same as prompt batch size
850
  repeat_by = num_images_per_prompt
851
 
852
+ # image = image.repeat_interleave(repeat_by, dim=0)
853
 
854
  image = image.to(device=device, dtype=dtype)
855
 
 
1286
  if isinstance(controlnet_cond_scale, list):
1287
  controlnet_cond_scale = controlnet_cond_scale[0]
1288
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1289
+
1290
+ print("-----------------------")
1291
+ print("control_model_input.shape", control_model_input.shape)
1292
 
1293
  control_model_input = torch.transpose(control_model_input, 1, 2)
1294
  control_model_input = control_model_input.reshape(