Update pipeline.py
Browse files- pipeline.py +4 -7
pipeline.py
CHANGED
@@ -691,12 +691,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
691 |
shape = shape
|
692 |
# shape = (1,) + shape[1:]
|
693 |
# ignore init latents for batch model
|
694 |
-
latents = [
|
695 |
-
torch.randn(
|
696 |
-
shape, generator=generator[i], device=rand_device, dtype=dtype
|
697 |
-
)
|
698 |
-
for i in range(batch_size)
|
699 |
-
]
|
700 |
latents = torch.cat(latents, dim=0).to(device)
|
701 |
else:
|
702 |
if init_latents is not None:
|
@@ -943,7 +938,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
943 |
# y_velocity=y_velocity,
|
944 |
# scale_velocity=scale_velocity,
|
945 |
# )
|
946 |
-
|
947 |
init_image,
|
948 |
init_image_strength,
|
949 |
init_noise_correlation,
|
@@ -956,6 +951,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
956 |
device,
|
957 |
generator,
|
958 |
)
|
|
|
|
|
959 |
|
960 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
961 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
691 |
shape = shape
|
692 |
# shape = (1,) + shape[1:]
|
693 |
# ignore init latents for batch model
|
694 |
+
latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)for i in range(batch_size)]
|
|
|
|
|
|
|
|
|
|
|
695 |
latents = torch.cat(latents, dim=0).to(device)
|
696 |
else:
|
697 |
if init_latents is not None:
|
|
|
938 |
# y_velocity=y_velocity,
|
939 |
# scale_velocity=scale_velocity,
|
940 |
# )
|
941 |
+
latents, init_latents = self.prepare_correlated_latents(
|
942 |
init_image,
|
943 |
init_image_strength,
|
944 |
init_noise_correlation,
|
|
|
951 |
device,
|
952 |
generator,
|
953 |
)
|
954 |
+
print(type(latents), hasattr(latents, 'shape') and latents.shape)
|
955 |
+
|
956 |
|
957 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
958 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|