Update pipeline.py
Browse files- pipeline.py +203 -6
pipeline.py
CHANGED
@@ -583,6 +583,185 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
583 |
|
584 |
return latents
|
585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
@torch.no_grad()
|
587 |
# @replace_example_docstring(EXAMPLE_DOC_STRING)
|
588 |
def __call__(
|
@@ -614,6 +793,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
614 |
x_velocity: Optional[float] = 0,
|
615 |
y_velocity: Optional[float] = 0,
|
616 |
scale_velocity: Optional[float] = 0,
|
|
|
|
|
|
|
617 |
):
|
618 |
r"""
|
619 |
The call function to the pipeline for generation.
|
@@ -753,8 +935,26 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
753 |
# generator,
|
754 |
# latents,
|
755 |
# )
|
756 |
-
latents = self.prepare_motion_latents(
|
757 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
num_channels_latents,
|
759 |
num_frames,
|
760 |
height,
|
@@ -762,10 +962,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
762 |
prompt_embeds.dtype,
|
763 |
device,
|
764 |
generator,
|
765 |
-
latents,
|
766 |
-
x_velocity=x_velocity,
|
767 |
-
y_velocity=y_velocity,
|
768 |
-
scale_velocity=scale_velocity,
|
769 |
)
|
770 |
|
771 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
|
583 |
|
584 |
return latents
|
585 |
|
586 |
+
def generate_correlated_noise(self, latents, init_noise_correlation):
|
587 |
+
cloned_latents = latents.clone()
|
588 |
+
p = init_noise_correlation
|
589 |
+
flattened_latents = torch.flatten(cloned_latents)
|
590 |
+
noise = torch.randn_like(flattened_latents)
|
591 |
+
correlated_noise = flattened_latents * p + math.sqrt(1 - p**2) * noise
|
592 |
+
|
593 |
+
return correlated_noise.reshape(cloned_latents.shape)
|
594 |
+
|
595 |
+
def generate_correlated_latents(self, latents, init_noise_correlation):
|
596 |
+
cloned_latents = latents.clone()
|
597 |
+
for i in range(1, cloned_latents.shape[2]):
|
598 |
+
p = init_noise_correlation
|
599 |
+
flattened_latents = torch.flatten(cloned_latents[:, :, i])
|
600 |
+
prev_flattened_latents = torch.flatten(cloned_latents[:, :, i - 1])
|
601 |
+
correlated_latents = (
|
602 |
+
prev_flattened_latents * p/math.sqrt((1+p**2))
|
603 |
+
+
|
604 |
+
flattened_latents * math.sqrt(1/(1 + p**2))
|
605 |
+
)
|
606 |
+
cloned_latents[:, :, i] = correlated_latents.reshape(
|
607 |
+
cloned_latents[:, :, i].shape
|
608 |
+
)
|
609 |
+
|
610 |
+
return cloned_latents
|
611 |
+
|
612 |
+
def generate_correlated_latents_legacy(self, latents, init_noise_correlation):
|
613 |
+
cloned_latents = latents.clone()
|
614 |
+
for i in range(1, cloned_latents.shape[2]):
|
615 |
+
p = init_noise_correlation
|
616 |
+
flattened_latents = torch.flatten(cloned_latents[:, :, i])
|
617 |
+
prev_flattened_latents = torch.flatten(cloned_latents[:, :, i - 1])
|
618 |
+
correlated_latents = (
|
619 |
+
prev_flattened_latents * p
|
620 |
+
+
|
621 |
+
flattened_latents * math.sqrt(1 - p**2)
|
622 |
+
)
|
623 |
+
cloned_latents[:, :, i] = correlated_latents.reshape(
|
624 |
+
cloned_latents[:, :, i].shape
|
625 |
+
)
|
626 |
+
|
627 |
+
return cloned_latents
|
628 |
+
|
629 |
+
def generate_mixed_noise(self, noise, init_noise_correlation):
|
630 |
+
shared_noise = torch.randn_like(noise[0, :, 0])
|
631 |
+
for b in range(noise.shape[0]):
|
632 |
+
for f in range(noise.shape[2]):
|
633 |
+
p = init_noise_correlation
|
634 |
+
flattened_latents = torch.flatten(noise[b, :, f])
|
635 |
+
shared_latents = torch.flatten(shared_noise)
|
636 |
+
correlated_latents = (
|
637 |
+
shared_latents * math.sqrt(p**2/(1+p**2)) +
|
638 |
+
flattened_latents * math.sqrt(1/(1+p**2))
|
639 |
+
)
|
640 |
+
noise[b, :, f] = correlated_latents.reshape(noise[b, :, f].shape)
|
641 |
+
|
642 |
+
return noise
|
643 |
+
|
644 |
+
def prepare_correlated_latents(
|
645 |
+
self,
|
646 |
+
init_image,
|
647 |
+
init_image_strength,
|
648 |
+
init_noise_correlation,
|
649 |
+
batch_size,
|
650 |
+
num_channels_latents,
|
651 |
+
video_length,
|
652 |
+
height,
|
653 |
+
width,
|
654 |
+
dtype,
|
655 |
+
device,
|
656 |
+
generator,
|
657 |
+
latents=None,
|
658 |
+
):
|
659 |
+
shape = (
|
660 |
+
batch_size,
|
661 |
+
num_channels_latents,
|
662 |
+
video_length,
|
663 |
+
height // self.vae_scale_factor,
|
664 |
+
width // self.vae_scale_factor,
|
665 |
+
)
|
666 |
+
|
667 |
+
if init_image is not None:
|
668 |
+
start_image = (
|
669 |
+
(
|
670 |
+
torchvision.transforms.functional.pil_to_tensor(
|
671 |
+
PIL.Image.open(init_image).resize((width, height))
|
672 |
+
)
|
673 |
+
/ 255
|
674 |
+
)[:3, :, :]
|
675 |
+
.to("cuda")
|
676 |
+
.to(torch.bfloat16)
|
677 |
+
.unsqueeze(0)
|
678 |
+
)
|
679 |
+
start_image = (
|
680 |
+
self.vae.encode(start_image.mul(2).sub(1))
|
681 |
+
.latent_dist.sample()
|
682 |
+
.view(1, 4, height // 8, width // 8)
|
683 |
+
* 0.18215
|
684 |
+
)
|
685 |
+
init_latents = start_image.unsqueeze(2).repeat(1, 1, video_length, 1, 1)
|
686 |
+
else:
|
687 |
+
init_latents = None
|
688 |
+
|
689 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
690 |
+
raise ValueError(
|
691 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
692 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
693 |
+
)
|
694 |
+
if latents is None:
|
695 |
+
rand_device = "cpu" if device.type == "mps" else device
|
696 |
+
if isinstance(generator, list):
|
697 |
+
shape = shape
|
698 |
+
# shape = (1,) + shape[1:]
|
699 |
+
# ignore init latents for batch model
|
700 |
+
latents = [
|
701 |
+
torch.randn(
|
702 |
+
shape, generator=generator[i], device=rand_device, dtype=dtype
|
703 |
+
)
|
704 |
+
for i in range(batch_size)
|
705 |
+
]
|
706 |
+
latents = torch.cat(latents, dim=0).to(device)
|
707 |
+
else:
|
708 |
+
if init_latents is not None:
|
709 |
+
offset = int(
|
710 |
+
init_image_strength * (len(self.scheduler.timesteps) - 1)
|
711 |
+
)
|
712 |
+
noise = torch.randn_like(init_latents)
|
713 |
+
noise = self.generate_correlated_latents(
|
714 |
+
noise, init_noise_correlation
|
715 |
+
)
|
716 |
+
|
717 |
+
# Eric - some black magic here
|
718 |
+
# We should be only adding the noise at timestep[offset], but I noticed that
|
719 |
+
# we get more motion and cooler motion if we add the noise at timestep[offset - 1]
|
720 |
+
# or offset - 2. However, this breaks the fewer timesteps there are, so let's interpolate
|
721 |
+
timesteps = self.scheduler.timesteps
|
722 |
+
average_timestep = None
|
723 |
+
if offset == 0:
|
724 |
+
average_timestep = timesteps[0]
|
725 |
+
elif offset == 1:
|
726 |
+
average_timestep = (
|
727 |
+
timesteps[offset - 1] * (1 - init_image_strength)
|
728 |
+
+ timesteps[offset] * init_image_strength
|
729 |
+
)
|
730 |
+
else:
|
731 |
+
average_timestep = timesteps[offset - 1]
|
732 |
+
|
733 |
+
latents = self.scheduler.add_noise(
|
734 |
+
init_latents, noise, average_timestep.long()
|
735 |
+
)
|
736 |
+
|
737 |
+
latents = self.scheduler.add_noise(
|
738 |
+
latents, torch.randn_like(init_latents), timesteps[-2]
|
739 |
+
)
|
740 |
+
else:
|
741 |
+
latents = torch.randn(
|
742 |
+
shape, generator=generator, device=rand_device, dtype=dtype
|
743 |
+
).to(device)
|
744 |
+
latents = self.generate_correlated_latents(
|
745 |
+
latents, init_noise_correlation
|
746 |
+
)
|
747 |
+
else:
|
748 |
+
if latents.shape != shape:
|
749 |
+
raise ValueError(
|
750 |
+
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
751 |
+
)
|
752 |
+
latents = latents.to(device)
|
753 |
+
|
754 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
755 |
+
if init_latents is None:
|
756 |
+
latents = latents * self.scheduler.init_noise_sigma
|
757 |
+
elif self.unet.trained_initial_frames and init_latents is not None:
|
758 |
+
# we only want to use this as the first frame
|
759 |
+
init_latents[:, :, 1:] = torch.zeros_like(init_latents[:, :, 1:])
|
760 |
+
|
761 |
+
latents = latents.to(device)
|
762 |
+
return latents, init_latents
|
763 |
+
|
764 |
+
|
765 |
@torch.no_grad()
|
766 |
# @replace_example_docstring(EXAMPLE_DOC_STRING)
|
767 |
def __call__(
|
|
|
793 |
x_velocity: Optional[float] = 0,
|
794 |
y_velocity: Optional[float] = 0,
|
795 |
scale_velocity: Optional[float] = 0,
|
796 |
+
init_image: Optional[str] = None,
|
797 |
+
init_image_strength: Optional[float] = 1.0,
|
798 |
+
init_noise_correlation: Optional[float] = 0.0,
|
799 |
):
|
800 |
r"""
|
801 |
The call function to the pipeline for generation.
|
|
|
935 |
# generator,
|
936 |
# latents,
|
937 |
# )
|
938 |
+
# latents = self.prepare_motion_latents(
|
939 |
+
# batch_size * num_videos_per_prompt,
|
940 |
+
# num_channels_latents,
|
941 |
+
# num_frames,
|
942 |
+
# height,
|
943 |
+
# width,
|
944 |
+
# prompt_embeds.dtype,
|
945 |
+
# device,
|
946 |
+
# generator,
|
947 |
+
# latents,
|
948 |
+
# x_velocity=x_velocity,
|
949 |
+
# y_velocity=y_velocity,
|
950 |
+
# scale_velocity=scale_velocity,
|
951 |
+
# )
|
952 |
+
latents = self.prepare_correlated_latents(
|
953 |
+
self,
|
954 |
+
init_image,
|
955 |
+
init_image_strength,
|
956 |
+
init_noise_correlation,
|
957 |
+
batch_size,
|
958 |
num_channels_latents,
|
959 |
num_frames,
|
960 |
height,
|
|
|
962 |
prompt_embeds.dtype,
|
963 |
device,
|
964 |
generator,
|
965 |
+
latents=None,
|
|
|
|
|
|
|
966 |
)
|
967 |
|
968 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|